| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657 |
- # Part of the implementation is borrowed and modified from latent-diffusion,
- # publicly available at https://github.com/CompVis/latent-diffusion.
- # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
- import math
- import torch
- from modelscope.models.multi_modal.dpm_solver_pytorch import (
- DPM_Solver, NoiseScheduleVP, model_wrapper, model_wrapper_guided_diffusion)
- __all__ = ['GaussianDiffusion', 'beta_schedule']
- def kl_divergence(mu1, logvar1, mu2, logvar2):
- a = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
- b = ((mu1 - mu2)**2) * torch.exp(-logvar2)
- return 0.5 * (a + b)
- def standard_normal_cdf(x):
- return 0.5 * (1.0 + torch.tanh(
- math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
- def discretized_gaussian_log_likelihood(x0, mean, log_scale):
- assert x0.shape == mean.shape == log_scale.shape
- cx = x0 - mean
- inv_stdv = torch.exp(-log_scale)
- cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
- cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
- log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
- log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
- cdf_delta = cdf_plus - cdf_min
- log_probs = torch.where(
- x0 < -0.999, log_cdf_plus,
- torch.where(x0 > 0.999, log_one_minus_cdf_min,
- torch.log(cdf_delta.clamp(min=1e-12))))
- assert log_probs.shape == x0.shape
- return log_probs
- def _i(tensor, t, x):
- tensor = tensor.to(x.device)
- shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
- return tensor[t].view(shape).to(x)
- def cosine_fn(u):
- return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
- def beta_schedule(schedule,
- num_timesteps=1000,
- init_beta=None,
- last_beta=None):
- if schedule == 'linear':
- scale = 1000.0 / num_timesteps
- init_beta = init_beta or scale * 0.0001
- last_beta = last_beta or scale * 0.02
- return torch.linspace(
- init_beta, last_beta, num_timesteps, dtype=torch.float64)
- elif schedule == 'quadratic':
- init_beta = init_beta or 0.0015
- last_beta = last_beta or 0.0195
- return torch.linspace(
- init_beta**0.5, last_beta**0.5, num_timesteps,
- dtype=torch.float64)**2
- elif schedule == 'cosine':
- betas = []
- for step in range(num_timesteps):
- t1 = step / num_timesteps
- t2 = (step + 1) / num_timesteps
- betas.append(min(1.0 - cosine_fn(t2) / cosine_fn(t1), 0.999))
- return torch.tensor(betas, dtype=torch.float64)
- else:
- raise ValueError(f'Unsupported schedule: {schedule}')
- class GaussianDiffusion(object):
- def __init__(self,
- betas,
- mean_type='eps',
- var_type='learned_range',
- loss_type='mse',
- rescale_timesteps=False):
- # check input
- if not isinstance(betas, torch.DoubleTensor):
- betas = torch.tensor(betas, dtype=torch.float64)
- assert min(betas) > 0 and max(betas) <= 1
- assert mean_type in ['x0', 'x_{t-1}', 'eps']
- assert var_type in [
- 'learned', 'learned_range', 'fixed_large', 'fixed_small'
- ]
- assert loss_type in [
- 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
- ]
- self.betas = betas
- self.num_timesteps = len(betas)
- self.mean_type = mean_type
- self.var_type = var_type
- self.loss_type = loss_type
- self.rescale_timesteps = rescale_timesteps
- # alphas
- alphas = 1 - self.betas
- self.alphas_cumprod = torch.cumprod(alphas, dim=0)
- self.alphas_cumprod_prev = torch.cat(
- [alphas.new_ones([1]), self.alphas_cumprod[:-1]])
- self.alphas_cumprod_next = torch.cat(
- [self.alphas_cumprod[1:],
- alphas.new_zeros([1])])
- # q(x_t | x_{t-1})
- self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
- self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
- - self.alphas_cumprod)
- self.log_one_minus_alphas_cumprod = torch.log(1.0
- - self.alphas_cumprod)
- self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
- self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
- - 1)
- # q(x_{t-1} | x_t, x_0)
- self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
- 1.0 - self.alphas_cumprod)
- self.posterior_log_variance_clipped = torch.log(
- self.posterior_variance.clamp(1e-20))
- self.posterior_mean_coef1 = betas * torch.sqrt(
- self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
- self.posterior_mean_coef2 = (
- 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
- 1.0 - self.alphas_cumprod)
- def q_sample(self, x0, t, noise=None):
- noise = torch.randn_like(x0) if noise is None else noise
- return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i(
- self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
- def q_mean_variance(self, x0, t):
- mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
- var = _i(1.0 - self.alphas_cumprod, t, x0)
- log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
- return mu, var, log_var
- def q_posterior_mean_variance(self, x0, xt, t):
- mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
- self.posterior_mean_coef2, t, xt) * xt
- var = _i(self.posterior_variance, t, xt)
- log_var = _i(self.posterior_log_variance_clipped, t, xt)
- return mu, var, log_var
- @torch.no_grad()
- def p_sample(self,
- xt,
- t,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None,
- condition_fn=None,
- guide_scale=None):
- # predict distribution of p(x_{t-1} | x_t)
- mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
- clamp, percentile,
- guide_scale)
- # random sample (with optional conditional function)
- noise = torch.randn_like(xt)
- shape = (-1, ) + ((1, ) * (xt.ndim - 1))
- mask = t.ne(0).float().view(*shape) # no noise when t == 0
- if condition_fn is not None:
- grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
- mu = mu.float() + var * grad.float()
- xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
- return xt_1, x0
- @torch.no_grad()
- def p_sample_loop(self,
- noise,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None,
- condition_fn=None,
- guide_scale=None):
- # prepare input
- b, c, h, w = noise.size()
- xt = noise
- # diffusion process
- for step in torch.arange(self.num_timesteps).flip(0):
- t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
- xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
- percentile, condition_fn, guide_scale)
- return xt
- def p_mean_variance(self,
- xt,
- t,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None,
- guide_scale=None):
- # predict distribution
- if guide_scale is None:
- out = model(xt, self._scale_timesteps(t), **model_kwargs)
- else:
- # classifier-free guidance
- # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
- assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
- assert self.mean_type == 'eps'
- y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
- u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
- a = u_out[:, :3]
- b = guide_scale * (y_out[:, :3] - u_out[:, :3])
- c = y_out[:, 3:]
- out = torch.cat([a + b, c], dim=1)
- # compute variance
- if self.var_type == 'learned':
- out, log_var = out.chunk(2, dim=1)
- var = torch.exp(log_var)
- elif self.var_type == 'learned_range':
- out, fraction = out.chunk(2, dim=1)
- min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
- max_log_var = _i(torch.log(self.betas), t, xt)
- fraction = (fraction + 1) / 2.0
- log_var = fraction * max_log_var + (1 - fraction) * min_log_var
- var = torch.exp(log_var)
- elif self.var_type == 'fixed_large':
- var = _i(
- torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
- xt)
- log_var = torch.log(var)
- elif self.var_type == 'fixed_small':
- var = _i(self.posterior_variance, t, xt)
- log_var = _i(self.posterior_log_variance_clipped, t, xt)
- # compute mean and x0
- if self.mean_type == 'x_{t-1}':
- mu = out # x_{t-1}
- x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i(
- self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
- xt) * xt
- elif self.mean_type == 'x0':
- x0 = out
- mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
- elif self.mean_type == 'eps':
- x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
- self.sqrt_recipm1_alphas_cumprod, t, xt) * out
- mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
- # restrict the range of x0
- if percentile is not None:
- assert percentile > 0 and percentile <= 1 # e.g., 0.995
- s = torch.quantile(
- x0.flatten(1).abs(), percentile,
- dim=1).clamp_(1.0).view(-1, 1, 1, 1)
- x0 = torch.min(s, torch.max(-s, x0)) / s
- elif clamp is not None:
- x0 = x0.clamp(-clamp, clamp)
- return mu, var, log_var, x0
- @torch.no_grad()
- def dpm_solver_sample_loop(self,
- noise,
- model,
- skip_type,
- order,
- method,
- model_kwargs={},
- clamp=None,
- percentile=None,
- condition_fn=None,
- guide_scale=None,
- dpm_solver_timesteps=20,
- t_start=None,
- t_end=None,
- lower_order_final=True,
- denoise_to_zero=False,
- solver_type='dpm_solver'):
- r"""Sample using DPM-Solver-based method.
- - condition_fn: for classifier-based guidance (guided-diffusion).
- - guide_scale: for classifier-free guidance (glide/dalle-2).
- Please check all the parameters in `dpm_solver.sample` before using.
- """
- noise_schedule = NoiseScheduleVP(
- schedule='discrete', betas=self.betas.float())
- model_fn = model_wrapper_guided_diffusion(
- model=model,
- noise_schedule=noise_schedule,
- var_type=self.var_type,
- mean_type=self.mean_type,
- model_kwargs=model_kwargs,
- clamp=clamp,
- percentile=percentile,
- rescale_timesteps=self.rescale_timesteps,
- num_timesteps=self.num_timesteps,
- guide_scale=guide_scale,
- condition_fn=condition_fn,
- )
- dpm_solver = DPM_Solver(
- model_fn=model_fn,
- noise_schedule=noise_schedule,
- )
- xt = dpm_solver.sample(
- noise,
- steps=dpm_solver_timesteps,
- order=order,
- skip_type=skip_type,
- method=method,
- solver_type=solver_type,
- t_start=t_start,
- t_end=t_end,
- lower_order_final=lower_order_final,
- denoise_to_zero=denoise_to_zero)
- return xt
- @torch.no_grad()
- def ddim_sample(self,
- xt,
- t,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None,
- condition_fn=None,
- guide_scale=None,
- ddim_timesteps=20,
- eta=0.0):
- stride = self.num_timesteps // ddim_timesteps
- # predict distribution of p(x_{t-1} | x_t)
- _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
- percentile, guide_scale)
- if condition_fn is not None:
- # x0 -> eps
- alpha = _i(self.alphas_cumprod, t, xt)
- eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
- self.sqrt_recipm1_alphas_cumprod, t, xt)
- eps = eps - (1 - alpha).sqrt() * condition_fn(
- xt, self._scale_timesteps(t), **model_kwargs)
- # eps -> x0
- x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
- self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
- # derive variables
- eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
- self.sqrt_recipm1_alphas_cumprod, t, xt)
- alphas = _i(self.alphas_cumprod, t, xt)
- alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
- a = (1 - alphas_prev) / (1 - alphas)
- b = (1 - alphas / alphas_prev)
- sigmas = eta * torch.sqrt(a * b)
- # random sample
- noise = torch.randn_like(xt)
- direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
- mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
- xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
- return xt_1, x0
- @torch.no_grad()
- def ddim_sample_loop(self,
- noise,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None,
- condition_fn=None,
- guide_scale=None,
- ddim_timesteps=20,
- eta=0.0):
- # prepare input
- b, c, h, w = noise.size()
- xt = noise
- # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
- steps = (1 + torch.arange(0, self.num_timesteps,
- self.num_timesteps // ddim_timesteps)).clamp(
- 0, self.num_timesteps - 1).flip(0)
- for step in steps:
- t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
- xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
- percentile, condition_fn, guide_scale,
- ddim_timesteps, eta)
- return xt
- @torch.no_grad()
- def ddim_reverse_sample(self,
- xt,
- t,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None,
- guide_scale=None,
- ddim_timesteps=20):
- stride = self.num_timesteps // ddim_timesteps
- # predict distribution of p(x_{t-1} | x_t)
- _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
- percentile, guide_scale)
- # derive variables
- eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
- self.sqrt_recipm1_alphas_cumprod, t, xt)
- alphas_next = _i(
- torch.cat(
- [self.alphas_cumprod,
- self.alphas_cumprod.new_zeros([1])]),
- (t + stride).clamp(0, self.num_timesteps), xt)
- # reverse sample
- mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
- return mu, x0
- @torch.no_grad()
- def ddim_reverse_sample_loop(self,
- x0,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None,
- guide_scale=None,
- ddim_timesteps=20):
- # prepare input
- b, c, h, w = x0.size()
- xt = x0
- # reconstruction steps
- steps = torch.arange(0, self.num_timesteps,
- self.num_timesteps // ddim_timesteps)
- for step in steps:
- t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
- xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
- percentile, guide_scale,
- ddim_timesteps)
- return xt
- @torch.no_grad()
- def plms_sample(self,
- xt,
- t,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None,
- condition_fn=None,
- guide_scale=None,
- plms_timesteps=20):
- stride = self.num_timesteps // plms_timesteps
- # function for compute eps
- def compute_eps(xt, t):
- # predict distribution of p(x_{t-1} | x_t)
- _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
- clamp, percentile, guide_scale)
- # condition
- if condition_fn is not None:
- # x0 -> eps
- alpha = _i(self.alphas_cumprod, t, xt)
- eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
- - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
- eps = eps - (1 - alpha).sqrt() * condition_fn(
- xt, self._scale_timesteps(t), **model_kwargs)
- # eps -> x0
- x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
- self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
- # derive eps
- eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
- self.sqrt_recipm1_alphas_cumprod, t, xt)
- return eps
- # function for compute x_0 and x_{t-1}
- def compute_x0(eps, t):
- # eps -> x0
- x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
- self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
- # deterministic sample
- alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
- direction = torch.sqrt(1 - alphas_prev) * eps
- xt_1 = torch.sqrt(alphas_prev) * x0 + direction
- return xt_1, x0
- # PLMS sample
- eps = compute_eps(xt, t)
- if len(eps_cache) == 0:
- # 2nd order pseudo improved Euler
- xt_1, x0 = compute_x0(eps, t)
- eps_next = compute_eps(xt_1, (t - stride).clamp(0))
- eps_prime = (eps + eps_next) / 2.0
- elif len(eps_cache) == 1:
- # 2nd order pseudo linear multistep (Adams-Bashforth)
- eps_prime = (3 * eps - eps_cache[-1]) / 2.0
- elif len(eps_cache) == 2:
- # 3rd order pseudo linear multistep (Adams-Bashforth)
- eps_prime = (23 * eps - 16 * eps_cache[-1]
- + 5 * eps_cache[-2]) / 12.0
- elif len(eps_cache) >= 3:
- # 4nd order pseudo linear multistep (Adams-Bashforth)
- eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
- - 9 * eps_cache[-3]) / 24.0
- xt_1, x0 = compute_x0(eps_prime, t)
- return xt_1, x0, eps
- @torch.no_grad()
- def plms_sample_loop(self,
- noise,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None,
- condition_fn=None,
- guide_scale=None,
- plms_timesteps=20):
- # prepare input
- b, c, h, w = noise.size()
- xt = noise
- # diffusion process
- steps = (1 + torch.arange(0, self.num_timesteps,
- self.num_timesteps // plms_timesteps)).clamp(
- 0, self.num_timesteps - 1).flip(0)
- eps_cache = []
- for step in steps:
- # PLMS sampling step
- t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
- xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
- percentile, condition_fn,
- guide_scale, plms_timesteps,
- eps_cache)
- # update eps cache
- eps_cache.append(eps)
- if len(eps_cache) >= 4:
- eps_cache.pop(0)
- return xt
- def loss(self, x0, t, model, model_kwargs={}, noise=None):
- noise = torch.randn_like(x0) if noise is None else noise
- xt = self.q_sample(x0, t, noise=noise)
- # compute loss
- if self.loss_type in ['kl', 'rescaled_kl']:
- loss, _ = self.variational_lower_bound(x0, xt, t, model,
- model_kwargs)
- if self.loss_type == 'rescaled_kl':
- loss = loss * self.num_timesteps
- elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
- out = model(xt, self._scale_timesteps(t), **model_kwargs)
- # VLB for variation
- loss_vlb = 0.0
- if self.var_type in ['learned', 'learned_range']:
- out, var = out.chunk(2, dim=1)
- frozen = torch.cat([
- out.detach(), var
- ], dim=1) # learn var without affecting the prediction of mean
- loss_vlb, _ = self.variational_lower_bound(
- x0, xt, t, model=lambda *args, **kwargs: frozen)
- if self.loss_type.startswith('rescaled_'):
- loss_vlb = loss_vlb * self.num_timesteps / 1000.0
- # MSE/L1 for x0/eps
- target = {
- 'eps': noise,
- 'x0': x0,
- 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
- }[self.mean_type]
- loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
- ).abs().flatten(1).mean(dim=1)
- # total loss
- loss = loss + loss_vlb
- return loss
- def variational_lower_bound(self,
- x0,
- xt,
- t,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None):
- # compute groundtruth and predicted distributions
- mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
- mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
- clamp, percentile)
- # compute KL loss
- kl = kl_divergence(mu1, log_var1, mu2, log_var2)
- kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
- # compute discretized NLL loss (for p(x0 | x1) only)
- nll = -discretized_gaussian_log_likelihood(
- x0, mean=mu2, log_scale=0.5 * log_var2)
- nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
- # NLL for p(x0 | x1) and KL otherwise
- vlb = torch.where(t == 0, nll, kl)
- return vlb, x0
- @torch.no_grad()
- def variational_lower_bound_loop(self,
- x0,
- model,
- model_kwargs={},
- clamp=None,
- percentile=None):
- # prepare input and output
- b, c, h, w = x0.size()
- metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
- # loop
- for step in torch.arange(self.num_timesteps).flip(0):
- # compute VLB
- t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
- noise = torch.randn_like(x0)
- xt = self.q_sample(x0, t, noise)
- vlb, pred_x0 = self.variational_lower_bound(
- x0, xt, t, model, model_kwargs, clamp, percentile)
- # predict eps from x0
- eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
- self.sqrt_recipm1_alphas_cumprod, t, xt)
- # collect metrics
- metrics['vlb'].append(vlb)
- metrics['x0_mse'].append(
- (pred_x0 - x0).square().flatten(1).mean(dim=1))
- metrics['mse'].append(
- (eps - noise).square().flatten(1).mean(dim=1))
- metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
- # compute the prior KL term for VLB, measured in bits-per-dim
- mu, _, log_var = self.q_mean_variance(x0, t)
- kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
- torch.zeros_like(log_var))
- kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
- # update metrics
- metrics['prior_bits_per_dim'] = kl_prior
- metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
- return metrics
- def _scale_timesteps(self, t):
- if self.rescale_timesteps:
- return t.float() * 1000.0 / self.num_timesteps
- return t
|