# Copyright (c) Alibaba, Inc. and its affiliates. import copy import os from typing import Any, Dict import torch import torch.nn as nn import torch.nn.functional as F from modelscope.metainfo import Models from modelscope.models import MODELS, TorchModel from modelscope.models.audio.separation.mossformer_block import ( MossFormerModule, ScaledSinuEmbedding) from modelscope.models.audio.separation.mossformer_conv_module import ( CumulativeLayerNorm, GlobalLayerNorm) from modelscope.models.base import Tensor from modelscope.utils.constant import Tasks EPS = 1e-8 @MODELS.register_module( Tasks.speech_separation, module_name=Models.speech_mossformer_separation_temporal_8k) class MossFormer(TorchModel): """Library to support MossFormer speech separation. Args: model_dir (str): the model path. """ def __init__(self, model_dir: str, *args, **kwargs): super().__init__(model_dir, *args, **kwargs) self.encoder = Encoder( kernel_size=kwargs['kernel_size'], out_channels=kwargs['out_channels']) self.decoder = Decoder( in_channels=kwargs['in_channels'], out_channels=1, kernel_size=kwargs['kernel_size'], stride=kwargs['stride'], bias=kwargs['bias']) self.mask_net = MossFormerMaskNet( kwargs['in_channels'], kwargs['out_channels'], MossFormerM(kwargs['num_blocks'], kwargs['d_model'], kwargs['attn_dropout'], kwargs['group_size'], kwargs['query_key_dim'], kwargs['expansion_factor'], kwargs['causal']), norm=kwargs['norm'], num_spks=kwargs['num_spks']) self.num_spks = kwargs['num_spks'] def forward(self, inputs: Tensor) -> Dict[str, Any]: # Separation mix_w = self.encoder(inputs) est_mask = self.mask_net(mix_w) mix_w = torch.stack([mix_w] * self.num_spks) sep_h = mix_w * est_mask # Decoding est_source = torch.cat( [ self.decoder(sep_h[i]).unsqueeze(-1) for i in range(self.num_spks) ], dim=-1, ) # T changed after conv1d in encoder, fix it here t_origin = inputs.size(1) t_est = est_source.size(1) if t_origin > t_est: est_source = F.pad(est_source, (0, 0, 0, t_origin - t_est)) else: est_source = est_source[:, :t_origin, :] return est_source def load_check_point(self, load_path=None, device=None): if not load_path: load_path = self.model_dir if not device: device = torch.device('cpu') self.encoder.load_state_dict( torch.load( os.path.join(load_path, 'encoder.bin'), map_location=device), strict=True) self.decoder.load_state_dict( torch.load( os.path.join(load_path, 'decoder.bin'), map_location=device), strict=True) self.mask_net.load_state_dict( torch.load( os.path.join(load_path, 'masknet.bin'), map_location=device), strict=True) def as_dict(self): return dict( encoder=self.encoder, decoder=self.decoder, masknet=self.mask_net) def select_norm(norm, dim, shape): """Just a wrapper to select the normalization type. """ if norm == 'gln': return GlobalLayerNorm(dim, shape, elementwise_affine=True) if norm == 'cln': return CumulativeLayerNorm(dim, elementwise_affine=True) if norm == 'ln': return nn.GroupNorm(1, dim, eps=1e-8) else: return nn.BatchNorm1d(dim) class Encoder(nn.Module): """Convolutional Encoder Layer. Args: kernel_size: Length of filters. in_channels: Number of input channels. out_channels: Number of output channels. Examples: >>> x = torch.randn(2, 1000) >>> encoder = Encoder(kernel_size=4, out_channels=64) >>> h = encoder(x) >>> h.shape # torch.Size([2, 64, 499]) """ def __init__(self, kernel_size: int = 2, out_channels: int = 64, in_channels: int = 1): super(Encoder, self).__init__() self.conv1d = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=kernel_size // 2, groups=1, bias=False, ) self.in_channels = in_channels def forward(self, x: torch.Tensor): """Return the encoded output. Args: x: Input tensor with dimensionality [B, L]. Returns: Encoded tensor with dimensionality [B, N, T_out]. where B = Batchsize L = Number of timepoints N = Number of filters T_out = Number of timepoints at the output of the encoder """ # B x L -> B x 1 x L if self.in_channels == 1: x = torch.unsqueeze(x, dim=1) # B x 1 x L -> B x N x T_out x = self.conv1d(x) x = F.relu(x) return x class Decoder(nn.ConvTranspose1d): """A decoder layer that consists of ConvTranspose1d. Args: kernel_size: Length of filters. in_channels: Number of input channels. out_channels: Number of output channels. Example --------- >>> x = torch.randn(2, 100, 1000) >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1) >>> h = decoder(x) >>> h.shape torch.Size([2, 1003]) """ def __init__(self, *args, **kwargs): super(Decoder, self).__init__(*args, **kwargs) def forward(self, x): """Return the decoded output. Args: x: Input tensor with dimensionality [B, N, L]. where, B = Batchsize, N = number of filters L = time points """ if x.dim() not in [2, 3]: raise RuntimeError('{} accept 3/4D tensor as input'.format( self.__name__)) x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) if torch.squeeze(x).dim() == 1: x = torch.squeeze(x, dim=1) else: x = torch.squeeze(x) return x class IdentityBlock: """This block is used when we want to have identity transformation within the Dual_path block. Example ------- >>> x = torch.randn(10, 100) >>> IB = IdentityBlock() >>> xhat = IB(x) """ def _init__(self, **kwargs): pass def __call__(self, x): return x class MossFormerM(nn.Module): """This class implements the transformer encoder. Args: num_blocks : int Number of mossformer blocks to include. d_model : int The dimension of the input embedding. attn_dropout : float Dropout for the self-attention (Optional). group_size: int the chunk size query_key_dim: int the attention vector dimension expansion_factor: int the expansion factor for the linear projection in conv module causal: bool true for causal / false for non causal Example ------- >>> import torch >>> x = torch.rand((8, 60, 512)) #B, S, N >>> net = MossFormerM(num_blocks=8, d_model=512) >>> output, _ = net(x) >>> output.shape torch.Size([8, 60, 512]) """ def __init__(self, num_blocks, d_model=None, attn_dropout=0.1, group_size=256, query_key_dim=128, expansion_factor=4., causal=False): super().__init__() self.mossformerM = MossFormerModule( dim=d_model, depth=num_blocks, group_size=group_size, query_key_dim=query_key_dim, expansion_factor=expansion_factor, causal=causal, attn_dropout=attn_dropout) import speechbrain as sb self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) def forward(self, src: torch.Tensor): """ Args: src: Tensor shape [B, S, N], where, B = Batchsize, S = time points N = number of filters The sequence to the encoder layer (required). """ output = self.mossformerM(src) output = self.norm(output) return output class ComputeAttention(nn.Module): """Computation block for dual-path processing. Args: att_mdl : torch.nn.module Model to process within the chunks. out_channels : int Dimensionality of attention model. norm : str Normalization type. skip_connection : bool Skip connection around the attention module. Example --------- >>> att_block = MossFormerM(num_blocks=8, d_model=512) >>> comp_att = ComputeAttention(att_block, 512) >>> x = torch.randn(10, 64, 512) >>> x = comp_att(x) >>> x.shape torch.Size([10, 64, 512]) """ def __init__( self, att_mdl, out_channels, norm='ln', skip_connection=True, ): super(ComputeAttention, self).__init__() self.att_mdl = att_mdl self.skip_connection = skip_connection # Norm self.norm = norm if norm is not None: self.att_norm = select_norm(norm, out_channels, 3) def forward(self, x: torch.Tensor): """Returns the output tensor. Args: x: Input tensor of dimension [B, S, N]. Returns: out: Output tensor of dimension [B, S, N]. where, B = Batchsize, N = number of filters S = time points """ # [B, S, N] att_out = x.permute(0, 2, 1).contiguous() att_out = self.att_mdl(att_out) # [B, N, S] att_out = att_out.permute(0, 2, 1).contiguous() if self.norm is not None: att_out = self.att_norm(att_out) # [B, N, S] if self.skip_connection: att_out = att_out + x out = att_out return out class MossFormerMaskNet(nn.Module): """The dual path model which is the basis for dualpathrnn, sepformer, dptnet. Args: in_channels : int Number of channels at the output of the encoder. out_channels : int Number of channels that would be inputted to the intra and inter blocks. att_model : torch.nn.module Attention model to process the input sequence. norm : str Normalization type. num_spks : int Number of sources (speakers). skip_connection : bool Skip connection around attention module. use_global_pos_enc : bool Global positional encodings. Example --------- >>> mossformer_block = MossFormerM(num_blocks=8, d_model=512) >>> mossformer_masknet = MossFormerMaskNet(64, 64, att_model, num_spks=2) >>> x = torch.randn(10, 64, 2000) >>> x = mossformer_masknet(x) >>> x.shape torch.Size([2, 10, 64, 2000]) """ def __init__( self, in_channels, out_channels, att_model, norm='ln', num_spks=2, skip_connection=True, use_global_pos_enc=True, ): super(MossFormerMaskNet, self).__init__() self.num_spks = num_spks self.norm = select_norm(norm, in_channels, 3) self.conv1d_encoder = nn.Conv1d( in_channels, out_channels, 1, bias=False) self.use_global_pos_enc = use_global_pos_enc if self.use_global_pos_enc: self.pos_enc = ScaledSinuEmbedding(out_channels) self.mdl = copy.deepcopy( ComputeAttention( att_model, out_channels, norm, skip_connection=skip_connection, )) self.conv1d_out = nn.Conv1d( out_channels, out_channels * num_spks, kernel_size=1) self.conv1_decoder = nn.Conv1d( out_channels, in_channels, 1, bias=False) self.prelu = nn.PReLU() self.activation = nn.ReLU() # gated output layer self.output = nn.Sequential( nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()) self.output_gate = nn.Sequential( nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()) def forward(self, x: torch.Tensor): """Returns the output tensor. Args: x: Input tensor of dimension [B, N, S]. Returns: out: Output tensor of dimension [spks, B, N, S] where, spks = Number of speakers B = Batchsize, N = number of filters S = the number of time frames """ # before each line we indicate the shape after executing the line # [B, N, L] x = self.norm(x) # [B, N, L] x = self.conv1d_encoder(x) if self.use_global_pos_enc: base = x x = x.transpose(1, -1) emb = self.pos_enc(x) emb = emb.transpose(0, -1) x = base + emb # [B, N, S] x = self.mdl(x) x = self.prelu(x) # [B, N*spks, S] x = self.conv1d_out(x) b, _, s = x.shape # [B*spks, N, S] x = x.view(b * self.num_spks, -1, s) # [B*spks, N, S] x = self.output(x) * self.output_gate(x) # [B*spks, N, S] x = self.conv1_decoder(x) # [B, spks, N, S] _, n, L = x.shape x = x.view(b, self.num_spks, n, L) x = self.activation(x) # [spks, B, N, S] x = x.transpose(0, 1) return x