modular_aimv2.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
  1. # coding=utf-8
  2. # Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Pytorch implementation of AIMv2 Model"""
  16. import math
  17. from typing import Optional
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from ...masking_utils import create_causal_mask
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  24. from ...modeling_utils import PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import (
  27. TransformersKwargs,
  28. auto_docstring,
  29. can_return_tuple,
  30. )
  31. from ...utils.deprecation import deprecate_kwarg
  32. from ...utils.generic import check_model_inputs
  33. from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
  34. from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm
  35. from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
  36. from ..siglip.modeling_siglip import SiglipAttention, SiglipEncoder, SiglipOutput
  37. class Aimv2VisionConfig(SiglipVisionConfig):
  38. r"""
  39. This is the configuration class to store the configuration of a [`Aimv2VisionModel`]. It is used to instantiate a
  40. AIMv2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
  41. configuration with the defaults will yield a similar configuration to that of the vision encoder of the AIMv2
  42. [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224) architecture.
  43. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  44. documentation from [`PretrainedConfig`] for more information.
  45. Args:
  46. hidden_size (`int`, *optional*, defaults to 1024):
  47. Dimensionality of the encoder layers and the pooler layer.
  48. intermediate_size (`int`, *optional*, defaults to 2816):
  49. Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
  50. num_hidden_layers (`int`, *optional*, defaults to 24):
  51. Number of hidden layers in the Transformer encoder.
  52. num_attention_heads (`int`, *optional*, defaults to 8):
  53. Number of attention heads for each attention layer in the Transformer encoder.
  54. num_channels (`int`, *optional*, defaults to 3):
  55. Number of channels in the input images.
  56. image_size (`int`, *optional*, defaults to 224):
  57. The size (resolution) of each image.
  58. patch_size (`int`, *optional*, defaults to 14):
  59. The size (resolution) of each patch.
  60. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  61. The epsilon used by the rms normalization layers.
  62. attention_dropout (`float`, *optional*, defaults to 0.0):
  63. The dropout ratio for the attention probabilities.
  64. qkv_bias (`bool`, *optional*, defaults to `False`):
  65. Whether to add a bias to the queries, keys and values.
  66. mlp_bias (`bool`, *optional*, defaults to `False`):
  67. Whether to add a bias to the Linear layers or Not.
  68. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  69. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  70. `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
  71. initializer_range (`float`, *optional*, defaults to 0.02):
  72. The standard deviation of the for initializing all weight matrices.
  73. use_head (`str`, *optional*, defaults to `True`):
  74. Whether to use Attention Pooling Head or Not.
  75. is_native (`str`, *optional*, defaults to `False`):
  76. Whether to use ckpt trained for image native resolution or not.
  77. Example:
  78. ```python
  79. >>> from transformers import SiglipVisionConfig, SiglipVisionModel
  80. >>> # Initializing a Aimv2VisionConfig with apple/aimv2-large-patch14-224 style configuration
  81. >>> configuration = Aimv2VisionConfig()
  82. >>> # Initializing a Aimv2VisionModel (with random weights) from the apple/aimv2-large-patch14-224 style configuration
  83. >>> model = Aimv2VisionModel(configuration)
  84. >>> # Accessing the model configuration
  85. >>> configuration = model.config
  86. ```"""
  87. def __init__(
  88. self,
  89. hidden_size: int = 1024,
  90. intermediate_size: int = 2816,
  91. num_hidden_layers: int = 24,
  92. num_attention_heads: int = 8,
  93. num_channels: int = 3,
  94. image_size: int = 224,
  95. patch_size: int = 14,
  96. rms_norm_eps: float = 1e-5,
  97. attention_dropout: float = 0.0,
  98. qkv_bias: bool = False,
  99. mlp_bias: bool = False,
  100. hidden_act: str = "silu",
  101. initializer_range: float = 0.02,
  102. use_head: bool = True,
  103. is_native: bool = False,
  104. **kwargs,
  105. ):
  106. super().__init__(
  107. hidden_size=hidden_size,
  108. intermediate_size=intermediate_size,
  109. num_hidden_layers=num_hidden_layers,
  110. num_attention_heads=num_attention_heads,
  111. hidden_act=hidden_act,
  112. num_channels=num_channels,
  113. image_size=image_size,
  114. patch_size=patch_size,
  115. qkv_bias=qkv_bias,
  116. **kwargs,
  117. )
  118. self.use_head = use_head
  119. self.initializer_range = initializer_range
  120. self.attention_dropout = attention_dropout
  121. self.mlp_bias = mlp_bias
  122. self.qkv_bias = qkv_bias
  123. self.rms_norm_eps = rms_norm_eps
  124. self.is_native = is_native
  125. del self.layer_norm_eps
  126. class Aimv2TextConfig(SiglipTextConfig):
  127. r"""
  128. This is the configuration class to store the configuration of a [`Aimv2TextModel`]. It is used to instantiate a
  129. AIMv2 text encoder according to the specified arguments, defining the model architecture. Instantiating a
  130. configuration with the defaults will yield a similar configuration to that of the text encoder of the AIMv2
  131. [apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
  132. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  133. documentation from [`PretrainedConfig`] for more information.
  134. Args:
  135. vocab_size (`int`, *optional*, defaults to 49408):
  136. Vocabulary size of the AIMv2 text model. Defines the number of different tokens that can be represented by
  137. the `inputs_ids` passed when calling [`Aimv2Model`].
  138. hidden_size (`int`, *optional*, defaults to 768):
  139. Dimensionality of the encoder layers and the pooler layer.
  140. intermediate_size (`int`, *optional*, defaults to 2048):
  141. Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
  142. num_hidden_layers (`int`, *optional*, defaults to 12):
  143. Number of hidden layers in the Transformer encoder.
  144. num_attention_heads (`int`, *optional*, defaults to 6):
  145. Number of attention heads for each attention layer in the Transformer encoder.
  146. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  147. The epsilon used by the rms normalization layers.
  148. attention_dropout (`float`, *optional*, defaults to 0.0):
  149. The dropout ratio for the attention probabilities.
  150. qkv_bias (`bool`, *optional*, defaults to `False`):
  151. Whether to add a bias to the queries, keys and values.
  152. mlp_bias (`bool`, *optional*, defaults to `False`):
  153. Whether to add a bias to the Linear layers or Not.
  154. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  155. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  156. `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
  157. pad_token_id (`int`, *optional*, defaults to 1):
  158. The id of the padding token in the vocabulary.
  159. bos_token_id (`int`, *optional*, defaults to 49406):
  160. The id of the beginning-of-sequence token in the vocabulary.
  161. eos_token_id (`int`, *optional*, defaults to 49407):
  162. The id of the end-of-sequence token in the vocabulary.
  163. max_position_embeddings (`int`, *optional*, defaults to 77):
  164. The maximum sequence length that this model might ever be used with. Typically set this to something large
  165. just in case (e.g., 512 or 1024 or 2048).
  166. initializer_range (`float`, *optional*, defaults to 0.02):
  167. The standard deviation of the for initializing all weight matrices.
  168. """
  169. def __init__(
  170. self,
  171. vocab_size: int = 49408,
  172. hidden_size: int = 768,
  173. intermediate_size: int = 2048,
  174. num_hidden_layers: int = 12,
  175. num_attention_heads: int = 6,
  176. rms_norm_eps: float = 1e-5,
  177. attention_dropout: float = 0.0,
  178. qkv_bias: bool = False,
  179. mlp_bias: bool = False,
  180. hidden_act: str = "silu",
  181. pad_token_id: Optional[int] = None,
  182. bos_token_id: Optional[int] = None,
  183. eos_token_id: int = 49407,
  184. max_position_embeddings: int = 77,
  185. initializer_range: bool = 0.02,
  186. **kwargs,
  187. ):
  188. super().__init__(
  189. vocab_size=vocab_size,
  190. hidden_size=hidden_size,
  191. intermediate_size=intermediate_size,
  192. num_hidden_layers=num_hidden_layers,
  193. num_attention_heads=num_attention_heads,
  194. hidden_act=hidden_act,
  195. max_position_embeddings=max_position_embeddings,
  196. pad_token_id=pad_token_id,
  197. bos_token_id=bos_token_id,
  198. eos_token_id=eos_token_id,
  199. **kwargs,
  200. )
  201. self.initializer_range = initializer_range
  202. self.attention_dropout = attention_dropout
  203. self.mlp_bias = mlp_bias
  204. self.qkv_bias = qkv_bias
  205. self.rms_norm_eps = rms_norm_eps
  206. del self.bos_token_id
  207. del self.pad_token_id
  208. del self.projection_size
  209. del self.layer_norm_eps
  210. class Aimv2Config(SiglipConfig):
  211. r"""
  212. [`Aimv2Config`] is the configuration class to store the configuration of a [`Aimv2Model`]. It is used to
  213. instantiate a AIMv2 model according to the specified arguments, defining the text model and vision model configs.
  214. Instantiating a configuration with the defaults will yield a similar configuration to that of the AIMv2
  215. [apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit) architecture.
  216. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  217. documentation from [`PretrainedConfig`] for more information.
  218. Args:
  219. text_config (`dict`, *optional*):
  220. Dictionary of configuration options used to initialize [`Aimv2TextConfig`].
  221. vision_config (`dict`, *optional*):
  222. Dictionary of configuration options used to initialize [`Aimv2VisionConfig`].
  223. projection_dim (`int`, *optional*, defaults to 512):
  224. Dimensionality of text and vision projection layers.
  225. logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
  226. The initial value of the *logit_scale* parameter.
  227. kwargs (*optional*):
  228. Dictionary of keyword arguments.
  229. Example:
  230. ```python
  231. >>> from transformers import Aimv2Config, Aimv2Model
  232. >>> # Initializing a Aimv2Config with apple/aimv2-large-patch14-224-lit style configuration
  233. >>> configuration = Aimv2Config()
  234. >>> # Initializing a Aimv2Model (with random weights) from the apple/aimv2-large-patch14-224-lit style configuration
  235. >>> model = Aimv2Model(configuration)
  236. >>> # Accessing the model configuration
  237. >>> configuration = model.config
  238. >>> # We can also initialize a Aimv2Config from a Aimv2TextConfig and a Aimv2VisionConfig
  239. >>> from transformers import Aimv2TextConfig, Aimv2VisionConfig
  240. >>> # Initializing a AIMv2Text and AIMv2Vision configuration
  241. >>> config_text = Aimv2TextConfig()
  242. >>> config_vision = Aimv2VisionConfig()
  243. >>> config = Aimv2Config(text_config=config_text, vision_config=config_vision)
  244. ```"""
  245. def __init__(
  246. self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
  247. ):
  248. super().__init__(text_config, vision_config, **kwargs)
  249. self.projection_dim = projection_dim
  250. self.logit_scale_init_value = logit_scale_init_value
  251. self.max_logit_scale = 100.0
  252. del self.initializer_factor
  253. class Aimv2Output(SiglipOutput):
  254. pass
  255. class Aimv2RMSNorm(LlamaRMSNorm):
  256. pass
  257. class Aimv2MLP(LlamaMLP):
  258. pass
  259. class Aimv2VisionEmbeddings(nn.Module):
  260. def __init__(self, config: Aimv2VisionConfig):
  261. super().__init__()
  262. self.config = config
  263. self.patch_size = config.patch_size
  264. self.patch_embed = nn.Conv2d(
  265. config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
  266. )
  267. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  268. num_patches = (config.image_size // config.patch_size) ** 2
  269. if not self.config.is_native:
  270. self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
  271. self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
  272. @staticmethod
  273. def build_2d_sincos_position_embedding(
  274. height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
  275. ) -> torch.Tensor:
  276. grid_w = torch.arange(int(width), dtype=dtype, device=device)
  277. grid_h = torch.arange(int(height), dtype=dtype, device=device)
  278. grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy")
  279. pos_dim = embed_dim // 4
  280. omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
  281. omega = 1.0 / (temperature**omega)
  282. out_h = grid_h.flatten()[..., None] @ omega[None, :]
  283. out_w = grid_w.flatten()[..., None] @ omega[None, :]
  284. return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
  285. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  286. _, _, height, width = pixel_values.size()
  287. hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
  288. hidden_states = self.rms_norm(hidden_states)
  289. if self.config.is_native:
  290. pos_embed = self.build_2d_sincos_position_embedding(
  291. height // self.patch_size,
  292. width // self.patch_size,
  293. embed_dim=self.config.hidden_size,
  294. device=hidden_states.device,
  295. dtype=hidden_states.dtype,
  296. )
  297. else:
  298. pos_embed = self.position_embedding(self.position_ids)
  299. hidden_states = hidden_states + pos_embed
  300. return hidden_states
  301. class Aimv2TextEmbeddings(CLIPTextEmbeddings):
  302. pass
  303. class Aimv2Attention(SiglipAttention):
  304. def __init__(self, config):
  305. super().__init__(config)
  306. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  307. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  308. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  309. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  310. class Aimv2EncoderLayer(GradientCheckpointingLayer):
  311. def __init__(self, config: Aimv2VisionConfig):
  312. super().__init__()
  313. self.attention = Aimv2Attention(config)
  314. self.ffn = Aimv2MLP(config)
  315. self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  316. self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  317. def forward(
  318. self,
  319. hidden_states: torch.Tensor,
  320. attention_mask: Optional[torch.Tensor] = None,
  321. **kwargs: Unpack[TransformersKwargs],
  322. ) -> torch.Tensor:
  323. norm_hidden_states = self.rms_norm1(hidden_states)
  324. attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)
  325. hidden_states = hidden_states + attn_output
  326. norm_hidden_states = self.rms_norm2(hidden_states)
  327. mlp_output = self.ffn(norm_hidden_states)
  328. hidden_states = hidden_states + mlp_output
  329. return hidden_states
  330. class Aimv2Encoder(SiglipEncoder):
  331. pass
  332. class Aimv2AttentionPoolingHead(nn.Module):
  333. def __init__(self, config: Aimv2VisionConfig):
  334. super().__init__()
  335. self.hidden_size = config.hidden_size
  336. self.num_heads = config.num_attention_heads
  337. self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
  338. self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
  339. self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
  340. self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
  341. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  342. batch_size, seq_len, hidden_dim = hidden_states.shape
  343. cls_token = self.cls_token.expand(batch_size, -1, -1)
  344. key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
  345. value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
  346. query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads)
  347. key = key.permute(0, 2, 1, 3)
  348. value = value.permute(0, 2, 1, 3)
  349. query = query.permute(0, 2, 1, 3)
  350. attn_output = F.scaled_dot_product_attention(query, key, value)
  351. attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim)
  352. attn_output = attn_output.mean(dim=1)
  353. output = self.output_proj(attn_output)
  354. return output
  355. @auto_docstring
  356. class Aimv2PreTrainedModel(PreTrainedModel):
  357. """
  358. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  359. models. The model is only intended for inference and doesn't support finetuning.
  360. """
  361. config: Aimv2Config
  362. base_model_prefix = "aimv2"
  363. supports_gradient_checkpointing = True
  364. _no_split_modules = [
  365. "Aimv2EncoderLayer",
  366. "Aimv2AttentionPoolingHead",
  367. "Aimv2VisionEmbeddings",
  368. "Aimv2TextEmbeddings",
  369. ]
  370. _supports_sdpa = True
  371. _supports_flash_attn = True
  372. _supports_flex_attn = True
  373. def _init_weights(self, module):
  374. super()._init_weights(module)
  375. if hasattr(module, "logit_scale"):
  376. if isinstance(module.logit_scale, nn.Parameter):
  377. module.logit_scale.data.fill_(math.log(1 / 0.07))
  378. elif isinstance(module, Aimv2AttentionPoolingHead):
  379. module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
  380. @auto_docstring(
  381. custom_intro="""
  382. The Vision model from AIMv2 without any head or projection on top.
  383. """
  384. )
  385. class Aimv2VisionModel(Aimv2PreTrainedModel):
  386. config: Aimv2VisionConfig
  387. main_input_name = "pixel_values"
  388. _can_record_outputs = {
  389. "hidden_states": Aimv2EncoderLayer,
  390. "attentions": Aimv2Attention,
  391. }
  392. def __init__(self, config: Aimv2VisionConfig):
  393. super().__init__(config)
  394. self.config = config
  395. self.embeddings = Aimv2VisionEmbeddings(config)
  396. self.encoder = Aimv2Encoder(config)
  397. # The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
  398. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  399. self.use_head = config.use_head
  400. if self.use_head:
  401. self.head = Aimv2AttentionPoolingHead(config)
  402. self.post_init()
  403. def get_input_embeddings(self) -> nn.Module:
  404. return self.embeddings.patch_embed
  405. @deprecate_kwarg("attention_mask", version="v4.58.0")
  406. @check_model_inputs(tie_last_hidden_states=False)
  407. @auto_docstring
  408. def forward(
  409. self,
  410. pixel_values,
  411. attention_mask: Optional[torch.Tensor] = None,
  412. **kwargs: Unpack[TransformersKwargs],
  413. ) -> BaseModelOutputWithPooling:
  414. r"""
  415. Examples:
  416. ```python
  417. >>> from PIL import Image
  418. >>> import requests
  419. >>> from transformers import AutoProcessor, Siglip2VisionModel
  420. >>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native")
  421. >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native")
  422. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  423. >>> image = Image.open(requests.get(url, stream=True).raw)
  424. >>> inputs = processor(images=image, return_tensors="pt")
  425. >>> outputs = model(**inputs)
  426. >>> last_hidden_state = outputs.last_hidden_state
  427. >>> pooled_output = outputs.pooler_output # pooled features
  428. ```"""
  429. hidden_states = self.embeddings(pixel_values)
  430. encoder_outputs: BaseModelOutput = self.encoder(
  431. inputs_embeds=hidden_states,
  432. **kwargs,
  433. )
  434. last_hidden_state = encoder_outputs.last_hidden_state
  435. last_hidden_state = self.rms_norm(last_hidden_state)
  436. pooler_output = self.head(last_hidden_state) if self.use_head else None
  437. return BaseModelOutputWithPooling(
  438. last_hidden_state=last_hidden_state,
  439. pooler_output=pooler_output,
  440. )
  441. @auto_docstring(
  442. custom_intro="""
  443. The text model from AIMv2 without any head or projection on top.
  444. """
  445. )
  446. class Aimv2TextModel(Aimv2PreTrainedModel):
  447. main_input_name = "input_ids"
  448. _can_record_outputs = {
  449. "hidden_states": Aimv2EncoderLayer,
  450. "attentions": Aimv2Attention,
  451. }
  452. def __init__(self, config: Aimv2TextConfig):
  453. super().__init__(config)
  454. self.config = config
  455. self.embeddings = Aimv2TextEmbeddings(config)
  456. self.encoder = Aimv2Encoder(config)
  457. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  458. self.eos_token_id = config.eos_token_id
  459. self.post_init()
  460. def get_input_embeddings(self) -> nn.Module:
  461. return self.embeddings.token_embedding
  462. def set_input_embeddings(self, value):
  463. self.embeddings.token_embedding = value
  464. @check_model_inputs(tie_last_hidden_states=False)
  465. @auto_docstring
  466. def forward(
  467. self,
  468. input_ids,
  469. attention_mask: Optional[torch.Tensor] = None,
  470. **kwargs: Unpack[TransformersKwargs],
  471. ) -> BaseModelOutputWithPooling:
  472. hidden_states = self.embeddings(input_ids)
  473. batch_size, seq_len, _ = hidden_states.shape
  474. cache_position = torch.arange(seq_len, dtype=torch.long, device=hidden_states.device)
  475. position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
  476. if attention_mask is not None:
  477. attention_mask = create_causal_mask(
  478. config=self.config,
  479. input_embeds=hidden_states,
  480. position_ids=position_ids,
  481. attention_mask=attention_mask,
  482. cache_position=cache_position,
  483. past_key_values=None,
  484. )
  485. encoder_outputs = self.encoder(
  486. inputs_embeds=hidden_states,
  487. attention_mask=attention_mask,
  488. **kwargs,
  489. )
  490. last_hidden_state = encoder_outputs.last_hidden_state
  491. last_hidden_state = self.rms_norm(last_hidden_state)
  492. # Get pooled output
  493. pooled_output = last_hidden_state[
  494. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  495. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1),
  496. ]
  497. return BaseModelOutputWithPooling(
  498. last_hidden_state=last_hidden_state,
  499. pooler_output=pooled_output,
  500. )
  501. @auto_docstring
  502. class Aimv2Model(CLIPModel):
  503. _supports_flash_attn = True
  504. def __init__(self, config: Aimv2Config):
  505. PreTrainedModel.__init__(self, config)
  506. self.projection_dim = config.projection_dim
  507. self.vision_embed_dim = config.vision_config.hidden_size
  508. self.text_embed_dim = config.text_config.hidden_size
  509. self.vision_model = Aimv2VisionModel._from_config(config.vision_config)
  510. self.text_model = Aimv2TextModel._from_config(config.text_config)
  511. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  512. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  513. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  514. self.max_log_logit_scale = math.log(config.max_logit_scale)
  515. self.post_init()
  516. @auto_docstring
  517. @can_return_tuple
  518. def forward(
  519. self,
  520. input_ids: Optional[torch.LongTensor] = None,
  521. pixel_values: Optional[torch.FloatTensor] = None,
  522. attention_mask: Optional[torch.Tensor] = None,
  523. **kwargs: Unpack[TransformersKwargs],
  524. ) -> Aimv2Output:
  525. r"""
  526. Examples:
  527. ```python
  528. >>> from PIL import Image
  529. >>> import requests
  530. >>> from transformers import AutoProcessor, Aimv2Model
  531. >>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit")
  532. >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
  533. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  534. >>> image = Image.open(requests.get(url, stream=True).raw)
  535. >>> inputs = processor(
  536. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  537. ... )
  538. >>> outputs = model(**inputs)
  539. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  540. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  541. ```"""
  542. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  543. pixel_values=pixel_values,
  544. **kwargs,
  545. )
  546. text_outputs: BaseModelOutputWithPooling = self.text_model(
  547. input_ids=input_ids,
  548. attention_mask=attention_mask,
  549. **kwargs,
  550. )
  551. image_embeds = vision_outputs.pooler_output
  552. image_embeds = self.visual_projection(image_embeds)
  553. text_embeds = text_outputs.pooler_output
  554. text_embeds = self.text_projection(text_embeds)
  555. # normalized features
  556. image_embeds = image_embeds / _get_vector_norm(image_embeds)
  557. text_embeds = text_embeds / _get_vector_norm(text_embeds)
  558. logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp().to(text_embeds.device)
  559. logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
  560. logits_per_image = logits_per_text.t()
  561. return Aimv2Output(
  562. logits_per_image=logits_per_image,
  563. logits_per_text=logits_per_text,
  564. text_embeds=text_embeds,
  565. image_embeds=image_embeds,
  566. text_model_output=text_outputs,
  567. vision_model_output=vision_outputs,
  568. )
  569. __all__ = [
  570. "Aimv2Config",
  571. "Aimv2VisionConfig",
  572. "Aimv2TextConfig",
  573. "Aimv2VisionModel",
  574. "Aimv2Model",
  575. "Aimv2PreTrainedModel",
  576. "Aimv2TextModel",
  577. ]