image_editing_pipeline.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path
  3. from typing import Any, Dict, Optional, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from diffusers import DDIMScheduler, StableDiffusionPipeline
  8. from PIL import Image
  9. from torchvision import transforms
  10. from tqdm import tqdm
  11. from modelscope.metainfo import Pipelines
  12. from modelscope.models.cv.image_editing import (
  13. MutualSelfAttentionControl, register_attention_editor_diffusers)
  14. from modelscope.outputs import OutputKeys
  15. from modelscope.pipelines.builder import PIPELINES
  16. from modelscope.pipelines.multi_modal.diffusers_wrapped.diffusers_pipeline import \
  17. DiffusersPipeline
  18. from modelscope.preprocessors import LoadImage
  19. from modelscope.utils.constant import Tasks
  20. from modelscope.utils.logger import get_logger
  21. logger = get_logger()
  22. __all__ = ['ImageEditingPipeline']
  23. @PIPELINES.register_module(
  24. Tasks.image_editing, module_name=Pipelines.image_editing)
  25. class ImageEditingPipeline(DiffusersPipeline):
  26. def __init__(self, model=str, preprocessor=None, **kwargs):
  27. """ MasaCtrl Image Editing Pipeline.
  28. Examples:
  29. >>> import cv2
  30. >>> from modelscope.pipelines import pipeline
  31. >>> from modelscope.utils.constant import Tasks
  32. >>> prompts = [
  33. >>> "", # source prompt
  34. >>> "a photo of a running corgi" # target prompt
  35. >>> ]
  36. >>> output_image_path = './result.png'
  37. >>> img = 'https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/public/ModelScope/test/images/corgi.jpg'
  38. >>> input = {'img': img, 'prompts': prompts}
  39. >>>
  40. >>> pipe = pipeline(
  41. >>> Tasks.image_editing,
  42. >>> model='damo/cv_masactrl_image-editing')
  43. >>>
  44. >>> output = pipe(input)['output_img']
  45. >>> cv2.imwrite(output_image_path, output)
  46. >>> print('pipeline: the output image path is {}'.format(output_image_path))
  47. """
  48. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  49. torch_dtype = kwargs.get('torch_dtype', torch.float32)
  50. self._device = getattr(
  51. kwargs, 'device',
  52. torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
  53. logger.info('load image editing pipeline done')
  54. scheduler = DDIMScheduler.from_pretrained(
  55. os.path.join(model, 'stable-diffusion-v1-4'),
  56. subfolder='scheduler')
  57. self.pipeline = _MasaCtrlPipeline.from_pretrained(
  58. os.path.join(model, 'stable-diffusion-v1-4'),
  59. scheduler=scheduler,
  60. torch_dtype=torch_dtype,
  61. use_safetensors=True).to(self._device)
  62. def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
  63. img = LoadImage.convert_to_img(input.get('img'))
  64. test_transforms = transforms.Compose(
  65. [transforms.ToTensor(),
  66. transforms.Normalize([0.5], [0.5])]) # [-1, 1]
  67. img = test_transforms(img).unsqueeze(0)
  68. img = F.interpolate(img, (512, 512))
  69. input['img'] = img.to(self._device)
  70. return input
  71. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  72. if not isinstance(input, dict):
  73. raise ValueError(
  74. f'Expected the input to be a dictionary, but got {type(input)}'
  75. )
  76. prompts = input.get('prompts')
  77. start_code, latents_list = self.pipeline.invert(
  78. input.get('img'),
  79. prompts[0],
  80. guidance_scale=7.5,
  81. num_inference_steps=50,
  82. return_intermediates=True)
  83. start_code = start_code.expand(len(prompts), -1, -1, -1)
  84. STEP, LAYER = 4, 10
  85. editor = MutualSelfAttentionControl(STEP, LAYER)
  86. register_attention_editor_diffusers(self.pipeline, editor)
  87. # inference the synthesized image
  88. output = self.pipeline(
  89. prompts,
  90. latents=start_code,
  91. guidance_scale=input.get('guidance_scale', 7.5),
  92. )[-1:]
  93. return {'output_tensor': output}
  94. def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
  95. output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
  96. 1, 2, 0).numpy().astype('uint8')
  97. return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]}
  98. class _MasaCtrlPipeline(StableDiffusionPipeline):
  99. def next_step(
  100. self,
  101. model_output: torch.FloatTensor,
  102. timestep: int,
  103. x: torch.FloatTensor,
  104. eta=0,
  105. verbose=False,
  106. ):
  107. """
  108. Inverse sampling for DDIM Inversion
  109. x_t -> x_(t+1)
  110. """
  111. if verbose:
  112. print('timestep: ', timestep)
  113. next_step = timestep
  114. timestep = min(
  115. timestep - self.scheduler.config.num_train_timesteps
  116. // self.scheduler.num_inference_steps, 999)
  117. alpha_prod_t = self.scheduler.alphas_cumprod[
  118. timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
  119. alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
  120. beta_prod_t = 1 - alpha_prod_t
  121. pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
  122. pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
  123. x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
  124. return x_next, pred_x0
  125. def step(
  126. self,
  127. model_output: torch.FloatTensor,
  128. timestep: int,
  129. x: torch.FloatTensor,
  130. eta: float = 0.0,
  131. verbose=False,
  132. ):
  133. """
  134. predict the sample the next step in the denoise process.
  135. x_t -> x_(t-1)
  136. """
  137. prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
  138. alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
  139. alpha_prod_t_prev = self.scheduler.alphas_cumprod[
  140. prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
  141. beta_prod_t = 1 - alpha_prod_t
  142. pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
  143. pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
  144. x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
  145. return x_prev, pred_x0
  146. @torch.no_grad()
  147. def image2latent(self, image):
  148. DEVICE = self._execution_device
  149. if type(image) is Image:
  150. image = np.array(image)
  151. image = torch.from_numpy(image).float() / 127.5 - 1
  152. image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
  153. # input image density range [-1, 1]
  154. latents = self.vae.encode(image)['latent_dist'].mean
  155. latents = latents * 0.18215
  156. return latents
  157. @torch.no_grad()
  158. def latent2image(self, latents, return_type='pt'):
  159. latents = 1 / 0.18215 * latents.detach()
  160. image = self.vae.decode(latents)['sample']
  161. if return_type == 'np':
  162. image = (image / 2 + 0.5).clamp(0, 1)
  163. image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
  164. image = (image * 255).astype(np.uint8)
  165. elif return_type == 'pt':
  166. image = (image / 2 + 0.5).clamp(0, 1)
  167. return image
  168. @torch.no_grad()
  169. def __call__(self,
  170. prompt,
  171. batch_size=1,
  172. height=512,
  173. width=512,
  174. num_inference_steps=50,
  175. guidance_scale=7.5,
  176. eta=0.0,
  177. latents=None,
  178. unconditioning=None,
  179. neg_prompt=None,
  180. ref_intermediate_latents=None,
  181. return_intermediates=False,
  182. **kwds):
  183. DEVICE = self._execution_device
  184. if isinstance(prompt, list):
  185. batch_size = len(prompt)
  186. elif isinstance(prompt, str):
  187. if batch_size > 1:
  188. prompt = [prompt] * batch_size
  189. # text embeddings
  190. text_input = self.tokenizer(
  191. prompt, padding='max_length', max_length=77, return_tensors='pt')
  192. text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
  193. print('input text embeddings :', text_embeddings.shape)
  194. # define initial latents
  195. latents_shape = (batch_size, self.unet.in_channels, height // 8,
  196. width // 8)
  197. if latents is None:
  198. latents = torch.randn(latents_shape, device=DEVICE)
  199. else:
  200. assert latents.shape == latents_shape, f'The shape of input latent tensor {latents.shape} should equal ' \
  201. f'to predefined one.'
  202. # unconditional embedding for classifier free guidance
  203. if guidance_scale > 1.:
  204. if neg_prompt:
  205. uc_text = neg_prompt
  206. else:
  207. uc_text = ''
  208. unconditional_input = self.tokenizer(
  209. [uc_text] * batch_size,
  210. padding='max_length',
  211. max_length=77,
  212. return_tensors='pt')
  213. unconditional_embeddings = self.text_encoder(
  214. unconditional_input.input_ids.to(DEVICE))[0]
  215. text_embeddings = torch.cat(
  216. [unconditional_embeddings, text_embeddings], dim=0)
  217. print('latents shape: ', latents.shape)
  218. # iterative sampling
  219. self.scheduler.set_timesteps(num_inference_steps)
  220. latents_list = [latents]
  221. pred_x0_list = [latents]
  222. for i, t in enumerate(
  223. tqdm(self.scheduler.timesteps, desc='DDIM Sampler')):
  224. if ref_intermediate_latents is not None:
  225. # note that the batch_size >= 2
  226. latents_ref = ref_intermediate_latents[-1 - i]
  227. _, latents_cur = latents.chunk(2)
  228. latents = torch.cat([latents_ref, latents_cur])
  229. if guidance_scale > 1.:
  230. model_inputs = torch.cat([latents] * 2)
  231. else:
  232. model_inputs = latents
  233. if unconditioning is not None and isinstance(unconditioning, list):
  234. _, text_embeddings = text_embeddings.chunk(2)
  235. text_embeddings = torch.cat([
  236. unconditioning[i].expand(*text_embeddings.shape),
  237. text_embeddings
  238. ])
  239. # predict the noise
  240. noise_pred = self.unet(
  241. model_inputs, t, encoder_hidden_states=text_embeddings).sample
  242. if guidance_scale > 1.:
  243. noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
  244. noise_pred = noise_pred_uncon + guidance_scale * (
  245. noise_pred_con - noise_pred_uncon)
  246. # compute the previous noise sample x_t -> x_t-1
  247. latents, pred_x0 = self.step(noise_pred, t, latents)
  248. latents_list.append(latents)
  249. pred_x0_list.append(pred_x0)
  250. image = self.latent2image(latents, return_type='pt')
  251. if return_intermediates:
  252. pred_x0_list = [
  253. self.latent2image(img, return_type='pt')
  254. for img in pred_x0_list
  255. ]
  256. latents_list = [
  257. self.latent2image(img, return_type='pt')
  258. for img in latents_list
  259. ]
  260. return image, pred_x0_list, latents_list
  261. return image
  262. @torch.no_grad()
  263. def invert(self,
  264. image: torch.Tensor,
  265. prompt,
  266. num_inference_steps=50,
  267. guidance_scale=7.5,
  268. eta=0.0,
  269. return_intermediates=False,
  270. **kwds):
  271. """
  272. invert a real image into noise map with determinisc DDIM inversion
  273. """
  274. DEVICE = self._execution_device
  275. batch_size = image.shape[0]
  276. if isinstance(prompt, list):
  277. if batch_size == 1:
  278. image = image.expand(len(prompt), -1, -1, -1)
  279. elif isinstance(prompt, str):
  280. if batch_size > 1:
  281. prompt = [prompt] * batch_size
  282. # text embeddings
  283. text_input = self.tokenizer(
  284. prompt, padding='max_length', max_length=77, return_tensors='pt')
  285. text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
  286. print('input text embeddings :', text_embeddings.shape)
  287. # define initial latents
  288. latents = self.image2latent(image)
  289. start_latents = latents
  290. # unconditional embedding for classifier free guidance
  291. if guidance_scale > 1.:
  292. unconditional_input = self.tokenizer(
  293. [''] * batch_size,
  294. padding='max_length',
  295. max_length=77,
  296. return_tensors='pt')
  297. unconditional_embeddings = self.text_encoder(
  298. unconditional_input.input_ids.to(DEVICE))[0]
  299. text_embeddings = torch.cat(
  300. [unconditional_embeddings, text_embeddings], dim=0)
  301. print('latents shape: ', latents.shape)
  302. self.scheduler.set_timesteps(num_inference_steps)
  303. print('Valid timesteps: ', reversed(self.scheduler.timesteps))
  304. latents_list = [latents]
  305. pred_x0_list = [latents]
  306. for i, t in enumerate(
  307. tqdm(
  308. reversed(self.scheduler.timesteps),
  309. desc='DDIM Inversion')):
  310. if guidance_scale > 1.:
  311. model_inputs = torch.cat([latents] * 2)
  312. else:
  313. model_inputs = latents
  314. # predict the noise
  315. noise_pred = self.unet(
  316. model_inputs, t, encoder_hidden_states=text_embeddings).sample
  317. if guidance_scale > 1.:
  318. noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
  319. noise_pred = noise_pred_uncon + guidance_scale * (
  320. noise_pred_con - noise_pred_uncon)
  321. # compute the previous noise sample x_t-1 -> x_t
  322. latents, pred_x0 = self.next_step(noise_pred, t, latents)
  323. latents_list.append(latents)
  324. pred_x0_list.append(pred_x0)
  325. if return_intermediates:
  326. return latents, latents_list
  327. return latents, start_latents