| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from typing import Any, Dict
- import json
- from funasr import AutoModel
- from modelscope.metainfo import Models
- from modelscope.models.base import Model
- from modelscope.models.builder import MODELS
- from modelscope.utils.constant import Frameworks, Tasks
- __all__ = ['GenericFunASR']
- @MODELS.register_module(
- Tasks.auto_speech_recognition, module_name=Models.funasr)
- @MODELS.register_module(
- Tasks.voice_activity_detection, module_name=Models.funasr)
- @MODELS.register_module(
- Tasks.language_score_prediction, module_name=Models.funasr)
- @MODELS.register_module(Tasks.punctuation, module_name=Models.funasr)
- @MODELS.register_module(Tasks.speaker_diarization, module_name=Models.funasr)
- @MODELS.register_module(Tasks.speaker_verification, module_name=Models.funasr)
- @MODELS.register_module(Tasks.speech_separation, module_name=Models.funasr)
- @MODELS.register_module(Tasks.speech_timestamp, module_name=Models.funasr)
- @MODELS.register_module(Tasks.emotion_recognition, module_name=Models.funasr)
- class GenericFunASR(Model):
- def __init__(self, model_dir, *args, **kwargs):
- """initialize the info of model.
- Args:
- model_dir (str): the model path.
- am_model_name (str): the am model name from configuration.json
- model_config (Dict[str, Any]): the detail config about model from configuration.json
- """
- super().__init__(model_dir, *args, **kwargs)
- model_cfg = json.loads(
- open(os.path.join(model_dir, 'configuration.json')).read())
- if 'vad_model' not in kwargs and 'vad_model' in model_cfg:
- kwargs['vad_model'] = model_cfg['vad_model']
- kwargs['vad_model_revision'] = model_cfg.get(
- 'vad_model_revision', None)
- if 'punc_model' not in kwargs and 'punc_model' in model_cfg:
- kwargs['punc_model'] = model_cfg['punc_model']
- kwargs['punc_model_revision'] = model_cfg.get(
- 'punc_model_revision', None)
- if 'spk_model' not in kwargs and 'spk_model' in model_cfg:
- kwargs['spk_model'] = model_cfg['spk_model']
- kwargs['spk_model_revision'] = model_cfg.get(
- 'spk_model_revision', None)
- self.model = AutoModel(model=model_dir, **kwargs)
- def forward(self, *args, **kwargs):
- """preload model and return the info of the model
- """
- output = self.model.generate(*args, **kwargs)
- return output
|