modeling_mixtral.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/mixtral/modular_mixtral.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_mixtral.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  11. # and OPT implementations in this library. It has been modified from its
  12. # original forms to accommodate minor architectural differences compared
  13. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  14. #
  15. # Licensed under the Apache License, Version 2.0 (the "License");
  16. # you may not use this file except in compliance with the License.
  17. # You may obtain a copy of the License at
  18. #
  19. # http://www.apache.org/licenses/LICENSE-2.0
  20. #
  21. # Unless required by applicable law or agreed to in writing, software
  22. # distributed under the License is distributed on an "AS IS" BASIS,
  23. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  24. # See the License for the specific language governing permissions and
  25. # limitations under the License.
  26. from typing import Callable, Optional, Union
  27. import torch
  28. import torch.nn.functional as F
  29. from torch import nn
  30. from transformers.utils.generic import check_model_inputs
  31. from ...activations import ACT2FN
  32. from ...cache_utils import Cache, DynamicCache
  33. from ...generation import GenerationMixin
  34. from ...integrations import use_kernel_forward_from_hub
  35. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  36. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  37. from ...modeling_layers import (
  38. GenericForQuestionAnswering,
  39. GenericForSequenceClassification,
  40. GenericForTokenClassification,
  41. GradientCheckpointingLayer,
  42. )
  43. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  44. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  45. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  46. from ...processing_utils import Unpack
  47. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  48. from ...utils.deprecation import deprecate_kwarg
  49. from ...utils.generic import OutputRecorder
  50. from .configuration_mixtral import MixtralConfig
  51. class MixtralBlockSparseTop2MLP(nn.Module):
  52. def __init__(self, config: MixtralConfig):
  53. super().__init__()
  54. self.ffn_dim = config.intermediate_size
  55. self.hidden_dim = config.hidden_size
  56. self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  57. self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
  58. self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  59. self.act_fn = ACT2FN[config.hidden_act]
  60. def forward(self, hidden_states):
  61. current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
  62. current_hidden_states = self.w2(current_hidden_states)
  63. return current_hidden_states
  64. class MixtralSparseMoeBlock(nn.Module):
  65. """
  66. This implementation is
  67. strictly equivalent to standard MoE with full capacity (no
  68. dropped tokens). It's faster since it formulates MoE operations
  69. in terms of block-sparse operations to accommodate imbalanced
  70. assignments of tokens to experts, whereas standard MoE either
  71. (1) drop tokens at the cost of reduced performance or (2) set
  72. capacity factor to number of experts and thus waste computation
  73. and memory on padding.
  74. """
  75. def __init__(self, config):
  76. super().__init__()
  77. self.hidden_dim = config.hidden_size
  78. self.ffn_dim = config.intermediate_size
  79. self.num_experts = config.num_local_experts
  80. self.top_k = config.num_experts_per_tok
  81. # gating
  82. self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
  83. self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
  84. # Jitter parameters
  85. self.jitter_noise = config.router_jitter_noise
  86. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  87. """ """
  88. batch_size, sequence_length, hidden_dim = hidden_states.shape
  89. if self.training and self.jitter_noise > 0:
  90. hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
  91. hidden_states = hidden_states.view(-1, hidden_dim)
  92. # router_logits: (batch * sequence_length, n_experts)
  93. router_logits = self.gate(hidden_states)
  94. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  95. routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
  96. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  97. # we cast back to the input dtype
  98. routing_weights = routing_weights.to(hidden_states.dtype)
  99. final_hidden_states = torch.zeros(
  100. (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
  101. )
  102. # One hot encode the selected experts to create an expert mask
  103. # this will be used to easily index which expert is going to be sollicitated
  104. expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
  105. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  106. for expert_idx in expert_hit:
  107. expert_layer = self.experts[expert_idx]
  108. idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
  109. # Index the correct hidden states and compute the expert hidden state for
  110. # the current expert. We need to make sure to multiply the output hidden
  111. # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
  112. current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
  113. current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
  114. # However `index_add_` only support torch tensors for indexing so we'll use
  115. # the `top_x` tensor here.
  116. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
  117. final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  118. return final_hidden_states, router_logits
  119. @use_kernel_forward_from_hub("RMSNorm")
  120. class MixtralRMSNorm(nn.Module):
  121. def __init__(self, hidden_size, eps=1e-6):
  122. """
  123. MixtralRMSNorm is equivalent to T5LayerNorm
  124. """
  125. super().__init__()
  126. self.weight = nn.Parameter(torch.ones(hidden_size))
  127. self.variance_epsilon = eps
  128. def forward(self, hidden_states):
  129. input_dtype = hidden_states.dtype
  130. hidden_states = hidden_states.to(torch.float32)
  131. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  132. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  133. return self.weight * hidden_states.to(input_dtype)
  134. def extra_repr(self):
  135. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  136. def rotate_half(x):
  137. """Rotates half the hidden dims of the input."""
  138. x1 = x[..., : x.shape[-1] // 2]
  139. x2 = x[..., x.shape[-1] // 2 :]
  140. return torch.cat((-x2, x1), dim=-1)
  141. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  142. """Applies Rotary Position Embedding to the query and key tensors.
  143. Args:
  144. q (`torch.Tensor`): The query tensor.
  145. k (`torch.Tensor`): The key tensor.
  146. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  147. sin (`torch.Tensor`): The sine part of the rotary embedding.
  148. position_ids (`torch.Tensor`, *optional*):
  149. Deprecated and unused.
  150. unsqueeze_dim (`int`, *optional*, defaults to 1):
  151. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  152. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  153. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  154. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  155. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  156. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  157. Returns:
  158. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  159. """
  160. cos = cos.unsqueeze(unsqueeze_dim)
  161. sin = sin.unsqueeze(unsqueeze_dim)
  162. q_embed = (q * cos) + (rotate_half(q) * sin)
  163. k_embed = (k * cos) + (rotate_half(k) * sin)
  164. return q_embed, k_embed
  165. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  166. """
  167. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  168. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  169. """
  170. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  171. if n_rep == 1:
  172. return hidden_states
  173. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  174. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  175. def eager_attention_forward(
  176. module: nn.Module,
  177. query: torch.Tensor,
  178. key: torch.Tensor,
  179. value: torch.Tensor,
  180. attention_mask: Optional[torch.Tensor],
  181. scaling: float,
  182. dropout: float = 0.0,
  183. **kwargs: Unpack[TransformersKwargs],
  184. ):
  185. key_states = repeat_kv(key, module.num_key_value_groups)
  186. value_states = repeat_kv(value, module.num_key_value_groups)
  187. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  188. if attention_mask is not None:
  189. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  190. attn_weights = attn_weights + causal_mask
  191. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  192. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  193. attn_output = torch.matmul(attn_weights, value_states)
  194. attn_output = attn_output.transpose(1, 2).contiguous()
  195. return attn_output, attn_weights
  196. class MixtralAttention(nn.Module):
  197. """Multi-headed attention from 'Attention Is All You Need' paper"""
  198. def __init__(self, config: MixtralConfig, layer_idx: int):
  199. super().__init__()
  200. self.config = config
  201. self.layer_idx = layer_idx
  202. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  203. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  204. self.scaling = self.head_dim**-0.5
  205. self.attention_dropout = config.attention_dropout
  206. self.is_causal = True
  207. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  208. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  209. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  210. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  211. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  212. def forward(
  213. self,
  214. hidden_states: torch.Tensor,
  215. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  216. attention_mask: Optional[torch.Tensor],
  217. past_key_values: Optional[Cache] = None,
  218. cache_position: Optional[torch.LongTensor] = None,
  219. **kwargs: Unpack[FlashAttentionKwargs],
  220. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  221. input_shape = hidden_states.shape[:-1]
  222. hidden_shape = (*input_shape, -1, self.head_dim)
  223. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  224. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  225. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  226. cos, sin = position_embeddings
  227. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  228. if past_key_values is not None:
  229. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  230. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  231. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  232. attention_interface: Callable = eager_attention_forward
  233. if self.config._attn_implementation != "eager":
  234. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  235. attn_output, attn_weights = attention_interface(
  236. self,
  237. query_states,
  238. key_states,
  239. value_states,
  240. attention_mask,
  241. dropout=0.0 if not self.training else self.attention_dropout,
  242. scaling=self.scaling,
  243. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  244. **kwargs,
  245. )
  246. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  247. attn_output = self.o_proj(attn_output)
  248. return attn_output, attn_weights
  249. class MixtralDecoderLayer(GradientCheckpointingLayer):
  250. def __init__(self, config: MixtralConfig, layer_idx: int):
  251. super().__init__()
  252. self.hidden_size = config.hidden_size
  253. self.self_attn = MixtralAttention(config, layer_idx)
  254. self.block_sparse_moe = MixtralSparseMoeBlock(config)
  255. self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  256. self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  257. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  258. def forward(
  259. self,
  260. hidden_states: torch.Tensor,
  261. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  262. attention_mask: Optional[torch.Tensor] = None,
  263. position_ids: Optional[torch.LongTensor] = None,
  264. past_key_values: Optional[Cache] = None,
  265. cache_position: Optional[torch.LongTensor] = None,
  266. **kwargs: Unpack[TransformersKwargs],
  267. ) -> torch.FloatTensor:
  268. residual = hidden_states
  269. hidden_states = self.input_layernorm(hidden_states)
  270. # Self Attention
  271. hidden_states, _ = self.self_attn(
  272. hidden_states=hidden_states,
  273. position_embeddings=position_embeddings,
  274. attention_mask=attention_mask,
  275. position_ids=position_ids,
  276. past_key_values=past_key_values,
  277. cache_position=cache_position,
  278. **kwargs,
  279. )
  280. hidden_states = residual + hidden_states
  281. # Fully Connected
  282. residual = hidden_states
  283. hidden_states = self.post_attention_layernorm(hidden_states)
  284. hidden_states, _ = self.block_sparse_moe(hidden_states)
  285. hidden_states = residual + hidden_states
  286. return hidden_states
  287. class MixtralRotaryEmbedding(nn.Module):
  288. inv_freq: torch.Tensor # fix linting for `register_buffer`
  289. def __init__(self, config: MixtralConfig, device=None):
  290. super().__init__()
  291. # BC: "rope_type" was originally "type"
  292. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  293. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  294. else:
  295. self.rope_type = "default"
  296. self.max_seq_len_cached = config.max_position_embeddings
  297. self.original_max_seq_len = config.max_position_embeddings
  298. self.config = config
  299. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  300. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  301. self.register_buffer("inv_freq", inv_freq, persistent=False)
  302. self.original_inv_freq = self.inv_freq
  303. @torch.no_grad()
  304. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  305. def forward(self, x, position_ids):
  306. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  307. position_ids_expanded = position_ids[:, None, :].float()
  308. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  309. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  310. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  311. emb = torch.cat((freqs, freqs), dim=-1)
  312. cos = emb.cos() * self.attention_scaling
  313. sin = emb.sin() * self.attention_scaling
  314. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  315. @auto_docstring
  316. class MixtralPreTrainedModel(PreTrainedModel):
  317. config: MixtralConfig
  318. base_model_prefix = "model"
  319. supports_gradient_checkpointing = True
  320. _no_split_modules = ["MixtralDecoderLayer"]
  321. _skip_keys_device_placement = ["past_key_values"]
  322. _supports_flash_attn = True
  323. _supports_sdpa = True
  324. _supports_flex_attn = True
  325. _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
  326. _supports_attention_backend = True
  327. _can_record_outputs = {
  328. "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1),
  329. "hidden_states": MixtralDecoderLayer,
  330. "attentions": MixtralAttention,
  331. }
  332. @auto_docstring
  333. class MixtralModel(MixtralPreTrainedModel):
  334. def __init__(self, config: MixtralConfig):
  335. super().__init__(config)
  336. self.padding_idx = config.pad_token_id
  337. self.vocab_size = config.vocab_size
  338. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  339. self.layers = nn.ModuleList(
  340. [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  341. )
  342. self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  343. self.rotary_emb = MixtralRotaryEmbedding(config=config)
  344. self.gradient_checkpointing = False
  345. # Initialize weights and apply final processing
  346. self.post_init()
  347. @check_model_inputs()
  348. @auto_docstring
  349. def forward(
  350. self,
  351. input_ids: Optional[torch.LongTensor] = None,
  352. attention_mask: Optional[torch.Tensor] = None,
  353. position_ids: Optional[torch.LongTensor] = None,
  354. past_key_values: Optional[Cache] = None,
  355. inputs_embeds: Optional[torch.FloatTensor] = None,
  356. use_cache: Optional[bool] = None,
  357. cache_position: Optional[torch.LongTensor] = None,
  358. **kwargs: Unpack[TransformersKwargs],
  359. ) -> MoeModelOutputWithPast:
  360. if (input_ids is None) ^ (inputs_embeds is not None):
  361. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  362. if use_cache and past_key_values is None:
  363. past_key_values = DynamicCache(config=self.config)
  364. if inputs_embeds is None:
  365. inputs_embeds = self.embed_tokens(input_ids)
  366. if cache_position is None:
  367. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  368. cache_position = torch.arange(
  369. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  370. )
  371. if position_ids is None:
  372. position_ids = cache_position.unsqueeze(0)
  373. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  374. causal_mask = mask_function(
  375. config=self.config,
  376. input_embeds=inputs_embeds,
  377. attention_mask=attention_mask,
  378. cache_position=cache_position,
  379. past_key_values=past_key_values,
  380. position_ids=position_ids,
  381. )
  382. hidden_states = inputs_embeds
  383. # create position embeddings to be shared across the decoder layers
  384. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  385. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  386. hidden_states = decoder_layer(
  387. hidden_states,
  388. position_embeddings=position_embeddings,
  389. attention_mask=causal_mask,
  390. position_ids=position_ids,
  391. past_key_values=past_key_values,
  392. use_cache=use_cache,
  393. cache_position=cache_position,
  394. **kwargs,
  395. )
  396. hidden_states = self.norm(hidden_states)
  397. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  398. last_hidden_state=hidden_states,
  399. past_key_values=past_key_values,
  400. )
  401. def load_balancing_loss_func(
  402. gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
  403. num_experts: Optional[int] = None,
  404. top_k=2,
  405. attention_mask: Optional[torch.Tensor] = None,
  406. ) -> Union[torch.Tensor, int]:
  407. r"""
  408. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  409. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  410. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  411. experts is too unbalanced.
  412. Args:
  413. gate_logits:
  414. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  415. shape [batch_size X sequence_length, num_experts].
  416. num_experts:
  417. Number of experts
  418. top_k:
  419. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  420. parameter.
  421. attention_mask (`torch.Tensor`, *optional*):
  422. The attention_mask used in forward function
  423. shape [batch_size X sequence_length] if not None.
  424. Returns:
  425. The auxiliary loss.
  426. """
  427. if gate_logits is None or not isinstance(gate_logits, tuple):
  428. return 0
  429. if isinstance(gate_logits, tuple):
  430. compute_device = gate_logits[0].device
  431. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  432. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  433. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  434. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  435. if attention_mask is None:
  436. # Compute the percentage of tokens routed to each experts
  437. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  438. # Compute the average probability of routing to these experts
  439. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  440. else:
  441. batch_size, sequence_length = attention_mask.shape
  442. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  443. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  444. expert_attention_mask = (
  445. attention_mask[None, :, :, None, None]
  446. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  447. .reshape(-1, top_k, num_experts)
  448. .to(compute_device)
  449. )
  450. # Compute the percentage of tokens routed to each experts
  451. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  452. expert_attention_mask, dim=0
  453. )
  454. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  455. router_per_expert_attention_mask = (
  456. attention_mask[None, :, :, None]
  457. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  458. .reshape(-1, num_experts)
  459. .to(compute_device)
  460. )
  461. # Compute the average probability of routing to these experts
  462. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  463. router_per_expert_attention_mask, dim=0
  464. )
  465. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  466. return overall_loss * num_experts
  467. @auto_docstring
  468. class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
  469. _tied_weights_keys = ["lm_head.weight"]
  470. _tp_plan = {"lm_head": "colwise_rep"}
  471. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  472. def __init__(self, config):
  473. super().__init__(config)
  474. self.model = MixtralModel(config)
  475. self.vocab_size = config.vocab_size
  476. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  477. self.router_aux_loss_coef = config.router_aux_loss_coef
  478. self.num_experts = config.num_local_experts
  479. self.num_experts_per_tok = config.num_experts_per_tok
  480. # Initialize weights and apply final processing
  481. self.post_init()
  482. @can_return_tuple
  483. @auto_docstring
  484. def forward(
  485. self,
  486. input_ids: Optional[torch.LongTensor] = None,
  487. attention_mask: Optional[torch.Tensor] = None,
  488. position_ids: Optional[torch.LongTensor] = None,
  489. past_key_values: Optional[Cache] = None,
  490. inputs_embeds: Optional[torch.FloatTensor] = None,
  491. labels: Optional[torch.LongTensor] = None,
  492. use_cache: Optional[bool] = None,
  493. output_router_logits: Optional[bool] = None,
  494. cache_position: Optional[torch.LongTensor] = None,
  495. logits_to_keep: Union[int, torch.Tensor] = 0,
  496. **kwargs: Unpack[TransformersKwargs],
  497. ) -> MoeCausalLMOutputWithPast:
  498. r"""
  499. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  500. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  501. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  502. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  503. Example:
  504. ```python
  505. >>> from transformers import AutoTokenizer, MixtralForCausalLM
  506. >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  507. >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  508. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  509. >>> inputs = tokenizer(prompt, return_tensors="pt")
  510. >>> # Generate
  511. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  512. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  513. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  514. ```"""
  515. output_router_logits = (
  516. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  517. )
  518. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  519. outputs: MoeModelOutputWithPast = self.model(
  520. input_ids=input_ids,
  521. attention_mask=attention_mask,
  522. position_ids=position_ids,
  523. past_key_values=past_key_values,
  524. inputs_embeds=inputs_embeds,
  525. use_cache=use_cache,
  526. output_router_logits=output_router_logits,
  527. cache_position=cache_position,
  528. **kwargs,
  529. )
  530. hidden_states = outputs.last_hidden_state
  531. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  532. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  533. logits = self.lm_head(hidden_states[:, slice_indices, :])
  534. loss = None
  535. if labels is not None:
  536. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  537. aux_loss = None
  538. if output_router_logits:
  539. aux_loss = load_balancing_loss_func(
  540. outputs.router_logits,
  541. self.num_experts,
  542. self.num_experts_per_tok,
  543. attention_mask,
  544. )
  545. if labels is not None:
  546. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  547. return MoeCausalLMOutputWithPast(
  548. loss=loss,
  549. aux_loss=aux_loss,
  550. logits=logits,
  551. past_key_values=outputs.past_key_values,
  552. hidden_states=outputs.hidden_states,
  553. attentions=outputs.attentions,
  554. router_logits=outputs.router_logits,
  555. )
  556. class MixtralForSequenceClassification(GenericForSequenceClassification, MixtralPreTrainedModel):
  557. pass
  558. class MixtralForTokenClassification(GenericForTokenClassification, MixtralPreTrainedModel):
  559. pass
  560. class MixtralForQuestionAnswering(GenericForQuestionAnswering, MixtralPreTrainedModel):
  561. pass
  562. __all__ = [
  563. "MixtralForCausalLM",
  564. "MixtralForQuestionAnswering",
  565. "MixtralModel",
  566. "MixtralPreTrainedModel",
  567. "MixtralForSequenceClassification",
  568. "MixtralForTokenClassification",
  569. ]