modeling_phimoe.py 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354
  1. # coding=utf-8
  2. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Phimoe model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache, StaticCache
  22. from ...generation import GenerationMixin
  23. from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
  24. from ...modeling_flash_attention_utils import is_flash_attn_available
  25. from ...modeling_layers import (
  26. GenericForSequenceClassification,
  27. GradientCheckpointingLayer,
  28. )
  29. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  30. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
  31. from ...modeling_utils import PreTrainedModel
  32. from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
  33. from ...utils.deprecation import deprecate_kwarg
  34. from .configuration_phimoe import PhimoeConfig
  35. if is_flash_attn_available():
  36. from ...modeling_flash_attention_utils import _flash_attention_forward
  37. if is_torch_flex_attn_available():
  38. from torch.nn.attention.flex_attention import BlockMask
  39. from ...integrations.flex_attention import make_flex_block_causal_mask
  40. # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
  41. # It means that the function will not be traced through and simply appear as a node in the graph.
  42. _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
  43. logger = logging.get_logger(__name__)
  44. # Copied from transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func
  45. def load_balancing_loss_func(
  46. gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
  47. num_experts: Optional[int] = None,
  48. top_k=2,
  49. attention_mask: Optional[torch.Tensor] = None,
  50. ) -> Union[torch.Tensor, int]:
  51. r"""
  52. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  53. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  54. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  55. experts is too unbalanced.
  56. Args:
  57. gate_logits:
  58. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  59. shape [batch_size X sequence_length, num_experts].
  60. num_experts:
  61. Number of experts
  62. top_k:
  63. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  64. parameter.
  65. attention_mask (`torch.Tensor`, *optional*):
  66. The attention_mask used in forward function
  67. shape [batch_size X sequence_length] if not None.
  68. Returns:
  69. The auxiliary loss.
  70. """
  71. if gate_logits is None or not isinstance(gate_logits, tuple):
  72. return 0
  73. if isinstance(gate_logits, tuple):
  74. compute_device = gate_logits[0].device
  75. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  76. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  77. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  78. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  79. if attention_mask is None:
  80. # Compute the percentage of tokens routed to each experts
  81. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  82. # Compute the average probability of routing to these experts
  83. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  84. else:
  85. batch_size, sequence_length = attention_mask.shape
  86. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  87. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  88. expert_attention_mask = (
  89. attention_mask[None, :, :, None, None]
  90. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  91. .reshape(-1, top_k, num_experts)
  92. .to(compute_device)
  93. )
  94. # Compute the percentage of tokens routed to each experts
  95. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  96. expert_attention_mask, dim=0
  97. )
  98. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  99. router_per_expert_attention_mask = (
  100. attention_mask[None, :, :, None]
  101. .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
  102. .reshape(-1, routing_weights.shape[1])
  103. .to(compute_device)
  104. )
  105. # Compute the average probability of routing to these experts
  106. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  107. router_per_expert_attention_mask, dim=0
  108. )
  109. device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
  110. rank = routing_weights.shape[1] * int(device_index)
  111. overall_loss = torch.sum(
  112. tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
  113. )
  114. return overall_loss * num_experts
  115. class PhimoeRotaryEmbedding(nn.Module):
  116. def __init__(
  117. self,
  118. config: Optional[PhimoeConfig] = None,
  119. ):
  120. super().__init__()
  121. self.config = config
  122. if config.rope_scaling is not None:
  123. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  124. self.short_mscale = config.rope_scaling.get("short_mscale")
  125. self.long_mscale = config.rope_scaling.get("long_mscale")
  126. else:
  127. self.rope_type = "default"
  128. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  129. def forward(self, x, seq_len=None):
  130. mscale = None
  131. if self.config.rope_scaling and seq_len:
  132. mscale = (
  133. self.long_mscale
  134. if seq_len > self.config.rope_scaling["original_max_position_embeddings"]
  135. else self.short_mscale
  136. )
  137. inv_freq, attention_scaling = self.rope_init_fn(self.config, x.device, seq_len)
  138. mscale = attention_scaling if mscale is None else mscale
  139. t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
  140. freqs = torch.outer(t, inv_freq)
  141. emb = torch.cat((freqs, freqs), dim=-1)
  142. return (emb.cos() * mscale).to(x.dtype), (emb.sin() * mscale).to(x.dtype)
  143. # Copied from transformers.models.llama.modeling_llama.rotate_half
  144. def rotate_half(x):
  145. """Rotates half the hidden dims of the input."""
  146. x1 = x[..., : x.shape[-1] // 2]
  147. x2 = x[..., x.shape[-1] // 2 :]
  148. return torch.cat((-x2, x1), dim=-1)
  149. def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
  150. """Applies Rotary Position Embedding to the query and key tensors.
  151. Args:
  152. q (`torch.Tensor`): The query tensor.
  153. k (`torch.Tensor`): The key tensor.
  154. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  155. sin (`torch.Tensor`): The sine part of the rotary embedding.
  156. position_ids (`torch.Tensor`):
  157. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  158. used to pass offsetted position ids when working with a KV-cache.
  159. unsqueeze_dim (`int`, *optional*, defaults to 1):
  160. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  161. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  162. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  163. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  164. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  165. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  166. Returns:
  167. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  168. """
  169. cos = cos[position_ids].unsqueeze(unsqueeze_dim)
  170. sin = sin[position_ids].unsqueeze(unsqueeze_dim)
  171. q_embed = (q * cos) + (rotate_half(q) * sin)
  172. k_embed = (k * cos) + (rotate_half(k) * sin)
  173. return q_embed, k_embed
  174. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  175. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  176. """
  177. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  178. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  179. """
  180. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  181. if n_rep == 1:
  182. return hidden_states
  183. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  184. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  185. class PhimoeAttention(nn.Module):
  186. """
  187. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  188. and "Generating Long Sequences with Sparse Transformers".
  189. """
  190. def __init__(self, config: PhimoeConfig, layer_idx: Optional[int] = None):
  191. super().__init__()
  192. self.config = config
  193. self.layer_idx = layer_idx
  194. if layer_idx is None:
  195. logger.warning_once(
  196. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  197. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  198. "when creating this class."
  199. )
  200. self.hidden_size = config.hidden_size
  201. self.num_heads = config.num_attention_heads
  202. self.head_dim = self.hidden_size // self.num_heads
  203. self.num_key_value_heads = config.num_key_value_heads
  204. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  205. self.max_position_embeddings = config.max_position_embeddings
  206. self.rope_theta = config.rope_theta
  207. self.is_causal = True
  208. self.attention_dropout = config.attention_dropout
  209. if (self.head_dim * self.num_heads) != self.hidden_size:
  210. raise ValueError(
  211. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  212. f" and `num_heads`: {self.num_heads})."
  213. )
  214. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.config.attention_bias)
  215. self.k_proj = nn.Linear(
  216. self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.attention_bias
  217. )
  218. self.v_proj = nn.Linear(
  219. self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.config.attention_bias
  220. )
  221. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.config.attention_bias)
  222. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  223. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  224. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  225. def forward(
  226. self,
  227. hidden_states: torch.Tensor,
  228. attention_mask: Optional[torch.Tensor] = None,
  229. position_ids: Optional[torch.LongTensor] = None,
  230. past_key_values: Optional[Cache] = None,
  231. output_attentions: bool = False,
  232. use_cache: bool = False,
  233. cache_position: Optional[torch.LongTensor] = None,
  234. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  235. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  236. bsz, q_len, _ = hidden_states.size()
  237. query_states = self.q_proj(hidden_states)
  238. key_states = self.k_proj(hidden_states)
  239. value_states = self.v_proj(hidden_states)
  240. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  241. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  242. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  243. cos, sin = position_embeddings
  244. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  245. if past_key_values is not None:
  246. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  247. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  248. # repeat k/v heads if n_kv_heads < n_heads
  249. key_states = repeat_kv(key_states, self.num_key_value_groups)
  250. value_states = repeat_kv(value_states, self.num_key_value_groups)
  251. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  252. if attention_mask is not None: # no matter the length, we just slice it
  253. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  254. attn_weights = attn_weights + causal_mask
  255. # upcast attention to fp32
  256. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  257. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  258. attn_output = torch.matmul(attn_weights, value_states)
  259. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  260. raise ValueError(
  261. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  262. f" {attn_output.size()}"
  263. )
  264. attn_output = attn_output.transpose(1, 2).contiguous()
  265. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  266. attn_output = self.o_proj(attn_output)
  267. if not output_attentions:
  268. attn_weights = None
  269. return attn_output, attn_weights
  270. class PhimoeFlashAttention2(PhimoeAttention):
  271. """
  272. Phimoe flash attention module. This module inherits from `PhimoeAttention` as the weights of the module stays
  273. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  274. flash attention and deal with padding tokens in case the input contains any of them.
  275. """
  276. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  277. def forward(
  278. self,
  279. hidden_states: torch.Tensor,
  280. attention_mask: Optional[torch.Tensor] = None,
  281. position_ids: Optional[torch.LongTensor] = None,
  282. past_key_values: Optional[Cache] = None,
  283. output_attentions: bool = False,
  284. use_cache: bool = False,
  285. cache_position: Optional[torch.LongTensor] = None,
  286. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  287. ):
  288. bsz, q_len, _ = hidden_states.size()
  289. query_states = self.q_proj(hidden_states)
  290. key_states = self.k_proj(hidden_states)
  291. value_states = self.v_proj(hidden_states)
  292. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  293. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  294. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  295. cos, sin = position_embeddings
  296. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  297. if past_key_values is not None:
  298. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  299. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  300. # repeat k/v heads if n_kv_heads < n_heads
  301. key_states = repeat_kv(key_states, self.num_key_value_groups)
  302. value_states = repeat_kv(value_states, self.num_key_value_groups)
  303. dropout_rate = 0.0 if not self.training else self.attention_dropout
  304. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  305. # therefore the input hidden states gets silently casted in float32. Hence, we need
  306. # cast them back in float16 just to be sure everything works as expected.
  307. input_dtype = query_states.dtype
  308. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  309. if input_dtype == torch.float32:
  310. if torch.is_autocast_enabled():
  311. target_dtype = (
  312. torch.get_autocast_dtype(device_type)
  313. if hasattr(torch, "get_autocast_dtype")
  314. else torch.get_autocast_gpu_dtype()
  315. )
  316. # Handle the case where the model is quantized
  317. elif hasattr(self.config, "_pre_quantization_dtype"):
  318. target_dtype = self.config._pre_quantization_dtype
  319. else:
  320. target_dtype = self.q_proj.weight.dtype
  321. logger.warning_once(
  322. f"The input hidden states seems to be silently casted in float32, this might be related to"
  323. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  324. f" {target_dtype}."
  325. )
  326. query_states = query_states.to(target_dtype)
  327. key_states = key_states.to(target_dtype)
  328. value_states = value_states.to(target_dtype)
  329. # Reashape to the expected shape for Flash Attention
  330. query_states = query_states.transpose(1, 2)
  331. key_states = key_states.transpose(1, 2)
  332. value_states = value_states.transpose(1, 2)
  333. attn_output = _flash_attention_forward(
  334. query_states,
  335. key_states,
  336. value_states,
  337. attention_mask,
  338. q_len,
  339. position_ids=position_ids,
  340. dropout=dropout_rate,
  341. sliding_window=getattr(self.config, "sliding_window", None),
  342. is_causal=self.is_causal,
  343. )
  344. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
  345. attn_output = self.o_proj(attn_output)
  346. if not output_attentions:
  347. attn_weights = None
  348. return attn_output, attn_weights
  349. class PhimoeSdpaAttention(PhimoeAttention):
  350. """
  351. Phimoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  352. `PhimoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  353. SDPA API.
  354. """
  355. # Adapted from PhimoeAttention.forward
  356. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  357. def forward(
  358. self,
  359. hidden_states: torch.Tensor,
  360. attention_mask: Optional[torch.Tensor] = None,
  361. position_ids: Optional[torch.LongTensor] = None,
  362. past_key_values: Optional[Cache] = None,
  363. output_attentions: bool = False,
  364. use_cache: bool = False,
  365. cache_position: Optional[torch.LongTensor] = None,
  366. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  367. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  368. if output_attentions:
  369. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  370. logger.warning_once(
  371. "PhimoeModel is using PhimoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  372. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  373. )
  374. return super().forward(
  375. hidden_states=hidden_states,
  376. attention_mask=attention_mask,
  377. position_ids=position_ids,
  378. past_key_values=past_key_values,
  379. output_attentions=output_attentions,
  380. use_cache=use_cache,
  381. position_embeddings=position_embeddings,
  382. )
  383. bsz, q_len, _ = hidden_states.size()
  384. query_states = self.q_proj(hidden_states)
  385. key_states = self.k_proj(hidden_states)
  386. value_states = self.v_proj(hidden_states)
  387. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  388. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  389. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  390. cos, sin = position_embeddings
  391. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  392. if past_key_values is not None:
  393. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  394. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  395. key_states = repeat_kv(key_states, self.num_key_value_groups)
  396. value_states = repeat_kv(value_states, self.num_key_value_groups)
  397. causal_mask = attention_mask
  398. if attention_mask is not None: # no matter the length, we just slice it
  399. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  400. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  401. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  402. if query_states.device.type == "cuda" and attention_mask is not None:
  403. query_states = query_states.contiguous()
  404. key_states = key_states.contiguous()
  405. value_states = value_states.contiguous()
  406. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  407. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  408. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
  409. is_causal = causal_mask is None and q_len > 1
  410. attn_output = torch.nn.functional.scaled_dot_product_attention(
  411. query_states,
  412. key_states,
  413. value_states,
  414. attn_mask=causal_mask,
  415. dropout_p=self.attention_dropout if self.training else 0.0,
  416. is_causal=is_causal,
  417. )
  418. attn_output = attn_output.transpose(1, 2).contiguous()
  419. attn_output = attn_output.view(bsz, q_len, self.hidden_size)
  420. attn_output = self.o_proj(attn_output)
  421. return attn_output, None
  422. PHIMOE_ATTENTION_CLASSES = {
  423. "eager": PhimoeAttention,
  424. "flash_attention_2": PhimoeFlashAttention2,
  425. "sdpa": PhimoeSdpaAttention,
  426. }
  427. # Copied from transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP with Mixtral->Phimoe
  428. class PhimoeBlockSparseTop2MLP(nn.Module):
  429. def __init__(self, config: PhimoeConfig):
  430. super().__init__()
  431. self.ffn_dim = config.intermediate_size
  432. self.hidden_dim = config.hidden_size
  433. self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  434. self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
  435. self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  436. self.act_fn = ACT2FN[config.hidden_act]
  437. def forward(self, hidden_states):
  438. current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
  439. current_hidden_states = self.w2(current_hidden_states)
  440. return current_hidden_states
  441. class MultiplierProcessor(torch.autograd.Function):
  442. @staticmethod
  443. def forward(
  444. ctx,
  445. scores: torch.Tensor,
  446. multiplier: torch.Tensor,
  447. selected_experts: torch.Tensor,
  448. masked_gates: torch.Tensor,
  449. mask_for_one: torch.Tensor,
  450. ):
  451. """
  452. Forward pass for the custom autograd function.
  453. Args:
  454. ctx: Context object to save information for backward computation.
  455. scores (torch.Tensor): Input scores tensor.
  456. multiplier (torch.Tensor): Multiplier tensor.
  457. selected_experts (torch.Tensor): Tensor of selected experts.
  458. masked_gates (torch.Tensor): Masked gates tensor.
  459. mask_for_one (torch.Tensor): Mask for one tensor.
  460. Returns:
  461. torch.Tensor: Result of the forward pass.
  462. """
  463. ctx.save_for_backward(multiplier, selected_experts, masked_gates)
  464. return multiplier * mask_for_one
  465. @staticmethod
  466. def backward(
  467. ctx,
  468. grad_at_output: torch.Tensor,
  469. ):
  470. """
  471. Backward pass for the custom autograd function.
  472. Args:
  473. ctx: Context object with saved tensors from the forward pass.
  474. grad_at_output (torch.Tensor): Gradient at the output.
  475. Returns:
  476. tuple[torch.Tensor, None, None, None, None]: Gradients for the inputs.
  477. """
  478. multiplier, selected_experts, masked_gates = ctx.saved_tensors
  479. grad_at_output = grad_at_output * multiplier
  480. grad_at_scores_expanded = masked_gates * grad_at_output.mul(-1)
  481. grad_at_scores_expanded.scatter_add_(
  482. dim=-1,
  483. index=selected_experts,
  484. src=grad_at_output,
  485. )
  486. return (
  487. grad_at_scores_expanded,
  488. None,
  489. None,
  490. None,
  491. None,
  492. )
  493. def sparsemixer(scores, jitter_eps, training, top_k=2):
  494. """
  495. Sparse mixer function to select top-k experts and compute multipliers.
  496. Based on the paper: https://huggingface.co/papers/2409.12136
  497. We first replace the TopK(·) function as random sampling of discrete variables
  498. in model training. Then, following Liu et al. (2023a) and Liu et al. (2023b), we apply Heun's
  499. third order method to approximate the expert routing gradient and construct a modified
  500. back-propagation to give a mathematically sound gradient estimation for expert routing.
  501. Args:
  502. scores (torch.Tensor): Input scores tensor.
  503. jitter_eps (float): Jitter epsilon for numerical stability.
  504. training (bool): Flag indicating if the model is in training mode.
  505. top_k (int): Number of top experts to select.
  506. Returns:
  507. tuple[torch.Tensor, torch.Tensor]: Multiplier and selected experts tensors.
  508. """
  509. if top_k != 2:
  510. raise ValueError("top_k must be equal to 2")
  511. # first expert
  512. with torch.no_grad():
  513. # Compute mask for sparsity
  514. mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
  515. factor = scores.abs().clamp(min=mask_logits_threshold)
  516. mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
  517. # Apply mask
  518. masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf"))
  519. if training:
  520. selected_experts = (
  521. (
  522. masked_gates
  523. - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log()
  524. )
  525. .max(dim=-1)[1]
  526. .unsqueeze(-1)
  527. ) # Gumbel sampling, more robust than the multinomial method
  528. else:
  529. selected_experts = max_ind
  530. # Compute scores for gradients
  531. masked_gates = torch.softmax(masked_gates, dim=-1)
  532. multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
  533. if training:
  534. # Compute midpoint mask
  535. max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
  536. mask_for_one = torch.logical_or(
  537. selected_experts == max_ind,
  538. torch.rand_like(max_scores) > 0.75, # Heun's third-order method
  539. )
  540. # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
  541. mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
  542. multiplier = MultiplierProcessor.apply(
  543. scores,
  544. multiplier_o,
  545. selected_experts,
  546. masked_gates,
  547. mask_for_one,
  548. )
  549. else:
  550. multiplier = multiplier_o
  551. # Masked out first expert
  552. masked_scores = torch.scatter(
  553. scores,
  554. -1,
  555. selected_experts,
  556. float("-inf"),
  557. )
  558. with torch.no_grad():
  559. # Compute mask for sparsity
  560. mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
  561. factor = scores.abs().clamp(min=mask_logits_threshold)
  562. mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
  563. # Apply mask
  564. masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
  565. if training:
  566. selected_experts_top2 = (
  567. (
  568. masked_gates_top2
  569. - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format)
  570. .exponential_()
  571. .log()
  572. )
  573. .max(dim=-1)[1]
  574. .unsqueeze(-1)
  575. ) # Gumbel sampling, more robust than the multinomial method
  576. else:
  577. selected_experts_top2 = max_ind
  578. # Compute scores for gradients
  579. masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
  580. multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
  581. if training:
  582. # Compute midpoint mask
  583. max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
  584. mask_for_one_top2 = torch.logical_or(
  585. selected_experts_top2 == max_ind,
  586. torch.rand_like(max_scores).uniform_() > 0.75, # Heun's third-order method
  587. )
  588. # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
  589. mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2)
  590. multiplier_top2 = MultiplierProcessor.apply(
  591. scores,
  592. multiplier_top2_o,
  593. selected_experts_top2,
  594. masked_gates_top2,
  595. mask_for_one_top2,
  596. )
  597. else:
  598. multiplier_top2 = multiplier_top2_o
  599. multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
  600. selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
  601. return (
  602. multiplier,
  603. selected_experts,
  604. )
  605. class PhimoeSparseMoeBlock(nn.Module):
  606. """
  607. This implementation is
  608. strictly equivalent to standard MoE with full capacity (no
  609. dropped tokens). It's faster since it formulates MoE operations
  610. in terms of block-sparse operations to accommodate imbalanced
  611. assignments of tokens to experts, whereas standard MoE either
  612. (1) drop tokens at the cost of reduced performance or (2) set
  613. capacity factor to number of experts and thus waste computation
  614. and memory on padding.
  615. """
  616. def __init__(self, config):
  617. super().__init__()
  618. self.hidden_dim = config.hidden_size
  619. self.ffn_dim = config.intermediate_size
  620. self.num_experts = config.num_local_experts
  621. self.top_k = config.num_experts_per_tok
  622. # gating
  623. self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
  624. self.experts = nn.ModuleList([PhimoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
  625. # Jitter parameters
  626. self.router_jitter_noise = config.router_jitter_noise
  627. self.input_jitter_noise = config.input_jitter_noise
  628. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  629. """ """
  630. batch_size, sequence_length, hidden_dim = hidden_states.shape
  631. if self.training and self.input_jitter_noise > 0:
  632. hidden_states *= torch.empty_like(hidden_states).uniform_(
  633. 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
  634. )
  635. hidden_states = hidden_states.view(-1, hidden_dim)
  636. router_logits = self.gate(hidden_states)
  637. routing_weights, selected_experts = sparsemixer(
  638. router_logits,
  639. jitter_eps=self.router_jitter_noise,
  640. training=self.training,
  641. )
  642. final_hidden_states = torch.zeros(
  643. (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
  644. )
  645. # One hot encode the selected experts to create an expert mask
  646. # this will be used to easily index which expert is going to be sollicitated
  647. expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
  648. # Loop over all available experts in the model and perform the computation on each expert
  649. for expert_idx in range(self.num_experts):
  650. expert_layer = self.experts[expert_idx]
  651. idx, top_x = torch.where(expert_mask[expert_idx])
  652. if top_x.shape[0] == 0:
  653. continue
  654. # Index the correct hidden states and compute the expert hidden state for
  655. # the current expert. We need to make sure to multiply the output hidden
  656. # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
  657. current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
  658. current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
  659. # However `index_add_` only support torch tensors for indexing so we'll use
  660. # the `top_x` tensor here.
  661. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
  662. final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  663. return final_hidden_states, router_logits
  664. class PhimoeDecoderLayer(GradientCheckpointingLayer):
  665. def __init__(self, config: PhimoeConfig, layer_idx: int):
  666. super().__init__()
  667. self.hidden_size = config.hidden_size
  668. self.self_attn = PHIMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
  669. self.block_sparse_moe = PhimoeSparseMoeBlock(config)
  670. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
  671. self.post_attention_layernorm = nn.LayerNorm(
  672. config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
  673. )
  674. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  675. def forward(
  676. self,
  677. hidden_states: torch.Tensor,
  678. attention_mask: Optional[torch.Tensor] = None,
  679. position_ids: Optional[torch.LongTensor] = None,
  680. past_key_values: Optional[Cache] = None,
  681. output_attentions: Optional[bool] = False,
  682. output_router_logits: Optional[bool] = False,
  683. use_cache: Optional[bool] = False,
  684. cache_position: Optional[torch.LongTensor] = None,
  685. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  686. **kwargs,
  687. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  688. """
  689. Args:
  690. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  691. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  692. `(batch, sequence_length)` where padding elements are indicated by 0.
  693. past_key_values (`Cache`, *optional*): cached past key and value projection states
  694. output_attentions (`bool`, *optional*):
  695. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  696. returned tensors for more detail.
  697. output_router_logits (`bool`, *optional*):
  698. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  699. should not be returned during inference.
  700. use_cache (`bool`, *optional*):
  701. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  702. (see `past_key_values`).
  703. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  704. Indices depicting the position of the input sequence tokens in the sequence.
  705. kwargs (`dict`, *optional*):
  706. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  707. into the model
  708. """
  709. residual = hidden_states
  710. hidden_states = self.input_layernorm(hidden_states)
  711. # Self Attention
  712. hidden_states, self_attn_weights = self.self_attn(
  713. hidden_states=hidden_states,
  714. attention_mask=attention_mask,
  715. position_ids=position_ids,
  716. past_key_values=past_key_values,
  717. output_attentions=output_attentions,
  718. use_cache=use_cache,
  719. cache_position=cache_position,
  720. position_embeddings=position_embeddings,
  721. )
  722. hidden_states = residual + hidden_states
  723. # Fully Connected
  724. residual = hidden_states
  725. hidden_states = self.post_attention_layernorm(hidden_states)
  726. hidden_states, router_logits = self.block_sparse_moe(hidden_states)
  727. hidden_states = residual + hidden_states
  728. outputs = (hidden_states,)
  729. if output_attentions:
  730. outputs += (self_attn_weights,)
  731. if output_router_logits:
  732. outputs += (router_logits,)
  733. return outputs
  734. @auto_docstring
  735. class PhimoePreTrainedModel(PreTrainedModel):
  736. config: PhimoeConfig
  737. base_model_prefix = "model"
  738. supports_gradient_checkpointing = True
  739. _no_split_modules = ["PhimoeDecoderLayer"]
  740. _skip_keys_device_placement = ["past_key_values"]
  741. _supports_flash_attn = True
  742. _supports_sdpa = True
  743. _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
  744. def _init_weights(self, module):
  745. std = self.config.initializer_range
  746. if isinstance(module, nn.Linear):
  747. module.weight.data.normal_(mean=0.0, std=std)
  748. if module.bias is not None:
  749. module.bias.data.zero_()
  750. elif isinstance(module, nn.Embedding):
  751. module.weight.data.normal_(mean=0.0, std=std)
  752. if module.padding_idx is not None:
  753. module.weight.data[module.padding_idx].zero_()
  754. elif isinstance(module, nn.LayerNorm):
  755. module.bias.data.zero_()
  756. module.weight.data.fill_(1.0)
  757. @auto_docstring
  758. class PhimoeModel(PhimoePreTrainedModel):
  759. """
  760. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhimoeDecoderLayer`]
  761. Args:
  762. config: PhimoeConfig
  763. """
  764. def __init__(self, config: PhimoeConfig):
  765. super().__init__(config)
  766. self.padding_idx = config.pad_token_id
  767. self.vocab_size = config.vocab_size
  768. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  769. self.layers = nn.ModuleList(
  770. [PhimoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  771. )
  772. self._attn_implementation = config._attn_implementation
  773. self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
  774. self.rotary_emb = PhimoeRotaryEmbedding(config=config)
  775. self.gradient_checkpointing = False
  776. # Initialize weights and apply final processing
  777. self.post_init()
  778. @can_return_tuple
  779. @auto_docstring
  780. def forward(
  781. self,
  782. input_ids: Optional[torch.LongTensor] = None,
  783. attention_mask: Optional[torch.Tensor] = None,
  784. position_ids: Optional[torch.LongTensor] = None,
  785. past_key_values: Optional[Cache] = None,
  786. inputs_embeds: Optional[torch.FloatTensor] = None,
  787. use_cache: Optional[bool] = None,
  788. output_attentions: Optional[bool] = None,
  789. output_hidden_states: Optional[bool] = None,
  790. output_router_logits: Optional[bool] = None,
  791. cache_position: Optional[torch.LongTensor] = None,
  792. ) -> MoeModelOutputWithPast:
  793. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  794. output_router_logits = (
  795. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  796. )
  797. output_hidden_states = (
  798. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  799. )
  800. use_cache = use_cache if use_cache is not None else self.config.use_cache
  801. if (input_ids is None) ^ (inputs_embeds is not None):
  802. raise ValueError(
  803. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  804. )
  805. if self.gradient_checkpointing and self.training:
  806. if use_cache:
  807. logger.warning_once(
  808. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  809. )
  810. use_cache = False
  811. if use_cache and past_key_values is None:
  812. past_key_values = DynamicCache(config=self.config)
  813. if inputs_embeds is None:
  814. inputs_embeds = self.embed_tokens(input_ids)
  815. if cache_position is None:
  816. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  817. cache_position = torch.arange(
  818. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  819. )
  820. if position_ids is None:
  821. position_ids = cache_position.unsqueeze(0)
  822. causal_mask = self._update_causal_mask(
  823. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  824. )
  825. hidden_states = inputs_embeds
  826. position_embeddings = self.rotary_emb(hidden_states, seq_len=cache_position[-1] + 1)
  827. # decoder layers
  828. all_hidden_states = () if output_hidden_states else None
  829. all_self_attns = () if output_attentions else None
  830. all_router_logits = () if output_router_logits else None
  831. for decoder_layer in self.layers:
  832. if output_hidden_states:
  833. all_hidden_states += (hidden_states,)
  834. layer_outputs = decoder_layer(
  835. hidden_states,
  836. attention_mask=causal_mask,
  837. position_ids=position_ids,
  838. past_key_values=past_key_values,
  839. output_attentions=output_attentions,
  840. output_router_logits=output_router_logits,
  841. use_cache=use_cache,
  842. cache_position=cache_position,
  843. position_embeddings=position_embeddings,
  844. )
  845. hidden_states = layer_outputs[0]
  846. if output_attentions:
  847. all_self_attns += (layer_outputs[1],)
  848. if output_router_logits:
  849. all_router_logits += (layer_outputs[-1],)
  850. hidden_states = self.norm(hidden_states)
  851. # add hidden states from the last decoder layer
  852. if output_hidden_states:
  853. all_hidden_states += (hidden_states,)
  854. return MoeModelOutputWithPast(
  855. last_hidden_state=hidden_states,
  856. past_key_values=past_key_values,
  857. hidden_states=all_hidden_states,
  858. attentions=all_self_attns,
  859. router_logits=all_router_logits,
  860. )
  861. def _update_causal_mask(
  862. self,
  863. attention_mask: Union[torch.Tensor, "BlockMask"],
  864. input_tensor: torch.Tensor,
  865. cache_position: torch.Tensor,
  866. past_key_values: Cache,
  867. output_attentions: bool = False,
  868. ):
  869. if self.config._attn_implementation == "flash_attention_2":
  870. if attention_mask is not None and past_key_values is not None:
  871. is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
  872. if is_padding_right:
  873. raise ValueError(
  874. "You are attempting to perform batched generation with padding_side='right'"
  875. " this may lead to unexpected behaviour for Flash Attention version of Phimoe. Make sure to "
  876. " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
  877. )
  878. if attention_mask is not None and 0.0 in attention_mask:
  879. return attention_mask
  880. return None
  881. if self.config._attn_implementation == "flex_attention":
  882. if isinstance(attention_mask, torch.Tensor):
  883. attention_mask = make_flex_block_causal_mask(attention_mask)
  884. return attention_mask
  885. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  886. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  887. # to infer the attention mask.
  888. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  889. using_static_cache = isinstance(past_key_values, StaticCache)
  890. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  891. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  892. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  893. attention_mask,
  894. inputs_embeds=input_tensor,
  895. past_key_values_length=past_seen_tokens,
  896. sliding_window=self.config.sliding_window,
  897. is_training=self.training,
  898. ):
  899. return None
  900. dtype = input_tensor.dtype
  901. min_dtype = torch.finfo(dtype).min
  902. sequence_length = input_tensor.shape[1]
  903. # StaticCache
  904. if using_static_cache:
  905. target_length = past_key_values.get_max_cache_shape()
  906. # DynamicCache or no cache
  907. else:
  908. target_length = (
  909. attention_mask.shape[-1]
  910. if isinstance(attention_mask, torch.Tensor)
  911. else past_seen_tokens + sequence_length + 1
  912. )
  913. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  914. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  915. attention_mask,
  916. sequence_length=sequence_length,
  917. target_length=target_length,
  918. dtype=dtype,
  919. cache_position=cache_position,
  920. batch_size=input_tensor.shape[0],
  921. config=self.config,
  922. past_key_values=past_key_values,
  923. )
  924. if (
  925. self.config._attn_implementation == "sdpa"
  926. and attention_mask is not None
  927. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  928. and not output_attentions
  929. ):
  930. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  931. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  932. # Details: https://github.com/pytorch/pytorch/issues/110213
  933. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  934. return causal_mask
  935. @staticmethod
  936. def _prepare_4d_causal_attention_mask_with_cache_position(
  937. attention_mask: torch.Tensor,
  938. sequence_length: int,
  939. target_length: int,
  940. dtype: torch.dtype,
  941. cache_position: torch.Tensor,
  942. batch_size: int,
  943. config: PhimoeConfig,
  944. past_key_values: Cache,
  945. ):
  946. """
  947. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  948. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  949. Args:
  950. attention_mask (`torch.Tensor`):
  951. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  952. sequence_length (`int`):
  953. The sequence length being processed.
  954. target_length (`int`):
  955. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  956. dtype (`torch.dtype`):
  957. The dtype to use for the 4D attention mask.
  958. cache_position (`torch.Tensor`):
  959. Indices depicting the position of the input sequence tokens in the sequence.
  960. batch_size (`torch.Tensor`):
  961. Batch size.
  962. config (`PhimoeConfig`):
  963. The model's configuration class
  964. past_key_values (`Cache`):
  965. The cache class that is being used currently to generate
  966. """
  967. if attention_mask is not None and attention_mask.dim() == 4:
  968. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  969. causal_mask = attention_mask
  970. else:
  971. min_dtype = torch.finfo(dtype).min
  972. causal_mask = torch.full(
  973. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  974. )
  975. diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
  976. -1, 1
  977. )
  978. text_config = config.get_text_config()
  979. if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
  980. # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
  981. # the check is needed to verify is current checkpoint was trained with sliding window or not
  982. is_static_sliding_cache = isinstance(past_key_values, StaticCache) and all(past_key_values.is_sliding)
  983. if not is_static_sliding_cache or sequence_length > target_length:
  984. sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
  985. cache_position.reshape(-1, 1) - text_config.sliding_window
  986. )
  987. diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
  988. causal_mask *= diagonal_attend_mask
  989. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  990. if attention_mask is not None:
  991. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  992. if attention_mask.shape[-1] > target_length:
  993. attention_mask = attention_mask[:, :target_length]
  994. mask_length = attention_mask.shape[-1]
  995. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  996. causal_mask.device
  997. )
  998. padding_mask = padding_mask == 0
  999. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1000. padding_mask, min_dtype
  1001. )
  1002. return causal_mask
  1003. class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
  1004. _tied_weights_keys = ["lm_head.weight"]
  1005. def __init__(self, config):
  1006. super().__init__(config)
  1007. self.model = PhimoeModel(config)
  1008. self.vocab_size = config.vocab_size
  1009. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=self.config.lm_head_bias)
  1010. self.router_aux_loss_coef = config.router_aux_loss_coef
  1011. self.num_experts = config.num_local_experts
  1012. self.num_experts_per_tok = config.num_experts_per_tok
  1013. # Initialize weights and apply final processing
  1014. self.post_init()
  1015. @can_return_tuple
  1016. @auto_docstring
  1017. def forward(
  1018. self,
  1019. input_ids: Optional[torch.LongTensor] = None,
  1020. attention_mask: Optional[torch.Tensor] = None,
  1021. position_ids: Optional[torch.LongTensor] = None,
  1022. past_key_values: Optional[Cache] = None,
  1023. inputs_embeds: Optional[torch.FloatTensor] = None,
  1024. labels: Optional[torch.LongTensor] = None,
  1025. use_cache: Optional[bool] = None,
  1026. output_attentions: Optional[bool] = None,
  1027. output_hidden_states: Optional[bool] = None,
  1028. output_router_logits: Optional[bool] = None,
  1029. cache_position: Optional[torch.LongTensor] = None,
  1030. logits_to_keep: Union[int, torch.Tensor] = 0,
  1031. **kwargs,
  1032. ) -> MoeCausalLMOutputWithPast:
  1033. r"""
  1034. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1035. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1036. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1037. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1038. Example:
  1039. ```python
  1040. >>> from transformers import AutoTokenizer, PhimoeForCausalLM
  1041. >>> model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
  1042. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
  1043. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  1044. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1045. >>> # Generate
  1046. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1047. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1048. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  1049. ```"""
  1050. if (
  1051. use_cache
  1052. and self.config.rope_scaling
  1053. and cache_position is not None
  1054. and cache_position[0] == self.config.original_max_position_embeddings
  1055. ):
  1056. logger.warning(
  1057. f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
  1058. )
  1059. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1060. output_router_logits = (
  1061. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  1062. )
  1063. output_hidden_states = (
  1064. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1065. )
  1066. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1067. outputs: MoeModelOutputWithPast = self.model(
  1068. input_ids=input_ids,
  1069. attention_mask=attention_mask,
  1070. position_ids=position_ids,
  1071. past_key_values=past_key_values,
  1072. inputs_embeds=inputs_embeds,
  1073. use_cache=use_cache,
  1074. output_attentions=output_attentions,
  1075. output_hidden_states=output_hidden_states,
  1076. output_router_logits=output_router_logits,
  1077. cache_position=cache_position,
  1078. )
  1079. hidden_states = outputs.last_hidden_state
  1080. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1081. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1082. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1083. loss = None
  1084. if labels is not None:
  1085. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  1086. aux_loss = None
  1087. if output_router_logits:
  1088. aux_loss = load_balancing_loss_func(
  1089. outputs.router_logits,
  1090. self.num_experts,
  1091. self.num_experts_per_tok,
  1092. attention_mask,
  1093. )
  1094. if labels is not None:
  1095. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  1096. return MoeCausalLMOutputWithPast(
  1097. loss=loss,
  1098. aux_loss=aux_loss,
  1099. logits=logits,
  1100. past_key_values=outputs.past_key_values,
  1101. hidden_states=outputs.hidden_states,
  1102. attentions=outputs.attentions,
  1103. router_logits=outputs.router_logits,
  1104. )
  1105. # Copied from transformers.models.phi3.modeling_phi3.Phi3ForCausalLM.prepare_inputs_for_generation
  1106. def prepare_inputs_for_generation(
  1107. self,
  1108. input_ids,
  1109. past_key_values=None,
  1110. attention_mask=None,
  1111. inputs_embeds=None,
  1112. cache_position=None,
  1113. position_ids=None,
  1114. use_cache=True,
  1115. logits_to_keep=None,
  1116. **kwargs,
  1117. ):
  1118. # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
  1119. # process
  1120. # When the first time input length reached long and short factor switching point, enforce re-compute cache
  1121. # It will cause downside of slower at this single token position, however, better than current failure.
  1122. if (
  1123. past_key_values
  1124. and self.config.rope_scaling
  1125. and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
  1126. ):
  1127. past_length = cache_position[0]
  1128. if past_length <= self.config.original_max_position_embeddings:
  1129. past_key_values = None
  1130. model_inputs = super().prepare_inputs_for_generation(
  1131. input_ids=input_ids,
  1132. past_key_values=past_key_values,
  1133. attention_mask=attention_mask,
  1134. inputs_embeds=inputs_embeds,
  1135. cache_position=cache_position,
  1136. position_ids=position_ids,
  1137. use_cache=use_cache,
  1138. logits_to_keep=logits_to_keep,
  1139. **kwargs,
  1140. )
  1141. return model_inputs
  1142. class PhimoeForSequenceClassification(GenericForSequenceClassification, PhimoePreTrainedModel): ...
  1143. __all__ = [
  1144. "PhimoePreTrainedModel",
  1145. "PhimoeModel",
  1146. "PhimoeForCausalLM",
  1147. "PhimoeForSequenceClassification",
  1148. ]