eval_frame.py 90 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350
  1. # mypy: disable-error-code="method-assign"
  2. """
  3. This module implements the core frame evaluation handler for TorchDynamo's compilation system.
  4. The eval frame handler intercepts Python bytecode execution at runtime to enable dynamic
  5. compilation and optimization of PyTorch code.
  6. Key components defined here:
  7. - Frame evaluation handlers that intercept and analyze Python execution frames
  8. - Guards management for tracking dependencies and invalidating compiled code
  9. - Optimization contexts and decorators (optimize, run_once, disable, etc.)
  10. - Export functionality for saving optimized graphs
  11. - Backend compiler integrations and callback management
  12. Functions in this file are responsible for modifying the eval frame handler at RUNTIME.
  13. Therefore, all functions in this file are hot and performance-critical. Functions that
  14. only execute at compile time should be placed in torch._dynamo.convert_frame.
  15. The eval frame handler is the core mechanism that enables TorchDynamo to dynamically
  16. intercept, analyze and optimize PyTorch code during execution. It works by registering
  17. a custom frame evaluation function that gets called for every Python frame, allowing
  18. us to detect PyTorch operations and trigger compilation as needed.
  19. """
  20. from __future__ import annotations
  21. import atexit
  22. import contextlib
  23. import functools
  24. import inspect
  25. import logging
  26. import os
  27. import sys
  28. import sysconfig
  29. import textwrap
  30. import threading
  31. import traceback
  32. import types
  33. import unittest
  34. import warnings
  35. import weakref
  36. from dataclasses import dataclass
  37. from enum import Enum
  38. from os.path import dirname, join
  39. from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
  40. from unittest.mock import patch
  41. import sympy
  42. import torch
  43. import torch.fx
  44. import torch.utils._pytree as pytree
  45. import torch.utils.checkpoint
  46. from torch import _guards
  47. # see discussion at https://github.com/pytorch/pytorch/issues/120699
  48. from torch._C._dynamo.eval_frame import ( # noqa: F401
  49. reset_code,
  50. set_code_exec_strategy,
  51. set_eval_frame,
  52. set_guard_complete_hook,
  53. set_guard_error_hook,
  54. set_skip_guard_eval_unsafe,
  55. unsupported,
  56. )
  57. from torch._dispatch.python import enable_python_dispatcher
  58. from torch._dynamo.types import ConvertFrameReturn, FrameAction, FrameExecStrategy
  59. from torch._export.utils import _compiling_state_context
  60. from torch._subclasses.fake_tensor import unset_fake_temporarily
  61. from torch._utils_internal import justknobs_check, log_export_usage
  62. from torch.export.dynamic_shapes import (
  63. _combine_args,
  64. _DimHint,
  65. _DimHintType,
  66. _IntWrapper,
  67. _process_dynamic_shapes,
  68. _RelaxedConstraint,
  69. Constraint,
  70. )
  71. from torch.fx import GraphModule
  72. from torch.fx.experimental._dynamism import (
  73. clone_and_convert_to_meta,
  74. track_dynamism_across_examples,
  75. )
  76. from torch.fx.experimental.proxy_tensor import make_fx
  77. from torch.fx.experimental.symbolic_shapes import (
  78. ConstraintViolationError,
  79. DimDynamic,
  80. ShapeEnv,
  81. StatelessSymbolicContext,
  82. )
  83. from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
  84. from . import config, convert_frame, distributed, external_utils, trace_rules, utils
  85. from .backends.registry import CompilerFn, lookup_backend
  86. from .code_context import code_context
  87. from .exc import (
  88. CondOpArgsMismatchError,
  89. ShortenTraceback,
  90. Unsupported,
  91. UserError,
  92. UserErrorType,
  93. )
  94. from .hooks import Hooks
  95. from .mutation_guard import install_generation_tagging_init
  96. from .utils import (
  97. _get_error_on_graph_break,
  98. _set_error_on_graph_break,
  99. common_constant_types,
  100. compile_times,
  101. )
  102. if TYPE_CHECKING:
  103. from collections.abc import Iterable, Sequence
  104. from torch._dynamo.package import CompilePackage
  105. from torch._dynamo.repro.after_dynamo import WrapBackendDebug
  106. from torch._subclasses import fake_tensor
  107. from torch.fx.node import Argument, Node, Target
  108. from .types import (
  109. CacheEntry,
  110. DynamoCallback,
  111. DynamoFrameType,
  112. GuardFail,
  113. GuardFilterEntry,
  114. )
  115. log = logging.getLogger(__name__)
  116. always_optimize_code_objects = utils.ExactWeakKeyDictionary()
  117. null_context = contextlib.nullcontext
  118. # See https://github.com/python/typing/pull/240
  119. class Unset(Enum):
  120. token = 0
  121. cached_backends: dict[int, CompilerFn] = {}
  122. unset = Unset.token
  123. def _maybe_set_eval_frame(callback: DynamoCallback) -> DynamoCallback:
  124. # A wrapper on set_eval_frame that is guarded by a Justknob.
  125. # Users can disable torchDynamo by setting the JK to False.
  126. if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
  127. torch._dynamo.utils.warn_once(
  128. "Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame"
  129. )
  130. return callback
  131. else:
  132. return set_eval_frame(callback)
  133. @dataclass
  134. class DynamoStance:
  135. stance: str = "default"
  136. skip_guard_eval_unsafe: bool = False
  137. backend: Union[str, Callable[..., Any], None] = None
  138. _stance = DynamoStance()
  139. def _set_stance(stance: DynamoStance) -> DynamoStance:
  140. global _stance
  141. from torch._C._dynamo.eval_frame import get_eval_frame_callback
  142. callback = get_eval_frame_callback()
  143. if callback is not False and callback is not None:
  144. raise RuntimeError("attempted to set_stance in a torch.compile region")
  145. prior = _stance
  146. _stance = stance
  147. return prior
  148. _set_stance._dynamo_forbidden = True # type: ignore[attr-defined]
  149. _EXAMPLE_INPUTS: Optional[dict[str, list[Any]]] = None
  150. def get_example_inputs(key: str) -> list[Any]:
  151. global _EXAMPLE_INPUTS
  152. if _EXAMPLE_INPUTS is None:
  153. _EXAMPLE_INPUTS = {}
  154. if key not in _EXAMPLE_INPUTS:
  155. _EXAMPLE_INPUTS[key] = []
  156. return _EXAMPLE_INPUTS[key]
  157. def _callback_from_stance(callback: DynamoCallback) -> DynamoCallback:
  158. if _stance.stance == "default":
  159. # force_backend
  160. if _stance.backend is not None and callback not in (False, None):
  161. callback = _create_wrapped_callback(get_compiler_fn(_stance.backend))
  162. return callback
  163. elif _stance.stance == "eager_then_compile":
  164. if callback not in (False, None):
  165. return _create_delayed_compile_callback(callback, _stance.stance)
  166. return callback
  167. elif _stance.stance == "aot_eager_then_compile":
  168. if callback not in (False, None):
  169. return _create_delayed_compile_callback(callback, _stance.stance)
  170. return callback
  171. elif _stance.stance == "force_eager":
  172. # disable
  173. return None
  174. elif _stance.stance == "eager_on_recompile":
  175. # run mode
  176. return False
  177. elif _stance.stance == "fail_on_recompile":
  178. if callback in (False, None):
  179. return callback
  180. def fail_callback(
  181. frame: DynamoFrameType, *args: Any, **kwargs: Any
  182. ) -> ConvertFrameReturn:
  183. if trace_rules.check(frame.f_code):
  184. return ConvertFrameReturn()
  185. if not convert_frame.has_tensor_in_frame(frame):
  186. return ConvertFrameReturn()
  187. from torch._C._dynamo.eval_frame import _debug_get_precompile_entries
  188. message = (
  189. "Detected recompile when torch.compile stance is 'fail_on_recompile'. "
  190. + f"filename: '{frame.f_code.co_filename}', "
  191. + f"function name: '{frame.f_code.co_name}', "
  192. + f"line number: {frame.f_lineno}"
  193. )
  194. precompile_entries = _debug_get_precompile_entries(frame.f_code)
  195. if len(precompile_entries) > 0:
  196. message += "\nFailed on the following precompiled guards: "
  197. for entry in precompile_entries:
  198. message += f"\n{entry.guard_manager}{entry.guard_manager.check_verbose(frame.f_locals)}" # type: ignore[attr-defined]
  199. raise RuntimeError(message)
  200. # to prevent cache miss due to different backend
  201. fail_callback._torchdynamo_orig_backend = callback # type: ignore[attr-defined]
  202. return fail_callback
  203. else:
  204. raise RuntimeError(f"invalid torch.compile stance '{_stance}'")
  205. def _create_wrapped_callback(
  206. compiler_fn: CompilerFn,
  207. ) -> convert_frame.CatchErrorsWrapper:
  208. hooks = Hooks()
  209. return convert_frame.catch_errors_wrapper(
  210. convert_frame.convert_frame( # type: ignore[arg-type]
  211. compiler_fn,
  212. hooks,
  213. ),
  214. hooks,
  215. )
  216. def _get_or_add_example_inputs(frame: DynamoFrameType) -> list[Any]:
  217. key = frame.f_code.co_filename + str(frame.f_code.co_firstlineno)
  218. example_inputs = get_example_inputs(key)
  219. if len(example_inputs) < 2:
  220. example_inputs.append(clone_and_convert_to_meta(frame.f_locals))
  221. return example_inputs
  222. def _create_delayed_compile_callback(
  223. callback: DynamoCallback, stance: str
  224. ) -> Callable[..., Any]:
  225. def callback_fn(*args: Any, **kwargs: Any) -> convert_frame.ConvertFrameReturn:
  226. frame = args[0]
  227. example_inputs = _get_or_add_example_inputs(frame)
  228. if len(example_inputs) == 1:
  229. if stance == "eager_then_compile":
  230. return ConvertFrameReturn(
  231. frame_exec_strategy=FrameExecStrategy(
  232. FrameAction.DEFAULT, FrameAction.DEFAULT
  233. )
  234. )
  235. elif stance == "aot_eager_then_compile":
  236. aot_eager_fn = get_compiler_fn("aot_eager")
  237. return _create_wrapped_callback(aot_eager_fn)(*args, **kwargs)
  238. dynamism = track_dynamism_across_examples(example_inputs)
  239. code_context.get_context(frame.f_code)["dynamism"] = dynamism
  240. compiler_fn = callback._torchdynamo_orig_backend._torchdynamo_orig_backend # type: ignore[union-attr]
  241. return _create_wrapped_callback(compiler_fn)(*args, **kwargs)
  242. # to prevent cache miss due to different backend
  243. callback_fn._torchdynamo_orig_backend = callback # type: ignore[attr-defined]
  244. return callback_fn
  245. def _is_skip_guard_eval_unsafe_stance() -> bool:
  246. return _stance.skip_guard_eval_unsafe
  247. def _reset_guarded_backend_cache() -> None:
  248. global cached_backends
  249. for backend in cached_backends.values():
  250. if hasattr(backend, "reset"):
  251. backend.reset()
  252. cached_backends.clear()
  253. DONT_WRAP_FILES = {
  254. # For tracing into fx modules
  255. inspect.getsourcefile(GraphModule),
  256. join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"),
  257. }
  258. def _debug_get_cache_entry_list(
  259. code: Union[types.CodeType, Callable[..., Any]],
  260. ) -> list[CacheEntry]:
  261. """
  262. Given a code object or a callable object, retrieve the cache entries
  263. stored in this code.
  264. """
  265. if callable(code):
  266. code = code.__code__
  267. return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
  268. class OptimizedModule(torch.nn.Module):
  269. """
  270. Wraps the original nn.Module object and later patches its
  271. forward method to optimized self.forward method.
  272. """
  273. _torchdynamo_orig_callable: Callable[..., Any]
  274. get_compiler_config: Callable[[], Any]
  275. _opt_mod_attributes = {
  276. "_orig_mod",
  277. "dynamo_ctx",
  278. "_torchdynamo_orig_callable",
  279. "get_compiler_config",
  280. "forward",
  281. "_forward",
  282. "__dict__",
  283. "named_children_walk",
  284. "_super_module_initialized",
  285. }
  286. def __init__(self, mod: torch.nn.Module, dynamo_ctx: _TorchDynamoContext) -> None:
  287. # NOTE: this must go first, because attribute reads/writes of `self`
  288. # uses `_orig_mod`, and sometimes users override `Module.__init__` to
  289. # do attribute reads/writes on `self`.
  290. #
  291. # We also can't use regular setattr because `super().__setattr__` will
  292. # complain for module value before `super().__init__()`
  293. object.__setattr__(self, "_orig_mod", mod)
  294. self._super_module_initialized = False
  295. super().__init__()
  296. self._super_module_initialized = True
  297. # Installs the params/buffer
  298. self._orig_mod = mod # `super().__setattr__` will register this module
  299. self.dynamo_ctx = dynamo_ctx
  300. self._initialize()
  301. self.training = self._orig_mod.training
  302. def _initialize(self) -> None:
  303. # Do this stuff in constructor to lower overhead slightly
  304. if isinstance(self.dynamo_ctx, DisableContext):
  305. # No need to check trace rules
  306. self.forward = self.dynamo_ctx(self._orig_mod.__call__)
  307. elif config.wrap_top_frame or (
  308. isinstance(self._orig_mod.forward, types.MethodType)
  309. and (
  310. trace_rules.check(self._orig_mod.forward)
  311. or getattr(self._orig_mod, "_is_fsdp_managed_module", False)
  312. )
  313. ):
  314. # This may be a torch.nn.* instance in trace_rules.py which
  315. # won't trigger a frame evaluation workaround to add an extra
  316. # frame we can capture
  317. self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod))
  318. else:
  319. # Invoke hooks outside of dynamo then pickup the inner frame
  320. self.forward = self.dynamo_ctx(self._orig_mod.__call__)
  321. if hasattr(self._orig_mod, "_initialize_hook"):
  322. self._forward = self.forward
  323. self.forward = self._call_lazy_check
  324. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  325. if torch.nn.modules.module._has_any_global_hook():
  326. warnings.warn(
  327. "Using `torch.compile(module)` when there are global hooks on "
  328. "modules (e.g., from `register_module_forward_hook`); this will"
  329. " cause the hooks to fire an extra time for the "
  330. "`OptimizedModule` created by `torch.compile(module)`. If this "
  331. "causes undesired behavior, please try using `module.compile()`"
  332. ", or use the per-module hooks instead",
  333. stacklevel=2,
  334. )
  335. return super().__call__(*args, **kwargs)
  336. def __reduce__(
  337. self,
  338. ) -> tuple[type[OptimizedModule], tuple[torch.nn.Module, _TorchDynamoContext]]:
  339. return (self.__class__, (self._orig_mod, self.dynamo_ctx))
  340. def __getstate__(self) -> dict[str, Any]:
  341. state = dict(self.__dict__)
  342. state.pop("forward", None)
  343. state.pop("__call__", None)
  344. return state
  345. def __setstate__(self, state: dict[str, Any]) -> None:
  346. self.__dict__ = state
  347. self._initialize()
  348. @property
  349. def training(self) -> bool:
  350. return self._orig_mod.training
  351. @training.setter
  352. def training(self, value: bool) -> None:
  353. # Ignore the `training` mutation in `super().__init__()`, since that's
  354. # setting the default on `nn.Module`, but we are mirroring the
  355. # `training` attr in `self._orig_mod`.
  356. if self._super_module_initialized:
  357. self._orig_mod.training = value
  358. def __getattr__(self, name: str) -> Any:
  359. if name == "_orig_mod":
  360. return self._modules["_orig_mod"]
  361. return getattr(self._orig_mod, name)
  362. def __setattr__(self, name: str, val: Any) -> None:
  363. # Allow patching over class attributes
  364. if hasattr(type(self), name):
  365. return super().__setattr__(name, val)
  366. if name in OptimizedModule._opt_mod_attributes:
  367. return super().__setattr__(name, val)
  368. return setattr(self._orig_mod, name, val)
  369. def __delattr__(self, name: str) -> None:
  370. # This mirrors `__setattr__`
  371. if hasattr(type(self), name):
  372. return super().__delattr__(name)
  373. if name in OptimizedModule._opt_mod_attributes:
  374. return super().__delattr__(name)
  375. return delattr(self._orig_mod, name)
  376. def _call_lazy_check(self, *args: Any, **kwargs: Any) -> Any:
  377. if (
  378. hasattr(self._orig_mod, "_initialize_hook")
  379. and hasattr(self._orig_mod, "_infer_parameters")
  380. and callable(self._orig_mod._infer_parameters)
  381. ):
  382. # In the case of a lazy module, we want to run
  383. # the pre-hooks which initialize it.
  384. # Afterwards, lazy module deletes its pre-hooks
  385. # to avoid treating it as lazy on subsequent recompile.
  386. self._orig_mod._infer_parameters(self._orig_mod, args, kwargs)
  387. return self._forward(*args, **kwargs)
  388. def __dir__(self) -> list[str]:
  389. orig_mod_attrs = self._orig_mod.__dir__()
  390. return orig_mod_attrs + [
  391. attr for attr in super().__dir__() if attr not in orig_mod_attrs
  392. ]
  393. def remove_from_cache(f: Any) -> None:
  394. """
  395. Make sure f.__code__ is not cached to force a recompile
  396. """
  397. if isinstance(f, types.CodeType):
  398. reset_code(f)
  399. elif hasattr(f, "__code__"):
  400. reset_code(f.__code__)
  401. elif hasattr(getattr(f, "forward", None), "__code__"):
  402. reset_code(f.forward.__code__)
  403. else:
  404. from . import reset # type: ignore[attr-defined]
  405. reset()
  406. log.warning("could not determine __code__ for %s", f)
  407. def nothing() -> None:
  408. pass
  409. def always_false() -> bool:
  410. return False
  411. def innermost_fn(
  412. fn: Callable[..., Any], unaltered_fn_attr: str = "_torchdynamo_orig_callable"
  413. ) -> Callable[..., Any]:
  414. """
  415. In case of nesting of _TorchDynamoContext calls, find the innermost
  416. function. TorchDynamo caches on fn.__code__ object, so its necessary to find
  417. the innermost function to pass on the optimize, run, disable etc.
  418. """
  419. unaltered_fn = fn
  420. while hasattr(unaltered_fn, unaltered_fn_attr):
  421. unaltered_fn = getattr(unaltered_fn, unaltered_fn_attr)
  422. assert callable(unaltered_fn), (
  423. f"A callable function is expected, but {type(unaltered_fn)} is provided."
  424. )
  425. return unaltered_fn
  426. def make_set_enable_dynamic(enable: bool) -> Any:
  427. assert isinstance(enable, bool)
  428. if enable:
  429. # Assume everything is dynamic by default
  430. return config._make_closure_patcher(assume_static_by_default=False)
  431. else:
  432. return config._make_closure_patcher(
  433. automatic_dynamic_shapes=False, assume_static_by_default=True
  434. )
  435. # A thread local storage that serves to store information as Dynamo traces
  436. # through a user provided function.
  437. class DynamoTLS(threading.local):
  438. # Each string is a summary of a frame Dynamo attempted to trace, stored in
  439. # temporal order.
  440. traced_frame_infos: list[str] = []
  441. dynamo_tls = DynamoTLS()
  442. def clear_dynamo_tls() -> None:
  443. dynamo_tls.traced_frame_infos.clear()
  444. @atexit.register
  445. def _log_traced_frames() -> None:
  446. """
  447. At program exit, log all of the frames Dynamo has attempted to trace from,
  448. excluding the continuation frames generated by Dynamo.
  449. """
  450. msg = "\n".join(dynamo_tls.traced_frame_infos)
  451. msg = textwrap.indent(msg, " * ")
  452. msg = f"TorchDynamo attempted to trace the following frames: [\n{msg}\n]"
  453. log.info(msg)
  454. def guard_collectives_hook(guard_eval_result: bool) -> bool:
  455. import torch.distributed as dist
  456. from torch._dynamo.utils import dynamo_timed
  457. # guard_eval_result == True ==> cache hit
  458. if pg := distributed.get_guard_pg():
  459. with dynamo_timed(
  460. "guard_collective", log_pt2_compile_event=False, log_waitcounter=True
  461. ):
  462. log.debug("guard_collective %s", guard_eval_result)
  463. # TODO: a bit awkward to time, this isn't inside of the dynamo compile region
  464. all_results = [None] * pg.size()
  465. dist.all_gather_object(all_results, guard_eval_result, group=pg)
  466. # True = everyone hit, OK to run
  467. # False = someone missed, force recompile everywhere
  468. res = all(all_results)
  469. log.debug("guard_collective %s -> %s", guard_eval_result, res)
  470. return res
  471. return guard_eval_result
  472. _not_set = object()
  473. class _TorchDynamoContext:
  474. def __init__(
  475. self,
  476. callback: DynamoCallback,
  477. on_enter: Callable[[], Any] = nothing,
  478. backend_ctx_ctor: Callable[
  479. [], contextlib.AbstractContextManager[Any]
  480. ] = null_context,
  481. patch_fn: Callable[[], Any] = nothing,
  482. first_ctx: bool = False,
  483. *,
  484. fullgraph: bool = False,
  485. error_on_graph_break: Optional[bool] = None,
  486. export: bool = False,
  487. dynamic: Optional[bool] = None,
  488. compiler_config: Optional[Any] = None,
  489. package: Optional[CompilePackage] = None,
  490. hooks: Optional[Hooks] = None,
  491. ) -> None:
  492. super().__init__()
  493. assert callable(callback) or callback is False or callback is None
  494. self.callback: DynamoCallback = callback
  495. self._backend_ctx_ctor = backend_ctx_ctor
  496. self.prior: Union[Unset, DynamoCallback] = unset
  497. self.first_ctx = first_ctx
  498. self.fullgraph = fullgraph
  499. self.error_on_graph_break = error_on_graph_break
  500. self.export = export
  501. self._dynamic = dynamic
  502. self.compiler_config = compiler_config
  503. self.cleanup_fns: list[Callable[[], Any]] = []
  504. self.enter_exit_hooks = []
  505. self._package = package
  506. self._hooks = hooks
  507. patch_fn()
  508. # Save the backends so that we can reset them during torch._dynamo.reset
  509. backend = innermost_fn(callback, unaltered_fn_attr="_torchdynamo_orig_backend") # type: ignore[arg-type]
  510. cached_backends.setdefault(id(backend), backend) # type: ignore[arg-type]
  511. if dynamic is not None:
  512. self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
  513. if on_enter is not nothing:
  514. # this case is not common
  515. def call_on_enter() -> Callable[[], None]:
  516. on_enter()
  517. return nothing
  518. self.enter_exit_hooks.append(call_on_enter)
  519. if backend_ctx_ctor is not contextlib.nullcontext:
  520. # this case is not common
  521. def call_backend_ctx() -> functools.partial[Optional[bool]]:
  522. ctx = backend_ctx_ctor()
  523. ctx.__enter__()
  524. return functools.partial(ctx.__exit__, None, None, None)
  525. self.enter_exit_hooks.append(call_backend_ctx)
  526. def __enter__(self) -> None:
  527. if config.raise_on_ctx_manager_usage:
  528. raise RuntimeError(
  529. "torch._dynamo.optimize(...) is used with a context manager. "
  530. "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html "
  531. "to use torch._dynamo.optimize(...) as an annotation/decorator. "
  532. )
  533. self.prior = set_eval_frame(None)
  534. self.cleanup_fns = [enter() for enter in self.enter_exit_hooks]
  535. self.prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
  536. _is_skip_guard_eval_unsafe_stance()
  537. )
  538. _maybe_set_eval_frame(_callback_from_stance(self.callback))
  539. def __exit__(
  540. self,
  541. exc_type: Optional[type[BaseException]],
  542. exc_val: Optional[BaseException],
  543. exc_tb: Optional[types.TracebackType],
  544. ) -> Optional[bool]:
  545. assert self.prior is not unset
  546. set_eval_frame(None)
  547. set_skip_guard_eval_unsafe(self.prior_skip_guard_eval_unsafe)
  548. for cleanup in self.cleanup_fns:
  549. cleanup()
  550. self.cleanup_fns.clear()
  551. _maybe_set_eval_frame(_callback_from_stance(self.prior))
  552. self.prior = unset
  553. return None
  554. def __call__(self, fn: Any) -> Any:
  555. # public api for compiler config/options
  556. def get_compiler_config() -> Any:
  557. return self.compiler_config
  558. from .package import DynamoCache
  559. # If self._package is lazily initialized, we should check the dynamo cache now
  560. if config.caching_precompile:
  561. if self._package is not None and not self._package.is_initialized():
  562. result = DynamoCache.load(fn)
  563. if result is None:
  564. # Create a fresh CompilePackage
  565. self._package.initialize(fn, None, ignore_inlined_sources=False)
  566. else:
  567. cache_entry, backends = result
  568. try:
  569. self._package.initialize(
  570. fn, cache_entry, ignore_inlined_sources=False
  571. )
  572. self._package.install(backends)
  573. except RuntimeError as e:
  574. log.warning("Failed to load entry from dynamo cache: %s", e)
  575. self._package.initialize(fn, None, ignore_inlined_sources=False)
  576. fn = innermost_fn(fn)
  577. def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any:
  578. from torch._dynamo.aot_compile import aot_compile_fullgraph
  579. if not self.fullgraph:
  580. raise RuntimeError(
  581. "Graph breaks are not supported with aot compile. Please use torch.compile(fullgraph=True)."
  582. )
  583. if not callable(self.callback):
  584. raise RuntimeError("aot compile requires a callable dynamo callback.")
  585. assert self._hooks is not None
  586. return aot_compile_fullgraph(
  587. fn,
  588. example_inputs,
  589. hooks=self._hooks,
  590. backend=innermost_fn(
  591. self.callback, unaltered_fn_attr="_torchdynamo_orig_backend"
  592. ),
  593. )
  594. # add context containing GraphModule to any GraphModule forward functions
  595. if isinstance(fn, GraphModule):
  596. # add context containing GraphModule to any GraphModule forward functions
  597. code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = (
  598. weakref.ref(fn)
  599. )
  600. # Optimize the forward method of torch.nn.Module object
  601. if isinstance(fn, torch.nn.Module):
  602. mod = fn
  603. new_mod = OptimizedModule(mod, self)
  604. # Save the function pointer to find the original callable while nesting
  605. # of decorators.
  606. new_mod._torchdynamo_orig_callable = mod.forward
  607. # when compiling torch.nn.Module,
  608. # provide public api OptimizedModule.get_compiler_config()
  609. assert not hasattr(new_mod, "get_compiler_config")
  610. new_mod.get_compiler_config = get_compiler_config
  611. return new_mod
  612. if inspect.isclass(fn):
  613. # User has wrapped the class with compile/disable decorator. Apply
  614. # disable to init/call method.
  615. cls_obj = fn
  616. cls_obj.__call__ = self(cls_obj.__call__)
  617. if issubclass(cls_obj, torch.nn.Module):
  618. # NN module variable tracker directly inlines the _call_impl.
  619. cls_obj._call_impl = self(cls_obj._call_impl)
  620. return cls_obj
  621. assert callable(fn), (
  622. f"A callable function is expected, but {type(fn)} is provided."
  623. )
  624. try:
  625. filename = inspect.getsourcefile(fn)
  626. except TypeError:
  627. filename = None
  628. if config.debug_force_nested_calls:
  629. fn = external_utils.wrap_inline(fn)
  630. elif config.wrap_top_frame or (
  631. (filename is None or trace_rules.check(fn))
  632. and (
  633. getattr(fn, "__name__", "")
  634. not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"]
  635. )
  636. and filename not in DONT_WRAP_FILES
  637. ):
  638. # call to a builtin without a frame for us to capture
  639. fn = external_utils.wrap_inline(fn)
  640. def do_nothing(*arg: Any, **kwargs: Any) -> None:
  641. pass
  642. callback: Callable[..., Any] = do_nothing
  643. if hasattr(self, "callback"):
  644. callback = self.callback # type: ignore[assignment]
  645. is_jit_tracing = torch._C._is_tracing
  646. is_fx_symbolic_tracing = torch.fx._symbolic_trace.is_fx_symbolic_tracing
  647. @functools.wraps(fn)
  648. def compile_wrapper(*args: Any, **kwargs: Any) -> Any:
  649. prior = set_eval_frame(None)
  650. try:
  651. if is_fx_symbolic_tracing():
  652. if config.error_on_nested_fx_trace:
  653. raise RuntimeError(
  654. "Detected that you are using FX to symbolically trace "
  655. "a dynamo-optimized function. This is not supported at the moment."
  656. )
  657. else:
  658. return fn(*args, **kwargs)
  659. if is_jit_tracing():
  660. raise RuntimeError(
  661. "Detected that you are using FX to torch.jit.trace "
  662. "a dynamo-optimized function. This is not supported at the moment."
  663. )
  664. cleanups = [enter() for enter in self.enter_exit_hooks]
  665. prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
  666. _is_skip_guard_eval_unsafe_stance()
  667. )
  668. prior_error_on_graph_break = None
  669. if not self.fullgraph and self.error_on_graph_break is not None:
  670. prior_error_on_graph_break = _get_error_on_graph_break()
  671. _set_error_on_graph_break(self.error_on_graph_break)
  672. # Ensure that if an assertion occurs after graph pushes
  673. # something onto the DynamicLayerStack then we pop it off (the
  674. # constructed graph code isn't guarded with try/finally).
  675. #
  676. # This used to be a context but putting a `with` here is a noticeable
  677. # perf regression (#126293)
  678. saved_dynamic_layer_stack_depth = (
  679. torch._C._functorch.get_dynamic_layer_stack_depth()
  680. )
  681. _maybe_set_eval_frame(_callback_from_stance(callback))
  682. try:
  683. return fn(*args, **kwargs)
  684. except Unsupported as e:
  685. if config.verbose:
  686. raise
  687. # strip internal tracebacks from causes
  688. cur_exn: BaseException = e
  689. while cur_exn.__cause__ is not None:
  690. cur_exn.__cause__.with_traceback(None)
  691. cur_exn = cur_exn.__cause__
  692. raise e.with_traceback(None) from e.__cause__ # User compiler error
  693. except ShortenTraceback as e:
  694. # Failures in the backend likely don't have useful
  695. # data in the TorchDynamo frames, so we strip them out.
  696. raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
  697. finally:
  698. # Restore the dynamic layer stack depth if necessary.
  699. set_eval_frame(None)
  700. if prior_error_on_graph_break is not None:
  701. _set_error_on_graph_break(prior_error_on_graph_break)
  702. torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
  703. saved_dynamic_layer_stack_depth
  704. )
  705. set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe)
  706. for cleanup in cleanups:
  707. cleanup()
  708. finally:
  709. _maybe_set_eval_frame(prior)
  710. # hooks to properly handle inlining
  711. if self.error_on_graph_break is not None:
  712. compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
  713. external_utils.wrap_inline_with_error_on_graph_break(
  714. fn, self.error_on_graph_break
  715. )
  716. )
  717. else:
  718. compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined]
  719. # Save the function pointer to find the original callable while nesting
  720. # of decorators.
  721. compile_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
  722. # when compiling user function instead of nn.Module
  723. # provide public api _fn.get_compiler_config()
  724. assert not hasattr(compile_wrapper, "get_compiler_config")
  725. compile_wrapper.get_compiler_config = get_compiler_config # type: ignore[attr-defined]
  726. if torch._dynamo.config.enable_aot_compile:
  727. compile_wrapper.aot_compile = aot_compile # type: ignore[attr-defined]
  728. # If the function is called using torch._dynamo.optimize decorator, we
  729. # should prevent any type of skipping.
  730. if callback not in (None, False):
  731. if not hasattr(fn, "__code__"):
  732. raise RuntimeError(
  733. textwrap.dedent(
  734. """
  735. torch._dynamo.optimize is called on a non function object.
  736. If this is a callable class, please wrap the relevant code into a function and optimize the
  737. wrapper function.
  738. >> class CallableClass:
  739. >> def __init__(self) -> None:
  740. >> super().__init__()
  741. >> self.relu = torch.nn.ReLU()
  742. >>
  743. >> def __call__(self, x):
  744. >> return self.relu(torch.sin(x))
  745. >>
  746. >> def print_hello(self):
  747. >> print("Hello world")
  748. >>
  749. >> mod = CallableClass()
  750. If you want to optimize the __call__ function and other code, wrap that up in a function
  751. >> def wrapper_fn(x):
  752. >> y = mod(x)
  753. >> return y.sum()
  754. and then optimize the wrapper_fn
  755. >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn)
  756. """
  757. )
  758. )
  759. always_optimize_code_objects[fn.__code__] = True
  760. return compile_wrapper
  761. class OptimizeContext(_TorchDynamoContext):
  762. def __init__(
  763. self,
  764. callback: DynamoCallback,
  765. backend_ctx_ctor: Callable[[], contextlib.AbstractContextManager[Any]],
  766. first_ctx: bool = False,
  767. *,
  768. fullgraph: bool = False,
  769. error_on_graph_break: Optional[bool] = None,
  770. export: bool = False,
  771. dynamic: Optional[bool] = None,
  772. compiler_config: Optional[Any] = None,
  773. rebuild_ctx: Optional[
  774. Callable[[], Union[OptimizeContext, _NullDecorator]]
  775. ] = None,
  776. package: Optional[CompilePackage] = None,
  777. hooks: Optional[Hooks] = None,
  778. ) -> None:
  779. def on_enter() -> None:
  780. install_generation_tagging_init()
  781. super().__init__(
  782. callback=callback,
  783. on_enter=on_enter,
  784. backend_ctx_ctor=backend_ctx_ctor,
  785. patch_fn=TorchPatcher.patch,
  786. first_ctx=first_ctx,
  787. fullgraph=fullgraph,
  788. error_on_graph_break=error_on_graph_break,
  789. export=export,
  790. dynamic=dynamic,
  791. compiler_config=compiler_config,
  792. package=package,
  793. hooks=hooks,
  794. )
  795. if config.compiled_autograd:
  796. _dynamic = self._dynamic
  797. if _dynamic is None:
  798. _dynamic = not torch._dynamo.config.assume_static_by_default
  799. def call_compiled_autograd() -> functools.partial[Optional[bool]]:
  800. assert rebuild_ctx is not None
  801. compiler_fn = rebuild_ctx()
  802. ctx = torch._dynamo.compiled_autograd._enable(
  803. compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False
  804. )
  805. ctx.__enter__()
  806. return functools.partial(ctx.__exit__, None, None, None)
  807. self.enter_exit_hooks.append(call_compiled_autograd)
  808. def __reduce__(
  809. self,
  810. ) -> tuple[type[OptimizeContext], tuple[Any, ...], dict[str, Any]]:
  811. return (
  812. self.__class__,
  813. (self.callback, self._backend_ctx_ctor, self.first_ctx),
  814. {
  815. "export": self.export,
  816. "dynamic": self._dynamic,
  817. "compiler_config": self.compiler_config,
  818. },
  819. )
  820. class RunOnlyContext(_TorchDynamoContext):
  821. def __init__(self) -> None:
  822. # cudagraph trees relies on generation increment
  823. def on_enter() -> None:
  824. torch._dynamo.mutation_guard.GenerationTracker.generation += 1
  825. super().__init__(callback=False, on_enter=on_enter)
  826. def __reduce__(self) -> tuple[type[RunOnlyContext], tuple[Any, ...]]:
  827. return (self.__class__, ())
  828. class DisableContext(_TorchDynamoContext):
  829. def __init__(self, msg: Optional[str] = None, wrapping: bool = True) -> None:
  830. super().__init__(callback=None)
  831. self.msg = msg
  832. self.wrapping = wrapping
  833. def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]:
  834. # Earlier this code was in the base class _TorchDynamoContext. But we
  835. # moved it here to have better code organization. For disable, we just
  836. # want the callback to be None. We don't have to check trace_rules or
  837. # create any wrapper.
  838. fn = innermost_fn(fn)
  839. if isinstance(fn, torch.nn.Module):
  840. mod = fn
  841. new_mod = OptimizedModule(mod, self)
  842. new_mod._torchdynamo_orig_callable = mod.forward
  843. return new_mod
  844. if isinstance(fn, type):
  845. # User has wrapped the class with compile/disable decorator. Apply
  846. # disable to init/call method.
  847. cls_obj = fn
  848. # Disable on init is useful for reconstruction of bytecodes where we
  849. # want to prevent Dynamo from tracing into the init function. Check
  850. # test_reconstruction in test_model_output.py.
  851. cls_obj.__init__ = self(cls_obj.__init__) # type: ignore[misc]
  852. cls_obj.__call__ = self(cls_obj.__call__)
  853. if issubclass(cls_obj, torch.nn.Module):
  854. # NN module variable tracker directly inlines the _call_impl. Disable it.
  855. cls_obj._call_impl = self(cls_obj._call_impl)
  856. return cls_obj
  857. assert callable(fn), (
  858. f"A callable function is expected, but {type(fn)} is provided."
  859. )
  860. def _fn(*args: Any, **kwargs: Any) -> Any:
  861. prior = set_eval_frame(None)
  862. try:
  863. _maybe_set_eval_frame(_callback_from_stance(self.callback))
  864. try:
  865. return fn(*args, **kwargs)
  866. finally:
  867. set_eval_frame(None)
  868. finally:
  869. _maybe_set_eval_frame(prior)
  870. # Under some circumstances (e.g. precompile) we can end up calling @disable
  871. # decorator in generated bytecode and trigger recompile. This is due to the
  872. # fact that the old callback from torch.compile() is still active and under
  873. # this circumstance we will trigger a failure with set_stance("fail_on_recompile").
  874. # Therefore we want to skip calling into any frame in this case.
  875. if self.wrapping:
  876. _fn = functools.wraps(fn)(_fn)
  877. _fn._torchdynamo_disable = True # type: ignore[attr-defined]
  878. _fn._torchdynamo_disable_msg = self.msg # type: ignore[attr-defined]
  879. # Save the function pointer to find the original callable while nesting
  880. # of decorators.
  881. _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
  882. return _fn
  883. def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]:
  884. return (self.__class__, ())
  885. def _optimize_catch_errors(
  886. compile_fn: convert_frame.ConvertFrameProtocol,
  887. hooks: Hooks,
  888. backend_ctx_ctor: Callable[
  889. [], contextlib.AbstractContextManager[Any]
  890. ] = null_context,
  891. fullgraph: bool = False,
  892. error_on_graph_break: Optional[bool] = None,
  893. export: bool = False,
  894. dynamic: Optional[bool] = None,
  895. compiler_config: Optional[Any] = None,
  896. rebuild_ctx: Optional[Callable[[], Union[OptimizeContext, _NullDecorator]]] = None,
  897. package: Optional[CompilePackage] = None,
  898. ) -> OptimizeContext:
  899. return OptimizeContext(
  900. convert_frame.catch_errors_wrapper(compile_fn, hooks),
  901. backend_ctx_ctor=backend_ctx_ctor,
  902. first_ctx=True,
  903. fullgraph=fullgraph,
  904. error_on_graph_break=error_on_graph_break,
  905. export=export,
  906. dynamic=dynamic,
  907. compiler_config=compiler_config,
  908. rebuild_ctx=rebuild_ctx,
  909. package=package,
  910. hooks=hooks,
  911. )
  912. def get_compiler_fn(
  913. compiler_fn: Union[str, Callable[..., Any], None],
  914. ) -> WrapBackendDebug:
  915. from .repro.after_dynamo import wrap_backend_debug
  916. if compiler_fn is None:
  917. # Special case None to avoid crashing in hasattr
  918. compiler_str = None
  919. elif hasattr(compiler_fn, "compiler_name"):
  920. compiler_str = compiler_fn.compiler_name # type: ignore[union-attr]
  921. assert isinstance(compiler_str, str)
  922. elif isinstance(compiler_fn, str):
  923. compiler_str = compiler_fn
  924. else:
  925. compiler_str = None
  926. compiler_fn = lookup_backend(compiler_fn) # type: ignore[arg-type]
  927. return wrap_backend_debug(compiler_fn, compiler_str)
  928. class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
  929. def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]:
  930. assert callable(fn), (
  931. f"A callable function is expected, but {type(fn)} is provided."
  932. )
  933. return fn
  934. # Make dynamo graph to have same input/output spec as user code
  935. def argument_names(
  936. f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any]
  937. ) -> list[str]:
  938. def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
  939. # Get a list of Parameter objects from the Signature object
  940. params = list(sig.parameters.values())
  941. # Separate positional arguments, keyword-only arguments and varargs/varkw
  942. args = [
  943. p.name for p in params if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  944. ]
  945. kwonlyargs = [
  946. p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
  947. ]
  948. varargs = next(
  949. (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
  950. None,
  951. )
  952. varkw = next(
  953. (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
  954. None,
  955. )
  956. # Get default values for positional arguments and keyword-only arguments
  957. defaults = tuple(
  958. p.default
  959. for p in params
  960. if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  961. and p.default is not inspect.Parameter.empty
  962. )
  963. kwonlydefaults = {
  964. p.name: p.default
  965. for p in params
  966. if p.kind == inspect.Parameter.KEYWORD_ONLY
  967. and p.default is not inspect.Parameter.empty
  968. }
  969. # Get annotations for parameters and return value
  970. annotations = {}
  971. if sig.return_annotation:
  972. annotations = {"return": sig.return_annotation}
  973. for parameter in params:
  974. annotations[parameter.name] = parameter.annotation
  975. # Return a FullArgSpec object with the extracted attributes
  976. return inspect.FullArgSpec(
  977. args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
  978. )
  979. fullargspec = signature_to_fullargspec(f_sig)
  980. # 1. Map `args` 1-to-1 to positional arguments in original signature.
  981. input_strs = fullargspec.args[: len(args)]
  982. if len(args) > len(fullargspec.args):
  983. # 2. If there are more arguments left in `args`, they map to varargs in original
  984. # signature. Assign names as {varargs}_0, {varargs}_1, ...
  985. assert fullargspec.varargs is not None, "More arguments than expected"
  986. input_strs += [
  987. f"{fullargspec.varargs}_{i}" for i in range(0, len(args) - len(input_strs))
  988. ]
  989. elif len(args) < len(fullargspec.args):
  990. # 3. If there are fewer arguments in `args` than `fullargspec.args`,
  991. # it implies these are arguments either with default values, or provided in
  992. # `kwargs`. The former can be safely ignored. Because Dynamo.export does not
  993. # export them as part of the function signature. The latter will be handled
  994. # in the next step.
  995. for unprovided_arg in fullargspec.args[
  996. len(args) : -len(fullargspec.defaults or [])
  997. ]:
  998. assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
  999. # 4. Keyword arguments provided in `kwargs`.
  1000. input_strs += list(kwargs.keys())
  1001. # 5. Keyword-only arguments with default values if not provided are not exported
  1002. # as part of the function signature.
  1003. for kwonly_arg in fullargspec.kwonlyargs:
  1004. kwonlydefaults = fullargspec.kwonlydefaults or {}
  1005. assert kwonly_arg in kwargs or kwonly_arg in kwonlydefaults, (
  1006. f"Missing keyword only argument {kwonly_arg}"
  1007. )
  1008. return input_strs
  1009. def check_if_dynamo_supported() -> None:
  1010. if sys.version_info >= (3, 14):
  1011. raise RuntimeError("Python 3.14+ not yet supported for torch.compile")
  1012. elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < (
  1013. 3,
  1014. 13,
  1015. 3,
  1016. ):
  1017. raise RuntimeError(
  1018. "torch.compile is not supported on Python < 3.13.3 built with GIL disabled. "
  1019. "Please use Python 3.13.3+."
  1020. )
  1021. def is_dynamo_supported() -> bool:
  1022. try:
  1023. check_if_dynamo_supported()
  1024. return True
  1025. except Exception:
  1026. return False
  1027. def check_if_inductor_supported() -> None:
  1028. check_if_dynamo_supported()
  1029. def is_inductor_supported() -> bool:
  1030. try:
  1031. check_if_inductor_supported()
  1032. return True
  1033. except Exception:
  1034. return False
  1035. def check_for_incompatible_configs() -> None:
  1036. # Some of the configs should be mutually exclusive
  1037. assert not (config.suppress_errors and config.fail_on_recompile_limit_hit), (
  1038. "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time."
  1039. )
  1040. def optimize(*args: Any, **kwargs: Any) -> Union[OptimizeContext, _NullDecorator]:
  1041. def rebuild_ctx() -> Union[OptimizeContext, _NullDecorator]:
  1042. ca_kwargs_override = config.compiled_autograd_kwargs_override
  1043. if ca_kwargs_override:
  1044. # NOTE: The process of translating other `torch.compile` kwargs to `torch._dynamo.optimize` kwargs
  1045. # is more complicated, we will add it in the future when needed.
  1046. assert set(ca_kwargs_override.keys()) == {"fullgraph"}, (
  1047. f"Only `fullgraph` kwarg override is supported for now, but got {ca_kwargs_override.keys()}"
  1048. )
  1049. kwargs["nopython"] = ca_kwargs_override["fullgraph"]
  1050. return optimize(*args, **kwargs)
  1051. return _optimize(rebuild_ctx, *args, **kwargs)
  1052. def _optimize(
  1053. rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]],
  1054. backend: Union[str, Callable[..., Any]] = "inductor",
  1055. *,
  1056. nopython: bool = False,
  1057. error_on_graph_break: Optional[bool] = None,
  1058. guard_export_fn: Optional[Callable[[_guards.GuardsSet], None]] = None,
  1059. guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
  1060. guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None,
  1061. disable: bool = False,
  1062. dynamic: Optional[bool] = None,
  1063. package: Optional[CompilePackage] = None,
  1064. ) -> Union[OptimizeContext, _NullDecorator]:
  1065. """
  1066. The main entrypoint of TorchDynamo. Do graph capture and call
  1067. backend() to optimize extracted graphs.
  1068. Args:
  1069. backend: One of the two things:
  1070. - Either, a function/callable taking a torch.fx.GraphModule and
  1071. example_inputs and returning a python callable that runs the
  1072. graph faster.
  1073. One can also provide additional context for the backend, like
  1074. torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
  1075. See AOTAutogradMemoryEfficientFusionWithContext for the usage.
  1076. - Or, a string backend name in `torch._dynamo.list_backends()`
  1077. nopython: If True, graph breaks will be errors and there will
  1078. be a single whole-program graph.
  1079. error_on_graph_break: If not None, the current `error_on_graph_break` setting is set to the given value.
  1080. See `torch._dynamo.error_on_graph_break()` for more details on what `error_on_graph_break` means.
  1081. Unlike `nopython=True` (i.e. `fullgraph=True`), there is no guarantee of a single whole-program graph.
  1082. If `nopython` is True, `error_on_graph_break` does nothing.
  1083. disable: If True, turn this decorator into a no-op
  1084. dynamic: If True, upfront compile as dynamic a kernel as possible. If False,
  1085. disable all dynamic shapes support (always specialize). If None, automatically
  1086. detect when sizes vary and generate dynamic kernels upon recompile.
  1087. Example Usage::
  1088. @torch._dynamo.optimize()
  1089. def toy_example(a, b): ...
  1090. """
  1091. check_if_dynamo_supported()
  1092. check_for_incompatible_configs()
  1093. # Note: The hooks object could be global instead of passed around, *however* that would make
  1094. # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
  1095. # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
  1096. # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an
  1097. # easier to understand UX at the cost of a little more plumbing on our end.
  1098. hooks = Hooks(
  1099. guard_export_fn=guard_export_fn,
  1100. guard_fail_fn=guard_fail_fn,
  1101. guard_filter_fn=guard_filter_fn,
  1102. )
  1103. torch._C._log_api_usage_once("torch._dynamo.optimize")
  1104. if (
  1105. disable
  1106. or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1"
  1107. or (not justknobs_check("pytorch/compiler:enable_dynamo"))
  1108. ):
  1109. return _NullDecorator()
  1110. if nopython and not config.debug_force_graph_break_on_leaf_return:
  1111. return optimize_assert(
  1112. backend,
  1113. dynamic=dynamic,
  1114. hooks=hooks,
  1115. rebuild_ctx=rebuild_ctx,
  1116. package=package,
  1117. )
  1118. backend = get_compiler_fn(backend)
  1119. # Find if backend has any extra context manager
  1120. backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
  1121. # The backend function is stashed in the callable returned by
  1122. # _optimize_catch_errors in the field _torchdynamo_orig_backend. This can
  1123. # be used by eval_frame.c to insert a guard on the backend.
  1124. # With CachingPrecompile, instantiate an uninitialized CompilePackage
  1125. # which gets initialized by _optimize_catch_errors.__call__ once we have a function
  1126. if config.caching_precompile and package is None:
  1127. from .package import CompilePackage
  1128. package = CompilePackage(fn=None, dynamo=None, ignore_inlined_sources=False)
  1129. return _optimize_catch_errors(
  1130. convert_frame.convert_frame(
  1131. backend,
  1132. hooks,
  1133. package=package,
  1134. ),
  1135. hooks,
  1136. backend_ctx_ctor,
  1137. fullgraph=False,
  1138. error_on_graph_break=error_on_graph_break
  1139. and not config.debug_force_graph_break_on_leaf_return,
  1140. dynamic=dynamic,
  1141. compiler_config=(
  1142. backend.get_compiler_config()
  1143. if hasattr(backend, "get_compiler_config")
  1144. else None
  1145. ),
  1146. rebuild_ctx=rebuild_ctx,
  1147. package=package,
  1148. )
  1149. # TODO(voz): Consider making "explain" output alongside a run / part of a run
  1150. @patch("torch._dynamo.symbolic_convert.explain", True)
  1151. def explain(f: Callable[..., Any], *extra_args: Any, **extra_kwargs: Any) -> Any:
  1152. from .backends.debugging import ExplainOutput
  1153. def inner(*args: Any, **kwargs: Any) -> ExplainOutput:
  1154. # TODO(voz): Do we want a decorator for this?
  1155. from . import reset # type: ignore[attr-defined]
  1156. reset()
  1157. graphs: list[torch.fx.GraphModule] = []
  1158. break_reasons: list[Any] = []
  1159. op_count: int = 0
  1160. ops_per_graph: list[list[Target]] = []
  1161. out_guards: list[_guards.Guard] = []
  1162. def dynamo_graph_accumulating_compiler(
  1163. gm: torch.fx.GraphModule, example_inputs: Any
  1164. ) -> Callable[..., Any]:
  1165. from .backends.debugging import _explain_graph_detail
  1166. nonlocal graphs
  1167. nonlocal op_count
  1168. nonlocal ops_per_graph
  1169. nonlocal break_reasons
  1170. gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail(
  1171. gm, graphs, op_count, ops_per_graph, break_reasons
  1172. )
  1173. return gm.forward
  1174. def guard_export_print(guards: Iterable[_guards.Guard]) -> None:
  1175. nonlocal out_guards
  1176. out_guards.extend(guards)
  1177. opt_f = optimize(
  1178. dynamo_graph_accumulating_compiler,
  1179. nopython=False,
  1180. guard_export_fn=guard_export_print,
  1181. )(f)
  1182. # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
  1183. opt_f(*args, **kwargs)
  1184. graph_count = len(graphs)
  1185. graph_break_count = graph_count - 1
  1186. compile_time = compile_times(repr="str")
  1187. # TODO(voz): Do we want a decorator for this?
  1188. reset()
  1189. return ExplainOutput(
  1190. graphs,
  1191. graph_count,
  1192. graph_break_count,
  1193. break_reasons,
  1194. op_count,
  1195. ops_per_graph,
  1196. out_guards,
  1197. compile_time,
  1198. )
  1199. if extra_args or extra_kwargs:
  1200. warnings.warn(
  1201. "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. "
  1202. "If you don't migrate, we may break your explain call in the future if your user defined kwargs "
  1203. "conflict with future kwargs added to explain(f).",
  1204. FutureWarning,
  1205. stacklevel=2,
  1206. )
  1207. return inner(*extra_args, **extra_kwargs)
  1208. else:
  1209. return inner
  1210. class FlattenInputOutputSignature(torch.fx.Transformer):
  1211. def __init__(
  1212. self,
  1213. m: torch.fx.GraphModule,
  1214. flat_args: list[Any],
  1215. matched_input_elements_positions: list[int],
  1216. flat_results: Sequence[Any],
  1217. matched_output_elements_positions: list[int],
  1218. example_fake_inputs: list[torch.Tensor],
  1219. flat_args_dynamic_dims: list[set[int]],
  1220. fake_mode: Optional[fake_tensor.FakeTensorMode] = None,
  1221. ) -> None:
  1222. super().__init__(m)
  1223. assert len(flat_args_dynamic_dims) == len(flat_args)
  1224. matched_input_elements_to_fake = {
  1225. val: example_fake_inputs[ix]
  1226. for ix, val in enumerate(matched_input_elements_positions)
  1227. }
  1228. self.new_args = []
  1229. for i in range(0, len(flat_args)):
  1230. arg = super().placeholder(f"arg{i}", (), {})
  1231. if i in matched_input_elements_to_fake:
  1232. arg.node.meta["val"] = matched_input_elements_to_fake[i]
  1233. else:
  1234. # Fill node.meta["val"] with faketensor from the input,
  1235. # if it's not found in matched_input_elements_positions
  1236. if fake_mode is not None and isinstance(flat_args[i], torch.Tensor):
  1237. # TODO(zhxchen17) Also preserve all the user constraints here.
  1238. arg.node.meta["val"] = fake_mode.from_tensor(
  1239. flat_args[i],
  1240. symbolic_context=StatelessSymbolicContext(
  1241. dynamic_sizes=[
  1242. (
  1243. DimDynamic.DYNAMIC
  1244. if d in flat_args_dynamic_dims[i]
  1245. else DimDynamic.STATIC
  1246. )
  1247. for d in range(len(flat_args[i].shape))
  1248. ],
  1249. constraint_sizes=[None] * len(flat_args[i].shape),
  1250. ),
  1251. )
  1252. elif isinstance(flat_args[i], _IntWrapper):
  1253. arg.node.meta["val"] = flat_args[i].val
  1254. else:
  1255. arg.node.meta["val"] = flat_args[i]
  1256. self.new_args.append(arg)
  1257. self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions)
  1258. self.matched_output_elements_positions = matched_output_elements_positions
  1259. self.flat_results = flat_results
  1260. def placeholder(
  1261. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  1262. ) -> Any:
  1263. arg = next(self.old_args_gen)
  1264. if "val" in self.current_node.meta:
  1265. arg.node.meta["val"] = self.current_node.meta["val"]
  1266. if "tensor_dict" in self.current_node.meta:
  1267. arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
  1268. if "example_value" in self.current_node.meta:
  1269. # NB: intentionally do not use set_example_value
  1270. arg.node.meta["example_value"] = self.current_node.meta["example_value"]
  1271. if "unbacked_bindings" in self.current_node.meta:
  1272. arg.node.meta["unbacked_bindings"] = self.current_node.meta[
  1273. "unbacked_bindings"
  1274. ]
  1275. return arg
  1276. def output(
  1277. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  1278. ) -> Any:
  1279. dynamo_result_flat = args[0]
  1280. lookup = [*dynamo_result_flat, *self.new_args] # type: ignore[misc]
  1281. new_results_flat = []
  1282. for i in range(len(self.flat_results)):
  1283. if self.matched_output_elements_positions[i] is not None:
  1284. new_results_flat.append(
  1285. lookup[self.matched_output_elements_positions[i]]
  1286. )
  1287. else:
  1288. const_val = self.flat_results[i]
  1289. assert isinstance(const_val, tuple(common_constant_types))
  1290. new_results_flat.append(const_val)
  1291. return super().output(target, (new_results_flat,), {})
  1292. def run_node(self, n: Node) -> Any:
  1293. self.current_node = n
  1294. result_proxy = super().run_node(n)
  1295. if "val" in self.current_node.meta:
  1296. result_proxy.node.meta["val"] = self.current_node.meta["val"]
  1297. if "example_value" in self.current_node.meta:
  1298. # NB: intentionally do not use set_example_value
  1299. result_proxy.node.meta["example_value"] = self.current_node.meta[
  1300. "example_value"
  1301. ]
  1302. if "unbacked_bindings" in self.current_node.meta:
  1303. result_proxy.node.meta["unbacked_bindings"] = self.current_node.meta[
  1304. "unbacked_bindings"
  1305. ]
  1306. if self.current_node.op != "output":
  1307. result_proxy.node._rename(
  1308. getattr(self.current_node, "name", result_proxy.node.name)
  1309. )
  1310. return result_proxy
  1311. def transform(self) -> torch.fx.GraphModule:
  1312. result_gm = super().transform()
  1313. if "dynamo_flat_name_to_original_fqn" in self.module.meta: # type: ignore[operator]
  1314. result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ # type: ignore[index]
  1315. "dynamo_flat_name_to_original_fqn" # type: ignore[index]
  1316. ]
  1317. if "dynamo_compile_id" in self.module.meta: # type: ignore[operator]
  1318. result_gm.meta["dynamo_compile_id"] = self.module.meta["dynamo_compile_id"] # type: ignore[index]
  1319. return result_gm
  1320. class ExportResult(NamedTuple):
  1321. graph_module: torch.fx.GraphModule
  1322. guards: _guards.GuardsSet
  1323. # NB: Do not add new fields without overriding __iter__; people are
  1324. # destructuring so it is BC-breaking
  1325. # NOTE: this function only supports graphs created by Dynamo's OutputGraph module
  1326. def check_signature_rewritable(graph: torch.fx.GraphModule) -> None:
  1327. input_errors = []
  1328. for node in graph.graph.find_nodes(op="placeholder"):
  1329. # set in OutputGraph._call_user_compiler
  1330. assert hasattr(node, "_dynamo_source")
  1331. assert hasattr(graph, "_source_to_user_stacks")
  1332. # NOTE: We can safely ignore these type warnings if and only if
  1333. # the function is made from OutputGraph (checked in the assertions)
  1334. source = node._dynamo_source # type: ignore[attr-defined]
  1335. user_stacks = graph._source_to_user_stacks.get(source) # type: ignore[operator, union-attr]
  1336. if user_stacks is None:
  1337. continue
  1338. assert len(user_stacks) > 0
  1339. # In some cases we may not have a useful stack. Look for a
  1340. # useful stack
  1341. stack = None
  1342. for s in user_stacks:
  1343. if len(s) == 0:
  1344. continue
  1345. stack = s
  1346. break
  1347. if stack is None:
  1348. msg = f"{source.name()}, a closed over free variable"
  1349. else:
  1350. tb = "".join(traceback.format_list(stack))
  1351. extra = ""
  1352. if len(user_stacks) > 1:
  1353. extra = f"(elided {len(user_stacks) - 1} more accesses)"
  1354. msg = f"{source.name()}, accessed at:\n{tb}{extra}"
  1355. # TODO: option to print ALL of the stack traces at once
  1356. input_errors.append(msg)
  1357. if input_errors:
  1358. raise UserError(
  1359. UserErrorType.INVALID_INPUT,
  1360. "Cannot export model which references tensors that are neither "
  1361. "buffers/parameters/constants nor are direct inputs. For each tensor, if you'd "
  1362. "like this tensor to be an explicit input, add it as a dummy argument "
  1363. "to the top-level model definition you are exporting; if you would "
  1364. "like its value to be embedded as an exported constant, wrap its access "
  1365. "in a function marked with @assume_constant_result.\n\n"
  1366. + "\n\n".join(input_errors),
  1367. )
  1368. def rewrite_signature(
  1369. f_sig: inspect.Signature,
  1370. graph: torch.fx.GraphModule,
  1371. fake_mode: Optional[fake_tensor.FakeTensorMode],
  1372. flat_args: list[Any],
  1373. in_spec: pytree.TreeSpec,
  1374. example_fake_inputs: list[Any],
  1375. graph_captured_input: Iterable[Any],
  1376. graph_captured_output: Optional[Iterable[Any]],
  1377. dynamo_traced_result: Any,
  1378. flat_args_dynamic_dims: list[set[int]],
  1379. ) -> torch.fx.GraphModule:
  1380. orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
  1381. def check_user_input_output(
  1382. flat_values: list[Any], error_type: UserErrorType
  1383. ) -> None:
  1384. supported_types = [
  1385. torch.Tensor,
  1386. torch.SymInt,
  1387. torch.SymFloat,
  1388. torch.SymBool,
  1389. torch._C.ScriptObject,
  1390. _IntWrapper,
  1391. ] + list(common_constant_types)
  1392. def is_supported_type(val: Any) -> bool:
  1393. return isinstance(val, tuple(supported_types))
  1394. value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
  1395. # We only check that the outputs are not None. Inputs can be None.
  1396. for v in flat_values:
  1397. if not is_supported_type(v):
  1398. if error_type == UserErrorType.INVALID_INPUT and v is None:
  1399. continue
  1400. raise UserError(
  1401. error_type,
  1402. f"It looks like one of the {value_type}s with type `{type(v)}` "
  1403. "is not supported or pytree-flattenable. \n"
  1404. f"Exported graphs {value_type}s can only contain the "
  1405. f"following supported types: {supported_types}. \n"
  1406. "If you are using a custom class object, "
  1407. "please register a pytree_flatten/unflatten function "
  1408. "using `torch.utils._pytree.register_pytree_node` or "
  1409. "`torch.export.register_dataclass`.",
  1410. )
  1411. check_user_input_output(flat_args, UserErrorType.INVALID_INPUT)
  1412. flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
  1413. check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT)
  1414. def check_optional_input_and_error(f_sig: inspect.Signature) -> None:
  1415. # Check if function has optional input.
  1416. for name, param in f_sig.parameters.items():
  1417. if param.default is not inspect.Parameter.empty:
  1418. from torch._dynamo.exc import Unsupported
  1419. log.error(
  1420. "Parameter %s is optional with a default value of %s",
  1421. name,
  1422. param.default,
  1423. )
  1424. raise Unsupported(
  1425. "Tracing through optional input is not supported yet",
  1426. case_name="optional_input",
  1427. )
  1428. def produce_matching(
  1429. debug_type: str, sources: Iterable[Any], candidates: Iterable[Any]
  1430. ) -> list[Optional[int]]:
  1431. matched_elements_positions: list[Optional[int]] = []
  1432. dict_of_source_vals = {}
  1433. for i, val in enumerate(sources):
  1434. dict_of_source_vals[id(val)] = i
  1435. for i, val in enumerate(candidates):
  1436. if isinstance(val, tuple(common_constant_types)):
  1437. matched_elements_positions.append(None)
  1438. elif id(val) not in dict_of_source_vals:
  1439. if debug_type == "inputs":
  1440. check_optional_input_and_error(f_sig)
  1441. raise AssertionError(
  1442. f"Unexpectedly found a {type(val)} in the {debug_type}.\n"
  1443. 'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"',
  1444. )
  1445. else:
  1446. matched_elements_positions.append(dict_of_source_vals[id(val)])
  1447. return matched_elements_positions
  1448. matched_input_elements_positions = produce_matching(
  1449. "inputs", flat_args, graph_captured_input
  1450. )
  1451. assert graph_captured_output is not None
  1452. matched_output_elements_positions = produce_matching(
  1453. "outputs", list(graph_captured_output) + flat_args, flat_results_traced
  1454. )
  1455. new_graph = FlattenInputOutputSignature(
  1456. graph,
  1457. flat_args,
  1458. matched_input_elements_positions, # type: ignore[arg-type]
  1459. flat_results_traced,
  1460. matched_output_elements_positions, # type: ignore[arg-type]
  1461. example_fake_inputs,
  1462. flat_args_dynamic_dims,
  1463. fake_mode,
  1464. ).transform()
  1465. new_graph.graph._codegen = _PyTreeCodeGen(
  1466. _PyTreeInfo(
  1467. argument_names(f_sig, orig_args, orig_kwargs),
  1468. in_spec,
  1469. out_spec_traced,
  1470. )
  1471. )
  1472. new_graph.recompile()
  1473. return new_graph
  1474. def export(
  1475. f: Callable[..., Any],
  1476. *extra_args: Any,
  1477. aten_graph: bool = False,
  1478. pre_dispatch: bool = False,
  1479. decomposition_table: Optional[
  1480. dict[torch._ops.OpOverload, Callable[..., Any]]
  1481. ] = None,
  1482. tracing_mode: str = "symbolic",
  1483. dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
  1484. specialize_float: bool = True,
  1485. assume_static_by_default: bool = False,
  1486. same_signature: bool = True,
  1487. disable_constraint_solver: bool = False,
  1488. prefer_deferred_runtime_asserts_over_guards: bool = False,
  1489. _log_export_usage: bool = True,
  1490. constraints: Optional[list[Constraint]] = None,
  1491. **extra_kwargs: Any,
  1492. ) -> Callable[..., ExportResult]:
  1493. """
  1494. Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
  1495. Args:
  1496. f (callable): A PyTorch function to be exported.
  1497. aten_graph (bool): If True, exports a graph with ATen operators.
  1498. If False, exports a graph with Python operators. Default is False.
  1499. pre_dispatch (bool): If True, exports a graph with ATen operators,
  1500. but before any logic in the PyTorch dispatcher has run.
  1501. This can be useful if you want to apply further transformations on a graph before running it
  1502. through autograd, autocast, or any other functionalities that are integrated into the dispatcher.
  1503. This flag is only valid if aten_graph=True is set.
  1504. Default is False.
  1505. decomposition_table (dict): A dictionary that maps operators to their decomposition functions.
  1506. Required if aten_graph or tracing_mode is specified. Default is None.
  1507. tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic".
  1508. dynamic_shapes:
  1509. An optional argument where the type should either be:
  1510. 1) a dict from argument names of ``f`` to their dynamic shape specifications,
  1511. 2) a tuple that specifies dynamic shape specifications for each input in original order.
  1512. If you are specifying dynamism on keyword args, you will need to pass them in the order that
  1513. is defined in the original function signature.
  1514. The dynamic shape of a tensor argument can be specified as either
  1515. (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
  1516. not required to include static dimension indices in this dict, but when they are,
  1517. they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
  1518. where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
  1519. are denoted by None. Arguments that are dicts or tuples / lists of tensors are
  1520. recursively specified by using mappings or sequences of contained specifications.
  1521. same_signature (bool): If True, rewrite the returned graph's signature to be the same as f.
  1522. disable_constraint_solver (bool): Whether the dim constraint solver must be disabled.
  1523. Returns:
  1524. A function that given args and kwargs, returns a tuple of (graph, guards)
  1525. Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options.
  1526. Guards: The guards we accumulated during tracing f above
  1527. Raises:
  1528. AssertionError: If decomposition_table is specified without setting aten_graph=True,
  1529. or if graph breaks during tracing in export.
  1530. AssertionError: If Dynamo input and output is not consistent with traced input/output.
  1531. Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
  1532. """
  1533. if config.debug_force_graph_break_on_leaf_return:
  1534. raise unittest.SkipTest("Cannot force graph break on export")
  1535. if _log_export_usage:
  1536. log_export_usage(event="export.private_api", flags={"_dynamo"})
  1537. # Deal with "local variable referenced before assignment"
  1538. _f = f
  1539. _specialize_float = specialize_float
  1540. _assume_static_by_default = assume_static_by_default
  1541. _constraints = constraints
  1542. def inner(*args: Any, **kwargs: Any) -> ExportResult:
  1543. if not _constraints:
  1544. combined_args = _combine_args(_f, args, kwargs)
  1545. constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
  1546. else:
  1547. constraints = _constraints
  1548. f = _f
  1549. specialize_float = _specialize_float
  1550. assume_static_by_default = _assume_static_by_default
  1551. check_if_dynamo_supported()
  1552. torch._C._log_api_usage_once("torch._dynamo.export")
  1553. if decomposition_table is not None:
  1554. assert aten_graph, (
  1555. "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
  1556. )
  1557. if pre_dispatch:
  1558. assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True"
  1559. f = innermost_fn(f)
  1560. call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
  1561. original_signature = inspect.signature(call_to_inspect) # type: ignore[arg-type]
  1562. graph = None
  1563. out_guards = None
  1564. graph_captured_input = None
  1565. graph_captured_result: Optional[tuple[torch.Tensor, ...]] = None
  1566. fake_mode = None
  1567. result_traced = None
  1568. def guard_export_print(guards: _guards.GuardsSet) -> None:
  1569. nonlocal out_guards
  1570. assert out_guards is None, (
  1571. "whole graph export entails exactly one guard export"
  1572. )
  1573. out_guards = guards
  1574. example_inputs: list[Any] = []
  1575. def dynamo_normalization_capturing_compiler(
  1576. gm: torch.fx.GraphModule, inner_example_inputs: list[Any]
  1577. ) -> Callable[..., Any]:
  1578. nonlocal graph
  1579. assert graph is None, (
  1580. "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
  1581. )
  1582. graph = gm
  1583. nonlocal fake_mode, example_inputs
  1584. # NB: do NOT pass inner_example_inputs here, we are detecting the
  1585. # Dynamo allocated fake mode, which should be DISTINCT from a
  1586. # potential outer ambient fake mode which the user provided.
  1587. # example_inputs is always the user specified inputs, so they
  1588. # would have the wrong fake mode attached to them
  1589. fake_mode = _guards.detect_fake_mode()
  1590. example_inputs = inner_example_inputs
  1591. def result_capturing_wrapper(*graph_inputs: Any) -> Any:
  1592. nonlocal graph_captured_result
  1593. nonlocal graph_captured_input
  1594. graph_captured_input = graph_inputs
  1595. assert graph is not None
  1596. named_parameters = dict(graph.named_parameters(remove_duplicate=False))
  1597. named_buffers = dict(graph.named_buffers(remove_duplicate=False))
  1598. ambient_fake_mode = (
  1599. _guards.detect_fake_mode(graph_inputs)
  1600. if _guards.detect_fake_mode(graph_inputs) is not None
  1601. else fake_mode
  1602. )
  1603. # We reran fake tensor propagation, but we didn't do
  1604. # anything with the resulting unbacked SymInts. Drop them
  1605. # from the pending list.
  1606. # NB: this is wrong if graph_captured_result has
  1607. # data-dependent output size!
  1608. ignore_fresh_unbacked = null_context()
  1609. assert ambient_fake_mode is not None
  1610. if shape_env := ambient_fake_mode.shape_env:
  1611. ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() # type: ignore[assignment]
  1612. with (
  1613. ambient_fake_mode,
  1614. enable_python_dispatcher(),
  1615. ignore_fresh_unbacked,
  1616. ):
  1617. params_and_buffers = {
  1618. **named_parameters,
  1619. **named_buffers,
  1620. }
  1621. fake_params_buffers = {}
  1622. for name, value in params_and_buffers.items():
  1623. fake_params_buffers[name] = ambient_fake_mode.from_tensor(
  1624. value, static_shapes=True
  1625. )
  1626. from torch._export.non_strict_utils import (
  1627. key_path_to_source,
  1628. KeyPath,
  1629. )
  1630. def fakify_with_ambient(
  1631. path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any]
  1632. ) -> Any:
  1633. if isinstance(t, torch.Tensor):
  1634. return ambient_fake_mode.from_tensor(t, static_shapes=True)
  1635. elif isinstance(t, _IntWrapper):
  1636. if (
  1637. t.dynamism is not None
  1638. and isinstance(t.dynamism, _DimHint)
  1639. and t.dynamism.type
  1640. in (
  1641. _DimHintType.DYNAMIC,
  1642. _DimHintType.AUTO,
  1643. )
  1644. ): # type: ignore[union-attr]
  1645. source = key_path_to_source(path)
  1646. symint = ambient_fake_mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr]
  1647. t.val, source, DimDynamic.DYNAMIC
  1648. )
  1649. return symint
  1650. else:
  1651. return t.val
  1652. else:
  1653. return t
  1654. fake_graph_inputs = pytree.tree_map_with_path(
  1655. fakify_with_ambient, graph_inputs
  1656. )
  1657. graph_captured_result = torch.func.functional_call(
  1658. graph,
  1659. fake_params_buffers, # type: ignore[arg-type]
  1660. fake_graph_inputs, # type: ignore[arg-type]
  1661. )
  1662. return graph_captured_result
  1663. return result_capturing_wrapper
  1664. # Note: This is needed by rewrite_signature. We need to put it before
  1665. # optimize_assert since user program may mutate the inputs.
  1666. flat_args, in_spec = pytree.tree_flatten((args, kwargs))
  1667. remove_from_cache(f)
  1668. constraint_violation_error = None
  1669. if tracing_mode != "symbolic":
  1670. assume_static_by_default = True
  1671. with (
  1672. config.patch(
  1673. specialize_int=True,
  1674. specialize_float=specialize_float,
  1675. assume_static_by_default=assume_static_by_default,
  1676. automatic_dynamic_shapes=False,
  1677. capture_dynamic_output_shape_ops=True,
  1678. capture_scalar_outputs=True,
  1679. prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
  1680. ),
  1681. _compiling_state_context(),
  1682. ):
  1683. opt_f = optimize_assert(
  1684. dynamo_normalization_capturing_compiler,
  1685. hooks=Hooks(
  1686. guard_export_fn=guard_export_print,
  1687. guard_fail_fn=None,
  1688. ),
  1689. export=True,
  1690. export_constraints=constraints,
  1691. )(f)
  1692. # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
  1693. try:
  1694. result_traced = opt_f(*args, **kwargs)
  1695. except ConstraintViolationError as e:
  1696. constraint_violation_error = e
  1697. remove_from_cache(f)
  1698. if (
  1699. not disable_constraint_solver
  1700. and (shape_env := getattr(fake_mode, "shape_env", None)) is not None
  1701. and (dim_constraints := shape_env.dim_constraints) is not None
  1702. and not isinstance(
  1703. call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
  1704. )
  1705. and not trace_rules.check(call_to_inspect)
  1706. ):
  1707. dim_constraints.solve()
  1708. forced_specializations = dim_constraints.forced_specializations()
  1709. msg = dim_constraints.prettify_results(
  1710. original_signature,
  1711. dynamic_shapes,
  1712. constraint_violation_error,
  1713. forced_specializations,
  1714. )
  1715. if constraint_violation_error:
  1716. constraint_violation_error.args = (
  1717. constraint_violation_error.args[0] + msg,
  1718. )
  1719. else:
  1720. if forced_specializations:
  1721. constraint_violation_error = ConstraintViolationError(msg)
  1722. else:
  1723. log.info(
  1724. "Summary of dimension constraints:%s",
  1725. msg,
  1726. )
  1727. # Error if we have any constraints on static values
  1728. for k in shape_env.var_to_range.keys():
  1729. if isinstance(k, sympy.Integer):
  1730. constraint_violation_error = ConstraintViolationError(
  1731. f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
  1732. "It appears that you're trying to set a constraint on a "
  1733. f"value which we evaluated to have a static value of {k}. "
  1734. 'Set TORCH_LOGS="+export" for more information.'
  1735. )
  1736. if constraint_violation_error:
  1737. raise constraint_violation_error
  1738. if graph is None:
  1739. assert same_signature, (
  1740. "Failed to produce a graph during tracing as no tensor operations were found and same_signature is False."
  1741. )
  1742. # If the module does not contain any tensor computation, we would create a graph with inputs and outputs.
  1743. # To be consistent with the graph traced by dynano, `graph` will have only tensor inputs as placeholders
  1744. # and tensor outputs as output nodes. non-tensor inputs and outputs will be added when rewriting signature.
  1745. # We will also construct the `example_inputs`, `graph_captured_input`, and `graph_captured_result` corresponding
  1746. # to `graph`.
  1747. example_inputs = []
  1748. graph_captured_input = ()
  1749. graph_captured_result = ()
  1750. fake_mode = torch._subclasses.FakeTensorMode(
  1751. shape_env=ShapeEnv(), export=True
  1752. )
  1753. if out_guards is None:
  1754. out_guards = _guards.GuardsSet()
  1755. assert out_guards is not None # suppress mypy error
  1756. parameter_names = list(original_signature.parameters.keys())
  1757. fx_graph = torch.fx.Graph()
  1758. for i, name in enumerate(parameter_names):
  1759. if torch.is_tensor(flat_args[i]):
  1760. node = fx_graph.placeholder(name)
  1761. node.meta["val"] = fake_mode.from_tensor(
  1762. flat_args[i], static_shapes=True
  1763. )
  1764. graph_captured_input = graph_captured_input + (flat_args[i],)
  1765. example_inputs.append(flat_args[i])
  1766. fx_graph.output(graph_captured_result)
  1767. module = torch.nn.Module()
  1768. graph = torch.fx.GraphModule(module, fx_graph)
  1769. log.info(
  1770. "Failed to capture a graph during tracing as no tensor operations were found.:\n\n%s",
  1771. graph.print_readable(print_output=False, colored=True),
  1772. )
  1773. else:
  1774. assert out_guards is not None, "Failed to produce guards during tracing"
  1775. assert fake_mode is not None
  1776. log.info(
  1777. "Dynamo captured graph:\n\n%s",
  1778. graph.print_readable(print_output=False, colored=True),
  1779. )
  1780. # This check need to happened before aten_graph
  1781. # because placeholder's _source_node attribute is not preserved by make_fx
  1782. if same_signature:
  1783. check_signature_rewritable(graph)
  1784. # NB: This is mostly hitting the cache; Dynamo already converted these
  1785. example_fake_inputs = [
  1786. fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
  1787. for t in example_inputs
  1788. ]
  1789. if aten_graph:
  1790. # Running graph with interpreter is needed for propagating the stack_trace
  1791. def graph_with_interpreter(*args: Any) -> Any:
  1792. with torch.fx.traceback.preserve_node_meta():
  1793. return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type]
  1794. with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
  1795. try:
  1796. graph = make_fx(
  1797. graph_with_interpreter,
  1798. decomposition_table=decomposition_table,
  1799. tracing_mode="real",
  1800. _allow_non_fake_inputs=True,
  1801. pre_dispatch=pre_dispatch,
  1802. _allow_fake_constant=False,
  1803. )(*example_fake_inputs)
  1804. except CondOpArgsMismatchError as e:
  1805. # Wrap the internal error to the user-facing error
  1806. raise UserError( # noqa: B904
  1807. UserErrorType.DYNAMIC_CONTROL_FLOW,
  1808. str(e),
  1809. case_name="cond_operands",
  1810. )
  1811. assert graph is not None
  1812. for node in graph.graph.find_nodes(op="get_attr"):
  1813. if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type]
  1814. node.meta["val"] = fake_mode.from_tensor(
  1815. getattr(graph, node.target), # type: ignore[arg-type]
  1816. static_shapes=True,
  1817. )
  1818. if same_signature:
  1819. flat_args_dynamic_dims = [
  1820. {
  1821. c.dim
  1822. for c in (constraints or ())
  1823. if (
  1824. c.t_id == id(x)
  1825. and not isinstance(c, _RelaxedConstraint)
  1826. and c.constraint_range.vr.lower != c.constraint_range.vr.upper
  1827. )
  1828. }
  1829. for x in flat_args
  1830. ]
  1831. graph = rewrite_signature(
  1832. original_signature,
  1833. graph,
  1834. fake_mode,
  1835. flat_args,
  1836. in_spec,
  1837. example_fake_inputs,
  1838. graph_captured_input, # type: ignore[arg-type]
  1839. graph_captured_result,
  1840. result_traced, # type: ignore[possibly-undefined]
  1841. flat_args_dynamic_dims,
  1842. )
  1843. return ExportResult(graph, out_guards)
  1844. if extra_args or extra_kwargs:
  1845. warnings.warn(
  1846. "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. "
  1847. "If you don't migrate, we may break your export call in the future if your user defined kwargs "
  1848. "conflict with future kwargs added to export(f).",
  1849. FutureWarning,
  1850. stacklevel=2,
  1851. )
  1852. return inner(*extra_args, **extra_kwargs) # type: ignore[return-value]
  1853. else:
  1854. return inner
  1855. def optimize_assert(*args: Any, **kwargs: Any) -> OptimizeContext:
  1856. if "rebuild_ctx" in kwargs and kwargs["rebuild_ctx"] is not None:
  1857. # called from optimize
  1858. rebuild_ctx = kwargs["rebuild_ctx"]
  1859. del kwargs["rebuild_ctx"]
  1860. else:
  1861. def rebuild_ctx() -> OptimizeContext:
  1862. return optimize_assert(*args, **kwargs)
  1863. return _optimize_assert(rebuild_ctx, *args, **kwargs)
  1864. def _optimize_assert(
  1865. rebuild_ctx: Callable[[], OptimizeContext],
  1866. backend: Union[str, Callable[..., Any], None],
  1867. *,
  1868. hooks: Hooks = Hooks(None, None, None),
  1869. export: bool = False,
  1870. export_constraints: Optional[Any] = None,
  1871. dynamic: Optional[bool] = None,
  1872. package: Optional[CompilePackage] = None,
  1873. ) -> OptimizeContext:
  1874. """
  1875. Guarantees single-graph capture.
  1876. The same as `torch._dynamo.optimize(backend)` but ignores
  1877. symbolic_convert.error_on_graph_break setting.
  1878. Used for fullgraph=True and export, since we must always error on graph breaks and ignore
  1879. symbolic_convert.error_on_graph_break. Can also be used for testing.
  1880. """
  1881. backend = get_compiler_fn(backend)
  1882. # Find if backend has any extra context manager
  1883. backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
  1884. if config.caching_precompile and package is None:
  1885. # Create an uninitialized package that will be set/filled by
  1886. # _OptimizeContext.__call__
  1887. # We need to instantiate the object here because the same CompilePackage
  1888. # needs to be shared between convert_frame_assert
  1889. # and OptimizeContext.
  1890. from .package import CompilePackage
  1891. package = CompilePackage(fn=None, dynamo=None, ignore_inlined_sources=False)
  1892. return _optimize_catch_errors(
  1893. convert_frame.convert_frame_assert(
  1894. backend,
  1895. export=export,
  1896. export_constraints=export_constraints,
  1897. package=package,
  1898. ),
  1899. hooks,
  1900. backend_ctx_ctor,
  1901. fullgraph=True,
  1902. export=export,
  1903. dynamic=dynamic,
  1904. rebuild_ctx=rebuild_ctx,
  1905. package=package,
  1906. )
  1907. class TorchPatcher:
  1908. @staticmethod
  1909. @functools.cache
  1910. def patch() -> None:
  1911. # A better way to disable the following would be decorate the source
  1912. # functions with @torch._disable_dynamo. However, this causes issues
  1913. # with torch.deploy internally.
  1914. from .decorators import disable
  1915. torch.jit.trace = disable(
  1916. torch.jit.trace, reason="tracing into TorchScript not fully supported"
  1917. )
  1918. torch.jit.trace_module = disable(
  1919. torch.jit.trace_module,
  1920. reason="tracing into TorchScript not fully supported",
  1921. )
  1922. torch.jit._get_trace_graph = disable(
  1923. torch.jit._get_trace_graph,
  1924. reason="tracing into TorchScript not fully supported",
  1925. )
  1926. torch.fx._symbolic_trace.Tracer.trace = disable(
  1927. torch.fx._symbolic_trace.Tracer.trace,
  1928. reason="tracing into FX not fully supported",
  1929. )
  1930. torch.distributions.Distribution.set_default_validate_args(False)
  1931. from torch.optim import (
  1932. adadelta,
  1933. adagrad,
  1934. adam,
  1935. adamax,
  1936. adamw,
  1937. asgd,
  1938. lbfgs,
  1939. nadam,
  1940. radam,
  1941. rmsprop,
  1942. rprop,
  1943. sgd,
  1944. sparse_adam,
  1945. )
  1946. optimizer_modules = {
  1947. adadelta,
  1948. adagrad,
  1949. adam,
  1950. adamax,
  1951. adamw,
  1952. asgd,
  1953. lbfgs,
  1954. nadam,
  1955. radam,
  1956. rmsprop,
  1957. rprop,
  1958. sgd,
  1959. sparse_adam,
  1960. }
  1961. for opt_mod in optimizer_modules:
  1962. opt_name = opt_mod.__name__.split(".")[-1]
  1963. fused_fn_name = f"_fused_{opt_name}"
  1964. if hasattr(opt_mod, fused_fn_name):
  1965. setattr(
  1966. opt_mod,
  1967. fused_fn_name,
  1968. disable(
  1969. getattr(opt_mod, fused_fn_name),
  1970. reason="don't trace into fused optimizer",
  1971. ),
  1972. )
  1973. optimizer_classes = [
  1974. opt
  1975. for opt in torch.optim.__dict__.values()
  1976. if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
  1977. ]
  1978. # Note: we don't support sparsity or tracing through backwards
  1979. excluded_optimizer_classes = {
  1980. torch.optim.SparseAdam,
  1981. torch.optim.LBFGS,
  1982. }
  1983. for opt in optimizer_classes:
  1984. if opt in excluded_optimizer_classes:
  1985. opt.step = disable(
  1986. opt.step, reason=f"optimizer {opt} step not supported"
  1987. )
  1988. if hasattr(opt, "_init_group"):
  1989. opt._init_group = disable(
  1990. opt._init_group, reason=f"optimizer {opt} _init_group not supported"
  1991. )
  1992. @staticmethod
  1993. def suppress_torch_distributed_warnings(
  1994. fn: Callable[..., Any],
  1995. ) -> Callable[..., Any]:
  1996. def inner_fn(*args: Any, **kwargs: Any) -> Any:
  1997. with torch._logging.hide_warnings(
  1998. torch._logging._internal.user_warning_filter
  1999. ):
  2000. return fn(*args, **kwargs)
  2001. return inner_fn
  2002. def skip_code(code: types.CodeType) -> None:
  2003. set_code_exec_strategy(
  2004. code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
  2005. )