frcrn.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Dict
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from modelscope.metainfo import Models
  8. from modelscope.models import TorchModel
  9. from modelscope.models.base import Tensor
  10. from modelscope.models.builder import MODELS
  11. from modelscope.utils.constant import ModelFile, Tasks
  12. from .conv_stft import ConviSTFT, ConvSTFT
  13. from .unet import UNet
  14. @MODELS.register_module(
  15. Tasks.acoustic_noise_suppression,
  16. module_name=Models.speech_frcrn_ans_cirm_16k)
  17. class FRCRNDecorator(TorchModel):
  18. r""" A decorator of FRCRN for integrating into modelscope framework """
  19. def __init__(self, model_dir: str, *args, **kwargs):
  20. """initialize the frcrn model from the `model_dir` path.
  21. Args:
  22. model_dir (str): the model path.
  23. """
  24. super().__init__(model_dir, *args, **kwargs)
  25. self.model = FRCRN(*args, **kwargs)
  26. model_bin_file = os.path.join(model_dir,
  27. ModelFile.TORCH_MODEL_BIN_FILE)
  28. if os.path.exists(model_bin_file):
  29. checkpoint = torch.load(
  30. model_bin_file, map_location=torch.device('cpu'))
  31. if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
  32. # the new trained model by user is based on FRCRNDecorator
  33. self.load_state_dict(checkpoint['state_dict'])
  34. else:
  35. # The released model on Modelscope is based on FRCRN
  36. self.model.load_state_dict(checkpoint, strict=False)
  37. def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
  38. result_list = self.model.forward(inputs['noisy'])
  39. output = {
  40. 'spec_l1': result_list[0],
  41. 'wav_l1': result_list[1],
  42. 'mask_l1': result_list[2],
  43. 'spec_l2': result_list[3],
  44. 'wav_l2': result_list[4],
  45. 'mask_l2': result_list[5]
  46. }
  47. if 'clean' in inputs:
  48. mix_result = self.model.loss(
  49. inputs['noisy'], inputs['clean'], result_list, mode='Mix')
  50. output.update(mix_result)
  51. sisnr_result = self.model.loss(
  52. inputs['noisy'], inputs['clean'], result_list, mode='SiSNR')
  53. output.update(sisnr_result)
  54. # logger hooker will use items under 'log_vars'
  55. output['log_vars'] = {k: mix_result[k].item() for k in mix_result}
  56. output['log_vars'].update(
  57. {k: sisnr_result[k].item()
  58. for k in sisnr_result})
  59. return output
  60. class FRCRN(nn.Module):
  61. r""" Frequency Recurrent CRN """
  62. def __init__(self,
  63. complex,
  64. model_complexity,
  65. model_depth,
  66. log_amp,
  67. padding_mode,
  68. win_len=400,
  69. win_inc=100,
  70. fft_len=512,
  71. win_type='hann',
  72. **kwargs):
  73. r"""
  74. Args:
  75. complex: Whether to use complex networks.
  76. model_complexity: define the model complexity with the number of layers
  77. model_depth: Only two options are available : 10, 20
  78. log_amp: Whether to use log amplitude to estimate signals
  79. padding_mode: Encoder's convolution filter. 'zeros', 'reflect'
  80. win_len: length of window used for defining one frame of sample points
  81. win_inc: length of window shifting (equivalent to hop_size)
  82. fft_len: number of Short Time Fourier Transform (STFT) points
  83. win_type: windowing type used in STFT, eg. 'hanning', 'hamming'
  84. """
  85. super().__init__()
  86. self.feat_dim = fft_len // 2 + 1
  87. self.win_len = win_len
  88. self.win_inc = win_inc
  89. self.fft_len = fft_len
  90. self.win_type = win_type
  91. fix = True
  92. self.stft = ConvSTFT(
  93. self.win_len,
  94. self.win_inc,
  95. self.fft_len,
  96. self.win_type,
  97. feature_type='complex',
  98. fix=fix)
  99. self.istft = ConviSTFT(
  100. self.win_len,
  101. self.win_inc,
  102. self.fft_len,
  103. self.win_type,
  104. feature_type='complex',
  105. fix=fix)
  106. self.unet = UNet(
  107. 1,
  108. complex=complex,
  109. model_complexity=model_complexity,
  110. model_depth=model_depth,
  111. padding_mode=padding_mode)
  112. self.unet2 = UNet(
  113. 1,
  114. complex=complex,
  115. model_complexity=model_complexity,
  116. model_depth=model_depth,
  117. padding_mode=padding_mode)
  118. def forward(self, inputs):
  119. out_list = []
  120. # [B, D*2, T]
  121. cmp_spec = self.stft(inputs)
  122. # [B, 1, D*2, T]
  123. cmp_spec = torch.unsqueeze(cmp_spec, 1)
  124. # to [B, 2, D, T] real_part/imag_part
  125. cmp_spec = torch.cat([
  126. cmp_spec[:, :, :self.feat_dim, :],
  127. cmp_spec[:, :, self.feat_dim:, :],
  128. ], 1)
  129. # [B, 2, D, T]
  130. cmp_spec = torch.unsqueeze(cmp_spec, 4)
  131. # [B, 1, D, T, 2]
  132. cmp_spec = torch.transpose(cmp_spec, 1, 4)
  133. unet1_out = self.unet(cmp_spec)
  134. cmp_mask1 = torch.tanh(unet1_out)
  135. unet2_out = self.unet2(unet1_out)
  136. cmp_mask2 = torch.tanh(unet2_out)
  137. est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1)
  138. out_list.append(est_spec)
  139. out_list.append(est_wav)
  140. out_list.append(est_mask)
  141. cmp_mask2 = cmp_mask2 + cmp_mask1
  142. est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2)
  143. out_list.append(est_spec)
  144. out_list.append(est_wav)
  145. out_list.append(est_mask)
  146. return out_list
  147. def apply_mask(self, cmp_spec, cmp_mask):
  148. est_spec = torch.cat([
  149. cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 0]
  150. - cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 1],
  151. cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 1]
  152. + cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 0]
  153. ], 1)
  154. est_spec = torch.cat([est_spec[:, 0, :, :], est_spec[:, 1, :, :]], 1)
  155. cmp_mask = torch.squeeze(cmp_mask, 1)
  156. cmp_mask = torch.cat([cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], 1)
  157. est_wav = self.istft(est_spec)
  158. est_wav = torch.squeeze(est_wav, 1)
  159. return est_spec, est_wav, cmp_mask
  160. def get_params(self, weight_decay=0.0):
  161. # add L2 penalty
  162. weights, biases = [], []
  163. for name, param in self.named_parameters():
  164. if 'bias' in name:
  165. biases += [param]
  166. else:
  167. weights += [param]
  168. params = [{
  169. 'params': weights,
  170. 'weight_decay': weight_decay,
  171. }, {
  172. 'params': biases,
  173. 'weight_decay': 0.0,
  174. }]
  175. return params
  176. def loss(self, noisy, labels, out_list, mode='Mix'):
  177. if mode == 'SiSNR':
  178. count = 0
  179. while count < len(out_list):
  180. est_spec = out_list[count]
  181. count = count + 1
  182. est_wav = out_list[count]
  183. count = count + 1
  184. est_mask = out_list[count]
  185. count = count + 1
  186. if count != 3:
  187. loss = self.loss_1layer(noisy, est_spec, est_wav, labels,
  188. est_mask, mode)
  189. return dict(sisnr=loss)
  190. elif mode == 'Mix':
  191. count = 0
  192. while count < len(out_list):
  193. est_spec = out_list[count]
  194. count = count + 1
  195. est_wav = out_list[count]
  196. count = count + 1
  197. est_mask = out_list[count]
  198. count = count + 1
  199. if count != 3:
  200. amp_loss, phase_loss, SiSNR_loss = self.loss_1layer(
  201. noisy, est_spec, est_wav, labels, est_mask, mode)
  202. loss = amp_loss + phase_loss + SiSNR_loss
  203. return dict(loss=loss, amp_loss=amp_loss, phase_loss=phase_loss)
  204. def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'):
  205. r""" Compute the loss by mode
  206. mode == 'Mix'
  207. est: [B, F*2, T]
  208. labels: [B, F*2,T]
  209. mode == 'SiSNR'
  210. est: [B, T]
  211. labels: [B, T]
  212. """
  213. if mode == 'SiSNR':
  214. if labels.dim() == 3:
  215. labels = torch.squeeze(labels, 1)
  216. if est_wav.dim() == 3:
  217. est_wav = torch.squeeze(est_wav, 1)
  218. return -si_snr(est_wav, labels)
  219. elif mode == 'Mix':
  220. if labels.dim() == 3:
  221. labels = torch.squeeze(labels, 1)
  222. if est_wav.dim() == 3:
  223. est_wav = torch.squeeze(est_wav, 1)
  224. SiSNR_loss = -si_snr(est_wav, labels)
  225. b, d, t = est.size()
  226. S = self.stft(labels)
  227. Sr = S[:, :self.feat_dim, :]
  228. Si = S[:, self.feat_dim:, :]
  229. Y = self.stft(noisy)
  230. Yr = Y[:, :self.feat_dim, :]
  231. Yi = Y[:, self.feat_dim:, :]
  232. Y_pow = Yr**2 + Yi**2
  233. gth_mask = torch.cat([(Sr * Yr + Si * Yi) / (Y_pow + 1e-8),
  234. (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)], 1)
  235. gth_mask[gth_mask > 2] = 1
  236. gth_mask[gth_mask < -2] = -1
  237. amp_loss = F.mse_loss(gth_mask[:, :self.feat_dim, :],
  238. cmp_mask[:, :self.feat_dim, :]) * d
  239. phase_loss = F.mse_loss(gth_mask[:, self.feat_dim:, :],
  240. cmp_mask[:, self.feat_dim:, :]) * d
  241. return amp_loss, phase_loss, SiSNR_loss
  242. def l2_norm(s1, s2):
  243. norm = torch.sum(s1 * s2, -1, keepdim=True)
  244. return norm
  245. def si_snr(s1, s2, eps=1e-8):
  246. s1_s2_norm = l2_norm(s1, s2)
  247. s2_s2_norm = l2_norm(s2, s2)
  248. s_target = s1_s2_norm / (s2_s2_norm + eps) * s2
  249. e_noise = s1 - s_target
  250. target_norm = l2_norm(s_target, s_target)
  251. noise_norm = l2_norm(e_noise, e_noise)
  252. snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps)
  253. return torch.mean(snr)