| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- __all__ = ['AutoencoderKL']
- def nonlinearity(x):
- # swish
- return x * torch.sigmoid(x)
- def Normalize(in_channels, num_groups=32):
- return torch.nn.GroupNorm(
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
- class DiagonalGaussianDistribution(object):
- def __init__(self, parameters, deterministic=False):
- self.parameters = parameters
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
- self.deterministic = deterministic
- self.std = torch.exp(0.5 * self.logvar)
- self.var = torch.exp(self.logvar)
- if self.deterministic:
- self.var = self.std = torch.zeros_like(
- self.mean).to(device=self.parameters.device)
- def sample(self):
- x = self.mean + self.std * torch.randn(
- self.mean.shape).to(device=self.parameters.device)
- return x
- def kl(self, other=None):
- if self.deterministic:
- return torch.Tensor([0.])
- else:
- if other is None:
- return 0.5 * torch.sum(
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
- dim=[1, 2, 3])
- else:
- return 0.5 * torch.sum(
- torch.pow(self.mean - other.mean, 2) / other.var
- + self.var / other.var - 1.0 - self.logvar + other.logvar,
- dim=[1, 2, 3])
- def nll(self, sample, dims=[1, 2, 3]):
- if self.deterministic:
- return torch.Tensor([0.])
- logtwopi = np.log(2.0 * np.pi)
- return 0.5 * torch.sum(
- logtwopi + self.logvar
- + torch.pow(sample - self.mean, 2) / self.var,
- dim=dims)
- def mode(self):
- return self.mean
- class Downsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=2, padding=0)
- def forward(self, x):
- if self.with_conv:
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
- return x
- class ResnetBlock(nn.Module):
- def __init__(self,
- *,
- in_channels,
- out_channels=None,
- conv_shortcut=False,
- dropout,
- temb_channels=512):
- super().__init__()
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.use_conv_shortcut = conv_shortcut
- self.norm1 = Normalize(in_channels)
- self.conv1 = torch.nn.Conv2d(
- in_channels, out_channels, kernel_size=3, stride=1, padding=1)
- if temb_channels > 0:
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
- self.norm2 = Normalize(out_channels)
- self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = torch.nn.Conv2d(
- out_channels, out_channels, kernel_size=3, stride=1, padding=1)
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- else:
- self.nin_shortcut = torch.nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- def forward(self, x, temb):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
- if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
- else:
- x = self.nin_shortcut(x)
- return x + h
- class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h * w)
- q = q.permute(0, 2, 1) # b,hw,c
- k = k.reshape(b, c, h * w) # b,c,hw
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c)**(-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
- # attend to values
- v = v.reshape(b, c, h * w)
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(
- v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b, c, h, w)
- h_ = self.proj_out(h_)
- return x + h_
- class AttnBlock(nn.Module): # noqa
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.k = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.v = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.proj_out = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h * w)
- q = q.permute(0, 2, 1) # b,hw,c
- k = k.reshape(b, c, h * w) # b,c,hw
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c)**(-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
- # attend to values
- v = v.reshape(b, c, h * w)
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(
- v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b, c, h, w)
- h_ = self.proj_out(h_)
- return x + h_
- class Upsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=1, padding=1)
- def forward(self, x):
- x = torch.nn.functional.interpolate(
- x, scale_factor=2.0, mode='nearest')
- if self.with_conv:
- x = self.conv(x)
- return x
- class Downsample(nn.Module): # noqa
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(
- in_channels, in_channels, kernel_size=3, stride=2, padding=0)
- def forward(self, x):
- if self.with_conv:
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
- return x
- class Encoder(nn.Module):
- def __init__(self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- z_channels,
- double_z=True,
- use_linear_attn=False,
- attn_type='vanilla',
- **ignore_kwargs):
- super().__init__()
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- # downsampling
- self.conv_in = torch.nn.Conv2d(
- in_channels, self.ch, kernel_size=3, stride=1, padding=1)
- curr_res = resolution
- in_ch_mult = (1, ) + tuple(ch_mult)
- self.in_ch_mult = in_ch_mult
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch * in_ch_mult[i_level]
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(AttnBlock(block_in))
- down = nn.Module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions - 1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- self.mid.attn_1 = AttnBlock(block_in)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in,
- 2 * z_channels if double_z else z_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- def forward(self, x):
- # timestep embedding
- temb = None
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions - 1:
- hs.append(self.down[i_level].downsample(hs[-1]))
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
- class Decoder(nn.Module):
- def __init__(self,
- *,
- ch,
- out_ch,
- ch_mult=(1, 2, 4, 8),
- num_res_blocks,
- attn_resolutions,
- dropout=0.0,
- resamp_with_conv=True,
- in_channels,
- resolution,
- z_channels,
- give_pre_end=False,
- tanh_out=False,
- use_linear_attn=False,
- attn_type='vanilla',
- **ignorekwargs):
- super().__init__()
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.give_pre_end = give_pre_end
- self.tanh_out = tanh_out
- # compute in_ch_mult, block_in and curr_res at lowest res
- block_in = ch * ch_mult[self.num_resolutions - 1]
- curr_res = resolution // 2**(self.num_resolutions - 1)
- self.z_shape = (1, z_channels, curr_res, curr_res)
- print('Working with z of shape {} = {} dimensions.'.format(
- self.z_shape, np.prod(self.z_shape)))
- # z to block_in
- self.conv_in = torch.nn.Conv2d(
- z_channels, block_in, kernel_size=3, stride=1, padding=1)
- # middle
- self.mid = nn.Module()
- self.mid.block_1 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- self.mid.attn_1 = AttnBlock(block_in)
- self.mid.block_2 = ResnetBlock(
- in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- block.append(
- ResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(AttnBlock(block_in))
- up = nn.Module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(
- block_in, out_ch, kernel_size=3, stride=1, padding=1)
- def forward(self, z):
- self.last_z_shape = z.shape
- # timestep embedding
- temb = None
- # z to block_in
- h = self.conv_in(z)
- # middle
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](h, temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
- # end
- if self.give_pre_end:
- return h
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
- class AutoencoderKL(nn.Module):
- def __init__(self,
- ddconfig,
- embed_dim,
- ckpt_path=None,
- ignore_keys=[],
- image_key='image',
- colorize_nlabels=None,
- monitor=None,
- ema_decay=None,
- learn_logvar=False):
- super().__init__()
- self.learn_logvar = learn_logvar
- self.image_key = image_key
- self.encoder = Encoder(**ddconfig)
- self.decoder = Decoder(**ddconfig)
- assert ddconfig['double_z']
- self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
- 2 * embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim,
- ddconfig['z_channels'], 1)
- self.embed_dim = embed_dim
- if colorize_nlabels is not None:
- assert type(colorize_nlabels) == int
- self.register_buffer('colorize',
- torch.randn(3, colorize_nlabels, 1, 1))
- if monitor is not None:
- self.monitor = monitor
- self.use_ema = ema_decay is not None
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
- def init_from_ckpt(self, path, ignore_keys=list()):
- sd = torch.load(path, map_location='cpu')['state_dict']
- keys = list(sd.keys())
- for key in keys:
- print(key, sd[key].shape)
- import collections
- sd_new = collections.OrderedDict()
- for k in keys:
- if k.find('first_stage_model') >= 0:
- k_new = k.split('first_stage_model.')[-1]
- sd_new[k_new] = sd[k]
- self.load_state_dict(sd_new, strict=True)
- print(f'Restored from {path}')
- def init_from_ckpt2(self, path, ignore_keys=list()):
- sd = torch.load(path, map_location='cpu')['state_dict']
- keys = list(sd.keys())
- first_stage_model
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print('Deleting key {} from state_dict.'.format(k))
- del sd[k]
- self.load_state_dict(sd, strict=False)
- print(f'Restored from {path}')
- def on_train_batch_end(self, *args, **kwargs):
- if self.use_ema:
- self.model_ema(self)
- def encode(self, x):
- h = self.encoder(x)
- moments = self.quant_conv(h)
- posterior = DiagonalGaussianDistribution(moments)
- return posterior
- def decode(self, z):
- z = self.post_quant_conv(z)
- dec = self.decoder(z)
- return dec
- def forward(self, input, sample_posterior=True):
- posterior = self.encode(input)
- if sample_posterior:
- z = posterior.sample()
- else:
- z = posterior.mode()
- dec = self.decode(z)
- return dec, posterior
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = x.permute(0, 3, 1,
- 2).to(memory_format=torch.contiguous_format).float()
- return x
- def get_last_layer(self):
- return self.decoder.conv_out.weight
- @torch.no_grad()
- def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
- log = dict()
- x = self.get_input(batch, self.image_key)
- x = x.to(self.device)
- if not only_inputs:
- xrec, posterior = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec.shape[1] > 3
- x = self.to_rgb(x)
- xrec = self.to_rgb(xrec)
- log['samples'] = self.decode(torch.randn_like(posterior.sample()))
- log['reconstructions'] = xrec
- if log_ema or self.use_ema:
- with self.ema_scope():
- xrec_ema, posterior_ema = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec_ema.shape[1] > 3
- xrec_ema = self.to_rgb(xrec_ema)
- log['samples_ema'] = self.decode(
- torch.randn_like(posterior_ema.sample()))
- log['reconstructions_ema'] = xrec_ema
- log['inputs'] = x
- return log
- def to_rgb(self, x):
- assert self.image_key == 'segmentation'
- if not hasattr(self, 'colorize'):
- self.register_buffer('colorize',
- torch.randn(3, x.shape[1], 1, 1).to(x))
- x = F.conv2d(x, weight=self.colorize)
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
- return x
- class IdentityFirstStage(torch.nn.Module):
- def __init__(self, *args, vq_interface=False, **kwargs):
- self.vq_interface = vq_interface
- super().__init__()
- def encode(self, x, *args, **kwargs):
- return x
- def decode(self, x, *args, **kwargs):
- return x
- def quantize(self, x, *args, **kwargs):
- if self.vq_interface:
- return x, None, [None, None, None]
- return x
- def forward(self, x, *args, **kwargs):
- return x
|