configuration_olmo3.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/olmo3/modular_olmo3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_olmo3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 the HuggingFace Team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from ...configuration_utils import PretrainedConfig, layer_type_validation
  22. from ...modeling_rope_utils import rope_config_validation
  23. class Olmo3Config(PretrainedConfig):
  24. r"""
  25. This is the configuration class to store the configuration of a [`Olmo3Model`]. It is used to instantiate an OLMo3
  26. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  27. defaults will yield a similar configuration to that of the [allenai/OLMo-3-0725-1B](https://huggingface.co/allenai/OLMo-3-0725-1B).
  28. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  29. documentation from [`PretrainedConfig`] for more information.
  30. Args:
  31. vocab_size (`int`, *optional*, defaults to 50304):
  32. Vocabulary size of the Olmo3 model. Defines the number of different tokens that can be represented by the
  33. `inputs_ids` passed when calling [`Olmo3Model`]
  34. hidden_size (`int`, *optional*, defaults to 4096):
  35. Dimension of the hidden representations.
  36. intermediate_size (`int`, *optional*, defaults to 11008):
  37. Dimension of the MLP representations.
  38. num_hidden_layers (`int`, *optional*, defaults to 32):
  39. Number of hidden layers in the Transformer decoder.
  40. num_attention_heads (`int`, *optional*, defaults to 32):
  41. Number of attention heads for each attention layer in the Transformer decoder.
  42. num_key_value_heads (`int`, *optional*):
  43. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  44. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  45. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  46. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  47. by meanpooling all the original heads within that group. For more details, check out [this
  48. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  49. `num_attention_heads`.
  50. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  51. The non-linear activation function (function or string) in the decoder.
  52. max_position_embeddings (`int`, *optional*, defaults to 2048):
  53. The maximum sequence length that this model might ever be used with.
  54. initializer_range (`float`, *optional*, defaults to 0.02):
  55. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  56. use_cache (`bool`, *optional*, defaults to `True`):
  57. Whether or not the model should return the last key/values attentions (not used by all models). Only
  58. relevant if `config.is_decoder=True`.
  59. pad_token_id (`int`, *optional*, defaults to 1):
  60. Padding token id.
  61. bos_token_id (`int`, *optional*):
  62. Beginning of stream token id.
  63. eos_token_id (`int`, *optional*, defaults to 50279):
  64. End of stream token id.
  65. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  66. Whether to tie weight embeddings
  67. rope_theta (`float`, *optional*, defaults to 10000.0):
  68. The base period of the RoPE embeddings.
  69. rope_scaling (`Dict`, *optional*):
  70. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  71. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  72. accordingly.
  73. Expected contents:
  74. `rope_type` (`str`):
  75. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  76. 'llama3'], with 'default' being the original RoPE implementation.
  77. `factor` (`float`, *optional*):
  78. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  79. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  80. original maximum pre-trained length.
  81. `original_max_position_embeddings` (`int`, *optional*):
  82. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  83. pretraining.
  84. `attention_factor` (`float`, *optional*):
  85. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  86. computation. If unspecified, it defaults to value recommended by the implementation, using the
  87. `factor` field to infer the suggested value.
  88. `beta_fast` (`float`, *optional*):
  89. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  90. ramp function. If unspecified, it defaults to 32.
  91. `beta_slow` (`float`, *optional*):
  92. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  93. ramp function. If unspecified, it defaults to 1.
  94. `short_factor` (`list[float]`, *optional*):
  95. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  96. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  97. size divided by the number of attention heads divided by 2
  98. `long_factor` (`list[float]`, *optional*):
  99. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  100. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  101. size divided by the number of attention heads divided by 2
  102. `low_freq_factor` (`float`, *optional*):
  103. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  104. `high_freq_factor` (`float`, *optional*):
  105. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  106. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  107. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  108. attention_dropout (`float`, *optional*, defaults to 0.0):
  109. The dropout ratio for the attention probabilities.
  110. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  111. The epsilon used by the rms normalization layers.
  112. sliding_window (`int`, *optional*, defaults to 4096):
  113. Size of the sliding window for sliding window attention.
  114. layer_types (`list`, *optional*):
  115. Attention pattern for each layer. Defaults to sliding window attention
  116. for 3 out of 4 layers, and full attention for every 4th layer.
  117. ```python
  118. >>> from transformers import Olmo3Model, Olmo3Config
  119. >>> # Initializing a Olmo3 7B style configuration
  120. >>> configuration = Olmo3Config()
  121. >>> # Initializing a model from the Olmo3 7B style configuration
  122. >>> model = Olmo3Model(configuration)
  123. >>> # Accessing the model configuration
  124. >>> configuration = model.config
  125. ```
  126. """
  127. model_type = "olmo3"
  128. keys_to_ignore_at_inference = ["past_key_values"]
  129. base_model_tp_plan = {
  130. "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  131. "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  132. "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  133. "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
  134. "layers.*.mlp.gate_proj": "colwise",
  135. "layers.*.mlp.up_proj": "colwise",
  136. "layers.*.mlp.down_proj": "rowwise",
  137. }
  138. base_model_pp_plan = {
  139. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  140. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  141. "norm": (["hidden_states"], ["hidden_states"]),
  142. }
  143. def __init__(
  144. self,
  145. vocab_size=50304,
  146. hidden_size=4096,
  147. intermediate_size=11008,
  148. num_hidden_layers=32,
  149. num_attention_heads=32,
  150. num_key_value_heads=None,
  151. hidden_act="silu",
  152. max_position_embeddings=2048,
  153. initializer_range=0.02,
  154. use_cache=True,
  155. pad_token_id=1,
  156. bos_token_id=None,
  157. eos_token_id=50279,
  158. tie_word_embeddings=False,
  159. rope_theta=10000.0,
  160. rope_scaling=None,
  161. attention_bias=False,
  162. attention_dropout=0.0,
  163. rms_norm_eps=1e-5,
  164. sliding_window=4096,
  165. layer_types=None,
  166. **kwargs,
  167. ):
  168. super().__init__(
  169. pad_token_id=pad_token_id,
  170. bos_token_id=bos_token_id,
  171. eos_token_id=eos_token_id,
  172. tie_word_embeddings=tie_word_embeddings,
  173. **kwargs,
  174. )
  175. self.vocab_size = vocab_size
  176. self.max_position_embeddings = max_position_embeddings
  177. self.hidden_size = hidden_size
  178. self.intermediate_size = intermediate_size
  179. self.num_hidden_layers = num_hidden_layers
  180. self.num_attention_heads = num_attention_heads
  181. # for backward compatibility
  182. if num_key_value_heads is None:
  183. num_key_value_heads = num_attention_heads
  184. self.num_key_value_heads = num_key_value_heads
  185. self.hidden_act = hidden_act
  186. self.initializer_range = initializer_range
  187. self.use_cache = use_cache
  188. self.rope_theta = rope_theta
  189. self.rope_scaling = rope_scaling
  190. self._rope_scaling_validation()
  191. self.attention_bias = attention_bias
  192. self.attention_dropout = attention_dropout
  193. self.rms_norm_eps = rms_norm_eps
  194. self.sliding_window = sliding_window
  195. self.layer_types = layer_types
  196. if self.layer_types is None:
  197. self.layer_types = [
  198. "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" for i in range(self.num_hidden_layers)
  199. ]
  200. layer_type_validation(self.layer_types)
  201. def _rope_scaling_validation(self):
  202. """
  203. Validate the `rope_scaling` configuration.
  204. """
  205. rope_config_validation(self)
  206. __all__ = ["Olmo3Config"]