modeling_mistral.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/mistral/modular_mistral.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_mistral.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. from typing import Callable, Optional, Union
  8. import torch
  9. from torch import nn
  10. from transformers.utils.generic import check_model_inputs
  11. from ...activations import ACT2FN
  12. from ...cache_utils import Cache, DynamicCache
  13. from ...generation import GenerationMixin
  14. from ...integrations import use_kernel_forward_from_hub
  15. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  16. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  17. from ...modeling_layers import (
  18. GenericForQuestionAnswering,
  19. GenericForSequenceClassification,
  20. GenericForTokenClassification,
  21. GradientCheckpointingLayer,
  22. )
  23. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  24. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  28. from ...utils.deprecation import deprecate_kwarg
  29. from .configuration_mistral import MistralConfig
  30. class MistralMLP(nn.Module):
  31. def __init__(self, config):
  32. super().__init__()
  33. self.config = config
  34. self.hidden_size = config.hidden_size
  35. self.intermediate_size = config.intermediate_size
  36. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  37. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  38. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  39. self.act_fn = ACT2FN[config.hidden_act]
  40. def forward(self, x):
  41. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  42. return down_proj
  43. def rotate_half(x):
  44. """Rotates half the hidden dims of the input."""
  45. x1 = x[..., : x.shape[-1] // 2]
  46. x2 = x[..., x.shape[-1] // 2 :]
  47. return torch.cat((-x2, x1), dim=-1)
  48. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  49. """Applies Rotary Position Embedding to the query and key tensors.
  50. Args:
  51. q (`torch.Tensor`): The query tensor.
  52. k (`torch.Tensor`): The key tensor.
  53. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  54. sin (`torch.Tensor`): The sine part of the rotary embedding.
  55. position_ids (`torch.Tensor`, *optional*):
  56. Deprecated and unused.
  57. unsqueeze_dim (`int`, *optional*, defaults to 1):
  58. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  59. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  60. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  61. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  62. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  63. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  64. Returns:
  65. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  66. """
  67. cos = cos.unsqueeze(unsqueeze_dim)
  68. sin = sin.unsqueeze(unsqueeze_dim)
  69. q_embed = (q * cos) + (rotate_half(q) * sin)
  70. k_embed = (k * cos) + (rotate_half(k) * sin)
  71. return q_embed, k_embed
  72. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  73. """
  74. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  75. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  76. """
  77. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  78. if n_rep == 1:
  79. return hidden_states
  80. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  81. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  82. def eager_attention_forward(
  83. module: nn.Module,
  84. query: torch.Tensor,
  85. key: torch.Tensor,
  86. value: torch.Tensor,
  87. attention_mask: Optional[torch.Tensor],
  88. scaling: float,
  89. dropout: float = 0.0,
  90. **kwargs: Unpack[TransformersKwargs],
  91. ):
  92. key_states = repeat_kv(key, module.num_key_value_groups)
  93. value_states = repeat_kv(value, module.num_key_value_groups)
  94. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  95. if attention_mask is not None:
  96. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  97. attn_weights = attn_weights + causal_mask
  98. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  99. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  100. attn_output = torch.matmul(attn_weights, value_states)
  101. attn_output = attn_output.transpose(1, 2).contiguous()
  102. return attn_output, attn_weights
  103. class MistralAttention(nn.Module):
  104. """Multi-headed attention from 'Attention Is All You Need' paper"""
  105. def __init__(self, config: MistralConfig, layer_idx: int):
  106. super().__init__()
  107. self.config = config
  108. self.layer_idx = layer_idx
  109. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  110. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  111. self.scaling = self.head_dim**-0.5
  112. self.attention_dropout = config.attention_dropout
  113. self.is_causal = True
  114. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  115. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  116. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  117. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  118. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  119. def forward(
  120. self,
  121. hidden_states: torch.Tensor,
  122. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  123. attention_mask: Optional[torch.Tensor],
  124. past_key_values: Optional[Cache] = None,
  125. cache_position: Optional[torch.LongTensor] = None,
  126. **kwargs: Unpack[FlashAttentionKwargs],
  127. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  128. input_shape = hidden_states.shape[:-1]
  129. hidden_shape = (*input_shape, -1, self.head_dim)
  130. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  131. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  132. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  133. cos, sin = position_embeddings
  134. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  135. if past_key_values is not None:
  136. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  137. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  138. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  139. attention_interface: Callable = eager_attention_forward
  140. if self.config._attn_implementation != "eager":
  141. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  142. attn_output, attn_weights = attention_interface(
  143. self,
  144. query_states,
  145. key_states,
  146. value_states,
  147. attention_mask,
  148. dropout=0.0 if not self.training else self.attention_dropout,
  149. scaling=self.scaling,
  150. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  151. **kwargs,
  152. )
  153. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  154. attn_output = self.o_proj(attn_output)
  155. return attn_output, attn_weights
  156. @use_kernel_forward_from_hub("RMSNorm")
  157. class MistralRMSNorm(nn.Module):
  158. def __init__(self, hidden_size, eps=1e-6):
  159. """
  160. MistralRMSNorm is equivalent to T5LayerNorm
  161. """
  162. super().__init__()
  163. self.weight = nn.Parameter(torch.ones(hidden_size))
  164. self.variance_epsilon = eps
  165. def forward(self, hidden_states):
  166. input_dtype = hidden_states.dtype
  167. hidden_states = hidden_states.to(torch.float32)
  168. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  169. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  170. return self.weight * hidden_states.to(input_dtype)
  171. def extra_repr(self):
  172. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  173. class MistralDecoderLayer(GradientCheckpointingLayer):
  174. def __init__(self, config: MistralConfig, layer_idx: int):
  175. super().__init__()
  176. self.hidden_size = config.hidden_size
  177. self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
  178. self.mlp = MistralMLP(config)
  179. self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  180. self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  181. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  182. def forward(
  183. self,
  184. hidden_states: torch.Tensor,
  185. attention_mask: Optional[torch.Tensor] = None,
  186. position_ids: Optional[torch.LongTensor] = None,
  187. past_key_values: Optional[Cache] = None,
  188. use_cache: Optional[bool] = False,
  189. cache_position: Optional[torch.LongTensor] = None,
  190. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  191. **kwargs: Unpack[TransformersKwargs],
  192. ) -> torch.Tensor:
  193. residual = hidden_states
  194. hidden_states = self.input_layernorm(hidden_states)
  195. # Self Attention
  196. hidden_states, _ = self.self_attn(
  197. hidden_states=hidden_states,
  198. attention_mask=attention_mask,
  199. position_ids=position_ids,
  200. past_key_values=past_key_values,
  201. use_cache=use_cache,
  202. cache_position=cache_position,
  203. position_embeddings=position_embeddings,
  204. **kwargs,
  205. )
  206. hidden_states = residual + hidden_states
  207. # Fully Connected
  208. residual = hidden_states
  209. hidden_states = self.post_attention_layernorm(hidden_states)
  210. hidden_states = self.mlp(hidden_states)
  211. hidden_states = residual + hidden_states
  212. return hidden_states
  213. @auto_docstring
  214. class MistralPreTrainedModel(PreTrainedModel):
  215. config: MistralConfig
  216. base_model_prefix = "model"
  217. supports_gradient_checkpointing = True
  218. _no_split_modules = ["MistralDecoderLayer"]
  219. _skip_keys_device_placement = ["past_key_values"]
  220. _supports_flash_attn = True
  221. _supports_sdpa = True
  222. _supports_flex_attn = True
  223. _can_compile_fullgraph = True
  224. _supports_attention_backend = True
  225. _can_record_outputs = {
  226. "hidden_states": MistralDecoderLayer,
  227. "attentions": MistralAttention,
  228. }
  229. class MistralRotaryEmbedding(nn.Module):
  230. inv_freq: torch.Tensor # fix linting for `register_buffer`
  231. def __init__(self, config: MistralConfig, device=None):
  232. super().__init__()
  233. # BC: "rope_type" was originally "type"
  234. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  235. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  236. else:
  237. self.rope_type = "default"
  238. self.max_seq_len_cached = config.max_position_embeddings
  239. self.original_max_seq_len = config.max_position_embeddings
  240. self.config = config
  241. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  242. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  243. self.register_buffer("inv_freq", inv_freq, persistent=False)
  244. self.original_inv_freq = self.inv_freq
  245. @torch.no_grad()
  246. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  247. def forward(self, x, position_ids):
  248. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  249. position_ids_expanded = position_ids[:, None, :].float()
  250. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  251. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  252. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  253. emb = torch.cat((freqs, freqs), dim=-1)
  254. cos = emb.cos() * self.attention_scaling
  255. sin = emb.sin() * self.attention_scaling
  256. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  257. @auto_docstring
  258. class MistralModel(MistralPreTrainedModel):
  259. def __init__(self, config: MistralConfig):
  260. super().__init__(config)
  261. self.padding_idx = config.pad_token_id
  262. self.vocab_size = config.vocab_size
  263. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  264. self.layers = nn.ModuleList(
  265. [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  266. )
  267. self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  268. self.rotary_emb = MistralRotaryEmbedding(config=config)
  269. self.gradient_checkpointing = False
  270. # Initialize weights and apply final processing
  271. self.post_init()
  272. @check_model_inputs()
  273. @auto_docstring
  274. def forward(
  275. self,
  276. input_ids: Optional[torch.LongTensor] = None,
  277. attention_mask: Optional[torch.Tensor] = None,
  278. position_ids: Optional[torch.LongTensor] = None,
  279. past_key_values: Optional[Cache] = None,
  280. inputs_embeds: Optional[torch.FloatTensor] = None,
  281. use_cache: Optional[bool] = None,
  282. cache_position: Optional[torch.LongTensor] = None,
  283. **kwargs: Unpack[TransformersKwargs],
  284. ) -> BaseModelOutputWithPast:
  285. if (input_ids is None) ^ (inputs_embeds is not None):
  286. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  287. if inputs_embeds is None:
  288. inputs_embeds = self.embed_tokens(input_ids)
  289. if use_cache and past_key_values is None:
  290. past_key_values = DynamicCache(config=self.config)
  291. if cache_position is None:
  292. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  293. cache_position = torch.arange(
  294. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  295. )
  296. if position_ids is None:
  297. position_ids = cache_position.unsqueeze(0)
  298. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  299. causal_mask = mask_function(
  300. config=self.config,
  301. input_embeds=inputs_embeds,
  302. attention_mask=attention_mask,
  303. cache_position=cache_position,
  304. past_key_values=past_key_values,
  305. position_ids=position_ids,
  306. )
  307. hidden_states = inputs_embeds
  308. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  309. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  310. hidden_states = decoder_layer(
  311. hidden_states,
  312. attention_mask=causal_mask,
  313. position_ids=position_ids,
  314. past_key_values=past_key_values,
  315. use_cache=use_cache,
  316. cache_position=cache_position,
  317. position_embeddings=position_embeddings,
  318. **kwargs,
  319. )
  320. hidden_states = self.norm(hidden_states)
  321. return BaseModelOutputWithPast(
  322. last_hidden_state=hidden_states,
  323. past_key_values=past_key_values if use_cache else None,
  324. )
  325. @auto_docstring
  326. class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
  327. _tied_weights_keys = ["lm_head.weight"]
  328. _tp_plan = {"lm_head": "colwise_rep"}
  329. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  330. def __init__(self, config):
  331. super().__init__(config)
  332. self.model = MistralModel(config)
  333. self.vocab_size = config.vocab_size
  334. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  335. # Initialize weights and apply final processing
  336. self.post_init()
  337. @can_return_tuple
  338. @auto_docstring
  339. def forward(
  340. self,
  341. input_ids: Optional[torch.LongTensor] = None,
  342. attention_mask: Optional[torch.Tensor] = None,
  343. position_ids: Optional[torch.LongTensor] = None,
  344. past_key_values: Optional[Cache] = None,
  345. inputs_embeds: Optional[torch.FloatTensor] = None,
  346. labels: Optional[torch.LongTensor] = None,
  347. use_cache: Optional[bool] = None,
  348. cache_position: Optional[torch.LongTensor] = None,
  349. logits_to_keep: Union[int, torch.Tensor] = 0,
  350. **kwargs: Unpack[TransformersKwargs],
  351. ) -> CausalLMOutputWithPast:
  352. r"""
  353. Example:
  354. ```python
  355. >>> from transformers import AutoTokenizer, MistralForCausalLM
  356. >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
  357. >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
  358. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  359. >>> inputs = tokenizer(prompt, return_tensors="pt")
  360. >>> # Generate
  361. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  362. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  363. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  364. ```"""
  365. outputs: BaseModelOutputWithPast = self.model(
  366. input_ids=input_ids,
  367. attention_mask=attention_mask,
  368. position_ids=position_ids,
  369. past_key_values=past_key_values,
  370. inputs_embeds=inputs_embeds,
  371. use_cache=use_cache,
  372. cache_position=cache_position,
  373. **kwargs,
  374. )
  375. hidden_states = outputs.last_hidden_state
  376. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  377. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  378. logits = self.lm_head(hidden_states[:, slice_indices, :])
  379. loss = None
  380. if labels is not None:
  381. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  382. return CausalLMOutputWithPast(
  383. loss=loss,
  384. logits=logits,
  385. past_key_values=outputs.past_key_values,
  386. hidden_states=outputs.hidden_states,
  387. attentions=outputs.attentions,
  388. )
  389. class MistralForTokenClassification(GenericForTokenClassification, MistralPreTrainedModel):
  390. pass
  391. class MistralForSequenceClassification(GenericForSequenceClassification, MistralPreTrainedModel):
  392. pass
  393. class MistralForQuestionAnswering(GenericForQuestionAnswering, MistralPreTrainedModel): ...
  394. __all__ = [
  395. "MistralForCausalLM",
  396. "MistralForQuestionAnswering",
  397. "MistralModel",
  398. "MistralPreTrainedModel",
  399. "MistralForSequenceClassification",
  400. "MistralForTokenClassification",
  401. ]