ERes2NetV2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """
  3. To further improve the short-duration feature extraction capability of ERes2Net,
  4. we expand the channel dimension within each stage. However, this modification also
  5. increases the number of model parameters and computational complexity.
  6. To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures,
  7. ultimately reducing both the model parameters and its computational cost.
  8. """
  9. import math
  10. import os
  11. from typing import Any, Dict, Union
  12. import numpy as np
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. import torchaudio.compliance.kaldi as Kaldi
  17. import modelscope.models.audio.sv.pooling_layers as pooling_layers
  18. from modelscope.metainfo import Models
  19. from modelscope.models import MODELS, TorchModel
  20. from modelscope.models.audio.sv.fusion import AFF
  21. from modelscope.utils.constant import Tasks
  22. from modelscope.utils.device import create_device
  23. class ReLU(nn.Hardtanh):
  24. def __init__(self, inplace=False):
  25. super(ReLU, self).__init__(0, 20, inplace)
  26. def __repr__(self):
  27. inplace_str = 'inplace' if self.inplace else ''
  28. return self.__class__.__name__ + ' (' \
  29. + inplace_str + ')'
  30. class BasicBlockERes2NetV2(nn.Module):
  31. def __init__(self,
  32. in_planes,
  33. planes,
  34. stride=1,
  35. baseWidth=26,
  36. scale=2,
  37. expansion=2):
  38. super(BasicBlockERes2NetV2, self).__init__()
  39. width = int(math.floor(planes * (baseWidth / 64.0)))
  40. self.width = width
  41. self.conv1 = nn.Conv2d(
  42. in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
  43. self.bn1 = nn.BatchNorm2d(width * scale)
  44. self.nums = scale
  45. self.expansion = expansion
  46. convs = []
  47. bns = []
  48. for i in range(self.nums):
  49. convs.append(
  50. nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
  51. bns.append(nn.BatchNorm2d(width))
  52. self.convs = nn.ModuleList(convs)
  53. self.bns = nn.ModuleList(bns)
  54. self.relu = ReLU(inplace=True)
  55. self.conv3 = nn.Conv2d(
  56. width * scale, planes * self.expansion, kernel_size=1, bias=False)
  57. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  58. self.shortcut = nn.Sequential()
  59. if stride != 1 or in_planes != self.expansion * planes:
  60. self.shortcut = nn.Sequential(
  61. nn.Conv2d(
  62. in_planes,
  63. self.expansion * planes,
  64. kernel_size=1,
  65. stride=stride,
  66. bias=False), nn.BatchNorm2d(self.expansion * planes))
  67. def forward(self, x):
  68. residual = x
  69. out = self.conv1(x)
  70. out = self.bn1(out)
  71. out = self.relu(out)
  72. spx = torch.split(out, self.width, 1)
  73. for i in range(self.nums):
  74. if i == 0:
  75. sp = spx[i]
  76. else:
  77. sp = sp + spx[i]
  78. sp = self.convs[i](sp)
  79. sp = self.relu(self.bns[i](sp))
  80. if i == 0:
  81. out = sp
  82. else:
  83. out = torch.cat((out, sp), 1)
  84. out = self.conv3(out)
  85. out = self.bn3(out)
  86. residual = self.shortcut(x)
  87. out += residual
  88. out = self.relu(out)
  89. return out
  90. class BasicBlockERes2NetV2AFF(nn.Module):
  91. def __init__(self,
  92. in_planes,
  93. planes,
  94. stride=1,
  95. baseWidth=26,
  96. scale=2,
  97. expansion=2):
  98. super(BasicBlockERes2NetV2AFF, self).__init__()
  99. width = int(math.floor(planes * (baseWidth / 64.0)))
  100. self.width = width
  101. self.conv1 = nn.Conv2d(
  102. in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
  103. self.bn1 = nn.BatchNorm2d(width * scale)
  104. self.nums = scale
  105. self.expansion = expansion
  106. convs = []
  107. fuse_models = []
  108. bns = []
  109. for i in range(self.nums):
  110. convs.append(
  111. nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
  112. bns.append(nn.BatchNorm2d(width))
  113. for j in range(self.nums - 1):
  114. fuse_models.append(AFF(channels=width, r=4))
  115. self.convs = nn.ModuleList(convs)
  116. self.bns = nn.ModuleList(bns)
  117. self.fuse_models = nn.ModuleList(fuse_models)
  118. self.relu = ReLU(inplace=True)
  119. self.conv3 = nn.Conv2d(
  120. width * scale, planes * self.expansion, kernel_size=1, bias=False)
  121. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  122. self.shortcut = nn.Sequential()
  123. if stride != 1 or in_planes != self.expansion * planes:
  124. self.shortcut = nn.Sequential(
  125. nn.Conv2d(
  126. in_planes,
  127. self.expansion * planes,
  128. kernel_size=1,
  129. stride=stride,
  130. bias=False), nn.BatchNorm2d(self.expansion * planes))
  131. def forward(self, x):
  132. residual = x
  133. out = self.conv1(x)
  134. out = self.bn1(out)
  135. out = self.relu(out)
  136. spx = torch.split(out, self.width, 1)
  137. for i in range(self.nums):
  138. if i == 0:
  139. sp = spx[i]
  140. else:
  141. sp = self.fuse_models[i - 1](sp, spx[i])
  142. sp = self.convs[i](sp)
  143. sp = self.relu(self.bns[i](sp))
  144. if i == 0:
  145. out = sp
  146. else:
  147. out = torch.cat((out, sp), 1)
  148. out = self.conv3(out)
  149. out = self.bn3(out)
  150. residual = self.shortcut(x)
  151. out += residual
  152. out = self.relu(out)
  153. return out
  154. class ERes2NetV2(nn.Module):
  155. def __init__(self,
  156. block=BasicBlockERes2NetV2,
  157. block_fuse=BasicBlockERes2NetV2AFF,
  158. num_blocks=[3, 4, 6, 3],
  159. m_channels=64,
  160. feat_dim=80,
  161. embed_dim=192,
  162. baseWidth=26,
  163. scale=2,
  164. expansion=2,
  165. pooling_func='TSTP',
  166. two_emb_layer=False):
  167. super(ERes2NetV2, self).__init__()
  168. self.in_planes = m_channels
  169. self.feat_dim = feat_dim
  170. self.embed_dim = embed_dim
  171. self.stats_dim = int(feat_dim / 8) * m_channels * 8
  172. self.two_emb_layer = two_emb_layer
  173. self.baseWidth = baseWidth
  174. self.scale = scale
  175. self.expansion = expansion
  176. self.conv1 = nn.Conv2d(
  177. 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
  178. self.bn1 = nn.BatchNorm2d(m_channels)
  179. self.layer1 = self._make_layer(
  180. block, m_channels, num_blocks[0], stride=1)
  181. self.layer2 = self._make_layer(
  182. block, m_channels * 2, num_blocks[1], stride=2)
  183. self.layer3 = self._make_layer(
  184. block_fuse, m_channels * 4, num_blocks[2], stride=2)
  185. self.layer4 = self._make_layer(
  186. block_fuse, m_channels * 8, num_blocks[3], stride=2)
  187. # Downsampling module
  188. self.layer3_ds = nn.Conv2d(
  189. m_channels * 4 * self.expansion,
  190. m_channels * 8 * self.expansion,
  191. kernel_size=3,
  192. padding=1,
  193. stride=2,
  194. bias=False)
  195. # Bottom-up fusion module
  196. self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
  197. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
  198. self.pool = getattr(pooling_layers, pooling_func)(
  199. in_dim=self.stats_dim * self.expansion)
  200. self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
  201. embed_dim)
  202. if self.two_emb_layer:
  203. self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False)
  204. self.seg_2 = nn.Linear(embed_dim, embed_dim)
  205. else:
  206. self.seg_bn_1 = nn.Identity()
  207. self.seg_2 = nn.Identity()
  208. def _make_layer(self, block, planes, num_blocks, stride):
  209. strides = [stride] + [1] * (num_blocks - 1)
  210. layers = []
  211. for stride in strides:
  212. layers.append(
  213. block(
  214. self.in_planes,
  215. planes,
  216. stride,
  217. baseWidth=self.baseWidth,
  218. scale=self.scale,
  219. expansion=self.expansion))
  220. self.in_planes = planes * self.expansion
  221. return nn.Sequential(*layers)
  222. def forward(self, x):
  223. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  224. x = x.unsqueeze_(1)
  225. out = F.relu(self.bn1(self.conv1(x)))
  226. out1 = self.layer1(out)
  227. out2 = self.layer2(out1)
  228. out3 = self.layer3(out2)
  229. out4 = self.layer4(out3)
  230. out3_ds = self.layer3_ds(out3)
  231. fuse_out34 = self.fuse34(out4, out3_ds)
  232. stats = self.pool(fuse_out34)
  233. embed_a = self.seg_1(stats)
  234. if self.two_emb_layer:
  235. out = F.relu(embed_a)
  236. out = self.seg_bn_1(out)
  237. embed_b = self.seg_2(out)
  238. return embed_b
  239. else:
  240. return embed_a
  241. @MODELS.register_module(
  242. Tasks.speaker_verification, module_name=Models.eres2netv2_sv)
  243. class SpeakerVerificationERes2NetV2(TorchModel):
  244. r"""ERes2NetV2 architecture with local and global feature fusion. ERes2NetV2 is mainly composed
  245. of Bottom-up Dual-stage Feature Fusion (BDFF) and Bottleneck-like Local Feature Fusion (BLFF).
  246. BDFF fuses multi-scale feature maps in bottom-up pathway to obtain global information.
  247. The BLFF extracts localization-preserved speaker features and strengthen the local information interaction.
  248. Args:
  249. model_dir: A model dir.
  250. model_config: The model config.
  251. """
  252. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  253. **kwargs):
  254. super().__init__(model_dir, model_config, *args, **kwargs)
  255. self.model_config = model_config
  256. self.embed_dim = self.model_config['embed_dim']
  257. self.baseWidth = self.model_config['baseWidth']
  258. self.scale = self.model_config['scale']
  259. self.expansion = self.model_config['expansion']
  260. self.other_config = kwargs
  261. self.feature_dim = 80
  262. self.device = create_device(self.other_config['device'])
  263. self.embedding_model = ERes2NetV2(
  264. embed_dim=self.embed_dim,
  265. baseWidth=self.baseWidth,
  266. scale=self.scale,
  267. expansion=self.expansion)
  268. pretrained_model_name = kwargs['pretrained_model']
  269. self.__load_check_point(pretrained_model_name)
  270. self.embedding_model.to(self.device)
  271. self.embedding_model.eval()
  272. def forward(self, audio):
  273. if isinstance(audio, np.ndarray):
  274. audio = torch.from_numpy(audio)
  275. if len(audio.shape) == 1:
  276. audio = audio.unsqueeze(0)
  277. assert len(
  278. audio.shape
  279. ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
  280. # audio shape: [N, T]
  281. feature = self.__extract_feature(audio)
  282. embedding = self.embedding_model(feature.to(self.device))
  283. return embedding.detach().cpu()
  284. def __extract_feature(self, audio):
  285. feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
  286. feature = feature - feature.mean(dim=0, keepdim=True)
  287. feature = feature.unsqueeze(0)
  288. return feature
  289. def __load_check_point(self, pretrained_model_name, device=None):
  290. if not device:
  291. device = torch.device('cpu')
  292. self.embedding_model.load_state_dict(
  293. torch.load(
  294. os.path.join(self.model_dir, pretrained_model_name),
  295. map_location=device),
  296. strict=True)