speech_separation_pipeline.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict, List, Sequence, Tuple, Union
  4. import json
  5. import yaml
  6. from funasr.utils import asr_utils
  7. from modelscope.metainfo import Pipelines
  8. from modelscope.models import Model
  9. from modelscope.outputs import OutputKeys
  10. from modelscope.pipelines.base import Pipeline
  11. from modelscope.pipelines.builder import PIPELINES
  12. from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
  13. update_local_model)
  14. from modelscope.utils.constant import Frameworks, ModelFile, Tasks
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. __all__ = ['SeparationPipeline']
  18. @PIPELINES.register_module(
  19. Tasks.speech_separation, module_name=Pipelines.funasr_speech_separation)
  20. class SeparationPipeline(Pipeline):
  21. """Speech Separation Inference Pipeline
  22. use `model` to create a speech separation pipeline for prediction.
  23. Args:
  24. model: A model instance, or a model local dir, or a model id in the model hub.
  25. kwargs (dict, `optional`):
  26. Extra kwargs passed into the preprocessor's constructor.
  27. Example:
  28. >>> from modelscope.pipelines import pipeline
  29. >>> pipeline = pipeline(
  30. >>> task=Tasks.speech_separation, model='damo/speech_separation_mossformer_8k_pytorch')
  31. >>> audio_in = 'mix_speech.wav'
  32. >>> print(pipeline(audio_in))
  33. """
  34. def __init__(self,
  35. model: Union[Model, str] = None,
  36. ngpu: int = 1,
  37. **kwargs):
  38. """use `model` to create an speech separation pipeline for prediction
  39. """
  40. super().__init__(model=model, **kwargs)
  41. config_path = os.path.join(model, ModelFile.CONFIGURATION)
  42. self.cmd = self.get_cmd(config_path, kwargs, model)
  43. from funasr.bin import ss_inference_launch
  44. self.funasr_infer_modelscope = ss_inference_launch.inference_launch(
  45. mode=self.cmd['mode'],
  46. batch_size=self.cmd['batch_size'],
  47. ngpu=ngpu,
  48. log_level=self.cmd['log_level'],
  49. ss_infer_config=self.cmd['ss_infer_config'],
  50. ss_model_file=self.cmd['ss_model_file'],
  51. output_dir=self.cmd['output_dir'],
  52. dtype=self.cmd['dtype'],
  53. seed=self.cmd['seed'],
  54. num_workers=self.cmd['num_workers'],
  55. num_spks=self.cmd['num_spks'],
  56. param_dict=self.cmd['param_dict'],
  57. **kwargs,
  58. )
  59. def __call__(self,
  60. audio_in: Union[str, bytes],
  61. audio_fs: int = None,
  62. recog_type: str = None,
  63. audio_format: str = None,
  64. output_dir: str = None,
  65. param_dict: dict = None,
  66. **kwargs) -> Dict[str, Any]:
  67. """
  68. Decoding the input audios
  69. Args:
  70. audio_in('str' or 'bytes'):
  71. - A string containing a local path to a wav file
  72. - A string containing a local path to a scp
  73. - A string containing a wav url
  74. - A bytes input
  75. audio_fs('int'):
  76. frequency of sample
  77. recog_type('str'):
  78. recog type for wav file or datasets file ('wav', 'test', 'dev', 'train')
  79. audio_format('str'):
  80. audio format ('pcm', 'scp', 'kaldi_ark', 'tfrecord')
  81. output_dir('str'):
  82. output dir
  83. param_dict('dict'):
  84. extra kwargs
  85. Return:
  86. A dictionary of result or a list of dictionary of result.
  87. The dictionary contain the following keys:
  88. - **text** ('str') --The vad result.
  89. """
  90. self.audio_in = None
  91. self.raw_inputs = None
  92. self.recog_type = recog_type
  93. self.audio_format = audio_format
  94. self.audio_fs = None
  95. checking_audio_fs = None
  96. if output_dir is not None:
  97. self.cmd['output_dir'] = output_dir
  98. if param_dict is not None:
  99. self.cmd['param_dict'] = param_dict
  100. if isinstance(audio_in, str):
  101. # for funasr code, generate wav.scp from url or local path
  102. self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
  103. elif isinstance(audio_in, bytes):
  104. self.audio_in = audio_in
  105. self.raw_inputs = None
  106. else:
  107. import numpy
  108. import torch
  109. if isinstance(audio_in, torch.Tensor):
  110. self.audio_in = None
  111. self.raw_inputs = audio_in
  112. elif isinstance(audio_in, numpy.ndarray):
  113. self.audio_in = None
  114. self.raw_inputs = audio_in
  115. # set the sample_rate of audio_in if checking_audio_fs is valid
  116. if checking_audio_fs is not None:
  117. self.audio_fs = checking_audio_fs
  118. if recog_type is None or audio_format is None:
  119. self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
  120. audio_in=self.audio_in,
  121. recog_type=recog_type,
  122. audio_format=audio_format)
  123. if hasattr(asr_utils,
  124. 'sample_rate_checking') and self.audio_in is not None:
  125. checking_audio_fs = asr_utils.sample_rate_checking(
  126. self.audio_in, self.audio_format)
  127. if checking_audio_fs is not None:
  128. self.audio_fs = checking_audio_fs
  129. if audio_fs is not None:
  130. self.cmd['fs']['audio_fs'] = audio_fs
  131. else:
  132. self.cmd['fs']['audio_fs'] = self.audio_fs
  133. output = self.forward(self.audio_in, **kwargs)
  134. return output
  135. def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
  136. model_cfg = json.loads(open(config_path).read())
  137. model_dir = os.path.dirname(config_path)
  138. # generate inference command
  139. ss_model_path = os.path.join(
  140. model_dir, model_cfg['model']['model_config']['ss_model_name'])
  141. ss_model_config = os.path.join(
  142. model_dir, model_cfg['model']['model_config']['ss_model_config'])
  143. mode = model_cfg['model']['model_config']['mode']
  144. frontend_conf = None
  145. if os.path.exists(ss_model_config):
  146. config_file = open(ss_model_config, encoding='utf-8')
  147. root = yaml.full_load(config_file)
  148. config_file.close()
  149. if 'frontend_conf' in root:
  150. frontend_conf = root['frontend_conf']
  151. update_local_model(model_cfg['model']['model_config'], model_path,
  152. extra_args)
  153. cmd = {
  154. 'mode': mode,
  155. 'batch_size': 1,
  156. 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
  157. 'log_level': 'ERROR',
  158. 'ss_infer_config': ss_model_config,
  159. 'ss_model_file': ss_model_path,
  160. 'output_dir': None,
  161. 'dtype': 'float32',
  162. 'seed': 0,
  163. 'num_workers': 0,
  164. 'num_spks': 2,
  165. 'param_dict': None,
  166. 'fs': {
  167. 'model_fs': None,
  168. 'audio_fs': None
  169. }
  170. }
  171. if frontend_conf is not None and 'fs' in frontend_conf:
  172. cmd['fs']['model_fs'] = frontend_conf['fs']
  173. user_args_dict = [
  174. 'output_dir', 'batch_size', 'mode', 'ngpu', 'param_dict',
  175. 'num_workers', 'fs'
  176. ]
  177. for user_args in user_args_dict:
  178. if user_args in extra_args:
  179. if extra_args.get(user_args) is not None:
  180. cmd[user_args] = extra_args[user_args]
  181. del extra_args[user_args]
  182. return cmd
  183. def postprocess(self, inputs: Dict[str, Any],
  184. **post_params) -> Dict[str, Any]:
  185. return inputs
  186. def forward(self, audio_in: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  187. """Decoding
  188. """
  189. logger.info('Speech Separation Processing ...')
  190. # generate inputs
  191. data_cmd: Sequence[Tuple[str, str, str]]
  192. if isinstance(self.audio_in, bytes):
  193. data_cmd = [self.audio_in, 'speech', 'bytes']
  194. elif isinstance(self.audio_in, str):
  195. data_cmd = [self.audio_in, 'speech', 'sound']
  196. elif self.raw_inputs is not None:
  197. data_cmd = None
  198. self.cmd['name_and_type'] = data_cmd
  199. self.cmd['raw_inputs'] = self.raw_inputs
  200. self.cmd['audio_in'] = self.audio_in
  201. ss_result = self.run_inference(self.cmd, **kwargs)
  202. return ss_result
  203. def run_inference(self, cmd, **kwargs):
  204. ss_result = []
  205. if self.framework == Frameworks.torch:
  206. ss_result = self.funasr_infer_modelscope(
  207. data_path_and_name_and_type=cmd['name_and_type'],
  208. raw_inputs=cmd['raw_inputs'],
  209. output_dir_v2=cmd['output_dir'],
  210. fs=cmd['fs'],
  211. param_dict=cmd['param_dict'],
  212. **kwargs)
  213. else:
  214. raise ValueError('model type is mismatching')
  215. return ss_result