asr.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict, List, Union
  4. from modelscope.metainfo import Preprocessors
  5. from modelscope.models.base import Model
  6. from modelscope.utils.constant import Fields, Frameworks
  7. from .base import Preprocessor
  8. from .builder import PREPROCESSORS
  9. __all__ = ['WavToScp']
  10. @PREPROCESSORS.register_module(
  11. Fields.audio, module_name=Preprocessors.wav_to_scp)
  12. class WavToScp(Preprocessor):
  13. """generate audio scp from wave or ark
  14. """
  15. def __init__(self):
  16. pass
  17. def __call__(self,
  18. model: Model = None,
  19. recog_type: str = None,
  20. audio_format: str = None,
  21. audio_in: Union[str, bytes] = None,
  22. audio_fs: int = None) -> Dict[str, Any]:
  23. assert model is not None, 'preprocess model is empty'
  24. assert recog_type is not None and len(
  25. recog_type) > 0, 'preprocess recog_type is empty'
  26. assert audio_format is not None, 'preprocess audio_format is empty'
  27. assert audio_in is not None, 'preprocess audio_in is empty'
  28. self.am_model = model
  29. out = self.forward(self.am_model.forward(), recog_type, audio_format,
  30. audio_in, audio_fs)
  31. return out
  32. def forward(self, model: Dict[str, Any], recog_type: str,
  33. audio_format: str, audio_in: Union[str, bytes], audio_fs: int,
  34. cmd: Dict[str, Any]) -> Dict[str, Any]:
  35. assert len(recog_type) > 0, 'preprocess recog_type is empty'
  36. assert len(audio_format) > 0, 'preprocess audio_format is empty'
  37. assert len(
  38. model['am_model']) > 0, 'preprocess model[am_model] is empty'
  39. assert len(model['am_model_path']
  40. ) > 0, 'preprocess model[am_model_path] is empty'
  41. assert os.path.exists(
  42. model['am_model_path']), 'preprocess am_model_path does not exist'
  43. assert len(model['model_workspace']
  44. ) > 0, 'preprocess model[model_workspace] is empty'
  45. assert os.path.exists(model['model_workspace']
  46. ), 'preprocess model_workspace does not exist'
  47. assert len(model['model_config']
  48. ) > 0, 'preprocess model[model_config] is empty'
  49. cmd['model_workspace'] = model['model_workspace']
  50. cmd['am_model'] = model['am_model']
  51. cmd['am_model_path'] = model['am_model_path']
  52. cmd['recog_type'] = recog_type
  53. cmd['audio_format'] = audio_format
  54. cmd['model_config'] = model['model_config']
  55. cmd['audio_fs'] = audio_fs
  56. if 'code_base' in cmd['model_config']:
  57. code_base = cmd['model_config']['code_base']
  58. else:
  59. code_base = None
  60. if isinstance(audio_in, str):
  61. # wav file path or the dataset path
  62. cmd['wav_path'] = audio_in
  63. if code_base != 'funasr':
  64. cmd = self.config_checking(cmd)
  65. cmd = self.env_setting(cmd)
  66. return cmd
  67. def config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  68. """config checking
  69. """
  70. assert inputs['model_config'].__contains__(
  71. 'type'), 'model type does not exist'
  72. inputs['model_type'] = inputs['model_config']['type']
  73. # code base
  74. if 'code_base' in inputs['model_config']:
  75. code_base = inputs['model_config']['code_base']
  76. else:
  77. code_base = None
  78. inputs['code_base'] = code_base
  79. # decoding mode
  80. if 'mode' in inputs['model_config']:
  81. mode = inputs['model_config']['mode']
  82. else:
  83. mode = None
  84. inputs['mode'] = mode
  85. if 'lang' in inputs['model_config']:
  86. inputs['model_lang'] = inputs['model_config']['lang']
  87. else:
  88. inputs['model_lang'] = 'zh-cn'
  89. if inputs['model_type'] == Frameworks.torch:
  90. assert inputs['model_config'].__contains__(
  91. 'batch_size'), 'batch_size does not exist'
  92. if inputs['model_config'].__contains__('am_model_config'):
  93. am_model_config = os.path.join(
  94. inputs['model_workspace'],
  95. inputs['model_config']['am_model_config'])
  96. assert os.path.exists(
  97. am_model_config), 'am_model_config does not exist'
  98. inputs['am_model_config'] = am_model_config
  99. else:
  100. inputs['am_model_config'] = ''
  101. if inputs['model_config'].__contains__('asr_model_config'):
  102. asr_model_config = os.path.join(
  103. inputs['model_workspace'],
  104. inputs['model_config']['asr_model_config'])
  105. assert os.path.exists(
  106. asr_model_config), 'asr_model_config does not exist'
  107. inputs['asr_model_config'] = asr_model_config
  108. else:
  109. asr_model_config = ''
  110. inputs['asr_model_config'] = ''
  111. if 'asr_model_wav_config' in inputs['model_config']:
  112. asr_model_wav_config: str = os.path.join(
  113. inputs['model_workspace'],
  114. inputs['model_config']['asr_model_wav_config'])
  115. assert os.path.exists(asr_model_wav_config
  116. ), 'asr_model_wav_config does not exist'
  117. else:
  118. asr_model_wav_config: str = inputs['asr_model_config']
  119. # the lm model file path
  120. if 'lm_model_name' in inputs['model_config']:
  121. lm_model_path = os.path.join(
  122. inputs['model_workspace'],
  123. inputs['model_config']['lm_model_name'])
  124. else:
  125. lm_model_path = None
  126. # the lm config file path
  127. if 'lm_model_config' in inputs['model_config']:
  128. lm_model_config = os.path.join(
  129. inputs['model_workspace'],
  130. inputs['model_config']['lm_model_config'])
  131. else:
  132. lm_model_config = None
  133. if lm_model_path and lm_model_config and os.path.exists(
  134. lm_model_path) and os.path.exists(lm_model_config):
  135. inputs['lm_model_path'] = lm_model_path
  136. inputs['lm_model_config'] = lm_model_config
  137. else:
  138. inputs['lm_model_path'] = None
  139. inputs['lm_model_config'] = None
  140. if 'audio_format' in inputs:
  141. if inputs['audio_format'] == 'wav' or inputs[
  142. 'audio_format'] == 'pcm':
  143. inputs['asr_model_config'] = asr_model_wav_config
  144. else:
  145. inputs['asr_model_config'] = asr_model_config
  146. if inputs['model_config'].__contains__('mvn_file'):
  147. mvn_file = os.path.join(inputs['model_workspace'],
  148. inputs['model_config']['mvn_file'])
  149. assert os.path.exists(mvn_file), 'mvn_file does not exist'
  150. inputs['mvn_file'] = mvn_file
  151. elif inputs['model_type'] == Frameworks.tf:
  152. assert inputs['model_config'].__contains__(
  153. 'vocab_file'), 'vocab_file does not exist'
  154. vocab_file: str = os.path.join(
  155. inputs['model_workspace'],
  156. inputs['model_config']['vocab_file'])
  157. assert os.path.exists(vocab_file), 'vocab file does not exist'
  158. inputs['vocab_file'] = vocab_file
  159. assert inputs['model_config'].__contains__(
  160. 'am_mvn_file'), 'am_mvn_file does not exist'
  161. am_mvn_file: str = os.path.join(
  162. inputs['model_workspace'],
  163. inputs['model_config']['am_mvn_file'])
  164. assert os.path.exists(am_mvn_file), 'am mvn file does not exist'
  165. inputs['am_mvn_file'] = am_mvn_file
  166. else:
  167. raise ValueError('model type is mismatched')
  168. return inputs
  169. def env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  170. # run with datasets, should set datasets_path and text_path
  171. if inputs['recog_type'] != 'wav':
  172. inputs['datasets_path'] = inputs['wav_path']
  173. # run with datasets, and audio format is waveform
  174. if inputs['audio_format'] == 'wav':
  175. inputs['wav_path'] = os.path.join(inputs['datasets_path'],
  176. 'wav', inputs['recog_type'])
  177. inputs['reference_text'] = os.path.join(
  178. inputs['datasets_path'], 'transcript', 'data.text')
  179. assert os.path.exists(
  180. inputs['reference_text']), 'reference text does not exist'
  181. # run with datasets, and audio format is kaldi_ark
  182. elif inputs['audio_format'] == 'kaldi_ark':
  183. inputs['wav_path'] = os.path.join(inputs['datasets_path'],
  184. inputs['recog_type'])
  185. inputs['reference_text'] = os.path.join(
  186. inputs['wav_path'], 'data.text')
  187. assert os.path.exists(
  188. inputs['reference_text']), 'reference text does not exist'
  189. # run with datasets, and audio format is tfrecord
  190. elif inputs['audio_format'] == 'tfrecord':
  191. inputs['wav_path'] = os.path.join(inputs['datasets_path'],
  192. inputs['recog_type'])
  193. inputs['reference_text'] = os.path.join(
  194. inputs['wav_path'], 'data.txt')
  195. assert os.path.exists(
  196. inputs['reference_text']), 'reference text does not exist'
  197. inputs['idx_text'] = os.path.join(inputs['wav_path'],
  198. 'data.idx')
  199. assert os.path.exists(
  200. inputs['idx_text']), 'idx text does not exist'
  201. # set asr model language
  202. if 'lang' in inputs['model_config']:
  203. inputs['model_lang'] = inputs['model_config']['lang']
  204. else:
  205. inputs['model_lang'] = 'zh-cn'
  206. return inputs