modeling_dots1.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/dots1/modular_dots1.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_dots1.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from typing import Callable, Optional, Union
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub
  29. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  33. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  37. from ...utils.deprecation import deprecate_kwarg
  38. from ...utils.generic import check_model_inputs
  39. from .configuration_dots1 import Dots1Config
  40. @use_kernel_forward_from_hub("RMSNorm")
  41. class Dots1RMSNorm(nn.Module):
  42. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  43. """
  44. Dots1RMSNorm is equivalent to T5LayerNorm
  45. """
  46. super().__init__()
  47. self.weight = nn.Parameter(torch.ones(hidden_size))
  48. self.variance_epsilon = eps
  49. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  50. input_dtype = hidden_states.dtype
  51. hidden_states = hidden_states.to(torch.float32)
  52. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  53. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  54. return self.weight * hidden_states.to(input_dtype)
  55. def extra_repr(self):
  56. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  57. class Dots1RotaryEmbedding(nn.Module):
  58. inv_freq: torch.Tensor # fix linting for `register_buffer`
  59. def __init__(self, config: Dots1Config, device=None):
  60. super().__init__()
  61. # BC: "rope_type" was originally "type"
  62. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  63. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  64. else:
  65. self.rope_type = "default"
  66. self.max_seq_len_cached = config.max_position_embeddings
  67. self.original_max_seq_len = config.max_position_embeddings
  68. self.config = config
  69. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  70. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  71. self.register_buffer("inv_freq", inv_freq, persistent=False)
  72. self.original_inv_freq = self.inv_freq
  73. @torch.no_grad()
  74. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  75. def forward(self, x, position_ids):
  76. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  77. position_ids_expanded = position_ids[:, None, :].float()
  78. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  79. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  80. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  81. emb = torch.cat((freqs, freqs), dim=-1)
  82. cos = emb.cos() * self.attention_scaling
  83. sin = emb.sin() * self.attention_scaling
  84. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  85. def rotate_half(x):
  86. """Rotates half the hidden dims of the input."""
  87. x1 = x[..., : x.shape[-1] // 2]
  88. x2 = x[..., x.shape[-1] // 2 :]
  89. return torch.cat((-x2, x1), dim=-1)
  90. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  91. """Applies Rotary Position Embedding to the query and key tensors.
  92. Args:
  93. q (`torch.Tensor`): The query tensor.
  94. k (`torch.Tensor`): The key tensor.
  95. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  96. sin (`torch.Tensor`): The sine part of the rotary embedding.
  97. position_ids (`torch.Tensor`, *optional*):
  98. Deprecated and unused.
  99. unsqueeze_dim (`int`, *optional*, defaults to 1):
  100. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  101. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  102. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  103. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  104. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  105. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  106. Returns:
  107. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  108. """
  109. cos = cos.unsqueeze(unsqueeze_dim)
  110. sin = sin.unsqueeze(unsqueeze_dim)
  111. q_embed = (q * cos) + (rotate_half(q) * sin)
  112. k_embed = (k * cos) + (rotate_half(k) * sin)
  113. return q_embed, k_embed
  114. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  115. """
  116. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  117. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  118. """
  119. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  120. if n_rep == 1:
  121. return hidden_states
  122. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  123. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  124. def eager_attention_forward(
  125. module: nn.Module,
  126. query: torch.Tensor,
  127. key: torch.Tensor,
  128. value: torch.Tensor,
  129. attention_mask: Optional[torch.Tensor],
  130. scaling: float,
  131. dropout: float = 0.0,
  132. **kwargs: Unpack[TransformersKwargs],
  133. ):
  134. key_states = repeat_kv(key, module.num_key_value_groups)
  135. value_states = repeat_kv(value, module.num_key_value_groups)
  136. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  137. if attention_mask is not None:
  138. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  139. attn_weights = attn_weights + causal_mask
  140. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  141. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  142. attn_output = torch.matmul(attn_weights, value_states)
  143. attn_output = attn_output.transpose(1, 2).contiguous()
  144. return attn_output, attn_weights
  145. class Dots1Attention(nn.Module):
  146. """Multi-headed attention from 'Attention Is All You Need' paper"""
  147. def __init__(self, config: Dots1Config, layer_idx: int):
  148. super().__init__()
  149. self.config = config
  150. self.layer_idx = layer_idx
  151. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  152. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  153. self.scaling = self.head_dim**-0.5
  154. self.attention_dropout = config.attention_dropout
  155. self.is_causal = True
  156. self.q_proj = nn.Linear(
  157. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  158. )
  159. self.k_proj = nn.Linear(
  160. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  161. )
  162. self.v_proj = nn.Linear(
  163. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  164. )
  165. self.o_proj = nn.Linear(
  166. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  167. )
  168. self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
  169. self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
  170. self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
  171. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  172. def forward(
  173. self,
  174. hidden_states: torch.Tensor,
  175. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  176. attention_mask: Optional[torch.Tensor],
  177. past_key_values: Optional[Cache] = None,
  178. cache_position: Optional[torch.LongTensor] = None,
  179. **kwargs: Unpack[FlashAttentionKwargs],
  180. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  181. input_shape = hidden_states.shape[:-1]
  182. hidden_shape = (*input_shape, -1, self.head_dim)
  183. query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  184. key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  185. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  186. cos, sin = position_embeddings
  187. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  188. if past_key_values is not None:
  189. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  190. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  191. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  192. attention_interface: Callable = eager_attention_forward
  193. if self.config._attn_implementation != "eager":
  194. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  195. attn_output, attn_weights = attention_interface(
  196. self,
  197. query_states,
  198. key_states,
  199. value_states,
  200. attention_mask,
  201. dropout=0.0 if not self.training else self.attention_dropout,
  202. scaling=self.scaling,
  203. sliding_window=self.sliding_window, # diff with Llama
  204. **kwargs,
  205. )
  206. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  207. attn_output = self.o_proj(attn_output)
  208. return attn_output, attn_weights
  209. class Dots1MLP(nn.Module):
  210. def __init__(self, config, hidden_size=None, intermediate_size=None):
  211. super().__init__()
  212. self.config = config
  213. self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
  214. self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
  215. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  216. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  217. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  218. self.act_fn = ACT2FN[config.hidden_act]
  219. def forward(self, x):
  220. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  221. return down_proj
  222. class Dots1MoE(nn.Module):
  223. """
  224. A mixed expert module containing shared experts.
  225. """
  226. def __init__(self, config):
  227. super().__init__()
  228. self.config = config
  229. self.experts = nn.ModuleList(
  230. [Dots1MLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_routed_experts)]
  231. )
  232. self.gate = Dots1TopkRouter(config)
  233. self.shared_experts = Dots1MLP(
  234. config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
  235. )
  236. def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
  237. r"""
  238. CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
  239. to not have to do a loop here (deepseek has 256 experts soooo yeah).
  240. """
  241. final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
  242. expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
  243. expert_mask = expert_mask.permute(2, 0, 1)
  244. for expert_idx in range(len(self.experts)):
  245. expert = self.experts[expert_idx]
  246. mask = expert_mask[expert_idx]
  247. token_indices, weight_indices = torch.where(mask)
  248. if token_indices.numel() > 0:
  249. expert_weights = topk_weights[token_indices, weight_indices]
  250. expert_input = hidden_states[token_indices]
  251. expert_output = expert(expert_input)
  252. weighted_output = expert_output * expert_weights.unsqueeze(-1)
  253. final_hidden_states.index_add_(0, token_indices, weighted_output)
  254. # in original deepseek, the output of the experts are gathered once we leave this module
  255. # thus the moe module is itelsf an IsolatedParallel module
  256. # and all expert are "local" meaning we shard but we don't gather
  257. return final_hidden_states.type(hidden_states.dtype)
  258. def forward(self, hidden_states):
  259. residuals = hidden_states
  260. orig_shape = hidden_states.shape
  261. topk_indices, topk_weights = self.gate(hidden_states)
  262. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  263. hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
  264. hidden_states = hidden_states + self.shared_experts(residuals)
  265. return hidden_states
  266. class Dots1TopkRouter(nn.Module):
  267. def __init__(self, config):
  268. super().__init__()
  269. self.config = config
  270. self.top_k = config.num_experts_per_tok
  271. self.n_routed_experts = config.n_routed_experts
  272. self.routed_scaling_factor = config.routed_scaling_factor
  273. self.n_group = config.n_group
  274. self.topk_group = config.topk_group
  275. self.norm_topk_prob = config.norm_topk_prob
  276. self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
  277. self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
  278. @torch.no_grad()
  279. def get_topk_indices(self, scores):
  280. scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
  281. group_scores = (
  282. scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
  283. .topk(2, dim=-1)[0]
  284. .sum(dim=-1)
  285. )
  286. group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
  287. group_mask = torch.zeros_like(group_scores)
  288. group_mask.scatter_(1, group_idx, 1)
  289. score_mask = (
  290. group_mask.unsqueeze(-1)
  291. .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
  292. .reshape(-1, self.n_routed_experts)
  293. )
  294. scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
  295. topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
  296. return topk_indices
  297. def forward(self, hidden_states):
  298. hidden_states = hidden_states.view(-1, self.config.hidden_size)
  299. router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
  300. scores = router_logits.sigmoid()
  301. topk_indices = self.get_topk_indices(scores)
  302. topk_weights = scores.gather(1, topk_indices)
  303. if self.norm_topk_prob:
  304. denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
  305. topk_weights /= denominator
  306. topk_weights = topk_weights * self.routed_scaling_factor
  307. return topk_indices, topk_weights
  308. class Dots1DecoderLayer(GradientCheckpointingLayer):
  309. def __init__(self, config: Dots1Config, layer_idx: int):
  310. super().__init__()
  311. self.hidden_size = config.hidden_size
  312. self.self_attn = Dots1Attention(config=config, layer_idx=layer_idx)
  313. if layer_idx >= config.first_k_dense_replace:
  314. self.mlp = Dots1MoE(config)
  315. else:
  316. self.mlp = Dots1MLP(config)
  317. self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  318. self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  319. self.attention_type = config.layer_types[layer_idx]
  320. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  321. def forward(
  322. self,
  323. hidden_states: torch.Tensor,
  324. attention_mask: Optional[torch.Tensor] = None,
  325. position_ids: Optional[torch.LongTensor] = None,
  326. past_key_values: Optional[Cache] = None,
  327. use_cache: Optional[bool] = False,
  328. cache_position: Optional[torch.LongTensor] = None,
  329. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  330. **kwargs: Unpack[TransformersKwargs],
  331. ) -> torch.Tensor:
  332. residual = hidden_states
  333. hidden_states = self.input_layernorm(hidden_states)
  334. # Self Attention
  335. hidden_states, _ = self.self_attn(
  336. hidden_states=hidden_states,
  337. attention_mask=attention_mask,
  338. position_ids=position_ids,
  339. past_key_values=past_key_values,
  340. use_cache=use_cache,
  341. cache_position=cache_position,
  342. position_embeddings=position_embeddings,
  343. **kwargs,
  344. )
  345. hidden_states = residual + hidden_states
  346. # Fully Connected
  347. residual = hidden_states
  348. hidden_states = self.post_attention_layernorm(hidden_states)
  349. hidden_states = self.mlp(hidden_states)
  350. hidden_states = residual + hidden_states
  351. return hidden_states
  352. @auto_docstring
  353. class Dots1PreTrainedModel(PreTrainedModel):
  354. config: Dots1Config
  355. base_model_prefix = "model"
  356. supports_gradient_checkpointing = True
  357. _no_split_modules = ["Dots1DecoderLayer"]
  358. _skip_keys_device_placement = ["past_key_values"]
  359. _supports_flash_attn = True
  360. _supports_sdpa = True
  361. _supports_flex_attn = True
  362. _can_compile_fullgraph = False
  363. _supports_attention_backend = True
  364. _can_record_outputs = {
  365. "hidden_states": Dots1DecoderLayer,
  366. "attentions": Dots1Attention,
  367. }
  368. def _init_weights(self, module):
  369. super()._init_weights(module)
  370. if isinstance(module, Dots1TopkRouter):
  371. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  372. @auto_docstring
  373. class Dots1Model(Dots1PreTrainedModel):
  374. def __init__(self, config: Dots1Config):
  375. super().__init__(config)
  376. self.padding_idx = config.pad_token_id
  377. self.vocab_size = config.vocab_size
  378. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  379. self.layers = nn.ModuleList(
  380. [Dots1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  381. )
  382. self.norm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  383. self.rotary_emb = Dots1RotaryEmbedding(config=config)
  384. self.gradient_checkpointing = False
  385. self.has_sliding_layers = "sliding_attention" in self.config.layer_types
  386. # Initialize weights and apply final processing
  387. self.post_init()
  388. @check_model_inputs()
  389. @auto_docstring
  390. def forward(
  391. self,
  392. input_ids: Optional[torch.LongTensor] = None,
  393. attention_mask: Optional[torch.Tensor] = None,
  394. position_ids: Optional[torch.LongTensor] = None,
  395. past_key_values: Optional[Cache] = None,
  396. inputs_embeds: Optional[torch.FloatTensor] = None,
  397. use_cache: Optional[bool] = None,
  398. cache_position: Optional[torch.LongTensor] = None,
  399. **kwargs: Unpack[TransformersKwargs],
  400. ) -> BaseModelOutputWithPast:
  401. if (input_ids is None) ^ (inputs_embeds is not None):
  402. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  403. if inputs_embeds is None:
  404. inputs_embeds = self.embed_tokens(input_ids)
  405. if use_cache and past_key_values is None:
  406. past_key_values = DynamicCache(config=self.config)
  407. if cache_position is None:
  408. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  409. cache_position = torch.arange(
  410. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  411. )
  412. if position_ids is None:
  413. position_ids = cache_position.unsqueeze(0)
  414. # It may already have been prepared by e.g. `generate`
  415. if not isinstance(causal_mask_mapping := attention_mask, dict):
  416. # Prepare mask arguments
  417. mask_kwargs = {
  418. "config": self.config,
  419. "input_embeds": inputs_embeds,
  420. "attention_mask": attention_mask,
  421. "cache_position": cache_position,
  422. "past_key_values": past_key_values,
  423. "position_ids": position_ids,
  424. }
  425. # Create the masks
  426. causal_mask_mapping = {
  427. "full_attention": create_causal_mask(**mask_kwargs),
  428. }
  429. # The sliding window alternating layers are not always activated depending on the config
  430. if self.has_sliding_layers:
  431. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  432. hidden_states = inputs_embeds
  433. # create position embeddings to be shared across the decoder layers
  434. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  435. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  436. hidden_states = decoder_layer(
  437. hidden_states,
  438. attention_mask=causal_mask_mapping[decoder_layer.attention_type],
  439. position_ids=position_ids,
  440. past_key_values=past_key_values,
  441. use_cache=use_cache,
  442. cache_position=cache_position,
  443. position_embeddings=position_embeddings,
  444. **kwargs,
  445. )
  446. hidden_states = self.norm(hidden_states)
  447. return BaseModelOutputWithPast(
  448. last_hidden_state=hidden_states,
  449. past_key_values=past_key_values if use_cache else None,
  450. )
  451. @auto_docstring
  452. class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin):
  453. _tied_weights_keys = ["lm_head.weight"]
  454. _tp_plan = {"lm_head": "colwise_rep"}
  455. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  456. def __init__(self, config):
  457. super().__init__(config)
  458. self.model = Dots1Model(config)
  459. self.vocab_size = config.vocab_size
  460. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  461. # Initialize weights and apply final processing
  462. self.post_init()
  463. @can_return_tuple
  464. @auto_docstring
  465. def forward(
  466. self,
  467. input_ids: Optional[torch.LongTensor] = None,
  468. attention_mask: Optional[torch.Tensor] = None,
  469. position_ids: Optional[torch.LongTensor] = None,
  470. past_key_values: Optional[Cache] = None,
  471. inputs_embeds: Optional[torch.FloatTensor] = None,
  472. labels: Optional[torch.LongTensor] = None,
  473. use_cache: Optional[bool] = None,
  474. cache_position: Optional[torch.LongTensor] = None,
  475. logits_to_keep: Union[int, torch.Tensor] = 0,
  476. **kwargs: Unpack[TransformersKwargs],
  477. ) -> CausalLMOutputWithPast:
  478. r"""
  479. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  480. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  481. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  482. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  483. Example:
  484. ```python
  485. >>> from transformers import AutoTokenizer, Dots1ForCausalLM
  486. >>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst")
  487. >>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst")
  488. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  489. >>> inputs = tokenizer(prompt, return_tensors="pt")
  490. >>> # Generate
  491. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  492. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  493. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  494. ```"""
  495. outputs: BaseModelOutputWithPast = self.model(
  496. input_ids=input_ids,
  497. attention_mask=attention_mask,
  498. position_ids=position_ids,
  499. past_key_values=past_key_values,
  500. inputs_embeds=inputs_embeds,
  501. use_cache=use_cache,
  502. cache_position=cache_position,
  503. **kwargs,
  504. )
  505. hidden_states = outputs.last_hidden_state
  506. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  507. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  508. logits = self.lm_head(hidden_states[:, slice_indices, :])
  509. loss = None
  510. if labels is not None:
  511. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  512. return CausalLMOutputWithPast(
  513. loss=loss,
  514. logits=logits,
  515. past_key_values=outputs.past_key_values,
  516. hidden_states=outputs.hidden_states,
  517. attentions=outputs.attentions,
  518. )
  519. __all__ = ["Dots1PreTrainedModel", "Dots1Model", "Dots1ForCausalLM"]