converter.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Dict
  4. import soundfile as sf
  5. import torch
  6. from modelscope.metainfo import Models
  7. from modelscope.models import TorchModel
  8. from modelscope.models.audio.vc.src.encoder import Encoder
  9. from modelscope.models.audio.vc.src.sv_models.DTDNN import \
  10. SpeakerVerificationCamplus
  11. from modelscope.models.audio.vc.src.vocoder import (ConditionGenerator,
  12. HiFiGANGenerator)
  13. from modelscope.models.base import Tensor
  14. from modelscope.models.builder import MODELS
  15. from modelscope.utils.constant import Tasks
  16. @MODELS.register_module(Tasks.voice_conversion, module_name=Models.unetvc_16k)
  17. class UnetVC(TorchModel):
  18. r"""A decorator of FRCRN for integrating into modelscope framework"""
  19. def __init__(self, model_dir: str, *args, **kwargs):
  20. """initialize the frcrn model from the `model_dir` path.
  21. Args:
  22. model_dir (str): the model path.
  23. """
  24. super().__init__(model_dir, *args, **kwargs)
  25. device = kwargs.get('device', 'cpu')
  26. self.device = device
  27. static_path = os.path.join(model_dir, 'static')
  28. self.encoder = Encoder(
  29. os.path.join(static_path, 'encoder_am.mvn'),
  30. os.path.join(static_path, 'encoder.onnx'))
  31. self.spk_emb = SpeakerVerificationCamplus(
  32. os.path.join(static_path, 'campplus_cn_common.bin'), device)
  33. self.converter = ConditionGenerator(
  34. unet=True, extra_info=True).to(device)
  35. G_path = os.path.join(static_path, 'converter.pth')
  36. self.converter.load_state_dict(
  37. torch.load(G_path, map_location=lambda storage, loc: storage))
  38. self.converter.eval()
  39. self.vocoder = HiFiGANGenerator().to(device)
  40. self.vocoder.load_state_dict(
  41. torch.load(
  42. os.path.join(static_path, 'vocoder.pth'),
  43. map_location=self.device)['state_dict'])
  44. self.vocoder.eval()
  45. self.vocoder.remove_weight_norm()
  46. def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
  47. target_wav_path = inputs['target_wav']
  48. source_wav_path = inputs['source_wav']
  49. save_wav_path = inputs['save_path']
  50. with torch.no_grad():
  51. source_enc = self.encoder.inference(source_wav_path).to(
  52. self.device)
  53. spk_emb = self.spk_emb.forward(target_wav_path).to(self.device)
  54. style_mc = self.encoder.get_feats(target_wav_path).to(self.device)
  55. coded_sp_converted_norm = self.converter(source_enc, spk_emb,
  56. style_mc)
  57. wav = self.vocoder(coded_sp_converted_norm.permute([0, 2, 1]))
  58. if os.path.exists(save_wav_path):
  59. sf.write(save_wav_path,
  60. wav.flatten().cpu().data.numpy(), 16000)
  61. return wav.flatten().cpu().data.numpy()