_utils.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # mypy: allow-untyped-defs
  2. """Defines utilities for interacting with scaled_dot_product_attention"""
  3. import math
  4. from typing import Optional
  5. import torch
  6. __all__: list[str] = []
  7. def _input_requires_grad(*tensors: torch.Tensor) -> bool:
  8. """Returns True if any of the tensors requires grad"""
  9. return any(t.requires_grad for t in tensors)
  10. def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor:
  11. """Handles the unpad of the last dimension"""
  12. if inpt_tensor.size(-1) != og_size:
  13. return inpt_tensor[..., :og_size]
  14. return inpt_tensor
  15. def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float:
  16. """
  17. For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output
  18. by the original head size and not the padded.
  19. """
  20. if scale is not None:
  21. return scale
  22. return 1.0 / math.sqrt(head_dim_size)
  23. def _validate_sdpa_input(
  24. query: torch.Tensor,
  25. key: torch.Tensor,
  26. value: torch.Tensor,
  27. attn_mask: Optional[torch.Tensor] = None,
  28. dropout_p=0.0,
  29. is_causal=False,
  30. scale=None,
  31. ):
  32. if query.dtype != key.dtype or query.dtype != value.dtype:
  33. raise ValueError(
  34. f"Expected query, key, and value to have the same dtype, "
  35. f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
  36. f"and value.dtype: {value.dtype} instead."
  37. )
  38. if query.device != key.device or query.device != value.device:
  39. raise ValueError(
  40. f"Expected query, key, and value to have the same device type, "
  41. f"but got query.device: {query.device}, key.device: {key.device}, "
  42. f"and value.device: {value.device} instead."
  43. )
  44. if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
  45. raise ValueError(
  46. f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
  47. f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
  48. )