diffusion.py 61 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import torch
  4. from .dpm_solver import (DPM_Solver, NoiseScheduleVP, model_wrapper,
  5. model_wrapper_guided_diffusion)
  6. from .ops.losses import discretized_gaussian_log_likelihood, kl_divergence
  7. __all__ = ['GaussianDiffusion', 'beta_schedule', 'GaussianDiffusion_style']
  8. def _i(tensor, t, x):
  9. r"""Index tensor using t and format the output according to x.
  10. """
  11. shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
  12. if tensor.device != x.device:
  13. tensor = tensor.to(x.device)
  14. return tensor[t].view(shape).to(x)
  15. def beta_schedule(schedule,
  16. num_timesteps=1000,
  17. init_beta=None,
  18. last_beta=None):
  19. '''
  20. This code defines a function beta_schedule that generates a sequence of beta
  21. values based on the given input parameters.
  22. These beta values can be used in video diffusion processes. The function has the following parameters:
  23. schedule(str): Determines the type of beta schedule to be generated.
  24. It can be 'linear', 'linear_sd', 'quadratic', or 'cosine'.
  25. num_timesteps(int, optional): The number of timesteps for the generated beta schedule. Default is 1000.
  26. init_beta(float, optional): The initial beta value.
  27. If not provided, a default value is used based on the chosen schedule.
  28. last_beta(float, optional): The final beta value.
  29. If not provided, a default value is used based on the chosen schedule.
  30. The function returns a PyTorch tensor containing the generated beta values.
  31. The beta schedule is determined by the schedule parameter:
  32. 1.Linear: Generates a linear sequence of beta values betweeninit_betaandlast_beta.
  33. 2.Linear_sd: Generates a linear sequence of beta values between the square root of
  34. init_beta and the square root oflast_beta, and then squares the result.
  35. 3.Quadratic: Similar to the 'linear_sd' schedule, but with different default values forinit_betaandlast_beta.
  36. 4.Cosine: Generates a sequence of beta values based on a cosine function,
  37. ensuring the values are between 0 and 0.999.
  38. If an unsupported schedule is provided, a ValueError is raised with a message indicating the issue.
  39. '''
  40. if schedule == 'linear':
  41. scale = 1000.0 / num_timesteps
  42. init_beta = init_beta or scale * 0.0001
  43. last_beta = last_beta or scale * 0.02
  44. return torch.linspace(
  45. init_beta, last_beta, num_timesteps, dtype=torch.float64)
  46. elif schedule == 'linear_sd':
  47. return torch.linspace(
  48. init_beta**0.5, last_beta**0.5, num_timesteps,
  49. dtype=torch.float64)**2
  50. elif schedule == 'quadratic':
  51. init_beta = init_beta or 0.0015
  52. last_beta = last_beta or 0.0195
  53. return torch.linspace(
  54. init_beta**0.5, last_beta**0.5, num_timesteps,
  55. dtype=torch.float64)**2
  56. elif schedule == 'cosine':
  57. betas = []
  58. for step in range(num_timesteps):
  59. t1 = step / num_timesteps
  60. t2 = (step + 1) / num_timesteps
  61. fn = lambda u: math.cos( # noqa
  62. (u + 0.008) / 1.008 * math.pi / 2)**2 # noqa
  63. betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
  64. return torch.tensor(betas, dtype=torch.float64)
  65. else:
  66. raise ValueError(f'Unsupported schedule: {schedule}')
  67. def load_stable_diffusion_pretrained(state_dict, temporal_attention):
  68. import collections
  69. sd_new = collections.OrderedDict()
  70. keys = list(state_dict.keys())
  71. for k in keys:
  72. if k.find('diffusion_model') >= 0:
  73. k_new = k.split('diffusion_model.')[-1]
  74. if k_new in [
  75. 'input_blocks.3.0.op.weight', 'input_blocks.3.0.op.bias',
  76. 'input_blocks.6.0.op.weight', 'input_blocks.6.0.op.bias',
  77. 'input_blocks.9.0.op.weight', 'input_blocks.9.0.op.bias'
  78. ]:
  79. k_new = k_new.replace('0.op', 'op')
  80. if temporal_attention:
  81. if k_new.find('middle_block.2') >= 0:
  82. k_new = k_new.replace('middle_block.2', 'middle_block.3')
  83. if k_new.find('output_blocks.5.2') >= 0:
  84. k_new = k_new.replace('output_blocks.5.2',
  85. 'output_blocks.5.3')
  86. if k_new.find('output_blocks.8.2') >= 0:
  87. k_new = k_new.replace('output_blocks.8.2',
  88. 'output_blocks.8.3')
  89. sd_new[k_new] = state_dict[k]
  90. return sd_new
  91. class AddGaussianNoise(object):
  92. def __init__(self, mean=0., std=0.1):
  93. self.std = std
  94. self.mean = mean
  95. def __call__(self, img):
  96. assert isinstance(img, torch.Tensor)
  97. dtype = img.dtype
  98. if not img.is_floating_point():
  99. img = img.to(torch.float32)
  100. out = img + self.std * torch.randn_like(img) + self.mean
  101. if out.dtype != dtype:
  102. out = out.to(dtype)
  103. return out
  104. def __repr__(self):
  105. return self.__class__.__name__ + '(mean={0}, std={1})'.format(
  106. self.mean, self.std)
  107. class GaussianDiffusion(object):
  108. def __init__(self,
  109. betas,
  110. mean_type='eps',
  111. var_type='learned_range',
  112. loss_type='mse',
  113. epsilon=1e-12,
  114. rescale_timesteps=False):
  115. # check input
  116. if not isinstance(betas, torch.DoubleTensor):
  117. betas = torch.tensor(betas, dtype=torch.float64)
  118. assert min(betas) > 0 and max(betas) <= 1
  119. assert mean_type in ['x0', 'x_{t-1}', 'eps']
  120. assert var_type in [
  121. 'learned', 'learned_range', 'fixed_large', 'fixed_small'
  122. ]
  123. assert loss_type in [
  124. 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
  125. 'charbonnier'
  126. ]
  127. self.betas = betas
  128. self.num_timesteps = len(betas)
  129. self.mean_type = mean_type # eps
  130. self.var_type = var_type # 'fixed_small'
  131. self.loss_type = loss_type # mse
  132. self.epsilon = epsilon # 1e-12
  133. self.rescale_timesteps = rescale_timesteps # False
  134. # alphas
  135. alphas = 1 - self.betas
  136. self.alphas_cumprod = torch.cumprod(alphas, dim=0)
  137. self.alphas_cumprod_prev = torch.cat(
  138. [alphas.new_ones([1]), self.alphas_cumprod[:-1]])
  139. self.alphas_cumprod_next = torch.cat(
  140. [self.alphas_cumprod[1:],
  141. alphas.new_zeros([1])])
  142. # q(x_t | x_{t-1})
  143. self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
  144. self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
  145. - self.alphas_cumprod)
  146. self.log_one_minus_alphas_cumprod = torch.log(1.0
  147. - self.alphas_cumprod)
  148. self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
  149. self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
  150. - 1)
  151. # q(x_{t-1} | x_t, x_0)
  152. self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
  153. 1.0 - self.alphas_cumprod)
  154. self.posterior_log_variance_clipped = torch.log(
  155. self.posterior_variance.clamp(1e-20))
  156. self.posterior_mean_coef1 = betas * torch.sqrt(
  157. self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
  158. self.posterior_mean_coef2 = (
  159. 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
  160. 1.0 - self.alphas_cumprod)
  161. def q_sample(self, x0, t, noise=None):
  162. r"""Sample from q(x_t | x_0).
  163. """
  164. noise = torch.randn_like(x0) if noise is None else noise
  165. return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
  166. _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise # noqa
  167. def q_mean_variance(self, x0, t):
  168. r"""Distribution of q(x_t | x_0).
  169. """
  170. mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
  171. var = _i(1.0 - self.alphas_cumprod, t, x0)
  172. log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
  173. return mu, var, log_var
  174. def q_posterior_mean_variance(self, x0, xt, t):
  175. r"""Distribution of q(x_{t-1} | x_t, x_0).
  176. """
  177. mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
  178. self.posterior_mean_coef2, t, xt) * xt
  179. var = _i(self.posterior_variance, t, xt)
  180. log_var = _i(self.posterior_log_variance_clipped, t, xt)
  181. return mu, var, log_var
  182. @torch.no_grad()
  183. def p_sample(self,
  184. xt,
  185. t,
  186. model,
  187. model_kwargs={},
  188. clamp=None,
  189. percentile=None,
  190. condition_fn=None,
  191. guide_scale=None):
  192. r"""Sample from p(x_{t-1} | x_t).
  193. - condition_fn: for classifier-based guidance (guided-diffusion).
  194. - guide_scale: for classifier-free guidance (glide/dalle-2).
  195. """
  196. # predict distribution of p(x_{t-1} | x_t)
  197. mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  198. clamp, percentile,
  199. guide_scale)
  200. # random sample (with optional conditional function)
  201. noise = torch.randn_like(xt)
  202. mask = t.ne(0).float().view(
  203. -1,
  204. *((1, ) * # noqa
  205. (xt.ndim - 1)))
  206. if condition_fn is not None:
  207. grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
  208. mu = mu.float() + var * grad.float()
  209. xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
  210. return xt_1, x0
  211. @torch.no_grad()
  212. def p_sample_loop(self,
  213. noise,
  214. model,
  215. model_kwargs={},
  216. clamp=None,
  217. percentile=None,
  218. condition_fn=None,
  219. guide_scale=None):
  220. r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
  221. """
  222. # prepare input
  223. b = noise.size(0)
  224. xt = noise
  225. # diffusion process
  226. for step in torch.arange(self.num_timesteps).flip(0):
  227. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  228. xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
  229. percentile, condition_fn, guide_scale)
  230. return xt
  231. def p_mean_variance(self,
  232. xt,
  233. t,
  234. model,
  235. model_kwargs={},
  236. clamp=None,
  237. percentile=None,
  238. guide_scale=None):
  239. r"""Distribution of p(x_{t-1} | x_t).
  240. """
  241. # predict distribution
  242. if guide_scale is None:
  243. out = model(xt, self._scale_timesteps(t), **model_kwargs)
  244. else:
  245. # classifier-free guidance
  246. # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
  247. assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
  248. y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
  249. u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
  250. dim = y_out.size(1) if self.var_type.startswith(
  251. 'fixed') else y_out.size(1) // 2
  252. out = torch.cat(
  253. [
  254. u_out[:, :dim] + guide_scale * # noqa
  255. (y_out[:, :dim] - u_out[:, :dim]),
  256. y_out[:, dim:]
  257. ],
  258. dim=1) # noqa
  259. # compute variance
  260. if self.var_type == 'learned':
  261. out, log_var = out.chunk(2, dim=1)
  262. var = torch.exp(log_var)
  263. elif self.var_type == 'learned_range':
  264. out, fraction = out.chunk(2, dim=1)
  265. min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
  266. max_log_var = _i(torch.log(self.betas), t, xt)
  267. fraction = (fraction + 1) / 2.0
  268. log_var = fraction * max_log_var + (1 - fraction) * min_log_var
  269. var = torch.exp(log_var)
  270. elif self.var_type == 'fixed_large':
  271. var = _i(
  272. torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
  273. xt)
  274. log_var = torch.log(var)
  275. elif self.var_type == 'fixed_small':
  276. var = _i(self.posterior_variance, t, xt)
  277. log_var = _i(self.posterior_log_variance_clipped, t, xt)
  278. # compute mean and x0
  279. if self.mean_type == 'x_{t-1}':
  280. mu = out # x_{t-1}
  281. x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
  282. _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt # noqa
  283. elif self.mean_type == 'x0':
  284. x0 = out
  285. mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
  286. elif self.mean_type == 'eps':
  287. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
  288. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out # noqa
  289. mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
  290. # restrict the range of x0
  291. if percentile is not None:
  292. assert percentile > 0 and percentile <= 1 # e.g., 0.995
  293. s = torch.quantile(
  294. x0.flatten(1).abs(), percentile,
  295. dim=1).clamp_(1.0).view(-1, 1, 1, 1)
  296. x0 = torch.min(s, torch.max(-s, x0)) / s
  297. elif clamp is not None:
  298. x0 = x0.clamp(-clamp, clamp)
  299. return mu, var, log_var, x0
  300. @torch.no_grad()
  301. def ddim_sample(self,
  302. xt,
  303. t,
  304. model,
  305. model_kwargs={},
  306. clamp=None,
  307. percentile=None,
  308. condition_fn=None,
  309. guide_scale=None,
  310. ddim_timesteps=20,
  311. eta=0.0):
  312. r"""Sample from p(x_{t-1} | x_t) using DDIM.
  313. - condition_fn: for classifier-based guidance (guided-diffusion).
  314. - guide_scale: for classifier-free guidance (glide/dalle-2).
  315. """
  316. stride = self.num_timesteps // ddim_timesteps
  317. # predict distribution of p(x_{t-1} | x_t)
  318. _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
  319. percentile, guide_scale)
  320. if condition_fn is not None:
  321. # x0 -> eps
  322. alpha = _i(self.alphas_cumprod, t, xt)
  323. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  324. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  325. eps = eps - (1 - alpha).sqrt() * condition_fn(
  326. xt, self._scale_timesteps(t), **model_kwargs)
  327. # eps -> x0
  328. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
  329. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
  330. # derive variables
  331. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  332. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  333. alphas = _i(self.alphas_cumprod, t, xt)
  334. alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
  335. sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa
  336. (1 - alphas / alphas_prev))
  337. # random sample
  338. noise = torch.randn_like(xt)
  339. direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
  340. mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
  341. xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
  342. return xt_1, x0
  343. @torch.no_grad()
  344. def ddim_sample_loop(self,
  345. noise,
  346. model,
  347. model_kwargs={},
  348. clamp=None,
  349. percentile=None,
  350. condition_fn=None,
  351. guide_scale=None,
  352. ddim_timesteps=20,
  353. eta=0.0):
  354. # prepare input
  355. b = noise.size(0)
  356. xt = noise
  357. # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
  358. steps = (1 + torch.arange(0, self.num_timesteps,
  359. self.num_timesteps // ddim_timesteps)).clamp(
  360. 0, self.num_timesteps - 1).flip(0)
  361. # import ipdb; ipdb.set_trace()
  362. for step in steps:
  363. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  364. xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
  365. percentile, condition_fn, guide_scale,
  366. ddim_timesteps, eta)
  367. return xt
  368. @torch.no_grad()
  369. def ddim_reverse_sample(self,
  370. xt,
  371. t,
  372. model,
  373. model_kwargs={},
  374. clamp=None,
  375. percentile=None,
  376. guide_scale=None,
  377. ddim_timesteps=20):
  378. r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
  379. """
  380. stride = self.num_timesteps // ddim_timesteps
  381. # predict distribution of p(x_{t-1} | x_t)
  382. _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
  383. percentile, guide_scale)
  384. # derive variables
  385. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  386. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  387. alphas_next = _i(
  388. torch.cat(
  389. [self.alphas_cumprod,
  390. self.alphas_cumprod.new_zeros([1])]),
  391. (t + stride).clamp(0, self.num_timesteps), xt)
  392. # reverse sample
  393. mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
  394. return mu, x0
  395. @torch.no_grad()
  396. def ddim_reverse_sample_loop(self,
  397. x0,
  398. model,
  399. model_kwargs={},
  400. clamp=None,
  401. percentile=None,
  402. guide_scale=None,
  403. ddim_timesteps=20):
  404. # prepare input
  405. b = x0.size(0)
  406. xt = x0
  407. # reconstruction steps
  408. steps = torch.arange(0, self.num_timesteps,
  409. self.num_timesteps // ddim_timesteps)
  410. for step in steps:
  411. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  412. xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
  413. percentile, guide_scale,
  414. ddim_timesteps)
  415. return xt
  416. @torch.no_grad()
  417. def plms_sample(self,
  418. xt,
  419. t,
  420. model,
  421. model_kwargs={},
  422. clamp=None,
  423. percentile=None,
  424. condition_fn=None,
  425. guide_scale=None,
  426. plms_timesteps=20):
  427. r"""Sample from p(x_{t-1} | x_t) using PLMS.
  428. - condition_fn: for classifier-based guidance (guided-diffusion).
  429. - guide_scale: for classifier-free guidance (glide/dalle-2).
  430. """
  431. stride = self.num_timesteps // plms_timesteps
  432. # function for compute eps
  433. def compute_eps(xt, t):
  434. # predict distribution of p(x_{t-1} | x_t)
  435. _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  436. clamp, percentile, guide_scale)
  437. # condition
  438. if condition_fn is not None:
  439. # x0 -> eps
  440. alpha = _i(self.alphas_cumprod, t, xt)
  441. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  442. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  443. eps = eps - (1 - alpha).sqrt() * condition_fn(
  444. xt, self._scale_timesteps(t), **model_kwargs)
  445. # eps -> x0
  446. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
  447. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
  448. # derive eps
  449. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  450. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  451. return eps
  452. # function for compute x_0 and x_{t-1}
  453. def compute_x0(eps, t):
  454. # eps -> x0
  455. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
  456. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
  457. # deterministic sample
  458. alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
  459. direction = torch.sqrt(1 - alphas_prev) * eps
  460. xt_1 = torch.sqrt(alphas_prev) * x0 + direction
  461. return xt_1, x0
  462. # PLMS sample
  463. eps = compute_eps(xt, t)
  464. if len(eps_cache) == 0:
  465. # 2nd order pseudo improved Euler
  466. xt_1, x0 = compute_x0(eps, t)
  467. eps_next = compute_eps(xt_1, (t - stride).clamp(0))
  468. eps_prime = (eps + eps_next) / 2.0
  469. elif len(eps_cache) == 1:
  470. # 2nd order pseudo linear multistep (Adams-Bashforth)
  471. eps_prime = (3 * eps - eps_cache[-1]) / 2.0
  472. elif len(eps_cache) == 2:
  473. # 3rd order pseudo linear multistep (Adams-Bashforth)
  474. eps_prime = (23 * eps - 16 * eps_cache[-1]
  475. + 5 * eps_cache[-2]) / 12.0
  476. elif len(eps_cache) >= 3:
  477. # 4nd order pseudo linear multistep (Adams-Bashforth)
  478. eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
  479. - 9 * eps_cache[-3]) / 24.0
  480. xt_1, x0 = compute_x0(eps_prime, t)
  481. return xt_1, x0, eps
  482. @torch.no_grad()
  483. def plms_sample_loop(self,
  484. noise,
  485. model,
  486. model_kwargs={},
  487. clamp=None,
  488. percentile=None,
  489. condition_fn=None,
  490. guide_scale=None,
  491. plms_timesteps=20):
  492. # prepare input
  493. b = noise.size(0)
  494. xt = noise
  495. # diffusion process
  496. steps = (1 + torch.arange(0, self.num_timesteps,
  497. self.num_timesteps // plms_timesteps)).clamp(
  498. 0, self.num_timesteps - 1).flip(0)
  499. eps_cache = []
  500. for step in steps:
  501. # PLMS sampling step
  502. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  503. xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
  504. percentile, condition_fn,
  505. guide_scale, plms_timesteps,
  506. eps_cache)
  507. # update eps cache
  508. eps_cache.append(eps)
  509. if len(eps_cache) >= 4:
  510. eps_cache.pop(0)
  511. return xt
  512. def loss(self,
  513. x0,
  514. t,
  515. model,
  516. model_kwargs={},
  517. noise=None,
  518. weight=None,
  519. use_div_loss=False):
  520. noise = torch.randn_like(
  521. x0) if noise is None else noise # [80, 4, 8, 32, 32]
  522. xt = self.q_sample(x0, t, noise=noise)
  523. # compute loss
  524. if self.loss_type in ['kl', 'rescaled_kl']:
  525. loss, _ = self.variational_lower_bound(x0, xt, t, model,
  526. model_kwargs)
  527. if self.loss_type == 'rescaled_kl':
  528. loss = loss * self.num_timesteps
  529. elif self.loss_type in ['mse', 'rescaled_mse', 'l1',
  530. 'rescaled_l1']: # self.loss_type: mse
  531. out = model(xt, self._scale_timesteps(t), **model_kwargs)
  532. # VLB for variation
  533. loss_vlb = 0.0
  534. if self.var_type in ['learned', 'learned_range'
  535. ]: # self.var_type: 'fixed_small'
  536. out, var = out.chunk(2, dim=1)
  537. frozen = torch.cat([
  538. out.detach(), var
  539. ], dim=1) # learn var without affecting the prediction of mean
  540. loss_vlb, _ = self.variational_lower_bound(
  541. x0, xt, t, model=lambda *args, **kwargs: frozen)
  542. if self.loss_type.startswith('rescaled_'):
  543. loss_vlb = loss_vlb * self.num_timesteps / 1000.0
  544. # MSE/L1 for x0/eps
  545. target = {
  546. 'eps': noise,
  547. 'x0': x0,
  548. 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
  549. }[self.mean_type]
  550. loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
  551. ).abs().flatten(1).mean(dim=1)
  552. if weight is not None:
  553. loss = loss * weight
  554. # div loss
  555. if use_div_loss and self.mean_type == 'eps' and x0.shape[2] > 1:
  556. # derive x0
  557. x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
  558. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
  559. # ncfhw, std on f
  560. div_loss = 0.001 / (
  561. x0_.std(dim=2).flatten(1).mean(dim=1) + 1e-4)
  562. loss = loss + div_loss
  563. # total loss
  564. loss = loss + loss_vlb
  565. elif self.loss_type in ['charbonnier']:
  566. out = model(xt, self._scale_timesteps(t), **model_kwargs)
  567. # VLB for variation
  568. loss_vlb = 0.0
  569. if self.var_type in ['learned', 'learned_range']:
  570. out, var = out.chunk(2, dim=1)
  571. frozen = torch.cat([out.detach(), var], dim=1)
  572. loss_vlb, _ = self.variational_lower_bound(
  573. x0, xt, t, model=lambda *args, **kwargs: frozen)
  574. if self.loss_type.startswith('rescaled_'):
  575. loss_vlb = loss_vlb * self.num_timesteps / 1000.0
  576. # MSE/L1 for x0/eps
  577. target = {
  578. 'eps': noise,
  579. 'x0': x0,
  580. 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
  581. }[self.mean_type]
  582. loss = torch.sqrt((out - target)**2 + self.epsilon)
  583. if weight is not None:
  584. loss = loss * weight
  585. loss = loss.flatten(1).mean(dim=1)
  586. # total loss
  587. loss = loss + loss_vlb
  588. return loss
  589. def variational_lower_bound(self,
  590. x0,
  591. xt,
  592. t,
  593. model,
  594. model_kwargs={},
  595. clamp=None,
  596. percentile=None):
  597. # compute groundtruth and predicted distributions
  598. mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
  599. mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  600. clamp, percentile)
  601. # compute KL loss
  602. kl = kl_divergence(mu1, log_var1, mu2, log_var2)
  603. kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
  604. # compute discretized NLL loss (for p(x0 | x1) only)
  605. nll = -discretized_gaussian_log_likelihood(
  606. x0, mean=mu2, log_scale=0.5 * log_var2)
  607. nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
  608. # NLL for p(x0 | x1) and KL otherwise
  609. vlb = torch.where(t == 0, nll, kl)
  610. return vlb, x0
  611. @torch.no_grad()
  612. def variational_lower_bound_loop(self,
  613. x0,
  614. model,
  615. model_kwargs={},
  616. clamp=None,
  617. percentile=None):
  618. r"""Compute the entire variational lower bound, measured in bits-per-dim.
  619. """
  620. # prepare input and output
  621. b = x0.size(0)
  622. metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
  623. # loop
  624. for step in torch.arange(self.num_timesteps).flip(0):
  625. # compute VLB
  626. t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
  627. noise = torch.randn_like(x0)
  628. xt = self.q_sample(x0, t, noise)
  629. vlb, pred_x0 = self.variational_lower_bound(
  630. x0, xt, t, model, model_kwargs, clamp, percentile)
  631. # predict eps from x0
  632. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  633. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  634. # collect metrics
  635. metrics['vlb'].append(vlb)
  636. metrics['x0_mse'].append(
  637. (pred_x0 - x0).square().flatten(1).mean(dim=1))
  638. metrics['mse'].append(
  639. (eps - noise).square().flatten(1).mean(dim=1))
  640. metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
  641. # compute the prior KL term for VLB, measured in bits-per-dim
  642. mu, _, log_var = self.q_mean_variance(x0, t)
  643. kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
  644. torch.zeros_like(log_var))
  645. kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
  646. # update metrics
  647. metrics['prior_bits_per_dim'] = kl_prior
  648. metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
  649. return metrics
  650. def _scale_timesteps(self, t):
  651. if self.rescale_timesteps: # noqa
  652. return t.float() * 1000.0 / self.num_timesteps
  653. return t
  654. class GaussianDiffusion_style(object):
  655. def __init__(self,
  656. betas,
  657. mean_type='eps',
  658. var_type='fixed_small',
  659. loss_type='mse',
  660. rescale_timesteps=False):
  661. # check input
  662. if not isinstance(betas, torch.DoubleTensor):
  663. betas = torch.tensor(betas, dtype=torch.float64)
  664. assert min(betas) > 0 and max(betas) <= 1
  665. assert mean_type in ['x0', 'x_{t-1}', 'eps']
  666. assert var_type in [
  667. 'learned', 'learned_range', 'fixed_large', 'fixed_small'
  668. ]
  669. assert loss_type in [
  670. 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
  671. ]
  672. self.betas = betas
  673. self.num_timesteps = len(betas)
  674. self.mean_type = mean_type
  675. self.var_type = var_type
  676. self.loss_type = loss_type
  677. self.rescale_timesteps = rescale_timesteps
  678. # alphas
  679. alphas = 1 - self.betas
  680. self.alphas_cumprod = torch.cumprod(alphas, dim=0)
  681. self.alphas_cumprod_prev = torch.cat(
  682. [alphas.new_ones([1]), self.alphas_cumprod[:-1]])
  683. self.alphas_cumprod_next = torch.cat(
  684. [self.alphas_cumprod[1:],
  685. alphas.new_zeros([1])])
  686. # q(x_t | x_{t-1})
  687. self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
  688. self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
  689. - self.alphas_cumprod)
  690. self.log_one_minus_alphas_cumprod = torch.log(1.0
  691. - self.alphas_cumprod)
  692. self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
  693. self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
  694. - 1)
  695. # q(x_{t-1} | x_t, x_0)
  696. self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
  697. 1.0 - self.alphas_cumprod)
  698. self.posterior_log_variance_clipped = torch.log(
  699. self.posterior_variance.clamp(1e-20))
  700. self.posterior_mean_coef1 = betas * torch.sqrt(
  701. self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
  702. self.posterior_mean_coef2 = (
  703. 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
  704. 1.0 - self.alphas_cumprod)
  705. def q_sample(self, x0, t, noise=None):
  706. r"""Sample from q(x_t | x_0).
  707. """
  708. noise = torch.randn_like(x0) if noise is None else noise
  709. xt = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
  710. _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise # noqa
  711. return xt.type_as(x0)
  712. def q_mean_variance(self, x0, t):
  713. r"""Distribution of q(x_t | x_0).
  714. """
  715. mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
  716. var = _i(1.0 - self.alphas_cumprod, t, x0)
  717. log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
  718. return mu, var, log_var
  719. def q_posterior_mean_variance(self, x0, xt, t):
  720. r"""Distribution of q(x_{t-1} | x_t, x_0).
  721. """
  722. mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
  723. self.posterior_mean_coef2, t, xt) * xt
  724. var = _i(self.posterior_variance, t, xt)
  725. log_var = _i(self.posterior_log_variance_clipped, t, xt)
  726. return mu, var, log_var
  727. @torch.no_grad()
  728. def p_sample(self,
  729. xt,
  730. t,
  731. model,
  732. model_kwargs={},
  733. clamp=None,
  734. percentile=None,
  735. condition_fn=None,
  736. guide_scale=None):
  737. r"""Sample from p(x_{t-1} | x_t).
  738. - condition_fn: for classifier-based guidance (guided-diffusion).
  739. - guide_scale: for classifier-free guidance (glide/dalle-2).
  740. """
  741. dtype = xt.dtype
  742. # predict distribution of p(x_{t-1} | x_t)
  743. mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  744. clamp, percentile,
  745. guide_scale)
  746. # random sample (with optional conditional function)
  747. noise = torch.randn_like(xt)
  748. t_mask = t.ne(0).float().view(
  749. -1,
  750. *((1, ) * # noqa
  751. (xt.ndim - 1)))
  752. if condition_fn is not None:
  753. grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
  754. mu = mu.float() + var * grad.float()
  755. xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * noise
  756. return xt_1.type(dtype), x0.type(dtype)
  757. @torch.no_grad()
  758. def p_sample_loop(self,
  759. noise,
  760. model,
  761. model_kwargs={},
  762. clamp=None,
  763. percentile=None,
  764. condition_fn=None,
  765. guide_scale=None):
  766. r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
  767. """
  768. # prepare input
  769. b = noise.size(0)
  770. xt = noise
  771. # diffusion process
  772. for step in torch.arange(self.num_timesteps).flip(0):
  773. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  774. xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
  775. percentile, condition_fn, guide_scale)
  776. return xt
  777. def p_mean_variance(self,
  778. xt,
  779. t,
  780. model,
  781. model_kwargs={},
  782. clamp=None,
  783. percentile=None,
  784. guide_scale=None):
  785. r"""Distribution of p(x_{t-1} | x_t).
  786. """
  787. # predict distribution
  788. if guide_scale is None:
  789. out = model(xt, t=self._scale_timesteps(t), **model_kwargs)
  790. else:
  791. # classifier-free guidance
  792. # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
  793. assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
  794. y_out = model(xt, t=self._scale_timesteps(t), **model_kwargs[0])
  795. if guide_scale != 1.0:
  796. u_out = model(
  797. xt, t=self._scale_timesteps(t), **model_kwargs[1])
  798. dim = y_out.size(1) if self.var_type.startswith(
  799. 'fixed') else y_out.size(1) // 2
  800. out = torch.cat(
  801. [
  802. u_out[:, :dim] + guide_scale * # noqa
  803. (y_out[:, :dim] - u_out[:, :dim]),
  804. y_out[:, dim:]
  805. ],
  806. dim=1) # noqa
  807. else:
  808. out = y_out
  809. # compute variance
  810. if self.var_type == 'learned':
  811. out, log_var = out.chunk(2, dim=1)
  812. var = torch.exp(log_var)
  813. elif self.var_type == 'learned_range':
  814. out, fraction = out.chunk(2, dim=1)
  815. min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
  816. max_log_var = _i(torch.log(self.betas), t, xt)
  817. fraction = (fraction + 1) / 2.0
  818. log_var = fraction * max_log_var + (1 - fraction) * min_log_var
  819. var = torch.exp(log_var)
  820. elif self.var_type == 'fixed_large':
  821. var = _i(
  822. torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
  823. xt)
  824. log_var = torch.log(var)
  825. elif self.var_type == 'fixed_small':
  826. var = _i(self.posterior_variance, t, xt)
  827. log_var = _i(self.posterior_log_variance_clipped, t, xt)
  828. # compute mean and x0
  829. if self.mean_type == 'x_{t-1}':
  830. mu = out # x_{t-1}
  831. x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
  832. _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt # noqa
  833. elif self.mean_type == 'x0':
  834. x0 = out
  835. elif self.mean_type == 'eps':
  836. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
  837. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out # noqa
  838. # restrict the range of x0
  839. if percentile is not None:
  840. assert percentile > 0 and percentile <= 1 # e.g., 0.995
  841. s = torch.quantile(
  842. x0.flatten(1).abs(), percentile,
  843. dim=1).clamp_(1.0).view(-1, 1, 1, 1, 1)
  844. # s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) # old
  845. x0 = torch.min(s, torch.max(-s, x0)) / s
  846. elif clamp is not None:
  847. x0 = x0.clamp(-clamp, clamp)
  848. # recompute mu using the restricted x0
  849. mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
  850. return mu, var, log_var, x0
  851. @torch.no_grad()
  852. def ddim_sample(self,
  853. xt,
  854. t,
  855. t_prev,
  856. model,
  857. model_kwargs={},
  858. clamp=None,
  859. percentile=None,
  860. condition_fn=None,
  861. guide_scale=None,
  862. ddim_timesteps=20,
  863. eta=0.0):
  864. r"""Sample from p(x_{t-1} | x_t) using DDIM.
  865. - condition_fn: for classifier-based guidance (guided-diffusion).
  866. - guide_scale: for classifier-free guidance (glide/dalle-2).
  867. """
  868. dtype = xt.dtype
  869. # predict distribution of p(x_{t-1} | x_t)
  870. _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
  871. percentile, guide_scale)
  872. if condition_fn is not None:
  873. # x0 -> eps
  874. alpha = _i(self.alphas_cumprod, t, xt)
  875. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  876. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  877. eps = eps - (1 - alpha).sqrt() * condition_fn(
  878. xt, self._scale_timesteps(t), **model_kwargs)
  879. # eps -> x0
  880. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
  881. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
  882. # derive variables
  883. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  884. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  885. alphas = _i(self.alphas_cumprod, t, xt)
  886. alphas_prev = _i(self.alphas_cumprod, t_prev, xt)
  887. sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa
  888. (1 - alphas / alphas_prev))
  889. # random sample
  890. noise = torch.randn_like(xt)
  891. direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
  892. t_mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
  893. xt_1 = torch.sqrt(
  894. alphas_prev) * x0 + direction + t_mask * sigmas * noise
  895. return xt_1.type(dtype), x0.type(dtype)
  896. @torch.no_grad()
  897. def ddim_sample_loop(self,
  898. noise,
  899. model,
  900. model_kwargs={},
  901. clamp=None,
  902. percentile=None,
  903. condition_fn=None,
  904. guide_scale=None,
  905. ddim_timesteps=20,
  906. eta=0.0):
  907. # prepare input
  908. b = noise.size(0)
  909. xt = noise
  910. # diffusion process
  911. steps = (1 + torch.arange(0, self.num_timesteps,
  912. self.num_timesteps // ddim_timesteps)).clamp(
  913. 0, self.num_timesteps - 1).flip(0)
  914. for i, step in enumerate(steps):
  915. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  916. t_prev = torch.full((b, ),
  917. steps[i + 1] if i < len(steps) - 1 else 0,
  918. dtype=torch.long,
  919. device=xt.device)
  920. xt, _ = self.ddim_sample(xt, t, t_prev, model, model_kwargs, clamp,
  921. percentile, condition_fn, guide_scale,
  922. ddim_timesteps, eta)
  923. return xt
  924. @torch.no_grad()
  925. def ddim_reverse_sample(self,
  926. xt,
  927. t,
  928. t_next,
  929. model,
  930. model_kwargs={},
  931. clamp=None,
  932. percentile=None,
  933. guide_scale=None,
  934. ddim_timesteps=20):
  935. r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
  936. """
  937. dtype = xt.dtype
  938. # predict distribution of p(x_{t-1} | x_t)
  939. _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
  940. percentile, guide_scale)
  941. # derive variables
  942. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  943. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  944. alphas_next = _i(
  945. torch.cat(
  946. [self.alphas_cumprod,
  947. self.alphas_cumprod.new_zeros([1])]), t_next, xt)
  948. # reverse sample
  949. mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
  950. return mu.type(dtype), x0.type(dtype)
  951. @torch.no_grad()
  952. def ddim_reverse_sample_loop(self,
  953. x0,
  954. model,
  955. model_kwargs={},
  956. clamp=None,
  957. percentile=None,
  958. guide_scale=None,
  959. ddim_timesteps=20):
  960. # prepare input
  961. b = x0.size(0)
  962. xt = x0
  963. # reconstruction steps
  964. steps = (1 + torch.arange(0, self.num_timesteps,
  965. self.num_timesteps // ddim_timesteps)).clamp(
  966. 0, self.num_timesteps - 1)
  967. for i, step in enumerate(steps):
  968. t = torch.full((b, ),
  969. steps[i - 1] if i > 0 else 0,
  970. dtype=torch.long,
  971. device=xt.device)
  972. t_next = torch.full((b, ),
  973. step,
  974. dtype=torch.long,
  975. device=xt.device)
  976. xt, _ = self.ddim_reverse_sample(xt, t, t_next, model,
  977. model_kwargs, clamp, percentile,
  978. guide_scale, ddim_timesteps)
  979. return xt
  980. @torch.no_grad()
  981. def plms_sample(self,
  982. xt,
  983. t,
  984. t_prev,
  985. model,
  986. model_kwargs={},
  987. clamp=None,
  988. percentile=None,
  989. condition_fn=None,
  990. guide_scale=None,
  991. plms_timesteps=20):
  992. r"""Sample from p(x_{t-1} | x_t) using PLMS.
  993. - condition_fn: for classifier-based guidance (guided-diffusion).
  994. - guide_scale: for classifier-free guidance (glide/dalle-2).
  995. """
  996. # function for compute eps
  997. def compute_eps(xt, t):
  998. dtype = xt.dtype
  999. # predict distribution of p(x_{t-1} | x_t)
  1000. _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  1001. clamp, percentile, guide_scale)
  1002. # condition
  1003. if condition_fn is not None:
  1004. # x0 -> eps
  1005. alpha = _i(self.alphas_cumprod, t, xt)
  1006. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  1007. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  1008. eps = eps - (1 - alpha).sqrt() * condition_fn(
  1009. xt, self._scale_timesteps(t), **model_kwargs)
  1010. # eps -> x0
  1011. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
  1012. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
  1013. # derive eps
  1014. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  1015. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  1016. return eps.type(dtype)
  1017. # function for compute x_0 and x_{t-1}
  1018. def compute_x0(eps, t):
  1019. dtype = eps.dtype
  1020. # eps -> x0
  1021. x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
  1022. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
  1023. # deterministic sample
  1024. alphas_prev = _i(self.alphas_cumprod, t_prev, xt)
  1025. direction = torch.sqrt(1 - alphas_prev) * eps
  1026. xt_1 = torch.sqrt(alphas_prev) * x0 + direction
  1027. return xt_1.type(dtype), x0.type(dtype)
  1028. # PLMS sample
  1029. eps = compute_eps(xt, t)
  1030. if len(eps_cache) == 0:
  1031. # 2nd order pseudo improved Euler
  1032. xt_1, x0 = compute_x0(eps, t)
  1033. eps_next = compute_eps(xt_1, t_prev)
  1034. eps_prime = (eps + eps_next) / 2.0
  1035. elif len(eps_cache) == 1:
  1036. # 2nd order pseudo linear multistep (Adams-Bashforth)
  1037. eps_prime = (3 * eps - eps_cache[-1]) / 2.0
  1038. elif len(eps_cache) == 2:
  1039. # 3rd order pseudo linear multistep (Adams-Bashforth)
  1040. eps_prime = (23 * eps - 16 * eps_cache[-1]
  1041. + 5 * eps_cache[-2]) / 12.0
  1042. elif len(eps_cache) >= 3:
  1043. # 4nd order pseudo linear multistep (Adams-Bashforth)
  1044. eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
  1045. - 9 * eps_cache[-3]) / 24.0
  1046. xt_1, x0 = compute_x0(eps_prime, t)
  1047. return xt_1, x0, eps
  1048. @torch.no_grad()
  1049. def plms_sample_loop(self,
  1050. noise,
  1051. model,
  1052. model_kwargs={},
  1053. clamp=None,
  1054. percentile=None,
  1055. condition_fn=None,
  1056. guide_scale=None,
  1057. plms_timesteps=20):
  1058. # prepare input
  1059. b = noise.size(0)
  1060. xt = noise
  1061. # diffusion process
  1062. steps = (1 + torch.arange(0, self.num_timesteps,
  1063. self.num_timesteps // plms_timesteps)).clamp(
  1064. 0, self.num_timesteps - 1).flip(0)
  1065. eps_cache = []
  1066. for i, step in enumerate(steps):
  1067. # PLMS sampling step
  1068. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  1069. t_prev = torch.full((b, ),
  1070. steps[i + 1] if i < len(steps) - 1 else 0,
  1071. dtype=torch.long,
  1072. device=xt.device)
  1073. xt, _, eps = self.plms_sample(xt, t, t_prev, model, model_kwargs,
  1074. clamp, percentile, condition_fn,
  1075. guide_scale, plms_timesteps,
  1076. eps_cache)
  1077. # update eps cache
  1078. eps_cache.append(eps)
  1079. if len(eps_cache) >= 4:
  1080. eps_cache.pop(0)
  1081. return xt
  1082. @torch.no_grad()
  1083. def dpm_solver_sample_loop(self,
  1084. noise,
  1085. model,
  1086. model_kwargs={},
  1087. order=2,
  1088. skip_type='logSNR',
  1089. method='multistep',
  1090. clamp=None,
  1091. percentile=None,
  1092. condition_fn=None,
  1093. guide_scale=None,
  1094. dpm_solver_timesteps=20,
  1095. algorithm_type='dpmsolver++',
  1096. t_start=None,
  1097. t_end=None,
  1098. lower_order_final=True,
  1099. denoise_to_zero=False,
  1100. solver_type='dpmsolver'):
  1101. r"""Sample using DPM-Solver-based method.
  1102. - condition_fn: for classifier-based guidance (guided-diffusion).
  1103. - guide_scale: for classifier-free guidance (glide/dalle-2).
  1104. Please check all the parameters in `dpm_solver.sample` before using.
  1105. """
  1106. assert self.mean_type in ('eps', 'x0')
  1107. assert percentile in (None, 0.995)
  1108. assert clamp is None or percentile is None
  1109. noise_schedule = NoiseScheduleVP(
  1110. schedule='discrete', betas=self.betas.float())
  1111. model_fn = model_wrapper_guided_diffusion(
  1112. model=model,
  1113. noise_schedule=noise_schedule,
  1114. var_type=self.var_type,
  1115. mean_type=self.mean_type,
  1116. model_kwargs=model_kwargs,
  1117. rescale_timesteps=self.rescale_timesteps,
  1118. num_timesteps=self.num_timesteps,
  1119. guide_scale=guide_scale,
  1120. condition_fn=condition_fn)
  1121. dpm_solver = DPM_Solver(
  1122. model_fn=model_fn,
  1123. noise_schedule=noise_schedule,
  1124. algorithm_type=algorithm_type,
  1125. percentile=percentile,
  1126. clamp=clamp)
  1127. xt = dpm_solver.sample(
  1128. noise,
  1129. steps=dpm_solver_timesteps,
  1130. order=order,
  1131. skip_type=skip_type,
  1132. method=method,
  1133. solver_type=solver_type,
  1134. t_start=t_start,
  1135. t_end=t_end,
  1136. lower_order_final=lower_order_final,
  1137. denoise_to_zero=denoise_to_zero)
  1138. return xt
  1139. @torch.no_grad()
  1140. def inpaint_p_sample(self,
  1141. xt,
  1142. t,
  1143. y,
  1144. mask,
  1145. model,
  1146. model_kwargs={},
  1147. clamp=None,
  1148. percentile=None,
  1149. guide_scale=None):
  1150. r"""DDPM sampling step for inpainting.
  1151. """
  1152. dtype = xt.dtype
  1153. # predict distribution of p(x_{t-1} | x_t), conditioned on y and mask
  1154. xt = self.q_sample(y, t) * mask + xt * (1 - mask)
  1155. mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  1156. clamp, percentile,
  1157. guide_scale)
  1158. # random sample
  1159. t_mask = t.ne(0).float().view(
  1160. -1,
  1161. *((1, ) * # noqa
  1162. (xt.ndim - 1)))
  1163. xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * torch.randn_like(xt)
  1164. return xt_1.type(dtype), x0.type(dtype)
  1165. @torch.no_grad()
  1166. def inpaint_p_sample_loop(self,
  1167. noise,
  1168. y,
  1169. mask,
  1170. model,
  1171. model_kwargs={},
  1172. clamp=None,
  1173. percentile=None,
  1174. guide_scale=None):
  1175. r"""DDPM sampling loop for inpainting.
  1176. """
  1177. # prepare input
  1178. b = noise.size(0)
  1179. xt = noise
  1180. # diffusion process
  1181. for step in torch.arange(self.num_timesteps).flip(0):
  1182. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  1183. xt, _ = self.inpaint_p_sample(xt, t, y, mask, model, model_kwargs,
  1184. clamp, percentile, guide_scale)
  1185. return xt
  1186. @torch.no_grad()
  1187. def inpaint_mcg_p_sample(self,
  1188. xt,
  1189. t,
  1190. y,
  1191. mask,
  1192. model,
  1193. model_kwargs={},
  1194. clamp=None,
  1195. percentile=None,
  1196. guide_scale=None,
  1197. mcg_scale=1.0):
  1198. r"""DDPM sampling step for inpainting, with Manifold Constrained Gradient (MCG) correction.
  1199. """
  1200. dtype = xt.dtype
  1201. # predict distribution of p(x_{t-1} | x_t), conditioned on y and mask
  1202. with torch.enable_grad():
  1203. xt.requires_grad_(True)
  1204. mu, var, log_var, x0 = self.p_mean_variance(
  1205. xt, t, model, model_kwargs, clamp, percentile, guide_scale)
  1206. loss = (y * mask - x0 * mask).square().mean()
  1207. grad = torch.autograd.grad(loss, xt)[0]
  1208. # random sample
  1209. t_mask = t.ne(0).float().view(
  1210. -1,
  1211. *((1, ) * # noqa
  1212. (xt.ndim - 1)))
  1213. xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * torch.randn_like(xt)
  1214. xt_1 = xt_1 - mcg_scale * grad
  1215. # merge foreground and background
  1216. xt_1 = self.q_sample(y, t) * mask + xt_1 * (1 - mask)
  1217. return xt_1.type(dtype), x0.type(dtype)
  1218. @torch.no_grad()
  1219. def inpaint_mcg_p_sample_loop(self,
  1220. noise,
  1221. y,
  1222. mask,
  1223. model,
  1224. model_kwargs={},
  1225. clamp=None,
  1226. percentile=None,
  1227. guide_scale=None,
  1228. mcg_scale=1.0):
  1229. r"""DDPM sampling loop for inpainting, with Manifold Constrained Gradient (MCG) correction.
  1230. """
  1231. # prepare input
  1232. b = noise.size(0)
  1233. xt = noise
  1234. # diffusion process
  1235. for step in torch.arange(self.num_timesteps).flip(0):
  1236. t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
  1237. xt, _ = self.inpaint_mcg_p_sample(xt, t, y, mask, model,
  1238. model_kwargs, clamp, percentile,
  1239. guide_scale, mcg_scale)
  1240. return xt
  1241. def loss(self,
  1242. x0,
  1243. t,
  1244. model,
  1245. model_kwargs={},
  1246. noise=None,
  1247. input_x0=None,
  1248. reduction='mean'):
  1249. assert reduction in ['mean', 'none']
  1250. noise = torch.randn_like(x0) if noise is None else noise
  1251. input_x0 = x0 if input_x0 is None else input_x0
  1252. xt = self.q_sample(input_x0, t, noise=noise)
  1253. # compute loss
  1254. if self.loss_type in ['kl', 'rescaled_kl']:
  1255. loss, _ = self.variational_lower_bound(x0, xt, t, model,
  1256. model_kwargs)
  1257. if self.loss_type == 'rescaled_kl':
  1258. loss = loss * self.num_timesteps
  1259. elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
  1260. out = model(xt, t=self._scale_timesteps(t), **model_kwargs)
  1261. # VLB for variation
  1262. loss_vlb = 0.0
  1263. if self.var_type in ['learned', 'learned_range']:
  1264. out, var = out.chunk(2, dim=1)
  1265. frozen = torch.cat([
  1266. out.detach(), var
  1267. ], dim=1) # learn var without affecting the prediction of mean
  1268. loss_vlb, _ = self.variational_lower_bound(
  1269. x0,
  1270. xt,
  1271. t,
  1272. model=lambda *args, **kwargs: frozen,
  1273. reduction=reduction)
  1274. if self.loss_type.startswith('rescaled_'):
  1275. loss_vlb = loss_vlb * self.num_timesteps / 1000.0
  1276. # MSE/L1 for x0/eps
  1277. target = {
  1278. 'eps': noise,
  1279. 'x0': x0,
  1280. 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
  1281. }[self.mean_type]
  1282. loss = (
  1283. out
  1284. - target).pow(1 if self.loss_type.endswith('l1') else 2).abs()
  1285. if reduction == 'mean':
  1286. loss = loss.flatten(1).mean(dim=1)
  1287. # total loss
  1288. loss = loss + loss_vlb
  1289. return loss
  1290. def variational_lower_bound(self,
  1291. x0,
  1292. xt,
  1293. t,
  1294. model,
  1295. model_kwargs={},
  1296. clamp=None,
  1297. percentile=None,
  1298. reduction='mean'):
  1299. assert reduction in ['mean', 'none']
  1300. # compute groundtruth and predicted distributions
  1301. mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
  1302. mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
  1303. clamp, percentile)
  1304. # compute KL loss
  1305. kl = kl_divergence(mu1, log_var1, mu2, log_var2) / math.log(2.0)
  1306. if reduction == 'mean':
  1307. kl = kl.flatten(1).mean(dim=1)
  1308. # compute discretized NLL loss (for p(x0 | x1) only)
  1309. nll = -discretized_gaussian_log_likelihood(
  1310. x0, mean=mu2, log_scale=0.5 * log_var2) / math.log(2.0)
  1311. if reduction == 'mean':
  1312. nll = nll.flatten(1).mean(dim=1)
  1313. # NLL for p(x0 | x1) and KL otherwise
  1314. t = t.view(-1, *(1, ) * (nll.ndim - 1))
  1315. vlb = torch.where(t == 0, nll, kl)
  1316. return vlb, x0
  1317. @torch.no_grad()
  1318. def variational_lower_bound_loop(self,
  1319. x0,
  1320. model,
  1321. model_kwargs={},
  1322. clamp=None,
  1323. percentile=None):
  1324. r"""Compute the entire variational lower bound, measured in bits-per-dim.
  1325. """
  1326. # prepare input and output
  1327. b = x0.size(0)
  1328. metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
  1329. # loop
  1330. for step in torch.arange(self.num_timesteps).flip(0):
  1331. # compute VLB
  1332. t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
  1333. noise = torch.randn_like(x0)
  1334. xt = self.q_sample(x0, t, noise)
  1335. vlb, pred_x0 = self.variational_lower_bound(
  1336. x0, xt, t, model, model_kwargs, clamp, percentile)
  1337. # predict eps from x0
  1338. eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
  1339. _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
  1340. # collect metrics
  1341. metrics['vlb'].append(vlb)
  1342. metrics['x0_mse'].append(
  1343. (pred_x0 - x0).square().flatten(1).mean(dim=1))
  1344. metrics['mse'].append(
  1345. (eps - noise).square().flatten(1).mean(dim=1))
  1346. metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
  1347. # compute the prior KL term for VLB, measured in bits-per-dim
  1348. mu, _, log_var = self.q_mean_variance(x0, t)
  1349. kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
  1350. torch.zeros_like(log_var))
  1351. kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
  1352. # update metrics
  1353. metrics['prior_bits_per_dim'] = kl_prior
  1354. metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
  1355. return metrics
  1356. def _scale_timesteps(self, t):
  1357. if self.rescale_timesteps:
  1358. return t.float() * 1000.0 / self.num_timesteps
  1359. return t