| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import io
- from typing import Any, Dict
- import numpy
- import soundfile as sf
- import torch
- from modelscope.fileio import File
- from modelscope.metainfo import Models, Pipelines
- from modelscope.models.base import Input
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines import Pipeline
- from modelscope.pipelines.builder import PIPELINES
- from modelscope.utils.constant import Tasks
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @PIPELINES.register_module(
- Tasks.speech_separation,
- module_name=Models.speech_mossformer_separation_temporal_8k)
- @PIPELINES.register_module(
- Tasks.speech_separation,
- module_name=Models.speech_mossformer2_separation_temporal_8k)
- class SeparationPipeline(Pipeline):
- def __init__(self, model, **kwargs):
- """create a speech separation pipeline for prediction
- Args:
- model: model id on modelscope hub.
- """
- logger.info('loading model...')
- super().__init__(model=model, **kwargs)
- self.model.load_check_point(device=self.device)
- self.model.eval()
- def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
- if isinstance(inputs, str):
- file_bytes = File.read(inputs)
- data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
- if fs != 8000:
- raise ValueError(
- 'modelscope error: The audio sample rate should be 8000')
- elif isinstance(inputs, bytes):
- data = torch.from_numpy(
- numpy.frombuffer(inputs, dtype=numpy.float32))
- return dict(data=data)
- def postprocess(self, inputs: Dict[str, Any],
- **post_params) -> Dict[str, Any]:
- return inputs
- def forward(
- self, inputs: Dict[str, Any], **forward_params
- ) -> Dict[str, Any]: # mix, targets, stage, noise=None):
- """Forward computations from the mixture to the separated signals."""
- logger.info('Start forward...')
- # Unpack lists and put tensors in the right device
- mix = inputs['data'].to(self.device)
- mix = torch.unsqueeze(mix, dim=1).transpose(0, 1)
- est_source = self.model(mix)
- result = []
- for ns in range(self.model.num_spks):
- signal = est_source[0, :, ns]
- signal = signal / signal.abs().max() * 0.5
- signal = signal.unsqueeze(0).cpu()
- # convert tensor to pcm
- output = (signal.numpy() * 32768).astype(numpy.int16).tobytes()
- result.append(output)
- logger.info('Finish forward.')
- return {OutputKeys.OUTPUT_PCM_LIST: result}
|