| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- from typing import Callable, Optional
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from ...cache_utils import Cache
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
- from ...utils import logging
- from ...utils.deprecation import deprecate_kwarg
- from ..llama.modeling_llama import (
- LlamaAttention,
- LlamaDecoderLayer,
- LlamaForCausalLM,
- LlamaMLP,
- LlamaModel,
- LlamaRotaryEmbedding,
- eager_attention_forward,
- rotate_half,
- )
- from .configuration_olmo import OlmoConfig
- logger = logging.get_logger(__name__)
- class OlmoLayerNorm(nn.Module):
- """LayerNorm but with no learnable weight or bias."""
- def __init__(self, hidden_size: int) -> None:
- super().__init__()
- self.normalized_shape = (hidden_size,)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- orig_dtype = hidden_states.dtype
- return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
- orig_dtype
- )
- class OlmoMLP(LlamaMLP):
- def __init__(self, config):
- super().__init__(config)
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- q_type, k_type = q.dtype, k.dtype
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed.to(q_type), k_embed.to(k_type)
- class OlmoAttention(LlamaAttention):
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_values: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- if self.config.clip_qkv is not None:
- query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
- key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
- value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
- query_states = query_states.view(hidden_shape).transpose(1, 2)
- key_states = key_states.view(hidden_shape).transpose(1, 2)
- value_states = value_states.view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class OlmoDecoderLayer(LlamaDecoderLayer):
- def __init__(self, config: OlmoConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.input_layernorm = OlmoLayerNorm(config.hidden_size)
- self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
- self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
- # This is identical to LlamaRotaryEmbedding except the output cos and sin are returned
- # as float32 rather than the input type.
- class OlmoRotaryEmbedding(LlamaRotaryEmbedding):
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- return cos, sin
- class OlmoModel(LlamaModel):
- def __init__(self, config: OlmoConfig):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = OlmoLayerNorm(config.hidden_size)
- class OlmoForCausalLM(LlamaForCausalLM):
- pass
- __all__ = [
- "OlmoForCausalLM",
- "OlmoModel",
- "OlmoPreTrainedModel", # noqa: F822
- ]
|