| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/lfm2/modular_lfm2.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_lfm2.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Any, Callable, Optional, Union
- import torch
- import torch.nn.functional as F
- from torch import nn
- from ...cache_utils import Cache
- from ...generation import GenerationMixin
- from ...integrations import use_kernel_forward_from_hub
- from ...masking_utils import create_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
- from ...utils.deprecation import deprecate_kwarg
- from ...utils.generic import check_model_inputs
- from ...utils.import_utils import is_causal_conv1d_available
- from .configuration_lfm2 import Lfm2Config
- if is_causal_conv1d_available():
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
- else:
- causal_conv1d_fn, causal_conv1d_update = None, None
- @use_kernel_forward_from_hub("RMSNorm")
- class Lfm2RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- Lfm2RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- class Lfm2RotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: Lfm2Config, device=None):
- super().__init__()
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- class Lfm2MLP(nn.Module):
- def __init__(self, config: Lfm2Config):
- super().__init__()
- intermediate_size = config.intermediate_size
- if config.block_auto_adjust_ff_dim:
- intermediate_size = int(2 * intermediate_size / 3)
- # custom dim factor multiplier
- if config.block_ffn_dim_multiplier is not None:
- intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size)
- intermediate_size = config.block_multiple_of * (
- (intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of
- )
- self.w1 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
- self.w3 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
- self.w2 = nn.Linear(intermediate_size, config.hidden_size, bias=False)
- def forward(self, x):
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
- class Lfm2HybridConvCache:
- """
- Attention and conv cache for Lfm2.
- It stores the Key and Value states as a list of tensors, one for each layer.
- Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`.
- Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`.
- """
- # Override @property existing in Cache
- max_batch_size = None
- is_compileable = False
- key_cache = None
- value_cache = None
- def __init__(
- self,
- config: Lfm2Config,
- max_batch_size: int,
- dtype: torch.dtype = torch.float32,
- device: Union[torch.device, str, None] = None,
- ):
- self.key_cache = []
- self.value_cache = []
- self.max_batch_size = max_batch_size
- self.layer_types = config.layer_types
- self.first_attention_layer = self.layer_types.index("full_attention")
- self.conv_L_cache = config.conv_L_cache
- self._dtype = dtype
- self.conv_cache: list[torch.Tensor] = []
- device = torch.device(device) if device is not None else None
- for _ in range(config.num_hidden_layers):
- conv_state = torch.zeros(
- self.max_batch_size,
- config.hidden_size,
- self.conv_L_cache,
- dtype=self._dtype,
- device=device,
- )
- torch._dynamo.mark_static_address(conv_state)
- self.conv_cache.append(conv_state)
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[dict[str, Any]] = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
- Return:
- A tuple containing the updated key and value states.
- """
- # Update the cache
- if key_states is not None:
- if len(self.key_cache) <= layer_idx:
- # There may be skipped layers, fill them with empty lists
- for _ in range(len(self.key_cache), layer_idx):
- self.key_cache.append(torch.tensor([]))
- self.value_cache.append(torch.tensor([]))
- self.key_cache.append(key_states)
- self.value_cache.append(value_states)
- elif (
- not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
- ): # fills previously skipped layers; checking for tensor causes errors
- self.key_cache[layer_idx] = key_states
- self.value_cache[layer_idx] = value_states
- else:
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
- def reorder_cache(self, beam_idx: torch.LongTensor):
- """Reorders the cache for beam search, given the selected beam indices."""
- for layer_idx in range(len(self.key_cache)):
- device = self.key_cache[layer_idx].device
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
- device = self.value_cache[layer_idx].device
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
- device = self.conv_cache[layer_idx].device
- self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- # take any layer that contains cache and not empty tensor
- layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx
- if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
- return 0
- return self.key_cache[layer_idx].shape[-2]
- def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
- """
- Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
- the given layer at `layer_idx`.
- The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
- for each layer.
- """
- full_mask_kv_offset = 0
- query_length = cache_position.shape[0]
- past_seen_tokens = self.get_seq_length()
- kv_length = query_length + past_seen_tokens
- return kv_length, full_mask_kv_offset
- def crop(self, max_length: int):
- """Crop the cache to the given length"""
- if max_length < 0:
- max_length = self.get_seq_length() - abs(max_length)
- if self.get_seq_length() <= max_length:
- return
- for idx in range(len(self.key_cache)):
- if self.key_cache[idx].numel():
- self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
- self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
- def __len__(self) -> int:
- return len(self.key_cache)
- def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
- def reset(self):
- for layer_idx in range(len(self.conv_cache)):
- # In-place ops prevent breaking the static address
- self.conv_cache[layer_idx].zero_()
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class Lfm2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Lfm2Config, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.is_causal = True
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
- self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
- self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_values: Optional[Lfm2HybridConvCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
- key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- output = self.out_proj(attn_output)
- return output, attn_weights
- def apply_mask_to_padding_states(hidden_states, attention_mask):
- """
- Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
- """
- if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
- dtype = hidden_states.dtype
- hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
- return hidden_states
- kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
- is_fast_path_available = all(kernel_modules)
- class Lfm2ShortConv(nn.Module):
- def __init__(
- self,
- config: Lfm2Config,
- layer_idx: int,
- ):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.L_cache = config.conv_L_cache
- self.bias = config.conv_bias
- self.conv = nn.Conv1d(
- in_channels=config.hidden_size,
- out_channels=config.hidden_size,
- kernel_size=self.L_cache,
- groups=config.hidden_size,
- bias=self.bias,
- padding=self.L_cache - 1,
- )
- self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias)
- self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias)
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def cuda_kernels_forward(
- self,
- x: torch.Tensor,
- past_key_values: Optional[Lfm2HybridConvCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ):
- x = apply_mask_to_padding_states(x, attention_mask)
- BCx = self.in_proj(x).transpose(-1, -2)
- B, C, x = BCx.chunk(3, dim=-2)
- Bx = B * x
- conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
- if past_key_values is not None and cache_position[0] > 0:
- conv_out = causal_conv1d_update(
- Bx.squeeze(-1),
- past_key_values.conv_cache[self.layer_idx],
- conv_weights,
- self.conv.bias,
- None,
- )
- conv_out = conv_out.unsqueeze(-1)
- else:
- if past_key_values is not None:
- conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
- past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
- conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
- y = C * conv_out
- y = self.out_proj(y.transpose(-1, -2).contiguous())
- return y
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def slow_forward(
- self,
- x: torch.Tensor,
- past_key_values: Optional[Lfm2HybridConvCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ):
- seqlen = x.shape[1]
- x = apply_mask_to_padding_states(x, attention_mask)
- BCx = self.in_proj(x).transpose(-1, -2)
- B, C, x = BCx.chunk(3, dim=-2)
- Bx = B * x
- if past_key_values is not None and cache_position[0] > 0:
- conv_state = past_key_values.conv_cache[self.layer_idx]
- cache_position = cache_position.clamp(0, self.L_cache - 1)
- conv_state = conv_state.roll(shifts=-1, dims=-1)
- conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype)
- past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
- conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
- if self.bias:
- conv_out += self.conv.bias
- conv_out = conv_out.unsqueeze(-1)
- else:
- if past_key_values is not None:
- conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
- past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
- conv_out = self.conv(Bx)[..., :seqlen]
- y = C * conv_out
- y = y.transpose(-1, -2).contiguous()
- y = self.out_proj(y)
- return y
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- past_key_values: Optional[Lfm2HybridConvCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ):
- if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling():
- return self.cuda_kernels_forward(hidden_states, past_key_values, cache_position, attention_mask)
- return self.slow_forward(hidden_states, past_key_values, cache_position, attention_mask)
- class Lfm2DecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: Lfm2Config, layer_idx: int):
- super().__init__()
- self.is_attention_layer = config.layer_types[layer_idx] == "full_attention"
- if self.is_attention_layer:
- self.self_attn = Lfm2Attention(config, layer_idx)
- else:
- self.conv = Lfm2ShortConv(config, layer_idx)
- self.feed_forward = Lfm2MLP(config)
- self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
- self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Lfm2HybridConvCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> torch.Tensor:
- residual = hidden_states
- if self.is_attention_layer:
- hidden_states, _ = self.self_attn(
- hidden_states=self.operator_norm(hidden_states),
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- cache_position=cache_position,
- **kwargs,
- )
- else:
- hidden_states = self.conv(
- hidden_states=self.operator_norm(hidden_states),
- past_key_values=past_key_values,
- cache_position=cache_position,
- attention_mask=attention_mask,
- )
- hidden_states = hidden_states + residual
- hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
- return hidden_states
- @auto_docstring
- class Lfm2PreTrainedModel(PreTrainedModel):
- config: Lfm2Config
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["Lfm2DecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _can_compile_fullgraph = False
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": Lfm2DecoderLayer,
- "attentions": Lfm2Attention,
- }
- @auto_docstring
- class Lfm2Model(Lfm2PreTrainedModel):
- def __init__(self, config: Lfm2Config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [Lfm2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.rotary_emb = Lfm2RotaryEmbedding(config=config)
- self.gradient_checkpointing = False
- self.pos_emb = Lfm2RotaryEmbedding(config)
- self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
- # Initialize weights and apply final processing
- self.post_init()
- @check_model_inputs()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Lfm2HybridConvCache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- batch_size = inputs_embeds.shape[0]
- past_key_values = Lfm2HybridConvCache(
- config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device
- )
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
- causal_mask = create_causal_mask(
- config=self.config,
- input_embeds=inputs_embeds,
- attention_mask=attention_mask,
- cache_position=cache_position,
- past_key_values=past_key_values,
- position_ids=position_ids,
- )
- hidden_states = inputs_embeds
- position_embeddings = self.pos_emb(hidden_states, position_ids)
- # decoder layers
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = self.embedding_norm(hidden_states)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- @auto_docstring
- class Lfm2ForCausalLM(Lfm2PreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
- _tp_plan = {"lm_head": "colwise_rep"}
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
- def __init__(self, config):
- super().__init__(config)
- self.model = Lfm2Model(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- Example:
- ```python
- >>> from transformers import AutoTokenizer, Lfm2ForCausalLM
- >>> model = Lfm2ForCausalLM.from_pretrained("meta-lfm2/Lfm2-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-lfm2/Lfm2-2-7b-hf")
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- cache_position=cache_position,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = ["Lfm2ForCausalLM", "Lfm2Model", "Lfm2PreTrainedModel"]
|