object_detection.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from typing import TYPE_CHECKING, Any, Union, overload
  2. from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
  3. from .base import Pipeline, build_pipeline_init_args
  4. if is_vision_available():
  5. from ..image_utils import load_image
  6. if is_torch_available():
  7. import torch
  8. from ..models.auto.modeling_auto import (
  9. MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
  10. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
  11. )
  12. if TYPE_CHECKING:
  13. from PIL import Image
  14. logger = logging.get_logger(__name__)
  15. @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
  16. class ObjectDetectionPipeline(Pipeline):
  17. """
  18. Object detection pipeline using any `AutoModelForObjectDetection`. This pipeline predicts bounding boxes of objects
  19. and their classes.
  20. Example:
  21. ```python
  22. >>> from transformers import pipeline
  23. >>> detector = pipeline(model="facebook/detr-resnet-50")
  24. >>> detector("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
  25. [{'score': 0.997, 'label': 'bird', 'box': {'xmin': 69, 'ymin': 171, 'xmax': 396, 'ymax': 507}}, {'score': 0.999, 'label': 'bird', 'box': {'xmin': 398, 'ymin': 105, 'xmax': 767, 'ymax': 507}}]
  26. >>> # x, y are expressed relative to the top left hand corner.
  27. ```
  28. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  29. This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  30. `"object-detection"`.
  31. See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=object-detection).
  32. """
  33. _load_processor = False
  34. _load_image_processor = True
  35. _load_feature_extractor = False
  36. _load_tokenizer = None
  37. def __init__(self, *args, **kwargs):
  38. super().__init__(*args, **kwargs)
  39. if self.framework == "tf":
  40. raise ValueError(f"The {self.__class__} is only available in PyTorch.")
  41. requires_backends(self, "vision")
  42. mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES.copy()
  43. mapping.update(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
  44. self.check_model_type(mapping)
  45. def _sanitize_parameters(self, **kwargs):
  46. preprocess_params = {}
  47. if "timeout" in kwargs:
  48. preprocess_params["timeout"] = kwargs["timeout"]
  49. postprocess_kwargs = {}
  50. if "threshold" in kwargs:
  51. postprocess_kwargs["threshold"] = kwargs["threshold"]
  52. return preprocess_params, {}, postprocess_kwargs
  53. @overload
  54. def __call__(self, image: Union[str, "Image.Image"], *args: Any, **kwargs: Any) -> list[dict[str, Any]]: ...
  55. @overload
  56. def __call__(
  57. self, image: Union[list[str], list["Image.Image"]], *args: Any, **kwargs: Any
  58. ) -> list[list[dict[str, Any]]]: ...
  59. def __call__(self, *args, **kwargs) -> Union[list[dict[str, Any]], list[list[dict[str, Any]]]]:
  60. """
  61. Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
  62. Args:
  63. inputs (`str`, `list[str]`, `PIL.Image` or `list[PIL.Image]`):
  64. The pipeline handles three types of images:
  65. - A string containing an HTTP(S) link pointing to an image
  66. - A string containing a local path to an image
  67. - An image loaded in PIL directly
  68. The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
  69. same format: all as HTTP(S) links, all as local paths, or all as PIL images.
  70. threshold (`float`, *optional*, defaults to 0.5):
  71. The probability necessary to make a prediction.
  72. timeout (`float`, *optional*, defaults to None):
  73. The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
  74. the call may block forever.
  75. Return:
  76. A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
  77. image, will return a list of dictionaries, if the input is a list of several images, will return a list of
  78. list of dictionaries corresponding to each image.
  79. The dictionaries contain the following keys:
  80. - **label** (`str`) -- The class label identified by the model.
  81. - **score** (`float`) -- The score attributed by the model for that label.
  82. - **box** (`list[dict[str, int]]`) -- The bounding box of detected object in image's original size.
  83. """
  84. # After deprecation of this is completed, remove the default `None` value for `images`
  85. if "images" in kwargs and "inputs" not in kwargs:
  86. kwargs["inputs"] = kwargs.pop("images")
  87. return super().__call__(*args, **kwargs)
  88. def preprocess(self, image, timeout=None):
  89. image = load_image(image, timeout=timeout)
  90. target_size = torch.IntTensor([[image.height, image.width]])
  91. inputs = self.image_processor(images=[image], return_tensors="pt")
  92. if self.framework == "pt":
  93. inputs = inputs.to(self.dtype)
  94. if self.tokenizer is not None:
  95. inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt")
  96. inputs["target_size"] = target_size
  97. return inputs
  98. def _forward(self, model_inputs):
  99. target_size = model_inputs.pop("target_size")
  100. outputs = self.model(**model_inputs)
  101. model_outputs = outputs.__class__({"target_size": target_size, **outputs})
  102. if self.tokenizer is not None:
  103. model_outputs["bbox"] = model_inputs["bbox"]
  104. return model_outputs
  105. def postprocess(self, model_outputs, threshold=0.5):
  106. target_size = model_outputs["target_size"]
  107. if self.tokenizer is not None:
  108. # This is a LayoutLMForTokenClassification variant.
  109. # The OCR got the boxes and the model classified the words.
  110. height, width = target_size[0].tolist()
  111. def unnormalize(bbox):
  112. return self._get_bounding_box(
  113. torch.Tensor(
  114. [
  115. (width * bbox[0] / 1000),
  116. (height * bbox[1] / 1000),
  117. (width * bbox[2] / 1000),
  118. (height * bbox[3] / 1000),
  119. ]
  120. )
  121. )
  122. scores, classes = model_outputs["logits"].squeeze(0).softmax(dim=-1).max(dim=-1)
  123. labels = [self.model.config.id2label[prediction] for prediction in classes.tolist()]
  124. boxes = [unnormalize(bbox) for bbox in model_outputs["bbox"].squeeze(0)]
  125. keys = ["score", "label", "box"]
  126. annotation = [dict(zip(keys, vals)) for vals in zip(scores.tolist(), labels, boxes) if vals[0] > threshold]
  127. else:
  128. # This is a regular ForObjectDetectionModel
  129. raw_annotations = self.image_processor.post_process_object_detection(model_outputs, threshold, target_size)
  130. raw_annotation = raw_annotations[0]
  131. scores = raw_annotation["scores"]
  132. labels = raw_annotation["labels"]
  133. boxes = raw_annotation["boxes"]
  134. raw_annotation["scores"] = scores.tolist()
  135. raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
  136. raw_annotation["boxes"] = [self._get_bounding_box(box) for box in boxes]
  137. # {"scores": [...], ...} --> [{"score":x, ...}, ...]
  138. keys = ["score", "label", "box"]
  139. annotation = [
  140. dict(zip(keys, vals))
  141. for vals in zip(raw_annotation["scores"], raw_annotation["labels"], raw_annotation["boxes"])
  142. ]
  143. return annotation
  144. def _get_bounding_box(self, box: "torch.Tensor") -> dict[str, int]:
  145. """
  146. Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
  147. Args:
  148. box (`torch.Tensor`): Tensor containing the coordinates in corners format.
  149. Returns:
  150. bbox (`dict[str, int]`): Dict containing the coordinates in corners format.
  151. """
  152. if self.framework != "pt":
  153. raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.")
  154. xmin, ymin, xmax, ymax = box.int().tolist()
  155. bbox = {
  156. "xmin": xmin,
  157. "ymin": ymin,
  158. "xmax": xmax,
  159. "ymax": ymax,
  160. }
  161. return bbox