stylegan2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. # The implementation is adopted from stylegan2-pytorch,
  2. # made public available under the MIT License at https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
  3. import functools
  4. import math
  5. import operator
  6. import random
  7. import torch
  8. from torch import nn
  9. from torch.autograd import Function
  10. from torch.nn import functional as F
  11. from .op import FusedLeakyReLU, conv2d_gradfix, fused_leaky_relu, upfirdn2d
  12. class PixelNorm(nn.Module):
  13. def __init__(self):
  14. super().__init__()
  15. def forward(self, input):
  16. return input * torch.rsqrt(
  17. torch.mean(input**2, dim=1, keepdim=True) + 1e-8)
  18. def make_kernel(k):
  19. k = torch.tensor(k, dtype=torch.float32)
  20. if k.ndim == 1:
  21. k = k[None, :] * k[:, None]
  22. k /= k.sum()
  23. return k
  24. class Upsample(nn.Module):
  25. def __init__(self, kernel, factor=2):
  26. super().__init__()
  27. self.factor = factor
  28. kernel = make_kernel(kernel) * (factor**2)
  29. self.register_buffer('kernel', kernel)
  30. p = kernel.shape[0] - factor
  31. pad0 = (p + 1) // 2 + factor - 1
  32. pad1 = p // 2
  33. self.pad = (pad0, pad1)
  34. def forward(self, input):
  35. out = upfirdn2d(
  36. input, self.kernel, up=self.factor, down=1, pad=self.pad)
  37. return out
  38. class Downsample(nn.Module):
  39. def __init__(self, kernel, factor=2):
  40. super().__init__()
  41. self.factor = factor
  42. kernel = make_kernel(kernel)
  43. self.register_buffer('kernel', kernel)
  44. p = kernel.shape[0] - factor
  45. pad0 = (p + 1) // 2
  46. pad1 = p // 2
  47. self.pad = (pad0, pad1)
  48. def forward(self, input):
  49. out = upfirdn2d(
  50. input, self.kernel, up=1, down=self.factor, pad=self.pad)
  51. return out
  52. class Blur(nn.Module):
  53. def __init__(self, kernel, pad, upsample_factor=1):
  54. super().__init__()
  55. kernel = make_kernel(kernel)
  56. if upsample_factor > 1:
  57. kernel = kernel * (upsample_factor**2)
  58. self.register_buffer('kernel', kernel)
  59. self.pad = pad
  60. def forward(self, input):
  61. out = upfirdn2d(input, self.kernel, pad=self.pad)
  62. return out
  63. class EqualConv2d(nn.Module):
  64. def __init__(self,
  65. in_channel,
  66. out_channel,
  67. kernel_size,
  68. stride=1,
  69. padding=0,
  70. bias=True):
  71. super().__init__()
  72. self.weight = nn.Parameter(
  73. torch.randn(out_channel, in_channel, kernel_size, kernel_size))
  74. self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
  75. self.stride = stride
  76. self.padding = padding
  77. if bias:
  78. self.bias = nn.Parameter(torch.zeros(out_channel))
  79. else:
  80. self.bias = None
  81. def forward(self, input):
  82. out = conv2d_gradfix.conv2d(
  83. input,
  84. self.weight * self.scale,
  85. bias=self.bias,
  86. stride=self.stride,
  87. padding=self.padding,
  88. )
  89. return out
  90. def __repr__(self):
  91. return (
  92. f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
  93. f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
  94. )
  95. class EqualLinear(nn.Module):
  96. def __init__(self,
  97. in_dim,
  98. out_dim,
  99. bias=True,
  100. bias_init=0,
  101. lr_mul=1,
  102. activation=None):
  103. super().__init__()
  104. self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
  105. if bias:
  106. self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
  107. else:
  108. self.bias = None
  109. self.activation = activation
  110. self.scale = (1 / math.sqrt(in_dim)) * lr_mul
  111. self.lr_mul = lr_mul
  112. def forward(self, input):
  113. if self.activation:
  114. out = F.linear(input, self.weight * self.scale)
  115. out = fused_leaky_relu(out, self.bias * self.lr_mul)
  116. else:
  117. out = F.linear(
  118. input, self.weight * self.scale, bias=self.bias * self.lr_mul)
  119. return out
  120. def __repr__(self):
  121. return (
  122. f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
  123. )
  124. class ModulatedConv2d(nn.Module):
  125. def __init__(
  126. self,
  127. in_channel,
  128. out_channel,
  129. kernel_size,
  130. style_dim,
  131. demodulate=True,
  132. upsample=False,
  133. downsample=False,
  134. blur_kernel=[1, 3, 3, 1],
  135. fused=True,
  136. ):
  137. super().__init__()
  138. self.eps = 1e-8
  139. self.kernel_size = kernel_size
  140. self.in_channel = in_channel
  141. self.out_channel = out_channel
  142. self.upsample = upsample
  143. self.downsample = downsample
  144. if upsample:
  145. factor = 2
  146. p = (len(blur_kernel) - factor) - (kernel_size - 1)
  147. pad0 = (p + 1) // 2 + factor - 1
  148. pad1 = p // 2 + 1
  149. self.blur = Blur(
  150. blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
  151. if downsample:
  152. factor = 2
  153. p = (len(blur_kernel) - factor) + (kernel_size - 1)
  154. pad0 = (p + 1) // 2
  155. pad1 = p // 2
  156. self.blur = Blur(blur_kernel, pad=(pad0, pad1))
  157. fan_in = in_channel * kernel_size**2
  158. self.scale = 1 / math.sqrt(fan_in)
  159. self.padding = kernel_size // 2
  160. self.weight = nn.Parameter(
  161. torch.randn(1, out_channel, in_channel, kernel_size, kernel_size))
  162. self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
  163. self.demodulate = demodulate
  164. self.fused = fused
  165. def __repr__(self):
  166. return (
  167. f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
  168. f'upsample={self.upsample}, downsample={self.downsample})')
  169. def forward(self, input, style):
  170. batch, in_channel, height, width = input.shape
  171. if not self.fused:
  172. weight = self.scale * self.weight.squeeze(0)
  173. style = self.modulation(style)
  174. if self.demodulate:
  175. w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1,
  176. 1)
  177. dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
  178. input = input * style.reshape(batch, in_channel, 1, 1)
  179. if self.upsample:
  180. weight = weight.transpose(0, 1)
  181. out = conv2d_gradfix.conv_transpose2d(
  182. input, weight, padding=0, stride=2)
  183. out = self.blur(out)
  184. elif self.downsample:
  185. input = self.blur(input)
  186. out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
  187. else:
  188. out = conv2d_gradfix.conv2d(
  189. input, weight, padding=self.padding)
  190. if self.demodulate:
  191. out = out * dcoefs.view(batch, -1, 1, 1)
  192. return out
  193. style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
  194. weight = self.scale * self.weight * style
  195. if self.demodulate:
  196. demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
  197. weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
  198. weight = weight.view(batch * self.out_channel, in_channel,
  199. self.kernel_size, self.kernel_size)
  200. if self.upsample:
  201. input = input.view(1, batch * in_channel, height, width)
  202. weight = weight.view(batch, self.out_channel, in_channel,
  203. self.kernel_size, self.kernel_size)
  204. weight = weight.transpose(1, 2).reshape(batch * in_channel,
  205. self.out_channel,
  206. self.kernel_size,
  207. self.kernel_size)
  208. out = conv2d_gradfix.conv_transpose2d(
  209. input, weight, padding=0, stride=2, groups=batch)
  210. _, _, height, width = out.shape
  211. out = out.view(batch, self.out_channel, height, width)
  212. out = self.blur(out)
  213. elif self.downsample:
  214. input = self.blur(input)
  215. _, _, height, width = input.shape
  216. input = input.view(1, batch * in_channel, height, width)
  217. out = conv2d_gradfix.conv2d(
  218. input, weight, padding=0, stride=2, groups=batch)
  219. _, _, height, width = out.shape
  220. out = out.view(batch, self.out_channel, height, width)
  221. else:
  222. input = input.view(1, batch * in_channel, height, width)
  223. out = conv2d_gradfix.conv2d(
  224. input, weight, padding=self.padding, groups=batch)
  225. _, _, height, width = out.shape
  226. out = out.view(batch, self.out_channel, height, width)
  227. return out
  228. class NoiseInjection(nn.Module):
  229. def __init__(self):
  230. super().__init__()
  231. self.weight = nn.Parameter(torch.zeros(1))
  232. def forward(self, image, noise=None):
  233. if noise is None:
  234. batch, _, height, width = image.shape
  235. noise = image.new_empty(batch, 1, height, width).normal_()
  236. return image + self.weight * noise
  237. class ConstantInput(nn.Module):
  238. def __init__(self, channel, size=4):
  239. super().__init__()
  240. self.input = nn.Parameter(torch.randn(1, channel, size, size))
  241. def forward(self, input):
  242. batch = input.shape[0]
  243. out = self.input.repeat(batch, 1, 1, 1)
  244. return out
  245. class StyledConv(nn.Module):
  246. def __init__(
  247. self,
  248. in_channel,
  249. out_channel,
  250. kernel_size,
  251. style_dim,
  252. upsample=False,
  253. blur_kernel=[1, 3, 3, 1],
  254. demodulate=True,
  255. ):
  256. super().__init__()
  257. self.conv = ModulatedConv2d(
  258. in_channel,
  259. out_channel,
  260. kernel_size,
  261. style_dim,
  262. upsample=upsample,
  263. blur_kernel=blur_kernel,
  264. demodulate=demodulate,
  265. )
  266. self.noise = NoiseInjection()
  267. # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
  268. # self.activate = ScaledLeakyReLU(0.2)
  269. self.activate = FusedLeakyReLU(out_channel)
  270. def forward(self, input, style, noise=None):
  271. out = self.conv(input, style)
  272. out = self.noise(out, noise=noise)
  273. # out = out + self.bias
  274. out = self.activate(out)
  275. return out
  276. class ToRGB(nn.Module):
  277. def __init__(self,
  278. in_channel,
  279. style_dim,
  280. upsample=True,
  281. blur_kernel=[1, 3, 3, 1]):
  282. super().__init__()
  283. if upsample:
  284. self.upsample = Upsample(blur_kernel)
  285. self.conv = ModulatedConv2d(
  286. in_channel, 3, 1, style_dim, demodulate=False)
  287. self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
  288. def forward(self, input, style, skip=None):
  289. out = self.conv(input, style)
  290. out = out + self.bias
  291. if skip is not None:
  292. skip = self.upsample(skip)
  293. out = out + skip
  294. return out
  295. class Generator(nn.Module):
  296. def __init__(
  297. self,
  298. size,
  299. style_dim,
  300. n_mlp,
  301. channel_multiplier=2,
  302. blur_kernel=[1, 3, 3, 1],
  303. lr_mlp=0.01,
  304. ):
  305. super().__init__()
  306. self.size = size
  307. self.style_dim = style_dim
  308. layers = [PixelNorm()]
  309. for i in range(n_mlp):
  310. layers.append(
  311. EqualLinear(
  312. style_dim,
  313. style_dim,
  314. lr_mul=lr_mlp,
  315. activation='fused_lrelu'))
  316. self.style = nn.Sequential(*layers)
  317. self.channels = {
  318. 4: 512,
  319. 8: 512,
  320. 16: 512,
  321. 32: 512,
  322. 64: 256 * channel_multiplier,
  323. 128: 128 * channel_multiplier,
  324. 256: 64 * channel_multiplier,
  325. 512: 32 * channel_multiplier,
  326. 1024: 16 * channel_multiplier,
  327. }
  328. self.input = ConstantInput(self.channels[4])
  329. self.conv1 = StyledConv(
  330. self.channels[4],
  331. self.channels[4],
  332. 3,
  333. style_dim,
  334. blur_kernel=blur_kernel)
  335. self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
  336. self.log_size = int(math.log(size, 2))
  337. self.num_layers = (self.log_size - 2) * 2 + 1
  338. self.convs = nn.ModuleList()
  339. self.upsamples = nn.ModuleList()
  340. self.to_rgbs = nn.ModuleList()
  341. self.noises = nn.Module()
  342. in_channel = self.channels[4]
  343. for layer_idx in range(self.num_layers):
  344. res = (layer_idx + 5) // 2
  345. shape = [1, 1, 2**res, 2**res]
  346. self.noises.register_buffer(f'noise_{layer_idx}',
  347. torch.randn(*shape))
  348. for i in range(3, self.log_size + 1):
  349. out_channel = self.channels[2**i]
  350. self.convs.append(
  351. StyledConv(
  352. in_channel,
  353. out_channel,
  354. 3,
  355. style_dim,
  356. upsample=True,
  357. blur_kernel=blur_kernel,
  358. ))
  359. self.convs.append(
  360. StyledConv(
  361. out_channel,
  362. out_channel,
  363. 3,
  364. style_dim,
  365. blur_kernel=blur_kernel))
  366. self.to_rgbs.append(ToRGB(out_channel, style_dim))
  367. in_channel = out_channel
  368. self.n_latent = self.log_size * 2 - 2
  369. def make_noise(self):
  370. device = self.input.input.device
  371. noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
  372. for i in range(3, self.log_size + 1):
  373. for _ in range(2):
  374. noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
  375. return noises
  376. def mean_latent(self, n_latent):
  377. latent_in = torch.randn(
  378. n_latent, self.style_dim, device=self.input.input.device)
  379. latent = self.style(latent_in).mean(0, keepdim=True)
  380. return latent
  381. def get_latent(self, input):
  382. return self.style(input)
  383. def forward(
  384. self,
  385. styles,
  386. return_latents=False,
  387. inject_index=None,
  388. truncation=1,
  389. truncation_latent=None,
  390. input_is_latent=False,
  391. noise=None,
  392. randomize_noise=True,
  393. ):
  394. if not input_is_latent:
  395. styles = [self.style(s) for s in styles]
  396. if noise is None:
  397. if randomize_noise:
  398. noise = [None] * self.num_layers
  399. else:
  400. noise = [
  401. getattr(self.noises, f'noise_{i}')
  402. for i in range(self.num_layers)
  403. ]
  404. if truncation < 1:
  405. style_t = []
  406. for style in styles:
  407. style_t.append(truncation_latent
  408. + truncation * (style - truncation_latent))
  409. styles = style_t
  410. if len(styles) < 2:
  411. inject_index = self.n_latent
  412. if styles[0].ndim < 3:
  413. latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
  414. else:
  415. latent = styles[0]
  416. else:
  417. if inject_index is None:
  418. inject_index = random.randint(1, self.n_latent - 1)
  419. latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
  420. latent2 = styles[1].unsqueeze(1).repeat(
  421. 1, self.n_latent - inject_index, 1)
  422. latent = torch.cat([latent, latent2], 1)
  423. out = self.input(latent)
  424. out = self.conv1(out, latent[:, 0], noise=noise[0])
  425. skip = self.to_rgb1(out, latent[:, 1])
  426. i = 1
  427. for conv1, conv2, noise1, noise2, to_rgb in zip(
  428. self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],
  429. self.to_rgbs):
  430. out = conv1(out, latent[:, i], noise=noise1)
  431. out = conv2(out, latent[:, i + 1], noise=noise2)
  432. skip = to_rgb(out, latent[:, i + 2], skip)
  433. i += 2
  434. image = skip
  435. if return_latents:
  436. return image, latent
  437. else:
  438. return image, None
  439. class ConvLayer(nn.Sequential):
  440. def __init__(
  441. self,
  442. in_channel,
  443. out_channel,
  444. kernel_size,
  445. downsample=False,
  446. blur_kernel=[1, 3, 3, 1],
  447. bias=True,
  448. activate=True,
  449. ):
  450. layers = []
  451. if downsample:
  452. factor = 2
  453. p = (len(blur_kernel) - factor) + (kernel_size - 1)
  454. pad0 = (p + 1) // 2
  455. pad1 = p // 2
  456. layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
  457. stride = 2
  458. self.padding = 0
  459. else:
  460. stride = 1
  461. self.padding = kernel_size // 2
  462. layers.append(
  463. EqualConv2d(
  464. in_channel,
  465. out_channel,
  466. kernel_size,
  467. padding=self.padding,
  468. stride=stride,
  469. bias=bias and not activate,
  470. ))
  471. if activate:
  472. layers.append(FusedLeakyReLU(out_channel, bias=bias))
  473. super().__init__(*layers)
  474. class ResBlock(nn.Module):
  475. def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
  476. super().__init__()
  477. self.conv1 = ConvLayer(in_channel, in_channel, 3)
  478. self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
  479. self.skip = ConvLayer(
  480. in_channel,
  481. out_channel,
  482. 1,
  483. downsample=True,
  484. activate=False,
  485. bias=False)
  486. def forward(self, input):
  487. out = self.conv1(input)
  488. out = self.conv2(out)
  489. skip = self.skip(input)
  490. out = (out + skip) / math.sqrt(2)
  491. return out
  492. class Discriminator(nn.Module):
  493. def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
  494. super().__init__()
  495. channels = {
  496. 4: 512,
  497. 8: 512,
  498. 16: 512,
  499. 32: 512,
  500. 64: 256 * channel_multiplier,
  501. 128: 128 * channel_multiplier,
  502. 256: 64 * channel_multiplier,
  503. 512: 32 * channel_multiplier,
  504. 1024: 16 * channel_multiplier,
  505. }
  506. convs = [ConvLayer(3, channels[size], 1)]
  507. log_size = int(math.log(size, 2))
  508. in_channel = channels[size]
  509. for i in range(log_size, 2, -1):
  510. out_channel = channels[2**(i - 1)]
  511. convs.append(ResBlock(in_channel, out_channel, blur_kernel))
  512. in_channel = out_channel
  513. self.convs = nn.Sequential(*convs)
  514. self.stddev_group = 4
  515. self.stddev_feat = 1
  516. self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
  517. self.final_linear = nn.Sequential(
  518. EqualLinear(
  519. channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
  520. EqualLinear(channels[4], 1),
  521. )
  522. def forward(self, input):
  523. out = self.convs(input)
  524. batch, channel, height, width = out.shape
  525. group = min(batch, self.stddev_group)
  526. stddev = out.view(group, -1, self.stddev_feat,
  527. channel // self.stddev_feat, height, width)
  528. stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
  529. stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
  530. stddev = stddev.repeat(group, 1, height, width)
  531. out = torch.cat([out, stddev], 1)
  532. out = self.final_conv(out)
  533. out = out.view(batch, -1)
  534. out = self.final_linear(out)
  535. return out