modeling_aimv2.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/aimv2/modular_aimv2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_aimv2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from dataclasses import dataclass
  23. from typing import Any, Callable, Optional
  24. import torch
  25. import torch.nn.functional as F
  26. from torch import nn
  27. from ...activations import ACT2FN
  28. from ...integrations import use_kernel_forward_from_hub
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs
  35. from ...utils.deprecation import deprecate_kwarg
  36. from ...utils.generic import check_model_inputs
  37. from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig
  38. @dataclass
  39. @auto_docstring
  40. class Aimv2Output(ModelOutput):
  41. r"""
  42. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  43. Contrastive loss for image-text similarity.
  44. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  45. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  46. similarity scores.
  47. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  48. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  49. similarity scores.
  50. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  51. The text embeddings obtained by applying the projection layer to the pooled output of [`Aimv2TextModel`].
  52. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  53. The image embeddings obtained by applying the projection layer to the pooled output of [`Aimv2VisionModel`].
  54. text_model_output (`BaseModelOutputWithPooling`):
  55. The output of the [`Aimv2TextModel`].
  56. vision_model_output (`BaseModelOutputWithPooling`):
  57. The output of the [`Aimv2VisionModel`].
  58. """
  59. loss: Optional[torch.FloatTensor] = None
  60. logits_per_image: Optional[torch.FloatTensor] = None
  61. logits_per_text: Optional[torch.FloatTensor] = None
  62. text_embeds: Optional[torch.FloatTensor] = None
  63. image_embeds: Optional[torch.FloatTensor] = None
  64. text_model_output: BaseModelOutputWithPooling = None
  65. vision_model_output: BaseModelOutputWithPooling = None
  66. def to_tuple(self) -> tuple[Any]:
  67. return tuple(
  68. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  69. for k in self.keys()
  70. )
  71. @use_kernel_forward_from_hub("RMSNorm")
  72. class Aimv2RMSNorm(nn.Module):
  73. def __init__(self, hidden_size, eps=1e-6):
  74. """
  75. Aimv2RMSNorm is equivalent to T5LayerNorm
  76. """
  77. super().__init__()
  78. self.weight = nn.Parameter(torch.ones(hidden_size))
  79. self.variance_epsilon = eps
  80. def forward(self, hidden_states):
  81. input_dtype = hidden_states.dtype
  82. hidden_states = hidden_states.to(torch.float32)
  83. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  84. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  85. return self.weight * hidden_states.to(input_dtype)
  86. def extra_repr(self):
  87. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  88. class Aimv2MLP(nn.Module):
  89. def __init__(self, config):
  90. super().__init__()
  91. self.config = config
  92. self.hidden_size = config.hidden_size
  93. self.intermediate_size = config.intermediate_size
  94. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  95. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  96. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  97. self.act_fn = ACT2FN[config.hidden_act]
  98. def forward(self, x):
  99. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  100. return down_proj
  101. class Aimv2VisionEmbeddings(nn.Module):
  102. def __init__(self, config: Aimv2VisionConfig):
  103. super().__init__()
  104. self.config = config
  105. self.patch_size = config.patch_size
  106. self.patch_embed = nn.Conv2d(
  107. config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
  108. )
  109. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  110. num_patches = (config.image_size // config.patch_size) ** 2
  111. if not self.config.is_native:
  112. self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
  113. self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
  114. @staticmethod
  115. def build_2d_sincos_position_embedding(
  116. height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
  117. ) -> torch.Tensor:
  118. grid_w = torch.arange(int(width), dtype=dtype, device=device)
  119. grid_h = torch.arange(int(height), dtype=dtype, device=device)
  120. grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy")
  121. pos_dim = embed_dim // 4
  122. omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
  123. omega = 1.0 / (temperature**omega)
  124. out_h = grid_h.flatten()[..., None] @ omega[None, :]
  125. out_w = grid_w.flatten()[..., None] @ omega[None, :]
  126. return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
  127. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  128. _, _, height, width = pixel_values.size()
  129. hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
  130. hidden_states = self.rms_norm(hidden_states)
  131. if self.config.is_native:
  132. pos_embed = self.build_2d_sincos_position_embedding(
  133. height // self.patch_size,
  134. width // self.patch_size,
  135. embed_dim=self.config.hidden_size,
  136. device=hidden_states.device,
  137. dtype=hidden_states.dtype,
  138. )
  139. else:
  140. pos_embed = self.position_embedding(self.position_ids)
  141. hidden_states = hidden_states + pos_embed
  142. return hidden_states
  143. class Aimv2TextEmbeddings(nn.Module):
  144. def __init__(self, config: Aimv2TextConfig):
  145. super().__init__()
  146. embed_dim = config.hidden_size
  147. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  148. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  149. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  150. self.register_buffer(
  151. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  152. )
  153. def forward(
  154. self,
  155. input_ids: Optional[torch.LongTensor] = None,
  156. position_ids: Optional[torch.LongTensor] = None,
  157. inputs_embeds: Optional[torch.FloatTensor] = None,
  158. ) -> torch.Tensor:
  159. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  160. max_position_embedding = self.position_embedding.weight.shape[0]
  161. if seq_length > max_position_embedding:
  162. raise ValueError(
  163. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  164. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  165. )
  166. if position_ids is None:
  167. position_ids = self.position_ids[:, :seq_length]
  168. if inputs_embeds is None:
  169. inputs_embeds = self.token_embedding(input_ids)
  170. position_embeddings = self.position_embedding(position_ids)
  171. embeddings = inputs_embeds + position_embeddings
  172. return embeddings
  173. def eager_attention_forward(
  174. module: nn.Module,
  175. query: torch.Tensor,
  176. key: torch.Tensor,
  177. value: torch.Tensor,
  178. attention_mask: Optional[torch.Tensor],
  179. scaling: float,
  180. dropout: float = 0.0,
  181. **kwargs,
  182. ):
  183. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  184. if attention_mask is not None:
  185. attn_weights = attn_weights + attention_mask
  186. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  187. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  188. attn_output = torch.matmul(attn_weights, value)
  189. attn_output = attn_output.transpose(1, 2).contiguous()
  190. return attn_output, attn_weights
  191. class Aimv2Attention(nn.Module):
  192. """Multi-headed attention from 'Attention Is All You Need' paper"""
  193. def __init__(self, config):
  194. super().__init__()
  195. self.config = config
  196. self.embed_dim = config.hidden_size
  197. self.num_heads = config.num_attention_heads
  198. self.head_dim = self.embed_dim // self.num_heads
  199. if self.head_dim * self.num_heads != self.embed_dim:
  200. raise ValueError(
  201. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  202. f" {self.num_heads})."
  203. )
  204. self.scale = self.head_dim**-0.5
  205. self.dropout = config.attention_dropout
  206. self.is_causal = False
  207. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  208. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  209. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  210. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  211. def forward(
  212. self,
  213. hidden_states: torch.Tensor,
  214. attention_mask: Optional[torch.Tensor] = None,
  215. **kwargs,
  216. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  217. """Input shape: Batch x Time x Channel"""
  218. batch_size, seq_length, embed_dim = hidden_states.shape
  219. queries = self.q_proj(hidden_states)
  220. keys = self.k_proj(hidden_states)
  221. values = self.v_proj(hidden_states)
  222. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  223. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  224. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  225. attention_interface: Callable = eager_attention_forward
  226. if self.config._attn_implementation != "eager":
  227. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  228. attn_output, attn_weights = attention_interface(
  229. self,
  230. queries,
  231. keys,
  232. values,
  233. attention_mask,
  234. is_causal=self.is_causal,
  235. scaling=self.scale,
  236. dropout=0.0 if not self.training else self.dropout,
  237. )
  238. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  239. attn_output = self.out_proj(attn_output)
  240. return attn_output, attn_weights
  241. class Aimv2EncoderLayer(GradientCheckpointingLayer):
  242. def __init__(self, config: Aimv2VisionConfig):
  243. super().__init__()
  244. self.attention = Aimv2Attention(config)
  245. self.ffn = Aimv2MLP(config)
  246. self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  247. self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  248. def forward(
  249. self,
  250. hidden_states: torch.Tensor,
  251. attention_mask: Optional[torch.Tensor] = None,
  252. **kwargs: Unpack[TransformersKwargs],
  253. ) -> torch.Tensor:
  254. norm_hidden_states = self.rms_norm1(hidden_states)
  255. attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)
  256. hidden_states = hidden_states + attn_output
  257. norm_hidden_states = self.rms_norm2(hidden_states)
  258. mlp_output = self.ffn(norm_hidden_states)
  259. hidden_states = hidden_states + mlp_output
  260. return hidden_states
  261. class Aimv2Encoder(nn.Module):
  262. """
  263. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  264. [`Aimv2EncoderLayer`].
  265. Args:
  266. config: Aimv2Config
  267. """
  268. def __init__(self, config: Aimv2Config):
  269. super().__init__()
  270. self.config = config
  271. self.layers = nn.ModuleList([Aimv2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  272. self.gradient_checkpointing = False
  273. # Ignore copy
  274. @auto_docstring
  275. def forward(
  276. self,
  277. inputs_embeds,
  278. attention_mask: Optional[torch.Tensor] = None,
  279. **kwargs: Unpack[TransformersKwargs],
  280. ) -> BaseModelOutput:
  281. hidden_states = inputs_embeds
  282. for encoder_layer in self.layers:
  283. hidden_states = encoder_layer(
  284. hidden_states,
  285. attention_mask,
  286. **kwargs,
  287. )
  288. return BaseModelOutput(last_hidden_state=hidden_states)
  289. class Aimv2AttentionPoolingHead(nn.Module):
  290. def __init__(self, config: Aimv2VisionConfig):
  291. super().__init__()
  292. self.hidden_size = config.hidden_size
  293. self.num_heads = config.num_attention_heads
  294. self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
  295. self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
  296. self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
  297. self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
  298. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  299. batch_size, seq_len, hidden_dim = hidden_states.shape
  300. cls_token = self.cls_token.expand(batch_size, -1, -1)
  301. key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
  302. value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
  303. query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads)
  304. key = key.permute(0, 2, 1, 3)
  305. value = value.permute(0, 2, 1, 3)
  306. query = query.permute(0, 2, 1, 3)
  307. attn_output = F.scaled_dot_product_attention(query, key, value)
  308. attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim)
  309. attn_output = attn_output.mean(dim=1)
  310. output = self.output_proj(attn_output)
  311. return output
  312. @auto_docstring
  313. class Aimv2PreTrainedModel(PreTrainedModel):
  314. """
  315. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  316. models. The model is only intended for inference and doesn't support finetuning.
  317. """
  318. config: Aimv2Config
  319. base_model_prefix = "aimv2"
  320. supports_gradient_checkpointing = True
  321. _no_split_modules = [
  322. "Aimv2EncoderLayer",
  323. "Aimv2AttentionPoolingHead",
  324. "Aimv2VisionEmbeddings",
  325. "Aimv2TextEmbeddings",
  326. ]
  327. _supports_sdpa = True
  328. _supports_flash_attn = True
  329. _supports_flex_attn = True
  330. def _init_weights(self, module):
  331. super()._init_weights(module)
  332. if hasattr(module, "logit_scale"):
  333. if isinstance(module.logit_scale, nn.Parameter):
  334. module.logit_scale.data.fill_(math.log(1 / 0.07))
  335. elif isinstance(module, Aimv2AttentionPoolingHead):
  336. module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
  337. @auto_docstring(
  338. custom_intro="""
  339. The Vision model from AIMv2 without any head or projection on top.
  340. """
  341. )
  342. class Aimv2VisionModel(Aimv2PreTrainedModel):
  343. config: Aimv2VisionConfig
  344. main_input_name = "pixel_values"
  345. _can_record_outputs = {
  346. "hidden_states": Aimv2EncoderLayer,
  347. "attentions": Aimv2Attention,
  348. }
  349. def __init__(self, config: Aimv2VisionConfig):
  350. super().__init__(config)
  351. self.config = config
  352. self.embeddings = Aimv2VisionEmbeddings(config)
  353. self.encoder = Aimv2Encoder(config)
  354. # The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
  355. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  356. self.use_head = config.use_head
  357. if self.use_head:
  358. self.head = Aimv2AttentionPoolingHead(config)
  359. self.post_init()
  360. def get_input_embeddings(self) -> nn.Module:
  361. return self.embeddings.patch_embed
  362. @deprecate_kwarg("attention_mask", version="v4.58.0")
  363. @check_model_inputs(tie_last_hidden_states=False)
  364. @auto_docstring
  365. def forward(
  366. self,
  367. pixel_values,
  368. attention_mask: Optional[torch.Tensor] = None,
  369. **kwargs: Unpack[TransformersKwargs],
  370. ) -> BaseModelOutputWithPooling:
  371. r"""
  372. Examples:
  373. ```python
  374. >>> from PIL import Image
  375. >>> import requests
  376. >>> from transformers import AutoProcessor, Siglip2VisionModel
  377. >>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native")
  378. >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native")
  379. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  380. >>> image = Image.open(requests.get(url, stream=True).raw)
  381. >>> inputs = processor(images=image, return_tensors="pt")
  382. >>> outputs = model(**inputs)
  383. >>> last_hidden_state = outputs.last_hidden_state
  384. >>> pooled_output = outputs.pooler_output # pooled features
  385. ```"""
  386. hidden_states = self.embeddings(pixel_values)
  387. encoder_outputs: BaseModelOutput = self.encoder(
  388. inputs_embeds=hidden_states,
  389. **kwargs,
  390. )
  391. last_hidden_state = encoder_outputs.last_hidden_state
  392. last_hidden_state = self.rms_norm(last_hidden_state)
  393. pooler_output = self.head(last_hidden_state) if self.use_head else None
  394. return BaseModelOutputWithPooling(
  395. last_hidden_state=last_hidden_state,
  396. pooler_output=pooler_output,
  397. )
  398. @auto_docstring(
  399. custom_intro="""
  400. The text model from AIMv2 without any head or projection on top.
  401. """
  402. )
  403. class Aimv2TextModel(Aimv2PreTrainedModel):
  404. main_input_name = "input_ids"
  405. _can_record_outputs = {
  406. "hidden_states": Aimv2EncoderLayer,
  407. "attentions": Aimv2Attention,
  408. }
  409. def __init__(self, config: Aimv2TextConfig):
  410. super().__init__(config)
  411. self.config = config
  412. self.embeddings = Aimv2TextEmbeddings(config)
  413. self.encoder = Aimv2Encoder(config)
  414. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  415. self.eos_token_id = config.eos_token_id
  416. self.post_init()
  417. def get_input_embeddings(self) -> nn.Module:
  418. return self.embeddings.token_embedding
  419. def set_input_embeddings(self, value):
  420. self.embeddings.token_embedding = value
  421. @check_model_inputs(tie_last_hidden_states=False)
  422. @auto_docstring
  423. def forward(
  424. self,
  425. input_ids,
  426. attention_mask: Optional[torch.Tensor] = None,
  427. **kwargs: Unpack[TransformersKwargs],
  428. ) -> BaseModelOutputWithPooling:
  429. hidden_states = self.embeddings(input_ids)
  430. batch_size, seq_len, _ = hidden_states.shape
  431. cache_position = torch.arange(seq_len, dtype=torch.long, device=hidden_states.device)
  432. position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
  433. if attention_mask is not None:
  434. attention_mask = create_causal_mask(
  435. config=self.config,
  436. input_embeds=hidden_states,
  437. position_ids=position_ids,
  438. attention_mask=attention_mask,
  439. cache_position=cache_position,
  440. past_key_values=None,
  441. )
  442. encoder_outputs = self.encoder(
  443. inputs_embeds=hidden_states,
  444. attention_mask=attention_mask,
  445. **kwargs,
  446. )
  447. last_hidden_state = encoder_outputs.last_hidden_state
  448. last_hidden_state = self.rms_norm(last_hidden_state)
  449. # Get pooled output
  450. pooled_output = last_hidden_state[
  451. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  452. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1),
  453. ]
  454. return BaseModelOutputWithPooling(
  455. last_hidden_state=last_hidden_state,
  456. pooler_output=pooled_output,
  457. )
  458. def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
  459. """
  460. This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
  461. model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
  462. """
  463. square_tensor = torch.pow(tensor, 2)
  464. sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
  465. normed_tensor = torch.pow(sum_tensor, 0.5)
  466. return normed_tensor
  467. @auto_docstring
  468. class Aimv2Model(Aimv2PreTrainedModel):
  469. config: Aimv2Config
  470. _no_split_modules = ["Aimv2TextEmbeddings", "Aimv2EncoderLayer", "Aimv2VisionEmbeddings"]
  471. _supports_flash_attn = True
  472. def __init__(self, config: Aimv2Config):
  473. super().__init__(config)
  474. self.projection_dim = config.projection_dim
  475. self.vision_embed_dim = config.vision_config.hidden_size
  476. self.text_embed_dim = config.text_config.hidden_size
  477. self.vision_model = Aimv2VisionModel._from_config(config.vision_config)
  478. self.text_model = Aimv2TextModel._from_config(config.text_config)
  479. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  480. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  481. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  482. self.max_log_logit_scale = math.log(config.max_logit_scale)
  483. self.post_init()
  484. @filter_out_non_signature_kwargs()
  485. @auto_docstring
  486. def get_text_features(
  487. self,
  488. input_ids: torch.Tensor,
  489. attention_mask: Optional[torch.Tensor] = None,
  490. position_ids: Optional[torch.Tensor] = None,
  491. ) -> torch.FloatTensor:
  492. r"""
  493. Returns:
  494. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  495. applying the projection layer to the pooled output of [`Aimv2TextModel`].
  496. Examples:
  497. ```python
  498. >>> import torch
  499. >>> from transformers import AutoTokenizer, Aimv2Model
  500. >>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32")
  501. >>> tokenizer = AutoTokenizer.from_pretrained("openai/aimv2-vit-base-patch32")
  502. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  503. >>> with torch.inference_mode():
  504. ... text_features = model.get_text_features(**inputs)
  505. ```"""
  506. text_outputs: BaseModelOutputWithPooling = self.text_model(
  507. input_ids=input_ids,
  508. attention_mask=attention_mask,
  509. position_ids=position_ids,
  510. )
  511. pooled_output = text_outputs.pooler_output
  512. text_features = self.text_projection(pooled_output)
  513. return text_features
  514. @filter_out_non_signature_kwargs()
  515. @auto_docstring
  516. def get_image_features(
  517. self,
  518. pixel_values: torch.FloatTensor,
  519. interpolate_pos_encoding: bool = False,
  520. ) -> torch.FloatTensor:
  521. r"""
  522. Returns:
  523. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  524. applying the projection layer to the pooled output of [`Aimv2VisionModel`].
  525. Examples:
  526. ```python
  527. >>> import torch
  528. >>> from transformers import AutoProcessor, Aimv2Model
  529. >>> from transformers.image_utils import load_image
  530. >>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32")
  531. >>> processor = AutoProcessor.from_pretrained("openai/aimv2-vit-base-patch32")
  532. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  533. >>> image = load_image(url)
  534. >>> inputs = processor(images=image, return_tensors="pt")
  535. >>> with torch.inference_mode():
  536. ... image_features = model.get_image_features(**inputs)
  537. ```"""
  538. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  539. pixel_values=pixel_values,
  540. interpolate_pos_encoding=interpolate_pos_encoding,
  541. )
  542. pooled_output = vision_outputs.pooler_output
  543. image_features = self.visual_projection(pooled_output)
  544. return image_features
  545. @auto_docstring
  546. @can_return_tuple
  547. def forward(
  548. self,
  549. input_ids: Optional[torch.LongTensor] = None,
  550. pixel_values: Optional[torch.FloatTensor] = None,
  551. attention_mask: Optional[torch.Tensor] = None,
  552. **kwargs: Unpack[TransformersKwargs],
  553. ) -> Aimv2Output:
  554. r"""
  555. Examples:
  556. ```python
  557. >>> from PIL import Image
  558. >>> import requests
  559. >>> from transformers import AutoProcessor, Aimv2Model
  560. >>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit")
  561. >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
  562. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  563. >>> image = Image.open(requests.get(url, stream=True).raw)
  564. >>> inputs = processor(
  565. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  566. ... )
  567. >>> outputs = model(**inputs)
  568. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  569. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  570. ```"""
  571. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  572. pixel_values=pixel_values,
  573. **kwargs,
  574. )
  575. text_outputs: BaseModelOutputWithPooling = self.text_model(
  576. input_ids=input_ids,
  577. attention_mask=attention_mask,
  578. **kwargs,
  579. )
  580. image_embeds = vision_outputs.pooler_output
  581. image_embeds = self.visual_projection(image_embeds)
  582. text_embeds = text_outputs.pooler_output
  583. text_embeds = self.text_projection(text_embeds)
  584. # normalized features
  585. image_embeds = image_embeds / _get_vector_norm(image_embeds)
  586. text_embeds = text_embeds / _get_vector_norm(text_embeds)
  587. logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp().to(text_embeds.device)
  588. logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
  589. logits_per_image = logits_per_text.t()
  590. return Aimv2Output(
  591. logits_per_image=logits_per_image,
  592. logits_per_text=logits_per_text,
  593. text_embeds=text_embeds,
  594. image_embeds=image_embeds,
  595. text_model_output=text_outputs,
  596. vision_model_output=vision_outputs,
  597. )
  598. __all__ = ["Aimv2VisionModel", "Aimv2Model", "Aimv2PreTrainedModel", "Aimv2TextModel"]