modular_owlv2.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Fast Image processor class for OWLv2."""
  16. import warnings
  17. from typing import Optional, Union
  18. import torch
  19. from torchvision.transforms.v2 import functional as F
  20. from ...image_processing_utils_fast import (
  21. BaseImageProcessorFast,
  22. BatchFeature,
  23. DefaultFastImageProcessorKwargs,
  24. )
  25. from ...image_transforms import group_images_by_shape, reorder_images
  26. from ...image_utils import (
  27. OPENAI_CLIP_MEAN,
  28. OPENAI_CLIP_STD,
  29. ChannelDimension,
  30. ImageInput,
  31. PILImageResampling,
  32. SizeDict,
  33. )
  34. from ...processing_utils import Unpack
  35. from ...utils import (
  36. TensorType,
  37. auto_docstring,
  38. )
  39. from ..owlvit.image_processing_owlvit_fast import OwlViTImageProcessorFast
  40. class Owlv2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ...
  41. @auto_docstring
  42. class Owlv2ImageProcessorFast(OwlViTImageProcessorFast):
  43. resample = PILImageResampling.BILINEAR
  44. image_mean = OPENAI_CLIP_MEAN
  45. image_std = OPENAI_CLIP_STD
  46. size = {"height": 960, "width": 960}
  47. rescale_factor = 1 / 255
  48. do_resize = True
  49. do_rescale = True
  50. do_normalize = True
  51. do_pad = True
  52. valid_kwargs = Owlv2FastImageProcessorKwargs
  53. crop_size = None
  54. do_center_crop = None
  55. def __init__(self, **kwargs: Unpack[Owlv2FastImageProcessorKwargs]):
  56. BaseImageProcessorFast.__init__(self, **kwargs)
  57. @auto_docstring
  58. def preprocess(self, images: ImageInput, **kwargs: Unpack[Owlv2FastImageProcessorKwargs]):
  59. return BaseImageProcessorFast.preprocess(self, images, **kwargs)
  60. def _pad_images(self, images: "torch.Tensor", constant_value: float = 0.5) -> "torch.Tensor":
  61. """
  62. Pad an image with zeros to the given size.
  63. """
  64. height, width = images.shape[-2:]
  65. size = max(height, width)
  66. pad_bottom = size - height
  67. pad_right = size - width
  68. padding = (0, 0, pad_right, pad_bottom)
  69. padded_image = F.pad(images, padding, fill=constant_value)
  70. return padded_image
  71. def pad(
  72. self,
  73. images: list["torch.Tensor"],
  74. disable_grouping: Optional[bool],
  75. constant_value: float = 0.5,
  76. **kwargs,
  77. ) -> list["torch.Tensor"]:
  78. """
  79. Unlike the Base class `self.pad` where all images are padded to the maximum image size,
  80. Owlv2 pads an image to square.
  81. """
  82. grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
  83. processed_images_grouped = {}
  84. for shape, stacked_images in grouped_images.items():
  85. stacked_images = self._pad_images(
  86. stacked_images,
  87. constant_value=constant_value,
  88. )
  89. processed_images_grouped[shape] = stacked_images
  90. processed_images = reorder_images(processed_images_grouped, grouped_images_index)
  91. return processed_images
  92. def resize(
  93. self,
  94. image: "torch.Tensor",
  95. size: SizeDict,
  96. anti_aliasing: bool = True,
  97. anti_aliasing_sigma=None,
  98. **kwargs,
  99. ) -> "torch.Tensor":
  100. """
  101. Resize an image as per the original implementation.
  102. Args:
  103. image (`Tensor`):
  104. Image to resize.
  105. size (`dict[str, int]`):
  106. Dictionary containing the height and width to resize the image to.
  107. anti_aliasing (`bool`, *optional*, defaults to `True`):
  108. Whether to apply anti-aliasing when downsampling the image.
  109. anti_aliasing_sigma (`float`, *optional*, defaults to `None`):
  110. Standard deviation for Gaussian kernel when downsampling the image. If `None`, it will be calculated
  111. automatically.
  112. """
  113. output_shape = (size.height, size.width)
  114. input_shape = image.shape
  115. # select height and width from input tensor
  116. factors = torch.tensor(input_shape[2:]).to(image.device) / torch.tensor(output_shape).to(image.device)
  117. if anti_aliasing:
  118. if anti_aliasing_sigma is None:
  119. anti_aliasing_sigma = ((factors - 1) / 2).clamp(min=0)
  120. else:
  121. anti_aliasing_sigma = torch.atleast_1d(anti_aliasing_sigma) * torch.ones_like(factors)
  122. if torch.any(anti_aliasing_sigma < 0):
  123. raise ValueError("Anti-aliasing standard deviation must be greater than or equal to zero")
  124. elif torch.any((anti_aliasing_sigma > 0) & (factors <= 1)):
  125. warnings.warn(
  126. "Anti-aliasing standard deviation greater than zero but not down-sampling along all axes"
  127. )
  128. if torch.any(anti_aliasing_sigma == 0):
  129. filtered = image
  130. else:
  131. kernel_sizes = 2 * torch.ceil(3 * anti_aliasing_sigma).int() + 1
  132. filtered = F.gaussian_blur(
  133. image, (kernel_sizes[0], kernel_sizes[1]), sigma=anti_aliasing_sigma.tolist()
  134. )
  135. else:
  136. filtered = image
  137. out = F.resize(filtered, size=(size.height, size.width), antialias=False)
  138. return out
  139. def _preprocess(
  140. self,
  141. images: list["torch.Tensor"],
  142. do_resize: bool,
  143. size: SizeDict,
  144. interpolation: Optional["F.InterpolationMode"],
  145. do_pad: bool,
  146. do_rescale: bool,
  147. rescale_factor: float,
  148. do_normalize: bool,
  149. image_mean: Optional[Union[float, list[float]]],
  150. image_std: Optional[Union[float, list[float]]],
  151. disable_grouping: Optional[bool],
  152. return_tensors: Optional[Union[str, TensorType]],
  153. **kwargs,
  154. ) -> BatchFeature:
  155. # Group images by size for batched resizing
  156. grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
  157. processed_images_grouped = {}
  158. for shape, stacked_images in grouped_images.items():
  159. # Rescale images before other operations as done in original implementation
  160. stacked_images = self.rescale_and_normalize(
  161. stacked_images, do_rescale, rescale_factor, False, image_mean, image_std
  162. )
  163. processed_images_grouped[shape] = stacked_images
  164. processed_images = reorder_images(processed_images_grouped, grouped_images_index)
  165. if do_pad:
  166. processed_images = self.pad(processed_images, constant_value=0.5, disable_grouping=disable_grouping)
  167. grouped_images, grouped_images_index = group_images_by_shape(
  168. processed_images, disable_grouping=disable_grouping
  169. )
  170. resized_images_grouped = {}
  171. for shape, stacked_images in grouped_images.items():
  172. if do_resize:
  173. resized_stack = self.resize(
  174. image=stacked_images,
  175. size=size,
  176. interpolation=interpolation,
  177. input_data_format=ChannelDimension.FIRST,
  178. )
  179. resized_images_grouped[shape] = resized_stack
  180. resized_images = reorder_images(resized_images_grouped, grouped_images_index)
  181. # Group images by size for further processing
  182. # Needed in case do_resize is False, or resize returns images with different sizes
  183. grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
  184. processed_images_grouped = {}
  185. for shape, stacked_images in grouped_images.items():
  186. # Fused rescale and normalize
  187. stacked_images = self.rescale_and_normalize(
  188. stacked_images, False, rescale_factor, do_normalize, image_mean, image_std
  189. )
  190. processed_images_grouped[shape] = stacked_images
  191. processed_images = reorder_images(processed_images_grouped, grouped_images_index)
  192. processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
  193. return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
  194. __all__ = ["Owlv2ImageProcessorFast"]