_guards.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319
  1. from __future__ import annotations
  2. import contextlib
  3. import dataclasses
  4. import enum
  5. import functools
  6. import logging
  7. import re
  8. import sys
  9. import threading
  10. import traceback
  11. import unittest.mock
  12. import weakref
  13. from abc import abstractmethod
  14. from collections import defaultdict
  15. from contextlib import contextmanager
  16. from dataclasses import dataclass
  17. from typing import Any, Generic, NamedTuple, Optional, overload, TYPE_CHECKING, TypeVar
  18. if sys.version_info >= (3, 11):
  19. from typing import dataclass_transform
  20. else:
  21. def dataclass_transform():
  22. def decorator(fn):
  23. return fn
  24. return decorator
  25. import torch
  26. from torch.utils import _pytree as pytree
  27. from torch.utils._ordered_set import OrderedSet
  28. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  29. from torch.utils._traceback import CapturedTraceback, format_frame
  30. from torch.utils.weak import WeakTensorKeyDictionary
  31. log = logging.getLogger(__name__)
  32. if TYPE_CHECKING:
  33. from collections.abc import Callable, Generator, Iterator
  34. from types import CodeType
  35. import sympy
  36. from torch._dynamo.backends.distributed import DDPOptimizerContext
  37. from torch._dynamo.codegen import PyCodegen
  38. from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
  39. from torch._subclasses.fake_tensor import FakeTensorMode
  40. """
  41. torch._guards is the definitional source of truth for general purpose guard structures.
  42. An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
  43. and no guard installation notions here.
  44. """
  45. COMPILE_ID_PATTERN = re.compile(r"^(?P<frame_id>\d+)/(?P<frame_compile_id>\d+)$")
  46. CA_COMPILE_ID_PATTERN = re.compile(
  47. r"^!(?P<compiled_autograd_id>\d+)(?:/(?P<frame_id>\d+)/(?P<frame_compile_id>\d+))?$"
  48. )
  49. # [Note: Updating CompiledId]
  50. #
  51. # CompiledId represents a unique program-level identifier, and we want to keep that
  52. # property as the codebase evolves. This property is relied on even outside of the pytorch
  53. # repo, e.g. tlparse or other internal tooling. The in-memory format can be freely changed,
  54. # as those dependencies only consume the string serialization.
  55. #
  56. # The string form should be:
  57. # 1. Program-level uid: CompileId can uniquely identify a compiled graph.
  58. # 2. Storage efficient: This object is logged in nearly every entry. We should elide symbols when possible.
  59. # 3. Compact: The string form is directly displayed by some tools. Special symbols are okay.
  60. @dataclass(frozen=True, kw_only=True, slots=True)
  61. class CompileId:
  62. frame_id: int | None
  63. # This id is per-frame, and counts how many times we've compiled this
  64. # frame. This could have been a global id but having this be per-frame
  65. # gives you a better intuitive sense for how many recompiles have occurred
  66. # so far.
  67. frame_compile_id: int | None
  68. # torch.compiling a compiled autograd graph
  69. compiled_autograd_id: int | None = None
  70. # TODO: consider also tracking the recompilation count
  71. # See Note: Updating CompileId
  72. def __str__(self) -> str:
  73. # NOTE: Keep this in sync with both from_string and the tlparse repo
  74. if self.compiled_autograd_id is not None:
  75. assert (self.frame_id is None) == (self.frame_compile_id is None)
  76. frame_str = ""
  77. if self.frame_id is not None:
  78. frame_str = f"/{self.frame_id}/{self.frame_compile_id}"
  79. return f"!{self.compiled_autograd_id}{frame_str}"
  80. else:
  81. assert self.frame_id is not None and self.frame_compile_id is not None
  82. return f"{self.frame_id}/{self.frame_compile_id}"
  83. @classmethod
  84. def from_string(cls, compile_id: str | None) -> CompileId | None:
  85. """
  86. Factory method that creates a CompileId from its string representation.
  87. Keep this in sync with the __str__ method.
  88. """
  89. if compile_id is None:
  90. return None
  91. try:
  92. for pattern in (COMPILE_ID_PATTERN, CA_COMPILE_ID_PATTERN):
  93. if match := pattern.match(compile_id):
  94. groups = match.groupdict()
  95. for k, v in groups.items():
  96. if v is not None:
  97. groups[k] = int(v)
  98. return cls(**groups) # type: ignore[arg-type]
  99. else:
  100. raise ValueError
  101. except Exception as e:
  102. raise ValueError(f"Invalid compile_id '{compile_id}'") from e
  103. class TraceId(NamedTuple):
  104. compile_id: CompileId
  105. # This starts off as 0, and every time we restart analysis it goes
  106. # up by one
  107. attempt: int
  108. def __str__(self) -> str:
  109. # Keep this in sync with tlparse repo
  110. if self.attempt == 0:
  111. return str(self.compile_id)
  112. else:
  113. return f"{self.compile_id}_{self.attempt}"
  114. class GuardSource(enum.Enum):
  115. LOCAL = 0
  116. GLOBAL = 1
  117. LOCAL_SPECIALIZED_NN_MODULE = 2
  118. GLOBAL_SPECIALIZED_NN_MODULE = 3
  119. CONSTANT = 4
  120. RANDOM_VALUE = 5
  121. SHAPE_ENV = 6
  122. LOCAL_FSDP_MODULE = 7
  123. GLOBAL_FSDP_MODULE = 8
  124. BACKWARD_STATE = 9
  125. EPHEMERAL = 10
  126. SYNTHETIC_LOCAL = 11
  127. LOCAL_UNSPECIALIZED_NN_MODULE = 12
  128. GLOBAL_UNSPECIALIZED_NN_MODULE = 13
  129. LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14
  130. GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15
  131. TEMP_LOCAL = 16
  132. def is_fsdp_module(self) -> bool:
  133. return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
  134. def is_specialized_nn_module(self) -> bool:
  135. import torch._dynamo.config as config
  136. if config._unsafe_skip_fsdp_module_guards:
  137. return (
  138. self
  139. in (
  140. GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
  141. GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
  142. )
  143. or self.is_fsdp_module()
  144. )
  145. return self in (
  146. GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
  147. GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
  148. )
  149. def is_unspecialized_nn_module(self) -> bool:
  150. return self in (
  151. GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
  152. GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
  153. GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  154. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  155. )
  156. def is_unspecialized_builtin_nn_module(self) -> bool:
  157. return self in (
  158. GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  159. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  160. )
  161. def is_local(self) -> bool:
  162. return self in (
  163. GuardSource.LOCAL,
  164. GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
  165. GuardSource.LOCAL_FSDP_MODULE,
  166. GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
  167. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  168. )
  169. """
  170. Base class for a "GuardBuilder" role.
  171. The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
  172. confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
  173. to torchdynamo's GuardBuilder.
  174. Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
  175. on GuardSource's select function.
  176. There is value in keeping this GuardBuilderBase empty to keep layering clean.
  177. """
  178. class GuardBuilderBase:
  179. pass
  180. @dataclasses.dataclass(frozen=True)
  181. class SLoc:
  182. framework_loc: traceback.FrameSummary | str | None
  183. maybe_user_loc: str | None
  184. def __str__(self) -> str:
  185. floc = (
  186. self.framework_loc
  187. if isinstance(self.framework_loc, str)
  188. else format_frame(self.framework_loc)
  189. )
  190. if self.maybe_user_loc is not None:
  191. return f"{self.maybe_user_loc} ({floc})"
  192. else:
  193. return f"({floc})"
  194. class ShapeGuard(NamedTuple):
  195. expr: sympy.logic.boolalg.Boolean
  196. sloc: SLoc
  197. size_oblivious: bool
  198. @dataclasses.dataclass(slots=True)
  199. class Guard:
  200. # originating_source is the source that called the make_guard method to
  201. # construct this guard object. The property name specifies what exactly it
  202. # is the guard is guarding on. The meaning of the name is dependent on the
  203. # create_fn; you must look at the use-site inside create_fn to know what
  204. # name means.
  205. #
  206. # That being said, although you might think this is just a "name", name is
  207. # usually an arbitrary Python expression that will be evaluated with all
  208. # globals (and locals, if you create a LOCAL guard) to extract the Python
  209. # object that we want to perform guard tests on. This evaluation
  210. # typically happens in GuardBuilder.eval. In these cases, name is
  211. # typically produced by originating_source.name (not to be confused with
  212. # GuardSource - the property source).
  213. #
  214. # Occasionally, name is not a valid Python expression; sometimes
  215. # it is meaningless. Example create_fns that are like this include
  216. # GRAD_MODE and SHAPE_ENV.
  217. originating_source: Source
  218. create_fn: Callable[[GuardBuilderBase, Guard], None]
  219. # Export only. These values are written to at time of guard check_fn creation.
  220. guard_types: list[str] | None = None
  221. code_list: list[str] | None = None
  222. obj_weakref: object | None = None
  223. guarded_class_weakref: weakref.ReferenceType[Any] | None = None
  224. stack: CapturedTraceback | None = None
  225. user_stack: traceback.StackSummary | None = None
  226. _hash: int | None = None
  227. _unserializable: bool = False
  228. def __hash__(self) -> int:
  229. if self._hash is None:
  230. self._hash = hash((self.name, self.source, id(self.create_fn)))
  231. return self._hash
  232. def sort_key(self) -> tuple[bool, int, int, str, int]:
  233. # Put the duplicate input guards at the end. The duplicate guards have
  234. # two sources while guard.name only considers one source.
  235. is_duplicate_input = (
  236. isinstance(self.create_fn, functools.partial)
  237. and self.create_fn.func is torch._dynamo.guards.GuardBuilder.DUPLICATE_INPUT
  238. )
  239. return (
  240. is_duplicate_input,
  241. self.source.value if self.source else -1,
  242. len(self.name),
  243. self.name,
  244. self.inner_create_fn().__code__.co_firstlineno,
  245. )
  246. def __lt__(self, other: Guard) -> bool:
  247. return self.sort_key() < other.sort_key()
  248. def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]:
  249. if isinstance(self.create_fn, functools.partial):
  250. return self.create_fn.func
  251. else:
  252. return self.create_fn
  253. @property
  254. def name(self) -> str:
  255. return self.originating_source.name
  256. @property
  257. def source(self) -> GuardSource:
  258. return self.originating_source.guard_source
  259. @staticmethod
  260. def weakref_to_str(obj_weakref: object) -> str:
  261. """
  262. This is a workaround of a Python weakref bug.
  263. `obj_weakref` is instance returned by `weakref.ref`,
  264. `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
  265. class MyConfig(dict):
  266. def __getattr__(self, x):
  267. return self[x]
  268. obj = MyConfig(offset=5)
  269. obj_weakref = weakref.ref(obj)
  270. str(obj_weakref) # raise error: KeyError: '__name__'
  271. """
  272. if isinstance(obj_weakref, weakref.ReferenceType):
  273. obj = obj_weakref()
  274. if obj is not None:
  275. return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
  276. else:
  277. return f"<weakref at {hex(id(obj_weakref))}; dead>"
  278. else:
  279. return str(obj_weakref)
  280. def __repr__(self) -> str:
  281. s = f"""
  282. {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
  283. {{
  284. 'guard_types': {self.guard_types},
  285. 'code': {self.code_list},
  286. 'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
  287. 'guarded_class': {self.guarded_class_weakref}
  288. }}
  289. """
  290. return s
  291. def __str__(self) -> str:
  292. output = f"Name: {repr(self.name)}\n"
  293. source = self.source.name.lower() if self.source else ""
  294. output += f" Source: {source}\n"
  295. output += f" Create Function: {self.inner_create_fn().__name__}\n"
  296. output += f" Guard Types: {self.guard_types}\n"
  297. output += f" Code List: {self.code_list}\n"
  298. output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
  299. output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
  300. return output
  301. def create(self, builder: GuardBuilderBase) -> Any:
  302. try:
  303. return self.create_fn(builder, self)
  304. except Exception:
  305. log.exception("Error while creating guard:\n%s", str(self).rstrip())
  306. if self.stack:
  307. log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
  308. raise
  309. def is_specialized_nn_module(self) -> bool:
  310. return self.source.is_specialized_nn_module()
  311. def is_fsdp_module(self) -> bool:
  312. return self.source.is_fsdp_module()
  313. def is_local(self) -> bool:
  314. return self.source.is_local()
  315. def create_fn_name(self) -> str:
  316. if isinstance(self.create_fn, functools.partial):
  317. create_fn = self.create_fn.func # type: ignore[attr-defined]
  318. else:
  319. create_fn = self.create_fn
  320. return create_fn.__name__
  321. def set_export_info(
  322. self,
  323. guard_type: str,
  324. guarded_class: weakref.ReferenceType[Any] | None,
  325. code_list: list[str],
  326. obj_weakref: object,
  327. ) -> None:
  328. if not self.guard_types:
  329. self.guard_types = []
  330. self.guard_types.append(guard_type)
  331. assert self.guarded_class_weakref in (
  332. guarded_class,
  333. None,
  334. ), "Guarded class id must be identical, or None"
  335. self.guarded_class_weakref = guarded_class
  336. if not self.code_list:
  337. self.code_list = code_list
  338. else:
  339. self.code_list.extend(code_list)
  340. # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have
  341. # multiple guards on the same object, the weakref can die between the
  342. # invocation of set_export_info calls. So a dead weakref is also
  343. # acceptable.
  344. assert (
  345. self.obj_weakref in (obj_weakref, None)
  346. or callable(self.obj_weakref)
  347. and self.obj_weakref() is None
  348. ), "Guarded object must be identical, None or ephemeral (dead weakref)"
  349. self.obj_weakref = obj_weakref
  350. T = TypeVar("T")
  351. """
  352. Parent structure for guard env expressions.
  353. A GuardEnvExpr can have any subtype.
  354. Note: All subtypes must be handled exhaustively in
  355. torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
  356. """
  357. @dataclasses.dataclass(frozen=True)
  358. class GuardEnvExpr:
  359. pass
  360. """
  361. A class representing a pair of duplicate inputs.
  362. input_pos_a and input_pos_b are input positions we have deduped.
  363. """
  364. @dataclasses.dataclass(frozen=True)
  365. class DuplicateInputs(GuardEnvExpr):
  366. input_source_a: Source
  367. input_source_b: Source
  368. def __post_init__(self) -> None:
  369. assert self.input_source_a != self.input_source_b
  370. """
  371. A class representing storage overlap relations among inputs that aliases the same storage.
  372. Given that a set of tensors alias the same storage, this guard checks whether they actually
  373. have overlapping storages.
  374. While non_overlapping_sources represent input tensors that definitely don't have any storage
  375. overlapping with any other input, overlapping_sources represent tensors that either:
  376. 1. Do overlap some other input tensor
  377. 2. Might not overlap some other input tensor, but we are not sure
  378. """
  379. @dataclasses.dataclass(frozen=True)
  380. class StorageOverlap(GuardEnvExpr):
  381. overlapping_sources: list[Source]
  382. non_overlapping_sources: list[Source]
  383. """
  384. Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
  385. copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
  386. can also be taken in at restore_graphstate(T) calls.
  387. When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
  388. does not provide any guarantees around consistency, idempotency, or safety of calling its APIs, yet.
  389. In the future, it will have a closer coupling to a generic Checkpoint management system.
  390. """
  391. class Checkpointable(Generic[T]):
  392. @abstractmethod
  393. def copy_graphstate(self) -> T: ...
  394. @abstractmethod
  395. def restore_graphstate(self, state: T) -> None: ...
  396. class GuardsCheckpointState:
  397. """
  398. The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
  399. """
  400. dynamo_guards: OrderedSet[Guard]
  401. def __init__(self, dynamo_guards: OrderedSet[Guard]) -> None:
  402. self.dynamo_guards = dynamo_guards
  403. def diff(self, other: GuardsCheckpointState) -> Optional[OrderedSet[Guard]]:
  404. """
  405. Produces a delta against another GuardsCheckpointState.
  406. Returns None if no delta is found, otherwise, return an OrderedSet() of mismatched
  407. Guard type objects.
  408. """
  409. r = self.dynamo_guards.difference(other.dynamo_guards)
  410. if len(r) == 0:
  411. return None
  412. return r
  413. def __eq__(self, other: object) -> bool:
  414. if not isinstance(other, GuardsCheckpointState):
  415. return False
  416. return self.diff(other) is None
  417. class ModuleContextCheckpointState:
  418. nn_modules: dict[str, torch.nn.Module] = {}
  419. def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None:
  420. self.nn_modules = nn_modules
  421. def diff(self, other: ModuleContextCheckpointState) -> set[str] | None:
  422. """
  423. Produces a delta against another ModuleContextCheckpointState.
  424. Returns None if no delta is found, otherwise, return a set() of mismatched
  425. module key names.
  426. """
  427. r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
  428. if len(r) == 0:
  429. return None
  430. return r
  431. def __eq__(self, other: object) -> bool:
  432. if not isinstance(other, ModuleContextCheckpointState):
  433. return False
  434. return self.diff(other) is None
  435. class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
  436. def __init__(self) -> None:
  437. self.nn_modules: dict[str, Any] = {}
  438. def copy_graphstate(self) -> ModuleContextCheckpointState:
  439. return ModuleContextCheckpointState(dict(self.nn_modules))
  440. def restore_graphstate(self, state: ModuleContextCheckpointState) -> None:
  441. assert isinstance(state, ModuleContextCheckpointState)
  442. self.nn_modules = state.nn_modules
  443. class GlobalContextCheckpointState:
  444. global_state: dict[str, tuple[Callable, Any]] = {}
  445. def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None:
  446. self.global_state = global_states
  447. def diff(self, other: GlobalContextCheckpointState) -> set[str] | None:
  448. """
  449. Produces a delta against another GlobalContextCheckpointState.
  450. Returns None if no delta is found, otherwise, return a set() of mismatched
  451. global key names.
  452. """
  453. r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
  454. if len(r) == 0:
  455. return None
  456. return r
  457. def __eq__(self, other: object) -> bool:
  458. if not isinstance(other, GlobalContextCheckpointState):
  459. return False
  460. return self.diff(other) is None
  461. class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
  462. """
  463. This keeps track of the global torch state during tracing of a function.
  464. For example, torch.is_grad_enabled.
  465. """
  466. _supported_global_states = {
  467. "grad_enabled",
  468. "autocast_enabled",
  469. "autocast_cpu_enabled",
  470. "autocast_gpu_dtype",
  471. "autocast_cpu_dtype",
  472. "autocast_cache_enabled",
  473. }
  474. def __init__(self) -> None:
  475. self.global_state: dict[str, tuple[Callable, Any]] = {}
  476. def copy_graphstate(self) -> GlobalContextCheckpointState:
  477. return GlobalContextCheckpointState(self.global_state)
  478. def restore_graphstate(self, state: GlobalContextCheckpointState) -> None:
  479. assert isinstance(state, GlobalContextCheckpointState)
  480. self.global_state = state.global_state
  481. assert (
  482. len(self.global_state) == len(self._supported_global_states)
  483. and set(self.global_state.keys()) == self._supported_global_states
  484. ), "Global state mismatch"
  485. for func, args in self.global_state.values():
  486. func(args)
  487. # Like a Set[Guard] but will record the user stack on all guards at the
  488. # time they were installed at their destination
  489. class GuardsSet:
  490. def __init__(self, inner: Optional[OrderedSet[Guard]] = None) -> None:
  491. if inner is None:
  492. self.inner: OrderedSet[Guard] = OrderedSet()
  493. else:
  494. self.inner = inner
  495. def __iter__(self) -> Iterator[Guard]:
  496. return iter(self.inner)
  497. def __len__(self) -> int:
  498. return len(self.inner)
  499. # Subtraction along with bool is typically used to determine the delta of
  500. # added guards between checkpoints for higher order ops
  501. def __sub__(self, other: GuardsSet) -> GuardsSet:
  502. return GuardsSet(self.inner - other.inner)
  503. def __bool__(self) -> bool:
  504. return bool(self.inner)
  505. def add(
  506. self, guard: Guard, *, collect_debug_stack: bool = True, skip: int = 0
  507. ) -> None:
  508. if guard in self.inner:
  509. return
  510. if collect_debug_stack:
  511. if guard.stack is None:
  512. guard.stack = CapturedTraceback.extract(skip=1 + skip)
  513. if guard.user_stack is None:
  514. guard.user_stack = TracingContext.extract_stack()
  515. self.inner.add(guard)
  516. def update(self, *others: set[Guard]) -> None:
  517. for o in others:
  518. for g in o:
  519. self.add(g, skip=1)
  520. def remove_guards_with_source(self, source: Source) -> None:
  521. """Delete all guards that contains a given source"""
  522. from ._dynamo.source import is_from_source
  523. self.inner = OrderedSet(
  524. g for g in self.inner if not is_from_source(g.originating_source, source)
  525. )
  526. """
  527. A GuardsContext is a checkpointable representation of all the guards in the current tracing
  528. context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
  529. directly outside of it. For passing around internal state representations of this object,
  530. prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
  531. """
  532. class GuardsContext(Checkpointable[GuardsCheckpointState]):
  533. def __init__(self) -> None:
  534. self.dynamo_guards: GuardsSet = GuardsSet()
  535. self.aotautograd_guards: list[GuardEnvExpr] = []
  536. def copy_graphstate(self) -> GuardsCheckpointState:
  537. return GuardsCheckpointState(OrderedSet(self.dynamo_guards.inner))
  538. def restore_graphstate(self, state: GuardsCheckpointState) -> None:
  539. # NB: "steals" the passed in state
  540. assert isinstance(state, GuardsCheckpointState)
  541. self.dynamo_guards = GuardsSet(state.dynamo_guards)
  542. class HopSubgraphCache:
  543. @abstractmethod
  544. def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ...
  545. @abstractmethod
  546. def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ...
  547. @abstractmethod
  548. def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ...
  549. @abstractmethod
  550. def get_autograd_key_entry(self, identifier: str) -> Callable | None: ...
  551. @abstractmethod
  552. def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ...
  553. @abstractmethod
  554. def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None: ...
  555. @abstractmethod
  556. def add_lazy_bwd_entry(
  557. self,
  558. identifier: str,
  559. tangent_metadata: tuple[object],
  560. gmod: torch.fx.GraphModule,
  561. ) -> int: ...
  562. @abstractmethod
  563. def get_lazy_bwd_entry(
  564. self, identifier: str, tangent_metadata: tuple[object]
  565. ) -> tuple[torch.fx.GraphModule | None, int | None]: ...
  566. class InvokeSubgraphCache(HopSubgraphCache):
  567. def __init__(self) -> None:
  568. self.autograd_cache: dict[str, Callable] = {}
  569. self.proxy_dispatch_cache: dict[str, Callable] = {}
  570. self.dynamo_installed_submodules: dict[int, list[str]] = defaultdict(list)
  571. self.lazy_bwd_cache: dict[
  572. str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
  573. ] = defaultdict(dict)
  574. self.effects_cache: dict[
  575. str, set
  576. ] = {} # Maps identifier -> set of effect types
  577. def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
  578. self.dynamo_installed_submodules[fn_id].append(identifier)
  579. def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]:
  580. return self.dynamo_installed_submodules.get(fn_id, [])
  581. def add_autograd_key_entry(self, identifier: str, key: Callable) -> None:
  582. self.autograd_cache[identifier] = key
  583. def get_autograd_key_entry(self, identifier: str) -> Callable | None:
  584. return self.autograd_cache.get(identifier, None)
  585. def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None:
  586. self.proxy_dispatch_cache[identifier] = key
  587. def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None:
  588. return self.proxy_dispatch_cache.get(identifier, None)
  589. def add_lazy_bwd_entry(
  590. self,
  591. identifier: str,
  592. tangent_metadata: tuple[object],
  593. gmod: torch.fx.GraphModule,
  594. ) -> int:
  595. # Save the number of existing graph modules in the dictionary to get the suffix
  596. num_gmods = len(self.lazy_bwd_cache[identifier])
  597. self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods)
  598. return num_gmods
  599. def get_lazy_bwd_entry(
  600. self, identifier: str, tangent_metadata: tuple[object]
  601. ) -> tuple[torch.fx.GraphModule | None, int | None]:
  602. if identifier not in self.lazy_bwd_cache:
  603. return (None, None)
  604. return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None))
  605. def add_effects(self, identifier: str, effects: set) -> None:
  606. """Store the effect types for a given invoke_subgraph identifier."""
  607. if prev_effects := self.effects_cache.get(identifier, None):
  608. assert effects == prev_effects, (
  609. "Different number of effects were found for invoke_subgraph "
  610. f"call with identifier {identifier}. \n"
  611. f"Previously we had the following effects: {prev_effects}.\n"
  612. f"But now we have: {effects}."
  613. )
  614. self.effects_cache[identifier] = effects
  615. def get_effects(self, identifier: str) -> set | None:
  616. """Retrieve the effect types for a given invoke_subgraph identifier."""
  617. return self.effects_cache.get(identifier, None)
  618. class HopDispatchSetCache:
  619. def __init__(self) -> None:
  620. # Delayed import to avoid circular dependency
  621. from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
  622. self.hop_cache_map = {invoke_subgraph: InvokeSubgraphCache()}
  623. def get_cache(self, op: torch._ops.HigherOrderOperator) -> HopSubgraphCache | None:
  624. if op not in self.hop_cache_map:
  625. return None
  626. return self.hop_cache_map[op] # type: ignore[index]
  627. _TLS = threading.local()
  628. """
  629. TracingContext is the source of truth for all currently accumulated information
  630. needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
  631. are open to managing their own TracingContext with that in mind.
  632. The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
  633. having to plumb complex subsystems across multiple verticals.
  634. Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
  635. Accessing the current tracing context via
  636. TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
  637. to plumb objects back up to where frame interpretation happened.
  638. Note that you can end up with multiple TracingContext for a single compilation
  639. of a frame, as we reset the TracingContext whenever we restart analysis.
  640. CompileContext is a more overarching context that encompasses multiple restarts.
  641. """
  642. class CompileContext:
  643. @staticmethod
  644. def get() -> CompileContext:
  645. assert _TLS.compile_context is not None
  646. return _TLS.compile_context
  647. @staticmethod
  648. def try_get() -> CompileContext | None:
  649. return getattr(_TLS, "compile_context", None)
  650. def __init__(self, compile_id: CompileId | None) -> None:
  651. assert compile_id is None or isinstance(compile_id, CompileId)
  652. self.compile_id: CompileId | None = compile_id
  653. self.attempt = 0
  654. # Verbose ShapeEnv guards produced.
  655. self.shape_env_guards: list[str] = []
  656. @staticmethod
  657. def current_compile_id() -> CompileId | None:
  658. self = CompileContext.try_get()
  659. if self is None:
  660. return None
  661. return self.compile_id
  662. @staticmethod
  663. def current_trace_id() -> TraceId | None:
  664. self = CompileContext.try_get()
  665. if self is None:
  666. return None
  667. if self.compile_id is None:
  668. return None
  669. return TraceId(self.compile_id, self.attempt)
  670. class TracingContext:
  671. """
  672. Provides the currently installed TracingContext, or None.
  673. Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
  674. will return None.
  675. """
  676. @staticmethod
  677. def try_get() -> TracingContext | None:
  678. return getattr(_TLS, "tracing_context", None)
  679. @staticmethod
  680. def get() -> TracingContext:
  681. if ctx := TracingContext.try_get():
  682. return ctx
  683. raise RuntimeError(
  684. "TracingContext.get() must be called within an ongoing trace."
  685. )
  686. def __init__(self, fake_mode: FakeTensorMode | None) -> None:
  687. self.guards_context = GuardsContext()
  688. self.module_context = ModuleContext()
  689. self.global_context = GlobalContext()
  690. self.previously_inlined_functions: dict[Any, Any] = dict()
  691. self.previously_cleaned_instructions: dict[Any, Any] = dict()
  692. self.fake_mode: FakeTensorMode | None = fake_mode
  693. self.frame_summary_stack: list[traceback.FrameSummary] = []
  694. # This is morally part of frame_summary_stack, but it is kept separate
  695. # for clarity. As we process a frame, this variable gets updated
  696. # to keep track of what line we are in the function. We make a
  697. # function call, this gets cleared and the frame location is pushed
  698. # to frame_summary_stack (prepping this variable for the inner frame's
  699. # progress)
  700. self.loc_in_frame: tuple[str, int, str] | None = None
  701. # this is only set after aot_autograd
  702. self.fw_metadata: ViewAndMutationMeta | None = None
  703. # this is only set when the DDPOptimizer is used
  704. self.ddp_optimizer_ctx: DDPOptimizerContext | None = None
  705. # this is only set after aot_autograd
  706. self.aot_graph_name: list[str] | None = None
  707. self.params_flat: list[Any] | None = None
  708. self.params_flat_unwrap_subclasses: list[Any] | None = None
  709. self.params_unwrapped_to_flat_index: list[Any] | None = None
  710. # this is for extended return calling convention from backend
  711. # compiler to aot_autograd
  712. # Per output, what the compiler specified stride of the output is,
  713. # or None if no stride is known. This is always the HINT, it
  714. # is never a SymInt (it would be better if it was a SymInt, but
  715. # I can't conveniently get this from Inductor atm. Also, be
  716. # careful not to accidentally induce guards on the SymInt if
  717. # you ever do change this in aot_autograd.py; you should check
  718. # on permutations preferentially.)
  719. self.output_strides: list[tuple[int, ...] | None] | None = None
  720. # When this is True, whenever we encounter an int in Dynamo tracing,
  721. # we will (1) force unspec it and (2) force it as a size-like unbacked
  722. # integer. This is currently used when processing certain lists of
  723. # ints that are known to be size-like and may have 0/1 entries that we
  724. # must not specialize on.
  725. self.force_unspec_int_unbacked_size_like = False
  726. # See note [Tensor Fakification and Symbol Caching]
  727. self.tensor_to_context = WeakTensorKeyDictionary()
  728. # If this true, Aot Autograd will return output Fake Tensors with appropriate
  729. # meta on the first invocation
  730. # see note: [Returning Fake Tensors on First AOT Autograd Call]
  731. self.fakify_first_call = False
  732. self.hop_dispatch_set_cache = HopDispatchSetCache()
  733. # list of code objects for inlined functions
  734. self.traced_code: list[CodeType] = []
  735. def clear(self) -> None:
  736. # Look at the note in output_graph.py in function `save_global_state`
  737. # for the context on clearing global context.
  738. self.global_context.global_state = {}
  739. self.previously_inlined_functions.clear()
  740. self.previously_cleaned_instructions.clear()
  741. @staticmethod
  742. @contextmanager
  743. def patch(**kwargs: Any) -> Generator[None, None, None]:
  744. prior = {}
  745. ctx = TracingContext.get()
  746. for key in kwargs:
  747. # KeyError on invalid entry
  748. prior[key] = getattr(ctx, key)
  749. for key, val in kwargs.items():
  750. setattr(ctx, key, val)
  751. try:
  752. yield
  753. finally:
  754. for key, val in prior.items():
  755. setattr(ctx, key, val)
  756. @staticmethod
  757. def extract_stack() -> traceback.StackSummary:
  758. self = TracingContext.try_get()
  759. if self is None:
  760. return traceback.StackSummary()
  761. stack = self.frame_summary_stack
  762. if self.loc_in_frame is not None:
  763. stack = stack + [self._populate_loc_in_frame_summary()]
  764. return traceback.StackSummary.from_list(stack)
  765. def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary:
  766. assert self.loc_in_frame is not None
  767. filename, lineno, frame_name = self.loc_in_frame
  768. return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False)
  769. # Call this when you want to call into some code that isn't necessarily
  770. # associated with the current frame state
  771. @staticmethod
  772. @contextlib.contextmanager
  773. def clear_frame() -> Generator[None, None, None]:
  774. tc = TracingContext.get()
  775. with (
  776. unittest.mock.patch.object(tc, "frame_summary_stack", []),
  777. unittest.mock.patch.object(tc, "loc_in_frame", None),
  778. ):
  779. try:
  780. yield
  781. except Exception as e:
  782. # Prevent real_stack from getting attached
  783. #
  784. # The invariant is that if an Exception as real_stack, we've
  785. # appropriately attached a user stack and we no longer need to
  786. # attach anything. Because we cannot conveniently interpose
  787. # when an exception is thrown, we instead interpose everywhere
  788. # we set what the user stack is set (using the context
  789. # manager). However, our compiler stack does "tail calls"
  790. # (when it calls into user compiler), at which point the
  791. # parent exception frames would incorrectly attach an
  792. # incorrect frame.
  793. #
  794. # However, if, somehow, someone raised an exception with this
  795. # scope that had a stack (for example, because they are
  796. # restoring the user stack state appropriately as they process
  797. # node by node), we should respect it. Thus, we cannot
  798. # unconditionally set None.
  799. if not hasattr(e, "real_stack"):
  800. e.real_stack = None # type: ignore[attr-defined]
  801. raise
  802. @staticmethod
  803. @contextlib.contextmanager
  804. def current_frame(
  805. frame_summary: traceback.FrameSummary | None,
  806. ) -> Generator[None, None, None]:
  807. # frame_summary can be None to solely take advantage of real_stack
  808. # attachment to thrown exceptions
  809. tc = TracingContext.get()
  810. if frame_summary is not None:
  811. tc.frame_summary_stack.append(frame_summary)
  812. old = tc.loc_in_frame
  813. tc.loc_in_frame = None
  814. try:
  815. yield
  816. except Exception as e:
  817. if not hasattr(e, "real_stack"):
  818. e.real_stack = tc.extract_stack() # type: ignore[attr-defined]
  819. raise
  820. finally:
  821. if frame_summary is not None:
  822. tc.frame_summary_stack.pop()
  823. tc.loc_in_frame = old
  824. @staticmethod
  825. @contextlib.contextmanager
  826. def report_output_strides() -> Generator[
  827. list[tuple[int, ...] | None] | None, None, None
  828. ]:
  829. tc = TracingContext.try_get()
  830. if tc is None:
  831. yield None
  832. return
  833. old_output_strides = tc.output_strides
  834. tc.output_strides = []
  835. try:
  836. yield tc.output_strides
  837. finally:
  838. tc.output_strides = old_output_strides
  839. @staticmethod
  840. def set_current_loc(filename: str, lineno: int, frame_name: str) -> None:
  841. # Save the current location in the frame. Lazily generate the
  842. # framesummary.
  843. TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
  844. @staticmethod
  845. def get_traced_code() -> list[CodeType] | None:
  846. tc = TracingContext.try_get()
  847. if tc is None:
  848. return None
  849. return tc.traced_code
  850. @contextmanager
  851. def compile_context(
  852. context: CompileContext | None,
  853. ) -> Generator[CompileContext | None, None, None]:
  854. old_context = getattr(_TLS, "compile_context", None)
  855. _TLS.compile_context = context
  856. try:
  857. yield context
  858. finally:
  859. _TLS.compile_context = old_context
  860. @contextmanager
  861. def tracing(
  862. context: TracingContext | None,
  863. ) -> Generator[TracingContext | None, None, None]:
  864. """
  865. This function installs the passed in tracing context as a dynamic scoped
  866. global variable.
  867. Calls to TracingContext.get() while not under a `with tracing()` context
  868. will return None.
  869. """
  870. old_context = getattr(_TLS, "tracing_context", None)
  871. _TLS.tracing_context = context
  872. try:
  873. yield context
  874. except Exception as e:
  875. if not hasattr(e, "real_stack") and context is not None:
  876. e.real_stack = context.extract_stack() # type: ignore[attr-defined]
  877. raise
  878. finally:
  879. if (
  880. context is not None
  881. and context.fake_mode is not None
  882. and context.fake_mode.shape_env is not None
  883. ):
  884. context.fake_mode.shape_env.cleanup()
  885. _TLS.tracing_context = old_context
  886. @overload
  887. def dataclass_with_cached_hash(cls: type[T], **kwargs: Any) -> type[T]: ...
  888. @overload
  889. def dataclass_with_cached_hash(
  890. cls: None = None, **kwargs: Any
  891. ) -> Callable[[type[T]], type[T]]: ...
  892. @dataclass_transform()
  893. def dataclass_with_cached_hash(
  894. cls: type[T] | None = None, **kwargs: Any
  895. ) -> type[T] | Callable[[type[T]], type[T]]:
  896. def wrap(cls_inner: type[T]) -> type[T]:
  897. new_cls = dataclasses.dataclass(cls_inner, **kwargs)
  898. old_hash = cls_inner.__hash__
  899. def __hash__(self) -> int:
  900. if not hasattr(self, "_hash"):
  901. object.__setattr__(self, "_hash", old_hash(self))
  902. return self._hash
  903. def __reduce__(self):
  904. # Exclude _hash from pickling to ensure deterministic cache keys.
  905. # The _hash is a cached value that can be nondeterministically computed
  906. # (e.g., based on id() of objects), so it should not affect pickling.
  907. fields = dataclasses.fields(self)
  908. field_values = tuple(getattr(self, f.name) for f in fields)
  909. return (self.__class__, field_values)
  910. new_cls.__hash__ = __hash__
  911. new_cls.__reduce__ = __reduce__
  912. return new_cls # type: ignore[return-value]
  913. if cls is None:
  914. return wrap
  915. return wrap(cls)
  916. # Subclasses can be found in torch/_dynamo/source.py
  917. # TODO(voz): Consider a toplevel torch/_source.py
  918. @dataclass_with_cached_hash(frozen=True)
  919. class Source:
  920. def is_dict_key(self) -> bool:
  921. return False
  922. def is_ephemeral(self) -> bool:
  923. return False
  924. def reconstruct(self, codegen: PyCodegen) -> None:
  925. raise NotImplementedError
  926. @functools.cached_property
  927. def guard_source(self) -> GuardSource:
  928. raise NotImplementedError
  929. @property
  930. def _name_template(self) -> str:
  931. """
  932. A template for the name of the source. Used to prevent code duplication between
  933. `name` and `get_value`.
  934. For non-ChainedSources, `name` and `get_value` use the returned string directly.
  935. For ChainedSources, `name` and `get_value` expect the return to be a format string
  936. with `{0}` present - `name` and `get_value` will apply different values to this function's
  937. returned format string.
  938. """
  939. raise NotImplementedError
  940. @functools.cached_property
  941. def name(self) -> str:
  942. return self._name_template
  943. def get_value(
  944. self,
  945. globals: dict[str, Any],
  946. locals: dict[str, Any],
  947. cache: weakref.WeakKeyDictionary[Source, Any],
  948. ) -> Any:
  949. if self in cache:
  950. return cache[self]
  951. value = eval(self._name_template, globals, locals)
  952. cache[self] = value
  953. return value
  954. def make_guard(self, fn: Callable[..., Any]) -> Guard:
  955. if self.guard_source is GuardSource.CONSTANT:
  956. raise NotImplementedError
  957. return Guard(self, fn)
  958. def is_specialized_nn_module(self) -> bool:
  959. return self.guard_source.is_specialized_nn_module()
  960. def subguards_allowed(self) -> bool:
  961. """True if you can guard on attributes of this"""
  962. return self.guard_source != GuardSource.SYNTHETIC_LOCAL
  963. # Subclasses can be found in torch/_dynamo/source.py
  964. @dataclass_with_cached_hash(frozen=True)
  965. class ChainedSource(Source):
  966. base: Source
  967. def is_dict_key(self) -> bool:
  968. # Recurse until you either hit a ConstDictKey or a Source
  969. return self.base.is_dict_key()
  970. def is_ephemeral(self) -> bool:
  971. return self.base.is_ephemeral()
  972. @functools.cached_property
  973. def guard_source(self) -> GuardSource:
  974. return self.base.guard_source
  975. def get_base(self) -> Source:
  976. current: Source = self
  977. while isinstance(current, ChainedSource):
  978. current = current.base
  979. return current
  980. @functools.cached_property
  981. def name(self) -> str:
  982. return self._name_template.format(self.base.name)
  983. def get_value(
  984. self,
  985. globals: dict[str, Any],
  986. locals: dict[str, Any],
  987. cache: weakref.WeakKeyDictionary[Source, Any],
  988. ) -> Any:
  989. if self in cache:
  990. return cache[self]
  991. tmpvar = "tmp"
  992. counter = 0
  993. while tmpvar in locals:
  994. tmpvar = f"tmp{counter}"
  995. counter += 1
  996. locals[tmpvar] = self.base.get_value(globals, locals, cache)
  997. value = eval(self._name_template.format(tmpvar), globals, locals)
  998. del locals[tmpvar]
  999. cache[self] = value
  1000. return value
  1001. def detect_fake_mode(inputs: Any = None) -> FakeTensorMode | None:
  1002. """
  1003. Attempts to "detect" what the current fake mode is. If there is one ambiently
  1004. available from TracingContext, we preferentially use that. Otherwise, we
  1005. heuristically detect the fake mode via the following sources, in order of
  1006. priority:
  1007. - Currently active fake mode on stack
  1008. - Fake mode associated with passed in tensors (inputs does not
  1009. have to be flattened)
  1010. """
  1011. from torch._subclasses.fake_tensor import (
  1012. FakeTensor,
  1013. FakeTensorMode,
  1014. get_plain_tensors,
  1015. )
  1016. fake_modes = []
  1017. if context := TracingContext.try_get():
  1018. fake_mode = context.fake_mode
  1019. if fake_mode is not None:
  1020. fake_modes.append((fake_mode, "tracing context", 0))
  1021. from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
  1022. for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
  1023. if isinstance(m, FakeTensorMode):
  1024. # pyrefly: ignore [bad-argument-type]
  1025. fake_modes.append((m, "active fake mode", i))
  1026. flat_inputs = pytree.tree_leaves(inputs)
  1027. for i, flat_input in enumerate(flat_inputs):
  1028. if isinstance(flat_input, FakeTensor):
  1029. # pyrefly: ignore [bad-argument-type]
  1030. fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
  1031. if is_traceable_wrapper_subclass(flat_input):
  1032. out: list[torch.Tensor | int | torch.SymInt] = []
  1033. get_plain_tensors(flat_input, out=out) # type: ignore[arg-type]
  1034. fake_tensors: list[FakeTensor] = [
  1035. x for x in out if isinstance(x, FakeTensor)
  1036. ]
  1037. fake_modes.extend(
  1038. # pyrefly: ignore [bad-argument-type]
  1039. [
  1040. (tensor.fake_mode, f"subclass input {i}", ix)
  1041. for ix, tensor in enumerate(fake_tensors)
  1042. ]
  1043. )
  1044. if fake_modes:
  1045. fake_mode, desc1, i1 = fake_modes[0]
  1046. for m, desc2, i2 in fake_modes[1:]:
  1047. assert fake_mode is m, (
  1048. f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
  1049. # pyrefly: ignore [missing-attribute]
  1050. f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
  1051. # pyrefly: ignore [missing-attribute]
  1052. f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
  1053. )
  1054. # pyrefly: ignore [bad-return]
  1055. return fake_mode
  1056. else:
  1057. return None
  1058. def active_fake_mode() -> FakeTensorMode | None:
  1059. """
  1060. Inspects the dispatch mode stack for an active fake mode and returns it.
  1061. Returns None if no fake mode is active.
  1062. """
  1063. from torch._subclasses.fake_tensor import FakeTensorMode
  1064. from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
  1065. for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
  1066. if isinstance(m, FakeTensorMode):
  1067. return m
  1068. return None