unet.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. #
  3. # The implementation here is modified based on
  4. # Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft )
  5. # and publicly available at
  6. # https://github.com/sweetcocoa/DeepComplexUNetPyTorch
  7. import torch
  8. import torch.nn as nn
  9. from . import complex_nn
  10. from .se_module_complex import SELayer
  11. class Encoder(nn.Module):
  12. def __init__(self,
  13. in_channels,
  14. out_channels,
  15. kernel_size,
  16. stride,
  17. padding=None,
  18. complex=False,
  19. padding_mode='zeros'):
  20. super().__init__()
  21. if padding is None:
  22. padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding
  23. if complex:
  24. conv = complex_nn.ComplexConv2d
  25. bn = complex_nn.ComplexBatchNorm2d
  26. else:
  27. conv = nn.Conv2d
  28. bn = nn.BatchNorm2d
  29. self.conv = conv(
  30. in_channels,
  31. out_channels,
  32. kernel_size=kernel_size,
  33. stride=stride,
  34. padding=padding,
  35. padding_mode=padding_mode)
  36. self.bn = bn(out_channels)
  37. self.relu = nn.LeakyReLU(inplace=True)
  38. def forward(self, x):
  39. x = self.conv(x)
  40. x = self.bn(x)
  41. x = self.relu(x)
  42. return x
  43. class Decoder(nn.Module):
  44. def __init__(self,
  45. in_channels,
  46. out_channels,
  47. kernel_size,
  48. stride,
  49. padding=(0, 0),
  50. complex=False):
  51. super().__init__()
  52. if complex:
  53. tconv = complex_nn.ComplexConvTranspose2d
  54. bn = complex_nn.ComplexBatchNorm2d
  55. else:
  56. tconv = nn.ConvTranspose2d
  57. bn = nn.BatchNorm2d
  58. self.transconv = tconv(
  59. in_channels,
  60. out_channels,
  61. kernel_size=kernel_size,
  62. stride=stride,
  63. padding=padding)
  64. self.bn = bn(out_channels)
  65. self.relu = nn.LeakyReLU(inplace=True)
  66. def forward(self, x):
  67. x = self.transconv(x)
  68. x = self.bn(x)
  69. x = self.relu(x)
  70. return x
  71. class UNet(nn.Module):
  72. def __init__(self,
  73. input_channels=1,
  74. complex=False,
  75. model_complexity=45,
  76. model_depth=20,
  77. padding_mode='zeros'):
  78. super().__init__()
  79. if complex:
  80. model_complexity = int(model_complexity // 1.414)
  81. self.set_size(
  82. model_complexity=model_complexity,
  83. input_channels=input_channels,
  84. model_depth=model_depth)
  85. self.encoders = []
  86. self.model_length = model_depth // 2
  87. self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128)
  88. self.se_layers_enc = []
  89. self.fsmn_enc = []
  90. for i in range(self.model_length):
  91. fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
  92. self.add_module('fsmn_enc{}'.format(i), fsmn_enc)
  93. self.fsmn_enc.append(fsmn_enc)
  94. module = Encoder(
  95. self.enc_channels[i],
  96. self.enc_channels[i + 1],
  97. kernel_size=self.enc_kernel_sizes[i],
  98. stride=self.enc_strides[i],
  99. padding=self.enc_paddings[i],
  100. complex=complex,
  101. padding_mode=padding_mode)
  102. self.add_module('encoder{}'.format(i), module)
  103. self.encoders.append(module)
  104. se_layer_enc = SELayer(self.enc_channels[i + 1], 8)
  105. self.add_module('se_layer_enc{}'.format(i), se_layer_enc)
  106. self.se_layers_enc.append(se_layer_enc)
  107. self.decoders = []
  108. self.fsmn_dec = []
  109. self.se_layers_dec = []
  110. for i in range(self.model_length):
  111. fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
  112. self.add_module('fsmn_dec{}'.format(i), fsmn_dec)
  113. self.fsmn_dec.append(fsmn_dec)
  114. module = Decoder(
  115. self.dec_channels[i] * 2,
  116. self.dec_channels[i + 1],
  117. kernel_size=self.dec_kernel_sizes[i],
  118. stride=self.dec_strides[i],
  119. padding=self.dec_paddings[i],
  120. complex=complex)
  121. self.add_module('decoder{}'.format(i), module)
  122. self.decoders.append(module)
  123. if i < self.model_length - 1:
  124. se_layer_dec = SELayer(self.dec_channels[i + 1], 8)
  125. self.add_module('se_layer_dec{}'.format(i), se_layer_dec)
  126. self.se_layers_dec.append(se_layer_dec)
  127. if complex:
  128. conv = complex_nn.ComplexConv2d
  129. else:
  130. conv = nn.Conv2d
  131. linear = conv(self.dec_channels[-1], 1, 1)
  132. self.add_module('linear', linear)
  133. self.complex = complex
  134. self.padding_mode = padding_mode
  135. self.decoders = nn.ModuleList(self.decoders)
  136. self.encoders = nn.ModuleList(self.encoders)
  137. self.se_layers_enc = nn.ModuleList(self.se_layers_enc)
  138. self.se_layers_dec = nn.ModuleList(self.se_layers_dec)
  139. self.fsmn_enc = nn.ModuleList(self.fsmn_enc)
  140. self.fsmn_dec = nn.ModuleList(self.fsmn_dec)
  141. def forward(self, inputs):
  142. x = inputs
  143. # go down
  144. xs = []
  145. xs_se = []
  146. xs_se.append(x)
  147. for i, encoder in enumerate(self.encoders):
  148. xs.append(x)
  149. if i > 0:
  150. x = self.fsmn_enc[i](x)
  151. x = encoder(x)
  152. xs_se.append(self.se_layers_enc[i](x))
  153. # xs : x0=input x1 ... x9
  154. x = self.fsmn(x)
  155. p = x
  156. for i, decoder in enumerate(self.decoders):
  157. p = decoder(p)
  158. if i < self.model_length - 1:
  159. p = self.fsmn_dec[i](p)
  160. if i == self.model_length - 1:
  161. break
  162. if i < self.model_length - 2:
  163. p = self.se_layers_dec[i](p)
  164. p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1)
  165. # cmp_spec: [12, 1, 513, 64, 2]
  166. cmp_spec = self.linear(p)
  167. return cmp_spec
  168. def set_size(self, model_complexity, model_depth=20, input_channels=1):
  169. if model_depth == 14:
  170. self.enc_channels = [
  171. input_channels, 128, 128, 128, 128, 128, 128, 128
  172. ]
  173. self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2),
  174. (5, 2), (2, 2)]
  175. self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
  176. (2, 1)]
  177. self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
  178. (0, 1), (0, 1)]
  179. self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1]
  180. self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2),
  181. (5, 2), (5, 2)]
  182. self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
  183. (2, 1)]
  184. self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
  185. (0, 1), (0, 1)]
  186. elif model_depth == 10:
  187. self.enc_channels = [
  188. input_channels,
  189. 16,
  190. 32,
  191. 64,
  192. 128,
  193. 256,
  194. ]
  195. self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
  196. self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
  197. self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
  198. self.dec_channels = [128, 128, 64, 32, 16, 1]
  199. self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)]
  200. self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
  201. self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
  202. elif model_depth == 20:
  203. self.enc_channels = [
  204. input_channels, model_complexity, model_complexity,
  205. model_complexity * 2, model_complexity * 2,
  206. model_complexity * 2, model_complexity * 2,
  207. model_complexity * 2, model_complexity * 2,
  208. model_complexity * 2, 128
  209. ]
  210. self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3),
  211. (5, 3), (5, 3), (5, 3), (5, 3), (5, 3)]
  212. self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1),
  213. (2, 2), (2, 1), (2, 2), (2, 1)]
  214. self.enc_paddings = [
  215. (3, 0),
  216. (0, 3),
  217. None, # (0, 2),
  218. None,
  219. None, # (3,1),
  220. None, # (3,1),
  221. None, # (1,2),
  222. None,
  223. None,
  224. None
  225. ]
  226. self.dec_channels = [
  227. 0, model_complexity * 2, model_complexity * 2,
  228. model_complexity * 2, model_complexity * 2,
  229. model_complexity * 2, model_complexity * 2,
  230. model_complexity * 2, model_complexity * 2,
  231. model_complexity * 2, model_complexity * 2,
  232. model_complexity * 2
  233. ]
  234. self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3),
  235. (4, 2), (6, 3), (7, 4), (1, 7), (7, 1)]
  236. self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2),
  237. (2, 1), (2, 2), (1, 1), (1, 1)]
  238. self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1),
  239. (1, 0), (2, 1), (2, 1), (0, 3), (3, 0)]
  240. else:
  241. raise ValueError('Unknown model depth : {}'.format(model_depth))