modular_mistral3.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # coding=utf-8
  2. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import Optional, Union
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache
  21. from ...processing_utils import Unpack
  22. from ...utils import logging
  23. from ..llava.modeling_llava import (
  24. LlavaCausalLMOutputWithPast,
  25. LlavaForConditionalGeneration,
  26. LlavaModel,
  27. LlavaModelOutputWithPast,
  28. LlavaPreTrainedModel,
  29. TransformersKwargs,
  30. )
  31. from ..mistral.modeling_mistral import MistralRMSNorm
  32. from .configuration_mistral3 import Mistral3Config
  33. logger = logging.get_logger(__name__)
  34. class Mistral3RMSNorm(MistralRMSNorm):
  35. pass
  36. class Mistral3PatchMerger(nn.Module):
  37. """
  38. Learned merging of spatial_merge_size ** 2 patches
  39. """
  40. def __init__(self, config: Mistral3Config):
  41. super().__init__()
  42. self.config = config
  43. hidden_size = config.vision_config.hidden_size
  44. self.spatial_merge_size = config.spatial_merge_size
  45. self.patch_size = self.config.vision_config.patch_size
  46. self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
  47. def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
  48. image_sizes = [
  49. (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
  50. ]
  51. tokens_per_image = [h * w for h, w in image_sizes]
  52. d = image_features.shape[-1]
  53. permuted_tensor = []
  54. for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
  55. # Reshape image_tokens into a 2D grid
  56. h, w = image_sizes[image_index]
  57. image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
  58. grid = torch.nn.functional.unfold(
  59. image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
  60. )
  61. grid = grid.view(d * self.spatial_merge_size**2, -1).t()
  62. permuted_tensor.append(grid)
  63. image_features = torch.cat(permuted_tensor, dim=0)
  64. image_features = self.merging_layer(image_features)
  65. return image_features
  66. class Mistral3MultiModalProjector(nn.Module):
  67. def __init__(self, config: Mistral3Config):
  68. super().__init__()
  69. self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps)
  70. self.patch_merger = Mistral3PatchMerger(config)
  71. # We have hidden_size * the number of vision feature layers
  72. num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
  73. self.linear_1 = nn.Linear(
  74. config.vision_config.hidden_size * num_feature_layers,
  75. config.text_config.hidden_size,
  76. bias=config.multimodal_projector_bias,
  77. )
  78. self.act = ACT2FN[config.projector_hidden_act]
  79. self.linear_2 = nn.Linear(
  80. config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
  81. )
  82. def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
  83. image_features = self.norm(image_features)
  84. image_features = self.patch_merger(image_features, image_sizes)
  85. hidden_states = self.linear_1(image_features)
  86. hidden_states = self.act(hidden_states)
  87. hidden_states = self.linear_2(hidden_states)
  88. return hidden_states
  89. class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
  90. pass
  91. class Mistral3ModelOutputWithPast(LlavaModelOutputWithPast):
  92. pass
  93. class Mistral3PreTrainedModel(LlavaPreTrainedModel):
  94. pass
  95. class Mistral3Model(LlavaModel):
  96. def get_image_features(
  97. self,
  98. pixel_values: torch.FloatTensor,
  99. image_sizes: torch.Tensor,
  100. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  101. **kwargs,
  102. ):
  103. """
  104. Obtains image last hidden states from the vision tower and apply multimodal projection.
  105. Args:
  106. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
  107. The tensors corresponding to the input images.
  108. vision_feature_layer (`Union[int, list[int]]`, *optional*):
  109. The index of the layer to select the vision feature. If multiple indices are provided,
  110. the vision feature of the corresponding indices will be concatenated to form the
  111. vision features.
  112. image_sizes (`torch.Tensor`, *optional*):
  113. Tensor containing the image sizes as returned by the processor.
  114. Returns:
  115. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  116. """
  117. vision_feature_layer = (
  118. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  119. )
  120. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  121. # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
  122. image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
  123. # If we have one vision feature layer, return the corresponding hidden states,
  124. # otherwise, select the hidden states of each feature layer and concatenate them
  125. if isinstance(vision_feature_layer, int):
  126. selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
  127. else:
  128. hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
  129. selected_image_feature = torch.cat(hs_pool, dim=-1)
  130. image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
  131. downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
  132. split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes]
  133. image_features = torch.split(image_features.squeeze(0), split_sizes)
  134. return image_features
  135. def forward(
  136. self,
  137. input_ids: Optional[torch.LongTensor] = None,
  138. pixel_values: Optional[torch.FloatTensor] = None,
  139. attention_mask: Optional[torch.Tensor] = None,
  140. position_ids: Optional[torch.LongTensor] = None,
  141. past_key_values: Optional[Cache] = None,
  142. inputs_embeds: Optional[torch.FloatTensor] = None,
  143. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  144. use_cache: Optional[bool] = None,
  145. output_attentions: Optional[bool] = None,
  146. output_hidden_states: Optional[bool] = None,
  147. return_dict: Optional[bool] = None,
  148. cache_position: Optional[torch.LongTensor] = None,
  149. image_sizes: Optional[torch.Tensor] = None,
  150. **kwargs: Unpack[TransformersKwargs],
  151. ) -> Union[tuple, Mistral3ModelOutputWithPast]:
  152. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  153. output_hidden_states = (
  154. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  155. )
  156. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  157. vision_feature_layer = (
  158. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  159. )
  160. if (input_ids is None) ^ (inputs_embeds is not None):
  161. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  162. if inputs_embeds is None:
  163. inputs_embeds = self.get_input_embeddings()(input_ids)
  164. if pixel_values is not None:
  165. image_features = self.get_image_features(
  166. pixel_values=pixel_values,
  167. vision_feature_layer=vision_feature_layer,
  168. image_sizes=image_sizes,
  169. )
  170. image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  171. special_image_mask = self.get_placeholder_mask(
  172. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  173. )
  174. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  175. outputs = self.language_model(
  176. attention_mask=attention_mask,
  177. position_ids=position_ids,
  178. past_key_values=past_key_values,
  179. inputs_embeds=inputs_embeds,
  180. use_cache=use_cache,
  181. output_attentions=output_attentions,
  182. output_hidden_states=output_hidden_states,
  183. return_dict=True,
  184. cache_position=cache_position,
  185. **kwargs,
  186. )
  187. return Mistral3ModelOutputWithPast(
  188. last_hidden_state=outputs.last_hidden_state,
  189. past_key_values=outputs.past_key_values,
  190. hidden_states=outputs.hidden_states,
  191. attentions=outputs.attentions,
  192. image_hidden_states=image_features if pixel_values is not None else None,
  193. )
  194. class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
  195. def get_image_features(
  196. self,
  197. pixel_values: torch.FloatTensor,
  198. image_sizes: torch.Tensor,
  199. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  200. **kwargs,
  201. ):
  202. return self.model.get_image_features(
  203. pixel_values=pixel_values,
  204. image_sizes=image_sizes,
  205. vision_feature_layer=vision_feature_layer,
  206. **kwargs,
  207. )
  208. def forward(
  209. self,
  210. input_ids: Optional[torch.LongTensor] = None,
  211. pixel_values: Optional[torch.FloatTensor] = None,
  212. attention_mask: Optional[torch.Tensor] = None,
  213. position_ids: Optional[torch.LongTensor] = None,
  214. past_key_values: Optional[Cache] = None,
  215. inputs_embeds: Optional[torch.FloatTensor] = None,
  216. labels: Optional[torch.LongTensor] = None,
  217. use_cache: Optional[bool] = None,
  218. output_attentions: Optional[bool] = None,
  219. output_hidden_states: Optional[bool] = None,
  220. return_dict: Optional[bool] = None,
  221. cache_position: Optional[torch.LongTensor] = None,
  222. logits_to_keep: Union[int, torch.Tensor] = 0,
  223. image_sizes: Optional[torch.Tensor] = None,
  224. **kwargs: Unpack[TransformersKwargs],
  225. ) -> Union[tuple, Mistral3CausalLMOutputWithPast]:
  226. r"""
  227. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  228. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  229. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  230. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  231. Example:
  232. ```python
  233. >>> from PIL import Image
  234. >>> import requests
  235. >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
  236. >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
  237. >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
  238. >>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
  239. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  240. >>> image = Image.open(requests.get(url, stream=True).raw)
  241. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  242. >>> # Generate
  243. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  244. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  245. "What is the image?The image depicts two cats lying on a pink blanket."
  246. ```"""
  247. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  248. output_hidden_states = (
  249. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  250. )
  251. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  252. outputs = self.model(
  253. input_ids=input_ids,
  254. pixel_values=pixel_values,
  255. attention_mask=attention_mask,
  256. position_ids=position_ids,
  257. past_key_values=past_key_values,
  258. inputs_embeds=inputs_embeds,
  259. use_cache=use_cache,
  260. output_attentions=output_attentions,
  261. output_hidden_states=output_hidden_states,
  262. return_dict=True,
  263. cache_position=cache_position,
  264. image_sizes=image_sizes,
  265. **kwargs,
  266. )
  267. hidden_states = outputs[0]
  268. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  269. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  270. logits = self.lm_head(hidden_states[:, slice_indices, :])
  271. loss = None
  272. if labels is not None:
  273. loss = self.loss_function(
  274. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  275. )
  276. return Mistral3CausalLMOutputWithPast(
  277. loss=loss,
  278. logits=logits,
  279. past_key_values=outputs.past_key_values,
  280. hidden_states=outputs.hidden_states,
  281. attentions=outputs.attentions,
  282. image_hidden_states=outputs.image_hidden_states,
  283. )
  284. __all__ = [
  285. "Mistral3Model",
  286. "Mistral3PreTrainedModel",
  287. "Mistral3ForConditionalGeneration",
  288. ]