modular_siglip2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Optional
  16. import torch
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
  20. from transformers.models.siglip.modeling_siglip import (
  21. BaseModelOutput,
  22. BaseModelOutputWithPooling,
  23. ImageClassifierOutput,
  24. SiglipForImageClassification,
  25. SiglipModel,
  26. SiglipMultiheadAttentionPoolingHead,
  27. SiglipOutput,
  28. SiglipPreTrainedModel,
  29. SiglipTextModel,
  30. SiglipTextModelOutput,
  31. SiglipVisionModel,
  32. SiglipVisionModelOutput,
  33. SiglipVisionTransformer,
  34. )
  35. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  36. from ...utils import auto_docstring, filter_out_non_signature_kwargs
  37. class Siglip2TextConfig(SiglipTextConfig):
  38. pass
  39. class Siglip2VisionConfig(SiglipVisionConfig):
  40. r"""
  41. This is the configuration class to store the configuration of a [`Siglip2VisionModel`]. It is used to instantiate a
  42. Siglip2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
  43. configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip2
  44. [google/siglip2-base-patch16-naflex](https://huggingface.co/google/siglip2-base-patch16-naflex) architecture.
  45. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  46. documentation from [`PretrainedConfig`] for more information.
  47. Args:
  48. hidden_size (`int`, *optional*, defaults to 768):
  49. Dimensionality of the encoder layers and the pooler layer.
  50. intermediate_size (`int`, *optional*, defaults to 3072):
  51. Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
  52. num_hidden_layers (`int`, *optional*, defaults to 12):
  53. Number of hidden layers in the Transformer encoder.
  54. num_attention_heads (`int`, *optional*, defaults to 12):
  55. Number of attention heads for each attention layer in the Transformer encoder.
  56. num_channels (`int`, *optional*, defaults to 3):
  57. Number of channels in the input images.
  58. num_patches (`int`, *optional*, defaults to 256):
  59. The number of patches in the image with the size of (`patch_size`, `patch_size`).
  60. The image is resized to fill maximum of this number of patches, and to preserve
  61. the aspect ratio. In case the resulted number of patches is lower, the image is
  62. padded in "patch" dimension.
  63. patch_size (`int`, *optional*, defaults to 16):
  64. The size (resolution) of each patch.
  65. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
  66. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  67. `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
  68. layer_norm_eps (`float`, *optional*, defaults to 1e-06):
  69. The epsilon used by the layer normalization layers.
  70. attention_dropout (`float`, *optional*, defaults to 0.0):
  71. The dropout ratio for the attention probabilities.
  72. Example:
  73. ```python
  74. >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel
  75. >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration
  76. >>> configuration = Siglip2VisionConfig()
  77. >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration
  78. >>> model = Siglip2VisionModel(configuration)
  79. >>> # Accessing the model configuration
  80. >>> configuration = model.config
  81. ```"""
  82. def __init__(
  83. self,
  84. hidden_size=768,
  85. intermediate_size=3072,
  86. num_hidden_layers=12,
  87. num_attention_heads=12,
  88. num_channels=3,
  89. num_patches=256,
  90. patch_size=16,
  91. hidden_act="gelu_pytorch_tanh",
  92. layer_norm_eps=1e-6,
  93. attention_dropout=0.0,
  94. **kwargs,
  95. ):
  96. super().__init__(**kwargs)
  97. self.num_patches = num_patches
  98. del self.image_size
  99. class Siglip2Config(SiglipConfig):
  100. pass
  101. class Siglip2VisionOutput(SiglipVisionModelOutput):
  102. pass
  103. class Siglip2TextOutput(SiglipTextModelOutput):
  104. pass
  105. class Siglip2Output(SiglipOutput):
  106. pass
  107. class Siglip2VisionEmbeddings(nn.Module):
  108. def __init__(self, config: Siglip2VisionConfig):
  109. super().__init__()
  110. self.config = config
  111. self.embed_dim = config.hidden_size
  112. self.patch_size = config.patch_size
  113. self.patch_embedding = nn.Linear(
  114. in_features=config.num_channels * self.patch_size * self.patch_size,
  115. out_features=self.embed_dim,
  116. )
  117. self.num_patches = config.num_patches
  118. self.position_embedding_size = int(self.num_patches**0.5)
  119. self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
  120. @staticmethod
  121. def resize_positional_embeddings(
  122. positional_embeddings: torch.Tensor,
  123. spatial_shapes: torch.LongTensor,
  124. max_length: int,
  125. ) -> torch.Tensor:
  126. """
  127. Resize positional embeddings to image-specific size and pad to a fixed size.
  128. Args:
  129. positional_embeddings (`torch.Tensor`):
  130. Position embeddings of shape (height, width, embed_dim)
  131. spatial_shapes (`torch.LongTensor`):
  132. Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
  133. max_length (`int`):
  134. Maximum length of the positional embeddings to pad resized positional embeddings to
  135. Returns:
  136. `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
  137. """
  138. batch_size = spatial_shapes.shape[0]
  139. embed_dim = positional_embeddings.shape[-1]
  140. source_dtype = positional_embeddings.dtype
  141. resulted_positional_embeddings = torch.empty(
  142. (batch_size, max_length, embed_dim),
  143. device=positional_embeddings.device,
  144. dtype=source_dtype,
  145. )
  146. # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
  147. positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
  148. # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
  149. if positional_embeddings.device.type == "cpu":
  150. positional_embeddings = positional_embeddings.to(torch.float32)
  151. for i in range(batch_size):
  152. # (1, dim, height, width) -> (1, dim, target_height, target_width)
  153. height, width = spatial_shapes[i]
  154. resized_embeddings = F.interpolate(
  155. positional_embeddings,
  156. size=(height, width),
  157. mode="bilinear",
  158. align_corners=False,
  159. antialias=True,
  160. )
  161. # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
  162. resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
  163. # Cast to original dtype
  164. resized_embeddings = resized_embeddings.to(source_dtype)
  165. resulted_positional_embeddings[i, : height * width] = resized_embeddings
  166. resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
  167. return resulted_positional_embeddings
  168. def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
  169. """
  170. Args:
  171. pixel_values (`torch.FloatTensor`):
  172. Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
  173. spatial_shapes (`list[tuple[int, int]]`):
  174. Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
  175. """
  176. # Apply patch embeddings to already patchified pixel values
  177. target_dtype = self.patch_embedding.weight.dtype
  178. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  179. # Get positional resized and padded positional embeddings
  180. positional_embeddings = self.position_embedding.weight.reshape(
  181. self.position_embedding_size, self.position_embedding_size, -1
  182. )
  183. resized_positional_embeddings = self.resize_positional_embeddings(
  184. positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
  185. )
  186. # Add positional embeddings to patch embeddings
  187. embeddings = patch_embeds + resized_positional_embeddings
  188. return embeddings
  189. class Siglip2VisionTransformer(SiglipVisionTransformer):
  190. def __init__(self, config: Siglip2VisionConfig):
  191. super().__init__(config)
  192. # Update: add `spatial_shapes` and `attention_mask`
  193. def forward(
  194. self,
  195. pixel_values: torch.FloatTensor,
  196. attention_mask: torch.Tensor,
  197. spatial_shapes: torch.LongTensor,
  198. output_attentions: Optional[bool] = None,
  199. output_hidden_states: Optional[bool] = None,
  200. ) -> BaseModelOutputWithPooling:
  201. r"""
  202. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  203. Tensor containing the spatial dimensions (height, width) of the input images.
  204. """
  205. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  206. output_hidden_states = (
  207. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  208. )
  209. hidden_states = self.embeddings(pixel_values, spatial_shapes)
  210. if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
  211. # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
  212. encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  213. else:
  214. encoder_attention_mask = attention_mask
  215. encoder_outputs: BaseModelOutput = self.encoder(
  216. inputs_embeds=hidden_states,
  217. attention_mask=encoder_attention_mask,
  218. output_attentions=output_attentions,
  219. output_hidden_states=output_hidden_states,
  220. )
  221. last_hidden_state = encoder_outputs.last_hidden_state
  222. last_hidden_state = self.post_layernorm(last_hidden_state)
  223. pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
  224. return BaseModelOutputWithPooling(
  225. last_hidden_state=last_hidden_state,
  226. pooler_output=pooler_output,
  227. hidden_states=encoder_outputs.hidden_states,
  228. attentions=encoder_outputs.attentions,
  229. )
  230. class Siglip2PreTrainedModel(SiglipPreTrainedModel):
  231. pass
  232. class Siglip2TextModel(SiglipTextModel):
  233. pass
  234. class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead):
  235. def __init__(self, config: Siglip2VisionConfig):
  236. super().__init__(config)
  237. self.num_heads = config.num_attention_heads
  238. def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  239. batch_size = hidden_state.shape[0]
  240. probe = self.probe.repeat(batch_size, 1, 1)
  241. if attention_mask is not None:
  242. target_len, source_len = probe.shape[1], hidden_state.shape[1]
  243. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
  244. attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
  245. attention_mask = attention_mask.reshape(-1, target_len, source_len)
  246. hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
  247. residual = hidden_state
  248. hidden_state = self.layernorm(hidden_state)
  249. hidden_state = residual + self.mlp(hidden_state)
  250. return hidden_state[:, 0]
  251. class Siglip2VisionModel(SiglipVisionModel):
  252. # Update: add `spatial_shapes` and `pixel_attention_mask`
  253. def forward(
  254. self,
  255. pixel_values: torch.FloatTensor,
  256. pixel_attention_mask: torch.Tensor,
  257. spatial_shapes: torch.LongTensor,
  258. output_attentions: Optional[bool] = None,
  259. output_hidden_states: Optional[bool] = None,
  260. ) -> BaseModelOutputWithPooling:
  261. r"""
  262. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  263. Mask to avoid performing attention on padding pixel indices.
  264. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  265. Tensor containing the spatial dimensions (height, width) of the input images.
  266. Examples:
  267. ```python
  268. >>> from PIL import Image
  269. >>> import requests
  270. >>> from transformers import AutoProcessor, Siglip2VisionModel
  271. >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
  272. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  273. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  274. >>> image = Image.open(requests.get(url, stream=True).raw)
  275. >>> inputs = processor(images=image, return_tensors="pt")
  276. >>> outputs = model(**inputs)
  277. >>> last_hidden_state = outputs.last_hidden_state
  278. >>> pooled_output = outputs.pooler_output # pooled features
  279. ```"""
  280. return self.vision_model(
  281. pixel_values=pixel_values,
  282. attention_mask=pixel_attention_mask,
  283. spatial_shapes=spatial_shapes,
  284. output_attentions=output_attentions,
  285. output_hidden_states=output_hidden_states,
  286. )
  287. class Siglip2Model(SiglipModel):
  288. # Update: add `spatial_shapes` and `pixel_attention_mask`
  289. @filter_out_non_signature_kwargs()
  290. @auto_docstring
  291. def get_image_features(
  292. self,
  293. pixel_values: Optional[torch.FloatTensor] = None,
  294. pixel_attention_mask: Optional[torch.Tensor] = None,
  295. spatial_shapes: Optional[torch.LongTensor] = None,
  296. ) -> torch.FloatTensor:
  297. r"""
  298. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  299. Mask to avoid performing attention on padding pixel indices.
  300. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  301. Tensor containing the spatial dimensions (height, width) of the input images.
  302. Returns:
  303. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  304. applying the projection layer to the pooled output of [`Siglip2VisionModel`].
  305. Examples:
  306. ```python
  307. >>> import torch
  308. >>> from transformers import AutoProcessor, AutoModel
  309. >>> from transformers.image_utils import load_image
  310. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  311. >>> image = load_image(url)
  312. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  313. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  314. >>> inputs = processor(images=image, return_tensors="pt")
  315. >>> with torch.no_grad():
  316. ... image_features = model.get_image_features(**inputs)
  317. ```
  318. """
  319. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  320. pixel_values=pixel_values,
  321. attention_mask=pixel_attention_mask,
  322. spatial_shapes=spatial_shapes,
  323. )
  324. pooled_output = vision_outputs.pooler_output
  325. return pooled_output
  326. # Update: add `spatial_shapes` and `pixel_attention_mask`
  327. def forward(
  328. self,
  329. input_ids: Optional[torch.LongTensor] = None,
  330. pixel_values: Optional[torch.FloatTensor] = None,
  331. pixel_attention_mask: Optional[torch.Tensor] = None,
  332. spatial_shapes: Optional[torch.LongTensor] = None,
  333. attention_mask: Optional[torch.Tensor] = None,
  334. position_ids: Optional[torch.LongTensor] = None,
  335. return_loss: Optional[bool] = None,
  336. output_attentions: Optional[bool] = None,
  337. output_hidden_states: Optional[bool] = None,
  338. ) -> Siglip2Output:
  339. r"""
  340. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  341. Mask to avoid performing attention on padding pixel indices.
  342. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  343. Tensor containing the spatial dimensions (height, width) of the input images.
  344. return_loss (`bool`, *optional*):
  345. Whether or not to return the contrastive loss.
  346. Examples:
  347. ```python
  348. >>> from PIL import Image
  349. >>> import requests
  350. >>> from transformers import AutoProcessor, AutoModel
  351. >>> import torch
  352. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  353. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  354. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  355. >>> image = Image.open(requests.get(url, stream=True).raw)
  356. >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
  357. >>> # important: we pass `padding=max_length` since the model was trained with this
  358. >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
  359. >>> with torch.no_grad():
  360. ... outputs = model(**inputs)
  361. >>> logits_per_image = outputs.logits_per_image
  362. >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
  363. >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
  364. 31.9% that image 0 is 'a photo of 2 cats'
  365. ```
  366. """
  367. # Use Siglip2 model's config for some fields (if specified) instead of those of vision & text components.
  368. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  369. output_hidden_states = (
  370. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  371. )
  372. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  373. pixel_values=pixel_values,
  374. attention_mask=pixel_attention_mask,
  375. spatial_shapes=spatial_shapes,
  376. output_attentions=output_attentions,
  377. output_hidden_states=output_hidden_states,
  378. )
  379. text_outputs: BaseModelOutputWithPooling = self.text_model(
  380. input_ids=input_ids,
  381. attention_mask=attention_mask,
  382. position_ids=position_ids,
  383. output_attentions=output_attentions,
  384. output_hidden_states=output_hidden_states,
  385. )
  386. image_embeds = vision_outputs.pooler_output
  387. text_embeds = text_outputs.pooler_output
  388. # normalized features
  389. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  390. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  391. # cosine similarity as logits
  392. logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
  393. logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
  394. logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
  395. logits_per_image = logits_per_text.t()
  396. loss = None
  397. if return_loss:
  398. # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287
  399. eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
  400. m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
  401. loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
  402. nll = -torch.sum(loglik, dim=-1)
  403. loss = nll.mean()
  404. return Siglip2Output(
  405. loss=loss,
  406. logits_per_image=logits_per_image,
  407. logits_per_text=logits_per_text,
  408. text_embeds=text_embeds,
  409. image_embeds=image_embeds,
  410. text_model_output=text_outputs,
  411. vision_model_output=vision_outputs,
  412. )
  413. class Siglip2ForImageClassification(SiglipForImageClassification):
  414. # Update: add `spatial_shapes` and `pixel_attention_mask`
  415. def forward(
  416. self,
  417. pixel_values: Optional[torch.Tensor] = None,
  418. pixel_attention_mask: Optional[torch.Tensor] = None,
  419. spatial_shapes: Optional[torch.LongTensor] = None,
  420. labels: Optional[torch.Tensor] = None,
  421. output_attentions: Optional[bool] = None,
  422. output_hidden_states: Optional[bool] = None,
  423. ) -> ImageClassifierOutput:
  424. r"""
  425. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  426. Mask to avoid performing attention on padding pixel indices.
  427. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  428. Tensor containing the spatial dimensions (height, width) of the input images.
  429. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  430. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  431. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  432. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  433. Examples:
  434. ```python
  435. >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
  436. >>> import torch
  437. >>> from PIL import Image
  438. >>> import requests
  439. >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
  440. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  441. >>> image = Image.open(requests.get(url, stream=True).raw)
  442. >>> # note: we are loading a `Siglip2Model` from the hub here,
  443. >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
  444. >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
  445. >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")
  446. >>> inputs = image_processor(images=image, return_tensors="pt")
  447. >>> outputs = model(**inputs)
  448. >>> logits = outputs.logits
  449. >>> # model predicts one of the two classes
  450. >>> predicted_class_idx = logits.argmax(-1).item()
  451. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  452. Predicted class: LABEL_1
  453. ```
  454. """
  455. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  456. output_hidden_states = (
  457. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  458. )
  459. outputs: BaseModelOutputWithPooling = self.vision_model(
  460. pixel_values,
  461. attention_mask=pixel_attention_mask,
  462. spatial_shapes=spatial_shapes,
  463. output_attentions=output_attentions,
  464. output_hidden_states=output_hidden_states,
  465. )
  466. sequence_output = outputs.last_hidden_state
  467. # average pool the patch tokens
  468. if pixel_attention_mask is not None:
  469. pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
  470. sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
  471. else:
  472. sequence_output = torch.mean(sequence_output, dim=1)
  473. # apply classifier
  474. logits = self.classifier(sequence_output)
  475. loss = None
  476. if labels is not None:
  477. loss = self.loss_function(labels, logits, self.config)
  478. return ImageClassifierOutput(
  479. loss=loss,
  480. logits=logits,
  481. hidden_states=outputs.hidden_states,
  482. attentions=outputs.attentions,
  483. )
  484. __all__ = [
  485. "Siglip2Config",
  486. "Siglip2TextConfig",
  487. "Siglip2VisionConfig",
  488. "Siglip2Model",
  489. "Siglip2PreTrainedModel",
  490. "Siglip2TextModel",
  491. "Siglip2VisionModel",
  492. "Siglip2ForImageClassification",
  493. ]