modular_internvl.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. # coding=utf-8
  2. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import collections.abc
  16. from dataclasses import dataclass
  17. from typing import Callable, Optional, Union
  18. import torch
  19. import torch.nn as nn
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  27. from ...utils.generic import check_model_inputs
  28. from ..clip.modeling_clip import CLIPMLP
  29. from ..janus.modeling_janus import JanusVisionAttention
  30. from ..llama.modeling_llama import LlamaRMSNorm
  31. from ..llava.modeling_llava import (
  32. LlavaCausalLMOutputWithPast,
  33. LlavaForConditionalGeneration,
  34. LlavaModel,
  35. LlavaModelOutputWithPast,
  36. LlavaPreTrainedModel,
  37. )
  38. from .configuration_internvl import InternVLConfig, InternVLVisionConfig
  39. logger = logging.get_logger(__name__)
  40. def eager_attention_forward(
  41. module: nn.Module,
  42. query: torch.Tensor,
  43. key: torch.Tensor,
  44. value: torch.Tensor,
  45. attention_mask: Optional[torch.Tensor],
  46. scaling: float,
  47. dropout: float = 0.0,
  48. **kwargs,
  49. ):
  50. key_states = key
  51. value_states = value
  52. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  53. if attention_mask is not None:
  54. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  55. attn_weights = attn_weights + causal_mask
  56. # No upcasting of the attention weights to float32 in this implementation
  57. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  58. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  59. attn_output = torch.matmul(attn_weights, value_states)
  60. attn_output = attn_output.transpose(1, 2).contiguous()
  61. return attn_output, attn_weights
  62. class InternVLVisionRMSNorm(LlamaRMSNorm):
  63. pass
  64. class InternVLVisionAttention(JanusVisionAttention):
  65. def __init__(self, config: InternVLVisionConfig):
  66. super().__init__(config)
  67. del self.num_key_value_groups
  68. # Needed for flash attention
  69. self.is_causal = False
  70. qk_norm = config.use_qk_norm
  71. self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
  72. self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
  73. def forward(
  74. self,
  75. hidden_states: torch.Tensor,
  76. attention_mask: Optional[torch.Tensor] = None,
  77. **kwargs: Unpack[TransformersKwargs],
  78. ):
  79. batch_size, seq_len, _ = hidden_states.size()
  80. query_states = self.q_proj(hidden_states)
  81. key_states = self.k_proj(hidden_states)
  82. value_states = self.v_proj(hidden_states)
  83. query_states = self.q_norm(query_states)
  84. key_states = self.k_norm(key_states)
  85. query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  86. key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  87. value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  88. attention_interface: Callable = eager_attention_forward
  89. if self.config._attn_implementation != "eager":
  90. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  91. attn_output, attn_weights = attention_interface(
  92. self,
  93. query_states,
  94. key_states,
  95. value_states,
  96. attention_mask,
  97. dropout=0.0 if not self.training else self.attention_dropout,
  98. scaling=self.scale,
  99. is_causal=False,
  100. **kwargs,
  101. )
  102. attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
  103. output = self.projection_layer(attn_output)
  104. output = self.projection_dropout(output)
  105. return output, attn_weights
  106. @dataclass
  107. @auto_docstring(
  108. custom_intro="""
  109. Class for outputs of [`InternVLVisionModel`].
  110. """
  111. )
  112. class InternVLVisionModelOutputWithPooling(BaseModelOutputWithPooling):
  113. r"""
  114. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  115. Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
  116. *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
  117. will be returned.
  118. """
  119. class InternVLVisionPatchEmbeddings(nn.Module):
  120. """
  121. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  122. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  123. Transformer.
  124. """
  125. def __init__(self, config):
  126. super().__init__()
  127. image_size, patch_size = config.image_size, config.patch_size
  128. num_channels, hidden_size = config.num_channels, config.hidden_size
  129. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  130. patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  131. self.image_size = image_size
  132. self.patch_size = patch_size
  133. self.num_channels = num_channels
  134. self.num_patches = num_patches
  135. self.patch_shape = patch_shape
  136. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  137. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  138. batch_size, num_channels, height, width = pixel_values.shape
  139. if num_channels != self.num_channels:
  140. raise ValueError(
  141. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  142. )
  143. embeddings = self.projection(pixel_values)
  144. patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
  145. embeddings = embeddings.flatten(2).transpose(1, 2)
  146. return embeddings, (patch_height, patch_width)
  147. # Based on timm implementation, which can be found here:
  148. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  149. class InternVLVisionEmbeddings(nn.Module):
  150. """
  151. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  152. """
  153. def __init__(self, config: InternVLVisionConfig) -> None:
  154. super().__init__()
  155. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  156. if config.use_mask_token:
  157. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  158. else:
  159. self.mask_token = None
  160. self.patch_embeddings = InternVLVisionPatchEmbeddings(config)
  161. self.patch_size = config.patch_size
  162. self.image_size = (
  163. config.image_size
  164. if isinstance(config.image_size, collections.abc.Iterable)
  165. else (config.image_size, config.image_size)
  166. )
  167. num_patches = self.patch_embeddings.num_patches
  168. if config.use_absolute_position_embeddings:
  169. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  170. else:
  171. self.position_embeddings = None
  172. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  173. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  174. """
  175. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  176. images. This method is also adapted to support torch.jit tracing.
  177. Adapted from:
  178. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  179. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  180. """
  181. num_patches = embeddings.shape[1] - 1
  182. num_positions = self.position_embeddings.shape[1] - 1
  183. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  184. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  185. return self.position_embeddings
  186. class_pos_embed = self.position_embeddings[:, :1]
  187. patch_pos_embed = self.position_embeddings[:, 1:]
  188. dim = embeddings.shape[-1]
  189. new_height = height // self.patch_size[0]
  190. new_width = width // self.patch_size[1]
  191. sqrt_num_positions = torch_int(num_positions**0.5)
  192. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  193. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  194. patch_pos_embed = nn.functional.interpolate(
  195. patch_pos_embed,
  196. size=(new_height, new_width),
  197. mode="bicubic",
  198. align_corners=False,
  199. )
  200. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  201. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  202. def forward(
  203. self,
  204. pixel_values: torch.Tensor,
  205. bool_masked_pos: Optional[torch.BoolTensor] = None,
  206. ) -> torch.Tensor:
  207. _, _, height, width = pixel_values.shape
  208. embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
  209. batch_size, seq_len, _ = embeddings.size()
  210. if bool_masked_pos is not None:
  211. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  212. # replace the masked visual tokens by mask_tokens
  213. w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  214. embeddings = embeddings * (1 - w) + mask_tokens * w
  215. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  216. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  217. if self.position_embeddings is not None:
  218. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  219. embeddings = self.dropout(embeddings)
  220. return embeddings, (patch_height, patch_width)
  221. class InternVLVisionMLP(CLIPMLP):
  222. pass
  223. NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm}
  224. class InternVLVisionLayer(GradientCheckpointingLayer):
  225. """This corresponds to the Block class in the timm implementation."""
  226. def __init__(self, config: InternVLVisionConfig) -> None:
  227. super().__init__()
  228. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  229. self.seq_len_dim = 1
  230. self.attention = InternVLVisionAttention(config)
  231. self.mlp = InternVLVisionMLP(config)
  232. # InternVL uses different layernorm implementations for different models
  233. self.layernorm_before = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
  234. self.layernorm_after = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps)
  235. init_values = config.layer_scale_init_value
  236. self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  237. self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  238. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  239. def forward(
  240. self,
  241. hidden_states: torch.Tensor,
  242. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  243. attention_output, _ = self.attention(
  244. self.layernorm_before(hidden_states), # in InternVLVision, layernorm is applied before self-attention
  245. )
  246. attention_output = self.lambda_1 * attention_output
  247. # first residual connection
  248. hidden_states = attention_output + hidden_states
  249. # in InternVLVision, layernorm is also applied after self-attention
  250. layer_output = self.layernorm_after(hidden_states)
  251. layer_output = self.mlp(layer_output)
  252. layer_output = self.dropout(layer_output)
  253. if self.lambda_2 is not None:
  254. layer_output = self.lambda_2 * layer_output
  255. # second residual connection
  256. layer_output = layer_output + hidden_states
  257. return layer_output
  258. class InternVLVisionEncoder(nn.Module):
  259. def __init__(self, config: InternVLVisionConfig) -> None:
  260. super().__init__()
  261. self.config = config
  262. self.layer = nn.ModuleList([InternVLVisionLayer(config) for i in range(config.num_hidden_layers)])
  263. self.gradient_checkpointing = False
  264. def forward(
  265. self,
  266. hidden_states: torch.Tensor,
  267. ) -> Union[tuple, BaseModelOutput]:
  268. for layer_module in self.layer:
  269. hidden_states = layer_module(hidden_states)
  270. return BaseModelOutput(
  271. last_hidden_state=hidden_states,
  272. )
  273. @auto_docstring
  274. class InternVLVisionPreTrainedModel(PreTrainedModel):
  275. config: InternVLVisionConfig
  276. base_model_prefix = "internvl_vision"
  277. main_input_name = "pixel_values"
  278. supports_gradient_checkpointing = True
  279. _no_split_modules = ["InternVLVisionLayer"]
  280. _supports_sdpa = True
  281. _supports_flash_attn = True
  282. _supports_flex_attn = True
  283. _supports_attention_backend = True
  284. _can_record_outputs = {
  285. "hidden_states": InternVLVisionLayer,
  286. "attentions": InternVLVisionAttention,
  287. }
  288. def _init_weights(self, module):
  289. """Initialize the weights"""
  290. super()._init_weights(module)
  291. if isinstance(module, InternVLVisionEmbeddings):
  292. module.cls_token.data.zero_()
  293. if module.mask_token is not None:
  294. module.mask_token.data.zero_()
  295. if module.position_embeddings is not None:
  296. module.position_embeddings.data.zero_()
  297. elif isinstance(module, InternVLVisionLayer):
  298. module.lambda_1.data.fill_(self.config.layer_scale_init_value)
  299. module.lambda_2.data.fill_(self.config.layer_scale_init_value)
  300. @auto_docstring
  301. class InternVLVisionModel(InternVLVisionPreTrainedModel):
  302. def __init__(self, config: InternVLVisionConfig) -> None:
  303. super().__init__(config)
  304. self.config = config
  305. self.embeddings = InternVLVisionEmbeddings(config)
  306. self.encoder = InternVLVisionEncoder(config)
  307. self.layernorm = (
  308. nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  309. )
  310. # Initialize weights and apply final processing
  311. self.post_init()
  312. def get_input_embeddings(self):
  313. return self.embeddings.patch_embeddings
  314. @check_model_inputs(tie_last_hidden_states=False)
  315. @auto_docstring
  316. def forward(
  317. self,
  318. pixel_values: torch.Tensor,
  319. bool_masked_pos: Optional[torch.BoolTensor] = None,
  320. ) -> Union[tuple, InternVLVisionModelOutputWithPooling]:
  321. r"""
  322. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  323. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  324. """
  325. embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  326. encoder_outputs = self.encoder(embedding_output)
  327. sequence_output = encoder_outputs[0]
  328. sequence_output = self.layernorm(sequence_output)
  329. return InternVLVisionModelOutputWithPooling(
  330. last_hidden_state=sequence_output,
  331. hidden_states=encoder_outputs.hidden_states,
  332. attentions=encoder_outputs.attentions,
  333. )
  334. class InternVLPreTrainedModel(LlavaPreTrainedModel):
  335. pass
  336. INTERNVL_INPUTS_DOCSTRING = None
  337. class InternVLMultiModalProjector(nn.Module):
  338. def __init__(self, config: InternVLConfig):
  339. super().__init__()
  340. self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2)
  341. self.linear_1 = nn.Linear(
  342. config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.text_config.hidden_size
  343. )
  344. self.act = ACT2FN[config.projector_hidden_act]
  345. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
  346. def forward(self, image_features):
  347. hidden_states = self.layer_norm(image_features)
  348. hidden_states = self.linear_1(hidden_states)
  349. hidden_states = self.act(hidden_states)
  350. hidden_states = self.linear_2(hidden_states)
  351. return hidden_states
  352. class InternVLModelOutputWithPast(LlavaModelOutputWithPast):
  353. pass
  354. class InternVLModel(LlavaModel):
  355. def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
  356. """Perform pixel shuffle downsampling on vision features.
  357. Args:
  358. vision_features (`torch.Tensor`):
  359. Input tensor of shape (batch_size, width, height, channels).
  360. scale_factor (`float`, *optional*, defaults to `0.5`):
  361. Factor by which to downsample. Default is 0.5, which halves the dimensions.
  362. Returns:
  363. vision_features (`torch.Tensor`):
  364. Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
  365. """
  366. batch_size, width, height, channels = vision_features.size()
  367. if height % scale_factor != 0 or width % scale_factor != 0:
  368. raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
  369. # Reshape to allow downsampling
  370. vision_features = vision_features.view(
  371. batch_size, width, int(height * scale_factor), int(channels / scale_factor)
  372. )
  373. # Permute dimensions to align downsampled axis correctly
  374. vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
  375. # Reshape to achieve final downsampled dimensions
  376. vision_features = vision_features.view(
  377. batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
  378. )
  379. # Swap height and width back for proper orientation
  380. vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
  381. return vision_features
  382. def get_image_features(
  383. self,
  384. pixel_values: torch.FloatTensor,
  385. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  386. vision_feature_select_strategy: Optional[str] = None,
  387. **kwargs,
  388. ):
  389. """
  390. Obtains image last hidden states from the vision tower and apply multimodal projection.
  391. Args:
  392. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  393. The tensors corresponding to the input images.
  394. vision_feature_layer (`int` or `list[int]`):
  395. Layer index or list of layer indices to extract features from.
  396. Returns:
  397. vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
  398. """
  399. vision_feature_layer = (
  400. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  401. )
  402. vision_feature_select_strategy = (
  403. vision_feature_select_strategy
  404. if vision_feature_select_strategy is not None
  405. else self.config.vision_feature_select_strategy
  406. )
  407. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  408. downsample_ratio = self.config.downsample_ratio
  409. if vision_feature_layer == -1:
  410. vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state
  411. else:
  412. vision_features = self.vision_model(pixel_values=pixel_values).hidden_states[vision_feature_layer]
  413. if vision_feature_select_strategy == "default":
  414. vision_features = vision_features[:, 1:, :]
  415. # Calculate dimensions based on vision features
  416. channels = vision_features.shape[1]
  417. feature_size = int(channels**0.5)
  418. batch_size = vision_features.shape[0]
  419. # Reshape tensor to spatial dimensions
  420. vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1)
  421. # Apply downsampling using pixel shuffle
  422. vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio)
  423. # Reshape tensor to prepare for projection
  424. vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1])
  425. # Project features through multi-modal projector
  426. vision_features = self.multi_modal_projector(vision_features)
  427. return vision_features
  428. @can_return_tuple
  429. @auto_docstring
  430. def forward(
  431. self,
  432. input_ids: Optional[torch.LongTensor] = None,
  433. pixel_values: Optional[torch.FloatTensor] = None,
  434. attention_mask: Optional[torch.Tensor] = None,
  435. position_ids: Optional[torch.LongTensor] = None,
  436. past_key_values: Optional[Cache] = None,
  437. inputs_embeds: Optional[torch.FloatTensor] = None,
  438. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  439. vision_feature_select_strategy: Optional[str] = None,
  440. cache_position: Optional[torch.LongTensor] = None,
  441. **kwargs: Unpack[TransformersKwargs],
  442. ) -> Union[tuple, InternVLModelOutputWithPast]:
  443. vision_feature_layer = (
  444. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  445. )
  446. vision_feature_select_strategy = (
  447. vision_feature_select_strategy
  448. if vision_feature_select_strategy is not None
  449. else self.config.vision_feature_select_strategy
  450. )
  451. if (input_ids is None) ^ (inputs_embeds is not None):
  452. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  453. if inputs_embeds is None:
  454. inputs_embeds = self.get_input_embeddings()(input_ids)
  455. if pixel_values is not None:
  456. image_features = self.get_image_features(
  457. pixel_values=pixel_values,
  458. vision_feature_layer=vision_feature_layer,
  459. vision_feature_select_strategy=vision_feature_select_strategy,
  460. )
  461. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  462. special_image_mask = self.get_placeholder_mask(
  463. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  464. )
  465. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  466. outputs = self.language_model(
  467. attention_mask=attention_mask,
  468. position_ids=position_ids,
  469. past_key_values=past_key_values,
  470. inputs_embeds=inputs_embeds,
  471. cache_position=cache_position,
  472. **kwargs,
  473. )
  474. return InternVLModelOutputWithPast(
  475. last_hidden_state=outputs.last_hidden_state,
  476. past_key_values=outputs.past_key_values,
  477. hidden_states=outputs.hidden_states,
  478. attentions=outputs.attentions,
  479. image_hidden_states=image_features if pixel_values is not None else None,
  480. )
  481. class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
  482. pass
  483. class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
  484. def forward(**super_kwargs):
  485. r"""
  486. Example:
  487. ```python
  488. >>> import torch
  489. >>> from transformers import AutoProcessor, AutoModelForImageTextToText
  490. >>> torch_device = "cuda"
  491. >>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
  492. >>> model = AutoModelForImageTextToText.from_pretrained(
  493. ... "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device
  494. ... )
  495. >>> messages = [
  496. ... {
  497. ... "role": "user",
  498. ... "content": [
  499. ... {
  500. ... "type": "image",
  501. ... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
  502. ... },
  503. ... {
  504. ... "type": "image",
  505. ... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
  506. ... },
  507. ... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
  508. ... ],
  509. ... },
  510. ... ]
  511. >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
  512. >>> generate_ids = model.generate(**inputs, max_new_tokens=200)
  513. >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
  514. The images depict the Statue of Liberty and the Golden Gate Bridge.
  515. ```"""
  516. super().forward(**super_kwargs)
  517. __all__ = [
  518. "InternVLVisionPreTrainedModel",
  519. "InternVLVisionModel",
  520. "InternVLPreTrainedModel",
  521. "InternVLModel",
  522. "InternVLForConditionalGeneration",
  523. ]