modular_olmo3.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # coding=utf-8
  2. # Copyright 2025 the HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Callable, Optional
  16. import torch
  17. import torch.nn as nn
  18. from transformers.utils.generic import TransformersKwargs
  19. from ...cache_utils import Cache, DynamicCache
  20. from ...configuration_utils import layer_type_validation
  21. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  22. from ...modeling_outputs import BaseModelOutputWithPast
  23. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, rope_config_validation
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  25. from ...processing_utils import Unpack
  26. from ..olmo2.configuration_olmo2 import Olmo2Config
  27. from ..olmo2.modeling_olmo2 import (
  28. Olmo2Attention,
  29. Olmo2DecoderLayer,
  30. Olmo2ForCausalLM,
  31. Olmo2Model,
  32. Olmo2PreTrainedModel,
  33. Olmo2RMSNorm,
  34. Olmo2RotaryEmbedding,
  35. apply_rotary_pos_emb,
  36. eager_attention_forward,
  37. )
  38. class Olmo3Config(Olmo2Config):
  39. r"""
  40. This is the configuration class to store the configuration of a [`Olmo3Model`]. It is used to instantiate an OLMo3
  41. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  42. defaults will yield a similar configuration to that of the [allenai/OLMo-3-0725-1B](https://huggingface.co/allenai/OLMo-3-0725-1B).
  43. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  44. documentation from [`PretrainedConfig`] for more information.
  45. Args:
  46. vocab_size (`int`, *optional*, defaults to 50304):
  47. Vocabulary size of the Olmo3 model. Defines the number of different tokens that can be represented by the
  48. `inputs_ids` passed when calling [`Olmo3Model`]
  49. hidden_size (`int`, *optional*, defaults to 4096):
  50. Dimension of the hidden representations.
  51. intermediate_size (`int`, *optional*, defaults to 11008):
  52. Dimension of the MLP representations.
  53. num_hidden_layers (`int`, *optional*, defaults to 32):
  54. Number of hidden layers in the Transformer decoder.
  55. num_attention_heads (`int`, *optional*, defaults to 32):
  56. Number of attention heads for each attention layer in the Transformer decoder.
  57. num_key_value_heads (`int`, *optional*):
  58. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  59. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  60. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  61. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  62. by meanpooling all the original heads within that group. For more details, check out [this
  63. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  64. `num_attention_heads`.
  65. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  66. The non-linear activation function (function or string) in the decoder.
  67. max_position_embeddings (`int`, *optional*, defaults to 2048):
  68. The maximum sequence length that this model might ever be used with.
  69. initializer_range (`float`, *optional*, defaults to 0.02):
  70. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  71. use_cache (`bool`, *optional*, defaults to `True`):
  72. Whether or not the model should return the last key/values attentions (not used by all models). Only
  73. relevant if `config.is_decoder=True`.
  74. pad_token_id (`int`, *optional*, defaults to 1):
  75. Padding token id.
  76. bos_token_id (`int`, *optional*):
  77. Beginning of stream token id.
  78. eos_token_id (`int`, *optional*, defaults to 50279):
  79. End of stream token id.
  80. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  81. Whether to tie weight embeddings
  82. rope_theta (`float`, *optional*, defaults to 10000.0):
  83. The base period of the RoPE embeddings.
  84. rope_scaling (`Dict`, *optional*):
  85. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  86. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  87. accordingly.
  88. Expected contents:
  89. `rope_type` (`str`):
  90. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  91. 'llama3'], with 'default' being the original RoPE implementation.
  92. `factor` (`float`, *optional*):
  93. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  94. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  95. original maximum pre-trained length.
  96. `original_max_position_embeddings` (`int`, *optional*):
  97. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  98. pretraining.
  99. `attention_factor` (`float`, *optional*):
  100. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  101. computation. If unspecified, it defaults to value recommended by the implementation, using the
  102. `factor` field to infer the suggested value.
  103. `beta_fast` (`float`, *optional*):
  104. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  105. ramp function. If unspecified, it defaults to 32.
  106. `beta_slow` (`float`, *optional*):
  107. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  108. ramp function. If unspecified, it defaults to 1.
  109. `short_factor` (`list[float]`, *optional*):
  110. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  111. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  112. size divided by the number of attention heads divided by 2
  113. `long_factor` (`list[float]`, *optional*):
  114. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  115. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  116. size divided by the number of attention heads divided by 2
  117. `low_freq_factor` (`float`, *optional*):
  118. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  119. `high_freq_factor` (`float`, *optional*):
  120. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  121. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  122. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  123. attention_dropout (`float`, *optional*, defaults to 0.0):
  124. The dropout ratio for the attention probabilities.
  125. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  126. The epsilon used by the rms normalization layers.
  127. sliding_window (`int`, *optional*, defaults to 4096):
  128. Size of the sliding window for sliding window attention.
  129. layer_types (`list`, *optional*):
  130. Attention pattern for each layer. Defaults to sliding window attention
  131. for 3 out of 4 layers, and full attention for every 4th layer.
  132. ```python
  133. >>> from transformers import Olmo3Model, Olmo3Config
  134. >>> # Initializing a Olmo3 7B style configuration
  135. >>> configuration = Olmo3Config()
  136. >>> # Initializing a model from the Olmo3 7B style configuration
  137. >>> model = Olmo3Model(configuration)
  138. >>> # Accessing the model configuration
  139. >>> configuration = model.config
  140. ```
  141. """
  142. model_type = "olmo3"
  143. base_model_tp_plan = {
  144. "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  145. "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  146. "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  147. "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
  148. "layers.*.mlp.gate_proj": "colwise",
  149. "layers.*.mlp.up_proj": "colwise",
  150. "layers.*.mlp.down_proj": "rowwise",
  151. }
  152. base_model_pp_plan = {
  153. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  154. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  155. "norm": (["hidden_states"], ["hidden_states"]),
  156. }
  157. def __init__(
  158. self,
  159. vocab_size=50304,
  160. hidden_size=4096,
  161. intermediate_size=11008,
  162. num_hidden_layers=32,
  163. num_attention_heads=32,
  164. num_key_value_heads=None,
  165. hidden_act="silu",
  166. max_position_embeddings=2048,
  167. initializer_range=0.02,
  168. use_cache=True,
  169. pad_token_id=1,
  170. bos_token_id=None,
  171. eos_token_id=50279,
  172. tie_word_embeddings=False,
  173. rope_theta=10000.0,
  174. rope_scaling=None,
  175. attention_bias=False,
  176. attention_dropout=0.0,
  177. rms_norm_eps=1e-5,
  178. sliding_window=4096,
  179. layer_types=None,
  180. **kwargs,
  181. ):
  182. super().__init__(
  183. vocab_size=vocab_size,
  184. hidden_size=hidden_size,
  185. intermediate_size=intermediate_size,
  186. num_hidden_layers=num_hidden_layers,
  187. num_attention_heads=num_attention_heads,
  188. num_key_value_heads=num_key_value_heads,
  189. hidden_act=hidden_act,
  190. max_position_embeddings=max_position_embeddings,
  191. initializer_range=initializer_range,
  192. use_cache=use_cache,
  193. pad_token_id=pad_token_id,
  194. bos_token_id=bos_token_id,
  195. eos_token_id=eos_token_id,
  196. tie_word_embeddings=tie_word_embeddings,
  197. rope_theta=rope_theta,
  198. rope_scaling=rope_scaling,
  199. attention_bias=attention_bias,
  200. attention_dropout=attention_dropout,
  201. rms_norm_eps=rms_norm_eps,
  202. **kwargs,
  203. )
  204. self.sliding_window = sliding_window
  205. self.layer_types = layer_types
  206. if self.layer_types is None:
  207. self.layer_types = [
  208. "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" for i in range(self.num_hidden_layers)
  209. ]
  210. layer_type_validation(self.layer_types)
  211. def _rope_scaling_validation(self):
  212. """
  213. Validate the `rope_scaling` configuration.
  214. """
  215. rope_config_validation(self)
  216. class Olmo3RMSNorm(Olmo2RMSNorm):
  217. pass
  218. # Olmo3 attention is identical to OLMo 2 attention except:
  219. # - Sliding window attention is used for 3 out of 4 layers.
  220. class Olmo3Attention(Olmo2Attention):
  221. def __init__(self, config: Olmo3Config, layer_idx: int):
  222. super().__init__(config, layer_idx=layer_idx)
  223. assert config.layer_types is not None
  224. self.attention_type = config.layer_types[layer_idx]
  225. self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None
  226. def forward(
  227. self,
  228. hidden_states: torch.Tensor,
  229. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  230. attention_mask: Optional[torch.Tensor],
  231. past_key_values: Optional[Cache] = None,
  232. cache_position: Optional[torch.LongTensor] = None,
  233. **kwargs: Unpack[TransformersKwargs],
  234. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  235. input_shape = hidden_states.shape[:-1]
  236. hidden_shape = (*input_shape, -1, self.head_dim)
  237. query_states = self.q_norm(self.q_proj(hidden_states))
  238. key_states = self.k_norm(self.k_proj(hidden_states))
  239. value_states = self.v_proj(hidden_states)
  240. query_states = query_states.view(hidden_shape).transpose(1, 2)
  241. key_states = key_states.view(hidden_shape).transpose(1, 2)
  242. value_states = value_states.view(hidden_shape).transpose(1, 2)
  243. cos, sin = position_embeddings
  244. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  245. if past_key_values is not None:
  246. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  247. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  248. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  249. attention_interface: Callable = eager_attention_forward
  250. if self.config._attn_implementation != "eager":
  251. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  252. attn_output, attn_weights = attention_interface(
  253. self,
  254. query_states,
  255. key_states,
  256. value_states,
  257. attention_mask,
  258. dropout=0.0 if not self.training else self.attention_dropout,
  259. scaling=self.scaling,
  260. sliding_window=self.sliding_window,
  261. **kwargs,
  262. )
  263. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  264. attn_output = self.o_proj(attn_output)
  265. return attn_output, attn_weights
  266. class Olmo3DecoderLayer(Olmo2DecoderLayer):
  267. pass
  268. # OLMo 3 RoPE is identical to OLMo 2 RoPE, except:
  269. # - RoPE scaling is not applied to sliding window attention layers.
  270. class Olmo3RotaryEmbedding(Olmo2RotaryEmbedding):
  271. def __init__(self, config: Olmo3Config, device=None, rope_type: Optional[str] = None):
  272. nn.Module.__init__(self)
  273. if rope_type is not None:
  274. self.rope_type = rope_type
  275. elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  276. # BC: "rope_type" was originally "type"
  277. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  278. else:
  279. self.rope_type = "default"
  280. assert self.rope_type is not None
  281. self.max_seq_len_cached = config.max_position_embeddings
  282. self.original_max_seq_len = config.max_position_embeddings
  283. self.config = config
  284. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  285. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  286. self.register_buffer("inv_freq", inv_freq, persistent=False)
  287. self.original_inv_freq = self.inv_freq
  288. class Olmo3PreTrainedModel(Olmo2PreTrainedModel):
  289. pass
  290. # The OLMo 3 model is identical to the OLMo 2 model, except:
  291. # - Sliding window attention is used for 3 out of 4 layers.
  292. # - RoPE scaling is not applied to sliding window attention layers.
  293. class Olmo3Model(Olmo2Model):
  294. def __init__(self, config: Olmo3Config):
  295. super().__init__(config)
  296. self.norm = Olmo3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  297. self.layers = nn.ModuleList(
  298. [Olmo3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  299. )
  300. self.rotary_embs = nn.ModuleDict(
  301. {
  302. "sliding_attention": Olmo3RotaryEmbedding(config=config, rope_type="default"),
  303. "full_attention": Olmo3RotaryEmbedding(config=config),
  304. }
  305. )
  306. del self.rotary_emb
  307. def forward(
  308. self,
  309. input_ids: Optional[torch.LongTensor] = None,
  310. attention_mask: Optional[torch.Tensor] = None,
  311. position_ids: Optional[torch.LongTensor] = None,
  312. past_key_values: Optional[Cache] = None,
  313. inputs_embeds: Optional[torch.FloatTensor] = None,
  314. cache_position: Optional[torch.LongTensor] = None,
  315. use_cache: Optional[bool] = None,
  316. **kwargs: Unpack[TransformersKwargs],
  317. ) -> BaseModelOutputWithPast:
  318. if (input_ids is None) ^ (inputs_embeds is not None):
  319. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  320. if inputs_embeds is None:
  321. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  322. if use_cache and past_key_values is None:
  323. past_key_values = DynamicCache(config=self.config)
  324. if cache_position is None:
  325. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  326. cache_position: torch.Tensor = torch.arange(
  327. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  328. )
  329. if position_ids is None:
  330. position_ids = cache_position.unsqueeze(0)
  331. # It may already have been prepared by e.g. `generate`
  332. if not isinstance(causal_mask_mapping := attention_mask, dict):
  333. # Prepare mask arguments
  334. mask_kwargs = {
  335. "config": self.config,
  336. "input_embeds": inputs_embeds,
  337. "attention_mask": attention_mask,
  338. "cache_position": cache_position,
  339. "past_key_values": past_key_values,
  340. "position_ids": position_ids,
  341. }
  342. # Create the masks
  343. causal_mask_mapping = {
  344. "full_attention": create_causal_mask(**mask_kwargs),
  345. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  346. }
  347. hidden_states = inputs_embeds
  348. position_embeddings_mapping = {
  349. "sliding_attention": self.rotary_embs["sliding_attention"](hidden_states, position_ids),
  350. "full_attention": self.rotary_embs["full_attention"](hidden_states, position_ids),
  351. }
  352. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  353. hidden_states = decoder_layer(
  354. hidden_states,
  355. attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type],
  356. position_ids=position_ids,
  357. past_key_values=past_key_values,
  358. cache_position=cache_position,
  359. position_embeddings=position_embeddings_mapping[decoder_layer.self_attn.attention_type],
  360. **kwargs,
  361. )
  362. hidden_states = self.norm(hidden_states)
  363. return BaseModelOutputWithPast(
  364. last_hidden_state=hidden_states,
  365. past_key_values=past_key_values,
  366. )
  367. class Olmo3ForCausalLM(Olmo2ForCausalLM):
  368. pass
  369. __all__ = [
  370. "Olmo3Config",
  371. "Olmo3ForCausalLM",
  372. "Olmo3Model",
  373. "Olmo3PreTrainedModel",
  374. ]