modular_mixtral.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # coding=utf-8
  2. # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Mixtral model."""
  21. from typing import Optional, Union
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, logging
  32. from ...utils.deprecation import deprecate_kwarg
  33. from ...utils.generic import OutputRecorder
  34. from ..mistral.modeling_mistral import (
  35. MistralAttention,
  36. MistralForCausalLM,
  37. MistralForQuestionAnswering,
  38. MistralForSequenceClassification,
  39. MistralForTokenClassification,
  40. MistralModel,
  41. MistralPreTrainedModel,
  42. MistralRMSNorm,
  43. MistralRotaryEmbedding,
  44. )
  45. from .configuration_mixtral import MixtralConfig
  46. logger = logging.get_logger(__name__)
  47. def load_balancing_loss_func(
  48. gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
  49. num_experts: Optional[int] = None,
  50. top_k=2,
  51. attention_mask: Optional[torch.Tensor] = None,
  52. ) -> Union[torch.Tensor, int]:
  53. r"""
  54. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  55. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  56. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  57. experts is too unbalanced.
  58. Args:
  59. gate_logits:
  60. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  61. shape [batch_size X sequence_length, num_experts].
  62. num_experts:
  63. Number of experts
  64. top_k:
  65. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  66. parameter.
  67. attention_mask (`torch.Tensor`, *optional*):
  68. The attention_mask used in forward function
  69. shape [batch_size X sequence_length] if not None.
  70. Returns:
  71. The auxiliary loss.
  72. """
  73. if gate_logits is None or not isinstance(gate_logits, tuple):
  74. return 0
  75. if isinstance(gate_logits, tuple):
  76. compute_device = gate_logits[0].device
  77. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  78. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  79. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  80. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  81. if attention_mask is None:
  82. # Compute the percentage of tokens routed to each experts
  83. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  84. # Compute the average probability of routing to these experts
  85. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  86. else:
  87. batch_size, sequence_length = attention_mask.shape
  88. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  89. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  90. expert_attention_mask = (
  91. attention_mask[None, :, :, None, None]
  92. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  93. .reshape(-1, top_k, num_experts)
  94. .to(compute_device)
  95. )
  96. # Compute the percentage of tokens routed to each experts
  97. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  98. expert_attention_mask, dim=0
  99. )
  100. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  101. router_per_expert_attention_mask = (
  102. attention_mask[None, :, :, None]
  103. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  104. .reshape(-1, num_experts)
  105. .to(compute_device)
  106. )
  107. # Compute the average probability of routing to these experts
  108. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  109. router_per_expert_attention_mask, dim=0
  110. )
  111. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  112. return overall_loss * num_experts
  113. class MixtralBlockSparseTop2MLP(nn.Module):
  114. def __init__(self, config: MixtralConfig):
  115. super().__init__()
  116. self.ffn_dim = config.intermediate_size
  117. self.hidden_dim = config.hidden_size
  118. self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  119. self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
  120. self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  121. self.act_fn = ACT2FN[config.hidden_act]
  122. def forward(self, hidden_states):
  123. current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
  124. current_hidden_states = self.w2(current_hidden_states)
  125. return current_hidden_states
  126. class MixtralSparseMoeBlock(nn.Module):
  127. """
  128. This implementation is
  129. strictly equivalent to standard MoE with full capacity (no
  130. dropped tokens). It's faster since it formulates MoE operations
  131. in terms of block-sparse operations to accommodate imbalanced
  132. assignments of tokens to experts, whereas standard MoE either
  133. (1) drop tokens at the cost of reduced performance or (2) set
  134. capacity factor to number of experts and thus waste computation
  135. and memory on padding.
  136. """
  137. def __init__(self, config):
  138. super().__init__()
  139. self.hidden_dim = config.hidden_size
  140. self.ffn_dim = config.intermediate_size
  141. self.num_experts = config.num_local_experts
  142. self.top_k = config.num_experts_per_tok
  143. # gating
  144. self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
  145. self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
  146. # Jitter parameters
  147. self.jitter_noise = config.router_jitter_noise
  148. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  149. """ """
  150. batch_size, sequence_length, hidden_dim = hidden_states.shape
  151. if self.training and self.jitter_noise > 0:
  152. hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
  153. hidden_states = hidden_states.view(-1, hidden_dim)
  154. # router_logits: (batch * sequence_length, n_experts)
  155. router_logits = self.gate(hidden_states)
  156. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  157. routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
  158. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  159. # we cast back to the input dtype
  160. routing_weights = routing_weights.to(hidden_states.dtype)
  161. final_hidden_states = torch.zeros(
  162. (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
  163. )
  164. # One hot encode the selected experts to create an expert mask
  165. # this will be used to easily index which expert is going to be sollicitated
  166. expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
  167. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  168. for expert_idx in expert_hit:
  169. expert_layer = self.experts[expert_idx]
  170. idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
  171. # Index the correct hidden states and compute the expert hidden state for
  172. # the current expert. We need to make sure to multiply the output hidden
  173. # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
  174. current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
  175. current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
  176. # However `index_add_` only support torch tensors for indexing so we'll use
  177. # the `top_x` tensor here.
  178. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
  179. final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  180. return final_hidden_states, router_logits
  181. class MixtralRMSNorm(MistralRMSNorm):
  182. pass
  183. class MixtralAttention(MistralAttention):
  184. pass
  185. class MixtralDecoderLayer(GradientCheckpointingLayer):
  186. def __init__(self, config: MixtralConfig, layer_idx: int):
  187. super().__init__()
  188. self.hidden_size = config.hidden_size
  189. self.self_attn = MixtralAttention(config, layer_idx)
  190. self.block_sparse_moe = MixtralSparseMoeBlock(config)
  191. self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  192. self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  193. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  194. def forward(
  195. self,
  196. hidden_states: torch.Tensor,
  197. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  198. attention_mask: Optional[torch.Tensor] = None,
  199. position_ids: Optional[torch.LongTensor] = None,
  200. past_key_values: Optional[Cache] = None,
  201. cache_position: Optional[torch.LongTensor] = None,
  202. **kwargs: Unpack[TransformersKwargs],
  203. ) -> torch.FloatTensor:
  204. residual = hidden_states
  205. hidden_states = self.input_layernorm(hidden_states)
  206. # Self Attention
  207. hidden_states, _ = self.self_attn(
  208. hidden_states=hidden_states,
  209. position_embeddings=position_embeddings,
  210. attention_mask=attention_mask,
  211. position_ids=position_ids,
  212. past_key_values=past_key_values,
  213. cache_position=cache_position,
  214. **kwargs,
  215. )
  216. hidden_states = residual + hidden_states
  217. # Fully Connected
  218. residual = hidden_states
  219. hidden_states = self.post_attention_layernorm(hidden_states)
  220. hidden_states, _ = self.block_sparse_moe(hidden_states)
  221. hidden_states = residual + hidden_states
  222. return hidden_states
  223. class MixtralRotaryEmbedding(MistralRotaryEmbedding):
  224. pass
  225. class MixtralPreTrainedModel(MistralPreTrainedModel):
  226. _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
  227. _can_record_outputs = {
  228. "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1),
  229. "hidden_states": MixtralDecoderLayer,
  230. "attentions": MixtralAttention,
  231. }
  232. class MixtralModel(MistralModel):
  233. def forward(
  234. self,
  235. input_ids: Optional[torch.LongTensor] = None,
  236. attention_mask: Optional[torch.Tensor] = None,
  237. position_ids: Optional[torch.LongTensor] = None,
  238. past_key_values: Optional[Cache] = None,
  239. inputs_embeds: Optional[torch.FloatTensor] = None,
  240. use_cache: Optional[bool] = None,
  241. cache_position: Optional[torch.LongTensor] = None,
  242. **kwargs: Unpack[TransformersKwargs],
  243. ) -> MoeModelOutputWithPast:
  244. if (input_ids is None) ^ (inputs_embeds is not None):
  245. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  246. if use_cache and past_key_values is None:
  247. past_key_values = DynamicCache(config=self.config)
  248. if inputs_embeds is None:
  249. inputs_embeds = self.embed_tokens(input_ids)
  250. if cache_position is None:
  251. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  252. cache_position = torch.arange(
  253. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  254. )
  255. if position_ids is None:
  256. position_ids = cache_position.unsqueeze(0)
  257. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  258. causal_mask = mask_function(
  259. config=self.config,
  260. input_embeds=inputs_embeds,
  261. attention_mask=attention_mask,
  262. cache_position=cache_position,
  263. past_key_values=past_key_values,
  264. position_ids=position_ids,
  265. )
  266. hidden_states = inputs_embeds
  267. # create position embeddings to be shared across the decoder layers
  268. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  269. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  270. hidden_states = decoder_layer(
  271. hidden_states,
  272. position_embeddings=position_embeddings,
  273. attention_mask=causal_mask,
  274. position_ids=position_ids,
  275. past_key_values=past_key_values,
  276. use_cache=use_cache,
  277. cache_position=cache_position,
  278. **kwargs,
  279. )
  280. hidden_states = self.norm(hidden_states)
  281. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  282. last_hidden_state=hidden_states,
  283. past_key_values=past_key_values,
  284. )
  285. class MixtralForCausalLM(MistralForCausalLM):
  286. _tied_weights_keys = ["lm_head.weight"]
  287. def __init__(self, config):
  288. super().__init__(config)
  289. self.model = MixtralModel(config)
  290. self.router_aux_loss_coef = config.router_aux_loss_coef
  291. self.num_experts = config.num_local_experts
  292. self.num_experts_per_tok = config.num_experts_per_tok
  293. def forward(
  294. self,
  295. input_ids: Optional[torch.LongTensor] = None,
  296. attention_mask: Optional[torch.Tensor] = None,
  297. position_ids: Optional[torch.LongTensor] = None,
  298. past_key_values: Optional[Cache] = None,
  299. inputs_embeds: Optional[torch.FloatTensor] = None,
  300. labels: Optional[torch.LongTensor] = None,
  301. use_cache: Optional[bool] = None,
  302. output_router_logits: Optional[bool] = None,
  303. cache_position: Optional[torch.LongTensor] = None,
  304. logits_to_keep: Union[int, torch.Tensor] = 0,
  305. **kwargs: Unpack[TransformersKwargs],
  306. ) -> MoeCausalLMOutputWithPast:
  307. r"""
  308. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  309. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  310. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  311. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  312. Example:
  313. ```python
  314. >>> from transformers import AutoTokenizer, MixtralForCausalLM
  315. >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  316. >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  317. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  318. >>> inputs = tokenizer(prompt, return_tensors="pt")
  319. >>> # Generate
  320. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  321. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  322. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  323. ```"""
  324. output_router_logits = (
  325. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  326. )
  327. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  328. outputs: MoeModelOutputWithPast = self.model(
  329. input_ids=input_ids,
  330. attention_mask=attention_mask,
  331. position_ids=position_ids,
  332. past_key_values=past_key_values,
  333. inputs_embeds=inputs_embeds,
  334. use_cache=use_cache,
  335. output_router_logits=output_router_logits,
  336. cache_position=cache_position,
  337. **kwargs,
  338. )
  339. hidden_states = outputs.last_hidden_state
  340. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  341. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  342. logits = self.lm_head(hidden_states[:, slice_indices, :])
  343. loss = None
  344. if labels is not None:
  345. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  346. aux_loss = None
  347. if output_router_logits:
  348. aux_loss = load_balancing_loss_func(
  349. outputs.router_logits,
  350. self.num_experts,
  351. self.num_experts_per_tok,
  352. attention_mask,
  353. )
  354. if labels is not None:
  355. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  356. return MoeCausalLMOutputWithPast(
  357. loss=loss,
  358. aux_loss=aux_loss,
  359. logits=logits,
  360. past_key_values=outputs.past_key_values,
  361. hidden_states=outputs.hidden_states,
  362. attentions=outputs.attentions,
  363. router_logits=outputs.router_logits,
  364. )
  365. class MixtralForSequenceClassification(MistralForSequenceClassification):
  366. pass
  367. class MixtralForTokenClassification(MistralForTokenClassification):
  368. pass
  369. class MixtralForQuestionAnswering(MistralForQuestionAnswering):
  370. pass
  371. __all__ = [
  372. "MixtralForCausalLM",
  373. "MixtralForQuestionAnswering",
  374. "MixtralModel",
  375. "MixtralPreTrainedModel",
  376. "MixtralForSequenceClassification",
  377. "MixtralForTokenClassification",
  378. ]