ssr_pipeline.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import numpy as np
  4. import torch
  5. from modelscope.metainfo import Pipelines
  6. from modelscope.outputs import OutputKeys
  7. from modelscope.pipelines.base import Input, Pipeline
  8. from modelscope.pipelines.builder import PIPELINES
  9. from modelscope.utils.constant import Tasks
  10. @PIPELINES.register_module(
  11. Tasks.speech_super_resolution,
  12. module_name=Pipelines.speech_super_resolution_inference)
  13. class SSRPipeline(Pipeline):
  14. r"""ANS (Acoustic Noise Suppression) Inference Pipeline .
  15. When invoke the class with pipeline.__call__(), it accept only one
  16. parameter:
  17. inputs(str): the path of wav file
  18. """
  19. SAMPLE_RATE = 48000
  20. def __init__(self, model, **kwargs):
  21. """
  22. use `model` and `preprocessor` to create a kws pipeline for prediction
  23. Args:
  24. model: model id on modelscope hub.
  25. """
  26. super().__init__(model=model, **kwargs)
  27. self.model.eval()
  28. self.stream_mode = kwargs.get('stream_mode', False)
  29. def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
  30. return inputs
  31. def forward(self, inputs: Dict[str, Any],
  32. **forward_params) -> Dict[str, Any]:
  33. with torch.no_grad():
  34. outputs = self.model(inputs)
  35. outputs *= 32768.
  36. outputs = np.array(outputs, 'int16').tobytes()
  37. return {OutputKeys.OUTPUT_PCM: outputs}
  38. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  39. return inputs