modular_yolos.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from typing import Optional, Union
  2. import torch
  3. from transformers.models.detr.image_processing_detr_fast import DetrImageProcessorFast
  4. from ...image_transforms import center_to_corners_format
  5. from ...utils import (
  6. TensorType,
  7. logging,
  8. )
  9. logger = logging.get_logger(__name__)
  10. def get_size_with_aspect_ratio(
  11. image_size: tuple[int, int], size: int, max_size: Optional[int] = None, mod_size: int = 16
  12. ) -> tuple[int, int]:
  13. """
  14. Computes the output image size given the input image size and the desired output size with multiple of divisible_size.
  15. Args:
  16. image_size (`tuple[int, int]`):
  17. The input image size.
  18. size (`int`):
  19. The desired output size.
  20. max_size (`int`, *optional*):
  21. The maximum allowed output size.
  22. mod_size (`int`, *optional*):
  23. The size to make multiple of mod_size.
  24. """
  25. height, width = image_size
  26. raw_size = None
  27. if max_size is not None:
  28. min_original_size = float(min((height, width)))
  29. max_original_size = float(max((height, width)))
  30. if max_original_size / min_original_size * size > max_size:
  31. raw_size = max_size * min_original_size / max_original_size
  32. size = int(round(raw_size))
  33. if width < height:
  34. ow = size
  35. if max_size is not None and raw_size is not None:
  36. oh = int(raw_size * height / width)
  37. else:
  38. oh = int(size * height / width)
  39. elif (height <= width and height == size) or (width <= height and width == size):
  40. oh, ow = height, width
  41. else:
  42. oh = size
  43. if max_size is not None and raw_size is not None:
  44. ow = int(raw_size * width / height)
  45. else:
  46. ow = int(size * width / height)
  47. if mod_size is not None:
  48. ow_mod = torch.remainder(torch.tensor(ow), mod_size).item()
  49. oh_mod = torch.remainder(torch.tensor(oh), mod_size).item()
  50. ow = ow - ow_mod
  51. oh = oh - oh_mod
  52. return (oh, ow)
  53. class YolosImageProcessorFast(DetrImageProcessorFast):
  54. def post_process(self, outputs, target_sizes):
  55. """
  56. Converts the raw output of [`YolosForObjectDetection`] into final bounding boxes in (top_left_x,
  57. top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
  58. Args:
  59. outputs ([`YolosObjectDetectionOutput`]):
  60. Raw outputs of the model.
  61. target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
  62. Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
  63. original image size (before any data augmentation). For visualization, this should be the image size
  64. after data augment, but before padding.
  65. Returns:
  66. `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
  67. in the batch as predicted by the model.
  68. """
  69. logger.warning_once(
  70. "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
  71. " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
  72. )
  73. out_logits, out_bbox = outputs.logits, outputs.pred_boxes
  74. if len(out_logits) != len(target_sizes):
  75. raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
  76. if target_sizes.shape[1] != 2:
  77. raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
  78. prob = out_logits.sigmoid()
  79. topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
  80. scores = topk_values
  81. topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
  82. labels = topk_indexes % out_logits.shape[2]
  83. boxes = center_to_corners_format(out_bbox)
  84. boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
  85. # and from relative [0, 1] to absolute [0, height] coordinates
  86. img_h, img_w = target_sizes.unbind(1)
  87. scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
  88. boxes = boxes * scale_fct[:, None, :]
  89. results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
  90. return results
  91. def post_process_object_detection(
  92. self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, list[tuple]] = None, top_k: int = 100
  93. ):
  94. """
  95. Converts the raw output of [`YolosForObjectDetection`] into final bounding boxes in (top_left_x,
  96. top_left_y, bottom_right_x, bottom_right_y) format. Only supports PyTorch.
  97. Args:
  98. outputs ([`YolosObjectDetectionOutput`]):
  99. Raw outputs of the model.
  100. threshold (`float`, *optional*):
  101. Score threshold to keep object detection predictions.
  102. target_sizes (`torch.Tensor` or `list[tuple[int, int]]`, *optional*):
  103. Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size
  104. (height, width) of each image in the batch. If left to None, predictions will not be resized.
  105. top_k (`int`, *optional*, defaults to 100):
  106. Keep only top k bounding boxes before filtering by thresholding.
  107. Returns:
  108. `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
  109. in the batch as predicted by the model.
  110. """
  111. out_logits, out_bbox = outputs.logits, outputs.pred_boxes
  112. if target_sizes is not None:
  113. if len(out_logits) != len(target_sizes):
  114. raise ValueError(
  115. "Make sure that you pass in as many target sizes as the batch dimension of the logits"
  116. )
  117. prob = out_logits.sigmoid()
  118. prob = prob.view(out_logits.shape[0], -1)
  119. k_value = min(top_k, prob.size(1))
  120. topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
  121. scores = topk_values
  122. topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
  123. labels = topk_indexes % out_logits.shape[2]
  124. boxes = center_to_corners_format(out_bbox)
  125. boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
  126. # and from relative [0, 1] to absolute [0, height] coordinates
  127. if target_sizes is not None:
  128. if isinstance(target_sizes, list):
  129. img_h = torch.Tensor([i[0] for i in target_sizes])
  130. img_w = torch.Tensor([i[1] for i in target_sizes])
  131. else:
  132. img_h, img_w = target_sizes.unbind(1)
  133. scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
  134. boxes = boxes * scale_fct[:, None, :]
  135. results = []
  136. for s, l, b in zip(scores, labels, boxes):
  137. score = s[s > threshold]
  138. label = l[s > threshold]
  139. box = b[s > threshold]
  140. results.append({"scores": score, "labels": label, "boxes": box})
  141. return results
  142. def post_process_segmentation(self):
  143. raise NotImplementedError("Segmentation post-processing is not implemented for Deformable DETR yet.")
  144. def post_process_instance(self):
  145. raise NotImplementedError("Instance post-processing is not implemented for Deformable DETR yet.")
  146. def post_process_panoptic(self):
  147. raise NotImplementedError("Panoptic post-processing is not implemented for Deformable DETR yet.")
  148. def post_process_instance_segmentation(self):
  149. raise NotImplementedError("Segmentation post-processing is not implemented for Deformable DETR yet.")
  150. def post_process_semantic_segmentation(self):
  151. raise NotImplementedError("Semantic segmentation post-processing is not implemented for Deformable DETR yet.")
  152. def post_process_panoptic_segmentation(self):
  153. raise NotImplementedError("Panoptic segmentation post-processing is not implemented for Deformable DETR yet.")
  154. __all__ = ["YolosImageProcessorFast"]