| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import copy
- import os
- from typing import Any, Dict
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from modelscope.metainfo import Models
- from modelscope.models import MODELS, TorchModel
- from modelscope.models.audio.separation.mossformer_block import (
- MossFormerModule, ScaledSinuEmbedding)
- from modelscope.models.audio.separation.mossformer_conv_module import (
- CumulativeLayerNorm, GlobalLayerNorm)
- from modelscope.models.base import Tensor
- from modelscope.utils.constant import Tasks
- EPS = 1e-8
- @MODELS.register_module(
- Tasks.speech_separation,
- module_name=Models.speech_mossformer_separation_temporal_8k)
- class MossFormer(TorchModel):
- """Library to support MossFormer speech separation.
- Args:
- model_dir (str): the model path.
- """
- def __init__(self, model_dir: str, *args, **kwargs):
- super().__init__(model_dir, *args, **kwargs)
- self.encoder = Encoder(
- kernel_size=kwargs['kernel_size'],
- out_channels=kwargs['out_channels'])
- self.decoder = Decoder(
- in_channels=kwargs['in_channels'],
- out_channels=1,
- kernel_size=kwargs['kernel_size'],
- stride=kwargs['stride'],
- bias=kwargs['bias'])
- self.mask_net = MossFormerMaskNet(
- kwargs['in_channels'],
- kwargs['out_channels'],
- MossFormerM(kwargs['num_blocks'], kwargs['d_model'],
- kwargs['attn_dropout'], kwargs['group_size'],
- kwargs['query_key_dim'], kwargs['expansion_factor'],
- kwargs['causal']),
- norm=kwargs['norm'],
- num_spks=kwargs['num_spks'])
- self.num_spks = kwargs['num_spks']
- def forward(self, inputs: Tensor) -> Dict[str, Any]:
- # Separation
- mix_w = self.encoder(inputs)
- est_mask = self.mask_net(mix_w)
- mix_w = torch.stack([mix_w] * self.num_spks)
- sep_h = mix_w * est_mask
- # Decoding
- est_source = torch.cat(
- [
- self.decoder(sep_h[i]).unsqueeze(-1)
- for i in range(self.num_spks)
- ],
- dim=-1,
- )
- # T changed after conv1d in encoder, fix it here
- t_origin = inputs.size(1)
- t_est = est_source.size(1)
- if t_origin > t_est:
- est_source = F.pad(est_source, (0, 0, 0, t_origin - t_est))
- else:
- est_source = est_source[:, :t_origin, :]
- return est_source
- def load_check_point(self, load_path=None, device=None):
- if not load_path:
- load_path = self.model_dir
- if not device:
- device = torch.device('cpu')
- self.encoder.load_state_dict(
- torch.load(
- os.path.join(load_path, 'encoder.bin'), map_location=device),
- strict=True)
- self.decoder.load_state_dict(
- torch.load(
- os.path.join(load_path, 'decoder.bin'), map_location=device),
- strict=True)
- self.mask_net.load_state_dict(
- torch.load(
- os.path.join(load_path, 'masknet.bin'), map_location=device),
- strict=True)
- def as_dict(self):
- return dict(
- encoder=self.encoder, decoder=self.decoder, masknet=self.mask_net)
- def select_norm(norm, dim, shape):
- """Just a wrapper to select the normalization type.
- """
- if norm == 'gln':
- return GlobalLayerNorm(dim, shape, elementwise_affine=True)
- if norm == 'cln':
- return CumulativeLayerNorm(dim, elementwise_affine=True)
- if norm == 'ln':
- return nn.GroupNorm(1, dim, eps=1e-8)
- else:
- return nn.BatchNorm1d(dim)
- class Encoder(nn.Module):
- """Convolutional Encoder Layer.
- Args:
- kernel_size: Length of filters.
- in_channels: Number of input channels.
- out_channels: Number of output channels.
- Examples:
- >>> x = torch.randn(2, 1000)
- >>> encoder = Encoder(kernel_size=4, out_channels=64)
- >>> h = encoder(x)
- >>> h.shape # torch.Size([2, 64, 499])
- """
- def __init__(self,
- kernel_size: int = 2,
- out_channels: int = 64,
- in_channels: int = 1):
- super(Encoder, self).__init__()
- self.conv1d = nn.Conv1d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=kernel_size // 2,
- groups=1,
- bias=False,
- )
- self.in_channels = in_channels
- def forward(self, x: torch.Tensor):
- """Return the encoded output.
- Args:
- x: Input tensor with dimensionality [B, L].
- Returns:
- Encoded tensor with dimensionality [B, N, T_out].
- where B = Batchsize
- L = Number of timepoints
- N = Number of filters
- T_out = Number of timepoints at the output of the encoder
- """
- # B x L -> B x 1 x L
- if self.in_channels == 1:
- x = torch.unsqueeze(x, dim=1)
- # B x 1 x L -> B x N x T_out
- x = self.conv1d(x)
- x = F.relu(x)
- return x
- class Decoder(nn.ConvTranspose1d):
- """A decoder layer that consists of ConvTranspose1d.
- Args:
- kernel_size: Length of filters.
- in_channels: Number of input channels.
- out_channels: Number of output channels.
- Example
- ---------
- >>> x = torch.randn(2, 100, 1000)
- >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
- >>> h = decoder(x)
- >>> h.shape
- torch.Size([2, 1003])
- """
- def __init__(self, *args, **kwargs):
- super(Decoder, self).__init__(*args, **kwargs)
- def forward(self, x):
- """Return the decoded output.
- Args:
- x: Input tensor with dimensionality [B, N, L].
- where, B = Batchsize,
- N = number of filters
- L = time points
- """
- if x.dim() not in [2, 3]:
- raise RuntimeError('{} accept 3/4D tensor as input'.format(
- self.__name__))
- x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
- if torch.squeeze(x).dim() == 1:
- x = torch.squeeze(x, dim=1)
- else:
- x = torch.squeeze(x)
- return x
- class IdentityBlock:
- """This block is used when we want to have identity transformation within the Dual_path block.
- Example
- -------
- >>> x = torch.randn(10, 100)
- >>> IB = IdentityBlock()
- >>> xhat = IB(x)
- """
- def _init__(self, **kwargs):
- pass
- def __call__(self, x):
- return x
- class MossFormerM(nn.Module):
- """This class implements the transformer encoder.
- Args:
- num_blocks : int
- Number of mossformer blocks to include.
- d_model : int
- The dimension of the input embedding.
- attn_dropout : float
- Dropout for the self-attention (Optional).
- group_size: int
- the chunk size
- query_key_dim: int
- the attention vector dimension
- expansion_factor: int
- the expansion factor for the linear projection in conv module
- causal: bool
- true for causal / false for non causal
- Example
- -------
- >>> import torch
- >>> x = torch.rand((8, 60, 512)) #B, S, N
- >>> net = MossFormerM(num_blocks=8, d_model=512)
- >>> output, _ = net(x)
- >>> output.shape
- torch.Size([8, 60, 512])
- """
- def __init__(self,
- num_blocks,
- d_model=None,
- attn_dropout=0.1,
- group_size=256,
- query_key_dim=128,
- expansion_factor=4.,
- causal=False):
- super().__init__()
- self.mossformerM = MossFormerModule(
- dim=d_model,
- depth=num_blocks,
- group_size=group_size,
- query_key_dim=query_key_dim,
- expansion_factor=expansion_factor,
- causal=causal,
- attn_dropout=attn_dropout)
- import speechbrain as sb
- self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
- def forward(self, src: torch.Tensor):
- """
- Args:
- src: Tensor shape [B, S, N],
- where, B = Batchsize,
- S = time points
- N = number of filters
- The sequence to the encoder layer (required).
- """
- output = self.mossformerM(src)
- output = self.norm(output)
- return output
- class ComputeAttention(nn.Module):
- """Computation block for dual-path processing.
- Args:
- att_mdl : torch.nn.module
- Model to process within the chunks.
- out_channels : int
- Dimensionality of attention model.
- norm : str
- Normalization type.
- skip_connection : bool
- Skip connection around the attention module.
- Example
- ---------
- >>> att_block = MossFormerM(num_blocks=8, d_model=512)
- >>> comp_att = ComputeAttention(att_block, 512)
- >>> x = torch.randn(10, 64, 512)
- >>> x = comp_att(x)
- >>> x.shape
- torch.Size([10, 64, 512])
- """
- def __init__(
- self,
- att_mdl,
- out_channels,
- norm='ln',
- skip_connection=True,
- ):
- super(ComputeAttention, self).__init__()
- self.att_mdl = att_mdl
- self.skip_connection = skip_connection
- # Norm
- self.norm = norm
- if norm is not None:
- self.att_norm = select_norm(norm, out_channels, 3)
- def forward(self, x: torch.Tensor):
- """Returns the output tensor.
- Args:
- x: Input tensor of dimension [B, S, N].
- Returns:
- out: Output tensor of dimension [B, S, N].
- where, B = Batchsize,
- N = number of filters
- S = time points
- """
- # [B, S, N]
- att_out = x.permute(0, 2, 1).contiguous()
- att_out = self.att_mdl(att_out)
- # [B, N, S]
- att_out = att_out.permute(0, 2, 1).contiguous()
- if self.norm is not None:
- att_out = self.att_norm(att_out)
- # [B, N, S]
- if self.skip_connection:
- att_out = att_out + x
- out = att_out
- return out
- class MossFormerMaskNet(nn.Module):
- """The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
- Args:
- in_channels : int
- Number of channels at the output of the encoder.
- out_channels : int
- Number of channels that would be inputted to the intra and inter blocks.
- att_model : torch.nn.module
- Attention model to process the input sequence.
- norm : str
- Normalization type.
- num_spks : int
- Number of sources (speakers).
- skip_connection : bool
- Skip connection around attention module.
- use_global_pos_enc : bool
- Global positional encodings.
- Example
- ---------
- >>> mossformer_block = MossFormerM(num_blocks=8, d_model=512)
- >>> mossformer_masknet = MossFormerMaskNet(64, 64, att_model, num_spks=2)
- >>> x = torch.randn(10, 64, 2000)
- >>> x = mossformer_masknet(x)
- >>> x.shape
- torch.Size([2, 10, 64, 2000])
- """
- def __init__(
- self,
- in_channels,
- out_channels,
- att_model,
- norm='ln',
- num_spks=2,
- skip_connection=True,
- use_global_pos_enc=True,
- ):
- super(MossFormerMaskNet, self).__init__()
- self.num_spks = num_spks
- self.norm = select_norm(norm, in_channels, 3)
- self.conv1d_encoder = nn.Conv1d(
- in_channels, out_channels, 1, bias=False)
- self.use_global_pos_enc = use_global_pos_enc
- if self.use_global_pos_enc:
- self.pos_enc = ScaledSinuEmbedding(out_channels)
- self.mdl = copy.deepcopy(
- ComputeAttention(
- att_model,
- out_channels,
- norm,
- skip_connection=skip_connection,
- ))
- self.conv1d_out = nn.Conv1d(
- out_channels, out_channels * num_spks, kernel_size=1)
- self.conv1_decoder = nn.Conv1d(
- out_channels, in_channels, 1, bias=False)
- self.prelu = nn.PReLU()
- self.activation = nn.ReLU()
- # gated output layer
- self.output = nn.Sequential(
- nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
- self.output_gate = nn.Sequential(
- nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
- def forward(self, x: torch.Tensor):
- """Returns the output tensor.
- Args:
- x: Input tensor of dimension [B, N, S].
- Returns:
- out: Output tensor of dimension [spks, B, N, S]
- where, spks = Number of speakers
- B = Batchsize,
- N = number of filters
- S = the number of time frames
- """
- # before each line we indicate the shape after executing the line
- # [B, N, L]
- x = self.norm(x)
- # [B, N, L]
- x = self.conv1d_encoder(x)
- if self.use_global_pos_enc:
- base = x
- x = x.transpose(1, -1)
- emb = self.pos_enc(x)
- emb = emb.transpose(0, -1)
- x = base + emb
- # [B, N, S]
- x = self.mdl(x)
- x = self.prelu(x)
- # [B, N*spks, S]
- x = self.conv1d_out(x)
- b, _, s = x.shape
- # [B*spks, N, S]
- x = x.view(b * self.num_spks, -1, s)
- # [B*spks, N, S]
- x = self.output(x) * self.output_gate(x)
- # [B*spks, N, S]
- x = self.conv1_decoder(x)
- # [B, spks, N, S]
- _, n, L = x.shape
- x = x.view(b, self.num_spks, n, L)
- x = self.activation(x)
- # [spks, B, N, S]
- x = x.transpose(0, 1)
- return x
|