model.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import sys
  4. import tempfile
  5. from typing import Dict, Optional, Tuple
  6. import torch
  7. import torch.nn as nn
  8. from modelscope.metainfo import Models
  9. from modelscope.models import TorchModel
  10. from modelscope.models.base import Tensor
  11. from modelscope.models.builder import MODELS
  12. from modelscope.utils.audio.audio_utils import update_conf
  13. from modelscope.utils.constant import Tasks
  14. from .cmvn import GlobalCMVN, load_kaldi_cmvn
  15. from .fsmn import FSMN
  16. @MODELS.register_module(
  17. Tasks.keyword_spotting,
  18. module_name=Models.speech_kws_fsmn_char_ctc_nearfield)
  19. class FSMNDecorator(TorchModel):
  20. r""" A decorator of FSMN for integrating into modelscope framework """
  21. def __init__(self,
  22. model_dir: str,
  23. cmvn_file: str = None,
  24. backbone: dict = None,
  25. input_dim: int = 400,
  26. output_dim: int = 2599,
  27. training: Optional[bool] = False,
  28. *args,
  29. **kwargs):
  30. """initialize the fsmn model from the `model_dir` path.
  31. Args:
  32. model_dir (str): the model path.
  33. cmvn_file (str): cmvn file
  34. backbone (dict): params related to backbone
  35. input_dim (int): input dimension of network
  36. output_dim (int): output dimension of network
  37. training (bool): training or inference mode
  38. """
  39. super().__init__(model_dir, *args, **kwargs)
  40. self.model = None
  41. self.model_cfg = None
  42. if training:
  43. self.model = self.init_model(cmvn_file, backbone, input_dim,
  44. output_dim)
  45. else:
  46. self.model_cfg = {
  47. 'model_workspace': model_dir,
  48. 'config_path': os.path.join(model_dir, 'config.yaml')
  49. }
  50. def __del__(self):
  51. if hasattr(self, 'tmp_dir'):
  52. self.tmp_dir.cleanup()
  53. def forward(self, input) -> Dict[str, Tensor]:
  54. """
  55. Args:
  56. input (torch.Tensor): Input tensor (B, T, D)
  57. """
  58. if self.model is not None and input is not None:
  59. return self.model.forward(input)
  60. else:
  61. return self.model_cfg
  62. def init_model(self, cmvn_file, backbone, input_dim, output_dim):
  63. if cmvn_file is not None:
  64. mean, istd = load_kaldi_cmvn(cmvn_file)
  65. global_cmvn = GlobalCMVN(
  66. torch.from_numpy(mean).float(),
  67. torch.from_numpy(istd).float(),
  68. )
  69. else:
  70. global_cmvn = None
  71. hidden_dim = 128
  72. preprocessing = None
  73. input_affine_dim = backbone['input_affine_dim']
  74. num_layers = backbone['num_layers']
  75. linear_dim = backbone['linear_dim']
  76. proj_dim = backbone['proj_dim']
  77. left_order = backbone['left_order']
  78. right_order = backbone['right_order']
  79. left_stride = backbone['left_stride']
  80. right_stride = backbone['right_stride']
  81. output_affine_dim = backbone['output_affine_dim']
  82. backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim,
  83. proj_dim, left_order, right_order, left_stride,
  84. right_stride, output_affine_dim, output_dim)
  85. classifier = None
  86. activation = None
  87. kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
  88. preprocessing, backbone, classifier, activation)
  89. return kws_model
  90. class KWSModel(nn.Module):
  91. """Our model consists of four parts:
  92. 1. global_cmvn: Optional, (idim, idim)
  93. 2. preprocessing: feature dimension projection, (idim, hdim)
  94. 3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
  95. 4. classifier: output layer or classifier of KWS model, (hdim, odim)
  96. 5. activation:
  97. nn.Sigmoid for wakeup word
  98. nn.Identity for speech command dataset
  99. """
  100. def __init__(
  101. self,
  102. idim: int,
  103. odim: int,
  104. hdim: int,
  105. global_cmvn: Optional[nn.Module],
  106. preprocessing: Optional[nn.Module],
  107. backbone: nn.Module,
  108. classifier: nn.Module,
  109. activation: nn.Module,
  110. ):
  111. """
  112. Args:
  113. idim (int): input dimension of network
  114. odim (int): output dimension of network
  115. hdim (int): hidden dimension of network
  116. global_cmvn (nn.Module): cmvn for input feature, (idim, idim)
  117. preprocessing (nn.Module): feature dimension projection, (idim, hdim)
  118. backbone (nn.Module): backbone or feature extractor of the whole network, (hdim, hdim)
  119. classifier (nn.Module): output layer or classifier of KWS model, (hdim, odim)
  120. activation (nn.Module): nn.Identity for training, nn.Sigmoid for inference
  121. """
  122. super().__init__()
  123. self.idim = idim
  124. self.odim = odim
  125. self.hdim = hdim
  126. self.global_cmvn = global_cmvn
  127. self.preprocessing = preprocessing
  128. self.backbone = backbone
  129. self.classifier = classifier
  130. self.activation = activation
  131. def to_kaldi_net(self):
  132. return self.backbone.to_kaldi_net()
  133. def to_pytorch_net(self, kaldi_file):
  134. return self.backbone.to_pytorch_net(kaldi_file)
  135. def forward(
  136. self,
  137. x: torch.Tensor,
  138. in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
  139. ) -> Tuple[torch.Tensor, torch.Tensor]:
  140. if self.global_cmvn is not None:
  141. x = self.global_cmvn(x)
  142. if self.preprocessing is not None:
  143. x = self.preprocessing(x)
  144. x, out_cache = self.backbone(x, in_cache)
  145. if self.classifier is not None:
  146. x = self.classifier(x)
  147. if self.activation is not None:
  148. x = self.activation(x)
  149. return x, out_cache
  150. def fuse_modules(self):
  151. if self.preprocessing is not None:
  152. self.preprocessing.fuse_modules()
  153. self.backbone.fuse_modules()