Res2Net.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ Res2Net implementation is adapted from https://github.com/Res2Net/Res2Net-PretrainedModels.
  3. Res2Net is an advanced neural network architecture that enhances the capabilities of standard ResNets
  4. by incorporating hierarchical residual-like connections. This innovative structure improves
  5. performance across various computer vision tasks, such as image classification and object
  6. detection, without significant computational overhead.
  7. Reference: https://arxiv.org/pdf/1904.01169.pdf
  8. Some modifications from the original architecture:
  9. 1. Smaller kernel size for the input layer
  10. 2. Smaller expansion in BasicBlockRes2Net
  11. """
  12. import math
  13. import os
  14. from typing import Any, Dict, Union
  15. import numpy as np
  16. import torch
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. import torchaudio.compliance.kaldi as Kaldi
  20. import modelscope.models.audio.sv.pooling_layers as pooling_layers
  21. from modelscope.metainfo import Models
  22. from modelscope.models import MODELS, TorchModel
  23. from modelscope.utils.constant import Tasks
  24. from modelscope.utils.device import create_device
  25. class ReLU(nn.Hardtanh):
  26. def __init__(self, inplace=False):
  27. super(ReLU, self).__init__(0, 20, inplace)
  28. def __repr__(self):
  29. inplace_str = 'inplace' if self.inplace else ''
  30. return self.__class__.__name__ + ' (' \
  31. + inplace_str + ')'
  32. class BasicBlockRes2Net(nn.Module):
  33. expansion = 2
  34. def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
  35. super(BasicBlockRes2Net, self).__init__()
  36. width = int(math.floor(planes * (baseWidth / 64.0)))
  37. self.conv1 = nn.Conv2d(
  38. in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
  39. self.bn1 = nn.BatchNorm2d(width * scale)
  40. self.nums = scale - 1
  41. convs = []
  42. bns = []
  43. for i in range(self.nums):
  44. convs.append(
  45. nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
  46. bns.append(nn.BatchNorm2d(width))
  47. self.convs = nn.ModuleList(convs)
  48. self.bns = nn.ModuleList(bns)
  49. self.relu = ReLU(inplace=True)
  50. self.conv3 = nn.Conv2d(
  51. width * scale, planes * self.expansion, kernel_size=1, bias=False)
  52. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  53. self.shortcut = nn.Sequential()
  54. if stride != 1 or in_planes != self.expansion * planes:
  55. self.shortcut = nn.Sequential(
  56. nn.Conv2d(
  57. in_planes,
  58. self.expansion * planes,
  59. kernel_size=1,
  60. stride=stride,
  61. bias=False), nn.BatchNorm2d(self.expansion * planes))
  62. self.stride = stride
  63. self.width = width
  64. self.scale = scale
  65. def forward(self, x):
  66. residual = x
  67. out = self.conv1(x)
  68. out = self.bn1(out)
  69. out = self.relu(out)
  70. spx = torch.split(out, self.width, 1)
  71. for i in range(self.nums):
  72. if i == 0:
  73. sp = spx[i]
  74. else:
  75. sp = sp + spx[i]
  76. sp = self.convs[i](sp)
  77. sp = self.relu(self.bns[i](sp))
  78. if i == 0:
  79. out = sp
  80. else:
  81. out = torch.cat((out, sp), 1)
  82. out = torch.cat((out, spx[self.nums]), 1)
  83. out = self.conv3(out)
  84. out = self.bn3(out)
  85. residual = self.shortcut(x)
  86. out += residual
  87. out = self.relu(out)
  88. return out
  89. class Res2Net(nn.Module):
  90. def __init__(self,
  91. block=BasicBlockRes2Net,
  92. num_blocks=[3, 4, 6, 3],
  93. m_channels=32,
  94. feat_dim=80,
  95. embedding_size=192,
  96. pooling_func='TSTP',
  97. two_emb_layer=False):
  98. super(Res2Net, self).__init__()
  99. self.in_planes = m_channels
  100. self.feat_dim = feat_dim
  101. self.embedding_size = embedding_size
  102. self.stats_dim = int(feat_dim / 8) * m_channels * 8
  103. self.two_emb_layer = two_emb_layer
  104. self.conv1 = nn.Conv2d(
  105. 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
  106. self.bn1 = nn.BatchNorm2d(m_channels)
  107. self.layer1 = self._make_layer(
  108. block, m_channels, num_blocks[0], stride=1)
  109. self.layer2 = self._make_layer(
  110. block, m_channels * 2, num_blocks[1], stride=2)
  111. self.layer3 = self._make_layer(
  112. block, m_channels * 4, num_blocks[2], stride=2)
  113. self.layer4 = self._make_layer(
  114. block, m_channels * 8, num_blocks[3], stride=2)
  115. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
  116. self.pool = getattr(pooling_layers, pooling_func)(
  117. in_dim=self.stats_dim * block.expansion)
  118. self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
  119. embedding_size)
  120. if self.two_emb_layer:
  121. self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
  122. self.seg_2 = nn.Linear(embedding_size, embedding_size)
  123. else:
  124. self.seg_bn_1 = nn.Identity()
  125. self.seg_2 = nn.Identity()
  126. def _make_layer(self, block, planes, num_blocks, stride):
  127. strides = [stride] + [1] * (num_blocks - 1)
  128. layers = []
  129. for stride in strides:
  130. layers.append(block(self.in_planes, planes, stride))
  131. self.in_planes = planes * block.expansion
  132. return nn.Sequential(*layers)
  133. def forward(self, x):
  134. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  135. x = x.unsqueeze_(1)
  136. out = F.relu(self.bn1(self.conv1(x)))
  137. out = self.layer1(out)
  138. out = self.layer2(out)
  139. out = self.layer3(out)
  140. out = self.layer4(out)
  141. stats = self.pool(out)
  142. embed_a = self.seg_1(stats)
  143. if self.two_emb_layer:
  144. out = F.relu(embed_a)
  145. out = self.seg_bn_1(out)
  146. embed_b = self.seg_2(out)
  147. return embed_b
  148. else:
  149. return embed_a
  150. @MODELS.register_module(
  151. Tasks.speaker_verification, module_name=Models.res2net_sv)
  152. class SpeakerVerificationResNet(TorchModel):
  153. r"""
  154. Args:
  155. model_dir: A model dir.
  156. model_config: The model config.
  157. """
  158. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  159. **kwargs):
  160. super().__init__(model_dir, model_config, *args, **kwargs)
  161. self.model_config = model_config
  162. self.embed_dim = self.model_config['embed_dim']
  163. self.m_channels = self.model_config['channels']
  164. self.other_config = kwargs
  165. self.feature_dim = 80
  166. self.device = create_device(self.other_config['device'])
  167. self.embedding_model = Res2Net(
  168. embedding_size=self.embed_dim, m_channels=self.m_channels)
  169. pretrained_model_name = kwargs['pretrained_model']
  170. self.__load_check_point(pretrained_model_name)
  171. self.embedding_model.to(self.device)
  172. self.embedding_model.eval()
  173. def forward(self, audio):
  174. if isinstance(audio, np.ndarray):
  175. audio = torch.from_numpy(audio)
  176. if len(audio.shape) == 1:
  177. audio = audio.unsqueeze(0)
  178. assert len(
  179. audio.shape
  180. ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
  181. # audio shape: [N, T]
  182. feature = self.__extract_feature(audio)
  183. embedding = self.embedding_model(feature.to(self.device))
  184. return embedding.detach().cpu()
  185. def __extract_feature(self, audio):
  186. feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
  187. feature = feature - feature.mean(dim=0, keepdim=True)
  188. feature = feature.unsqueeze(0)
  189. return feature
  190. def __load_check_point(self, pretrained_model_name, device=None):
  191. if not device:
  192. device = torch.device('cpu')
  193. self.embedding_model.load_state_dict(
  194. torch.load(
  195. os.path.join(self.model_dir, pretrained_model_name),
  196. map_location=device),
  197. strict=True)