diffusion.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. # Part of the implementation is borrowed and modified from latent-diffusion,
  2. # publicly available at https://github.com/CompVis/latent-diffusion.
  3. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  4. import math
  5. import torch
  6. from modelscope.models.multi_modal.dpm_solver_pytorch import (
  7. DPM_Solver, NoiseScheduleVP, model_wrapper, model_wrapper_guided_diffusion)
  8. __all__ = ['GaussianDiffusion', 'beta_schedule']
  9. def kl_divergence(mu1, logvar1, mu2, logvar2):
  10. a = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
  11. b = ((mu1 - mu2)**2) * torch.exp(-logvar2)
  12. return 0.5 * (a + b)
  13. def standard_normal_cdf(x):
  14. return 0.5 * (1.0 + torch.tanh(
  15. math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
  16. def discretized_gaussian_log_likelihood(x0, mean, log_scale):
  17. assert x0.shape == mean.shape == log_scale.shape
  18. cx = x0 - mean
  19. inv_stdv = torch.exp(-log_scale)
  20. cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
  21. cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
  22. log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
  23. log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
  24. cdf_delta = cdf_plus - cdf_min
  25. log_probs = torch.where(
  26. x0 < -0.999, log_cdf_plus,
  27. torch.where(x0 > 0.999, log_one_minus_cdf_min,
  28. torch.log(cdf_delta.clamp(min=1e-12))))
  29. assert log_probs.shape == x0.shape
  30. return log_probs
  31. def _i(tensor, t, x):
  32. tensor = tensor.to(x.device)
  33. shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
  34. return tensor[t].view(shape).to(x)
  35. def cosine_fn(u):
  36. return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
  37. def beta_schedule(schedule,
  38. num_timesteps=1000,
  39. init_beta=None,
  40. last_beta=None):
  41. if schedule == 'linear':
  42. scale = 1000.0 / num_timesteps
  43. init_beta = init_beta or scale * 0.0001
  44. last_beta = last_beta or scale * 0.02
  45. return torch.linspace(
  46. init_beta, last_beta, num_timesteps, dtype=torch.float64)
  47. elif schedule == 'quadratic':
  48. init_beta = init_beta or 0.0015
  49. last_beta = last_beta or 0.0195
  50. return torch.linspace(
  51. init_beta**0.5, last_beta**0.5, num_timesteps,
  52. dtype=torch.float64)**2
  53. elif schedule == 'cosine':
  54. betas = []
  55. for step in range(num_timesteps):
  56. t1 = step / num_timesteps
  57. t2 = (step + 1) / num_timesteps
  58. betas.append(min(1.0 - cosine_fn(t2) / cosine_fn(t1), 0.999))
  59. return torch.tensor(betas, dtype=torch.float64)
  60. else:
  61. raise ValueError(f'Unsupported schedule: {schedule}')
  62. class GaussianDiffusion(object):
  63. def __init__(self,
  64. betas,
  65. mean_type='eps',
  66. var_type='learned_range',
  67. loss_type='mse',
  68. rescale_timesteps=False):
  69. # check input
  70. if not isinstance(betas, torch.DoubleTensor):
  71. betas = torch.tensor(betas, dtype=torch.float64)
  72. assert min(betas) > 0 and max(betas) <= 1
  73. assert mean_type in ['x0', 'x_{t-1}', 'eps']
  74. assert var_type in [
  75. 'learned', 'learned_range', 'fixed_large', 'fixed_small'
  76. ]
  77. assert loss_type in [
  78. 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
  79. ]
  80. self.betas = betas
  81. self.num_timesteps = len(betas)
  82. self.mean_type = mean_type
  83. self.var_type = var_type
  84. self.loss_type = loss_type
  85. self.rescale_timesteps = rescale_timesteps
  86. # alphas
  87. alphas = 1 - self.betas
  88. self.alphas_cumprod = torch.cumprod(alphas, dim=0)
  89. self.alphas_cumprod_prev = torch.cat(
  90. [alphas.new_ones([1]), self.alphas_cumprod[:-1]])
  91. self.alphas_cumprod_next = torch.cat(
  92. [self.alphas_cumprod[1:],
  93. alphas.new_zeros([1])])
  94. # q(x_t | x_{t-1})
  95. self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
  96. self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
  97. - self.alphas_cumprod)
  98. self.log_one_minus_alphas_cumprod = torch.log(1.0
  99. - self.alphas_cumprod)
  100. self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
  101. self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
  102. - 1)
  103. # q(x_{t-1} | x_t, x_0)
  104. self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
  105. 1.0 - self.alphas_cumprod)
  106. self.posterior_log_variance_clipped = torch.log(
  107. self.posterior_variance.clamp(1e-20))
  108. self.posterior_mean_coef1 = betas * torch.sqrt(
  109. self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
  110. self.posterior_mean_coef2 = (
  111. 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
  112. 1.0 - self.alphas_cumprod)
  113. def q_sample(self, x0, t, noise=None):
  114. noise = torch.randn_like(x0) if noise is None else noise
  115. return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i(
  116. self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
  117. def q_mean_variance(self, x0, t):
  118. mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
  119. var = _i(1.0 - self.alphas_cumprod, t, x0)
  120. log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
  121. return mu, var, log_var
  122. def q_posterior_mean_variance(self, x0, xt, t):
  123. mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
  124. self.posterior_mean_coef2, t, xt) * xt
  125. var = _i(self.posterior_variance, t, xt)
  126. log_var = _i(self.posterior_log_variance_clipped, t, xt)
  127. return mu, var, log_var
  128. @torch.no_grad()
  129. def p_sample(self,
  130. xt,
  131. t,
  132. model,
  133. model_kwargs={},
  134. clamp=None,
  135. percentile=None,
  136. condition_fn=None,
  137. guide_scale=None):
  138. # predict distribution of p(x_{t-1} | x_t)
  139. mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  140. clamp, percentile,
  141. guide_scale)
  142. # random sample (with optional conditional function)
  143. noise = torch.randn_like(xt)
  144. shape = (-1, ) + ((1, ) * (xt.ndim - 1))
  145. mask = t.ne(0).float().view(*shape) # no noise when t == 0
  146. if condition_fn is not None:
  147. grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
  148. mu = mu.float() + var * grad.float()
  149. xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
  150. return xt_1, x0
  151. @torch.no_grad()
  152. def p_sample_loop(self,
  153. noise,
  154. model,
  155. model_kwargs={},
  156. clamp=None,
  157. percentile=None,
  158. condition_fn=None,
  159. guide_scale=None):
  160. # prepare input
  161. b, c, h, w = noise.size()
  162. xt = noise
  163. # diffusion process
  164. for step in torch.arange(self.num_timesteps).flip(0):
  165. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  166. xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
  167. percentile, condition_fn, guide_scale)
  168. return xt
  169. def p_mean_variance(self,
  170. xt,
  171. t,
  172. model,
  173. model_kwargs={},
  174. clamp=None,
  175. percentile=None,
  176. guide_scale=None):
  177. # predict distribution
  178. if guide_scale is None:
  179. out = model(xt, self._scale_timesteps(t), **model_kwargs)
  180. else:
  181. # classifier-free guidance
  182. # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
  183. assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
  184. assert self.mean_type == 'eps'
  185. y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
  186. u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
  187. a = u_out[:, :3]
  188. b = guide_scale * (y_out[:, :3] - u_out[:, :3])
  189. c = y_out[:, 3:]
  190. out = torch.cat([a + b, c], dim=1)
  191. # compute variance
  192. if self.var_type == 'learned':
  193. out, log_var = out.chunk(2, dim=1)
  194. var = torch.exp(log_var)
  195. elif self.var_type == 'learned_range':
  196. out, fraction = out.chunk(2, dim=1)
  197. min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
  198. max_log_var = _i(torch.log(self.betas), t, xt)
  199. fraction = (fraction + 1) / 2.0
  200. log_var = fraction * max_log_var + (1 - fraction) * min_log_var
  201. var = torch.exp(log_var)
  202. elif self.var_type == 'fixed_large':
  203. var = _i(
  204. torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
  205. xt)
  206. log_var = torch.log(var)
  207. elif self.var_type == 'fixed_small':
  208. var = _i(self.posterior_variance, t, xt)
  209. log_var = _i(self.posterior_log_variance_clipped, t, xt)
  210. # compute mean and x0
  211. if self.mean_type == 'x_{t-1}':
  212. mu = out # x_{t-1}
  213. x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i(
  214. self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
  215. xt) * xt
  216. elif self.mean_type == 'x0':
  217. x0 = out
  218. mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
  219. elif self.mean_type == 'eps':
  220. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
  221. self.sqrt_recipm1_alphas_cumprod, t, xt) * out
  222. mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
  223. # restrict the range of x0
  224. if percentile is not None:
  225. assert percentile > 0 and percentile <= 1 # e.g., 0.995
  226. s = torch.quantile(
  227. x0.flatten(1).abs(), percentile,
  228. dim=1).clamp_(1.0).view(-1, 1, 1, 1)
  229. x0 = torch.min(s, torch.max(-s, x0)) / s
  230. elif clamp is not None:
  231. x0 = x0.clamp(-clamp, clamp)
  232. return mu, var, log_var, x0
  233. @torch.no_grad()
  234. def dpm_solver_sample_loop(self,
  235. noise,
  236. model,
  237. skip_type,
  238. order,
  239. method,
  240. model_kwargs={},
  241. clamp=None,
  242. percentile=None,
  243. condition_fn=None,
  244. guide_scale=None,
  245. dpm_solver_timesteps=20,
  246. t_start=None,
  247. t_end=None,
  248. lower_order_final=True,
  249. denoise_to_zero=False,
  250. solver_type='dpm_solver'):
  251. r"""Sample using DPM-Solver-based method.
  252. - condition_fn: for classifier-based guidance (guided-diffusion).
  253. - guide_scale: for classifier-free guidance (glide/dalle-2).
  254. Please check all the parameters in `dpm_solver.sample` before using.
  255. """
  256. noise_schedule = NoiseScheduleVP(
  257. schedule='discrete', betas=self.betas.float())
  258. model_fn = model_wrapper_guided_diffusion(
  259. model=model,
  260. noise_schedule=noise_schedule,
  261. var_type=self.var_type,
  262. mean_type=self.mean_type,
  263. model_kwargs=model_kwargs,
  264. clamp=clamp,
  265. percentile=percentile,
  266. rescale_timesteps=self.rescale_timesteps,
  267. num_timesteps=self.num_timesteps,
  268. guide_scale=guide_scale,
  269. condition_fn=condition_fn,
  270. )
  271. dpm_solver = DPM_Solver(
  272. model_fn=model_fn,
  273. noise_schedule=noise_schedule,
  274. )
  275. xt = dpm_solver.sample(
  276. noise,
  277. steps=dpm_solver_timesteps,
  278. order=order,
  279. skip_type=skip_type,
  280. method=method,
  281. solver_type=solver_type,
  282. t_start=t_start,
  283. t_end=t_end,
  284. lower_order_final=lower_order_final,
  285. denoise_to_zero=denoise_to_zero)
  286. return xt
  287. @torch.no_grad()
  288. def ddim_sample(self,
  289. xt,
  290. t,
  291. model,
  292. model_kwargs={},
  293. clamp=None,
  294. percentile=None,
  295. condition_fn=None,
  296. guide_scale=None,
  297. ddim_timesteps=20,
  298. eta=0.0):
  299. stride = self.num_timesteps // ddim_timesteps
  300. # predict distribution of p(x_{t-1} | x_t)
  301. _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
  302. percentile, guide_scale)
  303. if condition_fn is not None:
  304. # x0 -> eps
  305. alpha = _i(self.alphas_cumprod, t, xt)
  306. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
  307. self.sqrt_recipm1_alphas_cumprod, t, xt)
  308. eps = eps - (1 - alpha).sqrt() * condition_fn(
  309. xt, self._scale_timesteps(t), **model_kwargs)
  310. # eps -> x0
  311. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
  312. self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
  313. # derive variables
  314. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
  315. self.sqrt_recipm1_alphas_cumprod, t, xt)
  316. alphas = _i(self.alphas_cumprod, t, xt)
  317. alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
  318. a = (1 - alphas_prev) / (1 - alphas)
  319. b = (1 - alphas / alphas_prev)
  320. sigmas = eta * torch.sqrt(a * b)
  321. # random sample
  322. noise = torch.randn_like(xt)
  323. direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
  324. mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
  325. xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
  326. return xt_1, x0
  327. @torch.no_grad()
  328. def ddim_sample_loop(self,
  329. noise,
  330. model,
  331. model_kwargs={},
  332. clamp=None,
  333. percentile=None,
  334. condition_fn=None,
  335. guide_scale=None,
  336. ddim_timesteps=20,
  337. eta=0.0):
  338. # prepare input
  339. b, c, h, w = noise.size()
  340. xt = noise
  341. # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
  342. steps = (1 + torch.arange(0, self.num_timesteps,
  343. self.num_timesteps // ddim_timesteps)).clamp(
  344. 0, self.num_timesteps - 1).flip(0)
  345. for step in steps:
  346. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  347. xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
  348. percentile, condition_fn, guide_scale,
  349. ddim_timesteps, eta)
  350. return xt
  351. @torch.no_grad()
  352. def ddim_reverse_sample(self,
  353. xt,
  354. t,
  355. model,
  356. model_kwargs={},
  357. clamp=None,
  358. percentile=None,
  359. guide_scale=None,
  360. ddim_timesteps=20):
  361. stride = self.num_timesteps // ddim_timesteps
  362. # predict distribution of p(x_{t-1} | x_t)
  363. _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
  364. percentile, guide_scale)
  365. # derive variables
  366. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
  367. self.sqrt_recipm1_alphas_cumprod, t, xt)
  368. alphas_next = _i(
  369. torch.cat(
  370. [self.alphas_cumprod,
  371. self.alphas_cumprod.new_zeros([1])]),
  372. (t + stride).clamp(0, self.num_timesteps), xt)
  373. # reverse sample
  374. mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
  375. return mu, x0
  376. @torch.no_grad()
  377. def ddim_reverse_sample_loop(self,
  378. x0,
  379. model,
  380. model_kwargs={},
  381. clamp=None,
  382. percentile=None,
  383. guide_scale=None,
  384. ddim_timesteps=20):
  385. # prepare input
  386. b, c, h, w = x0.size()
  387. xt = x0
  388. # reconstruction steps
  389. steps = torch.arange(0, self.num_timesteps,
  390. self.num_timesteps // ddim_timesteps)
  391. for step in steps:
  392. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  393. xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
  394. percentile, guide_scale,
  395. ddim_timesteps)
  396. return xt
  397. @torch.no_grad()
  398. def plms_sample(self,
  399. xt,
  400. t,
  401. model,
  402. model_kwargs={},
  403. clamp=None,
  404. percentile=None,
  405. condition_fn=None,
  406. guide_scale=None,
  407. plms_timesteps=20):
  408. stride = self.num_timesteps // plms_timesteps
  409. # function for compute eps
  410. def compute_eps(xt, t):
  411. # predict distribution of p(x_{t-1} | x_t)
  412. _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  413. clamp, percentile, guide_scale)
  414. # condition
  415. if condition_fn is not None:
  416. # x0 -> eps
  417. alpha = _i(self.alphas_cumprod, t, xt)
  418. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
  419. - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
  420. eps = eps - (1 - alpha).sqrt() * condition_fn(
  421. xt, self._scale_timesteps(t), **model_kwargs)
  422. # eps -> x0
  423. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
  424. self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
  425. # derive eps
  426. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
  427. self.sqrt_recipm1_alphas_cumprod, t, xt)
  428. return eps
  429. # function for compute x_0 and x_{t-1}
  430. def compute_x0(eps, t):
  431. # eps -> x0
  432. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
  433. self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
  434. # deterministic sample
  435. alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
  436. direction = torch.sqrt(1 - alphas_prev) * eps
  437. xt_1 = torch.sqrt(alphas_prev) * x0 + direction
  438. return xt_1, x0
  439. # PLMS sample
  440. eps = compute_eps(xt, t)
  441. if len(eps_cache) == 0:
  442. # 2nd order pseudo improved Euler
  443. xt_1, x0 = compute_x0(eps, t)
  444. eps_next = compute_eps(xt_1, (t - stride).clamp(0))
  445. eps_prime = (eps + eps_next) / 2.0
  446. elif len(eps_cache) == 1:
  447. # 2nd order pseudo linear multistep (Adams-Bashforth)
  448. eps_prime = (3 * eps - eps_cache[-1]) / 2.0
  449. elif len(eps_cache) == 2:
  450. # 3rd order pseudo linear multistep (Adams-Bashforth)
  451. eps_prime = (23 * eps - 16 * eps_cache[-1]
  452. + 5 * eps_cache[-2]) / 12.0
  453. elif len(eps_cache) >= 3:
  454. # 4nd order pseudo linear multistep (Adams-Bashforth)
  455. eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
  456. - 9 * eps_cache[-3]) / 24.0
  457. xt_1, x0 = compute_x0(eps_prime, t)
  458. return xt_1, x0, eps
  459. @torch.no_grad()
  460. def plms_sample_loop(self,
  461. noise,
  462. model,
  463. model_kwargs={},
  464. clamp=None,
  465. percentile=None,
  466. condition_fn=None,
  467. guide_scale=None,
  468. plms_timesteps=20):
  469. # prepare input
  470. b, c, h, w = noise.size()
  471. xt = noise
  472. # diffusion process
  473. steps = (1 + torch.arange(0, self.num_timesteps,
  474. self.num_timesteps // plms_timesteps)).clamp(
  475. 0, self.num_timesteps - 1).flip(0)
  476. eps_cache = []
  477. for step in steps:
  478. # PLMS sampling step
  479. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  480. xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
  481. percentile, condition_fn,
  482. guide_scale, plms_timesteps,
  483. eps_cache)
  484. # update eps cache
  485. eps_cache.append(eps)
  486. if len(eps_cache) >= 4:
  487. eps_cache.pop(0)
  488. return xt
  489. def loss(self, x0, t, model, model_kwargs={}, noise=None):
  490. noise = torch.randn_like(x0) if noise is None else noise
  491. xt = self.q_sample(x0, t, noise=noise)
  492. # compute loss
  493. if self.loss_type in ['kl', 'rescaled_kl']:
  494. loss, _ = self.variational_lower_bound(x0, xt, t, model,
  495. model_kwargs)
  496. if self.loss_type == 'rescaled_kl':
  497. loss = loss * self.num_timesteps
  498. elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
  499. out = model(xt, self._scale_timesteps(t), **model_kwargs)
  500. # VLB for variation
  501. loss_vlb = 0.0
  502. if self.var_type in ['learned', 'learned_range']:
  503. out, var = out.chunk(2, dim=1)
  504. frozen = torch.cat([
  505. out.detach(), var
  506. ], dim=1) # learn var without affecting the prediction of mean
  507. loss_vlb, _ = self.variational_lower_bound(
  508. x0, xt, t, model=lambda *args, **kwargs: frozen)
  509. if self.loss_type.startswith('rescaled_'):
  510. loss_vlb = loss_vlb * self.num_timesteps / 1000.0
  511. # MSE/L1 for x0/eps
  512. target = {
  513. 'eps': noise,
  514. 'x0': x0,
  515. 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
  516. }[self.mean_type]
  517. loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
  518. ).abs().flatten(1).mean(dim=1)
  519. # total loss
  520. loss = loss + loss_vlb
  521. return loss
  522. def variational_lower_bound(self,
  523. x0,
  524. xt,
  525. t,
  526. model,
  527. model_kwargs={},
  528. clamp=None,
  529. percentile=None):
  530. # compute groundtruth and predicted distributions
  531. mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
  532. mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  533. clamp, percentile)
  534. # compute KL loss
  535. kl = kl_divergence(mu1, log_var1, mu2, log_var2)
  536. kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
  537. # compute discretized NLL loss (for p(x0 | x1) only)
  538. nll = -discretized_gaussian_log_likelihood(
  539. x0, mean=mu2, log_scale=0.5 * log_var2)
  540. nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
  541. # NLL for p(x0 | x1) and KL otherwise
  542. vlb = torch.where(t == 0, nll, kl)
  543. return vlb, x0
  544. @torch.no_grad()
  545. def variational_lower_bound_loop(self,
  546. x0,
  547. model,
  548. model_kwargs={},
  549. clamp=None,
  550. percentile=None):
  551. # prepare input and output
  552. b, c, h, w = x0.size()
  553. metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
  554. # loop
  555. for step in torch.arange(self.num_timesteps).flip(0):
  556. # compute VLB
  557. t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
  558. noise = torch.randn_like(x0)
  559. xt = self.q_sample(x0, t, noise)
  560. vlb, pred_x0 = self.variational_lower_bound(
  561. x0, xt, t, model, model_kwargs, clamp, percentile)
  562. # predict eps from x0
  563. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
  564. self.sqrt_recipm1_alphas_cumprod, t, xt)
  565. # collect metrics
  566. metrics['vlb'].append(vlb)
  567. metrics['x0_mse'].append(
  568. (pred_x0 - x0).square().flatten(1).mean(dim=1))
  569. metrics['mse'].append(
  570. (eps - noise).square().flatten(1).mean(dim=1))
  571. metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
  572. # compute the prior KL term for VLB, measured in bits-per-dim
  573. mu, _, log_var = self.q_mean_variance(x0, t)
  574. kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
  575. torch.zeros_like(log_var))
  576. kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
  577. # update metrics
  578. metrics['prior_bits_per_dim'] = kl_prior
  579. metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
  580. return metrics
  581. def _scale_timesteps(self, t):
  582. if self.rescale_timesteps:
  583. return t.float() * 1000.0 / self.num_timesteps
  584. return t