modeling_diffllama.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/diffllama/modular_diffllama.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_diffllama.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # This code is based on Llama implementations in this library and Microsoft's
  11. # Differential Transformer implementations.
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. import math
  24. from typing import Optional, Union
  25. import torch
  26. from torch import nn
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache, StaticCache
  29. from ...generation import GenerationMixin
  30. from ...integrations import use_kernel_forward_from_hub
  31. from ...masking_utils import create_causal_mask
  32. from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
  33. from ...modeling_layers import (
  34. GenericForQuestionAnswering,
  35. GenericForSequenceClassification,
  36. GenericForTokenClassification,
  37. GradientCheckpointingLayer,
  38. )
  39. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  40. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  41. from ...modeling_utils import PreTrainedModel
  42. from ...processing_utils import Unpack
  43. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  44. from ...utils.deprecation import deprecate_kwarg
  45. from ...utils.generic import check_model_inputs
  46. from .configuration_diffllama import DiffLlamaConfig
  47. logger = logging.get_logger(__name__)
  48. class DiffLlamaMLP(nn.Module):
  49. def __init__(self, config):
  50. super().__init__()
  51. self.config = config
  52. self.hidden_size = config.hidden_size
  53. self.intermediate_size = config.intermediate_size
  54. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  55. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  56. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  57. self.act_fn = ACT2FN[config.hidden_act]
  58. def forward(self, x):
  59. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  60. return down_proj
  61. def rotate_half(x):
  62. """Rotates half the hidden dims of the input."""
  63. x1 = x[..., : x.shape[-1] // 2]
  64. x2 = x[..., x.shape[-1] // 2 :]
  65. return torch.cat((-x2, x1), dim=-1)
  66. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  67. """Applies Rotary Position Embedding to the query and key tensors.
  68. Args:
  69. q (`torch.Tensor`): The query tensor.
  70. k (`torch.Tensor`): The key tensor.
  71. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  72. sin (`torch.Tensor`): The sine part of the rotary embedding.
  73. position_ids (`torch.Tensor`, *optional*):
  74. Deprecated and unused.
  75. unsqueeze_dim (`int`, *optional*, defaults to 1):
  76. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  77. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  78. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  79. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  80. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  81. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  82. Returns:
  83. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  84. """
  85. cos = cos.unsqueeze(unsqueeze_dim)
  86. sin = sin.unsqueeze(unsqueeze_dim)
  87. q_embed = (q * cos) + (rotate_half(q) * sin)
  88. k_embed = (k * cos) + (rotate_half(k) * sin)
  89. return q_embed, k_embed
  90. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  91. """
  92. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  93. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  94. """
  95. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  96. if n_rep == 1:
  97. return hidden_states
  98. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  99. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  100. def lambda_init_fn(layer_idx):
  101. return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
  102. class DiffLlamaAttention(nn.Module):
  103. """Multi-headed attention from 'Attention Is All You Need' paper"""
  104. def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
  105. super().__init__()
  106. self.config = config
  107. self.layer_idx = layer_idx
  108. if layer_idx is None:
  109. logger.warning_once(
  110. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  111. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  112. "when creating this class."
  113. )
  114. self.attention_dropout = config.attention_dropout
  115. self.hidden_size = config.hidden_size
  116. self.num_heads = config.num_attention_heads
  117. self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
  118. self.num_key_value_heads = config.num_key_value_heads
  119. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  120. # under this are not used
  121. self.max_position_embeddings = config.max_position_embeddings
  122. self.rope_theta = config.rope_theta
  123. self.is_causal = True
  124. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  125. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  126. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  127. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
  128. self.lambda_init = lambda_init_fn(layer_idx)
  129. self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  130. self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  131. self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  132. self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  133. self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
  134. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  135. def forward(
  136. self,
  137. hidden_states: torch.Tensor,
  138. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  139. attention_mask: Optional[torch.Tensor] = None,
  140. position_ids: Optional[torch.LongTensor] = None,
  141. past_key_values: Optional[Cache] = None,
  142. use_cache: bool = False,
  143. cache_position: Optional[torch.LongTensor] = None,
  144. **kwargs,
  145. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  146. bsz, target_len, _ = hidden_states.size()
  147. q_len = target_len
  148. query_states = self.q_proj(hidden_states)
  149. key_states = self.k_proj(hidden_states)
  150. value_states = self.v_proj(hidden_states)
  151. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  152. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  153. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  154. cos, sin = position_embeddings
  155. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  156. if past_key_values is not None:
  157. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  158. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  159. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  160. key_states = repeat_kv(key_states, self.num_key_value_groups)
  161. value_states = repeat_kv(value_states, self.num_key_value_groups)
  162. value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
  163. value_states = value_states.repeat(1, 2, 1, 1)
  164. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  165. if attention_mask is not None: # no matter the length, we just slice it
  166. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  167. attn_weights = attn_weights + causal_mask
  168. # upcast attention to fp32
  169. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  170. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  171. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  172. query_states.dtype
  173. )
  174. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  175. query_states.dtype
  176. )
  177. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  178. attn_output = torch.matmul(attn_weights, value_states)
  179. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
  180. attn_output = attn_output1 - lambda_full * attn_output2
  181. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  182. attn_output = attn_output.transpose(1, 2).contiguous()
  183. attn_output = attn_output.reshape(bsz, q_len, -1)
  184. attn_output = self.o_proj(attn_output)
  185. return attn_output, attn_weights
  186. class DiffLlamaFlashAttention2(DiffLlamaAttention):
  187. """
  188. DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
  189. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  190. flash attention and deal with padding tokens in case the input contains any of them.
  191. """
  192. def __init__(self, *args, **kwargs):
  193. super().__init__(*args, **kwargs)
  194. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  195. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  196. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  197. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  198. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  199. def forward(
  200. self,
  201. hidden_states: torch.Tensor,
  202. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  203. attention_mask: Optional[torch.LongTensor] = None,
  204. position_ids: Optional[torch.LongTensor] = None,
  205. past_key_values: Optional[Cache] = None,
  206. use_cache: bool = False,
  207. cache_position: Optional[torch.LongTensor] = None,
  208. ) -> tuple[torch.Tensor, None]:
  209. if isinstance(past_key_values, StaticCache):
  210. raise ValueError(
  211. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  212. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  213. )
  214. bsz, q_len, _ = hidden_states.size()
  215. query_states = self.q_proj(hidden_states)
  216. key_states = self.k_proj(hidden_states)
  217. value_states = self.v_proj(hidden_states)
  218. # Flash attention requires the input to have the shape
  219. # batch_size x seq_length x head_dim x hidden_dim
  220. # therefore we just need to keep the original shape
  221. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  222. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  223. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  224. if position_embeddings is None:
  225. logger.warning_once(
  226. "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
  227. "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
  228. "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
  229. "removed and `position_embeddings` will be mandatory."
  230. )
  231. cos, sin = self.rotary_emb(value_states, position_ids)
  232. else:
  233. cos, sin = position_embeddings
  234. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  235. if past_key_values is not None:
  236. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  237. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  238. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  239. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  240. # to be able to avoid many of these transpose/reshape/view.
  241. query_states = query_states.transpose(1, 2)
  242. key_states = key_states.transpose(1, 2)
  243. value_states = value_states.transpose(1, 2)
  244. dropout_rate = self.attention_dropout if self.training else 0.0
  245. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  246. # therefore the input hidden states gets silently casted in float32. Hence, we need
  247. # cast them back in the correct dtype just to be sure everything works as expected.
  248. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  249. # in fp32. (DiffLlamaRMSNorm handles it correctly)
  250. input_dtype = query_states.dtype
  251. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  252. if input_dtype == torch.float32:
  253. if torch.is_autocast_enabled():
  254. target_dtype = (
  255. torch.get_autocast_dtype(device_type)
  256. if hasattr(torch, "get_autocast_dtype")
  257. else torch.get_autocast_gpu_dtype()
  258. )
  259. # Handle the case where the model is quantized
  260. elif hasattr(self.config, "_pre_quantization_dtype"):
  261. target_dtype = self.config._pre_quantization_dtype
  262. else:
  263. target_dtype = self.q_proj.weight.dtype
  264. logger.warning_once(
  265. f"The input hidden states seems to be silently casted in float32, this might be related to"
  266. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  267. f" {target_dtype}."
  268. )
  269. query_states = query_states.to(target_dtype)
  270. key_states = key_states.to(target_dtype)
  271. value_states = value_states.to(target_dtype)
  272. value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
  273. value_states1 = value_states1.repeat(1, 1, 2, 1)
  274. value_states2 = value_states2.repeat(1, 1, 2, 1)
  275. attn_output1 = _flash_attention_forward(
  276. query_states,
  277. key_states,
  278. value_states1,
  279. attention_mask,
  280. q_len,
  281. position_ids=position_ids,
  282. dropout=dropout_rate,
  283. sliding_window=getattr(self, "sliding_window", None),
  284. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  285. is_causal=self.is_causal,
  286. )
  287. attn_output2 = _flash_attention_forward(
  288. query_states,
  289. key_states,
  290. value_states2,
  291. attention_mask,
  292. q_len,
  293. position_ids=position_ids,
  294. dropout=dropout_rate,
  295. sliding_window=getattr(self, "sliding_window", None),
  296. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  297. is_causal=self.is_causal,
  298. )
  299. attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
  300. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
  301. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  302. query_states.dtype
  303. )
  304. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  305. query_states.dtype
  306. )
  307. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  308. attn_output = attn_output1 - lambda_full * attn_output2
  309. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  310. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  311. attn_output = self.o_proj(attn_output)
  312. return attn_output, None
  313. class DiffLlamaSdpaAttention(DiffLlamaAttention):
  314. """
  315. DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  316. `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  317. SDPA API.
  318. """
  319. # Adapted from DiffLlamaAttention.forward
  320. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  321. def forward(
  322. self,
  323. hidden_states: torch.Tensor,
  324. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  325. attention_mask: Optional[torch.Tensor] = None,
  326. position_ids: Optional[torch.LongTensor] = None,
  327. past_key_values: Optional[Cache] = None,
  328. use_cache: bool = False,
  329. cache_position: Optional[torch.LongTensor] = None,
  330. **kwargs,
  331. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  332. bsz, q_len, _ = hidden_states.size()
  333. query_states = self.q_proj(hidden_states)
  334. key_states = self.k_proj(hidden_states)
  335. value_states = self.v_proj(hidden_states)
  336. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  337. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  338. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  339. cos, sin = position_embeddings
  340. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  341. if past_key_values is not None:
  342. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  343. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  344. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  345. key_states = repeat_kv(key_states, self.num_key_value_groups)
  346. value_states = repeat_kv(value_states, self.num_key_value_groups)
  347. value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
  348. value_states = value_states.repeat(1, 2, 1, 1)
  349. causal_mask = attention_mask
  350. if attention_mask is not None:
  351. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  352. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  353. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  354. if query_states.device.type == "cuda" and causal_mask is not None:
  355. query_states = query_states.contiguous()
  356. key_states = key_states.contiguous()
  357. value_states = value_states.contiguous()
  358. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  359. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  360. is_causal = causal_mask is None and q_len > 1
  361. attn_output = torch.nn.functional.scaled_dot_product_attention(
  362. query_states,
  363. key_states,
  364. value_states,
  365. attn_mask=causal_mask,
  366. dropout_p=self.attention_dropout if self.training else 0.0,
  367. is_causal=is_causal,
  368. )
  369. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
  370. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  371. query_states.dtype
  372. )
  373. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  374. query_states.dtype
  375. )
  376. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  377. attn_output = attn_output1 - lambda_full * attn_output2
  378. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  379. attn_output = attn_output.transpose(1, 2).contiguous()
  380. attn_output = attn_output.view(bsz, q_len, -1)
  381. attn_output = self.o_proj(attn_output)
  382. return attn_output, None
  383. @use_kernel_forward_from_hub("RMSNorm")
  384. class DiffLlamaRMSNorm(nn.Module):
  385. def __init__(self, hidden_size, eps=1e-6):
  386. """
  387. DiffLlamaRMSNorm is equivalent to T5LayerNorm
  388. """
  389. super().__init__()
  390. self.weight = nn.Parameter(torch.ones(hidden_size))
  391. self.variance_epsilon = eps
  392. def forward(self, hidden_states):
  393. input_dtype = hidden_states.dtype
  394. hidden_states = hidden_states.to(torch.float32)
  395. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  396. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  397. return self.weight * hidden_states.to(input_dtype)
  398. def extra_repr(self):
  399. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  400. DIFFLLAMA_ATTENTION_CLASSES = {
  401. "eager": DiffLlamaAttention,
  402. "flash_attention_2": DiffLlamaFlashAttention2,
  403. "sdpa": DiffLlamaSdpaAttention,
  404. }
  405. class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
  406. def __init__(self, config: DiffLlamaConfig, layer_idx: int):
  407. super().__init__()
  408. self.hidden_size = config.hidden_size
  409. self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  410. self.mlp = DiffLlamaMLP(config)
  411. self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  412. self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  413. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  414. def forward(
  415. self,
  416. hidden_states: torch.Tensor,
  417. attention_mask: Optional[torch.Tensor] = None,
  418. position_ids: Optional[torch.LongTensor] = None,
  419. past_key_values: Optional[Cache] = None,
  420. use_cache: Optional[bool] = False,
  421. cache_position: Optional[torch.LongTensor] = None,
  422. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  423. **kwargs: Unpack[TransformersKwargs],
  424. ) -> torch.Tensor:
  425. residual = hidden_states
  426. hidden_states = self.input_layernorm(hidden_states)
  427. # Self Attention
  428. hidden_states, _ = self.self_attn(
  429. hidden_states=hidden_states,
  430. attention_mask=attention_mask,
  431. position_ids=position_ids,
  432. past_key_values=past_key_values,
  433. use_cache=use_cache,
  434. cache_position=cache_position,
  435. position_embeddings=position_embeddings,
  436. **kwargs,
  437. )
  438. hidden_states = residual + hidden_states
  439. # Fully Connected
  440. residual = hidden_states
  441. hidden_states = self.post_attention_layernorm(hidden_states)
  442. hidden_states = self.mlp(hidden_states)
  443. hidden_states = residual + hidden_states
  444. return hidden_states
  445. @auto_docstring
  446. class DiffLlamaPreTrainedModel(PreTrainedModel):
  447. config: DiffLlamaConfig
  448. base_model_prefix = "model"
  449. supports_gradient_checkpointing = True
  450. _no_split_modules = ["DiffLlamaDecoderLayer"]
  451. _skip_keys_device_placement = ["past_key_values"]
  452. _supports_flash_attn = True
  453. _supports_sdpa = True
  454. _supports_flex_attn = False
  455. _can_compile_fullgraph = True
  456. _supports_attention_backend = False
  457. _can_record_outputs = {
  458. "hidden_states": DiffLlamaDecoderLayer,
  459. "attentions": DiffLlamaAttention,
  460. }
  461. def _init_weights(self, module):
  462. super()._init_weights(module)
  463. if isinstance(module, DiffLlamaAttention):
  464. module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
  465. module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
  466. module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
  467. module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
  468. class DiffLlamaRotaryEmbedding(nn.Module):
  469. inv_freq: torch.Tensor # fix linting for `register_buffer`
  470. def __init__(self, config: DiffLlamaConfig, device=None):
  471. super().__init__()
  472. # BC: "rope_type" was originally "type"
  473. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  474. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  475. else:
  476. self.rope_type = "default"
  477. self.max_seq_len_cached = config.max_position_embeddings
  478. self.original_max_seq_len = config.max_position_embeddings
  479. self.config = config
  480. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  481. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  482. self.register_buffer("inv_freq", inv_freq, persistent=False)
  483. self.original_inv_freq = self.inv_freq
  484. @torch.no_grad()
  485. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  486. def forward(self, x, position_ids):
  487. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  488. position_ids_expanded = position_ids[:, None, :].float()
  489. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  490. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  491. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  492. emb = torch.cat((freqs, freqs), dim=-1)
  493. cos = emb.cos() * self.attention_scaling
  494. sin = emb.sin() * self.attention_scaling
  495. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  496. @auto_docstring
  497. class DiffLlamaModel(DiffLlamaPreTrainedModel):
  498. def __init__(self, config: DiffLlamaConfig):
  499. super().__init__(config)
  500. self.padding_idx = config.pad_token_id
  501. self.vocab_size = config.vocab_size
  502. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  503. self.layers = nn.ModuleList(
  504. [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  505. )
  506. self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  507. self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)
  508. self.gradient_checkpointing = False
  509. # Initialize weights and apply final processing
  510. self.post_init()
  511. @check_model_inputs()
  512. @auto_docstring
  513. def forward(
  514. self,
  515. input_ids: Optional[torch.LongTensor] = None,
  516. attention_mask: Optional[torch.Tensor] = None,
  517. position_ids: Optional[torch.LongTensor] = None,
  518. past_key_values: Optional[Cache] = None,
  519. inputs_embeds: Optional[torch.FloatTensor] = None,
  520. cache_position: Optional[torch.LongTensor] = None,
  521. use_cache: Optional[bool] = None,
  522. **kwargs: Unpack[TransformersKwargs],
  523. ) -> BaseModelOutputWithPast:
  524. if (input_ids is None) ^ (inputs_embeds is not None):
  525. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  526. if inputs_embeds is None:
  527. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  528. if use_cache and past_key_values is None:
  529. past_key_values = DynamicCache(config=self.config)
  530. if cache_position is None:
  531. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  532. cache_position: torch.Tensor = torch.arange(
  533. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  534. )
  535. if position_ids is None:
  536. position_ids = cache_position.unsqueeze(0)
  537. causal_mask = create_causal_mask(
  538. config=self.config,
  539. input_embeds=inputs_embeds,
  540. attention_mask=attention_mask,
  541. cache_position=cache_position,
  542. past_key_values=past_key_values,
  543. position_ids=position_ids,
  544. )
  545. hidden_states = inputs_embeds
  546. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  547. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  548. hidden_states = decoder_layer(
  549. hidden_states,
  550. attention_mask=causal_mask,
  551. position_ids=position_ids,
  552. past_key_values=past_key_values,
  553. cache_position=cache_position,
  554. position_embeddings=position_embeddings,
  555. **kwargs,
  556. )
  557. hidden_states = self.norm(hidden_states)
  558. return BaseModelOutputWithPast(
  559. last_hidden_state=hidden_states,
  560. past_key_values=past_key_values,
  561. )
  562. @auto_docstring
  563. class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
  564. _tied_weights_keys = ["lm_head.weight"]
  565. _tp_plan = {"lm_head": "colwise_rep"}
  566. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  567. def __init__(self, config):
  568. super().__init__(config)
  569. self.model = DiffLlamaModel(config)
  570. self.vocab_size = config.vocab_size
  571. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  572. # Initialize weights and apply final processing
  573. self.post_init()
  574. @can_return_tuple
  575. @auto_docstring
  576. def forward(
  577. self,
  578. input_ids: Optional[torch.LongTensor] = None,
  579. attention_mask: Optional[torch.Tensor] = None,
  580. position_ids: Optional[torch.LongTensor] = None,
  581. past_key_values: Optional[Cache] = None,
  582. inputs_embeds: Optional[torch.FloatTensor] = None,
  583. labels: Optional[torch.LongTensor] = None,
  584. use_cache: Optional[bool] = None,
  585. cache_position: Optional[torch.LongTensor] = None,
  586. logits_to_keep: Union[int, torch.Tensor] = 0,
  587. **kwargs: Unpack[TransformersKwargs],
  588. ) -> CausalLMOutputWithPast:
  589. r"""
  590. Example:
  591. ```python
  592. >>> from transformers import AutoTokenizer, DiffLlamaForCausalLM
  593. >>> model = DiffLlamaForCausalLM.from_pretrained("google/diffllama-7b")
  594. >>> tokenizer = AutoTokenizer.from_pretrained("google/diffllama-7b")
  595. >>> prompt = "What is your favorite condiment?"
  596. >>> inputs = tokenizer(prompt, return_tensors="pt")
  597. >>> # Generate
  598. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  599. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  600. "What is your favorite condiment?"
  601. ```"""
  602. outputs: BaseModelOutputWithPast = self.model(
  603. input_ids=input_ids,
  604. attention_mask=attention_mask,
  605. position_ids=position_ids,
  606. past_key_values=past_key_values,
  607. inputs_embeds=inputs_embeds,
  608. use_cache=use_cache,
  609. cache_position=cache_position,
  610. **kwargs,
  611. )
  612. hidden_states = outputs.last_hidden_state
  613. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  614. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  615. logits = self.lm_head(hidden_states[:, slice_indices, :])
  616. loss = None
  617. if labels is not None:
  618. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  619. return CausalLMOutputWithPast(
  620. loss=loss,
  621. logits=logits,
  622. past_key_values=outputs.past_key_values,
  623. hidden_states=outputs.hidden_states,
  624. attentions=outputs.attentions,
  625. )
  626. class DiffLlamaForSequenceClassification(GenericForSequenceClassification, DiffLlamaPreTrainedModel):
  627. pass
  628. class DiffLlamaForQuestionAnswering(GenericForQuestionAnswering, DiffLlamaPreTrainedModel):
  629. base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
  630. class DiffLlamaForTokenClassification(GenericForTokenClassification, DiffLlamaPreTrainedModel):
  631. pass
  632. __all__ = [
  633. "DiffLlamaPreTrainedModel",
  634. "DiffLlamaModel",
  635. "DiffLlamaForCausalLM",
  636. "DiffLlamaForSequenceClassification",
  637. "DiffLlamaForQuestionAnswering",
  638. "DiffLlamaForTokenClassification",
  639. ]