| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833 |
- # Copyright 2024 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Iterable
- from copy import deepcopy
- from functools import lru_cache, partial
- from typing import Any, Optional, TypedDict, Union
- import numpy as np
- from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
- from .image_transforms import (
- convert_to_rgb,
- get_resize_output_image_size,
- get_size_with_aspect_ratio,
- group_images_by_shape,
- reorder_images,
- )
- from .image_utils import (
- ChannelDimension,
- ImageInput,
- ImageType,
- SizeDict,
- get_image_size,
- get_image_size_for_max_height_width,
- get_image_type,
- infer_channel_dimension_format,
- make_flat_list_of_images,
- validate_kwargs,
- validate_preprocess_arguments,
- )
- from .processing_utils import Unpack
- from .utils import (
- TensorType,
- auto_docstring,
- is_torch_available,
- is_torchvision_available,
- is_vision_available,
- logging,
- )
- from .utils.import_utils import is_rocm_platform
- if is_vision_available():
- from .image_utils import PILImageResampling
- if is_torch_available():
- import torch
- if is_torchvision_available():
- from torchvision.transforms.v2 import functional as F
- from .image_utils import pil_torch_interpolation_mapping
- else:
- pil_torch_interpolation_mapping = None
- logger = logging.get_logger(__name__)
- @lru_cache(maxsize=10)
- def validate_fast_preprocess_arguments(
- do_rescale: Optional[bool] = None,
- rescale_factor: Optional[float] = None,
- do_normalize: Optional[bool] = None,
- image_mean: Optional[Union[float, list[float]]] = None,
- image_std: Optional[Union[float, list[float]]] = None,
- do_center_crop: Optional[bool] = None,
- crop_size: Optional[SizeDict] = None,
- do_resize: Optional[bool] = None,
- size: Optional[SizeDict] = None,
- interpolation: Optional["F.InterpolationMode"] = None,
- return_tensors: Optional[Union[str, TensorType]] = None,
- data_format: ChannelDimension = ChannelDimension.FIRST,
- ):
- """
- Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
- Raises `ValueError` if arguments incompatibility is caught.
- """
- validate_preprocess_arguments(
- do_rescale=do_rescale,
- rescale_factor=rescale_factor,
- do_normalize=do_normalize,
- image_mean=image_mean,
- image_std=image_std,
- do_center_crop=do_center_crop,
- crop_size=crop_size,
- do_resize=do_resize,
- size=size,
- interpolation=interpolation,
- )
- # Extra checks for ImageProcessorFast
- if return_tensors is not None and return_tensors != "pt":
- raise ValueError("Only returning PyTorch tensors is currently supported.")
- if data_format != ChannelDimension.FIRST:
- raise ValueError("Only channel first data format is currently supported.")
- def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
- """
- Squeezes a tensor, but only if the axis specified has dim 1.
- """
- if axis is None:
- return tensor.squeeze()
- try:
- return tensor.squeeze(axis=axis)
- except ValueError:
- return tensor
- def max_across_indices(values: Iterable[Any]) -> list[Any]:
- """
- Return the maximum value across all indices of an iterable of values.
- """
- return [max(values_i) for values_i in zip(*values)]
- def get_max_height_width(images: list["torch.Tensor"]) -> tuple[int, ...]:
- """
- Get the maximum height and width across all images in a batch.
- """
- _, max_height, max_width = max_across_indices([img.shape for img in images])
- return (max_height, max_width)
- def divide_to_patches(
- image: Union[np.ndarray, "torch.Tensor"], patch_size: int
- ) -> list[Union[np.ndarray, "torch.Tensor"]]:
- """
- Divides an image into patches of a specified size.
- Args:
- image (`Union[np.array, "torch.Tensor"]`):
- The input image.
- patch_size (`int`):
- The size of each patch.
- Returns:
- list: A list of Union[np.array, "torch.Tensor"] representing the patches.
- """
- patches = []
- height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
- for i in range(0, height, patch_size):
- for j in range(0, width, patch_size):
- patch = image[:, i : i + patch_size, j : j + patch_size]
- patches.append(patch)
- return patches
- class DefaultFastImageProcessorKwargs(TypedDict, total=False):
- do_resize: Optional[bool]
- size: Optional[dict[str, int]]
- default_to_square: Optional[bool]
- resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]]
- do_center_crop: Optional[bool]
- crop_size: Optional[dict[str, int]]
- do_rescale: Optional[bool]
- rescale_factor: Optional[Union[int, float]]
- do_normalize: Optional[bool]
- image_mean: Optional[Union[float, list[float]]]
- image_std: Optional[Union[float, list[float]]]
- do_pad: Optional[bool]
- pad_size: Optional[dict[str, int]]
- do_convert_rgb: Optional[bool]
- return_tensors: Optional[Union[str, TensorType]]
- data_format: Optional[ChannelDimension]
- input_data_format: Optional[Union[str, ChannelDimension]]
- device: Optional["torch.device"]
- disable_grouping: Optional[bool]
- @auto_docstring
- class BaseImageProcessorFast(BaseImageProcessor):
- resample = None
- image_mean = None
- image_std = None
- size = None
- default_to_square = True
- crop_size = None
- do_resize = None
- do_center_crop = None
- do_pad = None
- pad_size = None
- do_rescale = None
- rescale_factor = 1 / 255
- do_normalize = None
- do_convert_rgb = None
- return_tensors = None
- data_format = ChannelDimension.FIRST
- input_data_format = None
- device = None
- model_input_names = ["pixel_values"]
- valid_kwargs = DefaultFastImageProcessorKwargs
- unused_kwargs = None
- def __init__(self, **kwargs: Unpack[DefaultFastImageProcessorKwargs]):
- super().__init__(**kwargs)
- kwargs = self.filter_out_unused_kwargs(kwargs)
- size = kwargs.pop("size", self.size)
- self.size = (
- get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
- if size is not None
- else None
- )
- crop_size = kwargs.pop("crop_size", self.crop_size)
- self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
- pad_size = kwargs.pop("pad_size", self.pad_size)
- self.pad_size = get_size_dict(size=pad_size, param_name="pad_size") if pad_size is not None else None
- for key in self.valid_kwargs.__annotations__:
- kwarg = kwargs.pop(key, None)
- if kwarg is not None:
- setattr(self, key, kwarg)
- else:
- setattr(self, key, deepcopy(getattr(self, key, None)))
- # get valid kwargs names
- self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
- @property
- def is_fast(self) -> bool:
- """
- `bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision).
- """
- return True
- def pad(
- self,
- images: "torch.Tensor",
- pad_size: SizeDict = None,
- fill_value: Optional[int] = 0,
- padding_mode: Optional[str] = "constant",
- return_mask: bool = False,
- disable_grouping: Optional[bool] = False,
- **kwargs,
- ) -> "torch.Tensor":
- """
- Pads images to `(pad_size["height"], pad_size["width"])` or to the largest size in the batch.
- Args:
- images (`torch.Tensor`):
- Images to pad.
- pad_size (`SizeDict`, *optional*):
- Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
- fill_value (`int`, *optional*, defaults to `0`):
- The constant value used to fill the padded area.
- padding_mode (`str`, *optional*, defaults to "constant"):
- The padding mode to use. Can be any of the modes supported by
- `torch.nn.functional.pad` (e.g. constant, reflection, replication).
- return_mask (`bool`, *optional*, defaults to `False`):
- Whether to return a pixel mask to denote padded regions.
- disable_grouping (`bool`, *optional*, defaults to `False`):
- Whether to disable grouping of images by size.
- Returns:
- `torch.Tensor`: The resized image.
- """
- if pad_size is not None:
- if not (pad_size.height and pad_size.width):
- raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
- pad_size = (pad_size.height, pad_size.width)
- else:
- pad_size = get_max_height_width(images)
- grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
- processed_images_grouped = {}
- processed_masks_grouped = {}
- for shape, stacked_images in grouped_images.items():
- image_size = stacked_images.shape[-2:]
- padding_height = pad_size[0] - image_size[0]
- padding_width = pad_size[1] - image_size[1]
- if padding_height < 0 or padding_width < 0:
- raise ValueError(
- f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
- f"image size. Got pad_size={pad_size}, image_size={image_size}."
- )
- if image_size != pad_size:
- padding = (0, 0, padding_width, padding_height)
- stacked_images = F.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
- processed_images_grouped[shape] = stacked_images
- if return_mask:
- # keep only one from the channel dimension in pixel mask
- stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :]
- stacked_masks[..., : image_size[0], : image_size[1]] = 1
- processed_masks_grouped[shape] = stacked_masks
- processed_images = reorder_images(processed_images_grouped, grouped_images_index)
- if return_mask:
- processed_masks = reorder_images(processed_masks_grouped, grouped_images_index)
- return processed_images, processed_masks
- return processed_images
- def resize(
- self,
- image: "torch.Tensor",
- size: SizeDict,
- interpolation: Optional["F.InterpolationMode"] = None,
- antialias: bool = True,
- **kwargs,
- ) -> "torch.Tensor":
- """
- Resize an image to `(size["height"], size["width"])`.
- Args:
- image (`torch.Tensor`):
- Image to resize.
- size (`SizeDict`):
- Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
- interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
- `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
- Returns:
- `torch.Tensor`: The resized image.
- """
- interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
- if size.shortest_edge and size.longest_edge:
- # Resize the image so that the shortest edge or the longest edge is of the given size
- # while maintaining the aspect ratio of the original image.
- new_size = get_size_with_aspect_ratio(
- image.size()[-2:],
- size.shortest_edge,
- size.longest_edge,
- )
- elif size.shortest_edge:
- new_size = get_resize_output_image_size(
- image,
- size=size.shortest_edge,
- default_to_square=False,
- input_data_format=ChannelDimension.FIRST,
- )
- elif size.max_height and size.max_width:
- new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width)
- elif size.height and size.width:
- new_size = (size.height, size.width)
- else:
- raise ValueError(
- "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
- f" {size}."
- )
- # This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs
- # Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209
- # TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
- if torch.compiler.is_compiling() and is_rocm_platform():
- return self.compile_friendly_resize(image, new_size, interpolation, antialias)
- return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
- @staticmethod
- def compile_friendly_resize(
- image: "torch.Tensor",
- new_size: tuple[int, int],
- interpolation: Optional["F.InterpolationMode"] = None,
- antialias: bool = True,
- ) -> "torch.Tensor":
- """
- A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
- """
- if image.dtype == torch.uint8:
- # 256 is used on purpose instead of 255 to avoid numerical differences
- # see https://github.com/huggingface/transformers/pull/38540#discussion_r2127165652
- image = image.float() / 256
- image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
- image = image * 256
- # torch.where is used on purpose instead of torch.clamp to avoid bug in torch.compile
- # see https://github.com/huggingface/transformers/pull/38540#discussion_r2126888471
- image = torch.where(image > 255, 255, image)
- image = torch.where(image < 0, 0, image)
- image = image.round().to(torch.uint8)
- else:
- image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
- return image
- def rescale(
- self,
- image: "torch.Tensor",
- scale: float,
- **kwargs,
- ) -> "torch.Tensor":
- """
- Rescale an image by a scale factor. image = image * scale.
- Args:
- image (`torch.Tensor`):
- Image to rescale.
- scale (`float`):
- The scaling factor to rescale pixel values by.
- Returns:
- `torch.Tensor`: The rescaled image.
- """
- return image * scale
- def normalize(
- self,
- image: "torch.Tensor",
- mean: Union[float, Iterable[float]],
- std: Union[float, Iterable[float]],
- **kwargs,
- ) -> "torch.Tensor":
- """
- Normalize an image. image = (image - image_mean) / image_std.
- Args:
- image (`torch.Tensor`):
- Image to normalize.
- mean (`torch.Tensor`, `float` or `Iterable[float]`):
- Image mean to use for normalization.
- std (`torch.Tensor`, `float` or `Iterable[float]`):
- Image standard deviation to use for normalization.
- Returns:
- `torch.Tensor`: The normalized image.
- """
- return F.normalize(image, mean, std)
- @lru_cache(maxsize=10)
- def _fuse_mean_std_and_rescale_factor(
- self,
- do_normalize: Optional[bool] = None,
- image_mean: Optional[Union[float, list[float]]] = None,
- image_std: Optional[Union[float, list[float]]] = None,
- do_rescale: Optional[bool] = None,
- rescale_factor: Optional[float] = None,
- device: Optional["torch.device"] = None,
- ) -> tuple:
- if do_rescale and do_normalize:
- # Fused rescale and normalize
- image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
- image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
- do_rescale = False
- return image_mean, image_std, do_rescale
- def rescale_and_normalize(
- self,
- images: "torch.Tensor",
- do_rescale: bool,
- rescale_factor: float,
- do_normalize: bool,
- image_mean: Union[float, list[float]],
- image_std: Union[float, list[float]],
- ) -> "torch.Tensor":
- """
- Rescale and normalize images.
- """
- image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
- do_normalize=do_normalize,
- image_mean=image_mean,
- image_std=image_std,
- do_rescale=do_rescale,
- rescale_factor=rescale_factor,
- device=images.device,
- )
- # if/elif as we use fused rescale and normalize if both are set to True
- if do_normalize:
- images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
- elif do_rescale:
- images = self.rescale(images, rescale_factor)
- return images
- def center_crop(
- self,
- image: "torch.Tensor",
- size: SizeDict,
- **kwargs,
- ) -> "torch.Tensor":
- """
- Note: override torchvision's center_crop to have the same behavior as the slow processor.
- Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
- any edge, the image is padded with 0's and then center cropped.
- Args:
- image (`"torch.Tensor"`):
- Image to center crop.
- size (`dict[str, int]`):
- Size of the output image.
- Returns:
- `torch.Tensor`: The center cropped image.
- """
- if size.height is None or size.width is None:
- raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
- image_height, image_width = image.shape[-2:]
- crop_height, crop_width = size.height, size.width
- if crop_width > image_width or crop_height > image_height:
- padding_ltrb = [
- (crop_width - image_width) // 2 if crop_width > image_width else 0,
- (crop_height - image_height) // 2 if crop_height > image_height else 0,
- (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
- (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
- ]
- image = F.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0
- image_height, image_width = image.shape[-2:]
- if crop_width == image_width and crop_height == image_height:
- return image
- crop_top = int((image_height - crop_height) / 2.0)
- crop_left = int((image_width - crop_width) / 2.0)
- return F.crop(image, crop_top, crop_left, crop_height, crop_width)
- def convert_to_rgb(
- self,
- image: ImageInput,
- ) -> ImageInput:
- """
- Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
- as is.
- Args:
- image (ImageInput):
- The image to convert.
- Returns:
- ImageInput: The converted image.
- """
- return convert_to_rgb(image)
- def filter_out_unused_kwargs(self, kwargs: dict):
- """
- Filter out the unused kwargs from the kwargs dictionary.
- """
- if self.unused_kwargs is None:
- return kwargs
- for kwarg_name in self.unused_kwargs:
- if kwarg_name in kwargs:
- logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.")
- kwargs.pop(kwarg_name)
- return kwargs
- def _prepare_images_structure(
- self,
- images: ImageInput,
- expected_ndims: int = 3,
- ) -> ImageInput:
- """
- Prepare the images structure for processing.
- Args:
- images (`ImageInput`):
- The input images to process.
- Returns:
- `ImageInput`: The images with a valid nesting.
- """
- # Checks for `str` in case of URL/local path and optionally loads images
- images = self.fetch_images(images)
- return make_flat_list_of_images(images, expected_ndims=expected_ndims)
- def _process_image(
- self,
- image: ImageInput,
- do_convert_rgb: Optional[bool] = None,
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
- device: Optional["torch.device"] = None,
- ) -> "torch.Tensor":
- image_type = get_image_type(image)
- if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
- raise ValueError(f"Unsupported input image type {image_type}")
- if do_convert_rgb:
- image = self.convert_to_rgb(image)
- if image_type == ImageType.PIL:
- image = F.pil_to_tensor(image)
- elif image_type == ImageType.NUMPY:
- # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
- image = torch.from_numpy(image).contiguous()
- # If the image is 2D, we need to unsqueeze it to add a channel dimension for processing
- if image.ndim == 2:
- image = image.unsqueeze(0)
- # Infer the channel dimension format if not provided
- if input_data_format is None:
- input_data_format = infer_channel_dimension_format(image)
- if input_data_format == ChannelDimension.LAST:
- # We force the channel dimension to be first for torch tensors as this is what torchvision expects.
- image = image.permute(2, 0, 1).contiguous()
- # Now that we have torch tensors, we can move them to the right device
- if device is not None:
- image = image.to(device)
- return image
- def _prepare_image_like_inputs(
- self,
- images: ImageInput,
- do_convert_rgb: Optional[bool] = None,
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
- device: Optional["torch.device"] = None,
- expected_ndims: int = 3,
- ) -> list["torch.Tensor"]:
- """
- Prepare image-like inputs for processing.
- Args:
- images (`ImageInput`):
- The image-like inputs to process.
- do_convert_rgb (`bool`, *optional*):
- Whether to convert the images to RGB.
- input_data_format (`str` or `ChannelDimension`, *optional*):
- The input data format of the images.
- device (`torch.device`, *optional*):
- The device to put the processed images on.
- expected_ndims (`int`, *optional*):
- The expected number of dimensions for the images. (can be 2 for segmentation maps etc.)
- Returns:
- List[`torch.Tensor`]: The processed images.
- """
- # Get structured images (potentially nested)
- images = self._prepare_images_structure(images, expected_ndims=expected_ndims)
- process_image_partial = partial(
- self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
- )
- # Check if we have nested structure, assuming the nesting is consistent
- has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple))
- if has_nested_structure:
- processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
- else:
- processed_images = [process_image_partial(img) for img in images]
- return processed_images
- def _further_process_kwargs(
- self,
- size: Optional[SizeDict] = None,
- crop_size: Optional[SizeDict] = None,
- pad_size: Optional[SizeDict] = None,
- default_to_square: Optional[bool] = None,
- image_mean: Optional[Union[float, list[float]]] = None,
- image_std: Optional[Union[float, list[float]]] = None,
- data_format: Optional[ChannelDimension] = None,
- **kwargs,
- ) -> dict:
- """
- Update kwargs that need further processing before being validated
- Can be overridden by subclasses to customize the processing of kwargs.
- """
- if kwargs is None:
- kwargs = {}
- if size is not None:
- size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
- if crop_size is not None:
- crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
- if pad_size is not None:
- pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size"))
- if isinstance(image_mean, list):
- image_mean = tuple(image_mean)
- if isinstance(image_std, list):
- image_std = tuple(image_std)
- if data_format is None:
- data_format = ChannelDimension.FIRST
- kwargs["size"] = size
- kwargs["crop_size"] = crop_size
- kwargs["pad_size"] = pad_size
- kwargs["image_mean"] = image_mean
- kwargs["image_std"] = image_std
- kwargs["data_format"] = data_format
- # torch resize uses interpolation instead of resample
- # Check if resample is an int before checking if it's an instance of PILImageResampling
- # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
- # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
- resample = kwargs.pop("resample")
- kwargs["interpolation"] = (
- pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
- )
- return kwargs
- def _validate_preprocess_kwargs(
- self,
- do_rescale: Optional[bool] = None,
- rescale_factor: Optional[float] = None,
- do_normalize: Optional[bool] = None,
- image_mean: Optional[Union[float, tuple[float]]] = None,
- image_std: Optional[Union[float, tuple[float]]] = None,
- do_resize: Optional[bool] = None,
- size: Optional[SizeDict] = None,
- do_center_crop: Optional[bool] = None,
- crop_size: Optional[SizeDict] = None,
- interpolation: Optional["F.InterpolationMode"] = None,
- return_tensors: Optional[Union[str, TensorType]] = None,
- data_format: Optional[ChannelDimension] = None,
- **kwargs,
- ):
- """
- validate the kwargs for the preprocess method.
- """
- validate_fast_preprocess_arguments(
- do_rescale=do_rescale,
- rescale_factor=rescale_factor,
- do_normalize=do_normalize,
- image_mean=image_mean,
- image_std=image_std,
- do_resize=do_resize,
- size=size,
- do_center_crop=do_center_crop,
- crop_size=crop_size,
- interpolation=interpolation,
- return_tensors=return_tensors,
- data_format=data_format,
- )
- def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
- return self.preprocess(images, *args, **kwargs)
- @auto_docstring
- def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
- # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
- validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
- # Set default kwargs from self. This ensures that if a kwarg is not provided
- # by the user, it gets its default value from the instance, or is set to None.
- for kwarg_name in self._valid_kwargs_names:
- kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
- # Extract parameters that are only used for preparing the input images
- do_convert_rgb = kwargs.pop("do_convert_rgb")
- input_data_format = kwargs.pop("input_data_format")
- device = kwargs.pop("device")
- # Update kwargs that need further processing before being validated
- kwargs = self._further_process_kwargs(**kwargs)
- # Validate kwargs
- self._validate_preprocess_kwargs(**kwargs)
- # Pop kwargs that are not needed in _preprocess
- kwargs.pop("data_format")
- return self._preprocess_image_like_inputs(
- images, *args, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device, **kwargs
- )
- def _preprocess_image_like_inputs(
- self,
- images: ImageInput,
- *args,
- do_convert_rgb: bool,
- input_data_format: ChannelDimension,
- device: Optional[Union[str, "torch.device"]] = None,
- **kwargs: Unpack[DefaultFastImageProcessorKwargs],
- ) -> BatchFeature:
- """
- Preprocess image-like inputs.
- To be overridden by subclasses when image-like inputs other than images should be processed.
- It can be used for segmentation maps, depth maps, etc.
- """
- # Prepare input images
- images = self._prepare_image_like_inputs(
- images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
- )
- return self._preprocess(images, *args, **kwargs)
- def _preprocess(
- self,
- images: list["torch.Tensor"],
- do_resize: bool,
- size: SizeDict,
- interpolation: Optional["F.InterpolationMode"],
- do_center_crop: bool,
- crop_size: SizeDict,
- do_rescale: bool,
- rescale_factor: float,
- do_normalize: bool,
- image_mean: Optional[Union[float, list[float]]],
- image_std: Optional[Union[float, list[float]]],
- do_pad: Optional[bool],
- pad_size: Optional[SizeDict],
- disable_grouping: Optional[bool],
- return_tensors: Optional[Union[str, TensorType]],
- **kwargs,
- ) -> BatchFeature:
- # Group images by size for batched resizing
- grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
- resized_images_grouped = {}
- for shape, stacked_images in grouped_images.items():
- if do_resize:
- stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
- resized_images_grouped[shape] = stacked_images
- resized_images = reorder_images(resized_images_grouped, grouped_images_index)
- # Group images by size for further processing
- # Needed in case do_resize is False, or resize returns images with different sizes
- grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
- processed_images_grouped = {}
- for shape, stacked_images in grouped_images.items():
- if do_center_crop:
- stacked_images = self.center_crop(stacked_images, crop_size)
- # Fused rescale and normalize
- stacked_images = self.rescale_and_normalize(
- stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
- )
- processed_images_grouped[shape] = stacked_images
- processed_images = reorder_images(processed_images_grouped, grouped_images_index)
- if do_pad:
- processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
- return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
- def to_dict(self):
- encoder_dict = super().to_dict()
- encoder_dict.pop("_valid_processor_keys", None)
- encoder_dict.pop("_valid_kwargs_names", None)
- return encoder_dict
|