resume_execution.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723
  1. """
  2. This module provides functionality for resuming Python execution at specific points in code,
  3. primarily used by PyTorch Dynamo for control flow handling and optimization. It implements
  4. bytecode transformation and execution state management to enable:
  5. - Resuming execution at arbitrary points in Python bytecode
  6. - Managing context managers and their state across execution boundaries
  7. - Transforming and generating new code objects with preserved execution state
  8. - Supporting Python 3.11+ exception handling and block management
  9. - Restoring torch function mode stacks and other execution context
  10. The module is critical for PyTorch Dynamo's ability to optimize code while preserving
  11. Python semantics and execution state.
  12. """
  13. import copy
  14. import dataclasses
  15. import sys
  16. import types
  17. from collections.abc import Iterable
  18. from contextlib import AbstractContextManager
  19. from typing import Any, Callable, cast, Optional
  20. from .bytecode_transformation import (
  21. add_push_null,
  22. bytecode_from_template,
  23. create_call_function,
  24. create_instruction,
  25. create_jump_absolute,
  26. create_load_const,
  27. Instruction,
  28. overwrite_instruction,
  29. transform_code_object,
  30. unique_id,
  31. )
  32. from .utils import ExactWeakKeyDictionary
  33. # taken from code.h in cpython
  34. CO_OPTIMIZED = 0x0001
  35. CO_NEWLOCALS = 0x0002
  36. CO_VARARGS = 0x0004
  37. CO_VARKEYWORDS = 0x0008
  38. CO_NESTED = 0x0010
  39. CO_GENERATOR = 0x0020
  40. CO_NOFREE = 0x0040
  41. CO_COROUTINE = 0x0080
  42. CO_ITERABLE_COROUTINE = 0x0100
  43. CO_ASYNC_GENERATOR = 0x0200
  44. # trace_rules.py import this constant for consistency
  45. TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in"
  46. IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue"
  47. def _initial_push_null(insts: list[Instruction]) -> None:
  48. if sys.version_info >= (3, 11):
  49. insts.append(create_instruction("PUSH_NULL"))
  50. if sys.version_info < (3, 13):
  51. insts.append(create_instruction("SWAP", arg=2))
  52. # Generates bytecode from template and splits the code where LOAD_FAST dummy is present.
  53. def _bytecode_from_template_with_split(
  54. template: Callable[..., Any],
  55. stack_index: int,
  56. varname_map: Optional[dict[str, Any]] = None,
  57. ) -> tuple[list[Instruction], list[Instruction]]:
  58. template_code = bytecode_from_template(template, varname_map=varname_map)
  59. template_code.append(create_instruction("POP_TOP"))
  60. # adjust exception table entry depth
  61. for inst in template_code:
  62. if inst.exn_tab_entry:
  63. inst.exn_tab_entry.depth += stack_index
  64. # search for LOAD_FAST dummy and replace it with 2 NOPs (we can break up the bytecode between them)
  65. dummy_idx, dummy_inst = next(
  66. (
  67. (i, inst)
  68. for i, inst in enumerate(template_code)
  69. if inst.opname == "LOAD_FAST" and inst.argval == "dummy"
  70. ),
  71. (None, None),
  72. )
  73. assert dummy_idx is not None and dummy_inst is not None
  74. # replace LOAD_FAST dummy with first NOP marking exception area
  75. overwrite_instruction(dummy_inst, [create_instruction("NOP")])
  76. # POP_TOP follows LOAD_FAST dummy - replace with NOP marking end of exception area
  77. assert template_code[dummy_idx + 1].opname == "POP_TOP"
  78. overwrite_instruction(template_code[dummy_idx + 1], [create_instruction("NOP")])
  79. return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :]
  80. def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
  81. # NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
  82. # on torch._dynamo.utils.
  83. global __import_torch_dot__dynamo_dot_utils
  84. try:
  85. dummy
  86. except: # noqa: E722, B001
  87. __import_torch_dot__dynamo_dot_utils.set_torch_function_mode_stack( # type: ignore[name-defined]
  88. stack_var_name
  89. )
  90. raise
  91. @dataclasses.dataclass(frozen=True)
  92. class ReenterWith:
  93. stack_index: int
  94. target_values: Optional[tuple[Any, ...]] = None
  95. def try_except_torch_function_mode(
  96. self, code_options: dict[str, Any], cleanup: list[Instruction]
  97. ) -> list[Instruction]:
  98. """
  99. Codegen based off of:
  100. try:
  101. (rest)
  102. except:
  103. (restore previous tf mode stack)
  104. raise
  105. """
  106. from .variables.torch_function import get_prev_stack_var_name
  107. setup_try_except, epilogue = _bytecode_from_template_with_split(
  108. _try_except_tf_mode_template,
  109. self.stack_index,
  110. varname_map={"stack_var_name": get_prev_stack_var_name()},
  111. )
  112. cleanup[:] = epilogue + cleanup
  113. return setup_try_except
  114. # If we do not want to destroy the stack, we can do the same thing as a
  115. # `SETUP_WITH` block, only that we store the context manager in a local_symbol
  116. def try_finally(
  117. self, code_options: dict[str, Any], cleanup: list[Instruction]
  118. ) -> list[Instruction]:
  119. """
  120. Codegen based off of:
  121. load args
  122. enter context
  123. try:
  124. (rest)
  125. finally:
  126. exit context
  127. """
  128. # NOTE: we assume that TOS is a context manager CLASS!
  129. load_args = []
  130. if self.target_values:
  131. load_args = [create_load_const(val) for val in self.target_values]
  132. ctx_name = unique_id(f"___context_manager_{self.stack_index}")
  133. if ctx_name not in code_options["co_varnames"]:
  134. code_options["co_varnames"] += (ctx_name,)
  135. for name in ["__enter__", "__exit__"]:
  136. if name not in code_options["co_names"]:
  137. code_options["co_names"] += (name,)
  138. create_ctx: list[Instruction] = []
  139. _initial_push_null(create_ctx)
  140. create_ctx.extend(
  141. [
  142. *load_args,
  143. *create_call_function(len(load_args), False),
  144. create_instruction("STORE_FAST", argval=ctx_name),
  145. ]
  146. )
  147. def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
  148. ctx.__enter__()
  149. try:
  150. dummy
  151. finally:
  152. ctx.__exit__(None, None, None)
  153. setup_try_finally, epilogue = _bytecode_from_template_with_split(
  154. _template, self.stack_index, varname_map={"ctx": ctx_name}
  155. )
  156. cleanup[:] = epilogue + cleanup
  157. return create_ctx + setup_try_finally
  158. def __call__(
  159. self, code_options: dict[str, Any], cleanup: list[Instruction]
  160. ) -> tuple[list[Instruction], Optional[Instruction]]:
  161. """
  162. Codegen based off of:
  163. with ctx(args):
  164. (rest)
  165. """
  166. # NOTE: we assume that TOS is a context manager CLASS!
  167. load_args = []
  168. if self.target_values:
  169. load_args = [create_load_const(val) for val in self.target_values]
  170. create_ctx: list[Instruction] = []
  171. _initial_push_null(create_ctx)
  172. create_ctx.extend(
  173. [
  174. *load_args,
  175. *create_call_function(len(load_args), False),
  176. ]
  177. )
  178. def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
  179. with ctx:
  180. dummy
  181. setup_with, epilogue = _bytecode_from_template_with_split(
  182. _template, self.stack_index
  183. )
  184. cleanup[:] = epilogue + cleanup
  185. load_fast_ctx_inst = next(
  186. (
  187. inst
  188. for inst in setup_with
  189. if inst.opname == "LOAD_FAST" and inst.argval == "ctx"
  190. ),
  191. None,
  192. )
  193. assert load_fast_ctx_inst is not None
  194. # ctx already loaded on stack before the template - no need to LOAD_FAST
  195. overwrite_instruction(load_fast_ctx_inst, [create_instruction("NOP")])
  196. # 3.11+ only
  197. push_exc_info_gen = (
  198. inst for inst in epilogue if inst.opname == "PUSH_EXC_INFO"
  199. )
  200. push_exc_info_inst = next(push_exc_info_gen, None)
  201. # expect only 1 PUSH_EXC_INFO in epilogue
  202. assert next(push_exc_info_gen, None) is None
  203. return create_ctx + setup_with, push_exc_info_inst
  204. @dataclasses.dataclass
  205. class ResumeFunctionMetadata:
  206. code: types.CodeType
  207. instructions: list[Instruction] = dataclasses.field(default_factory=list)
  208. # Python 3.11+ fields
  209. # NOTE: Python 3.11 removed blocks, but for our purposes, a "block" consists
  210. # of instructions of all exception table entries that have the same target.
  211. # map from PUSH_EXC_INFO's in the prefix to original block target offset
  212. prefix_block_target_offset_remap: list[int] = dataclasses.field(
  213. default_factory=list
  214. )
  215. # per-offset map from new block target offsets to original block target offsets
  216. block_target_offset_remap: dict[tuple[int, int], dict[int, int]] = (
  217. dataclasses.field(default_factory=dict)
  218. )
  219. def _filter_iter(
  220. l1: Iterable[Any],
  221. l2: Iterable[Any],
  222. cond: Callable[[Any, Any], bool],
  223. ) -> list[Any]:
  224. """
  225. Two-pointer conditional filter.
  226. e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o)
  227. returns the instructions with offsets in sorted_offsets
  228. """
  229. it = iter(l2)
  230. res: list[Instruction] = []
  231. try:
  232. cur = next(it)
  233. for val in l1:
  234. if cond(val, cur):
  235. res.append(val)
  236. cur = next(it)
  237. except StopIteration:
  238. pass
  239. return res
  240. def _load_tuple_and_call(tup: tuple[Any, ...]) -> list[Instruction]:
  241. insts: list[Instruction] = []
  242. _initial_push_null(insts)
  243. insts.extend(create_load_const(val) for val in tup)
  244. insts.extend(create_call_function(len(tup), False))
  245. return insts
  246. class ContinueExecutionCache:
  247. cache = ExactWeakKeyDictionary()
  248. generated_code_metadata = ExactWeakKeyDictionary()
  249. @classmethod
  250. def lookup(
  251. cls, code: types.CodeType, lineno: int, init_offset: int, *key: Any
  252. ) -> types.CodeType:
  253. if code not in cls.cache:
  254. cls.cache[code] = {}
  255. key = tuple(key)
  256. if key not in cls.cache[code]:
  257. cls.cache[code][key] = cls.generate(code, lineno, init_offset, *key)
  258. return cls.cache[code][key]
  259. @classmethod
  260. def generate(
  261. cls,
  262. code: types.CodeType,
  263. lineno: int,
  264. init_offset: int,
  265. resume_offset: int,
  266. setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+
  267. nstack: int,
  268. argnames: tuple[str, ...],
  269. argnames_null: tuple[str, ...],
  270. setup_fns: tuple[ReenterWith, ...],
  271. stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...],
  272. argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...],
  273. null_idxes: tuple[int, ...],
  274. # mainly used to ensure distinct code objects per stack trace,
  275. # which prevents excessive recompilation of inner frames
  276. nested_code_objs: tuple[types.CodeType],
  277. ) -> types.CodeType:
  278. assert resume_offset is not None
  279. assert not (
  280. code.co_flags
  281. & (CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR)
  282. )
  283. assert code.co_flags & CO_OPTIMIZED
  284. if code in ContinueExecutionCache.generated_code_metadata:
  285. return cls.generate_based_on_original_code_object(
  286. code,
  287. lineno,
  288. init_offset,
  289. resume_offset,
  290. setup_fn_target_offsets,
  291. nstack,
  292. argnames,
  293. argnames_null,
  294. setup_fns,
  295. stack_ctx_vars,
  296. argnames_ctx_vars,
  297. null_idxes,
  298. nested_code_objs,
  299. )
  300. is_py311_plus = sys.version_info >= (3, 11)
  301. meta = ResumeFunctionMetadata(code)
  302. def update(
  303. instructions: list[Instruction], code_options: dict[str, Any]
  304. ) -> None:
  305. meta.instructions = copy.deepcopy(instructions)
  306. args = ["__nested_resume_fns", "__nested_frame_values"]
  307. args += [f"___stack{i}" for i in range(nstack)]
  308. args.extend(v for v in argnames if v not in args)
  309. freevars = tuple(code_options["co_cellvars"] or []) + tuple(
  310. code_options["co_freevars"] or []
  311. )
  312. freevars = tuple(sorted(freevars))
  313. code_options["co_name"] = (
  314. f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_{code_options['co_name']}_at_{lineno}"
  315. )
  316. if is_py311_plus:
  317. qualified_path = code_options["co_qualname"].rsplit(".", maxsplit=1)
  318. if len(qualified_path) == 1:
  319. code_options["co_qualname"] = code_options["co_name"]
  320. else:
  321. assert len(qualified_path) == 2
  322. module_name, co_name = qualified_path
  323. code_options["co_qualname"] = (
  324. f"{module_name}.{TORCH_DYNAMO_RESUME_IN_PREFIX}_{co_name}_at_{lineno}"
  325. )
  326. code_options["co_firstlineno"] = lineno
  327. code_options["co_cellvars"] = ()
  328. code_options["co_freevars"] = freevars
  329. code_options["co_argcount"] = len(args)
  330. code_options["co_posonlyargcount"] = 0
  331. code_options["co_kwonlyargcount"] = 0
  332. code_options["co_varnames"] = tuple(
  333. args
  334. + [v for v in argnames_null if v not in args]
  335. + [v for v in code_options["co_varnames"] if v not in args]
  336. + [IS_TRACING_RESUME_PROLOGUE_VARNAME]
  337. )
  338. code_options["co_flags"] = code_options["co_flags"] & ~(
  339. CO_VARARGS | CO_VARKEYWORDS
  340. )
  341. target = next(i for i in instructions if i.offset == resume_offset)
  342. prefix = []
  343. if is_py311_plus:
  344. if freevars:
  345. prefix.append(
  346. create_instruction("COPY_FREE_VARS", arg=len(freevars))
  347. )
  348. prefix.append(create_instruction("RESUME", arg=0))
  349. # Set is_tracing_resume_prologue to prevent graph breaks.
  350. # This doesn't really do anything at runtime, but dynamo will trace this
  351. # and will know that we're in a resume function prologue.
  352. prefix.extend(
  353. [
  354. create_instruction("LOAD_CONST", argval=True),
  355. create_instruction(
  356. "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
  357. ),
  358. ]
  359. )
  360. cleanup: list[Instruction] = []
  361. hooks = {fn.stack_index: fn for fn in setup_fns}
  362. hook_target_offsets = {
  363. fn.stack_index: setup_fn_target_offsets[i]
  364. for i, fn in enumerate(setup_fns)
  365. }
  366. offset_to_inst = {inst.offset: inst for inst in instructions}
  367. # map old hook targets to new targets generated by the hook
  368. old_hook_target_remap = {}
  369. null_idxes_i = 0
  370. stack_ctx_vars_d = dict(stack_ctx_vars) # type: ignore[var-annotated,arg-type]
  371. for i in range(nstack):
  372. while (
  373. null_idxes_i < len(null_idxes)
  374. and null_idxes[null_idxes_i] == i + null_idxes_i
  375. ):
  376. prefix.append(create_instruction("PUSH_NULL"))
  377. null_idxes_i += 1
  378. prefix.append(create_instruction("LOAD_FAST", argval=f"___stack{i}"))
  379. if i in hooks:
  380. hook = hooks.pop(i)
  381. hook_insts, exn_target = hook(code_options, cleanup)
  382. prefix.extend(hook_insts)
  383. if is_py311_plus:
  384. hook_target_offset = hook_target_offsets.pop(i)
  385. old_hook_target = offset_to_inst[hook_target_offset]
  386. meta.prefix_block_target_offset_remap.append(hook_target_offset)
  387. old_hook_target_remap[old_hook_target] = exn_target
  388. if i in stack_ctx_vars_d:
  389. # NOTE: we assume that current stack var is a context manager CLASS!
  390. # Load args for context variable and construct it
  391. prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[i]))
  392. if is_py311_plus:
  393. # reverse the mapping since targets of later/nested contexts are inserted
  394. # into the mapping later, but show up earlier in the prefix.
  395. meta.prefix_block_target_offset_remap = list(
  396. reversed(meta.prefix_block_target_offset_remap)
  397. )
  398. assert not hooks
  399. # NOTE: we assume that local var is a context manager CLASS!
  400. # initialize inactive context vars in argnames
  401. for name, vals in argnames_ctx_vars:
  402. prefix.append(create_instruction("LOAD_FAST", argval=name))
  403. prefix.extend(_load_tuple_and_call(vals))
  404. prefix.append(create_instruction("STORE_FAST", argval=name))
  405. # 3.12+: store NULL into variables that were NULL
  406. if argnames_null:
  407. assert sys.version_info >= (3, 12)
  408. for v in argnames_null:
  409. assert v not in args
  410. prefix.extend(
  411. [
  412. create_instruction("PUSH_NULL"),
  413. create_instruction("STORE_FAST", argval=v),
  414. ]
  415. )
  416. # Call nested resume function
  417. if nested_code_objs:
  418. prefix.extend(
  419. [
  420. # set up __nested_resume_fns[-1] call
  421. *add_push_null(
  422. [
  423. create_instruction(
  424. "LOAD_FAST", argval="__nested_resume_fns"
  425. ),
  426. create_instruction("LOAD_CONST", argval=-1),
  427. create_instruction("BINARY_SUBSCR"),
  428. ]
  429. ),
  430. # del __nested_resume_fns[-1]
  431. create_instruction("LOAD_FAST", argval="__nested_resume_fns"),
  432. create_instruction("LOAD_CONST", argval=-1),
  433. create_instruction("DELETE_SUBSCR"),
  434. # load [__nested_resume_fns, __nested_frame_values]
  435. create_instruction("LOAD_FAST", argval="__nested_resume_fns"),
  436. create_instruction("LOAD_FAST", argval="__nested_frame_values"),
  437. create_instruction("BUILD_LIST", arg=2),
  438. # load __nested_frame_values[-1]
  439. create_instruction("LOAD_FAST", argval="__nested_frame_values"),
  440. create_instruction("LOAD_CONST", argval=-1),
  441. create_instruction("BINARY_SUBSCR"),
  442. # create [
  443. # __nested_resume_fns,
  444. # __nested_frame_values,
  445. # *__nested_frame_values[-1],
  446. # ]
  447. create_instruction("LIST_EXTEND", arg=1),
  448. # del __nested_frame_values[-1]
  449. create_instruction("LOAD_FAST", argval="__nested_frame_values"),
  450. create_instruction("LOAD_CONST", argval=-1),
  451. create_instruction("DELETE_SUBSCR"),
  452. # delete __nested values
  453. create_instruction("DELETE_FAST", argval="__nested_resume_fns"),
  454. create_instruction(
  455. "DELETE_FAST", argval="__nested_frame_values"
  456. ),
  457. # Set is_tracing_resume_prologue back to allow graph breaks
  458. # in the nested resume
  459. create_instruction("LOAD_CONST", argval=False),
  460. create_instruction(
  461. "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
  462. ),
  463. # finish the call
  464. create_instruction("CALL_FUNCTION_EX", arg=0),
  465. ]
  466. )
  467. else:
  468. # Set is_tracing_resume_prologue back to allow graph breaks after the jump
  469. prefix.extend(
  470. [
  471. create_instruction("LOAD_CONST", argval=False),
  472. create_instruction(
  473. "STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
  474. ),
  475. ]
  476. )
  477. prefix.append(create_jump_absolute(target))
  478. # because the line number table monotonically increases from co_firstlineno
  479. # remove starts_line for any instructions before the graph break instruction
  480. # this will ensure the instructions after the break have the correct line numbers
  481. for inst in instructions:
  482. if inst.offset == target.offset:
  483. break
  484. inst.starts_line = None
  485. if sys.version_info >= (3, 11):
  486. inst.positions = None
  487. if cleanup:
  488. prefix.extend(cleanup)
  489. prefix.extend(cls.unreachable_codes(code_options))
  490. # remap original instructions' exception table entries
  491. if old_hook_target_remap:
  492. assert is_py311_plus
  493. for inst in instructions:
  494. if (
  495. inst.exn_tab_entry
  496. and inst.exn_tab_entry.target in old_hook_target_remap
  497. ):
  498. inst.exn_tab_entry.target = old_hook_target_remap[ # type: ignore[assignment]
  499. inst.exn_tab_entry.target
  500. ]
  501. # TODO(jansel): add dead code elimination here
  502. instructions[:] = prefix + instructions
  503. new_code, _ = transform_code_object(code, update)
  504. ContinueExecutionCache.generated_code_metadata[new_code] = meta
  505. return new_code
  506. @staticmethod
  507. def unreachable_codes(code_options: dict[str, Any]) -> list[Instruction]:
  508. """Codegen a `raise None` to make analysis work for unreachable code"""
  509. return [
  510. create_load_const(None),
  511. create_instruction("RAISE_VARARGS", arg=1),
  512. ]
  513. @classmethod
  514. def generate_based_on_original_code_object(
  515. cls,
  516. code: types.CodeType,
  517. lineno: int,
  518. init_offset: int,
  519. resume_offset: int,
  520. setup_fn_target_offsets: tuple[int, ...],
  521. *args: Any,
  522. ) -> types.CodeType:
  523. """
  524. This handles the case of generating a resume into code generated
  525. to resume something else. We want to always generate starting
  526. from the original code object so that if control flow paths
  527. converge we only generated 1 resume function (rather than 2^n
  528. resume functions).
  529. """
  530. meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[
  531. code
  532. ]
  533. def find_orig_offset(cur_offset: int) -> int:
  534. orig_offset = -1
  535. def find_orig_offset_transform(
  536. instructions: list[Instruction], code_options: dict[str, Any]
  537. ) -> None:
  538. nonlocal orig_offset
  539. (target,) = (i for i in instructions if i.offset == cur_offset)
  540. # match the functions starting at the last instruction as we have added a prefix
  541. new_target_tuple = tuple(
  542. i2
  543. for i1, i2 in zip(
  544. reversed(instructions), reversed(meta.instructions)
  545. )
  546. if i1 is target
  547. )
  548. if not new_target_tuple:
  549. # Instruction with cur_offset in instructions was not found
  550. # in the original code - orig_offset left as -1.
  551. # Caller expected to handle this case.
  552. return
  553. assert len(new_target_tuple) == 1
  554. new_target = new_target_tuple[0]
  555. assert target.opcode == new_target.opcode
  556. assert new_target.offset is not None
  557. orig_offset = new_target.offset
  558. transform_code_object(code, find_orig_offset_transform)
  559. return orig_offset
  560. orig_init_offset = find_orig_offset(init_offset)
  561. # It is fine if the initial instruction is not found in the original code;
  562. # this means we graph broke in the prefix, which only happens with nested graph breaks.
  563. # We should not be running into ambiguous graph break issues here.
  564. orig_resume_offset = find_orig_offset(resume_offset)
  565. assert orig_resume_offset > -1, (
  566. "resume instruction not found in original code - this is a bug."
  567. )
  568. if sys.version_info >= (3, 11):
  569. # setup_fn_target_offsets currently contains the target offset of
  570. # each setup_fn, based on `code`. When we codegen the resume function
  571. # based on the original code object, `meta.code`, the offsets in
  572. # setup_fn_target_offsets must be based on `meta.code` instead.
  573. offset_key = (orig_init_offset, orig_resume_offset)
  574. # NOTE: we key by offset_key since the same resume function may graph
  575. # break in multiple places and we need different block_target_offset_remap's
  576. # for each graph break location. Keying by orig_resume_offset may not be enough
  577. # if 2 graph breaks on different initial offsets resume on the same instruction
  578. # (although this is rare and not tested anywhere).
  579. if offset_key not in meta.block_target_offset_remap:
  580. block_target_offset_remap = meta.block_target_offset_remap[
  581. offset_key
  582. ] = {}
  583. def remap_block_offsets(
  584. instructions: list[Instruction], code_options: dict[str, Any]
  585. ) -> None:
  586. # NOTE: each prefix block generates exactly one PUSH_EXC_INFO,
  587. # so we can tell which block a prefix PUSH_EXC_INFO belongs to,
  588. # by counting. Then we can use meta.prefix_block_target_offset_remap
  589. # to determine where in the original code the PUSH_EXC_INFO offset
  590. # replaced.
  591. prefix_blocks: list[Instruction] = []
  592. for inst in instructions:
  593. # NOTE meta.prefix_block_target_offset_remap is based off of how we codegen'd
  594. # context managers at the prefix/prologue of the resume function. It is the same for
  595. # every graph break in the same resume function, so we do not need to recompute
  596. # for each graph break (unlike for meta.block_target_offset_remap)
  597. if len(prefix_blocks) == len(
  598. meta.prefix_block_target_offset_remap
  599. ):
  600. break
  601. if inst.opname == "PUSH_EXC_INFO":
  602. prefix_blocks.append(inst)
  603. # remap block target offsets for blocks generated in the resume prefix
  604. for inst, o in zip(
  605. prefix_blocks, meta.prefix_block_target_offset_remap
  606. ):
  607. block_target_offset_remap[cast(int, inst.offset)] = o
  608. # current bytecode targets are after the prefix PUSH_EXC_INFO's
  609. cur_start_offset = (
  610. cast(int, prefix_blocks[-1].offset) if prefix_blocks else -1
  611. )
  612. # get the remaining block target offsets of the current bytecode
  613. cur_inst_offsets = sorted(
  614. n for n in setup_fn_target_offsets if n > cur_start_offset
  615. )
  616. targets = _filter_iter(
  617. instructions, cur_inst_offsets, lambda inst, o: inst.offset == o
  618. )
  619. # The original code and resume code should have matching suffixes.
  620. # Match the post-prefix block target offsets of the current resume code
  621. # and the original code.
  622. orig_targets = reversed(
  623. _filter_iter(
  624. zip(reversed(instructions), reversed(meta.instructions)),
  625. reversed(targets),
  626. lambda v1, v2: v1[0] is v2,
  627. )
  628. )
  629. for orig, cur in zip(orig_targets, targets):
  630. block_target_offset_remap[cur.offset] = orig[1].offset
  631. transform_code_object(code, remap_block_offsets)
  632. # if offset_key or offset is not in setup_fn_target_offsets, it is an error
  633. # that needs to be fixed
  634. setup_fn_target_offsets = tuple(
  635. meta.block_target_offset_remap[offset_key][n]
  636. for n in setup_fn_target_offsets
  637. )
  638. return ContinueExecutionCache.lookup(
  639. meta.code,
  640. lineno,
  641. orig_init_offset,
  642. orig_resume_offset,
  643. setup_fn_target_offsets,
  644. *args,
  645. )