| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765 |
- # mypy: ignore-errors
- """
- This module contains classes and utilities for building variable trackers in Dynamo.
- Variable trackers are used to convert Python values into symbolic representations
- that can be traced and transformed during graph capture.
- The key classes are:
- - VariableBuilder: Handles source-tracked objects that need guards and proper
- reconstruction in the output graph. Used for inputs, module attributes, etc.
- - SourcelessBuilder: Handles ephemeral objects created during tracing that don't
- need source tracking or guards. Used for temporary lists, intermediate values, etc.
- Variable trackers enable Dynamo to track the flow of values through the program,
- maintain guards for dynamic properties, and reconstruct values in the output graph.
- The builders in this module handle converting Python values into appropriate
- VariableTracker instances based on their type and usage context.
- """
- import abc
- import collections
- import contextlib
- import copy
- import dataclasses
- import enum
- import functools
- import inspect
- import itertools
- import logging
- import math
- import operator
- import random
- import re
- import sys
- import traceback
- import types
- import weakref
- from collections.abc import MutableMapping
- from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
- import sympy
- import torch
- from torch import SymInt
- from torch._dispatch.python import enable_python_dispatcher
- from torch._dynamo.utils import (
- get_metrics_context,
- is_int_specialization_case,
- is_torch_sym,
- set_feature_use,
- )
- from torch._guards import TracingContext
- from torch._higher_order_ops.flat_apply import flat_apply
- from torch._higher_order_ops.torchbind import call_torchbind
- from torch._ops import HigherOrderOperator
- from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
- from torch._subclasses.meta_utils import is_sparse_any, safe_grad
- from torch._utils_internal import justknobs_check
- from torch.fx.experimental._backward_state import BackwardState
- from torch.fx.experimental._dynamism import normalize_source_name
- from torch.fx.experimental.symbolic_shapes import (
- _constrain_range_for_size,
- _nested_int_aware_sort,
- DimDynamic,
- RelaxedUnspecConstraint,
- StatefulSymbolicContext,
- SubclassSymbolicContext,
- SymbolicContext,
- SymIntSymbolicContext,
- TrackedFake,
- )
- from torch.fx.immutable_collections import immutable_dict, immutable_list
- from torch.nn.utils._expanded_weights import ExpandedWeight
- from torch.utils._python_dispatch import (
- is_traceable_wrapper_subclass,
- is_traceable_wrapper_subclass_type,
- )
- from torch.utils._sympy.value_ranges import ValueRanges
- from torch.utils.weak import TensorWeakRef
- from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules
- from ..device_interface import get_registered_device_interfaces
- from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented_v2
- from ..guards import GuardBuilder, install_guard, make_dupe_guard
- from ..pgo import (
- auto_dynamic,
- auto_unset,
- FrameStateSizeEntry,
- InferStride,
- process_automatic_dynamic,
- )
- from ..side_effects import SideEffects
- from ..source import (
- AttrProxySource,
- AttrSource,
- CallMethodItemSource,
- ChainedSource,
- ConstDictKeySource,
- ConvertIntSource,
- DictGetItemSource,
- DictSubclassGetItemSource,
- FloatTensorSource,
- GetItemSource,
- GradSource,
- is_constant_source,
- is_from_closure_source,
- is_from_global_source,
- is_from_nonlocal_source,
- is_from_optimizer_source,
- is_from_unspecialized_nn_module_source,
- ListGetItemSource,
- LocalSource,
- NonSerializableSetGetItemSource,
- NumpyTensorSource,
- OptimizerSource,
- RandomValueSource,
- Source,
- SubclassAttrListSource,
- TupleIteratorGetItemSource,
- UnspecializedBuiltinNNModuleSource,
- UnspecializedNNModuleSource,
- )
- from ..utils import (
- _extract_tensor_dict,
- build_checkpoint_variable,
- build_invoke_subgraph_variable,
- clone_input,
- common_constant_types,
- dict_keys,
- get_fake_value,
- get_items_from_dict,
- get_locals_to_steal,
- get_static_address_type,
- is_frozen_dataclass,
- is_function,
- is_function_or_wrapper,
- is_invoke_subgraph,
- is_lru_cache_wrapped_function,
- is_namedtuple,
- is_parameter_freezing,
- is_typing,
- is_utils_checkpoint,
- is_wrapper_or_member_descriptor,
- istype,
- namedtuple_fields,
- odict_values,
- proxy_args_kwargs,
- range_iterator,
- set_example_value,
- tensor_always_has_static_shape,
- tuple_iterator,
- tuple_iterator_getitem,
- tuple_iterator_len,
- unwrap_with_attr_name_if_wrapper,
- wrap_fake_exception,
- )
- from .base import (
- AttributeMutationNew,
- typestr,
- ValueMutationExisting,
- ValueMutationNew,
- VariableTracker,
- VariableTrackerMeta,
- )
- from .builtin import BuiltinVariable
- from .constant import ConstantVariable, EnumVariable
- from .ctx_manager import (
- AutocastModeVariable,
- DynamoConfigPatchVariable,
- ErrorOnGraphBreakVariable,
- EventVariable,
- NullContextVariable,
- PreserveVersionContextVariable,
- StreamContextVariable,
- StreamVariable,
- )
- from .dicts import (
- ConstDictVariable,
- DefaultDictVariable,
- DictKeySetVariable,
- FrozensetVariable,
- MappingProxyVariable,
- SetVariable,
- )
- from .distributed import (
- DeviceMeshVariable,
- PlacementClassVariable,
- PlacementVariable,
- ProcessGroupVariable,
- WorldMetaClassVariable,
- )
- from .functions import (
- BuiltinMethodVariable,
- CollectionsNamedTupleFunction,
- CollectiveFunctionRewriteVariable,
- CreateTMADescriptorExperimentalVariable,
- CreateTMADescriptorStableVariable,
- FunctoolsPartialVariable,
- FunctoolsWrapsVariable,
- SysFunctionVariable,
- TracebackVariable,
- TritonKernelVariable,
- UserFunctionVariable,
- UserMethodVariable,
- WrapperUserFunctionVariable,
- )
- from .higher_order_ops import TorchHigherOrderOperatorVariable
- from .iter import ItertoolsVariable
- from .lazy import LazyVariableTracker
- from .lists import (
- BaseListVariable,
- ListIteratorVariable,
- ListVariable,
- NamedTupleVariable,
- RangeVariable,
- SizeVariable,
- SliceVariable,
- TupleIteratorVariable,
- TupleVariable,
- )
- from .misc import (
- AutogradEngineVariable,
- AutogradFunctionContextVariable,
- AutogradFunctionVariable,
- ComptimeVariable,
- DebuggingVariable,
- DelayGraphBreakVariable,
- GetAttrVariable,
- GetSetDescriptorVariable,
- LambdaVariable,
- LoggingLoggerVariable,
- MethodWrapperVariable,
- NumpyDTypeVariable,
- NumpyTypeInfoVariable,
- NumpyVariable,
- PythonModuleVariable,
- RandomClassVariable,
- RandomVariable,
- RegexPatternVariable,
- SavedTensorBox,
- TorchVersionVariable,
- TypingVariable,
- WeakRefVariable,
- )
- from .nn_module import (
- FSDPManagedNNModuleVariable,
- UnspecializedBuiltinNNModuleVariable,
- UnspecializedNNModuleVariable,
- )
- from .optimizer import OptimizerVariable
- from .script_object import TorchScriptObjectVariable
- from .sdpa import SDPAParamsVariable
- from .tensor import (
- NumpyNdarrayVariable,
- supported_const_comparison_op_values,
- SymNodeVariable,
- TensorSubclassVariable,
- TensorVariable,
- UnspecializedPythonVariable,
- )
- from .torch import (
- DispatchKeySetVariable,
- FuncTorchInterpreterVariable,
- TorchCtxManagerClassVariable,
- TorchInGraphFunctionVariable,
- )
- from .torch_function import (
- TensorWithTFOverrideVariable,
- torch_function_mode_stack_state_mgr,
- TorchFunctionModeVariable,
- )
- from .user_defined import (
- FrozenDataClassVariable,
- IntWrapperVariable,
- KeyedJaggedTensorVariable,
- MutableMappingVariable,
- SourcelessGraphModuleVariable,
- UserDefinedClassVariable,
- UserDefinedDictVariable,
- UserDefinedExceptionClassVariable,
- UserDefinedListVariable,
- UserDefinedObjectVariable,
- UserDefinedSetVariable,
- UserDefinedTupleVariable,
- )
- try:
- import numpy as np
- except ModuleNotFoundError:
- np = None
- if TYPE_CHECKING:
- from torch._dynamo.codegen import PyCodegen
- from torch._dynamo.symbolic_convert import InstructionTranslator
- log = logging.getLogger(__name__)
- static_inputs_log = torch._logging.getArtifactLogger(
- __name__, "cudagraph_static_inputs"
- )
- DimList = list
- def safe_has_grad(t):
- with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
- return hasattr(t, "grad")
- class _missing:
- pass
- @dataclasses.dataclass
- class GraphArg:
- source: Source
- # TODO: storing a SymInt here but not a FakeTensor is a pretty strange
- # thing to do. Probably should have example (which stores an int) and
- # fake_example
- _example: Union[TensorWeakRef, torch.SymInt]
- # When True, this indicates that this GraphArg is a Python quantity (e.g.,
- # a float or int) which we pass to the FX graph as a Tensor. This
- # controls how we codegen calls into the Dynamo graph: we will call
- # torch.as_tensor on the quantity before passing it in.
- #
- # Note that we typically do not pass dynamic integers as tensors, because
- # they will most frequently just be used for size computation. But this
- # is a policy decision that we can change our mind on; in particular, when
- # an int comes from a random number generator (e.g., random.randint), we
- # DO pass it as a tensor.
- #
- # It's also worth noting that our current tracing rules for
- # pass_arg_as_tensor as subtly broken: we just pun the variable as a
- # 0d scalar Tensor and pray that the semantics are the same. Which they
- # often are, but not necessarily. ezyang(May 2024) plans to fix this
- # soon.
- pass_arg_as_tensor: bool
- fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor]
- # UnspecializedPythonVariable often masquerades as a tensor.
- # We MUST NOT generate shape guard code
- # that actually tries to access tensor properties on these values.
- # is_tensor lets us tell if this graph arg actually is a tensor
- # or not.
- is_tensor: bool = True
- # Sometimes, the Tensor we pass to example is freshly allocated (smh).
- # Then we cannot only keep a weak reference to it. This lets you
- # stash a strong reference too.
- example_strong_ref: Optional[torch.Tensor] = None
- @property
- def example(self):
- if isinstance(self._example, TensorWeakRef):
- r = self._example()
- assert r is not None
- return r
- else:
- return self._example
- def __post_init__(self):
- if isinstance(self._example, torch.Tensor):
- self._example = TensorWeakRef(self._example)
- assert is_fake(self.fake_tensor)
- def reconstruct(self, codegen: "PyCodegen"):
- codegen(self.source)
- def erase(self):
- self._example = None
- self.example_strong_ref = None
- def __eq__(self, other):
- return self.source.name() == other.source.name()
- class BackwardStateGraphArg(GraphArg):
- def __init__(self) -> None:
- super().__init__(
- source=None,
- _example=BackwardState(),
- pass_arg_as_tensor=False,
- fake_tensor=None,
- is_tensor=False,
- )
- def reconstruct(self, codegen: "PyCodegen"):
- assert codegen.tx.output.backward_state_var
- codegen.add_push_null(
- lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState")
- )
- codegen.call_function(0, False)
- codegen.dup_top()
- codegen.store(codegen.tx.output.backward_state_var)
- # All class-based iterators in itertools
- # NOTE: use id() because some objects are not hashable, it will raise error during lookup
- ITERTOOLS_TYPE_IDS: frozenset[int] = frozenset(
- id(member)
- for name, member in vars(itertools).items()
- if not name.startswith("_") and inspect.isclass(member)
- )
- # Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py
- ITERTOOLS_POLYFILLED_TYPE_IDS: set[int] = set()
- # Capture fn pointer at import time
- # This is to guard against trying to mark the iterated tensors
- # as static in case user overrides fn ptr
- og_module_named_buffers_fn_ptr = torch.nn.Module.named_buffers
- og_module_named_parameters_fn_ptr = torch.nn.Module.named_parameters
- class VariableBuilder:
- """Wrap a python value in a VariableTracker() instance"""
- def __init__(
- self,
- tx,
- source: Source,
- ) -> None:
- assert source is not None, (
- "Consider SourcelessBuilder for ephemeral objects, usually objects created locally."
- )
- assert TracingContext.try_get() is not None, "Expected active TracingContext"
- super().__init__()
- self.tx = tx
- self.source = source
- self.name = source.name()
- def __call__(self, value):
- if value in self.tx.output.side_effects:
- side_effect_result = self.tx.output.side_effects[value]
- dup_guard = make_dupe_guard(self.source, side_effect_result.source)
- if dup_guard:
- self.install_guards(dup_guard)
- return side_effect_result
- cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source)
- if cached_vt:
- return cached_vt
- vt = self._wrap(value)
- if vt.source is None:
- vt.source = self.source
- def _is_deduplicable_sym_variable(value, vt):
- # Constants like 0, 1, 2, etc. can be unspecialized as SymNodeVariables sometimes, but we
- # should NOT track them. If we use a single SymNodeVariable instance to track them
- # across multiple uses, then guards created for one usage will incorrectly apply to
- # all other usages of that constant, leading to unnecessary recompilations.
- return is_torch_sym(value) and isinstance(vt, SymNodeVariable)
- if (
- (
- self._can_lift_attrs_to_inputs(vt)
- or _is_deduplicable_sym_variable(value, vt)
- )
- and value not in self.tx.output.side_effects
- and not is_wrapper_or_member_descriptor(value)
- ):
- vt = self.tx.output.side_effects.track_object_existing(value, vt)
- self.tx.output.variable_tracker_cache.add(value, self.source, vt)
- return vt
- def _can_lift_attrs_to_inputs(self, vt):
- return type(vt) in {
- TensorVariable,
- TensorWithTFOverrideVariable,
- UserDefinedObjectVariable,
- NumpyNdarrayVariable,
- }
- def get_source(self):
- return self.source
- def install_guards(self, *guards):
- source = self.get_source()
- try:
- tmp = [source.make_guard(guard) for guard in guards]
- except NotImplementedError:
- return None
- install_guard(*tmp, skip=1)
- return {}
- @classmethod
- def _type_dispatch(cls):
- return cls._type_dispatch_impl(config.trace_numpy)
- @classmethod
- @functools.cache
- def _type_dispatch_impl(cls, trace_numpy):
- # NB: Careful not to close over self to avoid ref cycle from lru_cache
- entries = [
- (
- (
- torch.Tensor,
- torch.nn.Parameter,
- torch._subclasses.FakeTensor,
- torch._subclasses.functional_tensor.FunctionalTensor,
- ),
- cls.wrap_tensor,
- ),
- (
- (tuple, list, odict_values, collections.deque, torch.Size),
- cls.wrap_listlike,
- ),
- (tuple_iterator, cls.wrap_tuple_iterator),
- (range_iterator, cls.wrap_range_iterator),
- ((slice, range), cls.wrap_slice_range),
- (tuple(common_constant_types), cls.wrap_literal),
- (re.Pattern, cls.wrap_regex_pattern),
- (weakref.ReferenceType, cls.wrap_weakref),
- (torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle),
- (torch.jit.ScriptFunction, cls.wrap_jit_function),
- (types.MappingProxyType, cls.wrap_mapping_proxy),
- ]
- if trace_numpy and np:
- entries.append((np.ndarray, cls.wrap_numpy_ndarray))
- result = {}
- for ts, fn in entries:
- for t in ts if isinstance(ts, tuple) else (ts,):
- assert t not in result
- result[t] = fn
- return result
- def wrap_regex_pattern(self, value: re.Pattern):
- # TODO(jansel): something like a REPR_MATCH might be more robust here
- self.install_guards(GuardBuilder.ID_MATCH)
- return RegexPatternVariable(value)
- def wrap_weakref(self, value: weakref.ReferenceType):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WeakRefVariable.build(self.tx, value, source=self.source)
- def wrap_removable_handle(self, value):
- # This means that the removable handle was created in some other frame.
- # Our current infra requires the hook to be registered and removed in
- # the same frame. So graph break.
- # Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks
- unimplemented_v2(
- gb_type="Attempted to represent unregistered RemovableHandle",
- context="",
- explanation="Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, "
- "which is not supported. This happens because the RemovableHandle was created in another frame.",
- hints=[],
- )
- def wrap_jit_function(self, value):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WrapperUserFunctionVariable(
- value, "_torchdynamo_inline", source=self.source
- )
- def wrap_mapping_proxy(self, value):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- # This might be suboptimal compared to dict guards. But mappingproxy is
- # not very common, so its ok to guard on all keys.
- self.install_guards(GuardBuilder.MAPPING_KEYS_CHECK)
- all_const = all(ConstantVariable.is_literal(k) for k in value.keys())
- if not all_const:
- unimplemented_v2(
- gb_type="non-const keys in mappingproxy",
- context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}",
- explanation="Dynamo expects mappingproxy keys to be constants.",
- hints=[
- "Ensure your mappingproxy keys are constants (e.g. int, float, strings)",
- ],
- )
- def build_key_value(k, v):
- key = ConstantVariable.create(k)
- source_key = k
- source_value = GetItemSource(self.get_source(), source_key)
- res_value = LazyVariableTracker.create(v, source_value)
- return key, res_value
- items = dict(build_key_value(k, v) for k, v in value.items())
- # Create a dict_vt to be used in the mapping proxy variable
- dict_vt = ConstDictVariable(items, source=None)
- result = MappingProxyVariable(dict_vt, source=self.source)
- return self.tx.output.side_effects.track_mutable(value, result)
- @classmethod
- @functools.cache
- def _id_dispatch(
- cls,
- ) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]:
- from ..comptime import comptime
- entries = [
- (comptime, lambda self, value: ComptimeVariable()),
- (
- dataclasses.fields,
- lambda self, value: LambdaVariable(
- _dataclasses_fields_lambda,
- source=self.source,
- **self.install_guards(GuardBuilder.FUNCTION_MATCH),
- ),
- ),
- (torch.__version__, lambda self, value: TorchVersionVariable()),
- ]
- result = {}
- for ts, fn in entries:
- for t in ts if isinstance(ts, (tuple, list)) else (ts,):
- assert t not in result
- result[id(t)] = fn
- return result
- def _wrap(self, value):
- # import here to avoid circular dependencies
- from torch.utils._triton import (
- has_triton,
- has_triton_experimental_host_tma,
- has_triton_tensor_descriptor_host_tma,
- )
- from ..decorators import (
- DynamoConfigPatchProxy,
- ErrorOnGraphBreakDecoratorContextManager,
- )
- if has_triton():
- from triton.runtime.autotuner import Autotuner
- from triton.runtime.jit import JITFunction
- else:
- class JITFunction:
- pass
- class Autotuner:
- pass
- # default implementations, in case we don't have triton (or the wrong triton version)
- def create_1d_tma_descriptor():
- pass
- def create_2d_tma_descriptor():
- pass
- class TensorDescriptor:
- @staticmethod
- def from_tensor():
- pass
- if has_triton_experimental_host_tma():
- from triton.tools.experimental_descriptor import ( # noqa: F811
- create_1d_tma_descriptor,
- create_2d_tma_descriptor,
- )
- if has_triton_tensor_descriptor_host_tma():
- from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811
- # Handle exact type() match
- type_dispatch = self._type_dispatch().get(type(value))
- if type_dispatch is not None:
- return type_dispatch(self, value)
- # Handle exact id() match
- id_dispatch = self._id_dispatch().get(id(value))
- if id_dispatch is not None:
- return id_dispatch(self, value)
- # Everything else (NB: order matters!)
- if (
- isinstance(value, torch.Tensor)
- and type(value)
- not in (
- # These torch-native subclasses have overly restrictive
- # `__torch_function__` which prevents Dynamo from reading their
- # tensor attributes like `is_nested` or calling methods like
- # `_is_view`.
- torch.nn.parameter.UninitializedBuffer,
- torch.nn.parameter.UninitializedParameter,
- ExpandedWeight,
- )
- and type(value) not in config.nontraceable_tensor_subclasses
- ):
- if (
- type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__
- or is_traceable_wrapper_subclass(value)
- ):
- return self.wrap_tensor(value)
- if is_namedtuple(value):
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- output = [
- LazyVariableTracker.create(
- getattr(value, name),
- source=AttrSource(self.source, name),
- )
- for name in namedtuple_fields(type(value))
- ]
- result = NamedTupleVariable(
- output, tuple_cls=type(value), source=self.source
- )
- return result
- elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- all_const = all(ConstantVariable.is_literal(k) for k in value.keys())
- # For all_const, we don't have to guard on anything yet. We guard on
- # keys lazily by adding a dict_getitem entry for each accessed key.
- # For cases where we need to guard on all keys, we lazily put guards
- # during the dict call_method (check dicts.py)
- if not all_const:
- # Guard on the key order
- # This is not ideal, i.e., there is no need to guard on the key
- # order. But we guard on the key order because of the complexity
- #
- # 1) For non-constant objects, we can't save the key in the
- # guard context because it can be memory heavy. We can add
- # weakrefs but this complicates the accesses.
- #
- # 2) For non-constant objects, we also have to guard on the keys
- # (like TENSOR_MATCH on tensor). We might also have guards on
- # the attributes of the keys (like tensor.grad). To make this
- # work in tree structure is complicated.
- #
- # So, instead we guard on the key order. While guarding on key
- # order, we just save the indices and use it to access keys and
- # values. Indices are cheap to save.
- self.tx.output.guard_on_key_order.add(self.source)
- # We need all the keys to be hashable. We do this within the
- # _HashableTracker class in dicts.py
- def build_key_value(i, k, v):
- base = self.get_source()
- if all_const:
- key = ConstantVariable.create(k)
- source_key = k
- else:
- source_key = ConstDictKeySource(base, i)
- key = LazyVariableTracker.create(k, source_key)
- source_value = DictGetItemSource(base, source_key)
- res_value = LazyVariableTracker.create(v, source_value)
- return key, res_value
- # Ensure that we call dict.keys and not value.keys (which can call
- # overridden keys method). In the C++ guards, we relied on
- # PyDict_Next to traverse the dictionary, which uses the internal
- # data structure and does not call the overridden keys method.
- result = dict(
- build_key_value(i, k, v)
- for i, (k, v) in enumerate(get_items_from_dict(value))
- )
- if istype(value, collections.defaultdict):
- factory_source = AttrSource(self.source, "default_factory")
- result = DefaultDictVariable(
- result,
- type(value),
- default_factory=VariableBuilder(self.tx, factory_source)(
- value.default_factory
- ),
- source=self.source,
- )
- else:
- result = ConstDictVariable(
- result, user_cls=type(value), source=self.source
- )
- return self.tx.output.side_effects.track_mutable(value, result)
- elif isinstance(value, torch.nn.Module):
- return self.wrap_module(value)
- elif ConstantVariable.is_literal(value): # non-atomic literals
- return self.wrap_literal(value)
- elif isinstance(value, torch.overrides.TorchFunctionMode):
- var = TorchFunctionModeVariable(value, source=self.source)
- self.tx.output.side_effects.track_object_existing(value, var)
- return var
- elif istype(value, set):
- if any(isinstance(x, torch.Tensor) for x in value):
- unimplemented_v2(
- gb_type="Attempted to wrap a set with tensors",
- context="Python set containing torch.Tensor elements",
- explanation=(
- "Dynamo cannot trace sets of tensors. To get a stable ordering, "
- "Dynamo needs to convert the set into a list and the order might not be "
- "stable if the set contains tensors."
- ),
- hints=[
- "Use a dictionary where the keys are tensors.",
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- # The list gives a ordering for the set items. The ordering is based
- # on the Python hash and it is not related to object ordering inside
- # the set object. The order being incorrect at runtime will lead to
- # a recompilation.
- L = list(value)
- items = [
- LazyVariableTracker.create(
- v, source=NonSerializableSetGetItemSource(self.source, i)
- )
- for i, v in enumerate(L)
- ]
- result = SetVariable(items, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif istype(value, frozenset) and all(
- (
- # For DBR quantization, we could get a frozenset of torch funcs.
- (type(x) is types.BuiltinMethodType and x.__module__ == "torch")
- or
- # Another commonly used frozenset of types.
- x in torch.utils._pytree.BUILTIN_TYPES
- )
- for x in value
- ):
- # For the limited cases of frozenset here, we know the items won't
- # change across runs, so we can safely create sourceless VTs for
- # them and only guard on the frozenset id.
- # TODO support source for sets and remove the special logics here.
- items = [SourcelessBuilder.create(self.tx, v) for v in value]
- self.install_guards(GuardBuilder.ID_MATCH)
- return FrozensetVariable(items, source=self.source)
- elif isinstance(
- value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType)
- ):
- self.install_guards(GuardBuilder.ID_MATCH)
- return EnumVariable(value=value, source=self.source)
- elif DebuggingVariable.is_reorderable_logging_function(value):
- # Put this above builtin_callable so that print() can be handled
- # along with other builtin debugging functions
- self.install_guards(GuardBuilder.BUILTIN_MATCH)
- return DebuggingVariable(value, source=self.source)
- elif isinstance(value, logging.Logger):
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- return LoggingLoggerVariable(value, source=self.source)
- elif is_utils_checkpoint(value):
- return build_checkpoint_variable(source=self.source)
- elif is_invoke_subgraph(value):
- return build_invoke_subgraph_variable(source=self.source)
- elif isinstance(value, functools.partial):
- func_src = AttrSource(self.get_source(), "func")
- func_obj = VariableBuilder(self.tx, func_src)(value.func)
- args = []
- args_source = AttrSource(self.get_source(), "args")
- for i, arg in enumerate(value.args):
- args.append(
- VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
- )
- keywords = {}
- keywords_source = AttrSource(self.get_source(), "keywords")
- for k, v in value.keywords.items():
- if not ConstantVariable.is_literal(k):
- unimplemented_v2(
- gb_type="functools.partial() with non-literal keyword",
- context=f"non-literal keyword: {k}",
- explanation="functools.partial() expects literal/string keywords",
- hints=[*graph_break_hints.USER_ERROR],
- )
- keywords[k] = VariableBuilder(
- self.tx, DictGetItemSource(keywords_source, k)
- )(v)
- install_guard(
- self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
- keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH),
- args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
- )
- return FunctoolsPartialVariable(func_obj, args, keywords)
- elif is_typing(value):
- # typing.List, typing.Mapping, etc.
- self.install_guards(GuardBuilder.ID_MATCH)
- return TypingVariable(
- value,
- source=self.source,
- )
- elif np is not None and isinstance(value, np.generic):
- # numpy array scalars: convert to 0D arrays
- return self.wrap_numpy_ndarray(np.asarray(value))
- elif trace_rules.is_numpy(value):
- assert np
- self.install_guards(
- GuardBuilder.FUNCTION_MATCH
- if callable(value)
- else GuardBuilder.TYPE_MATCH
- )
- return NumpyVariable(value, source=self.source)
- elif trace_rules.is_numpy_dtype(value):
- self.install_guards(GuardBuilder.ID_MATCH)
- return NumpyDTypeVariable(value, source=self.source)
- elif trace_rules.is_numpy_type_info(value):
- if isinstance(value, np.iinfo):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- dt_source = AttrSource(self.source, "dtype")
- install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH))
- else:
- self.install_guards(GuardBuilder.ID_MATCH)
- return NumpyTypeInfoVariable(value, source=self.source)
- # NB: These can't be put in type_dispatch, they have to run later
- elif CollectiveFunctionRewriteVariable.can_rewrite(value):
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- return CollectiveFunctionRewriteVariable.create(
- self.tx,
- value,
- source=self.source,
- )
- elif istype(value, torch.autograd.function.FunctionMeta):
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- return AutogradFunctionVariable(
- value,
- source=self.source,
- )
- elif isinstance(value, torch.autograd.function.FunctionCtx):
- actual_saved_tensors = None
- try:
- actual_saved_tensors = value.saved_tensors
- except RuntimeError:
- pass
- saved_tensors = []
- guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)]
- if isinstance(actual_saved_tensors, tuple):
- saved_tensors_source = AttrSource(self.source, "saved_tensors")
- guards.append(
- saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)
- )
- for i, v in enumerate(actual_saved_tensors):
- saved_tensors.append(
- VariableBuilder(
- self.tx, GetItemSource(saved_tensors_source, i)
- )(v)
- )
- install_guard(*guards)
- return self.tx.output.side_effects.track_object_existing(
- value,
- AutogradFunctionContextVariable(
- value,
- source=self.source,
- saved_tensors=SavedTensorBox(saved_tensors),
- ),
- )
- elif (
- isinstance(value, types.MethodType)
- and istype(
- getattr(value, "__self__", None), torch.autograd.function.FunctionMeta
- )
- and getattr(value, "__name__", "") == "apply"
- and value == getattr(value.__self__, "apply", None)
- ):
- # handle aliased autograd function `apply` calls
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- return GetAttrVariable(
- AutogradFunctionVariable(
- value.__self__, source=AttrSource(self.source, member="__self__")
- ),
- "apply",
- )
- elif isinstance(value, torch._C._ImperativeEngine):
- self.install_guards(GuardBuilder.ID_MATCH)
- return AutogradEngineVariable(value, source=self.source)
- elif (
- value
- is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub
- ):
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- return LambdaVariable(
- lambda: UserFunctionVariable(
- torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks,
- ).call_function(
- self.tx,
- (self.tx.output.side_effects.get_ca_final_callbacks_var(),),
- {},
- )
- )
- elif isinstance(value, DynamoConfigPatchProxy):
- return DynamoConfigPatchVariable(value.changes)
- elif isinstance(value, ErrorOnGraphBreakDecoratorContextManager):
- return ErrorOnGraphBreakVariable(value.error_on_graph_break)
- elif callable(value) and trace_rules.lookup_callable(value) is not None:
- if trace_rules.is_callable_allowed(value):
- self.tx.output.has_user_defined_allowed_in_graph = True
- return trace_rules.lookup_callable(value).create_with_source(
- value, source=self.source
- )
- elif np and isinstance(value, np.number):
- return self.wrap_unspecialized_primitive(value)
- elif isinstance(value, HigherOrderOperator):
- if value is torch._higher_order_ops.invoke_subgraph:
- unimplemented_v2(
- gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph",
- context="",
- explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region",
- hints=[],
- )
- self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH)
- return TorchHigherOrderOperatorVariable.make(value, source=self.source)
- elif isinstance(value, torch.cuda.StreamContext):
- self.install_guards(GuardBuilder.ID_MATCH)
- stream_source = AttrSource(self.source, "stream")
- stream_var = VariableBuilder(self.tx, stream_source)(value.stream)
- return StreamContextVariable.create(self.tx, stream_var)
- elif isinstance(value, torch.Stream):
- self.install_guards(GuardBuilder.ID_MATCH)
- stream_proxy = self.tx.output.create_proxy(
- "call_function",
- type(value),
- (),
- {
- "stream_id": value.stream_id,
- "device_index": value.device_index,
- "device_type": value.device_type,
- },
- )
- set_example_value(stream_proxy.node, value)
- return StreamVariable(
- stream_proxy,
- value,
- value.device,
- source=self.source,
- )
- elif isinstance(value, (torch._C._SDPAParams)):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return SDPAParamsVariable.create(self.tx, value, self.source)
- elif isinstance(value, torch._functorch.pyfunctorch.FuncTorchInterpreter):
- self.install_guards(GuardBuilder.ID_MATCH)
- return FuncTorchInterpreterVariable(value)
- elif isinstance(value, torch.Event):
- self.install_guards(GuardBuilder.ID_MATCH)
- torch._dynamo.utils.store_user_object_weakref(value)
- event_proxy = self.tx.output.create_proxy(
- "call_function",
- torch._dynamo.utils.get_user_object_from_id,
- (id(value),),
- {},
- )
- set_example_value(event_proxy.node, value)
- return EventVariable(
- event_proxy,
- value,
- source=self.source,
- )
- elif (
- istype(value, contextlib.nullcontext)
- and inspect.getattr_static(value, "enter_result", None) is None
- ):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return NullContextVariable(source=self.source)
- elif KeyedJaggedTensorVariable.is_matching_object(value):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = KeyedJaggedTensorVariable(value, source=self.source)
- # TODO: this doing it manually is bad
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, torch.optim.Optimizer):
- self.install_guards(GuardBuilder.ID_MATCH)
- self.source = OptimizerSource(self.source)
- return OptimizerVariable(value, source=self.source)
- elif isinstance(value, torch.DispatchKeySet):
- self.install_guards(GuardBuilder.DISPATCH_KEY_SET_MATCH)
- return DispatchKeySetVariable(value)
- elif WorldMetaClassVariable.is_group_member_type(value):
- return WorldMetaClassVariable(value, source=self.source)
- elif ProcessGroupVariable.is_process_group(value):
- self.install_guards(GuardBuilder.ID_MATCH)
- return ProcessGroupVariable(value, source=self.source)
- elif DeviceMeshVariable.is_device_mesh(value):
- # TODO: see if we need to add custom guard instead of a simple ID_MATCH
- self.install_guards(GuardBuilder.EQUALS_MATCH)
- return DeviceMeshVariable(value, source=self.source)
- elif PlacementClassVariable.is_placement_type(value):
- # TODO: see if we need to add custom guard instead of a simple ID_MATCH
- self.install_guards(GuardBuilder.ID_MATCH)
- return PlacementClassVariable(value, source=self.source)
- elif PlacementVariable.is_placement(value):
- # TODO: see if we need to add custom guard instead of a simple ID_MATCH
- self.install_guards(GuardBuilder.EQUALS_MATCH)
- return PlacementVariable(
- value,
- source=self.source,
- )
- elif (
- id(value) in ITERTOOLS_TYPE_IDS
- and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS
- ):
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- return ItertoolsVariable(value, source=self.source)
- elif is_torch_sym(value):
- # Note: this doesn't handle nested symints.
- # For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo.
- # Concretely,
- # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
- # so that guards on the SymInts can be effectively applied on the original SymBool in user program.
- # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
- # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.
- source = (
- self.source
- if isinstance(value, torch.SymInt)
- else ConvertIntSource(self.source)
- )
- if value.node.has_hint():
- new_symint = (
- self.tx.output.shape_env.create_unspecified_symint_and_symbol(
- int(value.node.hint),
- source,
- dynamic_dim=DimDynamic.DYNAMIC,
- )
- )
- else:
- if isinstance(value, torch.SymBool):
- # We need to create an unbacked symint to replace the unbacked symbool.
- new_symint = self.tx.output.shape_env.create_unbacked_symint()
- else:
- # TODO (yidi): we need to figure out a way to propagate the guards
- # we accumulated when tracing the subggraph to outer shape_env. For normal symints,
- # this is automatically done by evaluating the guards once but this
- # will cause data-dependent error when we evaluate the outer unbacked symints.
- # The test case that triggers this graph break is test_cond_unbacked_symint_closure
- unimplemented_v2(
- gb_type="Attempted to wrap unbacked SymInt",
- context="",
- explanation="Unbacked SymInt input is not supported yet.",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(new_symint),
- new_symint,
- source=source,
- )
- sym_node_proxy.node.meta["grapharg"] = GraphArg(
- source,
- new_symint,
- False,
- None,
- is_tensor=False,
- example_strong_ref=new_symint,
- )
- # We bind the new_symint to graph input.
- sym_expr = new_symint.node.expr
- assert isinstance(sym_expr, sympy.Symbol), (
- f"{sym_expr} is not a basic Symbol."
- )
- self.tx.output.tracked_fakes.append(TrackedFake(new_symint, source, None))
- tracing_symint = (
- new_symint if isinstance(value, torch.SymInt) else new_symint == 1
- ) # cast it back to symbool for tracing
- return SymNodeVariable(sym_node_proxy, tracing_symint)
- elif isinstance(value, (JITFunction, Autotuner)):
- self.install_guards(GuardBuilder.ID_MATCH)
- return TritonKernelVariable(
- value,
- None, # No kernel idx provided
- None, # No grid provided
- source=self.source,
- )
- elif value is create_1d_tma_descriptor:
- return CreateTMADescriptorExperimentalVariable(rank=1)
- elif value is create_2d_tma_descriptor:
- return CreateTMADescriptorExperimentalVariable(rank=2)
- elif value is TensorDescriptor.from_tensor:
- return CreateTMADescriptorStableVariable()
- elif isinstance(value, torch.amp.autocast_mode.autocast):
- self.install_guards(GuardBuilder.ID_MATCH)
- return AutocastModeVariable(
- target_values=[
- value.device,
- value.fast_dtype,
- value._enabled,
- value._cache_enabled,
- ],
- source=self.source,
- )
- elif TorchCtxManagerClassVariable.is_matching_cls(value):
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- return TorchCtxManagerClassVariable(value, source=self.source)
- elif inspect.getattr_static(value, "__script_if_tracing_wrapper", False):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WrapperUserFunctionVariable(
- value, "__original_fn", source=self.source
- )
- elif is_lru_cache_wrapped_function(value):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source)
- elif value is traceback.clear_frames:
- return TracebackVariable(source=self.source)
- elif value is sys.exc_info or (
- sys.version_info >= (3, 11) and value is sys.exception
- ):
- return SysFunctionVariable(value, source=self.source)
- elif is_function_or_wrapper(value) and inspect.getattr_static(
- value, "_torchdynamo_inline", False
- ):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return WrapperUserFunctionVariable(
- value, "_torchdynamo_inline", source=self.source
- )
- elif value is functools.wraps:
- self.install_guards(GuardBuilder.ID_MATCH)
- return FunctoolsWrapsVariable(value, source=self.source)
- elif value is collections.namedtuple:
- self.install_guards(GuardBuilder.ID_MATCH)
- return CollectionsNamedTupleFunction(value, source=self.source)
- elif isinstance(
- value, types.BuiltinMethodType
- ) and BuiltinMethodVariable.is_supported_builtin_method(value):
- self.install_guards(GuardBuilder.ID_MATCH)
- return BuiltinMethodVariable(value, source=self.source)
- elif is_function(value) and value in (float.fromhex, float.hex):
- self.install_guards(GuardBuilder.ID_MATCH)
- return GetAttrVariable(
- BuiltinVariable(float, source=self.source),
- value.__name__,
- )
- elif is_function_or_wrapper(value):
- value, attr_name = unwrap_with_attr_name_if_wrapper(value)
- # For these wrappers, Dynamo points to the wrapped function,
- # so source needs to be updated as well.
- if attr_name is not None:
- self.source = AttrSource(self.source, attr_name)
- return trace_rules.lookup(value).create_with_source(
- value, source=self.source
- )
- elif value is random.Random:
- self.install_guards(GuardBuilder.ID_MATCH)
- return RandomClassVariable(source=self.source)
- elif istype(value, random.Random) and RandomVariable.is_supported_random_obj(
- value
- ):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = RandomVariable(value, source=self.source)
- self.tx.output.side_effects.track_mutable(value, result)
- return result
- # Don't use istype, since some python modules are not subclasses of types.ModuleType directly.
- # E.g, type(torch.ops) -> <class 'torch._ops._Ops'>,
- # type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'>
- elif isinstance(value, (types.ModuleType, replay_record.DummyModule)):
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- result = PythonModuleVariable(
- value,
- source=self.source,
- )
- self.tx.output.side_effects.track_object_existing(value, result)
- return result
- elif isinstance(value, types.MethodType) and isinstance(
- value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
- ):
- # don't let MethodTypes fall through to UserDefinedObject,
- # which doesn't support 'CALL_FUNCTION'
- # TODO(whc): Why do we limit this to methods on NNModules?
- # I don't have a good reason for this, but it preserves the existing behavior
- # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise.
- # I suspect we probably want to relax this check and dig deeper there.
- # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python,
- # but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here
- # and then `__func__` gets wrapped inside UserMethodVariable.
- self_obj = VariableBuilder(
- self.tx, source=AttrSource(self.source, "__self__")
- )(value.__self__)
- assert self_obj and isinstance(self_obj, VariableTracker), (
- "Failed to produce a valid self obj"
- )
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- return UserMethodVariable(
- value.__func__,
- self_obj,
- source=self.source,
- )
- elif isinstance(value, types.GetSetDescriptorType):
- # GetSet descriptors are C functions attached to an attribute lookup
- # using PyGetSetDef. Python, on attribute lookup, can decide to
- # create a new object on the fly, and therefore the `id` of the
- # descriptors is not guaranteed to be same for different attribute
- # accesses. Since these are unlikely to change during the program
- # execution, we can skip guarding on them.
- return GetSetDescriptorVariable(value)
- elif isinstance(value, types.MethodWrapperType):
- # Method-wrappers are written in C, and they are not guaranteed to
- # return the same object on attribute lookup. Therefore, we cannot
- # insert a FUNCTION_MATCH guard here. method-wrappers are very
- # unlikely to change, so its ok to skip the guard here.
- return MethodWrapperVariable(value)
- elif issubclass(type(value), type) and issubclass(value, BaseException):
- # match user defined exceptions
- self.install_guards(GuardBuilder.ID_MATCH)
- return UserDefinedExceptionClassVariable(value)
- elif issubclass(type(value), type):
- if value in (
- torch.utils.hooks.BackwardHook,
- torch.nn.Parameter,
- torch.nn.Buffer,
- ):
- # TODO(jansel): combine this case with the one above
- return trace_rules.lookup(value).create_with_source(
- value, source=self.source
- )
- if value is torch.autograd._unsafe_preserve_version_counter:
- self.install_guards(GuardBuilder.FUNCTION_MATCH)
- return PreserveVersionContextVariable.constructor(self.tx)
- if (
- # `value` must be a strict subclass of `torch.Tensor`
- issubclass(value, torch.Tensor)
- and value is not torch.Tensor
- # `TensorSubclassVariable` is not for subclass that overrides
- # `torch_dispatch`.
- and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__
- # `TensorSubclassVariable` would lead to construction of
- # `TensorWithTFOverrideVariable`, but we don't want that for
- # traceable wrapper subclasses (we wrap those subclass instances
- # into `TensorVariable`).
- and not is_traceable_wrapper_subclass_type(value)
- ):
- return TensorSubclassVariable(value, source=self.source)
- if not is_from_closure_source(self.source):
- # For closure source, the variable comes from LOAD_SUPER_ATTR,
- # which calls self.__class__. This is internal Cpython
- # implementation, and it is rare for the user to modify
- # self.__class__ manually.
- # For other cases, this is a userdefined class, so install an
- # ID_MATCH even if its a global variable.
- self.install_guards(GuardBuilder.ID_MATCH)
- return UserDefinedClassVariable(
- value,
- source=self.source,
- )
- elif TorchScriptObjectVariable.is_matching_cls(type(value)):
- from ..source import (
- FlattenScriptObjectSource,
- ScriptObjectQualifiedNameSource,
- )
- if torch._library.fake_class_registry.tracing_with_real(value):
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(value),
- value,
- source=self.source,
- )
- # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
- # setting example to be real value because these example values will be used
- # as example_inputs for user compiler.
- proxy.node.meta["grapharg"] = GraphArg(
- self.source, value, False, None, False, value
- )
- return TorchScriptObjectVariable.create(
- proxy,
- value,
- source=self.source,
- )
- # This exists to allow a smoother transition.
- # The implications are:
- # The script objects won't be tracked as proxies.
- # Methods on these objects won't show up in the graph.
- # The original script object might be mutated.
- if not hasattr(value, "__obj_flatten__"):
- return self.wrap_user_defined(value)
- # Install the guards on the fully qualified name of the script object
- LazyVariableTracker.realize_all(
- VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
- value._type().qualified_name() # type: ignore[attr-defined]
- )
- )
- # Install the guards on the content of the script object by setting the source
- # to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
- LazyVariableTracker.realize_all(
- VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
- value.__obj_flatten__()
- )
- )
- fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
- self.tx.output.fake_mode, value
- )
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(value),
- fake_script_obj,
- source=self.source,
- )
- # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
- # setting example to be real value because these example values will be used
- # as example_inputs for user compiler.
- proxy.node.meta["grapharg"] = GraphArg(
- self.source, value, False, None, False, fake_script_obj
- )
- return TorchScriptObjectVariable.create(
- proxy,
- fake_script_obj,
- source=self.source,
- )
- elif (
- isinstance(value, (dict, collections.OrderedDict))
- and type(value).__new__ is dict.__new__
- ):
- # Construct a dict_vt that will reside inside the UserDefinedDictVariable
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- # Guard on the key order
- self.tx.output.guard_on_key_order.add(self.source)
- # We need all the keys to be hashable. We do this within the
- # _HashableTracker class in dicts.py
- def build_key_value(i, k, v):
- base = self.get_source()
- source_key = ConstDictKeySource(base, i)
- key = LazyVariableTracker.create(k, source_key)
- source_value = DictSubclassGetItemSource(base, source_key)
- res_value = LazyVariableTracker.create(v, source_value)
- return key, res_value
- # Ensure that we call dict.keys and not value.keys (which can call
- # overridden keys method). In the C++ guards, we relied on
- # PyDict_Next to traverse the dictionary, which uses the internal
- # data structure and does not call the overridden keys method.
- result = dict(
- build_key_value(i, k, v)
- for i, (k, v) in enumerate(get_items_from_dict(value))
- )
- dict_vt = ConstDictVariable(
- result,
- user_cls=(
- collections.OrderedDict
- if isinstance(value, collections.OrderedDict)
- else dict
- ),
- mutation_type=ValueMutationExisting(),
- source=self.source,
- )
- # Force this to reconstruct on mutation to keep the reconstruction
- # bytecode simple
- dict_vt.should_reconstruct_all = True
- result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, tuple):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- # NB - Be careful in not triggering user code. Guards also work on
- # the underlying tuple data structure.
- output = [
- LazyVariableTracker.create(
- tuple.__getitem__(value, i),
- source=GetItemSource(self.get_source(), i),
- )
- for i in range(tuple.__len__(value))
- ]
- tuple_vt = TupleVariable(
- output, source=self.source, mutation_type=ValueMutationExisting()
- )
- result = UserDefinedTupleVariable(
- value, tuple_vt=tuple_vt, source=self.source
- )
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, list):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- # NB - Be careful in not triggering user code. Guards also work on
- # the underlying list data structure.
- output = [
- LazyVariableTracker.create(
- list.__getitem__(value, i),
- source=ListGetItemSource(self.get_source(), i),
- )
- for i in range(list.__len__(value))
- ]
- list_vt = ListVariable(
- output, source=self.source, mutation_type=ValueMutationExisting()
- )
- result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, (set, frozenset)):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- L = list(dict.fromkeys(value))
- output = [
- LazyVariableTracker.create(
- list.__getitem__(L, i),
- source=NonSerializableSetGetItemSource(self.get_source(), i),
- )
- for i in range(list.__len__(L))
- ]
- set_vt_cls = SetVariable if isinstance(value, set) else FrozensetVariable
- set_vt = set_vt_cls(
- output, source=self.source, mutation_type=ValueMutationExisting()
- )
- result = UserDefinedSetVariable(value, set_vt=set_vt, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif issubclass(type(value), MutableMapping):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = MutableMappingVariable(value, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif is_frozen_dataclass(value):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = FrozenDataClassVariable.create(self.tx, value, source=self.source)
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif isinstance(value, dict_keys):
- if all(ConstantVariable.is_literal(k) for k in value):
- # If the dict_keys object is passed from outside the compile region, it must either be passed along with
- # the corresponding dict object or treated as a set (when only the keys are passed into the compiled region).
- # - If it is passed along with the dict, the dict object itself is already guarded.
- # - If only the dict_keys object is passed, we add EQUALS_MATCH and SEQUENCE_LENGTH guards
- # to ensure it remains unchanged across multiple runs.
- items = [SourcelessBuilder.create(self.tx, v) for v in value]
- install_guard(
- self.get_source().make_guard(GuardBuilder.SEQUENCE_LENGTH),
- self.get_source().make_guard(GuardBuilder.EQUALS_MATCH),
- )
- return DictKeySetVariable(items, source=self.source)
- else:
- unimplemented_v2(
- gb_type="non-const keys in dict_keys",
- context=f"non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}",
- explanation="Dynamo expects dict_keys keys to be constants.",
- hints=[
- "Ensure your dict_keys keys are constants (e.g. int, float, strings)",
- ],
- )
- elif IntWrapperVariable.is_matching_object(value):
- from torch.export.dynamic_shapes import _DimHintType
- if value.dynamism is None or value.dynamism.type == _DimHintType.STATIC:
- return self.wrap_symint(value.val)
- elif value.dynamism.type == _DimHintType.DYNAMIC:
- log.debug(
- "%s marked %s via IntWrapper",
- self.source.name(),
- DimDynamic.DYNAMIC,
- )
- return self.wrap_symint(
- value.val,
- dynamism=DimDynamic.DYNAMIC,
- context=SymIntSymbolicContext(
- constraint=RelaxedUnspecConstraint(warn_only=False)
- ),
- )
- elif value.dynamism.type == _DimHintType.AUTO:
- log.debug(
- "%s marked %s via IntWrapper",
- self.source.name(),
- DimDynamic.DYNAMIC,
- )
- return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC)
- else:
- raise RuntimeError(f"Undefined dynamism {value.dynamism}")
- else:
- return self.wrap_user_defined(value)
- def wrap_user_defined(self, value: Any):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = UserDefinedObjectVariable(value, source=self.source)
- if not SideEffects.cls_supports_mutation_side_effects(type(value)):
- # don't allow STORE_ATTR mutation with custom __setattr__
- return result
- return self.tx.output.side_effects.track_object_existing(value, result)
- def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
- for item in value:
- if item is value:
- unimplemented_v2(
- gb_type="list elements are pointing to the list itself",
- context="",
- explanation="Dynamo does not support lists whose items reference to itself",
- hints=["Avoid using self referential list"],
- )
- if config.specialize_int and type(value) is torch.Size:
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value)
- # One can index a tensor with a list/tuple. Therefore, we need to
- # have a stricter match.
- self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
- # Tuples are immutable objects, so we should mark its items static. This
- # avoids wrapping of tuple items as symints. This helps for nn module
- # attributes like conv2d strides, dilations.
- if (
- istype(value, tuple)
- and all(ConstantVariable.is_literal(item) for item in value)
- and self.source.guard_source().is_unspecialized_nn_module()
- ):
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return TupleVariable([ConstantVariable.create(item) for item in value])
- output = [
- LazyVariableTracker.create(
- item,
- source=GetItemSource(self.get_source(), i),
- )
- for i, item in enumerate(value)
- ]
- maybe_gm = self.tx.output.local_scope.get("self")
- if isinstance(
- self.source, LocalSource
- ) and self.source.local_name in get_locals_to_steal(maybe_gm):
- # The input tensor list to dynamo from compiled autograd may contain activations
- # which are freed as they are used in inductor. Dynamo's default behavior is to
- # lift all tensors to the graph inputs, but this will cause dynamo to hold an
- # extra reference to the activation tensors and increase peak memory usage.
- # To allow freeing ASAP, we keep the list as graph argument to the dynamo output
- # graph, and unpack it locally.
- # e.g. instead of `def forward(self, L_inputs_0_, L_inputs_1_, ...):`, we have
- # `def forward(self, L_inputs_):`
- source = self.source
- assert isinstance(value, list)
- tensor_list_proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(value),
- value,
- source=source,
- )
- tensor_list_proxy.node.meta["steal_arg"] = True
- list_variable = wrap_fx_proxy_cls(
- target_cls=TensorVariable,
- tx=self.tx,
- proxy=tensor_list_proxy,
- example_value=value,
- subclass_type=None,
- source=source,
- )
- # Apply relevant logic from `VariableTracker.build(value[i])`
- # (except for the `create_graph_input` stuff).
- guards = []
- for i, tensor_variable in enumerate(list_variable.items):
- source_i = GetItemSource(base=source, index=i, index_is_slice=False)
- # access unpacked tensor from this list instead of from a lifted arg
- self.tx.output.input_source_to_var[source_i] = tensor_variable
- tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(
- value[i]
- )
- guard = functools.partial(
- GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i])
- )
- guards.append(source_i.make_guard(guard))
- install_guard(*guards, skip=1)
- grapharg = GraphArg(
- source,
- value,
- pass_arg_as_tensor=False,
- fake_tensor=None,
- is_tensor=False,
- )
- tensor_list_proxy.node.meta["grapharg"] = grapharg
- # The following is very important for maintaining the "python object
- # <==> variable tracker" 1-to-1 mapping, which is mainly handled via
- # `side_effects`. Note that constructing `tensor_variable` above
- # already adds it to graph arg, but we never registered it with
- # `side_effects`. The preemptive `realize` calls here basically
- # does that registration (at the end of `self.__call__`).
- #
- # A slightly cleaner alternative is to register the
- # `tensor_variable`s above with `side_effects` directly, and just
- # return the `list_variable`, but that breaks some tensor-subclass
- # related tests like `test_inputs_aliasing_bytecode_stack_restore`,
- # because `tensor_variable` is constructed via
- # `handle_traced_output`, which doesn't really expect/handle tensor
- # subclass.
- #
- # Eventually, we expect to fix remove all of these by having Dynamo
- # auto-boxing inputs to the compiled graph, see
- # https://github.com/pytorch/pytorch/issues/153701.
- for vt in output:
- vt.realize()
- result = BaseListVariable.cls_for_instance(value)(output, source=self.source)
- if istype(value, (list, collections.deque)):
- return self.tx.output.side_effects.track_mutable(value, result)
- return result
- def wrap_tuple_iterator(self, value: tuple_iterator):
- self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
- output = [
- VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
- tuple_iterator_getitem(value, i)
- )
- for i in range(tuple_iterator_len(value))
- ]
- result = TupleIteratorVariable(output, source=self.source)
- return self.tx.output.side_effects.track_mutable(value, result)
- def wrap_range_iterator(self, value: range_iterator):
- self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH)
- # Get all the values from the range iterator; no need to install guards
- # on items since `RANGE_ITERATOR_MATCH` guarantees the same items.
- items = [ConstantVariable.create(v) for v in copy.deepcopy(value)]
- result = ListIteratorVariable(items, source=self.source)
- return self.tx.output.side_effects.track_mutable(value, result)
- def wrap_slice_range(self, value: Union[slice, range]):
- items = [
- VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
- getattr(value, k)
- )
- for k in ("start", "stop", "step")
- ]
- self.install_guards(GuardBuilder.TYPE_MATCH)
- if isinstance(value, slice):
- return SliceVariable(items, source=self.source)
- else:
- return RangeVariable(items, source=self.source)
- def mark_static_input(self, value: torch.Tensor, guard: bool):
- from ..decorators import mark_static_address
- static_inputs_log.debug(
- "Marking static input %s, id: %s)", self.source.name(), id(value)
- )
- mark_static_address(value, guard=guard)
- # Check if we've seen this tensor before and update graph metadata if needed
- # As long as this runs before AOT this is sound
- if value in self.tx.output.side_effects:
- var = self.tx.output.side_effects[value]
- var.proxy.node.meta["tensor_dict"]["_dynamo_static_input_type"] = (
- value._dynamo_static_input_type
- )
- def wrap_module(self, value: torch.nn.Module):
- from ..eval_frame import OptimizedModule
- if len(value.__dict__) == 0:
- unimplemented_v2(
- gb_type="Uninitialized nn.Module",
- context=typestr(value),
- explanation=f"Attempted to trace an uninitialized nn.Module of type {typestr(value)}.",
- hints=[
- *graph_break_hints.USER_ERROR,
- "Ensure your nn.Module instance has called `super().__init__()`.",
- ],
- )
- if istype(value, OptimizedModule):
- # Check if the optimized module was disabled
- if inspect.getattr_static(value.forward, "_torchdynamo_disable", False):
- # This bytecode is mostly of kind LOAD_ATTR or LOAD_METHOD. If
- # we graph break here, Dynamo does not know how to create
- # continuation functions for such bytecodes. So, we delay the
- # graph break to CALL_FUNCTION.
- msg = inspect.getattr_static(
- value.forward, "_torchdynamo_disable_msg", None
- )
- return DelayGraphBreakVariable(
- source=self.source,
- msg=f"Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: {msg})",
- )
- self.install_guards(GuardBuilder.TYPE_MATCH)
- self.source = AttrSource(self.source, "_orig_mod")
- return self.wrap_module(value._orig_mod)
- if (
- isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
- and not config.allow_rnn
- ):
- unimplemented_v2(
- gb_type="Attempted to wrap RNN, GRU, or LSTM",
- context=str(value),
- explanation="Dynamo does not support RNN, GRU, or LSTM.",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- if getattr(value, "_is_fsdp_managed_module", False):
- # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
- # in fully_sharded_data_parallel.py for more information
- # we can't do this assert inside FSDP constructor,
- # since we don't know yet whether dynamo will be used
- if not getattr(value, "_fsdp_use_orig_params", False):
- unimplemented_v2(
- gb_type="FSDP with use_orig_params=False",
- context="",
- explanation="Dynamo only supports FSDP with use_orig_params=True",
- hints=[],
- )
- # Note on FSDP guarding
- # Eager FSDP already assumes (requires, but without enforcement)
- # that users don't mutate their model parameters/structure after
- # FSDP wrapping, because FSDP wouldn't notice or update its
- # FlatParams.
- #
- # Therefore, torch.compile can skip guarding on params or submodule
- # structure of fsdp_managed modules, by using FSDPNNModuleSource as
- # the guard source. This behavior is gated on
- # config.skip_fsdp_guards.
- self.install_guards(GuardBuilder.TYPE_MATCH)
- result = FSDPManagedNNModuleVariable(value, source=self.get_source())
- if not SideEffects.cls_supports_mutation_side_effects(type(value)):
- # don't allow STORE_ATTR mutation with custom __setattr__
- return result
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif mutation_guard.is_dynamic_nn_module(value, self.tx.export):
- # created dynamically, don't specialize on it
- # Note [Tracing a torch.compiled function]
- # when make_fx tracing a compiled function, we need
- if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy):
- value = value.get_base()
- self.source = AttrProxySource(self.source)
- if torch._dynamo.config.inline_inbuilt_nn_modules:
- freezing = is_parameter_freezing()
- # Guard against the case where user may overwrite named parameters
- # / named buffers
- # NOTE: This is not likely to happen but worth guarding to avoid
- # exception
- if (
- callable(value.named_parameters)
- and value.named_parameters.__func__
- is og_module_named_parameters_fn_ptr
- ):
- try: # catch TypeErrors in named_parameters() from unserializable nn modules
- for _, p in value.named_parameters():
- self.mark_static_input(p, guard=freezing)
- except TypeError as e:
- raise_observed_exception(type(e), self.tx, args=list(e.args))
- if (
- callable(value.named_buffers)
- and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr
- ):
- try: # catch TypeErrors in named_parameters() from unserializable nn modules
- for _, b in value.named_buffers():
- self.mark_static_input(b, guard=freezing)
- except TypeError as e:
- raise_observed_exception(type(e), self.tx, args=list(e.args))
- if freezing:
- # we need to add the module to tracing context
- # in order to allow its params to get invalidated
- # this will get cleaned up once compile ends
- self.tx.output.nn_modules[self.name] = value
- if (
- value.__module__.startswith(("torch.nn.modules", "torch.ao."))
- and not value.__module__.startswith("torch.nn.modules.container")
- ) or getattr(value.__class__, "_dynamo_marked_static", False):
- new_source = self.source
- if config.inline_inbuilt_nn_modules and (
- not self.tx.output.export or config.install_free_tensors
- ):
- # Export corner case - look at test_repros.py test_inlining_cornercase
- new_source = UnspecializedBuiltinNNModuleSource(self.source)
- result = UnspecializedBuiltinNNModuleVariable(value, source=new_source)
- install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
- else:
- new_source = self.source
- if config.inline_inbuilt_nn_modules and (
- not self.tx.output.export or config.install_free_tensors
- ):
- # Export corner case - look at test_repros.py test_inlining_cornercase
- new_source = UnspecializedNNModuleSource(self.source)
- result = UnspecializedNNModuleVariable(value, source=new_source)
- install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
- if not SideEffects.cls_supports_mutation_side_effects(type(value)):
- # don't allow STORE_ATTR mutation with custom __setattr__
- return result
- return self.tx.output.side_effects.track_object_existing(value, result)
- elif issubclass(
- value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
- ):
- self.install_guards(GuardBuilder.TYPE_MATCH)
- return UnspecializedNNModuleVariable(value, source=self.get_source())
- else:
- return self.tx.output.register_attr_or_module(
- value,
- self.name,
- source=self.get_source(),
- # Guards are added inside register_attr_or_module
- )
- def wrap_literal(self, value):
- if type(value) is int:
- # allowlist has higher precedence over specialization control.
- if is_dynamic_source(self.source.name()):
- log.debug("%s marked dynamic via source whitelist", self.source.name())
- return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC)
- if is_unbacked_source(self.source.name()):
- log.debug("%s marked unbacked via source whitelist", self.source.name())
- return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED)
- if not config.specialize_int:
- # unspecializing int by default, but still
- # specialize for the following conditions
- if is_int_specialization_case(value, self.source):
- recompile_hint = None
- if (
- self.source.guard_source().is_unspecialized_builtin_nn_module()
- or self.source.guard_source().is_unspecialized_nn_module()
- ):
- # This means that it is an integer from a NN module.
- # Dynamo considers nn module int attributes to be static
- # (a good heuristic). But a user might want to mark the
- # int attribute to be a symint, so track this integer
- # for recompilation later.
- recompile_hint = (
- "torch.compile considers integer attributes of the nn.Module to be static. "
- "If you are observing recompilation, you might want to make this integer dynamic "
- "using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this "
- "integer into a tensor."
- )
- process_automatic_dynamic(
- self.tx,
- self.source.name(),
- FrameStateSizeEntry.make_scalar(value),
- is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
- )
- self.install_guards(
- functools.partial(
- GuardBuilder.EQUALS_MATCH, recompile_hint=recompile_hint
- )
- )
- return ConstantVariable.create(value=value, source=self.source)
- return self.wrap_symint(value)
- elif not config.specialize_float and type(value) is float:
- return self.wrap_symfloat(value)
- else:
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- result = ConstantVariable.create(value=value, source=self.source)
- if isinstance(value, (list, set)):
- return self.tx.output.side_effects.track_mutable(value, result)
- return result
- def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
- if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
- raise InternalTorchDynamoError(
- "Cannot wrap a Tensor that has already been",
- "wrapped by this instance of Dynamo",
- )
- def wrap_tensor(self, value: torch.Tensor):
- source = self.get_source()
- # We cannot already be tracking the tensor, which implies
- # it would have already been wrapped
- assert value not in self.tx.output.side_effects
- is_static_input = get_static_address_type(value) is not None
- if (
- config.inline_inbuilt_nn_modules
- and not is_static_input
- and (
- isinstance(value, torch.nn.Parameter)
- # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior
- # compatible with previous behavior.
- or (source and source.guard_source().is_unspecialized_nn_module())
- )
- ):
- self.mark_static_input(value, guard=is_parameter_freezing())
- is_static_input = True
- # Install any tensors which are "free" variables; that is:
- # 1. Globals
- # 2. NonLocals
- # 3. tensors that are attributes of nn module
- should_install_free_tensor = config.install_free_tensors and (
- is_from_global_source(source)
- or is_from_nonlocal_source(source)
- or is_from_unspecialized_nn_module_source(source)
- )
- make_graph_attribute = is_static_input and (
- not config.inline_inbuilt_nn_modules
- or is_parameter_freezing()
- or torch._dynamo.config.prepare_freezing
- )
- if should_install_free_tensor or (
- (source.guard_source().is_specialized_nn_module() or make_graph_attribute)
- and not source.guard_source().is_fsdp_module()
- ):
- self.assert_not_wrapped_by_this_graph(value)
- return self.tx.output.register_attr_or_module(
- value, self.name, source=source
- )
- if get_static_address_type(value) == "guarded":
- # If it's a guarded tensor, we can install the parameter directly
- # into the Fx graph instead of lifting it as an input. Lifting
- # offers no benefit, such as regional compilation, since we still
- # guard on the tensor's ID. Moreover, installing it in the Fx graph
- # eliminates the pre-graph bytecode required to extract the tensor
- # from locals/globals, reducing overhead. This can lead to
- # significant cost savings, especially for optimizers handling many
- # tensors.
- self.install_guards(GuardBuilder.ID_MATCH)
- self.assert_not_wrapped_by_this_graph(value)
- return self.tx.output.register_attr_or_module(
- value, self.name, source=source
- )
- if is_constant_source(source):
- self.assert_not_wrapped_by_this_graph(value)
- return self.tx.output.register_attr_or_module(
- value,
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- source=source,
- # Guards are added inside register_attr_or_module
- )
- # NB: this just says we accessed a tensor from the same source again
- # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice).
- # This is distinct from two distinct sources mapping to the same
- # Tensor (per id())! No guard is necessary here. See below for the
- # other case.
- is_duplicate_tensor = source in self.tx.output.input_source_to_var
- if is_duplicate_tensor:
- return self.tx.output.input_source_to_var[source]
- options = {}
- subclass_type = infer_subclass_type(value)
- if subclass_type is not None:
- self.install_guards(GuardBuilder.TYPE_MATCH)
- if get_static_address_type(value) == "guarded":
- self.install_guards(GuardBuilder.ID_MATCH)
- # By this point, we should have deduplicated all tensors
- self.assert_not_wrapped_by_this_graph(value)
- if (
- isinstance(value, torch.Tensor)
- and value.is_nested
- and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor)
- ):
- unimplemented_v2(
- gb_type="Attempted to wrap strided NestedTensor",
- context="",
- explanation="torch.compile does not support strided NestedTensor",
- hints=[],
- )
- # TODO(pearu,sparse-team) - Add the corresponding SPARSE_TENSOR_MATCH guards
- if (
- isinstance(value, torch.Tensor)
- and is_sparse_any(value)
- and (not self.tx.export or not config.capture_sparse_compute)
- ):
- # A hot fix for sparse tensors + torch.compile. Support for
- # export + sparsity is being added but we need to create
- # SPARSE_TENSOR_GUARDS for guards to work properly.
- unimplemented_v2(
- gb_type="Attempted to wrap sparse Tensor",
- context="",
- explanation="torch.compile does not support sparse Tensors",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- if (
- safe_has_grad(value)
- and safe_grad(value) is not None
- and value.dtype != safe_grad(value).dtype
- ):
- unimplemented_v2(
- gb_type="dtype mismatch between tensor and its gradient",
- context=f"tensor dtype: {value.dtype}; grad dtype: {safe_grad(value).dtype}",
- explanation="Inconsistent dtype between tensor and its gradient. "
- "This can happen in FSDP and crashes meta tensor creation.",
- hints=[*graph_break_hints.SUPPORTABLE],
- )
- # tx.output has multiple tracers if we're introspecting HigherOrderOperator.
- # When we've discovered an untracked tensor, then we actually need
- # to get Dynamo to track the tensor (which is what this function does)
- # and put it as a graph input on the root tracer. Later on,
- # if the input is actually used in the body of the HigherOrderOperator,
- # then the relevant SubgraphTracer will lift it to being an input of
- # the subgraph.
- # See NOTE [HigherOrderOperator tracing design] for more details.
- example_value = wrap_to_fake_tensor_and_record(
- value, tx=self.tx, is_tensor=True, source=source
- )
- tensor_proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(value),
- example_value,
- source=source,
- )
- cache_real_value_when_export(self.tx, tensor_proxy, value)
- tensor_variable = wrap_fx_proxy(
- tx=self.tx,
- proxy=tensor_proxy,
- example_value=example_value,
- subclass_type=subclass_type,
- source=source,
- **options,
- )
- if value._is_view():
- # If value is a view, add its base tensor to the tracked fakes list.
- # This is so we are able to access the correct source for its symbolic
- # shape values, in case we need them.
- wrap_to_fake_tensor_and_record(
- value._base,
- tx=self.tx,
- source=AttrSource(source, "_base"),
- is_tensor=True,
- )
- guard_type = GuardBuilder.TENSOR_MATCH
- if isinstance(source, GradSource) and is_from_optimizer_source(source):
- guard_type = GuardBuilder.NOT_NONE_MATCH
- self.install_guards(
- functools.partial(
- guard_type,
- value=(
- value
- if isinstance(source, NumpyTensorSource)
- else TensorWeakRef(value)
- ),
- )
- )
- # We install TYPE_MATCH guards for traceable wrapper subclass object,
- # and recursively install corresponding guard for each inner attribute.
- if is_traceable_wrapper_subclass(value):
- self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
- self.install_guards(GuardBuilder.TYPE_MATCH)
- install_guard(
- SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
- )
- attrs, _ = value.__tensor_flatten__()
- for attr in attrs:
- inner_value = getattr(value, attr)
- inner_source = AttrSource(self.source, attr)
- LazyVariableTracker.realize_all(
- VariableBuilder(self.tx, inner_source)(inner_value)
- )
- self.tx.output.input_source_to_var[source] = tensor_variable
- assert "tensor_dict" not in tensor_proxy.node.meta
- tensor_proxy.node.meta["tensor_dict"] = _extract_tensor_dict(value)
- # Note: this information is conveyed via subclass_type now
- fake_tensor_value = tensor_variable.proxy.node.meta["example_value"]
- if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode:
- raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake")
- grapharg = GraphArg(source, value, False, fake_tensor_value)
- tensor_proxy.node.meta["grapharg"] = grapharg
- return tensor_variable
- def wrap_numpy_ndarray(self, value):
- assert np is not None
- assert isinstance(value, np.ndarray)
- source = NumpyTensorSource(self.get_source())
- from torch._numpy import _util
- readonly = not value.flags.writeable
- if readonly:
- try:
- value.flags.writeable = True
- except ValueError:
- # One can not easily make nditer elements writable,
- # but warning is not the end of the world
- assert isinstance(value.base, np.nditer)
- with torch_function_mode_stack_state_mgr.temp_restore_stack():
- try:
- tensor_value = _util._try_convert_to_tensor(value)
- if readonly:
- from torch._prims_common import clone_preserve_strides
- tensor_value = clone_preserve_strides(tensor_value)
- except NotImplementedError as e:
- # failed to convert to tensor, graph break
- unimplemented_v2(
- gb_type="failed to convert numpy.ndarray to Tensor",
- context=str(value),
- explanation="Exception encountered when attempting to convert numpy.ndarray to Tensor",
- hints=[],
- from_exc=e,
- )
- # We do this because we want the full behavior of guarding the numpy ndarray as if it were
- # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
- # that there's not another great way to do this atm.
- # This creates the right graphargs, as well as registration for guards in tensor names and shape env.
- LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value))
- example_value = wrap_to_fake_tensor_and_record(
- tensor_value,
- tx=self.tx,
- is_tensor=False,
- source=source,
- )
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(tensor_value),
- example_value,
- source=source,
- )
- cache_real_value_when_export(self.tx, proxy, tensor_value)
- options = {"source": source}
- numpy_ndarray_variable = wrap_fx_proxy_cls(
- target_cls=NumpyNdarrayVariable,
- tx=self.tx,
- proxy=proxy,
- example_value=example_value,
- **options,
- )
- self.tx.output.input_source_to_var[source] = numpy_ndarray_variable
- example_value = numpy_ndarray_variable.proxy.node.meta["example_value"]
- # pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be
- # converted to a tensor.
- grapharg = GraphArg(
- source,
- tensor_value,
- pass_arg_as_tensor=True,
- fake_tensor=example_value,
- is_tensor=True,
- example_strong_ref=tensor_value,
- )
- proxy.node.meta["grapharg"] = grapharg
- # TODO - Why do we need to set the source of the np ndarray vt back to
- # original source. Many tests fails.
- numpy_ndarray_variable.source = self.source
- return numpy_ndarray_variable
- def wrap_symint(
- self,
- value,
- dynamism: Optional[DimDynamic] = None,
- context: Optional[SymIntSymbolicContext] = None,
- ):
- assert type(value) is int
- if self.name in self.tx.output.unspec_variable_map:
- return self.tx.output.unspec_variable_map[self.name]
- shape_env = self.tx.output.shape_env
- if TracingContext.get().force_unspec_int_unbacked_size_like:
- wrapped_value = shape_env.create_unbacked_symint()
- _constrain_range_for_size(wrapped_value)
- self.tx.output.tracked_fakes.append(
- TrackedFake(wrapped_value, self.source, None)
- )
- # NB: We do not do float. For motivation, see
- # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
- # but the general idea is that we generate kernels that can
- # take unspecialized floats and use them in sizevar computation
- elif not is_constant_source(self.get_source()):
- if dynamism is None and torch._dynamo.config.specialize_int:
- # If specialize_int is False, also return
- # a constant (but this should have been handled
- # in the caller, TBH). But if `dynamism` is set, then actually
- # turn it into a symint
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value, source=self.source)
- name = self.source.name()
- frame_state_entry = process_automatic_dynamic(
- self.tx,
- name,
- FrameStateSizeEntry.make_scalar(value),
- is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
- )
- # TODO: This should be dynamic, as we in general do not
- # know if bare integers are actually going to be sizevars
- # and it is inappropriate to eagerly duck size them with
- # real sizevars
- normalized_source_name = normalize_source_name(self.source.name())
- base_source = self.source
- if isinstance(base_source, ChainedSource):
- base_source = base_source.get_base()
- if dynamism is not None:
- dynamic_dim = dynamism
- elif (
- config.automatic_dynamic_shapes
- and frame_state_entry.scalar is auto_dynamic
- ):
- set_feature_use("dynamo.automatic_dynamic_shapes", True)
- dynamic_dim = get_automatic_dynamic_shapes_mark_as()
- elif (
- isinstance(base_source, LocalSource)
- and base_source.dynamism is not None
- and dict(base_source.dynamism).get(normalized_source_name, {0: False})[
- 0
- ]
- ) or not config.assume_static_by_default:
- dynamic_dim = DimDynamic.DYNAMIC
- else: # assume_static_by_default
- # TODO: dynamic_dim = DimDynamic.STATIC should work but
- # for some reason it doesn't
- if frame_state_entry.scalar is auto_dynamic:
- set_feature_use("dynamo.automatic_dynamic_shapes", False)
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value)
- wrapped_value = shape_env.create_unspecified_symint_and_symbol(
- value,
- source=self.source,
- dynamic_dim=dynamic_dim,
- )
- self.tx.output.tracked_fakes.append(
- TrackedFake(wrapped_value, self.source, context)
- )
- else:
- assert is_constant_source(self.get_source())
- # TODO: Do I actually need guard for constant source?
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value, source=self.source)
- assert not isinstance(self.get_source(), RandomValueSource)
- install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
- options = {"source": self.get_source()}
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(wrapped_value),
- wrapped_value,
- source=self.get_source(),
- )
- sym_expr = wrapped_value.node.expr
- assert isinstance(sym_expr, sympy.Symbol), f"{sym_expr} is not a basic Symbol."
- self.tx.output.root_tracer.bound_symbols[sym_expr] = proxy
- unspec_var = SymNodeVariable(proxy, wrapped_value, **options)
- self.tx.output.unspec_variable_map[self.name] = unspec_var
- if not is_constant_source(self.get_source()):
- proxy.node.meta["grapharg"] = GraphArg(
- self.get_source(),
- wrapped_value,
- pass_arg_as_tensor=False,
- fake_tensor=None,
- is_tensor=False,
- example_strong_ref=wrapped_value,
- )
- return unspec_var
- def wrap_symfloat(self, value):
- # SymFloat wrapping is special. We first wrap it in the same way we
- # do an unspecialized primitive, and then we item() it into a
- # SymFloat. Removal of the item() call is left to a later FX pass,
- # mostly because that pass is more easily done after we have lowered
- # to ATen ops. (Dynamo doesn't do decomposition right now).
- if self.name in self.tx.output.unspec_variable_map:
- return self.tx.output.unspec_variable_map[self.name]
- frame_state_entry = process_automatic_dynamic(
- self.tx,
- self.source.name(),
- FrameStateSizeEntry.make_scalar(value),
- is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(),
- )
- # NB: we specialize on nan input, because our guard modeling in
- # ShapeEnv cannot deal with nan
- if (
- torch._dynamo.config.specialize_float
- or is_constant_source(self.get_source())
- or math.isnan(value)
- or math.isinf(value)
- # We don't support cudagraphs for now. Without this cudagraphs
- # break because they expect all cuda inputs but our tensorified
- # float will be a f64[] cpu tensor. Fixes the following test
- # when specialize_float=False
- # python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950
- or torch._inductor.config.triton.cudagraphs
- or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False)
- or (
- config.assume_static_by_default
- and frame_state_entry.scalar is not auto_dynamic
- )
- ):
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value, source=self.source)
- # NB: At the point we've gotten here, we don't assume static by
- # default. Since we have a guard mechanism, there isn't really any
- # downside to trying to be dynamic for float all the time. Unlike
- # ints, this won't make codegen perf worse. Modest cost to compile
- # time.
- wrapped_value = torch.tensor(value, dtype=torch.float64)
- # We don't support specializing floats for grad checking tensors
- # See https://github.com/pytorch/pytorch/pull/140828 for more
- # context.
- if torch._C._functorch.is_gradtrackingtensor(wrapped_value):
- self.install_guards(GuardBuilder.CONSTANT_MATCH)
- return ConstantVariable.create(value=value, source=self.source)
- # TODO: Switch RandomValueSource over to use this, this is more
- # accurate
- assert not isinstance(self.get_source(), RandomValueSource)
- install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
- # The FloatTensorSource here is just for pedantic correctness: if you
- # guard against an UnspecializedPythonVariable, you need to guard
- # against the tensor-ified version of the local, otherwise it's not a
- # Tensor. However, we never let the UnspecializedPythonVariable escape
- # here, so there should never actually be any guards against this
- # source.
- source = FloatTensorSource(self.get_source())
- options = {"source": source, "raw_value": value}
- # TODO: Maybe the tensor-ification should be built into the source,
- # rather than by special pattern match
- example_value = wrap_to_fake_tensor_and_record(
- wrapped_value, tx=self.tx, is_tensor=False, source=source
- )
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(wrapped_value),
- example_value,
- source=source,
- )
- cache_real_value_when_export(self.tx, proxy, wrapped_value)
- unspec_var = wrap_fx_proxy_cls(
- UnspecializedPythonVariable,
- tx=self.tx,
- proxy=proxy,
- example_value=example_value,
- **options,
- )
- assert isinstance(unspec_var, UnspecializedPythonVariable)
- self.tx.output.unspec_variable_map[self.name] = unspec_var
- if self.tx.export and not isinstance(self.get_source(), LocalSource):
- raise AssertionError(
- f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
- )
- fake_tensor_value = None
- example_value = unspec_var.proxy.node.meta["example_value"]
- assert is_fake(example_value)
- fake_tensor_value = example_value
- assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
- f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
- "({self.tx.fake_mode}) from InstructionTranslator"
- )
- # There's something a bit incoherent about pass_arg_as_tensor,
- # specifically regarding sources.
- #
- # Specifically, suppose we have "x: float" local argument. We
- # eventually end up with an UnspecializedPythonVariable denoting
- # torch.as_tensor(x)... but it's source is still L['x'] (which if you
- # accessed it directly is a float!) So you gotta be careful when
- # setting up your guards, because it's still going to be a float at
- # this point, the conversion happens only precisely at the point we're
- # actually calling the FX graph. This happens to be what we want for
- # shape guard generation, but it's kind of unintuitive.
- proxy.node.meta["grapharg"] = GraphArg(
- self.get_source(),
- wrapped_value,
- pass_arg_as_tensor=True,
- fake_tensor=fake_tensor_value,
- is_tensor=False,
- example_strong_ref=wrapped_value,
- )
- # Directly do item to bypass capture_scalar_outputs
- r = wrap_fx_proxy(
- self.tx,
- self.tx.output.create_proxy(
- "call_method",
- "item",
- *proxy_args_kwargs([unspec_var], {}),
- ),
- )
- self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None))
- get_metrics_context().set("tensorify_float_attempt", True, overwrite=True)
- return r
- def wrap_unspecialized_primitive(self, value):
- if self.name in self.tx.output.unspec_variable_map:
- return self.tx.output.unspec_variable_map[self.name]
- wrapped_value = torch.tensor(value)
- if not isinstance(self.get_source(), RandomValueSource):
- install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH))
- options = {"source": self.get_source()}
- options.update({"raw_value": value})
- example_value = wrap_to_fake_tensor_and_record(
- wrapped_value, tx=self.tx, is_tensor=False, source=self.get_source()
- )
- proxy = self.tx.output.root_tracer.create_graph_input(
- re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
- type(wrapped_value),
- example_value,
- source=self.get_source(),
- )
- cache_real_value_when_export(self.tx, proxy, wrapped_value)
- unspec_var = wrap_fx_proxy_cls(
- UnspecializedPythonVariable,
- tx=self.tx,
- proxy=proxy,
- example_value=example_value,
- **options,
- )
- self.tx.output.unspec_variable_map[self.name] = unspec_var
- if not is_constant_source(self.get_source()):
- if self.tx.export and not isinstance(self.get_source(), LocalSource):
- raise AssertionError(
- f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}"
- )
- fake_tensor_value = None
- if isinstance(unspec_var, ConstantVariable):
- # TODO: when can this happen?
- example_value = unspec_var.value
- else:
- example_value = unspec_var.proxy.node.meta["example_value"]
- assert is_fake(example_value)
- fake_tensor_value = example_value
- assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
- f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
- "({self.tx.fake_mode}) from InstructionTranslator"
- )
- proxy.node.meta["grapharg"] = GraphArg(
- self.get_source(),
- wrapped_value,
- pass_arg_as_tensor=True,
- fake_tensor=fake_tensor_value,
- is_tensor=False,
- example_strong_ref=wrapped_value,
- )
- return unspec_var
- def _dataclasses_fields_lambda(obj):
- if isinstance(obj, UserDefinedObjectVariable):
- value = obj.value
- else:
- unimplemented_v2(
- gb_type="dataclass fields failure",
- context=f"obj: {obj}; variable type: {type(obj)}",
- explanation=f"Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.",
- hints=[],
- )
- items = []
- for field in dataclasses.fields(value):
- source = None
- if obj.source:
- base_src = AttrSource(obj.source, "__dataclass_fields__")
- source = DictGetItemSource(base_src, field.name)
- items.append(UserDefinedObjectVariable(field, source=source))
- return TupleVariable(items)
- def _clone_input(value, fake_mode):
- if isinstance(value, torch.Tensor):
- # tensor subclasses will not be converted to FakeTensors and need to be cloned
- if not (
- isinstance(value, FakeTensor)
- or (
- # Is functional tensor fakeified by this instance of Dynamo
- torch._is_functional_tensor(value)
- and maybe_get_fake_mode(value) is fake_mode
- )
- or value.is_nested
- ):
- # NB: ensure strides are preserved
- value = clone_input(value)
- return value
- def wrap_fx_proxy(
- tx, proxy, example_value=None, subclass_type=None, **options
- ) -> VariableTracker:
- kwargs = {
- "tx": tx,
- "proxy": proxy,
- "example_value": example_value,
- "subclass_type": subclass_type,
- **options,
- }
- if subclass_type is None:
- return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
- else:
- result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
- result.install_global(tx)
- return result
- def cache_real_value_when_export(tx, proxy, example_value):
- if tx.export:
- # The legacy behavior for real value cache with subclasses was
- # to perform a clone WITHOUT preserving the subclass. It's
- # not entirely clear this is what you actually want though.
- with torch._C.DisableTorchFunctionSubclass():
- proxy.tracer.real_value_cache[proxy.node] = _clone_input(
- example_value, tx.fake_mode
- )
- # Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
- # Should be compositional instead
- #
- # This is a horribly complicated function that does too many things, to
- # explain what it does, let's first talk about the classic usage wrap_fx_proxy
- # for a TensorVariable. There are two primary modes of use:
- #
- # 1. Wrapping a pre-existing Tensor. In this case, example_value is set
- # to the pre-existing Tensor. (Note that this example_value will NOT
- # be the final example_value we put into node.meta['example_value'],
- # instead it is converted into a fake tensor using
- # wrap_to_fake_tensor_and_record and registered as a graph input.)
- #
- # 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In
- # this case, example_value is None (and we are going to figure it out
- # ourselves using FakeTensors, via get_fake_value, which will run
- # the operation represented by the (singular!) FX node referenced by
- # the passed in proxy.)
- #
- # The expectation is you end up with a Tensor output, and everything is
- # straightforwardly traced into the graph.
- #
- # In all cases, the returned `TensorVariable` subclass will have an `example_value`
- # and that `example_value` must be a `FakeTensor` produced by the currently running
- # instance of Dynamo.
- #
- # Upon closer inspection, you may notice that there are a slurry of non-Tensor
- # output cases in handle_traced_output. What gives? Well, we sometimes trace operations into the
- # graph that don't involve tensors.
- #
- # * Some operators return tuples; we need to recursively handle their
- # contents
- #
- # * Some operators have side effects that will affect subsequent AOTAutograd
- # tracing but don't otherwise return anything.
- #
- # * Some operators return symbolic ints/floats/bools which can go in the
- # graph and be traced (but only if they're actually symbolic! If they're
- # static you don't want to put them in the graph, which means you
- # shouldn't call this function.)
- #
- # The common theme is that you only use this function WHEN YOU ARE TRACING
- # SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call
- # this function without a proxy.
- def wrap_fx_proxy_cls(
- target_cls, tx, proxy, example_value=None, subclass_type=None, **options
- ):
- if example_value is None:
- return _wrap_fx_proxy(
- target_cls, tx, proxy, example_value, subclass_type, **options
- )
- elif isinstance(example_value, torch.Tensor):
- return _wrap_fx_preexisting_tensor(
- target_cls, tx, proxy, example_value, subclass_type, **options
- )
- else:
- # This will skip tracing an op and recursively reinvoke wrap_fx_proxy_cls on supported
- # data structures. In essence this just handles tracing some other value which may
- # contain Fake Tensors or is otherwise proxyable.
- return handle_traced_output(
- example_value, tx, proxy, options, subclass_type, target_cls
- )
- # This is 1 above (wrapping a preexisting tensor)
- def _wrap_fx_preexisting_tensor(
- target_cls, tx, proxy, tensor, subclass_type=None, **options
- ):
- from ..symbolic_convert import InstructionTranslatorBase
- assert isinstance(tensor, torch.Tensor), (
- f"_wrap_fx_preexisting_tensor expected tensor, got {type(tensor)}"
- )
- assert isinstance(tx, InstructionTranslatorBase)
- if "guards" in options and options["guards"] is not None:
- tx.output.guards.update(options["guards"])
- # Placeholders always carry example_value in node.meta.
- # non-placeholders always have no example_value in node.meta
- if proxy.node.op == "placeholder":
- assert "example_value" in proxy.node.meta, (
- f"placeholder {proxy} doesn't have 'example_value' in node.meta"
- )
- else:
- assert "example_value" not in proxy.node.meta, (
- f"{proxy.node.meta['example_value']}"
- )
- # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
- with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
- # Handle recursive calls here
- if maybe_get_fake_mode(tensor) is tx.fake_mode:
- pass
- else:
- cache_real_value_when_export(tx, proxy, tensor)
- if tx.export:
- # The legacy behavior for real value cache with subclasses was
- # to perform a clone WITHOUT preserving the subclass. It's
- # not entirely clear this is what you actually want though.
- with torch._C.DisableTorchFunctionSubclass():
- proxy.tracer.real_value_cache[proxy.node] = _clone_input(
- tensor, tx.fake_mode
- )
- # NB: If we're ignoring subclass, then the expectation is you will
- # take the returned TensorVariable and wrap it into a more
- # accurate TensorVariable that is able to track subclass-ness;
- # otherwise this is wrong!
- kwargs = {
- "is_tensor": target_cls
- in (TensorVariable, TensorWithTFOverrideVariable),
- }
- assert "source" in options and options["source"] is not None
- kwargs["source"] = options["source"]
- tensor = wrap_to_fake_tensor_and_record(tensor, tx=tx, **kwargs)
- if tensor.device.type != "meta" and (
- maybe_get_fake_mode(tensor) is not tx.fake_mode
- ):
- raise InternalTorchDynamoError(
- "`tensor` needs to be a `FakeTensor`"
- f"wrapped by this instance of Dynamo. Found: {tensor}"
- )
- return construct_tensor_variable(
- target_cls, tx, proxy, tensor, subclass_type, options
- )
- # This is 2 in the above comment (wrapping the output of a traced op)
- def _wrap_fx_proxy(
- target_cls, tx, proxy, example_value=None, subclass_type=None, **options
- ):
- from ..symbolic_convert import InstructionTranslatorBase
- assert isinstance(tx, InstructionTranslatorBase)
- if "guards" in options and options["guards"] is not None:
- tx.output.guards.update(options["guards"])
- assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}"
- # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
- with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
- # with preserve_rng_state():
- # only allow_non_graph_fake in this instance because we handle the non-fake
- # cases properly below.
- example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
- return handle_traced_output(
- example_value, tx, proxy, options, subclass_type, target_cls
- )
- # This handles wrapping of the output of an op traced into the graph
- def handle_traced_output(example_value, tx, proxy, options, subclass_type, target_cls):
- import torch._functorch.vmap
- import torch._subclasses.fake_tensor
- import torch._utils
- if isinstance(example_value, torch.Tensor):
- var = construct_tensor_variable(
- target_cls, tx, proxy, example_value, subclass_type, options
- )
- # NOTE: [Side effect tracking for newly constructed tensor]
- # For newly constructed objects that have mutable attributes, we usually
- # construct their VariableTracker via `track_object_new`, but since
- # tensor variable construction is a bit different, we handle them
- # specially here. This ensures that codegen will actually generate the
- # attribute mutations on this tensor.
- #
- # NOTE we pass a dummy object as the `item` argument to avoid
- # constructing a dummy _tensor_ object. The object isn't used for
- # newly constructed VTs anyways.
- tx.output.side_effects._track_obj(
- proxy, var, mutation_type_cls=AttributeMutationNew
- )
- return var
- elif (
- hasattr(proxy.node.target, "__name__")
- and proxy.node.target.__name__ == "set_state"
- and isinstance(proxy.node.target.__self__, torch._C.Generator)
- or proxy.node.target == torch.random.set_rng_state
- ):
- return TorchInGraphFunctionVariable(proxy.node.target)
- elif (
- proxy.node.target == torch._C._DisableFuncTorch
- or proxy.node.target == torch.cuda._is_in_bad_fork
- ):
- return UserDefinedObjectVariable(example_value)
- elif istype(example_value, torch.Size) and all(
- isinstance(x, int) for x in example_value
- ):
- sizes = [ConstantVariable.create(x) for x in example_value]
- return SizeVariable(sizes, **options)
- elif isinstance(example_value, (tuple, list)):
- set_example_value(proxy.node, example_value)
- unpacked = []
- for i, val in enumerate(example_value):
- if val is None:
- # nn.MultiheadAttention() can return None, see issue #175
- unpacked.append(
- ConstantVariable.create(None, **options),
- )
- else:
- proxy_i = proxy.tracer.create_proxy(
- kind="call_function",
- target=operator.getitem,
- args=(proxy, i),
- kwargs={},
- )
- if "source" in options:
- # This path should only trigger for list stealing, so it's
- # safe to use `GetItemSource`.
- assert isinstance(example_value, list)
- source = options["source"]
- options_i = options.copy()
- options_i["source"] = GetItemSource(
- base=source, index=i, index_is_slice=False
- )
- else:
- # use the same options object as parent
- options_i = options
- # WARNING: this assumes the same target_cls as this tuple/list call
- unpacked.append(
- wrap_fx_proxy_cls(
- target_cls=target_cls,
- tx=tx,
- proxy=proxy_i,
- example_value=val,
- **options_i,
- )
- )
- if isinstance(example_value, torch.Size):
- # NB: Keep the old proxy around. See SizeVariable for an
- # explanation why
- return SizeVariable(unpacked, proxy, **options)
- elif istype(example_value, tuple):
- return TupleVariable(unpacked, **options)
- elif istype(example_value, (list, immutable_list)):
- return ListVariable(unpacked, **options)
- else:
- assert (
- example_value.__class__.__module__ == "torch.return_types"
- or hasattr(example_value, "_fields")
- ), (
- f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}"
- )
- return NamedTupleVariable(unpacked, example_value.__class__, **options)
- elif example_value is None or proxy.node.target is torch.manual_seed:
- return ConstantVariable.create(None, **options)
- elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
- tx.output.current_tracer.track_produced_symints(example_value, proxy)
- set_example_value(proxy.node, example_value)
- return SymNodeVariable(proxy, example_value, **options)
- elif (
- inspect.isclass(proxy.node.target)
- and issubclass(proxy.node.target, torch.Stream)
- ) or proxy.node.target in [
- device_interface.current_stream
- for _, device_interface in get_registered_device_interfaces()
- ]:
- set_example_value(proxy.node, example_value)
- return StreamVariable(proxy, example_value, example_value.device, **options)
- elif (
- inspect.isclass(proxy.node.target)
- and issubclass(proxy.node.target, torch.Event)
- ) or proxy.node.target in [
- device_interface.Event
- for _, device_interface in get_registered_device_interfaces()
- ]:
- set_example_value(proxy.node, example_value)
- return EventVariable(proxy, example_value, **options)
- elif proxy.node.target == "query" and proxy.node.op == "call_method":
- set_example_value(proxy.node, example_value)
- return ConstantVariable(example_value, **options)
- elif (
- example_value is not None
- and isinstance(example_value, torch.Event)
- and proxy.node.target == "record_event"
- and proxy.node.op == "call_method"
- ):
- set_example_value(proxy.node, example_value)
- return EventVariable(proxy, example_value, **options)
- elif isinstance(example_value, int) and (
- proxy.node.target
- in [
- torch.sym_int,
- getattr,
- operator.getitem,
- torch._utils._element_size,
- torch.seed,
- operator.mod,
- torch._functorch.vmap._validate_and_get_batch_size,
- torch._functorch.predispatch._vmap_increment_nesting,
- torch._functorch.predispatch._vmap_decrement_nesting,
- # some mac builds are missing torch.distributed.get_rank()
- getattr(torch.distributed, "get_rank", _missing),
- getattr(torch.distributed, "get_world_size", _missing),
- # This always wants to be in the graph, even if the constraint
- # results in a constant int
- torch._constrain_as_size,
- ]
- or (
- # TODO: this is a little sus, because we didn't check what the self is
- proxy.node.op == "call_method" and proxy.node.target in ["bit_length"]
- )
- ):
- set_example_value(proxy.node, example_value)
- return ConstantVariable.create(example_value, **options)
- elif isinstance(example_value, torch.backends.cuda.SDPAParams):
- from .sdpa import SDPAParamsVariable
- set_example_value(proxy.node, example_value)
- return SDPAParamsVariable(proxy, **options)
- elif isinstance(example_value, bool) and (
- proxy.node.target
- in [
- torch._C._are_functorch_transforms_active,
- torch._C._functorch.is_batchedtensor,
- torch.backends.cuda.is_flash_attention_available,
- torch.backends.cuda.can_use_flash_attention,
- torch.backends.cuda.can_use_efficient_attention,
- "is_integer",
- ]
- + list(supported_const_comparison_op_values.keys())
- ):
- set_example_value(proxy.node, example_value)
- return ConstantVariable.create(example_value, **options)
- elif isinstance(example_value, (int, float, bool)) and (
- proxy.node.target is call_torchbind
- or proxy.node.target is flat_apply
- or (proxy.node.op == "call_method" and proxy.node.target == "item")
- ):
- set_example_value(proxy.node, example_value)
- return ConstantVariable.create(example_value, **options)
- elif isinstance(example_value, float) or proxy.node.target in ["hex", "__round__"]:
- set_example_value(proxy.node, example_value)
- return ConstantVariable.create(example_value, **options)
- else:
- unimplemented_v2(
- gb_type="torch.* op returned non-Tensor",
- context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}",
- explanation="torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output",
- hints=[],
- )
- def infer_subclass_type(value):
- if type(value) in (
- torch.Tensor,
- torch.nn.Parameter,
- torch._subclasses.fake_tensor.FakeTensor,
- torch._subclasses.functional_tensor.FunctionalTensor,
- ) or is_traceable_wrapper_subclass(value):
- # Ordinarily, we would fakeify a tensor so that it can get dynamic
- # shapes and be computed on without triggering actual operations.
- # However, how can we fakeify a tensor subclass? Ordinary
- # inheritance (nor multiple inheritance) won't work work.
- #
- # Instead, our plan is to *manually simulate* the tensor subclass
- # inheriting from a fake tensor with dynamo. This means our
- # data representation for a tensor subclass will be a fake tensor
- # + tensor subclass type + any extra data the subclass may have
- # been storing on the tensor. Because all Python accesses are
- # mediated through TensorWithTFOverrideVariable, we can ensure
- # that we dispatch differently, e.g., according to
- # __torch_function__
- #
- # To simplify things for now, the __dict__ tracking bits haven't
- # been implemented yet, but they can be added into this design at
- # a later point in time.
- return None
- else:
- return type(value)
- def get_specialized_props(target_cls, tx, example_value, subclass_type):
- specialized_props = target_cls.specialize(example_value)
- # TODO: not sure about this fake mode test
- if (
- isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
- and example_value.fake_mode is tx.fake_mode
- ):
- if subclass_type:
- tensor_type = subclass_type
- elif isinstance(example_value, torch.nn.Parameter):
- tensor_type = torch.nn.Parameter
- elif isinstance(example_value, torch.nn.Buffer):
- tensor_type = torch.nn.Buffer
- else:
- tensor_type = torch.Tensor
- specialized_props["class_type"] = tensor_type
- return specialized_props
- def construct_tensor_variable(
- target_cls, tx, proxy, example_value, subclass_type, options
- ):
- """
- Actually construct a tensor variable after all the pre-processing from
- wrapping a pre-existing or newly created tensor value.
- """
- # NB: In most (all?) cases, this does not actually do a clone.
- # (WARNING: this means that if we mutate metadata on the fake
- # tensor, the stored example value will update too!)
- example_value = _clone_input(example_value, tx.fake_mode)
- set_example_value(proxy.node, example_value)
- # We bind the unbacked symints in sizes/trdies of tensor lazily.
- # So that subgraphs can access the unbacked symbol's proxy in parent graph
- # when lifting unbacked symbols of input tensors to subgraph inputs.
- # We do it lazily because the tensor may not be used in subgraphs.
- if proxy.node.op != "placeholder":
- tx.output.current_tracer.track_produced_symints(example_value, proxy)
- options.update(get_specialized_props(target_cls, tx, example_value, subclass_type))
- return target_cls(proxy, **options)
- def get_automatic_dynamic_shapes_mark_as():
- if config.automatic_dynamic_shapes_mark_as == "dynamic":
- return DimDynamic.DYNAMIC
- elif config.automatic_dynamic_shapes_mark_as == "unbacked":
- return DimDynamic.SIZE_LIKE_UNBACKED
- elif config.automatic_dynamic_shapes_mark_as == "oblivious":
- return DimDynamic.OBLIVIOUS_SIZE
- else:
- raise ValueError(
- f"invalid automatic_dynamic_shapes_mark_as = {config.automatic_dynamic_shapes_mark_as}"
- )
- _DYNAMIC_SOURCES: Optional[set[str]] = None
- _DYNAMIC_SOURCES_CONFIG_HASH: Optional[int] = None
- def get_dynamic_sources() -> set[str]:
- global _DYNAMIC_SOURCES, _DYNAMIC_SOURCES_CONFIG_HASH
- current_hash = hash(torch.compiler.config.dynamic_sources)
- # If we have already calculated the sources and the config hasn't changed, return cached result
- if _DYNAMIC_SOURCES is not None and _DYNAMIC_SOURCES_CONFIG_HASH == current_hash:
- return _DYNAMIC_SOURCES
- # Config has changed or first time, (re)calculate the sources
- _DYNAMIC_SOURCES = {
- s
- for s in torch.compiler.config.dynamic_sources.replace(" ", "").split(",")
- if s
- }
- _DYNAMIC_SOURCES_CONFIG_HASH = current_hash
- return _DYNAMIC_SOURCES
- def is_dynamic_source(source_name: str) -> bool:
- dynamic_sources = get_dynamic_sources()
- for pattern in dynamic_sources:
- if pattern == source_name or re.match(pattern, source_name):
- log.debug(
- "%s was marked dynamic due to dynamic source allowlist pattern: %s",
- source_name,
- pattern,
- )
- return True
- return False
- def record_automatic_dynamic(
- tx: "InstructionTranslator", name: str, e: torch.Tensor
- ) -> FrameStateSizeEntry:
- # This mimics stride inference algorithm in _create_symbolic_sizes_strides_storage_offset
- ex_size = e.size()
- if not is_sparse_any(e):
- ex_stride = e.stride()
- dim = e.dim()
- stride = [None] * dim
- pending = [(ex_stride[i], -i) for i in range(dim)]
- pending.sort(key=_nested_int_aware_sort)
- candidates = {}
- for i_stride, neg_i in pending:
- i = -neg_i
- stride[i] = candidates.get(i_stride, i_stride)
- candidates.setdefault(i_stride * ex_size[i], InferStride(i))
- else:
- stride = []
- return process_automatic_dynamic(
- tx, name, FrameStateSizeEntry.make_tensor(tuple(ex_size), tuple(stride))
- )
- _UNBACKED_SOURCES: Optional[set[str]] = None
- _UNBACKED_SOURCES_CONFIG_HASH: Optional[int] = None
- def get_unbacked_sources() -> set[str]:
- global _UNBACKED_SOURCES, _UNBACKED_SOURCES_CONFIG_HASH
- current_hash = hash(torch.compiler.config.unbacked_sources)
- # If we have already calculated the sources and the config hasn't changed, return cached result
- if _UNBACKED_SOURCES is not None and _UNBACKED_SOURCES_CONFIG_HASH == current_hash:
- return _UNBACKED_SOURCES
- # Config has changed or first time, (re)calculate the sources
- _UNBACKED_SOURCES = {
- s
- for s in torch.compiler.config.unbacked_sources.replace(" ", "").split(",")
- if s
- }
- _UNBACKED_SOURCES_CONFIG_HASH = current_hash
- return _UNBACKED_SOURCES
- def is_unbacked_source(source_name: str) -> bool:
- unbacked_sources = get_unbacked_sources()
- for pattern in unbacked_sources:
- if pattern == source_name or re.match(pattern, source_name):
- log.debug(
- "%s was marked unbacked due to unbacked source allowlist pattern: %s",
- source_name,
- pattern,
- )
- return True
- return False
- # Performs automatic dynamic dim determination.
- # Returns a SymbolicContext
- def _automatic_dynamic(
- e, tx, source, static_shapes, outer_only=False
- ) -> SymbolicContext:
- # strided NT not supported
- if e.is_nested and not isinstance(
- e, torch.nested._internal.nested_tensor.NestedTensor
- ):
- unimplemented_v2(
- gb_type="Encountered strided NestedTensor in automatic dynamic dim determination",
- context="",
- explanation="torch.compile does not support strided NestedTensor",
- hints=[],
- )
- name = source.name()
- prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None)
- shape_env_to_source_to_symbol_cache = (
- prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None
- )
- # Get base context if the tensor is a view
- view_base_context: Optional[SymbolicContext] = None
- if e._is_view():
- base_source = AttrSource(source, "_base")
- view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes)
- if is_traceable_wrapper_subclass(e) and not outer_only:
- # Get symbolic context for outer tensor
- outer_context = _automatic_dynamic(
- e, tx, source, static_shapes, outer_only=True
- )
- # Get symbolic contexts for inner tensors
- inner_contexts = {} # mapping from attr -> symbolic context
- attrs, _ = type(e).__tensor_flatten__(e)
- for attr in attrs:
- inner_tensor = getattr(e, attr)
- inner_source = AttrSource(source, attr)
- inner_contexts[attr] = _automatic_dynamic(
- inner_tensor, tx, inner_source, static_shapes
- )
- return SubclassSymbolicContext(
- dynamic_sizes=outer_context.dynamic_sizes,
- dynamic_strides=outer_context.dynamic_strides,
- constraint_sizes=outer_context.constraint_sizes,
- constraint_strides=outer_context.constraint_strides,
- view_base_context=view_base_context,
- tensor_source=outer_context.tensor_source,
- shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache,
- inner_contexts=inner_contexts,
- )
- if static_shapes and not is_dynamic_source(name):
- return StatefulSymbolicContext(
- dynamic_sizes=[DimDynamic.STATIC] * e.dim(),
- dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(),
- constraint_sizes=[None] * e.dim(),
- constraint_strides=[None] * e.dim(),
- view_base_context=view_base_context,
- tensor_source=source,
- shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
- )
- # We preserve the dynamism of inputs. For example, when users call
- # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
- from torch.fx.experimental.symbolic_shapes import is_nested_int
- if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()):
- return StatefulSymbolicContext(
- dynamic_sizes=[
- DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
- for s in e.size()
- ],
- dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(),
- constraint_sizes=[None] * e.dim(),
- constraint_strides=[None] * e.dim(),
- view_base_context=view_base_context,
- tensor_source=source,
- shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
- )
- # Prep for automatic dynamic
- frame_state_entry = record_automatic_dynamic(tx, name, e)
- # TODO: index export_constraints ahead of time so we don't have to
- # do a linear scan every time here
- t_id = id(e)
- dim2constraint = {}
- def update_dim2constraint(dim, constraint_range, name):
- if dim in dim2constraint:
- from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
- old_constraint_range, old_name = dim2constraint[dim]
- new_constraint_range = StrictMinMaxConstraint(
- vr=constraint_range.vr & old_constraint_range.vr,
- warn_only=False,
- )
- # It is possible for (non-None) old_name and name to be different
- # but this will only happen the corresponding Dims can be derived equal.
- new_name = old_name or name
- dim2constraint[dim] = new_constraint_range, new_name
- else:
- dim2constraint[dim] = constraint_range, name
- from torch.export.dynamic_shapes import _RelaxedConstraint
- if tx.output.export_constraints:
- for constraint in tx.output.export_constraints:
- if isinstance(constraint, _RelaxedConstraint):
- continue
- if constraint.t_id == t_id:
- update_dim2constraint(
- constraint.dim, constraint.constraint_range, constraint.name
- )
- dynamic_sizes = []
- dynamic_strides = []
- constraint_sizes = []
- constraint_strides = []
- specialize_on = []
- for i in range(e.dim()):
- # NB: mark dynamic has precedence over static
- marked_strict_unbacked = i in getattr(
- e, "_dynamo_strict_unbacked_indices", set()
- )
- marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set())
- marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
- marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
- marked_static = i in getattr(e, "_dynamo_static_indices", set())
- specialize_on.append(getattr(e, "_specialize_on", {}).get(i, []))
- # Reflect the user directive in the frame_state
- # For dynamic, apply None always
- normalized_source_name = normalize_source_name(source.name())
- base_source = source
- if isinstance(base_source, ChainedSource):
- base_source = base_source.get_base()
- if marked_dynamic or (
- isinstance(base_source, LocalSource)
- and base_source.dynamism is not None
- and dict(base_source.dynamism).get(normalized_source_name, {i: False})[i]
- ):
- # TODO: This can be batched
- # TODO: Doing this here is kind of sus, maybe better to set this
- # up when we initially created the FrameStateSizeEntry to bong
- # into the mutable state
- log.debug("automatic dynamic %s marked dynamic", name)
- mark_size = [auto_unset] * e.dim()
- mark_size[i] = auto_dynamic
- frame_state_entry |= FrameStateSizeEntry.make_size(size=mark_size)
- # NB: both static and dynamic have precedence over
- automatic_dynamic_size = (
- config.automatic_dynamic_shapes and frame_state_entry.is_size_dynamic(i)
- )
- # NB: previously, if size was dynamic, we wouldn't make its stride
- # dynamic. But now, because of InferStride concept, we will properly
- # not make stride dynamic even if it's wobbling
- automatic_dynamic_stride = (
- config.automatic_dynamic_shapes and frame_state_entry.is_stride_dynamic(i)
- )
- if is_dynamic_source(name):
- log.debug("%s marked dynamic via source whitelist", name)
- automatic_dynamic_size = True
- if is_unbacked_source(name):
- log.debug("%s marked unbacked via source whitelist", name)
- automatic_dynamic_size = True
- automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride
- # We will process constraints first, as they will imply that we
- # have a dynamic dimension
- # Precedence: export constraints > eager constraints
- constraint = dim2constraint.get(i)
- if constraint is None:
- constraint_size = None
- constraint_stride = None
- if marked_dynamic and not config.allow_ignore_mark_dynamic:
- # constraint_stride is deliberaly kept None because no easy way to provide value ranges for mark dynamic
- constraint_stride = None
- if hasattr(e, "_dynamo_dynamic_range"):
- dim_range = [
- dr for dr in e._dynamo_dynamic_range if dr.dim == i
- ].pop()
- if dim_range.min is None and dim_range.max is None:
- constraint_size = RelaxedUnspecConstraint(warn_only=False)
- else:
- from torch.fx.experimental.symbolic_shapes import (
- StrictMinMaxConstraint,
- )
- constraint_size = StrictMinMaxConstraint(
- vr=ValueRanges(lower=dim_range.min, upper=dim_range.max),
- warn_only=False,
- )
- else:
- constraint_size = RelaxedUnspecConstraint(warn_only=False)
- elif marked_strict_unbacked:
- constraint_size = RelaxedUnspecConstraint(warn_only=False)
- elif not marked_static and automatic_dynamic:
- set_feature_use("dynamo.automatic_dynamic_shapes", True)
- if automatic_dynamic_size:
- constraint_size = RelaxedUnspecConstraint(warn_only=True)
- if automatic_dynamic_stride:
- constraint_stride = RelaxedUnspecConstraint(warn_only=True)
- else:
- if not marked_static and not config.automatic_dynamic_shapes:
- set_feature_use("dynamo.automatic_dynamic_shapes", False)
- constraint_size = None
- constraint_stride = None
- else:
- constraint_size, name_ = constraint
- constraint_stride = None
- dim_name = f"{name}.size()[{i}]"
- tx.output.shape_env.source_name_to_debug_name[dim_name] = name_
- constraint_sizes.append(constraint_size)
- constraint_strides.append(constraint_stride)
- if marked_unbacked or is_unbacked_source(name):
- dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED
- elif (
- constraint_size is not None
- or marked_dynamic
- or marked_weak_dynamic
- or is_nested_int(e.size()[i])
- ):
- # NB: We could assert static_shapes is False here, but it
- # seems better to allow the user to override symbolic_context in this
- # case
- if automatic_dynamic:
- dynamic_size = get_automatic_dynamic_shapes_mark_as()
- else:
- dynamic_size = DimDynamic.DYNAMIC
- elif static_shapes or config.assume_static_by_default or marked_static:
- dynamic_size = DimDynamic.STATIC
- else:
- # TODO: When does this show up?
- dynamic_size = DimDynamic.DUCK
- if constraint_stride is not None:
- dynamic_stride = DimDynamic.DYNAMIC
- else:
- dynamic_stride = DimDynamic.INFER_STRIDE
- dynamic_sizes.append(dynamic_size)
- dynamic_strides.append(dynamic_stride)
- return StatefulSymbolicContext(
- dynamic_sizes=dynamic_sizes,
- dynamic_strides=dynamic_strides,
- constraint_sizes=constraint_sizes,
- constraint_strides=constraint_strides,
- specialize_on=specialize_on,
- view_base_context=view_base_context,
- tensor_source=source,
- shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
- )
- # See note [Tensor Fakification and Symbol Caching]
- def wrap_to_fake_tensor_and_record(
- e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None
- ):
- if (
- type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor)
- or isinstance(e, torch.Tensor)
- or is_traceable_wrapper_subclass(e)
- ):
- assert source is not None
- static_shapes, _reason = tensor_always_has_static_shape(
- e,
- is_tensor,
- tensor_source=source,
- )
- if not parent_context:
- symbolic_context = _automatic_dynamic(e, tx, source, static_shapes)
- else:
- # Parent contexts are passed in when we are recursively creating
- # fake tensors for subclasses. A better design would be not to create a
- # parent/child relationship, but to recursively call _automatic_dynamic
- # as we recursively call wrap_to_fake_tensor_and_record. This runs
- # into bugs around how meta_utils knows and works to create fake tensors
- # with tensor subclasses. Ideally, dynamo would drive both the recursive
- # wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation.
- assert isinstance(source, AttrSource)
- inner_context_name = source.member
- symbolic_context = parent_context.inner_contexts[inner_context_name]
- log.debug(
- "wrap_to_fake %s %s %s %s",
- source.name(),
- tuple(e.shape),
- symbolic_context,
- type(e),
- )
- # Note [enable_python_dispatcher in dynamo]
- # Dynamo disables itself when it runs fake tensor prop, which means that tensor subclasses
- # have no way to know (purely based off of global state) if they are currently being run under compile or not.
- # we use enable_python_dispatcher mainly to tweak the DispatchKeyState so that subclass authors
- # can check it to know if they are running in an eager context or not
- with enable_python_dispatcher():
- fake_e = wrap_fake_exception(
- lambda: tx.fake_mode.from_tensor(
- e,
- source=source,
- symbolic_context=symbolic_context,
- )
- )
- if (
- source is not None
- and isinstance(fake_e, FakeTensor)
- and (sym_val := fake_e.item_memo) is not None
- ):
- tx.output.tracked_fakes.append(
- TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context)
- )
- if is_traceable_wrapper_subclass(fake_e):
- attrs, _ = fake_e.__tensor_flatten__()
- for attr in attrs:
- fake_inner = getattr(fake_e, attr)
- inner = getattr(e, attr)
- inner_source = AttrSource(source, attr)
- wrap_to_fake_tensor_and_record(
- inner,
- tx,
- source=inner_source,
- is_tensor=isinstance(fake_inner, torch.Tensor),
- parent_context=symbolic_context,
- )
- tx.output.tracing_context.tensor_to_context[e] = symbolic_context
- if is_sparse_any(fake_e):
- # TODO: for TensorGuards, this eventually may need more
- # fields for the size/stride of any other constituents
- values = fake_e._values() if fake_e.is_sparse else fake_e.values()
- tx.output.input_source_to_sizes_strides[source] = {
- "size": fake_e.size(),
- # TODO: revise this, but for now this stride instead of ()
- # avoids SegFault with PYTORCH_TEST_WITH_DYNAMO=1
- "stride": (1,) * fake_e.ndim,
- "values_size": values.size(),
- "values_stride": values.stride(),
- }
- else:
- tx.output.input_source_to_sizes_strides[source] = {
- "size": fake_e.size(),
- "stride": fake_e.stride(),
- }
- if (
- is_tensor
- and not (static_shapes and source.is_specialized_nn_module())
- and not is_constant_source(source)
- ):
- tx.output.tracked_fakes.append(
- TrackedFake(fake_e, source, symbolic_context)
- )
- tx.output.tracked_fakes_id_to_source[id(e)].append(source)
- return fake_e
- else:
- return e
- class SourcelessBuilder:
- """
- Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects
- that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over
- .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However,
- there may be reasons to represent it as a ListVariable internally.
- NOTE - Objects produced here are born UNGUARDED due to the nature of sources!
- NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant
- if/else type->VariableTracker trees that were cropping up all over dynamo.
- """
- def __init__(self) -> None:
- raise AssertionError("Use SourcelessBuilder.create()")
- @staticmethod
- def create(tx: "InstructionTranslator", value) -> VariableTracker:
- value_type = type(value)
- fast_handler = SourcelessBuilder._type_handlers.get(value_type)
- if fast_handler:
- return fast_handler(tx, value)
- if isinstance(value, VariableTracker):
- # This is always valid to call, and useful for recursive calls.
- return value
- elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS):
- return UserDefinedObjectVariable(value)
- elif ConstantVariable.is_literal(value):
- return ConstantVariable.create(value)
- elif callable(value) and trace_rules.lookup_callable(value) is not None:
- if trace_rules.is_callable_allowed(value):
- tx.output.has_user_defined_allowed_in_graph = True
- return trace_rules.lookup_callable(value)(value)
- elif callable(value) and UserDefinedClassVariable.is_supported_new_method(
- value
- ):
- # NamedTuple._make uses an alias of tuple.__new__
- obj = trace_rules.lookup_callable(value.__self__)(value.__self__)
- return GetAttrVariable(obj, "__new__")
- elif is_function_or_wrapper(value):
- return trace_rules.lookup(value)(value)
- elif isinstance(
- value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType)
- ):
- return EnumVariable(value)
- elif isinstance(value, (type, abc.ABCMeta)):
- return UserDefinedClassVariable(value)
- elif isinstance(value, types.MethodWrapperType):
- return MethodWrapperVariable(value)
- elif (
- isinstance(value, types.MethodType)
- # We only want to support sourceless class objects here
- # An instance variable is not allowed and it should have source
- and isinstance(value.__self__, (type, abc.ABCMeta))
- ):
- # value is a classmethod
- assert getattr(value.__self__, value.__func__.__name__) == value
- cls_obj_vt = SourcelessBuilder.create(tx, value.__self__)
- try:
- return cls_obj_vt.var_getattr(tx, value.__func__.__name__)
- except NotImplementedError:
- pass # failthrough to unimplemented branch
- elif isinstance(value, torch.fx.graph_module.GraphModule):
- return SourcelessGraphModuleVariable(value)
- elif isinstance(
- value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec)
- ):
- return UserDefinedObjectVariable(value)
- elif PlacementVariable.is_placement(value):
- return PlacementVariable(value)
- elif DeviceMeshVariable.is_device_mesh(value):
- return DeviceMeshVariable(value)
- elif value is functools.wraps:
- return FunctoolsWrapsVariable(value)
- elif isinstance(value, re.Pattern):
- return RegexPatternVariable(value)
- elif isinstance(value, torch._dynamo.variables.lazy.LazySymNodeFormatString):
- return ConstantVariable.create(str(value))
- elif isinstance(value, type(torch._higher_order_ops.flex_attention_backward)):
- return torch._dynamo.variables.higher_order_ops.FlexAttentionBackwardHighOrderVariable(
- value
- )
- elif isinstance(value, types.GenericAlias):
- return TypingVariable(value)
- elif is_namedtuple(value):
- output = [
- SourcelessBuilder.create(tx, getattr(value, name))
- for name in namedtuple_fields(type(value))
- ]
- return NamedTupleVariable(output, tuple_cls=type(value))
- elif (
- isinstance(value, torch.SymInt)
- and value.node.expr in tx.output.bound_symbols
- ):
- proxy = tx.output.bound_symbols[value.node.expr]
- return SymNodeVariable.create(tx, proxy)
- unimplemented_v2(
- gb_type="Unexpected type in sourceless builder",
- context=f"{value_type.__module__}.{value_type.__qualname__}",
- explanation=f"SourcelessBuilder.create does not know how to wrap {value_type}",
- hints=[*graph_break_hints.DYNAMO_BUG],
- )
- @staticmethod
- def wrap_constant_literal(value):
- assert ConstantVariable.is_literal(value)
- return ConstantVariable.create(value=value)
- @staticmethod
- def make_type_handlers():
- create = SourcelessBuilder.create
- handlers = {}
- for t in common_constant_types:
- handlers[t] = lambda tx, value: ConstantVariable(value)
- handlers[set] = lambda tx, value: SetVariable(
- [create(tx, x) for x in value], mutation_type=ValueMutationNew()
- )
- handlers[dict] = lambda tx, value: ConstDictVariable(
- {create(tx, k): create(tx, v) for k, v in value.items()},
- type(value),
- mutation_type=ValueMutationNew(),
- )
- handlers[list] = lambda tx, value: ListVariable(
- [create(tx, x) for x in value], mutation_type=ValueMutationNew()
- )
- handlers[tuple] = lambda tx, value: TupleVariable(
- [create(tx, x) for x in value]
- )
- handlers[torch.Size] = lambda tx, value: SizeVariable(
- [create(tx, x) for x in value]
- )
- handlers[collections.OrderedDict] = handlers[dict]
- handlers[immutable_dict] = handlers[dict]
- handlers[immutable_list] = handlers[list]
- handlers[random.Random] = lambda tx, value: RandomClassVariable()
- handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value)
- handlers[torch.DispatchKeySet] = lambda tx, value: DispatchKeySetVariable(
- value, mutation_type=ValueMutationNew()
- )
- handlers[torch._functorch.pyfunctorch.FuncTorchInterpreter] = (
- lambda tx, value: FuncTorchInterpreterVariable(
- value, mutation_type=ValueMutationNew()
- )
- )
- handlers[torch.distributions.constraints._Real] = (
- lambda tx, value: UserDefinedObjectVariable(
- value, mutation_type=ValueMutationNew()
- )
- )
- handlers[torch.distributions.constraints._Interval] = (
- lambda tx, value: UserDefinedObjectVariable(
- value, mutation_type=ValueMutationNew()
- )
- )
- handlers[torch.distributions.constraints.Constraint] = (
- lambda tx, value: UserDefinedObjectVariable(
- value, mutation_type=ValueMutationNew()
- )
- )
- def passthrough(tx: "InstructionTranslator", value):
- return value
- for cls in VariableTrackerMeta.all_subclasses:
- handlers[cls] = passthrough
- return handlers
- SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers()
- class SourcelessUserDefinedObjectBuilder:
- """
- SourceLessBuilder does not return a UserDefinedObjectVariable, but in some
- cases it might be ok to return UserDefinedObjects. In such case, use this
- builder.
- """
- def __init__(self) -> None:
- raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()")
- @staticmethod
- def create(tx: "InstructionTranslator", value) -> VariableTracker:
- value_type = type(value)
- if issubclass(value_type, MutableMapping):
- return MutableMappingVariable(value, mutation_type=ValueMutationNew())
- elif isinstance(value, torch.nn.Module):
- return UnspecializedNNModuleVariable(
- value, mutation_type=ValueMutationNew()
- )
- else:
- return UserDefinedObjectVariable(value, mutation_type=ValueMutationNew())
|