# Copyright (c) Alibaba, Inc. and its affiliates. import io from typing import Any, Dict import numpy import soundfile as sf import torch from modelscope.fileio import File from modelscope.metainfo import Models, Pipelines from modelscope.models.base import Input from modelscope.outputs import OutputKeys from modelscope.pipelines import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger logger = get_logger() @PIPELINES.register_module( Tasks.speech_separation, module_name=Models.speech_mossformer_separation_temporal_8k) @PIPELINES.register_module( Tasks.speech_separation, module_name=Models.speech_mossformer2_separation_temporal_8k) class SeparationPipeline(Pipeline): def __init__(self, model, **kwargs): """create a speech separation pipeline for prediction Args: model: model id on modelscope hub. """ logger.info('loading model...') super().__init__(model=model, **kwargs) self.model.load_check_point(device=self.device) self.model.eval() def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: if isinstance(inputs, str): file_bytes = File.read(inputs) data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') if fs != 8000: raise ValueError( 'modelscope error: The audio sample rate should be 8000') elif isinstance(inputs, bytes): data = torch.from_numpy( numpy.frombuffer(inputs, dtype=numpy.float32)) return dict(data=data) def postprocess(self, inputs: Dict[str, Any], **post_params) -> Dict[str, Any]: return inputs def forward( self, inputs: Dict[str, Any], **forward_params ) -> Dict[str, Any]: # mix, targets, stage, noise=None): """Forward computations from the mixture to the separated signals.""" logger.info('Start forward...') # Unpack lists and put tensors in the right device mix = inputs['data'].to(self.device) mix = torch.unsqueeze(mix, dim=1).transpose(0, 1) est_source = self.model(mix) result = [] for ns in range(self.model.num_spks): signal = est_source[0, :, ns] signal = signal / signal.abs().max() * 0.5 signal = signal.unsqueeze(0).cpu() # convert tensor to pcm output = (signal.numpy() * 32768).astype(numpy.int16).tobytes() result.append(output) logger.info('Finish forward.') return {OutputKeys.OUTPUT_PCM_LIST: result}