sdpa.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # mypy: ignore-errors
  2. from inspect import getattr_static
  3. from typing import TYPE_CHECKING
  4. from ..bytecode_transformation import create_call_function
  5. from ..exc import Unsupported
  6. from ..source import AttrSource
  7. from .base import VariableTracker
  8. if TYPE_CHECKING:
  9. from torch._dynamo.codegen import PyCodegen
  10. from torch._dynamo.symbolic_convert import InstructionTranslator
  11. PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split()
  12. class SDPAParamsVariable(VariableTracker):
  13. """Represents the c++ params struct for scaled dot product attention.
  14. This is a read-only container."""
  15. @staticmethod
  16. def create(tx: "InstructionTranslator", value, source):
  17. from torch.backends.cuda import SDPAParams
  18. from .torch import TorchInGraphFunctionVariable
  19. params = [
  20. VariableTracker.build(tx, getattr(value, p), AttrSource(source, p))
  21. for p in PARAM_NAMES
  22. ]
  23. return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
  24. def __init__(self, proxy, param_vars, **kwargs) -> None:
  25. self.proxy = proxy
  26. self.param_vars = param_vars
  27. super().__init__(**kwargs)
  28. def reconstruct(self, codegen: "PyCodegen"):
  29. assert self.source is None
  30. assert self.param_vars is not None
  31. codegen.add_push_null(
  32. lambda: codegen.load_import_from("torch._C", "_SDPAParams")
  33. )
  34. codegen.foreach(self.param_vars)
  35. codegen.extend_output(create_call_function(len(self.param_vars), False))
  36. def as_proxy(self):
  37. return self.proxy
  38. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  39. import torch._C
  40. from .builder import wrap_fx_proxy
  41. from .misc import GetAttrVariable
  42. try:
  43. getattr_static(torch._C._SDPAParams, name)
  44. except AttributeError:
  45. # Using raise from is too verbose here
  46. raise Unsupported(
  47. f"Unsupported torch._C._SDPAParams attribute {name}"
  48. ) from None
  49. proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
  50. if self.source is not None:
  51. return wrap_fx_proxy(
  52. tx=tx, proxy=proxy, source=AttrSource(self.source, name)
  53. )
  54. else:
  55. return wrap_fx_proxy(tx=tx, proxy=proxy)
  56. @staticmethod
  57. def is_sdpa_params(value):
  58. from torch.backends.cuda import SDPAParams
  59. return value is SDPAParams