model.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict
  4. import json
  5. from funasr import AutoModel
  6. from modelscope.metainfo import Models
  7. from modelscope.models.base import Model
  8. from modelscope.models.builder import MODELS
  9. from modelscope.utils.constant import Frameworks, Tasks
  10. __all__ = ['GenericFunASR']
  11. @MODELS.register_module(
  12. Tasks.auto_speech_recognition, module_name=Models.funasr)
  13. @MODELS.register_module(
  14. Tasks.voice_activity_detection, module_name=Models.funasr)
  15. @MODELS.register_module(
  16. Tasks.language_score_prediction, module_name=Models.funasr)
  17. @MODELS.register_module(Tasks.punctuation, module_name=Models.funasr)
  18. @MODELS.register_module(Tasks.speaker_diarization, module_name=Models.funasr)
  19. @MODELS.register_module(Tasks.speaker_verification, module_name=Models.funasr)
  20. @MODELS.register_module(Tasks.speech_separation, module_name=Models.funasr)
  21. @MODELS.register_module(Tasks.speech_timestamp, module_name=Models.funasr)
  22. @MODELS.register_module(Tasks.emotion_recognition, module_name=Models.funasr)
  23. class GenericFunASR(Model):
  24. def __init__(self, model_dir, *args, **kwargs):
  25. """initialize the info of model.
  26. Args:
  27. model_dir (str): the model path.
  28. am_model_name (str): the am model name from configuration.json
  29. model_config (Dict[str, Any]): the detail config about model from configuration.json
  30. """
  31. super().__init__(model_dir, *args, **kwargs)
  32. model_cfg = json.loads(
  33. open(os.path.join(model_dir, 'configuration.json')).read())
  34. if 'vad_model' not in kwargs and 'vad_model' in model_cfg:
  35. kwargs['vad_model'] = model_cfg['vad_model']
  36. kwargs['vad_model_revision'] = model_cfg.get(
  37. 'vad_model_revision', None)
  38. if 'punc_model' not in kwargs and 'punc_model' in model_cfg:
  39. kwargs['punc_model'] = model_cfg['punc_model']
  40. kwargs['punc_model_revision'] = model_cfg.get(
  41. 'punc_model_revision', None)
  42. if 'spk_model' not in kwargs and 'spk_model' in model_cfg:
  43. kwargs['spk_model'] = model_cfg['spk_model']
  44. kwargs['spk_model_revision'] = model_cfg.get(
  45. 'spk_model_revision', None)
  46. self.model = AutoModel(model=model_dir, **kwargs)
  47. def forward(self, *args, **kwargs):
  48. """preload model and return the info of the model
  49. """
  50. output = self.model.generate(*args, **kwargs)
  51. return output