modular_apertus.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # coding=utf-8
  2. # Copyright 2025 the HuggingFace Inc. team and the Swiss AI Initiative. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import Callable, Optional
  17. import torch
  18. from torch import nn
  19. from ...cache_utils import Cache
  20. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  21. from ...processing_utils import Unpack
  22. from ...utils import TransformersKwargs, logging
  23. from ..llama.configuration_llama import LlamaConfig
  24. from ..llama.modeling_llama import (
  25. LlamaAttention,
  26. LlamaDecoderLayer,
  27. LlamaForCausalLM,
  28. LlamaForTokenClassification,
  29. LlamaModel,
  30. LlamaPreTrainedModel,
  31. LlamaRMSNorm,
  32. LlamaRotaryEmbedding,
  33. apply_rotary_pos_emb,
  34. eager_attention_forward,
  35. )
  36. from ..nemotron.modeling_nemotron import NemotronMLP
  37. logger = logging.get_logger(__name__)
  38. class ApertusConfig(LlamaConfig):
  39. r"""
  40. This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus
  41. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  42. defaults will yield a similar configuration to that of the Apertus-8B.
  43. e.g. [swiss-ai/Apertus-8B](https://huggingface.co/swiss-ai/Apertus-8B)
  44. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  45. documentation from [`PretrainedConfig`] for more information.
  46. Args:
  47. vocab_size (`int`, *optional*, defaults to 131072):
  48. Vocabulary size of the Apertus model. Defines the number of different tokens that can be represented by the
  49. `inputs_ids` passed when calling [`ApertusModel`]
  50. hidden_size (`int`, *optional*, defaults to 4096):
  51. Dimension of the hidden representations.
  52. intermediate_size (`int`, *optional*, defaults to 14336):
  53. Dimension of the MLP representations.
  54. num_hidden_layers (`int`, *optional*, defaults to 32):
  55. Number of hidden layers in the Transformer decoder.
  56. num_attention_heads (`int`, *optional*, defaults to 32):
  57. Number of attention heads for each attention layer in the Transformer decoder.
  58. num_key_value_heads (`int`, *optional*):
  59. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  60. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  61. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  62. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  63. by meanpooling all the original heads within that group. For more details, check out [this
  64. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  65. `num_attention_heads`.
  66. hidden_act (`str` or `function`, *optional*, defaults to `"xielu"`):
  67. The non-linear activation function (function or string) in the decoder.
  68. max_position_embeddings (`int`, *optional*, defaults to 65536):
  69. The maximum sequence length that this model might ever be used with. Apertus supports up to 65536 tokens.
  70. initializer_range (`float`, *optional*, defaults to 0.02):
  71. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  72. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  73. The epsilon used by the rms normalization layers.
  74. use_cache (`bool`, *optional*, defaults to `True`):
  75. Whether or not the model should return the last key/values attentions (not used by all models). Only
  76. relevant if `config.is_decoder=True`.
  77. pad_token_id (`int`, *optional*, defaults to 3):
  78. Padding token id.
  79. bos_token_id (`int`, *optional*, defaults to 1):
  80. Beginning of stream token id.
  81. eos_token_id (`int`, *optional*, defaults to 2):
  82. End of stream token id.
  83. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  84. Whether to tie weight embeddings
  85. rope_theta (`float`, *optional*, defaults to 12000000.0):
  86. The base period of the RoPE embeddings.
  87. rope_scaling (`Dict`, *optional*):
  88. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  89. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  90. accordingly.
  91. Expected contents:
  92. `rope_type` (`str`):
  93. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  94. 'llama3'], with 'default' being the original RoPE implementation.
  95. `factor` (`float`, *optional*):
  96. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  97. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  98. original maximum pre-trained length.
  99. `original_max_position_embeddings` (`int`, *optional*):
  100. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  101. pretraining.
  102. `attention_factor` (`float`, *optional*):
  103. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  104. computation. If unspecified, it defaults to value recommended by the implementation, using the
  105. `factor` field to infer the suggested value.
  106. `beta_fast` (`float`, *optional*):
  107. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  108. ramp function. If unspecified, it defaults to 32.
  109. `beta_slow` (`float`, *optional*):
  110. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  111. ramp function. If unspecified, it defaults to 1.
  112. `short_factor` (`list[float]`, *optional*):
  113. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  114. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  115. size divided by the number of attention heads divided by 2
  116. `long_factor` (`list[float]`, *optional*):
  117. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  118. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  119. size divided by the number of attention heads divided by 2
  120. `low_freq_factor` (`float`, *optional*):
  121. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  122. `high_freq_factor` (`float`, *optional*):
  123. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  124. attention_bias (`bool`, *optional*, defaults to `False`):
  125. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  126. attention_dropout (`float`, *optional*, defaults to 0.0):
  127. The dropout ratio for the attention probabilities.
  128. ```python
  129. >>> from transformers import ApertusModel, ApertusConfig
  130. >>> # Initializing a Apertus-8B style configuration
  131. >>> configuration = ApertusConfig()
  132. >>> # Initializing a model from the Apertus-8B style configuration
  133. >>> model = ApertusModel(configuration)
  134. >>> # Accessing the model configuration
  135. >>> configuration = model.config
  136. ```"""
  137. model_type = "apertus"
  138. base_model_tp_plan = {
  139. "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  140. "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  141. "layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
  142. "layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
  143. "layers.*.mlp.up_proj": "colwise",
  144. "layers.*.mlp.down_proj": "rowwise",
  145. "layers.*.mlp.gate_proj": "colwise",
  146. }
  147. def __init__(
  148. self,
  149. vocab_size=131072,
  150. hidden_size=4096,
  151. intermediate_size=14336,
  152. num_hidden_layers=32,
  153. num_attention_heads=32,
  154. num_key_value_heads=None,
  155. hidden_act="xielu",
  156. max_position_embeddings=65536,
  157. initializer_range=0.02,
  158. rms_norm_eps=1e-5,
  159. use_cache=True,
  160. pad_token_id=3,
  161. bos_token_id=1,
  162. eos_token_id=2,
  163. tie_word_embeddings=False,
  164. rope_theta=12000000.0,
  165. rope_scaling={
  166. "rope_type": "llama3",
  167. "factor": 8.0,
  168. "original_max_position_embeddings": 8192,
  169. "low_freq_factor": 1.0,
  170. "high_freq_factor": 4.0,
  171. },
  172. attention_bias=False,
  173. attention_dropout=0.0,
  174. **kwargs,
  175. ):
  176. super().__init__(
  177. vocab_size=vocab_size,
  178. hidden_size=hidden_size,
  179. intermediate_size=intermediate_size,
  180. num_hidden_layers=num_hidden_layers,
  181. num_attention_heads=num_attention_heads,
  182. num_key_value_heads=num_key_value_heads,
  183. hidden_act=hidden_act,
  184. max_position_embeddings=max_position_embeddings,
  185. initializer_range=initializer_range,
  186. rms_norm_eps=rms_norm_eps,
  187. use_cache=use_cache,
  188. pad_token_id=pad_token_id,
  189. bos_token_id=bos_token_id,
  190. eos_token_id=eos_token_id,
  191. tie_word_embeddings=tie_word_embeddings,
  192. rope_theta=rope_theta,
  193. rope_scaling=rope_scaling,
  194. attention_bias=attention_bias,
  195. attention_dropout=attention_dropout,
  196. **kwargs,
  197. )
  198. del self.pretraining_tp
  199. del self.mlp_bias
  200. del self.head_dim
  201. class ApertusMLP(NemotronMLP):
  202. def __init__(self, config):
  203. super().__init__()
  204. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  205. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  206. class ApertusRMSNorm(LlamaRMSNorm):
  207. pass
  208. class ApertusRotaryEmbedding(LlamaRotaryEmbedding):
  209. pass
  210. class ApertusAttention(LlamaAttention):
  211. def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
  212. super().__init__(config, layer_idx)
  213. self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
  214. self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
  215. def forward(
  216. self,
  217. hidden_states: torch.Tensor,
  218. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  219. attention_mask: Optional[torch.Tensor],
  220. past_key_values: Optional[Cache] = None,
  221. cache_position: Optional[torch.LongTensor] = None,
  222. **kwargs: Unpack[TransformersKwargs],
  223. ) -> tuple[torch.Tensor, torch.Tensor]:
  224. input_shape = hidden_states.shape[:-1]
  225. hidden_shape = (*input_shape, -1, self.head_dim)
  226. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  227. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  228. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  229. query_states = self.q_norm(query_states)
  230. key_states = self.k_norm(key_states)
  231. cos, sin = position_embeddings
  232. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  233. if past_key_values is not None:
  234. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  235. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  236. attention_interface: Callable = eager_attention_forward
  237. if self.config._attn_implementation != "eager":
  238. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  239. attn_output, attn_weights = attention_interface(
  240. self,
  241. query_states,
  242. key_states,
  243. value_states,
  244. attention_mask,
  245. dropout=0.0 if not self.training else self.attention_dropout,
  246. scaling=self.scaling,
  247. **kwargs,
  248. )
  249. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  250. attn_output = self.o_proj(attn_output)
  251. return attn_output, attn_weights
  252. class ApertusDecoderLayer(LlamaDecoderLayer):
  253. def __init__(self, config: ApertusConfig, layer_idx: int):
  254. super().__init__(config, layer_idx)
  255. self.attention_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  256. self.feedforward_layernorm = ApertusRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  257. del self.input_layernorm
  258. del self.post_attention_layernorm
  259. def forward(
  260. self,
  261. hidden_states: torch.Tensor,
  262. attention_mask: Optional[torch.Tensor] = None,
  263. position_ids: Optional[torch.LongTensor] = None,
  264. past_key_values: Optional[Cache] = None,
  265. use_cache: Optional[bool] = False,
  266. cache_position: Optional[torch.LongTensor] = None,
  267. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  268. **kwargs: Unpack[TransformersKwargs],
  269. ) -> tuple[torch.Tensor]:
  270. residual = hidden_states
  271. hidden_states = self.attention_layernorm(hidden_states)
  272. hidden_states, _ = self.self_attn(
  273. hidden_states=hidden_states,
  274. attention_mask=attention_mask,
  275. position_ids=position_ids,
  276. past_key_values=past_key_values,
  277. use_cache=use_cache,
  278. cache_position=cache_position,
  279. position_embeddings=position_embeddings,
  280. **kwargs,
  281. )
  282. hidden_states = residual + hidden_states
  283. # Fully Connected
  284. residual = hidden_states
  285. hidden_states = self.feedforward_layernorm(hidden_states)
  286. hidden_states = self.mlp(hidden_states)
  287. hidden_states = residual + hidden_states
  288. return hidden_states
  289. class ApertusPreTrainedModel(LlamaPreTrainedModel):
  290. pass
  291. class ApertusModel(LlamaModel):
  292. pass
  293. class ApertusForCausalLM(LlamaForCausalLM):
  294. def forward(self, **super_kwargs):
  295. r"""
  296. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  297. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  298. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  299. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  300. Example:
  301. ```python
  302. >>> from transformers import AutoTokenizer, ApertusForCausalLM
  303. >>> model = ApertusForCausalLM.from_pretrained("swiss-ai/Apertus-8B")
  304. >>> tokenizer = AutoTokenizer.from_pretrained("swiss-ai/Apertus-8B")
  305. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  306. >>> inputs = tokenizer(prompt, return_tensors="pt")
  307. >>> # Generate
  308. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  309. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  310. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  311. ```"""
  312. return super().forward(**super_kwargs)
  313. class ApertusForTokenClassification(LlamaForTokenClassification):
  314. pass
  315. __all__ = [
  316. "ApertusConfig",
  317. "ApertusModel",
  318. "ApertusForCausalLM",
  319. "ApertusForTokenClassification",
  320. "ApertusPreTrainedModel",
  321. ]