text2sql_pipeline.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import torch
  4. from modelscope.metainfo import Pipelines
  5. from modelscope.models.multi_modal import OfaForAllTasks
  6. from modelscope.pipelines.base import Model, Pipeline
  7. from modelscope.pipelines.builder import PIPELINES
  8. from modelscope.pipelines.util import batch_process
  9. from modelscope.preprocessors import OfaPreprocessor, Preprocessor
  10. from modelscope.utils.constant import Tasks
  11. from modelscope.utils.logger import get_logger
  12. logger = get_logger()
  13. @PIPELINES.register_module(Tasks.text2sql, module_name=Pipelines.ofa_text2sql)
  14. class TextToSqlPipeline(Pipeline):
  15. R"""
  16. pipeline for text to sql task
  17. """
  18. def __init__(self,
  19. model: Union[Model, str],
  20. preprocessor: Optional[Preprocessor] = None,
  21. **kwargs):
  22. """
  23. use `model` and `preprocessor` to create a pipeline for text2sql task
  24. Args:
  25. model: model id on modelscope hub.
  26. """
  27. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  28. self.model.eval()
  29. if preprocessor is None:
  30. if isinstance(self.model, OfaForAllTasks):
  31. self.preprocessor = OfaPreprocessor(self.model.model_dir)
  32. def _batch(self, data):
  33. if isinstance(self.model, OfaForAllTasks):
  34. return batch_process(self.model, data)
  35. else:
  36. return super(TextToSqlPipeline, self)._batch(data)
  37. def forward(self, inputs: Dict[str, Any],
  38. **forward_params) -> Dict[str, Any]:
  39. with torch.no_grad():
  40. return super().forward(inputs, **forward_params)
  41. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  42. return inputs