modular_diffllama.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. # coding=utf-8
  2. # Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on Llama implementations in this library and Microsoft's
  5. # Differential Transformer implementations.
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import math
  18. from typing import Optional
  19. import torch
  20. from torch import nn
  21. from ...cache_utils import Cache, StaticCache
  22. from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
  23. from ...modeling_utils import PreTrainedModel
  24. from ...utils import logging
  25. from ...utils.deprecation import deprecate_kwarg
  26. from ..gemma.modeling_gemma import GemmaForCausalLM
  27. from ..llama.modeling_llama import (
  28. LlamaDecoderLayer,
  29. LlamaForQuestionAnswering,
  30. LlamaForSequenceClassification,
  31. LlamaForTokenClassification,
  32. LlamaModel,
  33. LlamaPreTrainedModel,
  34. apply_rotary_pos_emb,
  35. repeat_kv,
  36. )
  37. from ..mistral.modeling_mistral import MistralMLP
  38. from .configuration_diffllama import DiffLlamaConfig
  39. logger = logging.get_logger(__name__)
  40. _CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
  41. _CONFIG_FOR_DOC = "DiffLlamaConfig"
  42. class DiffLlamaMLP(MistralMLP):
  43. pass
  44. def lambda_init_fn(layer_idx):
  45. return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
  46. class DiffLlamaAttention(nn.Module):
  47. """Multi-headed attention from 'Attention Is All You Need' paper"""
  48. def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
  49. super().__init__()
  50. self.config = config
  51. self.layer_idx = layer_idx
  52. if layer_idx is None:
  53. logger.warning_once(
  54. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  55. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  56. "when creating this class."
  57. )
  58. self.attention_dropout = config.attention_dropout
  59. self.hidden_size = config.hidden_size
  60. self.num_heads = config.num_attention_heads
  61. self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
  62. self.num_key_value_heads = config.num_key_value_heads
  63. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  64. # under this are not used
  65. self.max_position_embeddings = config.max_position_embeddings
  66. self.rope_theta = config.rope_theta
  67. self.is_causal = True
  68. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  69. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  70. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  71. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
  72. self.lambda_init = lambda_init_fn(layer_idx)
  73. self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  74. self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  75. self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  76. self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  77. self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
  78. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  79. def forward(
  80. self,
  81. hidden_states: torch.Tensor,
  82. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  83. attention_mask: Optional[torch.Tensor] = None,
  84. position_ids: Optional[torch.LongTensor] = None,
  85. past_key_values: Optional[Cache] = None,
  86. use_cache: bool = False,
  87. cache_position: Optional[torch.LongTensor] = None,
  88. **kwargs,
  89. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  90. bsz, target_len, _ = hidden_states.size()
  91. q_len = target_len
  92. query_states = self.q_proj(hidden_states)
  93. key_states = self.k_proj(hidden_states)
  94. value_states = self.v_proj(hidden_states)
  95. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  96. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  97. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  98. cos, sin = position_embeddings
  99. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  100. if past_key_values is not None:
  101. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  102. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  103. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  104. key_states = repeat_kv(key_states, self.num_key_value_groups)
  105. value_states = repeat_kv(value_states, self.num_key_value_groups)
  106. value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
  107. value_states = value_states.repeat(1, 2, 1, 1)
  108. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  109. if attention_mask is not None: # no matter the length, we just slice it
  110. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  111. attn_weights = attn_weights + causal_mask
  112. # upcast attention to fp32
  113. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  114. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  115. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  116. query_states.dtype
  117. )
  118. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  119. query_states.dtype
  120. )
  121. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  122. attn_output = torch.matmul(attn_weights, value_states)
  123. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
  124. attn_output = attn_output1 - lambda_full * attn_output2
  125. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  126. attn_output = attn_output.transpose(1, 2).contiguous()
  127. attn_output = attn_output.reshape(bsz, q_len, -1)
  128. attn_output = self.o_proj(attn_output)
  129. return attn_output, attn_weights
  130. class DiffLlamaFlashAttention2(DiffLlamaAttention):
  131. """
  132. DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
  133. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  134. flash attention and deal with padding tokens in case the input contains any of them.
  135. """
  136. def __init__(self, *args, **kwargs):
  137. super().__init__(*args, **kwargs)
  138. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  139. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  140. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  141. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  142. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  143. def forward(
  144. self,
  145. hidden_states: torch.Tensor,
  146. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  147. attention_mask: Optional[torch.LongTensor] = None,
  148. position_ids: Optional[torch.LongTensor] = None,
  149. past_key_values: Optional[Cache] = None,
  150. use_cache: bool = False,
  151. cache_position: Optional[torch.LongTensor] = None,
  152. ) -> tuple[torch.Tensor, None]:
  153. if isinstance(past_key_values, StaticCache):
  154. raise ValueError(
  155. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  156. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  157. )
  158. bsz, q_len, _ = hidden_states.size()
  159. query_states = self.q_proj(hidden_states)
  160. key_states = self.k_proj(hidden_states)
  161. value_states = self.v_proj(hidden_states)
  162. # Flash attention requires the input to have the shape
  163. # batch_size x seq_length x head_dim x hidden_dim
  164. # therefore we just need to keep the original shape
  165. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  166. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  167. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  168. if position_embeddings is None:
  169. logger.warning_once(
  170. "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
  171. "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
  172. "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
  173. "removed and `position_embeddings` will be mandatory."
  174. )
  175. cos, sin = self.rotary_emb(value_states, position_ids)
  176. else:
  177. cos, sin = position_embeddings
  178. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  179. if past_key_values is not None:
  180. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  181. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  182. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  183. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  184. # to be able to avoid many of these transpose/reshape/view.
  185. query_states = query_states.transpose(1, 2)
  186. key_states = key_states.transpose(1, 2)
  187. value_states = value_states.transpose(1, 2)
  188. dropout_rate = self.attention_dropout if self.training else 0.0
  189. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  190. # therefore the input hidden states gets silently casted in float32. Hence, we need
  191. # cast them back in the correct dtype just to be sure everything works as expected.
  192. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  193. # in fp32. (DiffLlamaRMSNorm handles it correctly)
  194. input_dtype = query_states.dtype
  195. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  196. if input_dtype == torch.float32:
  197. if torch.is_autocast_enabled():
  198. target_dtype = (
  199. torch.get_autocast_dtype(device_type)
  200. if hasattr(torch, "get_autocast_dtype")
  201. else torch.get_autocast_gpu_dtype()
  202. )
  203. # Handle the case where the model is quantized
  204. elif hasattr(self.config, "_pre_quantization_dtype"):
  205. target_dtype = self.config._pre_quantization_dtype
  206. else:
  207. target_dtype = self.q_proj.weight.dtype
  208. logger.warning_once(
  209. f"The input hidden states seems to be silently casted in float32, this might be related to"
  210. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  211. f" {target_dtype}."
  212. )
  213. query_states = query_states.to(target_dtype)
  214. key_states = key_states.to(target_dtype)
  215. value_states = value_states.to(target_dtype)
  216. value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
  217. value_states1 = value_states1.repeat(1, 1, 2, 1)
  218. value_states2 = value_states2.repeat(1, 1, 2, 1)
  219. attn_output1 = _flash_attention_forward(
  220. query_states,
  221. key_states,
  222. value_states1,
  223. attention_mask,
  224. q_len,
  225. position_ids=position_ids,
  226. dropout=dropout_rate,
  227. sliding_window=getattr(self, "sliding_window", None),
  228. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  229. is_causal=self.is_causal,
  230. )
  231. attn_output2 = _flash_attention_forward(
  232. query_states,
  233. key_states,
  234. value_states2,
  235. attention_mask,
  236. q_len,
  237. position_ids=position_ids,
  238. dropout=dropout_rate,
  239. sliding_window=getattr(self, "sliding_window", None),
  240. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  241. is_causal=self.is_causal,
  242. )
  243. attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
  244. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
  245. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  246. query_states.dtype
  247. )
  248. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  249. query_states.dtype
  250. )
  251. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  252. attn_output = attn_output1 - lambda_full * attn_output2
  253. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  254. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  255. attn_output = self.o_proj(attn_output)
  256. return attn_output, None
  257. class DiffLlamaSdpaAttention(DiffLlamaAttention):
  258. """
  259. DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  260. `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  261. SDPA API.
  262. """
  263. # Adapted from DiffLlamaAttention.forward
  264. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  265. def forward(
  266. self,
  267. hidden_states: torch.Tensor,
  268. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  269. attention_mask: Optional[torch.Tensor] = None,
  270. position_ids: Optional[torch.LongTensor] = None,
  271. past_key_values: Optional[Cache] = None,
  272. use_cache: bool = False,
  273. cache_position: Optional[torch.LongTensor] = None,
  274. **kwargs,
  275. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  276. bsz, q_len, _ = hidden_states.size()
  277. query_states = self.q_proj(hidden_states)
  278. key_states = self.k_proj(hidden_states)
  279. value_states = self.v_proj(hidden_states)
  280. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  281. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  282. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  283. cos, sin = position_embeddings
  284. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  285. if past_key_values is not None:
  286. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  287. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  288. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  289. key_states = repeat_kv(key_states, self.num_key_value_groups)
  290. value_states = repeat_kv(value_states, self.num_key_value_groups)
  291. value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
  292. value_states = value_states.repeat(1, 2, 1, 1)
  293. causal_mask = attention_mask
  294. if attention_mask is not None:
  295. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  296. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  297. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  298. if query_states.device.type == "cuda" and causal_mask is not None:
  299. query_states = query_states.contiguous()
  300. key_states = key_states.contiguous()
  301. value_states = value_states.contiguous()
  302. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  303. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  304. is_causal = causal_mask is None and q_len > 1
  305. attn_output = torch.nn.functional.scaled_dot_product_attention(
  306. query_states,
  307. key_states,
  308. value_states,
  309. attn_mask=causal_mask,
  310. dropout_p=self.attention_dropout if self.training else 0.0,
  311. is_causal=is_causal,
  312. )
  313. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
  314. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  315. query_states.dtype
  316. )
  317. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  318. query_states.dtype
  319. )
  320. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  321. attn_output = attn_output1 - lambda_full * attn_output2
  322. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  323. attn_output = attn_output.transpose(1, 2).contiguous()
  324. attn_output = attn_output.view(bsz, q_len, -1)
  325. attn_output = self.o_proj(attn_output)
  326. return attn_output, None
  327. DIFFLLAMA_ATTENTION_CLASSES = {
  328. "eager": DiffLlamaAttention,
  329. "flash_attention_2": DiffLlamaFlashAttention2,
  330. "sdpa": DiffLlamaSdpaAttention,
  331. }
  332. class DiffLlamaDecoderLayer(LlamaDecoderLayer):
  333. def __init__(self, config: DiffLlamaConfig, layer_idx: int):
  334. super().__init__(config, layer_idx)
  335. self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  336. class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
  337. _supports_flex_attn = False
  338. _supports_attention_backend = False
  339. def _init_weights(self, module):
  340. PreTrainedModel._init_weights(self, module)
  341. if isinstance(module, DiffLlamaAttention):
  342. module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
  343. module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
  344. module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
  345. module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
  346. class DiffLlamaModel(LlamaModel):
  347. pass
  348. class DiffLlamaForCausalLM(GemmaForCausalLM):
  349. pass
  350. class DiffLlamaForSequenceClassification(LlamaForSequenceClassification):
  351. pass
  352. class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering):
  353. pass
  354. class DiffLlamaForTokenClassification(LlamaForTokenClassification):
  355. pass
  356. __all__ = [
  357. "DiffLlamaPreTrainedModel",
  358. "DiffLlamaModel",
  359. "DiffLlamaForCausalLM",
  360. "DiffLlamaForSequenceClassification",
  361. "DiffLlamaForQuestionAnswering",
  362. "DiffLlamaForTokenClassification",
  363. ]