model.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import sys
  4. import tempfile
  5. from typing import Dict, Optional
  6. from modelscope.metainfo import Models
  7. from modelscope.models import TorchModel
  8. from modelscope.models.base import Tensor
  9. from modelscope.models.builder import MODELS
  10. from modelscope.utils.audio.audio_utils import update_conf
  11. from modelscope.utils.constant import Tasks
  12. from .fsmn_sele_v2 import FSMNSeleNetV2
  13. from .fsmn_sele_v3 import FSMNSeleNetV3
  14. @MODELS.register_module(
  15. Tasks.keyword_spotting, module_name=Models.speech_dfsmn_kws_char_farfield)
  16. class FSMNSeleNetV2Decorator(TorchModel):
  17. r""" A decorator of FSMNSeleNetV2 for integrating into modelscope framework """
  18. MODEL_CLASS = FSMNSeleNetV2
  19. MODEL_TXT = 'model.txt'
  20. SC_CONFIG = 'sound_connect.conf'
  21. def __init__(self,
  22. model_dir: str,
  23. training: Optional[bool] = False,
  24. *args,
  25. **kwargs):
  26. """initialize the dfsmn model from the `model_dir` path.
  27. Args:
  28. model_dir (str): the model path.
  29. """
  30. super().__init__(model_dir, *args, **kwargs)
  31. if training:
  32. self.model = self.MODEL_CLASS(*args, **kwargs)
  33. else:
  34. sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
  35. model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
  36. self.tmp_dir = tempfile.TemporaryDirectory()
  37. new_config_file = os.path.join(self.tmp_dir.name, self.SC_CONFIG)
  38. self._sc = None
  39. if os.path.exists(model_txt_file):
  40. conf_dict = dict(kws_model=model_txt_file)
  41. update_conf(sc_config_file, new_config_file, conf_dict)
  42. try:
  43. if sys.version_info >= (3, 11):
  44. raise ImportError('Python version needs to be <= 3.10')
  45. import py_sound_connect
  46. except ImportError:
  47. raise ImportError(
  48. 'py_sound_connect needs python<=3.10, you can install it by:'
  49. 'pip install py_sound_connect -f '
  50. 'https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html'
  51. )
  52. self._sc = py_sound_connect.SoundConnect(new_config_file)
  53. self.size_in = self._sc.bytesPerBlockIn()
  54. self.size_out = self._sc.bytesPerBlockOut()
  55. else:
  56. raise Exception(
  57. f'Invalid model directory! Failed to load model file:'
  58. f' {model_txt_file}.')
  59. def __del__(self):
  60. if hasattr(self, 'tmp_dir'):
  61. self.tmp_dir.cleanup()
  62. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  63. return self.model.forward(input)
  64. def forward_decode(self, data: bytes):
  65. result = {'pcm': self._sc.process(data, self.size_out)}
  66. state = self._sc.kwsState()
  67. if state == 2:
  68. result['kws'] = {
  69. 'keyword':
  70. self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()),
  71. 'offset': self._sc.kwsKeywordOffset(),
  72. 'channel': self._sc.kwsBestChannel(),
  73. 'length': self._sc.kwsKeywordLength(),
  74. 'confidence': self._sc.kwsConfidence()
  75. }
  76. return result
  77. @MODELS.register_module(
  78. Tasks.keyword_spotting,
  79. module_name=Models.speech_dfsmn_kws_char_farfield_iot)
  80. class FSMNSeleNetV3Decorator(FSMNSeleNetV2Decorator):
  81. r""" A decorator of FSMNSeleNetV3 for integrating into modelscope framework """
  82. MODEL_CLASS = FSMNSeleNetV3
  83. def __init__(self,
  84. model_dir: str,
  85. training: Optional[bool] = False,
  86. *args,
  87. **kwargs):
  88. """initialize the dfsmn model from the `model_dir` path.
  89. Args:
  90. model_dir (str): the model path.
  91. """
  92. super().__init__(model_dir, training, *args, **kwargs)