model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import json
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from modelscope.metainfo import Models
  10. from modelscope.models import Model
  11. from modelscope.models.builder import MODELS
  12. from modelscope.models.multi_modal.diffusion.diffusion import (
  13. GaussianDiffusion, beta_schedule)
  14. from modelscope.models.multi_modal.diffusion.structbert import (BertConfig,
  15. BertModel)
  16. from modelscope.models.multi_modal.diffusion.tokenizer import FullTokenizer
  17. from modelscope.models.multi_modal.diffusion.unet_generator import \
  18. DiffusionGenerator
  19. from modelscope.models.multi_modal.diffusion.unet_upsampler_256 import \
  20. SuperResUNet256
  21. from modelscope.models.multi_modal.diffusion.unet_upsampler_1024 import \
  22. SuperResUNet1024
  23. from modelscope.utils.constant import ModelFile, Tasks
  24. from modelscope.utils.device import create_device
  25. from modelscope.utils.logger import get_logger
  26. logger = get_logger()
  27. __all__ = ['DiffusionForTextToImageSynthesis']
  28. def make_diffusion(schedule,
  29. num_timesteps=1000,
  30. init_beta=None,
  31. last_beta=None,
  32. var_type='fixed_small'):
  33. betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta)
  34. diffusion = GaussianDiffusion(betas, var_type=var_type)
  35. return diffusion
  36. class Tokenizer(object):
  37. def __init__(self, vocab_file, seq_len=64):
  38. self.vocab_file = vocab_file
  39. self.seq_len = seq_len
  40. self.tokenizer = FullTokenizer(
  41. vocab_file=vocab_file, do_lower_case=True)
  42. def __call__(self, text):
  43. # tokenization
  44. tokens = self.tokenizer.tokenize(text)
  45. tokens = ['[CLS]'] + tokens[:self.seq_len - 2] + ['[SEP]']
  46. input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
  47. input_mask = [1] * len(input_ids)
  48. segment_ids = [0] * len(input_ids)
  49. # padding
  50. input_ids += [0] * (self.seq_len - len(input_ids))
  51. input_mask += [0] * (self.seq_len - len(input_mask))
  52. segment_ids += [0] * (self.seq_len - len(segment_ids))
  53. assert len(input_ids) == len(input_mask) == len(
  54. segment_ids) == self.seq_len
  55. # convert to tensors
  56. input_ids = torch.LongTensor(input_ids)
  57. input_mask = torch.LongTensor(input_mask)
  58. segment_ids = torch.LongTensor(segment_ids)
  59. return input_ids, segment_ids, input_mask
  60. class DiffusionModel(nn.Module):
  61. def __init__(self, model_dir):
  62. super(DiffusionModel, self).__init__()
  63. # including text and generator config
  64. model_config = json.load(
  65. open('{}/model_config.json'.format(model_dir), encoding='utf-8'))
  66. # text encoder
  67. text_config = model_config['text_config']
  68. self.text_encoder = BertModel(BertConfig.from_dict(text_config))
  69. # generator (64x64)
  70. generator_config = model_config['generator_config']
  71. self.unet_generator = DiffusionGenerator(**generator_config)
  72. # upsampler (256x256)
  73. upsampler_256_config = model_config['upsampler_256_config']
  74. self.unet_upsampler_256 = SuperResUNet256(**upsampler_256_config)
  75. # upsampler (1024x1024)
  76. upsampler_1024_config = model_config['upsampler_1024_config']
  77. self.unet_upsampler_1024 = SuperResUNet1024(**upsampler_1024_config)
  78. def forward(self, noise, timesteps, input_ids, token_type_ids,
  79. attention_mask):
  80. context, y = self.text_encoder(
  81. input_ids=input_ids,
  82. token_type_ids=token_type_ids,
  83. attention_mask=attention_mask)
  84. context = context[-1]
  85. x = self.unet_generator(noise, timesteps, y, context, attention_mask)
  86. x = self.unet_upsampler_256(noise, timesteps, x,
  87. torch.zeros_like(timesteps), y, context,
  88. attention_mask)
  89. x = self.unet_upsampler_1024(x, t, x)
  90. return x
  91. @MODELS.register_module(
  92. Tasks.text_to_image_synthesis, module_name=Models.diffusion)
  93. class DiffusionForTextToImageSynthesis(Model):
  94. def __init__(self, model_dir, device='gpu', **kwargs):
  95. device = 'gpu' if torch.cuda.is_available() else 'cpu'
  96. super().__init__(model_dir=model_dir, device=device, **kwargs)
  97. diffusion_model = DiffusionModel(model_dir=model_dir)
  98. pretrained_params = torch.load(
  99. osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu')
  100. diffusion_model.load_state_dict(pretrained_params)
  101. diffusion_model.eval().to()
  102. self.device = create_device(device)
  103. diffusion_model.to(self.device)
  104. # modules
  105. self.text_encoder = diffusion_model.text_encoder
  106. self.unet_generator = diffusion_model.unet_generator
  107. self.unet_upsampler_256 = diffusion_model.unet_upsampler_256
  108. self.unet_upsampler_1024 = diffusion_model.unet_upsampler_1024
  109. # text tokenizer
  110. vocab_path = f'{model_dir}/{ModelFile.VOCAB_FILE}'
  111. self.tokenizer = Tokenizer(vocab_file=vocab_path, seq_len=64)
  112. # diffusion process
  113. diffusion_params = json.load(
  114. open(
  115. '{}/diffusion_config.json'.format(model_dir),
  116. encoding='utf-8'))
  117. self.diffusion_generator = make_diffusion(
  118. **diffusion_params['generator_config'])
  119. self.diffusion_upsampler_256 = make_diffusion(
  120. **diffusion_params['upsampler_256_config'])
  121. self.diffusion_upsampler_1024 = make_diffusion(
  122. **diffusion_params['upsampler_1024_config'])
  123. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  124. if not all([key in input for key in ('text', 'noise', 'timesteps')]):
  125. raise ValueError(
  126. f'input should contains "text", "noise", and "timesteps", but got {input.keys()}'
  127. )
  128. input_ids, token_type_ids, attention_mask = self.tokenizer(
  129. input['text'])
  130. input_ids = input_ids.to(self.device).unsqueeze(0)
  131. token_type_ids = token_type_ids.to(self.device).unsqueeze(0)
  132. attention_mask = attention_mask.to(self.device).unsqueeze(0)
  133. context, y = self.text_encoder(
  134. input_ids=input_ids,
  135. token_type_ids=token_type_ids,
  136. attention_mask=attention_mask)
  137. context = context[-1]
  138. x = self.unet_generator(noise, timesteps, y, context, attention_mask)
  139. x = self.unet_upsampler_256(noise, timesteps, x,
  140. torch.zeros_like(timesteps), y, context,
  141. attention_mask)
  142. x = self.unet_upsampler_1024(x, t, x)
  143. img = x.clamp(-1, 1).add(1).mul(127.5)
  144. img = img.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
  145. return img
  146. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  147. return inputs
  148. @torch.no_grad()
  149. def generate(self, input: Dict[str, Any]) -> Dict[str, Any]:
  150. if 'text' not in input:
  151. raise ValueError(
  152. f'input should contain "text", but got {input.keys()}')
  153. # encode text
  154. input_ids, token_type_ids, attention_mask = self.tokenizer(
  155. input['text'])
  156. input_ids = input_ids.to(self.device).unsqueeze(0)
  157. token_type_ids = token_type_ids.to(self.device).unsqueeze(0)
  158. attention_mask = attention_mask.to(self.device).unsqueeze(0)
  159. context, y = self.text_encoder(
  160. input_ids=input_ids,
  161. token_type_ids=token_type_ids,
  162. attention_mask=attention_mask)
  163. context = context[-1]
  164. # choose a proper solver
  165. solver = input.get('solver', 'dpm-solver')
  166. if solver == 'dpm-solver':
  167. # generation
  168. img = self.diffusion_generator.dpm_solver_sample_loop(
  169. noise=torch.randn(1, 3, 64, 64).to(self.device),
  170. model=self.unet_generator,
  171. model_kwargs=[{
  172. 'y': y,
  173. 'context': context,
  174. 'mask': attention_mask
  175. }, {
  176. 'y': torch.zeros_like(y),
  177. 'context': torch.zeros_like(context),
  178. 'mask': attention_mask
  179. }],
  180. percentile=input.get('generator_percentile', 0.995),
  181. guide_scale=input.get('generator_guide_scale', 5.0),
  182. dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20),
  183. order=3,
  184. skip_type='logSNR',
  185. method='singlestep',
  186. t_start=0.9946)
  187. # upsampling (64->256)
  188. if not input.get('debug', False):
  189. img = F.interpolate(
  190. img,
  191. scale_factor=4.0,
  192. mode='bilinear',
  193. align_corners=False)
  194. img = self.diffusion_upsampler_256.dpm_solver_sample_loop(
  195. noise=torch.randn_like(img),
  196. model=self.unet_upsampler_256,
  197. model_kwargs=[{
  198. 'lx': img,
  199. 'lt': torch.zeros(1).to(self.device),
  200. 'y': y,
  201. 'context': context,
  202. 'mask': attention_mask
  203. }, {
  204. 'lx': img,
  205. 'lt': torch.zeros(1).to(self.device),
  206. 'y': torch.zeros_like(y),
  207. 'context': torch.zeros_like(context),
  208. 'mask': torch.zeros_like(attention_mask)
  209. }],
  210. percentile=input.get('upsampler_256_percentile', 0.995),
  211. guide_scale=input.get('upsampler_256_guide_scale', 5.0),
  212. dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20),
  213. order=3,
  214. skip_type='logSNR',
  215. method='singlestep',
  216. t_start=0.9946)
  217. # upsampling (256->1024)
  218. if not input.get('debug', False):
  219. img = F.interpolate(
  220. img,
  221. scale_factor=4.0,
  222. mode='bilinear',
  223. align_corners=False)
  224. img = self.diffusion_upsampler_1024.dpm_solver_sample_loop(
  225. noise=torch.randn_like(img),
  226. model=self.unet_upsampler_256,
  227. model_kwargs=[{
  228. 'lx': img,
  229. 'lt': torch.zeros(1).to(self.device),
  230. 'y': y,
  231. 'context': context,
  232. 'mask': attention_mask
  233. }, {
  234. 'lx': img,
  235. 'lt': torch.zeros(1).to(self.device),
  236. 'y': torch.zeros_like(y),
  237. 'context': torch.zeros_like(context),
  238. 'mask': torch.zeros_like(attention_mask)
  239. }],
  240. percentile=input.get('upsampler_256_percentile', 0.995),
  241. guide_scale=input.get('upsampler_256_guide_scale', 5.0),
  242. dpm_solver_timesteps=input.get('dpm_solver_timesteps', 10),
  243. order=3,
  244. skip_type='logSNR',
  245. method='singlestep',
  246. t_start=None)
  247. elif solver == 'ddim':
  248. # generation
  249. img = self.diffusion_generator.ddim_sample_loop(
  250. noise=torch.randn(1, 3, 64, 64).to(self.device),
  251. model=self.unet_generator,
  252. model_kwargs=[{
  253. 'y': y,
  254. 'context': context,
  255. 'mask': attention_mask
  256. }, {
  257. 'y': torch.zeros_like(y),
  258. 'context': torch.zeros_like(context),
  259. 'mask': attention_mask
  260. }],
  261. percentile=input.get('generator_percentile', 0.995),
  262. guide_scale=input.get('generator_guide_scale', 5.0),
  263. ddim_timesteps=input.get('generator_ddim_timesteps', 250),
  264. eta=input.get('generator_ddim_eta', 0.0))
  265. # upsampling (64->256)
  266. if not input.get('debug', False):
  267. img = F.interpolate(
  268. img,
  269. scale_factor=4.0,
  270. mode='bilinear',
  271. align_corners=False)
  272. img = self.diffusion_upsampler_256.ddim_sample_loop(
  273. noise=torch.randn_like(img),
  274. model=self.unet_upsampler_256,
  275. model_kwargs=[{
  276. 'lx': img,
  277. 'lt': torch.zeros(1).to(self.device),
  278. 'y': y,
  279. 'context': context,
  280. 'mask': attention_mask
  281. }, {
  282. 'lx': img,
  283. 'lt': torch.zeros(1).to(self.device),
  284. 'y': torch.zeros_like(y),
  285. 'context': torch.zeros_like(context),
  286. 'mask': torch.zeros_like(attention_mask)
  287. }],
  288. percentile=input.get('upsampler_256_percentile', 0.995),
  289. guide_scale=input.get('upsampler_256_guide_scale', 5.0),
  290. ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50),
  291. eta=input.get('upsampler_256_ddim_eta', 0.0))
  292. # upsampling (256->1024)
  293. if not input.get('debug', False):
  294. img = F.interpolate(
  295. img,
  296. scale_factor=4.0,
  297. mode='bilinear',
  298. align_corners=False)
  299. img = self.diffusion_upsampler_1024.ddim_sample_loop(
  300. noise=torch.randn_like(img),
  301. model=self.unet_upsampler_1024,
  302. model_kwargs={'concat': img},
  303. percentile=input.get('upsampler_1024_percentile', 0.995),
  304. ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20),
  305. eta=input.get('upsampler_1024_ddim_eta', 0.0))
  306. else:
  307. raise ValueError(
  308. 'currently only supports "ddim" and "dpm-solve" solvers')
  309. # output
  310. img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute(
  311. 1, 2, 0).cpu().numpy().astype(np.uint8)
  312. return img