image_detection_pipeline.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import numpy as np
  4. from modelscope.metainfo import Pipelines
  5. from modelscope.outputs import OutputKeys
  6. from modelscope.pipelines.base import Input, Pipeline
  7. from modelscope.pipelines.builder import PIPELINES
  8. from modelscope.preprocessors import LoadImage
  9. from modelscope.utils.constant import Tasks
  10. @PIPELINES.register_module(
  11. Tasks.human_detection, module_name=Pipelines.human_detection)
  12. @PIPELINES.register_module(
  13. Tasks.image_object_detection, module_name=Pipelines.object_detection)
  14. @PIPELINES.register_module(
  15. Tasks.image_object_detection,
  16. module_name=Pipelines.abnormal_object_detection)
  17. class ImageDetectionPipeline(Pipeline):
  18. def __init__(self, model: str, **kwargs):
  19. """
  20. model: model id on modelscope hub.
  21. """
  22. super().__init__(model=model, auto_collate=False, **kwargs)
  23. def preprocess(self, input: Input) -> Dict[str, Any]:
  24. img = LoadImage.convert_to_ndarray(input)
  25. img = img.astype(np.float64)
  26. img = self.model.preprocess(img)
  27. result = {'img': img}
  28. return result
  29. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  30. outputs = self.model.inference(input['img'])
  31. result = {'data': outputs}
  32. return result
  33. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  34. bboxes, scores, labels = self.model.postprocess(inputs['data'])
  35. if bboxes is None:
  36. outputs = {
  37. OutputKeys.SCORES: [],
  38. OutputKeys.LABELS: [],
  39. OutputKeys.BOXES: []
  40. }
  41. return outputs
  42. outputs = {
  43. OutputKeys.SCORES: scores,
  44. OutputKeys.LABELS: labels,
  45. OutputKeys.BOXES: bboxes
  46. }
  47. return outputs