separation_pipeline.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. from typing import Any, Dict
  4. import numpy
  5. import soundfile as sf
  6. import torch
  7. from modelscope.fileio import File
  8. from modelscope.metainfo import Models, Pipelines
  9. from modelscope.models.base import Input
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.pipelines import Pipeline
  12. from modelscope.pipelines.builder import PIPELINES
  13. from modelscope.utils.constant import Tasks
  14. from modelscope.utils.logger import get_logger
  15. logger = get_logger()
  16. @PIPELINES.register_module(
  17. Tasks.speech_separation,
  18. module_name=Models.speech_mossformer_separation_temporal_8k)
  19. @PIPELINES.register_module(
  20. Tasks.speech_separation,
  21. module_name=Models.speech_mossformer2_separation_temporal_8k)
  22. class SeparationPipeline(Pipeline):
  23. def __init__(self, model, **kwargs):
  24. """create a speech separation pipeline for prediction
  25. Args:
  26. model: model id on modelscope hub.
  27. """
  28. logger.info('loading model...')
  29. super().__init__(model=model, **kwargs)
  30. self.model.load_check_point(device=self.device)
  31. self.model.eval()
  32. def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
  33. if isinstance(inputs, str):
  34. file_bytes = File.read(inputs)
  35. data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
  36. if fs != 8000:
  37. raise ValueError(
  38. 'modelscope error: The audio sample rate should be 8000')
  39. elif isinstance(inputs, bytes):
  40. data = torch.from_numpy(
  41. numpy.frombuffer(inputs, dtype=numpy.float32))
  42. return dict(data=data)
  43. def postprocess(self, inputs: Dict[str, Any],
  44. **post_params) -> Dict[str, Any]:
  45. return inputs
  46. def forward(
  47. self, inputs: Dict[str, Any], **forward_params
  48. ) -> Dict[str, Any]: # mix, targets, stage, noise=None):
  49. """Forward computations from the mixture to the separated signals."""
  50. logger.info('Start forward...')
  51. # Unpack lists and put tensors in the right device
  52. mix = inputs['data'].to(self.device)
  53. mix = torch.unsqueeze(mix, dim=1).transpose(0, 1)
  54. est_source = self.model(mix)
  55. result = []
  56. for ns in range(self.model.num_spks):
  57. signal = est_source[0, :, ns]
  58. signal = signal / signal.abs().max() * 0.5
  59. signal = signal.unsqueeze(0).cpu()
  60. # convert tensor to pcm
  61. output = (signal.numpy() * 32768).astype(numpy.int16).tobytes()
  62. result.append(output)
  63. logger.info('Finish forward.')
  64. return {OutputKeys.OUTPUT_PCM_LIST: result}