# Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class Conv1d_O(nn.Module): def __init__( self, out_channels, kernel_size, input_shape=None, in_channels=None, stride=1, dilation=1, padding='same', groups=1, bias=True, padding_mode='reflect', skip_transpose=False, ): super().__init__() self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.padding = padding self.padding_mode = padding_mode self.unsqueeze = False self.skip_transpose = skip_transpose if input_shape is None and in_channels is None: raise ValueError('Must provide one of input_shape or in_channels') if in_channels is None: in_channels = self._check_input_shape(input_shape) self.conv = nn.Conv1d( in_channels, out_channels, self.kernel_size, stride=self.stride, dilation=self.dilation, padding=0, groups=groups, bias=bias, ) def forward(self, x): """Returns the output of the convolution. Arguments --------- x : torch.Tensor (batch, time, channel) input to convolve. 2d or 4d tensors are expected. """ if not self.skip_transpose: x = x.transpose(1, -1) if self.unsqueeze: x = x.unsqueeze(1) if self.padding == 'same': x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) elif self.padding == 'causal': num_pad = (self.kernel_size - 1) * self.dilation x = F.pad(x, (num_pad, 0)) elif self.padding == 'valid': pass else: raise ValueError( "Padding must be 'same', 'valid' or 'causal'. Got " + self.padding) wx = self.conv(x) if self.unsqueeze: wx = wx.squeeze(1) if not self.skip_transpose: wx = wx.transpose(1, -1) return wx def _manage_padding( self, x, kernel_size: int, dilation: int, stride: int, ): # Detecting input shape L_in = x.shape[-1] # Time padding padding = get_padding_elem(L_in, stride, kernel_size, dilation) # Applying padding x = F.pad(x, padding, mode=self.padding_mode) return x def _check_input_shape(self, shape): """Checks the input shape and returns the number of input channels. """ if len(shape) == 2: self.unsqueeze = True in_channels = 1 elif self.skip_transpose: in_channels = shape[1] elif len(shape) == 3: in_channels = shape[2] else: raise ValueError('conv1d expects 2d, 3d inputs. Got ' + str(len(shape))) # Kernel size must be odd if self.kernel_size % 2 == 0: raise ValueError( 'The field kernel size must be an odd number. Got %s.' % (self.kernel_size)) return in_channels # Skip transpose as much as possible for efficiency class Conv1d(Conv1d_O): def __init__(self, *args, **kwargs): super().__init__(skip_transpose=True, *args, **kwargs) def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): """This function computes the number of elements to add for zero-padding. Arguments --------- L_in : int stride: int kernel_size : int dilation : int """ if stride > 1: n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) L_out = stride * (n_steps - 1) + kernel_size * dilation padding = [kernel_size // 2, kernel_size // 2] else: L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1 padding = [(L_in - L_out) // 2, (L_in - L_out) // 2] return padding class BatchNorm1d_O(nn.Module): def __init__( self, input_shape=None, input_size=None, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, combine_batch_time=False, skip_transpose=False, ): super().__init__() self.combine_batch_time = combine_batch_time self.skip_transpose = skip_transpose if input_size is None and skip_transpose: input_size = input_shape[1] elif input_size is None: input_size = input_shape[-1] self.norm = nn.BatchNorm1d( input_size, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, ) def forward(self, x): """Returns the normalized input tensor. Arguments --------- x : torch.Tensor (batch, time, [channels]) input to normalize. 2d or 3d tensors are expected in input 4d tensors can be used when combine_dims=True. """ shape_or = x.shape if self.combine_batch_time: if x.ndim == 3: x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) else: x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2]) elif not self.skip_transpose: x = x.transpose(-1, 1) x_n = self.norm(x) if self.combine_batch_time: x_n = x_n.reshape(shape_or) elif not self.skip_transpose: x_n = x_n.transpose(1, -1) return x_n class BatchNorm1d(BatchNorm1d_O): def __init__(self, *args, **kwargs): super().__init__(skip_transpose=True, *args, **kwargs) class Xvector(torch.nn.Module): """This model extracts X-vectors for speaker recognition and diarization. Arguments --------- device : str Device used e.g. "cpu" or "cuda". activation : torch class A class for constructing the activation layers. tdnn_blocks : int Number of time-delay neural (TDNN) layers. tdnn_channels : list of ints Output channels for TDNN layer. tdnn_kernel_sizes : list of ints List of kernel sizes for each TDNN layer. tdnn_dilations : list of ints List of dilations for kernels in each TDNN layer. lin_neurons : int Number of neurons in linear layers. Example ------- >>> compute_xvect = Xvector('cpu') >>> input_feats = torch.rand([5, 10, 40]) >>> outputs = compute_xvect(input_feats) >>> outputs.shape torch.Size([5, 1, 512]) """ def __init__( self, device='cpu', activation=torch.nn.LeakyReLU, tdnn_blocks=5, tdnn_channels=[512, 512, 512, 512, 1500], tdnn_kernel_sizes=[5, 3, 3, 1, 1], tdnn_dilations=[1, 2, 3, 1, 1], lin_neurons=512, in_channels=80, ): super().__init__() self.blocks = nn.ModuleList() # TDNN layers for block_index in range(tdnn_blocks): out_channels = tdnn_channels[block_index] self.blocks.extend([ Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=tdnn_kernel_sizes[block_index], dilation=tdnn_dilations[block_index], ), activation(), BatchNorm1d(input_size=out_channels), ]) in_channels = tdnn_channels[block_index] def forward(self, x, lens=None): """Returns the x-vectors. Arguments --------- x : torch.Tensor """ x = x.transpose(1, 2) for layer in self.blocks: try: x = layer(x, lengths=lens) except TypeError: x = layer(x) x = x.transpose(1, 2) return x