flex_attention.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. """
  2. Partially inspired by torchtune's flex attention implementation
  3. Citation:
  4. @software{torchtune,
  5. title = {torchtune: PyTorch's finetuning library},
  6. author = {torchtune maintainers and contributors},
  7. url = {https//github.com/pytorch/torchtune},
  8. license = {BSD-3-Clause},
  9. month = apr,
  10. year = {2024}
  11. }
  12. """
  13. # coding=utf-8
  14. # Copyright 2025 The HuggingFace Inc. team.
  15. #
  16. # Licensed under the Apache License, Version 2.0 (the "License");
  17. # you may not use this file except in compliance with the License.
  18. # You may obtain a copy of the License at
  19. #
  20. # http://www.apache.org/licenses/LICENSE-2.0
  21. #
  22. # Unless required by applicable law or agreed to in writing, software
  23. # distributed under the License is distributed on an "AS IS" BASIS,
  24. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. # See the License for the specific language governing permissions and
  26. # limitations under the License.
  27. from typing import Optional, Union
  28. import torch
  29. from packaging import version
  30. from ..utils import is_torch_flex_attn_available, logging
  31. from ..utils.import_utils import _torch_version, is_torch_less_or_equal, is_torchdynamo_compiling
  32. if is_torch_flex_attn_available():
  33. from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size
  34. from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
  35. logger = logging.get_logger(__name__)
  36. class WrappedFlexAttention:
  37. """
  38. We are doing a singleton class so that flex attention is compiled once when it's first called.
  39. """
  40. _instance = None
  41. _is_flex_compiled = False
  42. _compiled_flex_attention = None
  43. def __new__(cls, *args, **kwargs):
  44. if cls._instance is None:
  45. # Create a new instance if one doesn't already exist
  46. cls._instance = super().__new__(cls)
  47. return cls._instance
  48. @torch.compiler.disable(recursive=False)
  49. def __init__(self, training):
  50. """
  51. Initialize or update the singleton instance.
  52. """
  53. if not self._is_flex_compiled or training != self.training:
  54. self.training = training
  55. if is_torch_less_or_equal("2.5.1"):
  56. self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
  57. # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
  58. # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
  59. # see https://github.com/pytorch/pytorch/issues/146260 for training
  60. elif version.parse(_torch_version).base_version == "2.6.0" and training:
  61. self._compiled_flex_attention = torch.compile(
  62. flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
  63. )
  64. # Fallback, usually the most recent torch 2.7.x+ versions
  65. else:
  66. self._compiled_flex_attention = torch.compile(flex_attention)
  67. self._is_flex_compiled = True
  68. def __call__(self):
  69. return self._compiled_flex_attention
  70. def compile_friendly_flex_attention(
  71. query: torch.Tensor,
  72. key: torch.Tensor,
  73. value: torch.Tensor,
  74. training=False,
  75. **kwargs,
  76. ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
  77. # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
  78. # Do not use compiled version if already compiling forward (it raises issues)
  79. flex_attention_compiled = WrappedFlexAttention(training)() if not is_torchdynamo_compiling() else flex_attention
  80. return flex_attention_compiled(
  81. query,
  82. key,
  83. value,
  84. **kwargs,
  85. )
  86. Offset = Union[torch.Tensor, int]
  87. # TODO: deprecate / rename to make_flex_block_mask for clarity as it's not only causal anymore
  88. def make_flex_block_causal_mask(
  89. attention_mask_2d: torch.Tensor,
  90. attention_chunk_size: Optional[int] = None,
  91. query_length=None,
  92. key_length=None,
  93. offsets: Optional[tuple[Offset, Offset]] = None,
  94. is_causal: Optional[bool] = True,
  95. ) -> "BlockMask":
  96. """
  97. IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`,
  98. and will be removed in a future version without warnings. New code should not use it. It is only kept here
  99. for BC for now, while models using it are being patched accordingly.
  100. Create a block (causal) document mask for a batch of sequences, both packed and unpacked.
  101. Create Block (causal) logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
  102. The resultant BlockMask is a compressed representation of the full (causal) block
  103. mask. BlockMask is essential for performant computation of flex attention.
  104. See: https://pytorch.org/blog/flexattention/
  105. Args:
  106. attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
  107. of shape (batch_size, total_seq_len). e.g.
  108. For unpacked sequence:
  109. [[1, 1, 1, 1, 0, 0, 0],
  110. [1, 1, 1, 1, 1, 0, 0]]
  111. For packed sequence:
  112. [[1, 1, 1, 2, 2, 2, 0],
  113. [1, 1, 2, 2, 2, 3, 3]]
  114. Returns:
  115. BlockMask
  116. """
  117. batch_size, total_seq_len = attention_mask_2d.shape
  118. if not key_length:
  119. key_length = total_seq_len
  120. if not query_length:
  121. query_length = total_seq_len
  122. # older torch (2.5.x) cannot handle sequences not in multiples of 128 (default block size)
  123. pad_len = ((key_length // flex_default_block_size) + 1) * flex_default_block_size
  124. attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, pad_len - key_length))
  125. device = attention_mask_2d.device
  126. document_ids = attention_mask_2d.clone()
  127. if attention_chunk_size is not None:
  128. # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
  129. chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
  130. # Instead of passing a tensor mask, flex attention requires a mask_mod function
  131. # that determines which elements of QK^T should be included in the attention
  132. # computation prior to the softmax. For sample packing, we need both the
  133. # logic for both causal mask and document mask. See PyTorch's official
  134. # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
  135. def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
  136. """
  137. Defines the logic of a block causal mask by combining both a standard causal mask
  138. and a block diagonal document mask.
  139. See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
  140. for an illustration.
  141. """
  142. causal_mask = q_idx >= kv_idx # not valid when decoding
  143. document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
  144. padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
  145. final_mask = causal_mask & padding_mask & document_mask
  146. return final_mask
  147. def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
  148. """
  149. Combines the chunk mask with the causal mask for chunked attention.
  150. """
  151. chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
  152. causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
  153. return chunk_mask & causal_doc_mask
  154. def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
  155. """
  156. Utilizes default attention mask to enable encoder and encoder-decoder
  157. attention masks.
  158. """
  159. document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
  160. # kv indexing is crucial in order to work correctly
  161. padding_mask = attention_mask_2d[batch_idx, kv_idx] > 0
  162. final_mask = padding_mask & document_mask
  163. return final_mask
  164. if not is_causal:
  165. mask_mod_maybe_combined = default_mask_mod
  166. else:
  167. mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
  168. if offsets is not None:
  169. q_offset = offsets[0].to(device)
  170. kv_offset = offsets[1].to(device)
  171. def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
  172. offset_q = q_idx + q_offset
  173. offset_kv = kv_idx + kv_offset
  174. return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
  175. else:
  176. mask_mod = mask_mod_maybe_combined
  177. return create_block_mask(
  178. mask_mod=mask_mod,
  179. B=batch_size,
  180. H=None, # attention head
  181. Q_LEN=query_length,
  182. KV_LEN=key_length,
  183. device=device,
  184. # compiling the mask is not BC with older torch
  185. _compile=not is_torch_less_or_equal("2.5.1"),
  186. )
  187. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  188. """
  189. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  190. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  191. """
  192. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  193. if n_rep == 1:
  194. return hidden_states
  195. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  196. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  197. def flex_attention_forward(
  198. module: torch.nn.Module,
  199. query: torch.Tensor,
  200. key: torch.Tensor,
  201. value: torch.Tensor,
  202. attention_mask: Union[torch.Tensor, "BlockMask"],
  203. scaling: Optional[float] = None,
  204. softcap: Optional[float] = None,
  205. head_mask: Optional[torch.Tensor] = None,
  206. s_aux: Optional[torch.Tensor] = None,
  207. **kwargs,
  208. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  209. if head_mask is not None:
  210. logger.warning_once(
  211. "`flex_attention` does not support `head_mask`. Please set your attention to `eager` if you want this feature."
  212. )
  213. if kwargs.get("dropout", 0.0) > 0:
  214. raise ValueError(
  215. "`flex_attention` does not support `dropout`. Please use it with inference"
  216. " only (`model.eval()`) or turn off the attention dropout in the respective config."
  217. )
  218. block_mask = None
  219. score_mask = None
  220. if isinstance(attention_mask, BlockMask):
  221. block_mask = attention_mask
  222. else:
  223. score_mask = attention_mask
  224. if score_mask is not None:
  225. score_mask = score_mask[:, :, :, : key.shape[-2]]
  226. def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
  227. if softcap is not None:
  228. score = softcap * torch.tanh(score / softcap)
  229. if score_mask is not None:
  230. score = score + score_mask[batch_idx][0][q_idx][kv_idx]
  231. if head_mask is not None:
  232. score = score + head_mask[batch_idx][head_idx][0][0]
  233. # Note: attention sinks cannot be correctly implemented in score_mod
  234. # because it requires operating on the full attention matrix before softmax.
  235. # ==> this is done after flex attention
  236. return score
  237. enable_gqa = True
  238. num_local_query_heads = query.shape[1]
  239. # When running TP this helps:
  240. if (num_local_query_heads & (num_local_query_heads - 1)) != 0:
  241. key = repeat_kv(key, query.shape[1] // key.shape[1])
  242. value = repeat_kv(value, query.shape[1] // value.shape[1])
  243. enable_gqa = False
  244. kernel_options = kwargs.get("kernel_options")
  245. # On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it
  246. return_lse = query.device.type != "cpu"
  247. if not return_lse and s_aux is not None:
  248. raise ValueError(
  249. "Attention sinks cannot be run on CPU with flex attention. Please switch to a different device, e.g. CUDA"
  250. )
  251. flex_attention_output = compile_friendly_flex_attention(
  252. query,
  253. key,
  254. value,
  255. score_mod=score_mod,
  256. block_mask=block_mask,
  257. enable_gqa=enable_gqa,
  258. scale=scaling,
  259. kernel_options=kernel_options,
  260. # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
  261. # For simplification, we thus always return it as no additional computations are introduced.
  262. return_lse=return_lse,
  263. training=module.training,
  264. )
  265. # lse is returned in float32
  266. if return_lse:
  267. attention_output, lse = flex_attention_output # type: ignore[misc]
  268. lse = lse.to(value.dtype)
  269. if s_aux is not None:
  270. # Apply attention sinks by renormalizing using LSE
  271. batch_size, num_heads, seq_len_q, _ = attention_output.shape # batch, num_heads, seq_len, head_dim
  272. sinks = s_aux.view(1, -1, 1, 1).expand(batch_size, num_heads, seq_len_q, 1)
  273. # We need to compute the normalization that includes the sinks
  274. # since log(sum(exp(scores))) = lse, exp(log(sum(exp(scores)))) = exp(lse)
  275. # NB: log(sum(exp(scores)) + exp(sink)) = log(exp(lse) + exp(sink))
  276. lse_expanded = lse.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
  277. combined_lse = torch.logsumexp(torch.cat([lse_expanded, sinks], dim=-1), dim=-1, keepdim=True)
  278. # Use new_norm / old_norm = exp(combined_lse - lse) to compute renorm and apply
  279. renorm_factor = torch.exp(lse_expanded - combined_lse)
  280. attention_output = attention_output * renorm_factor
  281. else:
  282. attention_output = flex_attention_output # type: ignore[assignment]
  283. lse = None
  284. attention_output = attention_output.transpose(1, 2).contiguous()
  285. return attention_output, lse