modeling_idefics3.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. # coding=utf-8
  2. # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Idefics3 model."""
  16. from dataclasses import dataclass
  17. from typing import Callable, Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  24. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import BaseModelOutput, ModelOutput
  27. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  30. from ...utils.generic import check_model_inputs
  31. from ..auto import AutoModel
  32. from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
  33. logger = logging.get_logger(__name__)
  34. @dataclass
  35. @auto_docstring(
  36. custom_intro="""
  37. Base class for Idefics3 model's outputs that may also contain a past key/values (to speed up sequential decoding).
  38. """
  39. )
  40. class Idefics3BaseModelOutputWithPast(ModelOutput):
  41. r"""
  42. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  43. Sequence of hidden-states at the output of the last layer of the model.
  44. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  45. hidden_size)` is output.
  46. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  47. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  48. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  49. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  50. input) to speed up sequential decoding.
  51. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  52. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  53. sequence_length, hidden_size)`.
  54. image_hidden_states of the model produced by the vision encoder
  55. """
  56. last_hidden_state: Optional[torch.FloatTensor] = None
  57. past_key_values: Optional[Cache] = None
  58. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  59. attentions: Optional[tuple[torch.FloatTensor]] = None
  60. image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  61. @dataclass
  62. @auto_docstring(
  63. custom_intro="""
  64. Base class for Idefics causal language model (or autoregressive) outputs.
  65. """
  66. )
  67. class Idefics3CausalLMOutputWithPast(ModelOutput):
  68. r"""
  69. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  70. Language modeling loss (for next-token prediction).
  71. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  72. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  73. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  74. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  75. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  76. `past_key_values` input) to speed up sequential decoding.
  77. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  78. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  79. sequence_length, hidden_size)`.
  80. image_hidden_states of the model produced by the vision encoder
  81. """
  82. loss: Optional[torch.FloatTensor] = None
  83. logits: Optional[torch.FloatTensor] = None
  84. past_key_values: Optional[Cache] = None
  85. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  86. attentions: Optional[tuple[torch.FloatTensor]] = None
  87. image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  88. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings with Idefics2->Idefics3
  89. class Idefics3VisionEmbeddings(nn.Module):
  90. """
  91. This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
  92. resolution.
  93. The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
  94. which allows treating images in their native aspect ratio and without the need to resize them to the same
  95. fixed size. In particular, we start from the original pre-trained SigLIP model
  96. (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
  97. """
  98. def __init__(self, config: Idefics3VisionConfig):
  99. super().__init__()
  100. self.embed_dim = config.hidden_size
  101. self.image_size = config.image_size
  102. self.patch_size = config.patch_size
  103. self.patch_embedding = nn.Conv2d(
  104. in_channels=config.num_channels,
  105. out_channels=self.embed_dim,
  106. kernel_size=self.patch_size,
  107. stride=self.patch_size,
  108. padding="valid",
  109. )
  110. self.num_patches_per_side = self.image_size // self.patch_size
  111. self.num_patches = self.num_patches_per_side**2
  112. self.num_positions = self.num_patches
  113. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  114. def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
  115. batch_size, _, max_im_h, max_im_w = pixel_values.shape
  116. patch_embeds = self.patch_embedding(pixel_values)
  117. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  118. max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
  119. boundaries = torch.arange(
  120. 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
  121. )
  122. position_ids = torch.full(
  123. size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
  124. )
  125. for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
  126. nb_patches_h = p_attn_mask[:, 0].sum()
  127. nb_patches_w = p_attn_mask[0].sum()
  128. h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
  129. w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
  130. fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
  131. fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
  132. bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
  133. bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
  134. pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
  135. position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
  136. embeddings = embeddings + self.position_embedding(position_ids)
  137. return embeddings
  138. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  139. def eager_attention_forward(
  140. module: nn.Module,
  141. query: torch.Tensor,
  142. key: torch.Tensor,
  143. value: torch.Tensor,
  144. attention_mask: Optional[torch.Tensor],
  145. scaling: float,
  146. dropout: float = 0.0,
  147. **kwargs,
  148. ):
  149. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  150. if attention_mask is not None:
  151. attn_weights = attn_weights + attention_mask
  152. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  153. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  154. attn_output = torch.matmul(attn_weights, value)
  155. attn_output = attn_output.transpose(1, 2).contiguous()
  156. return attn_output, attn_weights
  157. # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics3Vision
  158. class Idefics3VisionAttention(nn.Module):
  159. """Multi-headed attention from 'Attention Is All You Need' paper"""
  160. # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
  161. def __init__(self, config):
  162. super().__init__()
  163. self.config = config
  164. self.embed_dim = config.hidden_size
  165. self.num_heads = config.num_attention_heads
  166. self.head_dim = self.embed_dim // self.num_heads
  167. if self.head_dim * self.num_heads != self.embed_dim:
  168. raise ValueError(
  169. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  170. f" {self.num_heads})."
  171. )
  172. self.scale = self.head_dim**-0.5
  173. self.dropout = config.attention_dropout
  174. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  175. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  176. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  177. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  178. # Ignore copy
  179. self.is_causal = False
  180. def forward(
  181. self,
  182. hidden_states: torch.Tensor,
  183. attention_mask: Optional[torch.Tensor] = None,
  184. **kwargs,
  185. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  186. """Input shape: Batch x Time x Channel"""
  187. batch_size, seq_length, embed_dim = hidden_states.shape
  188. queries = self.q_proj(hidden_states)
  189. keys = self.k_proj(hidden_states)
  190. values = self.v_proj(hidden_states)
  191. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  192. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  193. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  194. attention_interface: Callable = eager_attention_forward
  195. if self.config._attn_implementation != "eager":
  196. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  197. attn_output, attn_weights = attention_interface(
  198. self,
  199. queries,
  200. keys,
  201. values,
  202. attention_mask,
  203. is_causal=self.is_causal,
  204. scaling=self.scale,
  205. dropout=0.0 if not self.training else self.dropout,
  206. )
  207. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  208. attn_output = self.out_proj(attn_output)
  209. return attn_output, attn_weights
  210. # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics3Vision
  211. class Idefics3VisionMLP(nn.Module):
  212. def __init__(self, config):
  213. super().__init__()
  214. self.config = config
  215. self.activation_fn = ACT2FN[config.hidden_act]
  216. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  217. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  218. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  219. hidden_states = self.fc1(hidden_states)
  220. hidden_states = self.activation_fn(hidden_states)
  221. hidden_states = self.fc2(hidden_states)
  222. return hidden_states
  223. class Idefics3SimpleMLP(nn.Module):
  224. def __init__(self, config):
  225. super().__init__()
  226. input_size = config.vision_config.hidden_size * (config.scale_factor**2)
  227. output_size = config.text_config.hidden_size
  228. self.proj = nn.Linear(input_size, output_size, bias=False)
  229. def forward(self, x):
  230. return self.proj(x)
  231. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
  232. class Idefics3EncoderLayer(GradientCheckpointingLayer):
  233. def __init__(self, config: Idefics3VisionConfig):
  234. super().__init__()
  235. self.embed_dim = config.hidden_size
  236. self.self_attn = Idefics3VisionAttention(config)
  237. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  238. self.mlp = Idefics3VisionMLP(config)
  239. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  240. @auto_docstring
  241. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
  242. def forward(
  243. self,
  244. hidden_states: torch.Tensor,
  245. attention_mask: torch.Tensor,
  246. **kwargs: Unpack[TransformersKwargs],
  247. ) -> torch.FloatTensor:
  248. residual = hidden_states
  249. hidden_states = self.layer_norm1(hidden_states)
  250. hidden_states, _ = self.self_attn(
  251. hidden_states=hidden_states,
  252. attention_mask=attention_mask,
  253. **kwargs,
  254. )
  255. hidden_states = residual + hidden_states
  256. residual = hidden_states
  257. hidden_states = self.layer_norm2(hidden_states)
  258. hidden_states = self.mlp(hidden_states)
  259. hidden_states = residual + hidden_states
  260. return hidden_states
  261. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics3
  262. class Idefics3Encoder(nn.Module):
  263. """
  264. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  265. [`Idefics3EncoderLayer`].
  266. Args:
  267. config: Idefics3Config
  268. """
  269. def __init__(self, config: Idefics3Config):
  270. super().__init__()
  271. self.config = config
  272. self.layers = nn.ModuleList([Idefics3EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  273. self.gradient_checkpointing = False
  274. # Ignore copy
  275. @auto_docstring
  276. def forward(
  277. self,
  278. inputs_embeds,
  279. attention_mask: Optional[torch.Tensor] = None,
  280. ) -> Union[tuple, BaseModelOutput]:
  281. hidden_states = inputs_embeds
  282. for encoder_layer in self.layers:
  283. layer_outputs = encoder_layer(
  284. hidden_states,
  285. attention_mask,
  286. )
  287. hidden_states = layer_outputs
  288. return BaseModelOutput(last_hidden_state=hidden_states)
  289. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  290. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  291. """
  292. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  293. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  294. """
  295. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  296. if n_rep == 1:
  297. return hidden_states
  298. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  299. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  300. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics3
  301. class Idefics3RMSNorm(nn.Module):
  302. def __init__(self, hidden_size, eps=1e-6):
  303. """
  304. Idefics3RMSNorm is equivalent to T5LayerNorm
  305. """
  306. super().__init__()
  307. self.weight = nn.Parameter(torch.ones(hidden_size))
  308. self.variance_epsilon = eps
  309. def forward(self, hidden_states):
  310. input_dtype = hidden_states.dtype
  311. hidden_states = hidden_states.to(torch.float32)
  312. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  313. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  314. return self.weight * hidden_states.to(input_dtype)
  315. def extra_repr(self):
  316. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  317. class Idefics3Connector(nn.Module):
  318. def __init__(self, config):
  319. super().__init__()
  320. self.scale_factor = config.scale_factor
  321. self.modality_projection = Idefics3SimpleMLP(config)
  322. def pixel_shuffle(self, x, scale_factor=2):
  323. bsz, seq, embed_dim = x.size()
  324. height = width = int(seq**0.5)
  325. x = x.view(bsz, height, width, embed_dim)
  326. x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
  327. x = x.permute(0, 2, 1, 3)
  328. x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
  329. x = x.permute(0, 2, 1, 3)
  330. x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
  331. return x
  332. def forward(self, image_hidden_states):
  333. image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
  334. image_hidden_states = self.modality_projection(image_hidden_states)
  335. return image_hidden_states
  336. @auto_docstring
  337. class Idefics3PreTrainedModel(PreTrainedModel):
  338. config: Idefics3Config
  339. base_model_prefix = "model"
  340. supports_gradient_checkpointing = True
  341. _no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"]
  342. _skip_keys_device_placement = "past_key_values"
  343. _supports_flash_attn = True
  344. _supports_sdpa = True
  345. _supports_flex_attn = True
  346. _supports_attention_backend = True
  347. def _init_weights(self, module):
  348. std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
  349. if isinstance(module, (nn.Linear, nn.Conv2d)):
  350. module.weight.data.normal_(mean=0.0, std=std)
  351. if module.bias is not None:
  352. module.bias.data.zero_()
  353. elif isinstance(module, nn.Embedding):
  354. module.weight.data.normal_(mean=0.0, std=std)
  355. if module.padding_idx is not None:
  356. module.weight.data[module.padding_idx].zero_()
  357. elif isinstance(module, nn.LayerNorm):
  358. module.weight.data.fill_(1.0)
  359. module.bias.data.zero_()
  360. elif isinstance(module, Idefics3RMSNorm):
  361. module.weight.data.fill_(1.0)
  362. @auto_docstring(
  363. custom_intro="""
  364. The Idefics3 Vision Transformer Model outputting raw image embedding.
  365. """
  366. )
  367. class Idefics3VisionTransformer(Idefics3PreTrainedModel):
  368. config: Idefics3VisionConfig
  369. _supports_sdpa = True
  370. _supports_flash_attn = True
  371. _supports_flex_attn = True
  372. _can_record_outputs = {
  373. "hidden_states": Idefics3EncoderLayer,
  374. "attentions": Idefics3VisionAttention,
  375. }
  376. def __init__(self, config: Idefics3VisionConfig):
  377. super().__init__(config)
  378. embed_dim = config.hidden_size
  379. self.embeddings = Idefics3VisionEmbeddings(config)
  380. self.encoder = Idefics3Encoder(config)
  381. self.patch_size = config.patch_size
  382. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  383. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
  384. def get_input_embeddings(self):
  385. return self.embeddings
  386. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.set_input_embeddings
  387. def set_input_embeddings(self, value):
  388. self.embeddings = value
  389. @check_model_inputs(tie_last_hidden_states=False)
  390. def forward(
  391. self,
  392. pixel_values,
  393. patch_attention_mask: Optional[torch.BoolTensor] = None,
  394. **kwargs: Unpack[TransformersKwargs],
  395. ) -> Union[tuple, BaseModelOutput]:
  396. batch_size = pixel_values.size(0)
  397. if patch_attention_mask is None:
  398. patch_size = self.patch_size
  399. patch_attention_mask = torch.ones(
  400. (
  401. batch_size,
  402. pixel_values.size(2) // patch_size,
  403. pixel_values.size(3) // patch_size,
  404. )
  405. )
  406. patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
  407. hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
  408. patch_attention_mask = patch_attention_mask.view(batch_size, -1)
  409. # The call to `_upad_input` in `_flash_attention_forward` is expensive
  410. # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
  411. # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
  412. if self.config._attn_implementation != "flash_attention_2":
  413. patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
  414. elif not torch.any(~patch_attention_mask):
  415. patch_attention_mask = None
  416. encoder_outputs: BaseModelOutput = self.encoder(
  417. inputs_embeds=hidden_states,
  418. attention_mask=patch_attention_mask,
  419. )
  420. last_hidden_state = encoder_outputs.last_hidden_state
  421. last_hidden_state = self.post_layernorm(last_hidden_state)
  422. return BaseModelOutput(
  423. last_hidden_state=last_hidden_state,
  424. )
  425. @auto_docstring(
  426. custom_intro="""
  427. Idefics3 model consisting of a SIGLIP vision encoder and Llama3 language decoder
  428. """
  429. )
  430. class Idefics3Model(Idefics3PreTrainedModel):
  431. def __init__(self, config: Idefics3Config):
  432. super().__init__(config)
  433. self.padding_idx = self.config.text_config.pad_token_id
  434. self.vocab_size = self.config.text_config.vocab_size
  435. self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config)
  436. self.connector = Idefics3Connector(config)
  437. self.text_model = AutoModel.from_config(config.text_config)
  438. self.image_seq_len = int(
  439. ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
  440. )
  441. self.image_token_id = self.config.image_token_id
  442. self.post_init()
  443. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads
  444. def enable_input_require_grads(self):
  445. """
  446. Enables the gradients for the input embeddings.
  447. This is useful for lora when using gradient checkpointing.
  448. c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
  449. Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
  450. """
  451. def get_lowest_module(module):
  452. if len(list(module.children())) == 0:
  453. # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
  454. return module
  455. else:
  456. # Recursively call the function on each child module
  457. return get_lowest_module(list(module.children())[0])
  458. def make_inputs_require_grads(module, input, output):
  459. output.requires_grad_(True)
  460. self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
  461. self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
  462. make_inputs_require_grads
  463. )
  464. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads
  465. def disable_input_require_grads(self):
  466. self._text_require_grads_hook.remove()
  467. self._vision_require_grads_hook.remove()
  468. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings
  469. def get_input_embeddings(self):
  470. return self.text_model.get_input_embeddings()
  471. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.set_input_embeddings
  472. def set_input_embeddings(self, value):
  473. self.text_model.set_input_embeddings(value)
  474. def inputs_merger(
  475. self,
  476. input_ids: torch.LongTensor,
  477. inputs_embeds: Optional[torch.Tensor],
  478. image_hidden_states: Optional[torch.Tensor],
  479. ):
  480. """
  481. This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
  482. The merging happens as follows:
  483. - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
  484. - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
  485. We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
  486. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
  487. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
  488. """
  489. if input_ids is None:
  490. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  491. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  492. )
  493. special_image_mask = special_image_mask.all(-1)
  494. else:
  495. special_image_mask = input_ids == self.config.image_token_id
  496. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  497. image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
  498. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
  499. return inputs_embeds
  500. def get_image_features(
  501. self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
  502. ):
  503. """
  504. Encodes images into continuous embeddings that can be forwarded to the language model.
  505. Args:
  506. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  507. The tensors corresponding to the input images.
  508. pixel_attention_mask (`torch.LongTensor`, *optional*):
  509. The attention mask indicating padded regions in the image.
  510. """
  511. batch_size, num_images, num_channels, height, width = pixel_values.shape
  512. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  513. pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
  514. # Remove padding images - padding images are full 0.
  515. nb_values_per_image = pixel_values.shape[1:].numel()
  516. real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
  517. pixel_values = pixel_values[real_images_inds].contiguous()
  518. # Handle the vision attention mask
  519. if pixel_attention_mask is None:
  520. pixel_attention_mask = torch.ones(
  521. size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
  522. dtype=torch.bool,
  523. device=pixel_values.device,
  524. )
  525. else:
  526. # Remove padding images from the mask
  527. pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
  528. pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
  529. patch_size = self.config.vision_config.patch_size
  530. patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
  531. patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
  532. patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
  533. # Get sequence from the vision encoder
  534. image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
  535. image_hidden_states.last_hidden_state
  536. # Modality projection & resampling
  537. image_hidden_states = self.connector(image_hidden_states.last_hidden_state)
  538. return image_hidden_states
  539. @can_return_tuple
  540. @auto_docstring(
  541. custom_intro="""
  542. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  543. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  544. max_num_images is the maximum number of images among the batch_size samples in the batch.
  545. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  546. For efficiency, we only pass through the vision_model's forward the real images by
  547. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  548. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  549. """
  550. )
  551. def forward(
  552. self,
  553. input_ids: Optional[torch.LongTensor] = None,
  554. attention_mask: Optional[torch.Tensor] = None,
  555. position_ids: Optional[torch.LongTensor] = None,
  556. past_key_values: Optional[Cache] = None,
  557. inputs_embeds: Optional[torch.FloatTensor] = None,
  558. pixel_values: Optional[torch.FloatTensor] = None,
  559. pixel_attention_mask: Optional[torch.BoolTensor] = None,
  560. image_hidden_states: Optional[torch.FloatTensor] = None,
  561. use_cache: Optional[bool] = None,
  562. output_attentions: Optional[bool] = None,
  563. output_hidden_states: Optional[bool] = None,
  564. cache_position: Optional[torch.LongTensor] = None,
  565. return_dict: Optional[bool] = None,
  566. **kwargs: Unpack[FlashAttentionKwargs],
  567. ) -> Union[tuple, Idefics3BaseModelOutputWithPast]:
  568. r"""
  569. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  570. Mask to avoid performing attention on padding pixel indices.
  571. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  572. The hidden states of the image encoder after modality projection.
  573. """
  574. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  575. output_hidden_states = (
  576. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  577. )
  578. use_cache = use_cache if use_cache is not None else self.config.use_cache
  579. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  580. if self.training and self.text_model.gradient_checkpointing and use_cache:
  581. logger.warning_once(
  582. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  583. )
  584. use_cache = False
  585. # retrieve input_ids and inputs_embeds
  586. if input_ids is not None:
  587. batch_size, seq_length = input_ids.shape
  588. elif inputs_embeds is not None:
  589. batch_size, seq_length, _ = inputs_embeds.shape
  590. else:
  591. raise ValueError("You have to specify either input_ids or inputs_embeds")
  592. if use_cache and past_key_values is None:
  593. past_key_values = DynamicCache(config=self.config)
  594. if inputs_embeds is None:
  595. inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
  596. # START VISUAL INPUTS INTEGRATION
  597. if pixel_values is not None and image_hidden_states is not None:
  598. raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
  599. elif pixel_values is not None:
  600. image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
  601. elif image_hidden_states is not None:
  602. image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
  603. if image_hidden_states is not None:
  604. # When we generate, we don't want to replace the potential image_token_id that we generated by images
  605. # that simply don't exist
  606. inputs_embeds = self.inputs_merger(
  607. input_ids=input_ids,
  608. inputs_embeds=inputs_embeds,
  609. image_hidden_states=image_hidden_states,
  610. )
  611. outputs = self.text_model(
  612. inputs_embeds=inputs_embeds,
  613. attention_mask=attention_mask,
  614. position_ids=position_ids,
  615. past_key_values=past_key_values,
  616. use_cache=use_cache,
  617. output_attentions=output_attentions,
  618. output_hidden_states=output_hidden_states,
  619. cache_position=cache_position,
  620. return_dict=True,
  621. **kwargs,
  622. )
  623. return Idefics3BaseModelOutputWithPast(
  624. last_hidden_state=outputs.last_hidden_state,
  625. past_key_values=outputs.past_key_values,
  626. hidden_states=outputs.hidden_states,
  627. attentions=outputs.attentions,
  628. image_hidden_states=image_hidden_states,
  629. )
  630. @auto_docstring(
  631. custom_intro="""
  632. The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
  633. """
  634. )
  635. class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
  636. _tied_weights_keys = ["lm_head.weight"]
  637. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
  638. def __init__(self, config):
  639. super().__init__(config)
  640. self.model = Idefics3Model(config)
  641. self.image_token_id = self.config.image_token_id
  642. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  643. self.vocab_size = config.text_config.vocab_size
  644. # Initialize weights and apply final processing
  645. self.post_init()
  646. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.enable_input_require_grads
  647. def enable_input_require_grads(self):
  648. """
  649. Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
  650. the model weights fixed.
  651. """
  652. def make_inputs_require_grads(module, input, output):
  653. output.requires_grad_(True)
  654. self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
  655. self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
  656. make_inputs_require_grads
  657. )
  658. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads
  659. def disable_input_require_grads(self):
  660. self._text_require_grads_hook.remove()
  661. self._vision_require_grads_hook.remove()
  662. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings
  663. def get_input_embeddings(self):
  664. return self.model.text_model.get_input_embeddings()
  665. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_input_embeddings
  666. def set_input_embeddings(self, value):
  667. self.model.text_model.set_input_embeddings(value)
  668. def get_image_features(
  669. self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
  670. ):
  671. return self.model.get_image_features(pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask)
  672. @can_return_tuple
  673. @auto_docstring
  674. def forward(
  675. self,
  676. input_ids: Optional[torch.LongTensor] = None,
  677. attention_mask: Optional[torch.Tensor] = None,
  678. position_ids: Optional[torch.LongTensor] = None,
  679. past_key_values: Optional[Cache] = None,
  680. inputs_embeds: Optional[torch.FloatTensor] = None,
  681. pixel_values: Optional[torch.FloatTensor] = None,
  682. pixel_attention_mask: Optional[torch.BoolTensor] = None,
  683. image_hidden_states: Optional[torch.FloatTensor] = None,
  684. labels: Optional[torch.LongTensor] = None,
  685. use_cache: Optional[bool] = None,
  686. output_attentions: Optional[bool] = None,
  687. output_hidden_states: Optional[bool] = None,
  688. cache_position: Optional[torch.LongTensor] = None,
  689. return_dict: Optional[bool] = None,
  690. logits_to_keep: Union[int, torch.Tensor] = 0,
  691. **kwargs: Unpack[TransformersKwargs],
  692. ) -> Union[tuple, Idefics3CausalLMOutputWithPast]:
  693. r"""
  694. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  695. Mask to avoid performing attention on padding pixel indices.
  696. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  697. The hidden states of the image encoder after modality projection.
  698. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  699. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  700. config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).
  701. Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
  702. computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  703. Example:
  704. ```python
  705. >>> import requests
  706. >>> import torch
  707. >>> from PIL import Image
  708. >>> from io import BytesIO
  709. >>> from transformers import AutoProcessor, AutoModelForVision2Seq
  710. >>> from transformers.image_utils import load_image
  711. >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
  712. >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
  713. >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
  714. >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
  715. >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
  716. >>> model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", dtype=torch.bfloat16, device_map="auto")
  717. >>> # Create inputs
  718. >>> messages = [
  719. ... {
  720. ... "role": "user",
  721. ... "content": [
  722. ... {"type": "image"},
  723. ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
  724. ... {"type": "image"},
  725. ... {"type": "text", "text": "What can we see in this image?"},
  726. ... ]
  727. ... },
  728. ... {
  729. ... "role": "user",
  730. ... "content": [
  731. ... {"type": "image"},
  732. ... {"type": "text", "text": "In which city is that bridge located?"},
  733. ... ]
  734. ... }
  735. ... ]
  736. >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
  737. >>> images = [[image1, image2], [image3]]
  738. >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
  739. >>> # Generate
  740. >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
  741. >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
  742. >>> print(generated_texts[0])
  743. Assistant: There are buildings, trees, lights, and water visible in this image.
  744. >>> print(generated_texts[1])
  745. Assistant: The bridge is in San Francisco.
  746. ```"""
  747. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  748. output_hidden_states = (
  749. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  750. )
  751. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  752. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  753. outputs = self.model(
  754. input_ids=input_ids,
  755. attention_mask=attention_mask,
  756. position_ids=position_ids,
  757. past_key_values=past_key_values,
  758. inputs_embeds=inputs_embeds,
  759. pixel_values=pixel_values,
  760. pixel_attention_mask=pixel_attention_mask,
  761. image_hidden_states=image_hidden_states,
  762. use_cache=use_cache,
  763. output_attentions=output_attentions,
  764. output_hidden_states=output_hidden_states,
  765. cache_position=cache_position,
  766. return_dict=True,
  767. **kwargs,
  768. )
  769. hidden_states = outputs[0]
  770. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  771. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  772. logits = self.lm_head(hidden_states[:, slice_indices, :])
  773. loss = None
  774. if labels is not None:
  775. loss = self.loss_function(
  776. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  777. )
  778. return Idefics3CausalLMOutputWithPast(
  779. loss=loss,
  780. logits=logits,
  781. past_key_values=outputs.past_key_values,
  782. hidden_states=outputs.hidden_states,
  783. attentions=outputs.attentions,
  784. image_hidden_states=outputs.image_hidden_states,
  785. )
  786. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.prepare_inputs_for_generation
  787. def prepare_inputs_for_generation(
  788. self,
  789. input_ids,
  790. past_key_values=None,
  791. attention_mask=None,
  792. inputs_embeds=None,
  793. cache_position=None,
  794. pixel_values=None,
  795. pixel_attention_mask=None,
  796. image_hidden_states=None,
  797. logits_to_keep=None,
  798. **kwargs,
  799. ):
  800. # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
  801. # precedence is moved to the model, we can remove this fn)
  802. model_inputs = super().prepare_inputs_for_generation(
  803. input_ids,
  804. past_key_values=past_key_values,
  805. attention_mask=attention_mask,
  806. inputs_embeds=inputs_embeds,
  807. cache_position=cache_position,
  808. pixel_values=pixel_values,
  809. pixel_attention_mask=pixel_attention_mask,
  810. image_hidden_states=image_hidden_states,
  811. logits_to_keep=logits_to_keep,
  812. **kwargs,
  813. )
  814. if image_hidden_states is not None or cache_position[0] != 0:
  815. model_inputs["pixel_values"] = None
  816. model_inputs["pixel_attention_mask"] = None
  817. return model_inputs
  818. __all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"]