asr_pipeline.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import torch
  4. from modelscope.metainfo import Pipelines
  5. from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks
  6. from modelscope.pipelines.base import Model, Pipeline
  7. from modelscope.pipelines.builder import PIPELINES
  8. from modelscope.pipelines.util import batch_process
  9. from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor,
  10. Preprocessor)
  11. from modelscope.utils.constant import Tasks
  12. from modelscope.utils.logger import get_logger
  13. logger = get_logger()
  14. @PIPELINES.register_module(
  15. Tasks.auto_speech_recognition, module_name=Pipelines.ofa_asr)
  16. class AutomaticSpeechRecognitionPipeline(Pipeline):
  17. def __init__(self,
  18. model: Union[Model, str],
  19. preprocessor: Optional[Preprocessor] = None,
  20. **kwargs):
  21. """
  22. use `model` and `preprocessor` to create an automatic speech recognition pipeline for prediction
  23. Args:
  24. model: model id on modelscope hub.
  25. """
  26. assert isinstance(model, str) or isinstance(model, Model), \
  27. 'model must be a single str or OfaForAllTasks'
  28. if isinstance(model, str):
  29. pipe_model = Model.from_pretrained(model)
  30. elif isinstance(model, Model):
  31. pipe_model = model
  32. else:
  33. raise NotImplementedError
  34. pipe_model.model.eval()
  35. if preprocessor is None:
  36. if isinstance(pipe_model, OfaForAllTasks):
  37. preprocessor = OfaPreprocessor(pipe_model.model_dir)
  38. elif isinstance(pipe_model, MPlugForAllTasks):
  39. preprocessor = MPlugPreprocessor(pipe_model.model_dir)
  40. super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
  41. def _batch(self, data):
  42. if isinstance(self.model, OfaForAllTasks):
  43. return batch_process(self.model, data)
  44. else:
  45. return super(AutomaticSpeechRecognitionPipeline, self)._batch(data)
  46. def forward(self, inputs: Dict[str, Any],
  47. **forward_params) -> Dict[str, Any]:
  48. with torch.no_grad():
  49. return super().forward(inputs, **forward_params)
  50. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  51. return inputs