modular_minimax.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. # coding=utf-8
  2. # Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch MiniMax model."""
  17. from typing import Optional
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache
  23. from ...configuration_utils import layer_type_validation
  24. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  25. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import MoeModelOutputWithPast
  28. from ...processing_utils import Unpack
  29. from ...utils import TransformersKwargs, logging
  30. from ...utils.deprecation import deprecate_kwarg
  31. from ...utils.generic import OutputRecorder, check_model_inputs
  32. from ..mixtral.configuration_mixtral import MixtralConfig
  33. from ..mixtral.modeling_mixtral import (
  34. MixtralAttention,
  35. MixtralDecoderLayer,
  36. MixtralForCausalLM,
  37. MixtralForQuestionAnswering,
  38. MixtralForSequenceClassification,
  39. MixtralForTokenClassification,
  40. MixtralModel,
  41. MixtralPreTrainedModel,
  42. MixtralRMSNorm,
  43. MixtralSparseMoeBlock,
  44. )
  45. logger = logging.get_logger(__name__)
  46. class MiniMaxConfig(MixtralConfig):
  47. r"""
  48. This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
  49. MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
  50. with the defaults will yield a similar configuration to that of the MiniMax.
  51. [MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)
  52. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  53. documentation from [`PretrainedConfig`] for more information.
  54. Args:
  55. vocab_size (`int`, *optional*, defaults to 32000):
  56. Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
  57. `inputs_ids` passed when calling [`MiniMaxModel`]
  58. hidden_size (`int`, *optional*, defaults to 4096):
  59. Dimension of the hidden representations.
  60. intermediate_size (`int`, *optional*, defaults to 14336):
  61. Dimension of the MLP representations.
  62. num_hidden_layers (`int`, *optional*, defaults to 32):
  63. Number of hidden layers in the Transformer encoder.
  64. num_attention_heads (`int`, *optional*, defaults to 32):
  65. Number of attention heads for each attention layer in the Transformer encoder.
  66. num_key_value_heads (`int`, *optional*, defaults to 8):
  67. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  68. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  69. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  70. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  71. by meanpooling all the original heads within that group. For more details, check out [this
  72. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
  73. head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
  74. The attention head dimension.
  75. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  76. The non-linear activation function (function or string) in the decoder.
  77. max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
  78. The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
  79. allows sequence of up to 4096*32 tokens.
  80. initializer_range (`float`, *optional*, defaults to 0.02):
  81. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  82. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  83. The epsilon used by the rms normalization layers.
  84. use_cache (`bool`, *optional*, defaults to `True`):
  85. Whether or not the model should return the last key/values attentions (not used by all models). Only
  86. relevant if `config.is_decoder=True`.
  87. pad_token_id (`int`, *optional*):
  88. The id of the padding token.
  89. bos_token_id (`int`, *optional*, defaults to 1):
  90. The id of the "beginning-of-sequence" token.
  91. eos_token_id (`int`, *optional*, defaults to 2):
  92. The id of the "end-of-sequence" token.
  93. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  94. Whether the model's input and output word embeddings should be tied.
  95. rope_theta (`float`, *optional*, defaults to 1000000.0):
  96. The base period of the RoPE embeddings.
  97. sliding_window (`int`, *optional*):
  98. Sliding window attention window size. If not specified, will default to `4096`.
  99. attention_dropout (`float`, *optional*, defaults to 0.0):
  100. The dropout ratio for the attention probabilities.
  101. num_experts_per_tok (`int`, *optional*, defaults to 2):
  102. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  103. parameter
  104. num_local_experts (`int`, *optional*, defaults to 8):
  105. Number of experts per Sparse MLP layer.
  106. output_router_logits (`bool`, *optional*, defaults to `False`):
  107. Whether or not the router logits should be returned by the model. Enabling this will also
  108. allow the model to output the auxiliary loss. See [here]() for more details
  109. router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
  110. The aux loss factor for the total loss.
  111. router_jitter_noise (`float`, *optional*, defaults to 0.0):
  112. Amount of noise to add to the router.
  113. layer_types (`list`, *optional*):
  114. Attention pattern for each layer.
  115. block_size (`int`, *optional*, defaults to 256):
  116. The length of each attention block, determining how queries, keys, and values
  117. are grouped and processed for intra- and inter-block attention.
  118. full_attn_alpha_factor (`float`, *optional*, defaults to 1):
  119. Weight for residual value in residual connection after normal attention.
  120. full_attn_beta_factor (`float`, *optional*, defaults to 1):
  121. Weight for hidden state value in residual connection after normal attention.
  122. linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
  123. Weight for residual value in residual connection after lightning attention.
  124. linear_attn_beta_factor (`float`, *optional*, defaults to 1):
  125. Weight for hidden state value in residual connection after lightning attention.
  126. mlp_alpha_factor (`float`, *optional*, defaults to 1):
  127. Weight for residual value in residual connection after MLP.
  128. mlp_beta_factor (`float`, *optional*, defaults to 1):
  129. Weight for hidden state value in residual connection after MLP.
  130. ```python
  131. >>> from transformers import MiniMaxModel, MiniMaxConfig
  132. >>> # Initializing a MiniMax style configuration
  133. >>> configuration = MiniMaxConfig()
  134. >>> # Initializing a model from the MiniMax style configuration
  135. >>> model = MiniMaxModel(configuration)
  136. >>> # Accessing the model configuration
  137. >>> configuration = model.config
  138. ```"""
  139. def __init__(
  140. self,
  141. layer_types=None,
  142. block_size=256,
  143. full_attn_alpha_factor=1,
  144. full_attn_beta_factor=1,
  145. linear_attn_alpha_factor=1,
  146. linear_attn_beta_factor=1,
  147. mlp_alpha_factor=1,
  148. mlp_beta_factor=1,
  149. **super_kwargs,
  150. ):
  151. super().__init__(**super_kwargs)
  152. self.layer_types = layer_types
  153. self.block_size = block_size
  154. self.full_attn_alpha_factor = full_attn_alpha_factor
  155. self.full_attn_beta_factor = full_attn_beta_factor
  156. self.linear_attn_alpha_factor = linear_attn_alpha_factor
  157. self.linear_attn_beta_factor = linear_attn_beta_factor
  158. self.mlp_alpha_factor = mlp_alpha_factor
  159. self.mlp_beta_factor = mlp_beta_factor
  160. if self.layer_types is None:
  161. self.layer_types = [
  162. "full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
  163. ]
  164. layer_type_validation(self.layer_types, self.num_hidden_layers)
  165. class MiniMaxRMSNorm(MixtralRMSNorm):
  166. pass
  167. class MiniMaxCache(DynamicCache):
  168. def __init__(self):
  169. super().__init__()
  170. self.linear_cache: list[torch.Tensor] = []
  171. def set_linear_cache(self, layer_idx, linear_cache):
  172. # There may be skipped layers, fill them with empty lists
  173. for _ in range(len(self.linear_cache), layer_idx + 1):
  174. self.linear_cache.append([])
  175. self.linear_cache[layer_idx] = linear_cache
  176. def get_linear_cache(self, layer_idx: int):
  177. if layer_idx < len(self):
  178. return self.linear_cache[layer_idx]
  179. return None
  180. def __len__(self):
  181. return max(super().__len__(), len(self.linear_cache))
  182. def __getitem__(self, layer_idx: int):
  183. if layer_idx < len(self.linear_cache) and self.linear_cache[layer_idx] != []:
  184. return (self.linear_cache[layer_idx],)
  185. return super().__getitem__(layer_idx)
  186. def __iter__(self):
  187. for layer_idx in range(len(self)):
  188. yield self[layer_idx]
  189. def batch_repeat_interleave(self, repeats: int):
  190. for layer_idx in range(len(self)):
  191. if self.linear_cache[layer_idx] != []:
  192. self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
  193. else:
  194. self.layers[layer_idx].batch_repeat_interleave(repeats)
  195. def batch_select_indices(self, indices: torch.Tensor):
  196. for layer_idx in range(len(self)):
  197. if self.linear_cache[layer_idx] != []:
  198. self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
  199. else:
  200. self.layers[layer_idx].batch_select_indices(indices)
  201. def crop(self, max_length: int):
  202. raise RuntimeError("MiniMaxCache doesnot support `crop` method")
  203. class MiniMaxLightningAttention(nn.Module):
  204. def __init__(self, config: MiniMaxConfig, layer_idx: int):
  205. super().__init__()
  206. self.layer_idx = layer_idx
  207. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  208. self.num_attention_heads = config.num_attention_heads
  209. self.num_hidden_layers = config.num_hidden_layers
  210. self.block_size = config.block_size
  211. self.act_fn = ACT2FN[config.hidden_act]
  212. self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads)
  213. self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False)
  214. self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  215. self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
  216. slope_rate = self.get_slope_rate()
  217. query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
  218. self.register_buffer("slope_rate", slope_rate)
  219. self.register_buffer("query_decay", query_decay)
  220. self.register_buffer("key_decay", key_decay)
  221. self.register_buffer("diagonal_decay", diagonal_decay)
  222. def get_slope_rate(self):
  223. base = 1 / (2 ** (8 / self.num_attention_heads))
  224. exponent = torch.arange(self.num_attention_heads) + 1
  225. factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
  226. rate = base**exponent
  227. rate = rate * factor
  228. rate = rate[:, None, None]
  229. return rate
  230. def decay_factors(self, slope_rate):
  231. block_size_range = torch.arange(self.block_size) + 1
  232. query_decay = torch.exp(-slope_rate * block_size_range[:, None])
  233. key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None]))
  234. diagonal_decay = block_size_range[:, None] - block_size_range[None, :]
  235. diagonal_decay = diagonal_decay[None, None, :, :]
  236. diagonal_decay = slope_rate * diagonal_decay
  237. diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf"))
  238. diagonal_decay = torch.exp(diagonal_decay)
  239. return query_decay, key_decay, diagonal_decay
  240. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  241. def forward(
  242. self,
  243. hidden_states: torch.Tensor,
  244. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  245. attention_mask: Optional[torch.Tensor],
  246. past_key_values: Optional[Cache] = None,
  247. cache_position: Optional[torch.LongTensor] = None,
  248. **kwargs: Unpack[FlashAttentionKwargs],
  249. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  250. batch_size, seq_len, hidden_size = hidden_states.shape
  251. num_blocks = (seq_len + self.block_size - 1) // self.block_size
  252. qkv_states = self.act_fn(self.qkv_proj(hidden_states))
  253. qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim)
  254. query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3)
  255. query_states = query_states.transpose(1, 2)
  256. key_states = key_states.transpose(1, 2)
  257. value_states = value_states.transpose(1, 2)
  258. # calculated (K.T @ V) and saved as cache
  259. attn_weights_inter = None
  260. if past_key_values is not None:
  261. attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx)
  262. if attn_weights_inter is None:
  263. attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to(
  264. value_states
  265. )
  266. # apply attention_mask
  267. if attention_mask is not None:
  268. attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor
  269. value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0)
  270. attn_output = []
  271. for i in range(num_blocks):
  272. start_idx = i * self.block_size
  273. end_idx = min(start_idx + self.block_size, seq_len)
  274. current_block_size = end_idx - start_idx
  275. current_query_states = query_states[:, :, start_idx:end_idx]
  276. current_key_states = key_states[:, :, start_idx:end_idx]
  277. current_value_states = value_states[:, :, start_idx:end_idx]
  278. current_query_decay = self.query_decay[:, :current_block_size]
  279. current_key_decay = self.key_decay[:, -current_block_size:]
  280. current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size]
  281. block_decay = torch.exp(-self.slope_rate * current_block_size)
  282. # intra: ( Q @ K.T ) @ V -> QK * V
  283. attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2))
  284. attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states)
  285. # inter: Q @ ( K.T @ V ) -> Q * KV
  286. attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter)
  287. # final attention output
  288. current_attn_output = attn_output_inter + attn_output_intra
  289. attn_output.append(current_attn_output)
  290. # calculate attn_weights_inter for next block or cache
  291. next_attn_weights_inter = torch.matmul(
  292. (current_key_states * current_key_decay).transpose(-1, -2), current_value_states
  293. )
  294. attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter
  295. else:
  296. ratio = torch.exp(-self.slope_rate)
  297. attn_output = []
  298. for i in range(seq_len):
  299. current_query_states = query_states[:, :, i : i + 1]
  300. current_key_states = key_states[:, :, i : i + 1]
  301. current_value_states = value_states[:, :, i : i + 1]
  302. current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states)
  303. attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter
  304. current_attn_output = torch.matmul(current_query_states, attn_weights_inter)
  305. attn_output.append(current_attn_output)
  306. # concatenate attention outputs over all blocks
  307. attn_output = torch.cat(attn_output, dim=-2)
  308. # final output projection
  309. attn_output = attn_output.transpose(1, 2)
  310. attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim)
  311. attn_output = self.norm(attn_output)
  312. attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output
  313. attn_output = self.out_proj(attn_output)
  314. # update cache
  315. if past_key_values is not None:
  316. past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter)
  317. return attn_output, attn_weights_inter
  318. class MiniMaxAttention(MixtralAttention):
  319. pass
  320. class MiniMaxSparseMoeBlock(MixtralSparseMoeBlock):
  321. pass
  322. class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
  323. def __init__(self, config: MiniMaxConfig, layer_idx: int):
  324. super().__init__(config, layer_idx)
  325. self.layer_idx = layer_idx
  326. self.layer_type = config.layer_types[layer_idx]
  327. self.mlp_alpha_factor = config.mlp_alpha_factor
  328. self.mlp_beta_factor = config.mlp_beta_factor
  329. if self.layer_type == "linear_attention":
  330. self.self_attn = MiniMaxLightningAttention(config, layer_idx)
  331. self.attn_alpha_factor = config.linear_attn_alpha_factor
  332. self.attn_beta_factor = config.linear_attn_beta_factor
  333. else:
  334. self.self_attn = MiniMaxAttention(config, layer_idx)
  335. self.attn_alpha_factor = config.full_attn_alpha_factor
  336. self.attn_beta_factor = config.full_attn_beta_factor
  337. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  338. def forward(
  339. self,
  340. hidden_states: torch.Tensor,
  341. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  342. attention_mask: Optional[torch.Tensor] = None,
  343. position_ids: Optional[torch.LongTensor] = None,
  344. past_key_values: Optional[Cache] = None,
  345. output_attentions: Optional[bool] = False,
  346. output_router_logits: Optional[bool] = False,
  347. use_cache: Optional[bool] = False,
  348. cache_position: Optional[torch.LongTensor] = None,
  349. **kwargs: Unpack[FlashAttentionKwargs],
  350. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  351. """
  352. Args:
  353. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  354. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`):
  355. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  356. with `head_dim` being the embedding dimension of each attention head.
  357. attention_mask (`torch.Tensor`, *optional*): attention mask of size
  358. `(batch, sequence_length)` where padding elements are indicated by 0.
  359. past_key_values (`Cache`, *optional*): cached past key and value projection states
  360. output_attentions (`bool`, *optional*):
  361. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  362. returned tensors for more detail.
  363. output_router_logits (`bool`, *optional*):
  364. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  365. should not be returned during inference.
  366. use_cache (`bool`, *optional*):
  367. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  368. (see `past_key_values`).
  369. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  370. Indices depicting the position of the input sequence tokens in the sequence.
  371. kwargs (`dict`, *optional*):
  372. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  373. into the model
  374. """
  375. hidden_states = self.input_layernorm(hidden_states)
  376. residual = hidden_states
  377. # Self Attention
  378. hidden_states, _ = self.self_attn(
  379. hidden_states=hidden_states,
  380. position_embeddings=position_embeddings,
  381. attention_mask=attention_mask,
  382. position_ids=position_ids,
  383. past_key_values=past_key_values,
  384. output_attentions=output_attentions,
  385. use_cache=use_cache,
  386. cache_position=cache_position,
  387. **kwargs,
  388. )
  389. hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
  390. # Fully Connected
  391. hidden_states = self.post_attention_layernorm(hidden_states)
  392. residual = hidden_states
  393. hidden_states, _ = self.block_sparse_moe(hidden_states)
  394. hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
  395. return hidden_states
  396. class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
  397. _can_compile_fullgraph = False
  398. _can_record_outputs = {
  399. "router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1),
  400. "hidden_states": MiniMaxDecoderLayer,
  401. "attentions": [MiniMaxAttention, MiniMaxLightningAttention],
  402. }
  403. class MiniMaxModel(MixtralModel):
  404. @check_model_inputs()
  405. def forward(
  406. self,
  407. input_ids: Optional[torch.LongTensor] = None,
  408. attention_mask: Optional[torch.Tensor] = None,
  409. position_ids: Optional[torch.LongTensor] = None,
  410. past_key_values: Optional[MiniMaxCache] = None,
  411. inputs_embeds: Optional[torch.FloatTensor] = None,
  412. use_cache: Optional[bool] = None,
  413. output_attentions: Optional[bool] = None,
  414. cache_position: Optional[torch.LongTensor] = None,
  415. **kwargs: Unpack[TransformersKwargs],
  416. ) -> MoeModelOutputWithPast:
  417. if (input_ids is None) ^ (inputs_embeds is not None):
  418. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  419. if use_cache and past_key_values is None:
  420. past_key_values = MiniMaxCache()
  421. elif use_cache and not isinstance(past_key_values, MiniMaxCache):
  422. raise ValueError(
  423. f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
  424. )
  425. if inputs_embeds is None:
  426. inputs_embeds = self.embed_tokens(input_ids)
  427. if cache_position is None:
  428. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  429. cache_position = torch.arange(
  430. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  431. )
  432. if position_ids is None:
  433. position_ids = cache_position.unsqueeze(0)
  434. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  435. causal_mask = mask_function(
  436. config=self.config,
  437. input_embeds=inputs_embeds,
  438. attention_mask=attention_mask,
  439. cache_position=cache_position,
  440. past_key_values=past_key_values,
  441. position_ids=position_ids,
  442. )
  443. hidden_states = inputs_embeds
  444. # create position embeddings to be shared across the decoder layers
  445. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  446. for decoder_layer in self.layers:
  447. if decoder_layer.layer_type == "full_attention":
  448. input_attention_mask = causal_mask
  449. else:
  450. # lightning attention uses original attention_mask, and uses it only for the first step
  451. input_attention_mask = attention_mask
  452. hidden_states = decoder_layer(
  453. hidden_states,
  454. position_embeddings=position_embeddings,
  455. attention_mask=input_attention_mask,
  456. position_ids=position_ids,
  457. past_key_values=past_key_values,
  458. use_cache=use_cache,
  459. cache_position=cache_position,
  460. **kwargs,
  461. )
  462. hidden_states = self.norm(hidden_states)
  463. return MoeModelOutputWithPast(
  464. last_hidden_state=hidden_states,
  465. past_key_values=past_key_values,
  466. )
  467. class MiniMaxForCausalLM(MixtralForCausalLM):
  468. def forward(self, **super_kwargs):
  469. r"""
  470. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  471. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  472. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  473. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  474. Example:
  475. ```python
  476. >>> from transformers import AutoTokenizer, MiniMaxForCausalLM
  477. >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
  478. >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
  479. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  480. >>> inputs = tokenizer(prompt, return_tensors="pt")
  481. >>> # Generate
  482. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  483. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  484. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  485. ```"""
  486. return super().forward(**super_kwargs)
  487. class MiniMaxForSequenceClassification(MixtralForSequenceClassification):
  488. pass
  489. class MiniMaxForTokenClassification(MixtralForTokenClassification):
  490. pass
  491. class MiniMaxForQuestionAnswering(MixtralForQuestionAnswering):
  492. pass
  493. __all__ = [
  494. "MiniMaxConfig",
  495. "MiniMaxPreTrainedModel",
  496. "MiniMaxModel",
  497. "MiniMaxForCausalLM",
  498. "MiniMaxForSequenceClassification",
  499. "MiniMaxForTokenClassification",
  500. "MiniMaxForQuestionAnswering",
  501. ]