ecb.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # The implementation is adopted from ECBSR,
  2. # made publicly available under the Apache 2.0 License at
  3. # https://github.com/xindongzhang/ECBSR/blob/main/models/ecb.py
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. class SeqConv3x3(nn.Module):
  8. def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier):
  9. super(SeqConv3x3, self).__init__()
  10. self.type = seq_type
  11. self.inp_planes = inp_planes
  12. self.out_planes = out_planes
  13. if self.type == 'conv1x1-conv3x3':
  14. self.mid_planes = int(out_planes * depth_multiplier)
  15. conv0 = torch.nn.Conv2d(
  16. self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
  17. self.k0 = conv0.weight
  18. self.b0 = conv0.bias
  19. conv1 = torch.nn.Conv2d(
  20. self.mid_planes, self.out_planes, kernel_size=3)
  21. self.k1 = conv1.weight
  22. self.b1 = conv1.bias
  23. elif self.type == 'conv1x1-sobelx':
  24. conv0 = torch.nn.Conv2d(
  25. self.inp_planes, self.out_planes, kernel_size=1, padding=0)
  26. self.k0 = conv0.weight
  27. self.b0 = conv0.bias
  28. # init scale & bias
  29. scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
  30. self.scale = nn.Parameter(scale)
  31. # bias = 0.0
  32. # bias = [bias for c in range(self.out_planes)]
  33. # bias = torch.FloatTensor(bias)
  34. bias = torch.randn(self.out_planes) * 1e-3
  35. bias = torch.reshape(bias, (self.out_planes, ))
  36. self.bias = nn.Parameter(bias)
  37. # init mask
  38. self.mask = torch.zeros((self.out_planes, 1, 3, 3),
  39. dtype=torch.float32)
  40. for i in range(self.out_planes):
  41. self.mask[i, 0, 0, 0] = 1.0
  42. self.mask[i, 0, 1, 0] = 2.0
  43. self.mask[i, 0, 2, 0] = 1.0
  44. self.mask[i, 0, 0, 2] = -1.0
  45. self.mask[i, 0, 1, 2] = -2.0
  46. self.mask[i, 0, 2, 2] = -1.0
  47. self.mask = nn.Parameter(data=self.mask, requires_grad=False)
  48. elif self.type == 'conv1x1-sobely':
  49. conv0 = torch.nn.Conv2d(
  50. self.inp_planes, self.out_planes, kernel_size=1, padding=0)
  51. self.k0 = conv0.weight
  52. self.b0 = conv0.bias
  53. # init scale & bias
  54. scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
  55. self.scale = nn.Parameter(torch.FloatTensor(scale))
  56. # bias = 0.0
  57. # bias = [bias for c in range(self.out_planes)]
  58. # bias = torch.FloatTensor(bias)
  59. bias = torch.randn(self.out_planes) * 1e-3
  60. bias = torch.reshape(bias, (self.out_planes, ))
  61. self.bias = nn.Parameter(torch.FloatTensor(bias))
  62. # init mask
  63. self.mask = torch.zeros((self.out_planes, 1, 3, 3),
  64. dtype=torch.float32)
  65. for i in range(self.out_planes):
  66. self.mask[i, 0, 0, 0] = 1.0
  67. self.mask[i, 0, 0, 1] = 2.0
  68. self.mask[i, 0, 0, 2] = 1.0
  69. self.mask[i, 0, 2, 0] = -1.0
  70. self.mask[i, 0, 2, 1] = -2.0
  71. self.mask[i, 0, 2, 2] = -1.0
  72. self.mask = nn.Parameter(data=self.mask, requires_grad=False)
  73. elif self.type == 'conv1x1-laplacian':
  74. conv0 = torch.nn.Conv2d(
  75. self.inp_planes, self.out_planes, kernel_size=1, padding=0)
  76. self.k0 = conv0.weight
  77. self.b0 = conv0.bias
  78. # init scale & bias
  79. scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
  80. self.scale = nn.Parameter(torch.FloatTensor(scale))
  81. # bias = 0.0
  82. # bias = [bias for c in range(self.out_planes)]
  83. # bias = torch.FloatTensor(bias)
  84. bias = torch.randn(self.out_planes) * 1e-3
  85. bias = torch.reshape(bias, (self.out_planes, ))
  86. self.bias = nn.Parameter(torch.FloatTensor(bias))
  87. # init mask
  88. self.mask = torch.zeros((self.out_planes, 1, 3, 3),
  89. dtype=torch.float32)
  90. for i in range(self.out_planes):
  91. self.mask[i, 0, 0, 1] = 1.0
  92. self.mask[i, 0, 1, 0] = 1.0
  93. self.mask[i, 0, 1, 2] = 1.0
  94. self.mask[i, 0, 2, 1] = 1.0
  95. self.mask[i, 0, 1, 1] = -4.0
  96. self.mask = nn.Parameter(data=self.mask, requires_grad=False)
  97. else:
  98. raise ValueError('the type of seqconv is not supported!')
  99. def forward(self, x):
  100. if self.type == 'conv1x1-conv3x3':
  101. # conv-1x1
  102. y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
  103. # explicitly padding with bias
  104. y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
  105. b0_pad = self.b0.view(1, -1, 1, 1)
  106. y0[:, :, 0:1, :] = b0_pad
  107. y0[:, :, -1:, :] = b0_pad
  108. y0[:, :, :, 0:1] = b0_pad
  109. y0[:, :, :, -1:] = b0_pad
  110. # conv-3x3
  111. y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
  112. else:
  113. y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
  114. # explicitly padding with bias
  115. y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
  116. b0_pad = self.b0.view(1, -1, 1, 1)
  117. y0[:, :, 0:1, :] = b0_pad
  118. y0[:, :, -1:, :] = b0_pad
  119. y0[:, :, :, 0:1] = b0_pad
  120. y0[:, :, :, -1:] = b0_pad
  121. # conv-3x3
  122. y1 = F.conv2d(
  123. input=y0,
  124. weight=self.scale * self.mask,
  125. bias=self.bias,
  126. stride=1,
  127. groups=self.out_planes)
  128. return y1
  129. def rep_params(self):
  130. device = self.k0.get_device()
  131. if device < 0:
  132. device = None
  133. if self.type == 'conv1x1-conv3x3':
  134. # re-param conv kernel
  135. RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
  136. # re-param conv bias
  137. RB = torch.ones(
  138. 1, self.mid_planes, 3, 3, device=device) * self.b0.view(
  139. 1, -1, 1, 1)
  140. RB = F.conv2d(input=RB, weight=self.k1).view(-1, ) + self.b1
  141. else:
  142. tmp = self.scale * self.mask
  143. k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3),
  144. device=device)
  145. for i in range(self.out_planes):
  146. k1[i, i, :, :] = tmp[i, 0, :, :]
  147. b1 = self.bias
  148. # re-param conv kernel
  149. RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
  150. # re-param conv bias
  151. RB = torch.ones(
  152. 1, self.out_planes, 3, 3, device=device) * self.b0.view(
  153. 1, -1, 1, 1)
  154. RB = F.conv2d(input=RB, weight=k1).view(-1, ) + b1
  155. return RK, RB
  156. class ECB(nn.Module):
  157. def __init__(self,
  158. inp_planes,
  159. out_planes,
  160. depth_multiplier,
  161. act_type='prelu',
  162. with_idt=False):
  163. super(ECB, self).__init__()
  164. self.depth_multiplier = depth_multiplier
  165. self.inp_planes = inp_planes
  166. self.out_planes = out_planes
  167. self.act_type = act_type
  168. if with_idt and (self.inp_planes == self.out_planes):
  169. self.with_idt = True
  170. else:
  171. self.with_idt = False
  172. self.conv3x3 = torch.nn.Conv2d(
  173. self.inp_planes, self.out_planes, kernel_size=3, padding=1)
  174. self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes,
  175. self.out_planes, self.depth_multiplier)
  176. self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes,
  177. self.out_planes, -1)
  178. self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes,
  179. self.out_planes, -1)
  180. self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes,
  181. self.out_planes, -1)
  182. if self.act_type == 'prelu':
  183. self.act = nn.PReLU(num_parameters=self.out_planes)
  184. elif self.act_type == 'relu':
  185. self.act = nn.ReLU(inplace=True)
  186. elif self.act_type == 'rrelu':
  187. self.act = nn.RReLU(lower=-0.05, upper=0.05)
  188. elif self.act_type == 'softplus':
  189. self.act = nn.Softplus()
  190. elif self.act_type == 'linear':
  191. pass
  192. else:
  193. raise ValueError('The type of activation if not support!')
  194. def forward(self, x):
  195. if self.training:
  196. y = self.conv3x3(x) + \
  197. self.conv1x1_3x3(x) + \
  198. self.conv1x1_sbx(x) + \
  199. self.conv1x1_sby(x) + \
  200. self.conv1x1_lpl(x)
  201. if self.with_idt:
  202. y += x
  203. else:
  204. RK, RB = self.rep_params()
  205. y = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1)
  206. if self.act_type != 'linear':
  207. y = self.act(y)
  208. return y
  209. def rep_params(self):
  210. K0, B0 = self.conv3x3.weight, self.conv3x3.bias
  211. K1, B1 = self.conv1x1_3x3.rep_params()
  212. K2, B2 = self.conv1x1_sbx.rep_params()
  213. K3, B3 = self.conv1x1_sby.rep_params()
  214. K4, B4 = self.conv1x1_lpl.rep_params()
  215. RK, RB = (K0 + K1 + K2 + K3 + K4), (B0 + B1 + B2 + B3 + B4)
  216. if self.with_idt:
  217. device = RK.get_device()
  218. if device < 0:
  219. device = None
  220. K_idt = torch.zeros(
  221. self.out_planes, self.out_planes, 3, 3, device=device)
  222. for i in range(self.out_planes):
  223. K_idt[i, i, 1, 1] = 1.0
  224. B_idt = 0.0
  225. RK, RB = RK + K_idt, RB + B_idt
  226. return RK, RB
  227. if __name__ == '__main__':
  228. # # test seq-conv
  229. x = torch.randn(1, 3, 5, 5).cuda()
  230. conv = SeqConv3x3('conv1x1-conv3x3', 3, 3, 2).cuda()
  231. y0 = conv(x)
  232. RK, RB = conv.rep_params()
  233. y1 = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1)
  234. print(y0 - y1)
  235. # test ecb
  236. x = torch.randn(1, 3, 5, 5).cuda() * 200
  237. ecb = ECB(3, 3, 2, act_type='linear', with_idt=True).cuda()
  238. y0 = ecb(x)
  239. RK, RB = ecb.rep_params()
  240. y1 = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1)
  241. print(y0 - y1)