autoencoder.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. __all__ = ['AutoencoderKL']
  7. def nonlinearity(x):
  8. # swish
  9. return x * torch.sigmoid(x)
  10. def Normalize(in_channels, num_groups=32):
  11. return torch.nn.GroupNorm(
  12. num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
  13. class DiagonalGaussianDistribution(object):
  14. def __init__(self, parameters, deterministic=False):
  15. self.parameters = parameters
  16. self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
  17. self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
  18. self.deterministic = deterministic
  19. self.std = torch.exp(0.5 * self.logvar)
  20. self.var = torch.exp(self.logvar)
  21. if self.deterministic:
  22. self.var = self.std = torch.zeros_like(
  23. self.mean).to(device=self.parameters.device)
  24. def sample(self):
  25. x = self.mean + self.std * torch.randn(
  26. self.mean.shape).to(device=self.parameters.device)
  27. return x
  28. def kl(self, other=None):
  29. if self.deterministic:
  30. return torch.Tensor([0.])
  31. else:
  32. if other is None:
  33. return 0.5 * torch.sum(
  34. torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
  35. dim=[1, 2, 3])
  36. else:
  37. return 0.5 * torch.sum(
  38. torch.pow(self.mean - other.mean, 2) / other.var
  39. + self.var / other.var - 1.0 - self.logvar + other.logvar,
  40. dim=[1, 2, 3])
  41. def nll(self, sample, dims=[1, 2, 3]):
  42. if self.deterministic:
  43. return torch.Tensor([0.])
  44. logtwopi = np.log(2.0 * np.pi)
  45. return 0.5 * torch.sum(
  46. logtwopi + self.logvar
  47. + torch.pow(sample - self.mean, 2) / self.var,
  48. dim=dims)
  49. def mode(self):
  50. return self.mean
  51. class Downsample(nn.Module):
  52. def __init__(self, in_channels, with_conv):
  53. super().__init__()
  54. self.with_conv = with_conv
  55. if self.with_conv:
  56. # no asymmetric padding in torch conv, must do it ourselves
  57. self.conv = torch.nn.Conv2d(
  58. in_channels, in_channels, kernel_size=3, stride=2, padding=0)
  59. def forward(self, x):
  60. if self.with_conv:
  61. pad = (0, 1, 0, 1)
  62. x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
  63. x = self.conv(x)
  64. else:
  65. x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
  66. return x
  67. class ResnetBlock(nn.Module):
  68. def __init__(self,
  69. *,
  70. in_channels,
  71. out_channels=None,
  72. conv_shortcut=False,
  73. dropout,
  74. temb_channels=512):
  75. super().__init__()
  76. self.in_channels = in_channels
  77. out_channels = in_channels if out_channels is None else out_channels
  78. self.out_channels = out_channels
  79. self.use_conv_shortcut = conv_shortcut
  80. self.norm1 = Normalize(in_channels)
  81. self.conv1 = torch.nn.Conv2d(
  82. in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  83. if temb_channels > 0:
  84. self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
  85. self.norm2 = Normalize(out_channels)
  86. self.dropout = torch.nn.Dropout(dropout)
  87. self.conv2 = torch.nn.Conv2d(
  88. out_channels, out_channels, kernel_size=3, stride=1, padding=1)
  89. if self.in_channels != self.out_channels:
  90. if self.use_conv_shortcut:
  91. self.conv_shortcut = torch.nn.Conv2d(
  92. in_channels,
  93. out_channels,
  94. kernel_size=3,
  95. stride=1,
  96. padding=1)
  97. else:
  98. self.nin_shortcut = torch.nn.Conv2d(
  99. in_channels,
  100. out_channels,
  101. kernel_size=1,
  102. stride=1,
  103. padding=0)
  104. def forward(self, x, temb):
  105. h = x
  106. h = self.norm1(h)
  107. h = nonlinearity(h)
  108. h = self.conv1(h)
  109. if temb is not None:
  110. h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
  111. h = self.norm2(h)
  112. h = nonlinearity(h)
  113. h = self.dropout(h)
  114. h = self.conv2(h)
  115. if self.in_channels != self.out_channels:
  116. if self.use_conv_shortcut:
  117. x = self.conv_shortcut(x)
  118. else:
  119. x = self.nin_shortcut(x)
  120. return x + h
  121. class AttnBlock(nn.Module):
  122. def __init__(self, in_channels):
  123. super().__init__()
  124. self.in_channels = in_channels
  125. self.norm = Normalize(in_channels)
  126. self.q = torch.nn.Conv2d(
  127. in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  128. self.k = torch.nn.Conv2d(
  129. in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  130. self.v = torch.nn.Conv2d(
  131. in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  132. self.proj_out = torch.nn.Conv2d(
  133. in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  134. def forward(self, x):
  135. h_ = x
  136. h_ = self.norm(h_)
  137. q = self.q(h_)
  138. k = self.k(h_)
  139. v = self.v(h_)
  140. # compute attention
  141. b, c, h, w = q.shape
  142. q = q.reshape(b, c, h * w)
  143. q = q.permute(0, 2, 1) # b,hw,c
  144. k = k.reshape(b, c, h * w) # b,c,hw
  145. w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  146. w_ = w_ * (int(c)**(-0.5))
  147. w_ = torch.nn.functional.softmax(w_, dim=2)
  148. # attend to values
  149. v = v.reshape(b, c, h * w)
  150. w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  151. h_ = torch.bmm(
  152. v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  153. h_ = h_.reshape(b, c, h, w)
  154. h_ = self.proj_out(h_)
  155. return x + h_
  156. class AttnBlock(nn.Module): # noqa
  157. def __init__(self, in_channels):
  158. super().__init__()
  159. self.in_channels = in_channels
  160. self.norm = Normalize(in_channels)
  161. self.q = torch.nn.Conv2d(
  162. in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  163. self.k = torch.nn.Conv2d(
  164. in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  165. self.v = torch.nn.Conv2d(
  166. in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  167. self.proj_out = torch.nn.Conv2d(
  168. in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  169. def forward(self, x):
  170. h_ = x
  171. h_ = self.norm(h_)
  172. q = self.q(h_)
  173. k = self.k(h_)
  174. v = self.v(h_)
  175. # compute attention
  176. b, c, h, w = q.shape
  177. q = q.reshape(b, c, h * w)
  178. q = q.permute(0, 2, 1) # b,hw,c
  179. k = k.reshape(b, c, h * w) # b,c,hw
  180. w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
  181. w_ = w_ * (int(c)**(-0.5))
  182. w_ = torch.nn.functional.softmax(w_, dim=2)
  183. # attend to values
  184. v = v.reshape(b, c, h * w)
  185. w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
  186. h_ = torch.bmm(
  187. v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
  188. h_ = h_.reshape(b, c, h, w)
  189. h_ = self.proj_out(h_)
  190. return x + h_
  191. class Upsample(nn.Module):
  192. def __init__(self, in_channels, with_conv):
  193. super().__init__()
  194. self.with_conv = with_conv
  195. if self.with_conv:
  196. self.conv = torch.nn.Conv2d(
  197. in_channels, in_channels, kernel_size=3, stride=1, padding=1)
  198. def forward(self, x):
  199. x = torch.nn.functional.interpolate(
  200. x, scale_factor=2.0, mode='nearest')
  201. if self.with_conv:
  202. x = self.conv(x)
  203. return x
  204. class Downsample(nn.Module): # noqa
  205. def __init__(self, in_channels, with_conv):
  206. super().__init__()
  207. self.with_conv = with_conv
  208. if self.with_conv:
  209. # no asymmetric padding in torch conv, must do it ourselves
  210. self.conv = torch.nn.Conv2d(
  211. in_channels, in_channels, kernel_size=3, stride=2, padding=0)
  212. def forward(self, x):
  213. if self.with_conv:
  214. pad = (0, 1, 0, 1)
  215. x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
  216. x = self.conv(x)
  217. else:
  218. x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
  219. return x
  220. class Encoder(nn.Module):
  221. def __init__(self,
  222. *,
  223. ch,
  224. out_ch,
  225. ch_mult=(1, 2, 4, 8),
  226. num_res_blocks,
  227. attn_resolutions,
  228. dropout=0.0,
  229. resamp_with_conv=True,
  230. in_channels,
  231. resolution,
  232. z_channels,
  233. double_z=True,
  234. use_linear_attn=False,
  235. attn_type='vanilla',
  236. **ignore_kwargs):
  237. super().__init__()
  238. self.ch = ch
  239. self.temb_ch = 0
  240. self.num_resolutions = len(ch_mult)
  241. self.num_res_blocks = num_res_blocks
  242. self.resolution = resolution
  243. self.in_channels = in_channels
  244. # downsampling
  245. self.conv_in = torch.nn.Conv2d(
  246. in_channels, self.ch, kernel_size=3, stride=1, padding=1)
  247. curr_res = resolution
  248. in_ch_mult = (1, ) + tuple(ch_mult)
  249. self.in_ch_mult = in_ch_mult
  250. self.down = nn.ModuleList()
  251. for i_level in range(self.num_resolutions):
  252. block = nn.ModuleList()
  253. attn = nn.ModuleList()
  254. block_in = ch * in_ch_mult[i_level]
  255. block_out = ch * ch_mult[i_level]
  256. for i_block in range(self.num_res_blocks):
  257. block.append(
  258. ResnetBlock(
  259. in_channels=block_in,
  260. out_channels=block_out,
  261. temb_channels=self.temb_ch,
  262. dropout=dropout))
  263. block_in = block_out
  264. if curr_res in attn_resolutions:
  265. attn.append(AttnBlock(block_in))
  266. down = nn.Module()
  267. down.block = block
  268. down.attn = attn
  269. if i_level != self.num_resolutions - 1:
  270. down.downsample = Downsample(block_in, resamp_with_conv)
  271. curr_res = curr_res // 2
  272. self.down.append(down)
  273. # middle
  274. self.mid = nn.Module()
  275. self.mid.block_1 = ResnetBlock(
  276. in_channels=block_in,
  277. out_channels=block_in,
  278. temb_channels=self.temb_ch,
  279. dropout=dropout)
  280. self.mid.attn_1 = AttnBlock(block_in)
  281. self.mid.block_2 = ResnetBlock(
  282. in_channels=block_in,
  283. out_channels=block_in,
  284. temb_channels=self.temb_ch,
  285. dropout=dropout)
  286. # end
  287. self.norm_out = Normalize(block_in)
  288. self.conv_out = torch.nn.Conv2d(
  289. block_in,
  290. 2 * z_channels if double_z else z_channels,
  291. kernel_size=3,
  292. stride=1,
  293. padding=1)
  294. def forward(self, x):
  295. # timestep embedding
  296. temb = None
  297. # downsampling
  298. hs = [self.conv_in(x)]
  299. for i_level in range(self.num_resolutions):
  300. for i_block in range(self.num_res_blocks):
  301. h = self.down[i_level].block[i_block](hs[-1], temb)
  302. if len(self.down[i_level].attn) > 0:
  303. h = self.down[i_level].attn[i_block](h)
  304. hs.append(h)
  305. if i_level != self.num_resolutions - 1:
  306. hs.append(self.down[i_level].downsample(hs[-1]))
  307. # middle
  308. h = hs[-1]
  309. h = self.mid.block_1(h, temb)
  310. h = self.mid.attn_1(h)
  311. h = self.mid.block_2(h, temb)
  312. # end
  313. h = self.norm_out(h)
  314. h = nonlinearity(h)
  315. h = self.conv_out(h)
  316. return h
  317. class Decoder(nn.Module):
  318. def __init__(self,
  319. *,
  320. ch,
  321. out_ch,
  322. ch_mult=(1, 2, 4, 8),
  323. num_res_blocks,
  324. attn_resolutions,
  325. dropout=0.0,
  326. resamp_with_conv=True,
  327. in_channels,
  328. resolution,
  329. z_channels,
  330. give_pre_end=False,
  331. tanh_out=False,
  332. use_linear_attn=False,
  333. attn_type='vanilla',
  334. **ignorekwargs):
  335. super().__init__()
  336. self.ch = ch
  337. self.temb_ch = 0
  338. self.num_resolutions = len(ch_mult)
  339. self.num_res_blocks = num_res_blocks
  340. self.resolution = resolution
  341. self.in_channels = in_channels
  342. self.give_pre_end = give_pre_end
  343. self.tanh_out = tanh_out
  344. # compute in_ch_mult, block_in and curr_res at lowest res
  345. block_in = ch * ch_mult[self.num_resolutions - 1]
  346. curr_res = resolution // 2**(self.num_resolutions - 1)
  347. self.z_shape = (1, z_channels, curr_res, curr_res)
  348. print('Working with z of shape {} = {} dimensions.'.format(
  349. self.z_shape, np.prod(self.z_shape)))
  350. # z to block_in
  351. self.conv_in = torch.nn.Conv2d(
  352. z_channels, block_in, kernel_size=3, stride=1, padding=1)
  353. # middle
  354. self.mid = nn.Module()
  355. self.mid.block_1 = ResnetBlock(
  356. in_channels=block_in,
  357. out_channels=block_in,
  358. temb_channels=self.temb_ch,
  359. dropout=dropout)
  360. self.mid.attn_1 = AttnBlock(block_in)
  361. self.mid.block_2 = ResnetBlock(
  362. in_channels=block_in,
  363. out_channels=block_in,
  364. temb_channels=self.temb_ch,
  365. dropout=dropout)
  366. # upsampling
  367. self.up = nn.ModuleList()
  368. for i_level in reversed(range(self.num_resolutions)):
  369. block = nn.ModuleList()
  370. attn = nn.ModuleList()
  371. block_out = ch * ch_mult[i_level]
  372. for i_block in range(self.num_res_blocks + 1):
  373. block.append(
  374. ResnetBlock(
  375. in_channels=block_in,
  376. out_channels=block_out,
  377. temb_channels=self.temb_ch,
  378. dropout=dropout))
  379. block_in = block_out
  380. if curr_res in attn_resolutions:
  381. attn.append(AttnBlock(block_in))
  382. up = nn.Module()
  383. up.block = block
  384. up.attn = attn
  385. if i_level != 0:
  386. up.upsample = Upsample(block_in, resamp_with_conv)
  387. curr_res = curr_res * 2
  388. self.up.insert(0, up) # prepend to get consistent order
  389. # end
  390. self.norm_out = Normalize(block_in)
  391. self.conv_out = torch.nn.Conv2d(
  392. block_in, out_ch, kernel_size=3, stride=1, padding=1)
  393. def forward(self, z):
  394. self.last_z_shape = z.shape
  395. # timestep embedding
  396. temb = None
  397. # z to block_in
  398. h = self.conv_in(z)
  399. # middle
  400. h = self.mid.block_1(h, temb)
  401. h = self.mid.attn_1(h)
  402. h = self.mid.block_2(h, temb)
  403. # upsampling
  404. for i_level in reversed(range(self.num_resolutions)):
  405. for i_block in range(self.num_res_blocks + 1):
  406. h = self.up[i_level].block[i_block](h, temb)
  407. if len(self.up[i_level].attn) > 0:
  408. h = self.up[i_level].attn[i_block](h)
  409. if i_level != 0:
  410. h = self.up[i_level].upsample(h)
  411. # end
  412. if self.give_pre_end:
  413. return h
  414. h = self.norm_out(h)
  415. h = nonlinearity(h)
  416. h = self.conv_out(h)
  417. if self.tanh_out:
  418. h = torch.tanh(h)
  419. return h
  420. class AutoencoderKL(nn.Module):
  421. def __init__(self,
  422. ddconfig,
  423. embed_dim,
  424. ckpt_path=None,
  425. ignore_keys=[],
  426. image_key='image',
  427. colorize_nlabels=None,
  428. monitor=None,
  429. ema_decay=None,
  430. learn_logvar=False):
  431. super().__init__()
  432. self.learn_logvar = learn_logvar
  433. self.image_key = image_key
  434. self.encoder = Encoder(**ddconfig)
  435. self.decoder = Decoder(**ddconfig)
  436. assert ddconfig['double_z']
  437. self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
  438. 2 * embed_dim, 1)
  439. self.post_quant_conv = torch.nn.Conv2d(embed_dim,
  440. ddconfig['z_channels'], 1)
  441. self.embed_dim = embed_dim
  442. if colorize_nlabels is not None:
  443. assert type(colorize_nlabels) == int
  444. self.register_buffer('colorize',
  445. torch.randn(3, colorize_nlabels, 1, 1))
  446. if monitor is not None:
  447. self.monitor = monitor
  448. self.use_ema = ema_decay is not None
  449. if ckpt_path is not None:
  450. self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
  451. def init_from_ckpt(self, path, ignore_keys=list()):
  452. sd = torch.load(path, map_location='cpu')['state_dict']
  453. keys = list(sd.keys())
  454. for key in keys:
  455. print(key, sd[key].shape)
  456. import collections
  457. sd_new = collections.OrderedDict()
  458. for k in keys:
  459. if k.find('first_stage_model') >= 0:
  460. k_new = k.split('first_stage_model.')[-1]
  461. sd_new[k_new] = sd[k]
  462. self.load_state_dict(sd_new, strict=True)
  463. print(f'Restored from {path}')
  464. def init_from_ckpt2(self, path, ignore_keys=list()):
  465. sd = torch.load(path, map_location='cpu')['state_dict']
  466. keys = list(sd.keys())
  467. first_stage_model
  468. for k in keys:
  469. for ik in ignore_keys:
  470. if k.startswith(ik):
  471. print('Deleting key {} from state_dict.'.format(k))
  472. del sd[k]
  473. self.load_state_dict(sd, strict=False)
  474. print(f'Restored from {path}')
  475. def on_train_batch_end(self, *args, **kwargs):
  476. if self.use_ema:
  477. self.model_ema(self)
  478. def encode(self, x):
  479. h = self.encoder(x)
  480. moments = self.quant_conv(h)
  481. posterior = DiagonalGaussianDistribution(moments)
  482. return posterior
  483. def decode(self, z):
  484. z = self.post_quant_conv(z)
  485. dec = self.decoder(z)
  486. return dec
  487. def forward(self, input, sample_posterior=True):
  488. posterior = self.encode(input)
  489. if sample_posterior:
  490. z = posterior.sample()
  491. else:
  492. z = posterior.mode()
  493. dec = self.decode(z)
  494. return dec, posterior
  495. def get_input(self, batch, k):
  496. x = batch[k]
  497. if len(x.shape) == 3:
  498. x = x[..., None]
  499. x = x.permute(0, 3, 1,
  500. 2).to(memory_format=torch.contiguous_format).float()
  501. return x
  502. def get_last_layer(self):
  503. return self.decoder.conv_out.weight
  504. @torch.no_grad()
  505. def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
  506. log = dict()
  507. x = self.get_input(batch, self.image_key)
  508. x = x.to(self.device)
  509. if not only_inputs:
  510. xrec, posterior = self(x)
  511. if x.shape[1] > 3:
  512. # colorize with random projection
  513. assert xrec.shape[1] > 3
  514. x = self.to_rgb(x)
  515. xrec = self.to_rgb(xrec)
  516. log['samples'] = self.decode(torch.randn_like(posterior.sample()))
  517. log['reconstructions'] = xrec
  518. if log_ema or self.use_ema:
  519. with self.ema_scope():
  520. xrec_ema, posterior_ema = self(x)
  521. if x.shape[1] > 3:
  522. # colorize with random projection
  523. assert xrec_ema.shape[1] > 3
  524. xrec_ema = self.to_rgb(xrec_ema)
  525. log['samples_ema'] = self.decode(
  526. torch.randn_like(posterior_ema.sample()))
  527. log['reconstructions_ema'] = xrec_ema
  528. log['inputs'] = x
  529. return log
  530. def to_rgb(self, x):
  531. assert self.image_key == 'segmentation'
  532. if not hasattr(self, 'colorize'):
  533. self.register_buffer('colorize',
  534. torch.randn(3, x.shape[1], 1, 1).to(x))
  535. x = F.conv2d(x, weight=self.colorize)
  536. x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
  537. return x
  538. class IdentityFirstStage(torch.nn.Module):
  539. def __init__(self, *args, vq_interface=False, **kwargs):
  540. self.vq_interface = vq_interface
  541. super().__init__()
  542. def encode(self, x, *args, **kwargs):
  543. return x
  544. def decode(self, x, *args, **kwargs):
  545. return x
  546. def quantize(self, x, *args, **kwargs):
  547. if self.vq_interface:
  548. return x, None, [None, None, None]
  549. return x
  550. def forward(self, x, *args, **kwargs):
  551. return x