# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Any, Dict, List, Union import yaml from modelscope.metainfo import Preprocessors from modelscope.models.base import Model from modelscope.utils.constant import Fields from .base import Preprocessor from .builder import PREPROCESSORS __all__ = ['WavToLists'] @PREPROCESSORS.register_module( Fields.audio, module_name=Preprocessors.wav_to_lists) class WavToLists(Preprocessor): """generate audio lists file from wav """ def __init__(self): pass def __call__(self, model: Model, audio_in: Union[List[str], str, bytes]) -> Dict[str, Any]: """Call functions to load model and wav. Args: model (Model): model should be provided audio_in (Union[List[str], str, bytes]): audio_in[0] is positive wav path, audio_in[1] is negative wav path; audio_in (str) is positive wav path; audio_in (bytes) is audio pcm data; Returns: Dict[str, Any]: the kws result """ self.model = model out = self.forward(self.model.forward(), audio_in) return out def forward(self, model: Dict[str, Any], audio_in: Union[List[str], str, bytes]) -> Dict[str, Any]: assert len( model['config_path']) > 0, 'preprocess model[config_path] is empty' assert os.path.exists( model['config_path']), 'model config.yaml is absent' inputs = model.copy() import kws_util.common kws_type = kws_util.common.type_checking(audio_in) assert kws_type in [ 'wav', 'pcm', 'pos_testsets', 'neg_testsets', 'roc' ], f'kws_type {kws_type} is invalid, please check audio data' inputs['kws_type'] = kws_type if kws_type == 'wav': inputs['pos_wav_path'] = audio_in elif kws_type == 'pcm': inputs['pos_data'] = audio_in if kws_type in ['pos_testsets', 'roc']: inputs['pos_wav_path'] = audio_in[0] if kws_type in ['neg_testsets', 'roc']: inputs['neg_wav_path'] = audio_in[1] out = self.read_config(inputs) out = self.generate_wav_lists(out) return out def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """read and parse config.yaml to get all model files """ assert os.path.exists( inputs['config_path']), 'model config yaml file does not exist' config_file = open(inputs['config_path'], encoding='utf-8') root = yaml.full_load(config_file) config_file.close() inputs['cfg_file'] = root['cfg_file'] inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'], root['cfg_file']) inputs['keyword_grammar'] = root['keyword_grammar'] inputs['keyword_grammar_path'] = os.path.join( inputs['model_workspace'], root['keyword_grammar']) inputs['sample_rate'] = root['sample_rate'] return inputs def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """assemble wav lists """ import kws_util.common if inputs['kws_type'] == 'wav': wav_list = [] wave_scp_content: str = inputs['pos_wav_path'] wav_list.append(wave_scp_content) inputs['pos_wav_list'] = wav_list inputs['pos_wav_count'] = 1 inputs['pos_num_thread'] = 1 if inputs['kws_type'] == 'pcm': inputs['pos_wav_list'] = ['pcm_data'] inputs['pos_wav_count'] = 1 inputs['pos_num_thread'] = 1 if inputs['kws_type'] in ['pos_testsets', 'roc']: # find all positive wave wav_list = [] wav_dir = inputs['pos_wav_path'] wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) inputs['pos_wav_list'] = wav_list list_count: int = len(wav_list) inputs['pos_wav_count'] = list_count if list_count <= 128: inputs['pos_num_thread'] = list_count else: inputs['pos_num_thread'] = 128 if inputs['kws_type'] in ['neg_testsets', 'roc']: # find all negative wave wav_list = [] wav_dir = inputs['neg_wav_path'] wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir) inputs['neg_wav_list'] = wav_list list_count: int = len(wav_list) inputs['neg_wav_count'] = list_count if list_count <= 128: inputs['neg_num_thread'] = list_count else: inputs['neg_num_thread'] = 128 return inputs