| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947 |
- # mypy: ignore-errors
- """
- This module contains miscellaneous variable tracker implementations for various Python types
- and features used in Dynamo's symbolic execution. These classes help track and propagate
- information about different kinds of variables during graph capture.
- Key classes include:
- - SuperVariable: Handles super() calls and method resolution
- - ExceptionVariable: Tracks exception objects
- - RandomVariable: Manages random number generators
- - GetAttrVariable: Tracks attribute access
- - MethodWrapperVariable: Handles method wrappers
- - PythonModuleVariable: Tracks Python modules
- - NumpyVariable: Handles numpy functions and types
- - StringFormatVariable: Manages string formatting
- - DebuggingVariable: Handles print and logging
- """
- import dataclasses
- import functools
- import inspect
- import itertools
- import random
- import re
- import sys
- import types
- import warnings
- from typing import Optional, TYPE_CHECKING
- import torch._C
- import torch._numpy as tnp
- import torch.utils._pytree as pytree
- from .. import config, graph_break_hints, trace_rules, variables
- from ..bytecode_transformation import create_call_function, create_instruction
- from ..create_parameter_op import do_not_convert_to_tracable_parameter
- from ..exc import raise_observed_exception, unimplemented, unimplemented_v2
- from ..guards import GuardBuilder, install_guard
- from ..mutation_guard import unpatched_nn_module_init
- from ..source import (
- AttrSource,
- GenericAttrSource,
- GetItemSource,
- TypeMROSource,
- TypeSource,
- WeakRefCallSource,
- )
- from ..utils import (
- check_unspec_or_constant_args,
- cmp_name_to_op_mapping,
- identity,
- is_tensor_base_attr_getter,
- istype,
- list_methods,
- proxy_args_kwargs,
- tuple_methods,
- )
- from .base import VariableTracker
- from .constant import ConstantVariable
- from .functions import NestedUserFunctionVariable, UserFunctionVariable
- from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
- if TYPE_CHECKING:
- from torch._dynamo.codegen import PyCodegen
- from torch._dynamo.symbolic_convert import InstructionTranslator
- class NO_SUCH_SUBOBJ:
- pass
- class SuperVariable(VariableTracker):
- _nonvar_fields = {
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, typevar, objvar=None, **kwargs) -> None:
- super().__init__(**kwargs)
- # typevar is the first argument to super(). In the case where no argument
- # is provided to super(), it is the __class__ object where
- # the super() function is being called
- self.typevar = typevar
- # objvar here must be an instance or subtype of typevar.
- # In the case where super() is called without arguments, it is the first argument
- # to the current function where super() is called from (self for regular method,
- # cls for a classmethod)
- self.objvar = objvar
- def reconstruct(self, codegen: "PyCodegen"):
- codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super)))
- codegen(self.typevar)
- if self.objvar is not None:
- codegen(self.objvar)
- codegen.extend_output(create_call_function(2, False))
- else:
- codegen.extend_output(create_call_function(1, False))
- def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
- assert self.objvar, "1-arg super not implemented"
- search_type = self.typevar.as_python_constant()
- # The rest of this function does two things:
- # - Walk the mro to find where the attribute comes from to be
- # able to provide accurate source
- # - Call the getattr to get the object
- # Find the class object, where the function lives.
- # When objvar is "self", use type(self), when objvar is "cls", use it as-is
- type_to_use = self.objvar.python_type()
- type_to_use_source = (
- TypeSource(self.objvar.source) if self.objvar.source else None
- )
- if issubclass(type_to_use, type):
- type_to_use = self.objvar.value
- type_to_use_source = self.objvar.source
- source = None
- search_mro = type_to_use.__mro__
- try:
- start_index = search_mro.index(search_type) + 1
- except ValueError:
- # Corner case where the typevar is not in the mro of the objvar
- # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844
- return getattr(super(search_type, type_to_use), name), None
- # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812
- # super has its getattro implementation. The key point is that instead of calling getattr, it checks the
- # attribute in the class __dict__
- for index in range(start_index, len(search_mro)):
- # Dont call getattr, just check the __dict__ of the class
- if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ):
- if resolved_getattr is not NO_SUCH_SUBOBJ:
- # Equivalent of something like type(L['self']).__mro__[1].attr_name
- if type_to_use_source:
- source = AttrSource(
- GetItemSource(TypeMROSource(type_to_use_source), index),
- name,
- )
- return resolved_getattr, source
- unimplemented_v2(
- gb_type="Unable to resolve super getattr",
- context="",
- explanation=f"Dynamo failed to trace attribute `{name}` accessed "
- f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) "
- "because the resolved attribute type is not supported.",
- hints=[
- "Ensure the attribute exists in the parent class.",
- "Check the arguments passed to `super()`.",
- ],
- )
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
- # Check if getattr is a constant. If not, delay the actual work by
- # wrapping the result in GetAttrVariable. Mostly super is called with a
- # method, so most of the work is delayed to call_function.
- #
- # We could have just implemented a const_getattr. However, super is
- # special when it comes to finding sources. Compared to other VTs, super
- # requires the attr name to walk the mro and find the actual source (and
- # not just AttrSource).
- value, source = self._resolved_getattr_and_source(self, name)
- if not variables.ConstantVariable.is_literal(value):
- return GetAttrVariable(self, name)
- if source:
- install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
- return variables.ConstantVariable.create(value, source=source)
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- inner_fn, source = self._resolved_getattr_and_source(self, name)
- # This essentially simulates CPython's `super_getattro`:
- # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168
- # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`.
- #
- # However, `res`'s type needs to be checked for `tp_descr_get`, and
- # applied if it has one. We currently don't have polyfills for all the
- # relevant `tp_descr_get`, so we explicitly handle the cases we care
- # about here (e.g., note the staticmethod, classmethod cases).
- if inner_fn is object.__init__:
- return LambdaVariable(identity)
- elif inner_fn is torch.nn.Module.__init__:
- objvar = self.objvar
- from ..side_effects import AttributeMutationNew
- if (
- isinstance(objvar, variables.UserDefinedObjectVariable)
- and isinstance(objvar.mutation_type, AttributeMutationNew)
- and not (args or kwargs)
- ):
- with do_not_convert_to_tracable_parameter():
- return variables.UserFunctionVariable(
- unpatched_nn_module_init, source=source
- ).call_function(tx, [self.objvar] + args, kwargs)
- else:
- unimplemented_v2(
- gb_type="Unsupported super().__init__() call",
- context=f"call_method {self} {name} {args} {kwargs}",
- explanation="Dynamo encountered a super().__init__() call "
- f"on {objvar} that resolved to a `torch.nn.Module.__init__()` "
- "call that we cannot trace.",
- hints=[*graph_break_hints.DIFFICULT],
- )
- elif (
- self.objvar.source
- and hasattr(inner_fn, "__name__")
- and inner_fn.__name__ == "__new__"
- and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn)
- ):
- user_cls = inner_fn.__self__
- if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins":
- user_cls_vt = variables.BuiltinVariable(user_cls)
- else:
- user_cls_source = source.member
- user_cls_vt = variables.UserDefinedClassVariable(
- user_cls, source=user_cls_source
- )
- return user_cls_vt.call_method(tx, "__new__", args, kwargs)
- elif isinstance(inner_fn, staticmethod) and isinstance(
- inner_fn.__func__, types.FunctionType
- ):
- return variables.UserFunctionVariable(
- inner_fn.__func__, source=source
- ).call_function(tx, args, kwargs)
- elif isinstance(inner_fn, classmethod) and isinstance(
- inner_fn.__func__, types.FunctionType
- ):
- if isinstance(self.objvar, variables.UserDefinedClassVariable):
- # super().classmethod is called from a classmethod itself. So,
- # super was converted to super(__class__, cls) in bytecode and
- # therefore we have to propagate the cls.
- cls_variable = self.objvar
- else:
- # current function is an instance method, therefore super was
- # converted to super(__class__, self). We have to find
- # type(self) to bind the cls to the parent classmethod.
- # Note that it can't be the self.typevar because __class__ is
- # the class where the method is defined, which could be
- # different from type(self) with polymorphism.
- cls_source = None
- if self.objvar.source:
- cls_source = TypeSource(self.objvar.source)
- cls_variable = VariableTracker.build(
- tx, self.objvar.value_type, cls_source
- )
- return variables.UserFunctionVariable(
- inner_fn.__func__, source=AttrSource(source, "__func__")
- ).call_function(tx, [cls_variable, *args], kwargs)
- elif isinstance(inner_fn, types.FunctionType):
- return variables.UserFunctionVariable(
- inner_fn, source=source
- ).call_function(tx, [self.objvar] + args, kwargs)
- elif isinstance(inner_fn, types.MethodType):
- return variables.UserMethodVariable(
- inner_fn.__func__, self.objvar, source=source
- ).call_function(tx, args, kwargs)
- elif is_standard_setattr(inner_fn) and isinstance(
- self.objvar, UserDefinedObjectVariable
- ):
- return self.objvar.method_setattr_standard(tx, *args, **kwargs)
- elif inner_fn is object.__delattr__:
- attr = args[0]
- try:
- attr = attr.as_python_constant()
- except NotImplementedError as exc:
- unimplemented_v2(
- gb_type="Non-constant attribute given to `super().__delattr__()`",
- context=f"call_method {self} {name}",
- explanation="Dynamo requires the attribute name passed to "
- "`super().__delattr__(...)` to be a constant (string).",
- hints=[
- "Ensure the attribute name is a string literal or a constant variable."
- ],
- from_exc=exc,
- )
- if not tx.output.side_effects.is_attribute_mutation(self.objvar):
- unimplemented_v2(
- gb_type="Attempted super().__delattr__() on an object without mutation tracking",
- context=f"call_method {self} {name}",
- explanation="Dynamo needs to track mutations on an object "
- "before `super().__delattr__` can be used on it. But the "
- f"object ({self.objvar}) doesn't have attribute mutation "
- "tracking enabled.",
- hints=[
- "Ensure the object is tracked by Dynamo's side effect system.",
- *graph_break_hints.DYNAMO_BUG,
- ],
- )
- tx.output.side_effects.store_attr(
- self.objvar, attr, variables.DeletedVariable()
- )
- return variables.ConstantVariable(None)
- elif (
- isinstance(self.objvar, variables.UserDefinedDictVariable)
- and inner_fn in self.objvar._dict_methods
- ):
- return self.objvar._dict_vt.call_method(tx, name, args, kwargs)
- elif (
- isinstance(self.objvar, variables.UserDefinedSetVariable)
- and inner_fn in self.objvar._set_methods
- ):
- return self.objvar._set_vt.call_method(tx, name, args, kwargs)
- elif (
- isinstance(self.objvar, variables.UserDefinedTupleVariable)
- and inner_fn in tuple_methods
- ):
- return self.objvar._tuple_vt.call_method(tx, name, args, kwargs)
- elif (
- isinstance(self.objvar, variables.UserDefinedListVariable)
- and inner_fn in list_methods
- ):
- return self.objvar._list_vt.call_method(tx, name, args, kwargs)
- elif inner_fn is object.__getattribute__:
- # object.__getattribute__ has no side-effects. We can directly call
- # __getattribute__ to access the attribute.
- attr_name = args[0].value
- if tx.output.side_effects.has_pending_mutation_of_attr(
- self.objvar, attr_name
- ):
- result = tx.output.side_effects.load_attr(
- self.objvar, attr_name, deleted_ok=True
- )
- if isinstance(result, variables.DeletedVariable):
- raise_observed_exception(AttributeError, tx)
- return result
- try:
- # NB - use object.__getattribute__ to prevent running any user code
- attr_value = object.__getattribute__(self.objvar.value, attr_name)
- except AttributeError:
- raise_observed_exception(AttributeError, tx)
- attr_source = None
- if self.objvar.source is not None:
- # setup a object.__getattribute__(self.objvar, name) source
- attr_source = GenericAttrSource(self.objvar.source, attr_name)
- return VariableTracker.build(tx, attr_value, attr_source)
- elif inner_fn is torch._C._disabled_torch_function_impl:
- # See `THPModule_disable_torch_function` for the C impl.
- # The signature of _disabled_torch_function_impl is similar to
- # `__torch_function__`, just without the first `cls` argument:
- # * (func, types, args, kwargs)
- func = args[0]
- tf_kwargs = {}
- tf_args = args[2].items
- for hash_key_vt, value_vt in args[3].items.items():
- key_str = hash_key_vt.vt.as_python_constant()
- tf_kwargs[key_str] = value_vt
- tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled
- tx.symbolic_torch_function_state.torch_function_subclass_enabled = False
- try:
- return func.call_function(tx, tf_args, tf_kwargs)
- finally:
- tx.symbolic_torch_function_state.torch_function_subclass_enabled = (
- tx_old
- )
- elif (
- isinstance(inner_fn, types.MethodDescriptorType)
- and inner_fn in trace_rules.get_tensor_method()
- ):
- # FunctionType but implementation is in C, we support some of these,
- # e.g., tensor ops like `torch.Tensor.to`.
- fn_var = VariableTracker.build(tx, inner_fn, source)
- return fn_var.call_function(tx, [self.objvar] + args, kwargs)
- unimplemented_v2(
- gb_type="Attempted to call a super() attribute that is "
- "not a function or method",
- context=f"call_method {self} {name}",
- explanation="Dynamo does not know how to trace the call "
- f"`super().{name}()` because `super().{name}` is not a "
- "function or method attribute.",
- hints=[
- "Ensure the attribute accessed via `super()` is a standard method or function.",
- ],
- )
- class ExceptionVariable(VariableTracker):
- # The ExceptionVariable corresponds to the BaseException class in Python
- def __init__(self, exc_type, args, **kwargs) -> None:
- super().__init__(**kwargs)
- self.exc_type = exc_type
- self.args = args
- # When raising a new exception while another exception is already being
- # handled, the new exception's __context__ attribute is automatically
- # set to the handled exception.
- self.__context__ = ConstantVariable(None)
- # Set when user raised an exception from another:
- # raise ... from ...
- self.__cause__ = ConstantVariable(None)
- # Boolean flag that controls whether the __context__ attribute is set
- self.__suppress_context__ = ConstantVariable(False)
- # Contains the call stack where the exception was raised. Dynamo does
- # not track traceback. So, this variable is always set to None
- self.__traceback__ = ConstantVariable(None)
- def set_context(self, context: "ExceptionVariable"):
- self.__context__ = context
- def reconstruct(self, codegen: "PyCodegen"):
- codegen.add_push_null(
- lambda: codegen.load_import_from("builtins", self.exc_type.__name__)
- )
- codegen.foreach(self.args)
- codegen.call_function(len(self.args), False)
- def codegen_attr(name: str) -> None:
- attr = getattr(self, name)
- if istype(attr, ConstantVariable):
- assert attr.value in (True, False, None), attr
- else:
- codegen.dup_top()
- codegen(attr)
- codegen.extend_output(codegen.rot_n(2))
- codegen.store_attr(name)
- codegen_attr("__context__")
- codegen_attr("__cause__")
- codegen_attr("__suppress_context__")
- def python_type(self):
- return self.exc_type
- def call_setattr(
- self,
- tx: "InstructionTranslator",
- name_var: VariableTracker,
- val: VariableTracker,
- ):
- def raise_error(msg):
- raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)])
- name = name_var.as_python_constant()
- if name == "__context__":
- self.set_context(val)
- elif name == "__cause__":
- if (isinstance(val, ConstantVariable) and val.value is None) or isinstance(
- val,
- (
- variables.BuiltinVariable,
- variables.ExceptionVariable,
- variables.UserDefinedExceptionClassVariable,
- variables.UserDefinedExceptionObjectVariable,
- ),
- ):
- self.__cause__ = val
- self.__suppress_context__ = variables.ConstantVariable(True)
- else:
- raise_error("exception cause must be None or derive from BaseException")
- elif name == "__suppress_context__":
- if isinstance(val, ConstantVariable) and val.value in (True, False):
- self.__suppress_context__ = val
- else:
- raise_error("exception cause must be None or derive from BaseException")
- elif name == "__traceback__":
- if isinstance(val, ConstantVariable) and val.value is None:
- self.__traceback__ = val
- else:
- unimplemented_v2(
- gb_type="Set Exception object `__traceback__` attribute to not-`None`",
- context=f"call_setattr {self} {name}",
- explanation="Dynamo does not support setting the attribute "
- "'__traceback__' on tracked exception objects to anything "
- "other than None.",
- hints=[
- "Avoid setting '__traceback__' on exception objects "
- "within traced code, or set it to None."
- ],
- )
- else:
- unimplemented_v2(
- gb_type="Unsupported attribute assignment on Exception object",
- context=f"call_setattr {self} {name}",
- explanation="Dynamo does not support setting the attribute "
- f"'{name}' on tracked exception objects. Only `__context__`, "
- "`__cause__`, `__suppress_context__`, and `__traceback__` are supported.",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- return variables.ConstantVariable(None)
- def call_method(self, tx, name, args, kwargs):
- if name == "__setattr__":
- return self.call_setattr(tx, *args)
- elif name == "with_traceback":
- [tb] = args
- self.call_setattr(tx, ConstantVariable("__traceback__"), tb)
- return self
- else:
- return super().call_method(tx, name, args, kwargs)
- def var_getattr(self, tx, name):
- if name == "__context__":
- return self.__context__
- elif name == "__cause__":
- return self.__cause__
- elif name == "__suppress_context__":
- return self.__suppress_context__
- elif name == "__traceback__":
- return variables.ConstantVariable(None)
- elif name == "args":
- return variables.ListVariable(self.args, source=self.source)
- return super().var_getattr(tx, name)
- def __str__(self):
- return f"{self.__class__.__name__}({self.exc_type})"
- __repr__ = __str__
- class UnknownVariable(VariableTracker):
- """
- It could be anything!
- """
- class DelayGraphBreakVariable(UnknownVariable):
- """
- Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
- """
- def __init__(self, msg=None, **kwargs):
- super().__init__(**kwargs)
- self.msg = msg
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- unimplemented_v2(
- gb_type="Unsupported function call (delayed)",
- context=f"source: {self.source}",
- explanation="Dynamo determined that a graph break should occur "
- f"when calling `{self.source.name()}`. Reason: {self.msg}",
- hints=[],
- )
- class ComptimeVariable(VariableTracker):
- """
- This variable is special, it lets you execute arbitrary code at
- Dynamo compile time
- """
- def reconstruct(self, codegen: "PyCodegen"):
- raise NotImplementedError("comptime is special form")
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
- from ..comptime import comptime
- # To support the comptime.print_graph convenience accessors
- from .functions import UserFunctionVariable
- return UserFunctionVariable(
- getattr(comptime, name), source=AttrSource(self.source, name)
- )
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- from ..comptime import ComptimeContext
- # TODO: support an expression form as well
- assert not kwargs
- # Second argument is runtime lambda, ignored
- assert len(args) <= 2
- fn = args[0]
- if isinstance(fn, UserFunctionVariable):
- fn.get_function()(ComptimeContext(tx))
- elif isinstance(fn, NestedUserFunctionVariable):
- # We have to manually bind the freevars ourselves
- code = fn.get_code()
- assert not fn.closure, (
- "comptime function must not have free variables, "
- f"but these variables were free: {code.co_freevars}"
- )
- func = types.FunctionType(
- code,
- fn.f_globals,
- fn.fn_name.as_python_constant(),
- tuple(fn.defaults.items) if fn.defaults else None,
- # We could automatically promote free variables into
- # ComptimeVar but this is confusing if you access
- # a free variable that we actually DO have the runtime
- # value for
- # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
- (),
- )
- func(ComptimeContext(tx))
- else:
- raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
- return variables.ConstantVariable.create(None)
- class CellVariable(VariableTracker):
- # If the cell existed before Dynamo tracing started, this will be the
- # VariableTracker that represents the cell content.
- #
- # Note that all mutation to the cell (i.e., its content) will be buffered in
- # SideEffects, rather than being reflected here. One can think of
- # `CellVariable` as a special case for `UserDefinedObjectVariable`.
- pre_existing_contents: Optional[VariableTracker]
- # This is set when this cell can be referenced via `LOAD/STORE_DEREF` in the
- # root frame via this name (e.g., the name is in `co_cellvars/co_freevars`).
- local_name: Optional[str] = None
- def __init__(
- self, pre_existing_contents: Optional[VariableTracker] = None, **kwargs
- ) -> None:
- super().__init__(**kwargs)
- self.pre_existing_contents = pre_existing_contents
- class NewGlobalVariable(VariableTracker):
- def __init__(self, **kwargs) -> None:
- super().__init__(**kwargs)
- def produce_trampoline_autograd_apply(fn_cls):
- def trampoline_autograd_apply(*args, **kwargs):
- return fn_cls.apply(*args, **kwargs)
- trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
- return trampoline_autograd_apply
- class AutogradFunctionVariable(VariableTracker):
- """represents a torch.autograd.Function subclass"""
- _nonvar_fields = {
- "fn_cls",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, fn_cls, **kwargs) -> None:
- super().__init__(**kwargs)
- self.fn_cls = fn_cls
- def call_apply(self, tx: "InstructionTranslator", args, kwargs):
- requires_grad = False
- def visit(vt):
- nonlocal requires_grad
- if isinstance(vt, variables.TensorVariable):
- if vt.requires_grad is not False:
- requires_grad = True
- if isinstance(vt, variables.NNModuleVariable):
- if vt.is_training(tx):
- requires_grad = True
- VariableTracker.visit(visit, (args, kwargs))
- if requires_grad and torch.is_grad_enabled():
- if config.capture_autograd_function is False:
- warnings.warn(
- "The config.capture_autograd_function flag is deprecated, it's now always true."
- )
- from torch._functorch.autograd_function import (
- autograd_function_forward_rewritten,
- )
- from torch.autograd.function import _is_setup_context_defined
- forward_fn = self.fn_cls.forward
- is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context)
- if is_setup_ctx_defined:
- # If setup_context is defined, we generate a new forward function which includes
- # the original forward and setup_context function, and trace the new forward function.
- forward_fn = autograd_function_forward_rewritten(
- self.fn_cls.forward, self.fn_cls.setup_context
- )
- vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
- if vjp_fn is not torch.autograd.Function.vjp:
- unimplemented_v2(
- gb_type="Unsupported custom vjp",
- context=f"call_apply {self} {args} {kwargs}",
- explanation="Dynamo does not support tracing "
- "`torch.autograd.Function` subclasses that define "
- "a custom `vjp` method.",
- hints=[
- "Remove the custom `vjp` method if possible.",
- "Use standard `backward` instead if applicable.",
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined]
- if jvp_fn is not torch.autograd.Function.jvp:
- unimplemented_v2(
- gb_type="Unsupported custom jvp",
- context=f"call_apply {self} {args} {kwargs}",
- explanation="Dynamo does not support tracing "
- "`torch.autograd.Function` subclasses that define "
- "a custom `jvp` method.",
- hints=[
- "Remove the custom `jvp` method if possible.",
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- from .higher_order_ops import AutogradFunctionApplyVariable
- source = self.source
- if source is None:
- source = AttrSource(
- tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
- )
- val = AutogradFunctionApplyVariable(
- forward_fn,
- self.fn_cls.backward,
- source,
- source=AttrSource(source, member="apply"),
- ).call_function(tx, args, kwargs)
- # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
- # the forward function, as we don't want to generate guards for new_forward.__closure__
- # if forward is rewritten by autograd_function_forward_rewritten.
- # But we still need to generate correct guards for the original forward and setup_context
- # functions, so we have to add guards manually.
- if self.source:
- fwd_src = AttrSource(self.source, "forward")
- install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH))
- if is_setup_ctx_defined:
- setup_ctx_src = AttrSource(self.source, "setup_context")
- install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH))
- return val
- if self.source:
- source = AttrSource(self.source, "forward")
- else:
- source = None
- fn = self.fn_cls.forward
- ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
- args = [ctx, *args]
- if isinstance(fn, types.FunctionType):
- sig = inspect.signature(fn)
- if len(args) - 1 == len(sig._parameters):
- args = args[1:] # Don't use context
- return variables.UserFunctionVariable(fn, source=source).call_function(
- tx, args, kwargs
- )
- elif isinstance(fn, types.MethodType):
- return variables.UserMethodVariable(
- fn.__func__,
- variables.UserDefinedClassVariable(self.fn_cls),
- source=source,
- ).call_function(tx, args, kwargs)
- else:
- unimplemented_v2(
- gb_type="Non-function or method in subclass of torch.autograd.Function",
- context=f"call_apply {self} {args} {kwargs}",
- explanation="Dynamo requires the `forward` attribute of a "
- "`torch.autograd.Function` subclass to be a standard Python "
- f"function or method. Found type `{type(fn).__name__}` instead.",
- hints=[
- "Ensure the `forward` method is defined as a regular "
- "function or instance method."
- ],
- )
- def call_backward(self, tx: "InstructionTranslator", args, kwargs):
- fn = self.fn_cls.backward
- assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
- assert isinstance(fn, types.FunctionType)
- fn_source = AttrSource(self.source, "backward")
- return variables.UserFunctionVariable(fn, source=fn_source).call_function(
- tx, args, kwargs
- )
- def call_function(self, tx: "InstructionTranslator", args, kwargs):
- return AutogradFunctionVariable(self.fn_cls)
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ):
- from .builder import wrap_fx_proxy
- if name == "apply":
- if trace_rules.is_callable_allowed(self.fn_cls):
- trampoline_autograd_apply = produce_trampoline_autograd_apply(
- self.fn_cls
- )
- return wrap_fx_proxy(
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_function",
- trampoline_autograd_apply,
- *proxy_args_kwargs(args, kwargs),
- ),
- )
- else:
- return self.call_apply(tx, args, kwargs)
- elif name == "backward":
- return self.call_backward(tx, args, kwargs)
- else:
- source = AttrSource(self.source, name) if self.source is not None else None
- try:
- obj = inspect.getattr_static(self.fn_cls, name)
- except AttributeError:
- obj = None
- if isinstance(obj, staticmethod):
- func = obj.__get__(self.fn_cls)
- if source is not None:
- return (
- trace_rules.lookup(func)
- .create_with_source(func, source=source)
- .call_function(tx, args, kwargs)
- )
- else:
- return trace_rules.lookup(func)(func).call_function(
- tx, args, kwargs
- )
- elif isinstance(obj, classmethod):
- return variables.UserMethodVariable(
- obj.__func__, self, source=source
- ).call_function(tx, args, kwargs)
- else:
- unimplemented_v2(
- gb_type="Unsupported autograd.Function method",
- context=f"call_method {self} {name}",
- explanation="Dynamo does not support calling the method "
- f"`{name}` directly on the `torch.autograd.Function` "
- "instance. Supported methods include `apply`, `backward`, "
- "static methods, and class methods.",
- hints=[
- "Ensure the method is decorated with `@staticmethod` "
- "or `@classmethod` if it's meant to be called on the class.",
- ],
- )
- @dataclasses.dataclass
- class SavedTensorBox:
- tensors: list[VariableTracker] = dataclasses.field(default_factory=list)
- class AutogradFunctionContextVariable(UserDefinedObjectVariable):
- """
- Tracks an autograd.Function() context using mutation tracking in side_effects.py
- """
- _nonvar_fields = {
- "proxy",
- "inference",
- "saved_tensors",
- *UserDefinedObjectVariable._nonvar_fields,
- }
- def __init__(
- self,
- value,
- value_type=None,
- inference=False,
- saved_tensors=None,
- needs_input_grad=None,
- non_differentiable=None,
- **kwargs,
- ) -> None:
- super().__init__(value=value, value_type=value_type, **kwargs)
- self.inference = inference
- self.saved_tensors = saved_tensors
- self.needs_input_grad = needs_input_grad
- self.non_differentiable = non_differentiable
- @staticmethod
- def create(tx: "InstructionTranslator", args=None, kwargs=None):
- needs_input_grad = None
- if args and not kwargs:
- needs_input_grad = tuple(
- isinstance(x, variables.TensorVariable) and x.requires_grad
- for x in args
- )
- out = tx.output.side_effects.track_object_new(
- None,
- torch.autograd.function.FunctionCtx,
- functools.partial(
- AutogradFunctionContextVariable,
- inference=True,
- saved_tensors=SavedTensorBox(),
- needs_input_grad=needs_input_grad,
- ),
- {},
- )
- return out
- def as_proxy(self):
- if self.proxy is None:
- unimplemented_v2(
- gb_type="proxy not set",
- context=f"as_proxy {self}",
- explanation="Dynamo requires the autograd.Function context "
- "to be initialized with a proxy.",
- hints=[*graph_break_hints.DYNAMO_BUG],
- )
- return self.proxy
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "__setattr__":
- return super().call_method(tx, name, args, kwargs)
- elif name == "mark_non_differentiable":
- assert len(kwargs) == 0
- self.non_differentiable = proxy_args_kwargs(args, {})[0]
- return variables.ConstantVariable.create(None)
- if name != "save_for_backward":
- unimplemented_v2(
- gb_type="Unsupported autograd.Function context method",
- context=f"call_method {self} {name}",
- explanation="Dynamo does not support calling the method "
- f"`{name}` on `autograd.Function` context objects. Supported "
- "methods are `__setattr__`, `save_for_backward` and "
- "`mark_non_differentiable`.",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- if self.saved_tensors is None:
- unimplemented_v2(
- gb_type="Unsupported autograd.Function context `save_for_backward`",
- context=f"call_method {self} {name}",
- explanation="Dynamo requires the `saved_tensors` attribute "
- "to be initialized on the `autograd.Function` context object.",
- hints=[
- "Ensure that the `saved_tensors` attribute is properly "
- "initialized before calling `save_for_backward`. "
- "`save_for_backward` only supported on a newly constructed `torch.autograd.function.FunctionCtx`.",
- ],
- )
- if not self.inference:
- assert self.source and not kwargs
- tx.output.side_effects.track_save_for_backward(self, args)
- # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls.
- if len(self.saved_tensors.tensors) > 0:
- self.saved_tensors.tensors = []
- for arg in args:
- self.saved_tensors.tensors.append(arg)
- return variables.ConstantVariable.create(None)
- def var_getattr(self, tx: "InstructionTranslator", name):
- if name in ["save_for_backward", "mark_non_differentiable"]:
- return LambdaVariable(
- lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
- )
- if name == "saved_tensors" and self.saved_tensors is not None:
- return variables.TupleVariable(list(self.saved_tensors.tensors))
- if name == "needs_input_grad":
- if self.needs_input_grad is not None:
- return variables.ConstantVariable.create(self.needs_input_grad)
- if self.source:
- source = AttrSource(self.source, "needs_input_grad")
- return VariableTracker.build(tx, self.value.needs_input_grad, source)
- return super().var_getattr(tx, name)
- class AutogradEngineVariable(UserDefinedObjectVariable):
- """
- Represents a torch._C._ImperativeEngine instance.
- """
- def __init__(
- self,
- value,
- value_type=None,
- **kwargs,
- ) -> None:
- super().__init__(value=value, value_type=value_type, **kwargs)
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "queue_callback":
- if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
- assert tx.one_graph or tx.error_on_graph_break, (
- "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
- )
- return variables.UserFunctionVariable(
- torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
- source=self.source,
- ).call_function(
- tx,
- (tx.output.side_effects.get_ca_final_callbacks_var(), *args),
- kwargs,
- )
- else:
- unimplemented_v2(
- gb_type="Unsupported torch._C._ImperativeEngine.queue_callback()",
- context=f"call_method {self} {name}",
- explanation="queue_callback() is only supported when "
- "Compiled Autograd is enabled with fullgraph=True.",
- hints=[],
- )
- else:
- unimplemented_v2(
- gb_type="Unsupported torch._C._ImperativeEngine method",
- context=f"call_method {self} {name}",
- explanation="Dynamo only supports the `queue_callback` method "
- f"on a torch._C._ImperativeEngine instance, but found: `{name}`.",
- hints=[],
- )
- class LambdaVariable(VariableTracker):
- def __init__(self, fn, **kwargs) -> None:
- super().__init__(**kwargs)
- self.fn = fn
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- return self.fn(*args, **kwargs)
- class GetAttrVariable(VariableTracker):
- _nonvar_fields = {
- "name",
- "py_type",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, obj, name, py_type=None, **kwargs) -> None:
- super().__init__(**kwargs)
- assert isinstance(obj, VariableTracker)
- assert isinstance(name, str)
- self.obj = obj
- self.name = name
- self.py_type = py_type # In some cases we know the type (ex. tensor methods)
- def python_type(self):
- if self.py_type is not None:
- return self.py_type
- else:
- return super().python_type()
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.obj}, {self.name})"
- @staticmethod
- def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
- return getattr(base_proxy, attr)
- def as_proxy(self):
- return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
- def as_python_constant(self):
- constant = self.obj.as_python_constant()
- try:
- return getattr(constant, self.name)
- except AttributeError:
- raise NotImplementedError(f"{self} is not a constant") from None
- def const_getattr(self, tx: "InstructionTranslator", name):
- if not isinstance(self.obj, variables.NNModuleVariable):
- raise NotImplementedError
- step1 = tx.output.get_submodule(self.obj.module_key)
- if self.name not in step1.__dict__:
- raise NotImplementedError
- step2 = inspect.getattr_static(step1, self.name)
- if name not in step2.__dict__:
- raise NotImplementedError
- return inspect.getattr_static(step2, name)
- def reconstruct(self, codegen: "PyCodegen"):
- codegen(self.obj)
- codegen.extend_output(codegen.create_load_attrs(self.name))
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- return self.obj.call_method(tx, self.name, args, kwargs)
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: list[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker:
- if (
- name in ("__getitem__", "get")
- and self.name == "__dict__"
- and not kwargs
- and args[0].is_python_constant()
- and isinstance(
- self.obj,
- (
- variables.UserDefinedObjectVariable,
- variables.NNModuleVariable,
- variables.UserDefinedClassVariable,
- ),
- )
- ):
- obj = self.obj
- key = args[0].as_python_constant()
- if obj.has_key_in_generic_dict(tx, key):
- # redirect to var_getattr on the original obj
- return obj.var_getattr(tx, key)
- # Return the default value for get
- if name == "get":
- if len(args) == 2:
- return args[1]
- else:
- return variables.ConstantVariable(None)
- elif (
- name == "__contains__"
- and self.name == "__dict__"
- and len(args) == 1
- and args[0].is_python_constant()
- and not kwargs
- and isinstance(
- self.obj,
- (
- variables.UserDefinedObjectVariable,
- variables.NNModuleVariable,
- variables.UserDefinedClassVariable,
- ),
- )
- ):
- obj = self.obj
- key = args[0].as_python_constant()
- if obj.has_key_in_generic_dict(tx, key):
- return variables.ConstantVariable(True)
- else:
- return variables.ConstantVariable(False)
- elif name == "__setitem__" and self.name == "__dict__" and not kwargs:
- if isinstance(self.obj, variables.UserDefinedObjectVariable):
- # Bypass any custom setattr as we are updating the `__dict__` itself
- return self.obj.method_setattr_standard(
- tx, args[0], args[1], directly_update_dict=True
- )
- if isinstance(self.obj, variables.NNModuleVariable):
- # This matches how `setattr` is handled for NNModuleVariable
- self.obj.convert_to_unspecialized(tx)
- return super().call_method(tx, name, args, kwargs)
- def get_forwarded_dict(self, tx):
- assert (
- self.name == "__dict__"
- and isinstance(self.obj, variables.UserDefinedClassVariable)
- and not tx.output.side_effects.has_pending_mutation(self.obj)
- )
- self.obj.ban_mutation = True
- return VariableTracker.build(tx, self.obj.value.__dict__, self.source)
- class MethodWrapperVariable(VariableTracker):
- def __init__(self, method_wrapper, **kwargs) -> None:
- super().__init__(**kwargs)
- self.method_wrapper = method_wrapper
- self._builtin_fns = {}
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
- args[0], variables.TensorVariable
- ):
- assert len(args) == 1 and len(kwargs) == 0
- return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)
- # method-wrapper variables are common in __init__ calls. For example,
- # str("foo").__init__ is a method-wrapper. These method wrappers point
- # to C functions. Here we intercept if these method-wrappers are from
- # builtins and then call the function counterpart directly by obtaining
- # the self object.
- self_obj = self.method_wrapper.__self__
- wrapper_name = self.method_wrapper.__name__
- # TODO(dynamo-team) - We can perhaps expand the scope to more names and
- # more builtins.
- if wrapper_name == "__init__":
- fn_obj = type(self_obj).__init__
- if fn_obj is object.__init__:
- return variables.BuiltinVariable(object).call_method(
- tx, wrapper_name, [self_obj, *args], kwargs
- )
- return super().call_function(tx, args, kwargs)
- def is_python_constant(self):
- return True
- def as_python_constant(self):
- return self.method_wrapper
- class GetSetDescriptorVariable(VariableTracker):
- def __init__(self, desc, **kwargs) -> None:
- super().__init__(**kwargs)
- self.desc = desc
- def var_getattr(self, tx: "InstructionTranslator", name):
- if name == "__get__" and self.source:
- source = AttrSource(self.source, "__get__")
- return VariableTracker.build(tx, self.desc.__get__, source)
- else:
- return super().var_getattr(tx, name)
- def is_python_constant(self):
- return True
- def as_python_constant(self):
- return self.desc
- class PythonModuleVariable(VariableTracker):
- _nonvar_fields = {
- "value",
- "is_torch",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, value: types.ModuleType, **kwargs) -> None:
- super().__init__(**kwargs)
- self.value = value
- self.is_torch = self.value is torch or self.value.__name__.startswith("torch.")
- def python_type(self):
- return types.ModuleType
- def as_python_constant(self):
- return self.value
- def __repr__(self) -> str:
- return f"PythonModuleVariable({self.value})"
- def call_obj_hasattr(self, tx: "InstructionTranslator", name):
- result = hasattr(self.value, name)
- return variables.ConstantVariable.create(result)
- def var_getattr(self, tx: "InstructionTranslator", name):
- if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
- return tx.output.side_effects.load_attr(self, name)
- if self.is_torch or name not in self.value.__dict__:
- try:
- attr_value = getattr(self.value, name)
- except AttributeError:
- raise_observed_exception(AttributeError, tx)
- else:
- attr_value = self.value.__dict__[name]
- source = self.source and AttrSource(self.source, name)
- return VariableTracker.build(tx, attr_value, source)
- class TypingVariable(VariableTracker):
- def __init__(self, value, **kwargs) -> None:
- super().__init__(**kwargs)
- self.value = value
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- # Create a new typing variable, e.g., `List[int]`
- if name == "__getitem__" and len(args) == 1:
- new_typing = self.value[args[0].as_python_constant()]
- return TypingVariable(new_typing)
- unimplemented("unsupported method call on typing variablel")
- def var_getattr(self, tx: "InstructionTranslator", name: str):
- from .builder import SourcelessBuilder, VariableBuilder
- if name in cmp_name_to_op_mapping:
- return variables.GetAttrVariable(self, name)
- if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
- return tx.side_effects.load_attr(self, name)
- value = getattr(self.value, name)
- if self.source:
- attr_source = AttrSource(self.source, name)
- return VariableBuilder(tx, attr_source)(value)
- else:
- return SourcelessBuilder.create(tx, value)
- def as_python_constant(self):
- return self.value
- def reconstruct(self, codegen: "PyCodegen") -> None:
- # We're just trying to load the type here. Reconstructing the type from
- # scratch is tricky - for a type like `typing.List[int]` we'd need to
- # deconstruct the origin and args. The origin for `List[int]` is `list`
- # and the args is `(int,)`. When we recombine those we get the parts
- # back and need to emit code for:
- #
- # `typing.List[int]`
- #
- # But it's # worse than that - what if `typing` isn't in the globals (or
- # was loaded like `import typing as _typing ; _typing.List[int]`?) so we
- # really need to do something like:
- #
- # `sys.modules["typing"].List[int]`
- #
- # Argh - but what if they rewrote the global `int`? So we have to do:
- #
- # `sys.modules["typing"].List[sys.modules["builtins"].int]`
- #
- # But where do we get `sys`? What if they never imported it or have
- # something ELSE called `sys`?
- #
- # Let's skip all that noise and just emit it as a simple const.
- #
- codegen.append_output(codegen.create_load_const(self.value))
- @functools.lru_cache(maxsize=1)
- def get_np_to_tnp_map():
- """
- This generates a mapping from numpy modules to their torch._numpy
- modules equivalents.
- """
- from ..utils import NP_TO_TNP_MODULE
- np_fn_to_tnp_fn = {}
- for np_mod, tnp_mod in NP_TO_TNP_MODULE.items():
- for fn_name, tnp_fn in tnp_mod.__dict__.items():
- if callable(tnp_fn):
- # some internal details do leak from tnp
- # which are not part of numpy API.
- if np_fn := getattr(np_mod, fn_name, None):
- np_fn_to_tnp_fn[np_fn] = tnp_fn
- return np_fn_to_tnp_fn
- @functools.lru_cache(maxsize=1)
- def get_tnp_to_np_map():
- """
- This is just the reverse mapping of get_np_to_tnp_map() - mapping from
- torch._numpy modules to numpy equivalents.
- """
- m = get_np_to_tnp_map()
- return {v: k for k, v in m.items()}
- class NumpyVariable(VariableTracker):
- """
- Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
- """
- constant_fold_functions = (tnp.issubdtype,)
- def __init__(self, value, **kwargs) -> None:
- super().__init__(**kwargs)
- self.value = value
- @classmethod
- def can_constant_fold_through(cls, fn):
- mod = fn.__module__.split(".")
- assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
- return fn in cls.constant_fold_functions
- @classmethod
- def get_constant_collection_for_func(cls, fn):
- mod = fn.__module__.split(".")
- assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
- return np_constant_collections_map.get(fn, None)
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if not config.trace_numpy:
- unimplemented(f"numpy.{self.value}()")
- from ..utils import numpy_to_tensor_wrapper
- from .tensor import NumpyNdarrayVariable
- func = get_np_to_tnp_map().get(self.value)
- if func is None:
- unimplemented(
- f"Can't find numpy function {self.value} in torch._numpy. "
- " Please file an issue to request support for this function."
- )
- # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo)
- if (
- collection_variable_typ := self.get_constant_collection_for_func(func)
- ) is not None:
- try:
- return collection_variable_typ(
- self.value(
- *[x.as_python_constant() for x in args],
- **{k: v.as_python_constant() for k, v in kwargs.items()},
- )
- )
- except NotImplementedError:
- unimplemented(
- f"{self.value.__name__} with non-const args: {args} {kwargs}"
- )
- else:
- if (
- func.__module__ == "torch._numpy.random"
- and config.use_numpy_random_stream
- ):
- msg = f"delegate '{func.__qualname__}' to NumPy itself via "
- msg += (
- f"config.use_numpy_random_stream={config.use_numpy_random_stream}"
- )
- unimplemented(msg)
- args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs)
- if self.can_constant_fold_through(func) and (
- check_unspec_or_constant_args(args, kwargs)
- ):
- # constant fold
- return variables.ConstantVariable.create(
- self.as_python_constant()(
- *[x.as_python_constant() for x in args],
- **{k: v.as_python_constant() for k, v in kwargs.items()},
- ),
- )
- # TODO Add all the functions that go from constants to constants to can_constant_fold_through
- proxy = tx.output.create_proxy(
- "call_function",
- numpy_to_tensor_wrapper(func),
- *proxy_args_kwargs(args, kwargs),
- )
- return NumpyNdarrayVariable.create(tx, proxy)
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- unimplemented("numpy")
- def as_python_constant(self):
- return self.value
- def as_proxy(self):
- if config.trace_numpy and isinstance(self.value, type):
- # This handles numpy dtype attributes such as np.float32
- # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
- # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
- return self.value.__name__
- return super().as_proxy()
- # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
- class NullVariable(VariableTracker):
- def __init__(self, **kwargs) -> None:
- super().__init__(**kwargs)
- def __repr__(self) -> str:
- return "NullVariable"
- def reconstruct(self, codegen: "PyCodegen"):
- if sys.version_info < (3, 11):
- unimplemented("cannot reconstruct NullVariable in < Python 3.11")
- codegen.append_output(create_instruction("PUSH_NULL"))
- class DeletedVariable(VariableTracker):
- """Marker used to implement delattr()"""
- class StringFormatVariable(VariableTracker):
- """
- Represents a call to str.format(), we delay calling format until after the graph.
- """
- _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields}
- @classmethod
- def create(cls, format_string, sym_args, sym_kwargs):
- if all(
- x.is_python_constant()
- for x in itertools.chain(sym_args, sym_kwargs.values())
- ):
- return variables.ConstantVariable.create(
- format_string.format(
- *[v.as_python_constant() for v in sym_args],
- **{k: v.as_python_constant() for k, v in sym_kwargs.items()},
- )
- )
- return cls(format_string, list(sym_args), dict(sym_kwargs))
- def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None:
- super().__init__(**kwargs)
- assert isinstance(format_string, str)
- self.format_string = format_string
- self.sym_args = sym_args
- self.sym_kwargs = sym_kwargs
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
- def reconstruct(self, codegen: "PyCodegen"):
- codegen.add_push_null(
- lambda: codegen.extend_output(
- [
- codegen.create_load_const(self.format_string),
- codegen.create_load_attr("format"),
- ]
- ),
- call_function_ex=True,
- )
- codegen(variables.TupleVariable(self.sym_args))
- kwargs = {
- variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items()
- }
- codegen(variables.ConstDictVariable(kwargs))
- codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1))
- class DebuggingVariable(VariableTracker):
- """
- Represents a call to a debugging function like print(), or something
- registered to config.reorderable_logging_functions.
- """
- def __init__(self, value, **kwargs) -> None:
- super().__init__(**kwargs)
- self.value = value
- @staticmethod
- def is_reorderable_logging_function(obj):
- return (
- callable(obj)
- and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType))
- and obj in torch._dynamo.config.reorderable_logging_functions
- )
- def call_function(self, tx: "InstructionTranslator", args, kwargs):
- if tx.export:
- # For export cases, we can just make debugging functions no-ops
- return
- if not self.can_reorder_logs(self.value, args, kwargs):
- unimplemented(
- f"Reordering debugging function {self.value} "
- f"with inputs {args} {kwargs} is not yet implemented."
- )
- tx.debug_locals.append((self, list(args)))
- def reconstruct(self, codegen: "PyCodegen"):
- return self.source.reconstruct(codegen)
- @staticmethod
- def can_reorder_logs(fn, args, kwargs) -> True:
- """
- Run some additional checks for what sort of function calls can we
- actually reorder.
- """
- allowed_input_types = (
- variables.TensorVariable,
- variables.ConstantVariable,
- StringFormatVariable,
- )
- flat_args = pytree.tree_leaves([args, kwargs])
- for arg in flat_args:
- if not isinstance(arg, allowed_input_types):
- return False
- return True
- class LoggingLoggerVariable(VariableTracker):
- """
- Represents a call to any of logging.Logger methods
- """
- def __init__(self, value, **kwargs) -> None:
- super().__init__(**kwargs)
- self.value = value
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if tx.export:
- # For export cases, we can just make debugging functions no-ops
- return
- method = getattr(self.value, name, None)
- function = getattr(method, "__func__", None)
- if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods):
- return variables.ConstantVariable.create(None)
- unimplemented(
- "Logger not supported for non-export cases. "
- "To avoid graph breaks caused by logger in compile-mode, it is recommended to"
- " disable logging by adding logging methods to config.ignore_logger_methods"
- )
- class ConstantLikeVariable(VariableTracker):
- """self.value is a compile-time constant, but not a literal"""
- _error_prefix = "ConstantLikeVariable"
- try:
- from numpy import (
- dtype as np_dtype,
- floating as np_floating,
- generic as np_generic,
- )
- except ImportError:
- np_floating = type("invalid_type", (), {})
- np_dtype = type("invalid_type", (), {})
- def __init__(self, value, **kwargs) -> None:
- super().__init__(**kwargs)
- self.value = value
- def as_python_constant(self):
- return self.value
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: list[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker:
- try:
- # we only support constant propagation for methods
- cargs = [x.as_python_constant() for x in args]
- ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
- except NotImplementedError:
- unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})")
- result = getattr(self.value, name)(*cargs, **ckwargs)
- if variables.ConstantVariable.is_literal(result):
- return variables.ConstantVariable.create(result)
- if isinstance(result, re.Match):
- return ConstantRegexMatchVariable(result)
- unimplemented(f"{self._error_prefix}.{name}() -> {result}")
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
- result = getattr(self.value, name)
- if isinstance(result, self.np_floating):
- result = float(result)
- if isinstance(result, self.np_dtype):
- return NumpyDTypeVariable(result)
- if isinstance(result, type) and issubclass(result, self.np_generic):
- # things like x.dtype.type
- return NumpyVariable(result)
- if variables.ConstantVariable.is_literal(result):
- return variables.ConstantVariable.create(result)
- return GetAttrVariable(self, name)
- class RegexPatternVariable(ConstantLikeVariable):
- _error_prefix = "re.Pattern"
- class ConstantRegexMatchVariable(ConstantLikeVariable):
- _error_prefix = "re.Match"
- class TorchVersionVariable(ConstantLikeVariable):
- _error_prefix = "torch.__version__"
- def __init__(self, **kwargs) -> None:
- kwargs.setdefault("value", torch.__version__)
- assert kwargs["value"] is torch.__version__
- super().__init__(**kwargs)
- class NumpyTypeInfoVariable(ConstantLikeVariable):
- _error_prefix = "np.iinfo/np.finfo"
- class NumpyDTypeVariable(ConstantLikeVariable):
- _error_prefix = "np.dtype[...]"
- def as_proxy(self):
- """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:
- np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype.
- This also handles unsupported things nicely (i.e. structured arrays and object arrays).
- """
- return self.value.type.__name__
- np_constant_collections_map = {
- tnp.finfo: NumpyTypeInfoVariable,
- tnp.iinfo: NumpyTypeInfoVariable,
- tnp.dtype: NumpyDTypeVariable,
- }
- class RandomClassVariable(VariableTracker):
- """random.Random"""
- def __init__(self, **kwargs) -> None:
- super().__init__(**kwargs)
- def call_function(self, tx: "InstructionTranslator", args, kwargs):
- if len(args) > 1:
- unimplemented("random.Random() with > 1 arg")
- elif kwargs:
- unimplemented("random.Random() with kwargs")
- seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0]
- return RandomVariable(
- seed=seed, mutation_type=variables.base.ValueMutationNew()
- )
- class RandomVariable(VariableTracker):
- """random.Random()
- Implemented by wrapping a VariableTracker around a random.Random object.
- The supported methods for the random.Random object cannot be overridden.
- Assumes that random objects behave the same given a set seed or state.
- """
- _nonvar_fields = {
- "random",
- *VariableTracker._nonvar_fields,
- }
- _supported_fn_names = {
- "random",
- "randint",
- "randrange",
- "uniform",
- }
- def __init__(
- self,
- rand: Optional[random.Random] = None,
- seed: Optional[VariableTracker] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- if rand is not None:
- assert self.is_supported_random_obj(rand)
- self.random = random.Random()
- self.random.setstate(rand.getstate())
- else:
- seed = seed.as_python_constant() if seed is not None else None
- self.random = random.Random(seed)
- def python_type(self):
- return random.Random
- def as_python_constant(self):
- return self.random
- @staticmethod
- def is_supported_random_obj(val):
- if type(val) is not random.Random:
- return False
- for name in itertools.chain(
- RandomVariable._supported_fn_names, ("seed", "getstate", "setstate")
- ):
- if not hasattr(val, name):
- return False
- meth = getattr(val, name)
- if inspect.isbuiltin(meth):
- # e.g. random.Random.random
- if meth != getattr(random.Random, name).__get__(val):
- return False
- else:
- if getattr(meth, "__func__", None) is not getattr(random.Random, name):
- return False
- return True
- @staticmethod
- def check_state(state):
- assert type(state) is tuple
- assert type(state[0]) is int
- assert type(state[1]) is tuple
- assert all(type(x) is int for x in state[1])
- assert state[2] is None or type(state[2]) is float
- @staticmethod
- def wrap_state(state):
- RandomVariable.check_state(state)
- return variables.TupleVariable(
- [
- variables.ConstantVariable.create(state[0]),
- variables.TupleVariable(
- [variables.ConstantVariable.create(x) for x in state[1]]
- ),
- variables.ConstantVariable.create(state[2]),
- ]
- )
- @staticmethod
- def unwrap_state(state):
- state_obj = state.as_python_constant()
- RandomVariable.check_state(state_obj)
- return state_obj
- def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: list[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker:
- if name == "seed":
- tx.output.side_effects.mutation(self)
- self.random.seed(
- *[x.as_python_constant() for x in args],
- **{key: val.as_python_constant() for key, val in kwargs.items()},
- )
- return variables.ConstantVariable.create(None)
- elif name == "getstate":
- return self.wrap_state(self.random.getstate())
- elif name == "setstate":
- tx.output.side_effects.mutation(self)
- self.random.setstate(self.unwrap_state(args[0]))
- return variables.ConstantVariable.create(None)
- elif name in self._supported_fn_names:
- tx.output.side_effects.mutation(self)
- state = self.random.getstate()
- def call_random_meth(*args, **kwargs):
- r = random.Random()
- r.setstate(state)
- return getattr(r, name)(*args, **kwargs)
- # self.random state not actually updated by call_random_meth, so update here
- # by calling the method
- getattr(self.random, name)(
- *[x.as_python_constant() for x in args],
- **{k: v.as_python_constant() for k, v in kwargs.items()},
- )
- return call_random_fn(tx, call_random_meth, args, kwargs)
- return super().call_method(tx, name, args, kwargs)
- def reconstruct(self, codegen: "PyCodegen"):
- codegen.add_push_null(
- lambda: codegen.extend_output(
- [
- codegen.create_load_python_module(random),
- codegen.create_load_attr("Random"),
- ]
- )
- )
- codegen.call_function(0, False)
- # NOTE using add_push_null may result in NULL being duplicated
- # so defer the push_null to call_function
- codegen.dup_top()
- codegen.load_attr("setstate")
- codegen(self.wrap_state(self.random.getstate()))
- codegen.call_function(1, True)
- codegen.pop_top()
- class WeakRefVariable(VariableTracker):
- @staticmethod
- def build(tx, weakref_value, **options):
- source = options.get("source", None)
- callback = weakref_value.__callback__
- callback_source = source and AttrSource(source, "__callback__")
- callback_vt = VariableTracker.build(tx, callback, callback_source)
- referent = weakref_value()
- source = source and WeakRefCallSource(source)
- referent_vt = VariableTracker.build(tx, referent, source)
- options["source"] = source
- return WeakRefVariable(referent_vt, callback_vt, **options)
- def __init__(self, referent_vt, callback_vt, **options):
- super().__init__(**options)
- self.referent_vt = referent_vt
- self.callback_vt = callback_vt
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- return self.referent_vt
- def reconstruct(self, codegen: "PyCodegen"):
- codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref"))
- codegen(self.referent_vt)
- codegen(self.callback_vt)
- codegen.extend_output(create_call_function(2, False))
|