tinynas_detection_pipeline.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. from modelscope.metainfo import Pipelines
  4. from modelscope.outputs import OutputKeys
  5. from modelscope.outputs.cv_outputs import DetectionOutput
  6. from modelscope.pipelines.base import Input, Pipeline
  7. from modelscope.pipelines.builder import PIPELINES
  8. from modelscope.preprocessors import LoadImage, Preprocessor
  9. from modelscope.utils.constant import Tasks
  10. from modelscope.utils.cv.image_utils import \
  11. show_image_object_detection_auto_result
  12. from modelscope.utils.logger import get_logger
  13. logger = get_logger()
  14. @PIPELINES.register_module(
  15. Tasks.domain_specific_object_detection,
  16. module_name=Pipelines.tinynas_detection)
  17. @PIPELINES.register_module(
  18. Tasks.image_object_detection, module_name=Pipelines.tinynas_detection)
  19. class TinynasDetectionPipeline(Pipeline):
  20. def __init__(self,
  21. model: str,
  22. preprocessor: Optional[Preprocessor] = None,
  23. **kwargs):
  24. """Object detection pipeline, currently only for the tinynas-detection model.
  25. Args:
  26. model: A str format model id or model local dir to build the model instance from.
  27. preprocessor: A preprocessor instance to preprocess the data, if None,
  28. the pipeline will try to build the preprocessor according to the configuration.json file.
  29. kwargs: The args needed by the `Pipeline` class.
  30. """
  31. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  32. def preprocess(self, input: Input) -> Dict[str, Any]:
  33. img = LoadImage.convert_to_ndarray(input)
  34. return super().preprocess(img)
  35. def forward(
  36. self, input: Dict[str,
  37. Any]) -> Union[Dict[str, Any], DetectionOutput]:
  38. """The forward method of this pipeline.
  39. Args:
  40. input: The input data output from the `preprocess` procedure.
  41. Returns:
  42. A model output, either in a dict format, or in a standard `DetectionOutput` dataclass.
  43. If outputs a dict, these keys are needed:
  44. class_ids (`Tensor`, *optional*): class id for each object.
  45. boxes (`Tensor`, *optional*): Bounding box for each detected object
  46. in [left, top, right, bottom] format.
  47. scores (`Tensor`, *optional*): Detection score for each object.
  48. """
  49. return self.model(input['img'])
  50. def postprocess(
  51. self, inputs: Union[Dict[str, Any],
  52. DetectionOutput]) -> Dict[str, Any]:
  53. bboxes, scores, labels = inputs['boxes'], inputs['scores'], inputs[
  54. 'class_ids']
  55. if bboxes is None:
  56. outputs = {
  57. OutputKeys.SCORES: [],
  58. OutputKeys.LABELS: [],
  59. OutputKeys.BOXES: []
  60. }
  61. else:
  62. outputs = {
  63. OutputKeys.SCORES: scores,
  64. OutputKeys.LABELS: labels,
  65. OutputKeys.BOXES: bboxes
  66. }
  67. return outputs
  68. def show_result(self, img_path, result, save_path=None):
  69. show_image_object_detection_auto_result(img_path, result, save_path)