card_detection_pipeline.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, List, Union
  3. from modelscope.metainfo import Pipelines
  4. from modelscope.models.base.base_model import Model
  5. from modelscope.pipelines.base import Pipeline
  6. from modelscope.pipelines.builder import PIPELINES
  7. from modelscope.utils.constant import Tasks
  8. from modelscope.utils.input_output_typing import Image
  9. from modelscope.utils.logger import get_logger
  10. logger = get_logger()
  11. @PIPELINES.register_module(
  12. Tasks.card_detection, module_name=Pipelines.card_detection)
  13. class CardDetectionPipeline(Pipeline):
  14. r""" Card Detection Pipeline.
  15. Examples:
  16. >>> from modelscope.pipelines import pipeline
  17. >>> detector = pipeline('card-detection', 'damo/cv_resnet_carddetection_scrfd34gkps')
  18. >>> detector("http://www.modelscope.cn/api/v1/models/damo/cv_resnet_carddetection_scrfd34gkps/repo?Revision=master"
  19. >>> "&FilePath=description/card_detection1.jpg")
  20. >>> {
  21. >>> "boxes": [
  22. >>> [
  23. >>> 446.9007568359375,
  24. >>> 36.374977111816406,
  25. >>> 907.0919189453125,
  26. >>> 337.439208984375
  27. >>> ],
  28. >>> [
  29. >>> 454.3310241699219,
  30. >>> 336.08477783203125,
  31. >>> 921.26904296875,
  32. >>> 641.7871704101562
  33. >>> ]
  34. >>> ],
  35. >>> "keypoints": [
  36. >>> [
  37. >>> 457.34710693359375,
  38. >>> 339.02044677734375,
  39. >>> 446.72271728515625,
  40. >>> 52.899078369140625,
  41. >>> 902.8200073242188,
  42. >>> 35.063236236572266,
  43. >>> 908.5877685546875,
  44. >>> 325.62030029296875
  45. >>> ],
  46. >>> [
  47. >>> 465.2864074707031,
  48. >>> 642.8411254882812,
  49. >>> 454.38568115234375,
  50. >>> 357.4076232910156,
  51. >>> 902.5343017578125,
  52. >>> 334.18377685546875,
  53. >>> 922.0982055664062,
  54. >>> 621.0704345703125
  55. >>> ]
  56. >>> ],
  57. >>> "scores": [
  58. >>> 0.9296008944511414,
  59. >>> 0.9260380268096924
  60. >>> ]
  61. >>> }
  62. >>>
  63. """
  64. def __init__(self, model: str, **kwargs):
  65. """
  66. use `model` to create a face detection pipeline for prediction
  67. Args:
  68. model: model id on modelscope hub or `ScrfdDetect` Model.
  69. preprocessor: `SCRFDPreprocessor`.
  70. """
  71. super().__init__(model=model, **kwargs)
  72. assert isinstance(self.model,
  73. Model), 'model object is not initialized.'
  74. detector = self.model.to(self.device)
  75. self.detector = detector
  76. def __call__(self, input: Union[Image, List[Image]], **kwargs):
  77. """
  78. Detect objects (bounding boxes or keypoints) in the image(s) passed as inputs.
  79. Args:
  80. input (`Image` or `List[Image]`):
  81. The pipeline handles three types of images:
  82. - A string containing an HTTP(S) link pointing to an image
  83. - A string containing a local path to an image
  84. - An image loaded in PIL or opencv directly
  85. The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
  86. same format.
  87. Return:
  88. A dictionary of result or a list of dictionary of result. If the input is an image, a dictionary
  89. is returned. If input is a list of image, a list of dictionary is returned.
  90. The dictionary contain the following keys:
  91. - **scores** (`List[float]`) -- The detection score for each card in the image.
  92. - **boxes** (`List[float]) -- The bounding boxe [x1, y1, x2, y2] of detected objects in in image's
  93. original size.
  94. - **keypoints** (`List[Dict[str, int]]`, optional) -- The corner kepoint [x1, y1, x2, y2, x3, y3, x4, y4]
  95. of detected object in image's original size.
  96. """
  97. return super().__call__(input, **kwargs)
  98. def preprocess(self, input: Image) -> Dict[str, Any]:
  99. result = self.preprocessor(input)
  100. # openmmlab model compatibility
  101. if 'img_metas' in result:
  102. from mmcv.parallel import collate, scatter
  103. result = collate([result], samples_per_gpu=1)
  104. if next(self.model.parameters()).is_cuda:
  105. # scatter to specified GPU
  106. result = scatter(result,
  107. [next(self.model.parameters()).device])[0]
  108. return result
  109. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  110. return self.detector(**input)
  111. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  112. return inputs