sudoku_pipeline.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import torch
  4. from modelscope.metainfo import Pipelines
  5. from modelscope.models.multi_modal import OfaForAllTasks
  6. from modelscope.pipelines.base import Model, Pipeline
  7. from modelscope.pipelines.builder import PIPELINES
  8. from modelscope.pipelines.util import batch_process
  9. from modelscope.preprocessors import OfaPreprocessor, Preprocessor
  10. from modelscope.utils.constant import Tasks
  11. from modelscope.utils.logger import get_logger
  12. logger = get_logger()
  13. @PIPELINES.register_module(Tasks.sudoku, module_name=Pipelines.ofa_sudoku)
  14. class SudokuPipeline(Pipeline):
  15. R"""
  16. pipeline for sudoku solving
  17. """
  18. def __init__(self,
  19. model: Union[Model, str],
  20. preprocessor: Optional[Preprocessor] = None,
  21. **kwargs):
  22. """
  23. use `model` and `preprocessor` to create a pipeline for solving sudoku
  24. Args:
  25. model: model id on modelscope hub.
  26. """
  27. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  28. self.model.eval()
  29. if preprocessor is None:
  30. if isinstance(self.model, OfaForAllTasks):
  31. self.preprocessor = OfaPreprocessor(self.model.model_dir)
  32. else:
  33. raise 'no preprocessor is provided'
  34. def _batch(self, data):
  35. if isinstance(self.model, OfaForAllTasks):
  36. return batch_process(self.model, data)
  37. else:
  38. return super(SudokuPipeline, self)._batch(data)
  39. def forward(self, inputs: Dict[str, Any],
  40. **forward_params) -> Dict[str, Any]:
  41. with torch.no_grad():
  42. return super().forward(inputs, **forward_params)
  43. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  44. return inputs