face_detection_pipeline.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import Any, Dict, List, Union
  4. import cv2
  5. import numpy as np
  6. import PIL
  7. import torch
  8. from modelscope.metainfo import Pipelines
  9. from modelscope.models.base.base_model import Model
  10. from modelscope.models.cv.face_detection import ScrfdDetect, SCRFDPreprocessor
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines.base import Input, Pipeline
  13. from modelscope.pipelines.builder import PIPELINES
  14. from modelscope.preprocessors import LoadImage
  15. from modelscope.utils.config import Config
  16. from modelscope.utils.constant import ModelFile, Tasks
  17. from modelscope.utils.input_output_typing import Image
  18. from modelscope.utils.logger import get_logger
  19. logger = get_logger()
  20. @PIPELINES.register_module(
  21. Tasks.face_detection, module_name=Pipelines.face_detection)
  22. class FaceDetectionPipeline(Pipeline):
  23. def __init__(self, model: str, **kwargs):
  24. """
  25. use `model` to create a face detection pipeline for prediction
  26. Args:
  27. model (`str` or `Model`): model_id or `ScrfdDetect` or `TinyMogDetect` model.
  28. preprocessor(`Preprocessor`, *optional*, defaults to None): `SCRFDPreprocessor`.
  29. """
  30. super().__init__(model=model, **kwargs)
  31. config_path = osp.join(model, ModelFile.CONFIGURATION)
  32. cfg = Config.from_file(config_path)
  33. cfg_model = getattr(cfg, 'model', None)
  34. if cfg_model is None:
  35. # backward compatibility
  36. detector = ScrfdDetect(model_dir=model, **kwargs)
  37. else:
  38. assert isinstance(self.model,
  39. Model), 'model object is not initialized.'
  40. detector = self.model.to(self.device)
  41. # backward compatibility
  42. if self.preprocessor is None:
  43. self.preprocessor = SCRFDPreprocessor()
  44. self.detector = detector
  45. def __call__(self, input: Union[Image, List[Image]], **kwargs):
  46. """
  47. Detect objects (bounding boxes or keypoints) in the image(s) passed as inputs.
  48. Args:
  49. input (`Image` or `List[Image]`):
  50. The pipeline handles three types of images:
  51. - A string containing an HTTP(S) link pointing to an image
  52. - A string containing a local path to an image
  53. - An image loaded in PIL directly
  54. The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
  55. same format.
  56. Return:
  57. A dictionary of result or a list of dictionary of result. If the input is an image, a dictionary
  58. is returned. If input is a list of image, a list of dictionary is returned.
  59. The dictionary contain the following keys:
  60. - **scores** (`List[float]`) -- The detection score for each card in the image.
  61. - **boxes** (`List[float]) -- The bounding boxe [x1, y1, x2, y2] of detected objects in in image's
  62. original size.
  63. - **keypoints** (`List[Dict[str, int]]`, optional) -- The corner kepoint [x1, y1, x2, y2, x3, y3, x4, y4]
  64. of detected object in image's original size.
  65. """
  66. return super().__call__(input, **kwargs)
  67. def preprocess(self, input: Image) -> Dict[str, Any]:
  68. result = self.preprocessor(input)
  69. # openmmlab model compatibility
  70. if 'img_metas' in result:
  71. from mmcv.parallel import collate, scatter
  72. result = collate([result], samples_per_gpu=1)
  73. if next(self.model.parameters()).is_cuda:
  74. # scatter to specified GPU
  75. result = scatter(result,
  76. [next(self.model.parameters()).device])[0]
  77. return result
  78. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  79. return self.detector(**input)
  80. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  81. return inputs