| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996 |
- # coding=utf-8
- # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch StableLM model."""
- import math
- from typing import Optional, Union
- import torch
- from torch import nn
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationMixin
- from ...modeling_attn_mask_utils import AttentionMaskConverter
- from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
- from ...modeling_layers import (
- GenericForSequenceClassification,
- GenericForTokenClassification,
- GradientCheckpointingLayer,
- )
- from ...modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- )
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import PreTrainedModel
- from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
- from ...utils.deprecation import deprecate_kwarg
- from .configuration_stablelm import StableLmConfig
- if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
- from ...integrations.flex_attention import make_flex_block_causal_mask
- if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
- logger = logging.get_logger(__name__)
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm
- class StableLmRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: StableLmConfig, device=None):
- super().__init__()
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- # Copied from transformers.models.llama.modeling_llama.rotate_half
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->StableLm
- class StableLmMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
- class StableLmLayerNormPerHead(nn.Module):
- def __init__(self, dim, num_heads, eps=1e-5, bias=False):
- super().__init__()
- self.dim = dim
- self.num_heads = num_heads
- self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)])
- def forward(self, hidden_states: torch.Tensor):
- # Split along the num_heads axis to get per-head inputs
- # [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads
- states_per_heads = torch.split(hidden_states, 1, dim=1)
- # Normalize and merge the heads back together
- return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1)
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- class StableLmAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.rope_theta = config.rope_theta
- self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
- self.is_causal = True
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias)
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
- self.qk_layernorm = config.qk_layernorm
- if self.qk_layernorm:
- self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps)
- self.k_layernorm = StableLmLayerNormPerHead(
- self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
- )
- self.attention_dropout = nn.Dropout(config.attention_dropout)
- self.rotary_emb = StableLmRotaryEmbedding(config=self.config)
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- if self.qk_layernorm:
- query_states = self.q_layernorm(query_states)
- key_states = self.k_layernorm(key_states)
- cos, sin = position_embeddings
- # Partial rotary embedding
- query_rot, query_pass = (
- query_states[..., : self.rotary_ndims],
- query_states[..., self.rotary_ndims :],
- )
- key_rot, key_pass = (
- key_states[..., : self.rotary_ndims],
- key_states[..., self.rotary_ndims :],
- )
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
- # [batch_size, seq_length, num_heads, head_dim]
- query_states = torch.cat((query_rot, query_pass), dim=-1)
- key_states = torch.cat((key_rot, key_pass), dim=-1)
- if past_key_values is not None:
- # Specific to RoPE models with partial rotation
- cache_kwargs = {
- "sin": sin,
- "cos": cos,
- "partial_rotation_size": self.rotary_ndims,
- "cache_position": cache_position,
- }
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
- # Repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights += causal_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
- attn_weights = self.attention_dropout(attn_weights)
- attn_output = torch.matmul(attn_weights, value_states)
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
- return attn_output, attn_weights
- class StableLmSdpaAttention(StableLmAttention):
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- )
- bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- if self.qk_layernorm:
- query_states = self.q_layernorm(query_states)
- key_states = self.k_layernorm(key_states)
- cos, sin = position_embeddings
- # Partial rotary embedding
- query_rot, query_pass = (
- query_states[..., : self.rotary_ndims],
- query_states[..., self.rotary_ndims :],
- )
- key_rot, key_pass = (
- key_states[..., : self.rotary_ndims],
- key_states[..., self.rotary_ndims :],
- )
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
- # [batch_size, seq_length, num_heads, head_dim]
- query_states = torch.cat((query_rot, query_pass), dim=-1)
- key_states = torch.cat((key_rot, key_pass), dim=-1)
- if past_key_values is not None:
- # Specific to RoPE models with partial rotation
- cache_kwargs = {
- "sin": sin,
- "cos": cos,
- "partial_rotation_size": self.rotary_ndims,
- "cache_position": cache_position,
- }
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
- # Repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- causal_mask = attention_mask
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == "cuda" and attention_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
- is_causal = bool(causal_mask is None and q_len > 1)
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=causal_mask,
- dropout_p=self.attention_dropout.p if self.training else 0.0,
- is_causal=is_causal,
- )
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
- return attn_output, None
- class StableLmFlashAttention2(StableLmAttention):
- """
- StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs,
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
- # StableLmFlashAttention2 attention does not support output_attentions
- output_attentions = False
- bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- # therefore we just need to keep the original shape
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- if self.qk_layernorm:
- query_states = self.q_layernorm(query_states)
- key_states = self.k_layernorm(key_states)
- cos, sin = position_embeddings
- # Partial rotary embedding
- query_rot, query_pass = (
- query_states[..., : self.rotary_ndims],
- query_states[..., self.rotary_ndims :],
- )
- key_rot, key_pass = (
- key_states[..., : self.rotary_ndims],
- key_states[..., self.rotary_ndims :],
- )
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
- # [batch_size, seq_length, num_heads, head_dim]
- query_states = torch.cat((query_rot, query_pass), dim=-1)
- key_states = torch.cat((key_rot, key_pass), dim=-1)
- if past_key_values is not None:
- cache_kwargs = {
- "sin": sin,
- "cos": cos,
- "partial_rotation_size": self.rotary_ndims,
- "cache_position": cache_position,
- }
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
- # to be able to avoid many of these transpose/reshape/view.
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- dropout_rate = self.attention_dropout.p if self.training else 0.0
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- position_ids=position_ids,
- dropout=dropout_rate,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- is_causal=self.is_causal,
- )
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
- attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
- return attn_output, attn_weights
- ATTENTION_CLASSES = {
- "eager": StableLmAttention,
- "sdpa": StableLmSdpaAttention,
- "flash_attention_2": StableLmFlashAttention2,
- }
- class StableLmDecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: StableLmConfig, layer_idx: int):
- super().__init__()
- self.use_parallel_residual = config.use_parallel_residual
- self.hidden_size = config.hidden_size
- self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
- self.mlp = StableLmMLP(config)
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.post_attention_layernorm = None
- if not self.use_parallel_residual:
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
- `[0, config.n_positions - 1]`.
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`Cache`, *optional*):
- cached past key and value projection states
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence
- position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
- with `head_dim` being the embedding dimension of each attention head.
- """
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- self_attn_output, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- )
- # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
- if self.use_parallel_residual:
- # x = x + attn(ln1(x)) + mlp(ln1(x))
- # Fully Connected
- mlp_output = self.mlp(hidden_states)
- mlp_output = self.dropout(mlp_output)
- hidden_states = residual + self_attn_output + mlp_output
- else:
- # x = x + attn(ln1(x))
- # x = x + mlp(ln2(x))
- residual = residual + self_attn_output
- # Fully Connected
- mlp_output = self.mlp(self.post_attention_layernorm(residual))
- mlp_output = self.dropout(mlp_output)
- hidden_states = residual + mlp_output
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- return outputs
- @auto_docstring
- class StableLmPreTrainedModel(PreTrainedModel):
- config: StableLmConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["StableLmDecoderLayer"]
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn = True
- _supports_sdpa = True
- _can_compile_fullgraph = True
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, nn.LayerNorm):
- module.weight.data.fill_(1.0)
- module.bias.data.zero_()
- @auto_docstring
- class StableLmModel(StableLmPreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableLmDecoderLayer`]
- Args:
- config: StableLmConfig
- """
- def __init__(self, config: StableLmConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.rotary_emb = StableLmRotaryEmbedding(config=config)
- self._attn_implementation = config._attn_implementation
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> BaseModelOutputWithPast:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
- hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- hidden_states = self.norm(hidden_states)
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
- def _update_causal_mask(
- self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool = False,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
- if self.config._attn_implementation == "flex_attention":
- if isinstance(attention_mask, torch.Tensor):
- attention_mask = make_flex_block_causal_mask(attention_mask)
- return attention_mask
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
- dtype = input_tensor.dtype
- sequence_length = input_tensor.shape[1]
- if using_compilable_cache:
- target_length = past_key_values.get_max_cache_shape()
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
- return causal_mask
- @staticmethod
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
- causal_mask.device
- )
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
- return causal_mask
- # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm
- class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm
- def __init__(self, config):
- super().__init__(config)
- self.model = StableLmModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- # Ignore copy
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
- **kwargs,
- ) -> CausalLMOutputWithPast:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, StableLmForCausalLM
- >>> model = StableLmForCausalLM.from_pretrained("adept/persimmon-8b-base")
- >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base")
- >>> prompt = "human: Hey, what should I eat for dinner?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n'
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- cache_position=cache_position,
- )
- hidden_states = outputs.last_hidden_state
- # No upscaling to float was ever done for StableLm
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(
- logits,
- labels,
- vocab_size=self.config.vocab_size,
- **kwargs,
- )
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class StableLmForSequenceClassification(GenericForSequenceClassification, StableLmPreTrainedModel): ...
- class StableLmForTokenClassification(GenericForTokenClassification, StableLmPreTrainedModel): ...
- __all__ = [
- "StableLmForCausalLM",
- "StableLmModel",
- "StableLmPreTrainedModel",
- "StableLmForSequenceClassification",
- "StableLmForTokenClassification",
- ]
|