flash_attention.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from typing import Optional
  2. import torch
  3. from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
  4. from ..utils import logging
  5. logger = logging.get_logger(__name__)
  6. _use_top_left_mask = flash_attn_supports_top_left_mask()
  7. def flash_attention_forward(
  8. module: torch.nn.Module,
  9. query: torch.Tensor,
  10. key: torch.Tensor,
  11. value: torch.Tensor,
  12. attention_mask: Optional[torch.Tensor],
  13. dropout: float = 0.0,
  14. scaling: Optional[float] = None,
  15. sliding_window: Optional[int] = None,
  16. softcap: Optional[float] = None,
  17. **kwargs,
  18. ) -> tuple[torch.Tensor, None]:
  19. if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
  20. logger.warning_once(
  21. "`flash_attention_2` does not support `output_attentions=True` or `head_mask`."
  22. " Please set your attention to `eager` if you want any of these features."
  23. )
  24. # This is before the transpose
  25. seq_len = query.shape[2]
  26. if any(dim == 0 for dim in query.shape):
  27. raise ValueError(
  28. "Tensor query has shape with a zero dimension.\n"
  29. "FlashAttention does not support inputs with dim=0.\n"
  30. "Please check your input shapes or use SDPA instead."
  31. )
  32. # FA2 uses non-transposed inputs
  33. query = query.transpose(1, 2)
  34. key = key.transpose(1, 2)
  35. value = value.transpose(1, 2)
  36. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  37. # therefore the input hidden states gets silently casted in float32. Hence, we need
  38. # cast them back in the correct dtype just to be sure everything works as expected.
  39. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  40. # in fp32. (usually our RMSNorm modules handle it correctly)
  41. target_dtype = None
  42. if query.dtype == torch.float32:
  43. if torch.is_autocast_enabled():
  44. target_dtype = torch.get_autocast_gpu_dtype()
  45. # Handle the case where the model is quantized
  46. elif hasattr(module.config, "_pre_quantization_dtype"):
  47. target_dtype = module.config._pre_quantization_dtype
  48. else:
  49. target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
  50. # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
  51. is_causal = kwargs.pop("is_causal", None)
  52. if is_causal is None:
  53. is_causal = module.is_causal
  54. attn_output = _flash_attention_forward(
  55. query,
  56. key,
  57. value,
  58. attention_mask,
  59. query_length=seq_len,
  60. is_causal=is_causal,
  61. dropout=dropout,
  62. softmax_scale=scaling,
  63. sliding_window=sliding_window,
  64. softcap=softcap,
  65. use_top_left_mask=_use_top_left_mask,
  66. target_dtype=target_dtype,
  67. attn_implementation=module.config._attn_implementation,
  68. layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
  69. **kwargs,
  70. )
  71. return attn_output, None