| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- # mypy: ignore-errors
- """
- Constant and enum variable tracking in Dynamo.
- This module is fundamental to Dynamo's ability to track and propagate constant
- values during compilation, ensuring proper handling of Python literals and
- maintaining type safety through the compilation process.
- """
- import operator
- from typing import TYPE_CHECKING
- import torch
- from torch._dynamo.source import AttrSource, GetItemSource
- from .. import graph_break_hints, variables
- from ..exc import raise_observed_exception, unimplemented_v2
- from ..utils import cmp_name_to_op_mapping, common_constant_types, istype, np
- from .base import VariableTracker
- if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
- class ConstantVariable(VariableTracker):
- """
- Variable tracker for Python literals and basic immutable types, with automatic
- routing support for collection types (lists, tuples, sets, etc.).
- The create() method intelligently constructs appropriate variable types for
- nested collections.
- """
- @staticmethod
- def create(value, **kwargs) -> VariableTracker:
- """
- Create a `ConstantVariable` based on the given value, and supports
- automatic routing for collection types like `tuple` (in which case we'd
- create `ConstantVariable` for the leaf items).
- NOTE: the caller must install the proper guards if needed; most often
- the guard will be `CONSTANT_MATCH`.
- """
- source = kwargs.get("source", None)
- # Routing for supported collection literals.
- if isinstance(value, set):
- items = [ConstantVariable.create(x) for x in value]
- return variables.SetVariable(items, **kwargs)
- elif isinstance(value, frozenset):
- items = [ConstantVariable.create(x) for x in value]
- return variables.FrozensetVariable(items, **kwargs)
- elif isinstance(value, (list, tuple)):
- items = []
- for i, x in enumerate(value):
- item_source = GetItemSource(source, i) if source else None
- items.append(
- ConstantVariable.create(
- x,
- source=item_source,
- )
- )
- return variables.BaseListVariable.cls_for(type(value))(items, **kwargs)
- return ConstantVariable(value, **kwargs)
- def __init__(self, value, **kwargs) -> None:
- super().__init__(**kwargs)
- assert ConstantVariable.is_base_literal(value), f"""
- Cannot construct `ConstantVariable` for value of type {type(value)}.
- This failure likely due to PyTorch-internal use of `ConstantVariable` on
- non-literal python values, please try using `VariableTracker.build` instead. If
- you believe it's a necessary and legitimate use case (the value is immutable and
- can't easily be represented with another `VariableTracker` class), please add
- its type to `common_constant_types`.
- """
- if np is not None and isinstance(value, np.number):
- self.value = value.item()
- else:
- self.value = value
- def as_proxy(self):
- return self.value
- def __repr__(self) -> str:
- return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
- def as_python_constant(self):
- return self.value
- def is_python_constant(self):
- return True
- @property
- def items(self):
- """
- Need this when adding a BaseListVariable and a ConstantVariable together.
- Happens in detectron2.
- """
- return self.unpack_var_sequence(tx=None)
- def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
- return ConstantVariable.create(
- self.value[arg.as_python_constant()],
- )
- @staticmethod
- def is_base_literal(obj):
- return type(obj) in common_constant_types
- @staticmethod
- def is_literal(obj):
- if type(obj) in (list, tuple, set, frozenset, torch.Size):
- return all(ConstantVariable.is_literal(x) for x in obj)
- return ConstantVariable.is_base_literal(obj)
- def unpack_var_sequence(self, tx):
- try:
- return [ConstantVariable.create(x) for x in self.as_python_constant()]
- except TypeError as e:
- raise NotImplementedError from e
- def const_getattr(self, tx: "InstructionTranslator", name):
- if not hasattr(self.value, name):
- raise_observed_exception(AttributeError, tx, args=[name])
- member = getattr(self.value, name)
- if callable(member):
- raise NotImplementedError
- return member
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- from .tensor import SymNodeVariable
- if name == "format" and istype(self.value, str):
- return variables.BuiltinVariable(str.format).call_function(
- tx, [self, *args], kwargs
- )
- elif name == "join" and istype(self.value, str):
- assert len(args) == 1 and len(kwargs) == 0
- arg_unpacked = args[0].force_unpack_var_sequence(tx)
- try:
- arg_const = [x.as_python_constant() for x in arg_unpacked]
- return ConstantVariable.create(self.value.join(arg_const))
- except NotImplementedError:
- return super().call_method(tx, name, args, kwargs)
- if any(isinstance(x, SymNodeVariable) for x in args):
- # Promote to SymNodeVariable for operations involving dynamic shapes.
- return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
- tx, name, args, kwargs
- )
- try:
- const_args = [a.as_python_constant() for a in args]
- const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
- except NotImplementedError:
- return super().call_method(tx, name, args, kwargs)
- if isinstance(self.value, str) and name in str.__dict__.keys():
- method = getattr(self.value, name)
- try:
- return ConstantVariable.create(method(*const_args, **const_kwargs))
- except Exception as e:
- raise_observed_exception(type(e), tx)
- elif isinstance(self.value, (float, int)):
- if not (args or kwargs):
- try:
- return ConstantVariable.create(getattr(self.value, name)())
- except (OverflowError, ValueError) as exc:
- raise_observed_exception(
- type(exc),
- tx,
- args=list(map(ConstantVariable.create, exc.args)),
- )
- if (
- hasattr(operator, name)
- and len(args) == 1
- and args[0].is_python_constant()
- ):
- add_target = const_args[0]
- op = getattr(operator, name)
- if isinstance(
- add_target, (torch.SymBool, torch.SymFloat, torch.SymInt)
- ):
- # Addition between a non sym and sym makes a sym
- proxy = tx.output.create_proxy(
- "call_function", op, (self.value, add_target), {}
- )
- return SymNodeVariable.create(tx, proxy, add_target)
- else:
- try:
- return ConstantVariable.create(op(self.value, add_target))
- except Exception as e:
- raise_observed_exception(
- type(e), tx, args=list(map(ConstantVariable.create, e.args))
- )
- elif isinstance(self.value, bytes) and name == "decode":
- method = getattr(self.value, name)
- return ConstantVariable.create(method(*const_args, **const_kwargs))
- elif type(self.value) is complex and name in complex.__dict__.keys():
- method = getattr(self.value, name)
- try:
- return ConstantVariable.create(method(*const_args, **const_kwargs))
- except Exception as e:
- raise_observed_exception(type(e), tx)
- if name == "__len__" and not (args or kwargs):
- return ConstantVariable.create(len(self.value))
- elif name == "__round__" and len(args) == 1 and args[0].is_python_constant():
- try:
- return ConstantVariable.create(
- round(self.value, args[0].as_python_constant())
- )
- except Exception as e:
- raise_observed_exception(
- type(e), tx, args=list(map(ConstantVariable.create, e.args))
- )
- elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
- assert not kwargs
- search = args[0].as_python_constant()
- try:
- result = search in self.value
- return ConstantVariable.create(result)
- except TypeError as e:
- raise_observed_exception(
- type(e), tx, args=list(map(ConstantVariable.create, e.args))
- )
- return super().call_method(tx, name, args, kwargs)
- def call_obj_hasattr(
- self, tx: "InstructionTranslator", name: str
- ) -> "VariableTracker":
- result = hasattr(self.value, name)
- return variables.ConstantVariable.create(result)
- class EnumVariable(VariableTracker):
- """VariableTracker for enum.Enum and enum.IntEnum instances
- Provides specialized handling for Python enum types, supporting
- both standard Enum and IntEnum with proper value tracking and comparison.
- """
- def __init__(self, value, **kwargs) -> None:
- super().__init__(**kwargs)
- self.value = value
- @classmethod
- def create(cls, cls_type, value_vt, options):
- if isinstance(value_vt, variables.ConstantVariable):
- for member in list(cls_type):
- if member.value == value_vt.as_python_constant():
- return cls(member, **options)
- unimplemented_v2(
- gb_type="Failed to construct Enum variable",
- context=f"value: {value_vt}, allowed enum values: {list(cls_type)}",
- explanation="Attempted to construct an Enum value that is non-constant (e.g. int, string) "
- "or is not an acceptable value for the Enum. "
- f"Acceptable values for Enum `{cls_type}`: {list(cls_type)}.",
- hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE],
- )
- def as_proxy(self):
- if isinstance(self.value, int):
- return int(self.value) # convert IntEnum to a normal int
- return self.value
- def __repr__(self) -> str:
- return f"EnumVariable({type(self.value)})"
- def as_python_constant(self):
- return self.value
- def var_getattr(self, tx: "InstructionTranslator", name):
- if not hasattr(self.value, name):
- raise NotImplementedError
- if name in cmp_name_to_op_mapping:
- return variables.GetAttrVariable(self, name)
- member = getattr(self.value, name)
- source = self.source and AttrSource(self.source, name)
- return VariableTracker.build(tx, member, source=source)
|