modeling_siglip.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095
  1. # coding=utf-8
  2. # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Siglip model."""
  16. import math
  17. import warnings
  18. from dataclasses import dataclass
  19. from typing import Any, Callable, Optional, Union
  20. import numpy as np
  21. import torch
  22. from torch import nn
  23. from torch.nn.init import _calculate_fan_in_and_fan_out
  24. from ...activations import ACT2FN
  25. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  29. from ...processing_utils import Unpack
  30. from ...utils import (
  31. ModelOutput,
  32. TransformersKwargs,
  33. auto_docstring,
  34. can_return_tuple,
  35. filter_out_non_signature_kwargs,
  36. torch_int,
  37. )
  38. from ...utils.generic import check_model_inputs
  39. from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
  40. def _trunc_normal_(tensor, mean, std, a, b):
  41. # Cut & paste from PyTorch official master until it's in a few official releases - RW
  42. # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
  43. def norm_cdf(x):
  44. # Computes standard normal cumulative distribution function
  45. return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
  46. if (mean < a - 2 * std) or (mean > b + 2 * std):
  47. warnings.warn(
  48. "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  49. "The distribution of values may be incorrect.",
  50. stacklevel=2,
  51. )
  52. # Values are generated by using a truncated uniform distribution and
  53. # then using the inverse CDF for the normal distribution.
  54. # Get upper and lower cdf values
  55. l = norm_cdf((a - mean) / std)
  56. u = norm_cdf((b - mean) / std)
  57. # Uniformly fill tensor with values from [l, u], then translate to
  58. # [2l-1, 2u-1].
  59. tensor.uniform_(2 * l - 1, 2 * u - 1)
  60. # Use inverse cdf transform for normal distribution to get truncated
  61. # standard normal
  62. tensor.erfinv_()
  63. # Transform to proper mean, std
  64. tensor.mul_(std * math.sqrt(2.0))
  65. tensor.add_(mean)
  66. # Clamp to ensure it's in the proper range
  67. tensor.clamp_(min=a, max=b)
  68. def trunc_normal_tf_(
  69. tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
  70. ) -> torch.Tensor:
  71. """Fills the input Tensor with values drawn from a truncated
  72. normal distribution. The values are effectively drawn from the
  73. normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
  74. with values outside :math:`[a, b]` redrawn until they are within
  75. the bounds. The method used for generating the random values works
  76. best when :math:`a \\leq \text{mean} \\leq b`.
  77. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
  78. bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
  79. and the result is subsequently scaled and shifted by the mean and std args.
  80. Args:
  81. tensor: an n-dimensional `torch.Tensor`
  82. mean: the mean of the normal distribution
  83. std: the standard deviation of the normal distribution
  84. a: the minimum cutoff value
  85. b: the maximum cutoff value
  86. """
  87. with torch.no_grad():
  88. _trunc_normal_(tensor, 0, 1.0, a, b)
  89. tensor.mul_(std).add_(mean)
  90. def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
  91. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  92. if mode == "fan_in":
  93. denom = fan_in
  94. elif mode == "fan_out":
  95. denom = fan_out
  96. elif mode == "fan_avg":
  97. denom = (fan_in + fan_out) / 2
  98. variance = scale / denom
  99. if distribution == "truncated_normal":
  100. # constant is stddev of standard normal truncated to (-2, 2)
  101. trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
  102. elif distribution == "normal":
  103. with torch.no_grad():
  104. tensor.normal_(std=math.sqrt(variance))
  105. elif distribution == "uniform":
  106. bound = math.sqrt(3 * variance)
  107. with torch.no_grad():
  108. tensor.uniform_(-bound, bound)
  109. else:
  110. raise ValueError(f"invalid distribution {distribution}")
  111. def lecun_normal_(tensor):
  112. variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
  113. def default_flax_embed_init(tensor):
  114. variance_scaling_(tensor, mode="fan_in", distribution="normal")
  115. @dataclass
  116. @auto_docstring(
  117. custom_intro="""
  118. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  119. """
  120. )
  121. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
  122. class SiglipVisionModelOutput(ModelOutput):
  123. r"""
  124. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  125. The image embeddings obtained by applying the projection layer to the pooler_output.
  126. """
  127. image_embeds: Optional[torch.FloatTensor] = None
  128. last_hidden_state: Optional[torch.FloatTensor] = None
  129. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  130. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  131. @dataclass
  132. @auto_docstring(
  133. custom_intro="""
  134. Base class for text model's outputs that also contains a pooling of the last hidden states.
  135. """
  136. )
  137. # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
  138. class SiglipTextModelOutput(ModelOutput):
  139. r"""
  140. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  141. The text embeddings obtained by applying the projection layer to the pooler_output.
  142. """
  143. text_embeds: Optional[torch.FloatTensor] = None
  144. last_hidden_state: Optional[torch.FloatTensor] = None
  145. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  146. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  147. @dataclass
  148. @auto_docstring
  149. # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
  150. class SiglipOutput(ModelOutput):
  151. r"""
  152. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  153. Contrastive loss for image-text similarity.
  154. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  155. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  156. similarity scores.
  157. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  158. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  159. similarity scores.
  160. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  161. The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
  162. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  163. The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
  164. text_model_output (`BaseModelOutputWithPooling`):
  165. The output of the [`SiglipTextModel`].
  166. vision_model_output (`BaseModelOutputWithPooling`):
  167. The output of the [`SiglipVisionModel`].
  168. """
  169. loss: Optional[torch.FloatTensor] = None
  170. logits_per_image: Optional[torch.FloatTensor] = None
  171. logits_per_text: Optional[torch.FloatTensor] = None
  172. text_embeds: Optional[torch.FloatTensor] = None
  173. image_embeds: Optional[torch.FloatTensor] = None
  174. text_model_output: BaseModelOutputWithPooling = None
  175. vision_model_output: BaseModelOutputWithPooling = None
  176. def to_tuple(self) -> tuple[Any]:
  177. return tuple(
  178. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  179. for k in self.keys()
  180. )
  181. class SiglipVisionEmbeddings(nn.Module):
  182. def __init__(self, config: SiglipVisionConfig):
  183. super().__init__()
  184. self.config = config
  185. self.embed_dim = config.hidden_size
  186. self.image_size = config.image_size
  187. self.patch_size = config.patch_size
  188. self.patch_embedding = nn.Conv2d(
  189. in_channels=config.num_channels,
  190. out_channels=self.embed_dim,
  191. kernel_size=self.patch_size,
  192. stride=self.patch_size,
  193. padding="valid",
  194. )
  195. self.num_patches = (self.image_size // self.patch_size) ** 2
  196. self.num_positions = self.num_patches
  197. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  198. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  199. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  200. """
  201. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  202. images. This method is also adapted to support torch.jit tracing and no class embeddings.
  203. Adapted from:
  204. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  205. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  206. """
  207. num_patches = embeddings.shape[1]
  208. num_positions = self.position_embedding.weight.shape[0]
  209. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  210. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  211. return self.position_embedding(self.position_ids)
  212. patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
  213. dim = embeddings.shape[-1]
  214. new_height = height // self.patch_size
  215. new_width = width // self.patch_size
  216. sqrt_num_positions = torch_int(num_positions**0.5)
  217. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  218. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  219. patch_pos_embed = nn.functional.interpolate(
  220. patch_pos_embed,
  221. size=(new_height, new_width),
  222. mode="bicubic",
  223. align_corners=False,
  224. )
  225. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  226. return patch_pos_embed
  227. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  228. _, _, height, width = pixel_values.shape
  229. target_dtype = self.patch_embedding.weight.dtype
  230. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  231. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  232. if interpolate_pos_encoding:
  233. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  234. else:
  235. embeddings = embeddings + self.position_embedding(self.position_ids)
  236. return embeddings
  237. # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
  238. class SiglipTextEmbeddings(nn.Module):
  239. def __init__(self, config: SiglipTextConfig):
  240. super().__init__()
  241. embed_dim = config.hidden_size
  242. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  243. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  244. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  245. self.register_buffer(
  246. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  247. )
  248. def forward(
  249. self,
  250. input_ids: Optional[torch.LongTensor] = None,
  251. position_ids: Optional[torch.LongTensor] = None,
  252. inputs_embeds: Optional[torch.FloatTensor] = None,
  253. ) -> torch.Tensor:
  254. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  255. max_position_embedding = self.position_embedding.weight.shape[0]
  256. if seq_length > max_position_embedding:
  257. raise ValueError(
  258. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  259. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  260. )
  261. if position_ids is None:
  262. position_ids = self.position_ids[:, :seq_length]
  263. if inputs_embeds is None:
  264. inputs_embeds = self.token_embedding(input_ids)
  265. position_embeddings = self.position_embedding(position_ids)
  266. embeddings = inputs_embeds + position_embeddings
  267. return embeddings
  268. def eager_attention_forward(
  269. module: nn.Module,
  270. query: torch.Tensor,
  271. key: torch.Tensor,
  272. value: torch.Tensor,
  273. attention_mask: Optional[torch.Tensor],
  274. scaling: float,
  275. dropout: float = 0.0,
  276. **kwargs,
  277. ):
  278. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  279. if attention_mask is not None:
  280. attn_weights = attn_weights + attention_mask
  281. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  282. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  283. attn_output = torch.matmul(attn_weights, value)
  284. attn_output = attn_output.transpose(1, 2).contiguous()
  285. return attn_output, attn_weights
  286. class SiglipAttention(nn.Module):
  287. """Multi-headed attention from 'Attention Is All You Need' paper"""
  288. def __init__(self, config):
  289. super().__init__()
  290. self.config = config
  291. self.embed_dim = config.hidden_size
  292. self.num_heads = config.num_attention_heads
  293. self.head_dim = self.embed_dim // self.num_heads
  294. if self.head_dim * self.num_heads != self.embed_dim:
  295. raise ValueError(
  296. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  297. f" {self.num_heads})."
  298. )
  299. self.scale = self.head_dim**-0.5
  300. self.dropout = config.attention_dropout
  301. self.is_causal = False
  302. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  303. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  304. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  305. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  306. def forward(
  307. self,
  308. hidden_states: torch.Tensor,
  309. attention_mask: Optional[torch.Tensor] = None,
  310. **kwargs,
  311. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  312. """Input shape: Batch x Time x Channel"""
  313. batch_size, seq_length, embed_dim = hidden_states.shape
  314. queries = self.q_proj(hidden_states)
  315. keys = self.k_proj(hidden_states)
  316. values = self.v_proj(hidden_states)
  317. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  318. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  319. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  320. attention_interface: Callable = eager_attention_forward
  321. if self.config._attn_implementation != "eager":
  322. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  323. attn_output, attn_weights = attention_interface(
  324. self,
  325. queries,
  326. keys,
  327. values,
  328. attention_mask,
  329. is_causal=self.is_causal,
  330. scaling=self.scale,
  331. dropout=0.0 if not self.training else self.dropout,
  332. )
  333. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  334. attn_output = self.out_proj(attn_output)
  335. return attn_output, attn_weights
  336. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
  337. class SiglipMLP(nn.Module):
  338. def __init__(self, config):
  339. super().__init__()
  340. self.config = config
  341. self.activation_fn = ACT2FN[config.hidden_act]
  342. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  343. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  344. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  345. hidden_states = self.fc1(hidden_states)
  346. hidden_states = self.activation_fn(hidden_states)
  347. hidden_states = self.fc2(hidden_states)
  348. return hidden_states
  349. class SiglipEncoderLayer(GradientCheckpointingLayer):
  350. def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]):
  351. super().__init__()
  352. self.embed_dim = config.hidden_size
  353. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  354. self.self_attn = SiglipAttention(config)
  355. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  356. self.mlp = SiglipMLP(config)
  357. @auto_docstring
  358. def forward(
  359. self,
  360. hidden_states: torch.Tensor,
  361. attention_mask: torch.Tensor,
  362. **kwargs: Unpack[TransformersKwargs],
  363. ) -> torch.FloatTensor:
  364. residual = hidden_states
  365. hidden_states = self.layer_norm1(hidden_states)
  366. hidden_states, _ = self.self_attn(
  367. hidden_states=hidden_states,
  368. attention_mask=attention_mask,
  369. **kwargs,
  370. )
  371. hidden_states = residual + hidden_states
  372. residual = hidden_states
  373. hidden_states = self.layer_norm2(hidden_states)
  374. hidden_states = self.mlp(hidden_states)
  375. hidden_states = residual + hidden_states
  376. return hidden_states
  377. @auto_docstring
  378. class SiglipPreTrainedModel(PreTrainedModel):
  379. config: SiglipConfig
  380. base_model_prefix = "siglip"
  381. supports_gradient_checkpointing = True
  382. _no_split_modules = [
  383. "SiglipTextEmbeddings",
  384. "SiglipVisionEmbeddings",
  385. "SiglipEncoderLayer",
  386. "SiglipMultiheadAttentionPoolingHead",
  387. ]
  388. _supports_flash_attn = True
  389. _supports_sdpa = True
  390. _supports_flex_attn = True
  391. _supports_attention_backend = True
  392. _can_record_outputs = {
  393. "hidden_states": SiglipEncoderLayer,
  394. "attentions": SiglipAttention,
  395. }
  396. def _init_weights(self, module):
  397. """Initialize the weights"""
  398. if isinstance(module, SiglipVisionEmbeddings):
  399. width = (
  400. self.config.vision_config.hidden_size
  401. if isinstance(self.config, SiglipConfig)
  402. else self.config.hidden_size
  403. )
  404. nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
  405. elif isinstance(module, nn.Embedding):
  406. default_flax_embed_init(module.weight)
  407. elif isinstance(module, SiglipAttention):
  408. nn.init.xavier_uniform_(module.q_proj.weight)
  409. nn.init.xavier_uniform_(module.k_proj.weight)
  410. nn.init.xavier_uniform_(module.v_proj.weight)
  411. nn.init.xavier_uniform_(module.out_proj.weight)
  412. nn.init.zeros_(module.q_proj.bias)
  413. nn.init.zeros_(module.k_proj.bias)
  414. nn.init.zeros_(module.v_proj.bias)
  415. nn.init.zeros_(module.out_proj.bias)
  416. elif isinstance(module, SiglipMLP):
  417. nn.init.xavier_uniform_(module.fc1.weight)
  418. nn.init.xavier_uniform_(module.fc2.weight)
  419. nn.init.normal_(module.fc1.bias, std=1e-6)
  420. nn.init.normal_(module.fc2.bias, std=1e-6)
  421. elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
  422. nn.init.xavier_uniform_(module.probe.data)
  423. nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
  424. nn.init.zeros_(module.attention.in_proj_bias.data)
  425. elif isinstance(module, SiglipModel):
  426. logit_scale_init = torch.log(torch.tensor(1.0))
  427. module.logit_scale.data.fill_(logit_scale_init)
  428. module.logit_bias.data.zero_()
  429. elif isinstance(module, SiglipForImageClassification):
  430. nn.init.normal_(
  431. module.classifier.weight,
  432. std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
  433. )
  434. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  435. lecun_normal_(module.weight)
  436. if module.bias is not None:
  437. nn.init.zeros_(module.bias)
  438. elif isinstance(module, nn.LayerNorm):
  439. module.bias.data.zero_()
  440. module.weight.data.fill_(1.0)
  441. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
  442. class SiglipEncoder(nn.Module):
  443. """
  444. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  445. [`SiglipEncoderLayer`].
  446. Args:
  447. config: SiglipConfig
  448. """
  449. def __init__(self, config: SiglipConfig):
  450. super().__init__()
  451. self.config = config
  452. self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  453. self.gradient_checkpointing = False
  454. # Ignore copy
  455. @auto_docstring
  456. def forward(
  457. self,
  458. inputs_embeds,
  459. attention_mask: Optional[torch.Tensor] = None,
  460. **kwargs: Unpack[TransformersKwargs],
  461. ) -> BaseModelOutput:
  462. hidden_states = inputs_embeds
  463. for encoder_layer in self.layers:
  464. hidden_states = encoder_layer(
  465. hidden_states,
  466. attention_mask,
  467. **kwargs,
  468. )
  469. return BaseModelOutput(last_hidden_state=hidden_states)
  470. class SiglipTextTransformer(nn.Module):
  471. def __init__(self, config: SiglipTextConfig):
  472. super().__init__()
  473. self.config = config
  474. embed_dim = config.hidden_size
  475. self.embeddings = SiglipTextEmbeddings(config)
  476. self.encoder = SiglipEncoder(config)
  477. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  478. self.head = nn.Linear(embed_dim, config.projection_size)
  479. @can_return_tuple
  480. @auto_docstring
  481. def forward(
  482. self,
  483. input_ids: Optional[torch.Tensor] = None,
  484. attention_mask: Optional[torch.Tensor] = None,
  485. position_ids: Optional[torch.Tensor] = None,
  486. **kwargs: Unpack[TransformersKwargs],
  487. ) -> BaseModelOutputWithPooling:
  488. if input_ids is None:
  489. raise ValueError("You have to specify input_ids")
  490. input_shape = input_ids.size()
  491. input_ids = input_ids.view(-1, input_shape[-1])
  492. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  493. # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
  494. # expand attention_mask
  495. uses_flash_attention = "flash" in self.config._attn_implementation
  496. if uses_flash_attention:
  497. attention_mask = None
  498. elif attention_mask is not None and not uses_flash_attention:
  499. # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
  500. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  501. encoder_outputs: BaseModelOutput = self.encoder(
  502. inputs_embeds=hidden_states,
  503. attention_mask=attention_mask,
  504. **kwargs,
  505. )
  506. last_hidden_state = encoder_outputs.last_hidden_state
  507. last_hidden_state = self.final_layer_norm(last_hidden_state)
  508. # The model uses the last token's hidden state, which may be padding.
  509. pooled_output = last_hidden_state[:, -1, :]
  510. pooled_output = self.head(pooled_output)
  511. return BaseModelOutputWithPooling(
  512. last_hidden_state=last_hidden_state,
  513. pooler_output=pooled_output,
  514. )
  515. @auto_docstring(
  516. custom_intro="""
  517. The text model from SigLIP without any head or projection on top.
  518. """
  519. )
  520. class SiglipTextModel(SiglipPreTrainedModel):
  521. config: SiglipTextConfig
  522. def __init__(self, config: SiglipTextConfig):
  523. super().__init__(config)
  524. self.text_model = SiglipTextTransformer(config)
  525. # Initialize weights and apply final processing
  526. self.post_init()
  527. def get_input_embeddings(self) -> nn.Module:
  528. return self.text_model.embeddings.token_embedding
  529. def set_input_embeddings(self, value):
  530. self.text_model.embeddings.token_embedding = value
  531. @check_model_inputs(tie_last_hidden_states=False)
  532. @auto_docstring
  533. def forward(
  534. self,
  535. input_ids: Optional[torch.Tensor] = None,
  536. attention_mask: Optional[torch.Tensor] = None,
  537. position_ids: Optional[torch.Tensor] = None,
  538. **kwargs: Unpack[TransformersKwargs],
  539. ) -> BaseModelOutputWithPooling:
  540. r"""
  541. Examples:
  542. ```python
  543. >>> from transformers import AutoTokenizer, SiglipTextModel
  544. >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
  545. >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
  546. >>> # important: make sure to set padding="max_length" as that's how the model was trained
  547. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
  548. >>> outputs = model(**inputs)
  549. >>> last_hidden_state = outputs.last_hidden_state
  550. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  551. ```"""
  552. return self.text_model(
  553. input_ids=input_ids,
  554. attention_mask=attention_mask,
  555. position_ids=position_ids,
  556. **kwargs,
  557. )
  558. class SiglipVisionTransformer(nn.Module):
  559. def __init__(self, config: SiglipVisionConfig):
  560. super().__init__()
  561. self.config = config
  562. embed_dim = config.hidden_size
  563. self.embeddings = SiglipVisionEmbeddings(config)
  564. self.encoder = SiglipEncoder(config)
  565. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  566. self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
  567. if self.use_head:
  568. self.head = SiglipMultiheadAttentionPoolingHead(config)
  569. @auto_docstring
  570. def forward(
  571. self,
  572. pixel_values,
  573. interpolate_pos_encoding: Optional[bool] = False,
  574. **kwargs: Unpack[TransformersKwargs],
  575. ) -> BaseModelOutputWithPooling:
  576. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  577. encoder_outputs: BaseModelOutput = self.encoder(
  578. inputs_embeds=hidden_states,
  579. **kwargs,
  580. )
  581. last_hidden_state = encoder_outputs.last_hidden_state
  582. last_hidden_state = self.post_layernorm(last_hidden_state)
  583. pooler_output = self.head(last_hidden_state) if self.use_head else None
  584. return BaseModelOutputWithPooling(
  585. last_hidden_state=last_hidden_state,
  586. pooler_output=pooler_output,
  587. )
  588. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  589. """Multihead Attention Pooling."""
  590. def __init__(self, config: SiglipVisionConfig):
  591. super().__init__()
  592. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  593. self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
  594. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  595. self.mlp = SiglipMLP(config)
  596. def forward(self, hidden_state):
  597. batch_size = hidden_state.shape[0]
  598. probe = self.probe.repeat(batch_size, 1, 1)
  599. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  600. residual = hidden_state
  601. hidden_state = self.layernorm(hidden_state)
  602. hidden_state = residual + self.mlp(hidden_state)
  603. return hidden_state[:, 0]
  604. @auto_docstring(
  605. custom_intro="""
  606. The vision model from SigLIP without any head or projection on top.
  607. """
  608. )
  609. class SiglipVisionModel(SiglipPreTrainedModel):
  610. config: SiglipVisionConfig
  611. main_input_name = "pixel_values"
  612. def __init__(self, config: SiglipVisionConfig):
  613. super().__init__(config)
  614. self.vision_model = SiglipVisionTransformer(config)
  615. # Initialize weights and apply final processing
  616. self.post_init()
  617. def get_input_embeddings(self) -> nn.Module:
  618. return self.vision_model.embeddings.patch_embedding
  619. @check_model_inputs(tie_last_hidden_states=False)
  620. @auto_docstring
  621. def forward(
  622. self,
  623. pixel_values,
  624. interpolate_pos_encoding: bool = False,
  625. **kwargs: Unpack[TransformersKwargs],
  626. ) -> BaseModelOutputWithPooling:
  627. r"""
  628. Examples:
  629. ```python
  630. >>> from PIL import Image
  631. >>> import requests
  632. >>> from transformers import AutoProcessor, SiglipVisionModel
  633. >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
  634. >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
  635. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  636. >>> image = Image.open(requests.get(url, stream=True).raw)
  637. >>> inputs = processor(images=image, return_tensors="pt")
  638. >>> outputs = model(**inputs)
  639. >>> last_hidden_state = outputs.last_hidden_state
  640. >>> pooled_output = outputs.pooler_output # pooled features
  641. ```"""
  642. return self.vision_model(
  643. pixel_values=pixel_values,
  644. interpolate_pos_encoding=interpolate_pos_encoding,
  645. **kwargs,
  646. )
  647. @auto_docstring
  648. class SiglipModel(SiglipPreTrainedModel):
  649. config: SiglipConfig
  650. def __init__(self, config: SiglipConfig):
  651. super().__init__(config)
  652. if not isinstance(config.text_config, SiglipTextConfig):
  653. raise TypeError(
  654. "config.text_config is expected to be of type SiglipTextConfig but is of type"
  655. f" {type(config.text_config)}."
  656. )
  657. if not isinstance(config.vision_config, SiglipVisionConfig):
  658. raise TypeError(
  659. "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
  660. f" {type(config.vision_config)}."
  661. )
  662. text_config = config.text_config
  663. vision_config = config.vision_config
  664. # First, initialize the text and vision models with proper attention implementation
  665. text_model = SiglipTextModel._from_config(text_config)
  666. vision_model = SiglipVisionModel._from_config(vision_config)
  667. # Second, get the text and vision submodules (for backward compatibility)
  668. self.text_model = text_model.text_model
  669. self.vision_model = vision_model.vision_model
  670. self.logit_scale = nn.Parameter(torch.randn(1))
  671. self.logit_bias = nn.Parameter(torch.randn(1))
  672. # Initialize weights and apply final processing
  673. self.post_init()
  674. @filter_out_non_signature_kwargs()
  675. @auto_docstring
  676. def get_text_features(
  677. self,
  678. input_ids: torch.Tensor,
  679. attention_mask: Optional[torch.Tensor] = None,
  680. position_ids: Optional[torch.Tensor] = None,
  681. ) -> torch.FloatTensor:
  682. r"""
  683. Returns:
  684. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  685. applying the projection layer to the pooled output of [`SiglipTextModel`].
  686. Examples:
  687. ```python
  688. >>> from transformers import AutoTokenizer, AutoModel
  689. >>> import torch
  690. >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
  691. >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
  692. >>> # important: make sure to set padding="max_length" as that's how the model was trained
  693. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
  694. >>> with torch.no_grad():
  695. ... text_features = model.get_text_features(**inputs)
  696. ```"""
  697. text_outputs: BaseModelOutputWithPooling = self.text_model(
  698. input_ids=input_ids,
  699. attention_mask=attention_mask,
  700. position_ids=position_ids,
  701. )
  702. pooled_output = text_outputs.pooler_output
  703. return pooled_output
  704. @filter_out_non_signature_kwargs()
  705. @auto_docstring
  706. def get_image_features(
  707. self,
  708. pixel_values: torch.FloatTensor,
  709. interpolate_pos_encoding: bool = False,
  710. **kwargs: Unpack[TransformersKwargs],
  711. ) -> torch.FloatTensor:
  712. r"""
  713. Returns:
  714. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  715. applying the projection layer to the pooled output of [`SiglipVisionModel`].
  716. Examples:
  717. ```python
  718. >>> import torch
  719. >>> from transformers import AutoProcessor, AutoModel
  720. >>> from transformers.image_utils import load_image
  721. >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
  722. >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
  723. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  724. >>> image = load_image(url)
  725. >>> inputs = processor(images=image, return_tensors="pt")
  726. >>> with torch.no_grad():
  727. ... image_features = model.get_image_features(**inputs)
  728. ```"""
  729. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  730. pixel_values=pixel_values,
  731. interpolate_pos_encoding=interpolate_pos_encoding,
  732. **kwargs,
  733. )
  734. pooled_output = vision_outputs.pooler_output
  735. return pooled_output
  736. # NOTE: SiglipModel uses Pretrained backbones, so we don't need to add `check_model_inputs` here
  737. @can_return_tuple
  738. @auto_docstring
  739. def forward(
  740. self,
  741. input_ids: Optional[torch.LongTensor] = None,
  742. pixel_values: Optional[torch.FloatTensor] = None,
  743. attention_mask: Optional[torch.Tensor] = None,
  744. position_ids: Optional[torch.LongTensor] = None,
  745. return_loss: Optional[bool] = None,
  746. interpolate_pos_encoding: bool = False,
  747. **kwargs: Unpack[TransformersKwargs],
  748. ) -> SiglipOutput:
  749. r"""
  750. return_loss (`bool`, *optional*):
  751. Whether or not to return the contrastive loss.
  752. Examples:
  753. ```python
  754. >>> from PIL import Image
  755. >>> import requests
  756. >>> from transformers import AutoProcessor, AutoModel
  757. >>> import torch
  758. >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
  759. >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
  760. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  761. >>> image = Image.open(requests.get(url, stream=True).raw)
  762. >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
  763. >>> # important: we pass `padding=max_length` since the model was trained with this
  764. >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
  765. >>> with torch.no_grad():
  766. ... outputs = model(**inputs)
  767. >>> logits_per_image = outputs.logits_per_image
  768. >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
  769. >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
  770. 31.9% that image 0 is 'a photo of 2 cats'
  771. ```"""
  772. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  773. pixel_values=pixel_values,
  774. interpolate_pos_encoding=interpolate_pos_encoding,
  775. **kwargs,
  776. )
  777. text_outputs: BaseModelOutputWithPooling = self.text_model(
  778. input_ids=input_ids,
  779. attention_mask=attention_mask,
  780. position_ids=position_ids,
  781. **kwargs,
  782. )
  783. image_embeds = vision_outputs.pooler_output
  784. text_embeds = text_outputs.pooler_output
  785. # normalized features
  786. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  787. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  788. # cosine similarity as logits
  789. logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
  790. logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
  791. logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
  792. logits_per_image = logits_per_text.t()
  793. loss = None
  794. if return_loss:
  795. # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
  796. eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
  797. m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
  798. loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
  799. nll = -torch.sum(loglik, dim=-1)
  800. loss = nll.mean()
  801. return SiglipOutput(
  802. loss=loss,
  803. logits_per_image=logits_per_image,
  804. logits_per_text=logits_per_text,
  805. text_embeds=text_embeds,
  806. image_embeds=image_embeds,
  807. text_model_output=text_outputs,
  808. vision_model_output=vision_outputs,
  809. )
  810. @auto_docstring(
  811. custom_intro="""
  812. SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
  813. the patch tokens) e.g. for ImageNet.
  814. """
  815. )
  816. class SiglipForImageClassification(SiglipPreTrainedModel):
  817. main_input_name = "pixel_values"
  818. def __init__(self, config: SiglipConfig) -> None:
  819. super().__init__(config)
  820. self.num_labels = config.num_labels
  821. # Create the vision model with proper attention
  822. # and take only vision_model submodule (for backward compatibility)
  823. vision_model = SiglipVisionModel._from_config(config.vision_config)
  824. self.vision_model = vision_model.vision_model
  825. # Classifier head
  826. self.classifier = (
  827. nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  828. )
  829. # Initialize weights and apply final processing
  830. self.post_init()
  831. @check_model_inputs()
  832. @auto_docstring
  833. def forward(
  834. self,
  835. pixel_values: Optional[torch.Tensor] = None,
  836. labels: Optional[torch.Tensor] = None,
  837. interpolate_pos_encoding: bool = False,
  838. **kwargs: Unpack[TransformersKwargs],
  839. ) -> ImageClassifierOutput:
  840. r"""
  841. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  842. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  843. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  844. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  845. Examples:
  846. ```python
  847. >>> from transformers import AutoImageProcessor, SiglipForImageClassification
  848. >>> import torch
  849. >>> from PIL import Image
  850. >>> import requests
  851. >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
  852. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  853. >>> image = Image.open(requests.get(url, stream=True).raw)
  854. >>> # note: we are loading a `SiglipModel` from the hub here,
  855. >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
  856. >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
  857. >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
  858. >>> inputs = image_processor(images=image, return_tensors="pt")
  859. >>> outputs = model(**inputs)
  860. >>> logits = outputs.logits
  861. >>> # model predicts one of the two classes
  862. >>> predicted_class_idx = logits.argmax(-1).item()
  863. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  864. Predicted class: LABEL_1
  865. ```"""
  866. outputs: BaseModelOutputWithPooling = self.vision_model(
  867. pixel_values,
  868. interpolate_pos_encoding=interpolate_pos_encoding,
  869. **kwargs,
  870. )
  871. sequence_output = outputs.last_hidden_state
  872. # average pool the patch tokens
  873. sequence_output = torch.mean(sequence_output, dim=1)
  874. # apply classifier
  875. logits = self.classifier(sequence_output)
  876. loss = None
  877. if labels is not None:
  878. loss = self.loss_function(labels, logits, self.config)
  879. return ImageClassifierOutput(
  880. loss=loss,
  881. logits=logits,
  882. )
  883. __all__ = [
  884. "SiglipModel",
  885. "SiglipPreTrainedModel",
  886. "SiglipTextModel",
  887. "SiglipVisionModel",
  888. "SiglipForImageClassification",
  889. ]