modeling_arcee.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/arcee/modular_arcee.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_arcee.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from typing import Callable, Optional, Union
  22. import torch
  23. from torch import nn
  24. from transformers.utils import auto_docstring
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_layers import (
  31. GenericForQuestionAnswering,
  32. GenericForSequenceClassification,
  33. GenericForTokenClassification,
  34. GradientCheckpointingLayer,
  35. )
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, can_return_tuple
  41. from ...utils.deprecation import deprecate_kwarg
  42. from ...utils.generic import check_model_inputs
  43. from .configuration_arcee import ArceeConfig
  44. class ArceeMLP(nn.Module):
  45. def __init__(self, config):
  46. super().__init__()
  47. self.config = config
  48. self.hidden_size = config.hidden_size
  49. self.intermediate_size = config.intermediate_size
  50. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  51. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  52. self.act_fn = ACT2FN[config.hidden_act]
  53. def forward(self, x):
  54. return self.down_proj(self.act_fn(self.up_proj(x)))
  55. @use_kernel_forward_from_hub("RMSNorm")
  56. class ArceeRMSNorm(nn.Module):
  57. def __init__(self, hidden_size, eps=1e-6):
  58. """
  59. ArceeRMSNorm is equivalent to T5LayerNorm
  60. """
  61. super().__init__()
  62. self.weight = nn.Parameter(torch.ones(hidden_size))
  63. self.variance_epsilon = eps
  64. def forward(self, hidden_states):
  65. input_dtype = hidden_states.dtype
  66. hidden_states = hidden_states.to(torch.float32)
  67. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  68. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  69. return self.weight * hidden_states.to(input_dtype)
  70. def extra_repr(self):
  71. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  72. class ArceeRotaryEmbedding(nn.Module):
  73. inv_freq: torch.Tensor # fix linting for `register_buffer`
  74. def __init__(self, config: ArceeConfig, device=None):
  75. super().__init__()
  76. # BC: "rope_type" was originally "type"
  77. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  78. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  79. else:
  80. self.rope_type = "default"
  81. self.max_seq_len_cached = config.max_position_embeddings
  82. self.original_max_seq_len = config.max_position_embeddings
  83. self.config = config
  84. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  85. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  86. self.register_buffer("inv_freq", inv_freq, persistent=False)
  87. self.original_inv_freq = self.inv_freq
  88. @torch.no_grad()
  89. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  90. def forward(self, x, position_ids):
  91. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  92. position_ids_expanded = position_ids[:, None, :].float()
  93. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  94. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  95. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  96. emb = torch.cat((freqs, freqs), dim=-1)
  97. cos = emb.cos() * self.attention_scaling
  98. sin = emb.sin() * self.attention_scaling
  99. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  100. def rotate_half(x):
  101. """Rotates half the hidden dims of the input."""
  102. x1 = x[..., : x.shape[-1] // 2]
  103. x2 = x[..., x.shape[-1] // 2 :]
  104. return torch.cat((-x2, x1), dim=-1)
  105. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  106. """Applies Rotary Position Embedding to the query and key tensors.
  107. Args:
  108. q (`torch.Tensor`): The query tensor.
  109. k (`torch.Tensor`): The key tensor.
  110. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  111. sin (`torch.Tensor`): The sine part of the rotary embedding.
  112. position_ids (`torch.Tensor`, *optional*):
  113. Deprecated and unused.
  114. unsqueeze_dim (`int`, *optional*, defaults to 1):
  115. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  116. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  117. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  118. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  119. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  120. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  121. Returns:
  122. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  123. """
  124. cos = cos.unsqueeze(unsqueeze_dim)
  125. sin = sin.unsqueeze(unsqueeze_dim)
  126. q_embed = (q * cos) + (rotate_half(q) * sin)
  127. k_embed = (k * cos) + (rotate_half(k) * sin)
  128. return q_embed, k_embed
  129. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  130. """
  131. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  132. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  133. """
  134. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  135. if n_rep == 1:
  136. return hidden_states
  137. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  138. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  139. def eager_attention_forward(
  140. module: nn.Module,
  141. query: torch.Tensor,
  142. key: torch.Tensor,
  143. value: torch.Tensor,
  144. attention_mask: Optional[torch.Tensor],
  145. scaling: float,
  146. dropout: float = 0.0,
  147. **kwargs: Unpack[TransformersKwargs],
  148. ):
  149. key_states = repeat_kv(key, module.num_key_value_groups)
  150. value_states = repeat_kv(value, module.num_key_value_groups)
  151. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  152. if attention_mask is not None:
  153. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  154. attn_weights = attn_weights + causal_mask
  155. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  156. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  157. attn_output = torch.matmul(attn_weights, value_states)
  158. attn_output = attn_output.transpose(1, 2).contiguous()
  159. return attn_output, attn_weights
  160. class ArceeAttention(nn.Module):
  161. """Multi-headed attention from 'Attention Is All You Need' paper"""
  162. def __init__(self, config: ArceeConfig, layer_idx: int):
  163. super().__init__()
  164. self.config = config
  165. self.layer_idx = layer_idx
  166. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  167. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  168. self.scaling = self.head_dim**-0.5
  169. self.attention_dropout = config.attention_dropout
  170. self.is_causal = True
  171. self.q_proj = nn.Linear(
  172. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  173. )
  174. self.k_proj = nn.Linear(
  175. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  176. )
  177. self.v_proj = nn.Linear(
  178. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  179. )
  180. self.o_proj = nn.Linear(
  181. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  182. )
  183. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  184. def forward(
  185. self,
  186. hidden_states: torch.Tensor,
  187. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  188. attention_mask: Optional[torch.Tensor],
  189. past_key_values: Optional[Cache] = None,
  190. cache_position: Optional[torch.LongTensor] = None,
  191. **kwargs: Unpack[TransformersKwargs],
  192. ) -> tuple[torch.Tensor, torch.Tensor]:
  193. input_shape = hidden_states.shape[:-1]
  194. hidden_shape = (*input_shape, -1, self.head_dim)
  195. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  196. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  197. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  198. cos, sin = position_embeddings
  199. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  200. if past_key_values is not None:
  201. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  202. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  203. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  204. attention_interface: Callable = eager_attention_forward
  205. if self.config._attn_implementation != "eager":
  206. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  207. attn_output, attn_weights = attention_interface(
  208. self,
  209. query_states,
  210. key_states,
  211. value_states,
  212. attention_mask,
  213. dropout=0.0 if not self.training else self.attention_dropout,
  214. scaling=self.scaling,
  215. **kwargs,
  216. )
  217. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  218. attn_output = self.o_proj(attn_output)
  219. return attn_output, attn_weights
  220. class ArceeDecoderLayer(GradientCheckpointingLayer):
  221. def __init__(self, config: ArceeConfig, layer_idx: int):
  222. super().__init__()
  223. self.hidden_size = config.hidden_size
  224. self.self_attn = ArceeAttention(config=config, layer_idx=layer_idx)
  225. self.mlp = ArceeMLP(config)
  226. self.input_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  227. self.post_attention_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  228. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  229. def forward(
  230. self,
  231. hidden_states: torch.Tensor,
  232. attention_mask: Optional[torch.Tensor] = None,
  233. position_ids: Optional[torch.LongTensor] = None,
  234. past_key_values: Optional[Cache] = None,
  235. use_cache: Optional[bool] = False,
  236. cache_position: Optional[torch.LongTensor] = None,
  237. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  238. **kwargs: Unpack[TransformersKwargs],
  239. ) -> torch.Tensor:
  240. residual = hidden_states
  241. hidden_states = self.input_layernorm(hidden_states)
  242. # Self Attention
  243. hidden_states, _ = self.self_attn(
  244. hidden_states=hidden_states,
  245. attention_mask=attention_mask,
  246. position_ids=position_ids,
  247. past_key_values=past_key_values,
  248. use_cache=use_cache,
  249. cache_position=cache_position,
  250. position_embeddings=position_embeddings,
  251. **kwargs,
  252. )
  253. hidden_states = residual + hidden_states
  254. # Fully Connected
  255. residual = hidden_states
  256. hidden_states = self.post_attention_layernorm(hidden_states)
  257. hidden_states = self.mlp(hidden_states)
  258. hidden_states = residual + hidden_states
  259. return hidden_states
  260. @auto_docstring
  261. class ArceePreTrainedModel(PreTrainedModel):
  262. config: ArceeConfig
  263. base_model_prefix = "model"
  264. supports_gradient_checkpointing = True
  265. _no_split_modules = ["ArceeDecoderLayer"]
  266. _skip_keys_device_placement = ["past_key_values"]
  267. _supports_flash_attn = True
  268. _supports_sdpa = True
  269. _supports_flex_attn = True
  270. _can_compile_fullgraph = True
  271. _supports_attention_backend = True
  272. _can_record_outputs = {
  273. "hidden_states": ArceeDecoderLayer,
  274. "attentions": ArceeAttention,
  275. }
  276. @auto_docstring
  277. class ArceeModel(ArceePreTrainedModel):
  278. def __init__(self, config: ArceeConfig):
  279. super().__init__(config)
  280. self.padding_idx = config.pad_token_id
  281. self.vocab_size = config.vocab_size
  282. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  283. self.layers = nn.ModuleList(
  284. [ArceeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  285. )
  286. self.norm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  287. self.rotary_emb = ArceeRotaryEmbedding(config=config)
  288. self.gradient_checkpointing = False
  289. # Initialize weights and apply final processing
  290. self.post_init()
  291. @check_model_inputs()
  292. @auto_docstring
  293. def forward(
  294. self,
  295. input_ids: Optional[torch.LongTensor] = None,
  296. attention_mask: Optional[torch.Tensor] = None,
  297. position_ids: Optional[torch.LongTensor] = None,
  298. past_key_values: Optional[Cache] = None,
  299. inputs_embeds: Optional[torch.FloatTensor] = None,
  300. cache_position: Optional[torch.LongTensor] = None,
  301. use_cache: Optional[bool] = None,
  302. **kwargs: Unpack[TransformersKwargs],
  303. ) -> BaseModelOutputWithPast:
  304. if (input_ids is None) ^ (inputs_embeds is not None):
  305. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  306. if inputs_embeds is None:
  307. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  308. if use_cache and past_key_values is None:
  309. past_key_values = DynamicCache(config=self.config)
  310. if cache_position is None:
  311. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  312. cache_position: torch.Tensor = torch.arange(
  313. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  314. )
  315. if position_ids is None:
  316. position_ids = cache_position.unsqueeze(0)
  317. causal_mask = create_causal_mask(
  318. config=self.config,
  319. input_embeds=inputs_embeds,
  320. attention_mask=attention_mask,
  321. cache_position=cache_position,
  322. past_key_values=past_key_values,
  323. position_ids=position_ids,
  324. )
  325. hidden_states = inputs_embeds
  326. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  327. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  328. hidden_states = decoder_layer(
  329. hidden_states,
  330. attention_mask=causal_mask,
  331. position_ids=position_ids,
  332. past_key_values=past_key_values,
  333. cache_position=cache_position,
  334. position_embeddings=position_embeddings,
  335. **kwargs,
  336. )
  337. hidden_states = self.norm(hidden_states)
  338. return BaseModelOutputWithPast(
  339. last_hidden_state=hidden_states,
  340. past_key_values=past_key_values,
  341. )
  342. @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
  343. class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
  344. _tied_weights_keys = ["lm_head.weight"]
  345. _tp_plan = {"lm_head": "colwise_rep"}
  346. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  347. def __init__(self, config):
  348. super().__init__(config)
  349. self.model = ArceeModel(config)
  350. self.vocab_size = config.vocab_size
  351. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  352. # Initialize weights and apply final processing
  353. self.post_init()
  354. @can_return_tuple
  355. @auto_docstring
  356. def forward(
  357. self,
  358. input_ids: Optional[torch.LongTensor] = None,
  359. attention_mask: Optional[torch.Tensor] = None,
  360. position_ids: Optional[torch.LongTensor] = None,
  361. past_key_values: Optional[Cache] = None,
  362. inputs_embeds: Optional[torch.FloatTensor] = None,
  363. labels: Optional[torch.LongTensor] = None,
  364. use_cache: Optional[bool] = None,
  365. cache_position: Optional[torch.LongTensor] = None,
  366. logits_to_keep: Union[int, torch.Tensor] = 0,
  367. **kwargs: Unpack[TransformersKwargs],
  368. ) -> CausalLMOutputWithPast:
  369. r"""
  370. Example:
  371. ```python
  372. >>> from transformers import AutoTokenizer, ArceeForCausalLM
  373. >>> model = ArceeForCausalLM.from_pretrained("meta-arcee/Arcee-2-7b-hf")
  374. >>> tokenizer = AutoTokenizer.from_pretrained("meta-arcee/Arcee-2-7b-hf")
  375. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  376. >>> inputs = tokenizer(prompt, return_tensors="pt")
  377. >>> # Generate
  378. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  379. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  380. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  381. ```"""
  382. outputs: BaseModelOutputWithPast = self.model(
  383. input_ids=input_ids,
  384. attention_mask=attention_mask,
  385. position_ids=position_ids,
  386. past_key_values=past_key_values,
  387. inputs_embeds=inputs_embeds,
  388. use_cache=use_cache,
  389. cache_position=cache_position,
  390. **kwargs,
  391. )
  392. hidden_states = outputs.last_hidden_state
  393. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  394. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  395. logits = self.lm_head(hidden_states[:, slice_indices, :])
  396. loss = None
  397. if labels is not None:
  398. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  399. return CausalLMOutputWithPast(
  400. loss=loss,
  401. logits=logits,
  402. past_key_values=outputs.past_key_values,
  403. hidden_states=outputs.hidden_states,
  404. attentions=outputs.attentions,
  405. )
  406. @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
  407. class ArceeForSequenceClassification(GenericForSequenceClassification, ArceePreTrainedModel):
  408. pass
  409. @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
  410. class ArceeForQuestionAnswering(GenericForQuestionAnswering, ArceePreTrainedModel):
  411. base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
  412. @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
  413. class ArceeForTokenClassification(GenericForTokenClassification, ArceePreTrainedModel):
  414. pass
  415. __all__ = [
  416. "ArceeForCausalLM",
  417. "ArceeForQuestionAnswering",
  418. "ArceeForSequenceClassification",
  419. "ArceeForTokenClassification",
  420. "ArceeModel",
  421. "ArceePreTrainedModel",
  422. ]