codegen.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710
  1. """
  2. This module provides utilities for generating Python bytecode in PyTorch's Dynamo system.
  3. It includes functionality for:
  4. - Constructing bytecode sequences for Python operations
  5. - Managing stack operations and variable tracking
  6. - Handling graph outputs and their conversions
  7. - Supporting different Python versions (3.11+, 3.12+, 3.13+)
  8. - Converting high-level operations to low-level bytecode instructions
  9. - Managing constant loading and attribute access
  10. - Supporting function creation and closure handling
  11. """
  12. import collections
  13. import dataclasses
  14. import re
  15. import sys
  16. import types
  17. from collections import Counter, deque
  18. from collections.abc import Callable, Iterable
  19. from typing import Any, Optional, TYPE_CHECKING, Union
  20. import torch.nn
  21. from torch.utils._ordered_set import OrderedSet
  22. from . import config, graph_break_hints, utils
  23. from .bytecode_transformation import (
  24. add_push_null,
  25. add_push_null_call_function_ex,
  26. create_binary_subscr,
  27. create_build_tuple,
  28. create_call_function,
  29. create_call_function_ex,
  30. create_call_method,
  31. create_dup_top,
  32. create_instruction,
  33. create_load_const,
  34. create_load_method,
  35. create_rot_n,
  36. Instruction,
  37. )
  38. from .exc import IncorrectUsage, unimplemented
  39. from .source import AttrSource, ChainedSource, DictGetItemSource, Source
  40. from .utils import is_safe_constant, rot_n_helper
  41. from .variables.base import ValueMutationExisting, VariableTracker
  42. from .variables.functions import (
  43. ContextlibContextManagerLocalGeneratorObjectVariable,
  44. LocalGeneratorObjectVariable,
  45. )
  46. from .variables.nn_module import NNModuleVariable
  47. from .variables.tensor import (
  48. NumpyNdarrayVariable,
  49. SymNodeVariable,
  50. TensorVariable,
  51. UnspecializedPythonVariable,
  52. )
  53. from .variables.torch_function import TensorWithTFOverrideVariable
  54. if TYPE_CHECKING:
  55. from torch._dynamo.variables.builder import GraphArg
  56. from .symbolic_convert import InstructionTranslatorBase
  57. @dataclasses.dataclass
  58. class GraphOutputEntry:
  59. index: int
  60. variable: VariableTracker
  61. class PyCodegen:
  62. """
  63. Helper class uses for constructing Python bytecode
  64. """
  65. def __init__(
  66. self,
  67. tx: "InstructionTranslatorBase",
  68. root: Optional[torch.nn.Module] = None,
  69. graph_output_var: Optional[str] = None,
  70. tempvars: Optional[dict[Union[VariableTracker, Source], Any]] = None,
  71. overridden_sources: Optional[dict[Source, Source]] = None,
  72. ) -> None:
  73. self.root = root
  74. self.top_of_stack: Optional[Union[VariableTracker, Source]] = None
  75. self.uses: Counter[Union[VariableTracker, Source]] = collections.Counter()
  76. self.graph_outputs: dict[int, GraphOutputEntry] = {}
  77. self._output: list[Instruction] = []
  78. # This determines which VariableTracker/Source should be stored as
  79. # locals, and maps the VariableTracker/Source to the local variable
  80. # name. Note that it could map to None initially, in which case we'll
  81. # overwrite it to map to real temporary names via `add_cache`.
  82. self.tempvars: dict[Union[VariableTracker, Source], Any] = tempvars or {}
  83. self.tx = tx
  84. self.graph_output_var = graph_output_var
  85. self.code_options = self.tx.output.code_options
  86. self.cell_and_freevars = self.tx.cell_and_freevars
  87. self.new_var = self.tx.output.new_var
  88. self.value_from_source: bool = True
  89. # This serves as a way for codegen to use a different source; we need
  90. # this because sometimes we can't easily modify the original source
  91. # without affecting other components, e.g., guards.
  92. self.overridden_sources: dict[Source, Source] = overridden_sources or {}
  93. def restore_stack(
  94. self, stack_values: list[Any], *, value_from_source: bool = True
  95. ) -> None:
  96. prev = self.value_from_source
  97. self.value_from_source &= value_from_source
  98. try:
  99. self.foreach(stack_values)
  100. finally:
  101. self.value_from_source = prev
  102. def graph_output_vars(self) -> list[VariableTracker]:
  103. return [x.variable for x in self.graph_outputs.values()]
  104. def call_reconstruct(
  105. self, value: Union[VariableTracker, Source, "GraphArg"]
  106. ) -> None:
  107. res = value.reconstruct(self)
  108. assert res is None, f"reconstruct!=None {value}"
  109. def add_push_null(
  110. self, gen_fn: Callable[[], None], call_function_ex: bool = False
  111. ) -> None:
  112. """
  113. `gen_fn` generates instructions via PyCodegen methods
  114. that push a single callable to the stack.
  115. `add_push_null` pushes a NULL to the stack before or after the
  116. instructions generated by `gen_fn`, depending on Python version.
  117. Will attempt to use the NULL push bit for instructions
  118. with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
  119. """
  120. old_len = len(self._output)
  121. if sys.version_info < (3, 13):
  122. # gen_fn may DUP_TOP instead if TOS is not cleared.
  123. # Will cause problems since NULL will be pushed right
  124. # before the generated instructions in <= 3.12
  125. self.clear_tos()
  126. gen_fn()
  127. # inplace modify self._output
  128. added_insts = self._output[old_len:]
  129. del self._output[old_len:]
  130. if call_function_ex:
  131. self._output.extend(add_push_null_call_function_ex(added_insts))
  132. else:
  133. self._output.extend(add_push_null(added_insts))
  134. if sys.version_info >= (3, 13):
  135. # NULL will be at top of stack
  136. self.clear_tos()
  137. def __call__(
  138. self, value: Union[VariableTracker, Source, None], allow_cache: bool = True
  139. ) -> None:
  140. """
  141. Generate code such that top-of-stack (TOS) is set to value.
  142. `allow_cache` controls the behavior in the following manner. `value` can
  143. either be a VariableTracker or a Source.
  144. If `value` is a `Source`, `allow_cache` must be True (invariant asserted
  145. below). If the source was reconstructed earlier, we will reuse the
  146. generated code by loading from top of stack or tempvars.
  147. If `value` is a `VariableTracker`, we have the following cases:
  148. 1) `allow_cache=True`
  149. a) If the value.source is not None, we will emit the code based on
  150. `value.source` to handle aliasing.
  151. b) If value.source is None (example reconstructing a local list
  152. returned by the compiled function), we will reconstruct the variable
  153. tracker (w/o any source) to emit bytecode that generates a new
  154. python object.
  155. In both cases of value.source being None or not, if the value was
  156. reconstructed earlier, we will reuse the generated code by loading from
  157. top of stack or tempvars.
  158. 2) `allow_cache=False` - This is a special case (allow_cache defaults to
  159. True).
  160. a) If the value.source is not None, we reconstruct the variable
  161. tracker and emit a new python object. You might wonder what about
  162. aliasing? The place where we use this config also has the followup
  163. code where the original python object is assigned to this new python
  164. value to handle aliasing (check side_effects.py and search for
  165. allow_cache=False).
  166. b) If value.source is None, this is not allowed
  167. Notable effects:
  168. 1. `self.top_of_stack` will be set to `value`, if we don't codegen
  169. `value` based on source.
  170. 2. `self.uses[value]` will increment, unless (a). we codegen via
  171. `top_of_stack` or cached `tempvars`, or (b). `value` has special VT
  172. types like `NNModuleVariable`, etc.
  173. """
  174. assert value is not None
  175. if isinstance(value, Source):
  176. # If the source needs to be overridden, use the new one.
  177. source = self.overridden_sources.get(value, value)
  178. assert allow_cache is True, "allow_cache must be True for Source"
  179. if self.top_of_stack is value:
  180. self._output.append(create_dup_top())
  181. return
  182. if self.tempvars.get(source) is not None:
  183. self._output.append(self.create_load(self.tempvars[source]))
  184. self.top_of_stack = source
  185. return
  186. self.uses[source] += 1
  187. try:
  188. self.call_reconstruct(source)
  189. except NotImplementedError:
  190. unimplemented(
  191. gb_type="Reconstruction failure: source.reconstruct not implemented",
  192. context=str(source),
  193. explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.",
  194. hints=[*graph_break_hints.DYNAMO_BUG],
  195. )
  196. if source in self.tempvars:
  197. self._output.append(create_dup_top())
  198. self.add_cache(source)
  199. self.top_of_stack = source
  200. return
  201. assert isinstance(value, VariableTracker)
  202. output = self._output
  203. graph_outputs = self.graph_outputs
  204. if allow_cache:
  205. if self.top_of_stack is value:
  206. output.append(create_dup_top())
  207. return
  208. if self.tempvars.get(value) is not None:
  209. output.append(self.create_load(self.tempvars[value]))
  210. self.top_of_stack = value
  211. return
  212. if value.is_realized() and isinstance(
  213. value, ContextlibContextManagerLocalGeneratorObjectVariable
  214. ):
  215. raise IncorrectUsage(
  216. "NYI: Returning a @contextmanager object from a torch.compile function"
  217. )
  218. # Dynamo normally prefers codegen from source to account for aliasing.
  219. if (
  220. value.source is not None
  221. and allow_cache
  222. and not (
  223. value.is_realized() and isinstance(value, LocalGeneratorObjectVariable)
  224. )
  225. ):
  226. # There's a corner case for export: for instance, if the computation
  227. # graph is just identity on an input tensor, Dynamo would just emit
  228. # a `LOAD_FAST` from the input source, rather than generating an
  229. # identity FX graph.
  230. #
  231. # However, export wants to maximize graph capture; in the case
  232. # above, export _wants to_ obtain an identity FX graph (despite it
  233. # appears unnecessarily expensive for `torch.compile`), so we have
  234. # the following option to override Dynamo's preference for codegen
  235. # from source. Moreover, this option applies recursively, for cases
  236. # like input tensor being returned in a new dictionary.
  237. #
  238. # And why the `ValueMutationExisting` check? Not sure, so leaving it
  239. # to keep the old behavior, as when `value_from_source` was
  240. # introduced. TODO sort out the invariants among side effect,
  241. # codegen and export.
  242. if (
  243. isinstance(value.mutation_type, ValueMutationExisting)
  244. or self.value_from_source
  245. ):
  246. return self(value.source)
  247. if value.is_python_constant() and is_safe_constant(value.as_python_constant()):
  248. output.append(self.create_load_const(value.as_python_constant()))
  249. elif isinstance(value, TensorWithTFOverrideVariable):
  250. graph_outputs_key = self.add_graph_output(value)
  251. self.add_push_null(
  252. lambda: self.load_import_from(utils.__name__, "to_subclass")
  253. )
  254. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  255. output.append(
  256. self.create_load_global(
  257. value.global_mangled_class_name(self.tx), # type: ignore[arg-type]
  258. add=True,
  259. )
  260. )
  261. output.extend(create_call_function(2, False))
  262. elif (
  263. isinstance(value, SymNodeVariable)
  264. and value.python_type() is float
  265. and not self.tx.export
  266. ):
  267. # This is a little unusual; force the output convention to be a
  268. # Tensor here. Don't do this for export because this is
  269. # apparently load bearing for export tests (but I am a bit
  270. # doubtful it actually works in the real world)
  271. # NB: It works to add_graph_output on a computed expression
  272. # as_tensor here, because we memoize as_tensor calls on
  273. # SymNodeVariable!
  274. graph_outputs_key = self.add_graph_output(
  275. value.as_tensor(self.tx, torch.float64)
  276. )
  277. def gen_fn() -> None:
  278. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  279. output.append(self.create_load_attr("item"))
  280. self.add_push_null(gen_fn)
  281. output.extend(create_call_function(0, False))
  282. elif isinstance(
  283. value,
  284. (
  285. TensorVariable,
  286. SymNodeVariable,
  287. UnspecializedPythonVariable,
  288. NumpyNdarrayVariable,
  289. ),
  290. ):
  291. graph_outputs_key = self.add_graph_output(value)
  292. if isinstance(value, NumpyNdarrayVariable):
  293. self.add_push_null(
  294. lambda: self.load_import_from(utils.__name__, "to_numpy_helper")
  295. )
  296. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  297. output.extend(create_call_function(1, False))
  298. elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
  299. def gen_fn() -> None:
  300. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  301. output.append(self.create_load_attr("item"))
  302. self.add_push_null(gen_fn)
  303. output.extend(create_call_function(0, False))
  304. else:
  305. self.load_graph_output(graph_outputs[graph_outputs_key].index)
  306. elif isinstance(value, NNModuleVariable):
  307. parts = value.module_key.split(".")
  308. if parts[0] in self.code_options["co_varnames"]:
  309. output.append(self.create_load(parts[0]))
  310. parts = parts[1:]
  311. else:
  312. assert self.root is not None
  313. output.append(self.create_load_const_unchecked(self.root))
  314. for part in parts:
  315. output.append(self.create_load_attr(part))
  316. else:
  317. self.uses[value] += 1
  318. try:
  319. self.call_reconstruct(value)
  320. except NotImplementedError:
  321. unimplemented(
  322. gb_type="Reconstruction failure",
  323. context=str(value),
  324. explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.",
  325. hints=[
  326. "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable "
  327. "that Dynamo cannot reconstruct, then remove it from the return statement.",
  328. *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
  329. "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have "
  330. "reconstruction rules may be fundamentally unreconstructable.",
  331. ],
  332. )
  333. if allow_cache and value in self.tempvars:
  334. self._output.append(create_dup_top())
  335. self.add_cache(value)
  336. self.top_of_stack = value
  337. def add_graph_output(self, value: VariableTracker) -> int:
  338. graph_outputs_key = id(value.as_proxy())
  339. if graph_outputs_key not in self.graph_outputs:
  340. self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
  341. len(self.graph_outputs), value
  342. )
  343. return graph_outputs_key
  344. def load_graph_output(self, index: int) -> None:
  345. output = self._output
  346. assert self.graph_output_var is not None
  347. output.append(self.create_load(self.graph_output_var))
  348. output.append(self.create_load_const(index))
  349. output.append(self.create_binary_subscr())
  350. def add_cache(self, value: Union[VariableTracker, Source]) -> None:
  351. var = self.new_var()
  352. self.tempvars[value] = var
  353. self._output.append(self.create_store(var))
  354. def foreach(self, items: Iterable[Union[VariableTracker, Source]]) -> None:
  355. for i in items:
  356. self(i)
  357. def create_binary_subscr(self) -> Instruction:
  358. return create_binary_subscr()
  359. def setup_globally_cached(self, name: str, value: Any) -> list[Instruction]:
  360. """Store value in a new global"""
  361. name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
  362. f_globals = self.tx.f_globals
  363. if name in f_globals:
  364. assert id(f_globals[name]) == id(value)
  365. else:
  366. f_globals[name] = value
  367. return [self.create_load_global(name, add=True)]
  368. def clear_tos(self) -> None:
  369. self.top_of_stack = None
  370. def append_output(self, inst: Instruction) -> None:
  371. assert isinstance(inst, Instruction)
  372. self._output.append(inst)
  373. self.clear_tos()
  374. def extend_output(self, insts: list[Instruction]) -> None:
  375. assert all(isinstance(x, Instruction) for x in insts)
  376. self._output.extend(insts)
  377. self.clear_tos()
  378. def get_instructions(self) -> list[Instruction]:
  379. return self._output
  380. def create_load(self, name: str) -> Instruction:
  381. assert name in self.code_options["co_varnames"], f"{name} missing"
  382. return create_instruction("LOAD_FAST", argval=name)
  383. def create_load_closure(self, name: str) -> Instruction:
  384. assert name in self.cell_and_freevars()
  385. inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
  386. return create_instruction(inst_name, argval=name)
  387. def create_load_deref(self, name: str) -> Instruction:
  388. assert name in self.cell_and_freevars()
  389. return create_instruction("LOAD_DEREF", argval=name)
  390. def create_store(self, name: str) -> Instruction:
  391. assert name in self.code_options["co_varnames"], f"{name} missing"
  392. return create_instruction("STORE_FAST", argval=name)
  393. def create_store_deref(self, name: str) -> Instruction:
  394. assert name in self.cell_and_freevars()
  395. return create_instruction("STORE_DEREF", argval=name)
  396. def create_load_global(self, name: str, add: bool = False) -> Instruction:
  397. if add:
  398. self.tx.output.update_co_names(name)
  399. assert name in self.code_options["co_names"], f"{name} not in co_names"
  400. return create_instruction("LOAD_GLOBAL", argval=name)
  401. def create_load_const(self, value: Any) -> Instruction:
  402. return create_load_const(value)
  403. def create_load_const_unchecked(self, value: Any) -> Instruction:
  404. return create_load_const(value, checked=False)
  405. def load_method(self, name: str) -> None:
  406. self.tx.output.update_co_names(name)
  407. self.append_output(create_load_method(name))
  408. def call_method(self, nargs: int) -> None:
  409. self.extend_output(create_call_method(nargs))
  410. def create_load_attr(self, name: str) -> Instruction:
  411. if name not in self.code_options["co_names"]:
  412. self.code_options["co_names"] += (name,)
  413. return create_instruction("LOAD_ATTR", argval=name)
  414. def load_attr(self, name: str) -> None:
  415. self.append_output(self.create_load_attr(name))
  416. def create_load_attrs(self, names: str) -> list[Instruction]:
  417. return [self.create_load_attr(name) for name in names.split(".")]
  418. def create_store_attr(self, name: str) -> Instruction:
  419. if name not in self.code_options["co_names"]:
  420. self.code_options["co_names"] += (name,)
  421. return create_instruction("STORE_ATTR", argval=name)
  422. def store_attr(self, name: str) -> None:
  423. self.append_output(self.create_store_attr(name))
  424. def load_function_name(
  425. self, fn_name: str, push_null: bool, num_on_stack: int = 0
  426. ) -> list[Instruction]:
  427. """Load the global fn_name on the stack num_on_stack down"""
  428. output = []
  429. if push_null and sys.version_info >= (3, 11):
  430. output.extend(add_push_null(self.create_load_global(fn_name, add=True)))
  431. if num_on_stack > 0:
  432. output.extend(
  433. [
  434. *self.rot_n(num_on_stack + 2),
  435. *self.rot_n(num_on_stack + 2),
  436. ]
  437. )
  438. else:
  439. output.extend(
  440. [
  441. self.create_load_global(fn_name, add=True),
  442. *self.rot_n(num_on_stack + 1),
  443. ]
  444. )
  445. return output
  446. def rot_n(self, n: int) -> list[Instruction]:
  447. try:
  448. return create_rot_n(n)
  449. except AttributeError:
  450. # desired rotate bytecode doesn't exist, generate equivalent bytecode
  451. return [
  452. create_build_tuple(n),
  453. self.create_load_const_unchecked(rot_n_helper(n)),
  454. *create_rot_n(2),
  455. *create_call_function_ex(False, False),
  456. create_instruction("UNPACK_SEQUENCE", arg=n),
  457. ]
  458. def pop_top(self) -> None:
  459. self.append_output(create_instruction("POP_TOP"))
  460. def call_function(self, nargs: int, push_null: bool) -> None:
  461. self.extend_output(create_call_function(nargs, push_null=push_null))
  462. def dup_top(self) -> None:
  463. self.append_output(create_dup_top())
  464. def store(self, varname: str) -> None:
  465. self.append_output(self.create_store(varname))
  466. def load_deref(self, varname: str) -> None:
  467. self.append_output(self.create_load_deref(varname))
  468. def make_function_with_closure(
  469. self,
  470. fn_name: str,
  471. code: types.CodeType,
  472. ) -> None:
  473. """Creates a closure with code object `code`.
  474. Expects the TOS to be the tuple of cells to use for this closure.
  475. TOS will be popped to create the closure.
  476. Args:
  477. - fn_name: name of the function
  478. - code: code object of the function
  479. (does not include the tuple of cells on the TOS)
  480. """
  481. output = self._output
  482. output.append(self.create_load_const(code))
  483. if sys.version_info < (3, 11):
  484. output.append(self.create_load_const(fn_name))
  485. if sys.version_info >= (3, 13):
  486. output.extend(
  487. [
  488. create_instruction("MAKE_FUNCTION"),
  489. create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
  490. ]
  491. )
  492. else:
  493. output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
  494. self.clear_tos()
  495. def create_load_python_module(self, mod: types.ModuleType) -> Instruction:
  496. """
  497. Generate a LOAD_GLOBAL instruction to fetch a given python module.
  498. """
  499. output = self.tx.output
  500. global_scope = output.global_scope
  501. name = re.sub(r"^.*[.]", "", mod.__name__)
  502. if global_scope.get(name, None) is mod:
  503. return self.create_load_global(name, add=True)
  504. prefix = f"___module_{name}"
  505. global_name = self.tx.output.install_global_by_id(prefix, mod)
  506. return self.create_load_global(global_name, add=True)
  507. def mark_source_temp(self, source: Source) -> None:
  508. """
  509. Mark a source as a temp variable, so that it can be reused.
  510. """
  511. if source not in self.tempvars:
  512. self.tempvars[source] = None
  513. def make_call_generated_code(self, fn_name: str) -> None:
  514. """Call the generated code function stored in fn_name"""
  515. self.extend_output(self.load_function_name(fn_name, True))
  516. graphargs = self.tx.output.graphargs
  517. def extract_nested_sources(source: Source) -> list[Source]:
  518. nested_sources: list[Source] = []
  519. if isinstance(source, ChainedSource):
  520. nested_sources.append(source.base)
  521. if isinstance(source, DictGetItemSource) and isinstance(
  522. source.index, Source
  523. ):
  524. nested_sources.append(source.index)
  525. return nested_sources
  526. def collect_temp_sources(sources: deque[Source], codegen: PyCodegen) -> None:
  527. seen_sources: OrderedSet[Source] = OrderedSet()
  528. while sources:
  529. current_source = sources.popleft()
  530. if current_source in seen_sources:
  531. # This source is used at least twice, so it can be reused
  532. codegen.mark_source_temp(current_source)
  533. # Dont trace source further. This prevents us from marking too
  534. # many nodes as temp sources.
  535. continue
  536. seen_sources.add(current_source)
  537. sources.extend(extract_nested_sources(current_source))
  538. # Collect all the sources that are used more than once, so that we can
  539. # generate tmp variables in the generated pre-graph bytecode. This
  540. # essentially implements CSE.
  541. collect_temp_sources(
  542. deque([arg.source for arg in graphargs if arg.source is not None]), self
  543. )
  544. cm_var = None
  545. if config.record_runtime_overhead:
  546. # Record the pregraph bytecode start
  547. self.add_push_null(
  548. lambda: self.load_import_from(
  549. utils.__name__, "record_pregraph_bytecode_enter"
  550. )
  551. )
  552. self.extend_output(create_call_function(0, False))
  553. cm_var = self.new_var()
  554. self.store(cm_var)
  555. for arg in graphargs:
  556. if arg.pass_arg_as_tensor:
  557. self.add_push_null(
  558. lambda: self.extend_output(
  559. [
  560. self.create_load_python_module(torch),
  561. self.create_load_attr("_as_tensor_fullprec"),
  562. ]
  563. )
  564. )
  565. self.call_reconstruct(arg)
  566. self.extend_output(create_call_function(1, False))
  567. else:
  568. self.call_reconstruct(arg)
  569. if config.record_runtime_overhead:
  570. # Record the pregraph bytecode end
  571. self.add_push_null(
  572. lambda: self.load_import_from(
  573. utils.__name__, "record_pregraph_bytecode_exit"
  574. )
  575. )
  576. assert cm_var is not None
  577. self.extend_output([self.create_load(cm_var)])
  578. self.extend_output(create_call_function(1, False))
  579. self.pop_top()
  580. self.extend_output(create_call_function(len(graphargs), False))
  581. def create_import_name(self, module_name: str) -> Instruction:
  582. return create_instruction("IMPORT_NAME", argval=module_name)
  583. def load_import_from(self, module_name: str, object_name: str) -> None:
  584. source = AttrSource(self.tx.import_source(module_name), object_name)
  585. # Note: This approach is somewhat aggressive because typically, a source is marked
  586. # as a tempvar only when it is used more than once. In this case, we're marking it
  587. # as a tempvar without performing that analysis. However, this is a simple solution,
  588. # and in many cases, load imports are reused multiple times.
  589. self.mark_source_temp(source)
  590. self(source)
  591. def create_call_function_kw(
  592. self, nargs: int, kw_names: Iterable[str], push_null: bool
  593. ) -> list[Instruction]:
  594. if sys.version_info >= (3, 13):
  595. output = create_call_function(nargs, push_null)
  596. assert output[-1].opname == "CALL"
  597. output.insert(-1, self.create_load_const(kw_names))
  598. output[-1] = create_instruction("CALL_KW", arg=nargs)
  599. return output
  600. elif sys.version_info >= (3, 11):
  601. output = create_call_function(nargs, push_null)
  602. if sys.version_info >= (3, 12):
  603. idx = -1
  604. expected_inst = "CALL"
  605. else:
  606. idx = -2
  607. expected_inst = "PRECALL"
  608. assert output[idx].opname == expected_inst
  609. kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
  610. output.insert(idx, kw_names_inst)
  611. return output
  612. return [
  613. self.create_load_const(kw_names),
  614. create_instruction("CALL_FUNCTION_KW", arg=nargs),
  615. ]
  616. def create_delete(self, value: object) -> Instruction:
  617. return create_instruction("DELETE_FAST", argval=value)