common.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """
  2. This module provides common utilities and base classes for TorchDynamo backends.
  3. Key components:
  4. - AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends
  5. - Backend utilities for handling:
  6. - Fake tensor conversion
  7. - Device/dtype detection from inputs
  8. - Memory efficient fusion
  9. - Graph flattening
  10. - Common compiler configurations
  11. The utilities here are used by various backend implementations to handle
  12. common operations and provide consistent behavior across different backends.
  13. AOT autograd functionality is particularly important as it enables ahead-of-time
  14. optimization of both forward and backward passes.
  15. """
  16. import contextlib
  17. import functools
  18. import logging
  19. from collections.abc import Iterable
  20. from typing import Any, Callable
  21. from typing_extensions import ParamSpec, TypeVar
  22. from unittest.mock import patch
  23. import torch
  24. from torch._dynamo import disable
  25. from torch._dynamo.exc import TensorifyScalarRestartAnalysis
  26. from torch._dynamo.utils import counters, defake, flatten_graph_inputs
  27. from torch._functorch.aot_autograd import (
  28. aot_module_simplified,
  29. SerializableAOTDispatchCompiler,
  30. )
  31. from torch.utils._python_dispatch import _disable_current_modes
  32. log = logging.getLogger(__name__)
  33. P = ParamSpec("P")
  34. R = TypeVar("R")
  35. class AotAutograd:
  36. def __init__(self, **kwargs: Any) -> None:
  37. self.__name__ = "compiler_fn"
  38. self.kwargs = kwargs
  39. def __call__(
  40. self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any
  41. ) -> Callable[..., Any]:
  42. if kwargs:
  43. log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
  44. if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
  45. return flatten_graph_inputs(
  46. gm,
  47. example_inputs,
  48. self,
  49. )
  50. # Hack to get around circular import problems with aot_eager_decomp_partition
  51. if callable(self.kwargs.get("decompositions")):
  52. self.kwargs["decompositions"] = self.kwargs["decompositions"]()
  53. # NB: dont delete counter increment
  54. counters["aot_autograd"]["total"] += 1
  55. use_fallback = False
  56. if use_fallback:
  57. log.debug("Unable to use AOT Autograd because graph has mutation")
  58. counters["aot_autograd"]["not_ok"] += 1
  59. return gm
  60. def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]:
  61. def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R:
  62. # Note [Wrapping bw_compiler in disable]
  63. # The two disables here:
  64. # - stop TorchDynamo from trying to compile the bw_compiler function itself
  65. # - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces
  66. return disable(
  67. disable(
  68. bw_compiler_fn, reason="do not trace backward compiler function"
  69. )(*args, **kwargs), # type: ignore[misc]
  70. reason="do not trace generated backwards pass",
  71. )
  72. return _wrapped_bw_compiler
  73. bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
  74. if isinstance(bw_compiler, SerializableAOTDispatchCompiler):
  75. bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn)
  76. else:
  77. bw_compiler = wrap_bw_compiler(bw_compiler)
  78. self.kwargs["bw_compiler"] = bw_compiler
  79. self.kwargs["inference_compiler"] = (
  80. self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
  81. )
  82. from functorch.compile import nop
  83. from torch._inductor.debug import enable_aot_logging
  84. # debug asserts slow down compile time noticeably,
  85. # So only default them on when the aot_eager backend is used.
  86. if self.kwargs.get("fw_compiler", None) == nop:
  87. patch_config: contextlib.AbstractContextManager[Any] = patch(
  88. "functorch.compile.config.debug_assert", True
  89. )
  90. else:
  91. patch_config = contextlib.nullcontext()
  92. try:
  93. # NB: NOT cloned!
  94. with enable_aot_logging(), patch_config:
  95. cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  96. counters["aot_autograd"]["ok"] += 1
  97. return disable(cg, reason="do not trace AOT-compiled graph")
  98. except TensorifyScalarRestartAnalysis:
  99. raise
  100. except Exception:
  101. counters["aot_autograd"]["not_ok"] += 1
  102. raise
  103. def aot_autograd(**kwargs: Any) -> AotAutograd:
  104. return AotAutograd(**kwargs)
  105. def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]:
  106. from functorch.compile import (
  107. default_decompositions,
  108. min_cut_rematerialization_partition,
  109. ts_compile,
  110. )
  111. kwargs = {
  112. # these are taken from memory_efficient_fusion()
  113. "fw_compiler": ts_compile,
  114. "bw_compiler": ts_compile,
  115. "partition_fn": min_cut_rematerialization_partition,
  116. }
  117. if use_decomps:
  118. kwargs["decompositions"] = default_decompositions
  119. return kwargs
  120. def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any:
  121. """
  122. Decorator for backends that need real inputs. We swap out fake
  123. tensors for zero tensors.
  124. """
  125. @functools.wraps(fn)
  126. def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any:
  127. with _disable_current_modes():
  128. inputs = list(map(defake, inputs))
  129. return fn(model, inputs, **kwargs) # type: ignore[call-arg]
  130. return wrapper
  131. def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device:
  132. for x in example_inputs:
  133. if hasattr(x, "device"):
  134. return x.device
  135. return torch.device("cpu") # Default fallback
  136. def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype:
  137. for x in example_inputs:
  138. if hasattr(x, "dtype"):
  139. return x.dtype
  140. return torch.float32 # Default fallback