_guards.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180
  1. from __future__ import annotations
  2. import contextlib
  3. import dataclasses
  4. import enum
  5. import functools
  6. import logging
  7. import re
  8. import threading
  9. import traceback
  10. import unittest.mock
  11. import weakref
  12. from abc import abstractmethod
  13. from collections import defaultdict
  14. from contextlib import contextmanager
  15. from dataclasses import dataclass
  16. from typing import (
  17. Any,
  18. Callable,
  19. Generic,
  20. NamedTuple,
  21. Optional,
  22. TYPE_CHECKING,
  23. TypeVar,
  24. Union,
  25. )
  26. import torch
  27. from torch.utils import _pytree as pytree
  28. from torch.utils._backport_slots import dataclass_slots
  29. from torch.utils._traceback import CapturedTraceback, format_frame
  30. from torch.utils.weak import WeakTensorKeyDictionary
  31. log = logging.getLogger(__name__)
  32. if TYPE_CHECKING:
  33. from collections.abc import Generator, Iterator
  34. from types import CodeType
  35. import sympy
  36. from torch._dynamo.backends.distributed import DDPOptimizerContext
  37. from torch._dynamo.codegen import PyCodegen
  38. from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
  39. from torch._subclasses.fake_tensor import FakeTensorMode
  40. """
  41. torch._guards is the definitional source of truth for general purpose guard structures.
  42. An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
  43. and no guard installation notions here.
  44. """
  45. COMPILE_ID_PATTERN = re.compile(r"^(?P<frame_id>\d+)/(?P<frame_compile_id>\d+)$")
  46. CA_COMPILE_ID_PATTERN = re.compile(
  47. r"^!(?P<compiled_autograd_id>\d+)(?:/(?P<frame_id>\d+)/(?P<frame_compile_id>\d+))?$"
  48. )
  49. # [Note: Updating CompiledId]
  50. #
  51. # CompiledId represents a unique program-level identifier, and we want to keep that
  52. # property as the codebase evolves. This property is relied on even outside of the pytorch
  53. # repo, e.g. tlparse or other internal tooling. The in-memory format can be freely changed,
  54. # as those dependencies only consume the string serialization.
  55. #
  56. # The string form should be:
  57. # 1. Program-level uid: CompileId can uniquely identify a compiled graph.
  58. # 2. Storage efficient: This object is logged in nearly every entry. We should elide symbols when possible.
  59. # 3. Compact: The string form is directly displayed by some tools. Special symbols are okay.
  60. # TODO: mark as kw_only=True once we drop support for <Python 3.10
  61. @dataclass(frozen=True)
  62. class CompileId:
  63. frame_id: Optional[int]
  64. # This id is per-frame, and counts how many times we've compiled this
  65. # frame. This could have been a global id but having this be per-frame
  66. # gives you a better intuitive sense for how many recompiles have occurred
  67. # so far.
  68. frame_compile_id: Optional[int]
  69. # torch.compiling a compiled autograd graph
  70. compiled_autograd_id: Optional[int] = None
  71. # TODO: consider also tracking the recompilation count
  72. # See Note: Updating CompileId
  73. def __str__(self) -> str:
  74. # NOTE: Keep this in sync with both from_string and the tlparse repo
  75. if self.compiled_autograd_id is not None:
  76. assert (self.frame_id is None) == (self.frame_compile_id is None)
  77. frame_str = ""
  78. if self.frame_id is not None:
  79. frame_str = f"/{self.frame_id}/{self.frame_compile_id}"
  80. return f"!{self.compiled_autograd_id}{frame_str}"
  81. else:
  82. assert self.frame_id is not None and self.frame_compile_id is not None
  83. return f"{self.frame_id}/{self.frame_compile_id}"
  84. @classmethod
  85. def from_string(cls, compile_id: Optional[str]) -> Optional[CompileId]:
  86. """
  87. Factory method that creates a CompileId from its string representation.
  88. Keep this in sync with the __str__ method.
  89. """
  90. if compile_id is None:
  91. return None
  92. try:
  93. for pattern in (COMPILE_ID_PATTERN, CA_COMPILE_ID_PATTERN):
  94. if match := pattern.match(compile_id):
  95. groups = match.groupdict()
  96. for k, v in groups.items():
  97. if v is not None:
  98. groups[k] = int(v)
  99. return cls(**groups) # type: ignore[arg-type]
  100. else:
  101. raise ValueError
  102. except Exception as e:
  103. raise ValueError(f"Invalid compile_id '{compile_id}'") from e
  104. class TraceId(NamedTuple):
  105. compile_id: CompileId
  106. # This starts off as 0, and every time we restart analysis it goes
  107. # up by one
  108. attempt: int
  109. def __str__(self) -> str:
  110. # Keep this in sync with tlparse repo
  111. if self.attempt == 0:
  112. return str(self.compile_id)
  113. else:
  114. return f"{self.compile_id}_{self.attempt}"
  115. class GuardSource(enum.Enum):
  116. LOCAL = 0
  117. GLOBAL = 1
  118. LOCAL_SPECIALIZED_NN_MODULE = 2
  119. GLOBAL_SPECIALIZED_NN_MODULE = 3
  120. CONSTANT = 4
  121. RANDOM_VALUE = 5
  122. SHAPE_ENV = 6
  123. LOCAL_FSDP_MODULE = 7
  124. GLOBAL_FSDP_MODULE = 8
  125. BACKWARD_STATE = 9
  126. EPHEMERAL = 10
  127. SYNTHETIC_LOCAL = 11
  128. LOCAL_UNSPECIALIZED_NN_MODULE = 12
  129. GLOBAL_UNSPECIALIZED_NN_MODULE = 13
  130. LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14
  131. GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15
  132. def is_fsdp_module(self) -> bool:
  133. return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
  134. def is_specialized_nn_module(self) -> bool:
  135. import torch._dynamo.config as config
  136. if config._unsafe_skip_fsdp_module_guards:
  137. return (
  138. self
  139. in (
  140. GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
  141. GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
  142. )
  143. or self.is_fsdp_module()
  144. )
  145. return self in (
  146. GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
  147. GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
  148. )
  149. def is_unspecialized_nn_module(self) -> bool:
  150. return self in (
  151. GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
  152. GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
  153. GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  154. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  155. )
  156. def is_unspecialized_builtin_nn_module(self) -> bool:
  157. return self in (
  158. GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  159. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  160. )
  161. def is_local(self) -> bool:
  162. return self in (
  163. GuardSource.LOCAL,
  164. GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
  165. GuardSource.LOCAL_FSDP_MODULE,
  166. GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
  167. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  168. )
  169. """
  170. Base class for a "GuardBuilder" role.
  171. The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
  172. confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
  173. to torchdynamo's GuardBuilder.
  174. Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
  175. on GuardSource's select function.
  176. There is value in keeping this GuardBuilderBase empty to keep layering clean.
  177. """
  178. class GuardBuilderBase:
  179. pass
  180. @dataclasses.dataclass(frozen=True)
  181. class SLoc:
  182. framework_loc: Optional[Union[traceback.FrameSummary, str]]
  183. maybe_user_loc: Optional[str]
  184. def __str__(self) -> str:
  185. floc = (
  186. self.framework_loc
  187. if isinstance(self.framework_loc, str)
  188. else format_frame(self.framework_loc)
  189. )
  190. if self.maybe_user_loc is not None:
  191. return f"{self.maybe_user_loc} ({floc})"
  192. else:
  193. return f"({floc})"
  194. class ShapeGuard(NamedTuple):
  195. expr: sympy.logic.boolalg.Boolean
  196. sloc: SLoc
  197. size_oblivious: bool
  198. @dataclass_slots
  199. @dataclasses.dataclass
  200. class Guard:
  201. # originating_source is the source that called the make_guard method to
  202. # construct this guard object. The property name specifies what exactly it
  203. # is the guard is guarding on. The meaning of the name is dependent on the
  204. # create_fn; you must look at the use-site inside create_fn to know what
  205. # name means.
  206. #
  207. # That being said, although you might think this is just a "name", name is
  208. # usually an arbitrary Python expression that will be evaluated with all
  209. # globals (and locals, if you create a LOCAL guard) to extract the Python
  210. # object that we want to perform guard tests on. This evaluation
  211. # typically happens in GuardBuilder.eval. In these cases, name is
  212. # typically produced by originating_source.name() (not to be confused with
  213. # GuardSource - the property source).
  214. #
  215. # Occasionally, name is not a valid Python expression; sometimes
  216. # it is meaningless. Example create_fns that are like this include
  217. # GRAD_MODE and SHAPE_ENV.
  218. originating_source: Source
  219. create_fn: Callable[[GuardBuilderBase, Guard], None]
  220. # Export only. These values are written to at time of guard check_fn creation.
  221. guard_types: Optional[list[str]] = None
  222. code_list: Optional[list[str]] = None
  223. obj_weakref: Optional[object] = None
  224. guarded_class_weakref: Optional[weakref.ReferenceType[Any]] = None
  225. stack: Optional[CapturedTraceback] = None
  226. user_stack: Optional[traceback.StackSummary] = None
  227. _hash: Optional[int] = None
  228. _unserializable: bool = False
  229. def __hash__(self) -> int:
  230. if self._hash is None:
  231. self._hash = hash((self.name, self.source, id(self.create_fn)))
  232. return self._hash
  233. def sort_key(self) -> tuple[bool, int, int, str, int]:
  234. # Put the duplicate input guards at the end. The duplicate guards have
  235. # two sources while guard.name only considers one source.
  236. is_duplicate_input = (
  237. isinstance(self.create_fn, functools.partial)
  238. and self.create_fn.func is torch._dynamo.guards.GuardBuilder.DUPLICATE_INPUT
  239. )
  240. return (
  241. is_duplicate_input,
  242. self.source.value if self.source else -1,
  243. len(self.name),
  244. self.name,
  245. self.inner_create_fn().__code__.co_firstlineno,
  246. )
  247. def __lt__(self, other: Guard) -> bool:
  248. return self.sort_key() < other.sort_key()
  249. def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]:
  250. if isinstance(self.create_fn, functools.partial):
  251. return self.create_fn.func
  252. else:
  253. return self.create_fn
  254. @property
  255. def name(self) -> str:
  256. return self.originating_source.name()
  257. @property
  258. def source(self) -> GuardSource:
  259. return self.originating_source.guard_source()
  260. @staticmethod
  261. def weakref_to_str(obj_weakref: object) -> str:
  262. """
  263. This is a workaround of a Python weakref bug.
  264. `obj_weakref` is instance returned by `weakref.ref`,
  265. `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
  266. class MyConfig(dict):
  267. def __getattr__(self, x):
  268. return self[x]
  269. obj = MyConfig(offset=5)
  270. obj_weakref = weakref.ref(obj)
  271. str(obj_weakref) # raise error: KeyError: '__name__'
  272. """
  273. if isinstance(obj_weakref, weakref.ReferenceType):
  274. obj = obj_weakref()
  275. if obj is not None:
  276. return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
  277. else:
  278. return f"<weakref at {hex(id(obj_weakref))}; dead>"
  279. else:
  280. return str(obj_weakref)
  281. def __repr__(self) -> str:
  282. s = f"""
  283. {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
  284. {{
  285. 'guard_types': {self.guard_types},
  286. 'code': {self.code_list},
  287. 'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
  288. 'guarded_class': {self.guarded_class_weakref}
  289. }}
  290. """
  291. return s
  292. def __str__(self) -> str:
  293. output = f"Name: {repr(self.name)}\n"
  294. source = self.source.name.lower() if self.source else ""
  295. output += f" Source: {source}\n"
  296. output += f" Create Function: {self.inner_create_fn().__name__}\n"
  297. output += f" Guard Types: {self.guard_types}\n"
  298. output += f" Code List: {self.code_list}\n"
  299. output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
  300. output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
  301. return output
  302. def create(self, builder: GuardBuilderBase) -> Any:
  303. try:
  304. return self.create_fn(builder, self)
  305. except Exception:
  306. log.exception("Error while creating guard:\n%s", str(self).rstrip())
  307. if self.stack:
  308. log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
  309. raise
  310. def is_specialized_nn_module(self) -> bool:
  311. return self.source.is_specialized_nn_module()
  312. def is_fsdp_module(self) -> bool:
  313. return self.source.is_fsdp_module()
  314. def is_local(self) -> bool:
  315. return self.source.is_local()
  316. def create_fn_name(self) -> str:
  317. if isinstance(self.create_fn, functools.partial):
  318. create_fn = self.create_fn.func # type: ignore[attr-defined]
  319. else:
  320. create_fn = self.create_fn
  321. return create_fn.__name__
  322. def set_export_info(
  323. self,
  324. guard_type: str,
  325. guarded_class: Optional[weakref.ReferenceType[Any]],
  326. code_list: list[str],
  327. obj_weakref: object,
  328. ) -> None:
  329. if not self.guard_types:
  330. self.guard_types = []
  331. self.guard_types.append(guard_type)
  332. assert self.guarded_class_weakref in (
  333. guarded_class,
  334. None,
  335. ), "Guarded class id must be identical, or None"
  336. self.guarded_class_weakref = guarded_class
  337. if not self.code_list:
  338. self.code_list = code_list
  339. else:
  340. self.code_list.extend(code_list)
  341. # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have
  342. # multiple guards on the same object, the weakref can die between the
  343. # invocation of set_export_info calls. So a dead weakref is also
  344. # acceptable.
  345. assert (
  346. self.obj_weakref in (obj_weakref, None)
  347. or callable(self.obj_weakref)
  348. and self.obj_weakref() is None
  349. ), "Guarded object must be identical, None or ephemeral (dead weakref)"
  350. self.obj_weakref = obj_weakref
  351. T = TypeVar("T")
  352. """
  353. Parent structure for guard env expressions.
  354. A GuardEnvExpr can have any subtype.
  355. Note: All subtypes must be handled exhaustively in
  356. torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
  357. """
  358. @dataclasses.dataclass(frozen=True)
  359. class GuardEnvExpr:
  360. pass
  361. """
  362. A class representing a pair of duplicate inputs.
  363. input_pos_a and input_pos_b are input positions we have deduped.
  364. """
  365. @dataclasses.dataclass(frozen=True)
  366. class DuplicateInputs(GuardEnvExpr):
  367. input_source_a: Source
  368. input_source_b: Source
  369. def __post_init__(self) -> None:
  370. assert self.input_source_a != self.input_source_b
  371. """
  372. A class representing storage overlap relations among inputs that aliases the same storage.
  373. Given that a set of tensors alias the same storage, this guard checks whether they actually
  374. have overlapping storages.
  375. While non_overlapping_sources represent input tensors that definitely don't have any storage
  376. overlapping with any other input, overlapping_sources represent tensors that either:
  377. 1. Do overlap some other input tensor
  378. 2. Might not overlap some other input tensor, but we are not sure
  379. """
  380. @dataclasses.dataclass(frozen=True)
  381. class StorageOverlap(GuardEnvExpr):
  382. overlapping_sources: list[Source]
  383. non_overlapping_sources: list[Source]
  384. """
  385. Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
  386. copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
  387. can also be taken in at restore_graphstate(T) calls.
  388. When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
  389. does not provide any guarantees around consistency, idempotency, or safety of calling its APIs, yet.
  390. In the future, it will have a closer coupling to a generic Checkpoint management system.
  391. """
  392. class Checkpointable(Generic[T]):
  393. @abstractmethod
  394. def copy_graphstate(self) -> T: ...
  395. @abstractmethod
  396. def restore_graphstate(self, state: T) -> None: ...
  397. class GuardsCheckpointState:
  398. """
  399. The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
  400. """
  401. dynamo_guards: set[Guard] = set()
  402. def __init__(self, dynamo_guards: set[Guard]) -> None:
  403. self.dynamo_guards = dynamo_guards
  404. def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]:
  405. """
  406. Produces a delta against another GuardsCheckpointState.
  407. Returns None if no delta is found, otherwise, return a set() of mismatched
  408. Guard type objects.
  409. """
  410. r = self.dynamo_guards.difference(other.dynamo_guards)
  411. if len(r) == 0:
  412. return None
  413. return r
  414. def __eq__(self, other: object) -> bool:
  415. if not isinstance(other, GuardsCheckpointState):
  416. return False
  417. return self.diff(other) is None
  418. class ModuleContextCheckpointState:
  419. nn_modules: dict[str, torch.nn.Module] = {}
  420. def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None:
  421. self.nn_modules = nn_modules
  422. def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]:
  423. """
  424. Produces a delta against another ModuleContextCheckpointState.
  425. Returns None if no delta is found, otherwise, return a set() of mismatched
  426. module key names.
  427. """
  428. r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
  429. if len(r) == 0:
  430. return None
  431. return r
  432. def __eq__(self, other: object) -> bool:
  433. if not isinstance(other, ModuleContextCheckpointState):
  434. return False
  435. return self.diff(other) is None
  436. class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
  437. def __init__(self) -> None:
  438. self.nn_modules: dict[str, Any] = {}
  439. def copy_graphstate(self) -> ModuleContextCheckpointState:
  440. return ModuleContextCheckpointState(dict(self.nn_modules))
  441. def restore_graphstate(self, state: ModuleContextCheckpointState) -> None:
  442. assert isinstance(state, ModuleContextCheckpointState)
  443. self.nn_modules = state.nn_modules
  444. class GlobalContextCheckpointState:
  445. global_state: dict[str, tuple[Callable, Any]] = {}
  446. def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None:
  447. self.global_state = global_states
  448. def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]:
  449. """
  450. Produces a delta against another GlobalContextCheckpointState.
  451. Returns None if no delta is found, otherwise, return a set() of mismatched
  452. global key names.
  453. """
  454. r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
  455. if len(r) == 0:
  456. return None
  457. return r
  458. def __eq__(self, other: object) -> bool:
  459. if not isinstance(other, GlobalContextCheckpointState):
  460. return False
  461. return self.diff(other) is None
  462. class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
  463. """
  464. This keeps track of the global torch state during tracing of a function.
  465. For example, torch.is_grad_enabled.
  466. """
  467. _supported_global_states = {
  468. "grad_enabled",
  469. "autocast_enabled",
  470. "autocast_cpu_enabled",
  471. "autocast_gpu_dtype",
  472. "autocast_cpu_dtype",
  473. "autocast_cache_enabled",
  474. }
  475. def __init__(self) -> None:
  476. self.global_state: dict[str, tuple[Callable, Any]] = {}
  477. def copy_graphstate(self) -> GlobalContextCheckpointState:
  478. return GlobalContextCheckpointState(self.global_state)
  479. def restore_graphstate(self, state: GlobalContextCheckpointState) -> None:
  480. assert isinstance(state, GlobalContextCheckpointState)
  481. self.global_state = state.global_state
  482. assert (
  483. len(self.global_state) == len(self._supported_global_states)
  484. and set(self.global_state.keys()) == self._supported_global_states
  485. ), "Global state mismatch"
  486. for func, args in self.global_state.values():
  487. func(args)
  488. # Like a Set[Guard] but will record the user stack on all guards at the
  489. # time they were installed at their destination
  490. class GuardsSet:
  491. def __init__(self, inner: Optional[set[Guard]] = None) -> None:
  492. if inner is None:
  493. inner = set()
  494. self.inner = inner
  495. def __iter__(self) -> Iterator[Guard]:
  496. return iter(self.inner)
  497. def __len__(self) -> int:
  498. return len(self.inner)
  499. # Subtraction along with bool is typically used to determine the delta of
  500. # added guards between checkpoints for higher order ops
  501. def __sub__(self, other: GuardsSet) -> GuardsSet:
  502. return GuardsSet(self.inner - other.inner)
  503. def __bool__(self) -> bool:
  504. return bool(self.inner)
  505. def add(
  506. self, guard: Guard, *, collect_debug_stack: bool = True, skip: int = 0
  507. ) -> None:
  508. if guard in self.inner:
  509. return
  510. if collect_debug_stack:
  511. if guard.stack is None:
  512. guard.stack = CapturedTraceback.extract(skip=1 + skip)
  513. if guard.user_stack is None:
  514. guard.user_stack = TracingContext.extract_stack()
  515. self.inner.add(guard)
  516. def update(self, *others: set[Guard]) -> None:
  517. for o in others:
  518. for g in o:
  519. self.add(g, skip=1)
  520. def remove_guards_with_source(self, source: Source) -> None:
  521. """Delete all guards that contains a given source"""
  522. from ._dynamo.source import is_from_source
  523. self.inner = {
  524. g for g in self.inner if not is_from_source(g.originating_source, source)
  525. }
  526. """
  527. A GuardsContext is a checkpointable representation of all the guards in the current tracing
  528. context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
  529. directly outside of it. For passing around internal state representations of this object,
  530. prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
  531. """
  532. class GuardsContext(Checkpointable[GuardsCheckpointState]):
  533. def __init__(self) -> None:
  534. self.dynamo_guards: GuardsSet = GuardsSet()
  535. self.aotautograd_guards: list[GuardEnvExpr] = []
  536. def copy_graphstate(self) -> GuardsCheckpointState:
  537. return GuardsCheckpointState(set(self.dynamo_guards.inner))
  538. def restore_graphstate(self, state: GuardsCheckpointState) -> None:
  539. # NB: "steals" the passed in state
  540. assert isinstance(state, GuardsCheckpointState)
  541. self.dynamo_guards = GuardsSet(state.dynamo_guards)
  542. class HopSubgraphCache:
  543. @abstractmethod
  544. def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ...
  545. @abstractmethod
  546. def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ...
  547. @abstractmethod
  548. def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ...
  549. @abstractmethod
  550. def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: ...
  551. @abstractmethod
  552. def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ...
  553. @abstractmethod
  554. def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: ...
  555. @abstractmethod
  556. def add_lazy_bwd_entry(
  557. self,
  558. identifier: str,
  559. tangent_metadata: tuple[object],
  560. gmod: torch.fx.GraphModule,
  561. ) -> int: ...
  562. @abstractmethod
  563. def get_lazy_bwd_entry(
  564. self, identifier: str, tangent_metadata: tuple[object]
  565. ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: ...
  566. class InvokeSubgraphCache(HopSubgraphCache):
  567. def __init__(self) -> None:
  568. self.autograd_cache: dict[str, Callable] = {}
  569. self.proxy_dispatch_cache: dict[str, Callable] = {}
  570. self.dynamo_installed_submodules: dict[int, list[str]] = defaultdict(list)
  571. self.lazy_bwd_cache: dict[
  572. str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
  573. ] = defaultdict(dict)
  574. def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
  575. self.dynamo_installed_submodules[fn_id].append(identifier)
  576. def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]:
  577. return self.dynamo_installed_submodules.get(fn_id, [])
  578. def add_autograd_key_entry(self, identifier: str, key: Callable) -> None:
  579. self.autograd_cache[identifier] = key
  580. def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]:
  581. return self.autograd_cache.get(identifier, None)
  582. def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None:
  583. self.proxy_dispatch_cache[identifier] = key
  584. def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]:
  585. return self.proxy_dispatch_cache.get(identifier, None)
  586. def add_lazy_bwd_entry(
  587. self,
  588. identifier: str,
  589. tangent_metadata: tuple[object],
  590. gmod: torch.fx.GraphModule,
  591. ) -> int:
  592. # Save the number of existing graph modules in the dictionary to get the suffix
  593. num_gmods = len(self.lazy_bwd_cache[identifier])
  594. self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods)
  595. return num_gmods
  596. def get_lazy_bwd_entry(
  597. self, identifier: str, tangent_metadata: tuple[object]
  598. ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]:
  599. if identifier not in self.lazy_bwd_cache:
  600. return (None, None)
  601. return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None))
  602. class HopDispatchSetCache:
  603. def __init__(self) -> None:
  604. # Delayed import to avoid circular dependency
  605. from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
  606. self.hop_cache_map = {invoke_subgraph: InvokeSubgraphCache()}
  607. def get_cache(
  608. self, op: torch._ops.HigherOrderOperator
  609. ) -> Optional[HopSubgraphCache]:
  610. if op not in self.hop_cache_map:
  611. return None
  612. return self.hop_cache_map[op] # type: ignore[index]
  613. _TLS = threading.local()
  614. """
  615. TracingContext is the source of truth for all currently accumulated information
  616. needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
  617. are open to managing their own TracingContext with that in mind.
  618. The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
  619. having to plumb complex subsystems across multiple verticals.
  620. Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
  621. Accessing the current tracing context via
  622. TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
  623. to plumb objects back up to where frame interpretation happened.
  624. Note that you can end up with multiple TracingContext for a single compilation
  625. of a frame, as we reset the TracingContext whenever we restart analysis.
  626. CompileContext is a more overarching context that encompasses multiple restarts.
  627. """
  628. class CompileContext:
  629. @staticmethod
  630. def get() -> CompileContext:
  631. assert _TLS.compile_context is not None
  632. return _TLS.compile_context
  633. @staticmethod
  634. def try_get() -> Optional[CompileContext]:
  635. return getattr(_TLS, "compile_context", None)
  636. def __init__(self, compile_id: Optional[CompileId]) -> None:
  637. assert compile_id is None or isinstance(compile_id, CompileId)
  638. self.compile_id: Optional[CompileId] = compile_id
  639. self.attempt = 0
  640. # Verbose ShapeEnv guards produced.
  641. self.shape_env_guards: list[str] = []
  642. @staticmethod
  643. def current_compile_id() -> Optional[CompileId]:
  644. self = CompileContext.try_get()
  645. if self is None:
  646. return None
  647. return self.compile_id
  648. @staticmethod
  649. def current_trace_id() -> Optional[TraceId]:
  650. self = CompileContext.try_get()
  651. if self is None:
  652. return None
  653. if self.compile_id is None:
  654. return None
  655. return TraceId(self.compile_id, self.attempt)
  656. class TracingContext:
  657. """
  658. Provides the currently installed TracingContext, or None.
  659. Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
  660. will return None.
  661. """
  662. @staticmethod
  663. def try_get() -> Optional[TracingContext]:
  664. return getattr(_TLS, "tracing_context", None)
  665. @staticmethod
  666. def get() -> TracingContext:
  667. if ctx := TracingContext.try_get():
  668. return ctx
  669. raise RuntimeError(
  670. "TracingContext.get() must be called within an ongoing trace."
  671. )
  672. def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None:
  673. self.guards_context = GuardsContext()
  674. self.module_context = ModuleContext()
  675. self.global_context = GlobalContext()
  676. self.previously_inlined_functions: dict[Any, Any] = dict()
  677. self.previously_cleaned_instructions: dict[Any, Any] = dict()
  678. self.fake_mode: Optional[FakeTensorMode] = fake_mode
  679. self.frame_summary_stack: list[traceback.FrameSummary] = []
  680. # This is morally part of frame_summary_stack, but it is kept separate
  681. # for clarity. As we process a frame, this variable gets updated
  682. # to keep track of what line we are in the function. We make a
  683. # function call, this gets cleared and the frame location is pushed
  684. # to frame_summary_stack (prepping this variable for the inner frame's
  685. # progress)
  686. self.loc_in_frame: Optional[tuple[str, int, str]] = None
  687. # this is only set after aot_autograd
  688. self.fw_metadata: Optional[ViewAndMutationMeta] = None
  689. # this is only set when the DDPOptimizer is used
  690. self.ddp_optimizer_ctx: Optional[DDPOptimizerContext] = None
  691. # this is only set after aot_autograd
  692. self.aot_graph_name: Optional[list[str]] = None
  693. self.params_flat: Optional[list[Any]] = None
  694. self.params_flat_unwrap_subclasses: Optional[list[Any]] = None
  695. self.params_unwrapped_to_flat_index: Optional[list[Any]] = None
  696. # this is for extended return calling convention from backend
  697. # compiler to aot_autograd
  698. # Per output, what the compiler specified stride of the output is,
  699. # or None if no stride is known. This is always the HINT, it
  700. # is never a SymInt (it would be better if it was a SymInt, but
  701. # I can't conveniently get this from Inductor atm. Also, be
  702. # careful not to accidentally induce guards on the SymInt if
  703. # you ever do change this in aot_autograd.py; you should check
  704. # on permutations preferentially.)
  705. self.output_strides: Optional[list[Optional[tuple[int, ...]]]] = None
  706. # When this is True, whenever we encounter an int in Dynamo tracing,
  707. # we will (1) force unspec it and (2) force it as a size-like unbacked
  708. # integer. This is currently used when processing certain lists of
  709. # ints that are known to be size-like and may have 0/1 entries that we
  710. # must not specialize on.
  711. self.force_unspec_int_unbacked_size_like = False
  712. # See note [Tensor Fakification and Symbol Caching]
  713. self.tensor_to_context = WeakTensorKeyDictionary()
  714. # If this true, Aot Autograd will return output Fake Tensors with appropriate
  715. # meta on the first invocation
  716. # see note: [Returning Fake Tensors on First AOT Autograd Call]
  717. self.fakify_first_call = False
  718. self.hop_dispatch_set_cache = HopDispatchSetCache()
  719. # list of code objects for inlined functions
  720. self.traced_code: list[CodeType] = []
  721. def clear(self) -> None:
  722. # Look at the note in output_graph.py in function `save_global_state`
  723. # for the context on clearing global context.
  724. self.global_context.global_state = {}
  725. self.previously_inlined_functions.clear()
  726. self.previously_cleaned_instructions.clear()
  727. @staticmethod
  728. @contextmanager
  729. def patch(**kwargs: Any) -> Generator[None, None, None]:
  730. prior = {}
  731. ctx = TracingContext.get()
  732. for key in kwargs.keys():
  733. # KeyError on invalid entry
  734. prior[key] = getattr(ctx, key)
  735. for key, val in kwargs.items():
  736. setattr(ctx, key, val)
  737. try:
  738. yield
  739. finally:
  740. for key, val in prior.items():
  741. setattr(ctx, key, val)
  742. @staticmethod
  743. def extract_stack() -> traceback.StackSummary:
  744. self = TracingContext.try_get()
  745. if self is None:
  746. return traceback.StackSummary()
  747. stack = self.frame_summary_stack
  748. if self.loc_in_frame is not None:
  749. stack = stack + [self._populate_loc_in_frame_summary()]
  750. return traceback.StackSummary.from_list(stack)
  751. def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary:
  752. assert self.loc_in_frame is not None
  753. filename, lineno, frame_name = self.loc_in_frame
  754. return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False)
  755. # Call this when you want to call into some code that isn't necessarily
  756. # associated with the current frame state
  757. @staticmethod
  758. @contextlib.contextmanager
  759. def clear_frame() -> Generator[None, None, None]:
  760. tc = TracingContext.get()
  761. with (
  762. unittest.mock.patch.object(tc, "frame_summary_stack", []),
  763. unittest.mock.patch.object(tc, "loc_in_frame", None),
  764. ):
  765. try:
  766. yield
  767. except Exception as e:
  768. # Prevent real_stack from getting attached
  769. #
  770. # The invariant is that if an Exception as real_stack, we've
  771. # appropriately attached a user stack and we no longer need to
  772. # attach anything. Because we cannot conveniently interpose
  773. # when an exception is thrown, we instead interpose everywhere
  774. # we set what the user stack is set (using the context
  775. # manager). However, our compiler stack does "tail calls"
  776. # (when it calls into user compiler), at which point the
  777. # parent exception frames would incorrectly attach an
  778. # incorrect frame.
  779. #
  780. # However, if, somehow, someone raised an exception with this
  781. # scope that had a stack (for example, because they are
  782. # restoring the user stack state appropriately as they process
  783. # node by node), we should respect it. Thus, we cannot
  784. # unconditionally set None.
  785. if not hasattr(e, "real_stack"):
  786. e.real_stack = None # type: ignore[attr-defined]
  787. raise
  788. @staticmethod
  789. @contextlib.contextmanager
  790. def current_frame(
  791. frame_summary: Optional[traceback.FrameSummary],
  792. ) -> Generator[None, None, None]:
  793. # frame_summary can be None to solely take advantage of real_stack
  794. # attachment to thrown exceptions
  795. tc = TracingContext.get()
  796. if frame_summary is not None:
  797. tc.frame_summary_stack.append(frame_summary)
  798. old = tc.loc_in_frame
  799. tc.loc_in_frame = None
  800. try:
  801. yield
  802. except Exception as e:
  803. if not hasattr(e, "real_stack"):
  804. e.real_stack = tc.extract_stack() # type: ignore[attr-defined]
  805. raise
  806. finally:
  807. if frame_summary is not None:
  808. tc.frame_summary_stack.pop()
  809. tc.loc_in_frame = old
  810. @staticmethod
  811. @contextlib.contextmanager
  812. def report_output_strides() -> Generator[
  813. Optional[list[Optional[tuple[int, ...]]]], None, None
  814. ]:
  815. tc = TracingContext.try_get()
  816. if tc is None:
  817. yield None
  818. return
  819. old_output_strides = tc.output_strides
  820. tc.output_strides = []
  821. try:
  822. yield tc.output_strides
  823. finally:
  824. tc.output_strides = old_output_strides
  825. @staticmethod
  826. def set_current_loc(filename: str, lineno: int, frame_name: str) -> None:
  827. # Save the current location in the frame. Lazily generate the
  828. # framesummary.
  829. TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
  830. @staticmethod
  831. def get_traced_code() -> Optional[list[CodeType]]:
  832. tc = TracingContext.try_get()
  833. if tc is None:
  834. return None
  835. return tc.traced_code
  836. @contextmanager
  837. def compile_context(
  838. context: Optional[CompileContext],
  839. ) -> Generator[Optional[CompileContext], None, None]:
  840. old_context = getattr(_TLS, "compile_context", None)
  841. _TLS.compile_context = context
  842. try:
  843. yield context
  844. finally:
  845. _TLS.compile_context = old_context
  846. @contextmanager
  847. def tracing(
  848. context: Optional[TracingContext],
  849. ) -> Generator[Optional[TracingContext], None, None]:
  850. """
  851. This function installs the passed in tracing context as a dynamic scoped
  852. global variable.
  853. Calls to TracingContext.get() while not under a `with tracing()` context
  854. will return None.
  855. """
  856. old_context = getattr(_TLS, "tracing_context", None)
  857. _TLS.tracing_context = context
  858. try:
  859. yield context
  860. except Exception as e:
  861. if not hasattr(e, "real_stack") and context is not None:
  862. e.real_stack = context.extract_stack() # type: ignore[attr-defined]
  863. raise
  864. finally:
  865. if (
  866. context is not None
  867. and context.fake_mode is not None
  868. and context.fake_mode.shape_env is not None
  869. ):
  870. context.fake_mode.shape_env.cleanup()
  871. _TLS.tracing_context = old_context
  872. # Subclasses can be found in torch/_dynamo/source.py
  873. # TODO(voz): Consider a toplevel torch/_source.py
  874. @dataclasses.dataclass(frozen=True)
  875. class Source:
  876. def is_dict_key(self) -> bool:
  877. return False
  878. def is_ephemeral(self) -> bool:
  879. return False
  880. def reconstruct(self, codegen: PyCodegen) -> None:
  881. raise NotImplementedError
  882. def guard_source(self) -> GuardSource:
  883. raise NotImplementedError
  884. def name(self) -> str:
  885. raise NotImplementedError
  886. def make_guard(self, fn: Callable[..., Any]) -> Guard:
  887. if self.guard_source() is GuardSource.CONSTANT:
  888. raise NotImplementedError
  889. return Guard(self, fn)
  890. def is_specialized_nn_module(self) -> bool:
  891. return self.guard_source().is_specialized_nn_module()
  892. def subguards_allowed(self) -> bool:
  893. """True if you can guard on attributes of this"""
  894. return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
  895. # Subclasses can be found in torch/_dynamo/source.py
  896. @dataclasses.dataclass(frozen=True)
  897. class ChainedSource(Source):
  898. base: Source
  899. def is_dict_key(self) -> bool:
  900. # Recurse until you either hit a ConstDictKey or a Source
  901. return self.base.is_dict_key()
  902. def is_ephemeral(self) -> bool:
  903. return self.base.is_ephemeral()
  904. def get_base(self) -> Source:
  905. current: Source = self
  906. while isinstance(current, ChainedSource):
  907. current = current.base
  908. return current
  909. def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
  910. """
  911. Attempts to "detect" what the current fake mode is. If there is one ambiently
  912. available from TracingContext, we preferentially use that. Otherwise, we
  913. heuristically detect the fake mode via the following sources, in order of
  914. priority:
  915. - Currently active fake mode on stack
  916. - Fake mode associated with passed in tensors (inputs does not
  917. have to be flattened)
  918. """
  919. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
  920. fake_modes = []
  921. if context := TracingContext.try_get():
  922. fake_mode = context.fake_mode
  923. if fake_mode is not None:
  924. fake_modes.append((fake_mode, "tracing context", 0))
  925. from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
  926. for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
  927. if isinstance(m, FakeTensorMode):
  928. fake_modes.append((m, "active fake mode", i))
  929. flat_inputs = pytree.tree_leaves(inputs)
  930. for i, flat_input in enumerate(flat_inputs):
  931. if isinstance(flat_input, FakeTensor):
  932. fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
  933. if fake_modes:
  934. fake_mode, desc1, i1 = fake_modes[0]
  935. for m, desc2, i2 in fake_modes[1:]:
  936. assert fake_mode is m, (
  937. f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
  938. f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
  939. f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
  940. )
  941. return fake_mode
  942. else:
  943. return None
  944. def active_fake_mode() -> Optional[FakeTensorMode]:
  945. """
  946. Inspects the dispatch mode stack for an active fake mode and returns it.
  947. Returns None if no fake mode is active.
  948. """
  949. from torch._subclasses.fake_tensor import FakeTensorMode
  950. from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
  951. for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
  952. if isinstance(m, FakeTensorMode):
  953. return m
  954. return None