complex_nn.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. #
  3. # The implementation of class ComplexConv2d, ComplexConvTranspose2d and
  4. # ComplexBatchNorm2d here is modified based on Jongho Choi(sweetcocoa@snu.ac.kr
  5. # / Seoul National Univ., ESTsoft ) and publicly available at
  6. # https://github.com/sweetcocoa/DeepComplexUNetPyTorch
  7. import torch
  8. import torch.nn as nn
  9. from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn
  10. class ComplexUniDeepFsmn(nn.Module):
  11. def __init__(self, nIn, nHidden=128, nOut=128):
  12. super(ComplexUniDeepFsmn, self).__init__()
  13. self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
  14. self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
  15. self.fsmn_re_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)
  16. self.fsmn_im_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)
  17. def forward(self, x):
  18. r"""
  19. Args:
  20. x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2]
  21. Returns:
  22. [batch, feature, sequence, 2], eg: [6, 99, 1024, 2]
  23. """
  24. #
  25. b, c, h, T, d = x.size()
  26. x = torch.reshape(x, (b, c * h, T, d))
  27. # x: [b,h,T,2], [6, 256, 106, 2]
  28. x = torch.transpose(x, 1, 2)
  29. # x: [b,T,h,2], [6, 106, 256, 2]
  30. real_L1 = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
  31. imaginary_L1 = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])
  32. # GRU output: [99, 6, 128]
  33. real = self.fsmn_re_L2(real_L1) - self.fsmn_im_L2(imaginary_L1)
  34. imaginary = self.fsmn_re_L2(imaginary_L1) + self.fsmn_im_L2(real_L1)
  35. # output: [b,T,h,2], [99, 6, 1024, 2]
  36. output = torch.stack((real, imaginary), dim=-1)
  37. # output: [b,h,T,2], [6, 99, 1024, 2]
  38. output = torch.transpose(output, 1, 2)
  39. output = torch.reshape(output, (b, c, h, T, d))
  40. return output
  41. class ComplexUniDeepFsmn_L1(nn.Module):
  42. def __init__(self, nIn, nHidden=128, nOut=128):
  43. super(ComplexUniDeepFsmn_L1, self).__init__()
  44. self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
  45. self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
  46. def forward(self, x):
  47. r"""
  48. Args:
  49. x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2]
  50. """
  51. b, c, h, T, d = x.size()
  52. # x : [b,T,h,c,2]
  53. x = torch.transpose(x, 1, 3)
  54. x = torch.reshape(x, (b * T, h, c, d))
  55. real = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
  56. imaginary = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])
  57. # output: [b*T,h,c,2], [6*106, h, 256, 2]
  58. output = torch.stack((real, imaginary), dim=-1)
  59. output = torch.reshape(output, (b, T, h, c, d))
  60. output = torch.transpose(output, 1, 3)
  61. return output
  62. class ComplexConv2d(nn.Module):
  63. # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py
  64. def __init__(self,
  65. in_channel,
  66. out_channel,
  67. kernel_size,
  68. stride=1,
  69. padding=0,
  70. dilation=1,
  71. groups=1,
  72. bias=True,
  73. **kwargs):
  74. super().__init__()
  75. # Model components
  76. self.conv_re = nn.Conv2d(
  77. in_channel,
  78. out_channel,
  79. kernel_size,
  80. stride=stride,
  81. padding=padding,
  82. dilation=dilation,
  83. groups=groups,
  84. bias=bias,
  85. **kwargs)
  86. self.conv_im = nn.Conv2d(
  87. in_channel,
  88. out_channel,
  89. kernel_size,
  90. stride=stride,
  91. padding=padding,
  92. dilation=dilation,
  93. groups=groups,
  94. bias=bias,
  95. **kwargs)
  96. def forward(self, x):
  97. r"""
  98. Args:
  99. x: torch with shape: [batch,channel,axis1,axis2,2]
  100. """
  101. real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1])
  102. imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0])
  103. output = torch.stack((real, imaginary), dim=-1)
  104. return output
  105. class ComplexConvTranspose2d(nn.Module):
  106. def __init__(self,
  107. in_channel,
  108. out_channel,
  109. kernel_size,
  110. stride=1,
  111. padding=0,
  112. output_padding=0,
  113. dilation=1,
  114. groups=1,
  115. bias=True,
  116. **kwargs):
  117. super().__init__()
  118. # Model components
  119. self.tconv_re = nn.ConvTranspose2d(
  120. in_channel,
  121. out_channel,
  122. kernel_size=kernel_size,
  123. stride=stride,
  124. padding=padding,
  125. output_padding=output_padding,
  126. groups=groups,
  127. bias=bias,
  128. dilation=dilation,
  129. **kwargs)
  130. self.tconv_im = nn.ConvTranspose2d(
  131. in_channel,
  132. out_channel,
  133. kernel_size=kernel_size,
  134. stride=stride,
  135. padding=padding,
  136. output_padding=output_padding,
  137. groups=groups,
  138. bias=bias,
  139. dilation=dilation,
  140. **kwargs)
  141. def forward(self, x): # shape of x : [batch,channel,axis1,axis2,2]
  142. real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1])
  143. imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0])
  144. output = torch.stack((real, imaginary), dim=-1)
  145. return output
  146. class ComplexBatchNorm2d(nn.Module):
  147. def __init__(self,
  148. num_features,
  149. eps=1e-5,
  150. momentum=0.1,
  151. affine=True,
  152. track_running_stats=True,
  153. **kwargs):
  154. super().__init__()
  155. self.bn_re = nn.BatchNorm2d(
  156. num_features=num_features,
  157. momentum=momentum,
  158. affine=affine,
  159. eps=eps,
  160. track_running_stats=track_running_stats,
  161. **kwargs)
  162. self.bn_im = nn.BatchNorm2d(
  163. num_features=num_features,
  164. momentum=momentum,
  165. affine=affine,
  166. eps=eps,
  167. track_running_stats=track_running_stats,
  168. **kwargs)
  169. def forward(self, x):
  170. real = self.bn_re(x[..., 0])
  171. imag = self.bn_im(x[..., 1])
  172. output = torch.stack((real, imag), dim=-1)
  173. return output