ssr_infer.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Dict
  4. import librosa
  5. import soundfile as sf
  6. import torch
  7. from torchaudio.transforms import Spectrogram
  8. from modelscope.metainfo import Models
  9. from modelscope.models import TorchModel
  10. from modelscope.models.audio.ssr.models.hifigan import HiFiGANGenerator
  11. from modelscope.models.audio.ssr.models.Unet import MaskMapping
  12. from modelscope.models.base import Tensor
  13. from modelscope.models.builder import MODELS
  14. from modelscope.utils.constant import Tasks
  15. @MODELS.register_module(
  16. Tasks.speech_super_resolution, module_name=Models.hifissr)
  17. class HifiSSR(TorchModel):
  18. r"""A decorator of FRCRN for integrating into modelscope framework"""
  19. def __init__(self, model_dir: str, *args, **kwargs):
  20. """initialize the frcrn model from the `model_dir` path.
  21. Args:
  22. model_dir (str): the model path.
  23. """
  24. super().__init__(model_dir, *args, **kwargs)
  25. self.device = kwargs.get('device', 'cpu')
  26. self.front = Spectrogram(512, 512, int(48000 * 0.01)).to(self.device)
  27. self.vocoder = HiFiGANGenerator(
  28. input_channels=256,
  29. upsample_rates=[5, 4, 4, 3, 2],
  30. upsample_kernel_sizes=[10, 8, 8, 6, 4],
  31. weight_norm=False,
  32. upsample_initial_channel=1024).to(self.device)
  33. self.mapping = MaskMapping(32, 256).to(self.device)
  34. model_bin_file = os.path.join(model_dir, 'checkpoint.pt')
  35. if os.path.exists(model_bin_file):
  36. checkpoint = torch.load(model_bin_file, map_location=self.device)
  37. self.vocoder.load_state_dict(checkpoint['voc_state_dict'])
  38. self.vocoder.eval()
  39. self.mapping.load_state_dict(checkpoint['unet_state_dict'])
  40. self.mapping.eval()
  41. def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
  42. ref_fp = inputs['ref_wav']
  43. source_fp = inputs['source_wav']
  44. out_fp = inputs['out_wav']
  45. sr = 48000
  46. wav = librosa.load(source_fp, sr=sr)[0]
  47. source_mel = self.front(
  48. torch.FloatTensor(wav).unsqueeze(0).to(self.device))[:, :-1]
  49. source_mel = torch.log10(source_mel + 1e-6)
  50. source_mel = source_mel.unsqueeze(0)
  51. ref_wav = librosa.load(ref_fp, sr=sr)[0]
  52. ref_mel = self.front(
  53. torch.FloatTensor(ref_wav).unsqueeze(0).to(self.device))[:, :-1]
  54. ref_mel = torch.log10(ref_mel + 1e-6)
  55. with torch.no_grad():
  56. g_out = self.mapping(source_mel, ref_mel)
  57. g_out_wav = self.vocoder(g_out)
  58. g_out_wav = g_out_wav.flatten()
  59. if os.path.exists(out_fp):
  60. sf.write(out_fp, g_out_wav.cpu().data.numpy(), sr)
  61. return g_out_wav.cpu().data.numpy()