output_graph.py 148 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562
  1. """
  2. Core graph building functionality for PyTorch's Dynamo system. This module contains
  3. the essential components for constructing and managing FX graphs during compilation:
  4. - OutputGraph: Manages the overall graph construction and compilation process. It owns
  5. a SubgraphTracer and handles graph compilation, execution, and state management.
  6. OutputGraph also manages features like graph deduplication, symbolic shape handling,
  7. and tracking of side effects.
  8. - SubgraphTracer: Handles the actual FX graph construction by tracing Python code.
  9. It supports advanced features like higher-order operators through nested tracers,
  10. lifting of free variables, and handling of symbolic shapes.
  11. The module supports key Dynamo features including:
  12. - Higher-order operators through nested SubgraphTracers
  13. - Graph deduplication for optimization
  14. - Symbolic shape handling and propagation
  15. - Side effect tracking and management
  16. - Guard insertion and management
  17. """
  18. import collections
  19. import contextlib
  20. import copy
  21. import functools
  22. import inspect
  23. import itertools
  24. import logging
  25. import operator
  26. import re
  27. import sys
  28. import traceback
  29. import warnings
  30. import weakref
  31. from collections.abc import Generator, Sequence
  32. from dataclasses import dataclass, field as dc_field
  33. from types import CodeType
  34. from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
  35. from typing_extensions import ParamSpec, TypeVar
  36. import sympy
  37. import torch._guards
  38. import torch._logging
  39. import torch.distributed as dist
  40. import torch.nn
  41. import torch.utils._pytree as pytree
  42. from torch import fx, Tensor
  43. from torch._C._dynamo import guards
  44. from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
  45. from torch._guards import (
  46. CompileContext,
  47. CompileId,
  48. GlobalContextCheckpointState,
  49. Source,
  50. tracing,
  51. TracingContext,
  52. )
  53. from torch._subclasses.fake_tensor import FakeTensor
  54. from torch._utils_internal import signpost_event
  55. from torch.export.dynamic_shapes import _ConstraintTarget
  56. from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
  57. from torch.fx.experimental._backward_state import BackwardState
  58. from torch.fx.experimental.symbolic_shapes import (
  59. free_symbols,
  60. guard_scalar,
  61. is_symbolic,
  62. ShapeEnv,
  63. Specialization,
  64. )
  65. from torch.fx.node import Target
  66. from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
  67. from torch.multiprocessing.reductions import StorageWeakRef
  68. from torch.utils._ordered_set import OrderedSet
  69. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  70. from . import config, exc, logging as torchdynamo_logging, variables
  71. from .backends.registry import CompiledFn, CompilerFn
  72. from .bytecode_transformation import (
  73. create_binary_slice,
  74. create_call_function,
  75. create_dup_top,
  76. create_instruction,
  77. create_load_const,
  78. create_rot_n,
  79. create_swap,
  80. Instruction,
  81. unique_id,
  82. )
  83. from .code_context import code_context
  84. from .codegen import PyCodegen
  85. from .current_scope_id import enter_new_scope
  86. from .device_interface import get_interface_for_device
  87. from .exc import (
  88. BackendCompilerFailed,
  89. exceptions_allowed_to_be_fallback,
  90. SkipFrame,
  91. unimplemented_v2,
  92. unimplemented_v2_with_warning,
  93. )
  94. from .graph_deduplication import apply_graph_deduplication
  95. from .graph_region_tracker import GraphRegionTracker
  96. from .guards import GuardBuilder, install_guard
  97. from .mutation_guard import is_dynamic_nn_module
  98. from .side_effects import AttributeMutationExisting, SideEffects, ValueMutationExisting
  99. from .source import (
  100. _get_source_debug_name,
  101. AttrSource,
  102. BackwardStateSource,
  103. ConstantSource,
  104. GetItemSource,
  105. GlobalStateSource,
  106. is_constant_source,
  107. is_from_local_source,
  108. LocalSource,
  109. NumpyTensorSource,
  110. ParamBufferSource,
  111. ShapeEnvSource,
  112. SyntheticLocalSource,
  113. TensorProperty,
  114. TensorPropertySource,
  115. )
  116. from .utils import (
  117. _extract_tensor_dict,
  118. checkpoint_params,
  119. CleanupHook,
  120. clone_inputs,
  121. count_calls,
  122. counters,
  123. dynamo_timed,
  124. get_instruction_source_311,
  125. get_locals_to_steal,
  126. get_static_address_type,
  127. get_unique_name_wrt,
  128. graph_break_reasons,
  129. increment_op_count,
  130. istype,
  131. lazy_format_graph_code,
  132. LazyString,
  133. nn_module_proxy,
  134. same,
  135. set_example_value,
  136. )
  137. from .variables.base import VariableTracker
  138. from .variables.builder import (
  139. BackwardStateGraphArg,
  140. GraphArg,
  141. TrackedFake,
  142. wrap_fx_proxy,
  143. )
  144. from .variables.ctx_manager import ContextWrappingVariable
  145. from .variables.lists import BaseListVariable
  146. from .variables.misc import NullVariable
  147. from .variables.nn_module import NNModuleVariable
  148. from .variables.tensor import (
  149. NumpyNdarrayVariable,
  150. SymNodeVariable,
  151. TensorVariable,
  152. UnspecializedPythonVariable,
  153. )
  154. from .variables.torch_function import TensorWithTFOverrideVariable
  155. from .variables.user_defined import UserDefinedDictVariable
  156. if TYPE_CHECKING:
  157. from torch._dynamo.package import CompilePackage
  158. from torch._dynamo.symbolic_convert import InstructionTranslatorBase
  159. log = logging.getLogger(__name__)
  160. graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
  161. graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
  162. graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
  163. trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
  164. RootGuardManager = guards.RootGuardManager
  165. @dataclass(frozen=True)
  166. class VariableTrackerCacheKey:
  167. vt_id: int
  168. # Two different source can point to the same object. However, Dynamo handles
  169. # globals and local source differently when it comes to guards and possibly
  170. # some other parts as well. So, cache also relies on the source.
  171. source: Source
  172. @dataclass(frozen=True)
  173. class AliasingInfo:
  174. has_aliasing: bool
  175. msg: str
  176. @dataclass(frozen=True)
  177. class MutationInfo:
  178. has_mutation: bool
  179. msg: str
  180. class VariableTrackerCache:
  181. def __init__(self) -> None:
  182. self.cache: dict[VariableTrackerCacheKey, VariableTracker] = {}
  183. def lookup(self, value: Any, source: Source) -> Optional[VariableTracker]:
  184. key = VariableTrackerCacheKey(id(value), source)
  185. if key not in self.cache:
  186. return None
  187. return self.cache[key]
  188. def add(self, value: Any, source: Source, vt: VariableTracker) -> None:
  189. key = VariableTrackerCacheKey(id(value), source)
  190. self.cache[key] = vt
  191. def clone(self) -> "VariableTrackerCache":
  192. # Needed for copy and restore graph state
  193. new_cache = VariableTrackerCache()
  194. new_cache.cache.update(self.cache)
  195. return new_cache
  196. def clear(self) -> None:
  197. self.cache.clear()
  198. @functools.cache
  199. def _step_logger() -> Any:
  200. return torchdynamo_logging.get_step_logger(log)
  201. @dataclass
  202. class GraphCompileReason:
  203. """Stores why a given output graph was compiled; i.e. what caused the graph break."""
  204. reason: str
  205. user_stack: list[traceback.FrameSummary]
  206. # Indicates if this was a graph break reason due to graph break.
  207. graph_break: bool = True
  208. def __post_init__(self) -> None:
  209. if self.graph_break:
  210. graph_break_reasons.append(self)
  211. def _get_gen_rand_values_fn(random_calls: Any) -> Callable[[], list[Any]]:
  212. def _gen_rand_values() -> list[Any]:
  213. return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
  214. return _gen_rand_values
  215. class FakeRootModule(torch.nn.Module):
  216. """Trick the constructor of fx.GraphModule"""
  217. def __init__(self, nn_modules: dict[str, torch.nn.Module]):
  218. super().__init__()
  219. for k, v in nn_modules.items():
  220. setattr(self, k, v)
  221. def __repr__(self) -> str:
  222. return "FakeRootModule(...)"
  223. def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]) -> None:
  224. for k, v in nn_modules.items():
  225. setattr(self, k, v)
  226. class WrapperBackend:
  227. def __init__(self, backend: CompilerFn) -> None:
  228. self.backend: CompilerFn = backend
  229. def __call__(
  230. self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  231. ) -> CompiledFn:
  232. self.restore = checkpoint_params(gm)
  233. self.gm = gm
  234. copy_gm = copy.deepcopy(self.gm)
  235. self.candidate = self.backend(copy_gm, example_inputs)
  236. if self.candidate is None or self.candidate is self.gm.forward:
  237. return self.gm.forward
  238. if not config.verify_correctness:
  239. return self.candidate
  240. # if verify_correctness=True
  241. try:
  242. correct = self.gm.forward(*clone_inputs(example_inputs))
  243. result = self.candidate(*clone_inputs(example_inputs))
  244. # TODO: replace `same` function with the one in testing
  245. if same(correct, result):
  246. return self.candidate
  247. raise RuntimeError(f"incorrect results of backend {self}")
  248. except Exception:
  249. log.exception("error in verify_correctness")
  250. raise
  251. finally:
  252. self.restore()
  253. Scope = dict[str, object]
  254. @dataclass
  255. class OutputGraphGuardsState:
  256. """
  257. A base class containing fields that are considered "persistent" when we
  258. want to save all the important state for reconstrucing guards in a different
  259. process. Normally we don't need to add states here, but we may have to when
  260. the information is needed to serialize the guards, so the fields here are
  261. supposed to be serializable as a requirement.
  262. """
  263. local_scope: Scope
  264. global_scope: Scope
  265. # This records the initial torch function mode stack for guarding
  266. torch_function_mode_stack: list[torch.overrides.TorchFunctionMode]
  267. guard_on_key_order: set[Source]
  268. # Map from graph input's `Source` to sizes / strides metadata
  269. input_source_to_sizes_strides: dict[Source, dict[str, Any]]
  270. dual_level: int
  271. functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
  272. current_device: Optional[torch.device]
  273. global_state_guard: torch._C._dynamo.guards.GlobalStateGuard
  274. _guards: torch._guards.GuardsSet
  275. _aotautograd_guards: list[torch._guards.GuardEnvExpr]
  276. # Whether or not the guards should be checked for correctness
  277. export: bool = False
  278. skip_guards_check: bool = False
  279. export_constraints: bool = False
  280. name_of_builtins_dict_key_in_fglobals: Optional[str] = None
  281. @property
  282. def shape_env(self) -> ShapeEnv:
  283. raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}")
  284. @property
  285. def guards(self) -> torch._guards.GuardsSet:
  286. return self._guards
  287. @property
  288. def aotautograd_guards(self) -> list[torch._guards.GuardEnvExpr]:
  289. return self._aotautograd_guards
  290. @dataclass
  291. class StackLocalsMetadata:
  292. """
  293. Stores metadata for a frame's stack and locals for the purposes of building resume functions
  294. """
  295. num_stack: int = 0 # number of stack elements, minus removed NULLs
  296. locals_names: dict[str, int] = dc_field(
  297. default_factory=dict
  298. ) # order of locals codegen'd to the stack
  299. stack_null_idxes: list[int] = dc_field(default_factory=list)
  300. locals_null_keys: list[str] = dc_field(default_factory=list)
  301. stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list)
  302. stack_ctx_idxes_orig: list[int] = dc_field(default_factory=list)
  303. locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list)
  304. # TODO we should expand this to make it work for atribtrary in/out
  305. @dataclass
  306. class ExportMetaData:
  307. # maps graph input index to its' source which is later
  308. # used in export to map to correct user input. In its' flat form,
  309. # just looks like GetItem(base=LocalSource("foo", idx=0))
  310. graph_input_idx_to_local_source: dict[int, Source] = dc_field(default_factory=dict)
  311. # maps user output idx to what type of output it is. There are 3 options:
  312. # 1) graph out
  313. # 2) user input
  314. # 3) constants
  315. output_return_type: dict[int, tuple[str, Any]] = dc_field(default_factory=dict)
  316. # output spec of the traced function
  317. out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
  318. torch.utils._pytree._LEAF_SPEC
  319. )
  320. def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
  321. # f_globals["__builtins__"] can be a dict or a module. This is an
  322. # implementation detail -
  323. # https://docs.python.org/3/library/builtins.html.
  324. # This makes guarding on any builtin messy because the guard check_fn
  325. # has to check if the __builtins__ is a module or dict, and then access
  326. # by either using getattr or getitem respectively.
  327. # To solve this problem, we insert a new entry in f_globals which points
  328. # to the builtins __dict__ and then we guard any builtin on this dict.
  329. # To avoid any collision with the pre-existing keys, we use the
  330. # install_global to give us a unique dict key.
  331. f_builtins = global_scope["__builtins__"]
  332. if not isinstance(f_builtins, dict):
  333. f_builtins = f_builtins.__dict__
  334. return f_builtins
  335. class OutputGraph(OutputGraphGuardsState):
  336. """
  337. Wrapper class to hold outputs of InstructionTranslator. Mainly the
  338. generated fx.Graph.
  339. OutputGraph is 1:1 with a frame being processed. Each frame is associated
  340. with some root InstructionTranslator. When user code calls a function,
  341. we construct a InliningInstructionTranslator that continues to write into
  342. the root InstructionTranslator's OutputGraph.
  343. """
  344. side_effects: SideEffects
  345. def __init__(
  346. self,
  347. code_options: dict[str, Any],
  348. compiler_fn: Optional[CompilerFn],
  349. root_tx: "InstructionTranslatorBase",
  350. export: bool,
  351. export_constraints: Sequence[_ConstraintTarget],
  352. frame_state: Any,
  353. local_scope: Scope,
  354. global_scope: Scope,
  355. f_code: CodeType,
  356. torch_function_mode_stack: list[torch.overrides.TorchFunctionMode],
  357. package: Optional["CompilePackage"],
  358. ) -> None:
  359. super().__init__(
  360. local_scope,
  361. global_scope,
  362. torch_function_mode_stack,
  363. guard_on_key_order=set(),
  364. input_source_to_sizes_strides={},
  365. dual_level=torch.autograd.forward_ad._current_level,
  366. functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
  367. current_device=torch.utils._device.CURRENT_DEVICE,
  368. # initial_global_state is only None during NopTest.
  369. global_state_guard=torch._dynamo.convert_frame.initial_global_state
  370. or torch._C._dynamo.guards.GlobalStateGuard(),
  371. # These are set by @property instead, just initialize them as blank
  372. _guards=torch._guards.GuardsSet(),
  373. _aotautograd_guards=[],
  374. )
  375. self.tracers = [SubgraphTracer(self, is_export=export)]
  376. # Map from graph input's `Source` to its `VariableTracker` to
  377. # de-duplicate graph inputs by source and reuse the tracker
  378. self.input_source_to_var: dict[Source, VariableTracker] = {}
  379. self.export = export
  380. self.export_constraints = export_constraints # type: ignore[assignment]
  381. self.frame_state = frame_state
  382. self.cleanup_hooks: list[Callable[[], Any]] = []
  383. # compile_id is an id number for the current torch.compile
  384. self.compile_id: int = next(_compile_id_counter)
  385. # Set of globals installed via install_global* APIs
  386. self.installed_globals: set[str] = set()
  387. # TODO: maybe should just pass the entire f_code in here? Not
  388. # sure...
  389. self.co_fields = {
  390. "co_name": f_code.co_name,
  391. "co_filename": f_code.co_filename,
  392. "co_firstlineno": f_code.co_firstlineno,
  393. }
  394. self.region_tracker = GraphRegionTracker()
  395. # tracked_fakes says where any tensor that was wrapped to fake came
  396. # from. It is similar to GraphArg, in that all GraphArgs will get
  397. # will get added to TrackedFakes, but TrackedFakes also contains
  398. # GraphArgs that got pruned, and things like Tensor attributes which
  399. # aren't explicit graph inputs. Used by shape guard
  400. self.tracked_fakes: list[TrackedFake] = []
  401. shape_env = ShapeEnv(
  402. # Reference Cycle!
  403. # Share a reference to the list of TrackedFake.
  404. #
  405. # ShapeEnv needs this in order to be able to reproduce the call
  406. # to produce_guards at an arbitrary time point. That is because
  407. # TrackedFake instances may have its metadata changed throughout
  408. # the program execution.
  409. tracked_fakes=self.tracked_fakes,
  410. allow_scalar_outputs=config.capture_scalar_outputs,
  411. allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
  412. prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
  413. co_fields=self.co_fields,
  414. )
  415. # In export mode, we force the shape_env to strictly disallow any constraining
  416. # of the user marked dynamic dims
  417. import torch._functorch.config as _config
  418. with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
  419. fake_mode = torch._subclasses.FakeTensorMode(
  420. shape_env=shape_env,
  421. # TODO (tmanlaibaatar) Remove this once we always lift params and buffers
  422. allow_non_fake_inputs=True if self.export else False,
  423. export=self.export,
  424. )
  425. self.tracing_context: TracingContext = TracingContext(fake_mode)
  426. self.tracing_context.traced_code.append(f_code)
  427. self.dynamo_compile_id: Optional[CompileId] = (
  428. CompileContext.current_compile_id()
  429. )
  430. self.init_ambient_guards()
  431. # Map each tensor id to a list of sources. This is necessary because
  432. # tensor ids cannot be recovered from tracked fakes (in general).
  433. # We use this map to interpret (i.e., check for violations of) constraints,
  434. # specifically equality constraints, which have shared tensor ids in them.
  435. # This map should also be generally useful, e.g., for (de)serialization.
  436. self.tracked_fakes_id_to_source: dict[int, list[Source]] = (
  437. collections.defaultdict(list)
  438. )
  439. # Stores the full fqn of a param or buffer to the relevant source.
  440. self.param_name_to_source: Optional[dict[str, Source]] = {}
  441. self.side_effects = SideEffects(self)
  442. # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL
  443. # and LOAD_ATTR for same python objects free.
  444. self.variable_tracker_cache = VariableTrackerCache()
  445. self.unique_var_id = itertools.count()
  446. self.code_options: dict[str, Any] = dict(code_options)
  447. self.output_instructions: list[Instruction] = []
  448. # used to track nodes that are added between calls of copy_graphstate
  449. # and restore_graphstate
  450. self.timestamp = 0
  451. # A list of register_finalizer_fns to apply to the output graph module
  452. self.register_finalizer_fns: list[Callable[[fx.GraphModule], None]] = []
  453. # Not checkpointed
  454. self.compiler_fn: Optional[CompilerFn] = compiler_fn
  455. self.root_tx = root_tx
  456. self.package = package
  457. # Given a source, what are the user stacks of all locations that
  458. # accessed it?
  459. #
  460. # For efficiency, we only populate this:
  461. # - During export, and
  462. # - If the source could potentially lead to a spurious export input
  463. #
  464. # Feel free to populate this more frequently if other use-cases arise,
  465. # but be aware that we have to generate full stacks for each
  466. # recording!
  467. self.source_to_user_stacks: dict[Source, list[traceback.StackSummary]] = {}
  468. self._current_tx: list[InstructionTranslatorBase] = []
  469. self.cleanups: list[CleanupHook] = []
  470. self.should_exit = False
  471. self.unspec_variable_map: dict[str, UnspecializedPythonVariable] = {}
  472. # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
  473. self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
  474. # Tracks if the output graph has a user defined allowed function in the
  475. # graph. This is used later to determine if we should fallback to eager
  476. # for certain exceptions. THe idea is that if the user has applied
  477. # allow_in_graph, they would like to see the error instead of falling
  478. # back for backend errors.
  479. self.has_user_defined_allowed_in_graph = False
  480. # Tracks a list of called ops that were not tagged with "pt2_compliant_tag".
  481. # This information is useful for logging.
  482. self.non_compliant_ops: set[torch._ops.OpOverload] = set({})
  483. # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag".
  484. # This information is useful for logging.
  485. self.compliant_custom_ops: set[torch._ops.OpOverload] = set({})
  486. # We save the global torch state here to be restored in case of graph
  487. # breaks. The relevant issue is seen here
  488. # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086
  489. # where inlining of a function changes the global state (because of the
  490. # presence of torch.no_grad) and there is a graph break.
  491. self.save_global_state()
  492. # Tracks the original FQNs of the constant tensors from the original graph,
  493. # i.e. buffers and parameters.
  494. self.dynamo_flat_name_to_original_fqn: dict[str, str] = {}
  495. # All calls to random() are replaced with a single call to __gen_rand_values
  496. # functions that returns a tuple of random values for each original call.
  497. # random_calls tracks calls to random() and random_values_var stores the name of
  498. # the variable that stores __gen_rand_values results.
  499. self.random_calls: list[
  500. tuple[Callable[..., object], tuple[object, ...], dict[str, object]]
  501. ] = []
  502. self.random_values_var: Any = None
  503. # Bytecode to insert right before we call the graph
  504. self.pregraph_bytecode: list[Instruction] = []
  505. # Use to pass values to backward hooks when using compiled autograd
  506. self.backward_state: dict[str, VariableTracker] = {}
  507. self.backward_state_proxy: Optional[torch.fx.Proxy] = None
  508. self.backward_state_var: Optional[str] = None
  509. self.name_of_builtins_dict_key_in_fglobals: str = (
  510. self.install_builtins_dict_in_fglobals()
  511. )
  512. self.compiler_trace_stack = contextlib.ExitStack()
  513. # These are the ambient, currently-global saved_tensor_hooks stashed in autograd,
  514. # that are set for the entire duration of the compiled region.
  515. # This is an invariant today because we graph break on the saved_tensor_hook
  516. # context manager inside a compiled region
  517. self.saved_tensors_hooks_subgraph_names: Optional[list[str]] = (
  518. self.maybe_install_saved_tensors_hooks_subgraphs()
  519. )
  520. # mangled alias -> module fqn name
  521. self.import_sources: dict[str, str] = {}
  522. self.export_metadata = ExportMetaData()
  523. def mark_bytecode_tracing_start(self) -> None:
  524. self.compiler_trace_stack.enter_context(
  525. dynamo_timed(
  526. "bytecode_tracing",
  527. log_pt2_compile_event=True,
  528. )
  529. )
  530. def mark_bytecode_tracing_stop(self) -> None:
  531. self.compiler_trace_stack.close()
  532. def install_builtins_dict_in_fglobals(self) -> str:
  533. f_builtins = get_builtins_dict(self.global_scope)
  534. return self.install_global("__builtins_dict__", f_builtins)
  535. def add_backward_state_hook(
  536. self, hook: VariableTracker, prefix: str = "hook"
  537. ) -> tuple[str, torch.fx.Proxy]:
  538. name = f"{prefix}{len(self.backward_state)}"
  539. assert name not in self.backward_state
  540. self.backward_state[name] = hook
  541. return name, self.get_backward_state_proxy()
  542. def get_backward_state_proxy(self) -> torch.fx.Proxy:
  543. if self.backward_state_proxy is None:
  544. if self.export:
  545. unimplemented_v2(
  546. gb_type="backward_state does not support export",
  547. context="",
  548. explanation="Compiled autograd doesn't work with `torch.export`.",
  549. hints=[],
  550. )
  551. example_value = BackwardState()
  552. self.backward_state_proxy = self.root_tracer.create_graph_input(
  553. "dynamo_backward_state",
  554. type(example_value),
  555. example_value,
  556. source=BackwardStateSource(),
  557. )
  558. self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
  559. self.backward_state_var = self.new_var()
  560. return self.backward_state_proxy
  561. # This gets its own helper function so guards DEBUG logs are more informative
  562. def init_ambient_guards(self) -> None:
  563. # Register a SHAPE_ENV guard to make sure we setup shape guards
  564. # that show up in ShapeEnv
  565. self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
  566. self.guards.add(
  567. GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
  568. )
  569. self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
  570. self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
  571. self.guards.add(
  572. GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
  573. )
  574. ci = torch._C._functorch.peek_interpreter_stack()
  575. if ci is not None:
  576. self.guards.add(
  577. GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
  578. )
  579. if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
  580. self.guards.add(
  581. GlobalStateSource().make_guard(
  582. GuardBuilder.AUTOGRAD_SAVED_TENSORS_HOOKS
  583. )
  584. )
  585. def maybe_install_saved_tensors_hooks_subgraphs(self) -> Optional[list[str]]:
  586. if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
  587. return None
  588. get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
  589. are_inline_hooks = (
  590. torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
  591. )
  592. hooks = get_hooks()
  593. if not are_inline_hooks(hooks):
  594. return None
  595. # If GraphModule provided by user contains fx.wrap,
  596. # We can only rely on user provided cache hash in this case.
  597. # If user did not provide cache hash - then we always bypass cache.
  598. pack_gm, unpack_gm = hooks
  599. pack_subgraph_name = self.install_subgraph(
  600. "saved_tensors_hooks_pack",
  601. torch.fx.GraphModule(self.nn_modules, pack_gm.graph),
  602. )
  603. unpack_subgraph_name = self.install_subgraph(
  604. "saved_tensors_hooks_unpack",
  605. torch.fx.GraphModule(self.nn_modules, unpack_gm.graph),
  606. )
  607. assert pack_subgraph_name == "saved_tensors_hooks_pack_0"
  608. assert unpack_subgraph_name == "saved_tensors_hooks_unpack_0"
  609. return [pack_subgraph_name, unpack_subgraph_name]
  610. def dump_guards_state(self) -> OutputGraphGuardsState:
  611. # Dump a serializable version of self without extras
  612. return OutputGraphGuardsState(
  613. local_scope=self.local_scope,
  614. global_scope=self.global_scope,
  615. torch_function_mode_stack=self.torch_function_mode_stack,
  616. guard_on_key_order=self.guard_on_key_order,
  617. input_source_to_sizes_strides=self.input_source_to_sizes_strides,
  618. dual_level=self.dual_level,
  619. functorch_layers=self.functorch_layers,
  620. current_device=self.current_device,
  621. global_state_guard=self.global_state_guard,
  622. name_of_builtins_dict_key_in_fglobals=self.name_of_builtins_dict_key_in_fglobals,
  623. export=self.export,
  624. export_constraints=self.export_constraints,
  625. _guards=self.guards,
  626. _aotautograd_guards=self.aotautograd_guards,
  627. skip_guards_check=self.skip_guards_check,
  628. )
  629. def synthetic_graph_input(
  630. self, fn: Callable[..., Any], args: tuple[Any, ...]
  631. ) -> VariableTracker:
  632. """
  633. call fn(*args) before the graph runs and turn the result into a fake input.
  634. """
  635. example_value = fn(*args)
  636. varname = self.new_var()
  637. cg = PyCodegen(self.root_tx)
  638. cg.add_push_null(
  639. lambda: cg.load_import_from(
  640. fn.__module__,
  641. fn.__name__,
  642. )
  643. )
  644. cg.foreach(map(variables.ConstantVariable.create, args))
  645. cg.call_function(len(args), False)
  646. cg.store(varname)
  647. self.pregraph_bytecode.extend(cg.get_instructions())
  648. source = SyntheticLocalSource(varname)
  649. result = VariableTracker.build(self.root_tx, example_value, source)
  650. # Realize the VT because we will delete the guards on it in the next line.
  651. result = result.realize()
  652. TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
  653. source
  654. )
  655. return result
  656. def add_cleanup_hook(self, fn: Callable[[], Any]) -> None:
  657. self.cleanup_hooks.append(fn)
  658. def call_cleanup_hooks(self) -> None:
  659. for hook in reversed(self.cleanup_hooks):
  660. hook()
  661. self.cleanup_hooks.clear()
  662. @property
  663. def root_tracer(self) -> "SubgraphTracer":
  664. return self.tracers[0]
  665. @property
  666. def current_tracer(self) -> "SubgraphTracer":
  667. return self.tracers[-1]
  668. def is_root_tracer(self) -> bool:
  669. # Helper to tell if we are inside the higher order operator tracing.
  670. return len(self.tracers) == 1
  671. @property
  672. def graph(self) -> torch.fx.Graph:
  673. return self.current_tracer.graph
  674. # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer.
  675. @graph.setter
  676. def graph(self, value: torch.fx.Graph) -> None:
  677. self.current_tracer.graph = value
  678. @property
  679. def input_name_to_proxy(self) -> dict[str, fx.Proxy]:
  680. return self.current_tracer.input_name_to_proxy
  681. @property
  682. def real_value_cache(self) -> dict[fx.Node, torch.Tensor]:
  683. return self.current_tracer.real_value_cache
  684. @property
  685. def bound_symbols(self) -> dict[sympy.Symbol, Union[torch.fx.Proxy, "LazyProxy"]]:
  686. return self.current_tracer.bound_symbols
  687. # If you are here, and you're looking for create_graph_input,
  688. # to avoid ambiguity, please call one of the following:
  689. # - self.current_tracer.create_graph_input
  690. # - self.root_tracer.create_graph_input
  691. # See NOTE [HigherOrderOperator tracing design] for more context.
  692. def create_proxy(self, *args: Any, **kwargs: Any) -> torch.fx.Proxy:
  693. return self.current_tracer.create_proxy(*args, **kwargs)
  694. def create_node(self, *args: Any, **kwargs: Any) -> torch.fx.Node:
  695. return self.current_tracer.create_node(*args, **kwargs)
  696. def remove_node(self, *args: Any, **kwargs: Any) -> None:
  697. return self.current_tracer.remove_node(*args, **kwargs)
  698. @contextlib.contextmanager
  699. def subtracer(
  700. self, source_target: Optional[Target], prior_tracer: "SubgraphTracer"
  701. ) -> Generator[fx.Tracer, None, None]:
  702. new_scope_ctx = enter_new_scope()
  703. try:
  704. if prior_tracer:
  705. # Lineage MUST stay preserved
  706. assert prior_tracer.parent is self.current_tracer
  707. new_scope_ctx.__enter__()
  708. tracer = (
  709. prior_tracer
  710. if prior_tracer
  711. else SubgraphTracer(
  712. self,
  713. parent=self.current_tracer,
  714. source_target=source_target,
  715. is_export=self.current_tracer.is_export,
  716. )
  717. )
  718. self.tracers.append(tracer)
  719. yield tracer
  720. finally:
  721. new_scope_ctx.__exit__(None, None, None)
  722. self.tracers.pop()
  723. @property
  724. def output(self) -> "OutputGraph":
  725. return self
  726. @property
  727. def fake_mode(self) -> torch._subclasses.FakeTensorMode:
  728. assert self.tracing_context.fake_mode is not None
  729. return self.tracing_context.fake_mode
  730. @property
  731. def shape_env(self) -> ShapeEnv:
  732. assert self.tracing_context.fake_mode is not None
  733. assert self.tracing_context.fake_mode.shape_env is not None
  734. return self.tracing_context.fake_mode.shape_env
  735. @property
  736. def guards(self) -> torch._guards.GuardsSet:
  737. return self.tracing_context.guards_context.dynamo_guards
  738. @property
  739. def nn_modules(self) -> dict[str, Any]:
  740. return self.tracing_context.module_context.nn_modules
  741. @property
  742. def aotautograd_guards(self) -> list[torch._guards.GuardEnvExpr]:
  743. return self.tracing_context.guards_context.aotautograd_guards
  744. def save_global_state(
  745. self, out: Optional[dict[str, tuple[Callable[..., Any], bool]]] = None
  746. ) -> None:
  747. """
  748. Saves to out if it is provided. Else saves to the tracing context's global_state.
  749. """
  750. global_state = cast(
  751. dict[str, tuple[Callable[..., Any], bool]],
  752. (
  753. out
  754. if out is not None
  755. else self.tracing_context.global_context.global_state
  756. ),
  757. )
  758. global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
  759. global_state["autocast_enabled"] = (
  760. functools.partial(torch.set_autocast_enabled, "cuda"),
  761. torch.is_autocast_enabled("cuda"),
  762. )
  763. global_state["autocast_cpu_enabled"] = (
  764. functools.partial(torch.set_autocast_enabled, "cpu"),
  765. torch.is_autocast_enabled("cpu"),
  766. )
  767. global_state["autocast_gpu_dtype"] = ( # type:ignore[assignment]
  768. functools.partial(torch.set_autocast_dtype, "cuda"),
  769. torch.get_autocast_dtype("cuda"),
  770. )
  771. global_state["autocast_cpu_dtype"] = ( # type:ignore[assignment]
  772. functools.partial(torch.set_autocast_dtype, "cpu"),
  773. torch.get_autocast_dtype("cpu"),
  774. )
  775. global_state["autocast_cache_enabled"] = (
  776. torch.set_autocast_cache_enabled,
  777. torch.is_autocast_cache_enabled(),
  778. )
  779. def push_tx(self, tx: "InstructionTranslatorBase") -> None:
  780. self._current_tx.append(tx)
  781. def pop_tx(self) -> "InstructionTranslatorBase":
  782. return self._current_tx.pop()
  783. @property
  784. def current_tx(self) -> "InstructionTranslatorBase":
  785. return self.root_tx if not self._current_tx else self._current_tx[-1]
  786. def count_calls(self) -> int:
  787. return count_calls(self.graph)
  788. def is_empty_graph(self) -> bool:
  789. return len(list(self.graph.nodes)) == 0
  790. def has_outputs(self) -> bool:
  791. return len([x for x in self.graph.nodes if x.op == "output"]) > 0
  792. def get_submodule(self, keys: str) -> Union[torch.nn.Module, Any]:
  793. assert keys
  794. obj: Union[torch.nn.Module, dict[str, torch.nn.Module]] = self.nn_modules
  795. for k in keys.split("."):
  796. if isinstance(obj, dict):
  797. obj = obj[k]
  798. else:
  799. obj = getattr(obj, k)
  800. return obj
  801. def new_var(self, name: str = "tmp") -> str:
  802. existing = set(self.code_options["co_varnames"])
  803. # In common case, this will be O(1)
  804. while True:
  805. var = f"{name}_{next(self.unique_var_id)}"
  806. if var not in existing:
  807. self.code_options["co_varnames"] += (var,)
  808. return var
  809. def update_co_names(self, name: str) -> None:
  810. """Ensure self.code_options.co_names contains name"""
  811. if name not in self.code_options["co_names"]:
  812. self.code_options["co_names"] += (name,)
  813. @staticmethod
  814. def module_key_name(*names: Any) -> str:
  815. # create a new unique name
  816. name = "_".join(map(str, names))
  817. # Strip the guard lookup L/G access
  818. name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name)
  819. # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv
  820. name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
  821. # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
  822. name = re.sub(r"[^a-zA-Z0-9]", "_", name)
  823. if not name or not name[0].isalpha():
  824. name = "sub" + name
  825. return name
  826. def register_static_attr_and_return_proxy(
  827. self, attr_prefix: str, attr_value: Any
  828. ) -> fx.Proxy:
  829. attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules)
  830. # TODO `nn_modules` has been historically overloaded to store a lot more
  831. # than just nn module objects, fix that.
  832. self.nn_modules[attr_name] = attr_value
  833. proxy = self.create_proxy("get_attr", attr_name, (), {})
  834. set_example_value(proxy.node, attr_value)
  835. return proxy
  836. def register_attr_or_module(
  837. self,
  838. target: Union[torch.nn.Module, torch.Tensor, Any],
  839. *names: Any,
  840. **options: Any,
  841. ) -> VariableTracker:
  842. if is_dynamic_nn_module(target, self.export):
  843. # Instead of returning UnspecializedNNModuleVariable, call
  844. # VariableTracker.build so that it is tracked for mutation.
  845. return VariableTracker.build(self.current_tx, target, **options)
  846. options = dict(options)
  847. assert "source" in options
  848. source = options["source"]
  849. assert not isinstance(source, ParamBufferSource)
  850. if isinstance(target, torch.Tensor):
  851. tracer = self.current_tracer
  852. if not self.is_root_tracer():
  853. # For higher order ops, we don't want to insert the get_attr in
  854. # innermost graph. Instead, we want to raise the params/buffers
  855. # as inputs to the higher-order graph, and register them as
  856. # get_attrs in the root tracer.
  857. # Note that Dynamo will still call lift_tracked_freevar_to_input
  858. # when these inputs are encountered for the inner graph. The
  859. # only difference is what happens at the root tracer for
  860. # nn.Parameters vs free inputs. The free inputs are registered
  861. # as placeholders in the root graph, whereas the nn.Parameters
  862. # are registered as get_attr nodes in the root graph.
  863. tracer = self.root_tracer
  864. def wrap_name(module_key: str) -> VariableTracker:
  865. assert self.param_name_to_source is not None
  866. self.param_name_to_source[module_key] = source
  867. # Check if the attr has already been registered. This can happen
  868. # when two different sources point to the same tensor.
  869. assert self.root_tx is not None
  870. if target in self.root_tx.output.side_effects:
  871. return self.root_tx.output.side_effects[target]
  872. if get_static_address_type(target) == "guarded" and not isinstance(
  873. source, NumpyTensorSource
  874. ):
  875. install_guard(source.make_guard(GuardBuilder.ID_MATCH))
  876. elif not is_constant_source(source):
  877. install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
  878. vt = wrap_fx_proxy(
  879. self.root_tx,
  880. tracer.create_proxy("get_attr", module_key, (), {}),
  881. example_value=target,
  882. **options,
  883. )
  884. # Track the object so to avoid duplicate registration in case of
  885. # different sources pointing to the same tensor object.
  886. vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
  887. assert "tensor_dict" not in vt.as_proxy().node.meta
  888. vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target)
  889. return vt
  890. elif isinstance(target, torch.nn.Module):
  891. assert isinstance(target, torch.nn.Module)
  892. if source:
  893. install_guard(source.make_guard(GuardBuilder.NN_MODULE))
  894. def wrap_name(module_key: str) -> VariableTracker:
  895. return NNModuleVariable(type(target), module_key, target, **options)
  896. else:
  897. # This is Dynamo created graph module, e.g., graph module coming
  898. # from higher order ops. NNModuleVariable tracker can't be
  899. # sourceless, so let's return a unspecializedNNModule variable
  900. # tracker.
  901. def wrap_name(module_key: str) -> VariableTracker:
  902. return variables.UnspecializedNNModuleVariable(target, **options)
  903. elif isinstance(target, (torch.SymInt, torch.SymFloat)):
  904. # HACKY CODE REGION BEGIN
  905. # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
  906. # This ultimately gets written to self.nn_modules, which is unfortunate
  907. # Attrs that are tenors and symints and such need to be migrated to have their
  908. # own storage
  909. # alas, this is like this for now
  910. def wrap_name(module_key: str) -> VariableTracker:
  911. return SymNodeVariable.create(
  912. self,
  913. self.create_proxy("get_attr", module_key, (), {}),
  914. sym_num=target,
  915. **options,
  916. )
  917. # HACKY CODE REGION END
  918. else:
  919. def wrap_name(module_key: str) -> VariableTracker:
  920. self.output.update_co_names(module_key)
  921. self.global_scope[module_key] = target
  922. return VariableTracker.build(
  923. self, # type: ignore[arg-type]
  924. target,
  925. ConstantSource(source_name=module_key),
  926. )
  927. for k, v in self.nn_modules.items():
  928. if v is target:
  929. # it already exists
  930. return wrap_name(k)
  931. name = OutputGraph.module_key_name(*names)
  932. name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
  933. self.nn_modules[name] = target
  934. if isinstance(target, torch.nn.Module):
  935. def register_leaf_name(leaf_name: str) -> None:
  936. assert self.param_name_to_source is not None
  937. new_source = ParamBufferSource(source, leaf_name)
  938. new_name = f"{name}.{leaf_name}"
  939. self.param_name_to_source[new_name] = new_source
  940. if isinstance(source, LocalSource):
  941. self.dynamo_flat_name_to_original_fqn[
  942. OutputGraph.module_key_name(new_source.name())
  943. ] = leaf_name
  944. # annoying, but there are cases when we do not have parameters
  945. # see test_nn_moduledict_contains
  946. if hasattr(target, "_parameters"):
  947. for leaf_name, _ in target.named_parameters():
  948. register_leaf_name(leaf_name)
  949. if hasattr(target, "_buffers"):
  950. for leaf_name, _ in target.named_buffers():
  951. register_leaf_name(leaf_name)
  952. return wrap_name(name)
  953. def handle_aliases_for_stolen_lists(
  954. self, tx: "InstructionTranslatorBase"
  955. ) -> tuple[list[Instruction], dict[Source, Source]]:
  956. # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive
  957. maybe_gm = self.local_scope.get("self")
  958. stolen_list_names = get_locals_to_steal(maybe_gm)
  959. if not stolen_list_names:
  960. return [], {}
  961. alias_insts = []
  962. needs_alias: dict[str, list[VariableTracker]] = {}
  963. queue = [
  964. *tx.stack,
  965. *tx.symbolic_locals.values(),
  966. *self.side_effects.store_attr_mutations.keys(),
  967. ]
  968. while queue:
  969. x = queue.pop()
  970. if isinstance(x, BaseListVariable):
  971. assert isinstance(x.items, list)
  972. queue += x.items
  973. continue
  974. if not (
  975. (
  976. x not in self.side_effects.store_attr_mutations
  977. or isinstance(x.mutation_type, AttributeMutationExisting)
  978. )
  979. and isinstance(x.source, GetItemSource)
  980. and isinstance(x.source.base, LocalSource)
  981. and x.source.base.local_name in stolen_list_names
  982. ):
  983. continue
  984. stolen_name = x.source.base.local_name
  985. if stolen_name not in needs_alias:
  986. needs_alias[stolen_name] = []
  987. needs_alias[stolen_name].append(x)
  988. visited = {}
  989. overridden_sources: dict[Source, Source] = {}
  990. for arg in self.graphargs:
  991. if not (
  992. isinstance(arg._example, list)
  993. and isinstance(arg.source, LocalSource)
  994. and arg.source.local_name in needs_alias
  995. ):
  996. continue
  997. # arg is a list that will be cleared by the compiled function
  998. list_name = arg.source.local_name
  999. assert list_name in self.code_options["co_varnames"]
  1000. for x in needs_alias[list_name]:
  1001. # Skip if already handled.
  1002. if x.source in overridden_sources:
  1003. continue
  1004. # A small codegen optimization because we might have different
  1005. # VariableTrackers that share the same source.
  1006. list_idx = x.source.index # type: ignore[attr-defined]
  1007. if list_idx not in visited:
  1008. alias_name = self.new_var(
  1009. f"{list_name}_ref"
  1010. ) # self.new_var already adds unique id suffix
  1011. visited[list_idx] = alias_name
  1012. # bytecode of `alias_name = list_name[list_idx]`
  1013. alias_insts.extend(
  1014. [
  1015. create_instruction("LOAD_FAST", argval=list_name),
  1016. create_load_const(list_idx),
  1017. create_instruction("BINARY_SUBSCR"),
  1018. create_instruction("STORE_FAST", argval=alias_name),
  1019. ]
  1020. )
  1021. # operate on alias, handled by suffix codegen
  1022. old_source = x.source
  1023. overridden_sources[old_source] = LocalSource(visited[list_idx])
  1024. # NOTE: we need `overridden_sources` because (1) we want to codegen for
  1025. # these list items to use the new local source, but (2) we want to avoid
  1026. # updating `source` in place because that might break invariants in
  1027. # other parts of Dynamo like guards.
  1028. return alias_insts, overridden_sources
  1029. def _get_stack_values_to_restore(
  1030. self, tx: "InstructionTranslatorBase", stack_pops: int
  1031. ) -> tuple[list[VariableTracker], StackLocalsMetadata]:
  1032. """
  1033. Gets the stack + locals values belonging to tx that need to be restored.
  1034. Also prunes dead tx locals and realizes all VTs in the tx's stack.
  1035. NullVariables in stack/locals will NOT be restored, unless they are the top `stack_pops`
  1036. elements of the stack - it is expected that the next instruction to run will pop the top
  1037. `stack_pops` elements of the stack, so we should codegen NULLs.
  1038. Returns:
  1039. - stack_values: stack and locals values that need to be restored
  1040. - meta: locations of NULLs and ContextWrappingVariables in the stack/locals
  1041. (ignores the top `stack_pops` values on the stack)
  1042. """
  1043. tx.prune_dead_locals()
  1044. stack_values = []
  1045. meta = StackLocalsMetadata()
  1046. # realize any unrealized tensor VTs in case they
  1047. # need to be added to self.nn_modules as attributes
  1048. for i, value in enumerate(tx.stack):
  1049. variables.LazyVariableTracker.realize_all(value)
  1050. # ignore top `stack_pops` values on the stack
  1051. if len(tx.stack) - i <= stack_pops:
  1052. stack_values.append(value)
  1053. continue
  1054. if isinstance(value, NullVariable):
  1055. meta.stack_null_idxes.append(i)
  1056. else:
  1057. stack_values.append(value)
  1058. if isinstance(value, ContextWrappingVariable):
  1059. target_values = (
  1060. () if value.target_values is None else tuple(value.target_values)
  1061. )
  1062. # NOTE: track index in stack after NULLs have been removed
  1063. meta.stack_ctx_args.append((len(stack_values) - 1, target_values))
  1064. meta.stack_ctx_idxes_orig.append(i)
  1065. meta.num_stack = len(stack_values)
  1066. cell_and_freevars = set(tx.cellvars() + tx.freevars())
  1067. # NB: Typically (i.e., for graph compile from RETURN_VALUE),
  1068. # symbolic_locals will be empty at this point, as prune_dead_locals
  1069. # will clear out all of symbolic_locals because RETURN_VALUE is the
  1070. # last instruction and no more locals are used. The fanciness here
  1071. # is only needed for partial graphs.
  1072. # NOTE: All cell and free variables are represented as CellVariable,
  1073. # so checks for NULLs and context managers in the case of codegen'ing resume
  1074. # functions will not be performed on them. This is expected behavior.
  1075. for k, v in tx.symbolic_locals.items():
  1076. # Note! this explicitly uses .local_name for matching
  1077. # Failure to do so will cause spurious registrations in val_to_names.
  1078. # This will in turn result in spurious variables showing up in the graph.
  1079. # This was very tricky to debug. For an example, dump the graph at call_user_compiler
  1080. # while running test_subgraphs.py
  1081. # Do not include top-frame unmodified locals here - otherwise, the compiled graph may
  1082. # erroneously include them as part of the return. We manually codegen them afterward.
  1083. if (
  1084. isinstance(v.source, LocalSource)
  1085. and v.source.local_name == k
  1086. and tx is self.root_tx
  1087. ):
  1088. continue
  1089. # Do not load cell/free vars
  1090. if k in cell_and_freevars:
  1091. continue
  1092. # Do not load variable if it is NULL.
  1093. if sys.version_info >= (3, 12):
  1094. # NOTE: do not use isinstance, since it realizes lazy VT's
  1095. # Continuation function will load the NULL for v.
  1096. if type.__instancecheck__(NullVariable, v):
  1097. meta.locals_null_keys.append(k)
  1098. continue
  1099. else:
  1100. # A variable should never be NULL in < 3.12
  1101. assert not type.__instancecheck__(NullVariable, v)
  1102. meta.locals_names[k] = len(meta.locals_names)
  1103. if isinstance(v, ContextWrappingVariable):
  1104. target_values = (
  1105. () if v.target_values is None else tuple(v.target_values)
  1106. )
  1107. meta.locals_ctx_args.append((k, target_values))
  1108. stack_values.append(v)
  1109. return stack_values, meta
  1110. def compile_subgraph(
  1111. self,
  1112. tx: "InstructionTranslatorBase",
  1113. reason: GraphCompileReason,
  1114. partial_convert: bool = False,
  1115. stack_pops: int = 0,
  1116. ) -> list[StackLocalsMetadata]:
  1117. """
  1118. Compiles the current subgraph, with inputs w.r.t. self.root_tx, and codegens:
  1119. - Call the compiled subgraph
  1120. - Apply side effects
  1121. - Codegen stack and locals
  1122. - Store the locals
  1123. Python does not allow NULL to be an arg to a function, so we do not codegen NULLs on the stack,
  1124. unless the value is one of the top `stack_pops` values on the stack (these values are expected to be
  1125. popped immediately after this generated code. The prologue of the resume function is expected to restore
  1126. any dropped NULLs.
  1127. Returns stack indices and locals keys where we dropped NULLs, and where we found inactive context manager objects.
  1128. """
  1129. assert self.root_tx is not None
  1130. if not config.nested_graph_breaks:
  1131. # expect to only compile 1 frame
  1132. assert self.root_tx is tx
  1133. # bytecode tracing has finished. Pop the context manager for dynamo_timed
  1134. self.mark_bytecode_tracing_stop()
  1135. self.partial_convert = partial_convert
  1136. self.compile_subgraph_reason = reason
  1137. self.should_exit = True
  1138. log.debug("COMPILING GRAPH due to %s", reason)
  1139. # prefix instructions (Python 3.11+)
  1140. prefix_insts: list[Instruction] = []
  1141. if sys.version_info >= (3, 11):
  1142. for inst in self.root_tx.prefix_insts:
  1143. if inst.opname == "COPY_FREE_VARS":
  1144. prefix_insts.append(
  1145. create_instruction(
  1146. "COPY_FREE_VARS",
  1147. arg=len(self.root_tx.code_options["co_freevars"]),
  1148. )
  1149. )
  1150. else:
  1151. prefix_insts.append(copy.copy(inst))
  1152. # stack values and restore vars for each frame are pushed in reverse order
  1153. # i.e. last element corresponds to root frame (1),
  1154. # first element corresponds to current frame (N)
  1155. all_stack_values = []
  1156. all_stack_locals_metas = []
  1157. cur_tx: Optional[InstructionTranslatorBase] = tx
  1158. while cur_tx is not None:
  1159. # this should have been checked by the caller
  1160. assert all(block.can_restore() for block in cur_tx.block_stack)
  1161. stack_values, meta = self._get_stack_values_to_restore(
  1162. cur_tx, stack_pops if cur_tx is tx else 0
  1163. )
  1164. all_stack_values.append(stack_values)
  1165. all_stack_locals_metas.append(meta)
  1166. # Exit from all context manager variables to make sure global state is restored
  1167. for block in reversed(cur_tx.block_stack):
  1168. block.exit(cur_tx, is_graph_break=reason.graph_break)
  1169. cur_tx = cur_tx.parent
  1170. # "Garbage collect the heap".
  1171. self.side_effects.prune_dead_object_new(tx)
  1172. self.add_output_instructions(prefix_insts)
  1173. assert not (self.pregraph_bytecode and self.export), (
  1174. "export does not support pregraph_bytecode"
  1175. )
  1176. self.add_output_instructions(self.pregraph_bytecode)
  1177. alias_insts, overridden_sources = self.handle_aliases_for_stolen_lists(
  1178. self.root_tx
  1179. )
  1180. self.add_output_instructions(alias_insts)
  1181. self.cleanup_graph()
  1182. # Use nn.Module "proxies" in the constructed GraphModule so that
  1183. # the resulting GM does not hold additional strong references to the original modules.
  1184. # This prevents a strong ref cycle where Dynamo created code holds on to references
  1185. # to modules that also have Dynamo code cache invalidation checks.
  1186. # When cache invalidation runs, the generated GM will be invalidated, which also deletes
  1187. # the proxies.
  1188. nn_modules_proxies = {
  1189. name: nn_module_proxy(mod) for name, mod in self.nn_modules.items()
  1190. }
  1191. root = FakeRootModule(nn_modules_proxies)
  1192. from .decorators import disable
  1193. # to handle random calls
  1194. if len(self.random_calls) > 0:
  1195. random_calls_instructions = []
  1196. self.random_values_var = self.new_var("random_values")
  1197. rand_fn = disable(
  1198. _get_gen_rand_values_fn(self.random_calls),
  1199. reason="do not trace into Dynamo rng recovery function",
  1200. )
  1201. rand_fn_name = self.install_global("__gen_rand_values", rand_fn)
  1202. codegen = PyCodegen(
  1203. self.root_tx, root, overridden_sources=overridden_sources
  1204. )
  1205. random_calls_instructions.extend(
  1206. codegen.load_function_name(rand_fn_name, True)
  1207. )
  1208. random_calls_instructions.extend(create_call_function(0, False))
  1209. random_calls_instructions.append(
  1210. codegen.create_store(self.random_values_var),
  1211. )
  1212. self.add_output_instructions(random_calls_instructions)
  1213. # Codegen stack convention before the unsupported instruction
  1214. # NOTE: in these comment blocks, "locals" EXCLUDE free and cell vars.
  1215. # NOTE: stack and locals must be codegen'd BEFORE the unsupported instruction, since the latter
  1216. # can arbitrarily mutate the former.
  1217. # [
  1218. # frame N locals,
  1219. # frame N-1 stack + locals,
  1220. # ...,
  1221. # frame 1 stack + locals,
  1222. # ], frame N stack
  1223. # see symbolic_convert.py for
  1224. # codegen stack convention after the unsupported instruction
  1225. # NOTE: cells are loaded into continuation functions directly
  1226. # this determines the order that values are codegen'd to the stack
  1227. stack_values_flat = [val for vals in all_stack_values for val in vals]
  1228. stored_graph_output_var = False
  1229. graph_output_var = None
  1230. # call compiled fx graph and codegen all values - stack and locals
  1231. if (
  1232. self.root_tx is tx # single frame
  1233. and stack_values_flat
  1234. and all(
  1235. not isinstance(
  1236. v,
  1237. (
  1238. UnspecializedPythonVariable,
  1239. NumpyNdarrayVariable,
  1240. TensorWithTFOverrideVariable,
  1241. ),
  1242. )
  1243. and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
  1244. for v in stack_values_flat
  1245. )
  1246. and all(isinstance(x, TensorVariable) for x in stack_values_flat)
  1247. and len(set(stack_values_flat)) == len(stack_values_flat)
  1248. and self.side_effects.is_empty()
  1249. and not tx.debug_locals
  1250. and not self.backward_state
  1251. and not all_stack_locals_metas[-1].stack_null_idxes
  1252. and not all_stack_locals_metas[-1].locals_null_keys
  1253. ):
  1254. # optimization to generate better code in a common case
  1255. self.add_output_instructions(
  1256. [
  1257. # load in reverse since UNPACK_SEQUENCE will reverse
  1258. *self.compile_and_call_fx_graph(
  1259. tx, list(reversed(stack_values_flat)), root
  1260. ),
  1261. create_instruction("UNPACK_SEQUENCE", arg=len(stack_values_flat)),
  1262. ]
  1263. )
  1264. # function output will be moved to the correct places below
  1265. else:
  1266. graph_output_var = self.new_var("graph_out")
  1267. # load stack values in a flat manner - we will codegen bytecode to place them correctly
  1268. # according to our convention above
  1269. pass1 = PyCodegen(
  1270. self.root_tx,
  1271. root,
  1272. graph_output_var,
  1273. overridden_sources=overridden_sources,
  1274. )
  1275. self.codegen_suffix(tx, stack_values_flat, pass1)
  1276. # Use `pass1.uses` to selectively cache multi-user variables into a
  1277. # temporary local source. This (a). speeds up loading VTs with long
  1278. # chained source, and (b). avoids redundantly saving single-user VT
  1279. # into a temporary local.
  1280. tempvars = {} # type: ignore[var-annotated]
  1281. for val, count in pass1.uses.items():
  1282. # If it's already a local source, no need to cache it
  1283. if count > 1 and not istype(val, (SyntheticLocalSource, LocalSource)):
  1284. tempvars[val] = None
  1285. pass2 = PyCodegen(
  1286. self.root_tx,
  1287. root,
  1288. graph_output_var,
  1289. tempvars=tempvars,
  1290. overridden_sources=overridden_sources,
  1291. )
  1292. self.codegen_suffix(tx, stack_values_flat, pass2)
  1293. if (
  1294. torch._dynamo.config.log_graph_in_out_metadata
  1295. and stack_values_flat
  1296. and len(stack_values_flat) == 1
  1297. ):
  1298. vt = stack_values_flat[0]
  1299. if (
  1300. isinstance(vt, torch._dynamo.variables.NamedTupleVariable)
  1301. and vt.tuple_cls
  1302. is torch._dynamo.functional_export.ExportTracerOutput
  1303. ):
  1304. flat_returns = vt.items[0]
  1305. out_spec = vt.items[1]
  1306. assert isinstance(
  1307. flat_returns, torch._dynamo.variables.ListVariable
  1308. )
  1309. vt_to_graph_out_idx: dict[VariableTracker, int] = {}
  1310. for value in pass2.graph_outputs.values():
  1311. assert isinstance(value, torch._dynamo.codegen.GraphOutputEntry)
  1312. variable: VariableTracker = value.variable
  1313. vt_to_graph_out_idx[variable] = value.index
  1314. for idx, vt in enumerate(flat_returns.items):
  1315. if vt in vt_to_graph_out_idx:
  1316. self.export_metadata.output_return_type[idx] = (
  1317. "graph_out",
  1318. vt_to_graph_out_idx[vt],
  1319. )
  1320. elif (
  1321. vt.source is not None
  1322. and (source := getattr(vt.source, "base", None))
  1323. and source.is_input
  1324. ):
  1325. self.export_metadata.output_return_type[idx] = (
  1326. "input",
  1327. vt.source,
  1328. )
  1329. elif isinstance(vt, torch._dynamo.variables.ConstantVariable):
  1330. self.export_metadata.output_return_type[idx] = (
  1331. "constant",
  1332. vt.as_python_constant(),
  1333. )
  1334. else:
  1335. assert f"Encountered unrecognized type {vt} at output {idx}" # noqa: PLW0129
  1336. self.export_metadata.out_spec = out_spec.as_python_constant()
  1337. output = []
  1338. if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
  1339. output.extend(
  1340. self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  1341. )
  1342. if len(pass2.graph_outputs) != 0:
  1343. output.append(pass2.create_store(graph_output_var))
  1344. stored_graph_output_var = True
  1345. else:
  1346. output.append(create_instruction("POP_TOP"))
  1347. else:
  1348. # NB: Important to run compiler collective even when there is
  1349. # a graph break
  1350. self.run_compiler_collective()
  1351. self.add_output_instructions(output + pass2.get_instructions())
  1352. # store all stack and locals for each frame
  1353. # current state of the stack:
  1354. # *(frame N stack), *(frame N locals),
  1355. # ...,
  1356. # *(frame 1 stack), *(frame 1 locals)
  1357. self.add_output_instructions(
  1358. [
  1359. create_instruction(
  1360. "BUILD_LIST",
  1361. arg=len(stack_values_flat) - all_stack_locals_metas[0].num_stack,
  1362. ),
  1363. ]
  1364. )
  1365. # current state of the stack:
  1366. # *(frame N stack), [
  1367. # *(frame N locals),
  1368. # *(frame N-1 stack), *(frame N-1 locals),
  1369. # ...
  1370. # *(frame 1 stack), *(frame 1 locals),
  1371. # ]
  1372. # iterate current frame (N) to root frame (1)
  1373. # sliding window over frame stack/locals
  1374. start_idx = 0
  1375. end_idx = 0
  1376. for i, meta in enumerate(all_stack_locals_metas):
  1377. # do not pack frame N's stack into the value list
  1378. n_vals = len(meta.locals_names)
  1379. if i != 0:
  1380. n_vals += meta.num_stack
  1381. if n_vals == 0:
  1382. self.add_output_instructions(
  1383. [
  1384. create_instruction("BUILD_LIST", arg=0),
  1385. *create_swap(2),
  1386. ]
  1387. )
  1388. # [], stack_values_flat
  1389. else:
  1390. end_idx += n_vals
  1391. self.add_output_instructions(
  1392. [
  1393. create_dup_top(),
  1394. *create_binary_slice(start_idx, end_idx),
  1395. *create_swap(2),
  1396. ]
  1397. )
  1398. start_idx += n_vals
  1399. # stack_values_flat[x:y], stack_values_flat
  1400. # add root frame's unmodified locals here
  1401. if i == len(all_stack_locals_metas) - 1:
  1402. root_cg = PyCodegen(self.root_tx)
  1403. unmodified_locals_names: dict[str, int] = {}
  1404. for k, v in self.root_tx.symbolic_locals.items():
  1405. if isinstance(v.source, LocalSource) and v.source.local_name == k:
  1406. root_cg.append_output(root_cg.create_load(k))
  1407. unmodified_locals_names[k] = len(meta.locals_names) + len(
  1408. unmodified_locals_names
  1409. )
  1410. self.add_output_instructions(
  1411. root_cg.get_instructions()
  1412. + [
  1413. create_instruction(
  1414. "BUILD_LIST", arg=len(unmodified_locals_names)
  1415. ),
  1416. # arg=2 because we already swapped the locals list back
  1417. create_instruction("LIST_EXTEND", arg=2),
  1418. ]
  1419. )
  1420. meta.locals_names.update(unmodified_locals_names)
  1421. # *(frame N stack), metas[0] stack + locals, ..., metas[i] stack + locals, stack_values_flat
  1422. # current state of the stack:
  1423. # *(frame N stack)
  1424. # frame N locals,
  1425. # frame N-1 stack, frame N-1 locals,
  1426. # ...
  1427. # frame 1 stack, frame 1 locals,
  1428. # stack_values_flat
  1429. #
  1430. self.add_output_instructions(
  1431. [
  1432. create_instruction("POP_TOP"),
  1433. create_instruction("BUILD_LIST", arg=len(all_stack_locals_metas)),
  1434. *create_rot_n(all_stack_locals_metas[0].num_stack + 1),
  1435. ]
  1436. )
  1437. # final state of the stack before running the unsupported bytecode:
  1438. # [
  1439. # [frame N locals],
  1440. # [frame N-1 stack + locals],
  1441. # ...,
  1442. # [frame 1 stack + locals],
  1443. # ], *(frame N stack)
  1444. if graph_output_var and stored_graph_output_var:
  1445. self.add_output_instructions(
  1446. [create_instruction("DELETE_FAST", argval=graph_output_var)]
  1447. )
  1448. if self.export:
  1449. from torch.export._trace import _ExportModuleSpecTrackerDict
  1450. potential_side_effects = []
  1451. for var in self.side_effects._get_modified_vars():
  1452. if hasattr(var, "mutation_type"):
  1453. mut_type = var.mutation_type
  1454. # Make sure to skip codegen specific mutations
  1455. if isinstance(
  1456. mut_type, (AttributeMutationExisting, ValueMutationExisting)
  1457. ):
  1458. # export uses tracepoint pass to dump submodule inp/out spec
  1459. # into global state, so we filter it here
  1460. if not (
  1461. isinstance(var, UserDefinedDictVariable)
  1462. and isinstance(var.value, _ExportModuleSpecTrackerDict)
  1463. ):
  1464. potential_side_effects.append(var)
  1465. side_effect_refs = [
  1466. _get_source_debug_name(var.source) for var in potential_side_effects
  1467. ]
  1468. if len(side_effect_refs):
  1469. warnings.warn(
  1470. f"While exporting, we found certain side effects happened in the model.forward. "
  1471. f"Here are the list of potential sources you can double check: {side_effect_refs}"
  1472. )
  1473. return all_stack_locals_metas
  1474. def codegen_suffix(
  1475. self,
  1476. tx: "InstructionTranslatorBase",
  1477. stack_values: list[VariableTracker],
  1478. cg: PyCodegen,
  1479. ) -> None:
  1480. # NOTE: `codegen_save_tempvars` must run first to update `source` fields
  1481. # for variables with `AttributeMutationNew`, as they don't implement
  1482. # `reconstruct` themselves.
  1483. self.side_effects.codegen_save_tempvars(cg)
  1484. if self.backward_state:
  1485. assert not self.export
  1486. for name, val in self.backward_state.items():
  1487. cg(val)
  1488. assert self.backward_state_var is not None
  1489. cg.append_output(cg.create_load(self.backward_state_var))
  1490. cg.store_attr(name)
  1491. self.side_effects.codegen_hooks(cg)
  1492. # Return variables used for logging at the end
  1493. for debug_var, args in tx.debug_locals:
  1494. cg.add_push_null(lambda: cg(debug_var))
  1495. for arg in args:
  1496. cg(arg)
  1497. cg.extend_output(create_call_function(len(args), False))
  1498. cg.extend_output([create_instruction("POP_TOP")])
  1499. cg.restore_stack(stack_values, value_from_source=not tx.export)
  1500. self.side_effects.codegen_update_mutated(cg)
  1501. def cleanup_graph(self) -> None:
  1502. """
  1503. Remove "creation_timestamp" from node meta
  1504. Remove this pattern from the graph:
  1505. torch._C._set_grad_enabled(False)
  1506. torch._C._set_grad_enabled(True)
  1507. """
  1508. assert self.should_exit
  1509. nodes = list(self.graph.nodes)
  1510. for node in nodes:
  1511. node.meta.pop("creation_timestamp", None)
  1512. grad_enabled = torch.is_grad_enabled()
  1513. for node1, node2 in zip(nodes, nodes[1:]):
  1514. if (
  1515. node1.target is torch._C._set_grad_enabled
  1516. and tuple(node1.args) == (not grad_enabled,)
  1517. and not node1._erased
  1518. ):
  1519. grad_enabled = node1.args[0]
  1520. if (
  1521. node2.target is torch._C._set_grad_enabled
  1522. and tuple(node2.args) == (not grad_enabled,)
  1523. and not node2._erased
  1524. ):
  1525. grad_enabled = node2.args[0]
  1526. self.graph.erase_node(node1)
  1527. self.graph.erase_node(node2)
  1528. def bypass_package(self, reason: str = "", **kwargs: Any) -> None:
  1529. """
  1530. Do not save this output graph to the CompilePackage
  1531. """
  1532. if not self.package:
  1533. return
  1534. if torch._dynamo.config.strict_precompile:
  1535. raise torch._dynamo.exc.PackageError(
  1536. "Detected a package bypass: %s", reason
  1537. )
  1538. log.warning("Detected a package bypass: %s", reason)
  1539. torch._logging.trace_structured(
  1540. "artifact",
  1541. metadata_fn=lambda: {
  1542. "name": "precompile_cache_bypass",
  1543. "encoding": "json",
  1544. },
  1545. payload_fn=lambda: {
  1546. # precede with underscore so it always appear first in JSON in tlparse
  1547. "_reason": reason,
  1548. **kwargs,
  1549. },
  1550. )
  1551. self.package.bypass_current_entry()
  1552. self.package = None
  1553. def get_graph_sizes_structured(self) -> dict[str, list[Union[int, str]]]:
  1554. ret: dict[str, list[Union[int, str]]] = {}
  1555. for node in self.graph.nodes:
  1556. example_value = node.meta.get("example_value", None)
  1557. if isinstance(example_value, torch._subclasses.FakeTensor):
  1558. size = example_value.size()
  1559. ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size]
  1560. return ret
  1561. def get_graph_sizes(self, name: str) -> str:
  1562. graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
  1563. graph_sizes_str += f"===== {name} =====\n"
  1564. for node in self.graph.nodes:
  1565. example_value = node.meta.get("example_value", None)
  1566. if isinstance(example_value, torch._subclasses.FakeTensor):
  1567. size = example_value.size()
  1568. graph_sizes_str += f"{node.name}: {tuple(size)}\n"
  1569. concrete_size = []
  1570. has_symint = False
  1571. for sz in size:
  1572. if isinstance(sz, int):
  1573. concrete_size.append(sz)
  1574. elif isinstance(sz, torch.SymInt):
  1575. has_symint = True
  1576. concrete_size.append(sz.node.hint)
  1577. else:
  1578. break
  1579. else:
  1580. if has_symint:
  1581. graph_sizes_str += (
  1582. f"{node.name} (concrete): {tuple(concrete_size)}\n"
  1583. )
  1584. return graph_sizes_str
  1585. @contextlib.contextmanager
  1586. def restore_global_state(self) -> Any:
  1587. """
  1588. Momentarily restores the global state to what it was prior to tracing the current output
  1589. """
  1590. prior_global_state = self.tracing_context.global_context.copy_graphstate()
  1591. current_global_state: dict[str, tuple[Any, bool]] = {}
  1592. self.save_global_state(out=current_global_state)
  1593. try:
  1594. # Set to state prior to tracing the graph
  1595. self.tracing_context.global_context.restore_graphstate(prior_global_state)
  1596. yield
  1597. finally:
  1598. # Reset to state at the current time (e.g. before calling the user compiler)
  1599. self.tracing_context.global_context.restore_graphstate(
  1600. GlobalContextCheckpointState(current_global_state)
  1601. )
  1602. def run_compiler_collective(self) -> None:
  1603. tx = self.root_tx
  1604. assert tx is not None
  1605. if (ds := tx.distributed_state) is not None and ds.all_states is None:
  1606. compile_pg = ds.compile_pg
  1607. log.info("compiler_collective %s", ds.local_state)
  1608. torch._logging.trace_structured(
  1609. "artifact",
  1610. metadata_fn=lambda: {
  1611. "name": "compiler_collective",
  1612. "encoding": "string",
  1613. },
  1614. payload_fn=lambda: ds.local_state.render(),
  1615. )
  1616. device_types = compile_pg._device_types
  1617. assert len(device_types) == 1, (
  1618. "Expect only one device type but got {}".format("+".join(device_types))
  1619. )
  1620. with (
  1621. get_interface_for_device(device_types.pop()).device( # type: ignore[attr-defined]
  1622. compile_pg.rank() % torch.accelerator.device_count()
  1623. ),
  1624. dynamo_timed("compiler_collective", log_pt2_compile_event=True),
  1625. ):
  1626. all_states: list[Any] = [None] * compile_pg.size()
  1627. dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
  1628. ds.all_states = all_states
  1629. # Clear speculation log, because are tracing may diverge due to
  1630. # this information from the compiler collective
  1631. tx.speculation_log.clear()
  1632. raise exc.CompileCollectiveRestartAnalysis
  1633. def compile_and_call_fx_graph(
  1634. self,
  1635. tx: "InstructionTranslatorBase",
  1636. rv: list[VariableTracker],
  1637. root: FakeRootModule,
  1638. ) -> list[Instruction]:
  1639. """
  1640. Generate code from self.graph and return the Instruction()s to
  1641. call that generated code.
  1642. Code is generated w.r.t. self.root_tx.
  1643. tx is only used for preserving GraphModule metadata
  1644. """
  1645. with torch._guards.TracingContext.clear_frame():
  1646. from .decorators import disable
  1647. assert self.should_exit
  1648. self.run_compiler_collective()
  1649. if count_calls(self.graph) == 0 and len(rv) == 0:
  1650. return []
  1651. name = unique_id("__compiled_fn", with_uuid=True)
  1652. assert isinstance(rv, list)
  1653. assert isinstance(root, FakeRootModule)
  1654. output_node = self.create_node(
  1655. "output",
  1656. "output",
  1657. (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
  1658. {},
  1659. )
  1660. sub_gms = self.dedup_pass()
  1661. root.add_nn_modules(sub_gms) # type: ignore[arg-type]
  1662. self.current_tracer._maybe_preserve_original_meta(tx, output_node)
  1663. if not config.do_not_emit_runtime_asserts:
  1664. # There is a rare scenario where codegen_suffix adds a new entry
  1665. # to self.nn_modules while `root` knows only about the
  1666. # nn_modules at the time of its creation. This causes failures
  1667. # while creating the graph module because self.graph and root
  1668. # are out of sync. This only happens for `get_attr` nodes, so
  1669. # here we clean up the get_attr nodes that are unused.
  1670. self.remove_unused_get_attr_nodes()
  1671. insert_deferred_runtime_asserts(
  1672. fx.GraphModule(root, self.graph),
  1673. self.shape_env,
  1674. name,
  1675. export=self.export,
  1676. )
  1677. # NB: deferred runtime asserts can keep graphargs live, so make sure
  1678. # those are inserted before pruning
  1679. self.remove_unused_graphargs()
  1680. ncalls = count_calls(self.graph)
  1681. counters["stats"]["calls_captured"] += ncalls
  1682. self.remove_tensorify_specialized_graphargs()
  1683. # free a bit of memory
  1684. self.real_value_cache.clear()
  1685. gm = _make_graph_module(root, self.graph)
  1686. # Saved tensors hooks are not used by the graph.
  1687. # GraphModule by default only copies used in the graph submodules.
  1688. # Copying them into the result graph manually.
  1689. if self.saved_tensors_hooks_subgraph_names:
  1690. for subgraph_name in self.saved_tensors_hooks_subgraph_names:
  1691. setattr(gm, subgraph_name, getattr(root, subgraph_name))
  1692. for register_finalizer in self.register_finalizer_fns:
  1693. register_finalizer(gm)
  1694. if next(gm.parameters(), None) is not None:
  1695. # If dynamo produces a graph with parameters, skip package stuff
  1696. # Bypass output graph
  1697. self.bypass_package(
  1698. "Graph contains named parameters: either inline_inbuilt_nn_modules=False or there are static addresses.",
  1699. inline_builtin_nn_modules=torch._dynamo.config.inline_inbuilt_nn_modules,
  1700. gm=gm.print_readable(
  1701. print_output=False, include_stride=True, include_device=True
  1702. ),
  1703. )
  1704. if self.package is not None:
  1705. gm._backend_id = name
  1706. gm.compile_subgraph_reason = self.compile_subgraph_reason
  1707. gm.meta["dynamo_flat_name_to_original_fqn"] = (
  1708. self.dynamo_flat_name_to_original_fqn.copy()
  1709. )
  1710. gm.meta["dynamo_compile_id"] = self.dynamo_compile_id
  1711. gm.meta["backend_id"] = name
  1712. graph_code_log.debug(
  1713. "%s",
  1714. lazy_format_graph_code(
  1715. name, gm, include_stride=True, include_device=True, colored=True
  1716. ),
  1717. )
  1718. torch._logging.trace_structured(
  1719. "dynamo_output_graph",
  1720. lambda: {"sizes": self.get_graph_sizes_structured()},
  1721. payload_fn=lambda: gm.print_readable(
  1722. print_output=False, include_stride=True, include_device=True
  1723. ),
  1724. )
  1725. self.call_cleanup_hooks()
  1726. old_fake_mode = self.tracing_context.fake_mode
  1727. assert old_fake_mode is not None
  1728. if not self.export:
  1729. import torch._functorch.config as _config
  1730. with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
  1731. # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
  1732. backend_fake_mode = torch._subclasses.FakeTensorMode(
  1733. shape_env=old_fake_mode.shape_env,
  1734. )
  1735. # TODO(voz): Ostensibily, this should be scoped and
  1736. # restore back to old_fake_mode, but doing so currently violates
  1737. # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
  1738. self.tracing_context.fake_mode = backend_fake_mode
  1739. with self.restore_global_state():
  1740. compiled_fn = self.call_user_compiler(gm, self.example_inputs())
  1741. from torch.fx._lazy_graph_module import _LazyGraphModule
  1742. if isinstance(compiled_fn, _LazyGraphModule) or (
  1743. isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
  1744. and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined]
  1745. ):
  1746. # Since dynamo will run the forward method for the GraphModule shortly
  1747. # anyways, it does not hurt to do the real recompilation here if
  1748. # this is a _LazyGraphModule. This makes it easier for dynamo to
  1749. # optimize a _LazyGraphModule.
  1750. lazy_gm = (
  1751. compiled_fn
  1752. if isinstance(compiled_fn, _LazyGraphModule)
  1753. else compiled_fn.__self__ # type: ignore[attr-defined]
  1754. )
  1755. _LazyGraphModule.force_recompile(lazy_gm)
  1756. if not isinstance(compiled_fn, _LazyGraphModule):
  1757. # replace compiled_fn with the real forward method
  1758. compiled_fn = lazy_gm.forward
  1759. if self.package is not None:
  1760. self.package.add_backend_id(name, compiled_fn)
  1761. compiled_fn = disable(
  1762. compiled_fn, reason="do not trace Dynamo-compiled graph"
  1763. )
  1764. counters["stats"]["unique_graphs"] += 1
  1765. assert old_fake_mode.shape_env is not None
  1766. if specializations := old_fake_mode.shape_env.specializations:
  1767. specialization_guards = []
  1768. specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
  1769. sources = [a.source for a in self.graphargs]
  1770. for specialization in specializations:
  1771. source_index = sources.index(specialization.source)
  1772. check_fn_source = inspect.getsource(specialization.check_fn).strip()
  1773. # Required because the LABDA_GUARD API requires a root guard manager
  1774. unused_root_guard_manager = RootGuardManager()
  1775. check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
  1776. unused_root_guard_manager,
  1777. specialization.check_fn,
  1778. [check_fn_source],
  1779. )
  1780. log.debug(
  1781. "Compiling backend specialized graph with specialization=%s",
  1782. check_fn_source,
  1783. )
  1784. specialization_guards.append(
  1785. (
  1786. functools.partial(
  1787. lambda idx, args, check_fn=check_fn: check_fn(
  1788. args[idx]
  1789. ),
  1790. source_index,
  1791. ),
  1792. specialization,
  1793. )
  1794. )
  1795. @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") # type: ignore[misc]
  1796. def specialized_dispatch(*args: Any, **kwargs: Any) -> Any:
  1797. for check_fn, specialization in specialization_guards:
  1798. if check_fn(args):
  1799. if specialization in specialization_cache:
  1800. return specialization_cache[specialization](
  1801. *args, **kwargs
  1802. )
  1803. with self.shape_env.patch_source_specialization(
  1804. specialization.source, specialization.check_fn
  1805. ):
  1806. # Modify gm so AOTAutogradCache key changes per specialization
  1807. gm.meta["specialization"] = specialization
  1808. example_inputs: list[Tensor] = list(args)
  1809. with tracing(self.tracing_context):
  1810. specialization_cache[specialization] = (
  1811. self.call_user_compiler(gm, example_inputs)
  1812. )
  1813. return specialization_cache[specialization](*args, **kwargs)
  1814. return compiled_fn(*args, **kwargs)
  1815. # This is safe because we pre-process name to be unique
  1816. self.install_global_unsafe(name, specialized_dispatch)
  1817. else:
  1818. # This is safe because we pre-process name to be unique
  1819. self.install_global_unsafe(name, compiled_fn)
  1820. assert self.root_tx is not None
  1821. cg = PyCodegen(self.root_tx)
  1822. for idx, arg in enumerate(self.graphargs):
  1823. self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source
  1824. cg.make_call_generated_code(name)
  1825. return cg.get_instructions()
  1826. @property
  1827. def placeholders(self) -> list[fx.Node]:
  1828. return self.graph.find_nodes(op="placeholder")
  1829. @property
  1830. def graphargs(self) -> list[GraphArg]:
  1831. return [node.meta["grapharg"] for node in self.placeholders]
  1832. def call_user_compiler(
  1833. self, gm: fx.GraphModule, example_inputs: list[Tensor]
  1834. ) -> CompiledFn:
  1835. with dynamo_timed(
  1836. "OutputGraph.call_user_compiler",
  1837. phase_name="backend_compile",
  1838. log_pt2_compile_event=True,
  1839. log_waitcounter=True,
  1840. waitcounter_name_override="compile_aot_autograd",
  1841. dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
  1842. ):
  1843. return self._call_user_compiler(gm, example_inputs)
  1844. def _call_user_compiler(
  1845. self, gm: fx.GraphModule, example_inputs: list[Tensor]
  1846. ) -> CompiledFn:
  1847. assert self.compiler_fn is not None
  1848. tot = 0
  1849. placeholders = []
  1850. for node in gm.graph.nodes:
  1851. if node.op in ("call_function", "call_method", "call_module"):
  1852. tot += 1
  1853. if node.op == "placeholder":
  1854. placeholders.append(node)
  1855. increment_op_count(tot)
  1856. for pl in placeholders:
  1857. if not hasattr(pl, "_dynamo_source"):
  1858. arg = pl.meta["grapharg"]
  1859. # TODO: Why isn't this stored in meta :think:
  1860. # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
  1861. pl._dynamo_source = arg.source
  1862. # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
  1863. gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
  1864. gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]
  1865. name = (
  1866. self.compiler_fn.__name__
  1867. if hasattr(self.compiler_fn, "__name__")
  1868. else "<unknown compiler_fn>"
  1869. )
  1870. try:
  1871. _step_logger()(logging.INFO, f"calling compiler function {name}")
  1872. compiler_fn = self.compiler_fn
  1873. if config.verify_correctness:
  1874. compiler_fn = WrapperBackend(compiler_fn)
  1875. compiled_fn = compiler_fn(gm, example_inputs)
  1876. _step_logger()(logging.INFO, f"done compiler function {name}")
  1877. assert callable(compiled_fn), "compiler_fn did not return callable"
  1878. except (TensorifyScalarRestartAnalysis, ShortenTraceback):
  1879. raise
  1880. except exceptions_allowed_to_be_fallback as e:
  1881. if self.has_user_defined_allowed_in_graph:
  1882. raise BackendCompilerFailed(
  1883. self.compiler_fn, e, inspect.currentframe()
  1884. ).with_traceback(e.__traceback__) from None
  1885. unimplemented_v2_with_warning(
  1886. e,
  1887. self.root_tx.f_code,
  1888. gb_type="Backend compiler exception",
  1889. context=f"Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}",
  1890. explanation=f"Backend compiler `{name}` failed with {str(e)}. Adding a graph break.",
  1891. hints=[
  1892. "Report an issue to the backend compiler repo.",
  1893. ],
  1894. )
  1895. except SkipFrame as e:
  1896. # The backend compiler has requested that we skip the frame, instead of
  1897. # aborting execution.
  1898. raise e
  1899. except Exception as e:
  1900. raise BackendCompilerFailed(
  1901. self.compiler_fn, e, inspect.currentframe()
  1902. ).with_traceback(e.__traceback__) from None
  1903. signpost_event(
  1904. "dynamo",
  1905. "OutputGraph.call_user_compiler",
  1906. {
  1907. **self.co_fields,
  1908. "op_count": tot,
  1909. "node_count": len(gm.graph.nodes),
  1910. "input_count": len(placeholders),
  1911. },
  1912. )
  1913. return compiled_fn
  1914. def dedup_pass(self) -> dict[str, torch.fx.GraphModule]:
  1915. if torch._dynamo.config.use_graph_deduplication:
  1916. return apply_graph_deduplication(self)
  1917. else:
  1918. return {}
  1919. def install_subgraph(self, name: str, sub_gm: torch.fx.GraphModule) -> str:
  1920. next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True)
  1921. sub_gm.__name__ = next_name # type: ignore[assignment]
  1922. sub_gm.torchdynamo_force_dynamic = False # type: ignore[assignment]
  1923. # This graph module is not present in the user space, so it can't be
  1924. # accessed by a source. Set source=None.
  1925. self.register_attr_or_module(sub_gm, next_name, source=None)
  1926. return next_name
  1927. def example_inputs(self) -> list[torch.Tensor]:
  1928. result = [arg.example for arg in self.graphargs]
  1929. return result
  1930. def remove_unused_get_attr_nodes(self) -> None:
  1931. for node in sorted(self.graph.find_nodes(op="get_attr"), reverse=True):
  1932. if len(list(node.users)) == 0:
  1933. self.remove_node(node)
  1934. def remove_unused_graphargs(self) -> None:
  1935. # NB: It's OK to drop GraphArg for symbols that ended up being
  1936. # specialized iff they are not used in runtime assertions. You don't
  1937. # even have to make a guard for it, because ShapeEnv produce_guards
  1938. # operates on tracked_fakes, which never gets pruned.
  1939. # That being said, you'll get marginally better generated
  1940. # guard code if you promote the guard into a Dynamo guard (since that
  1941. # allows for the guard to be done using C++ guards.) If we get
  1942. # ShapeEnv guards to go into C++ guards, this will stop being a thing
  1943. # though!
  1944. assert self.should_exit
  1945. # Miniature DCE pass, but only for obviously trivial operations
  1946. def is_static_true(b_node: fx.node.Argument) -> bool:
  1947. if b_node is True:
  1948. return True
  1949. if not isinstance(b_node, fx.Node):
  1950. return False
  1951. b = b_node.meta.get("example_value")
  1952. if b is None:
  1953. return False
  1954. if b is True:
  1955. return True
  1956. if (
  1957. isinstance(b, torch.SymBool)
  1958. and (r := b.node.maybe_as_bool()) is not None
  1959. ):
  1960. return r
  1961. # TODO: We can also technically remove all cases when the input
  1962. # doesn't have unbacked inputs, since it's all in the ShapeEnv
  1963. return False
  1964. def is_symnode_arg(a: fx.node.Argument) -> bool:
  1965. from torch.fx.experimental.sym_node import SymTypes
  1966. if isinstance(a, (int, float, bool)):
  1967. return True
  1968. if isinstance(a, fx.Node):
  1969. return isinstance(a.meta.get("example_value"), SymTypes)
  1970. return False
  1971. # NB: We assume that you cannot do mutations on int/float/bool,
  1972. # because they are immutable types, and therefore is always safe to
  1973. # DCE.
  1974. def is_symnode_compute_node(node: fx.Node) -> bool:
  1975. from torch.fx.experimental.sym_node import SymTypes
  1976. if node.op != "call_function":
  1977. return False
  1978. # TODO: I don't think it's possible to have a bare int/float here?
  1979. if not isinstance(node.meta.get("example_value"), SymTypes):
  1980. return False
  1981. # TODO: This will bail here if you ever end up with a more complicated
  1982. # computation function, like sum(list_of_ints), even though it
  1983. # should be DCE'able
  1984. if not all(is_symnode_arg(a) for a in node.args):
  1985. return False
  1986. if not all(is_symnode_arg(a) for a in node.kwargs.values()):
  1987. return False
  1988. return True
  1989. from torch.fx.experimental.symbolic_shapes import is_accessor_node
  1990. for node in reversed(list(self.graph.nodes)):
  1991. if len(list(node.users)) == 0:
  1992. if (
  1993. node.op == "get_attr"
  1994. or (node.op == "call_function" and node.target is operator.getitem)
  1995. or (
  1996. node.op == "call_function"
  1997. and node.target is torch._check
  1998. and is_static_true(node.args[0])
  1999. )
  2000. or is_symnode_compute_node(node)
  2001. or is_accessor_node(node)
  2002. ):
  2003. self.remove_node(node)
  2004. def placeholder_binds_symbol(node: fx.Node) -> Optional[sympy.Symbol]:
  2005. arg = node.meta["grapharg"]
  2006. example = arg.example
  2007. if isinstance(example, torch.SymInt) and isinstance(
  2008. example.node.expr, sympy.Symbol
  2009. ):
  2010. return example.node.expr
  2011. return None
  2012. def remove_unused(node: fx.Node) -> None:
  2013. log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name())
  2014. # I'm not really sure why you need to delete these from the
  2015. # node since the node is going to get removed
  2016. del node.meta["grapharg"]
  2017. self.remove_node(node)
  2018. self.real_value_cache.pop(node, None)
  2019. used_symbols: set[sympy.Symbol] = set()
  2020. def update_used_symbols(
  2021. used_symbols: set[sympy.Symbol], fake: Union[torch.SymInt, torch.Tensor]
  2022. ) -> None:
  2023. used_symbols |= free_symbols(fake)
  2024. recheck_placeholders = []
  2025. for node in self.placeholders:
  2026. binds_symbol = placeholder_binds_symbol(node) is not None
  2027. # Don't delete symbol bindings yet
  2028. if binds_symbol:
  2029. if not node.users:
  2030. recheck_placeholders.append(node)
  2031. else:
  2032. if not node.users and not isinstance(
  2033. node.meta["grapharg"], BackwardStateGraphArg
  2034. ):
  2035. remove_unused(node)
  2036. else:
  2037. # Register the free symbols as uses
  2038. arg = node.meta["grapharg"]
  2039. if isinstance(arg, BackwardStateGraphArg):
  2040. continue
  2041. if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
  2042. real_script_obj = node.meta["grapharg"].example
  2043. fake_script_obj = node.meta["grapharg"].example_strong_ref
  2044. if not torch._library.fake_class_registry.tracing_with_real(
  2045. real_script_obj
  2046. ):
  2047. flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined]
  2048. for attr in flat_dict.keys():
  2049. fake_attr_val = getattr(
  2050. fake_script_obj.wrapped_obj, attr
  2051. )
  2052. pytree.tree_map_only(
  2053. (torch.SymInt, torch.Tensor),
  2054. lambda t: update_used_symbols(used_symbols, t),
  2055. fake_attr_val,
  2056. )
  2057. continue
  2058. fake = (
  2059. arg.fake_tensor if arg.fake_tensor is not None else arg.example
  2060. )
  2061. update_used_symbols(used_symbols, fake)
  2062. # After removing unused graphargs, prune unused binds_symbol
  2063. for node in recheck_placeholders:
  2064. symbol = placeholder_binds_symbol(node)
  2065. if symbol is not None:
  2066. if symbol not in used_symbols:
  2067. remove_unused(node)
  2068. else:
  2069. # Make sure we delete later occurrences of the same symbol
  2070. used_symbols.remove(symbol)
  2071. def remove_tensorify_specialized_graphargs(self) -> None:
  2072. # This is a pretty interesting function. Basically we have this problem
  2073. # where our compiler tends to choke when we have unused inputs. The way
  2074. # we support dynamic float arguments is by doing a joint fx pass and
  2075. # tensorifying away as many symfloats as we can. For the remaining symfloats
  2076. # we have no choice but to specialize... HOWEVER at that point in time
  2077. # we can no longer remove graph inputs. So our sledgehammer solution is to
  2078. # save the state of what inputs we should have specialized in dynamo and
  2079. # restart analysis. This function incorporates this "view from the future"
  2080. # state and specializes inputs that we know we won't be able to tensorify
  2081. # away in the joint pass. In principle we shouldn't choke on unused inputs
  2082. # and so this shouldn't be necessary. In practice CUDA graphs choke on
  2083. # unused inputs so we need this for now.
  2084. # Import here to prevent circular import
  2085. from torch._dynamo.symbolic_convert import TensorifyState
  2086. for node in self.graph.nodes:
  2087. example_value = node.meta.get("example_value")
  2088. if (
  2089. isinstance(example_value, FakeTensor)
  2090. and example_value.item_memo is not None
  2091. and hasattr(example_value.item_memo.node._expr, "name")
  2092. and all(u.target == "item" for u in node.users)
  2093. and TensorifyState.should_specialize(
  2094. # We use _expr instead of expr b/c we want the symbol not the replacement
  2095. example_value.item_memo.node._expr.name
  2096. )
  2097. ):
  2098. for u in list(node.users):
  2099. u.replace_all_uses_with(guard_scalar(example_value.item_memo))
  2100. self.remove_node(u)
  2101. self.remove_node(node)
  2102. def add_output_instructions(self, prefix: list[Instruction]) -> None:
  2103. """
  2104. We call this on the creation of a new compiled subgraph that is inserted
  2105. before user code.
  2106. """
  2107. self.output_instructions.extend(prefix)
  2108. self.should_exit = True
  2109. def install_global_unsafe(self, name: str, value: Any) -> None:
  2110. """
  2111. WARNING: prefer the safer `install_global_by_id/install_global`.
  2112. torch.compile instances should be independent of each other;
  2113. one footgun is to have one instance depend on the existence of
  2114. a global installed by another instance. This can happen if we mangle
  2115. a global the same way across both instances.
  2116. """
  2117. assert name not in self.installed_globals
  2118. self.installed_globals.add(name)
  2119. self.cleanups.append(CleanupHook.create(self.global_scope, name, value))
  2120. def install_global_by_id(self, prefix: str, value: Any) -> str:
  2121. """
  2122. Installs a global if it hasn't been installed already.
  2123. This is determined by (prefix, id(value)) pair.
  2124. Returns the name of the newly installed global.
  2125. """
  2126. # NB: need self.compile_id to distinguish this global
  2127. # from another global created in a different torch.compile instance
  2128. name = f"{prefix}_{id(value)}_c{self.compile_id}"
  2129. if name in self.installed_globals:
  2130. return name
  2131. self.install_global_unsafe(name, value)
  2132. return name
  2133. def install_global(self, prefix: str, value: Any) -> str:
  2134. """
  2135. Installs a global, generating a unique name for it.
  2136. Returns the name of the newly installed global.
  2137. """
  2138. # NB: unique_id is unique, even across torch.compile instances
  2139. name = unique_id(prefix)
  2140. self.install_global_unsafe(name, value)
  2141. return name
  2142. def cleanup(self) -> None:
  2143. # There is a reference cycle between tracer and OutputGraph, causing
  2144. # some of the tensor objects to be held alive for longer than necessary.
  2145. self.root_tx = None # type: ignore[assignment]
  2146. self.nn_modules.clear()
  2147. self.param_name_to_source = None
  2148. for node in self.graph.nodes:
  2149. if "grapharg" in node.meta:
  2150. del node.meta["grapharg"]
  2151. self.real_value_cache.clear()
  2152. self.input_name_to_proxy.clear()
  2153. self.side_effects.clear()
  2154. self.variable_tracker_cache.clear()
  2155. self.register_finalizer_fns.clear()
  2156. self.dynamo_flat_name_to_original_fqn.clear()
  2157. self.tracing_context.clear()
  2158. self.input_source_to_var.clear()
  2159. self.unspec_variable_map.clear()
  2160. self.backward_state.clear()
  2161. def add_graph_finalizer(
  2162. self, register_finalizer: Callable[[fx.GraphModule], None]
  2163. ) -> None:
  2164. self.register_finalizer_fns.append(register_finalizer)
  2165. def example_value_from_input_node(self, node: torch.fx.Node) -> Any:
  2166. """Extract the non-fake example tensor"""
  2167. if node.op == "placeholder":
  2168. return node.meta["grapharg"].example
  2169. assert node.op == "get_attr"
  2170. return self.nn_modules[node.target] # type: ignore[index]
  2171. class DynamoTracerOutput:
  2172. error_on_graph_break: bool
  2173. is_tracing_resume_prologue: bool
  2174. output_graph: Optional[OutputGraph]
  2175. def __init__(
  2176. self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None
  2177. ) -> None:
  2178. self.error_on_graph_break = tracer.error_on_graph_break
  2179. self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue
  2180. if error:
  2181. self.output_graph = None
  2182. else:
  2183. self.output_graph = tracer.output
  2184. err_epilogue = (
  2185. "With the current config, we will graph break "
  2186. "(and fall back to eager-mode PyTorch) on all ops "
  2187. "that have do not have the 'pt2_compliant_tag'. "
  2188. "Please see the following doc for how to mark this op as PT2 compliant "
  2189. "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html"
  2190. )
  2191. def check_pt2_compliant_op(
  2192. output_graph: OutputGraph, kind: str, target: Any, args: Any, kwargs: Any
  2193. ) -> None:
  2194. if kind != "call_function":
  2195. return
  2196. def encountered_compliant_op(target: torch._ops.OpOverload) -> None:
  2197. if target.namespace in {"prim", "prims", "aten"}:
  2198. return
  2199. output_graph.compliant_custom_ops.add(target)
  2200. def encountered_non_compliant_op(target: torch._ops.OpOverload, msg: str) -> None:
  2201. output_graph.non_compliant_ops.add(target)
  2202. if config.only_allow_pt2_compliant_ops:
  2203. unimplemented_v2(
  2204. gb_type="Encountered non-PT2-compliant op",
  2205. context="",
  2206. explanation=msg + " " + err_epilogue,
  2207. hints=[],
  2208. )
  2209. if isinstance(target, torch._ops.OpOverload):
  2210. if torch.Tag.pt2_compliant_tag in target.tags:
  2211. encountered_compliant_op(target)
  2212. return
  2213. encountered_non_compliant_op(
  2214. target,
  2215. f"Encountered the torch.ops.OpOverload {target} that is not PT2 compliant.",
  2216. )
  2217. return
  2218. if isinstance(target, torch._ops.OpOverloadPacket):
  2219. overloads = tuple(target.overloads())
  2220. # Optimization: Overload resolution is expensive.
  2221. # If there's only one overload, we know what it will resolve to.
  2222. if len(overloads) == 1:
  2223. op = getattr(target, overloads[0])
  2224. if torch.Tag.pt2_compliant_tag in op.tags:
  2225. encountered_compliant_op(op)
  2226. return
  2227. encountered_non_compliant_op(
  2228. op,
  2229. f"Encountered the non-overloaded "
  2230. f"torch.ops.OpOverloadPacket {target} "
  2231. f"that is not PT2 compliant. ",
  2232. )
  2233. return
  2234. args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes(
  2235. output_graph.current_tx, (args, kwargs), False
  2236. )
  2237. try:
  2238. overload = torch._C._jit_resolve_packet(
  2239. target._qualified_op_name, *args, **kwargs
  2240. )
  2241. except RuntimeError as e:
  2242. unimplemented_v2(
  2243. gb_type="Error when attempting to resolve op packet",
  2244. context="",
  2245. explanation=str(e),
  2246. hints=[],
  2247. )
  2248. op = getattr(target, overload)
  2249. if torch.Tag.pt2_compliant_tag in op.tags:
  2250. encountered_compliant_op(op)
  2251. else:
  2252. encountered_non_compliant_op(
  2253. op,
  2254. f"Encountered the torch.ops.OpOverloadPacket {target} "
  2255. f"which resolves to the overload ({overload}) that is "
  2256. f"not PT2 compliant.",
  2257. )
  2258. _compile_id_counter = itertools.count()
  2259. P = ParamSpec("P")
  2260. R = TypeVar("R")
  2261. class LazyProxy:
  2262. def __init__(
  2263. self,
  2264. tracer: "SubgraphTracer",
  2265. fn: Callable[P, R],
  2266. *args: P.args,
  2267. **kwargs: P.kwargs,
  2268. ) -> None:
  2269. self.tracer = tracer
  2270. self.fn = fn
  2271. self.args = args
  2272. self.kwargs = kwargs
  2273. def __call__(self) -> Any:
  2274. return self.fn(*self.args, **self.kwargs)
  2275. class SubgraphTracer(fx.Tracer):
  2276. """
  2277. Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
  2278. and the separation of responsibilities is that SubgraphTracer is
  2279. responsible for building the graph while OutputGraph is responsible for
  2280. compiling and executing the graph.
  2281. """
  2282. def __init__(
  2283. self,
  2284. output_graph: "OutputGraph",
  2285. parent: Optional["SubgraphTracer"] = None,
  2286. is_export: bool = False,
  2287. source_target: Optional[Target] = None,
  2288. ) -> None:
  2289. super().__init__()
  2290. self.output_graph = weakref.proxy(output_graph)
  2291. self.graph = torch.fx.Graph()
  2292. # See note [Export inputs must be explicitly passed in]
  2293. self.is_export = is_export
  2294. # Map from graph input name to its placeholder proxy object, where the
  2295. # map's keys give all current placeholder node names and can be used to
  2296. # create unique node names
  2297. self.input_name_to_proxy: dict[str, fx.Proxy] = {}
  2298. # Node => computed real value (see utils.get_real_value)
  2299. self.real_value_cache: dict[fx.Node, torch.Tensor] = {}
  2300. # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
  2301. self.parent = parent
  2302. self.source_target = source_target
  2303. # A dict mapping previously free variables (Proxy objects)
  2304. # to new Proxy objects that wrap inputs to this subgraph.
  2305. #
  2306. # This dict maps proxies in outer graphs to placeholders in current graph.
  2307. # It serves two purposes:
  2308. # - Proxies are associated with VariableTrackers. If we see
  2309. # the same VariableTracker twice (and it is a free variable),
  2310. # then we want to use the same Proxy in the current subgraph to
  2311. # record the tracing.
  2312. # - If we are tracing a HigherOrderOperator's body_fn, then we
  2313. # need to keep track of what free variables were lifted so we can
  2314. # rewrite the HigherOrderOperator call using the traced body_fn.
  2315. # Dicts maintain the order of args for the HigherOrderOperator call.
  2316. self.lifted_freevars: dict[fx.Proxy, fx.Proxy] = {}
  2317. # map basic symbols (unbacked and unbacked) to their bound proxies.
  2318. # There are only two cases where bound_symbols will be recorded:
  2319. # 1. when we create_graph_input for a backed SymInt that's basic symbol
  2320. # 2. when we track_produced_symints for intermediate results
  2321. # bound_symbols always map the symbol to the proxy whose
  2322. # tracer is the current tracer that's readily accessible in current tracer's graph.
  2323. self.bound_symbols: dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {}
  2324. self.prev_inst = None
  2325. # True if this tracer is currently tracing into torch.utils.checkpoint
  2326. # as part of speculate_subgraph.
  2327. self.under_activation_checkpoint = False
  2328. # True if we want to allow externally visible side-effects (doesn't throw error on their existence)
  2329. # during this tracer's tracing of torch.utils.checkpoint (via speculate_subgraph).
  2330. # Only safe if we know for sure that *NOT* replaying these side-effects during
  2331. # backward recomputation of the checkpoint region doesn't affect its correctness.
  2332. self.allow_side_effects_under_checkpoint = False
  2333. # True if we want to allow externally visible side-effects (doesn't throw error on their existence)
  2334. # during this tracer's tracing. This is currently only used by experimental AC out-of-tree
  2335. # via torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer.
  2336. # Note: Externally visible side-effects are allowed if this flag OR the above flag is True.
  2337. self.unsafe_allow_externally_visible_side_effects = False
  2338. # True if this tracer is currently tracing (reconstructing) into a Python generator
  2339. self.is_reconstructing_generator = False
  2340. self.debug_level: int = parent.debug_level + 1 if parent is not None else 0
  2341. self._cur_code = None
  2342. self._orig_gm_meta: Optional[list[Any]] = None
  2343. self._orig_gm_lineno_map: Optional[dict[int, Optional[int]]] = None
  2344. self._orig_gm_firstlineno: Optional[int] = None
  2345. # Each SubgraphTracer is associated with a source target, which indicates
  2346. # which operator this subgraph is attached to. We compute a source_fn_stack
  2347. # based on the source target. For the root tracer, it's set to [].
  2348. # This is useful for debugging and transforming the exported graph.
  2349. if self.parent is None:
  2350. self.source_fn_stack: list[Any] = []
  2351. else:
  2352. self.source_fn_stack = self.parent.source_fn_stack + [
  2353. (self.graph._target_to_str(source_target), source_target)
  2354. ]
  2355. # This is used to create a unique name for the placeholder
  2356. self._used_names: OrderedSet[str] = OrderedSet()
  2357. # Stores the versions of the input tensors at the time they are inserted
  2358. # as placeholders in the graph. This is used to track input mutation.
  2359. self._input_versions_at_beginning: list[int] = []
  2360. if torch.is_inference_mode_enabled():
  2361. raise RuntimeError(
  2362. "Inference mode is supposed to be disabled during compilation. Please open an issue."
  2363. )
  2364. # preserve original meta if it is available
  2365. def _maybe_preserve_original_meta(
  2366. self, tx: "InstructionTranslatorBase", node: fx.Node
  2367. ) -> None:
  2368. if (
  2369. self._orig_gm_meta
  2370. and self._orig_gm_lineno_map
  2371. and self._orig_gm_firstlineno
  2372. ):
  2373. lineno = tx.current_instruction.starts_line
  2374. node_idx = None
  2375. if lineno is not None:
  2376. node_idx = self._orig_gm_lineno_map.get(
  2377. lineno - self._orig_gm_firstlineno, None
  2378. )
  2379. if node_idx is not None:
  2380. meta = self._orig_gm_meta[node_idx]
  2381. for field in fx.proxy._COPY_META_FIELDS:
  2382. if field in meta:
  2383. node.meta[field] = meta[field]
  2384. if "stack_trace" in meta:
  2385. node.meta["stack_trace"] = meta["stack_trace"]
  2386. def create_proxy(
  2387. self,
  2388. kind: str,
  2389. target: Any,
  2390. args: Any,
  2391. kwargs: Any,
  2392. name: Optional[str] = None,
  2393. type_expr: Optional[Any] = None,
  2394. proxy_factory_fn: Optional[Callable[[fx.Node], fx.Proxy]] = None,
  2395. ) -> fx.Proxy:
  2396. # NOTE: [Nested SubgraphTracer and free_variable handling]
  2397. # --------------------------------------------------------
  2398. # Read NOTE [HigherOrderOperator tracing design] first.
  2399. #
  2400. # Let's say we're in the middle of introspecting the body of a possibly
  2401. # nested HigherOrderOperator, and we see a free variable.
  2402. #
  2403. # There are two cases:
  2404. # 1. We see a free variable that is already tracked by Dynamo.
  2405. # 2. We see a free variable that has not been tracked by Dynamo
  2406. #
  2407. # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below)
  2408. # which will lift the freevar to be an input of this subgraph
  2409. # and also recursively lift it to be an input on the parent(s).
  2410. #
  2411. # In case 2, before the call to `create_proxy`, the InstructionTranslator
  2412. # will see the freevar when it gets loaded by Python bytecode.
  2413. # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or
  2414. # LOAD_GLOBAL.
  2415. # There, the InstructionTranslator asks Dynamo to begin tracking the
  2416. # freevar by building a new Variable.
  2417. # Building a new Variable automatically lifts the freevar to be an
  2418. # input of the root SubgraphTracer.
  2419. #
  2420. # The implications for the code below are:
  2421. # - We will always be in Case 1 when we get to this code.
  2422. # - Any "free variable" we encounter here is guaranteed to already be
  2423. # bound, that is, it is either a graph input of the root graph, or
  2424. # some local variable of the root graph or a subgraph.
  2425. # - The additional work we need to do here is *only* that we need to
  2426. # lift this free variable into inputs (recursively) of each nested
  2427. # higher-order-op subgraph until we hit the subgraph where the free
  2428. # variable is bound
  2429. if self.parent is not None:
  2430. flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
  2431. new_flat_args = []
  2432. for arg in flat_args:
  2433. maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg)
  2434. new_flat_args.append(maybe_new_arg)
  2435. args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)
  2436. rv = super().create_proxy(
  2437. kind,
  2438. target,
  2439. args,
  2440. kwargs,
  2441. name,
  2442. type_expr,
  2443. proxy_factory_fn, # type: ignore[arg-type]
  2444. )
  2445. # append stack trace to fx node
  2446. tx = self.output_graph.current_tx
  2447. # log detailed location of line of code in 3.11
  2448. if sys.version_info >= (3, 11) and kind in (
  2449. "call_function",
  2450. "call_method",
  2451. "call_module",
  2452. ):
  2453. cur_inst = tx.current_instruction
  2454. if (
  2455. cur_inst is not self.prev_inst
  2456. and cur_inst.positions is not None
  2457. and cur_inst.positions.lineno is not None
  2458. ):
  2459. tx_code = tx.f_code
  2460. header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)
  2461. def get_trace_call_log_str() -> str:
  2462. line = get_instruction_source_311(tx_code, cur_inst).rstrip()
  2463. return f"TRACE FX call {rv.node.name} from {header}\n{line}"
  2464. trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
  2465. self.prev_inst = cur_inst
  2466. # update reference to original meta if we're tracing a new code object
  2467. is_retracing = False
  2468. if tx.f_code is not self._cur_code:
  2469. orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
  2470. "orig_graphmodule", lambda: None
  2471. )()
  2472. if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
  2473. is_retracing = True
  2474. self._orig_gm_meta = [
  2475. nd.meta for nd in orig_graphmodule_maybe.graph.nodes
  2476. ]
  2477. self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map
  2478. self._orig_gm_firstlineno = (
  2479. orig_graphmodule_maybe.forward.__code__.co_firstlineno
  2480. )
  2481. else:
  2482. self._orig_gm_meta = None
  2483. self._orig_gm_lineno_map = None
  2484. self._orig_gm_firstlineno = None
  2485. nn_module_stack = tx.nn_module_stack
  2486. if nn_module_stack:
  2487. rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
  2488. if kind in {"call_function", "call_method"}:
  2489. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  2490. (rv.node.name, target)
  2491. ]
  2492. elif kind == "call_module":
  2493. if self.parent is not None:
  2494. # TODO can remove once inline_inbuilt_nn_modules is always True
  2495. unimplemented_v2(
  2496. gb_type="Invoking an nn.Module inside a higher order operator",
  2497. context=f"Higher order op name: {self.source_target}",
  2498. explanation="This is not supported.",
  2499. hints=[],
  2500. )
  2501. # For modules we store the class
  2502. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  2503. (
  2504. rv.node.name,
  2505. next(
  2506. ty
  2507. for k, (_, ty) in rv.node.meta["nn_module_stack"].items()
  2508. if k.split("@")[0] == target
  2509. ),
  2510. )
  2511. ]
  2512. self._maybe_preserve_original_meta(tx, rv.node)
  2513. if not is_retracing:
  2514. if "nn_module_stack" not in rv.node.meta:
  2515. nn_module_stack = tx.nn_module_stack
  2516. if nn_module_stack:
  2517. rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
  2518. if "source_fn_stack" not in rv.node.meta:
  2519. if kind in {"call_function", "call_method"}:
  2520. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  2521. (rv.node.name, target)
  2522. ]
  2523. elif kind == "call_module":
  2524. if self.parent is not None:
  2525. # TODO can remove once inline_inbuilt_nn_modules is always True
  2526. unimplemented_v2(
  2527. gb_type="Invoking an nn.Module inside a HigherOrderOperator",
  2528. context="",
  2529. explanation="This is not supported.",
  2530. hints=[],
  2531. )
  2532. # For modules we store the class
  2533. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  2534. (
  2535. rv.node.name,
  2536. rv.node.meta["nn_module_stack"][target][1],
  2537. )
  2538. ]
  2539. if "stack_trace" not in rv.node.meta:
  2540. frame_summaries: list[traceback.FrameSummary] = []
  2541. while tx:
  2542. # Avoid frame summaries from inside the torch/nn/modules. This ensures that we keep the stack trace of
  2543. # the user code.
  2544. if not tx.is_co_filename_from_nn_modules():
  2545. frame_summaries.append(tx.frame_summary())
  2546. tx = getattr(tx, "parent", None)
  2547. # Reverse the frame_summaries, such that the innermost frame is at the last
  2548. frame_summaries.reverse()
  2549. # official from_list stub doesn't have new-style type
  2550. msgs = traceback.StackSummary.from_list(frame_summaries).format()
  2551. rv.node.stack_trace = "".join(msgs)
  2552. if (
  2553. torch._dynamo.config.use_graph_deduplication
  2554. or torch._dynamo.config.track_nodes_for_deduplication
  2555. ):
  2556. self.output_graph.region_tracker.track_node(
  2557. self.output_graph.current_tx, rv.node
  2558. )
  2559. return rv
  2560. def create_node(
  2561. self,
  2562. op: str,
  2563. target: Target,
  2564. args: Any = None,
  2565. kwargs: Any = None,
  2566. name: Optional[str] = None,
  2567. type_expr: Optional[Any] = None,
  2568. ) -> fx.Node:
  2569. check_pt2_compliant_op(self.output_graph, op, target, args, kwargs)
  2570. if self.parent is not None:
  2571. flat_args = pytree.arg_tree_leaves(*args, **kwargs)
  2572. for arg in flat_args:
  2573. if not isinstance(arg, torch.fx.Node):
  2574. continue
  2575. assert arg.graph == self.graph, (
  2576. "create_node using arg not from this SubgraphTracer"
  2577. )
  2578. node = super().create_node(op, target, args, kwargs, name, type_expr)
  2579. node.meta["creation_timestamp"] = self.output_graph.timestamp
  2580. self._used_names.add(node.name)
  2581. return node
  2582. # Note: we did not override erase_node since
  2583. # we call self.graph.erase_node elsewhere
  2584. def remove_node(self, node: fx.Node) -> None:
  2585. if len(node.users) > 0:
  2586. user_graph_nodes: list[torch.fx.Node] = []
  2587. for user in node.users.keys():
  2588. # For the case where user.graph == self.graph, that is a real bug and will raise
  2589. # properly.
  2590. if user.graph != self.graph:
  2591. # This is a nested graph, which needs to be deleted.
  2592. # If we do not do this, we will raise on attempting to remove this.
  2593. # As we only get here during restoration cleanup, this is sound.
  2594. user_graph_nodes.extend(reversed(list(user.graph.nodes)))
  2595. for other_graph_node in user_graph_nodes:
  2596. other_graph_node.graph.erase_node(other_graph_node)
  2597. self.graph.erase_node(node)
  2598. self.input_name_to_proxy.pop(node.name, None)
  2599. # when before=True, we will insert this input before the most recent
  2600. # inserted proxy. This is a hack to get around an ordering problem,
  2601. # where we first insert a tensor argument, and then insert bindings
  2602. # for SymInts that may occur in the tensor argument.
  2603. # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets
  2604. # fixed.
  2605. def create_graph_input(
  2606. self,
  2607. name: str,
  2608. type_expr: Any,
  2609. example_value: Any,
  2610. before: bool = False,
  2611. source: Optional[Source] = None,
  2612. ) -> fx.Proxy:
  2613. if isinstance(example_value, torch.Tensor):
  2614. self._input_versions_at_beginning.append(example_value._version)
  2615. log.debug(
  2616. "create_graph_input %s %s %s at debug_level %s before=%s",
  2617. name,
  2618. source.name() if source is not None else "(none)",
  2619. example_value,
  2620. self.debug_level,
  2621. before,
  2622. )
  2623. if source is None:
  2624. assert self.parent is not None, (
  2625. f"you are required to provide a source for inputs {name} example_val {example_value} on the root tracer"
  2626. )
  2627. # Note [Export inputs must be explicitly passed in]
  2628. # In eager, we are generally OK with adding graph inputs whenever we
  2629. # want, because we take care of writing the bytecode that knows how
  2630. # to source all the inputs.
  2631. #
  2632. # In export, this is bad, because you want a self-contained export
  2633. # object which only depends on the inputs you explicitly passed to it.
  2634. # So we are a bit more strict about what sources can become inputs
  2635. # in export
  2636. if self.is_export and self.parent is None:
  2637. assert source is not None
  2638. if not is_from_local_source(source, only_allow_input=True):
  2639. self.output_graph.source_to_user_stacks.setdefault(source, []).append(
  2640. TracingContext.extract_stack()
  2641. )
  2642. # _used_names contains the names of all the nodes in the graph,
  2643. # including intermediates. This ensures that we do not have a name
  2644. # collision.
  2645. name = get_unique_name_wrt(name, self._used_names)
  2646. if self.input_name_to_proxy:
  2647. prev_name = next(reversed(self.input_name_to_proxy))
  2648. node = self.input_name_to_proxy[prev_name].node
  2649. if before:
  2650. ctx = self.graph.inserting_before(node)
  2651. else:
  2652. ctx = self.graph.inserting_after(node)
  2653. else:
  2654. ctx = self.graph.inserting_before(None)
  2655. with ctx:
  2656. proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
  2657. set_example_value(proxy.node, example_value)
  2658. if self.input_name_to_proxy and before:
  2659. k, v = self.input_name_to_proxy.popitem()
  2660. self.input_name_to_proxy[name] = proxy
  2661. self.input_name_to_proxy[k] = v
  2662. else:
  2663. self.input_name_to_proxy[name] = proxy
  2664. # For placeholder nodes, `name` is passed as a str to the target,
  2665. # and then torch.fx decides the node.name. So, record the `target`
  2666. # name as well in the _used_names to prevent any collision.
  2667. self._used_names.add(name)
  2668. # NOTE: [Auto lift basic free symbols when create_graph_input]
  2669. # There are two sources of basic symbols:
  2670. #
  2671. # - They can come from inputs, e.g. when an input tensor is specified as dynamic. We handle
  2672. # this case by intercepting at create_graph_input. Whenever we call create_graph_input, we
  2673. # try to also lift the basic symbols in example values as graph input.
  2674. #
  2675. # 1. When create_graph_input for a tensor that has symbolic shapes,
  2676. # we look for basic symbols in its size and stride, we check if the symbol is bound
  2677. # in current graph (i.e. bound_symbols), it it's not bound, we'll create a placeholder
  2678. # for it then recursively check its parent, creates ph if not bound at parent until.
  2679. # reachting the top-level, where we require a source is attached to the proxy.
  2680. #
  2681. # 2. When create_graph_input for a tensor that contains compound exprs,
  2682. # for example, if an input to subgraph takes size [s1+s2//8], we'll look for the
  2683. # the free basic symbols in the sizes and lift all of them following 1.
  2684. #
  2685. # 3. When create_graph_input for a symint. The following invariants hold:
  2686. # a. if symint's expr is a basic symbol, we only lift it once.
  2687. # b. if symint's expr is compuned, we lift the expr as a single input. We won't lift The basic symbols
  2688. # in the compuned expr are NOT lifted. Because if the basic symbols are used inside the subgraph
  2689. # they will be lifted according to 3.a
  2690. #
  2691. # - They can come from intermediate results:
  2692. # For example, data-dependent operators such as t.item(), t.nonzero(), where basic symbols
  2693. # might be created. For this purpose, we track the basic symbols of intermediate results
  2694. # immediately after they're created at wrap_fx_proxy with track_produced_symints. Notice
  2695. # that for basic symbols that're already tracked by create_graph_input, we won't track it again.
  2696. #
  2697. # Also see NOTE: [Export inputs must be explicitly passed in]
  2698. is_strict_export = self.is_export
  2699. is_non_strict_export = torch.compiler.is_compiling()
  2700. if not is_strict_export and not is_non_strict_export:
  2701. if isinstance(example_value, torch.Tensor):
  2702. self._lift_basic_symbols(example_value, source)
  2703. elif isinstance(example_value, (list, tuple)):
  2704. for i, e in enumerate(example_value):
  2705. if not isinstance(e, torch.Tensor):
  2706. continue
  2707. e_source = None
  2708. if source:
  2709. e_source = GetItemSource(
  2710. base=source, index=i, index_is_slice=False
  2711. )
  2712. self._lift_basic_symbols(e, e_source)
  2713. # Bound the symbol to ph if example_value is a SymInt with basic symbol.
  2714. if isinstance(example_value, torch.SymInt) and isinstance(
  2715. example_value.node.expr, sympy.Symbol
  2716. ):
  2717. self.bound_symbols[example_value.node.expr] = proxy
  2718. return proxy
  2719. # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
  2720. def lift_tracked_freevar_to_input(
  2721. self, proxy: fx.Proxy
  2722. ) -> Union[LazyProxy, fx.Proxy]:
  2723. # You're doing something wrong if we are the root SubgraphTracer because
  2724. # Dynamo adds tensors to graph inputs before creating a proxy for them.
  2725. assert self.parent is not None, (
  2726. "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
  2727. )
  2728. example_value = proxy.node.meta["example_value"]
  2729. # To avoid lifting the same symbol twice, we check whether basic symbols has been tracked.
  2730. # For example, the basic symbols may have already been lifted for current subgraph when
  2731. # we automatically lift basic symbols in the sizes/strides of a tensor t.
  2732. # Suppose parent graph calls sz = t.size()[0], it creates
  2733. # a proxy in parent and the subgraph accesses sz via closure. sz's proxy is not tracked
  2734. # in current sub-tracer so we may lift the same symbol twice.
  2735. if (
  2736. isinstance(example_value, torch.SymInt)
  2737. and example_value.node.expr in self.bound_symbols
  2738. ):
  2739. return self.bound_symbols[example_value.node.expr]
  2740. # Proxies are associated with VariableTracker.
  2741. # It is possible that we've already lifted the Proxy to be an input.
  2742. # If that is the case, just return the already lifted Proxy.
  2743. if proxy in self.lifted_freevars:
  2744. return self.lifted_freevars[proxy]
  2745. # We first lift proxy to parent's graph then lift to current grpah's input
  2746. # so that when we bind symints of the sizes in current graph, those symints
  2747. # would already be lifted as inputs to parent graph.
  2748. if proxy.tracer != self.parent:
  2749. self.parent.lift_tracked_freevar_to_input(proxy)
  2750. example_value = proxy.node.meta["example_value"]
  2751. new_proxy = self.create_graph_input(
  2752. proxy.node.name, type(example_value), example_value
  2753. )
  2754. self.lifted_freevars[proxy] = new_proxy
  2755. return new_proxy
  2756. def maybe_lift_tracked_freevar_to_input(self, arg: Any) -> Any:
  2757. """
  2758. If arg is a free variable, then lift it to be an input.
  2759. Returns the new lifted arg (if arg was a freevar), else the
  2760. original arg.
  2761. """
  2762. if not isinstance(arg, torch.fx.Proxy):
  2763. # Note: arg can be a python built-in slice type e.g.
  2764. # x[:max_seq] is represented as get_item(t, (slice(None, max_seq, None)))
  2765. # we need to also look into the slice variable itself to lift the
  2766. # proxies there.
  2767. if isinstance(arg, slice):
  2768. return slice(
  2769. *(
  2770. self.maybe_lift_tracked_freevar_to_input(sub_arg)
  2771. for sub_arg in (arg.start, arg.stop, arg.step)
  2772. )
  2773. )
  2774. else:
  2775. return arg
  2776. elif arg.tracer == self:
  2777. return arg
  2778. return self.lift_tracked_freevar_to_input(arg)
  2779. # See NOTE: [Auto lift basic free symbols when create_graph_input] for overall design
  2780. # You MUST call this API every time when creating a proxy in wrap_fx_proxy for a call
  2781. # that produced symints or tensors with unbacked symint shapes.
  2782. # This function is used to track the symints with its proxies created during
  2783. # dynamo tracing so that subgraph knows how to bind a symbol input with parent's proxy.
  2784. # LazyProxy are created for tensor shapes that're unbacked so that we don't create proxies
  2785. # for symbols that're not going to be used, the LazyProxy will be turned into a proxy
  2786. # when it's lifted as input to subgraph.
  2787. def track_produced_symints(
  2788. self, example_value: Any, e_proxy: Union[LazyProxy, torch.fx.Proxy]
  2789. ) -> None:
  2790. # When binding the symbols in an exmaple_value, we bind the symbols
  2791. # to the proxy's associated Tracer instead of current tracer.
  2792. # This is because:
  2793. # 1. We may be calling wrap_tensors during speculate_subgraph because
  2794. # the variables are lazily realized. The proxy are top-level phs but
  2795. # current tracer is a subtracer.
  2796. # 2. For autograd.Function, we trace the backward graph with a new tracer
  2797. # whose parent is the forward tracer, but we're using all the proxies created
  2798. # in forward tracer to trace the backward.
  2799. # For example, forward calls save_for_backward for a input tensor t.
  2800. # Backward calls t.tolist(). In this case, all the proxies that backward tracer
  2801. # sees are from parent tracer (i.e. the forward tracer). (e.g. t[0].item())
  2802. # See test_validate_outputs_unbacked for repro on 2.
  2803. tracer = e_proxy.tracer
  2804. assert isinstance(tracer, SubgraphTracer)
  2805. def need_bind(s: Any) -> bool:
  2806. from torch.fx.experimental.symbolic_shapes import is_symbolic
  2807. return (
  2808. is_symbolic(s)
  2809. and isinstance(s.node.expr, sympy.Symbol)
  2810. and s.node.expr not in self.bound_symbols
  2811. )
  2812. def _proxy_with_example_value(
  2813. example_value: Any, *args: Any, **kwargs: Any
  2814. ) -> fx.Proxy:
  2815. # We need to insert proxy for creating sym_size/sym_stride/sym_storage right after e_proxy
  2816. nonlocal e_proxy
  2817. e_proxy = e_proxy() if isinstance(e_proxy, LazyProxy) else e_proxy
  2818. assert isinstance(e_proxy, torch.fx.Proxy)
  2819. with tracer.graph.inserting_after(e_proxy.node):
  2820. proxy = tracer.create_proxy(*args, **kwargs)
  2821. set_example_value(proxy.node, example_value)
  2822. return proxy
  2823. if isinstance(example_value, torch.Tensor):
  2824. for i, s in enumerate(example_value.size()):
  2825. if need_bind(s):
  2826. log.debug(
  2827. "track_produced_symints %s for %s.size()[%s] at debug_level %s",
  2828. s,
  2829. e_proxy,
  2830. i,
  2831. tracer.debug_level,
  2832. )
  2833. lazy_proxy = LazyProxy(
  2834. tracer,
  2835. _proxy_with_example_value,
  2836. s,
  2837. "call_function",
  2838. torch.ops.aten.sym_size.int,
  2839. (e_proxy, i),
  2840. {},
  2841. type_expr=type(s),
  2842. )
  2843. self.track_produced_symints(s, lazy_proxy)
  2844. storage_offset = example_value.storage_offset()
  2845. if need_bind(storage_offset):
  2846. log.debug(
  2847. "track_produced_symints %s for %s.storage_offset() at debug_level %s",
  2848. storage_offset,
  2849. e_proxy,
  2850. tracer.debug_level,
  2851. )
  2852. lazy_proxy = LazyProxy(
  2853. tracer,
  2854. _proxy_with_example_value,
  2855. storage_offset,
  2856. "call_function",
  2857. torch.ops.aten.sym_storage_offset,
  2858. (e_proxy,),
  2859. {},
  2860. type_expr=type(storage_offset),
  2861. )
  2862. self.track_produced_symints(storage_offset, lazy_proxy)
  2863. if example_value.layout is torch.strided:
  2864. for i, s in enumerate(example_value.stride()):
  2865. if need_bind(s):
  2866. log.debug(
  2867. "track_produced_symints %s for %s.stride()[%s] at debug_level %s",
  2868. s,
  2869. e_proxy,
  2870. i,
  2871. tracer.debug_level,
  2872. )
  2873. lazy_proxy = LazyProxy(
  2874. tracer,
  2875. _proxy_with_example_value,
  2876. s,
  2877. "call_function",
  2878. torch.ops.aten.sym_stride.int,
  2879. (e_proxy, i),
  2880. {},
  2881. type_expr=type(s),
  2882. )
  2883. self.track_produced_symints(s, lazy_proxy)
  2884. elif example_value.layout is torch.sparse_coo:
  2885. self.track_produced_symints(example_value._indices(), e_proxy)
  2886. self.track_produced_symints(example_value._values(), e_proxy)
  2887. elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
  2888. self.track_produced_symints(example_value.crow_indices(), e_proxy)
  2889. self.track_produced_symints(example_value.col_indices(), e_proxy)
  2890. elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
  2891. self.track_produced_symints(example_value.ccol_indices(), e_proxy)
  2892. self.track_produced_symints(example_value.row_indices(), e_proxy)
  2893. if is_traceable_wrapper_subclass(example_value):
  2894. attrs, ctx = example_value.__tensor_flatten__()
  2895. for attr in attrs:
  2896. inner_t = getattr(example_value, attr)
  2897. self.track_produced_symints(inner_t, getattr(e_proxy, attr))
  2898. elif isinstance(example_value, torch.SymInt):
  2899. if need_bind(example_value):
  2900. expr = example_value.node.expr
  2901. tracer.bound_symbols[expr] = e_proxy
  2902. # See Note [Auto lift basic free symbols when create_graph_input]
  2903. def _lift_basic_symbols(
  2904. self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source]
  2905. ) -> None:
  2906. # The before arg is for inserting symints in the sizes/strides of a tensor
  2907. # before the tensor. This ordering ensures that when we look at the tensor's
  2908. # symbols, they're already lifted/tracked. E.g. this assumption is used
  2909. # in insert_deferred_runtime_asserts.
  2910. def _lift_symbols_in_symint(
  2911. s: Union[int, torch.SymInt],
  2912. source: Optional[Source],
  2913. before: bool = False,
  2914. ) -> None:
  2915. if not is_symbolic(s):
  2916. return
  2917. assert isinstance(s, torch.SymInt)
  2918. self_to_be_bound = self.lookup_unbound_symbols(s)
  2919. if len(self_to_be_bound) == 0:
  2920. return
  2921. # For subgraph
  2922. if self.parent is not None:
  2923. # Recursively lift symbols in symint until top-level.
  2924. self.parent._lift_basic_symbols(s, source)
  2925. for s0 in self_to_be_bound:
  2926. parent_proxy = self.parent.bound_symbols[s0]
  2927. example_val = parent_proxy.node.meta["example_value"] # type: ignore[union-attr]
  2928. assert isinstance(example_val, torch.SymInt)
  2929. ph = self.create_graph_input(
  2930. str(s0),
  2931. type(example_val),
  2932. example_val,
  2933. before=before,
  2934. source=source,
  2935. )
  2936. log.debug(
  2937. "_lift_symbols_in_symint %s from %s at debug_level %s",
  2938. s0,
  2939. source.name() if source is not None else "subgraph inputs",
  2940. self.debug_level,
  2941. )
  2942. self.lifted_freevars[parent_proxy] = ph # type: ignore[index]
  2943. # For root_tracer:
  2944. else:
  2945. assert len(self_to_be_bound) == 1, (
  2946. f"For root tracer, we only expect to bind basic symbols (compound symbols "
  2947. f"should be cached before) but got unbound symbols {self_to_be_bound} in {s}"
  2948. )
  2949. assert source is not None, (
  2950. f"Source of '{s}' is None when lifting it to input of top-level. If it's an unbacked symbol, "
  2951. "this could be because it's not tracked with lazy_bind_unbacked_symbols. "
  2952. f"Otherwise, should provide a source when create_graph_input for `{s}` at root tracer."
  2953. )
  2954. s0 = next(iter(self_to_be_bound))
  2955. ph = self.create_graph_input(
  2956. str(s0),
  2957. type(s),
  2958. s,
  2959. before=before,
  2960. source=source,
  2961. )
  2962. log.debug(
  2963. "_lift_symbols_in_symint %s from %s at debug_level %s",
  2964. s,
  2965. source.name() if source is not None else "subgraph inputs",
  2966. self.debug_level,
  2967. )
  2968. ph.node.meta["grapharg"] = GraphArg(
  2969. source,
  2970. s,
  2971. pass_arg_as_tensor=False,
  2972. fake_tensor=None,
  2973. is_tensor=False,
  2974. )
  2975. if isinstance(example_value, torch.Tensor):
  2976. for i, s in enumerate(example_value.size()):
  2977. _lift_symbols_in_symint(
  2978. s,
  2979. (
  2980. TensorPropertySource(src, TensorProperty.SIZE, i)
  2981. if src is not None
  2982. else None
  2983. ),
  2984. before=True,
  2985. )
  2986. if example_value.layout is torch.strided:
  2987. for i, s in enumerate(example_value.stride()):
  2988. _lift_symbols_in_symint(
  2989. s,
  2990. (
  2991. TensorPropertySource(src, TensorProperty.STRIDE, i)
  2992. if src is not None
  2993. else None
  2994. ),
  2995. before=True,
  2996. )
  2997. _lift_symbols_in_symint(
  2998. example_value.storage_offset(),
  2999. (
  3000. TensorPropertySource(src, TensorProperty.STORAGE_OFFSET)
  3001. if src is not None
  3002. else None
  3003. ),
  3004. before=True,
  3005. )
  3006. elif example_value.layout is torch.sparse_coo:
  3007. self._lift_basic_symbols(example_value._indices(), src)
  3008. self._lift_basic_symbols(example_value._values(), src)
  3009. elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
  3010. self._lift_basic_symbols(example_value.crow_indices(), src)
  3011. self._lift_basic_symbols(example_value.col_indices(), src)
  3012. elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
  3013. self._lift_basic_symbols(example_value.ccol_indices(), src)
  3014. self._lift_basic_symbols(example_value.row_indices(), src)
  3015. if is_traceable_wrapper_subclass(example_value):
  3016. attrs, ctx = example_value.__tensor_flatten__()
  3017. for attr in attrs:
  3018. inner_t = getattr(example_value, attr)
  3019. self._lift_basic_symbols(
  3020. inner_t, AttrSource(src, attr) if src is not None else None
  3021. )
  3022. elif isinstance(example_value, torch.SymInt):
  3023. _lift_symbols_in_symint(
  3024. example_value,
  3025. src,
  3026. )
  3027. # Lookup the proxy in current tracer for each symbol in expressions of s,
  3028. # See Note [Auto lift basic free symbols when create_graph_input]
  3029. def lookup_unbound_symbols(self, s: torch.SymInt) -> list[sympy.Symbol]:
  3030. free_symbols = s.node.expr.free_symbols
  3031. if len(free_symbols) == 0:
  3032. return []
  3033. to_be_bound = []
  3034. for s0 in free_symbols:
  3035. if s0 not in self.bound_symbols:
  3036. to_be_bound.append(s0)
  3037. continue
  3038. proxy = self.bound_symbols[s0]
  3039. if isinstance(proxy, LazyProxy):
  3040. proxy = proxy()
  3041. self.bound_symbols[s0] = proxy
  3042. assert isinstance(proxy, torch.fx.Proxy) and proxy.tracer is self, (
  3043. f"The proxy of symbol {s0} doesn't belong to current tracer."
  3044. )
  3045. # Sort the symbols so that we can have a deterministic lifting order
  3046. return sorted(to_be_bound, key=lambda s: s.name)
  3047. def has_input_mutation(self) -> MutationInfo:
  3048. input_versions_at_beginning = self._input_versions_at_beginning
  3049. input_nodes = []
  3050. input_versions_at_end = []
  3051. for node in self.graph.nodes:
  3052. if node.op == "placeholder":
  3053. example_value = node.meta["example_value"]
  3054. if isinstance(example_value, torch.Tensor):
  3055. input_versions_at_end.append(example_value._version)
  3056. input_nodes.append(node)
  3057. else:
  3058. break
  3059. mutated_inputs = [
  3060. i
  3061. for i, (v1, v2) in enumerate(
  3062. zip(input_versions_at_beginning, input_versions_at_end)
  3063. )
  3064. if v1 != v2
  3065. ]
  3066. if len(mutated_inputs):
  3067. mutated_nodes = [input_nodes[i] for i in mutated_inputs]
  3068. msg = f"Input mutation detected at {mutated_nodes}"
  3069. return MutationInfo(True, msg)
  3070. return MutationInfo(False, "")
  3071. def has_aliasing(self) -> AliasingInfo:
  3072. from torch._higher_order_ops.utils import _collect_fake_inputs
  3073. input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
  3074. for node in self.graph.nodes:
  3075. if node.op == "placeholder":
  3076. example_value = _collect_fake_inputs([node])[0]
  3077. if isinstance(example_value, torch.Tensor):
  3078. storage = StorageWeakRef(example_value._typed_storage())
  3079. if storage in input_storages:
  3080. # input-input aliasing
  3081. msg = f"Input-to-input aliasing detected at nodes {input_storages[storage]} and {node}"
  3082. return AliasingInfo(True, msg)
  3083. input_storages[storage] = node
  3084. else:
  3085. break
  3086. output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
  3087. out_nodes = self.graph.find_nodes(op="output")[0]
  3088. for out_node in pytree.tree_leaves(out_nodes.args[0]):
  3089. if out_node:
  3090. example_value = _collect_fake_inputs([out_node])[0]
  3091. assert not isinstance(example_value, list)
  3092. if isinstance(example_value, torch.Tensor):
  3093. storage = StorageWeakRef(example_value._typed_storage())
  3094. if storage in output_storages:
  3095. # output-output aliasing
  3096. msg = f"Output-to-output aliasing detected at nodes {output_storages[storage]} and {out_node}"
  3097. return AliasingInfo(True, msg)
  3098. output_storages[storage] = out_node
  3099. intersected_storages = input_storages.keys() & output_storages.keys()
  3100. if len(intersected_storages) > 0:
  3101. # input-output aliasing
  3102. aliased = [
  3103. (input_storages[s], output_storages[s]) for s in intersected_storages
  3104. ]
  3105. aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
  3106. msg = f"Input-to-output aliasing detected at nodes {aliased}"
  3107. return AliasingInfo(True, msg)
  3108. return AliasingInfo(False, "")
  3109. # NOTE: [HigherOrderOperator tracing design]
  3110. # Ignoring HigherOrderOperators for a moment,
  3111. # OutputGraph represents the graph being built by Dynamo that may be compiled
  3112. # and executed. It holds a root SubgraphTracer where the FX graph is built.
  3113. #
  3114. # HigherOrderOperators are operators that take functions as their arguments.
  3115. # When Dynamo encounters a HigherOrderOperator, then it attempts to introspect
  3116. # the function passed to it (call this the "body function"), capture it into a
  3117. # GraphModule, and rewrite the call to the HigherOrderOperator to use the
  3118. # GraphModule.
  3119. #
  3120. # The way we handle the capture of body functions is through having
  3121. # (possibly nested) SubgraphTracers, one per body function.
  3122. #
  3123. # Mechanically, we do the introspection by:
  3124. # - Creating a new SubgraphTracer via OutputGraph.subtracer
  3125. # - Executing the body function.
  3126. # This constructs the graph of the body function in the new SubgraphTracer
  3127. # while modifying the state of the OutputGraph. For example:
  3128. # - the OutputGraph can receive new GraphArgs (if we discover any new
  3129. # untracked Tensors)
  3130. # - side effects from the body function get accumulated into
  3131. # OutputGraph.side_effects
  3132. # - guards produced by the body function get accumulated into OutputGraph.guards
  3133. #
  3134. # The traced function has some special properties that make it easier for us
  3135. # to transform later down the line:
  3136. # - we lift all free variables to being inputs.
  3137. #
  3138. # If the introspection fails (due to the existence of graph breaks), then
  3139. # we roll back the current OutputGraph state and graph break on the
  3140. # HigherOrderOperator.