#!/usr/bin/env python3 # # Copyright (c) Alibaba, Inc. and its affiliates. import os import random from typing import Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from modelscope.metainfo import Models from modelscope.models import TorchModel from modelscope.models.base import Tensor from modelscope.models.builder import MODELS from modelscope.utils.constant import ModelFile, Tasks from .zipenhancer_layers.generator import (DenseEncoder, MappingDecoder, PhaseDecoder) from .zipenhancer_layers.scaling import ScheduledFloat from .zipenhancer_layers.zipenhancer_layer import Zipformer2DualPathEncoder @MODELS.register_module( Tasks.acoustic_noise_suppression, module_name=Models.speech_zipenhancer_ans_multiloss_16k_base) class ZipenhancerDecorator(TorchModel): def __init__(self, model_dir: str, *args, **kwargs): super().__init__(model_dir, *args, **kwargs) h = dict( num_tsconformers=kwargs['num_tsconformers'], dense_channel=kwargs['dense_channel'], former_conf=kwargs['former_conf'], batch_first=kwargs['batch_first'], model_num_spks=kwargs['model_num_spks'], ) # num_tsconformers, dense_channel, former_name, former_conf, batch_first, model_num_spks h = AttrDict(h) self.model = ZipEnhancer(h) model_bin_file = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) if os.path.exists(model_bin_file): checkpoint = torch.load( model_bin_file, map_location=torch.device('cpu')) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: # the new trained model by user is based on ZipenhancerDecorator self.load_state_dict(checkpoint['state_dict']) else: # The released model on Modelscope is based on Zipenhancer # self.model.load_state_dict(checkpoint, strict=False) self.model.load_state_dict(checkpoint['generator']) # print(checkpoint['generator'].keys()) def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: n_fft = 400 hop_size = 100 win_size = 400 noisy_wav = inputs['noisy'] norm_factor = torch.sqrt(noisy_wav.shape[1] / torch.sum(noisy_wav**2.0)) noisy_audio = (noisy_wav * norm_factor) mag, pha, com = mag_pha_stft( noisy_audio, n_fft, hop_size, win_size, compress_factor=0.3, center=True) amp_g, pha_g, com_g, _, others = self.model.forward(mag, pha) wav = mag_pha_istft( amp_g, pha_g, n_fft, hop_size, win_size, compress_factor=0.3, center=True) wav = wav / norm_factor output = { 'wav_l2': wav, } return output class ZipEnhancer(nn.Module): def __init__(self, h): """ Initialize the ZipEnhancer module. Args: h (object): Configuration object containing various hyperparameters and settings. having num_tsconformers, former_name, former_conf, mask_decoder_type, ... """ super(ZipEnhancer, self).__init__() self.h = h num_tsconformers = h.num_tsconformers self.num_tscblocks = num_tsconformers self.dense_encoder = DenseEncoder(h, in_channel=2) self.TSConformer = Zipformer2DualPathEncoder( output_downsampling_factor=1, dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), **h.former_conf) self.mask_decoder = MappingDecoder(h, out_channel=h.model_num_spks) self.phase_decoder = PhaseDecoder(h, out_channel=h.model_num_spks) def forward(self, noisy_mag, noisy_pha): # [B, F, T] """ Forward pass of the ZipEnhancer module. Args: noisy_mag (Tensor): Noisy magnitude input tensor of shape [B, F, T]. noisy_pha (Tensor): Noisy phase input tensor of shape [B, F, T]. Returns: Tuple: denoised magnitude, denoised phase, denoised complex representation, (optional) predicted noise components, and other auxiliary information. """ others = dict() noisy_mag = noisy_mag.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F] noisy_pha = noisy_pha.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F] x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F] x = self.dense_encoder(x) # [B, C, T, F] x = self.TSConformer(x) pred_mag = self.mask_decoder(x) pred_pha = self.phase_decoder(x) # b, c, t, f -> b, 1, t, f -> b, f, t, 1 -> b, f, t denoised_mag = pred_mag[:, 0, :, :].unsqueeze(1).permute(0, 3, 2, 1).squeeze(-1) # b, t, f denoised_pha = pred_pha[:, 0, :, :].unsqueeze(1).permute(0, 3, 2, 1).squeeze(-1) # b, t, f denoised_com = torch.stack((denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)), dim=-1) return denoised_mag, denoised_pha, denoised_com, None, others class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True): hann_window = torch.hann_window(win_size, device=y.device) stft_spec = torch.stft( y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center, pad_mode='reflect', normalized=False, return_complex=True) stft_spec = torch.view_as_real(stft_spec) mag = torch.sqrt(stft_spec.pow(2).sum(-1) + (1e-9)) pha = torch.atan2(stft_spec[:, :, :, 1], stft_spec[:, :, :, 0] + (1e-5)) # Magnitude Compression mag = torch.pow(mag, compress_factor) com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1) return mag, pha, com def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): # Magnitude Decompression mag = torch.pow(mag, (1.0 / compress_factor)) com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha)) hann_window = torch.hann_window(win_size, device=com.device) wav = torch.istft( com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) return wav