| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/dots1/modular_dots1.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_dots1.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # coding=utf-8
- # Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Callable, Optional, Union
- import torch
- import torch.nn.functional as F
- from torch import nn
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationMixin
- from ...integrations import use_kernel_forward_from_hub
- from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
- from ...utils.deprecation import deprecate_kwarg
- from ...utils.generic import check_model_inputs
- from .configuration_dots1 import Dots1Config
- @use_kernel_forward_from_hub("RMSNorm")
- class Dots1RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
- """
- Dots1RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- class Dots1RotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: Dots1Config, device=None):
- super().__init__()
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class Dots1Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Dots1Config, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = True
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
- self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
- self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_values: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
- key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=self.sliding_window, # diff with Llama
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class Dots1MLP(nn.Module):
- def __init__(self, config, hidden_size=None, intermediate_size=None):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
- self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
- class Dots1MoE(nn.Module):
- """
- A mixed expert module containing shared experts.
- """
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.experts = nn.ModuleList(
- [Dots1MLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_routed_experts)]
- )
- self.gate = Dots1TopkRouter(config)
- self.shared_experts = Dots1MLP(
- config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
- )
- def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
- r"""
- CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
- to not have to do a loop here (deepseek has 256 experts soooo yeah).
- """
- final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
- expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
- expert_mask = expert_mask.permute(2, 0, 1)
- for expert_idx in range(len(self.experts)):
- expert = self.experts[expert_idx]
- mask = expert_mask[expert_idx]
- token_indices, weight_indices = torch.where(mask)
- if token_indices.numel() > 0:
- expert_weights = topk_weights[token_indices, weight_indices]
- expert_input = hidden_states[token_indices]
- expert_output = expert(expert_input)
- weighted_output = expert_output * expert_weights.unsqueeze(-1)
- final_hidden_states.index_add_(0, token_indices, weighted_output)
- # in original deepseek, the output of the experts are gathered once we leave this module
- # thus the moe module is itelsf an IsolatedParallel module
- # and all expert are "local" meaning we shard but we don't gather
- return final_hidden_states.type(hidden_states.dtype)
- def forward(self, hidden_states):
- residuals = hidden_states
- orig_shape = hidden_states.shape
- topk_indices, topk_weights = self.gate(hidden_states)
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
- hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
- hidden_states = hidden_states + self.shared_experts(residuals)
- return hidden_states
- class Dots1TopkRouter(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.top_k = config.num_experts_per_tok
- self.n_routed_experts = config.n_routed_experts
- self.routed_scaling_factor = config.routed_scaling_factor
- self.n_group = config.n_group
- self.topk_group = config.topk_group
- self.norm_topk_prob = config.norm_topk_prob
- self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
- self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
- @torch.no_grad()
- def get_topk_indices(self, scores):
- scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
- group_scores = (
- scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
- .topk(2, dim=-1)[0]
- .sum(dim=-1)
- )
- group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
- group_mask = torch.zeros_like(group_scores)
- group_mask.scatter_(1, group_idx, 1)
- score_mask = (
- group_mask.unsqueeze(-1)
- .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
- .reshape(-1, self.n_routed_experts)
- )
- scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
- topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
- return topk_indices
- def forward(self, hidden_states):
- hidden_states = hidden_states.view(-1, self.config.hidden_size)
- router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
- scores = router_logits.sigmoid()
- topk_indices = self.get_topk_indices(scores)
- topk_weights = scores.gather(1, topk_indices)
- if self.norm_topk_prob:
- denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
- topk_weights /= denominator
- topk_weights = topk_weights * self.routed_scaling_factor
- return topk_indices, topk_weights
- class Dots1DecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: Dots1Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = Dots1Attention(config=config, layer_idx=layer_idx)
- if layer_idx >= config.first_k_dense_replace:
- self.mlp = Dots1MoE(config)
- else:
- self.mlp = Dots1MLP(config)
- self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.attention_type = config.layer_types[layer_idx]
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- @auto_docstring
- class Dots1PreTrainedModel(PreTrainedModel):
- config: Dots1Config
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["Dots1DecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _can_compile_fullgraph = False
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": Dots1DecoderLayer,
- "attentions": Dots1Attention,
- }
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, Dots1TopkRouter):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- @auto_docstring
- class Dots1Model(Dots1PreTrainedModel):
- def __init__(self, config: Dots1Config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [Dots1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = Dots1RotaryEmbedding(config=config)
- self.gradient_checkpointing = False
- self.has_sliding_layers = "sliding_attention" in self.config.layer_types
- # Initialize weights and apply final processing
- self.post_init()
- @check_model_inputs()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
- # It may already have been prepared by e.g. `generate`
- if not isinstance(causal_mask_mapping := attention_mask, dict):
- # Prepare mask arguments
- mask_kwargs = {
- "config": self.config,
- "input_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- "cache_position": cache_position,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- }
- # Create the masks
- causal_mask_mapping = {
- "full_attention": create_causal_mask(**mask_kwargs),
- }
- # The sliding window alternating layers are not always activated depending on the config
- if self.has_sliding_layers:
- causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
- hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=causal_mask_mapping[decoder_layer.attention_type],
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
- )
- @auto_docstring
- class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
- _tp_plan = {"lm_head": "colwise_rep"}
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
- def __init__(self, config):
- super().__init__(config)
- self.model = Dots1Model(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, Dots1ForCausalLM
- >>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst")
- >>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst")
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- cache_position=cache_position,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = ["Dots1PreTrainedModel", "Dots1Model", "Dots1ForCausalLM"]
|