modeling_minimax.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/minimax/modular_minimax.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_minimax.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. from typing import Callable, Optional, Union
  23. import torch
  24. import torch.nn.functional as F
  25. from torch import nn
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub
  30. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...modeling_layers import (
  33. GenericForQuestionAnswering,
  34. GenericForSequenceClassification,
  35. GenericForTokenClassification,
  36. GradientCheckpointingLayer,
  37. )
  38. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  39. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  40. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  41. from ...processing_utils import Unpack
  42. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  43. from ...utils.deprecation import deprecate_kwarg
  44. from ...utils.generic import OutputRecorder, check_model_inputs
  45. from .configuration_minimax import MiniMaxConfig
  46. @use_kernel_forward_from_hub("RMSNorm")
  47. class MiniMaxRMSNorm(nn.Module):
  48. def __init__(self, hidden_size, eps=1e-6):
  49. """
  50. MiniMaxRMSNorm is equivalent to T5LayerNorm
  51. """
  52. super().__init__()
  53. self.weight = nn.Parameter(torch.ones(hidden_size))
  54. self.variance_epsilon = eps
  55. def forward(self, hidden_states):
  56. input_dtype = hidden_states.dtype
  57. hidden_states = hidden_states.to(torch.float32)
  58. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  59. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  60. return self.weight * hidden_states.to(input_dtype)
  61. def extra_repr(self):
  62. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  63. class MiniMaxCache(DynamicCache):
  64. def __init__(self):
  65. super().__init__()
  66. self.linear_cache: list[torch.Tensor] = []
  67. def set_linear_cache(self, layer_idx, linear_cache):
  68. # There may be skipped layers, fill them with empty lists
  69. for _ in range(len(self.linear_cache), layer_idx + 1):
  70. self.linear_cache.append([])
  71. self.linear_cache[layer_idx] = linear_cache
  72. def get_linear_cache(self, layer_idx: int):
  73. if layer_idx < len(self):
  74. return self.linear_cache[layer_idx]
  75. return None
  76. def __len__(self):
  77. return max(super().__len__(), len(self.linear_cache))
  78. def __getitem__(self, layer_idx: int):
  79. if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []:
  80. return (self.linear_cache[layer_idx],)
  81. return super().__getitem__(layer_idx)
  82. def __iter__(self):
  83. for layer_idx in range(len(self)):
  84. yield self[layer_idx]
  85. def batch_repeat_interleave(self, repeats: int):
  86. for layer_idx in range(len(self)):
  87. if self.linear_cache[layer_idx] != []:
  88. self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
  89. else:
  90. self.layers[layer_idx].batch_repeat_interleave(repeats)
  91. def batch_select_indices(self, indices: torch.Tensor):
  92. for layer_idx in range(len(self)):
  93. if self.linear_cache[layer_idx] != []:
  94. self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
  95. else:
  96. self.layers[layer_idx].batch_select_indices(indices)
  97. def crop(self, max_length: int):
  98. raise RuntimeError("MiniMaxCache doesnot support `crop` method")
  99. class MiniMaxLightningAttention(nn.Module):
  100. def __init__(self, config: MiniMaxConfig, layer_idx: int):
  101. super().__init__()
  102. self.layer_idx = layer_idx
  103. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  104. self.num_attention_heads = config.num_attention_heads
  105. self.num_hidden_layers = config.num_hidden_layers
  106. self.block_size = config.block_size
  107. self.act_fn = ACT2FN[config.hidden_act]
  108. self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads)
  109. self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False)
  110. self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  111. self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
  112. slope_rate = self.get_slope_rate()
  113. query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
  114. self.register_buffer("slope_rate", slope_rate)
  115. self.register_buffer("query_decay", query_decay)
  116. self.register_buffer("key_decay", key_decay)
  117. self.register_buffer("diagonal_decay", diagonal_decay)
  118. def get_slope_rate(self):
  119. base = 1 / (2 ** (8 / self.num_attention_heads))
  120. exponent = torch.arange(self.num_attention_heads) + 1
  121. factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
  122. rate = base**exponent
  123. rate = rate * factor
  124. rate = rate[:, None, None]
  125. return rate
  126. def decay_factors(self, slope_rate):
  127. block_size_range = torch.arange(self.block_size) + 1
  128. query_decay = torch.exp(-slope_rate * block_size_range[:, None])
  129. key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None]))
  130. diagonal_decay = block_size_range[:, None] - block_size_range[None, :]
  131. diagonal_decay = diagonal_decay[None, None, :, :]
  132. diagonal_decay = slope_rate * diagonal_decay
  133. diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf"))
  134. diagonal_decay = torch.exp(diagonal_decay)
  135. return query_decay, key_decay, diagonal_decay
  136. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  137. def forward(
  138. self,
  139. hidden_states: torch.Tensor,
  140. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  141. attention_mask: Optional[torch.Tensor],
  142. past_key_values: Optional[Cache] = None,
  143. cache_position: Optional[torch.LongTensor] = None,
  144. **kwargs: Unpack[FlashAttentionKwargs],
  145. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  146. batch_size, seq_len, hidden_size = hidden_states.shape
  147. num_blocks = (seq_len + self.block_size - 1) // self.block_size
  148. qkv_states = self.act_fn(self.qkv_proj(hidden_states))
  149. qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim)
  150. query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3)
  151. query_states = query_states.transpose(1, 2)
  152. key_states = key_states.transpose(1, 2)
  153. value_states = value_states.transpose(1, 2)
  154. # calculated (K.T @ V) and saved as cache
  155. attn_weights_inter = None
  156. if past_key_values is not None:
  157. attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx)
  158. if attn_weights_inter is None:
  159. attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to(
  160. value_states
  161. )
  162. # apply attention_mask
  163. if attention_mask is not None:
  164. attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor
  165. value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0)
  166. attn_output = []
  167. for i in range(num_blocks):
  168. start_idx = i * self.block_size
  169. end_idx = min(start_idx + self.block_size, seq_len)
  170. current_block_size = end_idx - start_idx
  171. current_query_states = query_states[:, :, start_idx:end_idx]
  172. current_key_states = key_states[:, :, start_idx:end_idx]
  173. current_value_states = value_states[:, :, start_idx:end_idx]
  174. current_query_decay = self.query_decay[:, :current_block_size]
  175. current_key_decay = self.key_decay[:, -current_block_size:]
  176. current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size]
  177. block_decay = torch.exp(-self.slope_rate * current_block_size)
  178. # intra: ( Q @ K.T ) @ V -> QK * V
  179. attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2))
  180. attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states)
  181. # inter: Q @ ( K.T @ V ) -> Q * KV
  182. attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter)
  183. # final attention output
  184. current_attn_output = attn_output_inter + attn_output_intra
  185. attn_output.append(current_attn_output)
  186. # calculate attn_weights_inter for next block or cache
  187. next_attn_weights_inter = torch.matmul(
  188. (current_key_states * current_key_decay).transpose(-1, -2), current_value_states
  189. )
  190. attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter
  191. else:
  192. ratio = torch.exp(-self.slope_rate)
  193. attn_output = []
  194. for i in range(seq_len):
  195. current_query_states = query_states[:, :, i : i + 1]
  196. current_key_states = key_states[:, :, i : i + 1]
  197. current_value_states = value_states[:, :, i : i + 1]
  198. current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states)
  199. attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter
  200. current_attn_output = torch.matmul(current_query_states, attn_weights_inter)
  201. attn_output.append(current_attn_output)
  202. # concatenate attention outputs over all blocks
  203. attn_output = torch.cat(attn_output, dim=-2)
  204. # final output projection
  205. attn_output = attn_output.transpose(1, 2)
  206. attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim)
  207. attn_output = self.norm(attn_output)
  208. attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output
  209. attn_output = self.out_proj(attn_output)
  210. # update cache
  211. if past_key_values is not None:
  212. past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter)
  213. return attn_output, attn_weights_inter
  214. def rotate_half(x):
  215. """Rotates half the hidden dims of the input."""
  216. x1 = x[..., : x.shape[-1] // 2]
  217. x2 = x[..., x.shape[-1] // 2 :]
  218. return torch.cat((-x2, x1), dim=-1)
  219. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  220. """Applies Rotary Position Embedding to the query and key tensors.
  221. Args:
  222. q (`torch.Tensor`): The query tensor.
  223. k (`torch.Tensor`): The key tensor.
  224. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  225. sin (`torch.Tensor`): The sine part of the rotary embedding.
  226. position_ids (`torch.Tensor`, *optional*):
  227. Deprecated and unused.
  228. unsqueeze_dim (`int`, *optional*, defaults to 1):
  229. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  230. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  231. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  232. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  233. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  234. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  235. Returns:
  236. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  237. """
  238. cos = cos.unsqueeze(unsqueeze_dim)
  239. sin = sin.unsqueeze(unsqueeze_dim)
  240. q_embed = (q * cos) + (rotate_half(q) * sin)
  241. k_embed = (k * cos) + (rotate_half(k) * sin)
  242. return q_embed, k_embed
  243. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  244. """
  245. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  246. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  247. """
  248. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  249. if n_rep == 1:
  250. return hidden_states
  251. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  252. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  253. def eager_attention_forward(
  254. module: nn.Module,
  255. query: torch.Tensor,
  256. key: torch.Tensor,
  257. value: torch.Tensor,
  258. attention_mask: Optional[torch.Tensor],
  259. scaling: float,
  260. dropout: float = 0.0,
  261. **kwargs: Unpack[TransformersKwargs],
  262. ):
  263. key_states = repeat_kv(key, module.num_key_value_groups)
  264. value_states = repeat_kv(value, module.num_key_value_groups)
  265. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  266. if attention_mask is not None:
  267. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  268. attn_weights = attn_weights + causal_mask
  269. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  270. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  271. attn_output = torch.matmul(attn_weights, value_states)
  272. attn_output = attn_output.transpose(1, 2).contiguous()
  273. return attn_output, attn_weights
  274. class MiniMaxAttention(nn.Module):
  275. """Multi-headed attention from 'Attention Is All You Need' paper"""
  276. def __init__(self, config: MiniMaxConfig, layer_idx: int):
  277. super().__init__()
  278. self.config = config
  279. self.layer_idx = layer_idx
  280. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  281. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  282. self.scaling = self.head_dim**-0.5
  283. self.attention_dropout = config.attention_dropout
  284. self.is_causal = True
  285. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  286. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  287. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  288. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  289. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  290. def forward(
  291. self,
  292. hidden_states: torch.Tensor,
  293. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  294. attention_mask: Optional[torch.Tensor],
  295. past_key_values: Optional[Cache] = None,
  296. cache_position: Optional[torch.LongTensor] = None,
  297. **kwargs: Unpack[FlashAttentionKwargs],
  298. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  299. input_shape = hidden_states.shape[:-1]
  300. hidden_shape = (*input_shape, -1, self.head_dim)
  301. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  302. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  303. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  304. cos, sin = position_embeddings
  305. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  306. if past_key_values is not None:
  307. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  308. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  309. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  310. attention_interface: Callable = eager_attention_forward
  311. if self.config._attn_implementation != "eager":
  312. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  313. attn_output, attn_weights = attention_interface(
  314. self,
  315. query_states,
  316. key_states,
  317. value_states,
  318. attention_mask,
  319. dropout=0.0 if not self.training else self.attention_dropout,
  320. scaling=self.scaling,
  321. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  322. **kwargs,
  323. )
  324. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  325. attn_output = self.o_proj(attn_output)
  326. return attn_output, attn_weights
  327. class MiniMaxBlockSparseTop2MLP(nn.Module):
  328. def __init__(self, config: MiniMaxConfig):
  329. super().__init__()
  330. self.ffn_dim = config.intermediate_size
  331. self.hidden_dim = config.hidden_size
  332. self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  333. self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
  334. self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  335. self.act_fn = ACT2FN[config.hidden_act]
  336. def forward(self, hidden_states):
  337. current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
  338. current_hidden_states = self.w2(current_hidden_states)
  339. return current_hidden_states
  340. class MiniMaxSparseMoeBlock(nn.Module):
  341. """
  342. This implementation is
  343. strictly equivalent to standard MoE with full capacity (no
  344. dropped tokens). It's faster since it formulates MoE operations
  345. in terms of block-sparse operations to accommodate imbalanced
  346. assignments of tokens to experts, whereas standard MoE either
  347. (1) drop tokens at the cost of reduced performance or (2) set
  348. capacity factor to number of experts and thus waste computation
  349. and memory on padding.
  350. """
  351. def __init__(self, config):
  352. super().__init__()
  353. self.hidden_dim = config.hidden_size
  354. self.ffn_dim = config.intermediate_size
  355. self.num_experts = config.num_local_experts
  356. self.top_k = config.num_experts_per_tok
  357. # gating
  358. self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
  359. self.experts = nn.ModuleList([MiniMaxBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
  360. # Jitter parameters
  361. self.jitter_noise = config.router_jitter_noise
  362. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  363. """ """
  364. batch_size, sequence_length, hidden_dim = hidden_states.shape
  365. if self.training and self.jitter_noise > 0:
  366. hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
  367. hidden_states = hidden_states.view(-1, hidden_dim)
  368. # router_logits: (batch * sequence_length, n_experts)
  369. router_logits = self.gate(hidden_states)
  370. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  371. routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
  372. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  373. # we cast back to the input dtype
  374. routing_weights = routing_weights.to(hidden_states.dtype)
  375. final_hidden_states = torch.zeros(
  376. (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
  377. )
  378. # One hot encode the selected experts to create an expert mask
  379. # this will be used to easily index which expert is going to be sollicitated
  380. expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
  381. expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  382. for expert_idx in expert_hit:
  383. expert_layer = self.experts[expert_idx]
  384. idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
  385. # Index the correct hidden states and compute the expert hidden state for
  386. # the current expert. We need to make sure to multiply the output hidden
  387. # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
  388. current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
  389. current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
  390. # However `index_add_` only support torch tensors for indexing so we'll use
  391. # the `top_x` tensor here.
  392. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
  393. final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  394. return final_hidden_states, router_logits
  395. class MiniMaxDecoderLayer(GradientCheckpointingLayer):
  396. def __init__(self, config: MiniMaxConfig, layer_idx: int):
  397. super().__init__()
  398. self.hidden_size = config.hidden_size
  399. self.self_attn = MiniMaxAttention(config, layer_idx)
  400. self.block_sparse_moe = MiniMaxSparseMoeBlock(config)
  401. self.input_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  402. self.post_attention_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  403. self.layer_idx = layer_idx
  404. self.layer_type = config.layer_types[layer_idx]
  405. self.mlp_alpha_factor = config.mlp_alpha_factor
  406. self.mlp_beta_factor = config.mlp_beta_factor
  407. if self.layer_type == "linear_attention":
  408. self.self_attn = MiniMaxLightningAttention(config, layer_idx)
  409. self.attn_alpha_factor = config.linear_attn_alpha_factor
  410. self.attn_beta_factor = config.linear_attn_beta_factor
  411. else:
  412. self.self_attn = MiniMaxAttention(config, layer_idx)
  413. self.attn_alpha_factor = config.full_attn_alpha_factor
  414. self.attn_beta_factor = config.full_attn_beta_factor
  415. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  416. def forward(
  417. self,
  418. hidden_states: torch.Tensor,
  419. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  420. attention_mask: Optional[torch.Tensor] = None,
  421. position_ids: Optional[torch.LongTensor] = None,
  422. past_key_values: Optional[Cache] = None,
  423. output_attentions: Optional[bool] = False,
  424. output_router_logits: Optional[bool] = False,
  425. use_cache: Optional[bool] = False,
  426. cache_position: Optional[torch.LongTensor] = None,
  427. **kwargs: Unpack[FlashAttentionKwargs],
  428. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  429. """
  430. Args:
  431. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  432. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`):
  433. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  434. with `head_dim` being the embedding dimension of each attention head.
  435. attention_mask (`torch.Tensor`, *optional*): attention mask of size
  436. `(batch, sequence_length)` where padding elements are indicated by 0.
  437. past_key_values (`Cache`, *optional*): cached past key and value projection states
  438. output_attentions (`bool`, *optional*):
  439. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  440. returned tensors for more detail.
  441. output_router_logits (`bool`, *optional*):
  442. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  443. should not be returned during inference.
  444. use_cache (`bool`, *optional*):
  445. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  446. (see `past_key_values`).
  447. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  448. Indices depicting the position of the input sequence tokens in the sequence.
  449. kwargs (`dict`, *optional*):
  450. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  451. into the model
  452. """
  453. hidden_states = self.input_layernorm(hidden_states)
  454. residual = hidden_states
  455. # Self Attention
  456. hidden_states, _ = self.self_attn(
  457. hidden_states=hidden_states,
  458. position_embeddings=position_embeddings,
  459. attention_mask=attention_mask,
  460. position_ids=position_ids,
  461. past_key_values=past_key_values,
  462. output_attentions=output_attentions,
  463. use_cache=use_cache,
  464. cache_position=cache_position,
  465. **kwargs,
  466. )
  467. hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
  468. # Fully Connected
  469. hidden_states = self.post_attention_layernorm(hidden_states)
  470. residual = hidden_states
  471. hidden_states, _ = self.block_sparse_moe(hidden_states)
  472. hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
  473. return hidden_states
  474. @auto_docstring
  475. class MiniMaxPreTrainedModel(PreTrainedModel):
  476. config: MiniMaxConfig
  477. base_model_prefix = "model"
  478. supports_gradient_checkpointing = True
  479. _no_split_modules = ["MiniMaxDecoderLayer"]
  480. _skip_keys_device_placement = ["past_key_values"]
  481. _supports_flash_attn = True
  482. _supports_sdpa = True
  483. _supports_flex_attn = True
  484. _can_compile_fullgraph = False
  485. _supports_attention_backend = True
  486. _can_record_outputs = {
  487. "router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1),
  488. "hidden_states": MiniMaxDecoderLayer,
  489. "attentions": [MiniMaxAttention, MiniMaxLightningAttention],
  490. }
  491. class MiniMaxRotaryEmbedding(nn.Module):
  492. inv_freq: torch.Tensor # fix linting for `register_buffer`
  493. def __init__(self, config: MiniMaxConfig, device=None):
  494. super().__init__()
  495. # BC: "rope_type" was originally "type"
  496. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  497. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  498. else:
  499. self.rope_type = "default"
  500. self.max_seq_len_cached = config.max_position_embeddings
  501. self.original_max_seq_len = config.max_position_embeddings
  502. self.config = config
  503. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  504. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  505. self.register_buffer("inv_freq", inv_freq, persistent=False)
  506. self.original_inv_freq = self.inv_freq
  507. @torch.no_grad()
  508. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  509. def forward(self, x, position_ids):
  510. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  511. position_ids_expanded = position_ids[:, None, :].float()
  512. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  513. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  514. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  515. emb = torch.cat((freqs, freqs), dim=-1)
  516. cos = emb.cos() * self.attention_scaling
  517. sin = emb.sin() * self.attention_scaling
  518. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  519. @auto_docstring
  520. class MiniMaxModel(MiniMaxPreTrainedModel):
  521. def __init__(self, config: MiniMaxConfig):
  522. super().__init__(config)
  523. self.padding_idx = config.pad_token_id
  524. self.vocab_size = config.vocab_size
  525. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  526. self.layers = nn.ModuleList(
  527. [MiniMaxDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  528. )
  529. self.norm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  530. self.rotary_emb = MiniMaxRotaryEmbedding(config=config)
  531. self.gradient_checkpointing = False
  532. # Initialize weights and apply final processing
  533. self.post_init()
  534. @check_model_inputs()
  535. def forward(
  536. self,
  537. input_ids: Optional[torch.LongTensor] = None,
  538. attention_mask: Optional[torch.Tensor] = None,
  539. position_ids: Optional[torch.LongTensor] = None,
  540. past_key_values: Optional[MiniMaxCache] = None,
  541. inputs_embeds: Optional[torch.FloatTensor] = None,
  542. use_cache: Optional[bool] = None,
  543. output_attentions: Optional[bool] = None,
  544. cache_position: Optional[torch.LongTensor] = None,
  545. **kwargs: Unpack[TransformersKwargs],
  546. ) -> MoeModelOutputWithPast:
  547. if (input_ids is None) ^ (inputs_embeds is not None):
  548. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  549. if use_cache and past_key_values is None:
  550. past_key_values = MiniMaxCache()
  551. elif use_cache and not isinstance(past_key_values, MiniMaxCache):
  552. raise ValueError(
  553. f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
  554. )
  555. if inputs_embeds is None:
  556. inputs_embeds = self.embed_tokens(input_ids)
  557. if cache_position is None:
  558. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  559. cache_position = torch.arange(
  560. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  561. )
  562. if position_ids is None:
  563. position_ids = cache_position.unsqueeze(0)
  564. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  565. causal_mask = mask_function(
  566. config=self.config,
  567. input_embeds=inputs_embeds,
  568. attention_mask=attention_mask,
  569. cache_position=cache_position,
  570. past_key_values=past_key_values,
  571. position_ids=position_ids,
  572. )
  573. hidden_states = inputs_embeds
  574. # create position embeddings to be shared across the decoder layers
  575. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  576. for decoder_layer in self.layers:
  577. if decoder_layer.layer_type == "full_attention":
  578. input_attention_mask = causal_mask
  579. else:
  580. # lightning attention uses original attention_mask, and uses it only for the first step
  581. input_attention_mask = attention_mask
  582. hidden_states = decoder_layer(
  583. hidden_states,
  584. position_embeddings=position_embeddings,
  585. attention_mask=input_attention_mask,
  586. position_ids=position_ids,
  587. past_key_values=past_key_values,
  588. use_cache=use_cache,
  589. cache_position=cache_position,
  590. **kwargs,
  591. )
  592. hidden_states = self.norm(hidden_states)
  593. return MoeModelOutputWithPast(
  594. last_hidden_state=hidden_states,
  595. past_key_values=past_key_values,
  596. )
  597. def load_balancing_loss_func(
  598. gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
  599. num_experts: Optional[int] = None,
  600. top_k=2,
  601. attention_mask: Optional[torch.Tensor] = None,
  602. ) -> Union[torch.Tensor, int]:
  603. r"""
  604. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  605. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  606. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  607. experts is too unbalanced.
  608. Args:
  609. gate_logits:
  610. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  611. shape [batch_size X sequence_length, num_experts].
  612. num_experts:
  613. Number of experts
  614. top_k:
  615. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  616. parameter.
  617. attention_mask (`torch.Tensor`, *optional*):
  618. The attention_mask used in forward function
  619. shape [batch_size X sequence_length] if not None.
  620. Returns:
  621. The auxiliary loss.
  622. """
  623. if gate_logits is None or not isinstance(gate_logits, tuple):
  624. return 0
  625. if isinstance(gate_logits, tuple):
  626. compute_device = gate_logits[0].device
  627. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  628. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  629. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  630. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  631. if attention_mask is None:
  632. # Compute the percentage of tokens routed to each experts
  633. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  634. # Compute the average probability of routing to these experts
  635. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  636. else:
  637. batch_size, sequence_length = attention_mask.shape
  638. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  639. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  640. expert_attention_mask = (
  641. attention_mask[None, :, :, None, None]
  642. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  643. .reshape(-1, top_k, num_experts)
  644. .to(compute_device)
  645. )
  646. # Compute the percentage of tokens routed to each experts
  647. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  648. expert_attention_mask, dim=0
  649. )
  650. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  651. router_per_expert_attention_mask = (
  652. attention_mask[None, :, :, None]
  653. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  654. .reshape(-1, num_experts)
  655. .to(compute_device)
  656. )
  657. # Compute the average probability of routing to these experts
  658. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  659. router_per_expert_attention_mask, dim=0
  660. )
  661. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  662. return overall_loss * num_experts
  663. @auto_docstring
  664. class MiniMaxForCausalLM(MiniMaxPreTrainedModel, GenerationMixin):
  665. _tied_weights_keys = ["lm_head.weight"]
  666. _tp_plan = {"lm_head": "colwise_rep"}
  667. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  668. def __init__(self, config):
  669. super().__init__(config)
  670. self.model = MiniMaxModel(config)
  671. self.vocab_size = config.vocab_size
  672. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  673. self.router_aux_loss_coef = config.router_aux_loss_coef
  674. self.num_experts = config.num_local_experts
  675. self.num_experts_per_tok = config.num_experts_per_tok
  676. # Initialize weights and apply final processing
  677. self.post_init()
  678. @can_return_tuple
  679. @auto_docstring
  680. def forward(
  681. self,
  682. input_ids: Optional[torch.LongTensor] = None,
  683. attention_mask: Optional[torch.Tensor] = None,
  684. position_ids: Optional[torch.LongTensor] = None,
  685. past_key_values: Optional[Cache] = None,
  686. inputs_embeds: Optional[torch.FloatTensor] = None,
  687. labels: Optional[torch.LongTensor] = None,
  688. use_cache: Optional[bool] = None,
  689. output_router_logits: Optional[bool] = None,
  690. cache_position: Optional[torch.LongTensor] = None,
  691. logits_to_keep: Union[int, torch.Tensor] = 0,
  692. **kwargs: Unpack[TransformersKwargs],
  693. ) -> MoeCausalLMOutputWithPast:
  694. r"""
  695. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  696. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  697. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  698. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  699. Example:
  700. ```python
  701. >>> from transformers import AutoTokenizer, MiniMaxForCausalLM
  702. >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
  703. >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
  704. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  705. >>> inputs = tokenizer(prompt, return_tensors="pt")
  706. >>> # Generate
  707. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  708. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  709. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  710. ```"""
  711. output_router_logits = (
  712. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  713. )
  714. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  715. outputs: MoeModelOutputWithPast = self.model(
  716. input_ids=input_ids,
  717. attention_mask=attention_mask,
  718. position_ids=position_ids,
  719. past_key_values=past_key_values,
  720. inputs_embeds=inputs_embeds,
  721. use_cache=use_cache,
  722. output_router_logits=output_router_logits,
  723. cache_position=cache_position,
  724. **kwargs,
  725. )
  726. hidden_states = outputs.last_hidden_state
  727. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  728. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  729. logits = self.lm_head(hidden_states[:, slice_indices, :])
  730. loss = None
  731. if labels is not None:
  732. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  733. aux_loss = None
  734. if output_router_logits:
  735. aux_loss = load_balancing_loss_func(
  736. outputs.router_logits,
  737. self.num_experts,
  738. self.num_experts_per_tok,
  739. attention_mask,
  740. )
  741. if labels is not None:
  742. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  743. return MoeCausalLMOutputWithPast(
  744. loss=loss,
  745. aux_loss=aux_loss,
  746. logits=logits,
  747. past_key_values=outputs.past_key_values,
  748. hidden_states=outputs.hidden_states,
  749. attentions=outputs.attentions,
  750. router_logits=outputs.router_logits,
  751. )
  752. class MiniMaxForSequenceClassification(GenericForSequenceClassification, MiniMaxPreTrainedModel):
  753. pass
  754. class MiniMaxForTokenClassification(GenericForTokenClassification, MiniMaxPreTrainedModel):
  755. pass
  756. class MiniMaxForQuestionAnswering(GenericForQuestionAnswering, MiniMaxPreTrainedModel):
  757. pass
  758. __all__ = [
  759. "MiniMaxPreTrainedModel",
  760. "MiniMaxModel",
  761. "MiniMaxForCausalLM",
  762. "MiniMaxForSequenceClassification",
  763. "MiniMaxForTokenClassification",
  764. "MiniMaxForQuestionAnswering",
  765. ]