| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319 |
- from __future__ import annotations
- import contextlib
- import dataclasses
- import enum
- import functools
- import logging
- import re
- import sys
- import threading
- import traceback
- import unittest.mock
- import weakref
- from abc import abstractmethod
- from collections import defaultdict
- from contextlib import contextmanager
- from dataclasses import dataclass
- from typing import Any, Generic, NamedTuple, Optional, overload, TYPE_CHECKING, TypeVar
- if sys.version_info >= (3, 11):
- from typing import dataclass_transform
- else:
- def dataclass_transform():
- def decorator(fn):
- return fn
- return decorator
- import torch
- from torch.utils import _pytree as pytree
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- from torch.utils._traceback import CapturedTraceback, format_frame
- from torch.utils.weak import WeakTensorKeyDictionary
- log = logging.getLogger(__name__)
- if TYPE_CHECKING:
- from collections.abc import Callable, Generator, Iterator
- from types import CodeType
- import sympy
- from torch._dynamo.backends.distributed import DDPOptimizerContext
- from torch._dynamo.codegen import PyCodegen
- from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
- from torch._subclasses.fake_tensor import FakeTensorMode
- """
- torch._guards is the definitional source of truth for general purpose guard structures.
- An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
- and no guard installation notions here.
- """
- COMPILE_ID_PATTERN = re.compile(r"^(?P<frame_id>\d+)/(?P<frame_compile_id>\d+)$")
- CA_COMPILE_ID_PATTERN = re.compile(
- r"^!(?P<compiled_autograd_id>\d+)(?:/(?P<frame_id>\d+)/(?P<frame_compile_id>\d+))?$"
- )
- # [Note: Updating CompiledId]
- #
- # CompiledId represents a unique program-level identifier, and we want to keep that
- # property as the codebase evolves. This property is relied on even outside of the pytorch
- # repo, e.g. tlparse or other internal tooling. The in-memory format can be freely changed,
- # as those dependencies only consume the string serialization.
- #
- # The string form should be:
- # 1. Program-level uid: CompileId can uniquely identify a compiled graph.
- # 2. Storage efficient: This object is logged in nearly every entry. We should elide symbols when possible.
- # 3. Compact: The string form is directly displayed by some tools. Special symbols are okay.
- @dataclass(frozen=True, kw_only=True, slots=True)
- class CompileId:
- frame_id: int | None
- # This id is per-frame, and counts how many times we've compiled this
- # frame. This could have been a global id but having this be per-frame
- # gives you a better intuitive sense for how many recompiles have occurred
- # so far.
- frame_compile_id: int | None
- # torch.compiling a compiled autograd graph
- compiled_autograd_id: int | None = None
- # TODO: consider also tracking the recompilation count
- # See Note: Updating CompileId
- def __str__(self) -> str:
- # NOTE: Keep this in sync with both from_string and the tlparse repo
- if self.compiled_autograd_id is not None:
- assert (self.frame_id is None) == (self.frame_compile_id is None)
- frame_str = ""
- if self.frame_id is not None:
- frame_str = f"/{self.frame_id}/{self.frame_compile_id}"
- return f"!{self.compiled_autograd_id}{frame_str}"
- else:
- assert self.frame_id is not None and self.frame_compile_id is not None
- return f"{self.frame_id}/{self.frame_compile_id}"
- @classmethod
- def from_string(cls, compile_id: str | None) -> CompileId | None:
- """
- Factory method that creates a CompileId from its string representation.
- Keep this in sync with the __str__ method.
- """
- if compile_id is None:
- return None
- try:
- for pattern in (COMPILE_ID_PATTERN, CA_COMPILE_ID_PATTERN):
- if match := pattern.match(compile_id):
- groups = match.groupdict()
- for k, v in groups.items():
- if v is not None:
- groups[k] = int(v)
- return cls(**groups) # type: ignore[arg-type]
- else:
- raise ValueError
- except Exception as e:
- raise ValueError(f"Invalid compile_id '{compile_id}'") from e
- class TraceId(NamedTuple):
- compile_id: CompileId
- # This starts off as 0, and every time we restart analysis it goes
- # up by one
- attempt: int
- def __str__(self) -> str:
- # Keep this in sync with tlparse repo
- if self.attempt == 0:
- return str(self.compile_id)
- else:
- return f"{self.compile_id}_{self.attempt}"
- class GuardSource(enum.Enum):
- LOCAL = 0
- GLOBAL = 1
- LOCAL_SPECIALIZED_NN_MODULE = 2
- GLOBAL_SPECIALIZED_NN_MODULE = 3
- CONSTANT = 4
- RANDOM_VALUE = 5
- SHAPE_ENV = 6
- LOCAL_FSDP_MODULE = 7
- GLOBAL_FSDP_MODULE = 8
- BACKWARD_STATE = 9
- EPHEMERAL = 10
- SYNTHETIC_LOCAL = 11
- LOCAL_UNSPECIALIZED_NN_MODULE = 12
- GLOBAL_UNSPECIALIZED_NN_MODULE = 13
- LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14
- GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15
- TEMP_LOCAL = 16
- def is_fsdp_module(self) -> bool:
- return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
- def is_specialized_nn_module(self) -> bool:
- import torch._dynamo.config as config
- if config._unsafe_skip_fsdp_module_guards:
- return (
- self
- in (
- GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
- )
- or self.is_fsdp_module()
- )
- return self in (
- GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
- )
- def is_unspecialized_nn_module(self) -> bool:
- return self in (
- GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
- GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- )
- def is_unspecialized_builtin_nn_module(self) -> bool:
- return self in (
- GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- )
- def is_local(self) -> bool:
- return self in (
- GuardSource.LOCAL,
- GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_FSDP_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- )
- """
- Base class for a "GuardBuilder" role.
- The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
- confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
- to torchdynamo's GuardBuilder.
- Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
- on GuardSource's select function.
- There is value in keeping this GuardBuilderBase empty to keep layering clean.
- """
- class GuardBuilderBase:
- pass
- @dataclasses.dataclass(frozen=True)
- class SLoc:
- framework_loc: traceback.FrameSummary | str | None
- maybe_user_loc: str | None
- def __str__(self) -> str:
- floc = (
- self.framework_loc
- if isinstance(self.framework_loc, str)
- else format_frame(self.framework_loc)
- )
- if self.maybe_user_loc is not None:
- return f"{self.maybe_user_loc} ({floc})"
- else:
- return f"({floc})"
- class ShapeGuard(NamedTuple):
- expr: sympy.logic.boolalg.Boolean
- sloc: SLoc
- size_oblivious: bool
- @dataclasses.dataclass(slots=True)
- class Guard:
- # originating_source is the source that called the make_guard method to
- # construct this guard object. The property name specifies what exactly it
- # is the guard is guarding on. The meaning of the name is dependent on the
- # create_fn; you must look at the use-site inside create_fn to know what
- # name means.
- #
- # That being said, although you might think this is just a "name", name is
- # usually an arbitrary Python expression that will be evaluated with all
- # globals (and locals, if you create a LOCAL guard) to extract the Python
- # object that we want to perform guard tests on. This evaluation
- # typically happens in GuardBuilder.eval. In these cases, name is
- # typically produced by originating_source.name (not to be confused with
- # GuardSource - the property source).
- #
- # Occasionally, name is not a valid Python expression; sometimes
- # it is meaningless. Example create_fns that are like this include
- # GRAD_MODE and SHAPE_ENV.
- originating_source: Source
- create_fn: Callable[[GuardBuilderBase, Guard], None]
- # Export only. These values are written to at time of guard check_fn creation.
- guard_types: list[str] | None = None
- code_list: list[str] | None = None
- obj_weakref: object | None = None
- guarded_class_weakref: weakref.ReferenceType[Any] | None = None
- stack: CapturedTraceback | None = None
- user_stack: traceback.StackSummary | None = None
- _hash: int | None = None
- _unserializable: bool = False
- def __hash__(self) -> int:
- if self._hash is None:
- self._hash = hash((self.name, self.source, id(self.create_fn)))
- return self._hash
- def sort_key(self) -> tuple[bool, int, int, str, int]:
- # Put the duplicate input guards at the end. The duplicate guards have
- # two sources while guard.name only considers one source.
- is_duplicate_input = (
- isinstance(self.create_fn, functools.partial)
- and self.create_fn.func is torch._dynamo.guards.GuardBuilder.DUPLICATE_INPUT
- )
- return (
- is_duplicate_input,
- self.source.value if self.source else -1,
- len(self.name),
- self.name,
- self.inner_create_fn().__code__.co_firstlineno,
- )
- def __lt__(self, other: Guard) -> bool:
- return self.sort_key() < other.sort_key()
- def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]:
- if isinstance(self.create_fn, functools.partial):
- return self.create_fn.func
- else:
- return self.create_fn
- @property
- def name(self) -> str:
- return self.originating_source.name
- @property
- def source(self) -> GuardSource:
- return self.originating_source.guard_source
- @staticmethod
- def weakref_to_str(obj_weakref: object) -> str:
- """
- This is a workaround of a Python weakref bug.
- `obj_weakref` is instance returned by `weakref.ref`,
- `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
- class MyConfig(dict):
- def __getattr__(self, x):
- return self[x]
- obj = MyConfig(offset=5)
- obj_weakref = weakref.ref(obj)
- str(obj_weakref) # raise error: KeyError: '__name__'
- """
- if isinstance(obj_weakref, weakref.ReferenceType):
- obj = obj_weakref()
- if obj is not None:
- return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
- else:
- return f"<weakref at {hex(id(obj_weakref))}; dead>"
- else:
- return str(obj_weakref)
- def __repr__(self) -> str:
- s = f"""
- {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
- {{
- 'guard_types': {self.guard_types},
- 'code': {self.code_list},
- 'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
- 'guarded_class': {self.guarded_class_weakref}
- }}
- """
- return s
- def __str__(self) -> str:
- output = f"Name: {repr(self.name)}\n"
- source = self.source.name.lower() if self.source else ""
- output += f" Source: {source}\n"
- output += f" Create Function: {self.inner_create_fn().__name__}\n"
- output += f" Guard Types: {self.guard_types}\n"
- output += f" Code List: {self.code_list}\n"
- output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
- output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
- return output
- def create(self, builder: GuardBuilderBase) -> Any:
- try:
- return self.create_fn(builder, self)
- except Exception:
- log.exception("Error while creating guard:\n%s", str(self).rstrip())
- if self.stack:
- log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
- raise
- def is_specialized_nn_module(self) -> bool:
- return self.source.is_specialized_nn_module()
- def is_fsdp_module(self) -> bool:
- return self.source.is_fsdp_module()
- def is_local(self) -> bool:
- return self.source.is_local()
- def create_fn_name(self) -> str:
- if isinstance(self.create_fn, functools.partial):
- create_fn = self.create_fn.func # type: ignore[attr-defined]
- else:
- create_fn = self.create_fn
- return create_fn.__name__
- def set_export_info(
- self,
- guard_type: str,
- guarded_class: weakref.ReferenceType[Any] | None,
- code_list: list[str],
- obj_weakref: object,
- ) -> None:
- if not self.guard_types:
- self.guard_types = []
- self.guard_types.append(guard_type)
- assert self.guarded_class_weakref in (
- guarded_class,
- None,
- ), "Guarded class id must be identical, or None"
- self.guarded_class_weakref = guarded_class
- if not self.code_list:
- self.code_list = code_list
- else:
- self.code_list.extend(code_list)
- # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have
- # multiple guards on the same object, the weakref can die between the
- # invocation of set_export_info calls. So a dead weakref is also
- # acceptable.
- assert (
- self.obj_weakref in (obj_weakref, None)
- or callable(self.obj_weakref)
- and self.obj_weakref() is None
- ), "Guarded object must be identical, None or ephemeral (dead weakref)"
- self.obj_weakref = obj_weakref
- T = TypeVar("T")
- """
- Parent structure for guard env expressions.
- A GuardEnvExpr can have any subtype.
- Note: All subtypes must be handled exhaustively in
- torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
- """
- @dataclasses.dataclass(frozen=True)
- class GuardEnvExpr:
- pass
- """
- A class representing a pair of duplicate inputs.
- input_pos_a and input_pos_b are input positions we have deduped.
- """
- @dataclasses.dataclass(frozen=True)
- class DuplicateInputs(GuardEnvExpr):
- input_source_a: Source
- input_source_b: Source
- def __post_init__(self) -> None:
- assert self.input_source_a != self.input_source_b
- """
- A class representing storage overlap relations among inputs that aliases the same storage.
- Given that a set of tensors alias the same storage, this guard checks whether they actually
- have overlapping storages.
- While non_overlapping_sources represent input tensors that definitely don't have any storage
- overlapping with any other input, overlapping_sources represent tensors that either:
- 1. Do overlap some other input tensor
- 2. Might not overlap some other input tensor, but we are not sure
- """
- @dataclasses.dataclass(frozen=True)
- class StorageOverlap(GuardEnvExpr):
- overlapping_sources: list[Source]
- non_overlapping_sources: list[Source]
- """
- Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
- copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
- can also be taken in at restore_graphstate(T) calls.
- When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
- does not provide any guarantees around consistency, idempotency, or safety of calling its APIs, yet.
- In the future, it will have a closer coupling to a generic Checkpoint management system.
- """
- class Checkpointable(Generic[T]):
- @abstractmethod
- def copy_graphstate(self) -> T: ...
- @abstractmethod
- def restore_graphstate(self, state: T) -> None: ...
- class GuardsCheckpointState:
- """
- The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
- """
- dynamo_guards: OrderedSet[Guard]
- def __init__(self, dynamo_guards: OrderedSet[Guard]) -> None:
- self.dynamo_guards = dynamo_guards
- def diff(self, other: GuardsCheckpointState) -> Optional[OrderedSet[Guard]]:
- """
- Produces a delta against another GuardsCheckpointState.
- Returns None if no delta is found, otherwise, return an OrderedSet() of mismatched
- Guard type objects.
- """
- r = self.dynamo_guards.difference(other.dynamo_guards)
- if len(r) == 0:
- return None
- return r
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, GuardsCheckpointState):
- return False
- return self.diff(other) is None
- class ModuleContextCheckpointState:
- nn_modules: dict[str, torch.nn.Module] = {}
- def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None:
- self.nn_modules = nn_modules
- def diff(self, other: ModuleContextCheckpointState) -> set[str] | None:
- """
- Produces a delta against another ModuleContextCheckpointState.
- Returns None if no delta is found, otherwise, return a set() of mismatched
- module key names.
- """
- r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
- if len(r) == 0:
- return None
- return r
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, ModuleContextCheckpointState):
- return False
- return self.diff(other) is None
- class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
- def __init__(self) -> None:
- self.nn_modules: dict[str, Any] = {}
- def copy_graphstate(self) -> ModuleContextCheckpointState:
- return ModuleContextCheckpointState(dict(self.nn_modules))
- def restore_graphstate(self, state: ModuleContextCheckpointState) -> None:
- assert isinstance(state, ModuleContextCheckpointState)
- self.nn_modules = state.nn_modules
- class GlobalContextCheckpointState:
- global_state: dict[str, tuple[Callable, Any]] = {}
- def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None:
- self.global_state = global_states
- def diff(self, other: GlobalContextCheckpointState) -> set[str] | None:
- """
- Produces a delta against another GlobalContextCheckpointState.
- Returns None if no delta is found, otherwise, return a set() of mismatched
- global key names.
- """
- r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
- if len(r) == 0:
- return None
- return r
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, GlobalContextCheckpointState):
- return False
- return self.diff(other) is None
- class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
- """
- This keeps track of the global torch state during tracing of a function.
- For example, torch.is_grad_enabled.
- """
- _supported_global_states = {
- "grad_enabled",
- "autocast_enabled",
- "autocast_cpu_enabled",
- "autocast_gpu_dtype",
- "autocast_cpu_dtype",
- "autocast_cache_enabled",
- }
- def __init__(self) -> None:
- self.global_state: dict[str, tuple[Callable, Any]] = {}
- def copy_graphstate(self) -> GlobalContextCheckpointState:
- return GlobalContextCheckpointState(self.global_state)
- def restore_graphstate(self, state: GlobalContextCheckpointState) -> None:
- assert isinstance(state, GlobalContextCheckpointState)
- self.global_state = state.global_state
- assert (
- len(self.global_state) == len(self._supported_global_states)
- and set(self.global_state.keys()) == self._supported_global_states
- ), "Global state mismatch"
- for func, args in self.global_state.values():
- func(args)
- # Like a Set[Guard] but will record the user stack on all guards at the
- # time they were installed at their destination
- class GuardsSet:
- def __init__(self, inner: Optional[OrderedSet[Guard]] = None) -> None:
- if inner is None:
- self.inner: OrderedSet[Guard] = OrderedSet()
- else:
- self.inner = inner
- def __iter__(self) -> Iterator[Guard]:
- return iter(self.inner)
- def __len__(self) -> int:
- return len(self.inner)
- # Subtraction along with bool is typically used to determine the delta of
- # added guards between checkpoints for higher order ops
- def __sub__(self, other: GuardsSet) -> GuardsSet:
- return GuardsSet(self.inner - other.inner)
- def __bool__(self) -> bool:
- return bool(self.inner)
- def add(
- self, guard: Guard, *, collect_debug_stack: bool = True, skip: int = 0
- ) -> None:
- if guard in self.inner:
- return
- if collect_debug_stack:
- if guard.stack is None:
- guard.stack = CapturedTraceback.extract(skip=1 + skip)
- if guard.user_stack is None:
- guard.user_stack = TracingContext.extract_stack()
- self.inner.add(guard)
- def update(self, *others: set[Guard]) -> None:
- for o in others:
- for g in o:
- self.add(g, skip=1)
- def remove_guards_with_source(self, source: Source) -> None:
- """Delete all guards that contains a given source"""
- from ._dynamo.source import is_from_source
- self.inner = OrderedSet(
- g for g in self.inner if not is_from_source(g.originating_source, source)
- )
- """
- A GuardsContext is a checkpointable representation of all the guards in the current tracing
- context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
- directly outside of it. For passing around internal state representations of this object,
- prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
- """
- class GuardsContext(Checkpointable[GuardsCheckpointState]):
- def __init__(self) -> None:
- self.dynamo_guards: GuardsSet = GuardsSet()
- self.aotautograd_guards: list[GuardEnvExpr] = []
- def copy_graphstate(self) -> GuardsCheckpointState:
- return GuardsCheckpointState(OrderedSet(self.dynamo_guards.inner))
- def restore_graphstate(self, state: GuardsCheckpointState) -> None:
- # NB: "steals" the passed in state
- assert isinstance(state, GuardsCheckpointState)
- self.dynamo_guards = GuardsSet(state.dynamo_guards)
- class HopSubgraphCache:
- @abstractmethod
- def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ...
- @abstractmethod
- def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ...
- @abstractmethod
- def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ...
- @abstractmethod
- def get_autograd_key_entry(self, identifier: str) -> Callable | None: ...
- @abstractmethod
- def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ...
- @abstractmethod
- def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None: ...
- @abstractmethod
- def add_lazy_bwd_entry(
- self,
- identifier: str,
- tangent_metadata: tuple[object],
- gmod: torch.fx.GraphModule,
- ) -> int: ...
- @abstractmethod
- def get_lazy_bwd_entry(
- self, identifier: str, tangent_metadata: tuple[object]
- ) -> tuple[torch.fx.GraphModule | None, int | None]: ...
- class InvokeSubgraphCache(HopSubgraphCache):
- def __init__(self) -> None:
- self.autograd_cache: dict[str, Callable] = {}
- self.proxy_dispatch_cache: dict[str, Callable] = {}
- self.dynamo_installed_submodules: dict[int, list[str]] = defaultdict(list)
- self.lazy_bwd_cache: dict[
- str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
- ] = defaultdict(dict)
- self.effects_cache: dict[
- str, set
- ] = {} # Maps identifier -> set of effect types
- def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
- self.dynamo_installed_submodules[fn_id].append(identifier)
- def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]:
- return self.dynamo_installed_submodules.get(fn_id, [])
- def add_autograd_key_entry(self, identifier: str, key: Callable) -> None:
- self.autograd_cache[identifier] = key
- def get_autograd_key_entry(self, identifier: str) -> Callable | None:
- return self.autograd_cache.get(identifier, None)
- def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None:
- self.proxy_dispatch_cache[identifier] = key
- def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None:
- return self.proxy_dispatch_cache.get(identifier, None)
- def add_lazy_bwd_entry(
- self,
- identifier: str,
- tangent_metadata: tuple[object],
- gmod: torch.fx.GraphModule,
- ) -> int:
- # Save the number of existing graph modules in the dictionary to get the suffix
- num_gmods = len(self.lazy_bwd_cache[identifier])
- self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods)
- return num_gmods
- def get_lazy_bwd_entry(
- self, identifier: str, tangent_metadata: tuple[object]
- ) -> tuple[torch.fx.GraphModule | None, int | None]:
- if identifier not in self.lazy_bwd_cache:
- return (None, None)
- return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None))
- def add_effects(self, identifier: str, effects: set) -> None:
- """Store the effect types for a given invoke_subgraph identifier."""
- if prev_effects := self.effects_cache.get(identifier, None):
- assert effects == prev_effects, (
- "Different number of effects were found for invoke_subgraph "
- f"call with identifier {identifier}. \n"
- f"Previously we had the following effects: {prev_effects}.\n"
- f"But now we have: {effects}."
- )
- self.effects_cache[identifier] = effects
- def get_effects(self, identifier: str) -> set | None:
- """Retrieve the effect types for a given invoke_subgraph identifier."""
- return self.effects_cache.get(identifier, None)
- class HopDispatchSetCache:
- def __init__(self) -> None:
- # Delayed import to avoid circular dependency
- from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
- self.hop_cache_map = {invoke_subgraph: InvokeSubgraphCache()}
- def get_cache(self, op: torch._ops.HigherOrderOperator) -> HopSubgraphCache | None:
- if op not in self.hop_cache_map:
- return None
- return self.hop_cache_map[op] # type: ignore[index]
- _TLS = threading.local()
- """
- TracingContext is the source of truth for all currently accumulated information
- needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
- are open to managing their own TracingContext with that in mind.
- The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
- having to plumb complex subsystems across multiple verticals.
- Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
- Accessing the current tracing context via
- TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
- to plumb objects back up to where frame interpretation happened.
- Note that you can end up with multiple TracingContext for a single compilation
- of a frame, as we reset the TracingContext whenever we restart analysis.
- CompileContext is a more overarching context that encompasses multiple restarts.
- """
- class CompileContext:
- @staticmethod
- def get() -> CompileContext:
- assert _TLS.compile_context is not None
- return _TLS.compile_context
- @staticmethod
- def try_get() -> CompileContext | None:
- return getattr(_TLS, "compile_context", None)
- def __init__(self, compile_id: CompileId | None) -> None:
- assert compile_id is None or isinstance(compile_id, CompileId)
- self.compile_id: CompileId | None = compile_id
- self.attempt = 0
- # Verbose ShapeEnv guards produced.
- self.shape_env_guards: list[str] = []
- @staticmethod
- def current_compile_id() -> CompileId | None:
- self = CompileContext.try_get()
- if self is None:
- return None
- return self.compile_id
- @staticmethod
- def current_trace_id() -> TraceId | None:
- self = CompileContext.try_get()
- if self is None:
- return None
- if self.compile_id is None:
- return None
- return TraceId(self.compile_id, self.attempt)
- class TracingContext:
- """
- Provides the currently installed TracingContext, or None.
- Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
- will return None.
- """
- @staticmethod
- def try_get() -> TracingContext | None:
- return getattr(_TLS, "tracing_context", None)
- @staticmethod
- def get() -> TracingContext:
- if ctx := TracingContext.try_get():
- return ctx
- raise RuntimeError(
- "TracingContext.get() must be called within an ongoing trace."
- )
- def __init__(self, fake_mode: FakeTensorMode | None) -> None:
- self.guards_context = GuardsContext()
- self.module_context = ModuleContext()
- self.global_context = GlobalContext()
- self.previously_inlined_functions: dict[Any, Any] = dict()
- self.previously_cleaned_instructions: dict[Any, Any] = dict()
- self.fake_mode: FakeTensorMode | None = fake_mode
- self.frame_summary_stack: list[traceback.FrameSummary] = []
- # This is morally part of frame_summary_stack, but it is kept separate
- # for clarity. As we process a frame, this variable gets updated
- # to keep track of what line we are in the function. We make a
- # function call, this gets cleared and the frame location is pushed
- # to frame_summary_stack (prepping this variable for the inner frame's
- # progress)
- self.loc_in_frame: tuple[str, int, str] | None = None
- # this is only set after aot_autograd
- self.fw_metadata: ViewAndMutationMeta | None = None
- # this is only set when the DDPOptimizer is used
- self.ddp_optimizer_ctx: DDPOptimizerContext | None = None
- # this is only set after aot_autograd
- self.aot_graph_name: list[str] | None = None
- self.params_flat: list[Any] | None = None
- self.params_flat_unwrap_subclasses: list[Any] | None = None
- self.params_unwrapped_to_flat_index: list[Any] | None = None
- # this is for extended return calling convention from backend
- # compiler to aot_autograd
- # Per output, what the compiler specified stride of the output is,
- # or None if no stride is known. This is always the HINT, it
- # is never a SymInt (it would be better if it was a SymInt, but
- # I can't conveniently get this from Inductor atm. Also, be
- # careful not to accidentally induce guards on the SymInt if
- # you ever do change this in aot_autograd.py; you should check
- # on permutations preferentially.)
- self.output_strides: list[tuple[int, ...] | None] | None = None
- # When this is True, whenever we encounter an int in Dynamo tracing,
- # we will (1) force unspec it and (2) force it as a size-like unbacked
- # integer. This is currently used when processing certain lists of
- # ints that are known to be size-like and may have 0/1 entries that we
- # must not specialize on.
- self.force_unspec_int_unbacked_size_like = False
- # See note [Tensor Fakification and Symbol Caching]
- self.tensor_to_context = WeakTensorKeyDictionary()
- # If this true, Aot Autograd will return output Fake Tensors with appropriate
- # meta on the first invocation
- # see note: [Returning Fake Tensors on First AOT Autograd Call]
- self.fakify_first_call = False
- self.hop_dispatch_set_cache = HopDispatchSetCache()
- # list of code objects for inlined functions
- self.traced_code: list[CodeType] = []
- def clear(self) -> None:
- # Look at the note in output_graph.py in function `save_global_state`
- # for the context on clearing global context.
- self.global_context.global_state = {}
- self.previously_inlined_functions.clear()
- self.previously_cleaned_instructions.clear()
- @staticmethod
- @contextmanager
- def patch(**kwargs: Any) -> Generator[None, None, None]:
- prior = {}
- ctx = TracingContext.get()
- for key in kwargs:
- # KeyError on invalid entry
- prior[key] = getattr(ctx, key)
- for key, val in kwargs.items():
- setattr(ctx, key, val)
- try:
- yield
- finally:
- for key, val in prior.items():
- setattr(ctx, key, val)
- @staticmethod
- def extract_stack() -> traceback.StackSummary:
- self = TracingContext.try_get()
- if self is None:
- return traceback.StackSummary()
- stack = self.frame_summary_stack
- if self.loc_in_frame is not None:
- stack = stack + [self._populate_loc_in_frame_summary()]
- return traceback.StackSummary.from_list(stack)
- def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary:
- assert self.loc_in_frame is not None
- filename, lineno, frame_name = self.loc_in_frame
- return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False)
- # Call this when you want to call into some code that isn't necessarily
- # associated with the current frame state
- @staticmethod
- @contextlib.contextmanager
- def clear_frame() -> Generator[None, None, None]:
- tc = TracingContext.get()
- with (
- unittest.mock.patch.object(tc, "frame_summary_stack", []),
- unittest.mock.patch.object(tc, "loc_in_frame", None),
- ):
- try:
- yield
- except Exception as e:
- # Prevent real_stack from getting attached
- #
- # The invariant is that if an Exception as real_stack, we've
- # appropriately attached a user stack and we no longer need to
- # attach anything. Because we cannot conveniently interpose
- # when an exception is thrown, we instead interpose everywhere
- # we set what the user stack is set (using the context
- # manager). However, our compiler stack does "tail calls"
- # (when it calls into user compiler), at which point the
- # parent exception frames would incorrectly attach an
- # incorrect frame.
- #
- # However, if, somehow, someone raised an exception with this
- # scope that had a stack (for example, because they are
- # restoring the user stack state appropriately as they process
- # node by node), we should respect it. Thus, we cannot
- # unconditionally set None.
- if not hasattr(e, "real_stack"):
- e.real_stack = None # type: ignore[attr-defined]
- raise
- @staticmethod
- @contextlib.contextmanager
- def current_frame(
- frame_summary: traceback.FrameSummary | None,
- ) -> Generator[None, None, None]:
- # frame_summary can be None to solely take advantage of real_stack
- # attachment to thrown exceptions
- tc = TracingContext.get()
- if frame_summary is not None:
- tc.frame_summary_stack.append(frame_summary)
- old = tc.loc_in_frame
- tc.loc_in_frame = None
- try:
- yield
- except Exception as e:
- if not hasattr(e, "real_stack"):
- e.real_stack = tc.extract_stack() # type: ignore[attr-defined]
- raise
- finally:
- if frame_summary is not None:
- tc.frame_summary_stack.pop()
- tc.loc_in_frame = old
- @staticmethod
- @contextlib.contextmanager
- def report_output_strides() -> Generator[
- list[tuple[int, ...] | None] | None, None, None
- ]:
- tc = TracingContext.try_get()
- if tc is None:
- yield None
- return
- old_output_strides = tc.output_strides
- tc.output_strides = []
- try:
- yield tc.output_strides
- finally:
- tc.output_strides = old_output_strides
- @staticmethod
- def set_current_loc(filename: str, lineno: int, frame_name: str) -> None:
- # Save the current location in the frame. Lazily generate the
- # framesummary.
- TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
- @staticmethod
- def get_traced_code() -> list[CodeType] | None:
- tc = TracingContext.try_get()
- if tc is None:
- return None
- return tc.traced_code
- @contextmanager
- def compile_context(
- context: CompileContext | None,
- ) -> Generator[CompileContext | None, None, None]:
- old_context = getattr(_TLS, "compile_context", None)
- _TLS.compile_context = context
- try:
- yield context
- finally:
- _TLS.compile_context = old_context
- @contextmanager
- def tracing(
- context: TracingContext | None,
- ) -> Generator[TracingContext | None, None, None]:
- """
- This function installs the passed in tracing context as a dynamic scoped
- global variable.
- Calls to TracingContext.get() while not under a `with tracing()` context
- will return None.
- """
- old_context = getattr(_TLS, "tracing_context", None)
- _TLS.tracing_context = context
- try:
- yield context
- except Exception as e:
- if not hasattr(e, "real_stack") and context is not None:
- e.real_stack = context.extract_stack() # type: ignore[attr-defined]
- raise
- finally:
- if (
- context is not None
- and context.fake_mode is not None
- and context.fake_mode.shape_env is not None
- ):
- context.fake_mode.shape_env.cleanup()
- _TLS.tracing_context = old_context
- @overload
- def dataclass_with_cached_hash(cls: type[T], **kwargs: Any) -> type[T]: ...
- @overload
- def dataclass_with_cached_hash(
- cls: None = None, **kwargs: Any
- ) -> Callable[[type[T]], type[T]]: ...
- @dataclass_transform()
- def dataclass_with_cached_hash(
- cls: type[T] | None = None, **kwargs: Any
- ) -> type[T] | Callable[[type[T]], type[T]]:
- def wrap(cls_inner: type[T]) -> type[T]:
- new_cls = dataclasses.dataclass(cls_inner, **kwargs)
- old_hash = cls_inner.__hash__
- def __hash__(self) -> int:
- if not hasattr(self, "_hash"):
- object.__setattr__(self, "_hash", old_hash(self))
- return self._hash
- def __reduce__(self):
- # Exclude _hash from pickling to ensure deterministic cache keys.
- # The _hash is a cached value that can be nondeterministically computed
- # (e.g., based on id() of objects), so it should not affect pickling.
- fields = dataclasses.fields(self)
- field_values = tuple(getattr(self, f.name) for f in fields)
- return (self.__class__, field_values)
- new_cls.__hash__ = __hash__
- new_cls.__reduce__ = __reduce__
- return new_cls # type: ignore[return-value]
- if cls is None:
- return wrap
- return wrap(cls)
- # Subclasses can be found in torch/_dynamo/source.py
- # TODO(voz): Consider a toplevel torch/_source.py
- @dataclass_with_cached_hash(frozen=True)
- class Source:
- def is_dict_key(self) -> bool:
- return False
- def is_ephemeral(self) -> bool:
- return False
- def reconstruct(self, codegen: PyCodegen) -> None:
- raise NotImplementedError
- @functools.cached_property
- def guard_source(self) -> GuardSource:
- raise NotImplementedError
- @property
- def _name_template(self) -> str:
- """
- A template for the name of the source. Used to prevent code duplication between
- `name` and `get_value`.
- For non-ChainedSources, `name` and `get_value` use the returned string directly.
- For ChainedSources, `name` and `get_value` expect the return to be a format string
- with `{0}` present - `name` and `get_value` will apply different values to this function's
- returned format string.
- """
- raise NotImplementedError
- @functools.cached_property
- def name(self) -> str:
- return self._name_template
- def get_value(
- self,
- globals: dict[str, Any],
- locals: dict[str, Any],
- cache: weakref.WeakKeyDictionary[Source, Any],
- ) -> Any:
- if self in cache:
- return cache[self]
- value = eval(self._name_template, globals, locals)
- cache[self] = value
- return value
- def make_guard(self, fn: Callable[..., Any]) -> Guard:
- if self.guard_source is GuardSource.CONSTANT:
- raise NotImplementedError
- return Guard(self, fn)
- def is_specialized_nn_module(self) -> bool:
- return self.guard_source.is_specialized_nn_module()
- def subguards_allowed(self) -> bool:
- """True if you can guard on attributes of this"""
- return self.guard_source != GuardSource.SYNTHETIC_LOCAL
- # Subclasses can be found in torch/_dynamo/source.py
- @dataclass_with_cached_hash(frozen=True)
- class ChainedSource(Source):
- base: Source
- def is_dict_key(self) -> bool:
- # Recurse until you either hit a ConstDictKey or a Source
- return self.base.is_dict_key()
- def is_ephemeral(self) -> bool:
- return self.base.is_ephemeral()
- @functools.cached_property
- def guard_source(self) -> GuardSource:
- return self.base.guard_source
- def get_base(self) -> Source:
- current: Source = self
- while isinstance(current, ChainedSource):
- current = current.base
- return current
- @functools.cached_property
- def name(self) -> str:
- return self._name_template.format(self.base.name)
- def get_value(
- self,
- globals: dict[str, Any],
- locals: dict[str, Any],
- cache: weakref.WeakKeyDictionary[Source, Any],
- ) -> Any:
- if self in cache:
- return cache[self]
- tmpvar = "tmp"
- counter = 0
- while tmpvar in locals:
- tmpvar = f"tmp{counter}"
- counter += 1
- locals[tmpvar] = self.base.get_value(globals, locals, cache)
- value = eval(self._name_template.format(tmpvar), globals, locals)
- del locals[tmpvar]
- cache[self] = value
- return value
- def detect_fake_mode(inputs: Any = None) -> FakeTensorMode | None:
- """
- Attempts to "detect" what the current fake mode is. If there is one ambiently
- available from TracingContext, we preferentially use that. Otherwise, we
- heuristically detect the fake mode via the following sources, in order of
- priority:
- - Currently active fake mode on stack
- - Fake mode associated with passed in tensors (inputs does not
- have to be flattened)
- """
- from torch._subclasses.fake_tensor import (
- FakeTensor,
- FakeTensorMode,
- get_plain_tensors,
- )
- fake_modes = []
- if context := TracingContext.try_get():
- fake_mode = context.fake_mode
- if fake_mode is not None:
- fake_modes.append((fake_mode, "tracing context", 0))
- from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
- for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
- if isinstance(m, FakeTensorMode):
- # pyrefly: ignore [bad-argument-type]
- fake_modes.append((m, "active fake mode", i))
- flat_inputs = pytree.tree_leaves(inputs)
- for i, flat_input in enumerate(flat_inputs):
- if isinstance(flat_input, FakeTensor):
- # pyrefly: ignore [bad-argument-type]
- fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
- if is_traceable_wrapper_subclass(flat_input):
- out: list[torch.Tensor | int | torch.SymInt] = []
- get_plain_tensors(flat_input, out=out) # type: ignore[arg-type]
- fake_tensors: list[FakeTensor] = [
- x for x in out if isinstance(x, FakeTensor)
- ]
- fake_modes.extend(
- # pyrefly: ignore [bad-argument-type]
- [
- (tensor.fake_mode, f"subclass input {i}", ix)
- for ix, tensor in enumerate(fake_tensors)
- ]
- )
- if fake_modes:
- fake_mode, desc1, i1 = fake_modes[0]
- for m, desc2, i2 in fake_modes[1:]:
- assert fake_mode is m, (
- f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
- # pyrefly: ignore [missing-attribute]
- f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
- # pyrefly: ignore [missing-attribute]
- f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
- )
- # pyrefly: ignore [bad-return]
- return fake_mode
- else:
- return None
- def active_fake_mode() -> FakeTensorMode | None:
- """
- Inspects the dispatch mode stack for an active fake mode and returns it.
- Returns None if no fake mode is active.
- """
- from torch._subclasses.fake_tensor import FakeTensorMode
- from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
- for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
- if isinstance(m, FakeTensorMode):
- return m
- return None
|