upfirdn2d.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import os
  2. from collections import abc
  3. import torch
  4. from torch.autograd import Function
  5. from torch.nn import functional as F
  6. from torch.utils.cpp_extension import load
  7. module_path = os.path.dirname(__file__)
  8. upfirdn2d_op = load(
  9. 'upfirdn2d',
  10. sources=[
  11. os.path.join(module_path, 'upfirdn2d.cpp'),
  12. os.path.join(module_path, 'upfirdn2d_kernel.cu'),
  13. ],
  14. )
  15. class UpFirDn2dBackward(Function):
  16. @staticmethod
  17. def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
  18. in_size, out_size):
  19. up_x, up_y = up
  20. down_x, down_y = down
  21. g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
  22. grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
  23. grad_input = upfirdn2d_op.upfirdn2d(
  24. grad_output,
  25. grad_kernel,
  26. down_x,
  27. down_y,
  28. up_x,
  29. up_y,
  30. g_pad_x0,
  31. g_pad_x1,
  32. g_pad_y0,
  33. g_pad_y1,
  34. )
  35. grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
  36. in_size[3])
  37. ctx.save_for_backward(kernel)
  38. pad_x0, pad_x1, pad_y0, pad_y1 = pad
  39. ctx.up_x = up_x
  40. ctx.up_y = up_y
  41. ctx.down_x = down_x
  42. ctx.down_y = down_y
  43. ctx.pad_x0 = pad_x0
  44. ctx.pad_x1 = pad_x1
  45. ctx.pad_y0 = pad_y0
  46. ctx.pad_y1 = pad_y1
  47. ctx.in_size = in_size
  48. ctx.out_size = out_size
  49. return grad_input
  50. @staticmethod
  51. def backward(ctx, gradgrad_input):
  52. kernel, = ctx.saved_tensors
  53. gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
  54. ctx.in_size[3], 1)
  55. gradgrad_out = upfirdn2d_op.upfirdn2d(
  56. gradgrad_input,
  57. kernel,
  58. ctx.up_x,
  59. ctx.up_y,
  60. ctx.down_x,
  61. ctx.down_y,
  62. ctx.pad_x0,
  63. ctx.pad_x1,
  64. ctx.pad_y0,
  65. ctx.pad_y1,
  66. )
  67. # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
  68. gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
  69. ctx.out_size[0], ctx.out_size[1])
  70. return gradgrad_out, None, None, None, None, None, None, None, None
  71. class UpFirDn2d(Function):
  72. @staticmethod
  73. def forward(ctx, input, kernel, up, down, pad):
  74. up_x, up_y = up
  75. down_x, down_y = down
  76. pad_x0, pad_x1, pad_y0, pad_y1 = pad
  77. kernel_h, kernel_w = kernel.shape
  78. batch, channel, in_h, in_w = input.shape
  79. ctx.in_size = input.shape
  80. input = input.reshape(-1, in_h, in_w, 1)
  81. ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
  82. out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
  83. out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
  84. ctx.out_size = (out_h, out_w)
  85. ctx.up = (up_x, up_y)
  86. ctx.down = (down_x, down_y)
  87. ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
  88. g_pad_x0 = kernel_w - pad_x0 - 1
  89. g_pad_y0 = kernel_h - pad_y0 - 1
  90. g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
  91. g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
  92. ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
  93. out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y,
  94. pad_x0, pad_x1, pad_y0, pad_y1)
  95. # out = out.view(major, out_h, out_w, minor)
  96. out = out.view(-1, channel, out_h, out_w)
  97. return out
  98. @staticmethod
  99. def backward(ctx, grad_output):
  100. kernel, grad_kernel = ctx.saved_tensors
  101. grad_input = None
  102. if ctx.needs_input_grad[0]:
  103. grad_input = UpFirDn2dBackward.apply(
  104. grad_output,
  105. kernel,
  106. grad_kernel,
  107. ctx.up,
  108. ctx.down,
  109. ctx.pad,
  110. ctx.g_pad,
  111. ctx.in_size,
  112. ctx.out_size,
  113. )
  114. return grad_input, None, None, None, None
  115. def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
  116. if not isinstance(up, abc.Iterable):
  117. up = (up, up)
  118. if not isinstance(down, abc.Iterable):
  119. down = (down, down)
  120. if len(pad) == 2:
  121. pad = (pad[0], pad[1], pad[0], pad[1])
  122. if input.device.type == 'cpu':
  123. out = upfirdn2d_native(input, kernel, *up, *down, *pad)
  124. else:
  125. out = UpFirDn2d.apply(input, kernel, up, down, pad)
  126. return out
  127. def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
  128. pad_y0, pad_y1):
  129. _, channel, in_h, in_w = input.shape
  130. input = input.reshape(-1, in_h, in_w, 1)
  131. _, in_h, in_w, minor = input.shape
  132. kernel_h, kernel_w = kernel.shape
  133. out = input.view(-1, in_h, 1, in_w, 1, minor)
  134. out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
  135. out = out.view(-1, in_h * up_y, in_w * up_x, minor)
  136. out = F.pad(
  137. out,
  138. [0, 0,
  139. max(pad_x0, 0),
  140. max(pad_x1, 0),
  141. max(pad_y0, 0),
  142. max(pad_y1, 0)])
  143. out = out[:,
  144. max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
  145. max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
  146. out = out.permute(0, 3, 1, 2)
  147. out = out.reshape(
  148. [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
  149. w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
  150. out = F.conv2d(out, w)
  151. out = out.reshape(
  152. -1,
  153. minor,
  154. in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
  155. in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
  156. )
  157. out = out.permute(0, 2, 3, 1)
  158. out = out[:, ::down_y, ::down_x, :]
  159. out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
  160. out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
  161. return out.view(-1, channel, out_h, out_w)