processing_sam2.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Processor class for SAM2.
  17. """
  18. from copy import deepcopy
  19. from typing import Optional, Union
  20. import numpy as np
  21. from ...image_utils import ImageInput
  22. from ...processing_utils import ProcessorMixin
  23. from ...tokenization_utils_base import BatchEncoding
  24. from ...utils import TensorType, is_torch_available, logging
  25. from ...utils.import_utils import requires
  26. logger = logging.get_logger(__name__)
  27. if is_torch_available():
  28. import torch
  29. @requires(backends=("torch",))
  30. class Sam2Processor(ProcessorMixin):
  31. r"""
  32. Constructs a SAM2 processor which wraps a SAM2 image processor and an 2D points & Bounding boxes processor into a
  33. single processor.
  34. [`Sam2Processor`] offers all the functionalities of [`Sam2ImageProcessorFast`] and [`Sam2VideoProcessor`]. See the docstring of
  35. [`~Sam2ImageProcessorFast.__call__`] and [`~Sam2VideoProcessor.__call__`] for more information.
  36. Args:
  37. image_processor (`Sam2ImageProcessorFast`):
  38. An instance of [`Sam2ImageProcessorFast`].
  39. target_size (`int`, *optional*):
  40. The target size (target_size, target_size) to which the image will be resized.
  41. point_pad_value (`int`, *optional*, defaults to -10):
  42. The value used for padding input points.
  43. """
  44. attributes = ["image_processor"]
  45. image_processor_class = "Sam2ImageProcessorFast"
  46. def __init__(self, image_processor, target_size: Optional[int] = None, point_pad_value: int = -10, **kwargs):
  47. super().__init__(image_processor, **kwargs)
  48. self.point_pad_value = point_pad_value
  49. self.target_size = target_size if target_size is not None else self.image_processor.size["height"]
  50. def __call__(
  51. self,
  52. images: Optional[ImageInput] = None,
  53. segmentation_maps: Optional[ImageInput] = None,
  54. input_points: Optional[Union[list[list[list[list[float]]]], torch.Tensor]] = None,
  55. input_labels: Optional[Union[list[list[list[int]]], torch.Tensor]] = None,
  56. input_boxes: Optional[Union[list[list[list[float]]], torch.Tensor]] = None,
  57. original_sizes: Optional[Union[list[list[float]], torch.Tensor]] = None,
  58. return_tensors: Optional[Union[str, TensorType]] = None,
  59. **kwargs,
  60. ) -> BatchEncoding:
  61. r"""
  62. This method uses [`Sam2ImageProcessorFast.__call__`] method to prepare image(s) for the model. It also prepares 2D
  63. points and bounding boxes for the model if they are provided.
  64. Args:
  65. images (`ImageInput`, *optional*):
  66. The image(s) to process.
  67. segmentation_maps (`ImageInput`, *optional*):
  68. The segmentation maps to process.
  69. input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*):
  70. The points to add to the frame.
  71. input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*):
  72. The labels for the points.
  73. input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
  74. The bounding boxes to add to the frame.
  75. original_sizes (`list[list[float]]`, `torch.Tensor`, *optional*):
  76. The original sizes of the images.
  77. return_tensors (`str` or `TensorType`, *optional*):
  78. The type of tensors to return.
  79. **kwargs:
  80. Additional keyword arguments to pass to the image processor.
  81. Returns:
  82. A [`BatchEncoding`] with the following fields:
  83. - `pixel_values` (`torch.Tensor`): The processed image(s).
  84. - `original_sizes` (`list[list[float]]`): The original sizes of the images.
  85. - `reshaped_input_sizes` (`torch.Tensor`): The reshaped input sizes of the images.
  86. - `labels` (`torch.Tensor`): The processed segmentation maps (if provided).
  87. - `input_points` (`torch.Tensor`): The processed points.
  88. - `input_labels` (`torch.Tensor`): The processed labels.
  89. - `input_boxes` (`torch.Tensor`): The processed bounding boxes.
  90. """
  91. if images is not None:
  92. encoding_image_processor = self.image_processor(
  93. images,
  94. segmentation_maps=segmentation_maps,
  95. return_tensors=return_tensors,
  96. **kwargs,
  97. )
  98. elif original_sizes is not None:
  99. if isinstance(original_sizes, torch.Tensor):
  100. original_sizes = original_sizes.cpu().tolist()
  101. encoding_image_processor = BatchEncoding({"original_sizes": original_sizes}, tensor_type=return_tensors)
  102. else:
  103. raise ValueError("Either images or original_sizes must be provided")
  104. # pop arguments that are not used in the forward but used nevertheless
  105. original_sizes = encoding_image_processor["original_sizes"]
  106. # Check original_sizes is of length 1 or len(images)
  107. if images is not None and len(original_sizes) != 1 and len(original_sizes) != len(images):
  108. raise ValueError(
  109. "original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size."
  110. )
  111. # Process input points, labels, and boxes if provided
  112. if input_points is not None or input_labels is not None or input_boxes is not None:
  113. # Validate and convert inputs to standardized format
  114. processed_points = self._validate_single_input(
  115. input_points,
  116. expected_depth=4,
  117. input_name="points",
  118. expected_format="[image level, object level, point level, point coordinates]",
  119. expected_coord_size=2,
  120. )
  121. processed_labels = self._validate_single_input(
  122. input_labels,
  123. expected_depth=3,
  124. input_name="labels",
  125. expected_format="[image level, object level, point level]",
  126. )
  127. processed_boxes = self._validate_single_input(
  128. input_boxes,
  129. expected_depth=3,
  130. input_name="boxes",
  131. expected_format="[image level, box level, box coordinates]",
  132. expected_coord_size=4,
  133. )
  134. # Get padding requirements for all inputs
  135. if processed_points is not None:
  136. points_max_dims = self._get_nested_dimensions(processed_points)[:3]
  137. if processed_labels is not None:
  138. labels_max_dims = self._get_nested_dimensions(processed_labels)[:3]
  139. if processed_boxes is not None:
  140. boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2]
  141. # Ensure points and labels have consistent dimensions
  142. if processed_points is not None and processed_labels is not None:
  143. if points_max_dims != labels_max_dims:
  144. raise ValueError(
  145. "Input points and labels have inconsistent dimensions. Please ensure they have the same dimensions."
  146. )
  147. # Check that boxes don't need padding (model limitation)
  148. if processed_boxes is not None and len(processed_boxes) >= 2:
  149. if any(len(img_boxes) < boxes_max_dims[1] for img_boxes in processed_boxes):
  150. raise ValueError(
  151. "Input boxes have inconsistent dimensions that would require padding, "
  152. "but boxes cannot be padded due to model limitations. "
  153. "Please ensure all images have the same number of boxes."
  154. )
  155. # Pad and normalize all inputs to final tensor format
  156. if processed_points is not None:
  157. padded_points = self._pad_nested_list(processed_points, points_max_dims + [2])
  158. final_points = torch.tensor(padded_points, dtype=torch.float32)
  159. self._normalize_tensor_coordinates(final_points, original_sizes, preserve_padding=True)
  160. encoding_image_processor.update({"input_points": final_points})
  161. if processed_labels is not None:
  162. padded_labels = self._pad_nested_list(processed_labels, labels_max_dims)
  163. final_labels = torch.tensor(padded_labels, dtype=torch.int64)
  164. encoding_image_processor.update({"input_labels": final_labels})
  165. if processed_boxes is not None:
  166. final_boxes = torch.tensor(processed_boxes, dtype=torch.float32)
  167. self._normalize_tensor_coordinates(final_boxes, original_sizes, is_bounding_box=True)
  168. encoding_image_processor.update({"input_boxes": final_boxes})
  169. return encoding_image_processor
  170. def _normalize_coordinates(
  171. self, target_size: int, coords: "torch.Tensor", original_size, is_bounding_box=False
  172. ) -> "torch.Tensor":
  173. """
  174. Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
  175. Args:
  176. target_size (`int`):
  177. The target size of the image.
  178. coords (`torch.Tensor`):
  179. The coordinates to be normalized.
  180. original_size (`tuple`):
  181. The original size of the image.
  182. is_bounding_box (`bool`, *optional*, defaults to `False`):
  183. Whether the coordinates are bounding boxes.
  184. """
  185. old_h, old_w = original_size
  186. new_h, new_w = target_size, target_size
  187. coords = deepcopy(coords).float()
  188. if is_bounding_box:
  189. coords = coords.reshape(-1, 2, 2)
  190. coords[..., 0] = coords[..., 0] * (new_w / old_w)
  191. coords[..., 1] = coords[..., 1] * (new_h / old_h)
  192. if is_bounding_box:
  193. coords = coords.reshape(-1, 4)
  194. return coords
  195. def _convert_to_nested_list(self, data, expected_depth, current_depth=0):
  196. """
  197. Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists.
  198. Args:
  199. data: Input data in any format
  200. expected_depth: Expected nesting depth
  201. current_depth: Current depth in recursion
  202. Returns:
  203. Nested list representation of the data
  204. """
  205. if data is None:
  206. return None
  207. # Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array
  208. if isinstance(data, torch.Tensor): # PyTorch tensor
  209. if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small tensor
  210. return data.numpy().tolist()
  211. else:
  212. return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
  213. elif isinstance(data, np.ndarray): # NumPy array
  214. if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small array
  215. return data.tolist()
  216. else:
  217. return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
  218. elif isinstance(data, list):
  219. if current_depth == expected_depth:
  220. # We've reached the expected depth, return as is
  221. return data
  222. else:
  223. # Continue recursion
  224. return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
  225. elif isinstance(data, (int, float)):
  226. return data
  227. else:
  228. raise ValueError(f"Unsupported data type: {type(data)}")
  229. def _get_nested_dimensions(self, nested_list, max_dims=None):
  230. """
  231. Get the maximum dimensions at each level of nesting.
  232. Args:
  233. nested_list (`list`):
  234. Nested list structure.
  235. max_dims (`list`, *optional*):
  236. Current maximum dimensions (for recursion).
  237. Returns:
  238. `list`: A list of maximum dimensions for each nesting level.
  239. """
  240. if max_dims is None:
  241. max_dims = []
  242. if not isinstance(nested_list, list):
  243. return max_dims
  244. if len(max_dims) == 0:
  245. max_dims.append(len(nested_list))
  246. else:
  247. max_dims[0] = max(max_dims[0], len(nested_list))
  248. if len(nested_list) > 0:
  249. for item in nested_list:
  250. if isinstance(item, list):
  251. sub_dims = self._get_nested_dimensions(item)
  252. # Merge sub_dims into max_dims
  253. for i, dim in enumerate(sub_dims):
  254. if i + 1 >= len(max_dims):
  255. max_dims.append(dim)
  256. else:
  257. max_dims[i + 1] = max(max_dims[i + 1], dim)
  258. return max_dims
  259. def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None):
  260. """
  261. Recursively pad a nested list to match target dimensions.
  262. Args:
  263. nested_list (`list`):
  264. Nested list to pad.
  265. target_dims (`list`):
  266. Target dimensions for each level.
  267. current_level (`int`, *optional*, defaults to 0):
  268. Current nesting level.
  269. pad_value (`int`, *optional*):
  270. Value to use for padding.
  271. Returns:
  272. `list`: The padded nested list.
  273. """
  274. if pad_value is None:
  275. pad_value = self.point_pad_value
  276. if current_level >= len(target_dims):
  277. return nested_list
  278. # Ensure we have a list
  279. if not isinstance(nested_list, list):
  280. nested_list = [nested_list]
  281. # Pad current level
  282. current_size = len(nested_list)
  283. target_size = target_dims[current_level]
  284. # Pad with appropriate values
  285. if current_level == len(target_dims) - 1:
  286. # At the coordinate level, pad with pad_value
  287. nested_list.extend([pad_value] * (target_size - current_size))
  288. else:
  289. # At higher levels, pad with nested structures
  290. if current_size > 0:
  291. # Create appropriately sized template
  292. if current_level < len(target_dims) - 2:
  293. # For non-coordinate levels, create empty nested structure
  294. template_dims = target_dims[current_level + 1 :]
  295. template = self._create_empty_nested_structure(template_dims, pad_value)
  296. else:
  297. # For coordinate level, create list of pad_values
  298. template = [pad_value] * target_dims[current_level + 1]
  299. nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)])
  300. else:
  301. # Create from scratch
  302. template_dims = target_dims[current_level + 1 :]
  303. template = self._create_empty_nested_structure(template_dims, pad_value)
  304. nested_list.extend([deepcopy(template) for _ in range(target_size)])
  305. # Recursively pad sublists
  306. if current_level < len(target_dims) - 1:
  307. for i in range(len(nested_list)):
  308. if isinstance(nested_list[i], list):
  309. nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value)
  310. return nested_list
  311. def _create_empty_nested_structure(self, dims, pad_value):
  312. """
  313. Create an empty nested structure with given dimensions filled with pad_value.
  314. Args:
  315. dims (`list`):
  316. The dimensions of the nested structure.
  317. pad_value (`int`):
  318. The value to fill the structure with.
  319. """
  320. if len(dims) == 1:
  321. return [pad_value] * dims[0]
  322. else:
  323. return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])]
  324. def _get_nesting_level(self, input_list):
  325. """
  326. Get the nesting level of a list structure.
  327. Args:
  328. input_list (`list`):
  329. The list to get the nesting level of.
  330. """
  331. if isinstance(input_list, list):
  332. if len(input_list) == 0:
  333. return 1
  334. return 1 + self._get_nesting_level(input_list[0])
  335. elif isinstance(input_list, (np.ndarray, torch.Tensor)):
  336. # For arrays/tensors, the nesting level is the number of dimensions
  337. return len(input_list.shape)
  338. return 0
  339. def _validate_single_input(
  340. self,
  341. data: Union[torch.Tensor, np.ndarray, list],
  342. expected_depth: int,
  343. input_name: str,
  344. expected_format: str,
  345. expected_coord_size: Optional[int] = None,
  346. ) -> list:
  347. """
  348. Validate a single input by ensuring proper nesting and raising an error if the input is not valid.
  349. Args:
  350. data (`torch.Tensor`, `np.ndarray`, or `list`):
  351. Input data to process.
  352. expected_depth (`int`):
  353. Expected nesting depth.
  354. input_name (`str`):
  355. Name of the input for error messages.
  356. expected_format (`str`):
  357. The expected format of the input.
  358. expected_coord_size (`int`, *optional*):
  359. Expected coordinate size (2 for points, 4 for boxes, None for labels).
  360. .
  361. """
  362. if data is None:
  363. return None
  364. # Handle tensors and numpy arrays first
  365. if isinstance(data, (torch.Tensor, np.ndarray)):
  366. # For tensors/arrays, we can directly check the number of dimensions
  367. if data.ndim != expected_depth:
  368. raise ValueError(
  369. f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions."
  370. )
  371. elif expected_coord_size is not None:
  372. if data.shape[-1] != expected_coord_size:
  373. raise ValueError(
  374. f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}."
  375. )
  376. return self._convert_to_nested_list(data, expected_depth)
  377. # Handle nested lists
  378. if isinstance(data, list):
  379. current_depth = self._get_nesting_level(data)
  380. if current_depth != expected_depth:
  381. raise ValueError(
  382. f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels."
  383. )
  384. return self._convert_to_nested_list(data, expected_depth)
  385. def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False):
  386. """
  387. Helper method to normalize coordinates in a tensor across multiple images.
  388. Args:
  389. tensor (`torch.Tensor`):
  390. Input tensor with coordinates.
  391. original_sizes (`list`):
  392. Original image sizes.
  393. is_bounding_box (`bool`, *optional*, defaults to `False`):
  394. Whether coordinates are bounding boxes.
  395. preserve_padding (`bool`, *optional*, defaults to `False`):
  396. Whether to preserve padding values (for points).
  397. """
  398. if preserve_padding:
  399. # For points: avoid normalizing pad values
  400. mask = tensor != self.point_pad_value
  401. coord_mask = mask.all(dim=-1, keepdim=True)
  402. for img_idx in range(len(original_sizes)):
  403. if img_idx < tensor.shape[0]:
  404. original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0]
  405. normalized_coords = self._normalize_coordinates(
  406. self.target_size, tensor[img_idx], original_size, is_bounding_box=is_bounding_box
  407. )
  408. if preserve_padding:
  409. # Only update non-padded values
  410. img_mask = coord_mask[img_idx]
  411. tensor[img_idx] = torch.where(
  412. img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx]
  413. )
  414. else:
  415. tensor[img_idx] = normalized_coords
  416. def post_process_masks(
  417. self,
  418. masks,
  419. original_sizes,
  420. mask_threshold=0.0,
  421. binarize=True,
  422. max_hole_area=0.0,
  423. max_sprinkle_area=0.0,
  424. apply_non_overlapping_constraints=False,
  425. **kwargs,
  426. ):
  427. """
  428. Remove padding and upscale masks to the original image size.
  429. Args:
  430. masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
  431. Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
  432. original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
  433. The original sizes of each image before it was resized to the model's expected input shape, in (height,
  434. width) format.
  435. mask_threshold (`float`, *optional*, defaults to 0.0):
  436. Threshold for binarization and post-processing operations.
  437. binarize (`bool`, *optional*, defaults to `True`):
  438. Whether to binarize the masks.
  439. max_hole_area (`float`, *optional*, defaults to 0.0):
  440. The maximum area of a hole to fill.
  441. max_sprinkle_area (`float`, *optional*, defaults to 0.0):
  442. The maximum area of a sprinkle to fill.
  443. apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
  444. Whether to apply non-overlapping constraints to the masks.
  445. Returns:
  446. (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
  447. is given by original_size.
  448. """
  449. return self.image_processor.post_process_masks(
  450. masks,
  451. original_sizes,
  452. mask_threshold,
  453. binarize,
  454. max_hole_area,
  455. max_sprinkle_area,
  456. apply_non_overlapping_constraints,
  457. **kwargs,
  458. )
  459. __all__ = ["Sam2Processor"]