image_paintbyexample_pipeline.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import cv2
  4. import numpy as np
  5. import PIL
  6. import torch
  7. import torch.nn as nn
  8. import torchvision
  9. from einops import rearrange
  10. from PIL import Image
  11. from torch.utils.data._utils.collate import default_collate
  12. from torchvision.transforms import Resize
  13. from modelscope.metainfo import Pipelines
  14. from modelscope.models.cv.image_paintbyexample import \
  15. StablediffusionPaintbyexample
  16. from modelscope.outputs import OutputKeys
  17. from modelscope.pipelines.base import Input, Pipeline
  18. from modelscope.pipelines.builder import PIPELINES
  19. from modelscope.preprocessors.image import load_image
  20. from modelscope.utils.constant import Tasks
  21. from modelscope.utils.logger import get_logger
  22. logger = get_logger()
  23. @PIPELINES.register_module(
  24. Tasks.image_paintbyexample, module_name=Pipelines.image_paintbyexample)
  25. class ImagePaintbyexamplePipeline(Pipeline):
  26. def __init__(self, model: str, **kwargs):
  27. """
  28. model: model id on modelscope hub.
  29. """
  30. assert isinstance(model, str), 'model must be a single str'
  31. from paint_ldm.models.diffusion.plms import PLMSSampler
  32. super().__init__(model=model, auto_collate=False, **kwargs)
  33. self.sampler = PLMSSampler(self.model.model)
  34. self.start_code = None
  35. def get_tensor(self, normalize=True, toTensor=True):
  36. transform_list = []
  37. if toTensor:
  38. transform_list += [torchvision.transforms.ToTensor()]
  39. if normalize:
  40. transform_list += [
  41. torchvision.transforms.Normalize((0.5, 0.5, 0.5),
  42. (0.5, 0.5, 0.5))
  43. ]
  44. return torchvision.transforms.Compose(transform_list)
  45. def get_tensor_clip(self, normalize=True, toTensor=True):
  46. transform_list = []
  47. if toTensor:
  48. transform_list += [torchvision.transforms.ToTensor()]
  49. if normalize:
  50. transform_list += [
  51. torchvision.transforms.Normalize(
  52. (0.48145466, 0.4578275, 0.40821073),
  53. (0.26862954, 0.26130258, 0.27577711))
  54. ]
  55. return torchvision.transforms.Compose(transform_list)
  56. def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
  57. if isinstance(input['img'], str):
  58. image_name, mask_name, ref_name = input['img'], input[
  59. 'mask'], input['reference']
  60. img = load_image(image_name).resize((512, 512))
  61. ref = load_image(ref_name).resize((224, 224))
  62. mask = load_image(mask_name).resize((512, 512)).convert('L')
  63. elif isinstance(input['img'], PIL.Image.Image):
  64. img = input['img'].convert('RGB').resize((512, 512))
  65. ref = input['reference'].convert('RGB').resize((224, 224))
  66. mask = input['mask'].resize((512, 512)).convert('L')
  67. else:
  68. raise TypeError(
  69. 'input should be either str or PIL.Image, and both inputs should have the same type'
  70. )
  71. img = self.get_tensor()(img)
  72. img = img.unsqueeze(0)
  73. ref = self.get_tensor_clip()(ref)
  74. ref = ref.unsqueeze(0)
  75. mask = np.array(mask)[None, None]
  76. mask = 1 - mask.astype(np.float32) / 255.0
  77. mask[mask < 0.5] = 0
  78. mask[mask >= 0.5] = 1
  79. mask = torch.from_numpy(mask)
  80. inpaint_image = img * mask
  81. test_model_kwargs = {}
  82. test_model_kwargs['inpaint_mask'] = mask.to(self.device)
  83. test_model_kwargs['inpaint_image'] = inpaint_image.to(self.device)
  84. test_model_kwargs['ref_tensor'] = ref.to(self.device)
  85. return test_model_kwargs
  86. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  87. result = self.perform_inference(input)
  88. return {OutputKeys.OUTPUT_IMG: result}
  89. def perform_inference(self, test_model_kwargs):
  90. with torch.no_grad():
  91. with self.model.model.ema_scope():
  92. ref_tensor = test_model_kwargs['ref_tensor']
  93. uc = self.model.model.learnable_vector
  94. c = self.model.model.get_learned_conditioning(
  95. ref_tensor.to(torch.float32))
  96. c = self.model.model.proj_out(c)
  97. z_inpaint = self.model.model.encode_first_stage(
  98. test_model_kwargs['inpaint_image'])
  99. z_inpaint = self.model.model.get_first_stage_encoding(
  100. z_inpaint).detach()
  101. test_model_kwargs['inpaint_image'] = z_inpaint
  102. test_model_kwargs['inpaint_mask'] = Resize(
  103. [z_inpaint.shape[-2], z_inpaint.shape[-1]])(
  104. test_model_kwargs['inpaint_mask'])
  105. shape = [4, 512 // 8, 512 // 8]
  106. samples_ddim, _ = self.sampler.sample(
  107. S=50,
  108. conditioning=c,
  109. batch_size=1,
  110. shape=shape,
  111. verbose=False,
  112. unconditional_guidance_scale=5,
  113. unconditional_conditioning=uc,
  114. eta=0.0,
  115. x_T=self.start_code,
  116. test_model_kwargs=test_model_kwargs)
  117. x_samples_ddim = self.model.model.decode_first_stage(
  118. samples_ddim)
  119. x_samples_ddim = torch.clamp(
  120. (x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  121. x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3,
  122. 1).numpy()
  123. x_checked_image = x_samples_ddim
  124. x_checked_image_torch = torch.from_numpy(
  125. x_checked_image).permute(0, 3, 1, 2)[0]
  126. x_sample = 255. * rearrange(
  127. x_checked_image_torch.cpu().numpy(), 'c h w -> h w c')
  128. img = x_sample.astype(np.uint8)
  129. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  130. return img
  131. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  132. return inputs