rdino.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ This ECAPA-TDNN implementation is adapted from https://github.com/speechbrain/speechbrain.
  3. RDINOHead implementation is adapted from DINO framework.
  4. """
  5. import math
  6. import os
  7. from typing import Any, Dict, Union
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import torchaudio.compliance.kaldi as Kaldi
  12. from modelscope.metainfo import Models
  13. from modelscope.models import MODELS, TorchModel
  14. from modelscope.utils.constant import Tasks
  15. def length_to_mask(length, max_len=None, dtype=None, device=None):
  16. assert len(length.shape) == 1
  17. if max_len is None:
  18. max_len = length.max().long().item()
  19. mask = torch.arange(
  20. max_len, device=length.device, dtype=length.dtype).expand(
  21. len(length), max_len) < length.unsqueeze(1)
  22. if dtype is None:
  23. dtype = length.dtype
  24. if device is None:
  25. device = length.device
  26. mask = torch.as_tensor(mask, dtype=dtype, device=device)
  27. return mask
  28. def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
  29. if stride > 1:
  30. n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
  31. L_out = stride * (n_steps - 1) + kernel_size * dilation
  32. padding = [kernel_size // 2, kernel_size // 2]
  33. else:
  34. L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
  35. padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
  36. return padding
  37. class Conv1d(nn.Module):
  38. def __init__(
  39. self,
  40. out_channels,
  41. kernel_size,
  42. in_channels,
  43. stride=1,
  44. dilation=1,
  45. padding='same',
  46. groups=1,
  47. bias=True,
  48. padding_mode='reflect',
  49. ):
  50. super().__init__()
  51. self.kernel_size = kernel_size
  52. self.stride = stride
  53. self.dilation = dilation
  54. self.padding = padding
  55. self.padding_mode = padding_mode
  56. self.conv = nn.Conv1d(
  57. in_channels,
  58. out_channels,
  59. self.kernel_size,
  60. stride=self.stride,
  61. dilation=self.dilation,
  62. padding=0,
  63. groups=groups,
  64. bias=bias,
  65. )
  66. def forward(self, x):
  67. if self.padding == 'same':
  68. x = self._manage_padding(x, self.kernel_size, self.dilation,
  69. self.stride)
  70. elif self.padding == 'causal':
  71. num_pad = (self.kernel_size - 1) * self.dilation
  72. x = F.pad(x, (num_pad, 0))
  73. elif self.padding == 'valid':
  74. pass
  75. else:
  76. raise ValueError(
  77. "Padding must be 'same', 'valid' or 'causal'. Got "
  78. + self.padding)
  79. wx = self.conv(x)
  80. return wx
  81. def _manage_padding(
  82. self,
  83. x,
  84. kernel_size: int,
  85. dilation: int,
  86. stride: int,
  87. ):
  88. L_in = x.shape[-1]
  89. padding = get_padding_elem(L_in, stride, kernel_size, dilation)
  90. x = F.pad(x, padding, mode=self.padding_mode)
  91. return x
  92. class BatchNorm1d(nn.Module):
  93. def __init__(
  94. self,
  95. input_size,
  96. eps=1e-05,
  97. momentum=0.1,
  98. ):
  99. super().__init__()
  100. self.norm = nn.BatchNorm1d(
  101. input_size,
  102. eps=eps,
  103. momentum=momentum,
  104. )
  105. def forward(self, x):
  106. return self.norm(x)
  107. class TDNNBlock(nn.Module):
  108. def __init__(
  109. self,
  110. in_channels,
  111. out_channels,
  112. kernel_size,
  113. dilation,
  114. activation=nn.ReLU,
  115. groups=1,
  116. ):
  117. super(TDNNBlock, self).__init__()
  118. self.conv = Conv1d(
  119. in_channels=in_channels,
  120. out_channels=out_channels,
  121. kernel_size=kernel_size,
  122. dilation=dilation,
  123. groups=groups,
  124. )
  125. self.activation = activation()
  126. self.norm = BatchNorm1d(input_size=out_channels)
  127. def forward(self, x):
  128. return self.norm(self.activation(self.conv(x)))
  129. class Res2NetBlock(torch.nn.Module):
  130. def __init__(self,
  131. in_channels,
  132. out_channels,
  133. scale=8,
  134. kernel_size=3,
  135. dilation=1):
  136. super(Res2NetBlock, self).__init__()
  137. assert in_channels % scale == 0
  138. assert out_channels % scale == 0
  139. in_channel = in_channels // scale
  140. hidden_channel = out_channels // scale
  141. self.blocks = nn.ModuleList([
  142. TDNNBlock(
  143. in_channel,
  144. hidden_channel,
  145. kernel_size=kernel_size,
  146. dilation=dilation,
  147. ) for i in range(scale - 1)
  148. ])
  149. self.scale = scale
  150. def forward(self, x):
  151. y = []
  152. for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
  153. if i == 0:
  154. y_i = x_i
  155. elif i == 1:
  156. y_i = self.blocks[i - 1](x_i)
  157. else:
  158. y_i = self.blocks[i - 1](x_i + y_i)
  159. y.append(y_i)
  160. y = torch.cat(y, dim=1)
  161. return y
  162. class SEBlock(nn.Module):
  163. def __init__(self, in_channels, se_channels, out_channels):
  164. super(SEBlock, self).__init__()
  165. self.conv1 = Conv1d(
  166. in_channels=in_channels, out_channels=se_channels, kernel_size=1)
  167. self.relu = torch.nn.ReLU(inplace=True)
  168. self.conv2 = Conv1d(
  169. in_channels=se_channels, out_channels=out_channels, kernel_size=1)
  170. self.sigmoid = torch.nn.Sigmoid()
  171. def forward(self, x, lengths=None):
  172. L = x.shape[-1]
  173. if lengths is not None:
  174. mask = length_to_mask(lengths * L, max_len=L, device=x.device)
  175. mask = mask.unsqueeze(1)
  176. total = mask.sum(dim=2, keepdim=True)
  177. s = (x * mask).sum(dim=2, keepdim=True) / total
  178. else:
  179. s = x.mean(dim=2, keepdim=True)
  180. s = self.relu(self.conv1(s))
  181. s = self.sigmoid(self.conv2(s))
  182. return s * x
  183. class AttentiveStatisticsPooling(nn.Module):
  184. def __init__(self, channels, attention_channels=128, global_context=True):
  185. super().__init__()
  186. self.eps = 1e-12
  187. self.global_context = global_context
  188. if global_context:
  189. self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
  190. else:
  191. self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
  192. self.tanh = nn.Tanh()
  193. self.conv = Conv1d(
  194. in_channels=attention_channels,
  195. out_channels=channels,
  196. kernel_size=1)
  197. def forward(self, x, lengths=None):
  198. L = x.shape[-1]
  199. def _compute_statistics(x, m, dim=2, eps=self.eps):
  200. mean = (m * x).sum(dim)
  201. std = torch.sqrt(
  202. (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
  203. return mean, std
  204. if lengths is None:
  205. lengths = torch.ones(x.shape[0], device=x.device)
  206. # Make binary mask of shape [N, 1, L]
  207. mask = length_to_mask(lengths * L, max_len=L, device=x.device)
  208. mask = mask.unsqueeze(1)
  209. # Expand the temporal context of the pooling layer by allowing the
  210. # self-attention to look at global properties of the utterance.
  211. if self.global_context:
  212. # torch.std is unstable for backward computation
  213. # https://github.com/pytorch/pytorch/issues/4320
  214. total = mask.sum(dim=2, keepdim=True).float()
  215. mean, std = _compute_statistics(x, mask / total)
  216. mean = mean.unsqueeze(2).repeat(1, 1, L)
  217. std = std.unsqueeze(2).repeat(1, 1, L)
  218. attn = torch.cat([x, mean, std], dim=1)
  219. else:
  220. attn = x
  221. # Apply layers
  222. attn = self.conv(self.tanh(self.tdnn(attn)))
  223. # Filter out zero-paddings
  224. attn = attn.masked_fill(mask == 0, float('-inf'))
  225. attn = F.softmax(attn, dim=2)
  226. mean, std = _compute_statistics(x, attn)
  227. # Append mean and std of the batch
  228. pooled_stats = torch.cat((mean, std), dim=1)
  229. pooled_stats = pooled_stats.unsqueeze(2)
  230. return pooled_stats
  231. class SERes2NetBlock(nn.Module):
  232. def __init__(
  233. self,
  234. in_channels,
  235. out_channels,
  236. res2net_scale=8,
  237. se_channels=128,
  238. kernel_size=1,
  239. dilation=1,
  240. activation=torch.nn.ReLU,
  241. groups=1,
  242. ):
  243. super().__init__()
  244. self.out_channels = out_channels
  245. self.tdnn1 = TDNNBlock(
  246. in_channels,
  247. out_channels,
  248. kernel_size=1,
  249. dilation=1,
  250. activation=activation,
  251. groups=groups,
  252. )
  253. self.res2net_block = Res2NetBlock(out_channels, out_channels,
  254. res2net_scale, kernel_size, dilation)
  255. self.tdnn2 = TDNNBlock(
  256. out_channels,
  257. out_channels,
  258. kernel_size=1,
  259. dilation=1,
  260. activation=activation,
  261. groups=groups,
  262. )
  263. self.se_block = SEBlock(out_channels, se_channels, out_channels)
  264. self.shortcut = None
  265. if in_channels != out_channels:
  266. self.shortcut = Conv1d(
  267. in_channels=in_channels,
  268. out_channels=out_channels,
  269. kernel_size=1,
  270. )
  271. def forward(self, x, lengths=None):
  272. residual = x
  273. if self.shortcut:
  274. residual = self.shortcut(x)
  275. x = self.tdnn1(x)
  276. x = self.res2net_block(x)
  277. x = self.tdnn2(x)
  278. x = self.se_block(x, lengths)
  279. return x + residual
  280. class ECAPA_TDNN(nn.Module):
  281. """An implementation of the speaker embedding model in a paper.
  282. "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
  283. TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
  284. """
  285. def __init__(
  286. self,
  287. input_size,
  288. device='cpu',
  289. lin_neurons=512,
  290. activation=torch.nn.ReLU,
  291. channels=[512, 512, 512, 512, 1536],
  292. kernel_sizes=[5, 3, 3, 3, 1],
  293. dilations=[1, 2, 3, 4, 1],
  294. attention_channels=128,
  295. res2net_scale=8,
  296. se_channels=128,
  297. global_context=True,
  298. groups=[1, 1, 1, 1, 1],
  299. ):
  300. super().__init__()
  301. assert len(channels) == len(kernel_sizes)
  302. assert len(channels) == len(dilations)
  303. self.channels = channels
  304. self.blocks = nn.ModuleList()
  305. # The initial TDNN layer
  306. self.blocks.append(
  307. TDNNBlock(
  308. input_size,
  309. channels[0],
  310. kernel_sizes[0],
  311. dilations[0],
  312. activation,
  313. groups[0],
  314. ))
  315. # SE-Res2Net layers
  316. for i in range(1, len(channels) - 1):
  317. self.blocks.append(
  318. SERes2NetBlock(
  319. channels[i - 1],
  320. channels[i],
  321. res2net_scale=res2net_scale,
  322. se_channels=se_channels,
  323. kernel_size=kernel_sizes[i],
  324. dilation=dilations[i],
  325. activation=activation,
  326. groups=groups[i],
  327. ))
  328. # Multi-layer feature aggregation
  329. self.mfa = TDNNBlock(
  330. channels[-1],
  331. channels[-1],
  332. kernel_sizes[-1],
  333. dilations[-1],
  334. activation,
  335. groups=groups[-1],
  336. )
  337. # Attentive Statistical Pooling
  338. self.asp = AttentiveStatisticsPooling(
  339. channels[-1],
  340. attention_channels=attention_channels,
  341. global_context=global_context,
  342. )
  343. self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
  344. # Final linear transformation
  345. self.fc = Conv1d(
  346. in_channels=channels[-1] * 2,
  347. out_channels=lin_neurons,
  348. kernel_size=1,
  349. )
  350. def forward(self, x, lengths=None):
  351. """Returns the embedding vector.
  352. Arguments
  353. ---------
  354. x : torch.Tensor
  355. Tensor of shape (batch, time, channel).
  356. """
  357. x = x.transpose(1, 2)
  358. xl = []
  359. for layer in self.blocks:
  360. try:
  361. x = layer(x, lengths=lengths)
  362. except TypeError:
  363. x = layer(x)
  364. xl.append(x)
  365. # Multi-layer feature aggregation
  366. x = torch.cat(xl[1:], dim=1)
  367. x = self.mfa(x)
  368. # Attentive Statistical Pooling
  369. x = self.asp(x, lengths=lengths)
  370. x = self.asp_bn(x)
  371. # Final linear transformation
  372. x = self.fc(x)
  373. x = x.transpose(1, 2).squeeze(1)
  374. return x
  375. class RDINOHead(nn.Module):
  376. def __init__(self,
  377. in_dim,
  378. out_dim,
  379. use_bn=False,
  380. norm_last_layer=True,
  381. nlayers=3,
  382. hidden_dim=2048,
  383. bottleneck_dim=256,
  384. add_dim=8192):
  385. super().__init__()
  386. nlayers = max(nlayers, 1)
  387. if nlayers == 1:
  388. self.mlp = nn.Linear(in_dim, bottleneck_dim)
  389. else:
  390. layers = [nn.Linear(in_dim, hidden_dim)]
  391. if use_bn:
  392. layers.append(nn.BatchNorm1d(hidden_dim))
  393. layers.append(nn.GELU())
  394. for _ in range(nlayers - 2):
  395. layers.append(nn.Linear(hidden_dim, hidden_dim))
  396. if use_bn:
  397. layers.append(nn.BatchNorm1d(hidden_dim))
  398. layers.append(nn.GELU())
  399. layers.append(nn.Linear(hidden_dim, add_dim))
  400. self.mlp = nn.Sequential(*layers)
  401. self.add_layer = nn.Linear(add_dim, bottleneck_dim)
  402. self.apply(self._init_weights)
  403. self.last_layer = nn.utils.weight_norm(
  404. nn.Linear(bottleneck_dim, out_dim, bias=False))
  405. self.last_layer.weight_g.data.fill_(1)
  406. if norm_last_layer:
  407. self.last_layer.weight_g.requires_grad = False
  408. def _init_weights(self, m):
  409. if isinstance(m, nn.Linear):
  410. torch.nn.init.trunc_normal_(m.weight, std=.02)
  411. if isinstance(m, nn.Linear) and m.bias is not None:
  412. nn.init.constant_(m.bias, 0)
  413. def forward(self, x):
  414. vicr_out = self.mlp(x)
  415. x = self.add_layer(vicr_out)
  416. x = nn.functional.normalize(x, dim=-1, p=2)
  417. x = self.last_layer(x)
  418. return vicr_out, x
  419. class Combine(nn.Module):
  420. def __init__(self, backbone, head):
  421. super(Combine, self).__init__()
  422. self.backbone = backbone
  423. self.head = head
  424. def forward(self, x):
  425. x = self.backbone(x)
  426. output = self.head(x)
  427. return output
  428. @MODELS.register_module(
  429. Tasks.speaker_verification, module_name=Models.rdino_tdnn_sv)
  430. class SpeakerVerification_RDINO(TorchModel):
  431. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  432. **kwargs):
  433. super().__init__(model_dir, model_config, *args, **kwargs)
  434. self.model_config = model_config
  435. self.other_config = kwargs
  436. if self.model_config['channel'] != 1024:
  437. raise ValueError(
  438. 'modelscope error: Currently only 1024-channel ecapa tdnn is supported.'
  439. )
  440. self.feature_dim = 80
  441. channels_config = [1024, 1024, 1024, 1024, 3072]
  442. self.embedding_model = ECAPA_TDNN(
  443. self.feature_dim, channels=channels_config)
  444. self.embedding_model = Combine(self.embedding_model,
  445. RDINOHead(512, 65536, True))
  446. pretrained_model_name = kwargs['pretrained_model']
  447. self.__load_check_point(pretrained_model_name)
  448. self.embedding_model.eval()
  449. def forward(self, audio):
  450. assert len(audio.shape) == 2 and audio.shape[
  451. 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]'
  452. # audio shape: [1, T]
  453. feature = self.__extract_feature(audio)
  454. embedding = self.embedding_model.backbone(feature)
  455. return embedding
  456. def __extract_feature(self, audio):
  457. feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
  458. feature = feature - feature.mean(dim=0, keepdim=True)
  459. feature = feature.unsqueeze(0)
  460. return feature
  461. def __load_check_point(self, pretrained_model_name, device=None):
  462. if not device:
  463. device = torch.device('cpu')
  464. state_dict = torch.load(
  465. os.path.join(self.model_dir, pretrained_model_name),
  466. map_location=device)
  467. state_dict_tea = {
  468. k.replace('module.', ''): v
  469. for k, v in state_dict['teacher'].items()
  470. }
  471. self.embedding_model.load_state_dict(state_dict_tea, strict=True)