image_processing_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # Copyright 2022 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from collections.abc import Iterable
  16. from typing import Optional, Union
  17. import numpy as np
  18. from .image_processing_base import BatchFeature, ImageProcessingMixin
  19. from .image_transforms import center_crop, normalize, rescale
  20. from .image_utils import ChannelDimension, get_image_size
  21. from .utils import logging
  22. from .utils.import_utils import requires
  23. logger = logging.get_logger(__name__)
  24. INIT_SERVICE_KWARGS = [
  25. "processor_class",
  26. "image_processor_type",
  27. ]
  28. @requires(backends=("vision",))
  29. class BaseImageProcessor(ImageProcessingMixin):
  30. def __init__(self, **kwargs):
  31. super().__init__(**kwargs)
  32. @property
  33. def is_fast(self) -> bool:
  34. """
  35. `bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision).
  36. """
  37. return False
  38. def __call__(self, images, **kwargs) -> BatchFeature:
  39. """Preprocess an image or a batch of images."""
  40. return self.preprocess(images, **kwargs)
  41. def preprocess(self, images, **kwargs) -> BatchFeature:
  42. raise NotImplementedError("Each image processor must implement its own preprocess method")
  43. def rescale(
  44. self,
  45. image: np.ndarray,
  46. scale: float,
  47. data_format: Optional[Union[str, ChannelDimension]] = None,
  48. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  49. **kwargs,
  50. ) -> np.ndarray:
  51. """
  52. Rescale an image by a scale factor. image = image * scale.
  53. Args:
  54. image (`np.ndarray`):
  55. Image to rescale.
  56. scale (`float`):
  57. The scaling factor to rescale pixel values by.
  58. data_format (`str` or `ChannelDimension`, *optional*):
  59. The channel dimension format for the output image. If unset, the channel dimension format of the input
  60. image is used. Can be one of:
  61. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  62. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  63. input_data_format (`ChannelDimension` or `str`, *optional*):
  64. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  65. from the input image. Can be one of:
  66. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  67. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  68. Returns:
  69. `np.ndarray`: The rescaled image.
  70. """
  71. return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
  72. def normalize(
  73. self,
  74. image: np.ndarray,
  75. mean: Union[float, Iterable[float]],
  76. std: Union[float, Iterable[float]],
  77. data_format: Optional[Union[str, ChannelDimension]] = None,
  78. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  79. **kwargs,
  80. ) -> np.ndarray:
  81. """
  82. Normalize an image. image = (image - image_mean) / image_std.
  83. Args:
  84. image (`np.ndarray`):
  85. Image to normalize.
  86. mean (`float` or `Iterable[float]`):
  87. Image mean to use for normalization.
  88. std (`float` or `Iterable[float]`):
  89. Image standard deviation to use for normalization.
  90. data_format (`str` or `ChannelDimension`, *optional*):
  91. The channel dimension format for the output image. If unset, the channel dimension format of the input
  92. image is used. Can be one of:
  93. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  94. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  95. input_data_format (`ChannelDimension` or `str`, *optional*):
  96. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  97. from the input image. Can be one of:
  98. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  99. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  100. Returns:
  101. `np.ndarray`: The normalized image.
  102. """
  103. return normalize(
  104. image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
  105. )
  106. def center_crop(
  107. self,
  108. image: np.ndarray,
  109. size: dict[str, int],
  110. data_format: Optional[Union[str, ChannelDimension]] = None,
  111. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  112. **kwargs,
  113. ) -> np.ndarray:
  114. """
  115. Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
  116. any edge, the image is padded with 0's and then center cropped.
  117. Args:
  118. image (`np.ndarray`):
  119. Image to center crop.
  120. size (`dict[str, int]`):
  121. Size of the output image.
  122. data_format (`str` or `ChannelDimension`, *optional*):
  123. The channel dimension format for the output image. If unset, the channel dimension format of the input
  124. image is used. Can be one of:
  125. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  126. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  127. input_data_format (`ChannelDimension` or `str`, *optional*):
  128. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  129. from the input image. Can be one of:
  130. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  131. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  132. """
  133. size = get_size_dict(size)
  134. if "height" not in size or "width" not in size:
  135. raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
  136. return center_crop(
  137. image,
  138. size=(size["height"], size["width"]),
  139. data_format=data_format,
  140. input_data_format=input_data_format,
  141. **kwargs,
  142. )
  143. def to_dict(self):
  144. encoder_dict = super().to_dict()
  145. encoder_dict.pop("_valid_processor_keys", None)
  146. return encoder_dict
  147. VALID_SIZE_DICT_KEYS = (
  148. {"height", "width"},
  149. {"shortest_edge"},
  150. {"shortest_edge", "longest_edge"},
  151. {"longest_edge"},
  152. {"max_height", "max_width"},
  153. )
  154. def is_valid_size_dict(size_dict):
  155. if not isinstance(size_dict, dict):
  156. return False
  157. size_dict_keys = set(size_dict.keys())
  158. for allowed_keys in VALID_SIZE_DICT_KEYS:
  159. if size_dict_keys == allowed_keys:
  160. return True
  161. return False
  162. def convert_to_size_dict(
  163. size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
  164. ):
  165. # By default, if size is an int we assume it represents a tuple of (size, size).
  166. if isinstance(size, int) and default_to_square:
  167. if max_size is not None:
  168. raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
  169. return {"height": size, "width": size}
  170. # In other configs, if size is an int and default_to_square is False, size represents the length of
  171. # the shortest edge after resizing.
  172. elif isinstance(size, int) and not default_to_square:
  173. size_dict = {"shortest_edge": size}
  174. if max_size is not None:
  175. size_dict["longest_edge"] = max_size
  176. return size_dict
  177. # Otherwise, if size is a tuple it's either (height, width) or (width, height)
  178. elif isinstance(size, (tuple, list)) and height_width_order:
  179. return {"height": size[0], "width": size[1]}
  180. elif isinstance(size, (tuple, list)) and not height_width_order:
  181. return {"height": size[1], "width": size[0]}
  182. elif size is None and max_size is not None:
  183. if default_to_square:
  184. raise ValueError("Cannot specify both default_to_square=True and max_size")
  185. return {"longest_edge": max_size}
  186. raise ValueError(f"Could not convert size input to size dict: {size}")
  187. def get_size_dict(
  188. size: Optional[Union[int, Iterable[int], dict[str, int]]] = None,
  189. max_size: Optional[int] = None,
  190. height_width_order: bool = True,
  191. default_to_square: bool = True,
  192. param_name="size",
  193. ) -> dict:
  194. """
  195. Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
  196. compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
  197. width) or (width, height) format.
  198. - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
  199. size[0]}` if `height_width_order` is `False`.
  200. - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
  201. - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
  202. is set, it is added to the dict as `{"longest_edge": max_size}`.
  203. Args:
  204. size (`Union[int, Iterable[int], dict[str, int]]`, *optional*):
  205. The `size` parameter to be cast into a size dictionary.
  206. max_size (`Optional[int]`, *optional*):
  207. The `max_size` parameter to be cast into a size dictionary.
  208. height_width_order (`bool`, *optional*, defaults to `True`):
  209. If `size` is a tuple, whether it's in (height, width) or (width, height) order.
  210. default_to_square (`bool`, *optional*, defaults to `True`):
  211. If `size` is an int, whether to default to a square image or not.
  212. """
  213. if not isinstance(size, dict):
  214. size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
  215. logger.info(
  216. f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
  217. f" Converted to {size_dict}.",
  218. )
  219. else:
  220. size_dict = size
  221. if not is_valid_size_dict(size_dict):
  222. raise ValueError(
  223. f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
  224. )
  225. return size_dict
  226. def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
  227. """
  228. Selects the best resolution from a list of possible resolutions based on the original size.
  229. This is done by calculating the effective and wasted resolution for each possible resolution.
  230. The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
  231. Args:
  232. original_size (tuple):
  233. The original size of the image in the format (height, width).
  234. possible_resolutions (list):
  235. A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
  236. Returns:
  237. tuple: The best fit resolution in the format (height, width).
  238. """
  239. original_height, original_width = original_size
  240. best_fit = None
  241. max_effective_resolution = 0
  242. min_wasted_resolution = float("inf")
  243. for height, width in possible_resolutions:
  244. scale = min(width / original_width, height / original_height)
  245. downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
  246. effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
  247. wasted_resolution = (width * height) - effective_resolution
  248. if effective_resolution > max_effective_resolution or (
  249. effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
  250. ):
  251. max_effective_resolution = effective_resolution
  252. min_wasted_resolution = wasted_resolution
  253. best_fit = (height, width)
  254. return best_fit
  255. def get_patch_output_size(image, target_resolution, input_data_format):
  256. """
  257. Given an image and a target resolution, calculate the output size of the image after cropping to the target
  258. """
  259. original_height, original_width = get_image_size(image, channel_dim=input_data_format)
  260. target_height, target_width = target_resolution
  261. scale_w = target_width / original_width
  262. scale_h = target_height / original_height
  263. if scale_w < scale_h:
  264. new_width = target_width
  265. new_height = min(math.ceil(original_height * scale_w), target_height)
  266. else:
  267. new_height = target_height
  268. new_width = min(math.ceil(original_width * scale_h), target_width)
  269. return new_height, new_width