image_utils.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969
  1. # Copyright 2021 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import base64
  15. import os
  16. from collections.abc import Iterable
  17. from dataclasses import dataclass
  18. from io import BytesIO
  19. from typing import Optional, Union
  20. import numpy as np
  21. import requests
  22. from .utils import (
  23. ExplicitEnum,
  24. is_jax_tensor,
  25. is_numpy_array,
  26. is_tf_tensor,
  27. is_torch_available,
  28. is_torch_tensor,
  29. is_torchvision_available,
  30. is_vision_available,
  31. logging,
  32. requires_backends,
  33. to_numpy,
  34. )
  35. from .utils.constants import ( # noqa: F401
  36. IMAGENET_DEFAULT_MEAN,
  37. IMAGENET_DEFAULT_STD,
  38. IMAGENET_STANDARD_MEAN,
  39. IMAGENET_STANDARD_STD,
  40. OPENAI_CLIP_MEAN,
  41. OPENAI_CLIP_STD,
  42. )
  43. if is_vision_available():
  44. import PIL.Image
  45. import PIL.ImageOps
  46. PILImageResampling = PIL.Image.Resampling
  47. if is_torchvision_available():
  48. from torchvision.transforms import InterpolationMode
  49. pil_torch_interpolation_mapping = {
  50. PILImageResampling.NEAREST: InterpolationMode.NEAREST_EXACT,
  51. PILImageResampling.BOX: InterpolationMode.BOX,
  52. PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
  53. PILImageResampling.HAMMING: InterpolationMode.HAMMING,
  54. PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
  55. PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
  56. }
  57. else:
  58. pil_torch_interpolation_mapping = {}
  59. if is_torch_available():
  60. import torch
  61. logger = logging.get_logger(__name__)
  62. ImageInput = Union[
  63. "PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"]
  64. ]
  65. class ChannelDimension(ExplicitEnum):
  66. FIRST = "channels_first"
  67. LAST = "channels_last"
  68. class AnnotationFormat(ExplicitEnum):
  69. COCO_DETECTION = "coco_detection"
  70. COCO_PANOPTIC = "coco_panoptic"
  71. class AnnotionFormat(ExplicitEnum):
  72. COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
  73. COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
  74. AnnotationType = dict[str, Union[int, str, list[dict]]]
  75. def is_pil_image(img):
  76. return is_vision_available() and isinstance(img, PIL.Image.Image)
  77. class ImageType(ExplicitEnum):
  78. PIL = "pillow"
  79. TORCH = "torch"
  80. NUMPY = "numpy"
  81. TENSORFLOW = "tensorflow"
  82. JAX = "jax"
  83. def get_image_type(image):
  84. if is_pil_image(image):
  85. return ImageType.PIL
  86. if is_torch_tensor(image):
  87. return ImageType.TORCH
  88. if is_numpy_array(image):
  89. return ImageType.NUMPY
  90. if is_tf_tensor(image):
  91. return ImageType.TENSORFLOW
  92. if is_jax_tensor(image):
  93. return ImageType.JAX
  94. raise ValueError(f"Unrecognized image type {type(image)}")
  95. def is_valid_image(img):
  96. return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
  97. def is_valid_list_of_images(images: list):
  98. return images and all(is_valid_image(image) for image in images)
  99. def concatenate_list(input_list):
  100. if isinstance(input_list[0], list):
  101. return [item for sublist in input_list for item in sublist]
  102. elif isinstance(input_list[0], np.ndarray):
  103. return np.concatenate(input_list, axis=0)
  104. elif isinstance(input_list[0], torch.Tensor):
  105. return torch.cat(input_list, dim=0)
  106. def valid_images(imgs):
  107. # If we have an list of images, make sure every image is valid
  108. if isinstance(imgs, (list, tuple)):
  109. for img in imgs:
  110. if not valid_images(img):
  111. return False
  112. # If not a list of tuple, we have been given a single image or batched tensor of images
  113. elif not is_valid_image(imgs):
  114. return False
  115. return True
  116. def is_batched(img):
  117. if isinstance(img, (list, tuple)):
  118. return is_valid_image(img[0])
  119. return False
  120. def is_scaled_image(image: np.ndarray) -> bool:
  121. """
  122. Checks to see whether the pixel values have already been rescaled to [0, 1].
  123. """
  124. if image.dtype == np.uint8:
  125. return False
  126. # It's possible the image has pixel values in [0, 255] but is of floating type
  127. return np.min(image) >= 0 and np.max(image) <= 1
  128. def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]:
  129. """
  130. Ensure that the output is a list of images. If the input is a single image, it is converted to a list of length 1.
  131. If the input is a batch of images, it is converted to a list of images.
  132. Args:
  133. images (`ImageInput`):
  134. Image of images to turn into a list of images.
  135. expected_ndims (`int`, *optional*, defaults to 3):
  136. Expected number of dimensions for a single input image. If the input image has a different number of
  137. dimensions, an error is raised.
  138. """
  139. if is_batched(images):
  140. return images
  141. # Either the input is a single image, in which case we create a list of length 1
  142. if is_pil_image(images):
  143. # PIL images are never batched
  144. return [images]
  145. if is_valid_image(images):
  146. if images.ndim == expected_ndims + 1:
  147. # Batch of images
  148. images = list(images)
  149. elif images.ndim == expected_ndims:
  150. # Single image
  151. images = [images]
  152. else:
  153. raise ValueError(
  154. f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
  155. f" {images.ndim} dimensions."
  156. )
  157. return images
  158. raise ValueError(
  159. "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
  160. f"jax.ndarray, but got {type(images)}."
  161. )
  162. def make_flat_list_of_images(
  163. images: Union[list[ImageInput], ImageInput],
  164. expected_ndims: int = 3,
  165. ) -> ImageInput:
  166. """
  167. Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
  168. If the input is a nested list of images, it is converted to a flat list of images.
  169. Args:
  170. images (`Union[list[ImageInput], ImageInput]`):
  171. The input image.
  172. expected_ndims (`int`, *optional*, defaults to 3):
  173. The expected number of dimensions for a single input image.
  174. Returns:
  175. list: A list of images or a 4d array of images.
  176. """
  177. # If the input is a nested list of images, we flatten it
  178. if (
  179. isinstance(images, (list, tuple))
  180. and all(isinstance(images_i, (list, tuple)) for images_i in images)
  181. and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
  182. ):
  183. return [img for img_list in images for img in img_list]
  184. if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
  185. if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
  186. return images
  187. if images[0].ndim == expected_ndims + 1:
  188. return [img for img_list in images for img in img_list]
  189. if is_valid_image(images):
  190. if is_pil_image(images) or images.ndim == expected_ndims:
  191. return [images]
  192. if images.ndim == expected_ndims + 1:
  193. return list(images)
  194. raise ValueError(f"Could not make a flat list of images from {images}")
  195. def make_nested_list_of_images(
  196. images: Union[list[ImageInput], ImageInput],
  197. expected_ndims: int = 3,
  198. ) -> list[ImageInput]:
  199. """
  200. Ensure that the output is a nested list of images.
  201. Args:
  202. images (`Union[list[ImageInput], ImageInput]`):
  203. The input image.
  204. expected_ndims (`int`, *optional*, defaults to 3):
  205. The expected number of dimensions for a single input image.
  206. Returns:
  207. list: A list of list of images or a list of 4d array of images.
  208. """
  209. # If it's a list of batches, it's already in the right format
  210. if (
  211. isinstance(images, (list, tuple))
  212. and all(isinstance(images_i, (list, tuple)) for images_i in images)
  213. and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
  214. ):
  215. return images
  216. # If it's a list of images, it's a single batch, so convert it to a list of lists
  217. if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
  218. if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
  219. return [images]
  220. if images[0].ndim == expected_ndims + 1:
  221. return [list(image) for image in images]
  222. # If it's a single image, convert it to a list of lists
  223. if is_valid_image(images):
  224. if is_pil_image(images) or images.ndim == expected_ndims:
  225. return [[images]]
  226. if images.ndim == expected_ndims + 1:
  227. return [list(images)]
  228. raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
  229. def to_numpy_array(img) -> np.ndarray:
  230. if not is_valid_image(img):
  231. raise ValueError(f"Invalid image type: {type(img)}")
  232. if is_vision_available() and isinstance(img, PIL.Image.Image):
  233. return np.array(img)
  234. return to_numpy(img)
  235. def infer_channel_dimension_format(
  236. image: np.ndarray, num_channels: Optional[Union[int, tuple[int, ...]]] = None
  237. ) -> ChannelDimension:
  238. """
  239. Infers the channel dimension format of `image`.
  240. Args:
  241. image (`np.ndarray`):
  242. The image to infer the channel dimension of.
  243. num_channels (`int` or `tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
  244. The number of channels of the image.
  245. Returns:
  246. The channel dimension of the image.
  247. """
  248. num_channels = num_channels if num_channels is not None else (1, 3)
  249. num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
  250. if image.ndim == 3:
  251. first_dim, last_dim = 0, 2
  252. elif image.ndim == 4:
  253. first_dim, last_dim = 1, 3
  254. elif image.ndim == 5:
  255. first_dim, last_dim = 2, 4
  256. else:
  257. raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
  258. if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
  259. logger.warning(
  260. f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension."
  261. )
  262. return ChannelDimension.FIRST
  263. elif image.shape[first_dim] in num_channels:
  264. return ChannelDimension.FIRST
  265. elif image.shape[last_dim] in num_channels:
  266. return ChannelDimension.LAST
  267. raise ValueError("Unable to infer channel dimension format")
  268. def get_channel_dimension_axis(
  269. image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
  270. ) -> int:
  271. """
  272. Returns the channel dimension axis of the image.
  273. Args:
  274. image (`np.ndarray`):
  275. The image to get the channel dimension axis of.
  276. input_data_format (`ChannelDimension` or `str`, *optional*):
  277. The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
  278. Returns:
  279. The channel dimension axis of the image.
  280. """
  281. if input_data_format is None:
  282. input_data_format = infer_channel_dimension_format(image)
  283. if input_data_format == ChannelDimension.FIRST:
  284. return image.ndim - 3
  285. elif input_data_format == ChannelDimension.LAST:
  286. return image.ndim - 1
  287. raise ValueError(f"Unsupported data format: {input_data_format}")
  288. def get_image_size(image: np.ndarray, channel_dim: Optional[ChannelDimension] = None) -> tuple[int, int]:
  289. """
  290. Returns the (height, width) dimensions of the image.
  291. Args:
  292. image (`np.ndarray`):
  293. The image to get the dimensions of.
  294. channel_dim (`ChannelDimension`, *optional*):
  295. Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
  296. Returns:
  297. A tuple of the image's height and width.
  298. """
  299. if channel_dim is None:
  300. channel_dim = infer_channel_dimension_format(image)
  301. if channel_dim == ChannelDimension.FIRST:
  302. return image.shape[-2], image.shape[-1]
  303. elif channel_dim == ChannelDimension.LAST:
  304. return image.shape[-3], image.shape[-2]
  305. else:
  306. raise ValueError(f"Unsupported data format: {channel_dim}")
  307. def get_image_size_for_max_height_width(
  308. image_size: tuple[int, int],
  309. max_height: int,
  310. max_width: int,
  311. ) -> tuple[int, int]:
  312. """
  313. Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
  314. Important, even if image_height < max_height and image_width < max_width, the image will be resized
  315. to at least one of the edges be equal to max_height or max_width.
  316. For example:
  317. - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
  318. - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
  319. Args:
  320. image_size (`tuple[int, int]`):
  321. The image to resize.
  322. max_height (`int`):
  323. The maximum allowed height.
  324. max_width (`int`):
  325. The maximum allowed width.
  326. """
  327. height, width = image_size
  328. height_scale = max_height / height
  329. width_scale = max_width / width
  330. min_scale = min(height_scale, width_scale)
  331. new_height = int(height * min_scale)
  332. new_width = int(width * min_scale)
  333. return new_height, new_width
  334. def is_valid_annotation_coco_detection(annotation: dict[str, Union[list, tuple]]) -> bool:
  335. if (
  336. isinstance(annotation, dict)
  337. and "image_id" in annotation
  338. and "annotations" in annotation
  339. and isinstance(annotation["annotations"], (list, tuple))
  340. and (
  341. # an image can have no annotations
  342. len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
  343. )
  344. ):
  345. return True
  346. return False
  347. def is_valid_annotation_coco_panoptic(annotation: dict[str, Union[list, tuple]]) -> bool:
  348. if (
  349. isinstance(annotation, dict)
  350. and "image_id" in annotation
  351. and "segments_info" in annotation
  352. and "file_name" in annotation
  353. and isinstance(annotation["segments_info"], (list, tuple))
  354. and (
  355. # an image can have no segments
  356. len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
  357. )
  358. ):
  359. return True
  360. return False
  361. def valid_coco_detection_annotations(annotations: Iterable[dict[str, Union[list, tuple]]]) -> bool:
  362. return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
  363. def valid_coco_panoptic_annotations(annotations: Iterable[dict[str, Union[list, tuple]]]) -> bool:
  364. return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
  365. def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
  366. """
  367. Loads `image` to a PIL Image.
  368. Args:
  369. image (`str` or `PIL.Image.Image`):
  370. The image to convert to the PIL Image format.
  371. timeout (`float`, *optional*):
  372. The timeout value in seconds for the URL request.
  373. Returns:
  374. `PIL.Image.Image`: A PIL Image.
  375. """
  376. requires_backends(load_image, ["vision"])
  377. if isinstance(image, str):
  378. if image.startswith("http://") or image.startswith("https://"):
  379. # We need to actually check for a real protocol, otherwise it's impossible to use a local file
  380. # like http_huggingface_co.png
  381. image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
  382. elif os.path.isfile(image):
  383. image = PIL.Image.open(image)
  384. else:
  385. if image.startswith("data:image/"):
  386. image = image.split(",")[1]
  387. # Try to load as base64
  388. try:
  389. b64 = base64.decodebytes(image.encode())
  390. image = PIL.Image.open(BytesIO(b64))
  391. except Exception as e:
  392. raise ValueError(
  393. f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
  394. )
  395. elif not isinstance(image, PIL.Image.Image):
  396. raise TypeError(
  397. "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
  398. )
  399. image = PIL.ImageOps.exif_transpose(image)
  400. image = image.convert("RGB")
  401. return image
  402. def load_images(
  403. images: Union[list, tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
  404. ) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]:
  405. """Loads images, handling different levels of nesting.
  406. Args:
  407. images: A single image, a list of images, or a list of lists of images to load.
  408. timeout: Timeout for loading images.
  409. Returns:
  410. A single image, a list of images, a list of lists of images.
  411. """
  412. if isinstance(images, (list, tuple)):
  413. if len(images) and isinstance(images[0], (list, tuple)):
  414. return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
  415. else:
  416. return [load_image(image, timeout=timeout) for image in images]
  417. else:
  418. return load_image(images, timeout=timeout)
  419. def validate_preprocess_arguments(
  420. do_rescale: Optional[bool] = None,
  421. rescale_factor: Optional[float] = None,
  422. do_normalize: Optional[bool] = None,
  423. image_mean: Optional[Union[float, list[float]]] = None,
  424. image_std: Optional[Union[float, list[float]]] = None,
  425. do_pad: Optional[bool] = None,
  426. pad_size: Optional[Union[dict[str, int], int]] = None,
  427. do_center_crop: Optional[bool] = None,
  428. crop_size: Optional[dict[str, int]] = None,
  429. do_resize: Optional[bool] = None,
  430. size: Optional[dict[str, int]] = None,
  431. resample: Optional["PILImageResampling"] = None,
  432. interpolation: Optional["InterpolationMode"] = None,
  433. ):
  434. """
  435. Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
  436. Raises `ValueError` if arguments incompatibility is caught.
  437. Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
  438. sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
  439. existing arguments when possible.
  440. """
  441. if do_rescale and rescale_factor is None:
  442. raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
  443. if do_pad and pad_size is None:
  444. # Processors pad images using different args depending on the model, so the below check is pointless
  445. # but we keep it for BC for now. TODO: remove in v5
  446. # Usually padding can be called with:
  447. # - "pad_size/size" if we're padding to specific values
  448. # - "size_divisor" if we're padding to any value divisible by X
  449. # - "None" if we're padding to the maximum size image in batch
  450. raise ValueError(
  451. "Depending on the model, `size_divisor` or `pad_size` or `size` must be specified if `do_pad` is `True`."
  452. )
  453. if do_normalize and (image_mean is None or image_std is None):
  454. raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
  455. if do_center_crop and crop_size is None:
  456. raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
  457. if interpolation is not None and resample is not None:
  458. raise ValueError(
  459. "Only one of `interpolation` and `resample` should be specified, depending on image processor type."
  460. )
  461. if do_resize and not (size is not None and (resample is not None or interpolation is not None)):
  462. raise ValueError("`size` and `resample/interpolation` must be specified if `do_resize` is `True`.")
  463. # In the future we can add a TF implementation here when we have TF models.
  464. class ImageFeatureExtractionMixin:
  465. """
  466. Mixin that contain utilities for preparing image features.
  467. """
  468. def _ensure_format_supported(self, image):
  469. if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
  470. raise ValueError(
  471. f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.ndarray` and "
  472. "`torch.Tensor` are."
  473. )
  474. def to_pil_image(self, image, rescale=None):
  475. """
  476. Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
  477. needed.
  478. Args:
  479. image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
  480. The image to convert to the PIL Image format.
  481. rescale (`bool`, *optional*):
  482. Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
  483. default to `True` if the image type is a floating type, `False` otherwise.
  484. """
  485. self._ensure_format_supported(image)
  486. if is_torch_tensor(image):
  487. image = image.numpy()
  488. if isinstance(image, np.ndarray):
  489. if rescale is None:
  490. # rescale default to the array being of floating type.
  491. rescale = isinstance(image.flat[0], np.floating)
  492. # If the channel as been moved to first dim, we put it back at the end.
  493. if image.ndim == 3 and image.shape[0] in [1, 3]:
  494. image = image.transpose(1, 2, 0)
  495. if rescale:
  496. image = image * 255
  497. image = image.astype(np.uint8)
  498. return PIL.Image.fromarray(image)
  499. return image
  500. def convert_rgb(self, image):
  501. """
  502. Converts `PIL.Image.Image` to RGB format.
  503. Args:
  504. image (`PIL.Image.Image`):
  505. The image to convert.
  506. """
  507. self._ensure_format_supported(image)
  508. if not isinstance(image, PIL.Image.Image):
  509. return image
  510. return image.convert("RGB")
  511. def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
  512. """
  513. Rescale a numpy image by scale amount
  514. """
  515. self._ensure_format_supported(image)
  516. return image * scale
  517. def to_numpy_array(self, image, rescale=None, channel_first=True):
  518. """
  519. Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
  520. dimension.
  521. Args:
  522. image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
  523. The image to convert to a NumPy array.
  524. rescale (`bool`, *optional*):
  525. Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
  526. default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
  527. channel_first (`bool`, *optional*, defaults to `True`):
  528. Whether or not to permute the dimensions of the image to put the channel dimension first.
  529. """
  530. self._ensure_format_supported(image)
  531. if isinstance(image, PIL.Image.Image):
  532. image = np.array(image)
  533. if is_torch_tensor(image):
  534. image = image.numpy()
  535. rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
  536. if rescale:
  537. image = self.rescale(image.astype(np.float32), 1 / 255.0)
  538. if channel_first and image.ndim == 3:
  539. image = image.transpose(2, 0, 1)
  540. return image
  541. def expand_dims(self, image):
  542. """
  543. Expands 2-dimensional `image` to 3 dimensions.
  544. Args:
  545. image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
  546. The image to expand.
  547. """
  548. self._ensure_format_supported(image)
  549. # Do nothing if PIL image
  550. if isinstance(image, PIL.Image.Image):
  551. return image
  552. if is_torch_tensor(image):
  553. image = image.unsqueeze(0)
  554. else:
  555. image = np.expand_dims(image, axis=0)
  556. return image
  557. def normalize(self, image, mean, std, rescale=False):
  558. """
  559. Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
  560. if it's a PIL Image.
  561. Args:
  562. image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
  563. The image to normalize.
  564. mean (`list[float]` or `np.ndarray` or `torch.Tensor`):
  565. The mean (per channel) to use for normalization.
  566. std (`list[float]` or `np.ndarray` or `torch.Tensor`):
  567. The standard deviation (per channel) to use for normalization.
  568. rescale (`bool`, *optional*, defaults to `False`):
  569. Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
  570. happen automatically.
  571. """
  572. self._ensure_format_supported(image)
  573. if isinstance(image, PIL.Image.Image):
  574. image = self.to_numpy_array(image, rescale=True)
  575. # If the input image is a PIL image, it automatically gets rescaled. If it's another
  576. # type it may need rescaling.
  577. elif rescale:
  578. if isinstance(image, np.ndarray):
  579. image = self.rescale(image.astype(np.float32), 1 / 255.0)
  580. elif is_torch_tensor(image):
  581. image = self.rescale(image.float(), 1 / 255.0)
  582. if isinstance(image, np.ndarray):
  583. if not isinstance(mean, np.ndarray):
  584. mean = np.array(mean).astype(image.dtype)
  585. if not isinstance(std, np.ndarray):
  586. std = np.array(std).astype(image.dtype)
  587. elif is_torch_tensor(image):
  588. import torch
  589. if not isinstance(mean, torch.Tensor):
  590. if isinstance(mean, np.ndarray):
  591. mean = torch.from_numpy(mean)
  592. else:
  593. mean = torch.tensor(mean)
  594. if not isinstance(std, torch.Tensor):
  595. if isinstance(std, np.ndarray):
  596. std = torch.from_numpy(std)
  597. else:
  598. std = torch.tensor(std)
  599. if image.ndim == 3 and image.shape[0] in [1, 3]:
  600. return (image - mean[:, None, None]) / std[:, None, None]
  601. else:
  602. return (image - mean) / std
  603. def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
  604. """
  605. Resizes `image`. Enforces conversion of input to PIL.Image.
  606. Args:
  607. image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
  608. The image to resize.
  609. size (`int` or `tuple[int, int]`):
  610. The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
  611. matched to this.
  612. If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
  613. `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
  614. this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
  615. resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
  616. The filter to user for resampling.
  617. default_to_square (`bool`, *optional*, defaults to `True`):
  618. How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
  619. square (`size`,`size`). If set to `False`, will replicate
  620. [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
  621. with support for resizing only the smallest edge and providing an optional `max_size`.
  622. max_size (`int`, *optional*, defaults to `None`):
  623. The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
  624. greater than `max_size` after being resized according to `size`, then the image is resized again so
  625. that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
  626. edge may be shorter than `size`. Only used if `default_to_square` is `False`.
  627. Returns:
  628. image: A resized `PIL.Image.Image`.
  629. """
  630. resample = resample if resample is not None else PILImageResampling.BILINEAR
  631. self._ensure_format_supported(image)
  632. if not isinstance(image, PIL.Image.Image):
  633. image = self.to_pil_image(image)
  634. if isinstance(size, list):
  635. size = tuple(size)
  636. if isinstance(size, int) or len(size) == 1:
  637. if default_to_square:
  638. size = (size, size) if isinstance(size, int) else (size[0], size[0])
  639. else:
  640. width, height = image.size
  641. # specified size only for the smallest edge
  642. short, long = (width, height) if width <= height else (height, width)
  643. requested_new_short = size if isinstance(size, int) else size[0]
  644. if short == requested_new_short:
  645. return image
  646. new_short, new_long = requested_new_short, int(requested_new_short * long / short)
  647. if max_size is not None:
  648. if max_size <= requested_new_short:
  649. raise ValueError(
  650. f"max_size = {max_size} must be strictly greater than the requested "
  651. f"size for the smaller edge size = {size}"
  652. )
  653. if new_long > max_size:
  654. new_short, new_long = int(max_size * new_short / new_long), max_size
  655. size = (new_short, new_long) if width <= height else (new_long, new_short)
  656. return image.resize(size, resample=resample)
  657. def center_crop(self, image, size):
  658. """
  659. Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
  660. size given, it will be padded (so the returned result has the size asked).
  661. Args:
  662. image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
  663. The image to resize.
  664. size (`int` or `tuple[int, int]`):
  665. The size to which crop the image.
  666. Returns:
  667. new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
  668. height, width).
  669. """
  670. self._ensure_format_supported(image)
  671. if not isinstance(size, tuple):
  672. size = (size, size)
  673. # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
  674. if is_torch_tensor(image) or isinstance(image, np.ndarray):
  675. if image.ndim == 2:
  676. image = self.expand_dims(image)
  677. image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
  678. else:
  679. image_shape = (image.size[1], image.size[0])
  680. top = (image_shape[0] - size[0]) // 2
  681. bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
  682. left = (image_shape[1] - size[1]) // 2
  683. right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
  684. # For PIL Images we have a method to crop directly.
  685. if isinstance(image, PIL.Image.Image):
  686. return image.crop((left, top, right, bottom))
  687. # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
  688. channel_first = image.shape[0] in [1, 3]
  689. # Transpose (height, width, n_channels) format images
  690. if not channel_first:
  691. if isinstance(image, np.ndarray):
  692. image = image.transpose(2, 0, 1)
  693. if is_torch_tensor(image):
  694. image = image.permute(2, 0, 1)
  695. # Check if cropped area is within image boundaries
  696. if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
  697. return image[..., top:bottom, left:right]
  698. # Otherwise, we may need to pad if the image is too small. Oh joy...
  699. new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
  700. if isinstance(image, np.ndarray):
  701. new_image = np.zeros_like(image, shape=new_shape)
  702. elif is_torch_tensor(image):
  703. new_image = image.new_zeros(new_shape)
  704. top_pad = (new_shape[-2] - image_shape[0]) // 2
  705. bottom_pad = top_pad + image_shape[0]
  706. left_pad = (new_shape[-1] - image_shape[1]) // 2
  707. right_pad = left_pad + image_shape[1]
  708. new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
  709. top += top_pad
  710. bottom += top_pad
  711. left += left_pad
  712. right += left_pad
  713. new_image = new_image[
  714. ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
  715. ]
  716. return new_image
  717. def flip_channel_order(self, image):
  718. """
  719. Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
  720. `image` to a NumPy array if it's a PIL Image.
  721. Args:
  722. image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
  723. The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
  724. be first.
  725. """
  726. self._ensure_format_supported(image)
  727. if isinstance(image, PIL.Image.Image):
  728. image = self.to_numpy_array(image)
  729. return image[::-1, :, :]
  730. def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
  731. """
  732. Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
  733. counter clockwise around its centre.
  734. Args:
  735. image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
  736. The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
  737. rotating.
  738. Returns:
  739. image: A rotated `PIL.Image.Image`.
  740. """
  741. resample = resample if resample is not None else PIL.Image.NEAREST
  742. self._ensure_format_supported(image)
  743. if not isinstance(image, PIL.Image.Image):
  744. image = self.to_pil_image(image)
  745. return image.rotate(
  746. angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
  747. )
  748. def validate_annotations(
  749. annotation_format: AnnotationFormat,
  750. supported_annotation_formats: tuple[AnnotationFormat, ...],
  751. annotations: list[dict],
  752. ) -> None:
  753. if annotation_format not in supported_annotation_formats:
  754. raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
  755. if annotation_format is AnnotationFormat.COCO_DETECTION:
  756. if not valid_coco_detection_annotations(annotations):
  757. raise ValueError(
  758. "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
  759. "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
  760. "being a list of annotations in the COCO format."
  761. )
  762. if annotation_format is AnnotationFormat.COCO_PANOPTIC:
  763. if not valid_coco_panoptic_annotations(annotations):
  764. raise ValueError(
  765. "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
  766. "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
  767. "the latter being a list of annotations in the COCO format."
  768. )
  769. def validate_kwargs(valid_processor_keys: list[str], captured_kwargs: list[str]):
  770. unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
  771. if unused_keys:
  772. unused_key_str = ", ".join(unused_keys)
  773. # TODO raise a warning here instead of simply logging?
  774. logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
  775. @dataclass(frozen=True)
  776. class SizeDict:
  777. """
  778. Hashable dictionary to store image size information.
  779. """
  780. height: Optional[int] = None
  781. width: Optional[int] = None
  782. longest_edge: Optional[int] = None
  783. shortest_edge: Optional[int] = None
  784. max_height: Optional[int] = None
  785. max_width: Optional[int] = None
  786. def __getitem__(self, key):
  787. if hasattr(self, key):
  788. return getattr(self, key)
  789. raise KeyError(f"Key {key} not found in SizeDict.")