debugging.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. """
  2. This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot
  3. compilation and execution issues. It includes:
  4. Key Debugging Backends:
  5. - eager: Simple pass-through backend that runs models in eager mode
  6. - eager_noexcept: Similar to eager but with additional exception handling
  7. - eager_debug: Adds schema validation checks for custom operators
  8. - aot_eager: Uses AOT Autograd with nop compiler for debugging
  9. - aot_eager_decomp_partition: Uses TorchInductor decompositions for debugging
  10. - torchscript: Compiles using TorchScript for debugging JIT-related issues
  11. Testing and Development Tools:
  12. - Backends for inducing specific errors (compile/runtime/accuracy)
  13. - ExplainOutput class for detailed graph compilation analysis
  14. - Utilities for cross-referencing and mode management
  15. - Tools for graph detail inspection and break reason analysis
  16. These backends are primarily used for:
  17. 1. Debugging graph breaks and compilation failures
  18. 2. Testing error handling and recovery mechanisms
  19. 3. Analyzing performance bottlenecks
  20. 4. Validating operator schemas and decompositions
  21. """
  22. import dataclasses
  23. import functools
  24. import logging
  25. from collections.abc import Iterable
  26. from importlib import import_module
  27. from typing import Any, Callable, Optional, TYPE_CHECKING, Union
  28. import torch
  29. from functorch.compile import min_cut_rematerialization_partition
  30. from torch import _guards
  31. from torch._dynamo.output_graph import GraphCompileReason
  32. from torch._functorch import config as functorch_config
  33. from torch._functorch.compilers import ts_compile
  34. from .common import aot_autograd
  35. from .registry import CompiledFn, CompilerFn, register_debug_backend as register_backend
  36. if TYPE_CHECKING:
  37. from torch.fx.node import Target
  38. log = logging.getLogger(__name__)
  39. @register_backend
  40. def eager(
  41. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  42. ) -> Callable[..., Any]:
  43. if kwargs:
  44. log.warning("eager backend ignoring extra kwargs %s", kwargs)
  45. return gm.forward
  46. def make_eager_backend_with_torch_function_mode(
  47. mode: torch.overrides.TorchFunctionMode,
  48. ) -> Callable[..., Any]:
  49. return make_eager_backend_with_torch_function_modes([mode])
  50. def make_eager_backend_with_torch_function_modes(
  51. modes: Iterable[torch.overrides.TorchFunctionMode],
  52. ) -> Callable[..., Any]:
  53. """Used to trace HOPs (cond and while) for eager execution, the metadata
  54. TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
  55. in the HOP, so we need to externally run this mode and not trace it."""
  56. from contextlib import ExitStack
  57. def fn(
  58. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  59. ) -> Callable[..., Any]:
  60. stack = ExitStack()
  61. for mode in modes:
  62. stack.enter_context(mode)
  63. result = gm.forward
  64. stack.close()
  65. return result
  66. return fn
  67. @register_backend
  68. def eager_noexcept(
  69. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  70. ) -> Callable[..., Any]:
  71. if kwargs:
  72. log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs)
  73. # This backend is intended to check that dynamo-generated GraphModules
  74. # do not cause errors.
  75. def inner(*args: Any) -> Any:
  76. try:
  77. return gm(*args)
  78. except Exception as e:
  79. raise torch._dynamo.exc.TorchDynamoException(
  80. "Unexpected exception when running generated GraphModule"
  81. ) from e
  82. return inner
  83. @register_backend
  84. def pre_dispatch_eager(
  85. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  86. ) -> torch.fx.GraphModule:
  87. if kwargs:
  88. log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs)
  89. from torch.fx.experimental.proxy_tensor import make_fx
  90. def runnable_gm(*args: Any) -> Any:
  91. return torch.fx.Interpreter(gm).run(*args)
  92. pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
  93. pre_dispatch_gm.print_readable()
  94. return pre_dispatch_gm
  95. @register_backend
  96. def eager_debug(
  97. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  98. ) -> Callable[..., Any]:
  99. if kwargs:
  100. log.warning("eager_debug backend ignoring extra kwargs %s", kwargs)
  101. from torch._subclasses.schema_check_mode import SchemaCheckMode
  102. # We could add more debugging bits here.
  103. # Right now, this backend can be used to check for and error on
  104. # custom dispatcher ops that have incorrect schemas.
  105. def inner(*args: Any) -> Any:
  106. with SchemaCheckMode():
  107. return torch.fx.Interpreter(gm).run(*args)
  108. return inner
  109. @register_backend(name="ts") # type: ignore[misc]
  110. def torchscript(
  111. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
  112. ) -> torch.jit.ScriptModule:
  113. return torch.jit.script(gm)
  114. # used boxed call to discard inputs when they are no longer needed
  115. def boxed_nop(
  116. fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  117. ) -> Callable[..., Any]:
  118. def run(args: Any) -> Any:
  119. return torch.fx.Interpreter(fx_g).boxed_run(args)
  120. run._boxed_call = True # type: ignore[attr-defined]
  121. return run
  122. def boxed_nop_with_mode(
  123. fx_g: torch.fx.GraphModule,
  124. example_inputs: list[torch.Tensor],
  125. *,
  126. mode: torch.overrides.TorchFunctionMode,
  127. ) -> Callable[..., Any]:
  128. def run(args: Any) -> Any:
  129. with mode:
  130. return torch.fx.Interpreter(fx_g).boxed_run(args)
  131. run._boxed_call = True # type: ignore[attr-defined]
  132. return run
  133. def fake_crossref_boxed_nop(
  134. fx_g: torch.fx.GraphModule,
  135. example_inputs: list[torch.Tensor],
  136. ignore_op_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None,
  137. ) -> Callable[..., Any]:
  138. def run(args: Any) -> Any:
  139. with torch._subclasses.CrossRefFakeMode(ignore_op_fn):
  140. return torch.fx.Interpreter(fx_g).boxed_run(args)
  141. run._boxed_call = True # type: ignore[attr-defined]
  142. return run
  143. def ignore_builtins(op: torch._ops.OpOverload) -> bool:
  144. return op.namespace in ("aten", "prims", "prim")
  145. def get_nop_func() -> Callable[
  146. [torch.fx.GraphModule, list[torch.Tensor]], Callable[..., Any]
  147. ]:
  148. if not torch._functorch.config.fake_tensor_crossref:
  149. return boxed_nop
  150. elif torch._functorch.config.fake_tensor_crossref == "all":
  151. return fake_crossref_boxed_nop
  152. else:
  153. assert torch._functorch.config.fake_tensor_crossref == "custom_ops"
  154. return functools.partial(fake_crossref_boxed_nop, ignore_op_fn=ignore_builtins)
  155. # Useful for debugging purpose
  156. # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
  157. def aot_eager(
  158. gm: torch.fx.GraphModule,
  159. fake_tensor_inputs: list[torch.Tensor],
  160. fw_compiler: Optional[Callable[..., Any]] = None,
  161. bw_compiler: Optional[Callable[..., Any]] = None,
  162. **kwargs: Any,
  163. ) -> Callable[..., Any]:
  164. return aot_autograd(
  165. fw_compiler=fw_compiler or boxed_nop,
  166. bw_compiler=bw_compiler or boxed_nop,
  167. partition_fn=min_cut_rematerialization_partition,
  168. keep_inference_input_mutations=True,
  169. )(gm, fake_tensor_inputs, **kwargs)
  170. register_backend(name="aot_eager", compiler_fn=aot_eager)
  171. aot_eager_default_partitioner = aot_autograd(
  172. fw_compiler=boxed_nop, keep_inference_input_mutations=True
  173. )
  174. register_backend(
  175. name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner
  176. )
  177. # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
  178. # inductor problems.
  179. # aot_eager_decomp_partition just replaces the inductor compiler with nop to help
  180. # isolate inductor vs aot_eager errors
  181. def aot_eager_decomp_partition(
  182. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  183. ) -> Callable[..., Any]:
  184. if kwargs:
  185. log.warning(
  186. "aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
  187. )
  188. from torch._inductor.compiler_bisector import CompilerBisector
  189. config_patches = {"unlift_effect_tokens": True}
  190. if bisect_changes := CompilerBisector.get_config_change(
  191. "aot_eager_decomp_partition"
  192. ):
  193. config_patches.update(bisect_changes) # type: ignore[arg-type]
  194. with functorch_config.patch(config_patches):
  195. return aot_autograd(
  196. # these are taken from memory_efficient_fusion()
  197. fw_compiler=get_nop_func(),
  198. bw_compiler=get_nop_func(),
  199. # NB: lambda here is to delay import of inductor
  200. decompositions=lambda: import_module(
  201. "torch._inductor.compile_fx"
  202. ).select_decomp_table(),
  203. partition_fn=functools.partial(
  204. min_cut_rematerialization_partition, compiler="inductor"
  205. ),
  206. )(gm, fake_tensor_inputs)
  207. register_backend(
  208. name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
  209. )
  210. # aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition,
  211. # except that it takes a TorchDispatchMode mode and run the fw/bw in the mode
  212. def aot_eager_decomp_partition_with_mode(
  213. gm: torch.fx.GraphModule,
  214. fake_tensor_inputs: list[torch.Tensor],
  215. mode: Any,
  216. **kwarg: Any,
  217. ) -> Callable[..., Any]:
  218. return aot_autograd(
  219. # these are taken from memory_efficient_fusion()
  220. fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode),
  221. bw_compiler=functools.partial(boxed_nop_with_mode, mode=mode),
  222. # NB: lambda here is to delay import of inductor
  223. decompositions=lambda: import_module(
  224. "torch._inductor.compile_fx"
  225. ).select_decomp_table(),
  226. partition_fn=functools.partial(
  227. min_cut_rematerialization_partition, compiler="inductor"
  228. ),
  229. )(gm, fake_tensor_inputs)
  230. register_backend(
  231. name="aot_eager_decomp_partition_with_mode",
  232. compiler_fn=aot_eager_decomp_partition_with_mode, # type: ignore[arg-type]
  233. )
  234. def aot_eager_decomp_partition_crossref(
  235. gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
  236. ) -> Callable[..., Any]:
  237. # if the config is set, respect it, otherwise only test custom_ops.
  238. # custom_op bad metas always manifest as an error whereas aten will only sometimes.
  239. # by default, use the less noisy option
  240. config_val = (
  241. "custom_ops"
  242. if not functorch_config.fake_tensor_crossref
  243. else functorch_config.fake_tensor_crossref
  244. )
  245. with functorch_config.patch(fake_tensor_crossref=config_val):
  246. return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs)
  247. register_backend(
  248. name="aot_eager_decomp_partition_crossref",
  249. compiler_fn=aot_eager_decomp_partition_crossref,
  250. )
  251. # AOT Autograd with torchscript backend. Default partitioner.
  252. # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
  253. # by using the relevant fuser with torch.jit.fuser(...)
  254. aot_ts = aot_autograd(fw_compiler=ts_compile)
  255. register_backend(name="aot_ts", compiler_fn=aot_ts)
  256. # These buggy backends are used for inducing bugs so that we can test
  257. # our repro extraction / minifier scripts
  258. class ReluCompileError(Exception):
  259. pass
  260. class TestingOnlyCompileError(Exception):
  261. pass
  262. @register_backend
  263. def relu_compile_error_TESTING_ONLY(
  264. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  265. ) -> torch.fx.GraphModule:
  266. for node in gm.graph.nodes:
  267. if node.target == torch.relu:
  268. raise ReluCompileError
  269. return gm
  270. @register_backend
  271. def relu_runtime_error_TESTING_ONLY(
  272. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  273. ) -> torch.fx.GraphModule:
  274. for node in gm.graph.nodes:
  275. if node.target == torch.relu:
  276. node.target = torch._assert
  277. node.args = (False, "ReluRuntimeError")
  278. gm.recompile()
  279. return gm
  280. @register_backend
  281. def relu_accuracy_error_TESTING_ONLY(
  282. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  283. ) -> torch.fx.GraphModule:
  284. for node in gm.graph.nodes:
  285. if node.target == torch.relu:
  286. node.target = torch.add
  287. node.args = (node.args[0], 1)
  288. gm.recompile()
  289. return gm
  290. @register_backend
  291. def non_leaf_compile_error_TESTING_ONLY(
  292. gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  293. ) -> torch.fx.GraphModule:
  294. # Require at least one non-trivial thing in the graph,
  295. # see https://github.com/pytorch/pytorch/issues/102898
  296. for node in gm.graph.nodes:
  297. if node.op == "call_function":
  298. break
  299. else:
  300. return gm
  301. for t in example_inputs:
  302. if not t.is_leaf:
  303. raise TestingOnlyCompileError
  304. return gm
  305. @dataclasses.dataclass
  306. class ExplainOutput:
  307. """
  308. This is the output of :func:`torch._dynamo.explain()`
  309. There is no reason to create this class directly.
  310. """
  311. graphs: list[torch.fx.GraphModule]
  312. graph_count: int
  313. graph_break_count: int
  314. break_reasons: list[GraphCompileReason]
  315. op_count: int
  316. ops_per_graph: Optional[list[list["Target"]]] = None
  317. out_guards: Optional[list[_guards.Guard]] = None
  318. compile_times: Optional[str] = None
  319. def __str__(self) -> str:
  320. output = f"Graph Count: {self.graph_count}\n"
  321. output += f"Graph Break Count: {self.graph_break_count}\n"
  322. output += f"Op Count: {self.op_count}\n"
  323. output += "Break Reasons:\n"
  324. for idx, break_reason in enumerate(self.break_reasons):
  325. output += f" Break Reason {idx + 1}:\n"
  326. output += f" Reason: {break_reason.reason}\n"
  327. output += " User Stack:\n"
  328. for frame_summary in break_reason.user_stack:
  329. output += f" {frame_summary}\n"
  330. if self.ops_per_graph is not None:
  331. output += "Ops per Graph:\n"
  332. for idx, ops in enumerate(self.ops_per_graph):
  333. output += f" Ops {idx + 1}:\n"
  334. for op in ops:
  335. output += f" {op}\n"
  336. if self.out_guards is not None:
  337. output += "Out Guards:\n"
  338. for i, guard in enumerate(self.out_guards):
  339. output += f" Guard {i + 1}:\n"
  340. output += f" {str(guard)}"
  341. if self.compile_times is not None:
  342. output += f"Compile Times: {self.compile_times}\n"
  343. return output
  344. def _explain_graph_detail(
  345. gm: torch.fx.GraphModule,
  346. graphs: list[torch.fx.GraphModule],
  347. op_count: int,
  348. ops_per_graph: list[list["Target"]],
  349. break_reasons: list[GraphCompileReason],
  350. ) -> tuple[
  351. torch.fx.GraphModule,
  352. list[torch.fx.GraphModule],
  353. int,
  354. list[list["Target"]],
  355. list[GraphCompileReason],
  356. ]:
  357. """
  358. This function is a utility which processes a torch.fx.GraphModule and
  359. accumulates information about its ops, graph breaks, and other details. It
  360. is intended to be used by the ExplainWithBackend class and
  361. `torch._dynamo.explain()` to provide details from Dynamo's graph capture.
  362. Parameters:
  363. gm (torch.fx.GraphModule): The GraphModule to be processed.
  364. graphs (list): A list that accumulates all the GraphModules processed.
  365. op_count (int): The total count of operations in all GraphModules processed so far.
  366. ops_per_graph (list): A list that accumulates the operations of each GraphModule.
  367. break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule.
  368. Returns:
  369. tuple: A tuple containing the processed GraphModule, the updated lists of graphs,
  370. operations per graph, and break reasons, and the updated operation count.
  371. """
  372. graphs.append(gm)
  373. ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
  374. op_count += len(ops)
  375. ops_per_graph.append(ops)
  376. if gm.compile_subgraph_reason.graph_break: # type: ignore[union-attr]
  377. break_reasons.append(gm.compile_subgraph_reason) # type: ignore[arg-type]
  378. return gm, graphs, op_count, ops_per_graph, break_reasons
  379. class ExplainWithBackend:
  380. """
  381. This class is intended to be used as a backend for `torch.compile`. It is
  382. composable with other backends. When used in this way, it accumulates
  383. information about graph breaks, ops, and other info and provides a string
  384. representation summarizing this information.
  385. Attributes:
  386. backend (str): The name of the backend to use for optimization.
  387. graphs (list): A list of the graphs captured by TorchDynamo.
  388. op_count (int): The total number of operations in all optimized graphs.
  389. break_reasons (list): A list of graph break reasons with stack traces.
  390. Example Usage:
  391. def fn(x):
  392. x = torch.sigmoid(x)
  393. return x
  394. torch._dynamo.reset()
  395. eb = ExplainWithBackend("inductor")
  396. optimized_fn = torch.compile(fn, backend=eb)
  397. result = optimized_fn(torch.randn(5))
  398. print(eb.output())
  399. """
  400. def __init__(self, backend: Union[CompilerFn, str]) -> None:
  401. from .registry import lookup_backend
  402. self.backend = lookup_backend(backend)
  403. self.graphs: list[torch.fx.GraphModule] = []
  404. self.op_count = 0
  405. self.break_reasons: list[GraphCompileReason] = []
  406. def __call__(
  407. self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  408. ) -> CompiledFn:
  409. ops_per_graph: list[list[Target]] = []
  410. gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
  411. gm, self.graphs, self.op_count, ops_per_graph, self.break_reasons
  412. )
  413. return self.backend(gm, example_inputs)
  414. def output(self) -> ExplainOutput:
  415. graph_count = len(self.graphs)
  416. output = ExplainOutput(
  417. self.graphs,
  418. graph_count,
  419. graph_count - 1,
  420. self.break_reasons,
  421. self.op_count,
  422. )
  423. return output