image_processing_utils_fast.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833
  1. # Copyright 2024 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Iterable
  15. from copy import deepcopy
  16. from functools import lru_cache, partial
  17. from typing import Any, Optional, TypedDict, Union
  18. import numpy as np
  19. from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  20. from .image_transforms import (
  21. convert_to_rgb,
  22. get_resize_output_image_size,
  23. get_size_with_aspect_ratio,
  24. group_images_by_shape,
  25. reorder_images,
  26. )
  27. from .image_utils import (
  28. ChannelDimension,
  29. ImageInput,
  30. ImageType,
  31. SizeDict,
  32. get_image_size,
  33. get_image_size_for_max_height_width,
  34. get_image_type,
  35. infer_channel_dimension_format,
  36. make_flat_list_of_images,
  37. validate_kwargs,
  38. validate_preprocess_arguments,
  39. )
  40. from .processing_utils import Unpack
  41. from .utils import (
  42. TensorType,
  43. auto_docstring,
  44. is_torch_available,
  45. is_torchvision_available,
  46. is_vision_available,
  47. logging,
  48. )
  49. from .utils.import_utils import is_rocm_platform
  50. if is_vision_available():
  51. from .image_utils import PILImageResampling
  52. if is_torch_available():
  53. import torch
  54. if is_torchvision_available():
  55. from torchvision.transforms.v2 import functional as F
  56. from .image_utils import pil_torch_interpolation_mapping
  57. else:
  58. pil_torch_interpolation_mapping = None
  59. logger = logging.get_logger(__name__)
  60. @lru_cache(maxsize=10)
  61. def validate_fast_preprocess_arguments(
  62. do_rescale: Optional[bool] = None,
  63. rescale_factor: Optional[float] = None,
  64. do_normalize: Optional[bool] = None,
  65. image_mean: Optional[Union[float, list[float]]] = None,
  66. image_std: Optional[Union[float, list[float]]] = None,
  67. do_center_crop: Optional[bool] = None,
  68. crop_size: Optional[SizeDict] = None,
  69. do_resize: Optional[bool] = None,
  70. size: Optional[SizeDict] = None,
  71. interpolation: Optional["F.InterpolationMode"] = None,
  72. return_tensors: Optional[Union[str, TensorType]] = None,
  73. data_format: ChannelDimension = ChannelDimension.FIRST,
  74. ):
  75. """
  76. Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
  77. Raises `ValueError` if arguments incompatibility is caught.
  78. """
  79. validate_preprocess_arguments(
  80. do_rescale=do_rescale,
  81. rescale_factor=rescale_factor,
  82. do_normalize=do_normalize,
  83. image_mean=image_mean,
  84. image_std=image_std,
  85. do_center_crop=do_center_crop,
  86. crop_size=crop_size,
  87. do_resize=do_resize,
  88. size=size,
  89. interpolation=interpolation,
  90. )
  91. # Extra checks for ImageProcessorFast
  92. if return_tensors is not None and return_tensors != "pt":
  93. raise ValueError("Only returning PyTorch tensors is currently supported.")
  94. if data_format != ChannelDimension.FIRST:
  95. raise ValueError("Only channel first data format is currently supported.")
  96. def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
  97. """
  98. Squeezes a tensor, but only if the axis specified has dim 1.
  99. """
  100. if axis is None:
  101. return tensor.squeeze()
  102. try:
  103. return tensor.squeeze(axis=axis)
  104. except ValueError:
  105. return tensor
  106. def max_across_indices(values: Iterable[Any]) -> list[Any]:
  107. """
  108. Return the maximum value across all indices of an iterable of values.
  109. """
  110. return [max(values_i) for values_i in zip(*values)]
  111. def get_max_height_width(images: list["torch.Tensor"]) -> tuple[int, ...]:
  112. """
  113. Get the maximum height and width across all images in a batch.
  114. """
  115. _, max_height, max_width = max_across_indices([img.shape for img in images])
  116. return (max_height, max_width)
  117. def divide_to_patches(
  118. image: Union[np.ndarray, "torch.Tensor"], patch_size: int
  119. ) -> list[Union[np.ndarray, "torch.Tensor"]]:
  120. """
  121. Divides an image into patches of a specified size.
  122. Args:
  123. image (`Union[np.array, "torch.Tensor"]`):
  124. The input image.
  125. patch_size (`int`):
  126. The size of each patch.
  127. Returns:
  128. list: A list of Union[np.array, "torch.Tensor"] representing the patches.
  129. """
  130. patches = []
  131. height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
  132. for i in range(0, height, patch_size):
  133. for j in range(0, width, patch_size):
  134. patch = image[:, i : i + patch_size, j : j + patch_size]
  135. patches.append(patch)
  136. return patches
  137. class DefaultFastImageProcessorKwargs(TypedDict, total=False):
  138. do_resize: Optional[bool]
  139. size: Optional[dict[str, int]]
  140. default_to_square: Optional[bool]
  141. resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]]
  142. do_center_crop: Optional[bool]
  143. crop_size: Optional[dict[str, int]]
  144. do_rescale: Optional[bool]
  145. rescale_factor: Optional[Union[int, float]]
  146. do_normalize: Optional[bool]
  147. image_mean: Optional[Union[float, list[float]]]
  148. image_std: Optional[Union[float, list[float]]]
  149. do_pad: Optional[bool]
  150. pad_size: Optional[dict[str, int]]
  151. do_convert_rgb: Optional[bool]
  152. return_tensors: Optional[Union[str, TensorType]]
  153. data_format: Optional[ChannelDimension]
  154. input_data_format: Optional[Union[str, ChannelDimension]]
  155. device: Optional["torch.device"]
  156. disable_grouping: Optional[bool]
  157. @auto_docstring
  158. class BaseImageProcessorFast(BaseImageProcessor):
  159. resample = None
  160. image_mean = None
  161. image_std = None
  162. size = None
  163. default_to_square = True
  164. crop_size = None
  165. do_resize = None
  166. do_center_crop = None
  167. do_pad = None
  168. pad_size = None
  169. do_rescale = None
  170. rescale_factor = 1 / 255
  171. do_normalize = None
  172. do_convert_rgb = None
  173. return_tensors = None
  174. data_format = ChannelDimension.FIRST
  175. input_data_format = None
  176. device = None
  177. model_input_names = ["pixel_values"]
  178. valid_kwargs = DefaultFastImageProcessorKwargs
  179. unused_kwargs = None
  180. def __init__(self, **kwargs: Unpack[DefaultFastImageProcessorKwargs]):
  181. super().__init__(**kwargs)
  182. kwargs = self.filter_out_unused_kwargs(kwargs)
  183. size = kwargs.pop("size", self.size)
  184. self.size = (
  185. get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
  186. if size is not None
  187. else None
  188. )
  189. crop_size = kwargs.pop("crop_size", self.crop_size)
  190. self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
  191. pad_size = kwargs.pop("pad_size", self.pad_size)
  192. self.pad_size = get_size_dict(size=pad_size, param_name="pad_size") if pad_size is not None else None
  193. for key in self.valid_kwargs.__annotations__:
  194. kwarg = kwargs.pop(key, None)
  195. if kwarg is not None:
  196. setattr(self, key, kwarg)
  197. else:
  198. setattr(self, key, deepcopy(getattr(self, key, None)))
  199. # get valid kwargs names
  200. self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
  201. @property
  202. def is_fast(self) -> bool:
  203. """
  204. `bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision).
  205. """
  206. return True
  207. def pad(
  208. self,
  209. images: "torch.Tensor",
  210. pad_size: SizeDict = None,
  211. fill_value: Optional[int] = 0,
  212. padding_mode: Optional[str] = "constant",
  213. return_mask: bool = False,
  214. disable_grouping: Optional[bool] = False,
  215. **kwargs,
  216. ) -> "torch.Tensor":
  217. """
  218. Pads images to `(pad_size["height"], pad_size["width"])` or to the largest size in the batch.
  219. Args:
  220. images (`torch.Tensor`):
  221. Images to pad.
  222. pad_size (`SizeDict`, *optional*):
  223. Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
  224. fill_value (`int`, *optional*, defaults to `0`):
  225. The constant value used to fill the padded area.
  226. padding_mode (`str`, *optional*, defaults to "constant"):
  227. The padding mode to use. Can be any of the modes supported by
  228. `torch.nn.functional.pad` (e.g. constant, reflection, replication).
  229. return_mask (`bool`, *optional*, defaults to `False`):
  230. Whether to return a pixel mask to denote padded regions.
  231. disable_grouping (`bool`, *optional*, defaults to `False`):
  232. Whether to disable grouping of images by size.
  233. Returns:
  234. `torch.Tensor`: The resized image.
  235. """
  236. if pad_size is not None:
  237. if not (pad_size.height and pad_size.width):
  238. raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
  239. pad_size = (pad_size.height, pad_size.width)
  240. else:
  241. pad_size = get_max_height_width(images)
  242. grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
  243. processed_images_grouped = {}
  244. processed_masks_grouped = {}
  245. for shape, stacked_images in grouped_images.items():
  246. image_size = stacked_images.shape[-2:]
  247. padding_height = pad_size[0] - image_size[0]
  248. padding_width = pad_size[1] - image_size[1]
  249. if padding_height < 0 or padding_width < 0:
  250. raise ValueError(
  251. f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
  252. f"image size. Got pad_size={pad_size}, image_size={image_size}."
  253. )
  254. if image_size != pad_size:
  255. padding = (0, 0, padding_width, padding_height)
  256. stacked_images = F.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
  257. processed_images_grouped[shape] = stacked_images
  258. if return_mask:
  259. # keep only one from the channel dimension in pixel mask
  260. stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :]
  261. stacked_masks[..., : image_size[0], : image_size[1]] = 1
  262. processed_masks_grouped[shape] = stacked_masks
  263. processed_images = reorder_images(processed_images_grouped, grouped_images_index)
  264. if return_mask:
  265. processed_masks = reorder_images(processed_masks_grouped, grouped_images_index)
  266. return processed_images, processed_masks
  267. return processed_images
  268. def resize(
  269. self,
  270. image: "torch.Tensor",
  271. size: SizeDict,
  272. interpolation: Optional["F.InterpolationMode"] = None,
  273. antialias: bool = True,
  274. **kwargs,
  275. ) -> "torch.Tensor":
  276. """
  277. Resize an image to `(size["height"], size["width"])`.
  278. Args:
  279. image (`torch.Tensor`):
  280. Image to resize.
  281. size (`SizeDict`):
  282. Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
  283. interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
  284. `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
  285. Returns:
  286. `torch.Tensor`: The resized image.
  287. """
  288. interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
  289. if size.shortest_edge and size.longest_edge:
  290. # Resize the image so that the shortest edge or the longest edge is of the given size
  291. # while maintaining the aspect ratio of the original image.
  292. new_size = get_size_with_aspect_ratio(
  293. image.size()[-2:],
  294. size.shortest_edge,
  295. size.longest_edge,
  296. )
  297. elif size.shortest_edge:
  298. new_size = get_resize_output_image_size(
  299. image,
  300. size=size.shortest_edge,
  301. default_to_square=False,
  302. input_data_format=ChannelDimension.FIRST,
  303. )
  304. elif size.max_height and size.max_width:
  305. new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width)
  306. elif size.height and size.width:
  307. new_size = (size.height, size.width)
  308. else:
  309. raise ValueError(
  310. "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
  311. f" {size}."
  312. )
  313. # This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs
  314. # Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209
  315. # TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
  316. if torch.compiler.is_compiling() and is_rocm_platform():
  317. return self.compile_friendly_resize(image, new_size, interpolation, antialias)
  318. return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
  319. @staticmethod
  320. def compile_friendly_resize(
  321. image: "torch.Tensor",
  322. new_size: tuple[int, int],
  323. interpolation: Optional["F.InterpolationMode"] = None,
  324. antialias: bool = True,
  325. ) -> "torch.Tensor":
  326. """
  327. A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
  328. """
  329. if image.dtype == torch.uint8:
  330. # 256 is used on purpose instead of 255 to avoid numerical differences
  331. # see https://github.com/huggingface/transformers/pull/38540#discussion_r2127165652
  332. image = image.float() / 256
  333. image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
  334. image = image * 256
  335. # torch.where is used on purpose instead of torch.clamp to avoid bug in torch.compile
  336. # see https://github.com/huggingface/transformers/pull/38540#discussion_r2126888471
  337. image = torch.where(image > 255, 255, image)
  338. image = torch.where(image < 0, 0, image)
  339. image = image.round().to(torch.uint8)
  340. else:
  341. image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
  342. return image
  343. def rescale(
  344. self,
  345. image: "torch.Tensor",
  346. scale: float,
  347. **kwargs,
  348. ) -> "torch.Tensor":
  349. """
  350. Rescale an image by a scale factor. image = image * scale.
  351. Args:
  352. image (`torch.Tensor`):
  353. Image to rescale.
  354. scale (`float`):
  355. The scaling factor to rescale pixel values by.
  356. Returns:
  357. `torch.Tensor`: The rescaled image.
  358. """
  359. return image * scale
  360. def normalize(
  361. self,
  362. image: "torch.Tensor",
  363. mean: Union[float, Iterable[float]],
  364. std: Union[float, Iterable[float]],
  365. **kwargs,
  366. ) -> "torch.Tensor":
  367. """
  368. Normalize an image. image = (image - image_mean) / image_std.
  369. Args:
  370. image (`torch.Tensor`):
  371. Image to normalize.
  372. mean (`torch.Tensor`, `float` or `Iterable[float]`):
  373. Image mean to use for normalization.
  374. std (`torch.Tensor`, `float` or `Iterable[float]`):
  375. Image standard deviation to use for normalization.
  376. Returns:
  377. `torch.Tensor`: The normalized image.
  378. """
  379. return F.normalize(image, mean, std)
  380. @lru_cache(maxsize=10)
  381. def _fuse_mean_std_and_rescale_factor(
  382. self,
  383. do_normalize: Optional[bool] = None,
  384. image_mean: Optional[Union[float, list[float]]] = None,
  385. image_std: Optional[Union[float, list[float]]] = None,
  386. do_rescale: Optional[bool] = None,
  387. rescale_factor: Optional[float] = None,
  388. device: Optional["torch.device"] = None,
  389. ) -> tuple:
  390. if do_rescale and do_normalize:
  391. # Fused rescale and normalize
  392. image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
  393. image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
  394. do_rescale = False
  395. return image_mean, image_std, do_rescale
  396. def rescale_and_normalize(
  397. self,
  398. images: "torch.Tensor",
  399. do_rescale: bool,
  400. rescale_factor: float,
  401. do_normalize: bool,
  402. image_mean: Union[float, list[float]],
  403. image_std: Union[float, list[float]],
  404. ) -> "torch.Tensor":
  405. """
  406. Rescale and normalize images.
  407. """
  408. image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
  409. do_normalize=do_normalize,
  410. image_mean=image_mean,
  411. image_std=image_std,
  412. do_rescale=do_rescale,
  413. rescale_factor=rescale_factor,
  414. device=images.device,
  415. )
  416. # if/elif as we use fused rescale and normalize if both are set to True
  417. if do_normalize:
  418. images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
  419. elif do_rescale:
  420. images = self.rescale(images, rescale_factor)
  421. return images
  422. def center_crop(
  423. self,
  424. image: "torch.Tensor",
  425. size: SizeDict,
  426. **kwargs,
  427. ) -> "torch.Tensor":
  428. """
  429. Note: override torchvision's center_crop to have the same behavior as the slow processor.
  430. Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
  431. any edge, the image is padded with 0's and then center cropped.
  432. Args:
  433. image (`"torch.Tensor"`):
  434. Image to center crop.
  435. size (`dict[str, int]`):
  436. Size of the output image.
  437. Returns:
  438. `torch.Tensor`: The center cropped image.
  439. """
  440. if size.height is None or size.width is None:
  441. raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
  442. image_height, image_width = image.shape[-2:]
  443. crop_height, crop_width = size.height, size.width
  444. if crop_width > image_width or crop_height > image_height:
  445. padding_ltrb = [
  446. (crop_width - image_width) // 2 if crop_width > image_width else 0,
  447. (crop_height - image_height) // 2 if crop_height > image_height else 0,
  448. (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
  449. (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
  450. ]
  451. image = F.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0
  452. image_height, image_width = image.shape[-2:]
  453. if crop_width == image_width and crop_height == image_height:
  454. return image
  455. crop_top = int((image_height - crop_height) / 2.0)
  456. crop_left = int((image_width - crop_width) / 2.0)
  457. return F.crop(image, crop_top, crop_left, crop_height, crop_width)
  458. def convert_to_rgb(
  459. self,
  460. image: ImageInput,
  461. ) -> ImageInput:
  462. """
  463. Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
  464. as is.
  465. Args:
  466. image (ImageInput):
  467. The image to convert.
  468. Returns:
  469. ImageInput: The converted image.
  470. """
  471. return convert_to_rgb(image)
  472. def filter_out_unused_kwargs(self, kwargs: dict):
  473. """
  474. Filter out the unused kwargs from the kwargs dictionary.
  475. """
  476. if self.unused_kwargs is None:
  477. return kwargs
  478. for kwarg_name in self.unused_kwargs:
  479. if kwarg_name in kwargs:
  480. logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.")
  481. kwargs.pop(kwarg_name)
  482. return kwargs
  483. def _prepare_images_structure(
  484. self,
  485. images: ImageInput,
  486. expected_ndims: int = 3,
  487. ) -> ImageInput:
  488. """
  489. Prepare the images structure for processing.
  490. Args:
  491. images (`ImageInput`):
  492. The input images to process.
  493. Returns:
  494. `ImageInput`: The images with a valid nesting.
  495. """
  496. # Checks for `str` in case of URL/local path and optionally loads images
  497. images = self.fetch_images(images)
  498. return make_flat_list_of_images(images, expected_ndims=expected_ndims)
  499. def _process_image(
  500. self,
  501. image: ImageInput,
  502. do_convert_rgb: Optional[bool] = None,
  503. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  504. device: Optional["torch.device"] = None,
  505. ) -> "torch.Tensor":
  506. image_type = get_image_type(image)
  507. if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
  508. raise ValueError(f"Unsupported input image type {image_type}")
  509. if do_convert_rgb:
  510. image = self.convert_to_rgb(image)
  511. if image_type == ImageType.PIL:
  512. image = F.pil_to_tensor(image)
  513. elif image_type == ImageType.NUMPY:
  514. # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
  515. image = torch.from_numpy(image).contiguous()
  516. # If the image is 2D, we need to unsqueeze it to add a channel dimension for processing
  517. if image.ndim == 2:
  518. image = image.unsqueeze(0)
  519. # Infer the channel dimension format if not provided
  520. if input_data_format is None:
  521. input_data_format = infer_channel_dimension_format(image)
  522. if input_data_format == ChannelDimension.LAST:
  523. # We force the channel dimension to be first for torch tensors as this is what torchvision expects.
  524. image = image.permute(2, 0, 1).contiguous()
  525. # Now that we have torch tensors, we can move them to the right device
  526. if device is not None:
  527. image = image.to(device)
  528. return image
  529. def _prepare_image_like_inputs(
  530. self,
  531. images: ImageInput,
  532. do_convert_rgb: Optional[bool] = None,
  533. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  534. device: Optional["torch.device"] = None,
  535. expected_ndims: int = 3,
  536. ) -> list["torch.Tensor"]:
  537. """
  538. Prepare image-like inputs for processing.
  539. Args:
  540. images (`ImageInput`):
  541. The image-like inputs to process.
  542. do_convert_rgb (`bool`, *optional*):
  543. Whether to convert the images to RGB.
  544. input_data_format (`str` or `ChannelDimension`, *optional*):
  545. The input data format of the images.
  546. device (`torch.device`, *optional*):
  547. The device to put the processed images on.
  548. expected_ndims (`int`, *optional*):
  549. The expected number of dimensions for the images. (can be 2 for segmentation maps etc.)
  550. Returns:
  551. List[`torch.Tensor`]: The processed images.
  552. """
  553. # Get structured images (potentially nested)
  554. images = self._prepare_images_structure(images, expected_ndims=expected_ndims)
  555. process_image_partial = partial(
  556. self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
  557. )
  558. # Check if we have nested structure, assuming the nesting is consistent
  559. has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple))
  560. if has_nested_structure:
  561. processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
  562. else:
  563. processed_images = [process_image_partial(img) for img in images]
  564. return processed_images
  565. def _further_process_kwargs(
  566. self,
  567. size: Optional[SizeDict] = None,
  568. crop_size: Optional[SizeDict] = None,
  569. pad_size: Optional[SizeDict] = None,
  570. default_to_square: Optional[bool] = None,
  571. image_mean: Optional[Union[float, list[float]]] = None,
  572. image_std: Optional[Union[float, list[float]]] = None,
  573. data_format: Optional[ChannelDimension] = None,
  574. **kwargs,
  575. ) -> dict:
  576. """
  577. Update kwargs that need further processing before being validated
  578. Can be overridden by subclasses to customize the processing of kwargs.
  579. """
  580. if kwargs is None:
  581. kwargs = {}
  582. if size is not None:
  583. size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
  584. if crop_size is not None:
  585. crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
  586. if pad_size is not None:
  587. pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size"))
  588. if isinstance(image_mean, list):
  589. image_mean = tuple(image_mean)
  590. if isinstance(image_std, list):
  591. image_std = tuple(image_std)
  592. if data_format is None:
  593. data_format = ChannelDimension.FIRST
  594. kwargs["size"] = size
  595. kwargs["crop_size"] = crop_size
  596. kwargs["pad_size"] = pad_size
  597. kwargs["image_mean"] = image_mean
  598. kwargs["image_std"] = image_std
  599. kwargs["data_format"] = data_format
  600. # torch resize uses interpolation instead of resample
  601. # Check if resample is an int before checking if it's an instance of PILImageResampling
  602. # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
  603. # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
  604. resample = kwargs.pop("resample")
  605. kwargs["interpolation"] = (
  606. pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
  607. )
  608. return kwargs
  609. def _validate_preprocess_kwargs(
  610. self,
  611. do_rescale: Optional[bool] = None,
  612. rescale_factor: Optional[float] = None,
  613. do_normalize: Optional[bool] = None,
  614. image_mean: Optional[Union[float, tuple[float]]] = None,
  615. image_std: Optional[Union[float, tuple[float]]] = None,
  616. do_resize: Optional[bool] = None,
  617. size: Optional[SizeDict] = None,
  618. do_center_crop: Optional[bool] = None,
  619. crop_size: Optional[SizeDict] = None,
  620. interpolation: Optional["F.InterpolationMode"] = None,
  621. return_tensors: Optional[Union[str, TensorType]] = None,
  622. data_format: Optional[ChannelDimension] = None,
  623. **kwargs,
  624. ):
  625. """
  626. validate the kwargs for the preprocess method.
  627. """
  628. validate_fast_preprocess_arguments(
  629. do_rescale=do_rescale,
  630. rescale_factor=rescale_factor,
  631. do_normalize=do_normalize,
  632. image_mean=image_mean,
  633. image_std=image_std,
  634. do_resize=do_resize,
  635. size=size,
  636. do_center_crop=do_center_crop,
  637. crop_size=crop_size,
  638. interpolation=interpolation,
  639. return_tensors=return_tensors,
  640. data_format=data_format,
  641. )
  642. def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
  643. return self.preprocess(images, *args, **kwargs)
  644. @auto_docstring
  645. def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
  646. # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
  647. validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
  648. # Set default kwargs from self. This ensures that if a kwarg is not provided
  649. # by the user, it gets its default value from the instance, or is set to None.
  650. for kwarg_name in self._valid_kwargs_names:
  651. kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
  652. # Extract parameters that are only used for preparing the input images
  653. do_convert_rgb = kwargs.pop("do_convert_rgb")
  654. input_data_format = kwargs.pop("input_data_format")
  655. device = kwargs.pop("device")
  656. # Update kwargs that need further processing before being validated
  657. kwargs = self._further_process_kwargs(**kwargs)
  658. # Validate kwargs
  659. self._validate_preprocess_kwargs(**kwargs)
  660. # Pop kwargs that are not needed in _preprocess
  661. kwargs.pop("data_format")
  662. return self._preprocess_image_like_inputs(
  663. images, *args, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device, **kwargs
  664. )
  665. def _preprocess_image_like_inputs(
  666. self,
  667. images: ImageInput,
  668. *args,
  669. do_convert_rgb: bool,
  670. input_data_format: ChannelDimension,
  671. device: Optional[Union[str, "torch.device"]] = None,
  672. **kwargs: Unpack[DefaultFastImageProcessorKwargs],
  673. ) -> BatchFeature:
  674. """
  675. Preprocess image-like inputs.
  676. To be overridden by subclasses when image-like inputs other than images should be processed.
  677. It can be used for segmentation maps, depth maps, etc.
  678. """
  679. # Prepare input images
  680. images = self._prepare_image_like_inputs(
  681. images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
  682. )
  683. return self._preprocess(images, *args, **kwargs)
  684. def _preprocess(
  685. self,
  686. images: list["torch.Tensor"],
  687. do_resize: bool,
  688. size: SizeDict,
  689. interpolation: Optional["F.InterpolationMode"],
  690. do_center_crop: bool,
  691. crop_size: SizeDict,
  692. do_rescale: bool,
  693. rescale_factor: float,
  694. do_normalize: bool,
  695. image_mean: Optional[Union[float, list[float]]],
  696. image_std: Optional[Union[float, list[float]]],
  697. do_pad: Optional[bool],
  698. pad_size: Optional[SizeDict],
  699. disable_grouping: Optional[bool],
  700. return_tensors: Optional[Union[str, TensorType]],
  701. **kwargs,
  702. ) -> BatchFeature:
  703. # Group images by size for batched resizing
  704. grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
  705. resized_images_grouped = {}
  706. for shape, stacked_images in grouped_images.items():
  707. if do_resize:
  708. stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
  709. resized_images_grouped[shape] = stacked_images
  710. resized_images = reorder_images(resized_images_grouped, grouped_images_index)
  711. # Group images by size for further processing
  712. # Needed in case do_resize is False, or resize returns images with different sizes
  713. grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
  714. processed_images_grouped = {}
  715. for shape, stacked_images in grouped_images.items():
  716. if do_center_crop:
  717. stacked_images = self.center_crop(stacked_images, crop_size)
  718. # Fused rescale and normalize
  719. stacked_images = self.rescale_and_normalize(
  720. stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
  721. )
  722. processed_images_grouped[shape] = stacked_images
  723. processed_images = reorder_images(processed_images_grouped, grouped_images_index)
  724. if do_pad:
  725. processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
  726. processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
  727. return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
  728. def to_dict(self):
  729. encoder_dict = super().to_dict()
  730. encoder_dict.pop("_valid_processor_keys", None)
  731. encoder_dict.pop("_valid_kwargs_names", None)
  732. return encoder_dict