audio_quantization_pipeline.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. from typing import Any, Dict, List, Sequence, Tuple, Union
  5. import numpy as np
  6. import yaml
  7. from modelscope.metainfo import Pipelines
  8. from modelscope.models import Model
  9. from modelscope.outputs import OutputKeys
  10. from modelscope.pipelines.base import Pipeline
  11. from modelscope.pipelines.builder import PIPELINES
  12. from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
  13. update_local_model)
  14. from modelscope.utils.constant import Frameworks, Tasks
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. __all__ = ['AudioQuantizationPipeline']
  18. @PIPELINES.register_module(
  19. Tasks.audio_quantization,
  20. module_name=Pipelines.audio_quantization_inference)
  21. class AudioQuantizationPipeline(Pipeline):
  22. """Audio Quantization Inference Pipeline
  23. use `model` to create a audio quantization pipeline.
  24. Args:
  25. model (AudioQuantizationPipeline): A model instance, or a model local dir, or a model id in the model hub.
  26. kwargs (dict, `optional`):
  27. Extra kwargs passed into the preprocessor's constructor.
  28. Examples:
  29. >>> from modelscope.pipelines import pipeline
  30. >>> from modelscope.utils.constant import Tasks
  31. >>> pipeline_aq = pipeline(
  32. >>> task=Tasks.audio_quantization,
  33. >>> model='damo/audio_codec-encodec-zh_en-general-16k-nq32ds640-pytorch'
  34. >>> )
  35. >>> audio_in='example.wav'
  36. >>> print(pipeline_aq(audio_in))
  37. """
  38. def __init__(self,
  39. model: Union[Model, str] = None,
  40. ngpu: int = 1,
  41. **kwargs):
  42. """use `model` to create an asr pipeline for prediction
  43. """
  44. super().__init__(model=model, **kwargs)
  45. self.model_cfg = self.model.forward()
  46. self.cmd = self.get_cmd(kwargs, model)
  47. from funcodec.bin import codec_inference
  48. self.funasr_infer_modelscope = codec_inference.inference_modelscope(
  49. mode=self.cmd['mode'],
  50. output_dir=self.cmd['output_dir'],
  51. batch_size=self.cmd['batch_size'],
  52. dtype=self.cmd['dtype'],
  53. ngpu=ngpu,
  54. seed=self.cmd['seed'],
  55. num_workers=self.cmd['num_workers'],
  56. log_level=self.cmd['log_level'],
  57. key_file=self.cmd['key_file'],
  58. config_file=self.cmd['config_file'],
  59. model_file=self.cmd['model_file'],
  60. model_tag=self.cmd['model_tag'],
  61. allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
  62. streaming=self.cmd['streaming'],
  63. sampling_rate=self.cmd['sampling_rate'],
  64. bit_width=self.cmd['bit_width'],
  65. use_scale=self.cmd['use_scale'],
  66. param_dict=self.cmd['param_dict'],
  67. **kwargs,
  68. )
  69. def __call__(self,
  70. audio_in: Union[tuple, str, Any] = None,
  71. output_dir: str = None,
  72. param_dict: dict = None) -> Dict[str, Any]:
  73. if len(audio_in) == 0:
  74. raise ValueError('The input should not be null.')
  75. else:
  76. self.audio_in = audio_in
  77. if output_dir is not None:
  78. self.cmd['output_dir'] = output_dir
  79. self.cmd['param_dict'] = param_dict
  80. output = self.forward(self.audio_in)
  81. result = self.postprocess(output)
  82. return result
  83. def postprocess(self, inputs: list) -> Dict[str, Any]:
  84. """Postprocessing
  85. """
  86. rst = {}
  87. for i in range(len(inputs)):
  88. if len(inputs) == 1 and i == 0:
  89. recon_wav = inputs[0]['value']
  90. output_wav = recon_wav.cpu().numpy()[0]
  91. output_wav = (output_wav * (2**15)).astype(np.int16)
  92. rst[OutputKeys.OUTPUT_WAV] = output_wav
  93. else:
  94. # for multiple inputs
  95. rst[inputs[i]['key']] = inputs[i]['value']
  96. return rst
  97. def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
  98. # generate asr inference command
  99. mode = self.model_cfg['model_config']['mode']
  100. _model_path = os.path.join(
  101. self.model_cfg['model_workspace'],
  102. self.model_cfg['model_config']['model_file'])
  103. _model_config = os.path.join(
  104. self.model_cfg['model_workspace'],
  105. self.model_cfg['model_config']['config_file'])
  106. update_local_model(self.model_cfg['model_config'], model_path,
  107. extra_args)
  108. cmd = {
  109. 'mode': mode,
  110. 'output_dir': None,
  111. 'batch_size': 1,
  112. 'dtype': 'float32',
  113. 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
  114. 'seed': 0,
  115. 'num_workers': 0,
  116. 'log_level': 'ERROR',
  117. 'key_file': None,
  118. 'model_file': _model_path,
  119. 'config_file': _model_config,
  120. 'model_tag': None,
  121. 'allow_variable_data_keys': True,
  122. 'streaming': False,
  123. 'sampling_rate': 16000,
  124. 'bit_width': 8000,
  125. 'use_scale': True,
  126. 'param_dict': None,
  127. }
  128. user_args_dict = [
  129. 'output_dir',
  130. 'batch_size',
  131. 'ngpu',
  132. 'log_level',
  133. 'allow_variable_data_keys',
  134. 'streaming',
  135. 'num_workers',
  136. 'sampling_rate',
  137. 'bit_width',
  138. 'use_scale',
  139. 'param_dict',
  140. ]
  141. # re-write the config with configure.json
  142. for user_args in user_args_dict:
  143. if (user_args in self.model_cfg['model_config']
  144. and self.model_cfg['model_config'][user_args] is not None):
  145. if isinstance(cmd[user_args], dict) and isinstance(
  146. self.model_cfg['model_config'][user_args], dict):
  147. cmd[user_args].update(
  148. self.model_cfg['model_config'][user_args])
  149. else:
  150. cmd[user_args] = self.model_cfg['model_config'][user_args]
  151. # rewrite the config with user args
  152. for user_args in user_args_dict:
  153. if user_args in extra_args:
  154. if extra_args.get(user_args) is not None:
  155. if isinstance(cmd[user_args], dict) and isinstance(
  156. extra_args[user_args], dict):
  157. cmd[user_args].update(extra_args[user_args])
  158. else:
  159. cmd[user_args] = extra_args[user_args]
  160. del extra_args[user_args]
  161. return cmd
  162. def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
  163. """Decoding
  164. """
  165. # log file_path/url or tuple (str, str)
  166. if isinstance(audio_in, str):
  167. logger.info(f'Audio Quantization Processing: {audio_in} ...')
  168. else:
  169. logger.info(
  170. f'Audio Quantization Processing: {str(audio_in)[:100]} ...')
  171. data_cmd, raw_inputs = None, None
  172. if isinstance(audio_in, str):
  173. # for scp inputs
  174. if len(audio_in.split(',')) == 3:
  175. data_cmd = [tuple(audio_in.split(','))]
  176. # for single-file inputs
  177. else:
  178. audio_scp, _ = generate_scp_from_url(audio_in)
  179. raw_inputs = audio_scp
  180. # for raw bytes
  181. elif isinstance(audio_in, bytes):
  182. data_cmd = (audio_in, 'speech', 'bytes')
  183. # for ndarray and tensor inputs
  184. else:
  185. import torch
  186. import numpy as np
  187. if isinstance(audio_in, torch.Tensor):
  188. raw_inputs = audio_in
  189. elif isinstance(audio_in, np.ndarray):
  190. raw_inputs = audio_in
  191. else:
  192. raise TypeError('Unsupported data type.')
  193. self.cmd['name_and_type'] = data_cmd
  194. self.cmd['raw_inputs'] = raw_inputs
  195. result = self.run_inference(self.cmd)
  196. return result
  197. def run_inference(self, cmd):
  198. if self.framework == Frameworks.torch:
  199. sv_result = self.funasr_infer_modelscope(
  200. data_path_and_name_and_type=cmd['name_and_type'],
  201. raw_inputs=cmd['raw_inputs'],
  202. output_dir_v2=cmd['output_dir'],
  203. param_dict=cmd['param_dict'])
  204. else:
  205. raise ValueError('model type is mismatching')
  206. return sv_result