zipenhancer.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. #!/usr/bin/env python3
  2. #
  3. # Copyright (c) Alibaba, Inc. and its affiliates.
  4. import os
  5. import random
  6. from typing import Dict
  7. import numpy as np
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from modelscope.metainfo import Models
  12. from modelscope.models import TorchModel
  13. from modelscope.models.base import Tensor
  14. from modelscope.models.builder import MODELS
  15. from modelscope.utils.constant import ModelFile, Tasks
  16. from .zipenhancer_layers.generator import (DenseEncoder, MappingDecoder,
  17. PhaseDecoder)
  18. from .zipenhancer_layers.scaling import ScheduledFloat
  19. from .zipenhancer_layers.zipenhancer_layer import Zipformer2DualPathEncoder
  20. @MODELS.register_module(
  21. Tasks.acoustic_noise_suppression,
  22. module_name=Models.speech_zipenhancer_ans_multiloss_16k_base)
  23. class ZipenhancerDecorator(TorchModel):
  24. def __init__(self, model_dir: str, *args, **kwargs):
  25. super().__init__(model_dir, *args, **kwargs)
  26. h = dict(
  27. num_tsconformers=kwargs['num_tsconformers'],
  28. dense_channel=kwargs['dense_channel'],
  29. former_conf=kwargs['former_conf'],
  30. batch_first=kwargs['batch_first'],
  31. model_num_spks=kwargs['model_num_spks'],
  32. )
  33. # num_tsconformers, dense_channel, former_name, former_conf, batch_first, model_num_spks
  34. h = AttrDict(h)
  35. self.model = ZipEnhancer(h)
  36. model_bin_file = os.path.join(model_dir,
  37. ModelFile.TORCH_MODEL_BIN_FILE)
  38. if os.path.exists(model_bin_file):
  39. checkpoint = torch.load(
  40. model_bin_file, map_location=torch.device('cpu'))
  41. if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
  42. # the new trained model by user is based on ZipenhancerDecorator
  43. self.load_state_dict(checkpoint['state_dict'])
  44. else:
  45. # The released model on Modelscope is based on Zipenhancer
  46. # self.model.load_state_dict(checkpoint, strict=False)
  47. self.model.load_state_dict(checkpoint['generator'])
  48. # print(checkpoint['generator'].keys())
  49. def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
  50. n_fft = 400
  51. hop_size = 100
  52. win_size = 400
  53. noisy_wav = inputs['noisy']
  54. norm_factor = torch.sqrt(noisy_wav.shape[1]
  55. / torch.sum(noisy_wav**2.0))
  56. noisy_audio = (noisy_wav * norm_factor)
  57. mag, pha, com = mag_pha_stft(
  58. noisy_audio,
  59. n_fft,
  60. hop_size,
  61. win_size,
  62. compress_factor=0.3,
  63. center=True)
  64. amp_g, pha_g, com_g, _, others = self.model.forward(mag, pha)
  65. wav = mag_pha_istft(
  66. amp_g,
  67. pha_g,
  68. n_fft,
  69. hop_size,
  70. win_size,
  71. compress_factor=0.3,
  72. center=True)
  73. wav = wav / norm_factor
  74. output = {
  75. 'wav_l2': wav,
  76. }
  77. return output
  78. class ZipEnhancer(nn.Module):
  79. def __init__(self, h):
  80. """
  81. Initialize the ZipEnhancer module.
  82. Args:
  83. h (object): Configuration object containing various hyperparameters and settings.
  84. having num_tsconformers, former_name, former_conf, mask_decoder_type, ...
  85. """
  86. super(ZipEnhancer, self).__init__()
  87. self.h = h
  88. num_tsconformers = h.num_tsconformers
  89. self.num_tscblocks = num_tsconformers
  90. self.dense_encoder = DenseEncoder(h, in_channel=2)
  91. self.TSConformer = Zipformer2DualPathEncoder(
  92. output_downsampling_factor=1,
  93. dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
  94. **h.former_conf)
  95. self.mask_decoder = MappingDecoder(h, out_channel=h.model_num_spks)
  96. self.phase_decoder = PhaseDecoder(h, out_channel=h.model_num_spks)
  97. def forward(self, noisy_mag, noisy_pha): # [B, F, T]
  98. """
  99. Forward pass of the ZipEnhancer module.
  100. Args:
  101. noisy_mag (Tensor): Noisy magnitude input tensor of shape [B, F, T].
  102. noisy_pha (Tensor): Noisy phase input tensor of shape [B, F, T].
  103. Returns:
  104. Tuple: denoised magnitude, denoised phase, denoised complex representation,
  105. (optional) predicted noise components, and other auxiliary information.
  106. """
  107. others = dict()
  108. noisy_mag = noisy_mag.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F]
  109. noisy_pha = noisy_pha.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F]
  110. x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
  111. x = self.dense_encoder(x)
  112. # [B, C, T, F]
  113. x = self.TSConformer(x)
  114. pred_mag = self.mask_decoder(x)
  115. pred_pha = self.phase_decoder(x)
  116. # b, c, t, f -> b, 1, t, f -> b, f, t, 1 -> b, f, t
  117. denoised_mag = pred_mag[:, 0, :, :].unsqueeze(1).permute(0, 3, 2,
  118. 1).squeeze(-1)
  119. # b, t, f
  120. denoised_pha = pred_pha[:, 0, :, :].unsqueeze(1).permute(0, 3, 2,
  121. 1).squeeze(-1)
  122. # b, t, f
  123. denoised_com = torch.stack((denoised_mag * torch.cos(denoised_pha),
  124. denoised_mag * torch.sin(denoised_pha)),
  125. dim=-1)
  126. return denoised_mag, denoised_pha, denoised_com, None, others
  127. class AttrDict(dict):
  128. def __init__(self, *args, **kwargs):
  129. super(AttrDict, self).__init__(*args, **kwargs)
  130. self.__dict__ = self
  131. def mag_pha_stft(y,
  132. n_fft,
  133. hop_size,
  134. win_size,
  135. compress_factor=1.0,
  136. center=True):
  137. hann_window = torch.hann_window(win_size, device=y.device)
  138. stft_spec = torch.stft(
  139. y,
  140. n_fft,
  141. hop_length=hop_size,
  142. win_length=win_size,
  143. window=hann_window,
  144. center=center,
  145. pad_mode='reflect',
  146. normalized=False,
  147. return_complex=True)
  148. stft_spec = torch.view_as_real(stft_spec)
  149. mag = torch.sqrt(stft_spec.pow(2).sum(-1) + (1e-9))
  150. pha = torch.atan2(stft_spec[:, :, :, 1], stft_spec[:, :, :, 0] + (1e-5))
  151. # Magnitude Compression
  152. mag = torch.pow(mag, compress_factor)
  153. com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1)
  154. return mag, pha, com
  155. def mag_pha_istft(mag,
  156. pha,
  157. n_fft,
  158. hop_size,
  159. win_size,
  160. compress_factor=1.0,
  161. center=True):
  162. # Magnitude Decompression
  163. mag = torch.pow(mag, (1.0 / compress_factor))
  164. com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha))
  165. hann_window = torch.hann_window(win_size, device=com.device)
  166. wav = torch.istft(
  167. com,
  168. n_fft,
  169. hop_length=hop_size,
  170. win_length=win_size,
  171. window=hann_window,
  172. center=center)
  173. return wav