linear_aec_pipeline.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import importlib
  3. import os
  4. from typing import Any, Dict
  5. import numpy as np
  6. import scipy.io.wavfile as wav
  7. import torch
  8. import yaml
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.pipelines.base import Pipeline
  12. from modelscope.pipelines.builder import PIPELINES
  13. from modelscope.preprocessors import LinearAECAndFbank
  14. from modelscope.utils.constant import ModelFile, Tasks
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. FEATURE_MVN = 'feature.DEY.mvn.txt'
  18. CONFIG_YAML = 'dey_mini.yaml'
  19. def initialize_config(module_cfg):
  20. r"""According to config items, load specific module dynamically with params.
  21. 1. Load the module corresponding to the "module" param.
  22. 2. Call function (or instantiate class) corresponding to the "main" param.
  23. 3. Send the param (in "args") into the function (or class) when calling ( or instantiating).
  24. Args:
  25. module_cfg (dict): config items, eg:
  26. {
  27. "module": "models.model",
  28. "main": "Model",
  29. "args": {...}
  30. }
  31. Returns:
  32. the module loaded.
  33. """
  34. module = importlib.import_module(module_cfg['module'])
  35. return getattr(module, module_cfg['main'])(**module_cfg['args'])
  36. @PIPELINES.register_module(
  37. Tasks.acoustic_echo_cancellation,
  38. module_name=Pipelines.speech_dfsmn_aec_psm_16k)
  39. class LinearAECPipeline(Pipeline):
  40. r"""AEC Inference Pipeline only support 16000 sample rate.
  41. When invoke the class with pipeline.__call__(), you should provide two params:
  42. Dict[str, Any]
  43. the path of wav files, eg:{
  44. "nearend_mic": "/your/data/near_end_mic_audio.wav",
  45. "farend_speech": "/your/data/far_end_speech_audio.wav"}
  46. output_path (str, optional): "/your/output/audio_after_aec.wav"
  47. the file path to write generate audio.
  48. """
  49. def __init__(self, model, **kwargs):
  50. """
  51. use `model` and `preprocessor` to create a kws pipeline for prediction
  52. Args:
  53. model: model id on modelscope hub.
  54. """
  55. super().__init__(model=model, **kwargs)
  56. self.check_trust_remote_code(
  57. 'This pipeline requires `trust_remote_code=True` to load the module defined'
  58. ' in the `dey_mini.yaml`, setting this to True means you trust the code and files'
  59. ' listed in this model repo.')
  60. self.use_cuda = torch.cuda.is_available()
  61. with open(
  62. os.path.join(self.model, CONFIG_YAML), encoding='utf-8') as f:
  63. self.config = yaml.full_load(f.read())
  64. self.config['io']['mvn'] = os.path.join(self.model, FEATURE_MVN)
  65. self._init_model()
  66. self.preprocessor = LinearAECAndFbank(self.config['io'])
  67. n_fft = self.config['loss']['args']['n_fft']
  68. hop_length = self.config['loss']['args']['hop_length']
  69. winlen = n_fft
  70. window = torch.hamming_window(winlen, periodic=False)
  71. def stft(x):
  72. return torch.view_as_real(
  73. torch.stft(
  74. x,
  75. n_fft,
  76. hop_length,
  77. winlen,
  78. center=False,
  79. window=window.to(x.device),
  80. return_complex=True))
  81. def istft(x, slen):
  82. return torch.istft(
  83. torch.view_as_complex(x),
  84. n_fft,
  85. hop_length,
  86. winlen,
  87. window=window.to(x.device),
  88. center=False,
  89. length=slen)
  90. self.stft = stft
  91. self.istft = istft
  92. def _init_model(self):
  93. checkpoint = torch.load(
  94. os.path.join(self.model, ModelFile.TORCH_MODEL_BIN_FILE),
  95. map_location='cpu',
  96. weights_only=True)
  97. self.model = initialize_config(self.config['nnet'])
  98. if self.use_cuda:
  99. self.model = self.model.cuda()
  100. self.model.load_state_dict(checkpoint)
  101. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  102. r"""The AEC process.
  103. Args:
  104. inputs: dict={'feature': Tensor, 'base': Tensor}
  105. 'feature' feature of input audio.
  106. 'base' the base audio to mask.
  107. Returns:
  108. output_pcm: generated audio array
  109. """
  110. output_data = self._process(inputs['feature'], inputs['base'])
  111. output_data = output_data.astype(np.int16).tobytes()
  112. return {OutputKeys.OUTPUT_PCM: output_data}
  113. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  114. r"""The post process. Will save audio to file, if the output_path is given.
  115. Args:
  116. inputs: a dict contains following keys:
  117. - output_pcm: generated audio array
  118. kwargs: accept 'output_path' which is the path to write generated audio
  119. Returns:
  120. output_pcm: generated audio array
  121. """
  122. if 'output_path' in kwargs.keys():
  123. wav.write(
  124. kwargs['output_path'], self.preprocessor.SAMPLE_RATE,
  125. np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16))
  126. return inputs
  127. def _process(self, fbanks, mixture):
  128. if self.use_cuda:
  129. fbanks = fbanks.cuda()
  130. mixture = mixture.cuda()
  131. if self.model.vad:
  132. with torch.no_grad():
  133. masks, vad = self.model(fbanks.unsqueeze(0))
  134. masks = masks.permute([2, 1, 0])
  135. else:
  136. with torch.no_grad():
  137. masks = self.model(fbanks.unsqueeze(0))
  138. masks = masks.permute([2, 1, 0])
  139. spectrum = self.stft(mixture)
  140. masked_spec = spectrum * masks
  141. masked_sig = self.istft(masked_spec, len(mixture)).cpu().numpy()
  142. return masked_sig