summarization_pipeline.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import torch
  4. from modelscope.metainfo import Pipelines, Preprocessors
  5. from modelscope.pipelines.base import Model, Pipeline
  6. from modelscope.pipelines.builder import PIPELINES
  7. from modelscope.pipelines.util import batch_process
  8. from modelscope.preprocessors import Preprocessor
  9. from modelscope.utils.constant import Fields, Tasks
  10. from modelscope.utils.logger import get_logger
  11. logger = get_logger()
  12. @PIPELINES.register_module(
  13. Tasks.text_summarization, module_name=Pipelines.text_generation)
  14. class SummarizationPipeline(Pipeline):
  15. def __init__(self,
  16. model: Union[Model, str],
  17. preprocessor: Optional[Preprocessor] = None,
  18. config_file: str = None,
  19. device: str = 'gpu',
  20. auto_collate=True,
  21. **kwargs):
  22. """Use `model` and `preprocessor` to create a Summarization pipeline for prediction.
  23. Args:
  24. model (str or Model): Supply either a local model dir which supported the summarization task,
  25. or a model id from the model hub, or a model instance.
  26. preprocessor (Preprocessor): An optional preprocessor instance.
  27. kwargs (dict, `optional`):
  28. Extra kwargs passed into the preprocessor's constructor.
  29. """
  30. super().__init__(
  31. model=model,
  32. preprocessor=preprocessor,
  33. config_file=config_file,
  34. device=device,
  35. auto_collate=auto_collate)
  36. self.model.eval()
  37. if preprocessor is None:
  38. if self.model.__class__.__name__ == 'OfaForAllTasks':
  39. self.preprocessor = Preprocessor.from_pretrained(
  40. self.model.model_dir,
  41. type=Preprocessors.ofa_tasks_preprocessor,
  42. field=Fields.multi_modal)
  43. else:
  44. self.preprocessor = Preprocessor.from_pretrained(
  45. self.model.model_dir, **kwargs)
  46. def _batch(self, data):
  47. if self.model.__class__.__name__ == 'OfaForAllTasks':
  48. return batch_process(self.model, data)
  49. else:
  50. return super(SummarizationPipeline, self)._batch(data)
  51. def forward(self, inputs: Dict[str, Any],
  52. **forward_params) -> Dict[str, Any]:
  53. with torch.no_grad():
  54. return super().forward(inputs, **forward_params)
  55. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  56. return inputs