modular_olmo2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. from typing import Callable, Optional
  2. import torch
  3. import torch.nn as nn
  4. from transformers.utils.generic import TransformersKwargs
  5. from ...cache_utils import Cache
  6. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  7. from ...processing_utils import Unpack
  8. from ...utils import logging
  9. from ...utils.deprecation import deprecate_kwarg
  10. from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward
  11. from ..olmo.configuration_olmo import OlmoConfig
  12. from ..olmo.modeling_olmo import (
  13. OlmoAttention,
  14. OlmoDecoderLayer,
  15. OlmoForCausalLM,
  16. OlmoModel,
  17. OlmoRotaryEmbedding,
  18. apply_rotary_pos_emb,
  19. )
  20. logger = logging.get_logger(__name__)
  21. class Olmo2Config(OlmoConfig):
  22. r"""
  23. This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2
  24. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  25. defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).
  26. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  27. documentation from [`PretrainedConfig`] for more information.
  28. Args:
  29. vocab_size (`int`, *optional*, defaults to 50304):
  30. Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the
  31. `inputs_ids` passed when calling [`Olmo2Model`]
  32. hidden_size (`int`, *optional*, defaults to 4096):
  33. Dimension of the hidden representations.
  34. intermediate_size (`int`, *optional*, defaults to 11008):
  35. Dimension of the MLP representations.
  36. num_hidden_layers (`int`, *optional*, defaults to 32):
  37. Number of hidden layers in the Transformer decoder.
  38. num_attention_heads (`int`, *optional*, defaults to 32):
  39. Number of attention heads for each attention layer in the Transformer decoder.
  40. num_key_value_heads (`int`, *optional*):
  41. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  42. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  43. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  44. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  45. by meanpooling all the original heads within that group. For more details, check out [this
  46. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  47. `num_attention_heads`.
  48. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  49. The non-linear activation function (function or string) in the decoder.
  50. max_position_embeddings (`int`, *optional*, defaults to 2048):
  51. The maximum sequence length that this model might ever be used with.
  52. initializer_range (`float`, *optional*, defaults to 0.02):
  53. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  54. use_cache (`bool`, *optional*, defaults to `True`):
  55. Whether or not the model should return the last key/values attentions (not used by all models). Only
  56. relevant if `config.is_decoder=True`.
  57. pad_token_id (`int`, *optional*, defaults to 1):
  58. Padding token id.
  59. bos_token_id (`int`, *optional*):
  60. Beginning of stream token id.
  61. eos_token_id (`int`, *optional*, defaults to 50279):
  62. End of stream token id.
  63. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  64. Whether to tie weight embeddings
  65. rope_theta (`float`, *optional*, defaults to 10000.0):
  66. The base period of the RoPE embeddings.
  67. rope_scaling (`Dict`, *optional*):
  68. Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
  69. strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
  70. `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
  71. `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
  72. these scaling strategies behave:
  73. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
  74. experimental feature, subject to breaking API changes in future versions.
  75. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  76. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  77. attention_dropout (`float`, *optional*, defaults to 0.0):
  78. The dropout ratio for the attention probabilities.
  79. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  80. The epsilon used by the rms normalization layers.
  81. ```python
  82. >>> from transformers import Olmo2Model, Olmo2Config
  83. >>> # Initializing a Olmo2 7B style configuration
  84. >>> configuration = Olmo2Config()
  85. >>> # Initializing a model from the Olmo2 7B style configuration
  86. >>> model = Olmo2Model(configuration)
  87. >>> # Accessing the model configuration
  88. >>> configuration = model.config
  89. ```
  90. """
  91. model_type = "olmo2"
  92. base_model_tp_plan = {
  93. "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  94. "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  95. "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  96. "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
  97. "layers.*.mlp.gate_proj": "colwise",
  98. "layers.*.mlp.up_proj": "colwise",
  99. "layers.*.mlp.down_proj": "rowwise",
  100. }
  101. base_model_pp_plan = {
  102. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  103. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  104. "norm": (["hidden_states"], ["hidden_states"]),
  105. }
  106. def __init__(
  107. self,
  108. vocab_size=50304,
  109. hidden_size=4096,
  110. intermediate_size=11008,
  111. num_hidden_layers=32,
  112. num_attention_heads=32,
  113. num_key_value_heads=None,
  114. hidden_act="silu",
  115. max_position_embeddings=2048,
  116. initializer_range=0.02,
  117. use_cache=True,
  118. pad_token_id=1,
  119. bos_token_id=None,
  120. eos_token_id=50279,
  121. tie_word_embeddings=False,
  122. rope_theta=10000.0,
  123. rope_scaling=None,
  124. attention_bias=False,
  125. attention_dropout=0.0,
  126. rms_norm_eps=1e-5,
  127. **kwargs,
  128. ):
  129. super().__init__(
  130. vocab_size=vocab_size,
  131. hidden_size=hidden_size,
  132. intermediate_size=intermediate_size,
  133. num_hidden_layers=num_hidden_layers,
  134. num_attention_heads=num_attention_heads,
  135. num_key_value_heads=num_key_value_heads,
  136. hidden_act=hidden_act,
  137. max_position_embeddings=max_position_embeddings,
  138. initializer_range=initializer_range,
  139. use_cache=use_cache,
  140. pad_token_id=pad_token_id,
  141. bos_token_id=bos_token_id,
  142. eos_token_id=eos_token_id,
  143. tie_word_embeddings=tie_word_embeddings,
  144. rope_theta=rope_theta,
  145. rope_scaling=rope_scaling,
  146. attention_bias=attention_bias,
  147. attention_dropout=attention_dropout,
  148. **kwargs,
  149. )
  150. self.rms_norm_eps = rms_norm_eps
  151. del self.clip_qkv
  152. # OLMo2 RMS norm is identical to Llama RMS norm except:
  153. # - Weight and hidden states are multiplied before converting back to the input dtype, rather than after.
  154. class Olmo2RMSNorm(LlamaRMSNorm):
  155. def forward(self, hidden_states):
  156. input_dtype = hidden_states.dtype
  157. hidden_states = hidden_states.to(torch.float32)
  158. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  159. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  160. return (self.weight * hidden_states).to(input_dtype)
  161. def rotate_half(x):
  162. """Rotates half the hidden dims of the input."""
  163. x1 = x[..., : x.shape[-1] // 2]
  164. x2 = x[..., x.shape[-1] // 2 :]
  165. return torch.cat((-x2, x1), dim=-1)
  166. # Olmo2 attention is identical to OLMo attention except:
  167. # - Norm is applied to attention queries and keys.
  168. # - No qkv clipping.
  169. class Olmo2Attention(OlmoAttention):
  170. def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
  171. super().__init__(config, layer_idx=layer_idx)
  172. self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
  173. self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
  174. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  175. def forward(
  176. self,
  177. hidden_states: torch.Tensor,
  178. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  179. attention_mask: Optional[torch.Tensor],
  180. past_key_values: Optional[Cache] = None,
  181. cache_position: Optional[torch.LongTensor] = None,
  182. **kwargs: Unpack[TransformersKwargs],
  183. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  184. input_shape = hidden_states.shape[:-1]
  185. hidden_shape = (*input_shape, -1, self.head_dim)
  186. query_states = self.q_norm(self.q_proj(hidden_states))
  187. key_states = self.k_norm(self.k_proj(hidden_states))
  188. value_states = self.v_proj(hidden_states)
  189. query_states = query_states.view(hidden_shape).transpose(1, 2)
  190. key_states = key_states.view(hidden_shape).transpose(1, 2)
  191. value_states = value_states.view(hidden_shape).transpose(1, 2)
  192. cos, sin = position_embeddings
  193. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  194. if past_key_values is not None:
  195. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  196. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  197. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  198. attention_interface: Callable = eager_attention_forward
  199. if self.config._attn_implementation != "eager":
  200. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  201. attn_output, attn_weights = attention_interface(
  202. self,
  203. query_states,
  204. key_states,
  205. value_states,
  206. attention_mask,
  207. dropout=0.0 if not self.training else self.attention_dropout,
  208. scaling=self.scaling,
  209. **kwargs,
  210. )
  211. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  212. attn_output = self.o_proj(attn_output)
  213. return attn_output, attn_weights
  214. # The OLMo2 layers are identical to those of the OLMo model except:
  215. # - RMSNorm is used instead of standard layer norm.
  216. # - Norm is applied after attention/feedforward rather than before.
  217. class Olmo2DecoderLayer(OlmoDecoderLayer):
  218. def __init__(self, config: Olmo2Config, layer_idx: int):
  219. super().__init__(config, layer_idx=layer_idx)
  220. self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  221. self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  222. self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
  223. del self.input_layernorm
  224. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  225. def forward(
  226. self,
  227. hidden_states: torch.Tensor,
  228. attention_mask: Optional[torch.Tensor] = None,
  229. position_ids: Optional[torch.LongTensor] = None,
  230. past_key_values: Optional[Cache] = None,
  231. use_cache: Optional[bool] = False,
  232. cache_position: Optional[torch.LongTensor] = None,
  233. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  234. **kwargs: Unpack[TransformersKwargs],
  235. ) -> torch.Tensor:
  236. residual = hidden_states
  237. hidden_states, _ = self.self_attn(
  238. hidden_states=hidden_states,
  239. attention_mask=attention_mask,
  240. position_ids=position_ids,
  241. past_key_values=past_key_values,
  242. use_cache=use_cache,
  243. cache_position=cache_position,
  244. position_embeddings=position_embeddings,
  245. **kwargs,
  246. )
  247. hidden_states = self.post_attention_layernorm(hidden_states)
  248. hidden_states = residual + hidden_states
  249. # Fully Connected
  250. residual = hidden_states
  251. hidden_states = self.mlp(hidden_states)
  252. hidden_states = self.post_feedforward_layernorm(hidden_states)
  253. hidden_states = residual + hidden_states
  254. return hidden_states
  255. class Olmo2RotaryEmbedding(OlmoRotaryEmbedding):
  256. pass
  257. class Olmo2PreTrainedModel(LlamaPreTrainedModel):
  258. pass
  259. # The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
  260. # standard layer norm for the output norm.
  261. class Olmo2Model(OlmoModel):
  262. def __init__(self, config: Olmo2Config):
  263. super().__init__(config)
  264. self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  265. self.layers = nn.ModuleList(
  266. [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  267. )
  268. # The heads now only need to redefine the model inside to the correct `RobertaModel`
  269. class Olmo2ForCausalLM(OlmoForCausalLM):
  270. pass
  271. __all__ = [
  272. "Olmo2Config",
  273. "Olmo2ForCausalLM",
  274. "Olmo2Model",
  275. "Olmo2PreTrainedModel",
  276. ]