modeling_mistral3.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/mistral3/modular_mistral3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_mistral3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. from dataclasses import dataclass
  23. from typing import Optional, Union
  24. import torch
  25. from torch import nn
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub
  30. from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
  31. from ...modeling_utils import PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  34. from ..auto import AutoModel
  35. from .configuration_mistral3 import Mistral3Config
  36. @use_kernel_forward_from_hub("RMSNorm")
  37. class Mistral3RMSNorm(nn.Module):
  38. def __init__(self, hidden_size, eps=1e-6):
  39. """
  40. Mistral3RMSNorm is equivalent to T5LayerNorm
  41. """
  42. super().__init__()
  43. self.weight = nn.Parameter(torch.ones(hidden_size))
  44. self.variance_epsilon = eps
  45. def forward(self, hidden_states):
  46. input_dtype = hidden_states.dtype
  47. hidden_states = hidden_states.to(torch.float32)
  48. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  49. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  50. return self.weight * hidden_states.to(input_dtype)
  51. def extra_repr(self):
  52. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  53. class Mistral3PatchMerger(nn.Module):
  54. """
  55. Learned merging of spatial_merge_size ** 2 patches
  56. """
  57. def __init__(self, config: Mistral3Config):
  58. super().__init__()
  59. self.config = config
  60. hidden_size = config.vision_config.hidden_size
  61. self.spatial_merge_size = config.spatial_merge_size
  62. self.patch_size = self.config.vision_config.patch_size
  63. self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
  64. def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
  65. image_sizes = [
  66. (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
  67. ]
  68. tokens_per_image = [h * w for h, w in image_sizes]
  69. d = image_features.shape[-1]
  70. permuted_tensor = []
  71. for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
  72. # Reshape image_tokens into a 2D grid
  73. h, w = image_sizes[image_index]
  74. image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
  75. grid = torch.nn.functional.unfold(
  76. image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
  77. )
  78. grid = grid.view(d * self.spatial_merge_size**2, -1).t()
  79. permuted_tensor.append(grid)
  80. image_features = torch.cat(permuted_tensor, dim=0)
  81. image_features = self.merging_layer(image_features)
  82. return image_features
  83. class Mistral3MultiModalProjector(nn.Module):
  84. def __init__(self, config: Mistral3Config):
  85. super().__init__()
  86. self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps)
  87. self.patch_merger = Mistral3PatchMerger(config)
  88. # We have hidden_size * the number of vision feature layers
  89. num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
  90. self.linear_1 = nn.Linear(
  91. config.vision_config.hidden_size * num_feature_layers,
  92. config.text_config.hidden_size,
  93. bias=config.multimodal_projector_bias,
  94. )
  95. self.act = ACT2FN[config.projector_hidden_act]
  96. self.linear_2 = nn.Linear(
  97. config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
  98. )
  99. def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
  100. image_features = self.norm(image_features)
  101. image_features = self.patch_merger(image_features, image_sizes)
  102. hidden_states = self.linear_1(image_features)
  103. hidden_states = self.act(hidden_states)
  104. hidden_states = self.linear_2(hidden_states)
  105. return hidden_states
  106. @dataclass
  107. @auto_docstring(
  108. custom_intro="""
  109. Base class for Mistral3 causal language model (or autoregressive) outputs.
  110. """
  111. )
  112. class Mistral3CausalLMOutputWithPast(ModelOutput):
  113. r"""
  114. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  115. Language modeling loss (for next-token prediction).
  116. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  117. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  118. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  119. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  120. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  121. `past_key_values` input) to speed up sequential decoding.
  122. image_hidden_states (`torch.FloatTensor`, *optional*):
  123. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  124. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  125. """
  126. loss: Optional[torch.FloatTensor] = None
  127. logits: Optional[torch.FloatTensor] = None
  128. past_key_values: Optional[Cache] = None
  129. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  130. attentions: Optional[tuple[torch.FloatTensor]] = None
  131. image_hidden_states: Optional[torch.FloatTensor] = None
  132. @dataclass
  133. @auto_docstring(
  134. custom_intro="""
  135. Base class for Mistral3 outputs, with hidden states and attentions.
  136. """
  137. )
  138. class Mistral3ModelOutputWithPast(BaseModelOutputWithPast):
  139. r"""
  140. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  141. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  142. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  143. `past_key_values` input) to speed up sequential decoding.
  144. image_hidden_states (`torch.FloatTensor`, *optional*):
  145. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  146. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  147. """
  148. image_hidden_states: Optional[torch.FloatTensor] = None
  149. @auto_docstring
  150. class Mistral3PreTrainedModel(PreTrainedModel):
  151. config: Mistral3Config
  152. base_model_prefix = ""
  153. supports_gradient_checkpointing = True
  154. _skip_keys_device_placement = "past_key_values"
  155. _supports_flash_attn = True
  156. _supports_sdpa = True
  157. _can_compile_fullgraph = True
  158. _supports_flex_attn = True
  159. _supports_attention_backend = True
  160. @auto_docstring(
  161. custom_intro="""
  162. The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.
  163. """
  164. )
  165. class Mistral3Model(Mistral3PreTrainedModel):
  166. _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
  167. def __init__(self, config: Mistral3Config):
  168. super().__init__(config)
  169. self.vision_tower = AutoModel.from_config(config.vision_config)
  170. self.multi_modal_projector = Mistral3MultiModalProjector(config)
  171. self.language_model = AutoModel.from_config(config.text_config)
  172. self.post_init()
  173. def get_input_embeddings(self):
  174. return self.language_model.get_input_embeddings()
  175. def set_input_embeddings(self, value):
  176. self.language_model.set_input_embeddings(value)
  177. def set_decoder(self, decoder):
  178. self.language_model = decoder
  179. def get_decoder(self):
  180. return self.language_model
  181. def get_image_features(
  182. self,
  183. pixel_values: torch.FloatTensor,
  184. image_sizes: torch.Tensor,
  185. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  186. **kwargs,
  187. ):
  188. """
  189. Obtains image last hidden states from the vision tower and apply multimodal projection.
  190. Args:
  191. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
  192. The tensors corresponding to the input images.
  193. vision_feature_layer (`Union[int, list[int]]`, *optional*):
  194. The index of the layer to select the vision feature. If multiple indices are provided,
  195. the vision feature of the corresponding indices will be concatenated to form the
  196. vision features.
  197. image_sizes (`torch.Tensor`, *optional*):
  198. Tensor containing the image sizes as returned by the processor.
  199. Returns:
  200. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  201. """
  202. vision_feature_layer = (
  203. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  204. )
  205. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  206. # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
  207. image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
  208. # If we have one vision feature layer, return the corresponding hidden states,
  209. # otherwise, select the hidden states of each feature layer and concatenate them
  210. if isinstance(vision_feature_layer, int):
  211. selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
  212. else:
  213. hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
  214. selected_image_feature = torch.cat(hs_pool, dim=-1)
  215. image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
  216. downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
  217. split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes]
  218. image_features = torch.split(image_features.squeeze(0), split_sizes)
  219. return image_features
  220. def get_placeholder_mask(
  221. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  222. ):
  223. """
  224. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  225. equal to the length of multimodal features. If the lengths are different, an error is raised.
  226. """
  227. if input_ids is None:
  228. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  229. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  230. )
  231. special_image_mask = special_image_mask.all(-1)
  232. else:
  233. special_image_mask = input_ids == self.config.image_token_id
  234. n_image_tokens = special_image_mask.sum()
  235. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  236. n_image_features = image_features.shape[0] * image_features.shape[1]
  237. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  238. raise ValueError(
  239. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  240. )
  241. return special_image_mask
  242. @can_return_tuple
  243. @auto_docstring
  244. def forward(
  245. self,
  246. input_ids: Optional[torch.LongTensor] = None,
  247. pixel_values: Optional[torch.FloatTensor] = None,
  248. attention_mask: Optional[torch.Tensor] = None,
  249. position_ids: Optional[torch.LongTensor] = None,
  250. past_key_values: Optional[Cache] = None,
  251. inputs_embeds: Optional[torch.FloatTensor] = None,
  252. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  253. use_cache: Optional[bool] = None,
  254. output_attentions: Optional[bool] = None,
  255. output_hidden_states: Optional[bool] = None,
  256. return_dict: Optional[bool] = None,
  257. cache_position: Optional[torch.LongTensor] = None,
  258. image_sizes: Optional[torch.Tensor] = None,
  259. **kwargs: Unpack[TransformersKwargs],
  260. ) -> Union[tuple, Mistral3ModelOutputWithPast]:
  261. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  262. output_hidden_states = (
  263. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  264. )
  265. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  266. vision_feature_layer = (
  267. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  268. )
  269. if (input_ids is None) ^ (inputs_embeds is not None):
  270. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  271. if inputs_embeds is None:
  272. inputs_embeds = self.get_input_embeddings()(input_ids)
  273. if pixel_values is not None:
  274. image_features = self.get_image_features(
  275. pixel_values=pixel_values,
  276. vision_feature_layer=vision_feature_layer,
  277. image_sizes=image_sizes,
  278. )
  279. image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  280. special_image_mask = self.get_placeholder_mask(
  281. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  282. )
  283. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  284. outputs = self.language_model(
  285. attention_mask=attention_mask,
  286. position_ids=position_ids,
  287. past_key_values=past_key_values,
  288. inputs_embeds=inputs_embeds,
  289. use_cache=use_cache,
  290. output_attentions=output_attentions,
  291. output_hidden_states=output_hidden_states,
  292. return_dict=True,
  293. cache_position=cache_position,
  294. **kwargs,
  295. )
  296. return Mistral3ModelOutputWithPast(
  297. last_hidden_state=outputs.last_hidden_state,
  298. past_key_values=outputs.past_key_values,
  299. hidden_states=outputs.hidden_states,
  300. attentions=outputs.attentions,
  301. image_hidden_states=image_features if pixel_values is not None else None,
  302. )
  303. @auto_docstring(
  304. custom_intro="""
  305. The MISTRAL3 model which consists of a vision backbone and a language model.
  306. """
  307. )
  308. class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin):
  309. _checkpoint_conversion_mapping = {
  310. "^language_model.model": "model.language_model",
  311. "^vision_tower": "model.vision_tower",
  312. "^multi_modal_projector": "model.multi_modal_projector",
  313. "^language_model.lm_head": "lm_head",
  314. }
  315. _tied_weights_keys = ["lm_head.weight"]
  316. def __init__(self, config: Mistral3Config):
  317. super().__init__(config)
  318. self.model = Mistral3Model(config)
  319. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  320. self.post_init()
  321. def get_input_embeddings(self):
  322. return self.model.get_input_embeddings()
  323. def set_input_embeddings(self, value):
  324. self.model.set_input_embeddings(value)
  325. def get_output_embeddings(self) -> nn.Module:
  326. return self.lm_head
  327. def set_decoder(self, decoder):
  328. self.model.set_decoder(decoder)
  329. def get_decoder(self):
  330. return self.model.get_decoder()
  331. def get_image_features(
  332. self,
  333. pixel_values: torch.FloatTensor,
  334. image_sizes: torch.Tensor,
  335. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  336. **kwargs,
  337. ):
  338. return self.model.get_image_features(
  339. pixel_values=pixel_values,
  340. image_sizes=image_sizes,
  341. vision_feature_layer=vision_feature_layer,
  342. **kwargs,
  343. )
  344. # Make modules available through conditional class for BC
  345. @property
  346. def language_model(self):
  347. return self.model.language_model
  348. @property
  349. def vision_tower(self):
  350. return self.model.vision_tower
  351. @property
  352. def multi_modal_projector(self):
  353. return self.model.multi_modal_projector
  354. @can_return_tuple
  355. @auto_docstring
  356. def forward(
  357. self,
  358. input_ids: Optional[torch.LongTensor] = None,
  359. pixel_values: Optional[torch.FloatTensor] = None,
  360. attention_mask: Optional[torch.Tensor] = None,
  361. position_ids: Optional[torch.LongTensor] = None,
  362. past_key_values: Optional[Cache] = None,
  363. inputs_embeds: Optional[torch.FloatTensor] = None,
  364. labels: Optional[torch.LongTensor] = None,
  365. use_cache: Optional[bool] = None,
  366. output_attentions: Optional[bool] = None,
  367. output_hidden_states: Optional[bool] = None,
  368. return_dict: Optional[bool] = None,
  369. cache_position: Optional[torch.LongTensor] = None,
  370. logits_to_keep: Union[int, torch.Tensor] = 0,
  371. image_sizes: Optional[torch.Tensor] = None,
  372. **kwargs: Unpack[TransformersKwargs],
  373. ) -> Union[tuple, Mistral3CausalLMOutputWithPast]:
  374. r"""
  375. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  376. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  377. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  378. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  379. Example:
  380. ```python
  381. >>> from PIL import Image
  382. >>> import requests
  383. >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
  384. >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
  385. >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
  386. >>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
  387. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  388. >>> image = Image.open(requests.get(url, stream=True).raw)
  389. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  390. >>> # Generate
  391. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  392. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  393. "What is the image?The image depicts two cats lying on a pink blanket."
  394. ```"""
  395. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  396. output_hidden_states = (
  397. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  398. )
  399. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  400. outputs = self.model(
  401. input_ids=input_ids,
  402. pixel_values=pixel_values,
  403. attention_mask=attention_mask,
  404. position_ids=position_ids,
  405. past_key_values=past_key_values,
  406. inputs_embeds=inputs_embeds,
  407. use_cache=use_cache,
  408. output_attentions=output_attentions,
  409. output_hidden_states=output_hidden_states,
  410. return_dict=True,
  411. cache_position=cache_position,
  412. image_sizes=image_sizes,
  413. **kwargs,
  414. )
  415. hidden_states = outputs[0]
  416. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  417. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  418. logits = self.lm_head(hidden_states[:, slice_indices, :])
  419. loss = None
  420. if labels is not None:
  421. loss = self.loss_function(
  422. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  423. )
  424. return Mistral3CausalLMOutputWithPast(
  425. loss=loss,
  426. logits=logits,
  427. past_key_values=outputs.past_key_values,
  428. hidden_states=outputs.hidden_states,
  429. attentions=outputs.attentions,
  430. image_hidden_states=outputs.image_hidden_states,
  431. )
  432. def prepare_inputs_for_generation(
  433. self,
  434. input_ids,
  435. past_key_values=None,
  436. inputs_embeds=None,
  437. pixel_values=None,
  438. attention_mask=None,
  439. cache_position=None,
  440. logits_to_keep=None,
  441. **kwargs,
  442. ):
  443. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  444. model_inputs = super().prepare_inputs_for_generation(
  445. input_ids,
  446. past_key_values=past_key_values,
  447. inputs_embeds=inputs_embeds,
  448. attention_mask=attention_mask,
  449. cache_position=cache_position,
  450. logits_to_keep=logits_to_keep,
  451. **kwargs,
  452. )
  453. if cache_position[0] == 0:
  454. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  455. # Otherwise we need pixel values to be passed to model
  456. model_inputs["pixel_values"] = pixel_values
  457. return model_inputs
  458. __all__ = ["Mistral3Model", "Mistral3PreTrainedModel", "Mistral3ForConditionalGeneration"]