| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562 |
- """
- Core graph building functionality for PyTorch's Dynamo system. This module contains
- the essential components for constructing and managing FX graphs during compilation:
- - OutputGraph: Manages the overall graph construction and compilation process. It owns
- a SubgraphTracer and handles graph compilation, execution, and state management.
- OutputGraph also manages features like graph deduplication, symbolic shape handling,
- and tracking of side effects.
- - SubgraphTracer: Handles the actual FX graph construction by tracing Python code.
- It supports advanced features like higher-order operators through nested tracers,
- lifting of free variables, and handling of symbolic shapes.
- The module supports key Dynamo features including:
- - Higher-order operators through nested SubgraphTracers
- - Graph deduplication for optimization
- - Symbolic shape handling and propagation
- - Side effect tracking and management
- - Guard insertion and management
- """
- import collections
- import contextlib
- import copy
- import functools
- import inspect
- import itertools
- import logging
- import operator
- import re
- import sys
- import traceback
- import warnings
- import weakref
- from collections.abc import Generator, Sequence
- from dataclasses import dataclass, field as dc_field
- from types import CodeType
- from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
- from typing_extensions import ParamSpec, TypeVar
- import sympy
- import torch._guards
- import torch._logging
- import torch.distributed as dist
- import torch.nn
- import torch.utils._pytree as pytree
- from torch import fx, Tensor
- from torch._C._dynamo import guards
- from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
- from torch._guards import (
- CompileContext,
- CompileId,
- GlobalContextCheckpointState,
- Source,
- tracing,
- TracingContext,
- )
- from torch._subclasses.fake_tensor import FakeTensor
- from torch._utils_internal import signpost_event
- from torch.export.dynamic_shapes import _ConstraintTarget
- from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
- from torch.fx.experimental._backward_state import BackwardState
- from torch.fx.experimental.symbolic_shapes import (
- free_symbols,
- guard_scalar,
- is_symbolic,
- ShapeEnv,
- Specialization,
- )
- from torch.fx.node import Target
- from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
- from torch.multiprocessing.reductions import StorageWeakRef
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- from . import config, exc, logging as torchdynamo_logging, variables
- from .backends.registry import CompiledFn, CompilerFn
- from .bytecode_transformation import (
- create_binary_slice,
- create_call_function,
- create_dup_top,
- create_instruction,
- create_load_const,
- create_rot_n,
- create_swap,
- Instruction,
- unique_id,
- )
- from .code_context import code_context
- from .codegen import PyCodegen
- from .current_scope_id import enter_new_scope
- from .device_interface import get_interface_for_device
- from .exc import (
- BackendCompilerFailed,
- exceptions_allowed_to_be_fallback,
- SkipFrame,
- unimplemented_v2,
- unimplemented_v2_with_warning,
- )
- from .graph_deduplication import apply_graph_deduplication
- from .graph_region_tracker import GraphRegionTracker
- from .guards import GuardBuilder, install_guard
- from .mutation_guard import is_dynamic_nn_module
- from .side_effects import AttributeMutationExisting, SideEffects, ValueMutationExisting
- from .source import (
- _get_source_debug_name,
- AttrSource,
- BackwardStateSource,
- ConstantSource,
- GetItemSource,
- GlobalStateSource,
- is_constant_source,
- is_from_local_source,
- LocalSource,
- NumpyTensorSource,
- ParamBufferSource,
- ShapeEnvSource,
- SyntheticLocalSource,
- TensorProperty,
- TensorPropertySource,
- )
- from .utils import (
- _extract_tensor_dict,
- checkpoint_params,
- CleanupHook,
- clone_inputs,
- count_calls,
- counters,
- dynamo_timed,
- get_instruction_source_311,
- get_locals_to_steal,
- get_static_address_type,
- get_unique_name_wrt,
- graph_break_reasons,
- increment_op_count,
- istype,
- lazy_format_graph_code,
- LazyString,
- nn_module_proxy,
- same,
- set_example_value,
- )
- from .variables.base import VariableTracker
- from .variables.builder import (
- BackwardStateGraphArg,
- GraphArg,
- TrackedFake,
- wrap_fx_proxy,
- )
- from .variables.ctx_manager import ContextWrappingVariable
- from .variables.lists import BaseListVariable
- from .variables.misc import NullVariable
- from .variables.nn_module import NNModuleVariable
- from .variables.tensor import (
- NumpyNdarrayVariable,
- SymNodeVariable,
- TensorVariable,
- UnspecializedPythonVariable,
- )
- from .variables.torch_function import TensorWithTFOverrideVariable
- from .variables.user_defined import UserDefinedDictVariable
- if TYPE_CHECKING:
- from torch._dynamo.package import CompilePackage
- from torch._dynamo.symbolic_convert import InstructionTranslatorBase
- log = logging.getLogger(__name__)
- graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
- graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
- graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
- trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
- RootGuardManager = guards.RootGuardManager
- @dataclass(frozen=True)
- class VariableTrackerCacheKey:
- vt_id: int
- # Two different source can point to the same object. However, Dynamo handles
- # globals and local source differently when it comes to guards and possibly
- # some other parts as well. So, cache also relies on the source.
- source: Source
- @dataclass(frozen=True)
- class AliasingInfo:
- has_aliasing: bool
- msg: str
- @dataclass(frozen=True)
- class MutationInfo:
- has_mutation: bool
- msg: str
- class VariableTrackerCache:
- def __init__(self) -> None:
- self.cache: dict[VariableTrackerCacheKey, VariableTracker] = {}
- def lookup(self, value: Any, source: Source) -> Optional[VariableTracker]:
- key = VariableTrackerCacheKey(id(value), source)
- if key not in self.cache:
- return None
- return self.cache[key]
- def add(self, value: Any, source: Source, vt: VariableTracker) -> None:
- key = VariableTrackerCacheKey(id(value), source)
- self.cache[key] = vt
- def clone(self) -> "VariableTrackerCache":
- # Needed for copy and restore graph state
- new_cache = VariableTrackerCache()
- new_cache.cache.update(self.cache)
- return new_cache
- def clear(self) -> None:
- self.cache.clear()
- @functools.cache
- def _step_logger() -> Any:
- return torchdynamo_logging.get_step_logger(log)
- @dataclass
- class GraphCompileReason:
- """Stores why a given output graph was compiled; i.e. what caused the graph break."""
- reason: str
- user_stack: list[traceback.FrameSummary]
- # Indicates if this was a graph break reason due to graph break.
- graph_break: bool = True
- def __post_init__(self) -> None:
- if self.graph_break:
- graph_break_reasons.append(self)
- def _get_gen_rand_values_fn(random_calls: Any) -> Callable[[], list[Any]]:
- def _gen_rand_values() -> list[Any]:
- return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
- return _gen_rand_values
- class FakeRootModule(torch.nn.Module):
- """Trick the constructor of fx.GraphModule"""
- def __init__(self, nn_modules: dict[str, torch.nn.Module]):
- super().__init__()
- for k, v in nn_modules.items():
- setattr(self, k, v)
- def __repr__(self) -> str:
- return "FakeRootModule(...)"
- def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]) -> None:
- for k, v in nn_modules.items():
- setattr(self, k, v)
- class WrapperBackend:
- def __init__(self, backend: CompilerFn) -> None:
- self.backend: CompilerFn = backend
- def __call__(
- self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
- ) -> CompiledFn:
- self.restore = checkpoint_params(gm)
- self.gm = gm
- copy_gm = copy.deepcopy(self.gm)
- self.candidate = self.backend(copy_gm, example_inputs)
- if self.candidate is None or self.candidate is self.gm.forward:
- return self.gm.forward
- if not config.verify_correctness:
- return self.candidate
- # if verify_correctness=True
- try:
- correct = self.gm.forward(*clone_inputs(example_inputs))
- result = self.candidate(*clone_inputs(example_inputs))
- # TODO: replace `same` function with the one in testing
- if same(correct, result):
- return self.candidate
- raise RuntimeError(f"incorrect results of backend {self}")
- except Exception:
- log.exception("error in verify_correctness")
- raise
- finally:
- self.restore()
- Scope = dict[str, object]
- @dataclass
- class OutputGraphGuardsState:
- """
- A base class containing fields that are considered "persistent" when we
- want to save all the important state for reconstrucing guards in a different
- process. Normally we don't need to add states here, but we may have to when
- the information is needed to serialize the guards, so the fields here are
- supposed to be serializable as a requirement.
- """
- local_scope: Scope
- global_scope: Scope
- # This records the initial torch function mode stack for guarding
- torch_function_mode_stack: list[torch.overrides.TorchFunctionMode]
- guard_on_key_order: set[Source]
- # Map from graph input's `Source` to sizes / strides metadata
- input_source_to_sizes_strides: dict[Source, dict[str, Any]]
- dual_level: int
- functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
- current_device: Optional[torch.device]
- global_state_guard: torch._C._dynamo.guards.GlobalStateGuard
- _guards: torch._guards.GuardsSet
- _aotautograd_guards: list[torch._guards.GuardEnvExpr]
- # Whether or not the guards should be checked for correctness
- export: bool = False
- skip_guards_check: bool = False
- export_constraints: bool = False
- name_of_builtins_dict_key_in_fglobals: Optional[str] = None
- @property
- def shape_env(self) -> ShapeEnv:
- raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}")
- @property
- def guards(self) -> torch._guards.GuardsSet:
- return self._guards
- @property
- def aotautograd_guards(self) -> list[torch._guards.GuardEnvExpr]:
- return self._aotautograd_guards
- @dataclass
- class StackLocalsMetadata:
- """
- Stores metadata for a frame's stack and locals for the purposes of building resume functions
- """
- num_stack: int = 0 # number of stack elements, minus removed NULLs
- locals_names: dict[str, int] = dc_field(
- default_factory=dict
- ) # order of locals codegen'd to the stack
- stack_null_idxes: list[int] = dc_field(default_factory=list)
- locals_null_keys: list[str] = dc_field(default_factory=list)
- stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list)
- stack_ctx_idxes_orig: list[int] = dc_field(default_factory=list)
- locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list)
- # TODO we should expand this to make it work for atribtrary in/out
- @dataclass
- class ExportMetaData:
- # maps graph input index to its' source which is later
- # used in export to map to correct user input. In its' flat form,
- # just looks like GetItem(base=LocalSource("foo", idx=0))
- graph_input_idx_to_local_source: dict[int, Source] = dc_field(default_factory=dict)
- # maps user output idx to what type of output it is. There are 3 options:
- # 1) graph out
- # 2) user input
- # 3) constants
- output_return_type: dict[int, tuple[str, Any]] = dc_field(default_factory=dict)
- # output spec of the traced function
- out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
- torch.utils._pytree._LEAF_SPEC
- )
- def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
- # f_globals["__builtins__"] can be a dict or a module. This is an
- # implementation detail -
- # https://docs.python.org/3/library/builtins.html.
- # This makes guarding on any builtin messy because the guard check_fn
- # has to check if the __builtins__ is a module or dict, and then access
- # by either using getattr or getitem respectively.
- # To solve this problem, we insert a new entry in f_globals which points
- # to the builtins __dict__ and then we guard any builtin on this dict.
- # To avoid any collision with the pre-existing keys, we use the
- # install_global to give us a unique dict key.
- f_builtins = global_scope["__builtins__"]
- if not isinstance(f_builtins, dict):
- f_builtins = f_builtins.__dict__
- return f_builtins
- class OutputGraph(OutputGraphGuardsState):
- """
- Wrapper class to hold outputs of InstructionTranslator. Mainly the
- generated fx.Graph.
- OutputGraph is 1:1 with a frame being processed. Each frame is associated
- with some root InstructionTranslator. When user code calls a function,
- we construct a InliningInstructionTranslator that continues to write into
- the root InstructionTranslator's OutputGraph.
- """
- side_effects: SideEffects
- def __init__(
- self,
- code_options: dict[str, Any],
- compiler_fn: Optional[CompilerFn],
- root_tx: "InstructionTranslatorBase",
- export: bool,
- export_constraints: Sequence[_ConstraintTarget],
- frame_state: Any,
- local_scope: Scope,
- global_scope: Scope,
- f_code: CodeType,
- torch_function_mode_stack: list[torch.overrides.TorchFunctionMode],
- package: Optional["CompilePackage"],
- ) -> None:
- super().__init__(
- local_scope,
- global_scope,
- torch_function_mode_stack,
- guard_on_key_order=set(),
- input_source_to_sizes_strides={},
- dual_level=torch.autograd.forward_ad._current_level,
- functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
- current_device=torch.utils._device.CURRENT_DEVICE,
- # initial_global_state is only None during NopTest.
- global_state_guard=torch._dynamo.convert_frame.initial_global_state
- or torch._C._dynamo.guards.GlobalStateGuard(),
- # These are set by @property instead, just initialize them as blank
- _guards=torch._guards.GuardsSet(),
- _aotautograd_guards=[],
- )
- self.tracers = [SubgraphTracer(self, is_export=export)]
- # Map from graph input's `Source` to its `VariableTracker` to
- # de-duplicate graph inputs by source and reuse the tracker
- self.input_source_to_var: dict[Source, VariableTracker] = {}
- self.export = export
- self.export_constraints = export_constraints # type: ignore[assignment]
- self.frame_state = frame_state
- self.cleanup_hooks: list[Callable[[], Any]] = []
- # compile_id is an id number for the current torch.compile
- self.compile_id: int = next(_compile_id_counter)
- # Set of globals installed via install_global* APIs
- self.installed_globals: set[str] = set()
- # TODO: maybe should just pass the entire f_code in here? Not
- # sure...
- self.co_fields = {
- "co_name": f_code.co_name,
- "co_filename": f_code.co_filename,
- "co_firstlineno": f_code.co_firstlineno,
- }
- self.region_tracker = GraphRegionTracker()
- # tracked_fakes says where any tensor that was wrapped to fake came
- # from. It is similar to GraphArg, in that all GraphArgs will get
- # will get added to TrackedFakes, but TrackedFakes also contains
- # GraphArgs that got pruned, and things like Tensor attributes which
- # aren't explicit graph inputs. Used by shape guard
- self.tracked_fakes: list[TrackedFake] = []
- shape_env = ShapeEnv(
- # Reference Cycle!
- # Share a reference to the list of TrackedFake.
- #
- # ShapeEnv needs this in order to be able to reproduce the call
- # to produce_guards at an arbitrary time point. That is because
- # TrackedFake instances may have its metadata changed throughout
- # the program execution.
- tracked_fakes=self.tracked_fakes,
- allow_scalar_outputs=config.capture_scalar_outputs,
- allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
- prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
- co_fields=self.co_fields,
- )
- # In export mode, we force the shape_env to strictly disallow any constraining
- # of the user marked dynamic dims
- import torch._functorch.config as _config
- with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
- fake_mode = torch._subclasses.FakeTensorMode(
- shape_env=shape_env,
- # TODO (tmanlaibaatar) Remove this once we always lift params and buffers
- allow_non_fake_inputs=True if self.export else False,
- export=self.export,
- )
- self.tracing_context: TracingContext = TracingContext(fake_mode)
- self.tracing_context.traced_code.append(f_code)
- self.dynamo_compile_id: Optional[CompileId] = (
- CompileContext.current_compile_id()
- )
- self.init_ambient_guards()
- # Map each tensor id to a list of sources. This is necessary because
- # tensor ids cannot be recovered from tracked fakes (in general).
- # We use this map to interpret (i.e., check for violations of) constraints,
- # specifically equality constraints, which have shared tensor ids in them.
- # This map should also be generally useful, e.g., for (de)serialization.
- self.tracked_fakes_id_to_source: dict[int, list[Source]] = (
- collections.defaultdict(list)
- )
- # Stores the full fqn of a param or buffer to the relevant source.
- self.param_name_to_source: Optional[dict[str, Source]] = {}
- self.side_effects = SideEffects(self)
- # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL
- # and LOAD_ATTR for same python objects free.
- self.variable_tracker_cache = VariableTrackerCache()
- self.unique_var_id = itertools.count()
- self.code_options: dict[str, Any] = dict(code_options)
- self.output_instructions: list[Instruction] = []
- # used to track nodes that are added between calls of copy_graphstate
- # and restore_graphstate
- self.timestamp = 0
- # A list of register_finalizer_fns to apply to the output graph module
- self.register_finalizer_fns: list[Callable[[fx.GraphModule], None]] = []
- # Not checkpointed
- self.compiler_fn: Optional[CompilerFn] = compiler_fn
- self.root_tx = root_tx
- self.package = package
- # Given a source, what are the user stacks of all locations that
- # accessed it?
- #
- # For efficiency, we only populate this:
- # - During export, and
- # - If the source could potentially lead to a spurious export input
- #
- # Feel free to populate this more frequently if other use-cases arise,
- # but be aware that we have to generate full stacks for each
- # recording!
- self.source_to_user_stacks: dict[Source, list[traceback.StackSummary]] = {}
- self._current_tx: list[InstructionTranslatorBase] = []
- self.cleanups: list[CleanupHook] = []
- self.should_exit = False
- self.unspec_variable_map: dict[str, UnspecializedPythonVariable] = {}
- # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
- self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
- # Tracks if the output graph has a user defined allowed function in the
- # graph. This is used later to determine if we should fallback to eager
- # for certain exceptions. THe idea is that if the user has applied
- # allow_in_graph, they would like to see the error instead of falling
- # back for backend errors.
- self.has_user_defined_allowed_in_graph = False
- # Tracks a list of called ops that were not tagged with "pt2_compliant_tag".
- # This information is useful for logging.
- self.non_compliant_ops: set[torch._ops.OpOverload] = set({})
- # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag".
- # This information is useful for logging.
- self.compliant_custom_ops: set[torch._ops.OpOverload] = set({})
- # We save the global torch state here to be restored in case of graph
- # breaks. The relevant issue is seen here
- # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086
- # where inlining of a function changes the global state (because of the
- # presence of torch.no_grad) and there is a graph break.
- self.save_global_state()
- # Tracks the original FQNs of the constant tensors from the original graph,
- # i.e. buffers and parameters.
- self.dynamo_flat_name_to_original_fqn: dict[str, str] = {}
- # All calls to random() are replaced with a single call to __gen_rand_values
- # functions that returns a tuple of random values for each original call.
- # random_calls tracks calls to random() and random_values_var stores the name of
- # the variable that stores __gen_rand_values results.
- self.random_calls: list[
- tuple[Callable[..., object], tuple[object, ...], dict[str, object]]
- ] = []
- self.random_values_var: Any = None
- # Bytecode to insert right before we call the graph
- self.pregraph_bytecode: list[Instruction] = []
- # Use to pass values to backward hooks when using compiled autograd
- self.backward_state: dict[str, VariableTracker] = {}
- self.backward_state_proxy: Optional[torch.fx.Proxy] = None
- self.backward_state_var: Optional[str] = None
- self.name_of_builtins_dict_key_in_fglobals: str = (
- self.install_builtins_dict_in_fglobals()
- )
- self.compiler_trace_stack = contextlib.ExitStack()
- # These are the ambient, currently-global saved_tensor_hooks stashed in autograd,
- # that are set for the entire duration of the compiled region.
- # This is an invariant today because we graph break on the saved_tensor_hook
- # context manager inside a compiled region
- self.saved_tensors_hooks_subgraph_names: Optional[list[str]] = (
- self.maybe_install_saved_tensors_hooks_subgraphs()
- )
- # mangled alias -> module fqn name
- self.import_sources: dict[str, str] = {}
- self.export_metadata = ExportMetaData()
- def mark_bytecode_tracing_start(self) -> None:
- self.compiler_trace_stack.enter_context(
- dynamo_timed(
- "bytecode_tracing",
- log_pt2_compile_event=True,
- )
- )
- def mark_bytecode_tracing_stop(self) -> None:
- self.compiler_trace_stack.close()
- def install_builtins_dict_in_fglobals(self) -> str:
- f_builtins = get_builtins_dict(self.global_scope)
- return self.install_global("__builtins_dict__", f_builtins)
- def add_backward_state_hook(
- self, hook: VariableTracker, prefix: str = "hook"
- ) -> tuple[str, torch.fx.Proxy]:
- name = f"{prefix}{len(self.backward_state)}"
- assert name not in self.backward_state
- self.backward_state[name] = hook
- return name, self.get_backward_state_proxy()
- def get_backward_state_proxy(self) -> torch.fx.Proxy:
- if self.backward_state_proxy is None:
- if self.export:
- unimplemented_v2(
- gb_type="backward_state does not support export",
- context="",
- explanation="Compiled autograd doesn't work with `torch.export`.",
- hints=[],
- )
- example_value = BackwardState()
- self.backward_state_proxy = self.root_tracer.create_graph_input(
- "dynamo_backward_state",
- type(example_value),
- example_value,
- source=BackwardStateSource(),
- )
- self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
- self.backward_state_var = self.new_var()
- return self.backward_state_proxy
- # This gets its own helper function so guards DEBUG logs are more informative
- def init_ambient_guards(self) -> None:
- # Register a SHAPE_ENV guard to make sure we setup shape guards
- # that show up in ShapeEnv
- self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
- self.guards.add(
- GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
- )
- self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
- self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
- self.guards.add(
- GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
- )
- ci = torch._C._functorch.peek_interpreter_stack()
- if ci is not None:
- self.guards.add(
- GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
- )
- if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
- self.guards.add(
- GlobalStateSource().make_guard(
- GuardBuilder.AUTOGRAD_SAVED_TENSORS_HOOKS
- )
- )
- def maybe_install_saved_tensors_hooks_subgraphs(self) -> Optional[list[str]]:
- if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
- return None
- get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
- are_inline_hooks = (
- torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
- )
- hooks = get_hooks()
- if not are_inline_hooks(hooks):
- return None
- # If GraphModule provided by user contains fx.wrap,
- # We can only rely on user provided cache hash in this case.
- # If user did not provide cache hash - then we always bypass cache.
- pack_gm, unpack_gm = hooks
- pack_subgraph_name = self.install_subgraph(
- "saved_tensors_hooks_pack",
- torch.fx.GraphModule(self.nn_modules, pack_gm.graph),
- )
- unpack_subgraph_name = self.install_subgraph(
- "saved_tensors_hooks_unpack",
- torch.fx.GraphModule(self.nn_modules, unpack_gm.graph),
- )
- assert pack_subgraph_name == "saved_tensors_hooks_pack_0"
- assert unpack_subgraph_name == "saved_tensors_hooks_unpack_0"
- return [pack_subgraph_name, unpack_subgraph_name]
- def dump_guards_state(self) -> OutputGraphGuardsState:
- # Dump a serializable version of self without extras
- return OutputGraphGuardsState(
- local_scope=self.local_scope,
- global_scope=self.global_scope,
- torch_function_mode_stack=self.torch_function_mode_stack,
- guard_on_key_order=self.guard_on_key_order,
- input_source_to_sizes_strides=self.input_source_to_sizes_strides,
- dual_level=self.dual_level,
- functorch_layers=self.functorch_layers,
- current_device=self.current_device,
- global_state_guard=self.global_state_guard,
- name_of_builtins_dict_key_in_fglobals=self.name_of_builtins_dict_key_in_fglobals,
- export=self.export,
- export_constraints=self.export_constraints,
- _guards=self.guards,
- _aotautograd_guards=self.aotautograd_guards,
- skip_guards_check=self.skip_guards_check,
- )
- def synthetic_graph_input(
- self, fn: Callable[..., Any], args: tuple[Any, ...]
- ) -> VariableTracker:
- """
- call fn(*args) before the graph runs and turn the result into a fake input.
- """
- example_value = fn(*args)
- varname = self.new_var()
- cg = PyCodegen(self.root_tx)
- cg.add_push_null(
- lambda: cg.load_import_from(
- fn.__module__,
- fn.__name__,
- )
- )
- cg.foreach(map(variables.ConstantVariable.create, args))
- cg.call_function(len(args), False)
- cg.store(varname)
- self.pregraph_bytecode.extend(cg.get_instructions())
- source = SyntheticLocalSource(varname)
- result = VariableTracker.build(self.root_tx, example_value, source)
- # Realize the VT because we will delete the guards on it in the next line.
- result = result.realize()
- TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
- source
- )
- return result
- def add_cleanup_hook(self, fn: Callable[[], Any]) -> None:
- self.cleanup_hooks.append(fn)
- def call_cleanup_hooks(self) -> None:
- for hook in reversed(self.cleanup_hooks):
- hook()
- self.cleanup_hooks.clear()
- @property
- def root_tracer(self) -> "SubgraphTracer":
- return self.tracers[0]
- @property
- def current_tracer(self) -> "SubgraphTracer":
- return self.tracers[-1]
- def is_root_tracer(self) -> bool:
- # Helper to tell if we are inside the higher order operator tracing.
- return len(self.tracers) == 1
- @property
- def graph(self) -> torch.fx.Graph:
- return self.current_tracer.graph
- # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer.
- @graph.setter
- def graph(self, value: torch.fx.Graph) -> None:
- self.current_tracer.graph = value
- @property
- def input_name_to_proxy(self) -> dict[str, fx.Proxy]:
- return self.current_tracer.input_name_to_proxy
- @property
- def real_value_cache(self) -> dict[fx.Node, torch.Tensor]:
- return self.current_tracer.real_value_cache
- @property
- def bound_symbols(self) -> dict[sympy.Symbol, Union[torch.fx.Proxy, "LazyProxy"]]:
- return self.current_tracer.bound_symbols
- # If you are here, and you're looking for create_graph_input,
- # to avoid ambiguity, please call one of the following:
- # - self.current_tracer.create_graph_input
- # - self.root_tracer.create_graph_input
- # See NOTE [HigherOrderOperator tracing design] for more context.
- def create_proxy(self, *args: Any, **kwargs: Any) -> torch.fx.Proxy:
- return self.current_tracer.create_proxy(*args, **kwargs)
- def create_node(self, *args: Any, **kwargs: Any) -> torch.fx.Node:
- return self.current_tracer.create_node(*args, **kwargs)
- def remove_node(self, *args: Any, **kwargs: Any) -> None:
- return self.current_tracer.remove_node(*args, **kwargs)
- @contextlib.contextmanager
- def subtracer(
- self, source_target: Optional[Target], prior_tracer: "SubgraphTracer"
- ) -> Generator[fx.Tracer, None, None]:
- new_scope_ctx = enter_new_scope()
- try:
- if prior_tracer:
- # Lineage MUST stay preserved
- assert prior_tracer.parent is self.current_tracer
- new_scope_ctx.__enter__()
- tracer = (
- prior_tracer
- if prior_tracer
- else SubgraphTracer(
- self,
- parent=self.current_tracer,
- source_target=source_target,
- is_export=self.current_tracer.is_export,
- )
- )
- self.tracers.append(tracer)
- yield tracer
- finally:
- new_scope_ctx.__exit__(None, None, None)
- self.tracers.pop()
- @property
- def output(self) -> "OutputGraph":
- return self
- @property
- def fake_mode(self) -> torch._subclasses.FakeTensorMode:
- assert self.tracing_context.fake_mode is not None
- return self.tracing_context.fake_mode
- @property
- def shape_env(self) -> ShapeEnv:
- assert self.tracing_context.fake_mode is not None
- assert self.tracing_context.fake_mode.shape_env is not None
- return self.tracing_context.fake_mode.shape_env
- @property
- def guards(self) -> torch._guards.GuardsSet:
- return self.tracing_context.guards_context.dynamo_guards
- @property
- def nn_modules(self) -> dict[str, Any]:
- return self.tracing_context.module_context.nn_modules
- @property
- def aotautograd_guards(self) -> list[torch._guards.GuardEnvExpr]:
- return self.tracing_context.guards_context.aotautograd_guards
- def save_global_state(
- self, out: Optional[dict[str, tuple[Callable[..., Any], bool]]] = None
- ) -> None:
- """
- Saves to out if it is provided. Else saves to the tracing context's global_state.
- """
- global_state = cast(
- dict[str, tuple[Callable[..., Any], bool]],
- (
- out
- if out is not None
- else self.tracing_context.global_context.global_state
- ),
- )
- global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
- global_state["autocast_enabled"] = (
- functools.partial(torch.set_autocast_enabled, "cuda"),
- torch.is_autocast_enabled("cuda"),
- )
- global_state["autocast_cpu_enabled"] = (
- functools.partial(torch.set_autocast_enabled, "cpu"),
- torch.is_autocast_enabled("cpu"),
- )
- global_state["autocast_gpu_dtype"] = ( # type:ignore[assignment]
- functools.partial(torch.set_autocast_dtype, "cuda"),
- torch.get_autocast_dtype("cuda"),
- )
- global_state["autocast_cpu_dtype"] = ( # type:ignore[assignment]
- functools.partial(torch.set_autocast_dtype, "cpu"),
- torch.get_autocast_dtype("cpu"),
- )
- global_state["autocast_cache_enabled"] = (
- torch.set_autocast_cache_enabled,
- torch.is_autocast_cache_enabled(),
- )
- def push_tx(self, tx: "InstructionTranslatorBase") -> None:
- self._current_tx.append(tx)
- def pop_tx(self) -> "InstructionTranslatorBase":
- return self._current_tx.pop()
- @property
- def current_tx(self) -> "InstructionTranslatorBase":
- return self.root_tx if not self._current_tx else self._current_tx[-1]
- def count_calls(self) -> int:
- return count_calls(self.graph)
- def is_empty_graph(self) -> bool:
- return len(list(self.graph.nodes)) == 0
- def has_outputs(self) -> bool:
- return len([x for x in self.graph.nodes if x.op == "output"]) > 0
- def get_submodule(self, keys: str) -> Union[torch.nn.Module, Any]:
- assert keys
- obj: Union[torch.nn.Module, dict[str, torch.nn.Module]] = self.nn_modules
- for k in keys.split("."):
- if isinstance(obj, dict):
- obj = obj[k]
- else:
- obj = getattr(obj, k)
- return obj
- def new_var(self, name: str = "tmp") -> str:
- existing = set(self.code_options["co_varnames"])
- # In common case, this will be O(1)
- while True:
- var = f"{name}_{next(self.unique_var_id)}"
- if var not in existing:
- self.code_options["co_varnames"] += (var,)
- return var
- def update_co_names(self, name: str) -> None:
- """Ensure self.code_options.co_names contains name"""
- if name not in self.code_options["co_names"]:
- self.code_options["co_names"] += (name,)
- @staticmethod
- def module_key_name(*names: Any) -> str:
- # create a new unique name
- name = "_".join(map(str, names))
- # Strip the guard lookup L/G access
- name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name)
- # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv
- name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
- # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
- name = re.sub(r"[^a-zA-Z0-9]", "_", name)
- if not name or not name[0].isalpha():
- name = "sub" + name
- return name
- def register_static_attr_and_return_proxy(
- self, attr_prefix: str, attr_value: Any
- ) -> fx.Proxy:
- attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules)
- # TODO `nn_modules` has been historically overloaded to store a lot more
- # than just nn module objects, fix that.
- self.nn_modules[attr_name] = attr_value
- proxy = self.create_proxy("get_attr", attr_name, (), {})
- set_example_value(proxy.node, attr_value)
- return proxy
- def register_attr_or_module(
- self,
- target: Union[torch.nn.Module, torch.Tensor, Any],
- *names: Any,
- **options: Any,
- ) -> VariableTracker:
- if is_dynamic_nn_module(target, self.export):
- # Instead of returning UnspecializedNNModuleVariable, call
- # VariableTracker.build so that it is tracked for mutation.
- return VariableTracker.build(self.current_tx, target, **options)
- options = dict(options)
- assert "source" in options
- source = options["source"]
- assert not isinstance(source, ParamBufferSource)
- if isinstance(target, torch.Tensor):
- tracer = self.current_tracer
- if not self.is_root_tracer():
- # For higher order ops, we don't want to insert the get_attr in
- # innermost graph. Instead, we want to raise the params/buffers
- # as inputs to the higher-order graph, and register them as
- # get_attrs in the root tracer.
- # Note that Dynamo will still call lift_tracked_freevar_to_input
- # when these inputs are encountered for the inner graph. The
- # only difference is what happens at the root tracer for
- # nn.Parameters vs free inputs. The free inputs are registered
- # as placeholders in the root graph, whereas the nn.Parameters
- # are registered as get_attr nodes in the root graph.
- tracer = self.root_tracer
- def wrap_name(module_key: str) -> VariableTracker:
- assert self.param_name_to_source is not None
- self.param_name_to_source[module_key] = source
- # Check if the attr has already been registered. This can happen
- # when two different sources point to the same tensor.
- assert self.root_tx is not None
- if target in self.root_tx.output.side_effects:
- return self.root_tx.output.side_effects[target]
- if get_static_address_type(target) == "guarded" and not isinstance(
- source, NumpyTensorSource
- ):
- install_guard(source.make_guard(GuardBuilder.ID_MATCH))
- elif not is_constant_source(source):
- install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
- vt = wrap_fx_proxy(
- self.root_tx,
- tracer.create_proxy("get_attr", module_key, (), {}),
- example_value=target,
- **options,
- )
- # Track the object so to avoid duplicate registration in case of
- # different sources pointing to the same tensor object.
- vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
- assert "tensor_dict" not in vt.as_proxy().node.meta
- vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target)
- return vt
- elif isinstance(target, torch.nn.Module):
- assert isinstance(target, torch.nn.Module)
- if source:
- install_guard(source.make_guard(GuardBuilder.NN_MODULE))
- def wrap_name(module_key: str) -> VariableTracker:
- return NNModuleVariable(type(target), module_key, target, **options)
- else:
- # This is Dynamo created graph module, e.g., graph module coming
- # from higher order ops. NNModuleVariable tracker can't be
- # sourceless, so let's return a unspecializedNNModule variable
- # tracker.
- def wrap_name(module_key: str) -> VariableTracker:
- return variables.UnspecializedNNModuleVariable(target, **options)
- elif isinstance(target, (torch.SymInt, torch.SymFloat)):
- # HACKY CODE REGION BEGIN
- # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
- # This ultimately gets written to self.nn_modules, which is unfortunate
- # Attrs that are tenors and symints and such need to be migrated to have their
- # own storage
- # alas, this is like this for now
- def wrap_name(module_key: str) -> VariableTracker:
- return SymNodeVariable.create(
- self,
- self.create_proxy("get_attr", module_key, (), {}),
- sym_num=target,
- **options,
- )
- # HACKY CODE REGION END
- else:
- def wrap_name(module_key: str) -> VariableTracker:
- self.output.update_co_names(module_key)
- self.global_scope[module_key] = target
- return VariableTracker.build(
- self, # type: ignore[arg-type]
- target,
- ConstantSource(source_name=module_key),
- )
- for k, v in self.nn_modules.items():
- if v is target:
- # it already exists
- return wrap_name(k)
- name = OutputGraph.module_key_name(*names)
- name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
- self.nn_modules[name] = target
- if isinstance(target, torch.nn.Module):
- def register_leaf_name(leaf_name: str) -> None:
- assert self.param_name_to_source is not None
- new_source = ParamBufferSource(source, leaf_name)
- new_name = f"{name}.{leaf_name}"
- self.param_name_to_source[new_name] = new_source
- if isinstance(source, LocalSource):
- self.dynamo_flat_name_to_original_fqn[
- OutputGraph.module_key_name(new_source.name())
- ] = leaf_name
- # annoying, but there are cases when we do not have parameters
- # see test_nn_moduledict_contains
- if hasattr(target, "_parameters"):
- for leaf_name, _ in target.named_parameters():
- register_leaf_name(leaf_name)
- if hasattr(target, "_buffers"):
- for leaf_name, _ in target.named_buffers():
- register_leaf_name(leaf_name)
- return wrap_name(name)
- def handle_aliases_for_stolen_lists(
- self, tx: "InstructionTranslatorBase"
- ) -> tuple[list[Instruction], dict[Source, Source]]:
- # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive
- maybe_gm = self.local_scope.get("self")
- stolen_list_names = get_locals_to_steal(maybe_gm)
- if not stolen_list_names:
- return [], {}
- alias_insts = []
- needs_alias: dict[str, list[VariableTracker]] = {}
- queue = [
- *tx.stack,
- *tx.symbolic_locals.values(),
- *self.side_effects.store_attr_mutations.keys(),
- ]
- while queue:
- x = queue.pop()
- if isinstance(x, BaseListVariable):
- assert isinstance(x.items, list)
- queue += x.items
- continue
- if not (
- (
- x not in self.side_effects.store_attr_mutations
- or isinstance(x.mutation_type, AttributeMutationExisting)
- )
- and isinstance(x.source, GetItemSource)
- and isinstance(x.source.base, LocalSource)
- and x.source.base.local_name in stolen_list_names
- ):
- continue
- stolen_name = x.source.base.local_name
- if stolen_name not in needs_alias:
- needs_alias[stolen_name] = []
- needs_alias[stolen_name].append(x)
- visited = {}
- overridden_sources: dict[Source, Source] = {}
- for arg in self.graphargs:
- if not (
- isinstance(arg._example, list)
- and isinstance(arg.source, LocalSource)
- and arg.source.local_name in needs_alias
- ):
- continue
- # arg is a list that will be cleared by the compiled function
- list_name = arg.source.local_name
- assert list_name in self.code_options["co_varnames"]
- for x in needs_alias[list_name]:
- # Skip if already handled.
- if x.source in overridden_sources:
- continue
- # A small codegen optimization because we might have different
- # VariableTrackers that share the same source.
- list_idx = x.source.index # type: ignore[attr-defined]
- if list_idx not in visited:
- alias_name = self.new_var(
- f"{list_name}_ref"
- ) # self.new_var already adds unique id suffix
- visited[list_idx] = alias_name
- # bytecode of `alias_name = list_name[list_idx]`
- alias_insts.extend(
- [
- create_instruction("LOAD_FAST", argval=list_name),
- create_load_const(list_idx),
- create_instruction("BINARY_SUBSCR"),
- create_instruction("STORE_FAST", argval=alias_name),
- ]
- )
- # operate on alias, handled by suffix codegen
- old_source = x.source
- overridden_sources[old_source] = LocalSource(visited[list_idx])
- # NOTE: we need `overridden_sources` because (1) we want to codegen for
- # these list items to use the new local source, but (2) we want to avoid
- # updating `source` in place because that might break invariants in
- # other parts of Dynamo like guards.
- return alias_insts, overridden_sources
- def _get_stack_values_to_restore(
- self, tx: "InstructionTranslatorBase", stack_pops: int
- ) -> tuple[list[VariableTracker], StackLocalsMetadata]:
- """
- Gets the stack + locals values belonging to tx that need to be restored.
- Also prunes dead tx locals and realizes all VTs in the tx's stack.
- NullVariables in stack/locals will NOT be restored, unless they are the top `stack_pops`
- elements of the stack - it is expected that the next instruction to run will pop the top
- `stack_pops` elements of the stack, so we should codegen NULLs.
- Returns:
- - stack_values: stack and locals values that need to be restored
- - meta: locations of NULLs and ContextWrappingVariables in the stack/locals
- (ignores the top `stack_pops` values on the stack)
- """
- tx.prune_dead_locals()
- stack_values = []
- meta = StackLocalsMetadata()
- # realize any unrealized tensor VTs in case they
- # need to be added to self.nn_modules as attributes
- for i, value in enumerate(tx.stack):
- variables.LazyVariableTracker.realize_all(value)
- # ignore top `stack_pops` values on the stack
- if len(tx.stack) - i <= stack_pops:
- stack_values.append(value)
- continue
- if isinstance(value, NullVariable):
- meta.stack_null_idxes.append(i)
- else:
- stack_values.append(value)
- if isinstance(value, ContextWrappingVariable):
- target_values = (
- () if value.target_values is None else tuple(value.target_values)
- )
- # NOTE: track index in stack after NULLs have been removed
- meta.stack_ctx_args.append((len(stack_values) - 1, target_values))
- meta.stack_ctx_idxes_orig.append(i)
- meta.num_stack = len(stack_values)
- cell_and_freevars = set(tx.cellvars() + tx.freevars())
- # NB: Typically (i.e., for graph compile from RETURN_VALUE),
- # symbolic_locals will be empty at this point, as prune_dead_locals
- # will clear out all of symbolic_locals because RETURN_VALUE is the
- # last instruction and no more locals are used. The fanciness here
- # is only needed for partial graphs.
- # NOTE: All cell and free variables are represented as CellVariable,
- # so checks for NULLs and context managers in the case of codegen'ing resume
- # functions will not be performed on them. This is expected behavior.
- for k, v in tx.symbolic_locals.items():
- # Note! this explicitly uses .local_name for matching
- # Failure to do so will cause spurious registrations in val_to_names.
- # This will in turn result in spurious variables showing up in the graph.
- # This was very tricky to debug. For an example, dump the graph at call_user_compiler
- # while running test_subgraphs.py
- # Do not include top-frame unmodified locals here - otherwise, the compiled graph may
- # erroneously include them as part of the return. We manually codegen them afterward.
- if (
- isinstance(v.source, LocalSource)
- and v.source.local_name == k
- and tx is self.root_tx
- ):
- continue
- # Do not load cell/free vars
- if k in cell_and_freevars:
- continue
- # Do not load variable if it is NULL.
- if sys.version_info >= (3, 12):
- # NOTE: do not use isinstance, since it realizes lazy VT's
- # Continuation function will load the NULL for v.
- if type.__instancecheck__(NullVariable, v):
- meta.locals_null_keys.append(k)
- continue
- else:
- # A variable should never be NULL in < 3.12
- assert not type.__instancecheck__(NullVariable, v)
- meta.locals_names[k] = len(meta.locals_names)
- if isinstance(v, ContextWrappingVariable):
- target_values = (
- () if v.target_values is None else tuple(v.target_values)
- )
- meta.locals_ctx_args.append((k, target_values))
- stack_values.append(v)
- return stack_values, meta
- def compile_subgraph(
- self,
- tx: "InstructionTranslatorBase",
- reason: GraphCompileReason,
- partial_convert: bool = False,
- stack_pops: int = 0,
- ) -> list[StackLocalsMetadata]:
- """
- Compiles the current subgraph, with inputs w.r.t. self.root_tx, and codegens:
- - Call the compiled subgraph
- - Apply side effects
- - Codegen stack and locals
- - Store the locals
- Python does not allow NULL to be an arg to a function, so we do not codegen NULLs on the stack,
- unless the value is one of the top `stack_pops` values on the stack (these values are expected to be
- popped immediately after this generated code. The prologue of the resume function is expected to restore
- any dropped NULLs.
- Returns stack indices and locals keys where we dropped NULLs, and where we found inactive context manager objects.
- """
- assert self.root_tx is not None
- if not config.nested_graph_breaks:
- # expect to only compile 1 frame
- assert self.root_tx is tx
- # bytecode tracing has finished. Pop the context manager for dynamo_timed
- self.mark_bytecode_tracing_stop()
- self.partial_convert = partial_convert
- self.compile_subgraph_reason = reason
- self.should_exit = True
- log.debug("COMPILING GRAPH due to %s", reason)
- # prefix instructions (Python 3.11+)
- prefix_insts: list[Instruction] = []
- if sys.version_info >= (3, 11):
- for inst in self.root_tx.prefix_insts:
- if inst.opname == "COPY_FREE_VARS":
- prefix_insts.append(
- create_instruction(
- "COPY_FREE_VARS",
- arg=len(self.root_tx.code_options["co_freevars"]),
- )
- )
- else:
- prefix_insts.append(copy.copy(inst))
- # stack values and restore vars for each frame are pushed in reverse order
- # i.e. last element corresponds to root frame (1),
- # first element corresponds to current frame (N)
- all_stack_values = []
- all_stack_locals_metas = []
- cur_tx: Optional[InstructionTranslatorBase] = tx
- while cur_tx is not None:
- # this should have been checked by the caller
- assert all(block.can_restore() for block in cur_tx.block_stack)
- stack_values, meta = self._get_stack_values_to_restore(
- cur_tx, stack_pops if cur_tx is tx else 0
- )
- all_stack_values.append(stack_values)
- all_stack_locals_metas.append(meta)
- # Exit from all context manager variables to make sure global state is restored
- for block in reversed(cur_tx.block_stack):
- block.exit(cur_tx, is_graph_break=reason.graph_break)
- cur_tx = cur_tx.parent
- # "Garbage collect the heap".
- self.side_effects.prune_dead_object_new(tx)
- self.add_output_instructions(prefix_insts)
- assert not (self.pregraph_bytecode and self.export), (
- "export does not support pregraph_bytecode"
- )
- self.add_output_instructions(self.pregraph_bytecode)
- alias_insts, overridden_sources = self.handle_aliases_for_stolen_lists(
- self.root_tx
- )
- self.add_output_instructions(alias_insts)
- self.cleanup_graph()
- # Use nn.Module "proxies" in the constructed GraphModule so that
- # the resulting GM does not hold additional strong references to the original modules.
- # This prevents a strong ref cycle where Dynamo created code holds on to references
- # to modules that also have Dynamo code cache invalidation checks.
- # When cache invalidation runs, the generated GM will be invalidated, which also deletes
- # the proxies.
- nn_modules_proxies = {
- name: nn_module_proxy(mod) for name, mod in self.nn_modules.items()
- }
- root = FakeRootModule(nn_modules_proxies)
- from .decorators import disable
- # to handle random calls
- if len(self.random_calls) > 0:
- random_calls_instructions = []
- self.random_values_var = self.new_var("random_values")
- rand_fn = disable(
- _get_gen_rand_values_fn(self.random_calls),
- reason="do not trace into Dynamo rng recovery function",
- )
- rand_fn_name = self.install_global("__gen_rand_values", rand_fn)
- codegen = PyCodegen(
- self.root_tx, root, overridden_sources=overridden_sources
- )
- random_calls_instructions.extend(
- codegen.load_function_name(rand_fn_name, True)
- )
- random_calls_instructions.extend(create_call_function(0, False))
- random_calls_instructions.append(
- codegen.create_store(self.random_values_var),
- )
- self.add_output_instructions(random_calls_instructions)
- # Codegen stack convention before the unsupported instruction
- # NOTE: in these comment blocks, "locals" EXCLUDE free and cell vars.
- # NOTE: stack and locals must be codegen'd BEFORE the unsupported instruction, since the latter
- # can arbitrarily mutate the former.
- # [
- # frame N locals,
- # frame N-1 stack + locals,
- # ...,
- # frame 1 stack + locals,
- # ], frame N stack
- # see symbolic_convert.py for
- # codegen stack convention after the unsupported instruction
- # NOTE: cells are loaded into continuation functions directly
- # this determines the order that values are codegen'd to the stack
- stack_values_flat = [val for vals in all_stack_values for val in vals]
- stored_graph_output_var = False
- graph_output_var = None
- # call compiled fx graph and codegen all values - stack and locals
- if (
- self.root_tx is tx # single frame
- and stack_values_flat
- and all(
- not isinstance(
- v,
- (
- UnspecializedPythonVariable,
- NumpyNdarrayVariable,
- TensorWithTFOverrideVariable,
- ),
- )
- and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
- for v in stack_values_flat
- )
- and all(isinstance(x, TensorVariable) for x in stack_values_flat)
- and len(set(stack_values_flat)) == len(stack_values_flat)
- and self.side_effects.is_empty()
- and not tx.debug_locals
- and not self.backward_state
- and not all_stack_locals_metas[-1].stack_null_idxes
- and not all_stack_locals_metas[-1].locals_null_keys
- ):
- # optimization to generate better code in a common case
- self.add_output_instructions(
- [
- # load in reverse since UNPACK_SEQUENCE will reverse
- *self.compile_and_call_fx_graph(
- tx, list(reversed(stack_values_flat)), root
- ),
- create_instruction("UNPACK_SEQUENCE", arg=len(stack_values_flat)),
- ]
- )
- # function output will be moved to the correct places below
- else:
- graph_output_var = self.new_var("graph_out")
- # load stack values in a flat manner - we will codegen bytecode to place them correctly
- # according to our convention above
- pass1 = PyCodegen(
- self.root_tx,
- root,
- graph_output_var,
- overridden_sources=overridden_sources,
- )
- self.codegen_suffix(tx, stack_values_flat, pass1)
- # Use `pass1.uses` to selectively cache multi-user variables into a
- # temporary local source. This (a). speeds up loading VTs with long
- # chained source, and (b). avoids redundantly saving single-user VT
- # into a temporary local.
- tempvars = {} # type: ignore[var-annotated]
- for val, count in pass1.uses.items():
- # If it's already a local source, no need to cache it
- if count > 1 and not istype(val, (SyntheticLocalSource, LocalSource)):
- tempvars[val] = None
- pass2 = PyCodegen(
- self.root_tx,
- root,
- graph_output_var,
- tempvars=tempvars,
- overridden_sources=overridden_sources,
- )
- self.codegen_suffix(tx, stack_values_flat, pass2)
- if (
- torch._dynamo.config.log_graph_in_out_metadata
- and stack_values_flat
- and len(stack_values_flat) == 1
- ):
- vt = stack_values_flat[0]
- if (
- isinstance(vt, torch._dynamo.variables.NamedTupleVariable)
- and vt.tuple_cls
- is torch._dynamo.functional_export.ExportTracerOutput
- ):
- flat_returns = vt.items[0]
- out_spec = vt.items[1]
- assert isinstance(
- flat_returns, torch._dynamo.variables.ListVariable
- )
- vt_to_graph_out_idx: dict[VariableTracker, int] = {}
- for value in pass2.graph_outputs.values():
- assert isinstance(value, torch._dynamo.codegen.GraphOutputEntry)
- variable: VariableTracker = value.variable
- vt_to_graph_out_idx[variable] = value.index
- for idx, vt in enumerate(flat_returns.items):
- if vt in vt_to_graph_out_idx:
- self.export_metadata.output_return_type[idx] = (
- "graph_out",
- vt_to_graph_out_idx[vt],
- )
- elif (
- vt.source is not None
- and (source := getattr(vt.source, "base", None))
- and source.is_input
- ):
- self.export_metadata.output_return_type[idx] = (
- "input",
- vt.source,
- )
- elif isinstance(vt, torch._dynamo.variables.ConstantVariable):
- self.export_metadata.output_return_type[idx] = (
- "constant",
- vt.as_python_constant(),
- )
- else:
- assert f"Encountered unrecognized type {vt} at output {idx}" # noqa: PLW0129
- self.export_metadata.out_spec = out_spec.as_python_constant()
- output = []
- if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
- output.extend(
- self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
- )
- if len(pass2.graph_outputs) != 0:
- output.append(pass2.create_store(graph_output_var))
- stored_graph_output_var = True
- else:
- output.append(create_instruction("POP_TOP"))
- else:
- # NB: Important to run compiler collective even when there is
- # a graph break
- self.run_compiler_collective()
- self.add_output_instructions(output + pass2.get_instructions())
- # store all stack and locals for each frame
- # current state of the stack:
- # *(frame N stack), *(frame N locals),
- # ...,
- # *(frame 1 stack), *(frame 1 locals)
- self.add_output_instructions(
- [
- create_instruction(
- "BUILD_LIST",
- arg=len(stack_values_flat) - all_stack_locals_metas[0].num_stack,
- ),
- ]
- )
- # current state of the stack:
- # *(frame N stack), [
- # *(frame N locals),
- # *(frame N-1 stack), *(frame N-1 locals),
- # ...
- # *(frame 1 stack), *(frame 1 locals),
- # ]
- # iterate current frame (N) to root frame (1)
- # sliding window over frame stack/locals
- start_idx = 0
- end_idx = 0
- for i, meta in enumerate(all_stack_locals_metas):
- # do not pack frame N's stack into the value list
- n_vals = len(meta.locals_names)
- if i != 0:
- n_vals += meta.num_stack
- if n_vals == 0:
- self.add_output_instructions(
- [
- create_instruction("BUILD_LIST", arg=0),
- *create_swap(2),
- ]
- )
- # [], stack_values_flat
- else:
- end_idx += n_vals
- self.add_output_instructions(
- [
- create_dup_top(),
- *create_binary_slice(start_idx, end_idx),
- *create_swap(2),
- ]
- )
- start_idx += n_vals
- # stack_values_flat[x:y], stack_values_flat
- # add root frame's unmodified locals here
- if i == len(all_stack_locals_metas) - 1:
- root_cg = PyCodegen(self.root_tx)
- unmodified_locals_names: dict[str, int] = {}
- for k, v in self.root_tx.symbolic_locals.items():
- if isinstance(v.source, LocalSource) and v.source.local_name == k:
- root_cg.append_output(root_cg.create_load(k))
- unmodified_locals_names[k] = len(meta.locals_names) + len(
- unmodified_locals_names
- )
- self.add_output_instructions(
- root_cg.get_instructions()
- + [
- create_instruction(
- "BUILD_LIST", arg=len(unmodified_locals_names)
- ),
- # arg=2 because we already swapped the locals list back
- create_instruction("LIST_EXTEND", arg=2),
- ]
- )
- meta.locals_names.update(unmodified_locals_names)
- # *(frame N stack), metas[0] stack + locals, ..., metas[i] stack + locals, stack_values_flat
- # current state of the stack:
- # *(frame N stack)
- # frame N locals,
- # frame N-1 stack, frame N-1 locals,
- # ...
- # frame 1 stack, frame 1 locals,
- # stack_values_flat
- #
- self.add_output_instructions(
- [
- create_instruction("POP_TOP"),
- create_instruction("BUILD_LIST", arg=len(all_stack_locals_metas)),
- *create_rot_n(all_stack_locals_metas[0].num_stack + 1),
- ]
- )
- # final state of the stack before running the unsupported bytecode:
- # [
- # [frame N locals],
- # [frame N-1 stack + locals],
- # ...,
- # [frame 1 stack + locals],
- # ], *(frame N stack)
- if graph_output_var and stored_graph_output_var:
- self.add_output_instructions(
- [create_instruction("DELETE_FAST", argval=graph_output_var)]
- )
- if self.export:
- from torch.export._trace import _ExportModuleSpecTrackerDict
- potential_side_effects = []
- for var in self.side_effects._get_modified_vars():
- if hasattr(var, "mutation_type"):
- mut_type = var.mutation_type
- # Make sure to skip codegen specific mutations
- if isinstance(
- mut_type, (AttributeMutationExisting, ValueMutationExisting)
- ):
- # export uses tracepoint pass to dump submodule inp/out spec
- # into global state, so we filter it here
- if not (
- isinstance(var, UserDefinedDictVariable)
- and isinstance(var.value, _ExportModuleSpecTrackerDict)
- ):
- potential_side_effects.append(var)
- side_effect_refs = [
- _get_source_debug_name(var.source) for var in potential_side_effects
- ]
- if len(side_effect_refs):
- warnings.warn(
- f"While exporting, we found certain side effects happened in the model.forward. "
- f"Here are the list of potential sources you can double check: {side_effect_refs}"
- )
- return all_stack_locals_metas
- def codegen_suffix(
- self,
- tx: "InstructionTranslatorBase",
- stack_values: list[VariableTracker],
- cg: PyCodegen,
- ) -> None:
- # NOTE: `codegen_save_tempvars` must run first to update `source` fields
- # for variables with `AttributeMutationNew`, as they don't implement
- # `reconstruct` themselves.
- self.side_effects.codegen_save_tempvars(cg)
- if self.backward_state:
- assert not self.export
- for name, val in self.backward_state.items():
- cg(val)
- assert self.backward_state_var is not None
- cg.append_output(cg.create_load(self.backward_state_var))
- cg.store_attr(name)
- self.side_effects.codegen_hooks(cg)
- # Return variables used for logging at the end
- for debug_var, args in tx.debug_locals:
- cg.add_push_null(lambda: cg(debug_var))
- for arg in args:
- cg(arg)
- cg.extend_output(create_call_function(len(args), False))
- cg.extend_output([create_instruction("POP_TOP")])
- cg.restore_stack(stack_values, value_from_source=not tx.export)
- self.side_effects.codegen_update_mutated(cg)
- def cleanup_graph(self) -> None:
- """
- Remove "creation_timestamp" from node meta
- Remove this pattern from the graph:
- torch._C._set_grad_enabled(False)
- torch._C._set_grad_enabled(True)
- """
- assert self.should_exit
- nodes = list(self.graph.nodes)
- for node in nodes:
- node.meta.pop("creation_timestamp", None)
- grad_enabled = torch.is_grad_enabled()
- for node1, node2 in zip(nodes, nodes[1:]):
- if (
- node1.target is torch._C._set_grad_enabled
- and tuple(node1.args) == (not grad_enabled,)
- and not node1._erased
- ):
- grad_enabled = node1.args[0]
- if (
- node2.target is torch._C._set_grad_enabled
- and tuple(node2.args) == (not grad_enabled,)
- and not node2._erased
- ):
- grad_enabled = node2.args[0]
- self.graph.erase_node(node1)
- self.graph.erase_node(node2)
- def bypass_package(self, reason: str = "", **kwargs: Any) -> None:
- """
- Do not save this output graph to the CompilePackage
- """
- if not self.package:
- return
- if torch._dynamo.config.strict_precompile:
- raise torch._dynamo.exc.PackageError(
- "Detected a package bypass: %s", reason
- )
- log.warning("Detected a package bypass: %s", reason)
- torch._logging.trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "precompile_cache_bypass",
- "encoding": "json",
- },
- payload_fn=lambda: {
- # precede with underscore so it always appear first in JSON in tlparse
- "_reason": reason,
- **kwargs,
- },
- )
- self.package.bypass_current_entry()
- self.package = None
- def get_graph_sizes_structured(self) -> dict[str, list[Union[int, str]]]:
- ret: dict[str, list[Union[int, str]]] = {}
- for node in self.graph.nodes:
- example_value = node.meta.get("example_value", None)
- if isinstance(example_value, torch._subclasses.FakeTensor):
- size = example_value.size()
- ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size]
- return ret
- def get_graph_sizes(self, name: str) -> str:
- graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
- graph_sizes_str += f"===== {name} =====\n"
- for node in self.graph.nodes:
- example_value = node.meta.get("example_value", None)
- if isinstance(example_value, torch._subclasses.FakeTensor):
- size = example_value.size()
- graph_sizes_str += f"{node.name}: {tuple(size)}\n"
- concrete_size = []
- has_symint = False
- for sz in size:
- if isinstance(sz, int):
- concrete_size.append(sz)
- elif isinstance(sz, torch.SymInt):
- has_symint = True
- concrete_size.append(sz.node.hint)
- else:
- break
- else:
- if has_symint:
- graph_sizes_str += (
- f"{node.name} (concrete): {tuple(concrete_size)}\n"
- )
- return graph_sizes_str
- @contextlib.contextmanager
- def restore_global_state(self) -> Any:
- """
- Momentarily restores the global state to what it was prior to tracing the current output
- """
- prior_global_state = self.tracing_context.global_context.copy_graphstate()
- current_global_state: dict[str, tuple[Any, bool]] = {}
- self.save_global_state(out=current_global_state)
- try:
- # Set to state prior to tracing the graph
- self.tracing_context.global_context.restore_graphstate(prior_global_state)
- yield
- finally:
- # Reset to state at the current time (e.g. before calling the user compiler)
- self.tracing_context.global_context.restore_graphstate(
- GlobalContextCheckpointState(current_global_state)
- )
- def run_compiler_collective(self) -> None:
- tx = self.root_tx
- assert tx is not None
- if (ds := tx.distributed_state) is not None and ds.all_states is None:
- compile_pg = ds.compile_pg
- log.info("compiler_collective %s", ds.local_state)
- torch._logging.trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "compiler_collective",
- "encoding": "string",
- },
- payload_fn=lambda: ds.local_state.render(),
- )
- device_types = compile_pg._device_types
- assert len(device_types) == 1, (
- "Expect only one device type but got {}".format("+".join(device_types))
- )
- with (
- get_interface_for_device(device_types.pop()).device( # type: ignore[attr-defined]
- compile_pg.rank() % torch.accelerator.device_count()
- ),
- dynamo_timed("compiler_collective", log_pt2_compile_event=True),
- ):
- all_states: list[Any] = [None] * compile_pg.size()
- dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
- ds.all_states = all_states
- # Clear speculation log, because are tracing may diverge due to
- # this information from the compiler collective
- tx.speculation_log.clear()
- raise exc.CompileCollectiveRestartAnalysis
- def compile_and_call_fx_graph(
- self,
- tx: "InstructionTranslatorBase",
- rv: list[VariableTracker],
- root: FakeRootModule,
- ) -> list[Instruction]:
- """
- Generate code from self.graph and return the Instruction()s to
- call that generated code.
- Code is generated w.r.t. self.root_tx.
- tx is only used for preserving GraphModule metadata
- """
- with torch._guards.TracingContext.clear_frame():
- from .decorators import disable
- assert self.should_exit
- self.run_compiler_collective()
- if count_calls(self.graph) == 0 and len(rv) == 0:
- return []
- name = unique_id("__compiled_fn", with_uuid=True)
- assert isinstance(rv, list)
- assert isinstance(root, FakeRootModule)
- output_node = self.create_node(
- "output",
- "output",
- (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
- {},
- )
- sub_gms = self.dedup_pass()
- root.add_nn_modules(sub_gms) # type: ignore[arg-type]
- self.current_tracer._maybe_preserve_original_meta(tx, output_node)
- if not config.do_not_emit_runtime_asserts:
- # There is a rare scenario where codegen_suffix adds a new entry
- # to self.nn_modules while `root` knows only about the
- # nn_modules at the time of its creation. This causes failures
- # while creating the graph module because self.graph and root
- # are out of sync. This only happens for `get_attr` nodes, so
- # here we clean up the get_attr nodes that are unused.
- self.remove_unused_get_attr_nodes()
- insert_deferred_runtime_asserts(
- fx.GraphModule(root, self.graph),
- self.shape_env,
- name,
- export=self.export,
- )
- # NB: deferred runtime asserts can keep graphargs live, so make sure
- # those are inserted before pruning
- self.remove_unused_graphargs()
- ncalls = count_calls(self.graph)
- counters["stats"]["calls_captured"] += ncalls
- self.remove_tensorify_specialized_graphargs()
- # free a bit of memory
- self.real_value_cache.clear()
- gm = _make_graph_module(root, self.graph)
- # Saved tensors hooks are not used by the graph.
- # GraphModule by default only copies used in the graph submodules.
- # Copying them into the result graph manually.
- if self.saved_tensors_hooks_subgraph_names:
- for subgraph_name in self.saved_tensors_hooks_subgraph_names:
- setattr(gm, subgraph_name, getattr(root, subgraph_name))
- for register_finalizer in self.register_finalizer_fns:
- register_finalizer(gm)
- if next(gm.parameters(), None) is not None:
- # If dynamo produces a graph with parameters, skip package stuff
- # Bypass output graph
- self.bypass_package(
- "Graph contains named parameters: either inline_inbuilt_nn_modules=False or there are static addresses.",
- inline_builtin_nn_modules=torch._dynamo.config.inline_inbuilt_nn_modules,
- gm=gm.print_readable(
- print_output=False, include_stride=True, include_device=True
- ),
- )
- if self.package is not None:
- gm._backend_id = name
- gm.compile_subgraph_reason = self.compile_subgraph_reason
- gm.meta["dynamo_flat_name_to_original_fqn"] = (
- self.dynamo_flat_name_to_original_fqn.copy()
- )
- gm.meta["dynamo_compile_id"] = self.dynamo_compile_id
- gm.meta["backend_id"] = name
- graph_code_log.debug(
- "%s",
- lazy_format_graph_code(
- name, gm, include_stride=True, include_device=True, colored=True
- ),
- )
- torch._logging.trace_structured(
- "dynamo_output_graph",
- lambda: {"sizes": self.get_graph_sizes_structured()},
- payload_fn=lambda: gm.print_readable(
- print_output=False, include_stride=True, include_device=True
- ),
- )
- self.call_cleanup_hooks()
- old_fake_mode = self.tracing_context.fake_mode
- assert old_fake_mode is not None
- if not self.export:
- import torch._functorch.config as _config
- with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
- # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
- backend_fake_mode = torch._subclasses.FakeTensorMode(
- shape_env=old_fake_mode.shape_env,
- )
- # TODO(voz): Ostensibily, this should be scoped and
- # restore back to old_fake_mode, but doing so currently violates
- # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
- self.tracing_context.fake_mode = backend_fake_mode
- with self.restore_global_state():
- compiled_fn = self.call_user_compiler(gm, self.example_inputs())
- from torch.fx._lazy_graph_module import _LazyGraphModule
- if isinstance(compiled_fn, _LazyGraphModule) or (
- isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
- and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined]
- ):
- # Since dynamo will run the forward method for the GraphModule shortly
- # anyways, it does not hurt to do the real recompilation here if
- # this is a _LazyGraphModule. This makes it easier for dynamo to
- # optimize a _LazyGraphModule.
- lazy_gm = (
- compiled_fn
- if isinstance(compiled_fn, _LazyGraphModule)
- else compiled_fn.__self__ # type: ignore[attr-defined]
- )
- _LazyGraphModule.force_recompile(lazy_gm)
- if not isinstance(compiled_fn, _LazyGraphModule):
- # replace compiled_fn with the real forward method
- compiled_fn = lazy_gm.forward
- if self.package is not None:
- self.package.add_backend_id(name, compiled_fn)
- compiled_fn = disable(
- compiled_fn, reason="do not trace Dynamo-compiled graph"
- )
- counters["stats"]["unique_graphs"] += 1
- assert old_fake_mode.shape_env is not None
- if specializations := old_fake_mode.shape_env.specializations:
- specialization_guards = []
- specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
- sources = [a.source for a in self.graphargs]
- for specialization in specializations:
- source_index = sources.index(specialization.source)
- check_fn_source = inspect.getsource(specialization.check_fn).strip()
- # Required because the LABDA_GUARD API requires a root guard manager
- unused_root_guard_manager = RootGuardManager()
- check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
- unused_root_guard_manager,
- specialization.check_fn,
- [check_fn_source],
- )
- log.debug(
- "Compiling backend specialized graph with specialization=%s",
- check_fn_source,
- )
- specialization_guards.append(
- (
- functools.partial(
- lambda idx, args, check_fn=check_fn: check_fn(
- args[idx]
- ),
- source_index,
- ),
- specialization,
- )
- )
- @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") # type: ignore[misc]
- def specialized_dispatch(*args: Any, **kwargs: Any) -> Any:
- for check_fn, specialization in specialization_guards:
- if check_fn(args):
- if specialization in specialization_cache:
- return specialization_cache[specialization](
- *args, **kwargs
- )
- with self.shape_env.patch_source_specialization(
- specialization.source, specialization.check_fn
- ):
- # Modify gm so AOTAutogradCache key changes per specialization
- gm.meta["specialization"] = specialization
- example_inputs: list[Tensor] = list(args)
- with tracing(self.tracing_context):
- specialization_cache[specialization] = (
- self.call_user_compiler(gm, example_inputs)
- )
- return specialization_cache[specialization](*args, **kwargs)
- return compiled_fn(*args, **kwargs)
- # This is safe because we pre-process name to be unique
- self.install_global_unsafe(name, specialized_dispatch)
- else:
- # This is safe because we pre-process name to be unique
- self.install_global_unsafe(name, compiled_fn)
- assert self.root_tx is not None
- cg = PyCodegen(self.root_tx)
- for idx, arg in enumerate(self.graphargs):
- self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source
- cg.make_call_generated_code(name)
- return cg.get_instructions()
- @property
- def placeholders(self) -> list[fx.Node]:
- return self.graph.find_nodes(op="placeholder")
- @property
- def graphargs(self) -> list[GraphArg]:
- return [node.meta["grapharg"] for node in self.placeholders]
- def call_user_compiler(
- self, gm: fx.GraphModule, example_inputs: list[Tensor]
- ) -> CompiledFn:
- with dynamo_timed(
- "OutputGraph.call_user_compiler",
- phase_name="backend_compile",
- log_pt2_compile_event=True,
- log_waitcounter=True,
- waitcounter_name_override="compile_aot_autograd",
- dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
- ):
- return self._call_user_compiler(gm, example_inputs)
- def _call_user_compiler(
- self, gm: fx.GraphModule, example_inputs: list[Tensor]
- ) -> CompiledFn:
- assert self.compiler_fn is not None
- tot = 0
- placeholders = []
- for node in gm.graph.nodes:
- if node.op in ("call_function", "call_method", "call_module"):
- tot += 1
- if node.op == "placeholder":
- placeholders.append(node)
- increment_op_count(tot)
- for pl in placeholders:
- if not hasattr(pl, "_dynamo_source"):
- arg = pl.meta["grapharg"]
- # TODO: Why isn't this stored in meta :think:
- # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
- pl._dynamo_source = arg.source
- # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
- gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
- gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]
- name = (
- self.compiler_fn.__name__
- if hasattr(self.compiler_fn, "__name__")
- else "<unknown compiler_fn>"
- )
- try:
- _step_logger()(logging.INFO, f"calling compiler function {name}")
- compiler_fn = self.compiler_fn
- if config.verify_correctness:
- compiler_fn = WrapperBackend(compiler_fn)
- compiled_fn = compiler_fn(gm, example_inputs)
- _step_logger()(logging.INFO, f"done compiler function {name}")
- assert callable(compiled_fn), "compiler_fn did not return callable"
- except (TensorifyScalarRestartAnalysis, ShortenTraceback):
- raise
- except exceptions_allowed_to_be_fallback as e:
- if self.has_user_defined_allowed_in_graph:
- raise BackendCompilerFailed(
- self.compiler_fn, e, inspect.currentframe()
- ).with_traceback(e.__traceback__) from None
- unimplemented_v2_with_warning(
- e,
- self.root_tx.f_code,
- gb_type="Backend compiler exception",
- context=f"Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}",
- explanation=f"Backend compiler `{name}` failed with {str(e)}. Adding a graph break.",
- hints=[
- "Report an issue to the backend compiler repo.",
- ],
- )
- except SkipFrame as e:
- # The backend compiler has requested that we skip the frame, instead of
- # aborting execution.
- raise e
- except Exception as e:
- raise BackendCompilerFailed(
- self.compiler_fn, e, inspect.currentframe()
- ).with_traceback(e.__traceback__) from None
- signpost_event(
- "dynamo",
- "OutputGraph.call_user_compiler",
- {
- **self.co_fields,
- "op_count": tot,
- "node_count": len(gm.graph.nodes),
- "input_count": len(placeholders),
- },
- )
- return compiled_fn
- def dedup_pass(self) -> dict[str, torch.fx.GraphModule]:
- if torch._dynamo.config.use_graph_deduplication:
- return apply_graph_deduplication(self)
- else:
- return {}
- def install_subgraph(self, name: str, sub_gm: torch.fx.GraphModule) -> str:
- next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True)
- sub_gm.__name__ = next_name # type: ignore[assignment]
- sub_gm.torchdynamo_force_dynamic = False # type: ignore[assignment]
- # This graph module is not present in the user space, so it can't be
- # accessed by a source. Set source=None.
- self.register_attr_or_module(sub_gm, next_name, source=None)
- return next_name
- def example_inputs(self) -> list[torch.Tensor]:
- result = [arg.example for arg in self.graphargs]
- return result
- def remove_unused_get_attr_nodes(self) -> None:
- for node in sorted(self.graph.find_nodes(op="get_attr"), reverse=True):
- if len(list(node.users)) == 0:
- self.remove_node(node)
- def remove_unused_graphargs(self) -> None:
- # NB: It's OK to drop GraphArg for symbols that ended up being
- # specialized iff they are not used in runtime assertions. You don't
- # even have to make a guard for it, because ShapeEnv produce_guards
- # operates on tracked_fakes, which never gets pruned.
- # That being said, you'll get marginally better generated
- # guard code if you promote the guard into a Dynamo guard (since that
- # allows for the guard to be done using C++ guards.) If we get
- # ShapeEnv guards to go into C++ guards, this will stop being a thing
- # though!
- assert self.should_exit
- # Miniature DCE pass, but only for obviously trivial operations
- def is_static_true(b_node: fx.node.Argument) -> bool:
- if b_node is True:
- return True
- if not isinstance(b_node, fx.Node):
- return False
- b = b_node.meta.get("example_value")
- if b is None:
- return False
- if b is True:
- return True
- if (
- isinstance(b, torch.SymBool)
- and (r := b.node.maybe_as_bool()) is not None
- ):
- return r
- # TODO: We can also technically remove all cases when the input
- # doesn't have unbacked inputs, since it's all in the ShapeEnv
- return False
- def is_symnode_arg(a: fx.node.Argument) -> bool:
- from torch.fx.experimental.sym_node import SymTypes
- if isinstance(a, (int, float, bool)):
- return True
- if isinstance(a, fx.Node):
- return isinstance(a.meta.get("example_value"), SymTypes)
- return False
- # NB: We assume that you cannot do mutations on int/float/bool,
- # because they are immutable types, and therefore is always safe to
- # DCE.
- def is_symnode_compute_node(node: fx.Node) -> bool:
- from torch.fx.experimental.sym_node import SymTypes
- if node.op != "call_function":
- return False
- # TODO: I don't think it's possible to have a bare int/float here?
- if not isinstance(node.meta.get("example_value"), SymTypes):
- return False
- # TODO: This will bail here if you ever end up with a more complicated
- # computation function, like sum(list_of_ints), even though it
- # should be DCE'able
- if not all(is_symnode_arg(a) for a in node.args):
- return False
- if not all(is_symnode_arg(a) for a in node.kwargs.values()):
- return False
- return True
- from torch.fx.experimental.symbolic_shapes import is_accessor_node
- for node in reversed(list(self.graph.nodes)):
- if len(list(node.users)) == 0:
- if (
- node.op == "get_attr"
- or (node.op == "call_function" and node.target is operator.getitem)
- or (
- node.op == "call_function"
- and node.target is torch._check
- and is_static_true(node.args[0])
- )
- or is_symnode_compute_node(node)
- or is_accessor_node(node)
- ):
- self.remove_node(node)
- def placeholder_binds_symbol(node: fx.Node) -> Optional[sympy.Symbol]:
- arg = node.meta["grapharg"]
- example = arg.example
- if isinstance(example, torch.SymInt) and isinstance(
- example.node.expr, sympy.Symbol
- ):
- return example.node.expr
- return None
- def remove_unused(node: fx.Node) -> None:
- log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name())
- # I'm not really sure why you need to delete these from the
- # node since the node is going to get removed
- del node.meta["grapharg"]
- self.remove_node(node)
- self.real_value_cache.pop(node, None)
- used_symbols: set[sympy.Symbol] = set()
- def update_used_symbols(
- used_symbols: set[sympy.Symbol], fake: Union[torch.SymInt, torch.Tensor]
- ) -> None:
- used_symbols |= free_symbols(fake)
- recheck_placeholders = []
- for node in self.placeholders:
- binds_symbol = placeholder_binds_symbol(node) is not None
- # Don't delete symbol bindings yet
- if binds_symbol:
- if not node.users:
- recheck_placeholders.append(node)
- else:
- if not node.users and not isinstance(
- node.meta["grapharg"], BackwardStateGraphArg
- ):
- remove_unused(node)
- else:
- # Register the free symbols as uses
- arg = node.meta["grapharg"]
- if isinstance(arg, BackwardStateGraphArg):
- continue
- if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
- real_script_obj = node.meta["grapharg"].example
- fake_script_obj = node.meta["grapharg"].example_strong_ref
- if not torch._library.fake_class_registry.tracing_with_real(
- real_script_obj
- ):
- flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined]
- for attr in flat_dict.keys():
- fake_attr_val = getattr(
- fake_script_obj.wrapped_obj, attr
- )
- pytree.tree_map_only(
- (torch.SymInt, torch.Tensor),
- lambda t: update_used_symbols(used_symbols, t),
- fake_attr_val,
- )
- continue
- fake = (
- arg.fake_tensor if arg.fake_tensor is not None else arg.example
- )
- update_used_symbols(used_symbols, fake)
- # After removing unused graphargs, prune unused binds_symbol
- for node in recheck_placeholders:
- symbol = placeholder_binds_symbol(node)
- if symbol is not None:
- if symbol not in used_symbols:
- remove_unused(node)
- else:
- # Make sure we delete later occurrences of the same symbol
- used_symbols.remove(symbol)
- def remove_tensorify_specialized_graphargs(self) -> None:
- # This is a pretty interesting function. Basically we have this problem
- # where our compiler tends to choke when we have unused inputs. The way
- # we support dynamic float arguments is by doing a joint fx pass and
- # tensorifying away as many symfloats as we can. For the remaining symfloats
- # we have no choice but to specialize... HOWEVER at that point in time
- # we can no longer remove graph inputs. So our sledgehammer solution is to
- # save the state of what inputs we should have specialized in dynamo and
- # restart analysis. This function incorporates this "view from the future"
- # state and specializes inputs that we know we won't be able to tensorify
- # away in the joint pass. In principle we shouldn't choke on unused inputs
- # and so this shouldn't be necessary. In practice CUDA graphs choke on
- # unused inputs so we need this for now.
- # Import here to prevent circular import
- from torch._dynamo.symbolic_convert import TensorifyState
- for node in self.graph.nodes:
- example_value = node.meta.get("example_value")
- if (
- isinstance(example_value, FakeTensor)
- and example_value.item_memo is not None
- and hasattr(example_value.item_memo.node._expr, "name")
- and all(u.target == "item" for u in node.users)
- and TensorifyState.should_specialize(
- # We use _expr instead of expr b/c we want the symbol not the replacement
- example_value.item_memo.node._expr.name
- )
- ):
- for u in list(node.users):
- u.replace_all_uses_with(guard_scalar(example_value.item_memo))
- self.remove_node(u)
- self.remove_node(node)
- def add_output_instructions(self, prefix: list[Instruction]) -> None:
- """
- We call this on the creation of a new compiled subgraph that is inserted
- before user code.
- """
- self.output_instructions.extend(prefix)
- self.should_exit = True
- def install_global_unsafe(self, name: str, value: Any) -> None:
- """
- WARNING: prefer the safer `install_global_by_id/install_global`.
- torch.compile instances should be independent of each other;
- one footgun is to have one instance depend on the existence of
- a global installed by another instance. This can happen if we mangle
- a global the same way across both instances.
- """
- assert name not in self.installed_globals
- self.installed_globals.add(name)
- self.cleanups.append(CleanupHook.create(self.global_scope, name, value))
- def install_global_by_id(self, prefix: str, value: Any) -> str:
- """
- Installs a global if it hasn't been installed already.
- This is determined by (prefix, id(value)) pair.
- Returns the name of the newly installed global.
- """
- # NB: need self.compile_id to distinguish this global
- # from another global created in a different torch.compile instance
- name = f"{prefix}_{id(value)}_c{self.compile_id}"
- if name in self.installed_globals:
- return name
- self.install_global_unsafe(name, value)
- return name
- def install_global(self, prefix: str, value: Any) -> str:
- """
- Installs a global, generating a unique name for it.
- Returns the name of the newly installed global.
- """
- # NB: unique_id is unique, even across torch.compile instances
- name = unique_id(prefix)
- self.install_global_unsafe(name, value)
- return name
- def cleanup(self) -> None:
- # There is a reference cycle between tracer and OutputGraph, causing
- # some of the tensor objects to be held alive for longer than necessary.
- self.root_tx = None # type: ignore[assignment]
- self.nn_modules.clear()
- self.param_name_to_source = None
- for node in self.graph.nodes:
- if "grapharg" in node.meta:
- del node.meta["grapharg"]
- self.real_value_cache.clear()
- self.input_name_to_proxy.clear()
- self.side_effects.clear()
- self.variable_tracker_cache.clear()
- self.register_finalizer_fns.clear()
- self.dynamo_flat_name_to_original_fqn.clear()
- self.tracing_context.clear()
- self.input_source_to_var.clear()
- self.unspec_variable_map.clear()
- self.backward_state.clear()
- def add_graph_finalizer(
- self, register_finalizer: Callable[[fx.GraphModule], None]
- ) -> None:
- self.register_finalizer_fns.append(register_finalizer)
- def example_value_from_input_node(self, node: torch.fx.Node) -> Any:
- """Extract the non-fake example tensor"""
- if node.op == "placeholder":
- return node.meta["grapharg"].example
- assert node.op == "get_attr"
- return self.nn_modules[node.target] # type: ignore[index]
- class DynamoTracerOutput:
- error_on_graph_break: bool
- is_tracing_resume_prologue: bool
- output_graph: Optional[OutputGraph]
- def __init__(
- self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None
- ) -> None:
- self.error_on_graph_break = tracer.error_on_graph_break
- self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue
- if error:
- self.output_graph = None
- else:
- self.output_graph = tracer.output
- err_epilogue = (
- "With the current config, we will graph break "
- "(and fall back to eager-mode PyTorch) on all ops "
- "that have do not have the 'pt2_compliant_tag'. "
- "Please see the following doc for how to mark this op as PT2 compliant "
- "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html"
- )
- def check_pt2_compliant_op(
- output_graph: OutputGraph, kind: str, target: Any, args: Any, kwargs: Any
- ) -> None:
- if kind != "call_function":
- return
- def encountered_compliant_op(target: torch._ops.OpOverload) -> None:
- if target.namespace in {"prim", "prims", "aten"}:
- return
- output_graph.compliant_custom_ops.add(target)
- def encountered_non_compliant_op(target: torch._ops.OpOverload, msg: str) -> None:
- output_graph.non_compliant_ops.add(target)
- if config.only_allow_pt2_compliant_ops:
- unimplemented_v2(
- gb_type="Encountered non-PT2-compliant op",
- context="",
- explanation=msg + " " + err_epilogue,
- hints=[],
- )
- if isinstance(target, torch._ops.OpOverload):
- if torch.Tag.pt2_compliant_tag in target.tags:
- encountered_compliant_op(target)
- return
- encountered_non_compliant_op(
- target,
- f"Encountered the torch.ops.OpOverload {target} that is not PT2 compliant.",
- )
- return
- if isinstance(target, torch._ops.OpOverloadPacket):
- overloads = tuple(target.overloads())
- # Optimization: Overload resolution is expensive.
- # If there's only one overload, we know what it will resolve to.
- if len(overloads) == 1:
- op = getattr(target, overloads[0])
- if torch.Tag.pt2_compliant_tag in op.tags:
- encountered_compliant_op(op)
- return
- encountered_non_compliant_op(
- op,
- f"Encountered the non-overloaded "
- f"torch.ops.OpOverloadPacket {target} "
- f"that is not PT2 compliant. ",
- )
- return
- args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes(
- output_graph.current_tx, (args, kwargs), False
- )
- try:
- overload = torch._C._jit_resolve_packet(
- target._qualified_op_name, *args, **kwargs
- )
- except RuntimeError as e:
- unimplemented_v2(
- gb_type="Error when attempting to resolve op packet",
- context="",
- explanation=str(e),
- hints=[],
- )
- op = getattr(target, overload)
- if torch.Tag.pt2_compliant_tag in op.tags:
- encountered_compliant_op(op)
- else:
- encountered_non_compliant_op(
- op,
- f"Encountered the torch.ops.OpOverloadPacket {target} "
- f"which resolves to the overload ({overload}) that is "
- f"not PT2 compliant.",
- )
- _compile_id_counter = itertools.count()
- P = ParamSpec("P")
- R = TypeVar("R")
- class LazyProxy:
- def __init__(
- self,
- tracer: "SubgraphTracer",
- fn: Callable[P, R],
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> None:
- self.tracer = tracer
- self.fn = fn
- self.args = args
- self.kwargs = kwargs
- def __call__(self) -> Any:
- return self.fn(*self.args, **self.kwargs)
- class SubgraphTracer(fx.Tracer):
- """
- Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
- and the separation of responsibilities is that SubgraphTracer is
- responsible for building the graph while OutputGraph is responsible for
- compiling and executing the graph.
- """
- def __init__(
- self,
- output_graph: "OutputGraph",
- parent: Optional["SubgraphTracer"] = None,
- is_export: bool = False,
- source_target: Optional[Target] = None,
- ) -> None:
- super().__init__()
- self.output_graph = weakref.proxy(output_graph)
- self.graph = torch.fx.Graph()
- # See note [Export inputs must be explicitly passed in]
- self.is_export = is_export
- # Map from graph input name to its placeholder proxy object, where the
- # map's keys give all current placeholder node names and can be used to
- # create unique node names
- self.input_name_to_proxy: dict[str, fx.Proxy] = {}
- # Node => computed real value (see utils.get_real_value)
- self.real_value_cache: dict[fx.Node, torch.Tensor] = {}
- # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
- self.parent = parent
- self.source_target = source_target
- # A dict mapping previously free variables (Proxy objects)
- # to new Proxy objects that wrap inputs to this subgraph.
- #
- # This dict maps proxies in outer graphs to placeholders in current graph.
- # It serves two purposes:
- # - Proxies are associated with VariableTrackers. If we see
- # the same VariableTracker twice (and it is a free variable),
- # then we want to use the same Proxy in the current subgraph to
- # record the tracing.
- # - If we are tracing a HigherOrderOperator's body_fn, then we
- # need to keep track of what free variables were lifted so we can
- # rewrite the HigherOrderOperator call using the traced body_fn.
- # Dicts maintain the order of args for the HigherOrderOperator call.
- self.lifted_freevars: dict[fx.Proxy, fx.Proxy] = {}
- # map basic symbols (unbacked and unbacked) to their bound proxies.
- # There are only two cases where bound_symbols will be recorded:
- # 1. when we create_graph_input for a backed SymInt that's basic symbol
- # 2. when we track_produced_symints for intermediate results
- # bound_symbols always map the symbol to the proxy whose
- # tracer is the current tracer that's readily accessible in current tracer's graph.
- self.bound_symbols: dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {}
- self.prev_inst = None
- # True if this tracer is currently tracing into torch.utils.checkpoint
- # as part of speculate_subgraph.
- self.under_activation_checkpoint = False
- # True if we want to allow externally visible side-effects (doesn't throw error on their existence)
- # during this tracer's tracing of torch.utils.checkpoint (via speculate_subgraph).
- # Only safe if we know for sure that *NOT* replaying these side-effects during
- # backward recomputation of the checkpoint region doesn't affect its correctness.
- self.allow_side_effects_under_checkpoint = False
- # True if we want to allow externally visible side-effects (doesn't throw error on their existence)
- # during this tracer's tracing. This is currently only used by experimental AC out-of-tree
- # via torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer.
- # Note: Externally visible side-effects are allowed if this flag OR the above flag is True.
- self.unsafe_allow_externally_visible_side_effects = False
- # True if this tracer is currently tracing (reconstructing) into a Python generator
- self.is_reconstructing_generator = False
- self.debug_level: int = parent.debug_level + 1 if parent is not None else 0
- self._cur_code = None
- self._orig_gm_meta: Optional[list[Any]] = None
- self._orig_gm_lineno_map: Optional[dict[int, Optional[int]]] = None
- self._orig_gm_firstlineno: Optional[int] = None
- # Each SubgraphTracer is associated with a source target, which indicates
- # which operator this subgraph is attached to. We compute a source_fn_stack
- # based on the source target. For the root tracer, it's set to [].
- # This is useful for debugging and transforming the exported graph.
- if self.parent is None:
- self.source_fn_stack: list[Any] = []
- else:
- self.source_fn_stack = self.parent.source_fn_stack + [
- (self.graph._target_to_str(source_target), source_target)
- ]
- # This is used to create a unique name for the placeholder
- self._used_names: OrderedSet[str] = OrderedSet()
- # Stores the versions of the input tensors at the time they are inserted
- # as placeholders in the graph. This is used to track input mutation.
- self._input_versions_at_beginning: list[int] = []
- if torch.is_inference_mode_enabled():
- raise RuntimeError(
- "Inference mode is supposed to be disabled during compilation. Please open an issue."
- )
- # preserve original meta if it is available
- def _maybe_preserve_original_meta(
- self, tx: "InstructionTranslatorBase", node: fx.Node
- ) -> None:
- if (
- self._orig_gm_meta
- and self._orig_gm_lineno_map
- and self._orig_gm_firstlineno
- ):
- lineno = tx.current_instruction.starts_line
- node_idx = None
- if lineno is not None:
- node_idx = self._orig_gm_lineno_map.get(
- lineno - self._orig_gm_firstlineno, None
- )
- if node_idx is not None:
- meta = self._orig_gm_meta[node_idx]
- for field in fx.proxy._COPY_META_FIELDS:
- if field in meta:
- node.meta[field] = meta[field]
- if "stack_trace" in meta:
- node.meta["stack_trace"] = meta["stack_trace"]
- def create_proxy(
- self,
- kind: str,
- target: Any,
- args: Any,
- kwargs: Any,
- name: Optional[str] = None,
- type_expr: Optional[Any] = None,
- proxy_factory_fn: Optional[Callable[[fx.Node], fx.Proxy]] = None,
- ) -> fx.Proxy:
- # NOTE: [Nested SubgraphTracer and free_variable handling]
- # --------------------------------------------------------
- # Read NOTE [HigherOrderOperator tracing design] first.
- #
- # Let's say we're in the middle of introspecting the body of a possibly
- # nested HigherOrderOperator, and we see a free variable.
- #
- # There are two cases:
- # 1. We see a free variable that is already tracked by Dynamo.
- # 2. We see a free variable that has not been tracked by Dynamo
- #
- # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below)
- # which will lift the freevar to be an input of this subgraph
- # and also recursively lift it to be an input on the parent(s).
- #
- # In case 2, before the call to `create_proxy`, the InstructionTranslator
- # will see the freevar when it gets loaded by Python bytecode.
- # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or
- # LOAD_GLOBAL.
- # There, the InstructionTranslator asks Dynamo to begin tracking the
- # freevar by building a new Variable.
- # Building a new Variable automatically lifts the freevar to be an
- # input of the root SubgraphTracer.
- #
- # The implications for the code below are:
- # - We will always be in Case 1 when we get to this code.
- # - Any "free variable" we encounter here is guaranteed to already be
- # bound, that is, it is either a graph input of the root graph, or
- # some local variable of the root graph or a subgraph.
- # - The additional work we need to do here is *only* that we need to
- # lift this free variable into inputs (recursively) of each nested
- # higher-order-op subgraph until we hit the subgraph where the free
- # variable is bound
- if self.parent is not None:
- flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
- new_flat_args = []
- for arg in flat_args:
- maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg)
- new_flat_args.append(maybe_new_arg)
- args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)
- rv = super().create_proxy(
- kind,
- target,
- args,
- kwargs,
- name,
- type_expr,
- proxy_factory_fn, # type: ignore[arg-type]
- )
- # append stack trace to fx node
- tx = self.output_graph.current_tx
- # log detailed location of line of code in 3.11
- if sys.version_info >= (3, 11) and kind in (
- "call_function",
- "call_method",
- "call_module",
- ):
- cur_inst = tx.current_instruction
- if (
- cur_inst is not self.prev_inst
- and cur_inst.positions is not None
- and cur_inst.positions.lineno is not None
- ):
- tx_code = tx.f_code
- header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)
- def get_trace_call_log_str() -> str:
- line = get_instruction_source_311(tx_code, cur_inst).rstrip()
- return f"TRACE FX call {rv.node.name} from {header}\n{line}"
- trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
- self.prev_inst = cur_inst
- # update reference to original meta if we're tracing a new code object
- is_retracing = False
- if tx.f_code is not self._cur_code:
- orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
- "orig_graphmodule", lambda: None
- )()
- if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
- is_retracing = True
- self._orig_gm_meta = [
- nd.meta for nd in orig_graphmodule_maybe.graph.nodes
- ]
- self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map
- self._orig_gm_firstlineno = (
- orig_graphmodule_maybe.forward.__code__.co_firstlineno
- )
- else:
- self._orig_gm_meta = None
- self._orig_gm_lineno_map = None
- self._orig_gm_firstlineno = None
- nn_module_stack = tx.nn_module_stack
- if nn_module_stack:
- rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
- if kind in {"call_function", "call_method"}:
- rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
- (rv.node.name, target)
- ]
- elif kind == "call_module":
- if self.parent is not None:
- # TODO can remove once inline_inbuilt_nn_modules is always True
- unimplemented_v2(
- gb_type="Invoking an nn.Module inside a higher order operator",
- context=f"Higher order op name: {self.source_target}",
- explanation="This is not supported.",
- hints=[],
- )
- # For modules we store the class
- rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
- (
- rv.node.name,
- next(
- ty
- for k, (_, ty) in rv.node.meta["nn_module_stack"].items()
- if k.split("@")[0] == target
- ),
- )
- ]
- self._maybe_preserve_original_meta(tx, rv.node)
- if not is_retracing:
- if "nn_module_stack" not in rv.node.meta:
- nn_module_stack = tx.nn_module_stack
- if nn_module_stack:
- rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
- if "source_fn_stack" not in rv.node.meta:
- if kind in {"call_function", "call_method"}:
- rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
- (rv.node.name, target)
- ]
- elif kind == "call_module":
- if self.parent is not None:
- # TODO can remove once inline_inbuilt_nn_modules is always True
- unimplemented_v2(
- gb_type="Invoking an nn.Module inside a HigherOrderOperator",
- context="",
- explanation="This is not supported.",
- hints=[],
- )
- # For modules we store the class
- rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
- (
- rv.node.name,
- rv.node.meta["nn_module_stack"][target][1],
- )
- ]
- if "stack_trace" not in rv.node.meta:
- frame_summaries: list[traceback.FrameSummary] = []
- while tx:
- # Avoid frame summaries from inside the torch/nn/modules. This ensures that we keep the stack trace of
- # the user code.
- if not tx.is_co_filename_from_nn_modules():
- frame_summaries.append(tx.frame_summary())
- tx = getattr(tx, "parent", None)
- # Reverse the frame_summaries, such that the innermost frame is at the last
- frame_summaries.reverse()
- # official from_list stub doesn't have new-style type
- msgs = traceback.StackSummary.from_list(frame_summaries).format()
- rv.node.stack_trace = "".join(msgs)
- if (
- torch._dynamo.config.use_graph_deduplication
- or torch._dynamo.config.track_nodes_for_deduplication
- ):
- self.output_graph.region_tracker.track_node(
- self.output_graph.current_tx, rv.node
- )
- return rv
- def create_node(
- self,
- op: str,
- target: Target,
- args: Any = None,
- kwargs: Any = None,
- name: Optional[str] = None,
- type_expr: Optional[Any] = None,
- ) -> fx.Node:
- check_pt2_compliant_op(self.output_graph, op, target, args, kwargs)
- if self.parent is not None:
- flat_args = pytree.arg_tree_leaves(*args, **kwargs)
- for arg in flat_args:
- if not isinstance(arg, torch.fx.Node):
- continue
- assert arg.graph == self.graph, (
- "create_node using arg not from this SubgraphTracer"
- )
- node = super().create_node(op, target, args, kwargs, name, type_expr)
- node.meta["creation_timestamp"] = self.output_graph.timestamp
- self._used_names.add(node.name)
- return node
- # Note: we did not override erase_node since
- # we call self.graph.erase_node elsewhere
- def remove_node(self, node: fx.Node) -> None:
- if len(node.users) > 0:
- user_graph_nodes: list[torch.fx.Node] = []
- for user in node.users.keys():
- # For the case where user.graph == self.graph, that is a real bug and will raise
- # properly.
- if user.graph != self.graph:
- # This is a nested graph, which needs to be deleted.
- # If we do not do this, we will raise on attempting to remove this.
- # As we only get here during restoration cleanup, this is sound.
- user_graph_nodes.extend(reversed(list(user.graph.nodes)))
- for other_graph_node in user_graph_nodes:
- other_graph_node.graph.erase_node(other_graph_node)
- self.graph.erase_node(node)
- self.input_name_to_proxy.pop(node.name, None)
- # when before=True, we will insert this input before the most recent
- # inserted proxy. This is a hack to get around an ordering problem,
- # where we first insert a tensor argument, and then insert bindings
- # for SymInts that may occur in the tensor argument.
- # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets
- # fixed.
- def create_graph_input(
- self,
- name: str,
- type_expr: Any,
- example_value: Any,
- before: bool = False,
- source: Optional[Source] = None,
- ) -> fx.Proxy:
- if isinstance(example_value, torch.Tensor):
- self._input_versions_at_beginning.append(example_value._version)
- log.debug(
- "create_graph_input %s %s %s at debug_level %s before=%s",
- name,
- source.name() if source is not None else "(none)",
- example_value,
- self.debug_level,
- before,
- )
- if source is None:
- assert self.parent is not None, (
- f"you are required to provide a source for inputs {name} example_val {example_value} on the root tracer"
- )
- # Note [Export inputs must be explicitly passed in]
- # In eager, we are generally OK with adding graph inputs whenever we
- # want, because we take care of writing the bytecode that knows how
- # to source all the inputs.
- #
- # In export, this is bad, because you want a self-contained export
- # object which only depends on the inputs you explicitly passed to it.
- # So we are a bit more strict about what sources can become inputs
- # in export
- if self.is_export and self.parent is None:
- assert source is not None
- if not is_from_local_source(source, only_allow_input=True):
- self.output_graph.source_to_user_stacks.setdefault(source, []).append(
- TracingContext.extract_stack()
- )
- # _used_names contains the names of all the nodes in the graph,
- # including intermediates. This ensures that we do not have a name
- # collision.
- name = get_unique_name_wrt(name, self._used_names)
- if self.input_name_to_proxy:
- prev_name = next(reversed(self.input_name_to_proxy))
- node = self.input_name_to_proxy[prev_name].node
- if before:
- ctx = self.graph.inserting_before(node)
- else:
- ctx = self.graph.inserting_after(node)
- else:
- ctx = self.graph.inserting_before(None)
- with ctx:
- proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
- set_example_value(proxy.node, example_value)
- if self.input_name_to_proxy and before:
- k, v = self.input_name_to_proxy.popitem()
- self.input_name_to_proxy[name] = proxy
- self.input_name_to_proxy[k] = v
- else:
- self.input_name_to_proxy[name] = proxy
- # For placeholder nodes, `name` is passed as a str to the target,
- # and then torch.fx decides the node.name. So, record the `target`
- # name as well in the _used_names to prevent any collision.
- self._used_names.add(name)
- # NOTE: [Auto lift basic free symbols when create_graph_input]
- # There are two sources of basic symbols:
- #
- # - They can come from inputs, e.g. when an input tensor is specified as dynamic. We handle
- # this case by intercepting at create_graph_input. Whenever we call create_graph_input, we
- # try to also lift the basic symbols in example values as graph input.
- #
- # 1. When create_graph_input for a tensor that has symbolic shapes,
- # we look for basic symbols in its size and stride, we check if the symbol is bound
- # in current graph (i.e. bound_symbols), it it's not bound, we'll create a placeholder
- # for it then recursively check its parent, creates ph if not bound at parent until.
- # reachting the top-level, where we require a source is attached to the proxy.
- #
- # 2. When create_graph_input for a tensor that contains compound exprs,
- # for example, if an input to subgraph takes size [s1+s2//8], we'll look for the
- # the free basic symbols in the sizes and lift all of them following 1.
- #
- # 3. When create_graph_input for a symint. The following invariants hold:
- # a. if symint's expr is a basic symbol, we only lift it once.
- # b. if symint's expr is compuned, we lift the expr as a single input. We won't lift The basic symbols
- # in the compuned expr are NOT lifted. Because if the basic symbols are used inside the subgraph
- # they will be lifted according to 3.a
- #
- # - They can come from intermediate results:
- # For example, data-dependent operators such as t.item(), t.nonzero(), where basic symbols
- # might be created. For this purpose, we track the basic symbols of intermediate results
- # immediately after they're created at wrap_fx_proxy with track_produced_symints. Notice
- # that for basic symbols that're already tracked by create_graph_input, we won't track it again.
- #
- # Also see NOTE: [Export inputs must be explicitly passed in]
- is_strict_export = self.is_export
- is_non_strict_export = torch.compiler.is_compiling()
- if not is_strict_export and not is_non_strict_export:
- if isinstance(example_value, torch.Tensor):
- self._lift_basic_symbols(example_value, source)
- elif isinstance(example_value, (list, tuple)):
- for i, e in enumerate(example_value):
- if not isinstance(e, torch.Tensor):
- continue
- e_source = None
- if source:
- e_source = GetItemSource(
- base=source, index=i, index_is_slice=False
- )
- self._lift_basic_symbols(e, e_source)
- # Bound the symbol to ph if example_value is a SymInt with basic symbol.
- if isinstance(example_value, torch.SymInt) and isinstance(
- example_value.node.expr, sympy.Symbol
- ):
- self.bound_symbols[example_value.node.expr] = proxy
- return proxy
- # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
- def lift_tracked_freevar_to_input(
- self, proxy: fx.Proxy
- ) -> Union[LazyProxy, fx.Proxy]:
- # You're doing something wrong if we are the root SubgraphTracer because
- # Dynamo adds tensors to graph inputs before creating a proxy for them.
- assert self.parent is not None, (
- "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
- )
- example_value = proxy.node.meta["example_value"]
- # To avoid lifting the same symbol twice, we check whether basic symbols has been tracked.
- # For example, the basic symbols may have already been lifted for current subgraph when
- # we automatically lift basic symbols in the sizes/strides of a tensor t.
- # Suppose parent graph calls sz = t.size()[0], it creates
- # a proxy in parent and the subgraph accesses sz via closure. sz's proxy is not tracked
- # in current sub-tracer so we may lift the same symbol twice.
- if (
- isinstance(example_value, torch.SymInt)
- and example_value.node.expr in self.bound_symbols
- ):
- return self.bound_symbols[example_value.node.expr]
- # Proxies are associated with VariableTracker.
- # It is possible that we've already lifted the Proxy to be an input.
- # If that is the case, just return the already lifted Proxy.
- if proxy in self.lifted_freevars:
- return self.lifted_freevars[proxy]
- # We first lift proxy to parent's graph then lift to current grpah's input
- # so that when we bind symints of the sizes in current graph, those symints
- # would already be lifted as inputs to parent graph.
- if proxy.tracer != self.parent:
- self.parent.lift_tracked_freevar_to_input(proxy)
- example_value = proxy.node.meta["example_value"]
- new_proxy = self.create_graph_input(
- proxy.node.name, type(example_value), example_value
- )
- self.lifted_freevars[proxy] = new_proxy
- return new_proxy
- def maybe_lift_tracked_freevar_to_input(self, arg: Any) -> Any:
- """
- If arg is a free variable, then lift it to be an input.
- Returns the new lifted arg (if arg was a freevar), else the
- original arg.
- """
- if not isinstance(arg, torch.fx.Proxy):
- # Note: arg can be a python built-in slice type e.g.
- # x[:max_seq] is represented as get_item(t, (slice(None, max_seq, None)))
- # we need to also look into the slice variable itself to lift the
- # proxies there.
- if isinstance(arg, slice):
- return slice(
- *(
- self.maybe_lift_tracked_freevar_to_input(sub_arg)
- for sub_arg in (arg.start, arg.stop, arg.step)
- )
- )
- else:
- return arg
- elif arg.tracer == self:
- return arg
- return self.lift_tracked_freevar_to_input(arg)
- # See NOTE: [Auto lift basic free symbols when create_graph_input] for overall design
- # You MUST call this API every time when creating a proxy in wrap_fx_proxy for a call
- # that produced symints or tensors with unbacked symint shapes.
- # This function is used to track the symints with its proxies created during
- # dynamo tracing so that subgraph knows how to bind a symbol input with parent's proxy.
- # LazyProxy are created for tensor shapes that're unbacked so that we don't create proxies
- # for symbols that're not going to be used, the LazyProxy will be turned into a proxy
- # when it's lifted as input to subgraph.
- def track_produced_symints(
- self, example_value: Any, e_proxy: Union[LazyProxy, torch.fx.Proxy]
- ) -> None:
- # When binding the symbols in an exmaple_value, we bind the symbols
- # to the proxy's associated Tracer instead of current tracer.
- # This is because:
- # 1. We may be calling wrap_tensors during speculate_subgraph because
- # the variables are lazily realized. The proxy are top-level phs but
- # current tracer is a subtracer.
- # 2. For autograd.Function, we trace the backward graph with a new tracer
- # whose parent is the forward tracer, but we're using all the proxies created
- # in forward tracer to trace the backward.
- # For example, forward calls save_for_backward for a input tensor t.
- # Backward calls t.tolist(). In this case, all the proxies that backward tracer
- # sees are from parent tracer (i.e. the forward tracer). (e.g. t[0].item())
- # See test_validate_outputs_unbacked for repro on 2.
- tracer = e_proxy.tracer
- assert isinstance(tracer, SubgraphTracer)
- def need_bind(s: Any) -> bool:
- from torch.fx.experimental.symbolic_shapes import is_symbolic
- return (
- is_symbolic(s)
- and isinstance(s.node.expr, sympy.Symbol)
- and s.node.expr not in self.bound_symbols
- )
- def _proxy_with_example_value(
- example_value: Any, *args: Any, **kwargs: Any
- ) -> fx.Proxy:
- # We need to insert proxy for creating sym_size/sym_stride/sym_storage right after e_proxy
- nonlocal e_proxy
- e_proxy = e_proxy() if isinstance(e_proxy, LazyProxy) else e_proxy
- assert isinstance(e_proxy, torch.fx.Proxy)
- with tracer.graph.inserting_after(e_proxy.node):
- proxy = tracer.create_proxy(*args, **kwargs)
- set_example_value(proxy.node, example_value)
- return proxy
- if isinstance(example_value, torch.Tensor):
- for i, s in enumerate(example_value.size()):
- if need_bind(s):
- log.debug(
- "track_produced_symints %s for %s.size()[%s] at debug_level %s",
- s,
- e_proxy,
- i,
- tracer.debug_level,
- )
- lazy_proxy = LazyProxy(
- tracer,
- _proxy_with_example_value,
- s,
- "call_function",
- torch.ops.aten.sym_size.int,
- (e_proxy, i),
- {},
- type_expr=type(s),
- )
- self.track_produced_symints(s, lazy_proxy)
- storage_offset = example_value.storage_offset()
- if need_bind(storage_offset):
- log.debug(
- "track_produced_symints %s for %s.storage_offset() at debug_level %s",
- storage_offset,
- e_proxy,
- tracer.debug_level,
- )
- lazy_proxy = LazyProxy(
- tracer,
- _proxy_with_example_value,
- storage_offset,
- "call_function",
- torch.ops.aten.sym_storage_offset,
- (e_proxy,),
- {},
- type_expr=type(storage_offset),
- )
- self.track_produced_symints(storage_offset, lazy_proxy)
- if example_value.layout is torch.strided:
- for i, s in enumerate(example_value.stride()):
- if need_bind(s):
- log.debug(
- "track_produced_symints %s for %s.stride()[%s] at debug_level %s",
- s,
- e_proxy,
- i,
- tracer.debug_level,
- )
- lazy_proxy = LazyProxy(
- tracer,
- _proxy_with_example_value,
- s,
- "call_function",
- torch.ops.aten.sym_stride.int,
- (e_proxy, i),
- {},
- type_expr=type(s),
- )
- self.track_produced_symints(s, lazy_proxy)
- elif example_value.layout is torch.sparse_coo:
- self.track_produced_symints(example_value._indices(), e_proxy)
- self.track_produced_symints(example_value._values(), e_proxy)
- elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
- self.track_produced_symints(example_value.crow_indices(), e_proxy)
- self.track_produced_symints(example_value.col_indices(), e_proxy)
- elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
- self.track_produced_symints(example_value.ccol_indices(), e_proxy)
- self.track_produced_symints(example_value.row_indices(), e_proxy)
- if is_traceable_wrapper_subclass(example_value):
- attrs, ctx = example_value.__tensor_flatten__()
- for attr in attrs:
- inner_t = getattr(example_value, attr)
- self.track_produced_symints(inner_t, getattr(e_proxy, attr))
- elif isinstance(example_value, torch.SymInt):
- if need_bind(example_value):
- expr = example_value.node.expr
- tracer.bound_symbols[expr] = e_proxy
- # See Note [Auto lift basic free symbols when create_graph_input]
- def _lift_basic_symbols(
- self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source]
- ) -> None:
- # The before arg is for inserting symints in the sizes/strides of a tensor
- # before the tensor. This ordering ensures that when we look at the tensor's
- # symbols, they're already lifted/tracked. E.g. this assumption is used
- # in insert_deferred_runtime_asserts.
- def _lift_symbols_in_symint(
- s: Union[int, torch.SymInt],
- source: Optional[Source],
- before: bool = False,
- ) -> None:
- if not is_symbolic(s):
- return
- assert isinstance(s, torch.SymInt)
- self_to_be_bound = self.lookup_unbound_symbols(s)
- if len(self_to_be_bound) == 0:
- return
- # For subgraph
- if self.parent is not None:
- # Recursively lift symbols in symint until top-level.
- self.parent._lift_basic_symbols(s, source)
- for s0 in self_to_be_bound:
- parent_proxy = self.parent.bound_symbols[s0]
- example_val = parent_proxy.node.meta["example_value"] # type: ignore[union-attr]
- assert isinstance(example_val, torch.SymInt)
- ph = self.create_graph_input(
- str(s0),
- type(example_val),
- example_val,
- before=before,
- source=source,
- )
- log.debug(
- "_lift_symbols_in_symint %s from %s at debug_level %s",
- s0,
- source.name() if source is not None else "subgraph inputs",
- self.debug_level,
- )
- self.lifted_freevars[parent_proxy] = ph # type: ignore[index]
- # For root_tracer:
- else:
- assert len(self_to_be_bound) == 1, (
- f"For root tracer, we only expect to bind basic symbols (compound symbols "
- f"should be cached before) but got unbound symbols {self_to_be_bound} in {s}"
- )
- assert source is not None, (
- f"Source of '{s}' is None when lifting it to input of top-level. If it's an unbacked symbol, "
- "this could be because it's not tracked with lazy_bind_unbacked_symbols. "
- f"Otherwise, should provide a source when create_graph_input for `{s}` at root tracer."
- )
- s0 = next(iter(self_to_be_bound))
- ph = self.create_graph_input(
- str(s0),
- type(s),
- s,
- before=before,
- source=source,
- )
- log.debug(
- "_lift_symbols_in_symint %s from %s at debug_level %s",
- s,
- source.name() if source is not None else "subgraph inputs",
- self.debug_level,
- )
- ph.node.meta["grapharg"] = GraphArg(
- source,
- s,
- pass_arg_as_tensor=False,
- fake_tensor=None,
- is_tensor=False,
- )
- if isinstance(example_value, torch.Tensor):
- for i, s in enumerate(example_value.size()):
- _lift_symbols_in_symint(
- s,
- (
- TensorPropertySource(src, TensorProperty.SIZE, i)
- if src is not None
- else None
- ),
- before=True,
- )
- if example_value.layout is torch.strided:
- for i, s in enumerate(example_value.stride()):
- _lift_symbols_in_symint(
- s,
- (
- TensorPropertySource(src, TensorProperty.STRIDE, i)
- if src is not None
- else None
- ),
- before=True,
- )
- _lift_symbols_in_symint(
- example_value.storage_offset(),
- (
- TensorPropertySource(src, TensorProperty.STORAGE_OFFSET)
- if src is not None
- else None
- ),
- before=True,
- )
- elif example_value.layout is torch.sparse_coo:
- self._lift_basic_symbols(example_value._indices(), src)
- self._lift_basic_symbols(example_value._values(), src)
- elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
- self._lift_basic_symbols(example_value.crow_indices(), src)
- self._lift_basic_symbols(example_value.col_indices(), src)
- elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
- self._lift_basic_symbols(example_value.ccol_indices(), src)
- self._lift_basic_symbols(example_value.row_indices(), src)
- if is_traceable_wrapper_subclass(example_value):
- attrs, ctx = example_value.__tensor_flatten__()
- for attr in attrs:
- inner_t = getattr(example_value, attr)
- self._lift_basic_symbols(
- inner_t, AttrSource(src, attr) if src is not None else None
- )
- elif isinstance(example_value, torch.SymInt):
- _lift_symbols_in_symint(
- example_value,
- src,
- )
- # Lookup the proxy in current tracer for each symbol in expressions of s,
- # See Note [Auto lift basic free symbols when create_graph_input]
- def lookup_unbound_symbols(self, s: torch.SymInt) -> list[sympy.Symbol]:
- free_symbols = s.node.expr.free_symbols
- if len(free_symbols) == 0:
- return []
- to_be_bound = []
- for s0 in free_symbols:
- if s0 not in self.bound_symbols:
- to_be_bound.append(s0)
- continue
- proxy = self.bound_symbols[s0]
- if isinstance(proxy, LazyProxy):
- proxy = proxy()
- self.bound_symbols[s0] = proxy
- assert isinstance(proxy, torch.fx.Proxy) and proxy.tracer is self, (
- f"The proxy of symbol {s0} doesn't belong to current tracer."
- )
- # Sort the symbols so that we can have a deterministic lifting order
- return sorted(to_be_bound, key=lambda s: s.name)
- def has_input_mutation(self) -> MutationInfo:
- input_versions_at_beginning = self._input_versions_at_beginning
- input_nodes = []
- input_versions_at_end = []
- for node in self.graph.nodes:
- if node.op == "placeholder":
- example_value = node.meta["example_value"]
- if isinstance(example_value, torch.Tensor):
- input_versions_at_end.append(example_value._version)
- input_nodes.append(node)
- else:
- break
- mutated_inputs = [
- i
- for i, (v1, v2) in enumerate(
- zip(input_versions_at_beginning, input_versions_at_end)
- )
- if v1 != v2
- ]
- if len(mutated_inputs):
- mutated_nodes = [input_nodes[i] for i in mutated_inputs]
- msg = f"Input mutation detected at {mutated_nodes}"
- return MutationInfo(True, msg)
- return MutationInfo(False, "")
- def has_aliasing(self) -> AliasingInfo:
- from torch._higher_order_ops.utils import _collect_fake_inputs
- input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
- for node in self.graph.nodes:
- if node.op == "placeholder":
- example_value = _collect_fake_inputs([node])[0]
- if isinstance(example_value, torch.Tensor):
- storage = StorageWeakRef(example_value._typed_storage())
- if storage in input_storages:
- # input-input aliasing
- msg = f"Input-to-input aliasing detected at nodes {input_storages[storage]} and {node}"
- return AliasingInfo(True, msg)
- input_storages[storage] = node
- else:
- break
- output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
- out_nodes = self.graph.find_nodes(op="output")[0]
- for out_node in pytree.tree_leaves(out_nodes.args[0]):
- if out_node:
- example_value = _collect_fake_inputs([out_node])[0]
- assert not isinstance(example_value, list)
- if isinstance(example_value, torch.Tensor):
- storage = StorageWeakRef(example_value._typed_storage())
- if storage in output_storages:
- # output-output aliasing
- msg = f"Output-to-output aliasing detected at nodes {output_storages[storage]} and {out_node}"
- return AliasingInfo(True, msg)
- output_storages[storage] = out_node
- intersected_storages = input_storages.keys() & output_storages.keys()
- if len(intersected_storages) > 0:
- # input-output aliasing
- aliased = [
- (input_storages[s], output_storages[s]) for s in intersected_storages
- ]
- aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
- msg = f"Input-to-output aliasing detected at nodes {aliased}"
- return AliasingInfo(True, msg)
- return AliasingInfo(False, "")
- # NOTE: [HigherOrderOperator tracing design]
- # Ignoring HigherOrderOperators for a moment,
- # OutputGraph represents the graph being built by Dynamo that may be compiled
- # and executed. It holds a root SubgraphTracer where the FX graph is built.
- #
- # HigherOrderOperators are operators that take functions as their arguments.
- # When Dynamo encounters a HigherOrderOperator, then it attempts to introspect
- # the function passed to it (call this the "body function"), capture it into a
- # GraphModule, and rewrite the call to the HigherOrderOperator to use the
- # GraphModule.
- #
- # The way we handle the capture of body functions is through having
- # (possibly nested) SubgraphTracers, one per body function.
- #
- # Mechanically, we do the introspection by:
- # - Creating a new SubgraphTracer via OutputGraph.subtracer
- # - Executing the body function.
- # This constructs the graph of the body function in the new SubgraphTracer
- # while modifying the state of the OutputGraph. For example:
- # - the OutputGraph can receive new GraphArgs (if we discover any new
- # untracked Tensors)
- # - side effects from the body function get accumulated into
- # OutputGraph.side_effects
- # - guards produced by the body function get accumulated into OutputGraph.guards
- #
- # The traced function has some special properties that make it easier for us
- # to transform later down the line:
- # - we lift all free variables to being inputs.
- #
- # If the introspection fails (due to the existence of graph breaks), then
- # we roll back the current OutputGraph state and graph break on the
- # HigherOrderOperator.
|