ecapa_tdnn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ This ECAPA-TDNN implementation is adapted from https://github.com/speechbrain/speechbrain.
  3. """
  4. import math
  5. import os
  6. from typing import Any, Dict, Union
  7. import numpy as np
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import torchaudio.compliance.kaldi as Kaldi
  12. from modelscope.metainfo import Models
  13. from modelscope.models import MODELS, TorchModel
  14. from modelscope.utils.constant import Tasks
  15. from modelscope.utils.device import create_device
  16. def length_to_mask(length, max_len=None, dtype=None, device=None):
  17. assert len(length.shape) == 1
  18. if max_len is None:
  19. max_len = length.max().long().item()
  20. mask = torch.arange(
  21. max_len, device=length.device, dtype=length.dtype).expand(
  22. len(length), max_len) < length.unsqueeze(1)
  23. if dtype is None:
  24. dtype = length.dtype
  25. if device is None:
  26. device = length.device
  27. mask = torch.as_tensor(mask, dtype=dtype, device=device)
  28. return mask
  29. def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
  30. if stride > 1:
  31. n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
  32. L_out = stride * (n_steps - 1) + kernel_size * dilation
  33. padding = [kernel_size // 2, kernel_size // 2]
  34. else:
  35. L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
  36. padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
  37. return padding
  38. class Conv1d(nn.Module):
  39. def __init__(
  40. self,
  41. out_channels,
  42. kernel_size,
  43. in_channels,
  44. stride=1,
  45. dilation=1,
  46. padding='same',
  47. groups=1,
  48. bias=True,
  49. padding_mode='reflect',
  50. ):
  51. super().__init__()
  52. self.kernel_size = kernel_size
  53. self.stride = stride
  54. self.dilation = dilation
  55. self.padding = padding
  56. self.padding_mode = padding_mode
  57. self.conv = nn.Conv1d(
  58. in_channels,
  59. out_channels,
  60. self.kernel_size,
  61. stride=self.stride,
  62. dilation=self.dilation,
  63. padding=0,
  64. groups=groups,
  65. bias=bias,
  66. )
  67. def forward(self, x):
  68. if self.padding == 'same':
  69. x = self._manage_padding(x, self.kernel_size, self.dilation,
  70. self.stride)
  71. elif self.padding == 'causal':
  72. num_pad = (self.kernel_size - 1) * self.dilation
  73. x = F.pad(x, (num_pad, 0))
  74. elif self.padding == 'valid':
  75. pass
  76. else:
  77. raise ValueError(
  78. "Padding must be 'same', 'valid' or 'causal'. Got "
  79. + self.padding)
  80. wx = self.conv(x)
  81. return wx
  82. def _manage_padding(
  83. self,
  84. x,
  85. kernel_size: int,
  86. dilation: int,
  87. stride: int,
  88. ):
  89. L_in = x.shape[-1]
  90. padding = get_padding_elem(L_in, stride, kernel_size, dilation)
  91. x = F.pad(x, padding, mode=self.padding_mode)
  92. return x
  93. class BatchNorm1d(nn.Module):
  94. def __init__(
  95. self,
  96. input_size,
  97. eps=1e-05,
  98. momentum=0.1,
  99. ):
  100. super().__init__()
  101. self.norm = nn.BatchNorm1d(
  102. input_size,
  103. eps=eps,
  104. momentum=momentum,
  105. )
  106. def forward(self, x):
  107. return self.norm(x)
  108. class TDNNBlock(nn.Module):
  109. def __init__(
  110. self,
  111. in_channels,
  112. out_channels,
  113. kernel_size,
  114. dilation,
  115. activation=nn.ReLU,
  116. groups=1,
  117. ):
  118. super(TDNNBlock, self).__init__()
  119. self.conv = Conv1d(
  120. in_channels=in_channels,
  121. out_channels=out_channels,
  122. kernel_size=kernel_size,
  123. dilation=dilation,
  124. groups=groups,
  125. )
  126. self.activation = activation()
  127. self.norm = BatchNorm1d(input_size=out_channels)
  128. def forward(self, x):
  129. return self.norm(self.activation(self.conv(x)))
  130. class Res2NetBlock(torch.nn.Module):
  131. def __init__(self,
  132. in_channels,
  133. out_channels,
  134. scale=8,
  135. kernel_size=3,
  136. dilation=1):
  137. super(Res2NetBlock, self).__init__()
  138. assert in_channels % scale == 0
  139. assert out_channels % scale == 0
  140. in_channel = in_channels // scale
  141. hidden_channel = out_channels // scale
  142. self.blocks = nn.ModuleList([
  143. TDNNBlock(
  144. in_channel,
  145. hidden_channel,
  146. kernel_size=kernel_size,
  147. dilation=dilation,
  148. ) for i in range(scale - 1)
  149. ])
  150. self.scale = scale
  151. def forward(self, x):
  152. y = []
  153. for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
  154. if i == 0:
  155. y_i = x_i
  156. elif i == 1:
  157. y_i = self.blocks[i - 1](x_i)
  158. else:
  159. y_i = self.blocks[i - 1](x_i + y_i)
  160. y.append(y_i)
  161. y = torch.cat(y, dim=1)
  162. return y
  163. class SEBlock(nn.Module):
  164. def __init__(self, in_channels, se_channels, out_channels):
  165. super(SEBlock, self).__init__()
  166. self.conv1 = Conv1d(
  167. in_channels=in_channels, out_channels=se_channels, kernel_size=1)
  168. self.relu = torch.nn.ReLU(inplace=True)
  169. self.conv2 = Conv1d(
  170. in_channels=se_channels, out_channels=out_channels, kernel_size=1)
  171. self.sigmoid = torch.nn.Sigmoid()
  172. def forward(self, x, lengths=None):
  173. L = x.shape[-1]
  174. if lengths is not None:
  175. mask = length_to_mask(lengths * L, max_len=L, device=x.device)
  176. mask = mask.unsqueeze(1)
  177. total = mask.sum(dim=2, keepdim=True)
  178. s = (x * mask).sum(dim=2, keepdim=True) / total
  179. else:
  180. s = x.mean(dim=2, keepdim=True)
  181. s = self.relu(self.conv1(s))
  182. s = self.sigmoid(self.conv2(s))
  183. return s * x
  184. class AttentiveStatisticsPooling(nn.Module):
  185. def __init__(self, channels, attention_channels=128, global_context=True):
  186. super().__init__()
  187. self.eps = 1e-12
  188. self.global_context = global_context
  189. if global_context:
  190. self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
  191. else:
  192. self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
  193. self.tanh = nn.Tanh()
  194. self.conv = Conv1d(
  195. in_channels=attention_channels,
  196. out_channels=channels,
  197. kernel_size=1)
  198. def forward(self, x, lengths=None):
  199. L = x.shape[-1]
  200. def _compute_statistics(x, m, dim=2, eps=self.eps):
  201. mean = (m * x).sum(dim)
  202. std = torch.sqrt(
  203. (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
  204. return mean, std
  205. if lengths is None:
  206. lengths = torch.ones(x.shape[0], device=x.device)
  207. # Make binary mask of shape [N, 1, L]
  208. mask = length_to_mask(lengths * L, max_len=L, device=x.device)
  209. mask = mask.unsqueeze(1)
  210. # Expand the temporal context of the pooling layer by allowing the
  211. # self-attention to look at global properties of the utterance.
  212. if self.global_context:
  213. # torch.std is unstable for backward computation
  214. # https://github.com/pytorch/pytorch/issues/4320
  215. total = mask.sum(dim=2, keepdim=True).float()
  216. mean, std = _compute_statistics(x, mask / total)
  217. mean = mean.unsqueeze(2).repeat(1, 1, L)
  218. std = std.unsqueeze(2).repeat(1, 1, L)
  219. attn = torch.cat([x, mean, std], dim=1)
  220. else:
  221. attn = x
  222. # Apply layers
  223. attn = self.conv(self.tanh(self.tdnn(attn)))
  224. # Filter out zero-paddings
  225. attn = attn.masked_fill(mask == 0, float('-inf'))
  226. attn = F.softmax(attn, dim=2)
  227. mean, std = _compute_statistics(x, attn)
  228. # Append mean and std of the batch
  229. pooled_stats = torch.cat((mean, std), dim=1)
  230. pooled_stats = pooled_stats.unsqueeze(2)
  231. return pooled_stats
  232. class SERes2NetBlock(nn.Module):
  233. def __init__(
  234. self,
  235. in_channels,
  236. out_channels,
  237. res2net_scale=8,
  238. se_channels=128,
  239. kernel_size=1,
  240. dilation=1,
  241. activation=torch.nn.ReLU,
  242. groups=1,
  243. ):
  244. super().__init__()
  245. self.out_channels = out_channels
  246. self.tdnn1 = TDNNBlock(
  247. in_channels,
  248. out_channels,
  249. kernel_size=1,
  250. dilation=1,
  251. activation=activation,
  252. groups=groups,
  253. )
  254. self.res2net_block = Res2NetBlock(out_channels, out_channels,
  255. res2net_scale, kernel_size, dilation)
  256. self.tdnn2 = TDNNBlock(
  257. out_channels,
  258. out_channels,
  259. kernel_size=1,
  260. dilation=1,
  261. activation=activation,
  262. groups=groups,
  263. )
  264. self.se_block = SEBlock(out_channels, se_channels, out_channels)
  265. self.shortcut = None
  266. if in_channels != out_channels:
  267. self.shortcut = Conv1d(
  268. in_channels=in_channels,
  269. out_channels=out_channels,
  270. kernel_size=1,
  271. )
  272. def forward(self, x, lengths=None):
  273. residual = x
  274. if self.shortcut:
  275. residual = self.shortcut(x)
  276. x = self.tdnn1(x)
  277. x = self.res2net_block(x)
  278. x = self.tdnn2(x)
  279. x = self.se_block(x, lengths)
  280. return x + residual
  281. class ECAPA_TDNN(nn.Module):
  282. """An implementation of the speaker embedding model in a paper.
  283. "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
  284. TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
  285. """
  286. def __init__(
  287. self,
  288. input_size,
  289. device='cpu',
  290. lin_neurons=192,
  291. activation=torch.nn.ReLU,
  292. channels=[512, 512, 512, 512, 1536],
  293. kernel_sizes=[5, 3, 3, 3, 1],
  294. dilations=[1, 2, 3, 4, 1],
  295. attention_channels=128,
  296. res2net_scale=8,
  297. se_channels=128,
  298. global_context=True,
  299. groups=[1, 1, 1, 1, 1],
  300. ):
  301. super().__init__()
  302. assert len(channels) == len(kernel_sizes)
  303. assert len(channels) == len(dilations)
  304. self.channels = channels
  305. self.blocks = nn.ModuleList()
  306. # The initial TDNN layer
  307. self.blocks.append(
  308. TDNNBlock(
  309. input_size,
  310. channels[0],
  311. kernel_sizes[0],
  312. dilations[0],
  313. activation,
  314. groups[0],
  315. ))
  316. # SE-Res2Net layers
  317. for i in range(1, len(channels) - 1):
  318. self.blocks.append(
  319. SERes2NetBlock(
  320. channels[i - 1],
  321. channels[i],
  322. res2net_scale=res2net_scale,
  323. se_channels=se_channels,
  324. kernel_size=kernel_sizes[i],
  325. dilation=dilations[i],
  326. activation=activation,
  327. groups=groups[i],
  328. ))
  329. # Multi-layer feature aggregation
  330. self.mfa = TDNNBlock(
  331. channels[-1],
  332. channels[-1],
  333. kernel_sizes[-1],
  334. dilations[-1],
  335. activation,
  336. groups=groups[-1],
  337. )
  338. # Attentive Statistical Pooling
  339. self.asp = AttentiveStatisticsPooling(
  340. channels[-1],
  341. attention_channels=attention_channels,
  342. global_context=global_context,
  343. )
  344. self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
  345. # Final linear transformation
  346. self.fc = Conv1d(
  347. in_channels=channels[-1] * 2,
  348. out_channels=lin_neurons,
  349. kernel_size=1,
  350. )
  351. def forward(self, x, lengths=None):
  352. """Returns the embedding vector.
  353. Arguments
  354. ---------
  355. x : torch.Tensor
  356. Tensor of shape (batch, time, channel).
  357. """
  358. x = x.transpose(1, 2)
  359. xl = []
  360. for layer in self.blocks:
  361. try:
  362. x = layer(x, lengths=lengths)
  363. except TypeError:
  364. x = layer(x)
  365. xl.append(x)
  366. # Multi-layer feature aggregation
  367. x = torch.cat(xl[1:], dim=1)
  368. x = self.mfa(x)
  369. # Attentive Statistical Pooling
  370. x = self.asp(x, lengths=lengths)
  371. x = self.asp_bn(x)
  372. # Final linear transformation
  373. x = self.fc(x)
  374. x = x.transpose(1, 2).squeeze(1)
  375. return x
  376. @MODELS.register_module(
  377. Tasks.speaker_verification, module_name=Models.ecapa_tdnn_sv)
  378. class SpeakerVerificationECAPATDNN(TorchModel):
  379. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  380. **kwargs):
  381. super().__init__(model_dir, model_config, *args, **kwargs)
  382. self.model_config = model_config
  383. self.other_config = kwargs
  384. if self.model_config['channel'] != 1024:
  385. raise ValueError(
  386. 'modelscope error: Currently only 1024-channel ecapa tdnn is supported.'
  387. )
  388. self.feature_dim = 80
  389. channels_config = [1024, 1024, 1024, 1024, 3072]
  390. self.device = create_device(self.other_config['device'])
  391. print(self.device)
  392. self.embedding_model = ECAPA_TDNN(
  393. self.feature_dim, channels=channels_config)
  394. pretrained_model_name = kwargs['pretrained_model']
  395. self.__load_check_point(pretrained_model_name)
  396. self.embedding_model.to(self.device)
  397. self.embedding_model.eval()
  398. def forward(self, audio):
  399. if isinstance(audio, np.ndarray):
  400. audio = torch.from_numpy(audio)
  401. if len(audio.shape) == 1:
  402. audio = audio.unsqueeze(0)
  403. assert len(
  404. audio.shape
  405. ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
  406. # audio shape: [N, T]
  407. feature = self.__extract_feature(audio)
  408. embedding = self.embedding_model(feature.to(self.device))
  409. return embedding.detach().cpu()
  410. def __extract_feature(self, audio):
  411. features = []
  412. for au in audio:
  413. feature = Kaldi.fbank(
  414. au.unsqueeze(0), num_mel_bins=self.feature_dim)
  415. feature = feature - feature.mean(dim=0, keepdim=True)
  416. features.append(feature.unsqueeze(0))
  417. features = torch.cat(features)
  418. return features
  419. def __load_check_point(self, pretrained_model_name):
  420. self.embedding_model.load_state_dict(
  421. torch.load(
  422. os.path.join(self.model_dir, pretrained_model_name),
  423. map_location=torch.device('cpu')),
  424. strict=True)