ans_pipeline.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. from typing import Any, Dict
  4. import librosa
  5. import numpy as np
  6. import soundfile as sf
  7. import torch
  8. from modelscope.fileio import File
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.pipelines.base import Input, Pipeline
  12. from modelscope.pipelines.builder import PIPELINES
  13. from modelscope.utils.audio.audio_utils import audio_norm
  14. from modelscope.utils.constant import Tasks
  15. @PIPELINES.register_module(
  16. Tasks.acoustic_noise_suppression,
  17. module_name=Pipelines.speech_frcrn_ans_cirm_16k)
  18. class ANSPipeline(Pipeline):
  19. r"""ANS (Acoustic Noise Suppression) Inference Pipeline .
  20. When invoke the class with pipeline.__call__(), it accept only one parameter:
  21. inputs(str): the path of wav file
  22. """
  23. SAMPLE_RATE = 16000
  24. def __init__(self, model, **kwargs):
  25. """
  26. use `model` and `preprocessor` to create a kws pipeline for prediction
  27. Args:
  28. model: model id on modelscope hub.
  29. """
  30. super().__init__(model=model, **kwargs)
  31. self.model.eval()
  32. self.stream_mode = kwargs.get('stream_mode', False)
  33. def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
  34. if self.stream_mode:
  35. raise TypeError('This model does not support stream mode!')
  36. if isinstance(inputs, bytes):
  37. data1, fs = sf.read(io.BytesIO(inputs))
  38. elif isinstance(inputs, str):
  39. file_bytes = File.read(inputs)
  40. data1, fs = sf.read(io.BytesIO(file_bytes))
  41. else:
  42. raise TypeError(f'Unsupported type {type(inputs)}.')
  43. if len(data1.shape) > 1:
  44. data1 = data1[:, 0]
  45. if fs != self.SAMPLE_RATE:
  46. data1 = librosa.resample(
  47. data1, orig_sr=fs, target_sr=self.SAMPLE_RATE)
  48. data1 = audio_norm(data1)
  49. data = data1.astype(np.float32)
  50. inputs = np.reshape(data, [1, data.shape[0]])
  51. return {'ndarray': inputs, 'nsamples': data.shape[0]}
  52. def forward(self, inputs: Dict[str, Any],
  53. **forward_params) -> Dict[str, Any]:
  54. ndarray = inputs['ndarray']
  55. if isinstance(ndarray, torch.Tensor):
  56. ndarray = ndarray.cpu().numpy()
  57. nsamples = inputs['nsamples']
  58. decode_do_segement = False
  59. window = 16000
  60. stride = int(window * 0.75)
  61. print('inputs:{}'.format(ndarray.shape))
  62. b, t = ndarray.shape # size()
  63. if t > window * 120:
  64. decode_do_segement = True
  65. if t < window:
  66. ndarray = np.concatenate(
  67. [ndarray, np.zeros((ndarray.shape[0], window - t))], 1)
  68. elif t < window + stride:
  69. padding = window + stride - t
  70. print('padding: {}'.format(padding))
  71. ndarray = np.concatenate(
  72. [ndarray, np.zeros((ndarray.shape[0], padding))], 1)
  73. else:
  74. if (t - window) % stride != 0:
  75. padding = t - (t - window) // stride * stride
  76. print('padding: {}'.format(padding))
  77. ndarray = np.concatenate(
  78. [ndarray, np.zeros((ndarray.shape[0], padding))], 1)
  79. print('inputs after padding:{}'.format(ndarray.shape))
  80. with torch.no_grad():
  81. ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device)
  82. b, t = ndarray.shape
  83. if decode_do_segement:
  84. outputs = np.zeros(t)
  85. give_up_length = (window - stride) // 2
  86. current_idx = 0
  87. while current_idx + window <= t:
  88. print('current_idx: {}'.format(current_idx))
  89. tmp_input = dict(noisy=ndarray[:, current_idx:current_idx
  90. + window])
  91. tmp_output = self.model(
  92. tmp_input, )['wav_l2'][0].cpu().numpy()
  93. end_index = current_idx + window - give_up_length
  94. if current_idx == 0:
  95. outputs[current_idx:
  96. end_index] = tmp_output[:-give_up_length]
  97. else:
  98. outputs[current_idx
  99. + give_up_length:end_index] = tmp_output[
  100. give_up_length:-give_up_length]
  101. current_idx += stride
  102. else:
  103. outputs = self.model(
  104. dict(noisy=ndarray))['wav_l2'][0].cpu().numpy()
  105. outputs = (outputs[:nsamples] * 32768).astype(np.int16).tobytes()
  106. return {OutputKeys.OUTPUT_PCM: outputs}
  107. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  108. if 'output_path' in kwargs.keys():
  109. sf.write(
  110. kwargs['output_path'],
  111. np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16),
  112. self.SAMPLE_RATE)
  113. return inputs
  114. @PIPELINES.register_module(
  115. Tasks.acoustic_noise_suppression,
  116. module_name=Pipelines.speech_zipenhancer_ans_multiloss_16k_base)
  117. class ANSZipEnhancerPipeline(Pipeline):
  118. r"""ANS (Acoustic Noise Suppression) Inference Pipeline .
  119. When invoke the class with pipeline.__call__(), it accept only one parameter:
  120. inputs(str): the path of wav file
  121. """
  122. SAMPLE_RATE = 16000
  123. def __init__(self, model, **kwargs):
  124. """
  125. use `model` and `preprocessor` to create a kws pipeline for prediction
  126. Args:
  127. model: model id on modelscope hub.
  128. """
  129. super().__init__(model=model, **kwargs)
  130. self.model.eval()
  131. self.stream_mode = kwargs.get('stream_mode', False)
  132. def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
  133. if self.stream_mode:
  134. raise TypeError('This model does not support stream mode!')
  135. if isinstance(inputs, bytes):
  136. data1, fs = sf.read(io.BytesIO(inputs))
  137. elif isinstance(inputs, str):
  138. file_bytes = File.read(inputs)
  139. data1, fs = sf.read(io.BytesIO(file_bytes))
  140. else:
  141. raise TypeError(f'Unsupported type {type(inputs)}.')
  142. if len(data1.shape) > 1:
  143. data1 = data1[:, 0]
  144. if fs != self.SAMPLE_RATE:
  145. data1 = librosa.resample(
  146. data1, orig_sr=fs, target_sr=self.SAMPLE_RATE)
  147. data1 = audio_norm(data1)
  148. data = data1.astype(np.float32)
  149. inputs = np.reshape(data, [1, data.shape[0]])
  150. return {'ndarray': inputs, 'nsamples': data.shape[0]}
  151. def forward(self, inputs: Dict[str, Any],
  152. **forward_params) -> Dict[str, Any]:
  153. ndarray = inputs['ndarray']
  154. if isinstance(ndarray, torch.Tensor):
  155. ndarray = ndarray.cpu().numpy()
  156. nsamples = inputs['nsamples']
  157. decode_do_segement = False
  158. window = 16000 * 2 # 2s
  159. stride = int(window * 0.75)
  160. print('inputs:{}'.format(ndarray.shape))
  161. b, t = ndarray.shape # size()
  162. if t > window * 3: # 6s
  163. decode_do_segement = True
  164. print('decode_do_segement')
  165. if t < window:
  166. ndarray = np.concatenate(
  167. [ndarray, np.zeros((ndarray.shape[0], window - t))], 1)
  168. elif decode_do_segement:
  169. if t < window + stride:
  170. padding = window + stride - t
  171. print('padding: {}'.format(padding))
  172. ndarray = np.concatenate(
  173. [ndarray, np.zeros((ndarray.shape[0], padding))], 1)
  174. else:
  175. if (t - window) % stride != 0:
  176. # padding = t - (t - window) // stride * stride
  177. padding = (
  178. (t - window) // stride + 1) * stride + window - t
  179. print('padding: {}'.format(padding))
  180. ndarray = np.concatenate(
  181. [ndarray,
  182. np.zeros((ndarray.shape[0], padding))], 1)
  183. # else:
  184. # if (t - window) % stride != 0:
  185. # padding = t - (t - window) // stride * stride
  186. # print('padding: {}'.format(padding))
  187. # ndarray = np.concatenate(
  188. # [ndarray, np.zeros((ndarray.shape[0], padding))], 1)
  189. print('inputs after padding:{}'.format(ndarray.shape))
  190. with torch.no_grad():
  191. ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device)
  192. b, t = ndarray.shape
  193. if decode_do_segement:
  194. outputs = np.zeros(t)
  195. give_up_length = (window - stride) // 2
  196. current_idx = 0
  197. while current_idx + window <= t:
  198. # print('current_idx: {}'.format(current_idx))
  199. print(
  200. '\rcurrent_idx: {} {:.2f}%'.format(
  201. current_idx, current_idx * 100 / t),
  202. end='')
  203. tmp_input = dict(noisy=ndarray[:, current_idx:current_idx
  204. + window])
  205. tmp_output = self.model(
  206. tmp_input, )['wav_l2'][0].cpu().numpy()
  207. end_index = current_idx + window - give_up_length
  208. if current_idx == 0:
  209. outputs[current_idx:
  210. end_index] = tmp_output[:-give_up_length]
  211. else:
  212. outputs[current_idx
  213. + give_up_length:end_index] = tmp_output[
  214. give_up_length:-give_up_length]
  215. current_idx += stride
  216. print('\rcurrent_idx: {} {:.2f}%'.format(current_idx, 100))
  217. else:
  218. outputs = self.model(
  219. dict(noisy=ndarray))['wav_l2'][0].cpu().numpy()
  220. outputs = (outputs[:nsamples] * 32768).astype(np.int16).tobytes()
  221. return {OutputKeys.OUTPUT_PCM: outputs}
  222. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  223. if 'output_path' in kwargs.keys():
  224. sf.write(
  225. kwargs['output_path'],
  226. np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16),
  227. self.SAMPLE_RATE)
  228. return inputs