convert_frame.py 69 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884
  1. """
  2. This module implements TorchDynamo's core frame conversion functionality, transforming Python
  3. frames into FX graphs. It handles:
  4. - Frame analysis and bytecode transformation
  5. - Guard creation and management for dynamic behaviors
  6. - Cache management for recompilation
  7. - Error handling and fallback mechanisms
  8. Key classes:
  9. - ConvertFrame: Main entry point for frame conversion with error handling
  10. - ConvertFrameAssert: Implements core frame to graph conversion logic
  11. - Tracker: Tracks input/output code objects during conversion
  12. - CatchErrorsWrapper: Provides error handling and suppression logic
  13. The conversion process preserves program semantics while enabling optimizations
  14. through torch.compile() and related systems.
  15. NOTE: _torchdynamo_orig_backend is used for convert frame wrappers to identify the inner wrapped function.
  16. By going down the _torchdynamo_orig_backend chain, one can recover the original unwrapped backend,
  17. which is checked for during the Dynamo cache lookup.
  18. """
  19. from __future__ import annotations
  20. import collections
  21. import contextlib
  22. import cProfile
  23. import dis
  24. import functools
  25. import gc
  26. import itertools
  27. import logging
  28. import os
  29. import pstats
  30. import random
  31. import subprocess
  32. import sys
  33. import threading
  34. import time
  35. import traceback
  36. import types
  37. import typing
  38. import weakref
  39. from dataclasses import dataclass
  40. from pathlib import Path
  41. from types import CellType, CodeType, FunctionType, ModuleType
  42. from typing import Any, Callable, Optional, TypeVar, Union
  43. from typing_extensions import ParamSpec
  44. from weakref import ReferenceType
  45. import torch
  46. import torch._logging
  47. from torch._C._dynamo.guards import GlobalStateGuard
  48. from torch._dynamo.callback import CallbackTrigger
  49. from torch._dynamo.distributed import get_compile_pg
  50. from torch._dynamo.symbolic_convert import TensorifyState
  51. from torch._guards import compile_context, CompileContext, CompileId, tracing
  52. from torch._logging import structured
  53. from torch._utils_internal import (
  54. compile_time_strobelight_meta,
  55. justknobs_check,
  56. maybe_upload_prof_stats_to_manifold,
  57. signpost_event,
  58. )
  59. from torch.fx._lazy_graph_module import _use_lazy_graph_module
  60. from torch.fx.experimental.symbolic_shapes import (
  61. ConstraintViolationError,
  62. GuardOnDataDependentSymNode,
  63. )
  64. from torch.fx.graph_module import _forward_from_src as original_forward_from_src
  65. from torch.monitor import _WaitCounter
  66. from torch.nn.parallel.distributed import DistributedDataParallel
  67. from torch.utils._python_dispatch import (
  68. _disable_current_modes,
  69. is_in_any_mode_without_ignore_compile_internals,
  70. is_in_torch_dispatch_mode,
  71. )
  72. from torch.utils._traceback import CapturedTraceback, format_traceback_short
  73. from . import config, decorators, exc, graph_break_hints, trace_rules
  74. from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
  75. from .bytecode_transformation import (
  76. check_inst_exn_tab_entries_valid,
  77. Instruction,
  78. is_generator,
  79. propagate_inst_exn_table_entries,
  80. transform_code_object,
  81. )
  82. from .cache_size import (
  83. CacheSizeRelevantForFrame,
  84. compute_cache_size,
  85. exceeds_recompile_limit,
  86. is_recompilation,
  87. )
  88. from .eval_frame import (
  89. always_optimize_code_objects,
  90. dynamo_tls,
  91. skip_code,
  92. TorchPatcher,
  93. )
  94. from .exc import (
  95. augment_exc_message,
  96. BackendCompilerFailed,
  97. FailOnRecompileLimitHit,
  98. format_error_msg,
  99. InternalTorchDynamoError,
  100. PackageError,
  101. RecompileLimitExceeded,
  102. ResumePrologueTracingError,
  103. ShortenTraceback,
  104. SkipCodeRecursiveException,
  105. TorchRuntimeError,
  106. UncapturedHigherOrderOpError,
  107. unimplemented_v2,
  108. Unsupported,
  109. )
  110. from .guards import (
  111. CheckFunctionManager,
  112. get_and_maybe_log_recompilation_reasons,
  113. GuardedCode,
  114. )
  115. from .hooks import Hooks
  116. from .output_graph import DynamoTracerOutput
  117. from .pgo import log_frame_dynamic_whitelist, put_code_state
  118. from .replay_record import ExecutionRecord
  119. from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
  120. from .symbolic_convert import (
  121. DistributedState,
  122. ExceptionStack,
  123. InstructionTranslator,
  124. LocalState,
  125. SpeculationLog,
  126. )
  127. from .trace_rules import is_numpy
  128. from .types import ConvertFrameReturn, FrameAction, FrameExecStrategy, wrap_guarded_code
  129. from .utils import (
  130. _get_error_on_graph_break,
  131. chromium_event_timed,
  132. CleanupManager,
  133. CompileTimeInstructionCounter,
  134. counters,
  135. dynamo_timed,
  136. format_bytecode,
  137. gen_record_file_name,
  138. get_hook_for_recompile_user_context,
  139. get_metrics_context,
  140. increment_frame,
  141. is_namedtuple,
  142. istype,
  143. LazyString,
  144. maybe_disable_inference_mode,
  145. maybe_disable_inference_mode_for_fake_prop,
  146. orig_code_map,
  147. reset_graph_break_dup_checker,
  148. setup_compile_debug,
  149. to_int_us,
  150. troubleshooting_url,
  151. write_record_to_file,
  152. )
  153. from .variables.torch_function import torch_function_mode_stack_state_mgr
  154. np: Optional[ModuleType]
  155. try:
  156. import numpy as np
  157. except ModuleNotFoundError:
  158. np = None
  159. if typing.TYPE_CHECKING:
  160. from .backends.registry import CompilerFn
  161. from .package import CompilePackage
  162. from .repro.after_dynamo import WrapBackendDebug
  163. from .types import BytecodeHook, CacheEntry, DynamoFrameType
  164. from .variables.builder import FrameStateSizeEntry
  165. log = logging.getLogger(__name__)
  166. bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
  167. graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
  168. compile_lock = threading.RLock()
  169. _T = TypeVar("_T")
  170. _P = ParamSpec("_P")
  171. class TODO_UNKNOWN:
  172. pass
  173. class Tracker:
  174. def __init__(self) -> None:
  175. self.seen: list[ReferenceType[CodeType]] = []
  176. self.seen_ids: set[int] = set()
  177. def add(self, strong_obj: CodeType) -> None:
  178. idx = id(strong_obj)
  179. if idx not in self.seen_ids:
  180. obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx))
  181. self.seen.append(obj)
  182. self.seen_ids.add(idx)
  183. def __contains__(self, item: CodeType) -> bool:
  184. return id(item) in self.seen_ids
  185. def clear(self) -> None:
  186. self.seen.clear()
  187. self.seen_ids.clear()
  188. input_codes = Tracker()
  189. output_codes = Tracker()
  190. initial_global_state: Optional[GlobalStateGuard] = None
  191. @functools.wraps(original_forward_from_src)
  192. def fx_forward_from_src_skip_result(
  193. src: str, globals: dict[str, Any], co_fields: Optional[dict[str, str]] = None
  194. ) -> FunctionType:
  195. # we monkey patch FX to prevent infinite loop of trying to convert
  196. # our generated code
  197. result = original_forward_from_src(src, globals, co_fields)
  198. skip_code(result.__code__)
  199. return result
  200. def log_dynamo_start(code: CodeType, skip: int = 0) -> list[str]:
  201. convert_frame_intern = structured.intern_string(__file__)
  202. captured_tb = CapturedTraceback.extract(skip=4 + skip).summary()
  203. frames_interned = structured.from_traceback(captured_tb)
  204. # Extract and filter the stack
  205. stack = list(
  206. itertools.takewhile(
  207. lambda f: f["filename"] != convert_frame_intern,
  208. frames_interned,
  209. )
  210. ) + [
  211. {
  212. "line": code.co_firstlineno,
  213. "name": code.co_name,
  214. "filename": structured.intern_string(code.co_filename),
  215. }
  216. ]
  217. # Initialize the ChromiumEventLogger on start
  218. torch._logging.trace_structured(
  219. "dynamo_start",
  220. lambda: {"stack": stack},
  221. )
  222. # Capture stack separately without using from_traceback to get the actual filenames
  223. stack_strings = [
  224. f"Line: {frame.lineno}, Name: {frame.name}, Filename: {frame.filename}"
  225. for frame in captured_tb
  226. if frame.filename != convert_frame_intern
  227. ] + [
  228. f"Line: {code.co_firstlineno}, Name: {code.co_name}, Filename: {code.co_filename}"
  229. ]
  230. return stack_strings
  231. def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  232. """
  233. Context manager to:
  234. 1) Save/restore torch.is_grad_enabled() state
  235. 2) Save/restore python random state
  236. 3) Save/restore torch random state
  237. 4) Monkey patch torch.fx.graph_module._forward_from_src
  238. """
  239. @functools.wraps(fn)
  240. def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  241. guards = GlobalStateGuard()
  242. prior_grad_mode = torch.is_grad_enabled()
  243. # Just in case we get left in a bad dispatch state we want to restore
  244. # it. This can happen because the dispatch bits aren't a true
  245. # stack/counter - so we can't just increment/decrement them as we enter
  246. # and leave.
  247. with (
  248. torch._C._PreserveDispatchKeyGuard(),
  249. maybe_disable_inference_mode(),
  250. maybe_disable_inference_mode_for_fake_prop(),
  251. ):
  252. prior_inference_mode = torch.is_inference_mode_enabled()
  253. prior_deterministic = torch.are_deterministic_algorithms_enabled()
  254. prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
  255. prior_mobile_allocator_state = (
  256. torch._C._is_default_mobile_cpu_allocator_set()
  257. )
  258. py_rng_state = random.getstate()
  259. prior_dtype = torch.get_default_dtype()
  260. torch_rng_state = torch.random.get_rng_state()
  261. cuda_rng_state = None
  262. if torch.cuda.is_available():
  263. cuda_rng_state = torch.cuda.get_rng_state()
  264. cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter(
  265. "cuda", "matmul"
  266. )
  267. prior_fwd_from_src = torch.fx.graph_module._forward_from_src
  268. torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
  269. cleanup = setup_compile_debug()
  270. exit_stack = contextlib.ExitStack()
  271. exit_stack.enter_context(
  272. torch.fx._symbolic_trace._maybe_revert_all_patches()
  273. )
  274. exit_stack.enter_context(torch_function_mode_stack_state_mgr)
  275. try:
  276. return fn(*args, **kwargs)
  277. finally:
  278. cleanup.close()
  279. assert torch._C._len_torch_function_stack() == 0, (
  280. "Torch function mode stack state changed while dynamo tracing, please report a bug"
  281. )
  282. exit_stack.close()
  283. torch._C._set_grad_enabled(prior_grad_mode)
  284. torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
  285. torch.use_deterministic_algorithms(
  286. prior_deterministic, warn_only=prior_warn_only
  287. )
  288. random.setstate(py_rng_state)
  289. torch.random.set_rng_state(torch_rng_state)
  290. torch.set_default_dtype(prior_dtype)
  291. curr_mobile_allocator_state = (
  292. torch._C._is_default_mobile_cpu_allocator_set()
  293. )
  294. if prior_mobile_allocator_state != curr_mobile_allocator_state:
  295. torch._C._unset_default_mobile_cpu_allocator()
  296. if cuda_rng_state is not None:
  297. torch.cuda.set_rng_state(cuda_rng_state)
  298. torch._C._set_fp32_precision_setter(
  299. "cuda", "matmul", cuda_matmul_fp32_prec
  300. )
  301. torch.fx.graph_module._forward_from_src = prior_fwd_from_src
  302. assert guards.check(), (
  303. f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"
  304. )
  305. _fn._torchdynamo_orig_backend = fn # type: ignore[attr-defined]
  306. return _fn
  307. @TorchPatcher.suppress_torch_distributed_warnings
  308. def has_tensor_in_frame(frame: DynamoFrameType) -> bool:
  309. """Check if the frame has torch.* related bits"""
  310. # Check if the function was decorated using torch._dynamo.optimize
  311. if frame.f_code in always_optimize_code_objects:
  312. return True
  313. # Check if there is global import of torch.*
  314. for co_name in frame.f_code.co_names:
  315. if co_name in frame.f_globals:
  316. obj = frame.f_globals[co_name]
  317. if isinstance(obj, ModuleType) and (
  318. obj.__name__.startswith("torch.") or obj is torch
  319. ):
  320. return True
  321. # ... or a global import of numpy.*
  322. if np and config.trace_numpy and (obj is np or is_numpy(obj)):
  323. return True
  324. seen_ids: dict[int, bool] = {}
  325. def has_tensor(obj: object) -> bool:
  326. """Recursively check if the obj has a tensor"""
  327. obj_id = id(obj)
  328. if obj_id in seen_ids:
  329. return seen_ids[obj_id]
  330. seen_ids[obj_id] = False
  331. if isinstance(obj, (torch.Tensor, torch.nn.Module)) or (
  332. istype(obj, type) and issubclass(obj, torch.nn.Module)
  333. ):
  334. seen_ids[obj_id] = True
  335. return seen_ids[obj_id]
  336. elif (
  337. config.trace_numpy
  338. and np
  339. and (istype(obj, np.ndarray) or isinstance(obj, np.generic))
  340. ):
  341. seen_ids[obj_id] = True
  342. return seen_ids[obj_id]
  343. elif istype(obj, (list, tuple)):
  344. seen_ids[obj_id] = any(has_tensor(v) for v in obj)
  345. return seen_ids[obj_id]
  346. elif istype(obj, dict):
  347. # Some packages like pytest can be updated during runtime. So, make a
  348. # copy of values to avoid issues like "RuntimeError: dictionary
  349. # changed size during iteration"
  350. values = list(obj.values())
  351. seen_ids[obj_id] = any(has_tensor(v) for v in values)
  352. return seen_ids[obj_id]
  353. elif istype(obj, (str, int, float, type(None), bool)):
  354. seen_ids[obj_id] = False
  355. return seen_ids[obj_id]
  356. elif is_namedtuple(obj) and hasattr(obj, "_fields"):
  357. seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields)
  358. return seen_ids[obj_id]
  359. else:
  360. # if config.debug:
  361. # print(
  362. # f"Assuming that object of type {type(obj)} does not have a tensor"
  363. # )
  364. return False
  365. # Check if the passed arguments are of type Tensor
  366. for value in frame.f_locals.values():
  367. if has_tensor(value):
  368. return True
  369. log.debug(
  370. "skipping because no torch.* %s \
  371. %s %s",
  372. frame.f_code.co_name,
  373. frame.f_code.co_filename,
  374. frame.f_code.co_firstlineno,
  375. )
  376. return False
  377. def exception_handler(
  378. e: Exception,
  379. code: CodeType,
  380. frame: Optional[DynamoFrameType] = None,
  381. export: bool = False,
  382. ) -> None:
  383. record_filename = None
  384. if hasattr(e, "exec_record"):
  385. record_filename = gen_record_file_name(e, code)
  386. write_record_to_file(record_filename, e.exec_record)
  387. e.record_filename = record_filename # type: ignore[attr-defined]
  388. augment_exc_message(e, export=export)
  389. FRAME_COUNTER = 0
  390. FRAME_COMPILE_COUNTER: typing.Counter[Union[int, FrameStateSizeEntry]] = (
  391. collections.Counter()
  392. )
  393. def maybe_cprofile(func: Callable[_P, _T]) -> Callable[_P, _T]:
  394. if config.cprofile:
  395. return cprofile_wrapper(func)
  396. return func
  397. def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
  398. @functools.wraps(func)
  399. def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  400. trace_id = CompileContext.current_trace_id()
  401. assert trace_id, "Trace id is None"
  402. profile_path = Path(
  403. f"/tmp/{func.__name__}_{str(trace_id).replace('/', '_')}.profile"
  404. )
  405. prof = cProfile.Profile()
  406. try:
  407. prof.enable()
  408. start_ts = time.time()
  409. retval = prof.runcall(func, *args, **kwargs)
  410. profile_latency = time.time() - start_ts
  411. prof.disable()
  412. except ValueError:
  413. log.exception("failed to enable cProfile")
  414. profile_latency = 0
  415. retval = func(*args, **kwargs)
  416. log.warning(
  417. "### Cprofile for %s trace id [%s] took %.3f seconds ###",
  418. func.__name__,
  419. trace_id,
  420. profile_latency,
  421. )
  422. ps = pstats.Stats(prof)
  423. try:
  424. prof.dump_stats(profile_path)
  425. except OSError:
  426. log.exception("Cannot write to %s", profile_path)
  427. log.warning("Raw profile at %s", profile_path)
  428. svg_path = profile_path.with_suffix(".svg")
  429. try:
  430. gprof2dot_process = subprocess.Popen(
  431. [
  432. "gprof2dot",
  433. "-f",
  434. "pstats",
  435. "--node-label=total-time-percentage",
  436. "--node-label=self-time-percentage",
  437. "--node-label=total-time",
  438. str(profile_path),
  439. ],
  440. stdout=subprocess.PIPE,
  441. )
  442. subprocess.check_call(
  443. ["dot", "-Tsvg", "-o", str(svg_path)],
  444. stdin=gprof2dot_process.stdout,
  445. )
  446. log.warning("Generated SVG from profile at %s", svg_path)
  447. except FileNotFoundError:
  448. log.warning(
  449. "Failed to generate SVG from profile -- dumping stats instead."
  450. "Try installing gprof2dot and dot for a better visualization"
  451. )
  452. ps.sort_stats(pstats.SortKey.TIME).print_stats(20)
  453. ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20)
  454. if manifold_link := maybe_upload_prof_stats_to_manifold(
  455. str(profile_path)
  456. ): # fb-only
  457. torch._logging.trace_structured(
  458. "link",
  459. lambda: {"name": "cprofile_manifold_url", "url": manifold_link},
  460. )
  461. return retval
  462. return profile_wrapper
  463. @dataclass
  464. class ConvertFrameBox:
  465. error_on_graph_break: Optional[bool] = None
  466. def get_compile_id(
  467. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  468. ) -> CompileId:
  469. global FRAME_COUNTER
  470. if "_id" not in frame_state:
  471. frame_state["_id"] = FRAME_COUNTER
  472. FRAME_COUNTER += 1
  473. frame_id = frame_state["_id"]
  474. assert isinstance(frame_id, int)
  475. frame_compile_id = FRAME_COMPILE_COUNTER[frame_id]
  476. FRAME_COMPILE_COUNTER[frame_id] += 1
  477. compiled_autograd_id = None
  478. if prior := CompileContext.current_compile_id():
  479. compiled_autograd_id = prior.compiled_autograd_id
  480. return CompileId(
  481. compiled_autograd_id=compiled_autograd_id,
  482. frame_id=frame_id,
  483. frame_compile_id=frame_compile_id,
  484. )
  485. class ConvertFrameAssert:
  486. def __init__(
  487. self,
  488. compiler_fn: CompilerFn,
  489. one_graph: bool = True,
  490. export: bool = False,
  491. export_constraints: Optional[typing.Never] = None,
  492. package: Optional[CompilePackage] = None,
  493. ) -> None:
  494. # assert export_constraints is None
  495. reset_graph_break_dup_checker()
  496. self._torchdynamo_orig_backend = compiler_fn
  497. self._one_graph = one_graph
  498. self._export = export
  499. self._export_constraints = export_constraints
  500. self._package = package
  501. self._box = ConvertFrameBox()
  502. @property
  503. def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]:
  504. return lambda backend: convert_frame_assert(
  505. backend,
  506. self._one_graph,
  507. self._export,
  508. self._export_constraints,
  509. )
  510. def __call__(
  511. self,
  512. frame: DynamoFrameType,
  513. cache_entry: Optional[CacheEntry],
  514. hooks: Hooks,
  515. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  516. *,
  517. skip: int = 0,
  518. ) -> ConvertFrameReturn:
  519. increment_frame()
  520. code = frame.f_code
  521. cache_size = compute_cache_size(frame, cache_entry)
  522. input_codes.add(code)
  523. if code in output_codes:
  524. return ConvertFrameReturn()
  525. if (
  526. os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION")
  527. and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name
  528. ):
  529. return ConvertFrameReturn()
  530. if code.co_name == "<genexpr>" and code.co_filename.endswith(
  531. (
  532. "transformers/file_utils.py",
  533. "transformers/utils/generic.py",
  534. "diffusers/utils/outputs.py",
  535. )
  536. ):
  537. # not needed, but cleans up torchbench error stats
  538. return ConvertFrameReturn()
  539. if code.co_name == "__setattr__":
  540. # setattr could be tricky to handle generally,
  541. # but also not likely useful to compile- skip the whole frame
  542. return ConvertFrameReturn()
  543. if code.co_name == "__init__" and code.co_filename.startswith(
  544. os.path.dirname(torch.optim.__file__)
  545. ):
  546. # optimizer support is still incomplete see
  547. # test_state_dict in test/dynamo/test_optimizers.py
  548. return ConvertFrameReturn()
  549. # Check if the frame is generated by an exec builtin call
  550. # TODO - Running exec generated frame seems propagates f_globals to the
  551. # next frames.
  552. if code.co_name == "<module>" and code.co_filename == "<string>":
  553. return ConvertFrameReturn()
  554. if (
  555. code.co_name == "<lambda>"
  556. and code.co_filename == "<string>"
  557. and not bool(frame.f_builtins)
  558. ):
  559. # namedtuple subclass constructor. Empty builtins cause issue with
  560. # len keyword in LIST_LEN guard.
  561. return ConvertFrameReturn()
  562. if is_generator(code):
  563. unimplemented_v2(
  564. gb_type="Attempt to trace generator",
  565. context="",
  566. explanation="Generators cannot be compiled directly with `torch.compile`.",
  567. hints=[
  568. "Call a generator from inside of a non-generator Python function and "
  569. "compile that function instead.",
  570. *graph_break_hints.FUNDAMENTAL,
  571. ],
  572. )
  573. if not has_tensor_in_frame(frame):
  574. return ConvertFrameReturn()
  575. # skip tracing non-recursive disabled functions
  576. # detect if the previous frame (non-convert_frame) is a non-recursive disable wrapper
  577. prev_frame = sys._getframe()
  578. while (
  579. prev_frame
  580. and "torch/_dynamo/convert_frame.py" in prev_frame.f_code.co_filename
  581. ):
  582. prev_frame = prev_frame.f_back # type: ignore[assignment]
  583. if (
  584. prev_frame
  585. and prev_frame.f_code is decorators._nonrecursive_disable_wrapper_code
  586. ):
  587. return ConvertFrameReturn(apply_to_code=False)
  588. global initial_global_state
  589. initial_global_state = GlobalStateGuard()
  590. compile_id = get_compile_id(frame_state)
  591. frame_id = compile_id.frame_id
  592. signpost_event(
  593. "dynamo",
  594. "_convert_frame_assert._compile",
  595. {
  596. "co_name": code.co_name,
  597. "frame_id": frame_id,
  598. "compile_id": str(compile_id),
  599. "co_filename": code.co_filename,
  600. "co_firstlineno": code.co_firstlineno,
  601. "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
  602. "accumulated_cache_size": cache_size.num_cache_entries,
  603. },
  604. )
  605. # Record traced frames, skipping Dynamo generated ones.
  606. if not code.co_name.startswith(TORCH_DYNAMO_RESUME_IN_PREFIX):
  607. info = f"{code.co_name} {code.co_filename}:{code.co_firstlineno}"
  608. dynamo_tls.traced_frame_infos.append(info)
  609. with compile_context(CompileContext(compile_id)):
  610. result = _compile(
  611. frame.f_code,
  612. frame.f_globals,
  613. frame.f_locals,
  614. frame.f_builtins,
  615. frame.closure,
  616. self._torchdynamo_orig_backend,
  617. self._one_graph,
  618. self._export,
  619. self._export_constraints,
  620. hooks,
  621. cache_entry,
  622. cache_size,
  623. frame,
  624. frame_state=frame_state,
  625. compile_id=compile_id,
  626. skip=skip + 1,
  627. package=self._package,
  628. convert_frame_box=self._box,
  629. )
  630. if config.caching_precompile and self._package is not None:
  631. from .package import DynamoCache
  632. # Record that the dynamo package has changed
  633. DynamoCache.record_package(self._package)
  634. return result
  635. def convert_frame_assert(
  636. compiler_fn: CompilerFn,
  637. one_graph: bool = True,
  638. export: bool = False,
  639. export_constraints: Optional[typing.Never] = None,
  640. package: Optional[CompilePackage] = None,
  641. ) -> ConvertFrameAssert:
  642. """Fully convert a frame into an FX graph, raising an exception if we fail."""
  643. return ConvertFrameAssert(
  644. compiler_fn, one_graph, export, export_constraints, package
  645. )
  646. from collections import OrderedDict
  647. from torch.utils.hooks import RemovableHandle
  648. # we have to use `OrderedDict` to make `RemovableHandle` work.
  649. _bytecode_hooks: dict[int, BytecodeHook] = OrderedDict()
  650. def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:
  651. """Register hooks for bytecode generated by Dynamo. The hook can do some
  652. logging, as well as return a new code object to be used. Please refer
  653. to `BytecodeHook` for the hook signature.
  654. """
  655. handle = RemovableHandle(_bytecode_hooks)
  656. _bytecode_hooks[handle.id] = hook
  657. return handle
  658. @preserve_global_state
  659. def trace_frame(
  660. code: types.CodeType,
  661. globals: dict[str, object],
  662. locals: dict[str, object],
  663. builtins: dict[str, object],
  664. closure: tuple[CellType],
  665. compiler_fn: CompilerFn,
  666. tf_mode_stack: list[torch.overrides.TorchFunctionMode],
  667. one_graph: bool,
  668. speculation_log: SpeculationLog,
  669. instructions: list[Instruction],
  670. code_options: dict[str, object],
  671. *,
  672. export: bool = False,
  673. export_constraints: Optional[typing.Never] = None,
  674. frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
  675. distributed_state: Optional[DistributedState] = None,
  676. package: Optional[CompilePackage] = None,
  677. ) -> DynamoTracerOutput:
  678. from torch.fx.experimental.validator import bisect, translation_validation_enabled
  679. speculation_log.restart() # type: ignore[has-type]
  680. exn_vt_stack = ExceptionStack()
  681. tracer = InstructionTranslator(
  682. instructions,
  683. code,
  684. locals,
  685. globals,
  686. builtins,
  687. closure,
  688. tf_mode_stack,
  689. code_options,
  690. compiler_fn,
  691. one_graph,
  692. export,
  693. export_constraints,
  694. frame_state=frame_state,
  695. speculation_log=speculation_log, # type: ignore[has-type]
  696. exn_vt_stack=exn_vt_stack,
  697. distributed_state=distributed_state, # type: ignore[has-type]
  698. package=package,
  699. )
  700. def run_tracer() -> None:
  701. try:
  702. tracer.output.mark_bytecode_tracing_start()
  703. with tracing(tracer.output.tracing_context), tracer.set_current_tx():
  704. tracer.run()
  705. except exc.UnspecializeRestartAnalysis:
  706. speculation_log.clear() # type: ignore[has-type]
  707. raise
  708. except (
  709. exc.SpeculationRestartAnalysis,
  710. exc.TensorifyScalarRestartAnalysis,
  711. exc.SkipFrame,
  712. ):
  713. raise
  714. except Exception:
  715. if translation_validation_enabled():
  716. bisect(tracer.output.shape_env)
  717. raise
  718. finally:
  719. tracer.output.call_cleanup_hooks()
  720. try:
  721. run_tracer()
  722. tracer_output = DynamoTracerOutput(tracer)
  723. output = tracer_output.output_graph
  724. assert output is not None
  725. assert output.output_instructions
  726. instructions[:] = output.output_instructions
  727. code_options.update(output.code_options)
  728. propagate_inst_exn_table_entries(instructions)
  729. check_inst_exn_tab_entries_valid(instructions)
  730. instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
  731. except Exception as e:
  732. e._torch_dynamo_tracer_output = DynamoTracerOutput(tracer, error=True) # type: ignore[attr-defined]
  733. raise
  734. return tracer_output
  735. @dataclass
  736. class DynamoOutput:
  737. """
  738. Represents the core data returned from a single dynamo run, including:
  739. - Guards, wrapped inside tracer_output.output_graph.guards
  740. - Generated bytecode
  741. - Other information needed for compilation.
  742. This data structure should capture all the "interesting" information dynamo
  743. produces on the frontend side before it enters user backend.
  744. """
  745. tracer_output: DynamoTracerOutput
  746. bytecode: types.CodeType
  747. last_attempt_start_time: Optional[float]
  748. def build_guards(
  749. self,
  750. code: types.CodeType,
  751. hooks: Optional[Hooks] = None,
  752. save: bool = False,
  753. cache_entry: Optional[CacheEntry] = None,
  754. strict_error: bool = False,
  755. ) -> CheckFunctionManager:
  756. assert self.tracer_output.output_graph is not None
  757. return CheckFunctionManager(
  758. code,
  759. self.tracer_output.output_graph,
  760. cache_entry,
  761. hooks.guard_fail_fn if hooks else None,
  762. hooks.guard_filter_fn if hooks else None,
  763. save_guards=save,
  764. strict_error=strict_error,
  765. )
  766. @dataclass
  767. class BackendInput:
  768. """
  769. Represents core data structure that dynamo will pass to a backend, including:
  770. - Graph module
  771. - Example inputs
  772. - The FakeTensorMode used for compiling graph.
  773. This data structure should capture all the information dynamo produces
  774. on for the user backend.
  775. """
  776. backend_id: str
  777. graph_module: torch.fx.GraphModule
  778. example_inputs: Any
  779. fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
  780. @dataclass
  781. class CaptureOutput:
  782. """
  783. CaptureOutput should represent all the information produced from torch
  784. compiler for a single graph capture. This intends to be consumed by
  785. various compiler frontends so that we can share as much compiler internals
  786. as possible and avoid great divergence between different stacks.
  787. This data structure should eventually contain all the information compiler
  788. produces as more refactors happens to converge different compiler
  789. frontends.
  790. """
  791. dynamo_output: DynamoOutput
  792. backend_input: BackendInput
  793. @dataclass
  794. class FrameInfo:
  795. code: types.CodeType
  796. globals: dict[str, object]
  797. locals: dict[str, object]
  798. builtins: dict[str, object]
  799. closure: tuple[CellType]
  800. def fullgraph_capture(
  801. frame: FrameInfo, *, _is_export_deprecated_do_not_use: bool = False
  802. ) -> CaptureOutput:
  803. """
  804. A standalone function which takes a frame and returns dynamo captured graph
  805. plus other important compile information. This should serve as the common
  806. interface for different torch compiler AOT frontengs (e.g. precompile, export).
  807. Note that this function doesn't apply context managers like metrics context
  808. or compile id, and the expectation is that the caller will apply them depending
  809. on the use case.
  810. The CaptureOutput is separated into two parts:
  811. 1. Dynamo specific information from DynamoOutput, which includes:
  812. - guards
  813. - generated bytecode
  814. - other information tracked by OutputGraph.
  815. 2. Backend specific information (indexed by unique backend id) such as:
  816. - fx graph
  817. - example inputs
  818. """
  819. from torch._guards import TracingContext
  820. backend_input: Optional[BackendInput] = None
  821. def fullgraph_compiler(
  822. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  823. ) -> torch.fx.GraphModule:
  824. nonlocal backend_input
  825. fake_mode = TracingContext.get().fake_mode
  826. assert fake_mode is not None
  827. assert isinstance(gm.meta["backend_id"], str)
  828. backend_input = BackendInput(
  829. gm.meta["backend_id"], gm, example_inputs, fake_mode
  830. )
  831. return gm
  832. try:
  833. dynamo_output = compile_frame(
  834. frame.code,
  835. frame.globals,
  836. frame.locals,
  837. frame.builtins,
  838. frame.closure,
  839. compiler_fn=fullgraph_compiler,
  840. export=_is_export_deprecated_do_not_use,
  841. one_graph=True,
  842. restart_reasons=set(),
  843. )
  844. # https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/eval_frame.py#L831
  845. except Unsupported as e:
  846. augment_exc_message(e)
  847. if config.verbose:
  848. raise
  849. # strip internal tracebacks from causes
  850. cur_exn: BaseException = e
  851. while cur_exn.__cause__ is not None:
  852. cur_exn.__cause__.with_traceback(None)
  853. cur_exn = cur_exn.__cause__
  854. raise e.with_traceback(None) from e.__cause__ # User compiler error
  855. assert backend_input is not None
  856. return CaptureOutput(dynamo_output, backend_input)
  857. def compile_frame( # type: ignore[return]
  858. code: types.CodeType,
  859. globals: dict[str, object],
  860. locals: dict[str, object],
  861. builtins: dict[str, object],
  862. closure: tuple[CellType],
  863. compiler_fn: CompilerFn,
  864. one_graph: bool,
  865. restart_reasons: set[str],
  866. *,
  867. export: bool = False,
  868. export_constraints: Optional[typing.Never] = None,
  869. frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
  870. distributed_state: Optional[DistributedState] = None,
  871. package: Optional[CompilePackage] = None,
  872. ) -> DynamoOutput:
  873. """
  874. A helper function taking a frame and backend, then return the generated bytecode
  875. and guards as a common data structure.
  876. This is a shared interface for multiple compiler frontends (e.g. torch.compile,
  877. torch.export) that needs to capture a graph out of python code.
  878. """
  879. # This is shared across restarts
  880. speculation_log = SpeculationLog()
  881. def transform(
  882. instructions: list[Instruction], code_options: dict[str, object]
  883. ) -> DynamoTracerOutput:
  884. tf_mode_stack: list[torch.overrides.TorchFunctionMode] = (
  885. torch.overrides._get_current_function_mode_stack()
  886. )
  887. tracer_output = trace_frame(
  888. code,
  889. globals,
  890. locals,
  891. builtins,
  892. closure,
  893. compiler_fn,
  894. tf_mode_stack,
  895. one_graph,
  896. speculation_log,
  897. instructions,
  898. code_options,
  899. export=export,
  900. export_constraints=export_constraints,
  901. frame_state=frame_state,
  902. distributed_state=distributed_state,
  903. package=package,
  904. )
  905. assert tracer_output is not None
  906. return tracer_output
  907. last_attempt_start_time = None
  908. for attempt in itertools.count():
  909. CompileContext.get().attempt = attempt
  910. try:
  911. with dynamo_timed(f"compile_attempt_{attempt}", log_pt2_compile_event=True):
  912. bytecode, tracer_output = transform_code_object(code, transform)
  913. assert tracer_output is not None
  914. return DynamoOutput(
  915. tracer_output=tracer_output,
  916. bytecode=bytecode,
  917. last_attempt_start_time=last_attempt_start_time,
  918. )
  919. except exc.RestartAnalysis as e:
  920. if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
  921. TensorifyState.clear()
  922. log.info(
  923. "Restarting analysis due to %s",
  924. LazyString(format_traceback_short, e.__traceback__),
  925. )
  926. # If restart reason is None just log the type of the exception
  927. restart_reasons.add(e.restart_reason or str(type(e)))
  928. # We now have a new "last attempt", reset the clock
  929. last_attempt_start_time = time.time()
  930. if attempt > 100:
  931. unimplemented_v2(
  932. gb_type="Excessive RestartAnalysis() calls",
  933. context="",
  934. explanation="Dynamo attempted to trace the same frame 100+ times. "
  935. "Giving up on compiling as the compile time tradeoff is likely not "
  936. "worth the performance gain.",
  937. hints=[],
  938. )
  939. except exc.SkipFrame as e:
  940. if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
  941. TensorifyState.clear()
  942. log.debug(
  943. "Skipping frame %s %s \
  944. %s %s",
  945. e,
  946. code.co_name,
  947. code.co_filename,
  948. code.co_firstlineno,
  949. )
  950. raise
  951. def _compile(
  952. code: CodeType,
  953. globals: dict[str, object],
  954. locals: dict[str, object],
  955. builtins: dict[str, object],
  956. closure: tuple[CellType],
  957. compiler_fn: CompilerFn,
  958. one_graph: bool,
  959. export: bool,
  960. export_constraints: Optional[typing.Never],
  961. hooks: Hooks,
  962. cache_entry: Optional[CacheEntry],
  963. cache_size: CacheSizeRelevantForFrame,
  964. frame: Optional[DynamoFrameType] = None,
  965. frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
  966. *,
  967. compile_id: CompileId,
  968. skip: int = 0,
  969. package: Optional[CompilePackage] = None,
  970. # Can be used to record things for the caller, both
  971. # in the case of normal and exception code paths
  972. convert_frame_box: Optional[ConvertFrameBox] = None,
  973. ) -> ConvertFrameReturn:
  974. from torch._inductor.async_compile import async_compile_pool_manager
  975. from torch.fx.experimental.validator import (
  976. BisectValidationException,
  977. ValidationException,
  978. )
  979. # Only nonlocal defs here please!
  980. # Time spent compiling this frame before restarting or failing analysis
  981. dynamo_time_before_restart: float = 0.0
  982. @compile_time_strobelight_meta(phase_name="compile_inner")
  983. def compile_inner(
  984. code: CodeType, one_graph: bool, hooks: Hooks
  985. ) -> tuple[ConvertFrameReturn, Optional[DynamoTracerOutput]]:
  986. with contextlib.ExitStack() as stack:
  987. stack.enter_context(
  988. torch._dynamo.callback_handler.install_callbacks(
  989. CallbackTrigger.DYNAMO, str(CompileContext.current_compile_id())
  990. )
  991. )
  992. stack.enter_context(CompileTimeInstructionCounter.record())
  993. return _compile_inner(code, one_graph, hooks)
  994. return (
  995. ConvertFrameReturn(),
  996. None,
  997. ) # dead, but see https://github.com/python/mypy/issues/7577
  998. @maybe_cprofile
  999. def _compile_inner(
  1000. code: CodeType,
  1001. one_graph: bool,
  1002. hooks: Hooks,
  1003. ) -> tuple[ConvertFrameReturn, DynamoTracerOutput]:
  1004. nonlocal dynamo_time_before_restart
  1005. last_attempt_start_time = start_time = time.time()
  1006. def log_bytecode(
  1007. prefix: str, name: str, filename: str, line_no: int, code: CodeType
  1008. ) -> None:
  1009. if bytecode_log.isEnabledFor(logging.DEBUG):
  1010. bytecode_log.debug(
  1011. format_bytecode(prefix, name, filename, line_no, code)
  1012. )
  1013. log_bytecode(
  1014. "ORIGINAL BYTECODE",
  1015. code.co_name,
  1016. code.co_filename,
  1017. code.co_firstlineno,
  1018. code,
  1019. )
  1020. out_code = None
  1021. try:
  1022. dynamo_output = compile_frame(
  1023. code,
  1024. globals,
  1025. locals,
  1026. builtins,
  1027. closure,
  1028. compiler_fn,
  1029. one_graph,
  1030. restart_reasons,
  1031. export=export,
  1032. export_constraints=export_constraints,
  1033. frame_state=frame_state,
  1034. distributed_state=distributed_state,
  1035. package=package,
  1036. )
  1037. except exc.SkipFrame as e:
  1038. if one_graph:
  1039. log.debug("No graph captured with export/fullgraph=True")
  1040. assert e._torch_dynamo_tracer_output is not None
  1041. return ConvertFrameReturn(), e._torch_dynamo_tracer_output
  1042. assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type]
  1043. "compiler collective wasn't run before compilation completed"
  1044. )
  1045. out_code = dynamo_output.bytecode
  1046. tracer_output = dynamo_output.tracer_output
  1047. if dynamo_output.last_attempt_start_time is not None:
  1048. last_attempt_start_time = dynamo_output.last_attempt_start_time
  1049. assert out_code is not None
  1050. log_bytecode(
  1051. "MODIFIED BYTECODE",
  1052. code.co_name,
  1053. code.co_filename,
  1054. code.co_firstlineno,
  1055. out_code,
  1056. )
  1057. for idx, hook in enumerate(_bytecode_hooks.values()):
  1058. with dynamo_timed(f"bytecode_hooks_{idx}", log_pt2_compile_event=True):
  1059. hook_output = hook(code, out_code)
  1060. if hook_output is not None:
  1061. out_code = hook_output
  1062. orig_code_map[out_code] = code
  1063. output_codes.add(out_code)
  1064. dynamo_time_before_restart = last_attempt_start_time - start_time
  1065. assert tracer_output.output_graph is not None
  1066. output = tracer_output.output_graph
  1067. # Tests for new code objects.
  1068. # The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c
  1069. # Only test once the code object is created.
  1070. # They are not tested during runtime.
  1071. def count_args(code: CodeType) -> int:
  1072. import inspect
  1073. return (
  1074. code.co_argcount
  1075. + code.co_kwonlyargcount
  1076. + bool(code.co_flags & inspect.CO_VARARGS)
  1077. + bool(code.co_flags & inspect.CO_VARKEYWORDS)
  1078. )
  1079. assert out_code is not None
  1080. total_argcount_old = count_args(code)
  1081. total_argcount_new = count_args(out_code)
  1082. msg = "arg mismatch: "
  1083. msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, "
  1084. msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}"
  1085. assert (
  1086. code.co_varnames[:total_argcount_old]
  1087. == out_code.co_varnames[:total_argcount_new]
  1088. ), msg
  1089. msg = "free var mismatch: "
  1090. msg += f"old code object has free var {code.co_freevars}, "
  1091. msg += f"new code object has free var {out_code.co_freevars}"
  1092. assert code.co_freevars == out_code.co_freevars, msg
  1093. msg = "cell var mismatch: "
  1094. msg += f"old code object has cell var {code.co_cellvars}, "
  1095. msg += f"new code object has cell var {out_code.co_cellvars}"
  1096. assert code.co_cellvars == out_code.co_cellvars, msg
  1097. # Skipping Dynamo on a frame without any extracted graph.
  1098. # This does not affect eager functionality. But this is necessary
  1099. # for export for cases where Dynamo-reconstructed bytecode can create
  1100. # new function frames, confusing export in thinking that there
  1101. # are extra graphs now.
  1102. if output.export and output.is_empty_graph():
  1103. return ConvertFrameReturn(), tracer_output
  1104. assert output.guards is not None
  1105. CleanupManager.instance[out_code] = output.cleanups
  1106. nonlocal cache_entry
  1107. with dynamo_timed("build_guards", log_pt2_compile_event=True):
  1108. check_fn = dynamo_output.build_guards(
  1109. code,
  1110. hooks=hooks,
  1111. save=package is not None,
  1112. cache_entry=cache_entry,
  1113. )
  1114. if package is not None:
  1115. assert check_fn.guards_state is not None
  1116. package.add_guarded_code(check_fn.guards_state, out_code)
  1117. package.add_inlined_source(output.tracing_context.traced_code)
  1118. compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
  1119. annotation_str = "Torch-Compiled Region: " + compile_id_str
  1120. guarded_code = GuardedCode(
  1121. out_code,
  1122. check_fn.guard_manager, # type: ignore[arg-type]
  1123. compile_id,
  1124. annotation_str,
  1125. )
  1126. if not output.is_empty_graph() and hooks.guard_export_fn is not None:
  1127. # We should not run the guard_export_fn when Dynamo does not
  1128. # generate any graph. This can happen in export when TorchDynamo
  1129. # generated bytecode has some reconstruction logic for mutated
  1130. # variables which can trigger TorchDynamo on the children frames but
  1131. # they are benign and do not generate any new graphs.
  1132. hooks.guard_export_fn(output.guards)
  1133. return wrap_guarded_code(guarded_code), tracer_output
  1134. metrics_context = get_metrics_context()
  1135. code_context = (
  1136. package.code_context(code) if package is not None else contextlib.nullcontext()
  1137. )
  1138. with (
  1139. _use_lazy_graph_module(config.use_lazy_graph_module),
  1140. compile_context(CompileContext(compile_id)),
  1141. async_compile_pool_manager(),
  1142. chromium_event_timed(
  1143. "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True
  1144. ),
  1145. _WaitCounter("pytorch.wait_counter.entire_forward_compile").guard(),
  1146. metrics_context,
  1147. dynamo_timed(
  1148. "_compile.compile_inner",
  1149. phase_name="entire_frame_compile",
  1150. dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
  1151. ),
  1152. code_context,
  1153. ):
  1154. restart_reasons: set[str] = set()
  1155. if compile_pg := get_compile_pg():
  1156. distributed_state = DistributedState(compile_pg, LocalState())
  1157. else:
  1158. distributed_state = None
  1159. # Check recompilations
  1160. recompile_reason: Optional[str] = None
  1161. if is_recompilation(cache_size) and frame:
  1162. reasons = get_and_maybe_log_recompilation_reasons(cache_entry, frame)
  1163. recompile_reason = (
  1164. "Unable to find recompilation reasons" if not reasons else reasons[0]
  1165. )
  1166. # Recheck for recompilation, for when inline_inbuilt_nn_modules is set to False
  1167. inline_inbuilt_nn_modules_candidate = False
  1168. if not config.inline_inbuilt_nn_modules and frame:
  1169. inbuilt_nn_reasons = get_and_maybe_log_recompilation_reasons(
  1170. cache_entry, frame, skip_logging=True
  1171. )
  1172. inbuilt_nn_recompile_reason = (
  1173. None if not inbuilt_nn_reasons else inbuilt_nn_reasons[0]
  1174. )
  1175. if (
  1176. inbuilt_nn_recompile_reason is not None
  1177. and "[inline-inbuilt-nn-modules-candidate]"
  1178. in inbuilt_nn_recompile_reason
  1179. ):
  1180. inline_inbuilt_nn_modules_candidate = True
  1181. # Set if the recompile is a candidate for inline_inbuilt_nn_modules
  1182. # regardless of whether inline_inbuilt_nn_modules is set or not
  1183. metrics_context.update_outer(
  1184. {
  1185. "recompile_reason": recompile_reason,
  1186. "inline_inbuilt_nn_modules_candidate": inline_inbuilt_nn_modules_candidate,
  1187. }
  1188. )
  1189. recompile_user_contexts = get_hook_for_recompile_user_context()
  1190. if recompile_user_contexts:
  1191. # cap each user context to N chars for data retention purposes. N=256
  1192. # is chosen to be large enough to capture the most important info.
  1193. user_contexts_msg = {
  1194. user_context()[:256] for user_context in recompile_user_contexts
  1195. }
  1196. metrics_context.set("recompile_user_contexts", user_contexts_msg)
  1197. exceeded, limit_type = exceeds_recompile_limit(cache_size, compile_id)
  1198. if exceeded:
  1199. def format_func_info(code: CodeType) -> str:
  1200. return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})"
  1201. # NS: Don't add period at the end of string, as it'll be added to URL
  1202. # rendering it incorrect
  1203. log.warning(
  1204. "torch._dynamo hit config.%s (%s)\n"
  1205. " function: %s\n"
  1206. " last reason: %s\n"
  1207. 'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n'
  1208. "To diagnose recompilation issues, see %s",
  1209. limit_type,
  1210. getattr(config, limit_type),
  1211. format_func_info(code),
  1212. recompile_reason,
  1213. troubleshooting_url,
  1214. )
  1215. if config.fail_on_recompile_limit_hit:
  1216. raise FailOnRecompileLimitHit(
  1217. f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
  1218. )
  1219. elif one_graph:
  1220. raise FailOnRecompileLimitHit(
  1221. f"{limit_type} reached with fullgraph=True. Excessive recompilations can degrade "
  1222. "performance due to the compilation overhead of each recompilation. To monitor "
  1223. "recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider "
  1224. "increasing torch._dynamo.config.cache_size_limit to an appropriate value."
  1225. )
  1226. elif justknobs_check(
  1227. "pytorch/compiler:skip_code_recursive_on_recompile_limit_hit"
  1228. ):
  1229. raise RecompileLimitExceeded(f"{limit_type} reached")
  1230. else:
  1231. # do not recursively skip frames
  1232. unimplemented_v2(
  1233. gb_type="Dynamo cache limit exceeded",
  1234. context=f"Limit type: {limit_type}",
  1235. explanation="Dynamo attempted to recompile the code object too many times, "
  1236. f"exceeding the {limit_type} cache size limit."
  1237. "Giving up on compiling as the compile time tradeoff is likely not "
  1238. "worth the performance gain.",
  1239. hints=[],
  1240. )
  1241. log.debug(
  1242. "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s",
  1243. code.co_name,
  1244. code.co_filename,
  1245. code.co_firstlineno,
  1246. skip + 2,
  1247. # -2: omit current frame, omit contextlib decorator
  1248. "".join(CapturedTraceback.extract(skip=2 + skip).format()),
  1249. )
  1250. # -4: -2 as above, plus trace_structured frames
  1251. #
  1252. # NB: the frame looks like this:
  1253. #
  1254. # # handled by skip argument
  1255. # torch/_dynamo/convert_frame.py:1069 in catch_errors
  1256. # torch/_dynamo/convert_frame.py:910 in _convert_frame
  1257. # torch/_dynamo/convert_frame.py:464 in _convert_frame_assert
  1258. # torch/_utils_internal.py:70 in wrapper_function
  1259. #
  1260. # # 2 current frame and context lib
  1261. # env/lib/python3.10/contextlib.py:79 in inner
  1262. # torch/_dynamo/convert_frame.py:776 in _compile
  1263. #
  1264. # # 2 extra here
  1265. # torch/_logging/_internal.py:1064 in trace_structured
  1266. # torch/_dynamo/convert_frame.py:780 in <lambda>
  1267. stack_trace = log_dynamo_start(code, skip)
  1268. start_time_ns = time.time_ns()
  1269. fail_type: Optional[str] = None
  1270. fail_reason: Optional[str] = None
  1271. exception_stack_trace: Optional[list[str]] = None
  1272. fail_user_frame_filename: Optional[str] = None
  1273. fail_user_frame_lineno: Optional[int] = None
  1274. torch._dynamo.utils.ReinplaceCounters.clear()
  1275. guarded_code = None
  1276. tracer_output = None
  1277. try:
  1278. guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
  1279. # NB: We only put_code_state in success case. Success case here
  1280. # does include graph breaks; specifically, if a graph break still
  1281. # resulted in a partially compiled graph, we WILL return here. An
  1282. # Unsupported exception will only bubble to the top level if we
  1283. # are unable to compile the frame at all. In this case, there's
  1284. # no point in uploading the code state, because we will always
  1285. # fail exactly the same way even without the update. (It's useful
  1286. # to upload for graph break though, because this can prevent
  1287. # extra graph break compilations.)
  1288. put_code_state()
  1289. if (
  1290. tracer_output
  1291. and (output_graph := tracer_output.output_graph)
  1292. and output_graph.has_outputs()
  1293. ):
  1294. log_frame_dynamic_whitelist(code)
  1295. return guarded_code
  1296. except Exception as e:
  1297. # NB: e's msg is mutated here to add user stack, but we DON'T want
  1298. # that stack in the Scuba logged fail_reason. So we grab the fail
  1299. # info here and add it to the metrics context below.
  1300. fail_type = type(e).__qualname__
  1301. fail_reason = str(e)
  1302. exception_stack_trace = [traceback.format_exc()]
  1303. exception_handler(e, code, frame, export=export)
  1304. # NB: this is the post-mutation exception
  1305. torch._logging.trace_structured(
  1306. "artifact",
  1307. metadata_fn=lambda: {
  1308. "name": "dynamo_error",
  1309. "encoding": "string",
  1310. },
  1311. payload_fn=lambda: traceback.format_exc(),
  1312. )
  1313. fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message(
  1314. e, compile_id
  1315. )
  1316. tracer_output = getattr(e, "_torch_dynamo_tracer_output", None)
  1317. if isinstance(
  1318. e,
  1319. (
  1320. Unsupported,
  1321. TorchRuntimeError,
  1322. BackendCompilerFailed,
  1323. AssertionError,
  1324. ConstraintViolationError,
  1325. GuardOnDataDependentSymNode,
  1326. ValidationException,
  1327. UncapturedHigherOrderOpError,
  1328. BisectValidationException,
  1329. ShortenTraceback,
  1330. PackageError,
  1331. ResumePrologueTracingError,
  1332. ),
  1333. ):
  1334. raise
  1335. else:
  1336. # Rewrap for clarity
  1337. raise InternalTorchDynamoError(
  1338. f"{type(e).__qualname__}: {str(e)}"
  1339. ).with_traceback(e.__traceback__) from None
  1340. finally:
  1341. # === WARNING WARNING WARNING ===
  1342. # If you commit a bug here, it will suppress writing to
  1343. # dynamo_compile table, and we will not have telemetry.
  1344. # Be extra careful when making changes here!
  1345. if torch._dynamo.config.run_gc_after_compile:
  1346. with dynamo_timed("gc", dynamo_compile_column_us="gc_time_us"):
  1347. log.info("run_gc_after_compile: running gc")
  1348. gc.collect(1)
  1349. output = None
  1350. if tracer_output:
  1351. output = tracer_output.output_graph
  1352. if output:
  1353. output.local_scope = {}
  1354. # tracer should already be None, keep an extra check here just in case.
  1355. if tracer := output.root_tx:
  1356. tracer.f_locals = {}
  1357. from .utils import curr_frame
  1358. frame_key = str(curr_frame)
  1359. if fail_reason is None and output is not None:
  1360. guard_count = len(output.guards)
  1361. shape_env_guard_count = len(output.shape_env.guards)
  1362. graph_op_count = output.count_calls()
  1363. graph_node_count = len(output.graph.nodes)
  1364. graph_node_shapes = output.get_graph_sizes_structured()
  1365. graph_input_count = len(output.placeholders)
  1366. non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops}
  1367. compliant_custom_ops = {
  1368. op.__qualname__ for op in output.compliant_custom_ops
  1369. }
  1370. torch._dynamo.utils.ReinplaceCounters.log()
  1371. else:
  1372. guard_count = None
  1373. shape_env_guard_count = None
  1374. graph_op_count = None
  1375. graph_node_count = None
  1376. graph_node_shapes = {}
  1377. graph_input_count = None
  1378. non_compliant_ops = set({})
  1379. compliant_custom_ops = set({})
  1380. restart_reasons = set()
  1381. # If compilation failed, the entire time is wasted
  1382. dynamo_time_before_restart = (time.time_ns() - start_time_ns) / 1e9
  1383. metrics = {
  1384. "frame_key": frame_key,
  1385. "co_name": code.co_name,
  1386. "co_filename": code.co_filename,
  1387. "co_firstlineno": code.co_firstlineno,
  1388. "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
  1389. "accumulated_cache_size": cache_size.num_cache_entries,
  1390. "guard_count": guard_count,
  1391. "shape_env_guard_count": shape_env_guard_count,
  1392. "graph_op_count": graph_op_count,
  1393. "graph_node_count": graph_node_count,
  1394. "graph_input_count": graph_input_count,
  1395. "fail_type": fail_type,
  1396. "fail_reason": fail_reason,
  1397. "fail_user_frame_filename": fail_user_frame_filename,
  1398. "fail_user_frame_lineno": fail_user_frame_lineno,
  1399. "non_compliant_ops": non_compliant_ops,
  1400. "compliant_custom_ops": compliant_custom_ops,
  1401. "restart_reasons": restart_reasons,
  1402. "dynamo_time_before_restart_s": dynamo_time_before_restart,
  1403. "has_guarded_code": guarded_code is not None,
  1404. "specialize_float": config.specialize_float,
  1405. "is_forward": True,
  1406. "dynamo_compile_time_before_restart_us": to_int_us(
  1407. dynamo_time_before_restart
  1408. ),
  1409. "stack_trace": stack_trace,
  1410. "graph_node_shapes": str(graph_node_shapes),
  1411. "exception_stack_trace": exception_stack_trace,
  1412. }
  1413. # TODO: replace with CompileEventLogger.compilation_metrics
  1414. # There are some columns here not in PT2 Compile Events
  1415. # so we need to slightly change it
  1416. metrics_context.update_outer(metrics)
  1417. # === END WARNING WARNING WARNING ===
  1418. # If tracer is available, then tracer.error_on_graph_break reflects value of
  1419. # global symbolic_convert.error_on_graph_break at the time of the graph break -
  1420. # symbolic_convert.error_on_graph_break may have been (correctly) changed during cleanup.
  1421. # If tracer is unavailable, then fallback to symbolic_convert.error_on_graph_break.
  1422. if convert_frame_box:
  1423. convert_frame_box.error_on_graph_break = (
  1424. tracer_output.error_on_graph_break
  1425. if tracer_output
  1426. else _get_error_on_graph_break()
  1427. )
  1428. class ConvertFrame:
  1429. def __init__(
  1430. self,
  1431. compiler_fn: CompilerFn,
  1432. hooks: Hooks,
  1433. package: Optional[CompilePackage] = None,
  1434. ) -> None:
  1435. self._torchdynamo_orig_backend = compiler_fn
  1436. self._inner_convert = convert_frame_assert(
  1437. compiler_fn, one_graph=False, package=package
  1438. )
  1439. self._hooks = hooks
  1440. @property
  1441. def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]:
  1442. return lambda backend: convert_frame(
  1443. backend,
  1444. self._hooks,
  1445. )
  1446. def __call__(
  1447. self,
  1448. frame: DynamoFrameType,
  1449. cache_entry: Optional[CacheEntry],
  1450. hooks: Hooks,
  1451. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  1452. skip: int = 0,
  1453. ) -> ConvertFrameReturn:
  1454. input_codes.add(frame.f_code)
  1455. counters["frames"]["total"] += 1
  1456. try:
  1457. result = self._inner_convert(
  1458. frame, cache_entry, hooks, frame_state, skip=skip + 1
  1459. )
  1460. counters["frames"]["ok"] += 1
  1461. return result
  1462. except Exception as e:
  1463. # Do not allow errors to be suppressed if we're tracing a resume function prologue
  1464. if isinstance(e, ResumePrologueTracingError):
  1465. raise
  1466. error_on_graph_break = (
  1467. self._inner_convert._box.error_on_graph_break is not None
  1468. )
  1469. assert error_on_graph_break is not None
  1470. if self._inner_convert._box.error_on_graph_break:
  1471. # NOTE we _might_ have to wrap the current in a custom exception
  1472. # in order to correctly bubble up to the top-level compile wrapper in
  1473. # eval_frame.py. But re-raising seems to work for now because exceptions from tracing
  1474. # a nested call that results in a top-level frame compile will be handled by the caller
  1475. # as an observed exception - we don't expect that exception to be suppressed.
  1476. raise
  1477. # These two exception types are "soft" failure, in the sense that
  1478. # we know this is due to something we didn't implement all the
  1479. # way, scare the user less about it. That being said, if you
  1480. # are trying to understand why a graph break happened, it's still
  1481. # important to have this information, so offer it.
  1482. #
  1483. # NB: NotImplementedError used to be on this list, but actually
  1484. # it is impossible for it to reach here, as it is converted into
  1485. # InternalTorchDynamoError. This behavior seemed reasonable
  1486. # to me (ezyang, Aug 2023) so I kept it, but maybe at some point
  1487. # someone wanted these to also get suppressed. If so, you'll
  1488. # need to make these exceptions not get wrapped
  1489. # We intentionally don't want to suppress error here.
  1490. if isinstance(e, UncapturedHigherOrderOpError):
  1491. raise
  1492. soft_fail = isinstance(e, Unsupported)
  1493. # This is a soft failure. In the sense, the code path reaches here
  1494. # when we do not support graph breaks on bytecodes like LOAD_ATTR,
  1495. # BUILD_SET etc. In such case, we can fallback to eager without
  1496. # scaring users.
  1497. if soft_fail and graph_break_log.isEnabledFor(logging.DEBUG):
  1498. # Log this message in the graph break. Also use the string
  1499. # "skip: " to tell that the whole frame is falling back to
  1500. # eager.
  1501. if hasattr(e, "compile_id") and hasattr(e, "real_stack"):
  1502. with compile_context(CompileContext(e.compile_id)): # type: ignore[attr-defined]
  1503. user_stack = e.real_stack
  1504. user_stack_formatted = "".join(
  1505. traceback.format_list(user_stack)
  1506. )
  1507. user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
  1508. torch._logging.trace_structured(
  1509. "artifact",
  1510. metadata_fn=lambda: {
  1511. "name": "dynamo_graph_break_reason",
  1512. "encoding": "string",
  1513. },
  1514. payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc()}",
  1515. )
  1516. graph_break_log.debug(
  1517. user_stack_trace,
  1518. exc_info=True,
  1519. )
  1520. if not config.suppress_errors and not soft_fail:
  1521. raise
  1522. # Suppress the error. NB: It's very important to do the
  1523. # suppression logging HERE, where the actual suppression
  1524. # happens. Previously it was somewhere else and so it was
  1525. # possible to accidentally not log at all.
  1526. record_filename = getattr(e, "record_filename", None)
  1527. code = frame.f_code
  1528. error_msg = format_error_msg(e, code, record_filename, frame)
  1529. if soft_fail:
  1530. log.info(error_msg, exc_info=True)
  1531. else:
  1532. log.warning(error_msg, exc_info=True)
  1533. if isinstance(e, SkipCodeRecursiveException):
  1534. return ConvertFrameReturn(
  1535. frame_exec_strategy=FrameExecStrategy(
  1536. FrameAction.SKIP, FrameAction.SKIP
  1537. )
  1538. )
  1539. elif isinstance(e, RecompileLimitExceeded):
  1540. return ConvertFrameReturn(
  1541. frame_exec_strategy=FrameExecStrategy(
  1542. FrameAction.RUN_ONLY, FrameAction.RUN_ONLY
  1543. )
  1544. )
  1545. return ConvertFrameReturn()
  1546. def convert_frame(
  1547. compiler_fn: CompilerFn,
  1548. hooks: Hooks,
  1549. package: Optional[CompilePackage] = None,
  1550. ) -> ConvertFrame:
  1551. """Try to convert a frame into an FX graph, if error leave frame unmodified"""
  1552. return ConvertFrame(compiler_fn, hooks, package=package)
  1553. # TODO mlazos: add support for same args, or record them
  1554. def replay(filename: str) -> None:
  1555. from .backends.debugging import eager
  1556. original_replay_val = config.replay_record_enabled
  1557. config.replay_record_enabled = False
  1558. with open(filename, "rb") as in_file:
  1559. record = ExecutionRecord.load(in_file)
  1560. record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
  1561. with decorators.error_on_graph_break(False):
  1562. try:
  1563. _compile(
  1564. record.code,
  1565. record.globals,
  1566. record.locals,
  1567. record.builtins,
  1568. record.closure,
  1569. compiler_fn=eager,
  1570. one_graph=False,
  1571. export=False,
  1572. export_constraints=None,
  1573. hooks=Hooks(),
  1574. cache_size=CacheSizeRelevantForFrame(0, 0),
  1575. cache_entry=None,
  1576. frame=None,
  1577. frame_state={},
  1578. compile_id=CompileId(frame_id=42, frame_compile_id=999),
  1579. )
  1580. finally:
  1581. config.replay_record_enabled = original_replay_val
  1582. def first_real_inst_idx(code: CodeType) -> int:
  1583. if sys.version_info < (3, 11):
  1584. return 0
  1585. for inst in dis.get_instructions(code):
  1586. if inst.opname == "RESUME":
  1587. return inst.offset // 2
  1588. raise RuntimeError("RESUME instruction not found in code")
  1589. class ConvertFrameProtocol(typing.Protocol):
  1590. def __call__(
  1591. self,
  1592. frame: DynamoFrameType,
  1593. cache_entry: Optional[CacheEntry],
  1594. hooks: Hooks,
  1595. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  1596. *,
  1597. skip: int = 0,
  1598. ) -> ConvertFrameReturn: ...
  1599. def should_skip_due_to_torch_dispatch_mode() -> bool:
  1600. return is_in_any_mode_without_ignore_compile_internals()
  1601. class CatchErrorsWrapper:
  1602. def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
  1603. functools.wraps(callback)(self)
  1604. self._torchdynamo_orig_backend = callback
  1605. self.hooks = hooks
  1606. def __call__(
  1607. self,
  1608. frame: DynamoFrameType,
  1609. cache_entry: Optional[CacheEntry],
  1610. frame_state: dict[str, Union[int, FrameStateSizeEntry]],
  1611. ) -> ConvertFrameReturn:
  1612. assert frame_state is not None
  1613. input_codes.add(frame.f_code)
  1614. is_skipfile = trace_rules.check(frame.f_code)
  1615. if sys.version_info >= (3, 13):
  1616. has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)
  1617. else:
  1618. has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code)
  1619. if (
  1620. # TODO: the first condition is not covered by any test
  1621. has_started_execution
  1622. or is_skipfile
  1623. or config.disable
  1624. or (
  1625. should_skip_due_to_torch_dispatch_mode()
  1626. and not getattr(self._torchdynamo_orig_backend, "_export", False)
  1627. )
  1628. ):
  1629. if log.isEnabledFor(logging.DEBUG):
  1630. if has_started_execution:
  1631. skip_reason = "traced frame already"
  1632. elif trace_rules.check(frame.f_code):
  1633. skip_reason = "in skipfiles"
  1634. elif is_in_torch_dispatch_mode(include_infra_modes=False):
  1635. skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile"
  1636. else:
  1637. skip_reason = "dynamo tracing is disabled"
  1638. log.debug(
  1639. "skipping: %s (reason: %s, file: %s)",
  1640. frame.f_code.co_name,
  1641. skip_reason,
  1642. frame.f_code.co_filename,
  1643. )
  1644. return ConvertFrameReturn()
  1645. if (
  1646. frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__"
  1647. ) or (
  1648. frame.f_code.co_filename.endswith("collections/__init__.py")
  1649. and frame.f_code.co_name == "_make"
  1650. ):
  1651. # nametuple constructor/_make
  1652. return ConvertFrameReturn()
  1653. if torch._dynamo.utils.get_optimize_ddp_mode() == "ddp_optimizer":
  1654. ddp_module = DistributedDataParallel._get_active_ddp_module()
  1655. if ddp_module:
  1656. with compile_lock:
  1657. from torch._dynamo.backends.distributed import DDPOptimizer
  1658. ddp_optimizer = DDPOptimizer(
  1659. bucket_bytes_cap=ddp_module.bucket_bytes_cap,
  1660. backend_compile_fn=self._torchdynamo_orig_backend._torchdynamo_orig_backend, # type: ignore[attr-defined]
  1661. )
  1662. assert hasattr(
  1663. self._torchdynamo_orig_backend, "_clone_with_backend"
  1664. ), (
  1665. "DDPOptimizer only supports callback fns that know how to clone themselves."
  1666. )
  1667. hijacked_callback = (
  1668. self._torchdynamo_orig_backend._clone_with_backend(
  1669. ddp_optimizer.compile_fn,
  1670. )
  1671. )
  1672. return hijacked_callback(
  1673. frame, cache_entry, self.hooks, frame_state
  1674. )
  1675. with compile_lock, _disable_current_modes():
  1676. # skip=1: skip this frame
  1677. result = self._torchdynamo_orig_backend(
  1678. frame, cache_entry, self.hooks, frame_state, skip=1
  1679. )
  1680. return result
  1681. def catch_errors_wrapper(
  1682. callback: ConvertFrameProtocol, hooks: Hooks
  1683. ) -> CatchErrorsWrapper:
  1684. return CatchErrorsWrapper(callback, hooks)