modular_doge.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800
  1. # coding=utf-8
  2. # Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # The Doge family of small language models is trained by SmallDoge Team.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """PyTorch Doge model."""
  18. import math
  19. from typing import Callable, Optional, Union
  20. import torch
  21. import torch.nn.functional as F
  22. from torch import nn
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache
  25. from ...configuration_utils import PretrainedConfig
  26. from ...integrations.flex_attention import compile_friendly_flex_attention
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  29. from ...modeling_rope_utils import rope_config_validation
  30. from ...modeling_utils import AttentionInterface, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...utils import TransformersKwargs, is_torch_flex_attn_available
  33. from ...utils.deprecation import deprecate_kwarg
  34. from ...utils.generic import OutputRecorder
  35. from ..llama.modeling_llama import (
  36. LlamaForSequenceClassification,
  37. LlamaMLP,
  38. LlamaPreTrainedModel,
  39. LlamaRMSNorm,
  40. LlamaRotaryEmbedding,
  41. apply_rotary_pos_emb,
  42. eager_attention_forward,
  43. repeat_kv,
  44. )
  45. from ..mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
  46. if is_torch_flex_attn_available():
  47. from torch.nn.attention.flex_attention import BlockMask
  48. class DogeConfig(PretrainedConfig):
  49. r"""
  50. This is the configuration class to store the configuration of a [`DogeModel`]. It is used to instantiate an Doge
  51. model according to the specified arguments, defining the model architecture like [SmallDoge/Doge-320M](https://huggingface.co/SmallDoge/Doge-320M).
  52. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  53. documentation from [`PretrainedConfig`] for more information.
  54. Args:
  55. vocab_size (`int`, *optional*, defaults to 32768):
  56. Vocabulary size of the Doge2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DogeModel`]
  57. hidden_size (`int`, *optional*, defaults to 1024):
  58. Dimension of the hidden representations.
  59. intermediate_size (`int`, *optional*, defaults to 2048):
  60. Dimension of the MLP representations.
  61. num_hidden_layers (`int`, *optional*, defaults to 32):
  62. Number of hidden layers in the Transformer decoder.
  63. hidden_dropout (`float`, *optional*, defaults to 0.0):
  64. Dropout probability for each sequence transformation and state transformation module.
  65. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  66. The non-linear activation function (function or string) in the decoder.
  67. initializer_range (`float`, *optional*, defaults to 0.02):
  68. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  69. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  70. The epsilon used by the rms normalization layers.
  71. use_cache (`bool`, *optional*, defaults to `True`):
  72. Whether or not the model should return the last key/values attentions (not used by all models). Only
  73. relevant if `config.is_decoder=True`.
  74. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  75. Whether the model's input and output word embeddings should be tied.
  76. max_position_embeddings (`int`, *optional*, defaults to 2048):
  77. The maximum sequence length that this model might ever be used with.
  78. rope_theta (`float`, *optional*, defaults to 10000.0):
  79. The base period of the RoPE embeddings.
  80. rope_scaling (`Dict`, *optional*):
  81. Dictionary containing the scaling configuration for the RoPE embeddings.
  82. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly.
  83. Doge family of small models use `{ 'rope_type': 'dynamic', 'factor': 4.0, 'original_max_position_embeddings': 2048 }` as the default value.
  84. Expected contents:
  85. `rope_type` (`str`):
  86. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation.
  87. `factor` (`float`, *optional*):
  88. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings.
  89. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length.
  90. `original_max_position_embeddings` (`int`, *optional*):
  91. Used with 'dynamic', 'longrope' and 'llama3'.
  92. The original max position embeddings used during pretraining.
  93. `attention_factor` (`float`, *optional*):
  94. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  95. computation.
  96. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
  97. `beta_fast` (`float`, *optional*):
  98. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  99. ramp function. If unspecified, it defaults to 32.
  100. `beta_slow` (`float`, *optional*):
  101. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  102. ramp function. If unspecified, it defaults to 1.
  103. `short_factor` (`List[float]`, *optional*):
  104. Only used with 'longrope'. The scaling factor to be applied to short contexts (<`original_max_position_embeddings`).
  105. Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
  106. `long_factor` (`List[float]`, *optional*):
  107. Only used with 'longrope'. The scaling factor to be applied to long contexts (<`original_max_position_embeddings`).
  108. Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
  109. `low_freq_factor` (`float`, *optional*):
  110. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  111. `high_freq_factor` (`float`, *optional*):
  112. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  113. num_attention_heads (`int`, *optional*, defaults to 8):
  114. Number of attention heads for each attention layer in the Transformer decoder.
  115. num_key_value_heads (`int`, *optional*):
  116. This is the number of key_value heads that should be used to implement Grouped Query Attention.
  117. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  118. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.
  119. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group.
  120. For more details checkout [this paper](https://huggingface.co/papers/2305.13245).
  121. If it is not specified, will default to `num_attention_heads`.
  122. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  123. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  124. attention_dropout (`float`, *optional*, defaults to 0.0):
  125. The dropout ratio for the attention probabilities.
  126. mlp_bias (`bool`, *optional*, defaults to `False`):
  127. Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
  128. sliding_window (`int`, *optional*):
  129. Sliding window attention window size. If not specified, will default to `None`.
  130. keep_window_size (`int`, *optional*, defaults to 2048):
  131. The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
  132. is_moe (`bool`, *optional*, defaults to `False`):
  133. Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize.
  134. num_experts (`int`, *optional*, defaults to 16384):
  135. Number of routed experts in the model. This is only used when `is_moe=True`.
  136. num_experts_per_tok (`int`, *optional*, defaults to 64):
  137. Number of selected experts to route per-token.
  138. norm_topk_prob (`bool`, *optional*, defaults to `False`):
  139. Whether to normalize the topk probabilities.
  140. output_router_logits (`bool`, *optional*, defaults to `False`):
  141. Whether or not the router logits should be returned by the model. Enabling this will also
  142. allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
  143. router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
  144. The aux loss factor for the total loss.
  145. ```python
  146. >>> from transformers import DogeConfig, DogeModel
  147. >>> # Initializing a Doge-320M style configuration
  148. >>> configuration = DogeConfig()
  149. >>> # Initializing a model from the Doge-320M style configuration
  150. >>> model = DogeModel(configuration)
  151. >>> # Accessing the model configuration
  152. >>> configuration = model.config
  153. ```"""
  154. model_type = "doge"
  155. keys_to_ignore_at_inference = ["past_key_values"]
  156. # Default tensor parallel plan for base model `DogeModel`
  157. base_model_tp_plan = {
  158. "layers.*.self_attn.q_proj": "colwise",
  159. "layers.*.self_attn.k_proj": "colwise",
  160. "layers.*.self_attn.v_proj": "colwise",
  161. "layers.*.self_attn.dt_proj": "rowwise",
  162. "layers.*.self_attn.o_proj": "rowwise",
  163. "layers.*.input_layernorm.weight": "sequence_parallel",
  164. "layers.*.input_residual.weight": "sequence_parallel",
  165. "layers.*.post_attention_layernorm.weight": "sequence_parallel",
  166. "layers.*.post_attention_residual.weight": "sequence_parallel",
  167. "norm.weight": "sequence_parallel",
  168. "layers.*.mlp.gate_proj": "colwise",
  169. "layers.*.mlp.up_proj": "colwise",
  170. "layers.*.mlp.down_proj": "rowwise",
  171. "layers.*.mlp.router_gate": "colwise_rep",
  172. "layers.*.mlp.down_embed": "rowwise_rep",
  173. "layers.*.mlp.up_embed": "rowwise_rep",
  174. }
  175. base_model_pp_plan = {
  176. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  177. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  178. "norm": (["hidden_states"], ["hidden_states"]),
  179. }
  180. def __init__(
  181. self,
  182. vocab_size=32768,
  183. hidden_size=1024,
  184. intermediate_size=2048,
  185. num_hidden_layers=32,
  186. hidden_dropout=0.0,
  187. hidden_act="silu",
  188. initializer_range=0.02,
  189. rms_norm_eps=1e-06,
  190. use_cache=True,
  191. tie_word_embeddings=False,
  192. max_position_embeddings=2048,
  193. rope_theta=10000.0,
  194. rope_scaling=None,
  195. num_attention_heads=8,
  196. num_key_value_heads=None,
  197. attention_bias=False,
  198. attention_dropout=0.0,
  199. mlp_bias=False,
  200. sliding_window=None,
  201. keep_window_size=2048,
  202. is_moe=False,
  203. num_experts=16384,
  204. num_experts_per_tok=64,
  205. norm_topk_prob=False,
  206. output_router_logits=False,
  207. router_aux_loss_coef=0.001,
  208. **kwargs,
  209. ):
  210. self.vocab_size = vocab_size
  211. self.hidden_size = hidden_size
  212. self.intermediate_size = intermediate_size
  213. self.num_hidden_layers = num_hidden_layers
  214. self.hidden_dropout = hidden_dropout
  215. self.hidden_act = hidden_act
  216. self.initializer_range = initializer_range
  217. self.rms_norm_eps = rms_norm_eps
  218. self.use_cache = use_cache
  219. self.max_position_embeddings = max_position_embeddings
  220. self.rope_theta = rope_theta
  221. self.rope_scaling = rope_scaling
  222. self.num_attention_heads = num_attention_heads
  223. self.num_key_value_heads = num_key_value_heads
  224. self.attention_bias = attention_bias
  225. self.attention_dropout = attention_dropout
  226. self.mlp_bias = mlp_bias
  227. self.sliding_window = sliding_window
  228. self.keep_window_size = keep_window_size
  229. self.is_moe = is_moe
  230. self.num_experts = num_experts
  231. self.num_experts_per_tok = num_experts_per_tok
  232. self.norm_topk_prob = norm_topk_prob
  233. self.output_router_logits = output_router_logits
  234. self.router_aux_loss_coef = router_aux_loss_coef
  235. # Validate the correctness of rotary position embeddings parameters
  236. # BC: if there is a 'type' field, copy it it to 'rope_type'.
  237. if self.rope_scaling is not None and "type" in self.rope_scaling:
  238. self.rope_scaling["rope_type"] = self.rope_scaling["type"]
  239. rope_config_validation(self)
  240. # for backward compatibility
  241. if num_key_value_heads is None:
  242. self.num_key_value_heads = num_attention_heads
  243. super().__init__(
  244. tie_word_embeddings=tie_word_embeddings,
  245. **kwargs,
  246. )
  247. class DogeRMSNorm(LlamaRMSNorm):
  248. pass
  249. class DogeRotaryEmbedding(LlamaRotaryEmbedding):
  250. pass
  251. def flex_attention_forward(
  252. module: nn.Module,
  253. query: torch.Tensor,
  254. key: torch.Tensor,
  255. value: torch.Tensor,
  256. attention_mask: Union[torch.Tensor, "BlockMask"],
  257. scaling: Optional[float] = None,
  258. softcap: Optional[float] = None,
  259. head_mask: Optional[torch.Tensor] = None,
  260. **kwargs,
  261. ) -> tuple[torch.Tensor, torch.Tensor]:
  262. block_mask = None
  263. causal_mask = None
  264. if isinstance(attention_mask, BlockMask):
  265. block_mask = attention_mask
  266. else:
  267. causal_mask = attention_mask
  268. if causal_mask is not None:
  269. causal_mask = causal_mask[:, :, :, : key.shape[-2]]
  270. def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
  271. if softcap is not None:
  272. score = softcap * torch.tanh(score / softcap)
  273. if causal_mask is not None:
  274. score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx]
  275. if head_mask is not None:
  276. score = score + head_mask[batch_idx][head_idx][0][0]
  277. return score
  278. attn_output, attention_weights = compile_friendly_flex_attention(
  279. query,
  280. key,
  281. value,
  282. score_mod=score_mod,
  283. block_mask=block_mask,
  284. enable_gqa=True,
  285. scale=scaling,
  286. # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
  287. # For simplification, we thus always return it as no additional computations are introduced.
  288. return_lse=True,
  289. )
  290. # lse is returned in float32
  291. attention_weights = attention_weights.to(value.dtype)
  292. attn_output = attn_output.transpose(1, 2).contiguous()
  293. return attn_output, attention_weights
  294. ALL_ATTENTION_FUNCTIONS = AttentionInterface()
  295. ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward
  296. class DogeAttention(nn.Module):
  297. def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
  298. super().__init__()
  299. self.config = config
  300. self.layer_idx = layer_idx
  301. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  302. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  303. self.scaling = self.head_dim**-0.5
  304. self.attention_dropout = config.attention_dropout
  305. self.keep_window_size = config.keep_window_size
  306. self.q_proj = nn.Linear(
  307. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  308. )
  309. self.k_proj = nn.Linear(
  310. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  311. )
  312. self.v_proj = nn.Linear(
  313. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  314. )
  315. # dynamic mask for the QK^T attention weights matrix
  316. self.A = nn.Parameter(torch.zeros(config.num_key_value_heads))
  317. self.dt_proj = nn.Linear(
  318. config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias
  319. )
  320. self.o_proj = nn.Linear(
  321. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  322. )
  323. self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  324. self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
  325. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  326. def forward(
  327. self,
  328. hidden_states: torch.Tensor,
  329. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  330. attention_mask: Optional[torch.Tensor] = None,
  331. past_key_values: Optional[Cache] = None,
  332. cache_position: Optional[torch.LongTensor] = None,
  333. **kwargs,
  334. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  335. input_shape = hidden_states.shape[:-1]
  336. hidden_shape = (*input_shape, -1, self.head_dim)
  337. query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  338. key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  339. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  340. cos, sin = position_embeddings
  341. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  342. if past_key_values is not None:
  343. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  344. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  345. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  346. # calculate dynamic mask from value_states
  347. dt_states = self.dt_proj(
  348. value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
  349. )
  350. dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
  351. attn_mask = self.prepare_dynamic_mask(
  352. hidden_states=hidden_states,
  353. dt_states=dt_states,
  354. keep_window_size=self.keep_window_size,
  355. attention_mask=attention_mask,
  356. )
  357. attn_mask = repeat_kv(attn_mask, self.num_key_value_groups)
  358. attention_interface: Callable = eager_attention_forward
  359. if self.config._attn_implementation != "eager":
  360. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  361. attn_output, attn_weights = attention_interface(
  362. self,
  363. query_states,
  364. key_states,
  365. value_states,
  366. attention_mask=attn_mask,
  367. dropout=0.0 if not self.training else self.attention_dropout,
  368. scaling=self.scaling,
  369. **kwargs,
  370. )
  371. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  372. attn_output = self.o_proj(attn_output)
  373. return attn_output, attn_weights
  374. def prepare_dynamic_mask(
  375. self,
  376. hidden_states: torch.Tensor,
  377. dt_states: torch.Tensor,
  378. keep_window_size: int = 2048,
  379. attention_mask: Optional[torch.Tensor] = None,
  380. ):
  381. """
  382. The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention.
  383. Combine `dt_states` with `attention_mask` to generate the final `attn_mask`.
  384. Args:
  385. hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
  386. dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`.
  387. keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
  388. attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
  389. """
  390. min_dtype = torch.finfo(hidden_states.dtype).min
  391. dtype = hidden_states.dtype
  392. attn_mask = dt_states[:, :, None, :].expand(
  393. -1, -1, hidden_states.shape[1], -1
  394. ) # [batch_size, num_heads, query_len, key_len]
  395. if attention_mask is not None and not isinstance(attention_mask, BlockMask):
  396. if attention_mask.dtype == torch.bool:
  397. dtype = hidden_states.dtype
  398. attention_mask = torch.where(
  399. attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype
  400. )
  401. attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype)
  402. if attn_mask.shape[-1] > keep_window_size:
  403. active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device)
  404. topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, largest=True, sorted=False).indices
  405. active_mask = active_mask.scatter(-1, topk_indices, 1.0)
  406. attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype)
  407. return attn_mask
  408. class DogeMLP(LlamaMLP):
  409. pass
  410. class DogeCDMoE(nn.Module):
  411. def __init__(self, config: DogeConfig):
  412. super().__init__()
  413. self.hidden_size = config.hidden_size
  414. self.intermediate_size = config.intermediate_size
  415. self.act_fn = ACT2FN[config.hidden_act]
  416. self.num_experts = config.num_experts
  417. self.num_keys = math.floor(math.sqrt(self.num_experts))
  418. self.top_k = config.num_experts_per_tok
  419. self.norm_topk_prob = config.norm_topk_prob
  420. # shared expert
  421. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  422. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  423. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  424. # router gate for retrieval experts
  425. self.router_gate = nn.Linear(self.hidden_size, self.num_keys * 2, bias=False)
  426. # routed experts
  427. self.down_embed = nn.Embedding(self.num_experts, self.hidden_size)
  428. self.up_embed = nn.Embedding(self.num_experts, self.hidden_size)
  429. def forward(
  430. self,
  431. hidden_states: torch.Tensor,
  432. **kwargs,
  433. ) -> torch.Tensor:
  434. bsz, seq_len, _ = hidden_states.shape
  435. # get routing logits with router gate
  436. router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
  437. # get experts with the highest routing logits
  438. (scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1)
  439. all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
  440. all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
  441. all_scores = all_scores.view(*all_scores.shape[:-2], -1)
  442. all_indices = all_indices.view(*all_indices.shape[:-2], -1)
  443. scores, position_indices = all_scores.topk(self.top_k, dim=-1)
  444. indices = all_indices.gather(-1, position_indices)
  445. routing_weights = F.softmax(scores, dim=-1)
  446. if self.norm_topk_prob:
  447. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  448. # mix routed experts states with shared expert states
  449. down_embed = self.down_embed(indices)
  450. up_embed = self.up_embed(indices)
  451. experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1)
  452. experts_weights = self.act_fn(experts_weights) * routing_weights
  453. experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
  454. hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
  455. hidden_states = hidden_states + experts_states
  456. return hidden_states, router_logits
  457. class DogeDecoderLayer(GradientCheckpointingLayer):
  458. def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
  459. super().__init__()
  460. self.hidden_dropout = config.hidden_dropout
  461. self.input_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  462. self.self_attn = DogeAttention(config=config, layer_idx=layer_idx)
  463. self.input_residual = nn.Parameter(torch.ones(config.hidden_size))
  464. self.post_attention_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  465. self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config)
  466. self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size))
  467. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  468. def forward(
  469. self,
  470. hidden_states: torch.Tensor,
  471. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  472. attention_mask: Optional[torch.Tensor] = None,
  473. position_ids: Optional[torch.LongTensor] = None,
  474. past_key_values: Optional[Cache] = None,
  475. use_cache: Optional[bool] = False,
  476. cache_position: Optional[torch.LongTensor] = None,
  477. **kwargs: Unpack[TransformersKwargs],
  478. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  479. # sequence transformation
  480. residual = hidden_states
  481. hidden_states = self.input_layernorm(hidden_states)
  482. hidden_states, self_attn_weights = self.self_attn(
  483. hidden_states=hidden_states,
  484. position_embeddings=position_embeddings,
  485. attention_mask=attention_mask,
  486. position_ids=position_ids,
  487. past_key_values=past_key_values,
  488. use_cache=use_cache,
  489. cache_position=cache_position,
  490. **kwargs,
  491. )
  492. hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
  493. hidden_states = self.input_residual * residual + hidden_states
  494. # state transformation
  495. residual = hidden_states
  496. hidden_states = self.post_attention_layernorm(hidden_states)
  497. hidden_states = self.mlp(hidden_states)
  498. hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
  499. hidden_states = self.post_attention_residual * residual + hidden_states
  500. return hidden_states
  501. class DogePreTrainedModel(LlamaPreTrainedModel):
  502. _supports_flash_attn = False
  503. _can_compile_fullgraph = False
  504. _can_record_outputs = {
  505. "router_logits": OutputRecorder(DogeCDMoE, index=1),
  506. "hidden_states": DogeDecoderLayer,
  507. "attentions": DogeAttention,
  508. }
  509. def _init_weights(self, module):
  510. """Initialize the weights"""
  511. PreTrainedModel._init_weights(self, module)
  512. if isinstance(module, DogeAttention):
  513. if hasattr(module, "A"):
  514. module.A.data.zero_()
  515. elif isinstance(module, DogeDecoderLayer):
  516. if hasattr(module, "input_residual"):
  517. module.input_residual.data.fill_(1.0)
  518. if hasattr(module, "post_attention_residual"):
  519. module.post_attention_residual.data.fill_(1.0)
  520. class DogeModel(MixtralModel):
  521. pass
  522. def load_balancing_loss_func(
  523. gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
  524. num_experts: Optional[int] = None,
  525. num_keys: Optional[int] = None,
  526. top_k: int = 2,
  527. attention_mask: Optional[torch.Tensor] = None,
  528. ) -> Union[torch.Tensor, int]:
  529. r"""
  530. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  531. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  532. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  533. experts is too unbalanced.
  534. Args:
  535. gate_logits:
  536. Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of
  537. shape [2, batch_size * sequence_length, num_keys].
  538. num_experts:
  539. Number of experts
  540. num_keys:
  541. Number of keys
  542. top_k:
  543. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  544. parameter.
  545. attention_mask (`torch.Tensor`, *optional*):
  546. The attention_mask used in forward function
  547. shape [batch_size X sequence_length] if not None.
  548. Returns:
  549. The auxiliary loss.
  550. """
  551. if gate_logits is None or not isinstance(gate_logits, tuple):
  552. return 0
  553. compute_dtype = gate_logits[0].dtype
  554. compute_device = gate_logits[0].device
  555. all_expert_indices = []
  556. all_routing_weights = []
  557. for layer_gate_logits in gate_logits:
  558. layer_gate_logits = layer_gate_logits.to(compute_device)
  559. (scores_x, scores_y), (indices_x, indices_y) = layer_gate_logits.topk(num_keys, dim=-1)
  560. all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
  561. all_indices = indices_x.unsqueeze(-1) * num_keys + indices_y.unsqueeze(-2)
  562. all_scores = all_scores.view(*all_scores.shape[:-2], -1)
  563. all_indices = all_indices.view(*all_indices.shape[:-2], -1)
  564. _, position_indices = all_scores.topk(top_k, dim=-1)
  565. expert_indices = all_indices.gather(-1, position_indices)
  566. routing_weights = F.softmax(all_scores, dim=-1)
  567. all_expert_indices.append(expert_indices)
  568. all_routing_weights.append(routing_weights)
  569. all_expert_indices = torch.cat(all_expert_indices, dim=0)
  570. all_routing_weights = torch.cat(all_routing_weights, dim=0)
  571. if attention_mask is None:
  572. # Compute the percentage of tokens routed to each experts
  573. all_expert_indices = all_expert_indices.view(-1)
  574. tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
  575. pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
  576. tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / all_expert_indices.shape[0]
  577. # Compute the average probability of routing to these experts
  578. router_prob_per_expert = torch.mean(all_routing_weights, dim=0)
  579. else:
  580. batch_size, sequence_length = attention_mask.shape
  581. num_hidden_layers = len(gate_logits)
  582. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  583. expert_attention_mask = (
  584. attention_mask[None, :, :, None]
  585. .expand((num_hidden_layers, batch_size, sequence_length, top_k))
  586. .reshape(-1)
  587. .to(compute_device)
  588. )
  589. all_expert_indices = all_expert_indices.view(-1)[expert_attention_mask.bool()]
  590. # Compute the percentage of tokens routed to each experts
  591. tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device)
  592. pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device)
  593. tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / torch.sum(
  594. expert_attention_mask
  595. )
  596. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  597. router_per_expert_attention_mask = (
  598. attention_mask[None, :, :, None]
  599. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  600. .reshape(-1, num_experts)
  601. .to(compute_device)
  602. )
  603. # Compute the average probability of routing to these experts
  604. router_prob_per_expert = torch.sum(all_routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  605. router_per_expert_attention_mask, dim=0
  606. )
  607. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
  608. return overall_loss * num_experts
  609. class DogeForCausalLM(MixtralForCausalLM):
  610. def __init__(self, config):
  611. super().__init__(config)
  612. self.model = DogeModel(config)
  613. self.num_experts = config.num_experts
  614. def forward(
  615. self,
  616. input_ids: Optional[torch.LongTensor] = None,
  617. attention_mask: Optional[torch.Tensor] = None,
  618. position_ids: Optional[torch.LongTensor] = None,
  619. past_key_values: Optional[Cache] = None,
  620. inputs_embeds: Optional[torch.FloatTensor] = None,
  621. labels: Optional[torch.LongTensor] = None,
  622. use_cache: Optional[bool] = None,
  623. cache_position: Optional[torch.LongTensor] = None,
  624. logits_to_keep: Union[int, torch.Tensor] = 0,
  625. output_router_logits: Optional[bool] = None,
  626. **kwargs: Unpack[TransformersKwargs],
  627. ) -> MoeCausalLMOutputWithPast:
  628. r"""
  629. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  630. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  631. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  632. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  633. Example:
  634. ```python
  635. >>> from transformers import AutoTokenizer, DogeForCausalLM
  636. >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M")
  637. >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M")
  638. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  639. >>> inputs = tokenizer(prompt, return_tensors="pt")
  640. >>> # Generate
  641. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  642. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  643. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  644. ```"""
  645. output_router_logits = (
  646. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  647. )
  648. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  649. outputs: MoeModelOutputWithPast = self.model(
  650. input_ids=input_ids,
  651. attention_mask=attention_mask,
  652. position_ids=position_ids,
  653. past_key_values=past_key_values,
  654. inputs_embeds=inputs_embeds,
  655. use_cache=use_cache,
  656. cache_position=cache_position,
  657. **kwargs,
  658. )
  659. hidden_states = outputs.last_hidden_state
  660. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  661. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  662. logits = self.lm_head(hidden_states[:, slice_indices, :])
  663. loss = None
  664. if labels is not None:
  665. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  666. aux_loss = None
  667. if output_router_logits:
  668. aux_loss = load_balancing_loss_func(
  669. outputs.router_logits,
  670. self.num_experts,
  671. math.floor(math.sqrt(self.num_experts)),
  672. self.num_experts_per_tok,
  673. attention_mask,
  674. )
  675. if labels is not None:
  676. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  677. return MoeCausalLMOutputWithPast(
  678. loss=loss,
  679. aux_loss=aux_loss,
  680. logits=logits,
  681. past_key_values=outputs.past_key_values,
  682. hidden_states=outputs.hidden_states,
  683. attentions=outputs.attentions,
  684. router_logits=outputs.router_logits,
  685. )
  686. class DogeForSequenceClassification(LlamaForSequenceClassification):
  687. pass
  688. __all__ = [
  689. "DogeConfig",
  690. "DogeForCausalLM",
  691. "DogeModel",
  692. "DogePreTrainedModel",
  693. "DogeForSequenceClassification",
  694. ]