pipeline_template.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import numpy as np
  4. from modelscope.metainfo import Pipelines
  5. from modelscope.models.base.base_model import Model
  6. from modelscope.outputs.outputs import OutputKeys
  7. from modelscope.pipelines.base import Pipeline
  8. from modelscope.pipelines.builder import PIPELINES
  9. from modelscope.utils.constant import Tasks
  10. __all__ = ['PipelineTemplate']
  11. @PIPELINES.register_module(
  12. Tasks.task_template, module_name=Pipelines.pipeline_template)
  13. class PipelineTemplate(Pipeline):
  14. """A pipeline template explain how to define parameters and input and
  15. output information. As a rule, the first parameter is the input,
  16. followed by the request parameters. The parameter must add type
  17. hint information, and set the default value if necessary,
  18. for the convenience of use.
  19. """
  20. def __init__(self, model: Model, **kwargs):
  21. """A pipeline template to describe input and
  22. output and parameter processing
  23. Args:
  24. model: A Model instance.
  25. """
  26. # call base init.
  27. super().__init__(model=model, **kwargs)
  28. def preprocess(self,
  29. input: Any,
  30. max_length: int = 1024,
  31. top_p: float = 0.8) -> Any:
  32. """Pipeline preprocess interface.
  33. Args:
  34. input (Any): The pipeline input, ref Tasks.task_template TASK_INPUTS.
  35. max_length (int, optional): The max_length parameter. Defaults to 1024.
  36. top_p (float, optional): The top_p parameter. Defaults to 0.8.
  37. Returns:
  38. Any: Return result process by forward.
  39. """
  40. pass
  41. def forward(self,
  42. input: Any,
  43. max_length: int = 1024,
  44. top_p: float = 0.8) -> Any:
  45. """The forward interface.
  46. Args:
  47. input (Any): The output of the preprocess.
  48. max_length (int, optional): max_length. Defaults to 1024.
  49. top_p (float, optional): top_p. Defaults to 0.8.
  50. Returns:
  51. Any: Return result process by postprocess.
  52. """
  53. pass
  54. def postprocess(self,
  55. inputs: Any,
  56. postprocess_param1: str = None) -> Dict[str, Any]:
  57. """The postprocess interface.
  58. Args:
  59. input (Any): The output of the forward.
  60. max_length (int, optional): max_length. Defaults to 1024.
  61. top_p (float, optional): top_p. Defaults to 0.8.
  62. Returns:
  63. Any: Return result process by postprocess.
  64. """
  65. result = {
  66. OutputKeys.BOXES: np.zeros(4),
  67. OutputKeys.OUTPUT_IMG: np.zeros(10, 4),
  68. OutputKeys.TEXT_EMBEDDING: np.zeros(1, 1000)
  69. }
  70. return result