ans_dfsmn_pipeline.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import collections
  3. import io
  4. import os
  5. import sys
  6. from typing import Any, Dict
  7. import librosa
  8. import numpy as np
  9. import soundfile as sf
  10. import torch
  11. from modelscope.fileio import File
  12. from modelscope.metainfo import Pipelines
  13. from modelscope.outputs import OutputKeys
  14. from modelscope.pipelines.base import Input, Pipeline
  15. from modelscope.pipelines.builder import PIPELINES
  16. from modelscope.utils.constant import ModelFile, Tasks
  17. HOP_LENGTH = 960
  18. N_FFT = 1920
  19. WINDOW_NAME_HAM = 'hamming'
  20. STFT_WIN_LEN = 1920
  21. WINLEN = 3840
  22. STRIDE = 1920
  23. @PIPELINES.register_module(
  24. Tasks.acoustic_noise_suppression,
  25. module_name=Pipelines.speech_dfsmn_ans_psm_48k_causal)
  26. class ANSDFSMNPipeline(Pipeline):
  27. """ANS (Acoustic Noise Suppression) inference pipeline based on DFSMN model.
  28. Args:
  29. stream_mode: set its work mode, default False
  30. In stream model, it accepts bytes as pipeline input that should be the audio data in PCM format.
  31. In normal model, it accepts str and treat it as the path of local wav file or the http link of remote wav file.
  32. """
  33. SAMPLE_RATE = 48000
  34. def __init__(self, model, **kwargs):
  35. super().__init__(model=model, **kwargs)
  36. model_bin_file = os.path.join(self.model.model_dir,
  37. ModelFile.TORCH_MODEL_BIN_FILE)
  38. if os.path.exists(model_bin_file):
  39. checkpoint = torch.load(
  40. model_bin_file, map_location=self.device, weights_only=True)
  41. self.model.load_state_dict(checkpoint)
  42. self.model.eval()
  43. self.stream_mode = kwargs.get('stream_mode', False)
  44. if self.stream_mode:
  45. # the unit of WINLEN and STRIDE is frame, 1 frame of 16bit = 2 bytes
  46. byte_buffer_length = \
  47. (WINLEN + STRIDE * (self.model.lorder - 1)) * 2
  48. self.buffer = collections.deque(maxlen=byte_buffer_length)
  49. # padding head
  50. for i in range(STRIDE * 2):
  51. self.buffer.append(b'\0')
  52. # it processes WINLEN frames at the first time, then STRIDE frames
  53. self.byte_length_remain = (STRIDE * 2 - WINLEN) * 2
  54. self.first_forward = True
  55. self.tensor_give_up_length = (WINLEN - STRIDE) // 2
  56. window = torch.hamming_window(
  57. STFT_WIN_LEN, periodic=False, device=self.device)
  58. def stft(x):
  59. return torch.stft(
  60. x,
  61. N_FFT,
  62. HOP_LENGTH,
  63. STFT_WIN_LEN,
  64. center=False,
  65. window=window,
  66. return_complex=False)
  67. def istft(x, slen):
  68. return librosa.istft(
  69. x,
  70. hop_length=HOP_LENGTH,
  71. win_length=STFT_WIN_LEN,
  72. window=WINDOW_NAME_HAM,
  73. center=False,
  74. length=slen)
  75. self.stft = stft
  76. self.istft = istft
  77. def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
  78. if self.stream_mode:
  79. if not isinstance(inputs, bytes):
  80. raise TypeError('Only support bytes in stream mode.')
  81. if len(inputs) > self.buffer.maxlen:
  82. raise ValueError(
  83. f'inputs length too large: {len(inputs)} > {self.buffer.maxlen}'
  84. )
  85. tensor_list = []
  86. current_index = 0
  87. while self.byte_length_remain + len(
  88. inputs) - current_index >= STRIDE * 2:
  89. byte_length_to_add = STRIDE * 2 - self.byte_length_remain
  90. for i in range(current_index,
  91. current_index + byte_length_to_add):
  92. self.buffer.append(inputs[i].to_bytes(
  93. 1, byteorder=sys.byteorder, signed=False))
  94. bytes_io = io.BytesIO()
  95. for b in self.buffer:
  96. bytes_io.write(b)
  97. data = np.frombuffer(bytes_io.getbuffer(), dtype=np.int16)
  98. data_tensor = torch.from_numpy(data).type(torch.FloatTensor)
  99. tensor_list.append(data_tensor)
  100. self.byte_length_remain = 0
  101. current_index += byte_length_to_add
  102. for i in range(current_index, len(inputs)):
  103. self.buffer.append(inputs[i].to_bytes(
  104. 1, byteorder=sys.byteorder, signed=False))
  105. self.byte_length_remain += 1
  106. return {'audio': tensor_list}
  107. else:
  108. if isinstance(inputs, str):
  109. data_bytes = File.read(inputs)
  110. elif isinstance(inputs, bytes):
  111. data_bytes = inputs
  112. else:
  113. raise TypeError(f'Unsupported type {type(inputs)}.')
  114. data_tensor = self.bytes2tensor(data_bytes)
  115. return {'audio': data_tensor}
  116. def bytes2tensor(self, file_bytes):
  117. data1, fs = sf.read(io.BytesIO(file_bytes))
  118. data1 = data1.astype(np.float32)
  119. if len(data1.shape) > 1:
  120. data1 = data1[:, 0]
  121. if fs != self.SAMPLE_RATE:
  122. data1 = librosa.resample(
  123. data1, orig_sr=fs, target_sr=self.SAMPLE_RATE)
  124. data = data1 * 32768
  125. data_tensor = torch.from_numpy(data).type(torch.FloatTensor)
  126. return data_tensor
  127. def forward(self, inputs: Dict[str, Any],
  128. **forward_params) -> Dict[str, Any]:
  129. if self.stream_mode:
  130. bytes_io = io.BytesIO()
  131. for origin_audio in inputs['audio']:
  132. masked_sig = self._forward(origin_audio)
  133. if self.first_forward:
  134. masked_sig = masked_sig[:-self.tensor_give_up_length]
  135. self.first_forward = False
  136. else:
  137. masked_sig = masked_sig[-WINLEN:]
  138. masked_sig = masked_sig[self.tensor_give_up_length:-self.
  139. tensor_give_up_length]
  140. bytes_io.write(masked_sig.astype(np.int16).tobytes())
  141. outputs = bytes_io.getvalue()
  142. else:
  143. origin_audio = inputs['audio']
  144. masked_sig = self._forward(origin_audio)
  145. outputs = masked_sig.astype(np.int16).tobytes()
  146. return {OutputKeys.OUTPUT_PCM: outputs}
  147. def _forward(self, origin_audio):
  148. with torch.no_grad():
  149. audio_in = origin_audio.unsqueeze(0)
  150. import torchaudio
  151. fbanks = torchaudio.compliance.kaldi.fbank(
  152. audio_in,
  153. dither=1.0,
  154. frame_length=40.0,
  155. frame_shift=20.0,
  156. num_mel_bins=120,
  157. sample_frequency=self.SAMPLE_RATE,
  158. window_type=WINDOW_NAME_HAM)
  159. fbanks = fbanks.unsqueeze(0)
  160. masks = self.model(fbanks)
  161. spectrum = self.stft(origin_audio)
  162. masks = masks.permute(2, 1, 0)
  163. masked_spec = (spectrum * masks).cpu()
  164. masked_spec = masked_spec.detach().numpy()
  165. masked_spec_complex = masked_spec[:, :, 0] + 1j * masked_spec[:, :, 1]
  166. masked_sig = self.istft(masked_spec_complex, len(origin_audio))
  167. return masked_sig
  168. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  169. if not self.stream_mode and 'output_path' in kwargs.keys():
  170. sf.write(
  171. kwargs['output_path'],
  172. np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16),
  173. self.SAMPLE_RATE)
  174. return inputs