graphs.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. from __future__ import annotations
  2. import gc
  3. import typing
  4. from typing import Callable, Optional, overload, TYPE_CHECKING, Union
  5. from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar
  6. import torch
  7. from torch import Tensor
  8. if TYPE_CHECKING:
  9. # importing _POOL_HANDLE at runtime toplevel causes an import cycle
  10. from torch.cuda import _POOL_HANDLE
  11. from .._utils import _dummy_type
  12. __all__ = [
  13. "is_current_stream_capturing",
  14. "graph_pool_handle",
  15. "CUDAGraph",
  16. "graph",
  17. "make_graphed_callables",
  18. ]
  19. _R = TypeVar("_R")
  20. _P = ParamSpec("_P")
  21. if not hasattr(torch._C, "_CudaStreamBase"):
  22. # Define dummy base classes
  23. torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
  24. torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
  25. torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
  26. "_cuda_isCurrentStreamCapturing"
  27. )
  28. from torch._C import ( # noqa: F401
  29. _cuda_isCurrentStreamCapturing,
  30. _CUDAGraph,
  31. _graph_pool_handle,
  32. )
  33. def is_current_stream_capturing() -> bool:
  34. r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
  35. If a CUDA context does not exist on the current device, returns False without initializing the context.
  36. """
  37. return _cuda_isCurrentStreamCapturing()
  38. # Python shim helps Sphinx process docstrings more reliably.
  39. def graph_pool_handle() -> _POOL_HANDLE:
  40. r"""Return an opaque token representing the id of a graph memory pool.
  41. See :ref:`Graph memory management<graph-memory-management>`.
  42. .. warning::
  43. This API is in beta and may change in future releases.
  44. """
  45. return torch.cuda._POOL_HANDLE(_graph_pool_handle())
  46. # Python shim helps Sphinx process docstrings more reliably.
  47. class CUDAGraph(torch._C._CUDAGraph):
  48. r"""Wrapper around a CUDA graph.
  49. Arguments:
  50. keep_graph (bool, optional): If ``keep_graph=False``, the
  51. cudaGraphExec_t will be instantiated on GPU at the end of
  52. ``capture_end`` and the underlying cudaGraph_t will be
  53. destroyed. Users who want to query or otherwise modify the
  54. underlying cudaGraph_t before instantiation can set
  55. ``keep_graph=True`` and access it via ``raw_cuda_graph`` after
  56. ``capture_end``. Note that the cudaGraphExec_t will not be
  57. instantiated at the end of ``capture_end`` in this
  58. case. Instead, it will be instantiated via an explicit called
  59. to ``instantiate`` or automatically on the first call to
  60. ``replay`` if ``instantiate`` was not already called. Calling
  61. ``instantiate`` manually before ``replay`` is recommended to
  62. prevent increased latency on the first call to ``replay``. It
  63. is allowed to modify the raw cudaGraph_t after first calling
  64. ``instantiate``, but the user must call ``instantiate`` again
  65. manually to make sure the instantiated graph has these
  66. changes. Pytorch has no means of tracking these changes.
  67. .. warning::
  68. This API is in beta and may change in future releases.
  69. """
  70. def __new__(cls, keep_graph: bool = False) -> Self:
  71. return super().__new__(cls, keep_graph)
  72. def capture_begin(
  73. self, pool: Optional[_POOL_HANDLE] = None, capture_error_mode: str = "global"
  74. ) -> None:
  75. r"""Begin capturing CUDA work on the current stream.
  76. Typically, you shouldn't call ``capture_begin`` yourself.
  77. Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
  78. which call ``capture_begin`` internally.
  79. Arguments:
  80. pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
  81. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
  82. with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
  83. capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
  84. Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
  85. may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
  86. actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
  87. unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
  88. """ # noqa: B950
  89. super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
  90. def capture_end(self) -> None:
  91. r"""End CUDA graph capture on the current stream.
  92. After ``capture_end``, ``replay`` may be called on this instance.
  93. Typically, you shouldn't call ``capture_end`` yourself.
  94. Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
  95. which call ``capture_end`` internally.
  96. """
  97. super().capture_end()
  98. def instantiate(self) -> None:
  99. r"""Instantiate the CUDA graph. Will be called by
  100. ``capture_end`` if ``keep_graph=False``, or by ``replay`` if
  101. ``keep_graph=True`` and ``instantiate`` has not already been
  102. explicitly called. Does not destroy the cudaGraph_t returned
  103. by ``raw_cuda_graph``.
  104. """
  105. super().instantiate()
  106. def replay(self) -> None:
  107. r"""Replay the CUDA work captured by this graph."""
  108. super().replay()
  109. def reset(self) -> None:
  110. r"""Delete the graph currently held by this instance."""
  111. super().reset()
  112. def pool(self) -> _POOL_HANDLE:
  113. r"""Return an opaque token representing the id of this graph's memory pool.
  114. This id can optionally be passed to another graph's ``capture_begin``,
  115. which hints the other graph may share the same memory pool.
  116. """
  117. return super().pool()
  118. def enable_debug_mode(self) -> None:
  119. r"""Enable debugging mode for CUDAGraph.debug_dump."""
  120. return super().enable_debug_mode()
  121. def debug_dump(self, debug_path: str) -> None:
  122. r"""
  123. Arguments:
  124. debug_path (required): Path to dump the graph to.
  125. Calls a debugging function to dump the graph if the debugging is
  126. enabled via CUDAGraph.enable_debug_mode()
  127. """
  128. return super().debug_dump(debug_path)
  129. def raw_cuda_graph(self) -> int:
  130. r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True.
  131. See the following for APIs for how to manipulate this object: `Graph Managmement <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html>`_ and `cuda-python Graph Management bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-management>`_
  132. """ # noqa: B950
  133. return super().raw_cuda_graph()
  134. def raw_cuda_graph_exec(self) -> int:
  135. r"""Returns the underlying cudaGraphExec_t. ``instantiate`` must have been called if ``keep_graph`` is True, or ``capture_end`` must have been called if ``keep_graph`` is False. If you call ``instantiate()`` after ``raw_cuda_graph_exec()``, the previously returned cudaGraphExec_t will be destroyed. It is your responsibility not to use this object after destruction.
  136. See the following for APIs for how to manipulate this object: `Graph Execution <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH__EXEC.html>`_ and `cuda-python Graph Execution bindings <https://nvidia.github.io/cuda-python/cuda-bindings/latest/module/runtime.html#graph-execution>`_
  137. """ # noqa: B950
  138. return super().raw_cuda_graph_exec()
  139. class graph:
  140. r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
  141. See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
  142. detailed use, and constraints.
  143. Arguments:
  144. cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
  145. pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
  146. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
  147. may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
  148. stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
  149. If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
  150. capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
  151. Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
  152. may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
  153. actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
  154. unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
  155. .. note::
  156. For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
  157. used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
  158. .. warning::
  159. This API is in beta and may change in future releases.
  160. .. _cudaStreamCaptureMode:
  161. https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
  162. """ # noqa: B950
  163. default_capture_stream: Optional[torch.cuda.Stream] = None
  164. def __init__(
  165. self,
  166. cuda_graph: CUDAGraph,
  167. pool: Optional[_POOL_HANDLE] = None,
  168. stream: Optional[torch.cuda.Stream] = None,
  169. capture_error_mode: str = "global",
  170. ):
  171. # Lazy-init of default_capture_stream helps avoid circular-import errors.
  172. # Not thread safe, but graphs already have the general (explicitly documented)
  173. # restriction that only one capture may be underway at a time in the process.
  174. if self.__class__.default_capture_stream is None:
  175. self.__class__.default_capture_stream = torch.cuda.Stream()
  176. self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = (
  177. () if pool is None else (pool,)
  178. )
  179. self.capture_stream = (
  180. stream if stream is not None else self.__class__.default_capture_stream
  181. )
  182. assert self.capture_stream is not None
  183. self.stream_ctx = torch.cuda.stream(self.capture_stream)
  184. self.cuda_graph = cuda_graph
  185. self.capture_error_mode = capture_error_mode
  186. def __enter__(self) -> None:
  187. # Free as much memory as we can for the graph
  188. torch.cuda.synchronize()
  189. if torch.compiler.config.force_cudagraph_gc:
  190. # Originally we unconditionally garbage collected here. On one hand
  191. # that's nice because we have a chance to collect more memory, but
  192. # on the other hand it is REALLY expensive, especially for doing
  193. # multiple cudagraph captures in a row. In theory it will only help
  194. # when a dead python cycle is holding onto CUDA memory.
  195. gc.collect()
  196. torch.cuda.empty_cache()
  197. # Stackoverflow seems comfortable with this pattern
  198. # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
  199. self.stream_ctx.__enter__()
  200. self.cuda_graph.capture_begin(
  201. # type: ignore[misc]
  202. *self.pool,
  203. capture_error_mode=self.capture_error_mode,
  204. )
  205. def __exit__(self, *args: object) -> None:
  206. self.cuda_graph.capture_end()
  207. self.stream_ctx.__exit__(*args)
  208. # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
  209. _ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]]
  210. @overload
  211. def make_graphed_callables(
  212. callables: _ModuleOrCallable,
  213. sample_args: tuple[Tensor, ...],
  214. num_warmup_iters: int = 3,
  215. allow_unused_input: bool = False,
  216. pool: Optional[_POOL_HANDLE] = None,
  217. ) -> _ModuleOrCallable: ...
  218. @overload
  219. def make_graphed_callables(
  220. callables: tuple[_ModuleOrCallable, ...],
  221. sample_args: tuple[tuple[Tensor, ...], ...],
  222. num_warmup_iters: int = 3,
  223. allow_unused_input: bool = False,
  224. pool: Optional[_POOL_HANDLE] = None,
  225. ) -> tuple[_ModuleOrCallable, ...]: ...
  226. def make_graphed_callables(
  227. callables: Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]],
  228. sample_args: Union[tuple[Tensor, ...], tuple[tuple[Tensor, ...], ...]],
  229. num_warmup_iters: int = 3,
  230. allow_unused_input: bool = False,
  231. pool: Optional[_POOL_HANDLE] = None,
  232. ) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]:
  233. r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
  234. Each graphed callable's forward pass runs its source callable's
  235. forward CUDA work as a CUDA graph inside a single autograd node.
  236. The graphed callable's forward pass also appends
  237. a backward node to the autograd graph. During backward, this node runs the
  238. callable's backward work as a CUDA graph.
  239. Therefore, each graphed callable should be a drop-in replacement for its source callable
  240. in an autograd-enabled training loop.
  241. See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
  242. If you pass a tuple of several callables, their captures will use the same memory pool.
  243. See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
  244. Arguments:
  245. callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
  246. See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
  247. is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
  248. they'll run in the live workload.
  249. sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
  250. If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
  251. If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
  252. num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
  253. 11 iterations for warm up. Default: ``3``.
  254. allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
  255. (and therefore their grad is always zero) is an error. Defaults to False.
  256. pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
  257. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
  258. with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
  259. .. note::
  260. The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
  261. that's expected for the corresponding real input in the training loop.
  262. .. warning::
  263. This API is in beta and may change in future releases.
  264. .. warning::
  265. ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
  266. .. warning::
  267. Returned callables do not support higher order differentiation (e.g., double backward).
  268. .. warning::
  269. In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
  270. may be trainable. Buffers must have ``requires_grad=False``.
  271. .. warning::
  272. After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
  273. you may not add or remove any of that Module's parameters or buffers.
  274. .. warning::
  275. :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
  276. registered on them at the time they are passed. However, registering hooks on modules *after* passing them
  277. through :func:`~torch.cuda.make_graphed_callables` is allowed.
  278. .. warning::
  279. When running a graphed callable, you must pass its arguments in the same order and format
  280. they appeared in that callable's ``sample_args``.
  281. .. warning::
  282. The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
  283. caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
  284. """
  285. if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
  286. raise RuntimeError(
  287. "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
  288. )
  289. just_one_callable = False
  290. _sample_args: tuple[tuple[Tensor, ...], ...]
  291. if not isinstance(callables, tuple):
  292. just_one_callable = True
  293. callables = (callables,)
  294. _sample_args = (typing.cast(tuple[Tensor, ...], sample_args),)
  295. else:
  296. _sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args)
  297. flatten_sample_args = []
  298. for c, args in zip(callables, _sample_args):
  299. if isinstance(c, torch.nn.Module):
  300. assert (
  301. len(c._backward_hooks) == 0
  302. and len(c._forward_hooks) == 0
  303. and len(c._forward_pre_hooks) == 0
  304. ), (
  305. "Modules must not have hooks registered at the time they are passed. However, registering hooks "
  306. + "on modules after passing them through make_graphed_callables is allowed."
  307. )
  308. assert all(b.requires_grad is False for b in c.buffers()), (
  309. "In any :class:`~torch.nn.Module` passed to "
  310. + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
  311. + "``requires_grad=False``."
  312. )
  313. flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
  314. flatten_sample_args.append(tuple(flatten_arg))
  315. assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
  316. "In the beta API, sample_args "
  317. + "for each callable must contain only Tensors. Other types are not allowed."
  318. )
  319. # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
  320. # passes to forward (ie, its sample_args) AND the module's parameter attributes.
  321. per_callable_len_user_args = [len(args) for args in flatten_sample_args]
  322. per_callable_module_params = [
  323. tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
  324. for c in callables
  325. ]
  326. per_callable_static_input_surfaces = [
  327. flatten_sample_args[i] + per_callable_module_params[i]
  328. for i in range(len(callables))
  329. ]
  330. fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
  331. bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
  332. mempool = graph_pool_handle() if pool is None else pool
  333. # Warmup
  334. # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
  335. # from ending up in any captures.
  336. torch.cuda.synchronize()
  337. with torch.cuda.stream(torch.cuda.Stream()):
  338. for func, args, static_input_surface in zip(
  339. callables, _sample_args, per_callable_static_input_surfaces
  340. ):
  341. grad_inputs, outputs, outputs_grad = None, None, None
  342. for _ in range(num_warmup_iters):
  343. outputs = torch.utils._pytree.tree_leaves(func(*args))
  344. outputs_grad = tuple(o for o in outputs if o.requires_grad)
  345. if len(outputs_grad) > 0:
  346. grad_inputs = torch.autograd.grad(
  347. outputs=outputs_grad,
  348. inputs=tuple(
  349. i for i in static_input_surface if i.requires_grad
  350. ),
  351. grad_outputs=tuple(
  352. torch.empty_like(o) for o in outputs if o.requires_grad
  353. ),
  354. only_inputs=True,
  355. allow_unused=allow_unused_input,
  356. )
  357. for v in [outputs, outputs_grad, grad_inputs]:
  358. del v
  359. torch.cuda.synchronize()
  360. # All captures here share a mempool. To avoid replays corrupting each other's memory,
  361. # the safest approach is to capture all passes in the same order they'll run:
  362. # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
  363. # Capture forward graphs
  364. per_callable_static_outputs = []
  365. per_callable_output_unflatten_spec = []
  366. for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs):
  367. with torch.cuda.graph(fwd_graph, pool=mempool):
  368. func_outputs = func(*args)
  369. flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs)
  370. per_callable_static_outputs.append(tuple(flatten_outputs))
  371. per_callable_output_unflatten_spec.append(spec)
  372. # Capture backward graphs in reverse order
  373. per_callable_static_grad_outputs = []
  374. per_callable_static_grad_inputs = []
  375. for static_input_surface, static_outputs, bwd_graph in zip(
  376. reversed(per_callable_static_input_surfaces),
  377. reversed(per_callable_static_outputs),
  378. reversed(bwd_graphs),
  379. ):
  380. # For now, assumes all static_outputs require grad
  381. # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
  382. static_grad_outputs = tuple(
  383. torch.empty_like(o) if o.requires_grad else None for o in static_outputs
  384. )
  385. outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
  386. grad_inputs = None
  387. if len(outputs_grad) > 0:
  388. with torch.cuda.graph(bwd_graph, pool=mempool):
  389. grad_inputs = torch.autograd.grad(
  390. outputs=outputs_grad,
  391. inputs=tuple(i for i in static_input_surface if i.requires_grad),
  392. grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
  393. only_inputs=True,
  394. allow_unused=allow_unused_input,
  395. )
  396. # Constructs a tuple suitable for returning from Graphed.backward:
  397. # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
  398. # I couldn't think of a slick one-liner for this pattern.
  399. static_grad_inputs = []
  400. grad_idx = 0
  401. for arg in static_input_surface:
  402. if arg.requires_grad and grad_inputs is not None:
  403. static_grad_inputs.append(grad_inputs[grad_idx])
  404. grad_idx += 1
  405. else:
  406. static_grad_inputs.append(None) # type: ignore[arg-type]
  407. static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
  408. per_callable_static_grad_outputs.append(static_grad_outputs)
  409. per_callable_static_grad_inputs.append(static_grad_inputs)
  410. # Reverses the most recent two lists
  411. per_callable_static_grad_outputs.reverse()
  412. per_callable_static_grad_inputs.reverse()
  413. # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
  414. def make_graphed_autograd_function(
  415. fwd_graph: CUDAGraph,
  416. bwd_graph: CUDAGraph,
  417. module_params: tuple[torch.nn.Parameter, ...],
  418. len_user_args: int,
  419. output_unflatten_spec: torch.utils._pytree.TreeSpec,
  420. static_input_surface: tuple[Tensor, ...],
  421. static_outputs: tuple[Tensor, ...],
  422. static_grad_outputs: tuple[Optional[Tensor], ...],
  423. static_grad_inputs: tuple[Tensor, ...],
  424. ) -> Callable[..., object]:
  425. class Graphed(torch.autograd.Function):
  426. @staticmethod
  427. def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
  428. # At this stage, only the user args may (potentially) be new tensors.
  429. for i in range(len_user_args):
  430. if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
  431. static_input_surface[i].copy_(inputs[i])
  432. fwd_graph.replay()
  433. assert isinstance(static_outputs, tuple)
  434. return tuple(o.detach() for o in static_outputs)
  435. @staticmethod
  436. @torch.autograd.function.once_differentiable
  437. def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
  438. assert len(grads) == len(static_grad_outputs)
  439. for g, grad in zip(static_grad_outputs, grads):
  440. if g is not None:
  441. # don't copy if autograd gods have been kind and the
  442. # incoming grad is already in the right place
  443. if g.data_ptr() != grad.data_ptr():
  444. g.copy_(grad)
  445. bwd_graph.replay()
  446. # Input args that didn't require grad expect a None gradient.
  447. assert isinstance(static_grad_inputs, tuple)
  448. return tuple(
  449. b.detach() if b is not None else b for b in static_grad_inputs
  450. )
  451. def functionalized(*user_args: object) -> object:
  452. # Runs the autograd function with inputs == all inputs to the graph that might require grad
  453. # (explicit user args + module parameters)
  454. # Assumes module params didn't change since capture.
  455. flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
  456. out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
  457. return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)
  458. return functionalized
  459. # Put together the final graphed callables
  460. ret: list[_ModuleOrCallable] = []
  461. for i, func in enumerate(callables):
  462. graphed = make_graphed_autograd_function(
  463. fwd_graphs[i],
  464. bwd_graphs[i],
  465. per_callable_module_params[i],
  466. per_callable_len_user_args[i],
  467. per_callable_output_unflatten_spec[i],
  468. per_callable_static_input_surfaces[i],
  469. per_callable_static_outputs[i],
  470. per_callable_static_grad_outputs[i],
  471. per_callable_static_grad_inputs[i],
  472. )
  473. if isinstance(func, torch.nn.Module):
  474. def make_graphed_forward(
  475. func: torch.nn.Module,
  476. graph_training_state: bool,
  477. graphed: Callable[_P, _R],
  478. orig_fwd: Callable[_P, _R],
  479. ) -> Callable[_P, _R]:
  480. def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R:
  481. # If the module's training-or-eval state matches what we graphed,
  482. # run the graph, otherwise run the original forward method
  483. if func.training == graph_training_state:
  484. return graphed(*user_args, **user_kwargs)
  485. else:
  486. return orig_fwd(*user_args, **user_kwargs)
  487. return new_fwd
  488. func.forward = make_graphed_forward(
  489. func, func.training, graphed, func.forward
  490. )
  491. ret.append(func)
  492. else:
  493. ret.append(graphed)
  494. if just_one_callable:
  495. return ret[0]
  496. return tuple(ret)