xvector.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """
  3. This TDNN implementation is adapted from https://github.com/wenet-e2e/wespeaker.
  4. TDNN replaces i-vectors for text-independent speaker verification with embeddings
  5. extracted from a feedforward deep neural network. The specific structure can be
  6. referred to in https://www.danielpovey.com/files/2017_interspeech_embeddings.pdf.
  7. """
  8. import math
  9. import os
  10. from typing import Any, Dict, Union
  11. import numpy as np
  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. import torchaudio.compliance.kaldi as Kaldi
  16. import modelscope.models.audio.sv.pooling_layers as pooling_layers
  17. from modelscope.metainfo import Models
  18. from modelscope.models import MODELS, TorchModel
  19. from modelscope.utils.constant import Tasks
  20. from modelscope.utils.device import create_device
  21. class TdnnLayer(nn.Module):
  22. def __init__(self, in_dim, out_dim, context_size, dilation=1, padding=0):
  23. """Define the TDNN layer, essentially 1-D convolution
  24. Args:
  25. in_dim (int): input dimension
  26. out_dim (int): output channels
  27. context_size (int): context size, essentially the filter size
  28. dilation (int, optional): Defaults to 1.
  29. padding (int, optional): Defaults to 0.
  30. """
  31. super(TdnnLayer, self).__init__()
  32. self.in_dim = in_dim
  33. self.out_dim = out_dim
  34. self.context_size = context_size
  35. self.dilation = dilation
  36. self.padding = padding
  37. self.conv_1d = nn.Conv1d(
  38. self.in_dim,
  39. self.out_dim,
  40. self.context_size,
  41. dilation=self.dilation,
  42. padding=self.padding)
  43. # Set Affine=false to be compatible with the original kaldi version
  44. self.bn = nn.BatchNorm1d(out_dim, affine=False)
  45. def forward(self, x):
  46. out = self.conv_1d(x)
  47. out = F.relu(out)
  48. out = self.bn(out)
  49. return out
  50. class XVEC(nn.Module):
  51. def __init__(self,
  52. feat_dim=40,
  53. hid_dim=512,
  54. stats_dim=1500,
  55. embed_dim=512,
  56. pooling_func='TSTP'):
  57. """
  58. Implementation of Kaldi style xvec, as described in
  59. X-VECTORS: ROBUST DNN EMBEDDINGS FOR SPEAKER RECOGNITION
  60. """
  61. super(XVEC, self).__init__()
  62. self.feat_dim = feat_dim
  63. self.stats_dim = stats_dim
  64. self.embed_dim = embed_dim
  65. self.frame_1 = TdnnLayer(feat_dim, hid_dim, context_size=5, dilation=1)
  66. self.frame_2 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=2)
  67. self.frame_3 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=3)
  68. self.frame_4 = TdnnLayer(hid_dim, hid_dim, context_size=1, dilation=1)
  69. self.frame_5 = TdnnLayer(
  70. hid_dim, stats_dim, context_size=1, dilation=1)
  71. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
  72. self.pool = getattr(pooling_layers, pooling_func)(
  73. in_dim=self.stats_dim)
  74. self.seg_1 = nn.Linear(self.stats_dim * self.n_stats, embed_dim)
  75. def forward(self, x):
  76. x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
  77. out = self.frame_1(x)
  78. out = self.frame_2(out)
  79. out = self.frame_3(out)
  80. out = self.frame_4(out)
  81. out = self.frame_5(out)
  82. stats = self.pool(out)
  83. embed_a = self.seg_1(stats)
  84. return embed_a
  85. @MODELS.register_module(Tasks.speaker_verification, module_name=Models.tdnn_sv)
  86. class SpeakerVerificationTDNN(TorchModel):
  87. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  88. **kwargs):
  89. super().__init__(model_dir, model_config, *args, **kwargs)
  90. self.model_config = model_config
  91. self.other_config = kwargs
  92. self.feature_dim = 80
  93. self.embed_dim = 512
  94. self.device = create_device(self.other_config['device'])
  95. print(self.device)
  96. self.embedding_model = XVEC(
  97. feat_dim=self.feature_dim, embed_dim=self.embed_dim)
  98. pretrained_model_name = kwargs['pretrained_model']
  99. self.__load_check_point(pretrained_model_name)
  100. self.embedding_model.to(self.device)
  101. self.embedding_model.eval()
  102. def forward(self, audio):
  103. if isinstance(audio, np.ndarray):
  104. audio = torch.from_numpy(audio)
  105. if len(audio.shape) == 1:
  106. audio = audio.unsqueeze(0)
  107. assert len(
  108. audio.shape
  109. ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
  110. # audio shape: [N, T]
  111. feature = self.__extract_feature(audio)
  112. embedding = self.embedding_model(feature.to(self.device))
  113. return embedding.detach().cpu()
  114. def __extract_feature(self, audio):
  115. features = []
  116. for au in audio:
  117. feature = Kaldi.fbank(
  118. au.unsqueeze(0), num_mel_bins=self.feature_dim)
  119. feature = feature - feature.mean(dim=0, keepdim=True)
  120. features.append(feature.unsqueeze(0))
  121. features = torch.cat(features)
  122. return features
  123. def __load_check_point(self, pretrained_model_name):
  124. self.embedding_model.load_state_dict(
  125. torch.load(
  126. os.path.join(self.model_dir, pretrained_model_name),
  127. map_location=torch.device('cpu')),
  128. strict=True)