modeling_doge.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/doge/modular_doge.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_doge.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # The Doge family of small language models is trained by SmallDoge Team.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. import math
  24. from typing import Callable, Optional, Union
  25. import torch
  26. import torch.nn.functional as F
  27. from torch import nn
  28. from ...activations import ACT2FN
  29. from ...cache_utils import Cache, DynamicCache
  30. from ...generation import GenerationMixin
  31. from ...integrations import use_kernel_forward_from_hub
  32. from ...integrations.flex_attention import compile_friendly_flex_attention
  33. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  34. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  35. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  36. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  37. from ...modeling_utils import AttentionInterface, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available
  40. from ...utils.deprecation import deprecate_kwarg
  41. from ...utils.generic import OutputRecorder, check_model_inputs
  42. from .configuration_doge import DogeConfig
  43. if is_torch_flex_attn_available():
  44. from torch.nn.attention.flex_attention import BlockMask
  45. @use_kernel_forward_from_hub("RMSNorm")
  46. class DogeRMSNorm(nn.Module):
  47. def __init__(self, hidden_size, eps=1e-6):
  48. """
  49. DogeRMSNorm is equivalent to T5LayerNorm
  50. """
  51. super().__init__()
  52. self.weight = nn.Parameter(torch.ones(hidden_size))
  53. self.variance_epsilon = eps
  54. def forward(self, hidden_states):
  55. input_dtype = hidden_states.dtype
  56. hidden_states = hidden_states.to(torch.float32)
  57. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  58. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  59. return self.weight * hidden_states.to(input_dtype)
  60. def extra_repr(self):
  61. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  62. class DogeRotaryEmbedding(nn.Module):
  63. inv_freq: torch.Tensor # fix linting for `register_buffer`
  64. def __init__(self, config: DogeConfig, device=None):
  65. super().__init__()
  66. # BC: "rope_type" was originally "type"
  67. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  68. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  69. else:
  70. self.rope_type = "default"
  71. self.max_seq_len_cached = config.max_position_embeddings
  72. self.original_max_seq_len = config.max_position_embeddings
  73. self.config = config
  74. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  75. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  76. self.register_buffer("inv_freq", inv_freq, persistent=False)
  77. self.original_inv_freq = self.inv_freq
  78. @torch.no_grad()
  79. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  80. def forward(self, x, position_ids):
  81. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  82. position_ids_expanded = position_ids[:, None, :].float()
  83. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  84. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  85. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  86. emb = torch.cat((freqs, freqs), dim=-1)
  87. cos = emb.cos() * self.attention_scaling
  88. sin = emb.sin() * self.attention_scaling
  89. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  90. def rotate_half(x):
  91. """Rotates half the hidden dims of the input."""
  92. x1 = x[..., : x.shape[-1] // 2]
  93. x2 = x[..., x.shape[-1] // 2 :]
  94. return torch.cat((-x2, x1), dim=-1)
  95. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  96. """Applies Rotary Position Embedding to the query and key tensors.
  97. Args:
  98. q (`torch.Tensor`): The query tensor.
  99. k (`torch.Tensor`): The key tensor.
  100. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  101. sin (`torch.Tensor`): The sine part of the rotary embedding.
  102. position_ids (`torch.Tensor`, *optional*):
  103. Deprecated and unused.
  104. unsqueeze_dim (`int`, *optional*, defaults to 1):
  105. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  106. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  107. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  108. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  109. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  110. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  111. Returns:
  112. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  113. """
  114. cos = cos.unsqueeze(unsqueeze_dim)
  115. sin = sin.unsqueeze(unsqueeze_dim)
  116. q_embed = (q * cos) + (rotate_half(q) * sin)
  117. k_embed = (k * cos) + (rotate_half(k) * sin)
  118. return q_embed, k_embed
  119. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  120. """
  121. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  122. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  123. """
  124. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  125. if n_rep == 1:
  126. return hidden_states
  127. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  128. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  129. def eager_attention_forward(
  130. module: nn.Module,
  131. query: torch.Tensor,
  132. key: torch.Tensor,
  133. value: torch.Tensor,
  134. attention_mask: Optional[torch.Tensor],
  135. scaling: float,
  136. dropout: float = 0.0,
  137. **kwargs: Unpack[TransformersKwargs],
  138. ):
  139. key_states = repeat_kv(key, module.num_key_value_groups)
  140. value_states = repeat_kv(value, module.num_key_value_groups)
  141. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  142. if attention_mask is not None:
  143. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  144. attn_weights = attn_weights + causal_mask
  145. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  146. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  147. attn_output = torch.matmul(attn_weights, value_states)
  148. attn_output = attn_output.transpose(1, 2).contiguous()
  149. return attn_output, attn_weights
  150. def flex_attention_forward(
  151. module: nn.Module,
  152. query: torch.Tensor,
  153. key: torch.Tensor,
  154. value: torch.Tensor,
  155. attention_mask: Union[torch.Tensor, "BlockMask"],
  156. scaling: Optional[float] = None,
  157. softcap: Optional[float] = None,
  158. head_mask: Optional[torch.Tensor] = None,
  159. **kwargs,
  160. ) -> tuple[torch.Tensor, torch.Tensor]:
  161. block_mask = None
  162. causal_mask = None
  163. if isinstance(attention_mask, BlockMask):
  164. block_mask = attention_mask
  165. else:
  166. causal_mask = attention_mask
  167. if causal_mask is not None:
  168. causal_mask = causal_mask[:, :, :, : key.shape[-2]]
  169. def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
  170. if softcap is not None:
  171. score = softcap * torch.tanh(score / softcap)
  172. if causal_mask is not None:
  173. score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx]
  174. if head_mask is not None:
  175. score = score + head_mask[batch_idx][head_idx][0][0]
  176. return score
  177. attn_output, attention_weights = compile_friendly_flex_attention(
  178. query,
  179. key,
  180. value,
  181. score_mod=score_mod,
  182. block_mask=block_mask,
  183. enable_gqa=True,
  184. scale=scaling,
  185. # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
  186. # For simplification, we thus always return it as no additional computations are introduced.
  187. return_lse=True,
  188. )
  189. # lse is returned in float32
  190. attention_weights = attention_weights.to(value.dtype)
  191. attn_output = attn_output.transpose(1, 2).contiguous()
  192. return attn_output, attention_weights
  193. ALL_ATTENTION_FUNCTIONS = AttentionInterface()
  194. ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward
  195. class DogeAttention(nn.Module):
  196. def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
  197. super().__init__()
  198. self.config = config
  199. self.layer_idx = layer_idx
  200. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  201. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  202. self.scaling = self.head_dim**-0.5
  203. self.attention_dropout = config.attention_dropout
  204. self.keep_window_size = config.keep_window_size
  205. self.q_proj = nn.Linear(
  206. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  207. )
  208. self.k_proj = nn.Linear(
  209. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  210. )
  211. self.v_proj = nn.Linear(
  212. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  213. )
  214. # dynamic mask for the QK^T attention weights matrix
  215. self.A = nn.Parameter(torch.zeros(config.num_key_value_heads))
  216. self.dt_proj = nn.Linear(
  217. config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias
  218. )
  219. self.o_proj = nn.Linear(
  220. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  221. )
  222. self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  223. self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  224. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  225. def forward(
  226. self,
  227. hidden_states: torch.Tensor,
  228. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  229. attention_mask: Optional[torch.Tensor] = None,
  230. past_key_values: Optional[Cache] = None,
  231. cache_position: Optional[torch.LongTensor] = None,
  232. **kwargs,
  233. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  234. input_shape = hidden_states.shape[:-1]
  235. hidden_shape = (*input_shape, -1, self.head_dim)
  236. query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  237. key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  238. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  239. cos, sin = position_embeddings
  240. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  241. if past_key_values is not None:
  242. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  243. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  244. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  245. # calculate dynamic mask from value_states
  246. dt_states = self.dt_proj(
  247. value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
  248. )
  249. dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
  250. attn_mask = self.prepare_dynamic_mask(
  251. hidden_states=hidden_states,
  252. dt_states=dt_states,
  253. keep_window_size=self.keep_window_size,
  254. attention_mask=attention_mask,
  255. )
  256. attn_mask = repeat_kv(attn_mask, self.num_key_value_groups)
  257. attention_interface: Callable = eager_attention_forward
  258. if self.config._attn_implementation != "eager":
  259. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  260. attn_output, attn_weights = attention_interface(
  261. self,
  262. query_states,
  263. key_states,
  264. value_states,
  265. attention_mask=attn_mask,
  266. dropout=0.0 if not self.training else self.attention_dropout,
  267. scaling=self.scaling,
  268. **kwargs,
  269. )
  270. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  271. attn_output = self.o_proj(attn_output)
  272. return attn_output, attn_weights
  273. def prepare_dynamic_mask(
  274. self,
  275. hidden_states: torch.Tensor,
  276. dt_states: torch.Tensor,
  277. keep_window_size: int = 2048,
  278. attention_mask: Optional[torch.Tensor] = None,
  279. ):
  280. """
  281. The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention.
  282. Combine `dt_states` with `attention_mask` to generate the final `attn_mask`.
  283. Args:
  284. hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
  285. dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`.
  286. keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
  287. attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
  288. """
  289. min_dtype = torch.finfo(hidden_states.dtype).min
  290. dtype = hidden_states.dtype
  291. attn_mask = dt_states[:, :, None, :].expand(
  292. -1, -1, hidden_states.shape[1], -1
  293. ) # [batch_size, num_heads, query_len, key_len]
  294. if attention_mask is not None and not isinstance(attention_mask, BlockMask):
  295. if attention_mask.dtype == torch.bool:
  296. dtype = hidden_states.dtype
  297. attention_mask = torch.where(
  298. attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype
  299. )
  300. attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype)
  301. if attn_mask.shape[-1] > keep_window_size:
  302. active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device)
  303. topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, largest=True, sorted=False).indices
  304. active_mask = active_mask.scatter(-1, topk_indices, 1.0)
  305. attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype)
  306. return attn_mask
  307. class DogeMLP(nn.Module):
  308. def __init__(self, config):
  309. super().__init__()
  310. self.config = config
  311. self.hidden_size = config.hidden_size
  312. self.intermediate_size = config.intermediate_size
  313. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  314. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  315. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  316. self.act_fn = ACT2FN[config.hidden_act]
  317. def forward(self, x):
  318. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  319. return down_proj
  320. class DogeCDMoE(nn.Module):
  321. def __init__(self, config: DogeConfig):
  322. super().__init__()
  323. self.hidden_size = config.hidden_size
  324. self.intermediate_size = config.intermediate_size
  325. self.act_fn = ACT2FN[config.hidden_act]
  326. self.num_experts = config.num_experts
  327. self.num_keys = math.floor(math.sqrt(self.num_experts))
  328. self.top_k = config.num_experts_per_tok
  329. self.norm_topk_prob = config.norm_topk_prob
  330. # shared expert
  331. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  332. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  333. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  334. # router gate for retrieval experts
  335. self.router_gate = nn.Linear(self.hidden_size, self.num_keys * 2, bias=False)
  336. # routed experts
  337. self.down_embed = nn.Embedding(self.num_experts, self.hidden_size)
  338. self.up_embed = nn.Embedding(self.num_experts, self.hidden_size)
  339. def forward(
  340. self,
  341. hidden_states: torch.Tensor,
  342. **kwargs,
  343. ) -> torch.Tensor:
  344. bsz, seq_len, _ = hidden_states.shape
  345. # get routing logits with router gate
  346. router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
  347. # get experts with the highest routing logits
  348. (scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1)
  349. all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
  350. all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
  351. all_scores = all_scores.view(*all_scores.shape[:-2], -1)
  352. all_indices = all_indices.view(*all_indices.shape[:-2], -1)
  353. scores, position_indices = all_scores.topk(self.top_k, dim=-1)
  354. indices = all_indices.gather(-1, position_indices)
  355. routing_weights = F.softmax(scores, dim=-1)
  356. if self.norm_topk_prob:
  357. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  358. # mix routed experts states with shared expert states
  359. down_embed = self.down_embed(indices)
  360. up_embed = self.up_embed(indices)
  361. experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1)
  362. experts_weights = self.act_fn(experts_weights) * routing_weights
  363. experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
  364. hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
  365. hidden_states = hidden_states + experts_states
  366. return hidden_states, router_logits
  367. class DogeDecoderLayer(GradientCheckpointingLayer):
  368. def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
  369. super().__init__()
  370. self.hidden_dropout = config.hidden_dropout
  371. self.input_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  372. self.self_attn = DogeAttention(config=config, layer_idx=layer_idx)
  373. self.input_residual = nn.Parameter(torch.ones(config.hidden_size))
  374. self.post_attention_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  375. self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config)
  376. self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size))
  377. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  378. def forward(
  379. self,
  380. hidden_states: torch.Tensor,
  381. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  382. attention_mask: Optional[torch.Tensor] = None,
  383. position_ids: Optional[torch.LongTensor] = None,
  384. past_key_values: Optional[Cache] = None,
  385. use_cache: Optional[bool] = False,
  386. cache_position: Optional[torch.LongTensor] = None,
  387. **kwargs: Unpack[TransformersKwargs],
  388. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  389. # sequence transformation
  390. residual = hidden_states
  391. hidden_states = self.input_layernorm(hidden_states)
  392. hidden_states, self_attn_weights = self.self_attn(
  393. hidden_states=hidden_states,
  394. position_embeddings=position_embeddings,
  395. attention_mask=attention_mask,
  396. position_ids=position_ids,
  397. past_key_values=past_key_values,
  398. use_cache=use_cache,
  399. cache_position=cache_position,
  400. **kwargs,
  401. )
  402. hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
  403. hidden_states = self.input_residual * residual + hidden_states
  404. # state transformation
  405. residual = hidden_states
  406. hidden_states = self.post_attention_layernorm(hidden_states)
  407. hidden_states = self.mlp(hidden_states)
  408. hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
  409. hidden_states = self.post_attention_residual * residual + hidden_states
  410. return hidden_states
  411. @auto_docstring
  412. class DogePreTrainedModel(PreTrainedModel):
  413. config: DogeConfig
  414. base_model_prefix = "model"
  415. supports_gradient_checkpointing = True
  416. _no_split_modules = ["DogeDecoderLayer"]
  417. _skip_keys_device_placement = ["past_key_values"]
  418. _supports_flash_attn = False
  419. _supports_sdpa = True
  420. _supports_flex_attn = True
  421. _can_compile_fullgraph = False
  422. _supports_attention_backend = True
  423. _can_record_outputs = {
  424. "router_logits": OutputRecorder(DogeCDMoE, index=1),
  425. "hidden_states": DogeDecoderLayer,
  426. "attentions": DogeAttention,
  427. }
  428. def _init_weights(self, module):
  429. """Initialize the weights"""
  430. super()._init_weights(module)
  431. if isinstance(module, DogeAttention):
  432. if hasattr(module, "A"):
  433. module.A.data.zero_()
  434. elif isinstance(module, DogeDecoderLayer):
  435. if hasattr(module, "input_residual"):
  436. module.input_residual.data.fill_(1.0)
  437. if hasattr(module, "post_attention_residual"):
  438. module.post_attention_residual.data.fill_(1.0)
  439. @auto_docstring
  440. class DogeModel(DogePreTrainedModel):
  441. def __init__(self, config: DogeConfig):
  442. super().__init__(config)
  443. self.padding_idx = config.pad_token_id
  444. self.vocab_size = config.vocab_size
  445. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  446. self.layers = nn.ModuleList(
  447. [DogeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  448. )
  449. self.norm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  450. self.rotary_emb = DogeRotaryEmbedding(config=config)
  451. self.gradient_checkpointing = False
  452. # Initialize weights and apply final processing
  453. self.post_init()
  454. @check_model_inputs()
  455. @auto_docstring
  456. def forward(
  457. self,
  458. input_ids: Optional[torch.LongTensor] = None,
  459. attention_mask: Optional[torch.Tensor] = None,
  460. position_ids: Optional[torch.LongTensor] = None,
  461. past_key_values: Optional[Cache] = None,
  462. inputs_embeds: Optional[torch.FloatTensor] = None,
  463. use_cache: Optional[bool] = None,
  464. cache_position: Optional[torch.LongTensor] = None,
  465. **kwargs: Unpack[TransformersKwargs],
  466. ) -> MoeModelOutputWithPast:
  467. if (input_ids is None) ^ (inputs_embeds is not None):
  468. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  469. if use_cache and past_key_values is None:
  470. past_key_values = DynamicCache(config=self.config)
  471. if inputs_embeds is None:
  472. inputs_embeds = self.embed_tokens(input_ids)
  473. if cache_position is None:
  474. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  475. cache_position = torch.arange(
  476. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  477. )
  478. if position_ids is None:
  479. position_ids = cache_position.unsqueeze(0)
  480. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  481. causal_mask = mask_function(
  482. config=self.config,
  483. input_embeds=inputs_embeds,
  484. attention_mask=attention_mask,
  485. cache_position=cache_position,
  486. past_key_values=past_key_values,
  487. position_ids=position_ids,
  488. )
  489. hidden_states = inputs_embeds
  490. # create position embeddings to be shared across the decoder layers
  491. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  492. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  493. hidden_states = decoder_layer(
  494. hidden_states,
  495. position_embeddings=position_embeddings,
  496. attention_mask=causal_mask,
  497. position_ids=position_ids,
  498. past_key_values=past_key_values,
  499. use_cache=use_cache,
  500. cache_position=cache_position,
  501. **kwargs,
  502. )
  503. hidden_states = self.norm(hidden_states)
  504. return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
  505. last_hidden_state=hidden_states,
  506. past_key_values=past_key_values,
  507. )
  508. def load_balancing_loss_func(
  509. gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
  510. num_experts: Optional[int] = None,
  511. num_keys: Optional[int] = None,
  512. top_k: int = 2,
  513. attention_mask: Optional[torch.Tensor] = None,
  514. ) -> Union[torch.Tensor, int]:
  515. r"""
  516. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  517. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  518. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  519. experts is too unbalanced.
  520. Args:
  521. gate_logits:
  522. Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of
  523. shape [2, batch_size * sequence_length, num_keys].
  524. num_experts:
  525. Number of experts
  526. num_keys:
  527. Number of keys
  528. top_k:
  529. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  530. parameter.
  531. attention_mask (`torch.Tensor`, *optional*):
  532. The attention_mask used in forward function
  533. shape [batch_size X sequence_length] if not None.
  534. Returns:
  535. The auxiliary loss.
  536. """
  537. if gate_logits is None or not isinstance(gate_logits, tuple):
  538. return 0
  539. compute_dtype = gate_logits[0].dtype
  540. compute_device = gate_logits[0].device
  541. all_expert_indices = []
  542. all_routing_weights = []
  543. for layer_gate_logits in gate_logits:
  544. layer_gate_logits = layer_gate_logits.to(compute_device)
  545. (scores_x, scores_y), (indices_x, indices_y) = layer_gate_logits.topk(num_keys, dim=-1)
  546. all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
  547. all_indices = indices_x.unsqueeze(-1) * num_keys + indices_y.unsqueeze(-2)
  548. all_scores = all_scores.view(*all_scores.shape[:-2], -1)
  549. all_indices = all_indices.view(*all_indices.shape[:-2], -1)
  550. _, position_indices = all_scores.topk(top_k, dim=-1)
  551. expert_indices = all_indices.gather(-1, position_indices)
  552. routing_weights = F.softmax(all_scores, dim=-1)
  553. all_expert_indices.append(expert_indices)
  554. all_routing_weights.append(routing_weights)
  555. all_expert_indices = torch.cat(all_expert_indices, dim=0)
  556. all_routing_weights = torch.cat(all_routing_weights, dim=0)
  557. if attention_mask is None:
  558. # Compute the percentage of tokens routed to each experts
  559. all_expert_indices = all_expert_indices.view(-1)
  560. tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
  561. pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
  562. tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / all_expert_indices.shape[0]
  563. # Compute the average probability of routing to these experts
  564. router_prob_per_expert = torch.mean(all_routing_weights, dim=0)
  565. else:
  566. batch_size, sequence_length = attention_mask.shape
  567. num_hidden_layers = len(gate_logits)
  568. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  569. expert_attention_mask = (
  570. attention_mask[None, :, :, None]
  571. .expand((num_hidden_layers, batch_size, sequence_length, top_k))
  572. .reshape(-1)
  573. .to(compute_device)
  574. )
  575. all_expert_indices = all_expert_indices.view(-1)[expert_attention_mask.bool()]
  576. # Compute the percentage of tokens routed to each experts
  577. tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
  578. pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
  579. tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / torch.sum(
  580. expert_attention_mask
  581. )
  582. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  583. router_per_expert_attention_mask = (
  584. attention_mask[None, :, :, None]
  585. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  586. .reshape(-1, num_experts)
  587. .to(compute_device)
  588. )
  589. # Compute the average probability of routing to these experts
  590. router_prob_per_expert = torch.sum(all_routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  591. router_per_expert_attention_mask, dim=0
  592. )
  593. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
  594. return overall_loss * num_experts
  595. @auto_docstring
  596. class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
  597. _tied_weights_keys = ["lm_head.weight"]
  598. _tp_plan = {"lm_head": "colwise_rep"}
  599. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  600. def __init__(self, config):
  601. super().__init__(config)
  602. self.model = DogeModel(config)
  603. self.vocab_size = config.vocab_size
  604. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  605. self.router_aux_loss_coef = config.router_aux_loss_coef
  606. self.num_experts = config.num_experts
  607. self.num_experts_per_tok = config.num_experts_per_tok
  608. # Initialize weights and apply final processing
  609. self.post_init()
  610. @can_return_tuple
  611. @auto_docstring
  612. def forward(
  613. self,
  614. input_ids: Optional[torch.LongTensor] = None,
  615. attention_mask: Optional[torch.Tensor] = None,
  616. position_ids: Optional[torch.LongTensor] = None,
  617. past_key_values: Optional[Cache] = None,
  618. inputs_embeds: Optional[torch.FloatTensor] = None,
  619. labels: Optional[torch.LongTensor] = None,
  620. use_cache: Optional[bool] = None,
  621. cache_position: Optional[torch.LongTensor] = None,
  622. logits_to_keep: Union[int, torch.Tensor] = 0,
  623. output_router_logits: Optional[bool] = None,
  624. **kwargs: Unpack[TransformersKwargs],
  625. ) -> MoeCausalLMOutputWithPast:
  626. r"""
  627. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  628. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  629. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  630. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  631. Example:
  632. ```python
  633. >>> from transformers import AutoTokenizer, DogeForCausalLM
  634. >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M")
  635. >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M")
  636. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  637. >>> inputs = tokenizer(prompt, return_tensors="pt")
  638. >>> # Generate
  639. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  640. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  641. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  642. ```"""
  643. output_router_logits = (
  644. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  645. )
  646. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  647. outputs: MoeModelOutputWithPast = self.model(
  648. input_ids=input_ids,
  649. attention_mask=attention_mask,
  650. position_ids=position_ids,
  651. past_key_values=past_key_values,
  652. inputs_embeds=inputs_embeds,
  653. use_cache=use_cache,
  654. cache_position=cache_position,
  655. **kwargs,
  656. )
  657. hidden_states = outputs.last_hidden_state
  658. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  659. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  660. logits = self.lm_head(hidden_states[:, slice_indices, :])
  661. loss = None
  662. if labels is not None:
  663. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  664. aux_loss = None
  665. if output_router_logits:
  666. aux_loss = load_balancing_loss_func(
  667. outputs.router_logits,
  668. self.num_experts,
  669. math.floor(math.sqrt(self.num_experts)),
  670. self.num_experts_per_tok,
  671. attention_mask,
  672. )
  673. if labels is not None:
  674. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  675. return MoeCausalLMOutputWithPast(
  676. loss=loss,
  677. aux_loss=aux_loss,
  678. logits=logits,
  679. past_key_values=outputs.past_key_values,
  680. hidden_states=outputs.hidden_states,
  681. attentions=outputs.attentions,
  682. router_logits=outputs.router_logits,
  683. )
  684. class DogeForSequenceClassification(GenericForSequenceClassification, DogePreTrainedModel):
  685. pass
  686. __all__ = ["DogeForCausalLM", "DogeModel", "DogePreTrainedModel", "DogeForSequenceClassification"]