modeling_flash_attention_utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668
  1. # Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import os
  16. from functools import partial
  17. from typing import Optional, TypedDict
  18. import torch
  19. import torch.nn.functional as F
  20. from .utils import (
  21. is_flash_attn_2_available,
  22. is_flash_attn_3_available,
  23. is_flash_attn_greater_or_equal_2_10,
  24. is_torch_npu_available,
  25. logging,
  26. )
  27. logger = logging.get_logger(__name__)
  28. # TODO Deprecate when all models have the attention interface
  29. def flash_attn_supports_top_left_mask():
  30. if is_flash_attn_3_available():
  31. return False
  32. if is_flash_attn_2_available():
  33. return not is_flash_attn_greater_or_equal_2_10()
  34. from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
  35. return is_npu_fa2_top_left_aligned_causal_mask()
  36. # TODO Deprecate when all models have the attention interface
  37. def is_flash_attn_available():
  38. return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
  39. # `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
  40. _flash_fn = None
  41. _flash_varlen_fn = None
  42. _pad_fn = None
  43. _unpad_fn = None
  44. # function that processes kwargs, generalized to handle any supported kwarg within the function
  45. _process_flash_kwargs_fn = None
  46. # exceptions where hf API doesn't match the original flash attention API
  47. _hf_api_to_flash_mapping = {
  48. "dropout": "dropout_p",
  49. "sliding_window": "window_size",
  50. }
  51. def _lazy_imports(implementation: Optional[str]):
  52. """
  53. Lazy loads the respective flash attention implementations.
  54. Return:
  55. flash_attn_func: The base flash attention function.
  56. flash_attn_varlen_func: The flash attention function supporting variable sequence lengths,
  57. e.g. for padding-free training.
  58. pad_input: The function to pad inputs into one sequence and returning the respective kwargs.
  59. unpad_input: The function to unpad outputs based on the kwargs (from pad_input).
  60. """
  61. is_fa2 = is_flash_attn_2_available()
  62. is_fa3 = is_flash_attn_3_available()
  63. pad_input, unpad_input = _pad_input, _unpad_input
  64. if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3):
  65. from flash_attn import flash_attn_func, flash_attn_varlen_func
  66. from flash_attn.bert_padding import pad_input, unpad_input
  67. elif is_torch_npu_available():
  68. # Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError
  69. # Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module
  70. from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
  71. from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
  72. else:
  73. if implementation == "flash_attention_3" or (implementation is None and is_fa3):
  74. from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
  75. # Kernels fallback
  76. else:
  77. flash_attn_func = getattr(implementation, "flash_attn_func", None)
  78. flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None)
  79. if flash_attn_varlen_func is None or flash_attn_func is None:
  80. raise ValueError(
  81. f"Could not find the currently requested flash attention implementation at `{implementation}`."
  82. f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`."
  83. )
  84. return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input
  85. def _lazy_define_process_function(flash_function):
  86. """
  87. Depending on the version and kernel some features are not supported. Due to limitations in
  88. `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported
  89. within `_process_flash_attention_kwargs`.
  90. NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`.
  91. This might be confusing for kwargs that we use in any case, e.g. `is_causal`.
  92. """
  93. flash_parameters = inspect.signature(flash_function).parameters
  94. process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters
  95. supports_mapping = {}
  96. for param in process_parameters:
  97. fa_param = _hf_api_to_flash_mapping.get(param, param)
  98. supports_mapping[fa_param] = fa_param in flash_parameters
  99. return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
  100. def lazy_import_flash_attention(implementation: Optional[str], force_import: Optional[bool] = False):
  101. """
  102. Lazily import flash attention and return the respective functions + flags.
  103. NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
  104. work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
  105. """
  106. global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
  107. if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
  108. _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
  109. global _process_flash_kwargs_fn
  110. if force_import or _process_flash_kwargs_fn is None:
  111. _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
  112. return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
  113. def _index_first_axis(tensor, indices):
  114. """
  115. A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
  116. after flattening the first two dimensions of the tensor. This is functionally equivalent to
  117. FA2's `index_first_axis` and replaces the need to import it.
  118. """
  119. # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
  120. # two dimensions to get (total_tokens, ...) before indexing.
  121. reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
  122. return reshaped_tensor[indices]
  123. def _unpad_input(hidden_states, attention_mask, unused_mask=None):
  124. """
  125. unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
  126. Arguments:
  127. hidden_states: (batch, seqlen, ...)
  128. attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
  129. unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
  130. Return:
  131. hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
  132. indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
  133. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
  134. max_seqlen_in_batch: int
  135. seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
  136. """
  137. all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
  138. seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
  139. used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  140. indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
  141. max_seqlen_in_batch = seqlens_in_batch.max().item()
  142. cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  143. return (
  144. _index_first_axis(hidden_states, indices),
  145. indices,
  146. cu_seqlens,
  147. max_seqlen_in_batch,
  148. used_seqlens_in_batch,
  149. )
  150. def _pad_input(hidden_states, indices, batch, seqlen):
  151. """
  152. pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
  153. Arguments:
  154. hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
  155. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
  156. batch: int, batch size for the padded sequence.
  157. seqlen: int, maximum sequence length for the padded sequence.
  158. Return:
  159. hidden_states: (batch, seqlen, ...)
  160. """
  161. dim = hidden_states.shape[1:]
  162. output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
  163. output[indices] = hidden_states
  164. return output.view(batch, seqlen, *dim)
  165. def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
  166. """
  167. Retrieves indexing data required to repad unpadded (ragged) tensors.
  168. Arguments:
  169. attention_mask (`torch.Tensor`):
  170. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  171. Return:
  172. indices (`torch.Tensor`):
  173. The indices of non-masked tokens from the flattened input sequence.
  174. cu_seqlens (`torch.Tensor`):
  175. The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  176. max_seqlen_in_batch (`int`):
  177. Maximum sequence length in batch.
  178. """
  179. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  180. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  181. # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
  182. # this might cause a graph break
  183. max_seqlen_in_batch = seqlens_in_batch.max().item()
  184. cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  185. return (
  186. indices,
  187. cu_seqlens,
  188. max_seqlen_in_batch,
  189. )
  190. def _upad_input(
  191. query_layer: torch.Tensor,
  192. key_layer: torch.Tensor,
  193. value_layer: torch.Tensor,
  194. attention_mask: torch.Tensor,
  195. query_length: int,
  196. unpad_input_func,
  197. ):
  198. """
  199. Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
  200. This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
  201. tensors for query, key, value tensors.
  202. Arguments:
  203. query_layer (`torch.Tensor`):
  204. Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
  205. key_layer (`torch.Tensor`):
  206. Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  207. value_layer (`torch.Tensor`):
  208. Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  209. attention_mask (`torch.Tensor`):
  210. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  211. query_length (`int`):
  212. Target length.
  213. unpad_input_func:
  214. The function to use for unpadding the input tensors.
  215. Return:
  216. query_layer (`torch.Tensor`):
  217. Query state without padding. Shape: (total_target_length, num_heads, head_dim).
  218. key_layer (`torch.Tensor`):
  219. Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  220. value_layer (`torch.Tensor`):
  221. Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  222. indices_q (`torch.Tensor`):
  223. The indices of non-masked tokens from the flattened input target sequence.
  224. (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
  225. The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  226. (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
  227. Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
  228. """
  229. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
  230. # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage
  231. # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
  232. if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
  233. key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
  234. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  235. key_layer = _index_first_axis(key_layer, indices_k)
  236. value_layer = _index_first_axis(value_layer, indices_k)
  237. if query_length == kv_seq_len:
  238. query_layer = _index_first_axis(query_layer, indices_k)
  239. cu_seqlens_q = cu_seqlens_k
  240. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  241. indices_q = indices_k
  242. elif query_length == 1:
  243. max_seqlen_in_batch_q = 1
  244. cu_seqlens_q = torch.arange(
  245. batch_size + 1, dtype=torch.int32, device=query_layer.device
  246. ) # There is a memcpy here, that is very bad.
  247. indices_q = cu_seqlens_q[:-1]
  248. query_layer = query_layer.squeeze(1)
  249. else:
  250. # The -q_len: slice assumes left padding.
  251. attention_mask = attention_mask[:, -query_length:]
  252. query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
  253. return (
  254. query_layer,
  255. key_layer,
  256. value_layer,
  257. indices_q,
  258. (cu_seqlens_q, cu_seqlens_k),
  259. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  260. )
  261. def prepare_fa_kwargs_from_position_ids(position_ids):
  262. """
  263. This function returns all the necessary kwargs to call `flash_attn_varlen_func` extracted from position_ids.
  264. Arguments:
  265. position_ids (`torch.Tensor`):
  266. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  267. Return:
  268. (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
  269. The cumulative sequence lengths for the target (query) and source (key, value), used to index into
  270. ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  271. (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
  272. Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
  273. `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
  274. """
  275. tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
  276. position_ids = position_ids.view(-1)
  277. indices_q = (position_ids == 0).nonzero().view(-1)
  278. cu_seq_lens_q = torch.cat(
  279. (
  280. indices_q.to(**tensor_kwargs),
  281. torch.tensor(position_ids.size(), **tensor_kwargs),
  282. )
  283. )
  284. cu_seq_lens_k = cu_seq_lens_q
  285. # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
  286. # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
  287. # for some models (e.g. qwen2-vl).
  288. max_length_q = cu_seq_lens_q.diff().max()
  289. # NOTE: With torch compile, this will cause a graph break if you don't set
  290. # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
  291. # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
  292. # This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
  293. # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
  294. max_length_q = max_length_q.item()
  295. max_length_k = max_length_q
  296. return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)
  297. def _prepare_from_posids(query, key, value, position_ids):
  298. """
  299. This function returns necessary arguments to call `flash_attn_varlen_func`.
  300. All three query, key, value states will be flattened.
  301. Cumulative lengths of each examples in the batch will be extracted from position_ids.
  302. NOTE: ideally cumulative lengths should be prepared at the data collator stage
  303. Arguments:
  304. query (`torch.Tensor`):
  305. Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
  306. key (`torch.Tensor`):
  307. Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  308. value (`torch.Tensor`):
  309. Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  310. position_ids (`torch.Tensor`):
  311. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  312. Return:
  313. query (`torch.Tensor`):
  314. Query state without padding. Shape: (total_target_length, num_heads, head_dim).
  315. key (`torch.Tensor`):
  316. Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  317. value (`torch.Tensor`):
  318. Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  319. (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
  320. The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  321. (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
  322. Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
  323. """
  324. query = query.contiguous().view(-1, query.size(-2), query.size(-1))
  325. key = key.contiguous().view(-1, key.size(-2), key.size(-1))
  326. value = value.contiguous().view(-1, value.size(-2), value.size(-1))
  327. (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(position_ids)
  328. return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))
  329. def _is_packed_sequence(position_ids, batch_size):
  330. """
  331. Check the position ids whether packed sequences are indicated or not
  332. 1. Position ids exist
  333. 2. Flattened sequences only are supported
  334. 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
  335. """
  336. if position_ids is None:
  337. return False
  338. increasing_position_sequences = (
  339. torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
  340. )
  341. return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
  342. def fa_peft_integration_check(
  343. q: torch.Tensor,
  344. k: torch.Tensor,
  345. v: torch.Tensor,
  346. target_dtype: Optional[torch.dtype] = None,
  347. ):
  348. """
  349. PEFT usually casts the layer norms in float32 for training stability reasons
  350. therefore the input hidden states gets silently casted in float32. Hence, we need
  351. cast them back in float16 / bfloat16 just to be sure everything works as expected.
  352. This might slowdown training & inference so it is recommended to not cast the LayerNorms!
  353. """
  354. if target_dtype and q.dtype == torch.float32:
  355. logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
  356. q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
  357. return q, k, v
  358. class FlashAttentionKwargs(TypedDict, total=False):
  359. """
  360. Keyword arguments for Flash Attention with Compile.
  361. Attributes:
  362. cu_seq_lens_q (`torch.LongTensor`, *optional*)
  363. Gets cumulative sequence length for query state.
  364. cu_seq_lens_k (`torch.LongTensor`, *optional*)
  365. Gets cumulative sequence length for key state.
  366. max_length_q (`int`, *optional*):
  367. Maximum sequence length for query state.
  368. max_length_k (`int`, *optional*):
  369. Maximum sequence length for key state.
  370. """
  371. cu_seq_lens_q: Optional[torch.LongTensor]
  372. cu_seq_lens_k: Optional[torch.LongTensor]
  373. max_length_q: Optional[int]
  374. max_length_k: Optional[int]
  375. def _process_flash_attention_kwargs(
  376. query_length: int,
  377. key_length: int,
  378. is_causal: bool,
  379. dropout: float = 0.0,
  380. softmax_scale: Optional[float] = None,
  381. sliding_window: Optional[int] = None,
  382. use_top_left_mask: bool = False,
  383. softcap: Optional[float] = None,
  384. deterministic: Optional[bool] = None,
  385. s_aux: Optional[torch.Tensor] = None,
  386. supports_mapping: Optional[dict[str, bool]] = None,
  387. **kwargs,
  388. ):
  389. """
  390. Returns a set of kwargs that are passed down to the according flash attention function based on
  391. requested features and whether it is supported - depends on the version and kernel implementation
  392. which is dynamically configured at `lazy_import_flash_attention`. The (un)supported features can be
  393. inspected in `supports_mapping`, see `_lazy_define_process_function` for more details.
  394. Args:
  395. query_length (`int`):
  396. Length of the query states
  397. key_length (`int`):
  398. Length of the key states
  399. is_causal (`bool`):
  400. Whether we perform causal (decoder) attention or full attention.
  401. dropout (`float`):
  402. Attention dropout.
  403. softmax_scale (`float`, *optional*):
  404. The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`.
  405. sliding_window (`int`, *optional*):
  406. The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back.
  407. use_top_left_mask (`bool`):
  408. Deprecated behavior of older versions of flash attention requiring different masking.
  409. softcap (`float`, *optional*):
  410. Softcap for the attention logits, used e.g. in gemma2.
  411. deterministic (`bool`, *optional*):
  412. Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
  413. s_aux (`torch.Tensor`, *optional*):
  414. Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
  415. Return:
  416. flash_kwargs (`dict`):
  417. A dict of kwargs that are requested and supported.
  418. """
  419. flash_kwargs = {
  420. "causal": is_causal and not (use_top_left_mask and query_length == 1),
  421. "softmax_scale": softmax_scale,
  422. }
  423. if supports_mapping["dropout_p"]:
  424. flash_kwargs["dropout_p"] = dropout
  425. if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window:
  426. # The flash attention API sets inclusive boundaries, i.e. (4, 0) would take 4 tokens to the left
  427. # and the current token for a total size of 5. However, we usually define our window sizes by
  428. # their total window size (when causal). Encoder models as of now seldom use SWA and when they
  429. # do, they have a custom workaround (e.g. ModernBERT) which would align with this symmetric logic, i.e.
  430. # for a total of `2*sliding_window + 1`.
  431. flash_kwargs["window_size"] = (sliding_window - 1, sliding_window - 1)
  432. if supports_mapping["deterministic"]:
  433. flash_kwargs["deterministic"] = (
  434. deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
  435. )
  436. if supports_mapping["softcap"] and softcap is not None:
  437. flash_kwargs["softcap"] = softcap
  438. # Only within kernel implementation atm
  439. if supports_mapping["s_aux"] and s_aux is not None:
  440. flash_kwargs["s_aux"] = s_aux
  441. return flash_kwargs
  442. def _flash_attention_forward(
  443. query_states: torch.Tensor,
  444. key_states: torch.Tensor,
  445. value_states: torch.Tensor,
  446. attention_mask: Optional[torch.Tensor],
  447. query_length: int,
  448. is_causal: bool,
  449. dropout: float = 0.0,
  450. position_ids: Optional[torch.Tensor] = None,
  451. softmax_scale: Optional[float] = None,
  452. sliding_window: Optional[int] = None,
  453. use_top_left_mask: bool = False,
  454. softcap: Optional[float] = None,
  455. deterministic: Optional[bool] = None,
  456. cu_seq_lens_q: Optional[torch.LongTensor] = None,
  457. cu_seq_lens_k: Optional[torch.LongTensor] = None,
  458. max_length_q: Optional[int] = None,
  459. max_length_k: Optional[int] = None,
  460. target_dtype: Optional[torch.dtype] = None,
  461. implementation: Optional[str] = None,
  462. **kwargs,
  463. ):
  464. """
  465. Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
  466. first unpad the input, then computes the attention scores and pad the final attention scores.
  467. (Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`.
  468. Args:
  469. query_states (`torch.Tensor`):
  470. Input query states to be passed to Flash Attention API
  471. key_states (`torch.Tensor`):
  472. Input key states to be passed to Flash Attention API
  473. value_states (`torch.Tensor`):
  474. Input value states to be passed to Flash Attention API
  475. attention_mask (`torch.Tensor`, *optional*):
  476. The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
  477. position of padding tokens and 1 for the position of non-padding tokens.
  478. implementation (`str`, *optional*):
  479. The attention implementation to use. If None, will default to the one based on the environment.
  480. """
  481. (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention(
  482. implementation
  483. )
  484. # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
  485. query_states, key_states, value_states = fa_peft_integration_check(
  486. query_states, key_states, value_states, target_dtype
  487. )
  488. # Extract the flash attention kwargs that have been requested (and are supported by the implementation)
  489. flash_kwargs = process_flash_kwargs_fn(
  490. query_length=query_length,
  491. key_length=key_states.size(1),
  492. is_causal=is_causal,
  493. dropout=dropout,
  494. softmax_scale=softmax_scale,
  495. sliding_window=sliding_window,
  496. use_top_left_mask=use_top_left_mask,
  497. softcap=softcap,
  498. deterministic=deterministic,
  499. **kwargs,
  500. )
  501. # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
  502. # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
  503. # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
  504. # use `flash_varlen_fn` knowing we already have all necessary the kwargs.
  505. #
  506. # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model.
  507. # See #39121 for more information.
  508. is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0))
  509. is_fa_with_varlen_kwargs = all(
  510. kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
  511. )
  512. # Contains at least one padding token in the sequence
  513. if attention_mask is not None:
  514. q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
  515. query_states, key_states, value_states, attention_mask, query_length, unpad_fn
  516. )
  517. # TODO for now this is required to work with
  518. # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
  519. if "mps" in str(q.device):
  520. cu_seq_lens_k = cu_seq_lens_k.clone()
  521. out_unpad = flash_varlen_fn(
  522. q,
  523. k,
  524. v,
  525. cu_seqlens_q=cu_seq_lens_q,
  526. cu_seqlens_k=cu_seq_lens_k,
  527. max_seqlen_q=max_length_q,
  528. max_seqlen_k=max_length_k,
  529. **flash_kwargs,
  530. )
  531. if isinstance(out_unpad, tuple):
  532. out_unpad = out_unpad[0]
  533. out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
  534. # Padding free, i.e. sequences flattened into one total sequence
  535. elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
  536. if cu_seq_lens_q is None or cu_seq_lens_k is None:
  537. q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids(
  538. query_states, key_states, value_states, position_ids
  539. )
  540. else:
  541. q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
  542. k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
  543. v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
  544. # TODO for now this is required to work with
  545. # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
  546. if "mps" in str(q.device):
  547. cu_seq_lens_k = cu_seq_lens_k.clone()
  548. out = flash_varlen_fn(
  549. q,
  550. k,
  551. v,
  552. cu_seqlens_q=cu_seq_lens_q,
  553. cu_seqlens_k=cu_seq_lens_k,
  554. max_seqlen_q=max_length_q,
  555. max_seqlen_k=max_length_k,
  556. **flash_kwargs,
  557. )
  558. if isinstance(out, tuple):
  559. out = out[0]
  560. out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1))
  561. # No padding
  562. else:
  563. out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
  564. if isinstance(out, tuple):
  565. out = out[0]
  566. return out