codegen.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. """
  2. This module provides utilities for generating Python bytecode in PyTorch's Dynamo system.
  3. It includes functionality for:
  4. - Constructing bytecode sequences for Python operations
  5. - Managing stack operations and variable tracking
  6. - Handling graph outputs and their conversions
  7. - Supporting different Python versions (3.11+, 3.12+, 3.13+)
  8. - Converting high-level operations to low-level bytecode instructions
  9. - Managing constant loading and attribute access
  10. - Supporting function creation and closure handling
  11. """
  12. import collections
  13. import dataclasses
  14. import re
  15. import sys
  16. import types
  17. from collections import Counter
  18. from collections.abc import Iterable
  19. from typing import Any, Callable, Optional, TYPE_CHECKING, Union
  20. import torch.nn
  21. from torch.utils._ordered_set import OrderedSet
  22. from . import config, graph_break_hints, utils
  23. from .bytecode_transformation import (
  24. add_push_null,
  25. add_push_null_call_function_ex,
  26. create_call_function,
  27. create_call_method,
  28. create_dup_top,
  29. create_instruction,
  30. create_load_const,
  31. create_load_method,
  32. create_rot_n,
  33. Instruction,
  34. )
  35. from .exc import IncorrectUsage, unimplemented_v2
  36. from .source import AttrSource, ChainedSource, DictGetItemSource, Source
  37. from .utils import is_safe_constant, rot_n_helper
  38. from .variables.base import ValueMutationExisting, VariableTracker
  39. from .variables.functions import (
  40. ContextlibContextManagerLocalGeneratorObjectVariable,
  41. LocalGeneratorObjectVariable,
  42. )
  43. from .variables.nn_module import NNModuleVariable
  44. from .variables.tensor import (
  45. NumpyNdarrayVariable,
  46. SymNodeVariable,
  47. TensorVariable,
  48. UnspecializedPythonVariable,
  49. )
  50. from .variables.torch_function import TensorWithTFOverrideVariable
  51. if TYPE_CHECKING:
  52. from torch._dynamo.variables.builder import GraphArg
  53. from .symbolic_convert import InstructionTranslatorBase
  54. @dataclasses.dataclass
  55. class GraphOutputEntry:
  56. index: int
  57. variable: VariableTracker
  58. class PyCodegen:
  59. """
  60. Helper class uses for constructing Python bytecode
  61. """
  62. def __init__(
  63. self,
  64. tx: "InstructionTranslatorBase",
  65. root: Optional[torch.nn.Module] = None,
  66. graph_output_var: Optional[str] = None,
  67. tempvars: Optional[dict[Union[VariableTracker, Source], Any]] = None,
  68. overridden_sources: Optional[dict[Source, Source]] = None,
  69. ) -> None:
  70. self.root = root
  71. self.top_of_stack: Optional[Union[VariableTracker, Source]] = None
  72. self.uses: Counter[Union[VariableTracker, Source]] = collections.Counter()
  73. self.graph_outputs: dict[int, GraphOutputEntry] = {}
  74. self._output: list[Instruction] = []
  75. # This determines which VariableTracker/Source should be stored as
  76. # locals, and maps the VariableTracker/Source to the local variable
  77. # name. Note that it could map to None initially, in which case we'll
  78. # overwrite it to map to real temporary names via `add_cache`.
  79. self.tempvars: dict[Union[VariableTracker, Source], Any] = tempvars or {}
  80. self.tx = tx
  81. self.graph_output_var = graph_output_var
  82. self.code_options = self.tx.output.code_options
  83. self.cell_and_freevars = self.tx.cell_and_freevars
  84. self.new_var = self.tx.output.new_var
  85. self.value_from_source: bool = True
  86. # This serves as a way for codegen to use a different source; we need
  87. # this because sometimes we can't easily modify the original source
  88. # without affecting other components, e.g., guards.
  89. self.overridden_sources: dict[Source, Source] = overridden_sources or {}
  90. def restore_stack(
  91. self, stack_values: list[Any], *, value_from_source: bool = True
  92. ) -> None:
  93. prev = self.value_from_source
  94. self.value_from_source &= value_from_source
  95. try:
  96. self.foreach(stack_values)
  97. finally:
  98. self.value_from_source = prev
  99. def graph_output_vars(self) -> list[VariableTracker]:
  100. return [x.variable for x in self.graph_outputs.values()]
  101. def call_reconstruct(
  102. self, value: Union[VariableTracker, Source, "GraphArg"]
  103. ) -> None:
  104. res = value.reconstruct(self)
  105. assert res is None, f"reconstruct!=None {value}"
  106. def add_push_null(
  107. self, gen_fn: Callable[[], None], call_function_ex: bool = False
  108. ) -> None:
  109. """
  110. `gen_fn` generates instructions via PyCodegen methods
  111. that push a single callable to the stack.
  112. `add_push_null` pushes a NULL to the stack before or after the
  113. instructions generated by `gen_fn`, depending on Python version.
  114. Will attempt to use the NULL push bit for instructions
  115. with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
  116. """
  117. old_len = len(self._output)
  118. if sys.version_info < (3, 13):
  119. # gen_fn may DUP_TOP instead if TOS is not cleared.
  120. # Will cause problems since NULL will be pushed right
  121. # before the generated instructions in <= 3.12
  122. self.clear_tos()
  123. gen_fn()
  124. # inplace modify self._output
  125. added_insts = self._output[old_len:]
  126. del self._output[old_len:]
  127. if call_function_ex:
  128. self._output.extend(add_push_null_call_function_ex(added_insts))
  129. else:
  130. self._output.extend(add_push_null(added_insts))
  131. if sys.version_info >= (3, 13):
  132. # NULL will be at top of stack
  133. self.clear_tos()
  134. def __call__(
  135. self, value: Union[VariableTracker, Source], allow_cache: bool = True
  136. ) -> None:
  137. """
  138. Generate code such that top-of-stack (TOS) is set to value.
  139. `allow_cache` controls the behavior in the following manner. `value` can
  140. either be a VariableTracker or a Source.
  141. If `value` is a `Source`, `allow_cache` must be True (invariant asserted
  142. below). If the source was reconstructed earlier, we will reuse the
  143. generated code by loading from top of stack or tempvars.
  144. If `value` is a `VariableTracker`, we have the following cases:
  145. 1) `allow_cache=True`
  146. a) If the value.source is not None, we will emit the code based on
  147. `value.source` to handle aliasing.
  148. b) If value.source is None (example reconstructing a local list
  149. returned by the compiled function), we will reconstruct the variable
  150. tracker (w/o any source) to emit bytecode that generates a new
  151. python object.
  152. In both cases of value.source being None or not, if the value was
  153. reconstructed earlier, we will reuse the generated code by loading from
  154. top of stack or tempvars.
  155. 2) `allow_cache=False` - This is a special case (allow_cache defaults to
  156. True).
  157. a) If the value.source is not None, we reconstruct the variable
  158. tracker and emit a new python object. You might wonder what about
  159. aliasing? The place where we use this config also has the followup
  160. code where the original python object is assigned to this new python
  161. value to handle aliasing (check side_effects.py and search for
  162. allow_cache=False).
  163. b) If value.source is None, this is not allowed. TODO - assert this.
  164. Notable effects:
  165. 1. `self.top_of_stack` will be set to `value`, if we don't codegen
  166. `value` based on source.
  167. 2. `self.uses[value]` will increment, unless (a). we codegen via
  168. `top_of_stack` or cached `tempvars`, or (b). `value` has special VT
  169. types like `NNModuleVariable`, etc.
  170. """
  171. if isinstance(value, Source):
  172. # If the source needs to be overridden, use the new one.
  173. source = self.overridden_sources.get(value, value)
  174. assert allow_cache is True, "allow_cache must be True for Source"
  175. if self.top_of_stack is value:
  176. self._output.append(create_dup_top())
  177. return
  178. if self.tempvars.get(source) is not None:
  179. self._output.append(self.create_load(self.tempvars[source]))
  180. self.top_of_stack = source
  181. return
  182. self.uses[source] += 1
  183. try:
  184. self.call_reconstruct(source)
  185. except NotImplementedError:
  186. unimplemented_v2(
  187. gb_type="Reconstruction failure: source.reconstruct not implemented",
  188. context=str(source),
  189. explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.",
  190. hints=[*graph_break_hints.DYNAMO_BUG],
  191. )
  192. if source in self.tempvars:
  193. self._output.append(create_dup_top())
  194. self.add_cache(source)
  195. self.top_of_stack = source
  196. return
  197. assert isinstance(value, VariableTracker)
  198. output = self._output
  199. graph_outputs = self.graph_outputs
  200. if allow_cache:
  201. if self.top_of_stack is value:
  202. output.append(create_dup_top())
  203. return
  204. if self.tempvars.get(value) is not None:
  205. output.append(self.create_load(self.tempvars[value]))
  206. self.top_of_stack = value
  207. return
  208. if value.is_realized() and isinstance(
  209. value, ContextlibContextManagerLocalGeneratorObjectVariable
  210. ):
  211. raise IncorrectUsage(
  212. "NYI: Returning a @contextmanager object from a torch.compile function"
  213. )
  214. # Dynamo normally prefers codegen from source to account for aliasing.
  215. if (
  216. value.source is not None
  217. and allow_cache
  218. and not (
  219. value.is_realized() and isinstance(value, LocalGeneratorObjectVariable)
  220. )
  221. ):
  222. # There's a corner case for export: for instance, if the computation
  223. # graph is just identity on an input tensor, Dynamo would just emit
  224. # a `LOAD_FAST` from the input source, rather than generating an
  225. # identity FX graph.
  226. #
  227. # However, export wants to maximize graph capture; in the case
  228. # above, export _wants to_ obtain an identity FX graph (despite it
  229. # appears unnecessarily expensive for `torch.compile`), so we have
  230. # the following option to override Dynamo's preference for codegen
  231. # from source. Moreover, this option applies recursively, for cases
  232. # like input tensor being returned in a new dictionary.
  233. #
  234. # And why the `ValueMutationExisting` check? Not sure, so leaving it
  235. # to keep the old behavior, as when `value_from_source` was
  236. # introduced. TODO sort out the invariants among side effect,
  237. # codegen and export.
  238. if (
  239. isinstance(value.mutation_type, ValueMutationExisting)
  240. or self.value_from_source
  241. ):
  242. return self(value.source)
  243. if value.is_python_constant() and is_safe_constant(value.as_python_constant()):
  244. output.append(self.create_load_const(value.as_python_constant()))
  245. elif isinstance(value, TensorWithTFOverrideVariable):
  246. graph_outputs_key = self.add_graph_output(value)
  247. self.add_push_null(
  248. lambda: self.load_import_from(utils.__name__, "to_subclass")
  249. )
  250. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  251. output.append(
  252. self.create_load_global(
  253. value.global_mangled_class_name(self.tx), add=True
  254. )
  255. )
  256. output.extend(create_call_function(2, False))
  257. elif (
  258. isinstance(value, SymNodeVariable)
  259. and value.python_type() == float
  260. and not self.tx.export
  261. ):
  262. # This is a little unusual; force the output convention to be a
  263. # Tensor here. Don't do this for export because this is
  264. # apparently load bearing for export tests (but I am a bit
  265. # doubtful it actually works in the real world)
  266. # NB: It works to add_graph_output on a computed expression
  267. # as_tensor here, because we memoize as_tensor calls on
  268. # SymNodeVariable!
  269. graph_outputs_key = self.add_graph_output(
  270. value.as_tensor(self.tx, torch.float64)
  271. )
  272. def gen_fn() -> None:
  273. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  274. output.append(self.create_load_attr("item"))
  275. self.add_push_null(gen_fn)
  276. output.extend(create_call_function(0, False))
  277. elif isinstance(
  278. value,
  279. (
  280. TensorVariable,
  281. SymNodeVariable,
  282. UnspecializedPythonVariable,
  283. NumpyNdarrayVariable,
  284. ),
  285. ):
  286. graph_outputs_key = self.add_graph_output(value)
  287. if isinstance(value, NumpyNdarrayVariable):
  288. self.add_push_null(
  289. lambda: self.load_import_from(utils.__name__, "to_numpy_helper")
  290. )
  291. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  292. output.extend(create_call_function(1, False))
  293. elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
  294. def gen_fn() -> None:
  295. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  296. output.append(self.create_load_attr("item"))
  297. self.add_push_null(gen_fn)
  298. output.extend(create_call_function(0, False))
  299. else:
  300. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  301. elif isinstance(value, NNModuleVariable):
  302. parts = value.module_key.split(".")
  303. if parts[0] in self.code_options["co_varnames"]:
  304. output.append(self.create_load(parts[0]))
  305. parts = parts[1:]
  306. else:
  307. assert self.root is not None
  308. output.append(self.create_load_const_unchecked(self.root))
  309. for part in parts:
  310. output.append(self.create_load_attr(part))
  311. else:
  312. self.uses[value] += 1
  313. try:
  314. self.call_reconstruct(value)
  315. except NotImplementedError:
  316. unimplemented_v2(
  317. gb_type="Reconstruction failure",
  318. context=str(value),
  319. explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.",
  320. hints=[
  321. "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable "
  322. "that Dynamo cannot reconstruct, then remove it from the return statement.",
  323. *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
  324. "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have "
  325. "reconstruction rules may be fundamentally unreconstructable.",
  326. ],
  327. )
  328. if allow_cache and value in self.tempvars:
  329. self._output.append(create_dup_top())
  330. self.add_cache(value)
  331. self.top_of_stack = value
  332. def add_graph_output(self, value: VariableTracker) -> int:
  333. graph_outputs_key = id(value.as_proxy())
  334. if graph_outputs_key not in self.graph_outputs:
  335. self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
  336. len(self.graph_outputs), value
  337. )
  338. return graph_outputs_key
  339. def load_graph_output(self, index: int) -> None:
  340. output = self._output
  341. assert self.graph_output_var is not None
  342. output.append(self.create_load(self.graph_output_var))
  343. output.append(self.create_load_const(index))
  344. output.append(self.create_binary_subscr())
  345. def add_cache(self, value: Union[VariableTracker, Source]) -> None:
  346. var = self.new_var()
  347. self.tempvars[value] = var
  348. self._output.append(self.create_store(var))
  349. def foreach(self, items: Iterable[Union[VariableTracker, Source]]) -> None:
  350. for i in items:
  351. self(i)
  352. def create_binary_subscr(self) -> Instruction:
  353. return create_instruction("BINARY_SUBSCR")
  354. def setup_globally_cached(self, name: str, value: Any) -> list[Instruction]:
  355. """Store value in a new global"""
  356. name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
  357. f_globals = self.tx.f_globals
  358. if name in f_globals:
  359. assert id(f_globals[name]) == id(value)
  360. else:
  361. f_globals[name] = value
  362. return [self.create_load_global(name, add=True)]
  363. def clear_tos(self) -> None:
  364. self.top_of_stack = None
  365. def append_output(self, inst: Instruction) -> None:
  366. assert isinstance(inst, Instruction)
  367. self._output.append(inst)
  368. self.clear_tos()
  369. def extend_output(self, insts: list[Instruction]) -> None:
  370. assert all(isinstance(x, Instruction) for x in insts)
  371. self._output.extend(insts)
  372. self.clear_tos()
  373. def get_instructions(self) -> list[Instruction]:
  374. return self._output
  375. def create_load(self, name: str) -> Instruction:
  376. assert name in self.code_options["co_varnames"], f"{name} missing"
  377. return create_instruction("LOAD_FAST", argval=name)
  378. def create_load_closure(self, name: str) -> Instruction:
  379. assert name in self.cell_and_freevars()
  380. inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
  381. return create_instruction(inst_name, argval=name)
  382. def create_load_deref(self, name: str) -> Instruction:
  383. assert name in self.cell_and_freevars()
  384. return create_instruction("LOAD_DEREF", argval=name)
  385. def create_store(self, name: str) -> Instruction:
  386. assert name in self.code_options["co_varnames"], f"{name} missing"
  387. return create_instruction("STORE_FAST", argval=name)
  388. def create_store_deref(self, name: str) -> Instruction:
  389. assert name in self.cell_and_freevars()
  390. return create_instruction("STORE_DEREF", argval=name)
  391. def create_load_global(self, name: str, add: bool = False) -> Instruction:
  392. if add:
  393. self.tx.output.update_co_names(name)
  394. assert name in self.code_options["co_names"], f"{name} not in co_names"
  395. return create_instruction("LOAD_GLOBAL", argval=name)
  396. def create_load_const(self, value: Any) -> Instruction:
  397. return create_load_const(value)
  398. def create_load_const_unchecked(self, value: Any) -> Instruction:
  399. return create_load_const(value, checked=False)
  400. def load_method(self, name: str) -> None:
  401. self.tx.output.update_co_names(name)
  402. self.append_output(create_load_method(name))
  403. def call_method(self, nargs: int) -> None:
  404. self.extend_output(create_call_method(nargs))
  405. def create_load_attr(self, name: str) -> Instruction:
  406. if name not in self.code_options["co_names"]:
  407. self.code_options["co_names"] += (name,)
  408. return create_instruction("LOAD_ATTR", argval=name)
  409. def load_attr(self, name: str) -> None:
  410. self.append_output(self.create_load_attr(name))
  411. def create_load_attrs(self, names: str) -> list[Instruction]:
  412. return [self.create_load_attr(name) for name in names.split(".")]
  413. def create_store_attr(self, name: str) -> Instruction:
  414. if name not in self.code_options["co_names"]:
  415. self.code_options["co_names"] += (name,)
  416. return create_instruction("STORE_ATTR", argval=name)
  417. def store_attr(self, name: str) -> None:
  418. self.append_output(self.create_store_attr(name))
  419. def load_function_name(
  420. self, fn_name: str, push_null: bool, num_on_stack: int = 0
  421. ) -> list[Instruction]:
  422. """Load the global fn_name on the stack num_on_stack down"""
  423. output = []
  424. if push_null and sys.version_info >= (3, 11):
  425. output.extend(add_push_null(self.create_load_global(fn_name, add=True)))
  426. if num_on_stack > 0:
  427. output.extend(
  428. [
  429. *self.rot_n(num_on_stack + 2),
  430. *self.rot_n(num_on_stack + 2),
  431. ]
  432. )
  433. else:
  434. output.extend(
  435. [
  436. self.create_load_global(fn_name, add=True),
  437. *self.rot_n(num_on_stack + 1),
  438. ]
  439. )
  440. return output
  441. def rot_n(self, n: int) -> list[Instruction]:
  442. try:
  443. return create_rot_n(n)
  444. except AttributeError:
  445. # desired rotate bytecode doesn't exist, generate equivalent bytecode
  446. return [
  447. create_instruction("BUILD_TUPLE", arg=n),
  448. self.create_load_const_unchecked(rot_n_helper(n)),
  449. *create_rot_n(2),
  450. create_instruction("CALL_FUNCTION_EX", arg=0),
  451. create_instruction("UNPACK_SEQUENCE", arg=n),
  452. ]
  453. def pop_top(self) -> None:
  454. self.append_output(create_instruction("POP_TOP"))
  455. def call_function(self, nargs: int, push_null: bool) -> None:
  456. self.extend_output(create_call_function(nargs, push_null=push_null))
  457. def dup_top(self) -> None:
  458. self.append_output(create_dup_top())
  459. def store(self, varname: str) -> None:
  460. self.append_output(self.create_store(varname))
  461. def load_deref(self, varname: str) -> None:
  462. self.append_output(self.create_load_deref(varname))
  463. def make_function_with_closure(
  464. self,
  465. tx: "InstructionTranslatorBase",
  466. fn_name: str,
  467. code: types.CodeType,
  468. push_null: bool,
  469. num_on_stack: int = 0,
  470. ) -> None:
  471. freevars = code.co_freevars
  472. assert freevars
  473. output = self._output
  474. def gen_fn() -> None:
  475. self.clear_tos()
  476. # Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
  477. # requires that in the generated bytecode, these cells would keep
  478. # their original local names, which we ensure via
  479. # `CellVariable.local_name`.
  480. for var in freevars:
  481. if tx is self.tx: # root frame
  482. assert var in self.cell_and_freevars()
  483. output.append(self.create_load_closure(var))
  484. else: # nested frame
  485. assert var in tx.cell_and_freevars()
  486. assert tx.post_prune_cell_and_freevars
  487. self(tx.post_prune_cell_and_freevars[var])
  488. output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
  489. output.append(self.create_load_const(code))
  490. if sys.version_info < (3, 11):
  491. output.append(self.create_load_const(fn_name))
  492. if sys.version_info >= (3, 13):
  493. output.extend(
  494. [
  495. create_instruction("MAKE_FUNCTION"),
  496. create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
  497. ]
  498. )
  499. else:
  500. output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
  501. if push_null and sys.version_info >= (3, 11):
  502. self.add_push_null(gen_fn)
  503. output.extend(self.rot_n(num_on_stack + 2))
  504. output.extend(self.rot_n(num_on_stack + 2))
  505. else:
  506. gen_fn()
  507. output.extend(self.rot_n(num_on_stack + 1))
  508. self.clear_tos()
  509. def create_load_python_module(self, mod: types.ModuleType) -> Instruction:
  510. """
  511. Generate a LOAD_GLOBAL instruction to fetch a given python module.
  512. """
  513. output = self.tx.output
  514. global_scope = output.global_scope
  515. name = re.sub(r"^.*[.]", "", mod.__name__)
  516. if global_scope.get(name, None) is mod:
  517. return self.create_load_global(name, add=True)
  518. prefix = f"___module_{name}"
  519. global_name = self.tx.output.install_global_by_id(prefix, mod)
  520. return self.create_load_global(global_name, add=True)
  521. def mark_source_temp(self, source: Source) -> None:
  522. """
  523. Mark a source as a temp variable, so that it can be reused.
  524. """
  525. if source not in self.tempvars:
  526. self.tempvars[source] = None
  527. def make_call_generated_code(self, fn_name: str) -> None:
  528. """Call the generated code function stored in fn_name"""
  529. self.extend_output(self.load_function_name(fn_name, True))
  530. graphargs = self.tx.output.graphargs
  531. seen_sources: OrderedSet[Source] = OrderedSet()
  532. def collect_temp_source(source: Source) -> None:
  533. if source in seen_sources:
  534. # This source is used at least twice, so it can be reused
  535. self.mark_source_temp(source)
  536. # Dont trace source further. This prevents us from marking too
  537. # many nodes as temp sources.
  538. return
  539. seen_sources.add(source)
  540. if isinstance(source, ChainedSource):
  541. collect_temp_source(source.base)
  542. if isinstance(source, DictGetItemSource) and isinstance(
  543. source.index, Source
  544. ):
  545. collect_temp_source(source.index)
  546. # Collect all the sources that are used more than once, so that we can
  547. # generate tmp variables in the generated pre-graph bytecode. This
  548. # essentially implements CSE.
  549. for arg in graphargs:
  550. if arg.source is not None:
  551. collect_temp_source(arg.source)
  552. cm_var = None
  553. if config.record_runtime_overhead:
  554. # Record the pregraph bytecode start
  555. self.add_push_null(
  556. lambda: self.load_import_from(
  557. utils.__name__, "record_pregraph_bytecode_enter"
  558. )
  559. )
  560. self.extend_output(create_call_function(0, False))
  561. cm_var = self.new_var()
  562. self.store(cm_var)
  563. for arg in graphargs:
  564. if arg.pass_arg_as_tensor:
  565. self.add_push_null(
  566. lambda: self.extend_output(
  567. [
  568. self.create_load_python_module(torch),
  569. self.create_load_attr("_as_tensor_fullprec"),
  570. ]
  571. )
  572. )
  573. self.call_reconstruct(arg)
  574. self.extend_output(create_call_function(1, False))
  575. else:
  576. self.call_reconstruct(arg)
  577. if config.record_runtime_overhead:
  578. # Record the pregraph bytecode end
  579. self.add_push_null(
  580. lambda: self.load_import_from(
  581. utils.__name__, "record_pregraph_bytecode_exit"
  582. )
  583. )
  584. assert cm_var is not None
  585. self.extend_output([self.create_load(cm_var)])
  586. self.extend_output(create_call_function(1, False))
  587. self.pop_top()
  588. self.extend_output(create_call_function(len(graphargs), False))
  589. def create_import_name(self, module_name: str) -> Instruction:
  590. return create_instruction("IMPORT_NAME", argval=module_name)
  591. def load_import_from(self, module_name: str, object_name: str) -> None:
  592. source = AttrSource(self.tx.import_source(module_name), object_name)
  593. # Note: This approach is somewhat aggressive because typically, a source is marked
  594. # as a tempvar only when it is used more than once. In this case, we're marking it
  595. # as a tempvar without performing that analysis. However, this is a simple solution,
  596. # and in many cases, load imports are reused multiple times.
  597. self.mark_source_temp(source)
  598. self(source)
  599. def create_call_function_kw(
  600. self, nargs: int, kw_names: Iterable[str], push_null: bool
  601. ) -> list[Instruction]:
  602. if sys.version_info >= (3, 13):
  603. output = create_call_function(nargs, push_null)
  604. assert output[-1].opname == "CALL"
  605. output.insert(-1, self.create_load_const(kw_names))
  606. output[-1] = create_instruction("CALL_KW", arg=nargs)
  607. return output
  608. elif sys.version_info >= (3, 11):
  609. output = create_call_function(nargs, push_null)
  610. if sys.version_info >= (3, 12):
  611. idx = -1
  612. expected_inst = "CALL"
  613. else:
  614. idx = -2
  615. expected_inst = "PRECALL"
  616. assert output[idx].opname == expected_inst
  617. kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
  618. output.insert(idx, kw_names_inst)
  619. return output
  620. return [
  621. self.create_load_const(kw_names),
  622. create_instruction("CALL_FUNCTION_KW", arg=nargs),
  623. ]
  624. def create_delete(self, value: object) -> Instruction:
  625. return create_instruction("DELETE_FAST", argval=value)