kws.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict, List, Union
  4. import yaml
  5. from modelscope.metainfo import Preprocessors
  6. from modelscope.models.base import Model
  7. from modelscope.utils.constant import Fields
  8. from .base import Preprocessor
  9. from .builder import PREPROCESSORS
  10. __all__ = ['WavToLists']
  11. @PREPROCESSORS.register_module(
  12. Fields.audio, module_name=Preprocessors.wav_to_lists)
  13. class WavToLists(Preprocessor):
  14. """generate audio lists file from wav
  15. """
  16. def __init__(self):
  17. pass
  18. def __call__(self, model: Model, audio_in: Union[List[str], str,
  19. bytes]) -> Dict[str, Any]:
  20. """Call functions to load model and wav.
  21. Args:
  22. model (Model): model should be provided
  23. audio_in (Union[List[str], str, bytes]):
  24. audio_in[0] is positive wav path, audio_in[1] is negative wav path;
  25. audio_in (str) is positive wav path;
  26. audio_in (bytes) is audio pcm data;
  27. Returns:
  28. Dict[str, Any]: the kws result
  29. """
  30. self.model = model
  31. out = self.forward(self.model.forward(), audio_in)
  32. return out
  33. def forward(self, model: Dict[str, Any],
  34. audio_in: Union[List[str], str, bytes]) -> Dict[str, Any]:
  35. assert len(
  36. model['config_path']) > 0, 'preprocess model[config_path] is empty'
  37. assert os.path.exists(
  38. model['config_path']), 'model config.yaml is absent'
  39. inputs = model.copy()
  40. import kws_util.common
  41. kws_type = kws_util.common.type_checking(audio_in)
  42. assert kws_type in [
  43. 'wav', 'pcm', 'pos_testsets', 'neg_testsets', 'roc'
  44. ], f'kws_type {kws_type} is invalid, please check audio data'
  45. inputs['kws_type'] = kws_type
  46. if kws_type == 'wav':
  47. inputs['pos_wav_path'] = audio_in
  48. elif kws_type == 'pcm':
  49. inputs['pos_data'] = audio_in
  50. if kws_type in ['pos_testsets', 'roc']:
  51. inputs['pos_wav_path'] = audio_in[0]
  52. if kws_type in ['neg_testsets', 'roc']:
  53. inputs['neg_wav_path'] = audio_in[1]
  54. out = self.read_config(inputs)
  55. out = self.generate_wav_lists(out)
  56. return out
  57. def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  58. """read and parse config.yaml to get all model files
  59. """
  60. assert os.path.exists(
  61. inputs['config_path']), 'model config yaml file does not exist'
  62. config_file = open(inputs['config_path'], encoding='utf-8')
  63. root = yaml.full_load(config_file)
  64. config_file.close()
  65. inputs['cfg_file'] = root['cfg_file']
  66. inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'],
  67. root['cfg_file'])
  68. inputs['keyword_grammar'] = root['keyword_grammar']
  69. inputs['keyword_grammar_path'] = os.path.join(
  70. inputs['model_workspace'], root['keyword_grammar'])
  71. inputs['sample_rate'] = root['sample_rate']
  72. return inputs
  73. def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  74. """assemble wav lists
  75. """
  76. import kws_util.common
  77. if inputs['kws_type'] == 'wav':
  78. wav_list = []
  79. wave_scp_content: str = inputs['pos_wav_path']
  80. wav_list.append(wave_scp_content)
  81. inputs['pos_wav_list'] = wav_list
  82. inputs['pos_wav_count'] = 1
  83. inputs['pos_num_thread'] = 1
  84. if inputs['kws_type'] == 'pcm':
  85. inputs['pos_wav_list'] = ['pcm_data']
  86. inputs['pos_wav_count'] = 1
  87. inputs['pos_num_thread'] = 1
  88. if inputs['kws_type'] in ['pos_testsets', 'roc']:
  89. # find all positive wave
  90. wav_list = []
  91. wav_dir = inputs['pos_wav_path']
  92. wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
  93. inputs['pos_wav_list'] = wav_list
  94. list_count: int = len(wav_list)
  95. inputs['pos_wav_count'] = list_count
  96. if list_count <= 128:
  97. inputs['pos_num_thread'] = list_count
  98. else:
  99. inputs['pos_num_thread'] = 128
  100. if inputs['kws_type'] in ['neg_testsets', 'roc']:
  101. # find all negative wave
  102. wav_list = []
  103. wav_dir = inputs['neg_wav_path']
  104. wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
  105. inputs['neg_wav_list'] = wav_list
  106. list_count: int = len(wav_list)
  107. inputs['neg_wav_count'] = list_count
  108. if list_count <= 128:
  109. inputs['neg_num_thread'] = list_count
  110. else:
  111. inputs['neg_num_thread'] = 128
  112. return inputs