# Copyright (c) Alibaba, Inc. and its affiliates. """ This ECAPA-TDNN implementation is adapted from https://github.com/speechbrain/speechbrain. Self-Distillation Prototypes Network(SDPN) is a self-supervised learning framework in SV. It comprises a teacher and a student network with identical architecture but different parameters. Teacher/student network consists of three main modules: the encoder for extracting speaker embeddings, multi-layer perceptron for feature transformation, and prototypes for computing soft-distributions between global and local views. EMA denotes Exponential Moving Average. """ import math import os from typing import Any, Dict, Union import torch import torch.nn as nn import torch.nn.functional as F import torchaudio.compliance.kaldi as Kaldi from modelscope.metainfo import Models from modelscope.models import MODELS, TorchModel from modelscope.utils.constant import Tasks def length_to_mask(length, max_len=None, dtype=None, device=None): assert len(length.shape) == 1 if max_len is None: max_len = length.max().long().item() mask = torch.arange( max_len, device=length.device, dtype=length.dtype).expand( len(length), max_len) < length.unsqueeze(1) if dtype is None: dtype = length.dtype if device is None: device = length.device mask = torch.as_tensor(mask, dtype=dtype, device=device) return mask def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): if stride > 1: n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) L_out = stride * (n_steps - 1) + kernel_size * dilation padding = [kernel_size // 2, kernel_size // 2] else: L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1 padding = [(L_in - L_out) // 2, (L_in - L_out) // 2] return padding class Conv1d(nn.Module): def __init__( self, out_channels, kernel_size, in_channels, stride=1, dilation=1, padding='same', groups=1, bias=True, padding_mode='reflect', ): super().__init__() self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.padding = padding self.padding_mode = padding_mode self.conv = nn.Conv1d( in_channels, out_channels, self.kernel_size, stride=self.stride, dilation=self.dilation, padding=0, groups=groups, bias=bias, ) def forward(self, x): if self.padding == 'same': x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) elif self.padding == 'causal': num_pad = (self.kernel_size - 1) * self.dilation x = F.pad(x, (num_pad, 0)) elif self.padding == 'valid': pass else: raise ValueError( "Padding must be 'same', 'valid' or 'causal'. Got " + self.padding) wx = self.conv(x) return wx def _manage_padding( self, x, kernel_size: int, dilation: int, stride: int, ): L_in = x.shape[-1] padding = get_padding_elem(L_in, stride, kernel_size, dilation) x = F.pad(x, padding, mode=self.padding_mode) return x class BatchNorm1d(nn.Module): def __init__( self, input_size, eps=1e-05, momentum=0.1, ): super().__init__() self.norm = nn.BatchNorm1d( input_size, eps=eps, momentum=momentum, ) def forward(self, x): return self.norm(x) class TDNNBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, dilation, activation=nn.ReLU, groups=1, ): super(TDNNBlock, self).__init__() self.conv = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, groups=groups, ) self.activation = activation() self.norm = BatchNorm1d(input_size=out_channels) def forward(self, x): return self.norm(self.activation(self.conv(x))) class Res2NetBlock(torch.nn.Module): def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): super(Res2NetBlock, self).__init__() assert in_channels % scale == 0 assert out_channels % scale == 0 in_channel = in_channels // scale hidden_channel = out_channels // scale self.blocks = nn.ModuleList([ TDNNBlock( in_channel, hidden_channel, kernel_size=kernel_size, dilation=dilation, ) for i in range(scale - 1) ]) self.scale = scale def forward(self, x): y = [] for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)): if i == 0: y_i = x_i elif i == 1: y_i = self.blocks[i - 1](x_i) else: y_i = self.blocks[i - 1](x_i + y_i) y.append(y_i) y = torch.cat(y, dim=1) return y class SEBlock(nn.Module): def __init__(self, in_channels, se_channels, out_channels): super(SEBlock, self).__init__() self.conv1 = Conv1d( in_channels=in_channels, out_channels=se_channels, kernel_size=1) self.relu = torch.nn.ReLU(inplace=True) self.conv2 = Conv1d( in_channels=se_channels, out_channels=out_channels, kernel_size=1) self.sigmoid = torch.nn.Sigmoid() def forward(self, x, lengths=None): L = x.shape[-1] if lengths is not None: mask = length_to_mask(lengths * L, max_len=L, device=x.device) mask = mask.unsqueeze(1) total = mask.sum(dim=2, keepdim=True) s = (x * mask).sum(dim=2, keepdim=True) / total else: s = x.mean(dim=2, keepdim=True) s = self.relu(self.conv1(s)) s = self.sigmoid(self.conv2(s)) return s * x class AttentiveStatisticsPooling(nn.Module): def __init__(self, channels, attention_channels=128, global_context=True): super().__init__() self.eps = 1e-12 self.global_context = global_context if global_context: self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) else: self.tdnn = TDNNBlock(channels, attention_channels, 1, 1) self.tanh = nn.Tanh() self.conv = Conv1d( in_channels=attention_channels, out_channels=channels, kernel_size=1) def forward(self, x, lengths=None): L = x.shape[-1] def _compute_statistics(x, m, dim=2, eps=self.eps): mean = (m * x).sum(dim) std = torch.sqrt( (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)) return mean, std if lengths is None: lengths = torch.ones(x.shape[0], device=x.device) # Make binary mask of shape [N, 1, L] mask = length_to_mask(lengths * L, max_len=L, device=x.device) mask = mask.unsqueeze(1) # Expand the temporal context of the pooling layer by allowing the # self-attention to look at global properties of the utterance. if self.global_context: # torch.std is unstable for backward computation # https://github.com/pytorch/pytorch/issues/4320 total = mask.sum(dim=2, keepdim=True).float() mean, std = _compute_statistics(x, mask / total) mean = mean.unsqueeze(2).repeat(1, 1, L) std = std.unsqueeze(2).repeat(1, 1, L) attn = torch.cat([x, mean, std], dim=1) else: attn = x # Apply layers attn = self.conv(self.tanh(self.tdnn(attn))) # Filter out zero-paddings attn = attn.masked_fill(mask == 0, float('-inf')) attn = F.softmax(attn, dim=2) mean, std = _compute_statistics(x, attn) # Append mean and std of the batch pooled_stats = torch.cat((mean, std), dim=1) pooled_stats = pooled_stats.unsqueeze(2) return pooled_stats class SERes2NetBlock(nn.Module): def __init__( self, in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1, activation=torch.nn.ReLU, groups=1, ): super().__init__() self.out_channels = out_channels self.tdnn1 = TDNNBlock( in_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, ) self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) self.tdnn2 = TDNNBlock( out_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, ) self.se_block = SEBlock(out_channels, se_channels, out_channels) self.shortcut = None if in_channels != out_channels: self.shortcut = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, ) def forward(self, x, lengths=None): residual = x if self.shortcut: residual = self.shortcut(x) x = self.tdnn1(x) x = self.res2net_block(x) x = self.tdnn2(x) x = self.se_block(x, lengths) return x + residual class ECAPA_TDNN(nn.Module): """An implementation of the speaker embedding model in a paper. "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143). """ def __init__( self, input_size, device='cpu', lin_neurons=512, activation=torch.nn.ReLU, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=8, se_channels=128, global_context=True, groups=[1, 1, 1, 1, 1], ): super().__init__() assert len(channels) == len(kernel_sizes) assert len(channels) == len(dilations) self.channels = channels self.blocks = nn.ModuleList() # The initial TDNN layer self.blocks.append( TDNNBlock( input_size, channels[0], kernel_sizes[0], dilations[0], activation, groups[0], )) # SE-Res2Net layers for i in range(1, len(channels) - 1): self.blocks.append( SERes2NetBlock( channels[i - 1], channels[i], res2net_scale=res2net_scale, se_channels=se_channels, kernel_size=kernel_sizes[i], dilation=dilations[i], activation=activation, groups=groups[i], )) # Multi-layer feature aggregation self.mfa = TDNNBlock( channels[-1], channels[-1], kernel_sizes[-1], dilations[-1], activation, groups=groups[-1], ) # Attentive Statistical Pooling self.asp = AttentiveStatisticsPooling( channels[-1], attention_channels=attention_channels, global_context=global_context, ) self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2) # Final linear transformation self.fc = Conv1d( in_channels=channels[-1] * 2, out_channels=lin_neurons, kernel_size=1, ) def forward(self, x, lengths=None): """Returns the embedding vector. Arguments --------- x : torch.Tensor Tensor of shape (batch, time, channel). """ x = x.transpose(1, 2) xl = [] for layer in self.blocks: try: x = layer(x, lengths=lengths) except TypeError: x = layer(x) xl.append(x) # Multi-layer feature aggregation x = torch.cat(xl[1:], dim=1) x = self.mfa(x) # Attentive Statistical Pooling x = self.asp(x, lengths=lengths) x = self.asp_bn(x) # Final linear transformation x = self.fc(x) x = x.transpose(1, 2).squeeze(1) return x def _no_grad_trunc_normal_(tensor, mean, std, a, b): def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_.' 'The distribution of values may be incorrect.', stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l_ = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l_, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l_ - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # type: (Tensor, float, float, float, float) -> Tensor return _no_grad_trunc_normal_(tensor, mean, std, a, b) class SDPNHead(nn.Module): def __init__(self, in_dim, use_bn=False, nlayers=3, hidden_dim=2048, bottleneck_dim=256): super().__init__() nlayers = max(nlayers, 1) if nlayers == 1: self.mlp = nn.Linear(in_dim, bottleneck_dim) else: layers = [nn.Linear(in_dim, hidden_dim)] if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) for _ in range(nlayers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim)) if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, bottleneck_dim)) self.mlp = nn.Sequential(*layers) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.mlp(x) x = nn.functional.normalize(x, dim=-1, p=2) return x class Combiner(torch.nn.Module): """ Combine backbone (ECAPA) and head (MLP) """ def __init__(self, backbone, head): super(Combiner, self).__init__() self.backbone = backbone self.head = head def forward(self, x): x = self.backbone(x) output = self.head(x) return x, output @MODELS.register_module(Tasks.speaker_verification, module_name=Models.sdpn_sv) class SpeakerVerificationSDPN(TorchModel): """ Self-Distillation Prototypes Network (SDPN) effectively facilitates self-supervised speaker representation learning. The specific structure can be referred to in https://arxiv.org/pdf/2308.02774. """ def __init__(self, model_dir, model_config: Dict[str, Any], *args, **kwargs): super().__init__(model_dir, model_config, *args, **kwargs) self.model_config = model_config self.other_config = kwargs if self.model_config['channel'] != 1024: raise ValueError( 'modelscope error: Currently only 1024-channel ecapa tdnn is supported.' ) self.feature_dim = 80 channels_config = [1024, 1024, 1024, 1024, 3072] self.embedding_model = ECAPA_TDNN( self.feature_dim, channels=channels_config) self.embedding_model = Combiner(self.embedding_model, SDPNHead(512, True)) pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) self.embedding_model.eval() def forward(self, audio): assert len(audio.shape) == 2 and audio.shape[ 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]' # audio shape: [1, T] feature = self.__extract_feature(audio) embedding = self.embedding_model.backbone(feature) return embedding def __extract_feature(self, audio): feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim) feature = feature - feature.mean(dim=0, keepdim=True) feature = feature.unsqueeze(0) return feature def __load_check_point(self, pretrained_model_name, device=None): if not device: device = torch.device('cpu') state_dict = torch.load( os.path.join(self.model_dir, pretrained_model_name), map_location=device) state_dict_tea = { k.replace('module.', ''): v for k, v in state_dict['teacher'].items() } self.embedding_model.load_state_dict(state_dict_tea, strict=True)