modeling_siglip2.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_siglip2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The HuggingFace Inc. team.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. import warnings
  23. from dataclasses import dataclass
  24. from typing import Any, Callable, Optional, Union
  25. import numpy as np
  26. import torch
  27. import torch.nn as nn
  28. import torch.nn.functional as F
  29. from torch.nn.init import _calculate_fan_in_and_fan_out
  30. from ...activations import ACT2FN
  31. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs
  37. from ...utils.generic import check_model_inputs
  38. from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
  39. @dataclass
  40. @auto_docstring(
  41. custom_intro="""
  42. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  43. """
  44. )
  45. class Siglip2VisionOutput(ModelOutput):
  46. r"""
  47. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  48. The image embeddings obtained by applying the projection layer to the pooler_output.
  49. """
  50. image_embeds: Optional[torch.FloatTensor] = None
  51. last_hidden_state: Optional[torch.FloatTensor] = None
  52. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  53. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  54. @dataclass
  55. @auto_docstring(
  56. custom_intro="""
  57. Base class for text model's outputs that also contains a pooling of the last hidden states.
  58. """
  59. )
  60. class Siglip2TextOutput(ModelOutput):
  61. r"""
  62. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  63. The text embeddings obtained by applying the projection layer to the pooler_output.
  64. """
  65. text_embeds: Optional[torch.FloatTensor] = None
  66. last_hidden_state: Optional[torch.FloatTensor] = None
  67. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  68. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  69. @dataclass
  70. @auto_docstring
  71. class Siglip2Output(ModelOutput):
  72. r"""
  73. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  74. Contrastive loss for image-text similarity.
  75. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  76. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  77. similarity scores.
  78. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  79. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  80. similarity scores.
  81. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  82. The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`].
  83. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  84. The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`].
  85. text_model_output (`BaseModelOutputWithPooling`):
  86. The output of the [`Siglip2TextModel`].
  87. vision_model_output (`BaseModelOutputWithPooling`):
  88. The output of the [`Siglip2VisionModel`].
  89. """
  90. loss: Optional[torch.FloatTensor] = None
  91. logits_per_image: Optional[torch.FloatTensor] = None
  92. logits_per_text: Optional[torch.FloatTensor] = None
  93. text_embeds: Optional[torch.FloatTensor] = None
  94. image_embeds: Optional[torch.FloatTensor] = None
  95. text_model_output: BaseModelOutputWithPooling = None
  96. vision_model_output: BaseModelOutputWithPooling = None
  97. def to_tuple(self) -> tuple[Any]:
  98. return tuple(
  99. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  100. for k in self.keys()
  101. )
  102. class Siglip2VisionEmbeddings(nn.Module):
  103. def __init__(self, config: Siglip2VisionConfig):
  104. super().__init__()
  105. self.config = config
  106. self.embed_dim = config.hidden_size
  107. self.patch_size = config.patch_size
  108. self.patch_embedding = nn.Linear(
  109. in_features=config.num_channels * self.patch_size * self.patch_size,
  110. out_features=self.embed_dim,
  111. )
  112. self.num_patches = config.num_patches
  113. self.position_embedding_size = int(self.num_patches**0.5)
  114. self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
  115. @staticmethod
  116. def resize_positional_embeddings(
  117. positional_embeddings: torch.Tensor,
  118. spatial_shapes: torch.LongTensor,
  119. max_length: int,
  120. ) -> torch.Tensor:
  121. """
  122. Resize positional embeddings to image-specific size and pad to a fixed size.
  123. Args:
  124. positional_embeddings (`torch.Tensor`):
  125. Position embeddings of shape (height, width, embed_dim)
  126. spatial_shapes (`torch.LongTensor`):
  127. Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
  128. max_length (`int`):
  129. Maximum length of the positional embeddings to pad resized positional embeddings to
  130. Returns:
  131. `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
  132. """
  133. batch_size = spatial_shapes.shape[0]
  134. embed_dim = positional_embeddings.shape[-1]
  135. source_dtype = positional_embeddings.dtype
  136. resulted_positional_embeddings = torch.empty(
  137. (batch_size, max_length, embed_dim),
  138. device=positional_embeddings.device,
  139. dtype=source_dtype,
  140. )
  141. # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
  142. positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
  143. # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
  144. if positional_embeddings.device.type == "cpu":
  145. positional_embeddings = positional_embeddings.to(torch.float32)
  146. for i in range(batch_size):
  147. # (1, dim, height, width) -> (1, dim, target_height, target_width)
  148. height, width = spatial_shapes[i]
  149. resized_embeddings = F.interpolate(
  150. positional_embeddings,
  151. size=(height, width),
  152. mode="bilinear",
  153. align_corners=False,
  154. antialias=True,
  155. )
  156. # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
  157. resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
  158. # Cast to original dtype
  159. resized_embeddings = resized_embeddings.to(source_dtype)
  160. resulted_positional_embeddings[i, : height * width] = resized_embeddings
  161. resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
  162. return resulted_positional_embeddings
  163. def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
  164. """
  165. Args:
  166. pixel_values (`torch.FloatTensor`):
  167. Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
  168. spatial_shapes (`list[tuple[int, int]]`):
  169. Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
  170. """
  171. # Apply patch embeddings to already patchified pixel values
  172. target_dtype = self.patch_embedding.weight.dtype
  173. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  174. # Get positional resized and padded positional embeddings
  175. positional_embeddings = self.position_embedding.weight.reshape(
  176. self.position_embedding_size, self.position_embedding_size, -1
  177. )
  178. resized_positional_embeddings = self.resize_positional_embeddings(
  179. positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
  180. )
  181. # Add positional embeddings to patch embeddings
  182. embeddings = patch_embeds + resized_positional_embeddings
  183. return embeddings
  184. def eager_attention_forward(
  185. module: nn.Module,
  186. query: torch.Tensor,
  187. key: torch.Tensor,
  188. value: torch.Tensor,
  189. attention_mask: Optional[torch.Tensor],
  190. scaling: float,
  191. dropout: float = 0.0,
  192. **kwargs,
  193. ):
  194. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  195. if attention_mask is not None:
  196. attn_weights = attn_weights + attention_mask
  197. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  198. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  199. attn_output = torch.matmul(attn_weights, value)
  200. attn_output = attn_output.transpose(1, 2).contiguous()
  201. return attn_output, attn_weights
  202. class Siglip2Attention(nn.Module):
  203. """Multi-headed attention from 'Attention Is All You Need' paper"""
  204. def __init__(self, config):
  205. super().__init__()
  206. self.config = config
  207. self.embed_dim = config.hidden_size
  208. self.num_heads = config.num_attention_heads
  209. self.head_dim = self.embed_dim // self.num_heads
  210. if self.head_dim * self.num_heads != self.embed_dim:
  211. raise ValueError(
  212. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  213. f" {self.num_heads})."
  214. )
  215. self.scale = self.head_dim**-0.5
  216. self.dropout = config.attention_dropout
  217. self.is_causal = False
  218. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  219. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  220. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  221. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  222. def forward(
  223. self,
  224. hidden_states: torch.Tensor,
  225. attention_mask: Optional[torch.Tensor] = None,
  226. **kwargs,
  227. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  228. """Input shape: Batch x Time x Channel"""
  229. batch_size, seq_length, embed_dim = hidden_states.shape
  230. queries = self.q_proj(hidden_states)
  231. keys = self.k_proj(hidden_states)
  232. values = self.v_proj(hidden_states)
  233. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  234. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  235. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  236. attention_interface: Callable = eager_attention_forward
  237. if self.config._attn_implementation != "eager":
  238. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  239. attn_output, attn_weights = attention_interface(
  240. self,
  241. queries,
  242. keys,
  243. values,
  244. attention_mask,
  245. is_causal=self.is_causal,
  246. scaling=self.scale,
  247. dropout=0.0 if not self.training else self.dropout,
  248. )
  249. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  250. attn_output = self.out_proj(attn_output)
  251. return attn_output, attn_weights
  252. class Siglip2MLP(nn.Module):
  253. def __init__(self, config):
  254. super().__init__()
  255. self.config = config
  256. self.activation_fn = ACT2FN[config.hidden_act]
  257. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  258. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  259. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  260. hidden_states = self.fc1(hidden_states)
  261. hidden_states = self.activation_fn(hidden_states)
  262. hidden_states = self.fc2(hidden_states)
  263. return hidden_states
  264. class Siglip2EncoderLayer(GradientCheckpointingLayer):
  265. def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
  266. super().__init__()
  267. self.embed_dim = config.hidden_size
  268. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  269. self.self_attn = Siglip2Attention(config)
  270. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  271. self.mlp = Siglip2MLP(config)
  272. @auto_docstring
  273. def forward(
  274. self,
  275. hidden_states: torch.Tensor,
  276. attention_mask: torch.Tensor,
  277. **kwargs: Unpack[TransformersKwargs],
  278. ) -> torch.FloatTensor:
  279. residual = hidden_states
  280. hidden_states = self.layer_norm1(hidden_states)
  281. hidden_states, _ = self.self_attn(
  282. hidden_states=hidden_states,
  283. attention_mask=attention_mask,
  284. **kwargs,
  285. )
  286. hidden_states = residual + hidden_states
  287. residual = hidden_states
  288. hidden_states = self.layer_norm2(hidden_states)
  289. hidden_states = self.mlp(hidden_states)
  290. hidden_states = residual + hidden_states
  291. return hidden_states
  292. class Siglip2Encoder(nn.Module):
  293. """
  294. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  295. [`Siglip2EncoderLayer`].
  296. Args:
  297. config: Siglip2Config
  298. """
  299. def __init__(self, config: Siglip2Config):
  300. super().__init__()
  301. self.config = config
  302. self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  303. self.gradient_checkpointing = False
  304. # Ignore copy
  305. @auto_docstring
  306. def forward(
  307. self,
  308. inputs_embeds,
  309. attention_mask: Optional[torch.Tensor] = None,
  310. **kwargs: Unpack[TransformersKwargs],
  311. ) -> BaseModelOutput:
  312. hidden_states = inputs_embeds
  313. for encoder_layer in self.layers:
  314. hidden_states = encoder_layer(
  315. hidden_states,
  316. attention_mask,
  317. **kwargs,
  318. )
  319. return BaseModelOutput(last_hidden_state=hidden_states)
  320. class Siglip2VisionTransformer(nn.Module):
  321. def __init__(self, config: Siglip2VisionConfig):
  322. super().__init__()
  323. self.config = config
  324. embed_dim = config.hidden_size
  325. self.embeddings = Siglip2VisionEmbeddings(config)
  326. self.encoder = Siglip2Encoder(config)
  327. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  328. self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
  329. if self.use_head:
  330. self.head = Siglip2MultiheadAttentionPoolingHead(config)
  331. @auto_docstring
  332. def forward(
  333. self,
  334. pixel_values: torch.FloatTensor,
  335. attention_mask: torch.Tensor,
  336. spatial_shapes: torch.LongTensor,
  337. output_attentions: Optional[bool] = None,
  338. output_hidden_states: Optional[bool] = None,
  339. ) -> BaseModelOutputWithPooling:
  340. r"""
  341. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  342. Tensor containing the spatial dimensions (height, width) of the input images.
  343. """
  344. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  345. output_hidden_states = (
  346. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  347. )
  348. hidden_states = self.embeddings(pixel_values, spatial_shapes)
  349. if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
  350. # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
  351. encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  352. else:
  353. encoder_attention_mask = attention_mask
  354. encoder_outputs: BaseModelOutput = self.encoder(
  355. inputs_embeds=hidden_states,
  356. attention_mask=encoder_attention_mask,
  357. output_attentions=output_attentions,
  358. output_hidden_states=output_hidden_states,
  359. )
  360. last_hidden_state = encoder_outputs.last_hidden_state
  361. last_hidden_state = self.post_layernorm(last_hidden_state)
  362. pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
  363. return BaseModelOutputWithPooling(
  364. last_hidden_state=last_hidden_state,
  365. pooler_output=pooler_output,
  366. hidden_states=encoder_outputs.hidden_states,
  367. attentions=encoder_outputs.attentions,
  368. )
  369. def _trunc_normal_(tensor, mean, std, a, b):
  370. # Cut & paste from PyTorch official master until it's in a few official releases - RW
  371. # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
  372. def norm_cdf(x):
  373. # Computes standard normal cumulative distribution function
  374. return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
  375. if (mean < a - 2 * std) or (mean > b + 2 * std):
  376. warnings.warn(
  377. "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  378. "The distribution of values may be incorrect.",
  379. stacklevel=2,
  380. )
  381. # Values are generated by using a truncated uniform distribution and
  382. # then using the inverse CDF for the normal distribution.
  383. # Get upper and lower cdf values
  384. l = norm_cdf((a - mean) / std)
  385. u = norm_cdf((b - mean) / std)
  386. # Uniformly fill tensor with values from [l, u], then translate to
  387. # [2l-1, 2u-1].
  388. tensor.uniform_(2 * l - 1, 2 * u - 1)
  389. # Use inverse cdf transform for normal distribution to get truncated
  390. # standard normal
  391. tensor.erfinv_()
  392. # Transform to proper mean, std
  393. tensor.mul_(std * math.sqrt(2.0))
  394. tensor.add_(mean)
  395. # Clamp to ensure it's in the proper range
  396. tensor.clamp_(min=a, max=b)
  397. def trunc_normal_tf_(
  398. tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
  399. ) -> torch.Tensor:
  400. """Fills the input Tensor with values drawn from a truncated
  401. normal distribution. The values are effectively drawn from the
  402. normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
  403. with values outside :math:`[a, b]` redrawn until they are within
  404. the bounds. The method used for generating the random values works
  405. best when :math:`a \\leq \text{mean} \\leq b`.
  406. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
  407. bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
  408. and the result is subsequently scaled and shifted by the mean and std args.
  409. Args:
  410. tensor: an n-dimensional `torch.Tensor`
  411. mean: the mean of the normal distribution
  412. std: the standard deviation of the normal distribution
  413. a: the minimum cutoff value
  414. b: the maximum cutoff value
  415. """
  416. with torch.no_grad():
  417. _trunc_normal_(tensor, 0, 1.0, a, b)
  418. tensor.mul_(std).add_(mean)
  419. def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
  420. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  421. if mode == "fan_in":
  422. denom = fan_in
  423. elif mode == "fan_out":
  424. denom = fan_out
  425. elif mode == "fan_avg":
  426. denom = (fan_in + fan_out) / 2
  427. variance = scale / denom
  428. if distribution == "truncated_normal":
  429. # constant is stddev of standard normal truncated to (-2, 2)
  430. trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
  431. elif distribution == "normal":
  432. with torch.no_grad():
  433. tensor.normal_(std=math.sqrt(variance))
  434. elif distribution == "uniform":
  435. bound = math.sqrt(3 * variance)
  436. with torch.no_grad():
  437. tensor.uniform_(-bound, bound)
  438. else:
  439. raise ValueError(f"invalid distribution {distribution}")
  440. def lecun_normal_(tensor):
  441. variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
  442. def default_flax_embed_init(tensor):
  443. variance_scaling_(tensor, mode="fan_in", distribution="normal")
  444. @auto_docstring
  445. class Siglip2PreTrainedModel(PreTrainedModel):
  446. config: Siglip2Config
  447. base_model_prefix = "siglip2"
  448. supports_gradient_checkpointing = True
  449. _no_split_modules = [
  450. "Siglip2TextEmbeddings",
  451. "Siglip2VisionEmbeddings",
  452. "Siglip2EncoderLayer",
  453. "Siglip2MultiheadAttentionPoolingHead",
  454. ]
  455. _supports_flash_attn = True
  456. _supports_sdpa = True
  457. _supports_flex_attn = True
  458. _supports_attention_backend = True
  459. _can_record_outputs = {
  460. "hidden_states": Siglip2EncoderLayer,
  461. "attentions": Siglip2Attention,
  462. }
  463. def _init_weights(self, module):
  464. """Initialize the weights"""
  465. if isinstance(module, Siglip2VisionEmbeddings):
  466. width = (
  467. self.config.vision_config.hidden_size
  468. if isinstance(self.config, Siglip2Config)
  469. else self.config.hidden_size
  470. )
  471. nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
  472. elif isinstance(module, nn.Embedding):
  473. default_flax_embed_init(module.weight)
  474. elif isinstance(module, Siglip2Attention):
  475. nn.init.xavier_uniform_(module.q_proj.weight)
  476. nn.init.xavier_uniform_(module.k_proj.weight)
  477. nn.init.xavier_uniform_(module.v_proj.weight)
  478. nn.init.xavier_uniform_(module.out_proj.weight)
  479. nn.init.zeros_(module.q_proj.bias)
  480. nn.init.zeros_(module.k_proj.bias)
  481. nn.init.zeros_(module.v_proj.bias)
  482. nn.init.zeros_(module.out_proj.bias)
  483. elif isinstance(module, Siglip2MLP):
  484. nn.init.xavier_uniform_(module.fc1.weight)
  485. nn.init.xavier_uniform_(module.fc2.weight)
  486. nn.init.normal_(module.fc1.bias, std=1e-6)
  487. nn.init.normal_(module.fc2.bias, std=1e-6)
  488. elif isinstance(module, Siglip2MultiheadAttentionPoolingHead):
  489. nn.init.xavier_uniform_(module.probe.data)
  490. nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
  491. nn.init.zeros_(module.attention.in_proj_bias.data)
  492. elif isinstance(module, Siglip2Model):
  493. logit_scale_init = torch.log(torch.tensor(1.0))
  494. module.logit_scale.data.fill_(logit_scale_init)
  495. module.logit_bias.data.zero_()
  496. elif isinstance(module, Siglip2ForImageClassification):
  497. nn.init.normal_(
  498. module.classifier.weight,
  499. std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
  500. )
  501. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  502. lecun_normal_(module.weight)
  503. if module.bias is not None:
  504. nn.init.zeros_(module.bias)
  505. elif isinstance(module, nn.LayerNorm):
  506. module.bias.data.zero_()
  507. module.weight.data.fill_(1.0)
  508. class Siglip2TextEmbeddings(nn.Module):
  509. def __init__(self, config: Siglip2TextConfig):
  510. super().__init__()
  511. embed_dim = config.hidden_size
  512. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  513. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  514. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  515. self.register_buffer(
  516. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  517. )
  518. def forward(
  519. self,
  520. input_ids: Optional[torch.LongTensor] = None,
  521. position_ids: Optional[torch.LongTensor] = None,
  522. inputs_embeds: Optional[torch.FloatTensor] = None,
  523. ) -> torch.Tensor:
  524. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  525. max_position_embedding = self.position_embedding.weight.shape[0]
  526. if seq_length > max_position_embedding:
  527. raise ValueError(
  528. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  529. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  530. )
  531. if position_ids is None:
  532. position_ids = self.position_ids[:, :seq_length]
  533. if inputs_embeds is None:
  534. inputs_embeds = self.token_embedding(input_ids)
  535. position_embeddings = self.position_embedding(position_ids)
  536. embeddings = inputs_embeds + position_embeddings
  537. return embeddings
  538. class Siglip2TextTransformer(nn.Module):
  539. def __init__(self, config: Siglip2TextConfig):
  540. super().__init__()
  541. self.config = config
  542. embed_dim = config.hidden_size
  543. self.embeddings = Siglip2TextEmbeddings(config)
  544. self.encoder = Siglip2Encoder(config)
  545. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  546. self.head = nn.Linear(embed_dim, config.projection_size)
  547. @can_return_tuple
  548. @auto_docstring
  549. def forward(
  550. self,
  551. input_ids: Optional[torch.Tensor] = None,
  552. attention_mask: Optional[torch.Tensor] = None,
  553. position_ids: Optional[torch.Tensor] = None,
  554. **kwargs: Unpack[TransformersKwargs],
  555. ) -> BaseModelOutputWithPooling:
  556. if input_ids is None:
  557. raise ValueError("You have to specify input_ids")
  558. input_shape = input_ids.size()
  559. input_ids = input_ids.view(-1, input_shape[-1])
  560. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  561. # note: Siglip2's text model does not use a causal mask, unlike the original CLIP model.
  562. # expand attention_mask
  563. uses_flash_attention = "flash" in self.config._attn_implementation
  564. if uses_flash_attention:
  565. attention_mask = None
  566. elif attention_mask is not None and not uses_flash_attention:
  567. # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
  568. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  569. encoder_outputs: BaseModelOutput = self.encoder(
  570. inputs_embeds=hidden_states,
  571. attention_mask=attention_mask,
  572. **kwargs,
  573. )
  574. last_hidden_state = encoder_outputs.last_hidden_state
  575. last_hidden_state = self.final_layer_norm(last_hidden_state)
  576. # The model uses the last token's hidden state, which may be padding.
  577. pooled_output = last_hidden_state[:, -1, :]
  578. pooled_output = self.head(pooled_output)
  579. return BaseModelOutputWithPooling(
  580. last_hidden_state=last_hidden_state,
  581. pooler_output=pooled_output,
  582. )
  583. @auto_docstring(
  584. custom_intro="""
  585. The text model from Siglip2 without any head or projection on top.
  586. """
  587. )
  588. class Siglip2TextModel(Siglip2PreTrainedModel):
  589. config: Siglip2TextConfig
  590. def __init__(self, config: Siglip2TextConfig):
  591. super().__init__(config)
  592. self.text_model = Siglip2TextTransformer(config)
  593. # Initialize weights and apply final processing
  594. self.post_init()
  595. def get_input_embeddings(self) -> nn.Module:
  596. return self.text_model.embeddings.token_embedding
  597. def set_input_embeddings(self, value):
  598. self.text_model.embeddings.token_embedding = value
  599. @check_model_inputs(tie_last_hidden_states=False)
  600. @auto_docstring
  601. def forward(
  602. self,
  603. input_ids: Optional[torch.Tensor] = None,
  604. attention_mask: Optional[torch.Tensor] = None,
  605. position_ids: Optional[torch.Tensor] = None,
  606. **kwargs: Unpack[TransformersKwargs],
  607. ) -> BaseModelOutputWithPooling:
  608. r"""
  609. Examples:
  610. ```python
  611. >>> from transformers import AutoTokenizer, Siglip2TextModel
  612. >>> model = Siglip2TextModel.from_pretrained("google/siglip2-base-patch16-224")
  613. >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")
  614. >>> # important: make sure to set padding="max_length" as that's how the model was trained
  615. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
  616. >>> outputs = model(**inputs)
  617. >>> last_hidden_state = outputs.last_hidden_state
  618. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  619. ```"""
  620. return self.text_model(
  621. input_ids=input_ids,
  622. attention_mask=attention_mask,
  623. position_ids=position_ids,
  624. **kwargs,
  625. )
  626. class Siglip2MultiheadAttentionPoolingHead(nn.Module):
  627. """Multihead Attention Pooling."""
  628. def __init__(self, config: Siglip2VisionConfig):
  629. super().__init__()
  630. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  631. self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
  632. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  633. self.mlp = Siglip2MLP(config)
  634. self.num_heads = config.num_attention_heads
  635. def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  636. batch_size = hidden_state.shape[0]
  637. probe = self.probe.repeat(batch_size, 1, 1)
  638. if attention_mask is not None:
  639. target_len, source_len = probe.shape[1], hidden_state.shape[1]
  640. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
  641. attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
  642. attention_mask = attention_mask.reshape(-1, target_len, source_len)
  643. hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
  644. residual = hidden_state
  645. hidden_state = self.layernorm(hidden_state)
  646. hidden_state = residual + self.mlp(hidden_state)
  647. return hidden_state[:, 0]
  648. @auto_docstring(
  649. custom_intro="""
  650. The vision model from Siglip2 without any head or projection on top.
  651. """
  652. )
  653. class Siglip2VisionModel(Siglip2PreTrainedModel):
  654. config: Siglip2VisionConfig
  655. main_input_name = "pixel_values"
  656. def __init__(self, config: Siglip2VisionConfig):
  657. super().__init__(config)
  658. self.vision_model = Siglip2VisionTransformer(config)
  659. # Initialize weights and apply final processing
  660. self.post_init()
  661. def get_input_embeddings(self) -> nn.Module:
  662. return self.vision_model.embeddings.patch_embedding
  663. @check_model_inputs(tie_last_hidden_states=False)
  664. @auto_docstring
  665. def forward(
  666. self,
  667. pixel_values: torch.FloatTensor,
  668. pixel_attention_mask: torch.Tensor,
  669. spatial_shapes: torch.LongTensor,
  670. output_attentions: Optional[bool] = None,
  671. output_hidden_states: Optional[bool] = None,
  672. ) -> BaseModelOutputWithPooling:
  673. r"""
  674. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  675. Mask to avoid performing attention on padding pixel indices.
  676. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  677. Tensor containing the spatial dimensions (height, width) of the input images.
  678. Examples:
  679. ```python
  680. >>> from PIL import Image
  681. >>> import requests
  682. >>> from transformers import AutoProcessor, Siglip2VisionModel
  683. >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
  684. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  685. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  686. >>> image = Image.open(requests.get(url, stream=True).raw)
  687. >>> inputs = processor(images=image, return_tensors="pt")
  688. >>> outputs = model(**inputs)
  689. >>> last_hidden_state = outputs.last_hidden_state
  690. >>> pooled_output = outputs.pooler_output # pooled features
  691. ```"""
  692. return self.vision_model(
  693. pixel_values=pixel_values,
  694. attention_mask=pixel_attention_mask,
  695. spatial_shapes=spatial_shapes,
  696. output_attentions=output_attentions,
  697. output_hidden_states=output_hidden_states,
  698. )
  699. @auto_docstring
  700. class Siglip2Model(Siglip2PreTrainedModel):
  701. config: Siglip2Config
  702. def __init__(self, config: Siglip2Config):
  703. super().__init__(config)
  704. if not isinstance(config.text_config, Siglip2TextConfig):
  705. raise TypeError(
  706. "config.text_config is expected to be of type Siglip2TextConfig but is of type"
  707. f" {type(config.text_config)}."
  708. )
  709. if not isinstance(config.vision_config, Siglip2VisionConfig):
  710. raise TypeError(
  711. "config.vision_config is expected to be of type Siglip2VisionConfig but is of type"
  712. f" {type(config.vision_config)}."
  713. )
  714. text_config = config.text_config
  715. vision_config = config.vision_config
  716. # First, initialize the text and vision models with proper attention implementation
  717. text_model = Siglip2TextModel._from_config(text_config)
  718. vision_model = Siglip2VisionModel._from_config(vision_config)
  719. # Second, get the text and vision submodules (for backward compatibility)
  720. self.text_model = text_model.text_model
  721. self.vision_model = vision_model.vision_model
  722. self.logit_scale = nn.Parameter(torch.randn(1))
  723. self.logit_bias = nn.Parameter(torch.randn(1))
  724. # Initialize weights and apply final processing
  725. self.post_init()
  726. @filter_out_non_signature_kwargs()
  727. @auto_docstring
  728. def get_text_features(
  729. self,
  730. input_ids: torch.Tensor,
  731. attention_mask: Optional[torch.Tensor] = None,
  732. position_ids: Optional[torch.Tensor] = None,
  733. ) -> torch.FloatTensor:
  734. r"""
  735. Returns:
  736. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  737. applying the projection layer to the pooled output of [`Siglip2TextModel`].
  738. Examples:
  739. ```python
  740. >>> from transformers import AutoTokenizer, AutoModel
  741. >>> import torch
  742. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  743. >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")
  744. >>> # important: make sure to set padding="max_length" as that's how the model was trained
  745. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
  746. >>> with torch.no_grad():
  747. ... text_features = model.get_text_features(**inputs)
  748. ```"""
  749. text_outputs: BaseModelOutputWithPooling = self.text_model(
  750. input_ids=input_ids,
  751. attention_mask=attention_mask,
  752. position_ids=position_ids,
  753. )
  754. pooled_output = text_outputs.pooler_output
  755. return pooled_output
  756. @filter_out_non_signature_kwargs()
  757. @auto_docstring
  758. def get_image_features(
  759. self,
  760. pixel_values: Optional[torch.FloatTensor] = None,
  761. pixel_attention_mask: Optional[torch.Tensor] = None,
  762. spatial_shapes: Optional[torch.LongTensor] = None,
  763. ) -> torch.FloatTensor:
  764. r"""
  765. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  766. Mask to avoid performing attention on padding pixel indices.
  767. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  768. Tensor containing the spatial dimensions (height, width) of the input images.
  769. Returns:
  770. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  771. applying the projection layer to the pooled output of [`Siglip2VisionModel`].
  772. Examples:
  773. ```python
  774. >>> import torch
  775. >>> from transformers import AutoProcessor, AutoModel
  776. >>> from transformers.image_utils import load_image
  777. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  778. >>> image = load_image(url)
  779. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  780. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  781. >>> inputs = processor(images=image, return_tensors="pt")
  782. >>> with torch.no_grad():
  783. ... image_features = model.get_image_features(**inputs)
  784. ```
  785. """
  786. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  787. pixel_values=pixel_values,
  788. attention_mask=pixel_attention_mask,
  789. spatial_shapes=spatial_shapes,
  790. )
  791. pooled_output = vision_outputs.pooler_output
  792. return pooled_output
  793. # NOTE: Siglip2Model uses Pretrained backbones, so we don't need to add `check_model_inputs` here
  794. @can_return_tuple
  795. @auto_docstring
  796. def forward(
  797. self,
  798. input_ids: Optional[torch.LongTensor] = None,
  799. pixel_values: Optional[torch.FloatTensor] = None,
  800. pixel_attention_mask: Optional[torch.Tensor] = None,
  801. spatial_shapes: Optional[torch.LongTensor] = None,
  802. attention_mask: Optional[torch.Tensor] = None,
  803. position_ids: Optional[torch.LongTensor] = None,
  804. return_loss: Optional[bool] = None,
  805. output_attentions: Optional[bool] = None,
  806. output_hidden_states: Optional[bool] = None,
  807. ) -> Siglip2Output:
  808. r"""
  809. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  810. Mask to avoid performing attention on padding pixel indices.
  811. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  812. Tensor containing the spatial dimensions (height, width) of the input images.
  813. return_loss (`bool`, *optional*):
  814. Whether or not to return the contrastive loss.
  815. Examples:
  816. ```python
  817. >>> from PIL import Image
  818. >>> import requests
  819. >>> from transformers import AutoProcessor, AutoModel
  820. >>> import torch
  821. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  822. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  823. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  824. >>> image = Image.open(requests.get(url, stream=True).raw)
  825. >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
  826. >>> # important: we pass `padding=max_length` since the model was trained with this
  827. >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
  828. >>> with torch.no_grad():
  829. ... outputs = model(**inputs)
  830. >>> logits_per_image = outputs.logits_per_image
  831. >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
  832. >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
  833. 31.9% that image 0 is 'a photo of 2 cats'
  834. ```
  835. """
  836. # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components.
  837. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  838. output_hidden_states = (
  839. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  840. )
  841. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  842. pixel_values=pixel_values,
  843. attention_mask=pixel_attention_mask,
  844. spatial_shapes=spatial_shapes,
  845. output_attentions=output_attentions,
  846. output_hidden_states=output_hidden_states,
  847. )
  848. text_outputs: BaseModelOutputWithPooling = self.text_model(
  849. input_ids=input_ids,
  850. attention_mask=attention_mask,
  851. position_ids=position_ids,
  852. output_attentions=output_attentions,
  853. output_hidden_states=output_hidden_states,
  854. )
  855. image_embeds = vision_outputs.pooler_output
  856. text_embeds = text_outputs.pooler_output
  857. # normalized features
  858. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  859. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  860. # cosine similarity as logits
  861. logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
  862. logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
  863. logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
  864. logits_per_image = logits_per_text.t()
  865. loss = None
  866. if return_loss:
  867. # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287
  868. eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
  869. m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
  870. loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
  871. nll = -torch.sum(loglik, dim=-1)
  872. loss = nll.mean()
  873. return Siglip2Output(
  874. loss=loss,
  875. logits_per_image=logits_per_image,
  876. logits_per_text=logits_per_text,
  877. text_embeds=text_embeds,
  878. image_embeds=image_embeds,
  879. text_model_output=text_outputs,
  880. vision_model_output=vision_outputs,
  881. )
  882. @auto_docstring(
  883. custom_intro="""
  884. Siglip2 vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
  885. the patch tokens) e.g. for ImageNet.
  886. """
  887. )
  888. class Siglip2ForImageClassification(Siglip2PreTrainedModel):
  889. main_input_name = "pixel_values"
  890. def __init__(self, config: Siglip2Config) -> None:
  891. super().__init__(config)
  892. self.num_labels = config.num_labels
  893. # Create the vision model with proper attention
  894. # and take only vision_model submodule (for backward compatibility)
  895. vision_model = Siglip2VisionModel._from_config(config.vision_config)
  896. self.vision_model = vision_model.vision_model
  897. # Classifier head
  898. self.classifier = (
  899. nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  900. )
  901. # Initialize weights and apply final processing
  902. self.post_init()
  903. @check_model_inputs()
  904. @auto_docstring
  905. def forward(
  906. self,
  907. pixel_values: Optional[torch.Tensor] = None,
  908. pixel_attention_mask: Optional[torch.Tensor] = None,
  909. spatial_shapes: Optional[torch.LongTensor] = None,
  910. labels: Optional[torch.Tensor] = None,
  911. output_attentions: Optional[bool] = None,
  912. output_hidden_states: Optional[bool] = None,
  913. ) -> ImageClassifierOutput:
  914. r"""
  915. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  916. Mask to avoid performing attention on padding pixel indices.
  917. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  918. Tensor containing the spatial dimensions (height, width) of the input images.
  919. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  920. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  921. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  922. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  923. Examples:
  924. ```python
  925. >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
  926. >>> import torch
  927. >>> from PIL import Image
  928. >>> import requests
  929. >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
  930. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  931. >>> image = Image.open(requests.get(url, stream=True).raw)
  932. >>> # note: we are loading a `Siglip2Model` from the hub here,
  933. >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
  934. >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
  935. >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")
  936. >>> inputs = image_processor(images=image, return_tensors="pt")
  937. >>> outputs = model(**inputs)
  938. >>> logits = outputs.logits
  939. >>> # model predicts one of the two classes
  940. >>> predicted_class_idx = logits.argmax(-1).item()
  941. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  942. Predicted class: LABEL_1
  943. ```
  944. """
  945. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  946. output_hidden_states = (
  947. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  948. )
  949. outputs: BaseModelOutputWithPooling = self.vision_model(
  950. pixel_values,
  951. attention_mask=pixel_attention_mask,
  952. spatial_shapes=spatial_shapes,
  953. output_attentions=output_attentions,
  954. output_hidden_states=output_hidden_states,
  955. )
  956. sequence_output = outputs.last_hidden_state
  957. # average pool the patch tokens
  958. if pixel_attention_mask is not None:
  959. pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
  960. sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
  961. else:
  962. sequence_output = torch.mean(sequence_output, dim=1)
  963. # apply classifier
  964. logits = self.classifier(sequence_output)
  965. loss = None
  966. if labels is not None:
  967. loss = self.loss_function(labels, logits, self.config)
  968. return ImageClassifierOutput(
  969. loss=loss,
  970. logits=logits,
  971. hidden_states=outputs.hidden_states,
  972. attentions=outputs.attentions,
  973. )
  974. __all__ = [
  975. "Siglip2Model",
  976. "Siglip2PreTrainedModel",
  977. "Siglip2TextModel",
  978. "Siglip2VisionModel",
  979. "Siglip2ForImageClassification",
  980. ]