DTDNN.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from collections import OrderedDict
  4. from typing import Any, Dict, Union
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torchaudio.compliance.kaldi as Kaldi
  10. from modelscope.metainfo import Models
  11. from modelscope.models import MODELS, TorchModel
  12. from modelscope.models.audio.sv.DTDNN_layers import (BasicResBlock,
  13. CAMDenseTDNNBlock,
  14. DenseLayer, StatsPool,
  15. TDNNLayer, TransitLayer,
  16. get_nonlinear)
  17. from modelscope.utils.constant import Tasks
  18. from modelscope.utils.device import create_device
  19. class FCM(nn.Module):
  20. def __init__(self,
  21. block=BasicResBlock,
  22. num_blocks=[2, 2],
  23. m_channels=32,
  24. feat_dim=80):
  25. super(FCM, self).__init__()
  26. self.in_planes = m_channels
  27. self.conv1 = nn.Conv2d(
  28. 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
  29. self.bn1 = nn.BatchNorm2d(m_channels)
  30. self.layer1 = self._make_layer(
  31. block, m_channels, num_blocks[0], stride=2)
  32. self.layer2 = self._make_layer(
  33. block, m_channels, num_blocks[0], stride=2)
  34. self.conv2 = nn.Conv2d(
  35. m_channels,
  36. m_channels,
  37. kernel_size=3,
  38. stride=(2, 1),
  39. padding=1,
  40. bias=False)
  41. self.bn2 = nn.BatchNorm2d(m_channels)
  42. self.out_channels = m_channels * (feat_dim // 8)
  43. def _make_layer(self, block, planes, num_blocks, stride):
  44. strides = [stride] + [1] * (num_blocks - 1)
  45. layers = []
  46. for stride in strides:
  47. layers.append(block(self.in_planes, planes, stride))
  48. self.in_planes = planes * block.expansion
  49. return nn.Sequential(*layers)
  50. def forward(self, x):
  51. x = x.unsqueeze(1)
  52. out = F.relu(self.bn1(self.conv1(x)))
  53. out = self.layer1(out)
  54. out = self.layer2(out)
  55. out = F.relu(self.bn2(self.conv2(out)))
  56. shape = out.shape
  57. out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
  58. return out
  59. class CAMPPlus(nn.Module):
  60. def __init__(self,
  61. feat_dim=80,
  62. embedding_size=512,
  63. growth_rate=32,
  64. bn_size=4,
  65. init_channels=128,
  66. config_str='batchnorm-relu',
  67. memory_efficient=True,
  68. output_level='segment'):
  69. super(CAMPPlus, self).__init__()
  70. self.head = FCM(feat_dim=feat_dim)
  71. channels = self.head.out_channels
  72. self.output_level = output_level
  73. self.xvector = nn.Sequential(
  74. OrderedDict([
  75. ('tdnn',
  76. TDNNLayer(
  77. channels,
  78. init_channels,
  79. 5,
  80. stride=2,
  81. dilation=1,
  82. padding=-1,
  83. config_str=config_str)),
  84. ]))
  85. channels = init_channels
  86. for i, (num_layers, kernel_size, dilation) in enumerate(
  87. zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
  88. block = CAMDenseTDNNBlock(
  89. num_layers=num_layers,
  90. in_channels=channels,
  91. out_channels=growth_rate,
  92. bn_channels=bn_size * growth_rate,
  93. kernel_size=kernel_size,
  94. dilation=dilation,
  95. config_str=config_str,
  96. memory_efficient=memory_efficient)
  97. self.xvector.add_module('block%d' % (i + 1), block)
  98. channels = channels + num_layers * growth_rate
  99. self.xvector.add_module(
  100. 'transit%d' % (i + 1),
  101. TransitLayer(
  102. channels, channels // 2, bias=False,
  103. config_str=config_str))
  104. channels //= 2
  105. self.xvector.add_module('out_nonlinear',
  106. get_nonlinear(config_str, channels))
  107. if self.output_level == 'segment':
  108. self.xvector.add_module('stats', StatsPool())
  109. self.xvector.add_module(
  110. 'dense',
  111. DenseLayer(
  112. channels * 2, embedding_size, config_str='batchnorm_'))
  113. else:
  114. assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
  115. for m in self.modules():
  116. if isinstance(m, (nn.Conv1d, nn.Linear)):
  117. nn.init.kaiming_normal_(m.weight.data)
  118. if m.bias is not None:
  119. nn.init.zeros_(m.bias)
  120. def forward(self, x):
  121. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  122. x = self.head(x)
  123. x = self.xvector(x)
  124. if self.output_level == 'frame':
  125. x = x.transpose(1, 2)
  126. return x
  127. @MODELS.register_module(
  128. Tasks.speaker_verification, module_name=Models.campplus_sv)
  129. class SpeakerVerificationCAMPPlus(TorchModel):
  130. r"""A fast and efficient speaker embedding model, using a 2-dimensional convolution residual network as the head
  131. and a densely connected time delay neural network as the backbone.
  132. Args:
  133. model_dir: A model dir.
  134. model_config: The model config.
  135. """
  136. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  137. **kwargs):
  138. super().__init__(model_dir, model_config, *args, **kwargs)
  139. self.model_config = model_config
  140. self.other_config = kwargs
  141. self.feature_dim = self.model_config['fbank_dim']
  142. self.emb_size = self.model_config['emb_size']
  143. self.device = create_device(self.other_config['device'])
  144. self.embedding_model = CAMPPlus(self.feature_dim, self.emb_size)
  145. pretrained_model_name = kwargs['pretrained_model']
  146. self.__load_check_point(pretrained_model_name)
  147. self.embedding_model.to(self.device)
  148. self.embedding_model.eval()
  149. def forward(self, audio):
  150. if isinstance(audio, np.ndarray):
  151. audio = torch.from_numpy(audio)
  152. if len(audio.shape) == 1:
  153. audio = audio.unsqueeze(0)
  154. assert len(
  155. audio.shape
  156. ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
  157. # audio shape: [N, T]
  158. feature = self.__extract_feature(audio)
  159. embedding = self.embedding_model(feature.to(self.device))
  160. return embedding.detach().cpu()
  161. def __extract_feature(self, audio):
  162. features = []
  163. for au in audio:
  164. feature = Kaldi.fbank(
  165. au.unsqueeze(0), num_mel_bins=self.feature_dim)
  166. feature = feature - feature.mean(dim=0, keepdim=True)
  167. features.append(feature.unsqueeze(0))
  168. features = torch.cat(features)
  169. return features
  170. def __load_check_point(self, pretrained_model_name):
  171. self.embedding_model.load_state_dict(
  172. torch.load(
  173. os.path.join(self.model_dir, pretrained_model_name),
  174. map_location=torch.device('cpu')),
  175. strict=True)