# Copyright (c) Alibaba, Inc. and its affiliates. """ This ECAPA-TDNN implementation is adapted from https://github.com/speechbrain/speechbrain. RDINOHead implementation is adapted from DINO framework. """ import math import os from typing import Any, Dict, Union import torch import torch.nn as nn import torch.nn.functional as F import torchaudio.compliance.kaldi as Kaldi from modelscope.metainfo import Models from modelscope.models import MODELS, TorchModel from modelscope.utils.constant import Tasks def length_to_mask(length, max_len=None, dtype=None, device=None): assert len(length.shape) == 1 if max_len is None: max_len = length.max().long().item() mask = torch.arange( max_len, device=length.device, dtype=length.dtype).expand( len(length), max_len) < length.unsqueeze(1) if dtype is None: dtype = length.dtype if device is None: device = length.device mask = torch.as_tensor(mask, dtype=dtype, device=device) return mask def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): if stride > 1: n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) L_out = stride * (n_steps - 1) + kernel_size * dilation padding = [kernel_size // 2, kernel_size // 2] else: L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1 padding = [(L_in - L_out) // 2, (L_in - L_out) // 2] return padding class Conv1d(nn.Module): def __init__( self, out_channels, kernel_size, in_channels, stride=1, dilation=1, padding='same', groups=1, bias=True, padding_mode='reflect', ): super().__init__() self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.padding = padding self.padding_mode = padding_mode self.conv = nn.Conv1d( in_channels, out_channels, self.kernel_size, stride=self.stride, dilation=self.dilation, padding=0, groups=groups, bias=bias, ) def forward(self, x): if self.padding == 'same': x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) elif self.padding == 'causal': num_pad = (self.kernel_size - 1) * self.dilation x = F.pad(x, (num_pad, 0)) elif self.padding == 'valid': pass else: raise ValueError( "Padding must be 'same', 'valid' or 'causal'. Got " + self.padding) wx = self.conv(x) return wx def _manage_padding( self, x, kernel_size: int, dilation: int, stride: int, ): L_in = x.shape[-1] padding = get_padding_elem(L_in, stride, kernel_size, dilation) x = F.pad(x, padding, mode=self.padding_mode) return x class BatchNorm1d(nn.Module): def __init__( self, input_size, eps=1e-05, momentum=0.1, ): super().__init__() self.norm = nn.BatchNorm1d( input_size, eps=eps, momentum=momentum, ) def forward(self, x): return self.norm(x) class TDNNBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, dilation, activation=nn.ReLU, groups=1, ): super(TDNNBlock, self).__init__() self.conv = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, groups=groups, ) self.activation = activation() self.norm = BatchNorm1d(input_size=out_channels) def forward(self, x): return self.norm(self.activation(self.conv(x))) class Res2NetBlock(torch.nn.Module): def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): super(Res2NetBlock, self).__init__() assert in_channels % scale == 0 assert out_channels % scale == 0 in_channel = in_channels // scale hidden_channel = out_channels // scale self.blocks = nn.ModuleList([ TDNNBlock( in_channel, hidden_channel, kernel_size=kernel_size, dilation=dilation, ) for i in range(scale - 1) ]) self.scale = scale def forward(self, x): y = [] for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)): if i == 0: y_i = x_i elif i == 1: y_i = self.blocks[i - 1](x_i) else: y_i = self.blocks[i - 1](x_i + y_i) y.append(y_i) y = torch.cat(y, dim=1) return y class SEBlock(nn.Module): def __init__(self, in_channels, se_channels, out_channels): super(SEBlock, self).__init__() self.conv1 = Conv1d( in_channels=in_channels, out_channels=se_channels, kernel_size=1) self.relu = torch.nn.ReLU(inplace=True) self.conv2 = Conv1d( in_channels=se_channels, out_channels=out_channels, kernel_size=1) self.sigmoid = torch.nn.Sigmoid() def forward(self, x, lengths=None): L = x.shape[-1] if lengths is not None: mask = length_to_mask(lengths * L, max_len=L, device=x.device) mask = mask.unsqueeze(1) total = mask.sum(dim=2, keepdim=True) s = (x * mask).sum(dim=2, keepdim=True) / total else: s = x.mean(dim=2, keepdim=True) s = self.relu(self.conv1(s)) s = self.sigmoid(self.conv2(s)) return s * x class AttentiveStatisticsPooling(nn.Module): def __init__(self, channels, attention_channels=128, global_context=True): super().__init__() self.eps = 1e-12 self.global_context = global_context if global_context: self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) else: self.tdnn = TDNNBlock(channels, attention_channels, 1, 1) self.tanh = nn.Tanh() self.conv = Conv1d( in_channels=attention_channels, out_channels=channels, kernel_size=1) def forward(self, x, lengths=None): L = x.shape[-1] def _compute_statistics(x, m, dim=2, eps=self.eps): mean = (m * x).sum(dim) std = torch.sqrt( (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)) return mean, std if lengths is None: lengths = torch.ones(x.shape[0], device=x.device) # Make binary mask of shape [N, 1, L] mask = length_to_mask(lengths * L, max_len=L, device=x.device) mask = mask.unsqueeze(1) # Expand the temporal context of the pooling layer by allowing the # self-attention to look at global properties of the utterance. if self.global_context: # torch.std is unstable for backward computation # https://github.com/pytorch/pytorch/issues/4320 total = mask.sum(dim=2, keepdim=True).float() mean, std = _compute_statistics(x, mask / total) mean = mean.unsqueeze(2).repeat(1, 1, L) std = std.unsqueeze(2).repeat(1, 1, L) attn = torch.cat([x, mean, std], dim=1) else: attn = x # Apply layers attn = self.conv(self.tanh(self.tdnn(attn))) # Filter out zero-paddings attn = attn.masked_fill(mask == 0, float('-inf')) attn = F.softmax(attn, dim=2) mean, std = _compute_statistics(x, attn) # Append mean and std of the batch pooled_stats = torch.cat((mean, std), dim=1) pooled_stats = pooled_stats.unsqueeze(2) return pooled_stats class SERes2NetBlock(nn.Module): def __init__( self, in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1, activation=torch.nn.ReLU, groups=1, ): super().__init__() self.out_channels = out_channels self.tdnn1 = TDNNBlock( in_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, ) self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) self.tdnn2 = TDNNBlock( out_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, ) self.se_block = SEBlock(out_channels, se_channels, out_channels) self.shortcut = None if in_channels != out_channels: self.shortcut = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, ) def forward(self, x, lengths=None): residual = x if self.shortcut: residual = self.shortcut(x) x = self.tdnn1(x) x = self.res2net_block(x) x = self.tdnn2(x) x = self.se_block(x, lengths) return x + residual class ECAPA_TDNN(nn.Module): """An implementation of the speaker embedding model in a paper. "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143). """ def __init__( self, input_size, device='cpu', lin_neurons=512, activation=torch.nn.ReLU, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=8, se_channels=128, global_context=True, groups=[1, 1, 1, 1, 1], ): super().__init__() assert len(channels) == len(kernel_sizes) assert len(channels) == len(dilations) self.channels = channels self.blocks = nn.ModuleList() # The initial TDNN layer self.blocks.append( TDNNBlock( input_size, channels[0], kernel_sizes[0], dilations[0], activation, groups[0], )) # SE-Res2Net layers for i in range(1, len(channels) - 1): self.blocks.append( SERes2NetBlock( channels[i - 1], channels[i], res2net_scale=res2net_scale, se_channels=se_channels, kernel_size=kernel_sizes[i], dilation=dilations[i], activation=activation, groups=groups[i], )) # Multi-layer feature aggregation self.mfa = TDNNBlock( channels[-1], channels[-1], kernel_sizes[-1], dilations[-1], activation, groups=groups[-1], ) # Attentive Statistical Pooling self.asp = AttentiveStatisticsPooling( channels[-1], attention_channels=attention_channels, global_context=global_context, ) self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2) # Final linear transformation self.fc = Conv1d( in_channels=channels[-1] * 2, out_channels=lin_neurons, kernel_size=1, ) def forward(self, x, lengths=None): """Returns the embedding vector. Arguments --------- x : torch.Tensor Tensor of shape (batch, time, channel). """ x = x.transpose(1, 2) xl = [] for layer in self.blocks: try: x = layer(x, lengths=lengths) except TypeError: x = layer(x) xl.append(x) # Multi-layer feature aggregation x = torch.cat(xl[1:], dim=1) x = self.mfa(x) # Attentive Statistical Pooling x = self.asp(x, lengths=lengths) x = self.asp_bn(x) # Final linear transformation x = self.fc(x) x = x.transpose(1, 2).squeeze(1) return x class RDINOHead(nn.Module): def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256, add_dim=8192): super().__init__() nlayers = max(nlayers, 1) if nlayers == 1: self.mlp = nn.Linear(in_dim, bottleneck_dim) else: layers = [nn.Linear(in_dim, hidden_dim)] if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) for _ in range(nlayers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim)) if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, add_dim)) self.mlp = nn.Sequential(*layers) self.add_layer = nn.Linear(add_dim, bottleneck_dim) self.apply(self._init_weights) self.last_layer = nn.utils.weight_norm( nn.Linear(bottleneck_dim, out_dim, bias=False)) self.last_layer.weight_g.data.fill_(1) if norm_last_layer: self.last_layer.weight_g.requires_grad = False def _init_weights(self, m): if isinstance(m, nn.Linear): torch.nn.init.trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): vicr_out = self.mlp(x) x = self.add_layer(vicr_out) x = nn.functional.normalize(x, dim=-1, p=2) x = self.last_layer(x) return vicr_out, x class Combine(nn.Module): def __init__(self, backbone, head): super(Combine, self).__init__() self.backbone = backbone self.head = head def forward(self, x): x = self.backbone(x) output = self.head(x) return output @MODELS.register_module( Tasks.speaker_verification, module_name=Models.rdino_tdnn_sv) class SpeakerVerification_RDINO(TorchModel): def __init__(self, model_dir, model_config: Dict[str, Any], *args, **kwargs): super().__init__(model_dir, model_config, *args, **kwargs) self.model_config = model_config self.other_config = kwargs if self.model_config['channel'] != 1024: raise ValueError( 'modelscope error: Currently only 1024-channel ecapa tdnn is supported.' ) self.feature_dim = 80 channels_config = [1024, 1024, 1024, 1024, 3072] self.embedding_model = ECAPA_TDNN( self.feature_dim, channels=channels_config) self.embedding_model = Combine(self.embedding_model, RDINOHead(512, 65536, True)) pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) self.embedding_model.eval() def forward(self, audio): assert len(audio.shape) == 2 and audio.shape[ 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]' # audio shape: [1, T] feature = self.__extract_feature(audio) embedding = self.embedding_model.backbone(feature) return embedding def __extract_feature(self, audio): feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim) feature = feature - feature.mean(dim=0, keepdim=True) feature = feature.unsqueeze(0) return feature def __load_check_point(self, pretrained_model_name, device=None): if not device: device = torch.device('cpu') state_dict = torch.load( os.path.join(self.model_dir, pretrained_model_name), map_location=device) state_dict_tea = { k.replace('module.', ''): v for k, v in state_dict['teacher'].items() } self.embedding_model.load_state_dict(state_dict_tea, strict=True)