modular_smollm3.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Callable, Optional
  16. import torch
  17. from ...cache_utils import Cache
  18. from ...configuration_utils import PretrainedConfig, layer_type_validation
  19. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  20. from ...modeling_rope_utils import rope_config_validation
  21. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  22. from ...processing_utils import Unpack
  23. from ...utils import logging
  24. from ...utils.deprecation import deprecate_kwarg
  25. from ..llama.modeling_llama import (
  26. LlamaAttention,
  27. LlamaDecoderLayer,
  28. LlamaForCausalLM,
  29. LlamaForQuestionAnswering,
  30. LlamaForSequenceClassification,
  31. LlamaForTokenClassification,
  32. LlamaPreTrainedModel,
  33. apply_rotary_pos_emb,
  34. eager_attention_forward,
  35. )
  36. from ..qwen2.modeling_qwen2 import Qwen2Model
  37. logger = logging.get_logger(__name__)
  38. class SmolLM3Config(PretrainedConfig):
  39. r"""
  40. This is the configuration class to store the configuration of a [`SmolLM3Model`]. It is used to instantiate a
  41. SmolLM3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
  42. with the defaults will yield a similar configuration to that of the SmolLM3 3B.
  43. e.g. [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B)
  44. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  45. documentation from [`PretrainedConfig`] for more information.
  46. Args:
  47. vocab_size (`int`, *optional*, defaults to 128256):
  48. Vocabulary size of the SmolLM3 model. Defines the number of different tokens that can be represented by the
  49. `inputs_ids` passed when calling [`SmolLM3Model`]
  50. hidden_size (`int`, *optional*, defaults to 2048):
  51. Dimension of the hidden representations.
  52. intermediate_size (`int`, *optional*, defaults to 11008):
  53. Dimension of the MLP representations.
  54. num_hidden_layers (`int`, *optional*, defaults to 36):
  55. Number of hidden layers in the Transformer encoder.
  56. num_attention_heads (`int`, *optional*, defaults to 16):
  57. Number of attention heads for each attention layer in the Transformer encoder.
  58. num_key_value_heads (`int`, *optional*, defaults to 4):
  59. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  60. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  61. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  62. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  63. by meanpooling all the original heads within that group. For more details checkout [this
  64. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `16`.
  65. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  66. The non-linear activation function (function or string) in the decoder.
  67. max_position_embeddings (`int`, *optional*, defaults to 32768):
  68. The maximum sequence length that this model might ever be used with.
  69. initializer_range (`float`, *optional*, defaults to 0.02):
  70. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  71. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  72. The epsilon used by the rms normalization layers.
  73. use_cache (`bool`, *optional*, defaults to `True`):
  74. Whether or not the model should return the last key/values attentions (not used by all models). Only
  75. relevant if `config.is_decoder=True`.
  76. pad_token_id (`int`, *optional*, defaults to 128004):
  77. The id of the padding token.
  78. bos_token_id (`int`, *optional*, defaults to 128000):
  79. The id of the beginning of sentence token.
  80. eos_token_id (`int`, *optional*, defaults to 128001):
  81. The id of the end of sentence token.
  82. rope_theta (`float`, *optional*, defaults to 2000000.0):
  83. The base period of the RoPE embeddings.
  84. rope_scaling (`Dict`, *optional*):
  85. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  86. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  87. accordingly.
  88. Expected contents:
  89. `rope_type` (`str`):
  90. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  91. 'llama3'], with 'default' being the original RoPE implementation.
  92. `factor` (`float`, *optional*):
  93. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  94. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  95. original maximum pre-trained length.
  96. `original_max_position_embeddings` (`int`, *optional*):
  97. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  98. pretraining.
  99. `attention_factor` (`float`, *optional*):
  100. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  101. computation. If unspecified, it defaults to value recommended by the implementation, using the
  102. `factor` field to infer the suggested value.
  103. `beta_fast` (`float`, *optional*):
  104. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  105. ramp function. If unspecified, it defaults to 32.
  106. `beta_slow` (`float`, *optional*):
  107. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  108. ramp function. If unspecified, it defaults to 1.
  109. `short_factor` (`List[float]`, *optional*):
  110. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  111. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  112. size divided by the number of attention heads divided by 2
  113. `long_factor` (`List[float]`, *optional*):
  114. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  115. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  116. size divided by the number of attention heads divided by 2
  117. `low_freq_factor` (`float`, *optional*):
  118. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  119. `high_freq_factor` (`float`, *optional*):
  120. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  121. use_sliding_window (`bool`, *optional*, defaults to `False`):
  122. Whether to use sliding window attention.
  123. sliding_window (`int`, *optional*):
  124. Sliding window attention (SWA) window size. If not specified, will default to `None`.
  125. no_rope_layers (`List[int]`, *optional*):
  126. List with at least the same length as the number of layers in the model.
  127. A `1` at an index position indicates that the corresponding layer will use RoPE,
  128. while a `0` indicates that it's a NoPE layer.
  129. no_rope_layer_interval (`int`, *optional*, defaults to 4):
  130. If `no_rope_layers` is `None`, it will be created using a NoPE layer every
  131. `no_rope_layer_interval` layers.
  132. layer_types (`list`, *optional*):
  133. Attention pattern for each layer. Automatically computed based on sliding window and NoPE settings.
  134. attention_bias (`bool`, *optional*, defaults to `False`):
  135. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  136. attention_dropout (`float`, *optional*, defaults to 0.0):
  137. The dropout ratio for the attention probabilities.
  138. ```python
  139. >>> from transformers import SmolLM3Model, SmolLM3Config
  140. >>> # Initializing a SmolLM3 style configuration
  141. >>> configuration = SmolLM3Config()
  142. >>> # Initializing a model from the SmolLM3 style configuration
  143. >>> model = SmolLM3Model(configuration)
  144. >>> # Accessing the model configuration
  145. >>> configuration = model.config
  146. ```"""
  147. model_type = "smollm3"
  148. keys_to_ignore_at_inference = ["past_key_values"]
  149. base_model_tp_plan = {
  150. "layers.*.self_attn.q_proj": "colwise",
  151. "layers.*.self_attn.k_proj": "colwise",
  152. "layers.*.self_attn.v_proj": "colwise",
  153. "layers.*.self_attn.o_proj": "rowwise",
  154. "layers.*.mlp.gate_proj": "colwise",
  155. "layers.*.mlp.up_proj": "colwise",
  156. "layers.*.mlp.down_proj": "rowwise",
  157. }
  158. base_model_pp_plan = {
  159. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  160. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  161. "norm": (["hidden_states"], ["hidden_states"]),
  162. }
  163. def __init__(
  164. self,
  165. vocab_size=128256,
  166. hidden_size=2048,
  167. intermediate_size=11008,
  168. num_hidden_layers=36,
  169. num_attention_heads=16,
  170. num_key_value_heads=4,
  171. hidden_act="silu",
  172. max_position_embeddings=32768,
  173. initializer_range=0.02,
  174. rms_norm_eps=1e-6,
  175. use_cache=True,
  176. pad_token_id=128004,
  177. bos_token_id=128000,
  178. eos_token_id=128001,
  179. rope_theta=2000000.0,
  180. rope_scaling=None,
  181. use_sliding_window=False,
  182. sliding_window=None,
  183. no_rope_layers=None,
  184. no_rope_layer_interval=4,
  185. layer_types=None,
  186. attention_bias=False,
  187. attention_dropout=0.0,
  188. mlp_bias=False,
  189. **kwargs,
  190. ):
  191. super().__init__(
  192. pad_token_id=pad_token_id,
  193. bos_token_id=bos_token_id,
  194. eos_token_id=eos_token_id,
  195. **kwargs,
  196. )
  197. self.vocab_size = vocab_size
  198. self.max_position_embeddings = max_position_embeddings
  199. self.mlp_bias = mlp_bias
  200. self.hidden_size = hidden_size
  201. self.intermediate_size = intermediate_size
  202. self.num_hidden_layers = num_hidden_layers
  203. self.num_attention_heads = num_attention_heads
  204. self.use_sliding_window = use_sliding_window
  205. self.sliding_window = sliding_window
  206. # for backward compatibility
  207. if num_key_value_heads is None:
  208. num_key_value_heads = num_attention_heads
  209. self.num_key_value_heads = num_key_value_heads
  210. self.hidden_act = hidden_act
  211. self.initializer_range = initializer_range
  212. self.rms_norm_eps = rms_norm_eps
  213. self.use_cache = use_cache
  214. self.rope_theta = rope_theta
  215. self.rope_scaling = rope_scaling
  216. self.attention_bias = attention_bias
  217. self.attention_dropout = attention_dropout
  218. if no_rope_layers is None:
  219. self.no_rope_layers = [
  220. int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(num_hidden_layers)
  221. ]
  222. else:
  223. self.no_rope_layers = no_rope_layers
  224. self.no_rope_layer_interval = no_rope_layer_interval
  225. # Update layer_types based on sliding window and NoPE pattern
  226. if layer_types is None:
  227. layer_types = []
  228. for layer_idx in range(num_hidden_layers):
  229. has_rope = self.no_rope_layers[layer_idx]
  230. if use_sliding_window and sliding_window is not None and not has_rope:
  231. layer_types.append("sliding_attention")
  232. else:
  233. layer_types.append("full_attention")
  234. self.layer_types = layer_types
  235. layer_type_validation(self.layer_types, self.num_hidden_layers)
  236. # Validate the correctness of rotary position embeddings parameters
  237. # BC: if there is a 'type' field, move it to 'rope_type'.
  238. if self.rope_scaling is not None and "type" in self.rope_scaling:
  239. self.rope_scaling["rope_type"] = self.rope_scaling["type"]
  240. rope_config_validation(self)
  241. class SmolLM3Attention(LlamaAttention):
  242. def __init__(self, config: SmolLM3Config, layer_idx: int):
  243. super().__init__(config, layer_idx)
  244. self.use_rope = config.no_rope_layers[layer_idx]
  245. self.sliding_window = (
  246. config.sliding_window
  247. if config.use_sliding_window and config.layer_types[layer_idx] == "sliding_attention"
  248. else None
  249. )
  250. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  251. def forward(
  252. self,
  253. hidden_states: torch.Tensor,
  254. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  255. attention_mask: Optional[torch.Tensor],
  256. past_key_values: Optional[Cache] = None,
  257. cache_position: Optional[torch.LongTensor] = None,
  258. **kwargs: Unpack[FlashAttentionKwargs],
  259. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  260. input_shape = hidden_states.shape[:-1]
  261. hidden_shape = (*input_shape, -1, self.head_dim)
  262. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  263. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  264. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  265. if self.use_rope:
  266. cos, sin = position_embeddings
  267. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  268. if past_key_values is not None:
  269. cache_kwargs = {"cache_position": cache_position}
  270. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  271. attention_interface: Callable = eager_attention_forward
  272. if self.config._attn_implementation != "eager":
  273. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  274. attn_output, attn_weights = attention_interface(
  275. self,
  276. query_states,
  277. key_states,
  278. value_states,
  279. attention_mask,
  280. dropout=0.0 if not self.training else self.attention_dropout,
  281. scaling=self.scaling,
  282. sliding_window=self.sliding_window,
  283. **kwargs,
  284. )
  285. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  286. attn_output = self.o_proj(attn_output)
  287. return attn_output, attn_weights
  288. class SmolLM3DecoderLayer(LlamaDecoderLayer):
  289. def __init__(self, config: SmolLM3Config, layer_idx: int):
  290. super().__init__(config, layer_idx)
  291. self.attention_type = config.layer_types[layer_idx]
  292. class SmolLM3PreTrainedModel(LlamaPreTrainedModel):
  293. pass
  294. class SmolLM3Model(Qwen2Model):
  295. pass
  296. class SmolLM3ForCausalLM(LlamaForCausalLM):
  297. pass
  298. class SmolLM3ForSequenceClassification(LlamaForSequenceClassification):
  299. pass
  300. class SmolLM3ForTokenClassification(LlamaForTokenClassification):
  301. pass
  302. class SmolLM3ForQuestionAnswering(LlamaForQuestionAnswering):
  303. pass
  304. __all__ = [
  305. "SmolLM3Config",
  306. "SmolLM3PreTrainedModel",
  307. "SmolLM3Model",
  308. "SmolLM3ForCausalLM",
  309. "SmolLM3ForSequenceClassification",
  310. "SmolLM3ForTokenClassification",
  311. "SmolLM3ForQuestionAnswering",
  312. ]