external_utils.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. """
  2. This module contains utility functions that are explicitly allowed to be called during
  3. TorchDynamo compilation. These functions are carefully vetted to ensure they work
  4. correctly within the TorchDynamo tracing and compilation process.
  5. Key functionality groups:
  6. - Compilation State:
  7. Functions for checking compilation state (is_compiling)
  8. - Function Wrapping:
  9. Utilities for wrapping functions (wrap_inline, wrap_numpy) to work with
  10. TorchDynamo compilation
  11. - Autograd Hooks:
  12. Functions and classes for handling autograd hooks and backward passes
  13. (call_hook, FakeBackwardCFunction, etc.)
  14. - Tensor Operations:
  15. Utility functions for tensor operations and transformations
  16. """
  17. import functools
  18. import warnings
  19. from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
  20. from typing_extensions import deprecated, ParamSpec
  21. import torch
  22. import torch.utils._pytree as pytree
  23. try:
  24. import numpy as np
  25. except ModuleNotFoundError:
  26. np = None # type: ignore[assignment]
  27. _P = ParamSpec("_P")
  28. _R = TypeVar("_R")
  29. if TYPE_CHECKING:
  30. # TorchScript does not support `@deprecated`
  31. # This is a workaround to avoid breaking TorchScript
  32. @deprecated(
  33. "`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
  34. category=FutureWarning,
  35. )
  36. def is_compiling() -> bool:
  37. return torch.compiler.is_compiling()
  38. else:
  39. def is_compiling() -> bool:
  40. """
  41. Indicates whether we are tracing/compiling with torch.compile() or torch.export().
  42. """
  43. # NOTE: With `@torch.compile(backend="eager")`, torch._dynamo.is_compiling() will get traced
  44. # and return true. torch.compiler.is_compiling() is skipped and will return false.
  45. return torch.compiler.is_compiling()
  46. def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  47. """
  48. Create an extra frame around fn that is not in skipfiles.
  49. """
  50. @functools.wraps(fn)
  51. def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  52. return fn(*args, **kwargs)
  53. return inner
  54. def call_hook(
  55. hook: Callable[..., Optional[torch.Tensor]], *args: Any, **kwargs: Any
  56. ) -> torch.Tensor:
  57. """
  58. Used by compiled autograd to handle hook returning None.
  59. """
  60. result = hook(*args)
  61. if result is None:
  62. return args[0]
  63. elif kwargs.get("hook_type") == "post_acc_grad_hook":
  64. raise RuntimeError("Tensor post accumulate grad hooks should return None.")
  65. return result
  66. def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
  67. r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
  68. from ``torch.Tensor``s to ``torch.Tensor``s.
  69. """
  70. if not np:
  71. return f
  72. @functools.wraps(f)
  73. def wrap(*args: _P.args, **kwargs: _P.kwargs) -> pytree.PyTree:
  74. args, kwargs = pytree.tree_map_only(
  75. torch.Tensor, lambda x: x.numpy(), (args, kwargs)
  76. )
  77. out = f(*args, **kwargs)
  78. return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
  79. return wrap
  80. class FakeBackwardCFunction:
  81. def __init__(
  82. self,
  83. real: torch.autograd.function.BackwardCFunction,
  84. saved_tensors: list[torch.Tensor],
  85. ) -> None:
  86. self.real = real
  87. self.saved_tensors = saved_tensors
  88. def __getattr__(self, name: str) -> Any:
  89. if name == "saved_variables":
  90. warnings.warn(
  91. "'saved_variables' is deprecated; use 'saved_tensors'",
  92. DeprecationWarning,
  93. )
  94. return self.saved_tensors
  95. return getattr(self.real, name)
  96. def call_backward(
  97. backward_c_function: torch.autograd.function.BackwardCFunction,
  98. saved_tensors: list[torch.Tensor],
  99. *args: Any,
  100. ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
  101. fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
  102. grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined]
  103. if not isinstance(grads, tuple):
  104. grads = (grads,)
  105. return grads
  106. def normalize_as_list(x: Any) -> list[Any]:
  107. if isinstance(x, tuple):
  108. return list(x)
  109. elif isinstance(x, list):
  110. return x
  111. return [x]
  112. def untyped_storage_size(x: torch.Tensor) -> int:
  113. return x.untyped_storage().size()
  114. class FakeCompiledAutogradEngine:
  115. @staticmethod
  116. def queue_callback(
  117. final_callbacks: list[Callable[[], None]], cb: Callable[[], None]
  118. ) -> None:
  119. final_callbacks.append(cb)
  120. @staticmethod
  121. def exec_final_callbacks(final_callbacks: list[Callable[[], None]]) -> None:
  122. i = 0
  123. while i < len(final_callbacks):
  124. cb = final_callbacks[i]
  125. cb()
  126. i += 1
  127. final_callbacks.clear()
  128. @staticmethod
  129. def _exec_final_callbacks_stub() -> None:
  130. pass
  131. def call_hook_from_backward_state(
  132. *args: Any, bw_state: Any, hook_name: str, **kwargs: Any
  133. ) -> Any:
  134. return getattr(bw_state, hook_name)(*args, **kwargs)
  135. def call_module_hooks_from_backward_state(
  136. _: Any, result: Any, *args: Any, bw_state: Any, hooks_name: str, module_name: str
  137. ) -> Any:
  138. module = getattr(bw_state, module_name)
  139. hooks = getattr(bw_state, hooks_name)
  140. for hook in hooks:
  141. new_result = hook(module, result, *args)
  142. if new_result is not None:
  143. result = new_result
  144. return result
  145. # used for torch._dynamo.disable(recursive=False)
  146. def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  147. # wrap function to get the right error message
  148. # this function is in external_utils so that convert_frame doesn't skip it.
  149. @functools.wraps(fn)
  150. def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  151. return fn(*args, **kwargs)
  152. return nonrecursive_disable_wrapper
  153. def wrap_dunder_call_ctx_manager(self: Any, func: Callable[_P, _R]) -> Callable[_P, _R]:
  154. """
  155. Apply self as a ctx manager around a call to func
  156. """
  157. # NOTE: do not functools.wraps(func) because we don't ever want this frame to be skipped!
  158. def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  159. with self:
  160. return func(*args, **kwargs)
  161. return inner
  162. # Use only on ints marked dynamic via torch.empty(0, integer)
  163. # Currently only way to mark ints as dynamic: https://github.com/pytorch/pytorch/issues/129623
  164. def unwrap_maybe_dynamic_int(x: Union[torch.Tensor, int]) -> int:
  165. if isinstance(x, torch.Tensor):
  166. # x.size() is expected to be [0, dynamic_int]
  167. return x.size(1)
  168. return x
  169. def call_accumulate_grad(
  170. variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool
  171. ) -> None:
  172. updated_grad = torch._dynamo.compiled_autograd.ops.AccumulateGrad( # type: ignore[attr-defined]
  173. [grad], variable, variable.grad, has_post_hooks
  174. )
  175. variable.grad = updated_grad[0]
  176. def wrap_inline_with_error_on_graph_break(
  177. fn: Callable[_P, _R], error_on_graph_break: bool
  178. ) -> Callable[_P, _R]:
  179. # NB: need multiple definitions in order to prevent `fullgraph` from
  180. # being a freevar of wrapper
  181. # NOTE: do not functools.wraps(fn) because we don't ever want these wrappers to be skipped!
  182. if error_on_graph_break:
  183. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  184. with torch._dynamo.error_on_graph_break(True):
  185. return fn(*args, **kwargs)
  186. else:
  187. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  188. with torch._dynamo.error_on_graph_break(False):
  189. return fn(*args, **kwargs)
  190. return wrapper
  191. def filter_out_const_values(tup: tuple[Any, ...], masks: list[bool]) -> tuple[Any, ...]:
  192. """
  193. masks is a list of bools, where True means the corresponding element in tup
  194. is a const value. Filter out the const values.
  195. """
  196. out = []
  197. for mask_idx, mask in enumerate(masks):
  198. if not mask:
  199. out.append(tup[mask_idx])
  200. return tuple(out)
  201. def insert_const_values_with_mask(
  202. tup: tuple[Any, ...], masks: list[bool], values: tuple[Any, ...]
  203. ) -> tuple[Any, ...]:
  204. """
  205. masks and values are of same length. For indices where the mask is True, use
  206. the const_values to fill in.
  207. """
  208. out = []
  209. idx = 0
  210. for mask_idx, mask in enumerate(masks):
  211. if mask:
  212. out.append(values[mask_idx])
  213. else:
  214. out.append(tup[idx])
  215. idx += 1
  216. return tuple(out)