flash_paged.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from typing import Optional
  2. import torch
  3. from ..generation.continuous_batching import PagedAttentionCache
  4. from ..utils import is_flash_attn_2_available
  5. # For some reason, if we dont assign the function to a variable here, it will be garbage collected
  6. try:
  7. if is_flash_attn_2_available():
  8. from flash_attn import flash_attn_varlen_func # noqa: F401
  9. FLASH_ATTN_VARLEN_FUNC = flash_attn_varlen_func
  10. else:
  11. raise RuntimeError(
  12. "Flash Attention 2 is not installed. Please refer to https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install it"
  13. )
  14. except Exception as e:
  15. msg = repr(e)
  16. def FLASH_ATTN_VARLEN_FUNC(*args, **kwargs):
  17. raise Exception(f"flash_attn_varlen_func is not available: {msg}")
  18. def paged_attention_forward(
  19. module: torch.nn.Module,
  20. q: torch.Tensor,
  21. k: torch.Tensor,
  22. v: torch.Tensor,
  23. attention_mask: Optional[torch.Tensor] = None,
  24. cache: PagedAttentionCache = None,
  25. cu_seq_lens_q=None,
  26. cu_seq_lens_k=None,
  27. max_seqlen_q=None,
  28. max_seqlen_k=None,
  29. implementation=None,
  30. **kwargs,
  31. ) -> torch.Tensor:
  32. r"""Perform the forward pass of attention with paged key-value cache.
  33. This function handles the cache updates and performs the attention computation
  34. using the flash_attn_varlen_func for efficient processing.
  35. Args:
  36. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  37. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full k
  38. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full v
  39. cu_seq_lens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  40. of the sequences in the batch, used to index into q.
  41. cu_seq_lens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  42. of the sequences in the batch, used to index into kv.
  43. max_seqlen_q: int. Maximum query sequence length in the batch.
  44. max_seqlen_k: int. Maximum key sequence length in the batch.
  45. dropout_p: float. Dropout probability.
  46. softmax_scale: float. The scaling of QK^T before applying softmax.
  47. Default to 1 / sqrt(headdim).
  48. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  49. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  50. softcap: float. Anything > 0 activates softcapping attention.
  51. """
  52. sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window - 1, 0)
  53. layer_type = "full_attention" if sliding_window == (-1, -1) else "sliding_attention"
  54. # .update changes the shape of k and v from [1, num_kv_heads, seqlen_kv, head_dim] to [-1, num_kv_heads, head_dim]
  55. if cache is not None:
  56. k, v = cache.update(k, v, module.layer_idx, **kwargs)
  57. # Retrieve the cumulative sequence lengths for the current layer
  58. if isinstance(cu_seq_lens_k, dict):
  59. cu_seq_lens_k = cu_seq_lens_k[layer_type]
  60. max_seqlen_k = max_seqlen_k[layer_type]
  61. if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"):
  62. flash_attn_varlen_func = implementation.flash_attn_varlen_func
  63. else:
  64. flash_attn_varlen_func = FLASH_ATTN_VARLEN_FUNC
  65. custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}
  66. attn_output = flash_attn_varlen_func(
  67. q.transpose(1, 2).squeeze(0).contiguous(),
  68. k.contiguous(),
  69. v.contiguous(),
  70. cu_seq_lens_q.to(torch.int32),
  71. cu_seq_lens_k.to(torch.int32).clone(),
  72. max_seqlen_q,
  73. max_seqlen_k,
  74. softmax_scale=module.scaling,
  75. causal=True, # kind of a must, it automatically aligns the mask for q < k
  76. window_size=sliding_window, # -1 means infinite context window
  77. **custom_kwargs,
  78. )
  79. if isinstance(attn_output, tuple):
  80. attn_output = attn_output[0]
  81. return attn_output, None