anydoor_pipeline.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict
  4. import cv2
  5. import einops
  6. import numpy as np
  7. import requests
  8. import torch
  9. from PIL import Image
  10. from modelscope.metainfo import Pipelines
  11. from modelscope.models.cv.anydoor.cldm.ddim_hacked import DDIMSampler
  12. from modelscope.models.cv.anydoor.datasets.data_utils import (
  13. box2squre, box_in_box, expand_bbox, expand_image_mask, get_bbox_from_mask,
  14. pad_to_square, sobel)
  15. from modelscope.outputs import OutputKeys
  16. from modelscope.pipelines.base import Input, Pipeline
  17. from modelscope.pipelines.builder import PIPELINES
  18. from modelscope.preprocessors.image import load_image
  19. from modelscope.utils.constant import Tasks
  20. from modelscope.utils.logger import get_logger
  21. logger = get_logger()
  22. @PIPELINES.register_module(
  23. Tasks.image_to_image_generation, module_name=Pipelines.anydoor)
  24. class AnydoorPipeline(Pipeline):
  25. r""" AnyDoor Pipeline.
  26. Examples:
  27. >>> from modelscope.pipelines import pipeline
  28. >>> from modelscope.utils.constant import Tasks
  29. >>> from PIL import Image
  30. >>> ref_image = 'data/test/images/image_anydoor_fg.png'
  31. >>> ref_mask = 'data/test/images/image_anydoor_fg_mask.png'
  32. >>> bg_image = 'data/test/images/image_anydoor_bg.png'
  33. >>> bg_mask = 'data/test/images/image_anydoor_bg_mask.png'
  34. >>> anydoor_pipeline = pipeline(Tasks.image_to_image_generation, model='damo/AnyDoor')
  35. >>> out = anydoor_pipeline((ref_image, ref_mask, bg_image, bg_mask))
  36. >>> assert isinstance(out['output_img'], Image.Image)
  37. """
  38. def __init__(self, model: str, **kwargs):
  39. """
  40. use `model` to create a action detection pipeline for prediction
  41. Args:
  42. model: model id on modelscope hub.
  43. """
  44. super().__init__(model=model, **kwargs)
  45. model_ckpt = os.path.join(self.model.model_dir,
  46. self.cfg.model.model_path)
  47. self.model.load_state_dict(
  48. self._get_state_dict(model_ckpt, location='cuda'))
  49. self.ddim_sampler = DDIMSampler(self.model)
  50. @staticmethod
  51. def _get_state_dict(ckpt_path, location='cpu'):
  52. def get_state_dict(d):
  53. return d.get('state_dict', d)
  54. _, extension = os.path.splitext(ckpt_path)
  55. if extension.lower() == '.safetensors':
  56. import safetensors.torch
  57. state_dict = safetensors.torch.load_file(
  58. ckpt_path, device=location)
  59. else:
  60. state_dict = get_state_dict(
  61. torch.load(
  62. ckpt_path,
  63. map_location=torch.device(location),
  64. weights_only=True))
  65. state_dict = get_state_dict(state_dict)
  66. print(f'Loaded state_dict from [{ckpt_path}]')
  67. return state_dict
  68. def preprocess(self, inputs: Input) -> Dict[str, Any]:
  69. ref_image, ref_mask, tar_image, tar_mask = inputs
  70. ref_image = np.asarray(load_image(ref_image).convert('RGB'))
  71. ref_mask = np.where(
  72. np.asarray(load_image(ref_mask).convert('L')) > 128, 1,
  73. 0).astype(np.uint8)
  74. tar_image = np.asarray(load_image(tar_image).convert('RGB'))
  75. tar_mask = np.where(
  76. np.asarray(load_image(tar_mask).convert('L')) > 128, 1,
  77. 0).astype(np.uint8)
  78. # ========= Reference ===========
  79. # ref expand
  80. ref_box_yyxx = get_bbox_from_mask(ref_mask)
  81. # ref filter mask
  82. ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1)
  83. masked_ref_image = ref_image * ref_mask_3 + np.ones_like(
  84. ref_image) * 255 * (1 - ref_mask_3)
  85. y1, y2, x1, x2 = ref_box_yyxx
  86. masked_ref_image = masked_ref_image[y1:y2, x1:x2, :]
  87. ref_mask = ref_mask[y1:y2, x1:x2]
  88. ratio = np.random.randint(11, 15) / 10 # 11,13
  89. masked_ref_image, ref_mask = expand_image_mask(
  90. masked_ref_image, ref_mask, ratio=ratio)
  91. ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1)
  92. # to square and resize
  93. masked_ref_image = pad_to_square(
  94. masked_ref_image, pad_value=255, random=False)
  95. masked_ref_image = cv2.resize(
  96. masked_ref_image.astype(np.uint8), (224, 224)).astype(np.uint8)
  97. ref_mask_3 = pad_to_square(ref_mask_3 * 255, pad_value=0, random=False)
  98. ref_mask_3 = cv2.resize(ref_mask_3.astype(np.uint8),
  99. (224, 224)).astype(np.uint8)
  100. ref_mask = ref_mask_3[:, :, 0]
  101. # collage aug
  102. masked_ref_image_compose, ref_mask_compose = masked_ref_image, ref_mask
  103. ref_mask_3 = np.stack(
  104. [ref_mask_compose, ref_mask_compose, ref_mask_compose], -1)
  105. ref_image_collage = sobel(masked_ref_image_compose,
  106. ref_mask_compose / 255)
  107. # ========= Target ===========
  108. tar_box_yyxx = get_bbox_from_mask(tar_mask)
  109. tar_box_yyxx = expand_bbox(
  110. tar_mask, tar_box_yyxx, ratio=[1.1, 1.2]) # 1.1 1.3
  111. # crop
  112. tar_box_yyxx_crop = expand_bbox(
  113. tar_image, tar_box_yyxx, ratio=[1.3, 3.0])
  114. tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) # crop box
  115. y1, y2, x1, x2 = tar_box_yyxx_crop
  116. cropped_target_image = tar_image[y1:y2, x1:x2, :]
  117. cropped_tar_mask = tar_mask[y1:y2, x1:x2]
  118. tar_box_yyxx = box_in_box(tar_box_yyxx, tar_box_yyxx_crop)
  119. y1, y2, x1, x2 = tar_box_yyxx
  120. # collage
  121. ref_image_collage = cv2.resize(
  122. ref_image_collage.astype(np.uint8), (x2 - x1, y2 - y1))
  123. ref_mask_compose = cv2.resize(
  124. ref_mask_compose.astype(np.uint8), (x2 - x1, y2 - y1))
  125. ref_mask_compose = (ref_mask_compose > 128).astype(np.uint8)
  126. collage = cropped_target_image.copy()
  127. collage[y1:y2, x1:x2, :] = ref_image_collage
  128. collage_mask = cropped_target_image.copy() * 0.0
  129. collage_mask[y1:y2, x1:x2, :] = 1.0
  130. collage_mask = np.stack(
  131. [cropped_tar_mask, cropped_tar_mask, cropped_tar_mask], -1)
  132. # the size before pad
  133. H1, W1 = collage.shape[0], collage.shape[1]
  134. cropped_target_image = pad_to_square(
  135. cropped_target_image, pad_value=0, random=False).astype(np.uint8)
  136. collage = pad_to_square(
  137. collage, pad_value=0, random=False).astype(np.uint8)
  138. collage_mask = pad_to_square(
  139. collage_mask, pad_value=0, random=False).astype(np.uint8)
  140. # the size after pad
  141. H2, W2 = collage.shape[0], collage.shape[1]
  142. cropped_target_image = cv2.resize(
  143. cropped_target_image.astype(np.uint8),
  144. (512, 512)).astype(np.float32)
  145. collage = cv2.resize(collage.astype(np.uint8),
  146. (512, 512)).astype(np.float32)
  147. collage_mask = (cv2.resize(collage_mask.astype(
  148. np.uint8), (512, 512)).astype(np.float32) > 0.5).astype(np.float32)
  149. masked_ref_image = masked_ref_image / 255
  150. cropped_target_image = cropped_target_image / 127.5 - 1.0
  151. collage = collage / 127.5 - 1.0
  152. collage = np.concatenate([collage, collage_mask[:, :, :1]], -1)
  153. item = dict(
  154. tar_image=tar_image,
  155. ref=masked_ref_image.copy(),
  156. jpg=cropped_target_image.copy(),
  157. hint=collage.copy(),
  158. extra_sizes=np.array([H1, W1, H2, W2]),
  159. tar_box_yyxx_crop=np.array(tar_box_yyxx_crop))
  160. return item
  161. def forward(self,
  162. item: Dict[str, Any],
  163. num_samples=1,
  164. strength=1.0,
  165. ddim_steps=30,
  166. scale=3.0) -> Dict[str, Any]:
  167. tar_image = item['tar_image'].cpu().numpy()
  168. ref = item['ref']
  169. hint = item['hint']
  170. num_samples = 1
  171. control = hint.float().cuda()
  172. control = torch.stack([control for _ in range(num_samples)], dim=0)
  173. control = einops.rearrange(control, 'b h w c -> b c h w').clone()
  174. clip_input = ref.float().cuda()
  175. clip_input = torch.stack([clip_input for _ in range(num_samples)],
  176. dim=0)
  177. clip_input = einops.rearrange(clip_input, 'b h w c -> b c h w').clone()
  178. H, W = 512, 512
  179. cond = {
  180. 'c_concat': [control],
  181. 'c_crossattn': [self.model.get_learned_conditioning(clip_input)]
  182. }
  183. un_cond = {
  184. 'c_concat': [control],
  185. 'c_crossattn': [
  186. self.model.get_learned_conditioning(
  187. [torch.zeros((1, 3, 224, 224))] * num_samples)
  188. ]
  189. }
  190. shape = (4, H // 8, W // 8)
  191. self.model.control_scales = ([strength] * 13)
  192. samples, _ = self.ddim_sampler.sample(
  193. ddim_steps,
  194. num_samples,
  195. shape,
  196. cond,
  197. verbose=False,
  198. eta=0,
  199. unconditional_guidance_scale=scale,
  200. unconditional_conditioning=un_cond)
  201. x_samples = self.model.decode_first_stage(samples)
  202. x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5
  203. + 127.5).cpu().numpy()
  204. result = x_samples[0][:, :, ::-1]
  205. result = np.clip(result, 0, 255)
  206. pred = x_samples[0]
  207. pred = np.clip(pred, 0, 255)[1:, :, :]
  208. sizes = item['extra_sizes'].cpu().numpy()
  209. tar_box_yyxx_crop = item['tar_box_yyxx_crop'].cpu().numpy()
  210. return dict(
  211. pred=pred,
  212. tar_image=tar_image,
  213. sizes=sizes,
  214. tar_box_yyxx_crop=tar_box_yyxx_crop)
  215. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  216. pred = inputs['pred']
  217. tar_image = inputs['tar_image']
  218. extra_sizes = inputs['sizes']
  219. tar_box_yyxx_crop = inputs['tar_box_yyxx_crop']
  220. H1, W1, H2, W2 = extra_sizes
  221. y1, y2, x1, x2 = tar_box_yyxx_crop
  222. pred = cv2.resize(pred, (W2, H2))
  223. m = 3 # maigin_pixel
  224. if W1 == H1:
  225. tar_image[y1 + m:y2 - m, x1 + m:x2 - m, :] = pred[m:-m, m:-m]
  226. gen_image = torch.from_numpy(tar_image.copy()).permute(2, 0, 1)
  227. gen_image = gen_image.permute(1, 2, 0).numpy()
  228. gen_image = Image.fromarray(gen_image, mode='RGB')
  229. return {OutputKeys.OUTPUT_IMG: gen_image}
  230. if W1 < W2:
  231. pad1 = int((W2 - W1) / 2)
  232. pad2 = W2 - W1 - pad1
  233. pred = pred[:, pad1:-pad2, :]
  234. else:
  235. pad1 = int((H2 - H1) / 2)
  236. pad2 = H2 - H1 - pad1
  237. pred = pred[pad1:-pad2, :, :]
  238. gen_image = tar_image.copy()
  239. gen_image[y1 + m:y2 - m, x1 + m:x2 - m, :] = pred[m:-m, m:-m]
  240. gen_image = torch.from_numpy(gen_image).permute(2, 0, 1)
  241. gen_image = gen_image.permute(1, 2, 0).numpy()
  242. gen_image = Image.fromarray(gen_image, mode='RGB')
  243. return {OutputKeys.OUTPUT_IMG: gen_image}