modeling_internvl.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/internvl/modular_internvl.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_internvl.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import collections.abc
  22. from dataclasses import dataclass
  23. from typing import Callable, Optional, Union
  24. import torch
  25. import torch.nn as nn
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_int
  35. from ...utils.generic import check_model_inputs
  36. from ..auto import AutoModel
  37. from .configuration_internvl import InternVLConfig, InternVLVisionConfig
  38. @use_kernel_forward_from_hub("RMSNorm")
  39. class InternVLVisionRMSNorm(nn.Module):
  40. def __init__(self, hidden_size, eps=1e-6):
  41. """
  42. InternVLVisionRMSNorm is equivalent to T5LayerNorm
  43. """
  44. super().__init__()
  45. self.weight = nn.Parameter(torch.ones(hidden_size))
  46. self.variance_epsilon = eps
  47. def forward(self, hidden_states):
  48. input_dtype = hidden_states.dtype
  49. hidden_states = hidden_states.to(torch.float32)
  50. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  51. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  52. return self.weight * hidden_states.to(input_dtype)
  53. def extra_repr(self):
  54. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  55. def eager_attention_forward(
  56. module: nn.Module,
  57. query: torch.Tensor,
  58. key: torch.Tensor,
  59. value: torch.Tensor,
  60. attention_mask: Optional[torch.Tensor],
  61. scaling: float,
  62. dropout: float = 0.0,
  63. **kwargs,
  64. ):
  65. key_states = key
  66. value_states = value
  67. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  68. if attention_mask is not None:
  69. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  70. attn_weights = attn_weights + causal_mask
  71. # No upcasting of the attention weights to float32 in this implementation
  72. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  73. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  74. attn_output = torch.matmul(attn_weights, value_states)
  75. attn_output = attn_output.transpose(1, 2).contiguous()
  76. return attn_output, attn_weights
  77. class InternVLVisionAttention(nn.Module):
  78. """Attention Class for InternVL Vision Encoder"""
  79. def __init__(self, config: InternVLVisionConfig):
  80. super().__init__()
  81. self.config = config
  82. self.embed_dim = config.hidden_size
  83. self.num_heads = config.num_attention_heads
  84. self.head_dim = self.embed_dim // self.num_heads
  85. if self.head_dim * self.num_heads != self.embed_dim:
  86. raise ValueError(
  87. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  88. f" {self.num_heads})."
  89. )
  90. self.scale = self.head_dim**-0.5
  91. self.attention_dropout = config.attention_dropout
  92. proj_dropout = config.projection_dropout
  93. qk_norm = config.use_qk_norm
  94. # Needed for flash attention
  95. self.is_causal = False
  96. self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  97. self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  98. self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  99. self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
  100. self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
  101. self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
  102. self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
  103. def forward(
  104. self,
  105. hidden_states: torch.Tensor,
  106. attention_mask: Optional[torch.Tensor] = None,
  107. **kwargs: Unpack[TransformersKwargs],
  108. ):
  109. batch_size, seq_len, _ = hidden_states.size()
  110. query_states = self.q_proj(hidden_states)
  111. key_states = self.k_proj(hidden_states)
  112. value_states = self.v_proj(hidden_states)
  113. query_states = self.q_norm(query_states)
  114. key_states = self.k_norm(key_states)
  115. query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  116. key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  117. value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  118. attention_interface: Callable = eager_attention_forward
  119. if self.config._attn_implementation != "eager":
  120. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  121. attn_output, attn_weights = attention_interface(
  122. self,
  123. query_states,
  124. key_states,
  125. value_states,
  126. attention_mask,
  127. dropout=0.0 if not self.training else self.attention_dropout,
  128. scaling=self.scale,
  129. is_causal=False,
  130. **kwargs,
  131. )
  132. attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
  133. output = self.projection_layer(attn_output)
  134. output = self.projection_dropout(output)
  135. return output, attn_weights
  136. @dataclass
  137. @auto_docstring(
  138. custom_intro="""
  139. Class for outputs of [`InternVLVisionModel`].
  140. """
  141. )
  142. class InternVLVisionModelOutputWithPooling(BaseModelOutputWithPooling):
  143. r"""
  144. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  145. Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
  146. *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
  147. will be returned.
  148. """
  149. class InternVLVisionPatchEmbeddings(nn.Module):
  150. """
  151. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  152. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  153. Transformer.
  154. """
  155. def __init__(self, config):
  156. super().__init__()
  157. image_size, patch_size = config.image_size, config.patch_size
  158. num_channels, hidden_size = config.num_channels, config.hidden_size
  159. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  160. patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  161. self.image_size = image_size
  162. self.patch_size = patch_size
  163. self.num_channels = num_channels
  164. self.num_patches = num_patches
  165. self.patch_shape = patch_shape
  166. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  167. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  168. batch_size, num_channels, height, width = pixel_values.shape
  169. if num_channels != self.num_channels:
  170. raise ValueError(
  171. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  172. )
  173. embeddings = self.projection(pixel_values)
  174. patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
  175. embeddings = embeddings.flatten(2).transpose(1, 2)
  176. return embeddings, (patch_height, patch_width)
  177. # Based on timm implementation, which can be found here:
  178. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  179. class InternVLVisionEmbeddings(nn.Module):
  180. """
  181. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  182. """
  183. def __init__(self, config: InternVLVisionConfig) -> None:
  184. super().__init__()
  185. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  186. if config.use_mask_token:
  187. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  188. else:
  189. self.mask_token = None
  190. self.patch_embeddings = InternVLVisionPatchEmbeddings(config)
  191. self.patch_size = config.patch_size
  192. self.image_size = (
  193. config.image_size
  194. if isinstance(config.image_size, collections.abc.Iterable)
  195. else (config.image_size, config.image_size)
  196. )
  197. num_patches = self.patch_embeddings.num_patches
  198. if config.use_absolute_position_embeddings:
  199. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  200. else:
  201. self.position_embeddings = None
  202. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  203. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  204. """
  205. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  206. images. This method is also adapted to support torch.jit tracing.
  207. Adapted from:
  208. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  209. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  210. """
  211. num_patches = embeddings.shape[1] - 1
  212. num_positions = self.position_embeddings.shape[1] - 1
  213. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  214. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  215. return self.position_embeddings
  216. class_pos_embed = self.position_embeddings[:, :1]
  217. patch_pos_embed = self.position_embeddings[:, 1:]
  218. dim = embeddings.shape[-1]
  219. new_height = height // self.patch_size[0]
  220. new_width = width // self.patch_size[1]
  221. sqrt_num_positions = torch_int(num_positions**0.5)
  222. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  223. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  224. patch_pos_embed = nn.functional.interpolate(
  225. patch_pos_embed,
  226. size=(new_height, new_width),
  227. mode="bicubic",
  228. align_corners=False,
  229. )
  230. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  231. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  232. def forward(
  233. self,
  234. pixel_values: torch.Tensor,
  235. bool_masked_pos: Optional[torch.BoolTensor] = None,
  236. ) -> torch.Tensor:
  237. _, _, height, width = pixel_values.shape
  238. embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
  239. batch_size, seq_len, _ = embeddings.size()
  240. if bool_masked_pos is not None:
  241. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  242. # replace the masked visual tokens by mask_tokens
  243. w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  244. embeddings = embeddings * (1 - w) + mask_tokens * w
  245. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  246. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  247. if self.position_embeddings is not None:
  248. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  249. embeddings = self.dropout(embeddings)
  250. return embeddings, (patch_height, patch_width)
  251. class InternVLVisionMLP(nn.Module):
  252. def __init__(self, config):
  253. super().__init__()
  254. self.config = config
  255. self.activation_fn = ACT2FN[config.hidden_act]
  256. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  257. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  258. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  259. hidden_states = self.fc1(hidden_states)
  260. hidden_states = self.activation_fn(hidden_states)
  261. hidden_states = self.fc2(hidden_states)
  262. return hidden_states
  263. NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
  264. class InternVLVisionLayer(GradientCheckpointingLayer):
  265. """This corresponds to the Block class in the timm implementation."""
  266. def __init__(self, config: InternVLVisionConfig) -> None:
  267. super().__init__()
  268. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  269. self.seq_len_dim = 1
  270. self.attention = InternVLVisionAttention(config)
  271. self.mlp = InternVLVisionMLP(config)
  272. # InternVL uses different layernorm implementations for different models
  273. self.layernorm_before = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
  274. self.layernorm_after = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
  275. init_values = config.layer_scale_init_value
  276. self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  277. self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  278. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  279. def forward(
  280. self,
  281. hidden_states: torch.Tensor,
  282. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  283. attention_output, _ = self.attention(
  284. self.layernorm_before(hidden_states), # in InternVLVision, layernorm is applied before self-attention
  285. )
  286. attention_output = self.lambda_1 * attention_output
  287. # first residual connection
  288. hidden_states = attention_output + hidden_states
  289. # in InternVLVision, layernorm is also applied after self-attention
  290. layer_output = self.layernorm_after(hidden_states)
  291. layer_output = self.mlp(layer_output)
  292. layer_output = self.dropout(layer_output)
  293. if self.lambda_2 is not None:
  294. layer_output = self.lambda_2 * layer_output
  295. # second residual connection
  296. layer_output = layer_output + hidden_states
  297. return layer_output
  298. class InternVLVisionEncoder(nn.Module):
  299. def __init__(self, config: InternVLVisionConfig) -> None:
  300. super().__init__()
  301. self.config = config
  302. self.layer = nn.ModuleList([InternVLVisionLayer(config) for i in range(config.num_hidden_layers)])
  303. self.gradient_checkpointing = False
  304. def forward(
  305. self,
  306. hidden_states: torch.Tensor,
  307. ) -> Union[tuple, BaseModelOutput]:
  308. for layer_module in self.layer:
  309. hidden_states = layer_module(hidden_states)
  310. return BaseModelOutput(
  311. last_hidden_state=hidden_states,
  312. )
  313. @auto_docstring
  314. class InternVLVisionPreTrainedModel(PreTrainedModel):
  315. config: InternVLVisionConfig
  316. base_model_prefix = "internvl_vision"
  317. main_input_name = "pixel_values"
  318. supports_gradient_checkpointing = True
  319. _no_split_modules = ["InternVLVisionLayer"]
  320. _supports_sdpa = True
  321. _supports_flash_attn = True
  322. _supports_flex_attn = True
  323. _supports_attention_backend = True
  324. _can_record_outputs = {
  325. "hidden_states": InternVLVisionLayer,
  326. "attentions": InternVLVisionAttention,
  327. }
  328. def _init_weights(self, module):
  329. """Initialize the weights"""
  330. super()._init_weights(module)
  331. if isinstance(module, InternVLVisionEmbeddings):
  332. module.cls_token.data.zero_()
  333. if module.mask_token is not None:
  334. module.mask_token.data.zero_()
  335. if module.position_embeddings is not None:
  336. module.position_embeddings.data.zero_()
  337. elif isinstance(module, InternVLVisionLayer):
  338. module.lambda_1.data.fill_(self.config.layer_scale_init_value)
  339. module.lambda_2.data.fill_(self.config.layer_scale_init_value)
  340. @auto_docstring
  341. class InternVLVisionModel(InternVLVisionPreTrainedModel):
  342. def __init__(self, config: InternVLVisionConfig) -> None:
  343. super().__init__(config)
  344. self.config = config
  345. self.embeddings = InternVLVisionEmbeddings(config)
  346. self.encoder = InternVLVisionEncoder(config)
  347. self.layernorm = (
  348. nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  349. )
  350. # Initialize weights and apply final processing
  351. self.post_init()
  352. def get_input_embeddings(self):
  353. return self.embeddings.patch_embeddings
  354. @check_model_inputs(tie_last_hidden_states=False)
  355. @auto_docstring
  356. def forward(
  357. self,
  358. pixel_values: torch.Tensor,
  359. bool_masked_pos: Optional[torch.BoolTensor] = None,
  360. ) -> Union[tuple, InternVLVisionModelOutputWithPooling]:
  361. r"""
  362. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  363. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  364. """
  365. embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  366. encoder_outputs = self.encoder(embedding_output)
  367. sequence_output = encoder_outputs[0]
  368. sequence_output = self.layernorm(sequence_output)
  369. return InternVLVisionModelOutputWithPooling(
  370. last_hidden_state=sequence_output,
  371. hidden_states=encoder_outputs.hidden_states,
  372. attentions=encoder_outputs.attentions,
  373. )
  374. @auto_docstring
  375. class InternVLPreTrainedModel(PreTrainedModel):
  376. config: InternVLConfig
  377. base_model_prefix = ""
  378. supports_gradient_checkpointing = True
  379. _skip_keys_device_placement = "past_key_values"
  380. _supports_flash_attn = True
  381. _supports_sdpa = True
  382. _can_compile_fullgraph = True
  383. _supports_flex_attn = True
  384. _supports_attention_backend = True
  385. class InternVLMultiModalProjector(nn.Module):
  386. def __init__(self, config: InternVLConfig):
  387. super().__init__()
  388. self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2)
  389. self.linear_1 = nn.Linear(
  390. config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.text_config.hidden_size
  391. )
  392. self.act = ACT2FN[config.projector_hidden_act]
  393. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
  394. def forward(self, image_features):
  395. hidden_states = self.layer_norm(image_features)
  396. hidden_states = self.linear_1(hidden_states)
  397. hidden_states = self.act(hidden_states)
  398. hidden_states = self.linear_2(hidden_states)
  399. return hidden_states
  400. @dataclass
  401. @auto_docstring(
  402. custom_intro="""
  403. Base class for InternVL outputs, with hidden states and attentions.
  404. """
  405. )
  406. class InternVLModelOutputWithPast(BaseModelOutputWithPast):
  407. r"""
  408. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  409. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  410. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  411. `past_key_values` input) to speed up sequential decoding.
  412. image_hidden_states (`torch.FloatTensor`, *optional*):
  413. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  414. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  415. """
  416. image_hidden_states: Optional[torch.FloatTensor] = None
  417. @auto_docstring(
  418. custom_intro="""
  419. The InternVL model which consists of a vision backbone and a language model, without a language modeling head.
  420. """
  421. )
  422. class InternVLModel(InternVLPreTrainedModel):
  423. _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
  424. def __init__(self, config: InternVLConfig):
  425. super().__init__(config)
  426. self.vision_tower = AutoModel.from_config(config.vision_config)
  427. self.multi_modal_projector = InternVLMultiModalProjector(config)
  428. self.language_model = AutoModel.from_config(config.text_config)
  429. self.post_init()
  430. def get_input_embeddings(self):
  431. return self.language_model.get_input_embeddings()
  432. def set_input_embeddings(self, value):
  433. self.language_model.set_input_embeddings(value)
  434. def set_decoder(self, decoder):
  435. self.language_model = decoder
  436. def get_decoder(self):
  437. return self.language_model
  438. def get_image_features(
  439. self,
  440. pixel_values: torch.FloatTensor,
  441. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  442. vision_feature_select_strategy: Optional[str] = None,
  443. **kwargs,
  444. ):
  445. """
  446. Obtains image last hidden states from the vision tower and apply multimodal projection.
  447. Args:
  448. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  449. The tensors corresponding to the input images.
  450. vision_feature_layer (`int` or `list[int]`):
  451. Layer index or list of layer indices to extract features from.
  452. Returns:
  453. vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
  454. """
  455. vision_feature_layer = (
  456. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  457. )
  458. vision_feature_select_strategy = (
  459. vision_feature_select_strategy
  460. if vision_feature_select_strategy is not None
  461. else self.config.vision_feature_select_strategy
  462. )
  463. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  464. downsample_ratio = self.config.downsample_ratio
  465. if vision_feature_layer == -1:
  466. vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state
  467. else:
  468. vision_features = self.vision_model(pixel_values=pixel_values).hidden_states[vision_feature_layer]
  469. if vision_feature_select_strategy == "default":
  470. vision_features = vision_features[:, 1:, :]
  471. # Calculate dimensions based on vision features
  472. channels = vision_features.shape[1]
  473. feature_size = int(channels**0.5)
  474. batch_size = vision_features.shape[0]
  475. # Reshape tensor to spatial dimensions
  476. vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1)
  477. # Apply downsampling using pixel shuffle
  478. vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio)
  479. # Reshape tensor to prepare for projection
  480. vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1])
  481. # Project features through multi-modal projector
  482. vision_features = self.multi_modal_projector(vision_features)
  483. return vision_features
  484. def get_placeholder_mask(
  485. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  486. ):
  487. """
  488. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  489. equal to the length of multimodal features. If the lengths are different, an error is raised.
  490. """
  491. if input_ids is None:
  492. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  493. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  494. )
  495. special_image_mask = special_image_mask.all(-1)
  496. else:
  497. special_image_mask = input_ids == self.config.image_token_id
  498. n_image_tokens = special_image_mask.sum()
  499. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  500. n_image_features = image_features.shape[0] * image_features.shape[1]
  501. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  502. raise ValueError(
  503. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  504. )
  505. return special_image_mask
  506. @can_return_tuple
  507. @auto_docstring
  508. def forward(
  509. self,
  510. input_ids: Optional[torch.LongTensor] = None,
  511. pixel_values: Optional[torch.FloatTensor] = None,
  512. attention_mask: Optional[torch.Tensor] = None,
  513. position_ids: Optional[torch.LongTensor] = None,
  514. past_key_values: Optional[Cache] = None,
  515. inputs_embeds: Optional[torch.FloatTensor] = None,
  516. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  517. vision_feature_select_strategy: Optional[str] = None,
  518. cache_position: Optional[torch.LongTensor] = None,
  519. **kwargs: Unpack[TransformersKwargs],
  520. ) -> Union[tuple, InternVLModelOutputWithPast]:
  521. vision_feature_layer = (
  522. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  523. )
  524. vision_feature_select_strategy = (
  525. vision_feature_select_strategy
  526. if vision_feature_select_strategy is not None
  527. else self.config.vision_feature_select_strategy
  528. )
  529. if (input_ids is None) ^ (inputs_embeds is not None):
  530. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  531. if inputs_embeds is None:
  532. inputs_embeds = self.get_input_embeddings()(input_ids)
  533. if pixel_values is not None:
  534. image_features = self.get_image_features(
  535. pixel_values=pixel_values,
  536. vision_feature_layer=vision_feature_layer,
  537. vision_feature_select_strategy=vision_feature_select_strategy,
  538. )
  539. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  540. special_image_mask = self.get_placeholder_mask(
  541. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  542. )
  543. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  544. outputs = self.language_model(
  545. attention_mask=attention_mask,
  546. position_ids=position_ids,
  547. past_key_values=past_key_values,
  548. inputs_embeds=inputs_embeds,
  549. cache_position=cache_position,
  550. **kwargs,
  551. )
  552. return InternVLModelOutputWithPast(
  553. last_hidden_state=outputs.last_hidden_state,
  554. past_key_values=outputs.past_key_values,
  555. hidden_states=outputs.hidden_states,
  556. attentions=outputs.attentions,
  557. image_hidden_states=image_features if pixel_values is not None else None,
  558. )
  559. def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
  560. """Perform pixel shuffle downsampling on vision features.
  561. Args:
  562. vision_features (`torch.Tensor`):
  563. Input tensor of shape (batch_size, width, height, channels).
  564. scale_factor (`float`, *optional*, defaults to `0.5`):
  565. Factor by which to downsample. Default is 0.5, which halves the dimensions.
  566. Returns:
  567. vision_features (`torch.Tensor`):
  568. Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
  569. """
  570. batch_size, width, height, channels = vision_features.size()
  571. if height % scale_factor != 0 or width % scale_factor != 0:
  572. raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
  573. # Reshape to allow downsampling
  574. vision_features = vision_features.view(
  575. batch_size, width, int(height * scale_factor), int(channels / scale_factor)
  576. )
  577. # Permute dimensions to align downsampled axis correctly
  578. vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
  579. # Reshape to achieve final downsampled dimensions
  580. vision_features = vision_features.view(
  581. batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
  582. )
  583. # Swap height and width back for proper orientation
  584. vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
  585. return vision_features
  586. @dataclass
  587. @auto_docstring(
  588. custom_intro="""
  589. Base class for InternVL causal language model (or autoregressive) outputs.
  590. """
  591. )
  592. class InternVLCausalLMOutputWithPast(ModelOutput):
  593. r"""
  594. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  595. Language modeling loss (for next-token prediction).
  596. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  597. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  598. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  599. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  600. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  601. `past_key_values` input) to speed up sequential decoding.
  602. image_hidden_states (`torch.FloatTensor`, *optional*):
  603. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  604. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  605. """
  606. loss: Optional[torch.FloatTensor] = None
  607. logits: Optional[torch.FloatTensor] = None
  608. past_key_values: Optional[Cache] = None
  609. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  610. attentions: Optional[tuple[torch.FloatTensor]] = None
  611. image_hidden_states: Optional[torch.FloatTensor] = None
  612. @auto_docstring(
  613. custom_intro="""
  614. The INTERNVL model which consists of a vision backbone and a language model.
  615. """
  616. )
  617. class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin):
  618. _checkpoint_conversion_mapping = {
  619. "^language_model.model": "model.language_model",
  620. "^vision_tower": "model.vision_tower",
  621. "^multi_modal_projector": "model.multi_modal_projector",
  622. "^language_model.lm_head": "lm_head",
  623. }
  624. _tied_weights_keys = ["lm_head.weight"]
  625. def __init__(self, config: InternVLConfig):
  626. super().__init__(config)
  627. self.model = InternVLModel(config)
  628. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  629. self.post_init()
  630. def get_input_embeddings(self):
  631. return self.model.get_input_embeddings()
  632. def set_input_embeddings(self, value):
  633. self.model.set_input_embeddings(value)
  634. def get_output_embeddings(self) -> nn.Module:
  635. return self.lm_head
  636. def set_decoder(self, decoder):
  637. self.model.set_decoder(decoder)
  638. def get_decoder(self):
  639. return self.model.get_decoder()
  640. def get_image_features(
  641. self,
  642. pixel_values: torch.FloatTensor,
  643. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  644. vision_feature_select_strategy: Optional[str] = None,
  645. **kwargs,
  646. ):
  647. return self.model.get_image_features(
  648. pixel_values=pixel_values,
  649. vision_feature_layer=vision_feature_layer,
  650. vision_feature_select_strategy=vision_feature_select_strategy,
  651. **kwargs,
  652. )
  653. # Make modules available through conditional class for BC
  654. @property
  655. def language_model(self):
  656. return self.model.language_model
  657. @property
  658. def vision_tower(self):
  659. return self.model.vision_tower
  660. @property
  661. def multi_modal_projector(self):
  662. return self.model.multi_modal_projector
  663. @can_return_tuple
  664. @auto_docstring
  665. def forward(
  666. self,
  667. input_ids: Optional[torch.LongTensor] = None,
  668. pixel_values: Optional[torch.FloatTensor] = None,
  669. attention_mask: Optional[torch.Tensor] = None,
  670. position_ids: Optional[torch.LongTensor] = None,
  671. past_key_values: Optional[Cache] = None,
  672. inputs_embeds: Optional[torch.FloatTensor] = None,
  673. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  674. vision_feature_select_strategy: Optional[str] = None,
  675. labels: Optional[torch.LongTensor] = None,
  676. cache_position: Optional[torch.LongTensor] = None,
  677. logits_to_keep: Union[int, torch.Tensor] = 0,
  678. image_sizes: Optional[torch.Tensor] = None,
  679. **kwargs: Unpack[TransformersKwargs],
  680. ) -> Union[tuple, InternVLCausalLMOutputWithPast]:
  681. r"""
  682. Example:
  683. ```python
  684. >>> import torch
  685. >>> from transformers import AutoProcessor, AutoModelForImageTextToText
  686. >>> torch_device = "cuda"
  687. >>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
  688. >>> model = AutoModelForImageTextToText.from_pretrained(
  689. ... "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device
  690. ... )
  691. >>> messages = [
  692. ... {
  693. ... "role": "user",
  694. ... "content": [
  695. ... {
  696. ... "type": "image",
  697. ... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
  698. ... },
  699. ... {
  700. ... "type": "image",
  701. ... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
  702. ... },
  703. ... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
  704. ... ],
  705. ... },
  706. ... ]
  707. >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
  708. >>> generate_ids = model.generate(**inputs, max_new_tokens=200)
  709. >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
  710. The images depict the Statue of Liberty and the Golden Gate Bridge.
  711. ```"""
  712. vision_feature_layer = (
  713. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  714. )
  715. vision_feature_select_strategy = (
  716. vision_feature_select_strategy
  717. if vision_feature_select_strategy is not None
  718. else self.config.vision_feature_select_strategy
  719. )
  720. outputs = self.model(
  721. input_ids=input_ids,
  722. pixel_values=pixel_values,
  723. attention_mask=attention_mask,
  724. position_ids=position_ids,
  725. past_key_values=past_key_values,
  726. inputs_embeds=inputs_embeds,
  727. vision_feature_layer=vision_feature_layer,
  728. vision_feature_select_strategy=vision_feature_select_strategy,
  729. cache_position=cache_position,
  730. image_sizes=image_sizes,
  731. **kwargs,
  732. )
  733. hidden_states = outputs[0]
  734. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  735. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  736. logits = self.lm_head(hidden_states[:, slice_indices, :])
  737. loss = None
  738. if labels is not None:
  739. loss = self.loss_function(
  740. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  741. )
  742. return InternVLCausalLMOutputWithPast(
  743. loss=loss,
  744. logits=logits,
  745. past_key_values=outputs.past_key_values,
  746. hidden_states=outputs.hidden_states,
  747. attentions=outputs.attentions,
  748. image_hidden_states=outputs.image_hidden_states,
  749. )
  750. def prepare_inputs_for_generation(
  751. self,
  752. input_ids,
  753. past_key_values=None,
  754. inputs_embeds=None,
  755. pixel_values=None,
  756. attention_mask=None,
  757. cache_position=None,
  758. logits_to_keep=None,
  759. **kwargs,
  760. ):
  761. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  762. model_inputs = super().prepare_inputs_for_generation(
  763. input_ids,
  764. past_key_values=past_key_values,
  765. inputs_embeds=inputs_embeds,
  766. attention_mask=attention_mask,
  767. cache_position=cache_position,
  768. logits_to_keep=logits_to_keep,
  769. **kwargs,
  770. )
  771. if cache_position[0] == 0:
  772. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  773. # Otherwise we need pixel values to be passed to model
  774. model_inputs["pixel_values"] = pixel_values
  775. return model_inputs
  776. __all__ = [
  777. "InternVLVisionPreTrainedModel",
  778. "InternVLVisionModel",
  779. "InternVLPreTrainedModel",
  780. "InternVLModel",
  781. "InternVLForConditionalGeneration",
  782. ]