| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import importlib
- import os
- from typing import Any, Dict
- import numpy as np
- import scipy.io.wavfile as wav
- import torch
- import yaml
- from modelscope.metainfo import Pipelines
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines.base import Pipeline
- from modelscope.pipelines.builder import PIPELINES
- from modelscope.preprocessors import LinearAECAndFbank
- from modelscope.utils.constant import ModelFile, Tasks
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- FEATURE_MVN = 'feature.DEY.mvn.txt'
- CONFIG_YAML = 'dey_mini.yaml'
- def initialize_config(module_cfg):
- r"""According to config items, load specific module dynamically with params.
- 1. Load the module corresponding to the "module" param.
- 2. Call function (or instantiate class) corresponding to the "main" param.
- 3. Send the param (in "args") into the function (or class) when calling ( or instantiating).
- Args:
- module_cfg (dict): config items, eg:
- {
- "module": "models.model",
- "main": "Model",
- "args": {...}
- }
- Returns:
- the module loaded.
- """
- module = importlib.import_module(module_cfg['module'])
- return getattr(module, module_cfg['main'])(**module_cfg['args'])
- @PIPELINES.register_module(
- Tasks.acoustic_echo_cancellation,
- module_name=Pipelines.speech_dfsmn_aec_psm_16k)
- class LinearAECPipeline(Pipeline):
- r"""AEC Inference Pipeline only support 16000 sample rate.
- When invoke the class with pipeline.__call__(), you should provide two params:
- Dict[str, Any]
- the path of wav files, eg:{
- "nearend_mic": "/your/data/near_end_mic_audio.wav",
- "farend_speech": "/your/data/far_end_speech_audio.wav"}
- output_path (str, optional): "/your/output/audio_after_aec.wav"
- the file path to write generate audio.
- """
- def __init__(self, model, **kwargs):
- """
- use `model` and `preprocessor` to create a kws pipeline for prediction
- Args:
- model: model id on modelscope hub.
- """
- super().__init__(model=model, **kwargs)
- self.check_trust_remote_code(
- 'This pipeline requires `trust_remote_code=True` to load the module defined'
- ' in the `dey_mini.yaml`, setting this to True means you trust the code and files'
- ' listed in this model repo.')
- self.use_cuda = torch.cuda.is_available()
- with open(
- os.path.join(self.model, CONFIG_YAML), encoding='utf-8') as f:
- self.config = yaml.full_load(f.read())
- self.config['io']['mvn'] = os.path.join(self.model, FEATURE_MVN)
- self._init_model()
- self.preprocessor = LinearAECAndFbank(self.config['io'])
- n_fft = self.config['loss']['args']['n_fft']
- hop_length = self.config['loss']['args']['hop_length']
- winlen = n_fft
- window = torch.hamming_window(winlen, periodic=False)
- def stft(x):
- return torch.view_as_real(
- torch.stft(
- x,
- n_fft,
- hop_length,
- winlen,
- center=False,
- window=window.to(x.device),
- return_complex=True))
- def istft(x, slen):
- return torch.istft(
- torch.view_as_complex(x),
- n_fft,
- hop_length,
- winlen,
- window=window.to(x.device),
- center=False,
- length=slen)
- self.stft = stft
- self.istft = istft
- def _init_model(self):
- checkpoint = torch.load(
- os.path.join(self.model, ModelFile.TORCH_MODEL_BIN_FILE),
- map_location='cpu',
- weights_only=True)
- self.model = initialize_config(self.config['nnet'])
- if self.use_cuda:
- self.model = self.model.cuda()
- self.model.load_state_dict(checkpoint)
- def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- r"""The AEC process.
- Args:
- inputs: dict={'feature': Tensor, 'base': Tensor}
- 'feature' feature of input audio.
- 'base' the base audio to mask.
- Returns:
- output_pcm: generated audio array
- """
- output_data = self._process(inputs['feature'], inputs['base'])
- output_data = output_data.astype(np.int16).tobytes()
- return {OutputKeys.OUTPUT_PCM: output_data}
- def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
- r"""The post process. Will save audio to file, if the output_path is given.
- Args:
- inputs: a dict contains following keys:
- - output_pcm: generated audio array
- kwargs: accept 'output_path' which is the path to write generated audio
- Returns:
- output_pcm: generated audio array
- """
- if 'output_path' in kwargs.keys():
- wav.write(
- kwargs['output_path'], self.preprocessor.SAMPLE_RATE,
- np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16))
- return inputs
- def _process(self, fbanks, mixture):
- if self.use_cuda:
- fbanks = fbanks.cuda()
- mixture = mixture.cuda()
- if self.model.vad:
- with torch.no_grad():
- masks, vad = self.model(fbanks.unsqueeze(0))
- masks = masks.permute([2, 1, 0])
- else:
- with torch.no_grad():
- masks = self.model(fbanks.unsqueeze(0))
- masks = masks.permute([2, 1, 0])
- spectrum = self.stft(mixture)
- masked_spec = spectrum * masks
- masked_sig = self.istft(masked_spec, len(mixture)).cpu().numpy()
- return masked_sig
|