modeling_pixtral.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # coding=utf-8
  2. # Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Pixtral model."""
  16. from collections.abc import Callable
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput
  24. from ...modeling_rope_utils import dynamic_rope_update
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import auto_docstring, can_return_tuple, logging
  28. from .configuration_pixtral import PixtralVisionConfig
  29. logger = logging.get_logger(__name__)
  30. def position_ids_in_meshgrid(patch_embeds_list, max_width):
  31. positions = []
  32. for patch in patch_embeds_list:
  33. height, width = patch.shape[-2:]
  34. mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
  35. h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
  36. ids = h_grid * max_width + v_grid
  37. positions.append(ids[:, 0])
  38. return torch.cat(positions)
  39. class PixtralRotaryEmbedding(nn.Module):
  40. """
  41. The key with pixtral embedding is just that you have a frequency for each pixel positions.
  42. If you have height x width pixels (or embedding pixels), then the frequency used for ROPE
  43. is given by indexing the pre_computed frequency on the width and height.
  44. What you output is of dimension (batch, height * width, dim) with dim the embed dim.
  45. This simply means that for each image hidden state, you are going to add
  46. a corresponding positional embedding, based on its index in the grid.
  47. """
  48. inv_freq: torch.Tensor # fix linting for `register_buffer`
  49. def __init__(self, config, device=None):
  50. super().__init__()
  51. self.rope_type = "default"
  52. self.dim = config.head_dim
  53. self.base = config.rope_theta
  54. max_patches_per_side = config.image_size // config.patch_size
  55. freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
  56. h = torch.arange(max_patches_per_side, device=freqs.device)
  57. w = torch.arange(max_patches_per_side, device=freqs.device)
  58. freqs_h = torch.outer(h, freqs[::2]).float()
  59. freqs_w = torch.outer(w, freqs[1::2]).float()
  60. inv_freq = torch.cat(
  61. [
  62. freqs_h[:, None, :].repeat(1, max_patches_per_side, 1),
  63. freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1),
  64. ],
  65. dim=-1,
  66. ).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes
  67. # Different from paper, but it uses a different permutation in order to obtain the same calculation
  68. # TODO maybe make it torch compatible later on. We can also just slice
  69. self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False)
  70. @torch.no_grad()
  71. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  72. def forward(self, x, position_ids):
  73. freqs = self.inv_freq[position_ids]
  74. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  75. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  76. emb = freqs
  77. cos = emb.cos()
  78. sin = emb.sin()
  79. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  80. # Copied from transformers.models.llama.modeling_llama.rotate_half
  81. def rotate_half(x):
  82. """Rotates half the hidden dims of the input."""
  83. x1 = x[..., : x.shape[-1] // 2]
  84. x2 = x[..., x.shape[-1] // 2 :]
  85. return torch.cat((-x2, x1), dim=-1)
  86. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  87. """Applies Rotary Position Embedding to the query and key tensors.
  88. Args:
  89. q (`torch.Tensor`): The query tensor.
  90. k (`torch.Tensor`): The key tensor.
  91. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  92. sin (`torch.Tensor`): The sine part of the rotary embedding.
  93. position_ids (`torch.Tensor`, *optional*):
  94. Deprecated and unused.
  95. unsqueeze_dim (`int`, *optional*, defaults to 1):
  96. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  97. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  98. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  99. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  100. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  101. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  102. Returns:
  103. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  104. """
  105. cos = cos.unsqueeze(unsqueeze_dim)
  106. sin = sin.unsqueeze(unsqueeze_dim)
  107. q_embed = (q * cos) + (rotate_half(q) * sin)
  108. k_embed = (k * cos) + (rotate_half(k) * sin)
  109. return q_embed, k_embed
  110. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  111. def eager_attention_forward(
  112. module: nn.Module,
  113. query: torch.Tensor,
  114. key: torch.Tensor,
  115. value: torch.Tensor,
  116. attention_mask: Optional[torch.Tensor],
  117. scaling: float,
  118. dropout: float = 0.0,
  119. **kwargs,
  120. ):
  121. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  122. if attention_mask is not None:
  123. attn_weights = attn_weights + attention_mask
  124. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  125. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  126. attn_output = torch.matmul(attn_weights, value)
  127. attn_output = attn_output.transpose(1, 2).contiguous()
  128. return attn_output, attn_weights
  129. class PixtralAttention(nn.Module):
  130. """
  131. Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS.
  132. """
  133. def __init__(self, config):
  134. super().__init__()
  135. self.config = config
  136. self.embed_dim = config.hidden_size
  137. self.num_heads = config.num_attention_heads
  138. self.head_dim = self.embed_dim // self.num_heads
  139. self.is_causal = False
  140. self.scaling = self.head_dim**-0.5
  141. self.is_causal = False
  142. self.dropout = config.attention_dropout
  143. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  144. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  145. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  146. self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  147. def forward(
  148. self,
  149. hidden_states: torch.Tensor,
  150. attention_mask: Optional[torch.Tensor] = None,
  151. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  152. output_attentions: Optional[bool] = False,
  153. **kwargs: Unpack[FlashAttentionKwargs],
  154. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  155. """Input shape: Batch x Time x Channel"""
  156. batch_size, patches, _ = hidden_states.size()
  157. query_states = self.q_proj(hidden_states)
  158. key_states = self.k_proj(hidden_states)
  159. value_states = self.v_proj(hidden_states)
  160. query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
  161. key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
  162. value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
  163. cos, sin = position_embeddings
  164. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0)
  165. attention_interface: Callable = eager_attention_forward
  166. if self.config._attn_implementation != "eager":
  167. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  168. # Since we use packing, if flash_attention_2 is selected we rely on position_ids
  169. if self.config._attn_implementation == "flash_attention_2":
  170. kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True)
  171. attn_output, attn_weights = attention_interface(
  172. self,
  173. query_states,
  174. key_states,
  175. value_states,
  176. attention_mask,
  177. dropout=0.0 if not self.training else self.dropout,
  178. scaling=self.scaling,
  179. **kwargs,
  180. )
  181. attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
  182. attn_output = self.o_proj(attn_output)
  183. if not output_attentions:
  184. attn_weights = None
  185. return attn_output, attn_weights
  186. # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral
  187. class PixtralMLP(nn.Module):
  188. def __init__(self, config):
  189. super().__init__()
  190. self.config = config
  191. self.hidden_size = config.hidden_size
  192. self.intermediate_size = config.intermediate_size
  193. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  194. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  195. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  196. self.act_fn = ACT2FN[config.hidden_act]
  197. def forward(self, x):
  198. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  199. return down_proj
  200. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral
  201. class PixtralRMSNorm(nn.Module):
  202. def __init__(self, hidden_size, eps=1e-6):
  203. """
  204. PixtralRMSNorm is equivalent to T5LayerNorm
  205. """
  206. super().__init__()
  207. self.weight = nn.Parameter(torch.ones(hidden_size))
  208. self.variance_epsilon = eps
  209. def forward(self, hidden_states):
  210. input_dtype = hidden_states.dtype
  211. hidden_states = hidden_states.to(torch.float32)
  212. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  213. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  214. return self.weight * hidden_states.to(input_dtype)
  215. def extra_repr(self):
  216. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  217. class PixtralAttentionLayer(GradientCheckpointingLayer):
  218. def __init__(self, config):
  219. super().__init__()
  220. self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
  221. self.feed_forward = PixtralMLP(config)
  222. self.attention = PixtralAttention(config)
  223. self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
  224. def forward(
  225. self,
  226. hidden_states: torch.Tensor,
  227. attention_mask: torch.Tensor,
  228. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  229. output_attentions: Optional[bool] = None,
  230. **kwargs: Unpack[FlashAttentionKwargs],
  231. ) -> tuple[torch.FloatTensor]:
  232. """
  233. Args:
  234. hidden_states (`torch.FloatTensor`):
  235. Input to the layer of shape `(batch, seq_len, embed_dim)`.
  236. attention_mask (`torch.FloatTensor`):
  237. Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
  238. output_attentions (`bool`, *optional*, defaults to `False`):
  239. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  240. returned tensors for more detail.
  241. """
  242. residual = hidden_states
  243. hidden_states = self.attention_norm(hidden_states)
  244. hidden_states, attn_weights = self.attention(
  245. hidden_states=hidden_states,
  246. attention_mask=attention_mask,
  247. position_embeddings=position_embeddings,
  248. output_attentions=output_attentions,
  249. **kwargs,
  250. )
  251. hidden_states = residual + hidden_states
  252. residual = hidden_states
  253. hidden_states = self.ffn_norm(hidden_states)
  254. hidden_states = self.feed_forward(hidden_states)
  255. hidden_states = residual + hidden_states
  256. outputs = (hidden_states,)
  257. if output_attentions:
  258. outputs += (attn_weights,)
  259. return outputs
  260. class PixtralTransformer(nn.Module):
  261. def __init__(self, config):
  262. super().__init__()
  263. self.config = config
  264. self.layers = torch.nn.ModuleList()
  265. for _ in range(config.num_hidden_layers):
  266. self.layers.append(PixtralAttentionLayer(config))
  267. self.gradient_checkpointing = False
  268. def forward(
  269. self,
  270. inputs_embeds,
  271. attention_mask: Optional[torch.Tensor] = None,
  272. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  273. output_attentions: Optional[bool] = None,
  274. output_hidden_states: Optional[bool] = None,
  275. return_dict: Optional[bool] = None,
  276. **kwargs: Unpack[FlashAttentionKwargs],
  277. ) -> Union[tuple, BaseModelOutput]:
  278. r"""
  279. Args:
  280. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  281. Embeddings which serve as input to the Transformer.
  282. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  283. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  284. - 1 for tokens that are **not masked**,
  285. - 0 for tokens that are **masked**.
  286. [What are attention masks?](../glossary#attention-mask)
  287. output_attentions (`bool`, *optional*):
  288. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  289. returned tensors for more detail.
  290. output_hidden_states (`bool`, *optional*):
  291. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  292. for more detail.
  293. return_dict (`bool`, *optional*):
  294. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  295. """
  296. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  297. output_hidden_states = (
  298. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  299. )
  300. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  301. encoder_states = () if output_hidden_states else None
  302. all_attentions = () if output_attentions else None
  303. hidden_states = inputs_embeds
  304. for encoder_layer in self.layers:
  305. if output_hidden_states:
  306. encoder_states = encoder_states + (hidden_states,)
  307. layer_outputs = encoder_layer(
  308. hidden_states,
  309. attention_mask,
  310. position_embeddings=position_embeddings,
  311. output_attentions=output_attentions,
  312. **kwargs,
  313. )
  314. hidden_states = layer_outputs[0]
  315. if output_attentions:
  316. all_attentions = all_attentions + (layer_outputs[1],)
  317. if output_hidden_states:
  318. encoder_states = encoder_states + (hidden_states,)
  319. if not return_dict:
  320. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  321. return BaseModelOutput(
  322. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  323. )
  324. @auto_docstring
  325. class PixtralPreTrainedModel(PreTrainedModel):
  326. config: PixtralVisionConfig
  327. base_model_prefix = "model"
  328. main_input_name = "pixel_values"
  329. supports_gradient_checkpointing = True
  330. _supports_attention_backend = True
  331. _supports_flash_attn = True
  332. _supports_sdpa = True
  333. _supports_flex_attn = True
  334. _no_split_modules = ["PixtralAttentionLayer"]
  335. _supports_flash_attn = True
  336. _supports_sdpa = True
  337. _supports_flex_attn = True
  338. _supports_attention_backend = True
  339. def _init_weights(self, module):
  340. std = self.config.initializer_range
  341. if isinstance(module, (nn.Linear, nn.Conv2d)):
  342. module.weight.data.normal_(mean=0.0, std=std)
  343. if module.bias is not None:
  344. module.bias.data.zero_()
  345. elif isinstance(module, PixtralRMSNorm):
  346. module.weight.data.fill_(1.0)
  347. def generate_block_attention_mask(patch_embeds_list, tensor):
  348. dtype = tensor.dtype
  349. device = tensor.device
  350. seq_len = tensor.shape[1]
  351. d_min = torch.finfo(dtype).min
  352. causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)
  353. block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
  354. block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
  355. for start, end in zip(block_start_idx, block_end_idx):
  356. causal_mask[start:end, start:end] = 0
  357. causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
  358. return causal_mask
  359. @auto_docstring
  360. class PixtralVisionModel(PixtralPreTrainedModel):
  361. base_model_prefix = "vision_encoder"
  362. def __init__(self, config):
  363. super().__init__(config)
  364. self.config = config
  365. self.patch_conv = nn.Conv2d(
  366. in_channels=config.num_channels,
  367. out_channels=config.hidden_size,
  368. kernel_size=config.patch_size,
  369. stride=config.patch_size,
  370. bias=False,
  371. )
  372. self.patch_size = config.patch_size
  373. self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5)
  374. self.transformer = PixtralTransformer(config)
  375. self.patch_positional_embedding = PixtralRotaryEmbedding(config)
  376. self.post_init()
  377. def get_input_embeddings(self):
  378. return self.patch_conv
  379. @can_return_tuple
  380. @auto_docstring
  381. def forward(
  382. self,
  383. pixel_values: torch.Tensor,
  384. image_sizes: Optional[torch.Tensor] = None,
  385. output_hidden_states: Optional[bool] = None,
  386. output_attentions: Optional[bool] = None,
  387. return_dict: Optional[bool] = None,
  388. *args,
  389. **kwargs: Unpack[FlashAttentionKwargs],
  390. ) -> Union[tuple, BaseModelOutput]:
  391. if image_sizes is None:
  392. batch_size, _, height, width = pixel_values.shape
  393. image_sizes = [(height, width)] * batch_size
  394. # pass images through initial convolution independently
  395. patch_embeds = self.patch_conv(pixel_values)
  396. patch_embeds_list = [
  397. embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)]
  398. for embed, size in zip(patch_embeds, image_sizes)
  399. ]
  400. # flatten to a single sequence
  401. patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
  402. patch_embeds = self.ln_pre(patch_embeds)
  403. # positional embeddings
  404. position_ids = position_ids_in_meshgrid(
  405. patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
  406. )
  407. kwargs["position_ids"] = position_ids
  408. position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids)
  409. if self.config._attn_implementation == "flash_attention_2":
  410. # We only rely on position_ids when using flash_attention_2
  411. attention_mask = None
  412. else:
  413. attention_mask = generate_block_attention_mask(
  414. [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
  415. )
  416. return self.transformer(
  417. patch_embeds,
  418. attention_mask=attention_mask,
  419. position_embeddings=position_embeddings,
  420. output_hidden_states=output_hidden_states,
  421. output_attentions=output_attentions,
  422. return_dict=True,
  423. **kwargs,
  424. )
  425. __all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"]