modular_olmo.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from typing import Callable, Optional
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ...cache_utils import Cache
  6. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  7. from ...utils import logging
  8. from ...utils.deprecation import deprecate_kwarg
  9. from ..llama.modeling_llama import (
  10. LlamaAttention,
  11. LlamaDecoderLayer,
  12. LlamaForCausalLM,
  13. LlamaMLP,
  14. LlamaModel,
  15. LlamaRotaryEmbedding,
  16. eager_attention_forward,
  17. rotate_half,
  18. )
  19. from .configuration_olmo import OlmoConfig
  20. logger = logging.get_logger(__name__)
  21. class OlmoLayerNorm(nn.Module):
  22. """LayerNorm but with no learnable weight or bias."""
  23. def __init__(self, hidden_size: int) -> None:
  24. super().__init__()
  25. self.normalized_shape = (hidden_size,)
  26. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  27. orig_dtype = hidden_states.dtype
  28. return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
  29. orig_dtype
  30. )
  31. class OlmoMLP(LlamaMLP):
  32. def __init__(self, config):
  33. super().__init__(config)
  34. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  35. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  36. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  37. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  38. """Applies Rotary Position Embedding to the query and key tensors.
  39. Args:
  40. q (`torch.Tensor`): The query tensor.
  41. k (`torch.Tensor`): The key tensor.
  42. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  43. sin (`torch.Tensor`): The sine part of the rotary embedding.
  44. position_ids (`torch.Tensor`, *optional*):
  45. Deprecated and unused.
  46. unsqueeze_dim (`int`, *optional*, defaults to 1):
  47. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  48. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  49. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  50. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  51. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  52. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  53. Returns:
  54. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  55. """
  56. q_type, k_type = q.dtype, k.dtype
  57. cos = cos.unsqueeze(unsqueeze_dim)
  58. sin = sin.unsqueeze(unsqueeze_dim)
  59. q_embed = (q * cos) + (rotate_half(q) * sin)
  60. k_embed = (k * cos) + (rotate_half(k) * sin)
  61. return q_embed.to(q_type), k_embed.to(k_type)
  62. class OlmoAttention(LlamaAttention):
  63. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  64. def forward(
  65. self,
  66. hidden_states: torch.Tensor,
  67. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  68. attention_mask: Optional[torch.Tensor],
  69. past_key_values: Optional[Cache] = None,
  70. cache_position: Optional[torch.LongTensor] = None,
  71. **kwargs,
  72. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  73. input_shape = hidden_states.shape[:-1]
  74. hidden_shape = (*input_shape, -1, self.head_dim)
  75. query_states = self.q_proj(hidden_states)
  76. key_states = self.k_proj(hidden_states)
  77. value_states = self.v_proj(hidden_states)
  78. if self.config.clip_qkv is not None:
  79. query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  80. key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  81. value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  82. query_states = query_states.view(hidden_shape).transpose(1, 2)
  83. key_states = key_states.view(hidden_shape).transpose(1, 2)
  84. value_states = value_states.view(hidden_shape).transpose(1, 2)
  85. cos, sin = position_embeddings
  86. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  87. if past_key_values is not None:
  88. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  89. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  90. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  91. attention_interface: Callable = eager_attention_forward
  92. if self.config._attn_implementation != "eager":
  93. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  94. attn_output, attn_weights = attention_interface(
  95. self,
  96. query_states,
  97. key_states,
  98. value_states,
  99. attention_mask,
  100. dropout=0.0 if not self.training else self.attention_dropout,
  101. scaling=self.scaling,
  102. **kwargs,
  103. )
  104. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  105. attn_output = self.o_proj(attn_output)
  106. return attn_output, attn_weights
  107. class OlmoDecoderLayer(LlamaDecoderLayer):
  108. def __init__(self, config: OlmoConfig, layer_idx: int):
  109. super().__init__(config, layer_idx)
  110. self.input_layernorm = OlmoLayerNorm(config.hidden_size)
  111. self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
  112. self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
  113. # This is identical to LlamaRotaryEmbedding except the output cos and sin are returned
  114. # as float32 rather than the input type.
  115. class OlmoRotaryEmbedding(LlamaRotaryEmbedding):
  116. def forward(self, x, position_ids):
  117. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  118. position_ids_expanded = position_ids[:, None, :].float()
  119. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  120. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  121. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  122. emb = torch.cat((freqs, freqs), dim=-1)
  123. cos = emb.cos() * self.attention_scaling
  124. sin = emb.sin() * self.attention_scaling
  125. return cos, sin
  126. class OlmoModel(LlamaModel):
  127. def __init__(self, config: OlmoConfig):
  128. super().__init__(config)
  129. self.layers = nn.ModuleList(
  130. [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  131. )
  132. self.norm = OlmoLayerNorm(config.hidden_size)
  133. class OlmoForCausalLM(LlamaForCausalLM):
  134. pass
  135. __all__ = [
  136. "OlmoForCausalLM",
  137. "OlmoModel",
  138. "OlmoPreTrainedModel", # noqa: F822
  139. ]