ERes2Net.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
  3. ERes2Net incorporates both local and global feature fusion techniques to improve the performance. The local feature
  4. fusion (LFF) fuses the features within one single residual block to extract the local signal.
  5. The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
  6. """
  7. import math
  8. import os
  9. from typing import Any, Dict, Union
  10. import numpy as np
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torchaudio.compliance.kaldi as Kaldi
  15. import modelscope.models.audio.sv.pooling_layers as pooling_layers
  16. from modelscope.metainfo import Models
  17. from modelscope.models import MODELS, TorchModel
  18. from modelscope.models.audio.sv.fusion import AFF
  19. from modelscope.utils.constant import Tasks
  20. from modelscope.utils.device import create_device
  21. class ReLU(nn.Hardtanh):
  22. def __init__(self, inplace=False):
  23. super(ReLU, self).__init__(0, 20, inplace)
  24. def __repr__(self):
  25. inplace_str = 'inplace' if self.inplace else ''
  26. return self.__class__.__name__ + ' (' \
  27. + inplace_str + ')'
  28. def conv1x1(in_planes, out_planes, stride=1):
  29. '1x1 convolution without padding'
  30. return nn.Conv2d(
  31. in_planes,
  32. out_planes,
  33. kernel_size=1,
  34. stride=stride,
  35. padding=0,
  36. bias=False)
  37. def conv3x3(in_planes, out_planes, stride=1):
  38. '3x3 convolution with padding'
  39. return nn.Conv2d(
  40. in_planes,
  41. out_planes,
  42. kernel_size=3,
  43. stride=stride,
  44. padding=1,
  45. bias=False)
  46. class BasicBlockERes2Net(nn.Module):
  47. expansion = 2
  48. def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
  49. super(BasicBlockERes2Net, self).__init__()
  50. width = int(math.floor(planes * (baseWidth / 64.0)))
  51. self.conv1 = conv1x1(in_planes, width * scale, stride)
  52. self.bn1 = nn.BatchNorm2d(width * scale)
  53. self.nums = scale
  54. convs = []
  55. bns = []
  56. for i in range(self.nums):
  57. convs.append(conv3x3(width, width))
  58. bns.append(nn.BatchNorm2d(width))
  59. self.convs = nn.ModuleList(convs)
  60. self.bns = nn.ModuleList(bns)
  61. self.relu = ReLU(inplace=True)
  62. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  63. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  64. self.shortcut = nn.Sequential()
  65. if stride != 1 or in_planes != self.expansion * planes:
  66. self.shortcut = nn.Sequential(
  67. nn.Conv2d(
  68. in_planes,
  69. self.expansion * planes,
  70. kernel_size=1,
  71. stride=stride,
  72. bias=False), nn.BatchNorm2d(self.expansion * planes))
  73. self.stride = stride
  74. self.width = width
  75. self.scale = scale
  76. def forward(self, x):
  77. residual = x
  78. out = self.conv1(x)
  79. out = self.bn1(out)
  80. out = self.relu(out)
  81. spx = torch.split(out, self.width, 1)
  82. for i in range(self.nums):
  83. if i == 0:
  84. sp = spx[i]
  85. else:
  86. sp = sp + spx[i]
  87. sp = self.convs[i](sp)
  88. sp = self.relu(self.bns[i](sp))
  89. if i == 0:
  90. out = sp
  91. else:
  92. out = torch.cat((out, sp), 1)
  93. out = self.conv3(out)
  94. out = self.bn3(out)
  95. residual = self.shortcut(x)
  96. out += residual
  97. out = self.relu(out)
  98. return out
  99. class BasicBlockERes2Net_AFF(nn.Module):
  100. expansion = 2
  101. def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
  102. super(BasicBlockERes2Net_AFF, self).__init__()
  103. width = int(math.floor(planes * (baseWidth / 64.0)))
  104. self.conv1 = conv1x1(in_planes, width * scale, stride)
  105. self.bn1 = nn.BatchNorm2d(width * scale)
  106. self.nums = scale
  107. convs = []
  108. fuse_models = []
  109. bns = []
  110. for i in range(self.nums):
  111. convs.append(conv3x3(width, width))
  112. bns.append(nn.BatchNorm2d(width))
  113. for j in range(self.nums - 1):
  114. fuse_models.append(AFF(channels=width))
  115. self.convs = nn.ModuleList(convs)
  116. self.bns = nn.ModuleList(bns)
  117. self.fuse_models = nn.ModuleList(fuse_models)
  118. self.relu = ReLU(inplace=True)
  119. self.conv3 = conv1x1(width * scale, planes * self.expansion)
  120. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  121. self.shortcut = nn.Sequential()
  122. if stride != 1 or in_planes != self.expansion * planes:
  123. self.shortcut = nn.Sequential(
  124. nn.Conv2d(
  125. in_planes,
  126. self.expansion * planes,
  127. kernel_size=1,
  128. stride=stride,
  129. bias=False), nn.BatchNorm2d(self.expansion * planes))
  130. self.stride = stride
  131. self.width = width
  132. self.scale = scale
  133. def forward(self, x):
  134. residual = x
  135. out = self.conv1(x)
  136. out = self.bn1(out)
  137. out = self.relu(out)
  138. spx = torch.split(out, self.width, 1)
  139. for i in range(self.nums):
  140. if i == 0:
  141. sp = spx[i]
  142. else:
  143. sp = self.fuse_models[i - 1](sp, spx[i])
  144. sp = self.convs[i](sp)
  145. sp = self.relu(self.bns[i](sp))
  146. if i == 0:
  147. out = sp
  148. else:
  149. out = torch.cat((out, sp), 1)
  150. out = self.conv3(out)
  151. out = self.bn3(out)
  152. residual = self.shortcut(x)
  153. out += residual
  154. out = self.relu(out)
  155. return out
  156. class ERes2Net(nn.Module):
  157. def __init__(self,
  158. block=BasicBlockERes2Net,
  159. block_fuse=BasicBlockERes2Net_AFF,
  160. num_blocks=[3, 4, 6, 3],
  161. m_channels=32,
  162. feat_dim=80,
  163. embed_dim=192,
  164. pooling_func='TSTP',
  165. two_emb_layer=False):
  166. super(ERes2Net, self).__init__()
  167. self.in_planes = m_channels
  168. self.feat_dim = feat_dim
  169. self.embed_dim = embed_dim
  170. self.stats_dim = int(feat_dim / 8) * m_channels * 8
  171. self.two_emb_layer = two_emb_layer
  172. self.conv1 = nn.Conv2d(
  173. 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
  174. self.bn1 = nn.BatchNorm2d(m_channels)
  175. self.layer1 = self._make_layer(
  176. block, m_channels, num_blocks[0], stride=1)
  177. self.layer2 = self._make_layer(
  178. block, m_channels * 2, num_blocks[1], stride=2)
  179. self.layer3 = self._make_layer(
  180. block_fuse, m_channels * 4, num_blocks[2], stride=2)
  181. self.layer4 = self._make_layer(
  182. block_fuse, m_channels * 8, num_blocks[3], stride=2)
  183. # downsampling
  184. self.layer1_downsample = nn.Conv2d(
  185. m_channels * 2,
  186. m_channels * 4,
  187. kernel_size=3,
  188. stride=2,
  189. padding=1,
  190. bias=False)
  191. self.layer2_downsample = nn.Conv2d(
  192. m_channels * 4,
  193. m_channels * 8,
  194. kernel_size=3,
  195. padding=1,
  196. stride=2,
  197. bias=False)
  198. self.layer3_downsample = nn.Conv2d(
  199. m_channels * 8,
  200. m_channels * 16,
  201. kernel_size=3,
  202. padding=1,
  203. stride=2,
  204. bias=False)
  205. # bottom-up fusion
  206. self.fuse_mode12 = AFF(channels=m_channels * 4)
  207. self.fuse_mode123 = AFF(channels=m_channels * 8)
  208. self.fuse_mode1234 = AFF(channels=m_channels * 16)
  209. self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
  210. self.pool = getattr(pooling_layers, pooling_func)(
  211. in_dim=self.stats_dim * block.expansion)
  212. self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
  213. embed_dim)
  214. if self.two_emb_layer:
  215. self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False)
  216. self.seg_2 = nn.Linear(embed_dim, embed_dim)
  217. else:
  218. self.seg_bn_1 = nn.Identity()
  219. self.seg_2 = nn.Identity()
  220. def _make_layer(self, block, planes, num_blocks, stride):
  221. strides = [stride] + [1] * (num_blocks - 1)
  222. layers = []
  223. for stride in strides:
  224. layers.append(block(self.in_planes, planes, stride))
  225. self.in_planes = planes * block.expansion
  226. return nn.Sequential(*layers)
  227. def forward(self, x):
  228. x = x.permute(0, 2, 1)
  229. x = x.unsqueeze_(1)
  230. out = F.relu(self.bn1(self.conv1(x)))
  231. out1 = self.layer1(out)
  232. # bottom-up fusion
  233. out2 = self.layer2(out1)
  234. out1_downsample = self.layer1_downsample(out1)
  235. fuse_out12 = self.fuse_mode12(out2, out1_downsample)
  236. out3 = self.layer3(out2)
  237. fuse_out12_downsample = self.layer2_downsample(fuse_out12)
  238. fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
  239. out4 = self.layer4(out3)
  240. fuse_out123_downsample = self.layer3_downsample(fuse_out123)
  241. fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
  242. stats = self.pool(fuse_out1234)
  243. embed_a = self.seg_1(stats)
  244. if self.two_emb_layer:
  245. out = F.relu(embed_a)
  246. out = self.seg_bn_1(out)
  247. embed_b = self.seg_2(out)
  248. return embed_b
  249. else:
  250. return embed_a
  251. @MODELS.register_module(
  252. Tasks.speaker_verification, module_name=Models.eres2net_sv)
  253. class SpeakerVerificationERes2Net(TorchModel):
  254. r"""Enhanced Res2Net architecture with local and global feature fusion. ERes2Net is mainly composed
  255. of LFF and GFF. The LFF extracts localization-preserved speaker features and strengthen the local information
  256. interaction. GFF fuses multi-scale feature maps in bottom-up pathway to obtain global information.
  257. Args:
  258. model_dir: A model dir.
  259. model_config: The model config.
  260. """
  261. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  262. **kwargs):
  263. super().__init__(model_dir, model_config, *args, **kwargs)
  264. self.model_config = model_config
  265. self.embed_dim = self.model_config['embed_dim']
  266. self.m_channels = self.model_config['channels']
  267. self.other_config = kwargs
  268. self.feature_dim = 80
  269. self.device = create_device(self.other_config['device'])
  270. self.embedding_model = ERes2Net(
  271. embed_dim=self.embed_dim, m_channels=self.m_channels)
  272. pretrained_model_name = kwargs['pretrained_model']
  273. self.__load_check_point(pretrained_model_name)
  274. self.embedding_model.to(self.device)
  275. self.embedding_model.eval()
  276. def forward(self, audio):
  277. if isinstance(audio, np.ndarray):
  278. audio = torch.from_numpy(audio)
  279. if len(audio.shape) == 1:
  280. audio = audio.unsqueeze(0)
  281. assert len(
  282. audio.shape
  283. ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
  284. # audio shape: [N, T]
  285. feature = self.__extract_feature(audio)
  286. embedding = self.embedding_model(feature.to(self.device))
  287. return embedding.detach().cpu()
  288. def __extract_feature(self, audio):
  289. feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
  290. feature = feature - feature.mean(dim=0, keepdim=True)
  291. feature = feature.unsqueeze(0)
  292. return feature
  293. def __load_check_point(self, pretrained_model_name, device=None):
  294. if not device:
  295. device = torch.device('cpu')
  296. self.embedding_model.load_state_dict(
  297. torch.load(
  298. os.path.join(self.model_dir, pretrained_model_name),
  299. map_location=device),
  300. strict=True)