sdpn.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ This ECAPA-TDNN implementation is adapted from https://github.com/speechbrain/speechbrain.
  3. Self-Distillation Prototypes Network(SDPN) is a self-supervised learning framework in SV.
  4. It comprises a teacher and a student network with identical architecture
  5. but different parameters. Teacher/student network consists of three main modules:
  6. the encoder for extracting speaker embeddings, multi-layer perceptron for
  7. feature transformation, and prototypes for computing soft-distributions between
  8. global and local views. EMA denotes Exponential Moving Average.
  9. """
  10. import math
  11. import os
  12. from typing import Any, Dict, Union
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. import torchaudio.compliance.kaldi as Kaldi
  17. from modelscope.metainfo import Models
  18. from modelscope.models import MODELS, TorchModel
  19. from modelscope.utils.constant import Tasks
  20. def length_to_mask(length, max_len=None, dtype=None, device=None):
  21. assert len(length.shape) == 1
  22. if max_len is None:
  23. max_len = length.max().long().item()
  24. mask = torch.arange(
  25. max_len, device=length.device, dtype=length.dtype).expand(
  26. len(length), max_len) < length.unsqueeze(1)
  27. if dtype is None:
  28. dtype = length.dtype
  29. if device is None:
  30. device = length.device
  31. mask = torch.as_tensor(mask, dtype=dtype, device=device)
  32. return mask
  33. def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
  34. if stride > 1:
  35. n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
  36. L_out = stride * (n_steps - 1) + kernel_size * dilation
  37. padding = [kernel_size // 2, kernel_size // 2]
  38. else:
  39. L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
  40. padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
  41. return padding
  42. class Conv1d(nn.Module):
  43. def __init__(
  44. self,
  45. out_channels,
  46. kernel_size,
  47. in_channels,
  48. stride=1,
  49. dilation=1,
  50. padding='same',
  51. groups=1,
  52. bias=True,
  53. padding_mode='reflect',
  54. ):
  55. super().__init__()
  56. self.kernel_size = kernel_size
  57. self.stride = stride
  58. self.dilation = dilation
  59. self.padding = padding
  60. self.padding_mode = padding_mode
  61. self.conv = nn.Conv1d(
  62. in_channels,
  63. out_channels,
  64. self.kernel_size,
  65. stride=self.stride,
  66. dilation=self.dilation,
  67. padding=0,
  68. groups=groups,
  69. bias=bias,
  70. )
  71. def forward(self, x):
  72. if self.padding == 'same':
  73. x = self._manage_padding(x, self.kernel_size, self.dilation,
  74. self.stride)
  75. elif self.padding == 'causal':
  76. num_pad = (self.kernel_size - 1) * self.dilation
  77. x = F.pad(x, (num_pad, 0))
  78. elif self.padding == 'valid':
  79. pass
  80. else:
  81. raise ValueError(
  82. "Padding must be 'same', 'valid' or 'causal'. Got "
  83. + self.padding)
  84. wx = self.conv(x)
  85. return wx
  86. def _manage_padding(
  87. self,
  88. x,
  89. kernel_size: int,
  90. dilation: int,
  91. stride: int,
  92. ):
  93. L_in = x.shape[-1]
  94. padding = get_padding_elem(L_in, stride, kernel_size, dilation)
  95. x = F.pad(x, padding, mode=self.padding_mode)
  96. return x
  97. class BatchNorm1d(nn.Module):
  98. def __init__(
  99. self,
  100. input_size,
  101. eps=1e-05,
  102. momentum=0.1,
  103. ):
  104. super().__init__()
  105. self.norm = nn.BatchNorm1d(
  106. input_size,
  107. eps=eps,
  108. momentum=momentum,
  109. )
  110. def forward(self, x):
  111. return self.norm(x)
  112. class TDNNBlock(nn.Module):
  113. def __init__(
  114. self,
  115. in_channels,
  116. out_channels,
  117. kernel_size,
  118. dilation,
  119. activation=nn.ReLU,
  120. groups=1,
  121. ):
  122. super(TDNNBlock, self).__init__()
  123. self.conv = Conv1d(
  124. in_channels=in_channels,
  125. out_channels=out_channels,
  126. kernel_size=kernel_size,
  127. dilation=dilation,
  128. groups=groups,
  129. )
  130. self.activation = activation()
  131. self.norm = BatchNorm1d(input_size=out_channels)
  132. def forward(self, x):
  133. return self.norm(self.activation(self.conv(x)))
  134. class Res2NetBlock(torch.nn.Module):
  135. def __init__(self,
  136. in_channels,
  137. out_channels,
  138. scale=8,
  139. kernel_size=3,
  140. dilation=1):
  141. super(Res2NetBlock, self).__init__()
  142. assert in_channels % scale == 0
  143. assert out_channels % scale == 0
  144. in_channel = in_channels // scale
  145. hidden_channel = out_channels // scale
  146. self.blocks = nn.ModuleList([
  147. TDNNBlock(
  148. in_channel,
  149. hidden_channel,
  150. kernel_size=kernel_size,
  151. dilation=dilation,
  152. ) for i in range(scale - 1)
  153. ])
  154. self.scale = scale
  155. def forward(self, x):
  156. y = []
  157. for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
  158. if i == 0:
  159. y_i = x_i
  160. elif i == 1:
  161. y_i = self.blocks[i - 1](x_i)
  162. else:
  163. y_i = self.blocks[i - 1](x_i + y_i)
  164. y.append(y_i)
  165. y = torch.cat(y, dim=1)
  166. return y
  167. class SEBlock(nn.Module):
  168. def __init__(self, in_channels, se_channels, out_channels):
  169. super(SEBlock, self).__init__()
  170. self.conv1 = Conv1d(
  171. in_channels=in_channels, out_channels=se_channels, kernel_size=1)
  172. self.relu = torch.nn.ReLU(inplace=True)
  173. self.conv2 = Conv1d(
  174. in_channels=se_channels, out_channels=out_channels, kernel_size=1)
  175. self.sigmoid = torch.nn.Sigmoid()
  176. def forward(self, x, lengths=None):
  177. L = x.shape[-1]
  178. if lengths is not None:
  179. mask = length_to_mask(lengths * L, max_len=L, device=x.device)
  180. mask = mask.unsqueeze(1)
  181. total = mask.sum(dim=2, keepdim=True)
  182. s = (x * mask).sum(dim=2, keepdim=True) / total
  183. else:
  184. s = x.mean(dim=2, keepdim=True)
  185. s = self.relu(self.conv1(s))
  186. s = self.sigmoid(self.conv2(s))
  187. return s * x
  188. class AttentiveStatisticsPooling(nn.Module):
  189. def __init__(self, channels, attention_channels=128, global_context=True):
  190. super().__init__()
  191. self.eps = 1e-12
  192. self.global_context = global_context
  193. if global_context:
  194. self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
  195. else:
  196. self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
  197. self.tanh = nn.Tanh()
  198. self.conv = Conv1d(
  199. in_channels=attention_channels,
  200. out_channels=channels,
  201. kernel_size=1)
  202. def forward(self, x, lengths=None):
  203. L = x.shape[-1]
  204. def _compute_statistics(x, m, dim=2, eps=self.eps):
  205. mean = (m * x).sum(dim)
  206. std = torch.sqrt(
  207. (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
  208. return mean, std
  209. if lengths is None:
  210. lengths = torch.ones(x.shape[0], device=x.device)
  211. # Make binary mask of shape [N, 1, L]
  212. mask = length_to_mask(lengths * L, max_len=L, device=x.device)
  213. mask = mask.unsqueeze(1)
  214. # Expand the temporal context of the pooling layer by allowing the
  215. # self-attention to look at global properties of the utterance.
  216. if self.global_context:
  217. # torch.std is unstable for backward computation
  218. # https://github.com/pytorch/pytorch/issues/4320
  219. total = mask.sum(dim=2, keepdim=True).float()
  220. mean, std = _compute_statistics(x, mask / total)
  221. mean = mean.unsqueeze(2).repeat(1, 1, L)
  222. std = std.unsqueeze(2).repeat(1, 1, L)
  223. attn = torch.cat([x, mean, std], dim=1)
  224. else:
  225. attn = x
  226. # Apply layers
  227. attn = self.conv(self.tanh(self.tdnn(attn)))
  228. # Filter out zero-paddings
  229. attn = attn.masked_fill(mask == 0, float('-inf'))
  230. attn = F.softmax(attn, dim=2)
  231. mean, std = _compute_statistics(x, attn)
  232. # Append mean and std of the batch
  233. pooled_stats = torch.cat((mean, std), dim=1)
  234. pooled_stats = pooled_stats.unsqueeze(2)
  235. return pooled_stats
  236. class SERes2NetBlock(nn.Module):
  237. def __init__(
  238. self,
  239. in_channels,
  240. out_channels,
  241. res2net_scale=8,
  242. se_channels=128,
  243. kernel_size=1,
  244. dilation=1,
  245. activation=torch.nn.ReLU,
  246. groups=1,
  247. ):
  248. super().__init__()
  249. self.out_channels = out_channels
  250. self.tdnn1 = TDNNBlock(
  251. in_channels,
  252. out_channels,
  253. kernel_size=1,
  254. dilation=1,
  255. activation=activation,
  256. groups=groups,
  257. )
  258. self.res2net_block = Res2NetBlock(out_channels, out_channels,
  259. res2net_scale, kernel_size, dilation)
  260. self.tdnn2 = TDNNBlock(
  261. out_channels,
  262. out_channels,
  263. kernel_size=1,
  264. dilation=1,
  265. activation=activation,
  266. groups=groups,
  267. )
  268. self.se_block = SEBlock(out_channels, se_channels, out_channels)
  269. self.shortcut = None
  270. if in_channels != out_channels:
  271. self.shortcut = Conv1d(
  272. in_channels=in_channels,
  273. out_channels=out_channels,
  274. kernel_size=1,
  275. )
  276. def forward(self, x, lengths=None):
  277. residual = x
  278. if self.shortcut:
  279. residual = self.shortcut(x)
  280. x = self.tdnn1(x)
  281. x = self.res2net_block(x)
  282. x = self.tdnn2(x)
  283. x = self.se_block(x, lengths)
  284. return x + residual
  285. class ECAPA_TDNN(nn.Module):
  286. """An implementation of the speaker embedding model in a paper.
  287. "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
  288. TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
  289. """
  290. def __init__(
  291. self,
  292. input_size,
  293. device='cpu',
  294. lin_neurons=512,
  295. activation=torch.nn.ReLU,
  296. channels=[512, 512, 512, 512, 1536],
  297. kernel_sizes=[5, 3, 3, 3, 1],
  298. dilations=[1, 2, 3, 4, 1],
  299. attention_channels=128,
  300. res2net_scale=8,
  301. se_channels=128,
  302. global_context=True,
  303. groups=[1, 1, 1, 1, 1],
  304. ):
  305. super().__init__()
  306. assert len(channels) == len(kernel_sizes)
  307. assert len(channels) == len(dilations)
  308. self.channels = channels
  309. self.blocks = nn.ModuleList()
  310. # The initial TDNN layer
  311. self.blocks.append(
  312. TDNNBlock(
  313. input_size,
  314. channels[0],
  315. kernel_sizes[0],
  316. dilations[0],
  317. activation,
  318. groups[0],
  319. ))
  320. # SE-Res2Net layers
  321. for i in range(1, len(channels) - 1):
  322. self.blocks.append(
  323. SERes2NetBlock(
  324. channels[i - 1],
  325. channels[i],
  326. res2net_scale=res2net_scale,
  327. se_channels=se_channels,
  328. kernel_size=kernel_sizes[i],
  329. dilation=dilations[i],
  330. activation=activation,
  331. groups=groups[i],
  332. ))
  333. # Multi-layer feature aggregation
  334. self.mfa = TDNNBlock(
  335. channels[-1],
  336. channels[-1],
  337. kernel_sizes[-1],
  338. dilations[-1],
  339. activation,
  340. groups=groups[-1],
  341. )
  342. # Attentive Statistical Pooling
  343. self.asp = AttentiveStatisticsPooling(
  344. channels[-1],
  345. attention_channels=attention_channels,
  346. global_context=global_context,
  347. )
  348. self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
  349. # Final linear transformation
  350. self.fc = Conv1d(
  351. in_channels=channels[-1] * 2,
  352. out_channels=lin_neurons,
  353. kernel_size=1,
  354. )
  355. def forward(self, x, lengths=None):
  356. """Returns the embedding vector.
  357. Arguments
  358. ---------
  359. x : torch.Tensor
  360. Tensor of shape (batch, time, channel).
  361. """
  362. x = x.transpose(1, 2)
  363. xl = []
  364. for layer in self.blocks:
  365. try:
  366. x = layer(x, lengths=lengths)
  367. except TypeError:
  368. x = layer(x)
  369. xl.append(x)
  370. # Multi-layer feature aggregation
  371. x = torch.cat(xl[1:], dim=1)
  372. x = self.mfa(x)
  373. # Attentive Statistical Pooling
  374. x = self.asp(x, lengths=lengths)
  375. x = self.asp_bn(x)
  376. # Final linear transformation
  377. x = self.fc(x)
  378. x = x.transpose(1, 2).squeeze(1)
  379. return x
  380. def _no_grad_trunc_normal_(tensor, mean, std, a, b):
  381. def norm_cdf(x):
  382. # Computes standard normal cumulative distribution function
  383. return (1. + math.erf(x / math.sqrt(2.))) / 2.
  384. if (mean < a - 2 * std) or (mean > b + 2 * std):
  385. warnings.warn(
  386. 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_.'
  387. 'The distribution of values may be incorrect.',
  388. stacklevel=2)
  389. with torch.no_grad():
  390. # Values are generated by using a truncated uniform distribution and
  391. # then using the inverse CDF for the normal distribution.
  392. # Get upper and lower cdf values
  393. l_ = norm_cdf((a - mean) / std)
  394. u = norm_cdf((b - mean) / std)
  395. # Uniformly fill tensor with values from [l_, u], then translate to
  396. # [2l-1, 2u-1].
  397. tensor.uniform_(2 * l_ - 1, 2 * u - 1)
  398. # Use inverse cdf transform for normal distribution to get truncated
  399. # standard normal
  400. tensor.erfinv_()
  401. # Transform to proper mean, std
  402. tensor.mul_(std * math.sqrt(2.))
  403. tensor.add_(mean)
  404. # Clamp to ensure it's in the proper range
  405. tensor.clamp_(min=a, max=b)
  406. return tensor
  407. def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
  408. # type: (Tensor, float, float, float, float) -> Tensor
  409. return _no_grad_trunc_normal_(tensor, mean, std, a, b)
  410. class SDPNHead(nn.Module):
  411. def __init__(self,
  412. in_dim,
  413. use_bn=False,
  414. nlayers=3,
  415. hidden_dim=2048,
  416. bottleneck_dim=256):
  417. super().__init__()
  418. nlayers = max(nlayers, 1)
  419. if nlayers == 1:
  420. self.mlp = nn.Linear(in_dim, bottleneck_dim)
  421. else:
  422. layers = [nn.Linear(in_dim, hidden_dim)]
  423. if use_bn:
  424. layers.append(nn.BatchNorm1d(hidden_dim))
  425. layers.append(nn.GELU())
  426. for _ in range(nlayers - 2):
  427. layers.append(nn.Linear(hidden_dim, hidden_dim))
  428. if use_bn:
  429. layers.append(nn.BatchNorm1d(hidden_dim))
  430. layers.append(nn.GELU())
  431. layers.append(nn.Linear(hidden_dim, bottleneck_dim))
  432. self.mlp = nn.Sequential(*layers)
  433. self.apply(self._init_weights)
  434. def _init_weights(self, m):
  435. if isinstance(m, nn.Linear):
  436. trunc_normal_(m.weight, std=.02)
  437. if isinstance(m, nn.Linear) and m.bias is not None:
  438. nn.init.constant_(m.bias, 0)
  439. def forward(self, x):
  440. x = self.mlp(x)
  441. x = nn.functional.normalize(x, dim=-1, p=2)
  442. return x
  443. class Combiner(torch.nn.Module):
  444. """
  445. Combine backbone (ECAPA) and head (MLP)
  446. """
  447. def __init__(self, backbone, head):
  448. super(Combiner, self).__init__()
  449. self.backbone = backbone
  450. self.head = head
  451. def forward(self, x):
  452. x = self.backbone(x)
  453. output = self.head(x)
  454. return x, output
  455. @MODELS.register_module(Tasks.speaker_verification, module_name=Models.sdpn_sv)
  456. class SpeakerVerificationSDPN(TorchModel):
  457. """
  458. Self-Distillation Prototypes Network (SDPN) effectively facilitates
  459. self-supervised speaker representation learning. The specific structure can be
  460. referred to in https://arxiv.org/pdf/2308.02774.
  461. """
  462. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  463. **kwargs):
  464. super().__init__(model_dir, model_config, *args, **kwargs)
  465. self.model_config = model_config
  466. self.other_config = kwargs
  467. if self.model_config['channel'] != 1024:
  468. raise ValueError(
  469. 'modelscope error: Currently only 1024-channel ecapa tdnn is supported.'
  470. )
  471. self.feature_dim = 80
  472. channels_config = [1024, 1024, 1024, 1024, 3072]
  473. self.embedding_model = ECAPA_TDNN(
  474. self.feature_dim, channels=channels_config)
  475. self.embedding_model = Combiner(self.embedding_model,
  476. SDPNHead(512, True))
  477. pretrained_model_name = kwargs['pretrained_model']
  478. self.__load_check_point(pretrained_model_name)
  479. self.embedding_model.eval()
  480. def forward(self, audio):
  481. assert len(audio.shape) == 2 and audio.shape[
  482. 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]'
  483. # audio shape: [1, T]
  484. feature = self.__extract_feature(audio)
  485. embedding = self.embedding_model.backbone(feature)
  486. return embedding
  487. def __extract_feature(self, audio):
  488. feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
  489. feature = feature - feature.mean(dim=0, keepdim=True)
  490. feature = feature.unsqueeze(0)
  491. return feature
  492. def __load_check_point(self, pretrained_model_name, device=None):
  493. if not device:
  494. device = torch.device('cpu')
  495. state_dict = torch.load(
  496. os.path.join(self.model_dir, pretrained_model_name),
  497. map_location=device)
  498. state_dict_tea = {
  499. k.replace('module.', ''): v
  500. for k, v in state_dict['teacher'].items()
  501. }
  502. self.embedding_model.load_state_dict(state_dict_tea, strict=True)