ResNet.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ ResNet implementation is adapted from https://github.com/wenet-e2e/wespeaker.
  3. ResNet, or Residual Neural Network, is notable for its optimization ease
  4. and depth-induced accuracy gains. It utilizes skip connections within its residual
  5. blocks to counteract the vanishing gradient problem in deep networks.
  6. Reference: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
  7. Deep Residual Learning for Image Recognition. arXiv:1512.03385
  8. """
  9. import math
  10. import os
  11. from typing import Any, Dict, Union
  12. import numpy as np
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. import torchaudio.compliance.kaldi as Kaldi
  17. import modelscope.models.audio.sv.pooling_layers as pooling_layers
  18. from modelscope.metainfo import Models
  19. from modelscope.models import MODELS, TorchModel
  20. from modelscope.utils.constant import Tasks
  21. from modelscope.utils.device import create_device
  22. class BasicBlock(nn.Module):
  23. expansion = 1
  24. def __init__(self, in_planes, planes, stride=1):
  25. super(BasicBlock, self).__init__()
  26. self.conv1 = nn.Conv2d(
  27. in_planes,
  28. planes,
  29. kernel_size=3,
  30. stride=stride,
  31. padding=1,
  32. bias=False)
  33. self.bn1 = nn.BatchNorm2d(planes)
  34. self.conv2 = nn.Conv2d(
  35. planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
  36. self.bn2 = nn.BatchNorm2d(planes)
  37. self.shortcut = nn.Sequential()
  38. if stride != 1 or in_planes != self.expansion * planes:
  39. self.shortcut = nn.Sequential(
  40. nn.Conv2d(
  41. in_planes,
  42. self.expansion * planes,
  43. kernel_size=1,
  44. stride=stride,
  45. bias=False), nn.BatchNorm2d(self.expansion * planes))
  46. def forward(self, x):
  47. out = F.relu(self.bn1(self.conv1(x)))
  48. out = self.bn2(self.conv2(out))
  49. out += self.shortcut(x)
  50. out = F.relu(out)
  51. return out
  52. class ResNet(nn.Module):
  53. def __init__(self,
  54. block=BasicBlock,
  55. num_blocks=[3, 4, 6, 3],
  56. m_channels=32,
  57. feat_dim=80,
  58. embedding_size=128,
  59. pooling_func='TSTP',
  60. two_emb_layer=True):
  61. super(ResNet, self).__init__()
  62. self.in_planes = m_channels
  63. self.feat_dim = feat_dim
  64. self.embedding_size = embedding_size
  65. self.stats_dim = int(feat_dim / 8) * m_channels * 8
  66. self.two_emb_layer = two_emb_layer
  67. self.conv1 = nn.Conv2d(
  68. 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
  69. self.bn1 = nn.BatchNorm2d(m_channels)
  70. self.layer1 = self._make_layer(
  71. block, m_channels, num_blocks[0], stride=1)
  72. self.layer2 = self._make_layer(
  73. block, m_channels * 2, num_blocks[1], stride=2)
  74. self.layer3 = self._make_layer(
  75. block, m_channels * 4, num_blocks[2], stride=2)
  76. self.layer4 = self._make_layer(
  77. block, m_channels * 8, num_blocks[3], stride=2)
  78. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
  79. self.pool = getattr(pooling_layers, pooling_func)(
  80. in_dim=self.stats_dim * block.expansion)
  81. self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
  82. embedding_size)
  83. if self.two_emb_layer:
  84. self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
  85. self.seg_2 = nn.Linear(embedding_size, embedding_size)
  86. else:
  87. self.seg_bn_1 = nn.Identity()
  88. self.seg_2 = nn.Identity()
  89. def _make_layer(self, block, planes, num_blocks, stride):
  90. strides = [stride] + [1] * (num_blocks - 1)
  91. layers = []
  92. for stride in strides:
  93. layers.append(block(self.in_planes, planes, stride))
  94. self.in_planes = planes * block.expansion
  95. return nn.Sequential(*layers)
  96. def forward(self, x):
  97. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  98. x = x.unsqueeze_(1)
  99. out = F.relu(self.bn1(self.conv1(x)))
  100. out1 = self.layer1(out)
  101. out2 = self.layer2(out1)
  102. out3 = self.layer3(out2)
  103. out = self.layer4(out3)
  104. stats = self.pool(out)
  105. embed_a = self.seg_1(stats)
  106. if self.two_emb_layer:
  107. out = F.relu(embed_a)
  108. out = self.seg_bn_1(out)
  109. embed_b = self.seg_2(out)
  110. return embed_b
  111. else:
  112. return embed_a
  113. @MODELS.register_module(
  114. Tasks.speaker_verification, module_name=Models.resnet_sv)
  115. class SpeakerVerificationResNet(TorchModel):
  116. r"""
  117. Args:
  118. model_dir: A model dir.
  119. model_config: The model config.
  120. """
  121. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  122. **kwargs):
  123. super().__init__(model_dir, model_config, *args, **kwargs)
  124. self.model_config = model_config
  125. self.embed_dim = self.model_config['embed_dim']
  126. self.m_channels = self.model_config['channels']
  127. self.other_config = kwargs
  128. self.feature_dim = 80
  129. self.device = create_device(self.other_config['device'])
  130. self.embedding_model = ResNet(
  131. embedding_size=self.embed_dim, m_channels=self.m_channels)
  132. pretrained_model_name = kwargs['pretrained_model']
  133. self.__load_check_point(pretrained_model_name)
  134. self.embedding_model.to(self.device)
  135. self.embedding_model.eval()
  136. def forward(self, audio):
  137. if isinstance(audio, np.ndarray):
  138. audio = torch.from_numpy(audio)
  139. if len(audio.shape) == 1:
  140. audio = audio.unsqueeze(0)
  141. assert len(
  142. audio.shape
  143. ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
  144. # audio shape: [N, T]
  145. feature = self.__extract_feature(audio)
  146. embedding = self.embedding_model(feature.to(self.device))
  147. return embedding.detach().cpu()
  148. def __extract_feature(self, audio):
  149. feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
  150. feature = feature - feature.mean(dim=0, keepdim=True)
  151. feature = feature.unsqueeze(0)
  152. return feature
  153. def __load_check_point(self, pretrained_model_name, device=None):
  154. if not device:
  155. device = torch.device('cpu')
  156. self.embedding_model.load_state_dict(
  157. torch.load(
  158. os.path.join(self.model_dir, pretrained_model_name),
  159. map_location=device),
  160. strict=True)