ERes2Net_aug.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
  3. ERes2Net_aug incorporates both local and global feature fusion techniques
  4. to improve the performance. The training code is located on the following
  5. GitHub repository: https://github.com/alibaba-damo-academy/3D-Speaker.
  6. """
  7. import math
  8. import os
  9. from typing import Any, Dict, Union
  10. import numpy as np
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torchaudio.compliance.kaldi as Kaldi
  15. import modelscope.models.audio.sv.pooling_layers as pooling_layers
  16. from modelscope.metainfo import Models
  17. from modelscope.models import MODELS, TorchModel
  18. from modelscope.models.audio.sv.fusion import AFF
  19. from modelscope.utils.constant import Tasks
  20. from modelscope.utils.device import create_device
  21. class ReLU(nn.Hardtanh):
  22. def __init__(self, inplace=False):
  23. super(ReLU, self).__init__(0, 20, inplace)
  24. def __repr__(self):
  25. inplace_str = 'inplace' if self.inplace else ''
  26. return self.__class__.__name__ + ' (' \
  27. + inplace_str + ')'
  28. def conv1x1(in_planes, out_planes, stride=1):
  29. '1x1 convolution without padding'
  30. return nn.Conv2d(
  31. in_planes,
  32. out_planes,
  33. kernel_size=1,
  34. stride=stride,
  35. padding=0,
  36. bias=False)
  37. def conv3x3(in_planes, out_planes, stride=1):
  38. '3x3 convolution with padding'
  39. return nn.Conv2d(
  40. in_planes,
  41. out_planes,
  42. kernel_size=3,
  43. stride=stride,
  44. padding=1,
  45. bias=False)
  46. class BasicBlockERes2Net(nn.Module):
  47. expansion = 4
  48. def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
  49. super(BasicBlockERes2Net, self).__init__()
  50. width = int(math.floor(planes * (baseWidth / 64.0)))
  51. self.conv1 = conv1x1(in_planes, width * scale, stride)
  52. self.bn1 = nn.BatchNorm2d(width * scale)
  53. self.nums = scale
  54. convs = []
  55. bns = []
  56. for i in range(self.nums):
  57. convs.append(conv3x3(width, width))
  58. bns.append(nn.BatchNorm2d(width))
  59. self.convs = nn.ModuleList(convs)
  60. self.bns = nn.ModuleList(bns)
  61. self.relu = ReLU(inplace=True)
  62. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  63. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  64. self.shortcut = nn.Sequential()
  65. if stride != 1 or in_planes != self.expansion * planes:
  66. self.shortcut = nn.Sequential(
  67. nn.Conv2d(
  68. in_planes,
  69. self.expansion * planes,
  70. kernel_size=1,
  71. stride=stride,
  72. bias=False), nn.BatchNorm2d(self.expansion * planes))
  73. self.stride = stride
  74. self.width = width
  75. self.scale = scale
  76. def forward(self, x):
  77. residual = x
  78. out = self.conv1(x)
  79. out = self.bn1(out)
  80. out = self.relu(out)
  81. spx = torch.split(out, self.width, 1)
  82. for i in range(self.nums):
  83. if i == 0:
  84. sp = spx[i]
  85. else:
  86. sp = sp + spx[i]
  87. sp = self.convs[i](sp)
  88. sp = self.relu(self.bns[i](sp))
  89. if i == 0:
  90. out = sp
  91. else:
  92. out = torch.cat((out, sp), 1)
  93. out = self.conv3(out)
  94. out = self.bn3(out)
  95. residual = self.shortcut(x)
  96. out += residual
  97. out = self.relu(out)
  98. return out
  99. class BasicBlockERes2Net_diff_AFF(nn.Module):
  100. expansion = 4
  101. def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
  102. super(BasicBlockERes2Net_diff_AFF, self).__init__()
  103. width = int(math.floor(planes * (baseWidth / 64.0)))
  104. self.conv1 = conv1x1(in_planes, width * scale, stride)
  105. self.bn1 = nn.BatchNorm2d(width * scale)
  106. self.nums = scale
  107. convs = []
  108. fuse_models = []
  109. bns = []
  110. for i in range(self.nums):
  111. convs.append(conv3x3(width, width))
  112. bns.append(nn.BatchNorm2d(width))
  113. # Add different fuse_model parameters
  114. for j in range(self.nums - 1):
  115. fuse_models.append(AFF(channels=width))
  116. self.convs = nn.ModuleList(convs)
  117. self.bns = nn.ModuleList(bns)
  118. self.fuse_models = nn.ModuleList(fuse_models)
  119. self.relu = ReLU(inplace=True)
  120. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  121. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  122. self.shortcut = nn.Sequential()
  123. if stride != 1 or in_planes != self.expansion * planes:
  124. self.shortcut = nn.Sequential(
  125. nn.Conv2d(
  126. in_planes,
  127. self.expansion * planes,
  128. kernel_size=1,
  129. stride=stride,
  130. bias=False), nn.BatchNorm2d(self.expansion * planes))
  131. self.stride = stride
  132. self.width = width
  133. self.scale = scale
  134. def forward(self, x):
  135. residual = x
  136. out = self.conv1(x)
  137. out = self.bn1(out)
  138. out = self.relu(out)
  139. spx = torch.split(out, self.width, 1)
  140. for i in range(self.nums):
  141. if i == 0:
  142. sp = spx[i]
  143. else:
  144. sp = self.fuse_models[i - 1](sp, spx[i])
  145. sp = self.convs[i](sp)
  146. sp = self.relu(self.bns[i](sp))
  147. if i == 0:
  148. out = sp
  149. else:
  150. out = torch.cat((out, sp), 1)
  151. out = self.conv3(out)
  152. out = self.bn3(out)
  153. residual = self.shortcut(x)
  154. out += residual
  155. out = self.relu(out)
  156. return out
  157. class ERes2Net_aug(nn.Module):
  158. def __init__(self,
  159. block=BasicBlockERes2Net,
  160. block_fuse=BasicBlockERes2Net_diff_AFF,
  161. num_blocks=[3, 4, 6, 3],
  162. m_channels=64,
  163. feat_dim=80,
  164. embedding_size=192,
  165. pooling_func='TSTP',
  166. two_emb_layer=False):
  167. super(ERes2Net_aug, self).__init__()
  168. self.in_planes = m_channels
  169. self.feat_dim = feat_dim
  170. self.embedding_size = embedding_size
  171. self.stats_dim = int(feat_dim / 8) * m_channels * 8
  172. self.two_emb_layer = two_emb_layer
  173. self.conv1 = nn.Conv2d(
  174. 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
  175. self.bn1 = nn.BatchNorm2d(m_channels)
  176. self.layer1 = self._make_layer(
  177. block, m_channels, num_blocks[0], stride=1)
  178. self.layer2 = self._make_layer(
  179. block, m_channels * 2, num_blocks[1], stride=2)
  180. self.layer3 = self._make_layer(
  181. block_fuse, m_channels * 4, num_blocks[2], stride=2)
  182. self.layer4 = self._make_layer(
  183. block_fuse, m_channels * 8, num_blocks[3], stride=2)
  184. self.layer1_downsample = nn.Conv2d(
  185. m_channels * 4,
  186. m_channels * 8,
  187. kernel_size=3,
  188. padding=1,
  189. stride=2,
  190. bias=False)
  191. self.layer2_downsample = nn.Conv2d(
  192. m_channels * 8,
  193. m_channels * 16,
  194. kernel_size=3,
  195. padding=1,
  196. stride=2,
  197. bias=False)
  198. self.layer3_downsample = nn.Conv2d(
  199. m_channels * 16,
  200. m_channels * 32,
  201. kernel_size=3,
  202. padding=1,
  203. stride=2,
  204. bias=False)
  205. self.fuse_mode12 = AFF(channels=m_channels * 8)
  206. self.fuse_mode123 = AFF(channels=m_channels * 16)
  207. self.fuse_mode1234 = AFF(channels=m_channels * 32)
  208. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
  209. self.pool = getattr(pooling_layers, pooling_func)(
  210. in_dim=self.stats_dim * block.expansion)
  211. self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
  212. embedding_size)
  213. if self.two_emb_layer:
  214. self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
  215. self.seg_2 = nn.Linear(embedding_size, embedding_size)
  216. else:
  217. self.seg_bn_1 = nn.Identity()
  218. self.seg_2 = nn.Identity()
  219. def _make_layer(self, block, planes, num_blocks, stride):
  220. strides = [stride] + [1] * (num_blocks - 1)
  221. layers = []
  222. for stride in strides:
  223. layers.append(block(self.in_planes, planes, stride))
  224. self.in_planes = planes * block.expansion
  225. return nn.Sequential(*layers)
  226. def forward(self, x):
  227. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  228. x = x.unsqueeze_(1)
  229. out = F.relu(self.bn1(self.conv1(x)))
  230. out1 = self.layer1(out)
  231. out2 = self.layer2(out1)
  232. out1_downsample = self.layer1_downsample(out1)
  233. fuse_out12 = self.fuse_mode12(out2, out1_downsample)
  234. out3 = self.layer3(out2)
  235. fuse_out12_downsample = self.layer2_downsample(fuse_out12)
  236. fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
  237. out4 = self.layer4(out3)
  238. fuse_out123_downsample = self.layer3_downsample(fuse_out123)
  239. fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
  240. stats = self.pool(fuse_out1234)
  241. embed_a = self.seg_1(stats)
  242. if self.two_emb_layer:
  243. out = F.relu(embed_a)
  244. out = self.seg_bn_1(out)
  245. embed_b = self.seg_2(out)
  246. return embed_b
  247. else:
  248. return embed_a
  249. @MODELS.register_module(
  250. Tasks.speaker_verification, module_name=Models.eres2net_aug_sv)
  251. class SpeakerVerificationERes2Net(TorchModel):
  252. r"""Enhanced Res2Net_aug architecture with local and global feature fusion.
  253. ERes2Net_aug is an upgraded version of ERes2Net that uses a larger number of
  254. parameters to achieve better recognition performance.
  255. Args:
  256. model_dir: A model dir.
  257. model_config: The model config.
  258. """
  259. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  260. **kwargs):
  261. super().__init__(model_dir, model_config, *args, **kwargs)
  262. self.model_config = model_config
  263. self.other_config = kwargs
  264. self.feature_dim = 80
  265. self.device = create_device(self.other_config['device'])
  266. self.embedding_model = ERes2Net_aug()
  267. pretrained_model_name = kwargs['pretrained_model']
  268. self.__load_check_point(pretrained_model_name)
  269. self.embedding_model.to(self.device)
  270. self.embedding_model.eval()
  271. def forward(self, audio):
  272. if isinstance(audio, np.ndarray):
  273. audio = torch.from_numpy(audio)
  274. if len(audio.shape) == 1:
  275. audio = audio.unsqueeze(0)
  276. assert len(
  277. audio.shape
  278. ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
  279. # audio shape: [N, T]
  280. feature = self.__extract_feature(audio)
  281. embedding = self.embedding_model(feature.to(self.device))
  282. return embedding.detach().cpu()
  283. def __extract_feature(self, audio):
  284. feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
  285. feature = feature - feature.mean(dim=0, keepdim=True)
  286. feature = feature.unsqueeze(0)
  287. return feature
  288. def __load_check_point(self, pretrained_model_name, device=None):
  289. if not device:
  290. device = torch.device('cpu')
  291. self.embedding_model.load_state_dict(
  292. torch.load(
  293. os.path.join(self.model_dir, pretrained_model_name),
  294. map_location=device),
  295. strict=True)