denoise_net.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Related papers:
  3. # Shengkui Zhao, Trung Hieu Nguyen, Bin Ma, “Monaural Speech Enhancement with Complex Convolutional
  4. # Block Attention Module and Joint Time Frequency Losses”, ICASSP 2021.
  5. # Shiliang Zhang, Ming Lei, Zhijie Yan, Lirong Dai, “Deep-FSMN for Large Vocabulary Continuous Speech
  6. # Recognition “, arXiv:1803.05030, 2018.
  7. from torch import nn
  8. from modelscope.metainfo import Models
  9. from modelscope.models import MODELS, TorchModel
  10. from modelscope.models.audio.ans.layers.activations import (RectifiedLinear,
  11. Sigmoid)
  12. from modelscope.models.audio.ans.layers.affine_transform import AffineTransform
  13. from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn
  14. from modelscope.utils.constant import Tasks
  15. @MODELS.register_module(
  16. Tasks.acoustic_noise_suppression, module_name=Models.speech_dfsmn_ans)
  17. class DfsmnAns(TorchModel):
  18. """Denoise model with DFSMN.
  19. Args:
  20. model_dir (str): the model path.
  21. fsmn_depth (int): the depth of deepfsmn
  22. lorder (int):
  23. """
  24. def __init__(self,
  25. model_dir: str,
  26. fsmn_depth=9,
  27. lorder=20,
  28. *args,
  29. **kwargs):
  30. super().__init__(model_dir, *args, **kwargs)
  31. self.lorder = lorder
  32. self.linear1 = AffineTransform(120, 256)
  33. self.relu = RectifiedLinear(256, 256)
  34. repeats = [
  35. UniDeepFsmn(256, 256, lorder, 256) for i in range(fsmn_depth)
  36. ]
  37. self.deepfsmn = nn.Sequential(*repeats)
  38. self.linear2 = AffineTransform(256, 961)
  39. self.sig = Sigmoid(961, 961)
  40. def forward(self, input):
  41. """
  42. Args:
  43. input: fbank feature [batch_size,number_of_frame,feature_dimension]
  44. Returns:
  45. mask value [batch_size, number_of_frame, FFT_size/2+1]
  46. """
  47. x1 = self.linear1(input)
  48. x2 = self.relu(x1)
  49. x3 = self.deepfsmn(x2)
  50. x4 = self.linear2(x3)
  51. x5 = self.sig(x4)
  52. return x5
  53. def to_kaldi_nnet(self):
  54. re_str = ''
  55. re_str += '<Nnet>\n'
  56. re_str += self.linear1.to_kaldi_nnet()
  57. re_str += self.relu.to_kaldi_nnet()
  58. for dfsmn in self.deepfsmn:
  59. re_str += dfsmn.to_kaldi_nnet()
  60. re_str += self.linear2.to_kaldi_nnet()
  61. re_str += self.sig.to_kaldi_nnet()
  62. re_str += '</Nnet>\n'
  63. return re_str