kws_farfield_pipeline.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. import wave
  4. from typing import Any, Dict
  5. import numpy
  6. import soundfile as sf
  7. from modelscope.fileio import File
  8. from modelscope.metainfo import Pipelines
  9. from modelscope.outputs import OutputKeys
  10. from modelscope.pipelines.base import Input, Pipeline
  11. from modelscope.pipelines.builder import PIPELINES
  12. from modelscope.utils.constant import Tasks
  13. @PIPELINES.register_module(
  14. Tasks.keyword_spotting,
  15. module_name=Pipelines.speech_dfsmn_kws_char_farfield)
  16. class KWSFarfieldPipeline(Pipeline):
  17. r"""A Keyword Spotting Inference Pipeline .
  18. When invoke the class with pipeline.__call__(), it accept only one parameter:
  19. inputs(str): the path of wav file
  20. """
  21. SAMPLE_RATE = 16000
  22. SAMPLE_WIDTH = 2
  23. INPUT_CHANNELS = 3
  24. OUTPUT_CHANNELS = 2
  25. def __init__(self, model, **kwargs):
  26. """
  27. use `model` to create a kws far field pipeline for prediction
  28. Args:
  29. model: model id on modelscope hub.
  30. """
  31. super().__init__(model=model, **kwargs)
  32. self.model = self.model.to(self.device)
  33. self.model.eval()
  34. frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH
  35. self._nframe = self.model.size_in // frame_size
  36. if 'keyword_map' in kwargs:
  37. self._keyword_map = kwargs['keyword_map']
  38. else:
  39. self._keyword_map = {}
  40. def _sanitize_parameters(self, **pipeline_parameters):
  41. return pipeline_parameters, pipeline_parameters, pipeline_parameters
  42. def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
  43. if isinstance(inputs, bytes):
  44. return dict(input_file=inputs)
  45. elif isinstance(inputs, str):
  46. return dict(input_file=inputs)
  47. elif isinstance(inputs, Dict):
  48. return inputs
  49. else:
  50. raise ValueError(f'Not supported input type: {type(inputs)}')
  51. def forward(self, inputs: Dict[str, Any],
  52. **forward_params) -> Dict[str, Any]:
  53. input_file = inputs['input_file']
  54. if isinstance(input_file, str):
  55. input_file = File.read(input_file)
  56. frames, samplerate = sf.read(io.BytesIO(input_file), dtype='int16')
  57. if len(frames.shape) == 1:
  58. frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1)
  59. kws_list = []
  60. if 'output_file' in forward_params:
  61. with wave.open(forward_params['output_file'], 'wb') as fout:
  62. fout.setframerate(self.SAMPLE_RATE)
  63. fout.setnchannels(self.OUTPUT_CHANNELS)
  64. fout.setsampwidth(self.SAMPLE_WIDTH)
  65. self._process(frames, kws_list, fout)
  66. else:
  67. self._process(frames, kws_list)
  68. return {OutputKeys.KWS_LIST: kws_list}
  69. def _process(self,
  70. frames: numpy.ndarray,
  71. kws_list,
  72. fout: wave.Wave_write = None):
  73. for start_index in range(0, frames.shape[0], self._nframe):
  74. end_index = start_index + self._nframe
  75. if end_index > frames.shape[0]:
  76. end_index = frames.shape[0]
  77. data = frames[start_index:end_index, :].tobytes()
  78. result = self.model.forward_decode(data)
  79. if fout:
  80. fout.writeframes(result['pcm'])
  81. if 'kws' in result:
  82. result['kws']['offset'] += start_index / self.SAMPLE_RATE
  83. result['kws']['type'] = 'wakeup'
  84. keyword = result['kws']['keyword']
  85. if keyword in self._keyword_map:
  86. result['kws']['keyword'] = self._keyword_map[keyword]
  87. kws_list.append(result['kws'])
  88. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  89. return inputs