side_effects.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218
  1. """
  2. Side effect tracking and management for TorchDynamo's compilation system.
  3. This module provides infrastructure for tracking and managing side effects that occur
  4. during symbolic execution, including:
  5. - Tracking mutations to objects, attributes, and variables
  6. - Managing context changes (cell variables, global namespace modifications)
  7. - Handling aliasing and object identity preservation
  8. - Managing stack frame state and local variable changes
  9. - Tracking function calls with side effects
  10. Key classes:
  11. - SideEffects: Main container for tracking all side effects during execution
  12. - MutableSideEffects: Specialization for mutable object tracking
  13. - AttributeMutation/ValueMutation: Track specific types of mutations
  14. - Various specialized side effect classes for different scenarios
  15. The side effect system ensures that mutations performed during symbolic execution
  16. are properly replayed during runtime, maintaining the correctness of compiled code
  17. while enabling optimizations where safe.
  18. """
  19. import collections
  20. import contextlib
  21. import inspect
  22. import warnings
  23. import weakref
  24. from collections.abc import Generator, MutableMapping
  25. from types import CellType
  26. from typing import Any, Optional, TYPE_CHECKING
  27. import torch.nn
  28. from torch._dynamo.variables.misc import AutogradFunctionContextVariable
  29. from . import graph_break_hints, utils, variables
  30. from .bytecode_transformation import (
  31. bytecode_from_template,
  32. create_call_function,
  33. create_call_method,
  34. create_instruction,
  35. )
  36. from .codegen import PyCodegen
  37. from .exc import SideEffectsError, unimplemented_v2
  38. from .source import GlobalSource, LocalCellSource, LocalSource, Source
  39. from .utils import is_frozen_dataclass, nn_module_new, object_new
  40. from .variables.base import (
  41. AttributeMutation,
  42. AttributeMutationExisting,
  43. AttributeMutationNew,
  44. is_side_effect_safe,
  45. ValueMutationExisting,
  46. ValueMutationNew,
  47. VariableTracker,
  48. )
  49. from .variables.user_defined import FrozenDataClassVariable
  50. if TYPE_CHECKING:
  51. from torch._dynamo.output_graph import OutputGraph
  52. from torch._dynamo.symbolic_convert import InstructionTranslatorBase
  53. from torch._dynamo.variables.lists import ListVariable
  54. def _manual_dict_setitem(
  55. dict_from: dict[Any, Any], dict_to: dict[Any, Any], mro_index: int
  56. ) -> None:
  57. # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
  58. # to be careful because we don't want to trigger the user defined object
  59. # setitem or clear. The mro_index is used to find the dict/OrderedDict from
  60. # the class mro.
  61. dict_class = type(dict_to).__mro__[mro_index]
  62. dict_class.clear(dict_to) # type: ignore[attr-defined]
  63. for k, v in dict_from.items():
  64. dict_class.__setitem__(dict_to, k, v) # type: ignore[index]
  65. def _manual_list_update(list_from: list[Any], list_to: list[Any]) -> None:
  66. list.clear(list_to)
  67. list.extend(list_to, list_from)
  68. class SideEffects:
  69. """
  70. Maintain records of mutations and provide methods to apply them during code generation.
  71. Handles tracking and applying side effects during PyTorch Dynamo compilation,
  72. maintaining Python semantics by managing mutations, attribute modifications,
  73. and other side effects that occur during program execution.
  74. Key responsibilities:
  75. - Tracks mutations to Python objects, lists, and dictionaries that need to be
  76. applied after an FX graph is run.
  77. - Manages attribute modifications and deletions
  78. - Handles tensor hooks and backward pass state
  79. - Tracks cell variable mutations and global variable changes
  80. - Ensures correct ordering and application of side effects after graph execution
  81. This ensures that optimized code behaves identically to the original Python code with
  82. respect to object mutations and other side effects.
  83. """
  84. id_to_variable: dict[int, VariableTracker]
  85. store_attr_mutations: dict[VariableTracker, dict[str, VariableTracker]]
  86. keepalive: list[Any]
  87. def __init__(
  88. self,
  89. output_graph: "OutputGraph",
  90. id_to_variable: Optional[dict[int, VariableTracker]] = None,
  91. store_attr_mutations: Optional[
  92. dict[VariableTracker, dict[str, VariableTracker]]
  93. ] = None,
  94. keepalive: Optional[list[Any]] = None,
  95. save_for_backward: Optional[
  96. list[tuple[AutogradFunctionContextVariable, list[VariableTracker]]]
  97. ] = None,
  98. tensor_hooks: Optional[
  99. dict[
  100. int,
  101. tuple[
  102. "variables.TensorVariable",
  103. VariableTracker,
  104. "variables.RemovableHandleVariable",
  105. str,
  106. ],
  107. ]
  108. ] = None,
  109. ) -> None:
  110. super().__init__()
  111. self.output_graph_weakref = weakref.ref(output_graph)
  112. self.id_to_variable = id_to_variable or {}
  113. self.store_attr_mutations = store_attr_mutations or {}
  114. self.keepalive = keepalive or []
  115. self.save_for_backward = save_for_backward or []
  116. self.tensor_hooks = tensor_hooks or {}
  117. # Used by MappingProxyVariable to graph break in case of any mutated
  118. # dict
  119. self._has_existing_dict_mutation = False
  120. # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph.
  121. # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd.
  122. self.ca_final_callbacks_var: Optional[ListVariable] = None
  123. # Tracks VariableTracker objects whose mutations can be skipped.
  124. # For normal mutated variables, Dynamo generates code to replay/reconstruct
  125. # the mutations after graph execution. However, variables in this set have
  126. # their mutations ignored - the mutations happen during
  127. # execution but don't need to be replayed in the generated code.
  128. # Used for temporary mutations in contexts like torch.func.functional_call,
  129. # where module parameters/buffers are modified but later restored.
  130. self.ignore_mutation_on_these_variables: set[VariableTracker] = set()
  131. def ignore_mutations_on(self, var: VariableTracker) -> None:
  132. """Mutations to this variable will be executed but not not tracked,
  133. typically used for temporary mutations that are later restored."""
  134. self.ignore_mutation_on_these_variables.add(var)
  135. def stop_ignoring_mutations_on(self, var: VariableTracker) -> None:
  136. """Remove a variable from the skip mutation set, restoring normal mutation tracking."""
  137. if var in self.ignore_mutation_on_these_variables:
  138. self.ignore_mutation_on_these_variables.remove(var)
  139. def __eq__(self, other: object) -> bool:
  140. assert isinstance(other, SideEffects)
  141. # NB: do NOT test keepalive
  142. return (
  143. self.id_to_variable == other.id_to_variable
  144. and self.store_attr_mutations == other.store_attr_mutations
  145. and self.save_for_backward == other.save_for_backward
  146. and self.tensor_hooks == other.tensor_hooks
  147. )
  148. def diff(self, other: "SideEffects") -> Optional[str]:
  149. if self.id_to_variable != other.id_to_variable:
  150. sk_itv = self.id_to_variable.keys()
  151. ok_itv = other.id_to_variable.keys()
  152. if sk_itv != ok_itv:
  153. return f"id_to_variable keys: {sk_itv} != {ok_itv}"
  154. # Feel free to augment this with more fancy diffing logic
  155. # if needed for debugging
  156. return "id_to_variable: unknown diff"
  157. elif self.store_attr_mutations != other.store_attr_mutations:
  158. sk_sam = self.store_attr_mutations.keys()
  159. ok_sam = other.store_attr_mutations.keys()
  160. if sk_sam != ok_sam:
  161. return f"store_attr_mutations keys: {sk_sam} != {ok_sam}"
  162. return "store_attr_mutations: unknown diff"
  163. elif self.save_for_backward != other.save_for_backward:
  164. return "save_for_backward"
  165. elif self.tensor_hooks != other.tensor_hooks:
  166. return "tensor_hooks"
  167. else:
  168. return None
  169. def clone(self) -> "SideEffects":
  170. """Create a shallow copy"""
  171. ref = self.output_graph_weakref()
  172. assert ref is not None
  173. return self.__class__(
  174. output_graph=ref,
  175. id_to_variable=dict(self.id_to_variable),
  176. store_attr_mutations={
  177. k: dict(v) for k, v in self.store_attr_mutations.items()
  178. },
  179. keepalive=list(self.keepalive),
  180. save_for_backward=self.save_for_backward,
  181. tensor_hooks=self.tensor_hooks,
  182. )
  183. def __contains__(self, item: Any) -> bool:
  184. return id(item) in self.id_to_variable
  185. def __getitem__(self, item: Any) -> VariableTracker:
  186. return self.id_to_variable[id(item)]
  187. def should_allow_side_effects_under_checkpoint(self) -> bool:
  188. output_graph = self.output_graph_weakref()
  189. return bool(
  190. output_graph
  191. and output_graph.current_tx.output.current_tracer.under_activation_checkpoint
  192. and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint
  193. )
  194. def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool:
  195. output_graph = self.output_graph_weakref()
  196. return bool(
  197. output_graph
  198. and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
  199. )
  200. def is_reconstructing_generator(self) -> bool:
  201. output_graph = self.output_graph_weakref()
  202. return bool(
  203. output_graph
  204. and output_graph.current_tx.output.current_tracer.is_reconstructing_generator
  205. )
  206. def check_allowed_side_effect(self, item: VariableTracker) -> bool:
  207. from torch._dynamo.variables.misc import AutogradFunctionContextVariable
  208. # People do things like self.dim = dim inside autograd.Function.
  209. # These are benign.
  210. if isinstance(item, AutogradFunctionContextVariable):
  211. return True
  212. if self.should_allow_externally_visible_side_effects_in_subtracer():
  213. return True
  214. if self.should_allow_side_effects_under_checkpoint():
  215. return True
  216. if self.is_reconstructing_generator():
  217. # This is missing the case where one mutates a tensor. See
  218. # test_generator.py::test_reconstruct_generator_tensor_mutation
  219. raise SideEffectsError(
  220. "Cannot reconstruct a generator with variable mutations. "
  221. "Dynamo needs to fully exhaust the generator, which may cause "
  222. "unintended variable modifications."
  223. )
  224. if not is_side_effect_safe(item.mutation_type):
  225. # TODO plumb HOP information here
  226. unimplemented_v2(
  227. gb_type="HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)",
  228. context="",
  229. explanation="This is not supported.",
  230. hints=[],
  231. )
  232. return False
  233. def store_attr(
  234. self, item: VariableTracker, name: str, value: VariableTracker
  235. ) -> None:
  236. assert self.is_attribute_mutation(item)
  237. self.check_allowed_side_effect(item)
  238. if item not in self.store_attr_mutations:
  239. self.store_attr_mutations[item] = {}
  240. self.store_attr_mutations[item][name] = value
  241. def load_attr(
  242. self,
  243. item: VariableTracker,
  244. name: str,
  245. deleted_ok: bool = False,
  246. check: bool = False,
  247. ) -> VariableTracker:
  248. if check:
  249. assert self.is_attribute_mutation(item)
  250. result = self.store_attr_mutations[item][name]
  251. if not deleted_ok and isinstance(result, variables.DeletedVariable):
  252. unimplemented_v2(
  253. gb_type="Attempted to read a deleted variable",
  254. context=f"item: {item}, name: {name}",
  255. explanation="",
  256. hints=[*graph_break_hints.USER_ERROR],
  257. )
  258. return result
  259. def store_cell(self, cellvar: VariableTracker, value: VariableTracker) -> None:
  260. if cellvar.is_immutable():
  261. unimplemented_v2(
  262. gb_type="Write to immutable cell",
  263. context=f"cellvar: {cellvar}, value: {value}",
  264. explanation="Dynamo doesn't support writing to immutable/sourceless cell variables.",
  265. hints=[*graph_break_hints.DIFFICULT],
  266. )
  267. assert isinstance(cellvar, variables.CellVariable)
  268. assert isinstance(value, variables.VariableTracker)
  269. self.store_attr(cellvar, "cell_contents", value)
  270. def load_cell(self, cellvar: VariableTracker) -> VariableTracker:
  271. assert isinstance(cellvar, variables.CellVariable)
  272. if self.has_pending_mutation_of_attr(cellvar, "cell_contents"):
  273. return self.load_attr(cellvar, "cell_contents", check=False)
  274. if cellvar.pre_existing_contents:
  275. return cellvar.pre_existing_contents
  276. unimplemented_v2(
  277. gb_type="Read uninitialized cell",
  278. context=str(cellvar),
  279. explanation="Attempted to read a cell variable that has not been populated yet.",
  280. hints=[*graph_break_hints.USER_ERROR],
  281. )
  282. def load_global(self, gvar: VariableTracker, name: str) -> VariableTracker:
  283. assert isinstance(gvar, variables.VariableTracker)
  284. return self.load_attr(gvar, name)
  285. def store_global(
  286. self, gvar: VariableTracker, name: str, value: VariableTracker
  287. ) -> None:
  288. assert isinstance(gvar, variables.VariableTracker)
  289. assert isinstance(value, variables.VariableTracker)
  290. self.store_attr(gvar, name, value)
  291. @staticmethod
  292. def cls_supports_mutation_side_effects(cls: type) -> bool:
  293. return inspect.getattr_static(cls, "__getattribute__", None) in (
  294. object.__getattribute__,
  295. dict.__getattribute__,
  296. set.__getattribute__,
  297. frozenset.__getattribute__,
  298. int.__getattribute__,
  299. str.__getattribute__,
  300. list.__getattribute__,
  301. tuple.__getattribute__,
  302. BaseException.__getattribute__,
  303. )
  304. def is_attribute_mutation(self, item: VariableTracker) -> bool:
  305. return isinstance(item.mutation_type, AttributeMutation)
  306. def has_pending_mutation(self, item: VariableTracker) -> bool:
  307. return self.is_attribute_mutation(item) and bool(
  308. self.store_attr_mutations.get(item)
  309. )
  310. def has_pending_mutation_of_attr(self, item: VariableTracker, name: str) -> bool:
  311. return self.is_attribute_mutation(
  312. item
  313. ) and name in self.store_attr_mutations.get(item, ())
  314. def is_modified(self, item: VariableTracker) -> bool:
  315. if item.is_immutable():
  316. return False
  317. if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)):
  318. return True
  319. if isinstance(item, variables.UserDefinedObjectVariable):
  320. # Checks if the underlying dict or tuple vt has been modified
  321. return item in self.store_attr_mutations or item.is_underlying_vt_modified(
  322. self
  323. )
  324. if self.is_attribute_mutation(item):
  325. return item in self.store_attr_mutations
  326. return item.mutation_type.is_modified # type: ignore[attr-defined]
  327. def _track_obj(
  328. self,
  329. item: Any,
  330. variable: VariableTracker,
  331. mutation_type_cls: type = ValueMutationExisting,
  332. ) -> VariableTracker:
  333. """Start tracking an existing or new variable for mutation"""
  334. if id(item) in self.id_to_variable:
  335. raise AssertionError(
  336. f"{variable} is already tracked for mutation. This could be "
  337. "because you are not using VariableBuilder to construct "
  338. "the variable tracker. "
  339. f"Source of new object: {variable.source}. "
  340. f"Source of previously tracked object: {self.id_to_variable[id(item)].source}."
  341. )
  342. variable.mutation_type = mutation_type_cls()
  343. self.id_to_variable[id(item)] = variable
  344. self.keepalive.append(item)
  345. return variable
  346. track_mutable = _track_obj
  347. def track_object_existing(
  348. self,
  349. item: Any,
  350. variable: VariableTracker,
  351. ) -> VariableTracker:
  352. return self._track_obj(
  353. item,
  354. variable,
  355. mutation_type_cls=AttributeMutationExisting,
  356. )
  357. def track_object_new(
  358. self,
  359. cls_source: Source,
  360. user_cls: Any,
  361. variable_cls: Any,
  362. options: dict[str, Any],
  363. ) -> VariableTracker:
  364. if user_cls is torch.autograd.function.FunctionCtx:
  365. with warnings.catch_warnings(record=True):
  366. obj = torch.autograd.Function()
  367. else:
  368. obj = object_new(user_cls)
  369. variable = variable_cls(
  370. obj,
  371. mutation_type=AttributeMutationNew(cls_source),
  372. **options,
  373. )
  374. self.id_to_variable[id(obj)] = variable
  375. self.keepalive.append(obj)
  376. return variable
  377. def get_variable_cls(self, user_cls: type) -> type:
  378. from torch.overrides import TorchFunctionMode
  379. from .variables.ctx_manager import GenericContextWrappingVariable
  380. from .variables.torch_function import TorchFunctionModeVariable
  381. from .variables.user_defined import is_forbidden_context_manager
  382. variable_cls: type[variables.UserDefinedObjectVariable] = (
  383. variables.UserDefinedObjectVariable
  384. )
  385. if issubclass(
  386. user_cls, TorchFunctionMode
  387. ) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls):
  388. variable_cls = TorchFunctionModeVariable
  389. elif (
  390. hasattr(user_cls, "__enter__")
  391. and hasattr(user_cls, "__exit__")
  392. and not is_forbidden_context_manager(user_cls)
  393. ):
  394. variable_cls = GenericContextWrappingVariable
  395. elif issubclass(user_cls, torch.nn.Module):
  396. variable_cls = variables.UnspecializedNNModuleVariable
  397. elif issubclass(user_cls, (dict, collections.OrderedDict)):
  398. variable_cls = variables.UserDefinedDictVariable
  399. elif issubclass(user_cls, (set, frozenset)):
  400. variable_cls = variables.UserDefinedSetVariable
  401. elif issubclass(user_cls, tuple):
  402. variable_cls = variables.UserDefinedTupleVariable
  403. elif issubclass(user_cls, list):
  404. variable_cls = variables.UserDefinedListVariable
  405. elif issubclass(user_cls, MutableMapping):
  406. variable_cls = variables.MutableMappingVariable
  407. elif is_frozen_dataclass(user_cls):
  408. variable_cls = FrozenDataClassVariable
  409. elif issubclass(user_cls, BaseException):
  410. variable_cls = variables.UserDefinedExceptionObjectVariable
  411. assert issubclass(variable_cls, variables.UserDefinedObjectVariable)
  412. return variable_cls
  413. def get_example_value(
  414. self,
  415. base_cls_vt: VariableTracker,
  416. cls_vt: VariableTracker,
  417. init_args: list[VariableTracker],
  418. ) -> Any:
  419. user_cls = cls_vt.value # type: ignore[attr-defined]
  420. if issubclass(user_cls, torch.nn.Module):
  421. # TODO(anijain2305) - Is it possible to remove this specialization?
  422. obj = nn_module_new(user_cls)
  423. else:
  424. if isinstance(base_cls_vt, variables.BuiltinVariable):
  425. base_cls = base_cls_vt.fn
  426. elif isinstance(base_cls_vt, variables.UserDefinedClassVariable):
  427. base_cls = base_cls_vt.value
  428. else:
  429. raise RuntimeError(f"Unexpected base_cls_vt {base_cls_vt}")
  430. assert variables.UserDefinedClassVariable.is_supported_new_method(
  431. base_cls.__new__
  432. )
  433. # TODO(anijain2305) - Consider adding get_example_value method to
  434. # each VT to get an example value for all args. As we expand the
  435. # scope to other __new__ methods, we might need to call __new__ with
  436. # init_args (like functools.partial)
  437. # init_args = [arg.get_example_value() for arg in init_args]
  438. # obj = base_cls.__new__(user_cls, *init_args)
  439. obj = base_cls.__new__(user_cls)
  440. return obj
  441. def track_new_user_defined_object(
  442. self,
  443. base_cls_vt: VariableTracker,
  444. cls_vt: VariableTracker,
  445. init_args: list[VariableTracker],
  446. ) -> VariableTracker:
  447. """
  448. Creates a UserDefinedObjectVariable (or its subclass) variable tracker
  449. and mark it for attribute mutation tracking.
  450. Also records the variable trackers to call __new__ method on
  451. reconstruction. Roughly, the reconstruction looks like this
  452. base_cls_vt.__new__(user_cls, *init_args)
  453. """
  454. cls_source = cls_vt.source
  455. user_cls = cls_vt.value # type: ignore[attr-defined]
  456. variable_cls = self.get_variable_cls(user_cls)
  457. obj = self.get_example_value(base_cls_vt, cls_vt, init_args)
  458. variable = variable_cls(
  459. obj,
  460. cls_source=cls_vt.source,
  461. base_cls_vt=base_cls_vt,
  462. init_args=init_args,
  463. mutation_type=AttributeMutationNew(cls_source),
  464. )
  465. self.id_to_variable[id(obj)] = variable
  466. self.keepalive.append(obj)
  467. return variable
  468. def track_cell_new(
  469. self,
  470. ) -> VariableTracker:
  471. obj = object()
  472. variable = variables.CellVariable(
  473. mutation_type=AttributeMutationNew(),
  474. )
  475. self.id_to_variable[id(obj)] = variable
  476. self.keepalive.append(obj)
  477. return variable
  478. def track_cell_existing(
  479. self, source: Optional[Source], cell: CellType, contents: VariableTracker
  480. ) -> VariableTracker:
  481. variable = variables.CellVariable(
  482. # We don't support mutation to cell without source because we need
  483. # source to properly codegen the mutations.
  484. mutation_type=None if source is None else AttributeMutationExisting(),
  485. pre_existing_contents=contents,
  486. source=source,
  487. )
  488. self.id_to_variable[id(cell)] = variable
  489. self.keepalive.append(cell)
  490. return variable
  491. def track_global_existing(self, source: Source, item: Any) -> VariableTracker:
  492. variable = variables.NewGlobalVariable(
  493. mutation_type=AttributeMutationExisting(),
  494. source=source,
  495. )
  496. self.id_to_variable[id(item)] = variable
  497. self.keepalive.append(item)
  498. return variable
  499. def track_save_for_backward(
  500. self, ctx: VariableTracker, args: list[VariableTracker]
  501. ) -> None:
  502. assert isinstance(ctx, variables.AutogradFunctionContextVariable)
  503. self.save_for_backward.append((ctx, args))
  504. def track_runahead_tensor_and_symvar_side_effects(
  505. self, other: "SideEffects"
  506. ) -> None:
  507. # In higher order ops we want to keep track of tensors seen in the
  508. # speculate_subgraph so that we don't lift them again as a new input in
  509. # other speculate_subgraph or in the root tracer.
  510. for other_item in other.keepalive:
  511. other_id = id(other_item)
  512. other_variable = other.id_to_variable[other_id]
  513. if other_id not in self.id_to_variable and isinstance(
  514. other_variable, (variables.TensorVariable, variables.SymNodeVariable)
  515. ):
  516. self.track_object_existing(other_item, other_variable)
  517. def prune_dead_object_new(self, tx: "InstructionTranslatorBase") -> None:
  518. # Avoid VT cycles from e.g., recursive function.
  519. visited: set[VariableTracker] = set()
  520. live_new_objects: set[VariableTracker] = set()
  521. def visit(var: VariableTracker) -> None:
  522. if var in visited:
  523. return
  524. visited.add(var)
  525. # Object may have been mutated, store this mutation.
  526. if isinstance(var.mutation_type, AttributeMutationNew):
  527. live_new_objects.add(var)
  528. # It's possible that we have mutated the value of this variable
  529. # to be another one. The new value is in store_attr_mutations.
  530. # Also recurse through the new value to detect alive AttributeMutationNew.
  531. if var in self.store_attr_mutations:
  532. VariableTracker.visit(
  533. visit, # noqa: F821
  534. self.store_attr_mutations[var],
  535. )
  536. def is_live(var: VariableTracker) -> bool:
  537. if isinstance(var.mutation_type, AttributeMutationNew):
  538. return var in live_new_objects
  539. return True
  540. pre_existing_vars = [
  541. var
  542. for var in self.id_to_variable.values()
  543. if not isinstance(var.mutation_type, AttributeMutationNew)
  544. ]
  545. # The only live side effects come from returns (tx.stack), any intermediates
  546. # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables.
  547. # Recursively visit Variables and see if any of them have been mutated.
  548. init_live_vars = []
  549. # gather stack/symbolic_locals for all tx's up the chain
  550. cur_tx: Optional[InstructionTranslatorBase] = tx
  551. while cur_tx is not None:
  552. init_live_vars.extend([cur_tx.stack, cur_tx.symbolic_locals])
  553. cur_tx = cur_tx.parent
  554. VariableTracker.visit(
  555. visit,
  556. # TODO track from all possible sources.
  557. init_live_vars
  558. + [
  559. pre_existing_vars,
  560. tx.output.backward_state,
  561. self.tensor_hooks,
  562. ],
  563. )
  564. # Manually release the self-referential function, which indirectly
  565. # captures certain `VariableTracker` and affects parts of PT test/logic
  566. # that are sensitive to when certain objects get released.
  567. del visit
  568. # NB: cell variable handling.is tricky.
  569. # cell variables must stay alive if any NestedUserFunctionVariable
  570. # are live. "visit"-ing the NestedUserFunctionVariable visits
  571. # the .closures field, from which we will see if we need to keep
  572. # any mutations to cell variables alive.
  573. self.id_to_variable = {
  574. k: v for k, v in self.id_to_variable.items() if is_live(v)
  575. }
  576. self.store_attr_mutations = {
  577. k: v for k, v in self.store_attr_mutations.items() if is_live(k)
  578. }
  579. def mutation(self, var: VariableTracker) -> None:
  580. if var in self.ignore_mutation_on_these_variables:
  581. return
  582. self.check_allowed_side_effect(var)
  583. if isinstance(var.mutation_type, ValueMutationExisting):
  584. var.mutation_type.is_modified = True
  585. if (
  586. var.source
  587. and isinstance(var, variables.ConstDictVariable)
  588. and not isinstance(var, variables.SetVariable)
  589. ):
  590. self._has_existing_dict_mutation = True
  591. def has_existing_dict_mutation(self) -> bool:
  592. return self._has_existing_dict_mutation
  593. def _get_modified_vars(self) -> list[VariableTracker]:
  594. return [var for var in self.id_to_variable.values() if self.is_modified(var)]
  595. def codegen_save_tempvars(self, cg: PyCodegen) -> None:
  596. # We must codegen modified VT to their source by default, so that
  597. # mutation and aliasing are properly accounted for.
  598. #
  599. # Since newly constructed objects don't have a source, we manually
  600. # codegen their construction and store them to a newly assigned local
  601. # source. Note that `ValueMutationNew` isn't tracked by SideEffects.
  602. for var in self._get_modified_vars():
  603. if not isinstance(var.mutation_type, AttributeMutationNew):
  604. assert var.source is not None
  605. continue
  606. if isinstance(var, variables.CellVariable):
  607. # Cells created in the root frame are created either by
  608. # `MAKE_CELL` or by them being in `co_cellvars`, so we only emit
  609. # `make_cell` for the non-root-frame cells here.
  610. # TODO generalize this so we never need to call `make_cell`.
  611. if var.local_name is None:
  612. cg.add_push_null(
  613. lambda: cg.load_import_from(utils.__name__, "make_cell")
  614. )
  615. cg.extend_output(create_call_function(0, False))
  616. cg.add_cache(var)
  617. var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
  618. elif var.source is None:
  619. var.source = LocalCellSource(var.local_name)
  620. elif isinstance(var, variables.TensorVariable):
  621. # NOTE: for historical reasons we never assigned local sources
  622. # to newly constructed tensor object, so we keep it that way.
  623. # They are always loaded from output of the fx graph, so one can
  624. # think of it as having a "OutputGraphSource" for codegen
  625. # purposes.
  626. #
  627. # However, tensor subclass objects are different, because the
  628. # reconstruction logic in `PyCodegen` loads the data tensor from
  629. # graph output and then calls `as_subclass`, meaning we must
  630. # assign a source to it to ensure we only reconstruct one
  631. # subclass instance.
  632. if isinstance(
  633. var, variables.torch_function.TensorWithTFOverrideVariable
  634. ):
  635. # Don't codegen from temp source assigned from the 1st pass.
  636. cg(var, allow_cache=False)
  637. cg.add_cache(var)
  638. # `add_cache` generates STORE and consumes TOS, but we never
  639. # cleared it. TODO move this call into `add_cache`
  640. cg.clear_tos()
  641. var.source = LocalSource(cg.tempvars[var])
  642. elif isinstance(var, variables.AutogradFunctionContextVariable):
  643. unimplemented_v2(
  644. gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region",
  645. context="",
  646. explanation="We cannot reconstruct a torch.autograd.Function's context object.",
  647. hints=[],
  648. )
  649. else:
  650. # Reconstruct the bytecode for
  651. # base_cls.__new__(user_cls, *args)
  652. if isinstance(var, variables.UserDefinedObjectVariable):
  653. def load_new_method() -> None:
  654. assert var.base_cls_vt is not None
  655. cg(var.base_cls_vt) # type: ignore[attr-defined]
  656. cg.extend_output([cg.create_load_attr("__new__")])
  657. cg.add_push_null(load_new_method)
  658. else:
  659. cg.add_push_null(
  660. lambda: cg.load_import_from(utils.__name__, "object_new")
  661. )
  662. assert var.mutation_type.cls_source is not None
  663. cg(var.mutation_type.cls_source)
  664. # Generate the args to the __new__ method
  665. for arg in var.init_args: # type: ignore[attr-defined]
  666. cg(arg)
  667. # Call the __new__ method
  668. cg.extend_output(create_call_function(1 + len(var.init_args), False)) # type: ignore[attr-defined]
  669. cg.add_cache(var)
  670. var.source = LocalSource(cg.tempvars[var])
  671. for ctx, args in self.save_for_backward:
  672. cg(ctx.source)
  673. cg.load_method("save_for_backward")
  674. for arg in args:
  675. cg(arg)
  676. cg.extend_output(
  677. [
  678. *create_call_method(len(args)),
  679. create_instruction("POP_TOP"),
  680. ]
  681. )
  682. def register_hook(
  683. self,
  684. tensor: "variables.TensorVariable",
  685. hook: VariableTracker,
  686. handle: "variables.RemovableHandleVariable",
  687. name: str,
  688. ) -> None:
  689. assert isinstance(tensor, variables.TensorVariable)
  690. assert isinstance(hook, variables.VariableTracker)
  691. assert (
  692. isinstance(handle, variables.RemovableHandleVariable)
  693. and handle.is_mutable()
  694. )
  695. assert hasattr(torch.Tensor, name)
  696. idx = len(self.tensor_hooks.keys())
  697. # duplicate index possible because of self.remove_hook()
  698. while idx in self.tensor_hooks:
  699. idx += 1
  700. self.tensor_hooks[idx] = (tensor, hook, handle, name)
  701. assert not handle.idx
  702. handle.idx = idx
  703. def remove_hook(self, idx: int) -> None:
  704. del self.tensor_hooks[idx]
  705. def codegen_hooks(self, cg: PyCodegen) -> None:
  706. for (
  707. tensor,
  708. hook,
  709. handle,
  710. name,
  711. ) in self.tensor_hooks.values():
  712. # Note: [On tensor.register_hook]
  713. #
  714. # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented
  715. # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries).
  716. #
  717. # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph.
  718. # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in
  719. # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able
  720. # tensors. Because a source indicates knowledge of this object outside the torch compile region, and
  721. # because we are running residuals firmly before .backward() can be run, it is sound to invoke
  722. # `register_hook` on a known tensor.
  723. #
  724. # For tensors without a source, we support a limited subset of hooks. Global functions only, and
  725. # compiled_autograd must be enabled or we will graph break.
  726. #
  727. # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the
  728. # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed
  729. # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the
  730. # stack intact.
  731. #
  732. # Dynamo Tensor Hooks Workflow:
  733. # - Functions passed to register_hook are lifted globally.
  734. # - For tensors with sources:
  735. # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to:
  736. # - Generate the tensor.
  737. # - Issue a register_hook call on the tensor, linking to the globally stored function.
  738. # - Incorporate a handle if one was established in the eager phase.
  739. # - For tensors without sources:
  740. # - We don't generate any instructions for registering a hook.
  741. # - Handles from intermediary hooks are NYI.
  742. # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it.
  743. # - We then manually insert the call function above into the graph.
  744. # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST.
  745. assert tensor.source, "Hooks on non input tensors NYI - should not get here"
  746. def gen_fn() -> None:
  747. cg(tensor)
  748. cg.extend_output([cg.create_load_attr(name)])
  749. cg.add_push_null(gen_fn)
  750. cg(hook)
  751. cg.extend_output(create_call_function(1, False))
  752. # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will
  753. # be associated with the return value of register_hook(). This consumes the top of stack.
  754. cg.add_cache(handle)
  755. def get_ca_final_callbacks_var(self) -> "variables.ListVariable":
  756. from .variables.base import ValueMutationNew
  757. if self.ca_final_callbacks_var is None:
  758. self.ca_final_callbacks_var = variables.ListVariable(
  759. [], mutation_type=ValueMutationNew()
  760. )
  761. return self.ca_final_callbacks_var
  762. def codegen_update_mutated(self, cg: PyCodegen) -> None:
  763. suffixes = []
  764. for var in self._get_modified_vars():
  765. if isinstance(var, variables.ListVariable):
  766. # old[:] = new
  767. cg(var, allow_cache=False) # Don't codegen via source
  768. cg(var.source) # type: ignore[attr-defined]
  769. cg.extend_output(
  770. [
  771. cg.create_load_const(None),
  772. cg.create_load_const(None),
  773. create_instruction("BUILD_SLICE", arg=2),
  774. ]
  775. )
  776. suffixes.append([create_instruction("STORE_SUBSCR")])
  777. elif isinstance(var, variables.lists.DequeVariable):
  778. # For limited maxlen, the order of operations matter for side
  779. # effect, but we currently don't track the order, so no support.
  780. if not (
  781. isinstance(var.maxlen, variables.ConstantVariable)
  782. and var.maxlen.value is None
  783. ):
  784. unimplemented_v2(
  785. gb_type="Side effect on existing deque with limited maxlen",
  786. context="",
  787. explanation="This is not supported.",
  788. hints=[
  789. "Don't use a deque with `maxlen` specified.",
  790. ],
  791. )
  792. # old.extend(new), this runs last
  793. cg(var.source)
  794. cg.load_method("extend")
  795. cg(var, allow_cache=False) # Don't codegen via source
  796. suffixes.append(
  797. [
  798. *create_call_method(1),
  799. create_instruction("POP_TOP"),
  800. ]
  801. )
  802. # old.clear(), this runs first
  803. cg(var.source)
  804. cg.load_method("clear")
  805. suffixes.append(
  806. [
  807. *create_call_method(0),
  808. create_instruction("POP_TOP"),
  809. ]
  810. )
  811. elif isinstance(var, variables.ConstDictVariable):
  812. # Reconstruct works as follow:
  813. # (1) Skip codegen if there are no new items
  814. # (2) codegen(...) each pair of key/value
  815. # (3) create a new dictionary with the pairs of key/values above
  816. # (4) clear the original dictionary
  817. # + only if a key was removed from the input dict
  818. # (5) update the original dictionary with the dict created in (2)
  819. if var.has_new_items():
  820. cg(var.source) # type: ignore[attr-defined]
  821. cg.load_method("update")
  822. cg(var, allow_cache=False) # Don't codegen via source
  823. if var.should_reconstruct_all:
  824. cg(var.source) # type: ignore[attr-defined]
  825. cg.load_method("clear")
  826. suffixes.append(
  827. [
  828. *create_call_method(1), # update
  829. create_instruction("POP_TOP"),
  830. ]
  831. )
  832. if var.should_reconstruct_all:
  833. # clear will appear before "update" as the suffixes are
  834. # applied in reverse order.
  835. suffixes.append(
  836. [
  837. *create_call_method(0), # clear
  838. create_instruction("POP_TOP"),
  839. ]
  840. )
  841. elif isinstance(
  842. var, variables.torch_function.TorchFunctionModeStackVariable
  843. ):
  844. # Needed in the finally block for stack restoration
  845. cg.add_push_null(
  846. lambda: cg.load_import_from(
  847. utils.__name__, "get_torch_function_mode_stack"
  848. )
  849. )
  850. cg.call_function(0, False)
  851. name = variables.torch_function.get_prev_stack_var_name()
  852. cg.code_options["co_varnames"] += (name,)
  853. cg.append_output(create_instruction("STORE_FAST", argval=name))
  854. cg.add_push_null(
  855. lambda: cg.load_import_from(
  856. utils.__name__, "set_torch_function_mode_stack"
  857. )
  858. )
  859. cg.foreach(var.symbolic_stack)
  860. cg.append_output(
  861. create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
  862. )
  863. cg.call_function(1, False)
  864. cg.append_output(create_instruction("POP_TOP"))
  865. elif isinstance(var, variables.CellVariable) and var.local_name is not None:
  866. # Emit more readable and performant bytecode.
  867. # TODO generalize this for cells created during inlining.
  868. if var in self.store_attr_mutations:
  869. contents_var = self.load_cell(var)
  870. cg(contents_var)
  871. suffixes.append([cg.create_store_deref(var.local_name)])
  872. elif self.is_attribute_mutation(var):
  873. if isinstance(
  874. var, variables.UserDefinedDictVariable
  875. ) and self.is_modified(var._dict_vt):
  876. # Do dict related update manually here. The store_attr
  877. # mutations will be applied later.
  878. varname_map = {}
  879. for name in _manual_dict_setitem.__code__.co_varnames:
  880. varname_map[name] = cg.tx.output.new_var()
  881. try:
  882. mro_index = type(var.value).__mro__.index(
  883. collections.OrderedDict
  884. )
  885. except ValueError:
  886. mro_index = type(var.value).__mro__.index(dict)
  887. cg.extend_output(
  888. [
  889. create_instruction("LOAD_CONST", argval=mro_index),
  890. create_instruction(
  891. "STORE_FAST", argval=varname_map["mro_index"]
  892. ),
  893. ]
  894. )
  895. cg(var.source) # type: ignore[attr-defined]
  896. cg.extend_output(
  897. [
  898. create_instruction(
  899. "STORE_FAST", argval=varname_map["dict_to"]
  900. )
  901. ]
  902. )
  903. cg(var._dict_vt, allow_cache=False) # Don't codegen via source
  904. cg.extend_output(
  905. [
  906. create_instruction(
  907. "STORE_FAST", argval=varname_map["dict_from"]
  908. )
  909. ]
  910. )
  911. dict_update_insts = bytecode_from_template(
  912. _manual_dict_setitem, varname_map=varname_map
  913. )
  914. suffixes.append(
  915. [
  916. *dict_update_insts,
  917. create_instruction("POP_TOP"),
  918. ]
  919. )
  920. elif isinstance(
  921. var, variables.UserDefinedListVariable
  922. ) and self.is_modified(var._list_vt):
  923. # Update the list to the updated items. Be careful in
  924. # calling the list methods and not the overridden methods.
  925. varname_map = {}
  926. for name in _manual_list_update.__code__.co_varnames:
  927. varname_map[name] = cg.tx.output.new_var()
  928. cg(var.source) # type: ignore[attr-defined]
  929. cg.extend_output(
  930. [
  931. create_instruction(
  932. "STORE_FAST", argval=varname_map["list_to"]
  933. )
  934. ]
  935. )
  936. cg(var._list_vt, allow_cache=False) # Don't codegen via source
  937. cg.extend_output(
  938. [
  939. create_instruction(
  940. "STORE_FAST", argval=varname_map["list_from"]
  941. )
  942. ]
  943. )
  944. list_update_insts = bytecode_from_template(
  945. _manual_list_update, varname_map=varname_map
  946. )
  947. suffixes.append(
  948. [
  949. *list_update_insts,
  950. create_instruction("POP_TOP"),
  951. ]
  952. )
  953. # Applying mutations involves two steps: 1) Push all
  954. # reconstructed objects onto the stack. 2) Call STORE_ATTR to
  955. # apply the mutations.
  956. #
  957. # Dynamo must ensure that mutations are applied in the same
  958. # order as in the original program. Therefore, two reverse
  959. # operations occur below.
  960. #
  961. # The first reverse operation concerns `suffixes`. We apply
  962. # suffixes in reverse order due to the way Python handles the
  963. # stack. In Step 1, we push all reconstructed objects onto the
  964. # stack, but the item at the top of the stack refers to the last
  965. # attribute in the mutation order. If not fixed, this will apply
  966. # the mutations of attributes in the reverse order. To account
  967. # for this reversal, we iterate through the mutable attributes
  968. # in reverse order.
  969. for name, value in reversed(
  970. self.store_attr_mutations.get(var, {}).items()
  971. ):
  972. if isinstance(var, variables.NewGlobalVariable):
  973. cg.tx.output.update_co_names(name)
  974. cg(value)
  975. assert isinstance(var.source, GlobalSource) # type: ignore[attr-defined]
  976. suffixes.append(
  977. [create_instruction("STORE_GLOBAL", argval=name)]
  978. )
  979. elif isinstance(value, variables.DeletedVariable):
  980. if isinstance(
  981. var.mutation_type, AttributeMutationExisting
  982. ) and hasattr(getattr(var, "value", None), name):
  983. cg.tx.output.update_co_names(name)
  984. cg(var.source)
  985. suffixes.append(
  986. [create_instruction("DELETE_ATTR", argval=name)]
  987. )
  988. elif isinstance(
  989. var, variables.UserDefinedObjectVariable
  990. ) and var.should_skip_descriptor_setter(name):
  991. cg.add_push_null(
  992. lambda: cg.load_import_from(
  993. utils.__name__, "object_setattr_ignore_descriptor"
  994. )
  995. )
  996. cg(var.source) # type: ignore[attr-defined]
  997. cg(variables.ConstantVariable(name))
  998. cg(value)
  999. suffixes.append(
  1000. [
  1001. *create_call_function(3, False),
  1002. create_instruction("POP_TOP"),
  1003. ]
  1004. )
  1005. elif (
  1006. isinstance(var, variables.UserDefinedObjectVariable)
  1007. and var.needs_slow_setattr()
  1008. ):
  1009. # __setattr__ is defined on this object, so call object.__setattr__ directly
  1010. cg.load_import_from("builtins", "object")
  1011. cg.load_method("__setattr__")
  1012. cg(var.source) # type: ignore[attr-defined]
  1013. cg(variables.ConstantVariable(name))
  1014. cg(value)
  1015. suffixes.append(
  1016. [*create_call_method(3), create_instruction("POP_TOP")]
  1017. )
  1018. else:
  1019. cg.tx.output.update_co_names(name)
  1020. cg(value)
  1021. cg(var)
  1022. suffixes.append([create_instruction("STORE_ATTR", argval=name)])
  1023. elif isinstance(var, variables.ListIteratorVariable):
  1024. for _ in range(var.index):
  1025. cg.add_push_null(
  1026. lambda: cg.load_import_from(utils.__name__, "iter_next")
  1027. )
  1028. cg(var.source) # type: ignore[attr-defined]
  1029. cg.call_function(1, False)
  1030. cg.pop_top()
  1031. elif isinstance(var, variables.RandomVariable):
  1032. # set correct random seed state
  1033. def gen_fn() -> None:
  1034. cg(var.source) # type: ignore[attr-defined]
  1035. cg.load_attr("setstate")
  1036. cg.add_push_null(gen_fn)
  1037. cg(var.wrap_state(var.random.getstate()))
  1038. suffixes.append(
  1039. [
  1040. *create_call_function(1, False), # setstate
  1041. create_instruction("POP_TOP"),
  1042. ]
  1043. )
  1044. else:
  1045. raise AssertionError(type(var))
  1046. # do all the actual mutations at the very end to handle dependencies
  1047. for suffix in reversed(suffixes):
  1048. cg.extend_output(suffix)
  1049. def is_empty(self) -> bool:
  1050. return not (
  1051. any(map(self.is_modified, self.id_to_variable.values()))
  1052. or self.tensor_hooks
  1053. or self.save_for_backward
  1054. or self.tensor_hooks
  1055. )
  1056. def clear(self) -> None:
  1057. self.keepalive.clear()
  1058. self.id_to_variable.clear()
  1059. @contextlib.contextmanager
  1060. def allow_side_effects_under_checkpoint(
  1061. tx: "InstructionTranslatorBase",
  1062. ) -> Generator[None, None, None]:
  1063. assert tx.output.current_tracer.under_activation_checkpoint
  1064. orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint
  1065. try:
  1066. tx.output.current_tracer.allow_side_effects_under_checkpoint = True
  1067. yield
  1068. finally:
  1069. tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val
  1070. @contextlib.contextmanager
  1071. def allow_externally_visible_side_effects_in_subtracer(
  1072. tx: "InstructionTranslatorBase",
  1073. ) -> Generator[None, None, None]:
  1074. orig_val = tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
  1075. try:
  1076. tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = True
  1077. yield
  1078. finally:
  1079. tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = orig_val
  1080. @contextlib.contextmanager
  1081. def disallow_side_effects_in_generator(
  1082. tx: "InstructionTranslatorBase",
  1083. ) -> Generator[None, None, None]:
  1084. orig_val = tx.output.current_tracer.is_reconstructing_generator
  1085. try:
  1086. tx.output.current_tracer.is_reconstructing_generator = True
  1087. yield
  1088. finally:
  1089. tx.output.current_tracer.is_reconstructing_generator = orig_val