modeling_idefics2.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196
  1. # coding=utf-8
  2. # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Idefics2 model."""
  16. from dataclasses import dataclass
  17. from typing import Callable, Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  24. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import BaseModelOutput, ModelOutput
  27. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  30. from ...utils.deprecation import deprecate_kwarg
  31. from ...utils.generic import check_model_inputs
  32. from ..auto import AutoModel
  33. from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig
  34. logger = logging.get_logger(__name__)
  35. @dataclass
  36. @auto_docstring(
  37. custom_intro="""
  38. Base class for Idefics2 model's outputs that may also contain a past key/values (to speed up sequential decoding).
  39. """
  40. )
  41. class Idefics2BaseModelOutputWithPast(ModelOutput):
  42. r"""
  43. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  44. Sequence of hidden-states at the output of the last layer of the model.
  45. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  46. hidden_size)` is output.
  47. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  48. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  49. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  50. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  51. input) to speed up sequential decoding.
  52. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  53. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  54. sequence_length, hidden_size)`.
  55. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  56. """
  57. last_hidden_state: Optional[torch.FloatTensor] = None
  58. past_key_values: Optional[Cache] = None
  59. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  60. attentions: Optional[tuple[torch.FloatTensor]] = None
  61. image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  62. @dataclass
  63. @auto_docstring(
  64. custom_intro="""
  65. Base class for Idefics2 causal language model (or autoregressive) outputs.
  66. """
  67. )
  68. # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Idefics2
  69. class Idefics2CausalLMOutputWithPast(ModelOutput):
  70. r"""
  71. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  72. Language modeling loss (for next-token prediction).
  73. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  74. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  75. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  76. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  77. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  78. `past_key_values` input) to speed up sequential decoding.
  79. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  80. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  81. sequence_length, hidden_size)`.
  82. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  83. """
  84. loss: Optional[torch.FloatTensor] = None
  85. logits: Optional[torch.FloatTensor] = None
  86. past_key_values: Optional[Cache] = None
  87. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  88. attentions: Optional[tuple[torch.FloatTensor]] = None
  89. image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  90. class Idefics2VisionEmbeddings(nn.Module):
  91. """
  92. This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
  93. resolution.
  94. The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
  95. which allows treating images in their native aspect ratio and without the need to resize them to the same
  96. fixed size. In particular, we start from the original pre-trained SigLIP model
  97. (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
  98. """
  99. def __init__(self, config: Idefics2VisionConfig):
  100. super().__init__()
  101. self.embed_dim = config.hidden_size
  102. self.image_size = config.image_size
  103. self.patch_size = config.patch_size
  104. self.patch_embedding = nn.Conv2d(
  105. in_channels=config.num_channels,
  106. out_channels=self.embed_dim,
  107. kernel_size=self.patch_size,
  108. stride=self.patch_size,
  109. padding="valid",
  110. )
  111. self.num_patches_per_side = self.image_size // self.patch_size
  112. self.num_patches = self.num_patches_per_side**2
  113. self.num_positions = self.num_patches
  114. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  115. def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
  116. batch_size, _, max_im_h, max_im_w = pixel_values.shape
  117. patch_embeds = self.patch_embedding(pixel_values)
  118. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  119. max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
  120. boundaries = torch.arange(
  121. 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device
  122. )
  123. position_ids = torch.full(
  124. size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device
  125. )
  126. for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
  127. nb_patches_h = p_attn_mask[:, 0].sum()
  128. nb_patches_w = p_attn_mask[0].sum()
  129. h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=pixel_values.dtype)
  130. w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=pixel_values.dtype)
  131. fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6)
  132. fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)
  133. bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
  134. bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
  135. pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
  136. position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
  137. embeddings = embeddings + self.position_embedding(position_ids)
  138. return embeddings
  139. def eager_attention_forward(
  140. module: nn.Module,
  141. query: torch.Tensor,
  142. key: torch.Tensor,
  143. value: torch.Tensor,
  144. attention_mask: Optional[torch.Tensor],
  145. scaling: float,
  146. dropout: float = 0.0,
  147. **kwargs,
  148. ):
  149. if hasattr(module, "num_key_value_groups"):
  150. key = repeat_kv(key, module.num_key_value_groups)
  151. value = repeat_kv(value, module.num_key_value_groups)
  152. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  153. if attention_mask is not None:
  154. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  155. attn_weights = attn_weights + causal_mask
  156. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  157. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  158. attn_output = torch.matmul(attn_weights, value)
  159. attn_output = attn_output.transpose(1, 2).contiguous()
  160. return attn_output, attn_weights
  161. # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics2Vision
  162. class Idefics2VisionAttention(nn.Module):
  163. """Multi-headed attention from 'Attention Is All You Need' paper"""
  164. # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
  165. def __init__(self, config):
  166. super().__init__()
  167. self.config = config
  168. self.embed_dim = config.hidden_size
  169. self.num_heads = config.num_attention_heads
  170. self.head_dim = self.embed_dim // self.num_heads
  171. if self.head_dim * self.num_heads != self.embed_dim:
  172. raise ValueError(
  173. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  174. f" {self.num_heads})."
  175. )
  176. self.scale = self.head_dim**-0.5
  177. self.dropout = config.attention_dropout
  178. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  179. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  180. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  181. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  182. # Ignore copy
  183. self.is_causal = False
  184. def forward(
  185. self,
  186. hidden_states: torch.Tensor,
  187. attention_mask: Optional[torch.Tensor] = None,
  188. **kwargs,
  189. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  190. """Input shape: Batch x Time x Channel"""
  191. batch_size, seq_length, embed_dim = hidden_states.shape
  192. queries = self.q_proj(hidden_states)
  193. keys = self.k_proj(hidden_states)
  194. values = self.v_proj(hidden_states)
  195. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  196. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  197. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  198. attention_interface: Callable = eager_attention_forward
  199. if self.config._attn_implementation != "eager":
  200. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  201. attn_output, attn_weights = attention_interface(
  202. self,
  203. queries,
  204. keys,
  205. values,
  206. attention_mask,
  207. is_causal=self.is_causal,
  208. scaling=self.scale,
  209. dropout=0.0 if not self.training else self.dropout,
  210. )
  211. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  212. attn_output = self.out_proj(attn_output)
  213. return attn_output, attn_weights
  214. # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics2Vision
  215. class Idefics2VisionMLP(nn.Module):
  216. def __init__(self, config):
  217. super().__init__()
  218. self.config = config
  219. self.activation_fn = ACT2FN[config.hidden_act]
  220. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  221. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  222. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  223. hidden_states = self.fc1(hidden_states)
  224. hidden_states = self.activation_fn(hidden_states)
  225. hidden_states = self.fc2(hidden_states)
  226. return hidden_states
  227. class Idefics2MLP(nn.Module):
  228. def __init__(
  229. self,
  230. hidden_size: int,
  231. intermediate_size: int,
  232. output_size: int,
  233. hidden_act: str,
  234. ):
  235. super().__init__()
  236. self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
  237. self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
  238. self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
  239. self.act_fn = ACT2FN[hidden_act]
  240. def forward(self, x):
  241. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  242. # Copied from transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead with Siglip->Idefics2
  243. class Idefics2MultiheadAttentionPoolingHead(nn.Module):
  244. """Multihead Attention Pooling."""
  245. def __init__(self, config: Idefics2VisionConfig):
  246. super().__init__()
  247. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  248. self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
  249. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  250. # Ignore copy
  251. self.mlp = Idefics2MLP(
  252. hidden_size=config.hidden_size,
  253. intermediate_size=config.intermediate_size,
  254. hidden_act=config.hidden_act,
  255. output_size=config.hidden_size,
  256. )
  257. def forward(self, hidden_state):
  258. batch_size = hidden_state.shape[0]
  259. probe = self.probe.repeat(batch_size, 1, 1)
  260. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  261. residual = hidden_state
  262. hidden_state = self.layernorm(hidden_state)
  263. hidden_state = residual + self.mlp(hidden_state)
  264. return hidden_state[:, 0]
  265. class Idefics2EncoderLayer(GradientCheckpointingLayer):
  266. def __init__(self, config: Idefics2VisionConfig):
  267. super().__init__()
  268. self.embed_dim = config.hidden_size
  269. self.self_attn = Idefics2VisionAttention(config)
  270. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  271. self.mlp = Idefics2VisionMLP(config)
  272. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  273. @auto_docstring
  274. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
  275. def forward(
  276. self,
  277. hidden_states: torch.Tensor,
  278. attention_mask: torch.Tensor,
  279. **kwargs: Unpack[TransformersKwargs],
  280. ) -> torch.FloatTensor:
  281. residual = hidden_states
  282. hidden_states = self.layer_norm1(hidden_states)
  283. hidden_states, _ = self.self_attn(
  284. hidden_states=hidden_states,
  285. attention_mask=attention_mask,
  286. **kwargs,
  287. )
  288. hidden_states = residual + hidden_states
  289. residual = hidden_states
  290. hidden_states = self.layer_norm2(hidden_states)
  291. hidden_states = self.mlp(hidden_states)
  292. hidden_states = residual + hidden_states
  293. return hidden_states
  294. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics2
  295. class Idefics2Encoder(nn.Module):
  296. """
  297. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  298. [`Idefics2EncoderLayer`].
  299. Args:
  300. config: Idefics2Config
  301. """
  302. def __init__(self, config: Idefics2Config):
  303. super().__init__()
  304. self.config = config
  305. self.layers = nn.ModuleList([Idefics2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  306. self.gradient_checkpointing = False
  307. # Ignore copy
  308. @auto_docstring
  309. def forward(
  310. self,
  311. inputs_embeds,
  312. attention_mask: Optional[torch.Tensor] = None,
  313. **kwargs: Unpack[TransformersKwargs],
  314. ) -> BaseModelOutput:
  315. hidden_states = inputs_embeds
  316. for encoder_layer in self.layers:
  317. hidden_states = encoder_layer(
  318. hidden_states,
  319. attention_mask,
  320. **kwargs,
  321. )
  322. return BaseModelOutput(last_hidden_state=hidden_states)
  323. @auto_docstring
  324. class Idefics2PreTrainedModel(PreTrainedModel):
  325. config: Idefics2Config
  326. base_model_prefix = "model"
  327. supports_gradient_checkpointing = True
  328. _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
  329. _skip_keys_device_placement = "past_key_values"
  330. _supports_flash_attn = True
  331. _supports_sdpa = True
  332. _supports_flex_attn = True
  333. _supports_attention_backend = True
  334. def _init_weights(self, module):
  335. std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
  336. if isinstance(module, (nn.Linear, nn.Conv2d)):
  337. module.weight.data.normal_(mean=0.0, std=std)
  338. if module.bias is not None:
  339. module.bias.data.zero_()
  340. elif isinstance(module, nn.Embedding):
  341. module.weight.data.normal_(mean=0.0, std=std)
  342. if module.padding_idx is not None:
  343. module.weight.data[module.padding_idx].zero_()
  344. elif isinstance(module, nn.LayerNorm):
  345. module.weight.data.fill_(1.0)
  346. module.bias.data.zero_()
  347. elif isinstance(module, Idefics2RMSNorm):
  348. module.weight.data.fill_(1.0)
  349. elif isinstance(module, nn.MultiheadAttention):
  350. module._reset_parameters() # native torch init
  351. elif isinstance(module, Idefics2MultiheadAttentionPoolingHead):
  352. module.probe.data.normal_()
  353. elif isinstance(module, Idefics2PerceiverResampler):
  354. module.latents.data.fill_(1.0)
  355. @auto_docstring(
  356. custom_intro="""
  357. Idefics2 vision encoder model that returnss raw image embeddings.
  358. """
  359. )
  360. class Idefics2VisionTransformer(Idefics2PreTrainedModel):
  361. config: Idefics2VisionConfig
  362. _supports_sdpa = True
  363. _supports_flash_attn = True
  364. _supports_flex_attn = True
  365. _can_record_outputs = {
  366. "hidden_states": Idefics2EncoderLayer,
  367. "attentions": Idefics2VisionAttention,
  368. }
  369. def __init__(self, config: Idefics2VisionConfig):
  370. super().__init__(config)
  371. embed_dim = config.hidden_size
  372. self.config = config
  373. self.embeddings = Idefics2VisionEmbeddings(config)
  374. self.encoder = Idefics2Encoder(config)
  375. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  376. def get_input_embeddings(self):
  377. return self.embeddings
  378. def set_input_embeddings(self, value):
  379. self.embeddings = value
  380. @check_model_inputs(tie_last_hidden_states=False)
  381. @auto_docstring
  382. def forward(
  383. self,
  384. pixel_values,
  385. patch_attention_mask: Optional[torch.BoolTensor] = None,
  386. **kwargs: Unpack[TransformersKwargs],
  387. ) -> Union[tuple, BaseModelOutput]:
  388. r"""
  389. patch_attention_mask (`torch.BoolTensor` of shape `(batch_size, num_patches_height, num_patches_width)`, *optional*):
  390. The attention mask for the patches.
  391. """
  392. batch_size = pixel_values.size(0)
  393. if patch_attention_mask is None:
  394. patch_size = self.config.patch_size
  395. patch_attention_mask = torch.ones(
  396. (
  397. batch_size,
  398. pixel_values.size(2) // patch_size,
  399. pixel_values.size(3) // patch_size,
  400. )
  401. )
  402. patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
  403. hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
  404. patch_attention_mask = patch_attention_mask.view(batch_size, -1)
  405. # The call to `_upad_input` in `_flash_attention_forward` is expensive
  406. # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
  407. # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
  408. if not torch.any(~patch_attention_mask):
  409. patch_attention_mask = None
  410. elif self.config._attn_implementation != "flash_attention_2":
  411. patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
  412. encoder_outputs: BaseModelOutput = self.encoder(
  413. inputs_embeds=hidden_states,
  414. attention_mask=patch_attention_mask,
  415. **kwargs,
  416. )
  417. last_hidden_state = encoder_outputs.last_hidden_state
  418. last_hidden_state = self.post_layernorm(last_hidden_state)
  419. return BaseModelOutput(last_hidden_state=last_hidden_state)
  420. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  421. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  422. """
  423. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  424. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  425. """
  426. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  427. if n_rep == 1:
  428. return hidden_states
  429. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  430. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  431. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics2
  432. class Idefics2RMSNorm(nn.Module):
  433. def __init__(self, hidden_size, eps=1e-6):
  434. """
  435. Idefics2RMSNorm is equivalent to T5LayerNorm
  436. """
  437. super().__init__()
  438. self.weight = nn.Parameter(torch.ones(hidden_size))
  439. self.variance_epsilon = eps
  440. def forward(self, hidden_states):
  441. input_dtype = hidden_states.dtype
  442. hidden_states = hidden_states.to(torch.float32)
  443. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  444. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  445. return self.weight * hidden_states.to(input_dtype)
  446. def extra_repr(self):
  447. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  448. class Idefics2PerceiverAttention(nn.Module):
  449. def __init__(self, config, layer_idx: Optional[int] = None) -> None:
  450. """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
  451. super().__init__()
  452. self.config = config
  453. self.layer_idx = None
  454. self.hidden_size = config.hidden_size
  455. self.num_heads = config.resampler_n_heads
  456. self.head_dim = config.resampler_head_dim
  457. self.num_key_value_heads = config.num_key_value_heads
  458. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  459. self.attention_dropout = config.attention_dropout
  460. self.scaling = self.head_dim**-0.5
  461. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  462. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  463. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  464. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  465. self.is_causal = False
  466. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  467. def forward(
  468. self,
  469. latents: torch.Tensor,
  470. context: torch.Tensor,
  471. attention_mask: Optional[torch.Tensor] = None,
  472. position_ids: Optional[torch.LongTensor] = None,
  473. past_key_values: Optional[Cache] = None,
  474. **kwargs: Unpack[TransformersKwargs],
  475. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  476. """
  477. Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
  478. Args:
  479. latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
  480. context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
  481. attention_mask (`torch.Tensor`, *optional*): Tensor of shape [bsz, 1, seq, n_latents] representing attention mask.
  482. position_ids (`torch.LongTensor`, *optional*): Tensor of shape [bsz, seq] representing position indices of each input token.
  483. past_key_values (`Cache`, *optional*): Tuple of tensors containing cached key and value states.
  484. output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
  485. use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_values for caching.
  486. """
  487. bsz, q_len, _ = latents.size()
  488. kv_seq_len = q_len + context.size()[1]
  489. hidden_states = torch.concat([context, latents], dim=-2)
  490. queries = self.q_proj(latents)
  491. keys = self.k_proj(hidden_states)
  492. values = self.v_proj(hidden_states)
  493. queries = queries.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  494. keys = keys.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  495. values = values.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  496. past_key_values = getattr(self, "past_key_values", past_key_values)
  497. if past_key_values is not None:
  498. keys, values = past_key_values.update(keys, values, self.layer_idx)
  499. attention_interface: Callable = eager_attention_forward
  500. if self.config._attn_implementation != "eager":
  501. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  502. attn_output, attn_weights = attention_interface(
  503. self,
  504. queries,
  505. keys,
  506. values,
  507. attention_mask,
  508. is_causal=self.is_causal,
  509. scaling=self.scaling,
  510. dropout=0.0 if not self.training else self.attention_dropout,
  511. **kwargs,
  512. )
  513. attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
  514. attn_output = self.o_proj(attn_output)
  515. return attn_output, attn_weights
  516. class Idefics2PerceiverLayer(nn.Module):
  517. def __init__(self, config, layer_idx: int):
  518. super().__init__()
  519. self.hidden_size = config.hidden_size
  520. self.n_latents = config.resampler_n_latents
  521. self.depth = config.resampler_depth
  522. self.rms_norm_eps = config.rms_norm_eps
  523. self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
  524. self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
  525. self.self_attn = Idefics2PerceiverAttention(config, layer_idx=layer_idx)
  526. self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
  527. self.mlp = Idefics2MLP(
  528. hidden_size=config.hidden_size,
  529. intermediate_size=config.hidden_size * 4,
  530. output_size=config.hidden_size,
  531. hidden_act=config.hidden_act,
  532. )
  533. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  534. def forward(
  535. self,
  536. latents: torch.Tensor,
  537. context: torch.Tensor,
  538. attention_mask: Optional[torch.Tensor] = None,
  539. position_ids: Optional[torch.LongTensor] = None,
  540. past_key_values: Optional[Cache] = None,
  541. **kwargs: Unpack[TransformersKwargs],
  542. ) -> torch.FloatTensor:
  543. """
  544. Args:
  545. latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  546. context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  547. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  548. `(batch, sequence_length)` where padding elements are indicated by 0.
  549. output_attentions (`bool`, *optional*):
  550. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  551. returned tensors for more detail.
  552. use_cache (`bool`, *optional*):
  553. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  554. (see `past_key_values`).
  555. past_key_values (`Cache`, *optional*): cached past key and value projection states
  556. """
  557. residual = latents
  558. latents = self.input_latents_norm(latents)
  559. context = self.input_context_norm(context)
  560. latents, _ = self.self_attn(
  561. latents=latents,
  562. context=context,
  563. attention_mask=attention_mask,
  564. **kwargs,
  565. )
  566. latents = residual + latents
  567. residual = latents
  568. latents = self.post_attention_layernorm(latents)
  569. latents = self.mlp(latents)
  570. latents = residual + latents
  571. return latents
  572. @auto_docstring(
  573. custom_intro="""
  574. Idefics2 perceiver resampler model that performs `depth` blocks of cross-attention with a fixed
  575. """
  576. )
  577. class Idefics2PerceiverResampler(Idefics2PreTrainedModel):
  578. config: Idefics2PerceiverConfig
  579. _supports_sdpa = True
  580. _supports_flash_attention_2 = True
  581. _supports_flex_attn = True
  582. def __init__(self, config) -> None:
  583. super().__init__(config)
  584. self.hidden_size = config.hidden_size
  585. self.hidden_act = config.hidden_act
  586. self.n_latents = config.resampler_n_latents
  587. self.depth = config.resampler_depth
  588. self.rms_norm_eps = config.rms_norm_eps
  589. # Create Latents for Perceiver
  590. self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size))
  591. # Create Transformer Blocks
  592. self.layers = nn.ModuleList([Idefics2PerceiverLayer(config, idx) for idx in range(self.depth)])
  593. self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
  594. @auto_docstring
  595. def forward(
  596. self,
  597. context: torch.Tensor,
  598. attention_mask: torch.Tensor,
  599. **kwargs: Unpack[TransformersKwargs],
  600. ) -> torch.Tensor:
  601. r"""
  602. context (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`):
  603. Input to the layer.
  604. """
  605. # seq embed -> bsz seq embed
  606. latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size()))
  607. latent_attention_mask = torch.ones(
  608. (attention_mask.size(0), latents.size(1)), dtype=attention_mask.dtype, device=attention_mask.device
  609. )
  610. attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
  611. attention_mask = (
  612. _prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents)
  613. if self.config._attn_implementation != "flash_attention_2"
  614. else attention_mask
  615. )
  616. compressed_context = latents
  617. for perceiver_layer in self.layers:
  618. compressed_context = perceiver_layer(
  619. compressed_context,
  620. context,
  621. attention_mask=attention_mask,
  622. position_ids=None,
  623. **kwargs,
  624. )
  625. compressed_context = self.norm(compressed_context)
  626. return compressed_context
  627. class Idefics2Connector(nn.Module):
  628. def __init__(self, config):
  629. super().__init__()
  630. self.modality_projection = Idefics2MLP(
  631. hidden_size=config.vision_config.hidden_size,
  632. intermediate_size=config.text_config.intermediate_size,
  633. output_size=config.text_config.hidden_size,
  634. hidden_act=config.text_config.hidden_act,
  635. )
  636. self.perceiver_resampler = Idefics2PerceiverResampler._from_config(config.perceiver_config)
  637. def forward(self, image_hidden_states, attention_mask):
  638. image_hidden_states = self.modality_projection(image_hidden_states)
  639. image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
  640. return image_hidden_states
  641. @auto_docstring(
  642. custom_intro="""
  643. Idefics2 model consisting of a SIGLIP vision encoder and Mistral language decoder
  644. """
  645. )
  646. class Idefics2Model(Idefics2PreTrainedModel):
  647. def __init__(self, config: Idefics2Config):
  648. super().__init__(config)
  649. self.padding_idx = self.config.text_config.pad_token_id
  650. self.vocab_size = self.config.text_config.vocab_size
  651. self.vision_model = Idefics2VisionTransformer._from_config(config.vision_config)
  652. self.connector = Idefics2Connector(config)
  653. self.text_model = AutoModel.from_config(config.text_config)
  654. self.image_seq_len = config.perceiver_config.resampler_n_latents
  655. self.image_token_id = self.config.image_token_id
  656. self.post_init()
  657. def enable_input_require_grads(self):
  658. """
  659. Enables the gradients for the input embeddings.
  660. This is useful for lora when using gradient checkpointing.
  661. c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
  662. Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
  663. """
  664. def get_lowest_module(module):
  665. if len(list(module.children())) == 0:
  666. # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
  667. return module
  668. else:
  669. # Recursively call the function on each child module
  670. return get_lowest_module(list(module.children())[0])
  671. def make_inputs_require_grads(module, input, output):
  672. output.requires_grad_(True)
  673. self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
  674. self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
  675. make_inputs_require_grads
  676. )
  677. def disable_input_require_grads(self):
  678. self._text_require_grads_hook.remove()
  679. self._vision_require_grads_hook.remove()
  680. def get_input_embeddings(self):
  681. return self.text_model.get_input_embeddings()
  682. def set_input_embeddings(self, value):
  683. self.text_model.set_input_embeddings(value)
  684. def inputs_merger(
  685. self,
  686. input_ids: torch.LongTensor,
  687. inputs_embeds: Optional[torch.Tensor],
  688. image_hidden_states: Optional[torch.Tensor],
  689. ):
  690. """
  691. This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
  692. The merging happens as follows:
  693. - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
  694. - We get the image hidden states for the image through the vision encoder (and potentially the perceiver), and that hidden state is then projected into the text embedding space.
  695. We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
  696. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
  697. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
  698. """
  699. if input_ids is None:
  700. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  701. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  702. )
  703. special_image_mask = special_image_mask.all(-1)
  704. else:
  705. special_image_mask = input_ids == self.config.image_token_id
  706. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  707. image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
  708. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
  709. return inputs_embeds
  710. def get_image_features(
  711. self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
  712. ):
  713. """
  714. Encodes images into continuous embeddings that can be forwarded to the language model.
  715. Args:
  716. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  717. The tensors corresponding to the input images.
  718. pixel_attention_mask (`torch.LongTensor`, *optional*):
  719. The attention mask indicating padded regions in the image.
  720. """
  721. batch_size, num_images, num_channels, height, width = pixel_values.shape
  722. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  723. pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
  724. # Remove padding images - padding images are full 0.
  725. nb_values_per_image = pixel_values.shape[1:].numel()
  726. real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
  727. pixel_values = pixel_values[real_images_inds].contiguous()
  728. # Handle the vision attention mask
  729. if pixel_attention_mask is None:
  730. pixel_attention_mask = torch.ones(
  731. size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
  732. dtype=torch.bool,
  733. device=pixel_values.device,
  734. )
  735. else:
  736. # Remove padding images from the mask/pP p
  737. pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
  738. pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
  739. patch_size = self.config.vision_config.patch_size
  740. patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
  741. patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
  742. patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
  743. # Get sequence from the vision encoder
  744. image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
  745. image_hidden_states = image_hidden_states.last_hidden_state
  746. # Modality projection & resampling
  747. image_hidden_states = self.connector(
  748. image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
  749. )
  750. image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
  751. return image_hidden_states
  752. @can_return_tuple
  753. @auto_docstring(
  754. custom_intro="""
  755. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  756. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  757. max_num_images is the maximum number of images among the batch_size samples in the batch.
  758. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  759. For efficiency, we only pass through the vision_model's forward the real images by
  760. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  761. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  762. """
  763. )
  764. def forward(
  765. self,
  766. input_ids: Optional[torch.LongTensor] = None,
  767. attention_mask: Optional[torch.Tensor] = None,
  768. position_ids: Optional[torch.LongTensor] = None,
  769. past_key_values: Optional[Cache] = None,
  770. inputs_embeds: Optional[torch.FloatTensor] = None,
  771. pixel_values: Optional[torch.FloatTensor] = None,
  772. pixel_attention_mask: Optional[torch.BoolTensor] = None,
  773. image_hidden_states: Optional[torch.FloatTensor] = None,
  774. use_cache: Optional[bool] = None,
  775. cache_position: Optional[torch.LongTensor] = None,
  776. **kwargs: Unpack[FlashAttentionKwargs],
  777. ) -> Union[tuple, Idefics2BaseModelOutputWithPast]:
  778. r"""
  779. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  780. Mask to avoid performing attention on padding pixel indices.
  781. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  782. The hidden states of the image encoder after modality projection and perceiver resampling.
  783. """
  784. if self.training and self.text_model.gradient_checkpointing and use_cache:
  785. logger.warning_once(
  786. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  787. )
  788. use_cache = False
  789. # retrieve input_ids and inputs_embeds
  790. if input_ids is not None:
  791. batch_size, seq_length = input_ids.shape
  792. elif inputs_embeds is not None:
  793. batch_size, seq_length, _ = inputs_embeds.shape
  794. else:
  795. raise ValueError("You have to specify either input_ids or inputs_embeds")
  796. if use_cache and past_key_values is None:
  797. past_key_values = DynamicCache(config=self.config)
  798. if inputs_embeds is None:
  799. inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
  800. # START VISUAL INPUTS INTEGRATION
  801. if pixel_values is not None and image_hidden_states is not None:
  802. raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
  803. elif pixel_values is not None:
  804. image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
  805. elif image_hidden_states is not None:
  806. image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
  807. if image_hidden_states is not None:
  808. # When we generate, we don't want to replace the potential image_token_id that we generated by images
  809. # that simply don't exist
  810. inputs_embeds = self.inputs_merger(
  811. input_ids=input_ids,
  812. inputs_embeds=inputs_embeds,
  813. image_hidden_states=image_hidden_states,
  814. )
  815. kwargs["return_dict"] = True
  816. outputs = self.text_model(
  817. inputs_embeds=inputs_embeds,
  818. attention_mask=attention_mask,
  819. position_ids=position_ids,
  820. past_key_values=past_key_values,
  821. use_cache=use_cache,
  822. cache_position=cache_position,
  823. **kwargs,
  824. )
  825. return Idefics2BaseModelOutputWithPast(
  826. last_hidden_state=outputs.last_hidden_state,
  827. past_key_values=outputs.past_key_values,
  828. hidden_states=outputs.hidden_states,
  829. attentions=outputs.attentions,
  830. image_hidden_states=image_hidden_states,
  831. )
  832. @auto_docstring(
  833. custom_intro="""
  834. The Idefics2 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
  835. """
  836. )
  837. class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin):
  838. _tied_weights_keys = ["lm_head.weight"]
  839. def __init__(self, config):
  840. super().__init__(config)
  841. self.model = Idefics2Model(config)
  842. self.image_token_id = self.config.image_token_id
  843. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  844. self.vocab_size = config.text_config.vocab_size
  845. # Initialize weights and apply final processing
  846. self.post_init()
  847. def enable_input_require_grads(self):
  848. """
  849. Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
  850. the model weights fixed.
  851. """
  852. def make_inputs_require_grads(module, input, output):
  853. output.requires_grad_(True)
  854. self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
  855. self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
  856. make_inputs_require_grads
  857. )
  858. def disable_input_require_grads(self):
  859. self._text_require_grads_hook.remove()
  860. self._vision_require_grads_hook.remove()
  861. def get_input_embeddings(self):
  862. return self.model.text_model.get_input_embeddings()
  863. def set_input_embeddings(self, value):
  864. self.model.text_model.set_input_embeddings(value)
  865. def get_image_features(
  866. self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.LongTensor] = None
  867. ):
  868. return self.model.get_image_features(pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask)
  869. @can_return_tuple
  870. @auto_docstring
  871. def forward(
  872. self,
  873. input_ids: Optional[torch.LongTensor] = None,
  874. attention_mask: Optional[torch.Tensor] = None,
  875. position_ids: Optional[torch.LongTensor] = None,
  876. past_key_values: Optional[Cache] = None,
  877. inputs_embeds: Optional[torch.FloatTensor] = None,
  878. pixel_values: Optional[torch.FloatTensor] = None,
  879. pixel_attention_mask: Optional[torch.BoolTensor] = None,
  880. image_hidden_states: Optional[torch.FloatTensor] = None,
  881. labels: Optional[torch.LongTensor] = None,
  882. use_cache: Optional[bool] = None,
  883. output_attentions: Optional[bool] = None,
  884. output_hidden_states: Optional[bool] = None,
  885. return_dict: Optional[bool] = None,
  886. cache_position: Optional[torch.LongTensor] = None,
  887. logits_to_keep: Union[int, torch.Tensor] = 0,
  888. **kwargs: Unpack[TransformersKwargs],
  889. ) -> Union[tuple, Idefics2CausalLMOutputWithPast]:
  890. r"""
  891. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  892. Mask to avoid performing attention on padding pixel indices.
  893. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  894. The hidden states of the image encoder after modality projection and perceiver resampling.
  895. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  896. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  897. config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics2ForConditionalGeneration`).
  898. Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
  899. computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  900. Example:
  901. ```python
  902. >>> import requests
  903. >>> import torch
  904. >>> from PIL import Image
  905. >>> from io import BytesIO
  906. >>> from transformers import AutoProcessor, AutoModelForVision2Seq
  907. >>> from transformers.image_utils import load_image
  908. >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
  909. >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
  910. >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
  911. >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
  912. >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b-base")
  913. >>> model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/idefics2-8b-base", device_map="auto")
  914. >>> BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
  915. >>> EOS_WORDS_IDS = [processor.tokenizer.eos_token_id]
  916. >>> # Create inputs
  917. >>> prompts = [
  918. ... "<image>In this image, we can see the city of New York, and more specifically the Statue of Liberty.<image>In this image,",
  919. ... "In which city is that bridge located?<image>",
  920. ... ]
  921. >>> images = [[image1, image2], [image3]]
  922. >>> inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to("cuda")
  923. >>> # Generate
  924. >>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=20)
  925. >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
  926. >>> print(generated_texts)
  927. ['In this image, we can see the city of New York, and more specifically the Statue of Liberty. In this image, we can see the city of New York, and more specifically the Statue of Liberty.\n\n', 'In which city is that bridge located?\n\nThe bridge is located in the city of Pittsburgh, Pennsylvania.\n\n\nThe bridge is']
  928. ```"""
  929. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  930. output_hidden_states = (
  931. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  932. )
  933. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  934. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  935. outputs = self.model(
  936. input_ids=input_ids,
  937. attention_mask=attention_mask,
  938. position_ids=position_ids,
  939. past_key_values=past_key_values,
  940. inputs_embeds=inputs_embeds,
  941. pixel_values=pixel_values,
  942. pixel_attention_mask=pixel_attention_mask,
  943. image_hidden_states=image_hidden_states,
  944. use_cache=use_cache,
  945. output_attentions=output_attentions,
  946. output_hidden_states=output_hidden_states,
  947. cache_position=cache_position,
  948. return_dict=True,
  949. **kwargs,
  950. )
  951. hidden_states = outputs[0]
  952. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  953. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  954. logits = self.lm_head(hidden_states[:, slice_indices, :])
  955. loss = None
  956. if labels is not None:
  957. loss = self.loss_function(
  958. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  959. )
  960. return Idefics2CausalLMOutputWithPast(
  961. loss=loss,
  962. logits=logits,
  963. past_key_values=outputs.past_key_values,
  964. hidden_states=outputs.hidden_states,
  965. attentions=outputs.attentions,
  966. image_hidden_states=outputs.image_hidden_states,
  967. )
  968. def prepare_inputs_for_generation(
  969. self,
  970. input_ids,
  971. past_key_values=None,
  972. attention_mask=None,
  973. inputs_embeds=None,
  974. cache_position=None,
  975. pixel_values=None,
  976. pixel_attention_mask=None,
  977. image_hidden_states=None,
  978. logits_to_keep=None,
  979. **kwargs,
  980. ):
  981. # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
  982. # precedence is moved to the model, we can remove this fn)
  983. model_inputs = super().prepare_inputs_for_generation(
  984. input_ids,
  985. past_key_values=past_key_values,
  986. attention_mask=attention_mask,
  987. inputs_embeds=inputs_embeds,
  988. cache_position=cache_position,
  989. pixel_values=pixel_values,
  990. pixel_attention_mask=pixel_attention_mask,
  991. image_hidden_states=image_hidden_states,
  992. logits_to_keep=logits_to_keep,
  993. **kwargs,
  994. )
  995. if image_hidden_states is not None or cache_position[0] != 0:
  996. model_inputs["pixel_values"] = None
  997. model_inputs["pixel_attention_mask"] = None
  998. return model_inputs
  999. __all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"]