modular_mistral.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. from typing import Callable, Optional
  2. import torch
  3. from torch import nn
  4. from transformers.utils.generic import check_model_inputs
  5. from ...cache_utils import Cache, DynamicCache
  6. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  7. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  8. from ...modeling_layers import (
  9. GenericForQuestionAnswering,
  10. )
  11. from ...modeling_outputs import BaseModelOutputWithPast
  12. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  13. from ...processing_utils import Unpack
  14. from ...utils import TransformersKwargs, auto_docstring, logging
  15. from ...utils.deprecation import deprecate_kwarg
  16. from ..llama.modeling_llama import (
  17. LlamaAttention,
  18. LlamaDecoderLayer,
  19. LlamaForCausalLM,
  20. LlamaForSequenceClassification,
  21. LlamaForTokenClassification,
  22. LlamaMLP,
  23. LlamaModel,
  24. LlamaPreTrainedModel,
  25. apply_rotary_pos_emb,
  26. eager_attention_forward,
  27. )
  28. from .configuration_mistral import MistralConfig
  29. logger = logging.get_logger(__name__)
  30. class MistralMLP(LlamaMLP):
  31. def __init__(self, config):
  32. super().__init__(config)
  33. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  34. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  35. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  36. class MistralAttention(LlamaAttention):
  37. def __init__(self, config: MistralConfig, layer_idx: int):
  38. super().__init__(config, layer_idx)
  39. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  40. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  41. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  42. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  43. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  44. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  45. def forward(
  46. self,
  47. hidden_states: torch.Tensor,
  48. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  49. attention_mask: Optional[torch.Tensor],
  50. past_key_values: Optional[Cache] = None,
  51. cache_position: Optional[torch.LongTensor] = None,
  52. **kwargs: Unpack[FlashAttentionKwargs],
  53. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  54. input_shape = hidden_states.shape[:-1]
  55. hidden_shape = (*input_shape, -1, self.head_dim)
  56. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  57. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  58. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  59. cos, sin = position_embeddings
  60. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  61. if past_key_values is not None:
  62. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  63. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  64. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  65. attention_interface: Callable = eager_attention_forward
  66. if self.config._attn_implementation != "eager":
  67. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  68. attn_output, attn_weights = attention_interface(
  69. self,
  70. query_states,
  71. key_states,
  72. value_states,
  73. attention_mask,
  74. dropout=0.0 if not self.training else self.attention_dropout,
  75. scaling=self.scaling,
  76. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  77. **kwargs,
  78. )
  79. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  80. attn_output = self.o_proj(attn_output)
  81. return attn_output, attn_weights
  82. class MistralDecoderLayer(LlamaDecoderLayer):
  83. def __init__(self, config: MistralConfig, layer_idx: int):
  84. super().__init__(config, layer_idx)
  85. self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
  86. self.mlp = MistralMLP(config)
  87. class MistralPreTrainedModel(LlamaPreTrainedModel):
  88. _can_record_outputs = {
  89. "hidden_states": MistralDecoderLayer,
  90. "attentions": MistralAttention,
  91. }
  92. class MistralModel(LlamaModel):
  93. @check_model_inputs()
  94. @auto_docstring
  95. def forward(
  96. self,
  97. input_ids: Optional[torch.LongTensor] = None,
  98. attention_mask: Optional[torch.Tensor] = None,
  99. position_ids: Optional[torch.LongTensor] = None,
  100. past_key_values: Optional[Cache] = None,
  101. inputs_embeds: Optional[torch.FloatTensor] = None,
  102. use_cache: Optional[bool] = None,
  103. cache_position: Optional[torch.LongTensor] = None,
  104. **kwargs: Unpack[TransformersKwargs],
  105. ) -> BaseModelOutputWithPast:
  106. if (input_ids is None) ^ (inputs_embeds is not None):
  107. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  108. if inputs_embeds is None:
  109. inputs_embeds = self.embed_tokens(input_ids)
  110. if use_cache and past_key_values is None:
  111. past_key_values = DynamicCache(config=self.config)
  112. if cache_position is None:
  113. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  114. cache_position = torch.arange(
  115. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  116. )
  117. if position_ids is None:
  118. position_ids = cache_position.unsqueeze(0)
  119. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  120. causal_mask = mask_function(
  121. config=self.config,
  122. input_embeds=inputs_embeds,
  123. attention_mask=attention_mask,
  124. cache_position=cache_position,
  125. past_key_values=past_key_values,
  126. position_ids=position_ids,
  127. )
  128. hidden_states = inputs_embeds
  129. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  130. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  131. hidden_states = decoder_layer(
  132. hidden_states,
  133. attention_mask=causal_mask,
  134. position_ids=position_ids,
  135. past_key_values=past_key_values,
  136. use_cache=use_cache,
  137. cache_position=cache_position,
  138. position_embeddings=position_embeddings,
  139. **kwargs,
  140. )
  141. hidden_states = self.norm(hidden_states)
  142. return BaseModelOutputWithPast(
  143. last_hidden_state=hidden_states,
  144. past_key_values=past_key_values if use_cache else None,
  145. )
  146. class MistralForCausalLM(LlamaForCausalLM):
  147. pass
  148. class MistralForTokenClassification(LlamaForTokenClassification):
  149. pass
  150. class MistralForSequenceClassification(LlamaForSequenceClassification):
  151. pass
  152. class MistralForQuestionAnswering(GenericForQuestionAnswering, MistralPreTrainedModel): ...
  153. __all__ = [
  154. "MistralForCausalLM",
  155. "MistralForQuestionAnswering",
  156. "MistralModel",
  157. "MistralPreTrainedModel",
  158. "MistralForSequenceClassification",
  159. "MistralForTokenClassification",
  160. ]