bytecode_analysis.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. """
  2. This module provides utilities for analyzing and optimizing Python bytecode.
  3. Key functionality includes:
  4. - Dead code elimination
  5. - Jump instruction optimization
  6. - Stack size analysis and verification
  7. - Live variable analysis
  8. - Line number propagation and cleanup
  9. - Exception table handling for Python 3.11+
  10. The utilities in this module are used to analyze and transform bytecode
  11. for better performance while maintaining correct semantics.
  12. """
  13. import bisect
  14. import dataclasses
  15. import dis
  16. import sys
  17. from typing import Any, TYPE_CHECKING, Union
  18. if TYPE_CHECKING:
  19. # TODO(lucaskabela): consider moving Instruction into this file
  20. # and refactoring in callsite; that way we don't have to guard this import
  21. from .bytecode_transformation import Instruction
  22. TERMINAL_OPCODES = {
  23. dis.opmap["RETURN_VALUE"],
  24. dis.opmap["JUMP_FORWARD"],
  25. dis.opmap["RAISE_VARARGS"],
  26. # TODO(jansel): double check exception handling
  27. }
  28. TERMINAL_OPCODES.add(dis.opmap["RERAISE"])
  29. if sys.version_info >= (3, 11):
  30. TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD"])
  31. TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
  32. else:
  33. TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
  34. if (3, 12) <= sys.version_info < (3, 14):
  35. TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
  36. if sys.version_info >= (3, 13):
  37. TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD_NO_INTERRUPT"])
  38. JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs)
  39. JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES}
  40. HASLOCAL = set(dis.haslocal)
  41. HASFREE = set(dis.hasfree)
  42. stack_effect = dis.stack_effect
  43. def get_indexof(insts: list["Instruction"]) -> dict["Instruction", int]:
  44. """
  45. Get a mapping from instruction memory address to index in instruction list.
  46. Additionally checks that each instruction only appears once in the list.
  47. """
  48. indexof = {}
  49. for i, inst in enumerate(insts):
  50. assert inst not in indexof
  51. indexof[inst] = i
  52. return indexof
  53. def remove_dead_code(instructions: list["Instruction"]) -> list["Instruction"]:
  54. """Dead code elimination"""
  55. indexof = get_indexof(instructions)
  56. live_code = set()
  57. def find_live_code(start: int) -> None:
  58. for i in range(start, len(instructions)):
  59. if i in live_code:
  60. return
  61. live_code.add(i)
  62. inst = instructions[i]
  63. if inst.exn_tab_entry:
  64. find_live_code(indexof[inst.exn_tab_entry.target])
  65. if inst.opcode in JUMP_OPCODES:
  66. assert inst.target is not None
  67. find_live_code(indexof[inst.target])
  68. if inst.opcode in TERMINAL_OPCODES:
  69. return
  70. find_live_code(0)
  71. # change exception table entries if start/end instructions are dead
  72. # assumes that exception table entries have been propagated,
  73. # e.g. with bytecode_transformation.propagate_inst_exn_table_entries,
  74. # and that instructions with an exn_tab_entry lies within its start/end.
  75. if sys.version_info >= (3, 11):
  76. live_idx = sorted(live_code)
  77. for i, inst in enumerate(instructions):
  78. if i in live_code and inst.exn_tab_entry:
  79. # find leftmost live instruction >= start
  80. start_idx = bisect.bisect_left(
  81. live_idx, indexof[inst.exn_tab_entry.start]
  82. )
  83. assert start_idx < len(live_idx)
  84. # find rightmost live instruction <= end
  85. end_idx = (
  86. bisect.bisect_right(live_idx, indexof[inst.exn_tab_entry.end]) - 1
  87. )
  88. assert end_idx >= 0
  89. assert live_idx[start_idx] <= i <= live_idx[end_idx]
  90. inst.exn_tab_entry.start = instructions[live_idx[start_idx]]
  91. inst.exn_tab_entry.end = instructions[live_idx[end_idx]]
  92. return [inst for i, inst in enumerate(instructions) if i in live_code]
  93. def remove_pointless_jumps(instructions: list["Instruction"]) -> list["Instruction"]:
  94. """Eliminate jumps to the next instruction"""
  95. pointless_jumps = {
  96. id(a)
  97. for a, b in zip(instructions, instructions[1:])
  98. if a.opname == "JUMP_ABSOLUTE" and a.target is b
  99. }
  100. return [inst for inst in instructions if id(inst) not in pointless_jumps]
  101. def propagate_line_nums(instructions: list["Instruction"]) -> None:
  102. """Ensure every instruction has line number set in case some are removed"""
  103. cur_line_no = None
  104. def populate_line_num(inst: "Instruction") -> None:
  105. nonlocal cur_line_no
  106. if inst.starts_line:
  107. cur_line_no = inst.starts_line
  108. inst.starts_line = cur_line_no
  109. for inst in instructions:
  110. populate_line_num(inst)
  111. def remove_extra_line_nums(instructions: list["Instruction"]) -> None:
  112. """Remove extra starts line properties before packing bytecode"""
  113. cur_line_no = None
  114. def remove_line_num(inst: "Instruction") -> None:
  115. nonlocal cur_line_no
  116. if inst.starts_line is None:
  117. return
  118. elif inst.starts_line == cur_line_no:
  119. inst.starts_line = None
  120. else:
  121. cur_line_no = inst.starts_line
  122. for inst in instructions:
  123. remove_line_num(inst)
  124. @dataclasses.dataclass
  125. class ReadsWrites:
  126. reads: set[Any]
  127. writes: set[Any]
  128. visited: set[Any]
  129. def livevars_analysis(
  130. instructions: list["Instruction"], instruction: "Instruction"
  131. ) -> set[Any]:
  132. indexof = get_indexof(instructions)
  133. must = ReadsWrites(set(), set(), set())
  134. may = ReadsWrites(set(), set(), set())
  135. def walk(state: ReadsWrites, start: int) -> None:
  136. if start in state.visited:
  137. return
  138. state.visited.add(start)
  139. for i in range(start, len(instructions)):
  140. inst = instructions[i]
  141. if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
  142. if "LOAD" in inst.opname or "DELETE" in inst.opname:
  143. if inst.argval not in must.writes:
  144. state.reads.add(inst.argval)
  145. elif "STORE" in inst.opname:
  146. state.writes.add(inst.argval)
  147. elif inst.opname == "MAKE_CELL":
  148. pass
  149. else:
  150. raise NotImplementedError(f"unhandled {inst.opname}")
  151. if inst.exn_tab_entry:
  152. walk(may, indexof[inst.exn_tab_entry.target])
  153. if inst.opcode in JUMP_OPCODES:
  154. assert inst.target is not None
  155. walk(may, indexof[inst.target])
  156. state = may
  157. if inst.opcode in TERMINAL_OPCODES:
  158. return
  159. walk(must, indexof[instruction])
  160. return must.reads | may.reads
  161. @dataclasses.dataclass
  162. class FixedPointBox:
  163. value: bool = True
  164. @dataclasses.dataclass
  165. class StackSize:
  166. low: Union[int, float]
  167. high: Union[int, float]
  168. fixed_point: FixedPointBox
  169. def zero(self) -> None:
  170. self.low = 0
  171. self.high = 0
  172. self.fixed_point.value = False
  173. def offset_of(self, other: "StackSize", n: int) -> None:
  174. prior = (self.low, self.high)
  175. self.low = min(self.low, other.low + n)
  176. self.high = max(self.high, other.high + n)
  177. if (self.low, self.high) != prior:
  178. self.fixed_point.value = False
  179. def exn_tab_jump(self, depth: int) -> None:
  180. prior = (self.low, self.high)
  181. self.low = min(self.low, depth)
  182. self.high = max(self.high, depth)
  183. if (self.low, self.high) != prior:
  184. self.fixed_point.value = False
  185. def stacksize_analysis(instructions: list["Instruction"]) -> Union[int, float]:
  186. assert instructions
  187. fixed_point = FixedPointBox()
  188. stack_sizes = {
  189. inst: StackSize(float("inf"), float("-inf"), fixed_point)
  190. for inst in instructions
  191. }
  192. stack_sizes[instructions[0]].zero()
  193. for _ in range(100):
  194. if fixed_point.value:
  195. break
  196. fixed_point.value = True
  197. for inst, next_inst in zip(instructions, instructions[1:] + [None]):
  198. stack_size = stack_sizes[inst]
  199. if inst.opcode not in TERMINAL_OPCODES:
  200. assert next_inst is not None, f"missing next inst: {inst}"
  201. eff = stack_effect(inst.opcode, inst.arg, jump=False)
  202. stack_sizes[next_inst].offset_of(stack_size, eff)
  203. if inst.opcode in JUMP_OPCODES:
  204. assert inst.target is not None, f"missing target: {inst}"
  205. stack_sizes[inst.target].offset_of(
  206. stack_size, stack_effect(inst.opcode, inst.arg, jump=True)
  207. )
  208. if inst.exn_tab_entry:
  209. # see https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
  210. # on why depth is computed this way.
  211. depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1
  212. stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth)
  213. low = min(x.low for x in stack_sizes.values())
  214. high = max(x.high for x in stack_sizes.values())
  215. assert fixed_point.value, "failed to reach fixed point"
  216. assert low >= 0
  217. return high