function.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import inspect
  4. import itertools
  5. import warnings
  6. from collections import OrderedDict
  7. from typing import Any, Callable, Optional, TypeVar
  8. from typing_extensions import Concatenate, deprecated, ParamSpec
  9. import torch
  10. import torch._C as _C
  11. import torch._functorch as _functorch
  12. import torch.utils.hooks as hooks
  13. from torch._C import _functions
  14. from torch._functorch.autograd_function import custom_function_call
  15. __all__ = [
  16. "FunctionCtx",
  17. "BackwardCFunction",
  18. "FunctionMeta",
  19. "Function",
  20. "once_differentiable",
  21. "InplaceFunction",
  22. "NestedIOFunction",
  23. ]
  24. # Unique id provider for each class inheriting from Function
  25. # This is incremented in FunctionMeta during class definition
  26. AUTOGRAD_FUNCTION_COUNTER = itertools.count()
  27. _T = TypeVar("_T")
  28. _R = TypeVar("_R")
  29. _P = ParamSpec("_P")
  30. # Formerly known as: _ContextMethodMixin
  31. class FunctionCtx:
  32. def save_for_backward(self, *tensors: torch.Tensor):
  33. r"""Save given tensors for a future call to :func:`~Function.backward`.
  34. ``save_for_backward`` should be called at most once, in either the
  35. :func:`setup_context` or :func:`forward` methods, and only with tensors.
  36. All tensors intended to be used in the backward pass should be saved
  37. with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
  38. incorrect gradients and memory leaks, and enable the application of saved
  39. tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
  40. See :ref:`extending-autograd` for more details.
  41. Note that if intermediary tensors, tensors that are neither inputs
  42. nor outputs of :func:`forward`, are saved for backward, your custom Function
  43. may not support double backward.
  44. Custom Functions that do not support double backward should decorate their
  45. :func:`backward` method with ``@once_differentiable`` so that performing
  46. double backward raises an error. If you'd like to support double backward,
  47. you can either recompute intermediaries based on the inputs during backward
  48. or return the intermediaries as the outputs of the custom Function. See the
  49. `double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_
  50. for more details.
  51. In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
  52. attribute. Before returning them to the user, a check is made to ensure
  53. they weren't used in any in-place operation that modified their content.
  54. Arguments can also be ``None``. This is a no-op.
  55. See :ref:`extending-autograd` for more details on how to use this method.
  56. Example::
  57. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  58. >>> class Func(Function):
  59. >>> @staticmethod
  60. >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
  61. >>> w = x * z
  62. >>> out = x * y + y * z + w * y
  63. >>> ctx.save_for_backward(x, y, w, out)
  64. >>> ctx.z = z # z is not a tensor
  65. >>> return out
  66. >>>
  67. >>> @staticmethod
  68. >>> @once_differentiable
  69. >>> def backward(ctx, grad_out):
  70. >>> x, y, w, out = ctx.saved_tensors
  71. >>> z = ctx.z
  72. >>> gx = grad_out * (y + y * z)
  73. >>> gy = grad_out * (x + z + w)
  74. >>> gz = None
  75. >>> return gx, gy, gz
  76. >>>
  77. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
  78. >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
  79. >>> c = 4
  80. >>> d = Func.apply(a, b, c)
  81. """
  82. self.to_save = tensors
  83. def save_for_forward(self, *tensors: torch.Tensor):
  84. r"""Save given tensors for a future call to :func:`~Function.jvp`.
  85. ``save_for_forward`` should be called at most once, in either the
  86. :func:`setup_context` or :func:`forward` methods, and all arguments
  87. should be tensors.
  88. In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors`
  89. attribute.
  90. Arguments can also be ``None``. This is a no-op.
  91. See :ref:`extending-autograd` for more details on how to use this method.
  92. Example::
  93. >>> # xdoctest: +SKIP
  94. >>> class Func(torch.autograd.Function):
  95. >>> @staticmethod
  96. >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
  97. >>> ctx.save_for_backward(x, y)
  98. >>> ctx.save_for_forward(x, y)
  99. >>> ctx.z = z
  100. >>> return x * y * z
  101. >>>
  102. >>> @staticmethod
  103. >>> def jvp(ctx, x_t, y_t, _):
  104. >>> x, y = ctx.saved_tensors
  105. >>> z = ctx.z
  106. >>> return z * (y * x_t + x * y_t)
  107. >>>
  108. >>> @staticmethod
  109. >>> def vjp(ctx, grad_out):
  110. >>> x, y = ctx.saved_tensors
  111. >>> z = ctx.z
  112. >>> return z * grad_out * y, z * grad_out * x, None
  113. >>>
  114. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
  115. >>> t = torch.tensor(1., dtype=torch.double)
  116. >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
  117. >>> c = 4
  118. >>>
  119. >>> with fwAD.dual_level():
  120. >>> a_dual = fwAD.make_dual(a, t)
  121. >>> d = Func.apply(a_dual, b, c)
  122. """
  123. for tensor in tensors:
  124. assert isinstance(tensor, torch.Tensor) or tensor is None, (
  125. "save_for_forward expects all arguments to be tensors; you should "
  126. "save non-tensors as attributes on ctx."
  127. )
  128. self.saved_for_forward = tensors
  129. def mark_dirty(self, *args: torch.Tensor):
  130. r"""Mark given tensors as modified in an in-place operation.
  131. This should be called at most once, in either the :func:`setup_context`
  132. or :func:`forward` methods, and all arguments should be inputs.
  133. Every tensor that's been modified in-place in a call to :func:`forward`
  134. should be given to this function, to ensure correctness of our checks.
  135. It doesn't matter whether the function is called before or after
  136. modification.
  137. Examples::
  138. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  139. >>> class Inplace(Function):
  140. >>> @staticmethod
  141. >>> def forward(ctx, x):
  142. >>> x_npy = x.numpy() # x_npy shares storage with x
  143. >>> x_npy += 1
  144. >>> ctx.mark_dirty(x)
  145. >>> return x
  146. >>>
  147. >>> @staticmethod
  148. >>> @once_differentiable
  149. >>> def backward(ctx, grad_output):
  150. >>> return grad_output
  151. >>>
  152. >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
  153. >>> b = a * a
  154. >>> Inplace.apply(a) # This would lead to wrong gradients!
  155. >>> # but the engine would not know unless we mark_dirty
  156. >>> # xdoctest: +SKIP
  157. >>> b.backward() # RuntimeError: one of the variables needed for gradient
  158. >>> # computation has been modified by an inplace operation
  159. """
  160. self.dirty_tensors = args
  161. @deprecated(
  162. "`mark_shared_storage` is deprecated. "
  163. "Tensors with shared storages are automatically tracked. "
  164. "Note that calls to `set_()` are not tracked",
  165. category=FutureWarning,
  166. )
  167. def mark_shared_storage(self, *pairs):
  168. pass
  169. def mark_non_differentiable(self, *args: torch.Tensor):
  170. r"""Mark outputs as non-differentiable.
  171. This should be called at most once, in either the :func:`setup_context`
  172. or :func:`forward` methods, and all arguments should be tensor outputs.
  173. This will mark outputs as not requiring gradients, increasing the
  174. efficiency of backward computation. You still need to accept a gradient
  175. for each output in :meth:`~Function.backward`, but it's always going to
  176. be a zero tensor with the same shape as the shape of a corresponding
  177. output.
  178. This is used e.g. for indices returned from a sort. See example::
  179. >>> class Func(Function):
  180. >>> @staticmethod
  181. >>> def forward(ctx, x):
  182. >>> sorted, idx = x.sort()
  183. >>> ctx.mark_non_differentiable(idx)
  184. >>> ctx.save_for_backward(x, idx)
  185. >>> return sorted, idx
  186. >>>
  187. >>> @staticmethod
  188. >>> @once_differentiable
  189. >>> def backward(ctx, g1, g2): # still need to accept g2
  190. >>> x, idx = ctx.saved_tensors
  191. >>> grad_input = torch.zeros_like(x)
  192. >>> grad_input.index_add_(0, idx, g1)
  193. >>> return grad_input
  194. """
  195. self.non_differentiable = args
  196. def set_materialize_grads(self, value: bool):
  197. r"""Set whether to materialize grad tensors. Default is ``True``.
  198. This should be called only from either the :func:`setup_context` or
  199. :func:`forward` methods.
  200. If ``True``, undefined grad tensors will be expanded to tensors full of zeros
  201. prior to calling the :func:`backward` and :func:`jvp` methods.
  202. Example::
  203. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  204. >>> class SimpleFunc(Function):
  205. >>> @staticmethod
  206. >>> def forward(ctx, x):
  207. >>> return x.clone(), x.clone()
  208. >>>
  209. >>> @staticmethod
  210. >>> @once_differentiable
  211. >>> def backward(ctx, g1, g2):
  212. >>> return g1 + g2 # No check for None necessary
  213. >>>
  214. >>> # We modify SimpleFunc to handle non-materialized grad outputs
  215. >>> class Func(Function):
  216. >>> @staticmethod
  217. >>> def forward(ctx, x):
  218. >>> ctx.set_materialize_grads(False)
  219. >>> ctx.save_for_backward(x)
  220. >>> return x.clone(), x.clone()
  221. >>>
  222. >>> @staticmethod
  223. >>> @once_differentiable
  224. >>> def backward(ctx, g1, g2):
  225. >>> x, = ctx.saved_tensors
  226. >>> grad_input = torch.zeros_like(x)
  227. >>> if g1 is not None: # We must check for None now
  228. >>> grad_input += g1
  229. >>> if g2 is not None:
  230. >>> grad_input += g2
  231. >>> return grad_input
  232. >>>
  233. >>> a = torch.tensor(1., requires_grad=True)
  234. >>> b, _ = Func.apply(a) # induces g2 to be undefined
  235. """
  236. self.materialize_grads = value
  237. # DO NOT USE: This is only defined to be able to load old serialized models
  238. _ContextMethodMixin = FunctionCtx
  239. class _HookMixin:
  240. @staticmethod
  241. def _register_hook(backward_hooks, hook):
  242. if backward_hooks is None:
  243. backward_hooks = OrderedDict()
  244. handle = hooks.RemovableHandle(backward_hooks)
  245. backward_hooks[handle.id] = hook
  246. return backward_hooks, handle
  247. class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
  248. r"""
  249. This class is used for internal autograd work. Do not use.
  250. """
  251. def apply(self, *args):
  252. r"""
  253. Apply method used when executing this Node during the backward
  254. """
  255. # _forward_cls is defined by derived class
  256. # The user should define either backward or vjp but never both.
  257. backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
  258. vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined]
  259. if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
  260. raise RuntimeError(
  261. "Implementing both 'backward' and 'vjp' for a custom "
  262. "Function is not allowed. You should only implement one "
  263. "of them."
  264. )
  265. user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
  266. return user_fn(self, *args)
  267. def apply_jvp(self, *args):
  268. r"""
  269. Apply method used when executing forward mode AD during the forward
  270. """
  271. # _forward_cls is defined by derived class
  272. return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined]
  273. def _compiled_autograd_key(self):
  274. return self._forward_cls._compiled_autograd_key(self) # type: ignore[attr-defined]
  275. class FunctionMeta(type):
  276. """Function metaclass.
  277. This metaclass sets up the following properties:
  278. _backward_cls: The Function class corresponding to the differentiated
  279. version of this function (which is generated on the fly by this
  280. metaclass).
  281. """
  282. def __init__(cls, name, bases, attrs):
  283. backward_fn = type(
  284. name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
  285. )
  286. backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined]
  287. cls._backward_cls = backward_fn
  288. super().__init__(name, bases, attrs)
  289. class _SingleLevelFunction(
  290. _C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta
  291. ):
  292. @staticmethod
  293. def forward(*args: Any, **kwargs: Any) -> Any:
  294. r"""Define the forward of the custom autograd Function.
  295. This function is to be overridden by all subclasses.
  296. There are two ways to define forward:
  297. Usage 1 (Combined forward and ctx)::
  298. @staticmethod
  299. def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
  300. pass
  301. - It must accept a context ctx as the first argument, followed by any
  302. number of arguments (tensors or other types).
  303. - See :ref:`combining-forward-context` for more details
  304. Usage 2 (Separate forward and ctx)::
  305. @staticmethod
  306. def forward(*args: Any, **kwargs: Any) -> Any:
  307. pass
  308. @staticmethod
  309. def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
  310. pass
  311. - The forward no longer accepts a ctx argument.
  312. - Instead, you must also override the :meth:`torch.autograd.Function.setup_context`
  313. staticmethod to handle setting up the ``ctx`` object.
  314. ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs
  315. to the forward.
  316. - See :ref:`extending-autograd` for more details
  317. The context can be used to store arbitrary data that can be then
  318. retrieved during the backward pass. Tensors should not be stored
  319. directly on `ctx` (though this is not currently enforced for
  320. backward compatibility). Instead, tensors should be saved either with
  321. :func:`ctx.save_for_backward` if they are intended to be used in
  322. ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
  323. if they are intended to be used for in ``jvp``.
  324. """
  325. raise NotImplementedError(
  326. "You must implement the forward function for custom autograd.Function."
  327. )
  328. @staticmethod
  329. def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> Any:
  330. r"""There are two ways to define the forward pass of an autograd.Function.
  331. Either:
  332. 1. Override forward with the signature ``forward(ctx, *args, **kwargs)``.
  333. ``setup_context`` is not overridden. Setting up the ctx for backward
  334. happens inside the ``forward``.
  335. 2. Override forward with the signature ``forward(*args, **kwargs)`` and
  336. override ``setup_context``. Setting up the ctx for backward happens
  337. inside ``setup_context`` (as opposed to inside the ``forward``)
  338. See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details.
  339. """
  340. raise NotImplementedError("setup_context is not implemented.")
  341. @staticmethod
  342. def backward(ctx: Any, *grad_outputs: Any) -> Any:
  343. r"""Define a formula for differentiating the operation with backward mode automatic differentiation.
  344. This function is to be overridden by all subclasses.
  345. (Defining this function is equivalent to defining the ``vjp`` function.)
  346. It must accept a context :attr:`ctx` as the first argument, followed by
  347. as many outputs as the :func:`forward` returned (None will be passed in
  348. for non tensor outputs of the forward function),
  349. and it should return as many tensors, as there were inputs to
  350. :func:`forward`. Each argument is the gradient w.r.t the given output,
  351. and each returned value should be the gradient w.r.t. the
  352. corresponding input. If an input is not a Tensor or is a Tensor not
  353. requiring grads, you can just pass None as a gradient for that input.
  354. The context can be used to retrieve tensors saved during the forward
  355. pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
  356. of booleans representing whether each input needs gradient. E.g.,
  357. :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
  358. first input to :func:`forward` needs gradient computed w.r.t. the
  359. output.
  360. """
  361. raise NotImplementedError(
  362. "You must implement either the backward or vjp method for "
  363. "your custom autograd.Function to use it with backward "
  364. "mode AD."
  365. )
  366. # vjp and backward are alias of each other
  367. vjp = backward
  368. @staticmethod
  369. def jvp(ctx: Any, *grad_inputs: Any) -> Any:
  370. r"""Define a formula for differentiating the operation with forward mode automatic differentiation.
  371. This function is to be overridden by all subclasses.
  372. It must accept a context :attr:`ctx` as the first argument, followed by
  373. as many inputs as the :func:`forward` got (None will be passed in
  374. for non tensor inputs of the forward function),
  375. and it should return as many tensors as there were outputs to
  376. :func:`forward`. Each argument is the gradient w.r.t the given input,
  377. and each returned value should be the gradient w.r.t. the
  378. corresponding output. If an output is not a Tensor or the function is not
  379. differentiable with respect to that output, you can just pass None as a
  380. gradient for that input.
  381. You can use the :attr:`ctx` object to pass any value from the forward to this
  382. functions.
  383. """
  384. raise NotImplementedError(
  385. "You must implement the jvp function for custom "
  386. "autograd.Function to use it with forward mode AD."
  387. )
  388. class Function(_SingleLevelFunction):
  389. r"""Base class to create custom `autograd.Function`.
  390. To create a custom `autograd.Function`, subclass this class and implement
  391. the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom
  392. op in the forward pass, call the class method ``apply``. Do not call
  393. :meth:`forward` directly.
  394. To ensure correctness and best performance, make sure you are calling the
  395. correct methods on ``ctx`` and validating your backward function using
  396. :func:`torch.autograd.gradcheck`.
  397. See :ref:`extending-autograd` for more details on how to use this class.
  398. Examples::
  399. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  400. >>> class Exp(Function):
  401. >>> @staticmethod
  402. >>> def forward(ctx, i):
  403. >>> result = i.exp()
  404. >>> ctx.save_for_backward(result)
  405. >>> return result
  406. >>>
  407. >>> @staticmethod
  408. >>> def backward(ctx, grad_output):
  409. >>> result, = ctx.saved_tensors
  410. >>> return grad_output * result
  411. >>>
  412. >>> # Use it by calling the apply method:
  413. >>> # xdoctest: +SKIP
  414. >>> output = Exp.apply(input)
  415. """
  416. def __init__(self, *args, **kwargs):
  417. warnings.warn(
  418. f"{self.__class__} should not be instantiated. Methods on autograd functions "
  419. "are all static, so you should invoke them on the class itself. "
  420. "Instantiating an autograd function will raise an "
  421. "error in a future version of PyTorch.",
  422. DeprecationWarning,
  423. stacklevel=2,
  424. )
  425. def __call__(self, *args, **kwargs):
  426. raise RuntimeError(
  427. "Legacy autograd function with non-static forward method is deprecated. "
  428. "Please use new-style autograd function with static forward method. "
  429. "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)"
  430. )
  431. """
  432. Bool that specifies if PyTorch should attempt to autogenerate
  433. :func:`torch.vmap` support for this autograd.Function. You may set this to
  434. True only if this autograd.Function's forward, backward, and jvp (if they
  435. exist) are written using PyTorch operations; otherwise, please override
  436. :meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`.
  437. Please see :ref:`func-autograd-function` for more details.
  438. """
  439. generate_vmap_rule = False
  440. @staticmethod
  441. def vmap(info, in_dims, *args):
  442. r"""Define the behavior for this autograd.Function underneath :func:`torch.vmap`.
  443. For a :func:`torch.autograd.Function` to support
  444. :func:`torch.vmap`, you must either override this static method, or set
  445. ``generate_vmap_rule`` to ``True`` (you may not do both).
  446. If you choose to override this staticmethod: it must accept
  447. - an ``info`` object as the first argument. ``info.batch_size``
  448. specifies the size of the dimension being vmapped over,
  449. while ``info.randomness`` is the randomness option passed to
  450. :func:`torch.vmap`.
  451. - an ``in_dims`` tuple as the second argument.
  452. For each arg in ``args``, ``in_dims`` has a corresponding
  453. ``Optional[int]``. It is ``None`` if the arg is not a Tensor or if
  454. the arg is not being vmapped over, otherwise, it is an integer
  455. specifying what dimension of the Tensor is being vmapped over.
  456. - ``*args``, which is the same as the args to :meth:`~Function.forward`.
  457. The return of the vmap staticmethod is a tuple of ``(output, out_dims)``.
  458. Similar to ``in_dims``, ``out_dims`` should be of the same structure as
  459. ``output`` and contain one ``out_dim`` per output that specifies if the
  460. output has the vmapped dimension and what index it is in.
  461. Please see :ref:`func-autograd-function` for more details.
  462. """
  463. raise NotImplementedError(
  464. "To use autograd.Function with vmap, you must either override the "
  465. "vmap staticmethod or set generate_vmap_rule=True."
  466. )
  467. @classmethod
  468. def apply(cls, *args, **kwargs):
  469. def bind_default_args(func, *args, **kwargs):
  470. signature = inspect.signature(func)
  471. bound_args = signature.bind(*args, **kwargs)
  472. bound_args.apply_defaults()
  473. return bound_args.args
  474. is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context)
  475. if is_setup_ctx_defined:
  476. args = bind_default_args(cls.forward, *args, **kwargs)
  477. if not torch._C._are_functorch_transforms_active():
  478. # See NOTE: [functorch vjp and autograd interaction]
  479. args = _functorch.utils.unwrap_dead_wrappers(args)
  480. return super().apply(*args, **kwargs) # type: ignore[misc]
  481. if not is_setup_ctx_defined:
  482. raise RuntimeError(
  483. "In order to use an autograd.Function with functorch transforms "
  484. "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
  485. "staticmethod. For more details, please see "
  486. "https://pytorch.org/docs/main/notes/extending.func.html"
  487. )
  488. return custom_function_call(cls, *args, **kwargs)
  489. @staticmethod
  490. def _compiled_autograd_key(ctx):
  491. return (ctx._autograd_function_id,)
  492. def _is_setup_context_defined(fn):
  493. return fn != _SingleLevelFunction.setup_context
  494. def once_differentiable(
  495. fn: Callable[Concatenate[_T, _P], _R],
  496. ) -> Callable[Concatenate[_T, _P], _R]:
  497. @functools.wraps(fn)
  498. def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R:
  499. with torch.no_grad():
  500. outputs = fn(ctx, *args, **kwargs)
  501. if not torch.is_grad_enabled():
  502. return outputs
  503. # If any of the inputs have requires_grad=True, we force the outputs
  504. # to have requires_grad=True but point to a grad_fn which throws an
  505. # error message during (double) back-propagation.
  506. # XXX: this is only an approximation of requires_grad - there's no way
  507. # to figure out if fn didn't use ctx.saved_tensors and as a result
  508. # some Tensors might require grad, even if no args do.
  509. # Unfortunately, this leads to unexpected error messages ("no nodes
  510. # require computing gradients"), but I don't have a better idea.
  511. # These functions would raise an error in backward anyway.
  512. requires_grad = any(
  513. isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args
  514. )
  515. if not requires_grad:
  516. return outputs
  517. if not isinstance(outputs, tuple):
  518. outputs_ = (outputs,)
  519. else:
  520. outputs_ = outputs
  521. err_fn = _functions.DelayedError(
  522. b"trying to differentiate twice a function that was marked "
  523. b"with @once_differentiable",
  524. len(outputs_),
  525. )
  526. # Create aliases of each output that has requires_grad=True. We need
  527. # at least one of the inputs to err_fn to require grad so that the
  528. # output will have a grad_fn.
  529. def fake_requires_grad(var):
  530. if var is not None:
  531. var = var.detach()
  532. var.requires_grad = True
  533. return var
  534. return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value]
  535. return wrapper
  536. class InplaceFunction(Function):
  537. r"""
  538. This class is here only for backward compatibility reasons.
  539. Use :class:`Function` instead of this for any new use case.
  540. """
  541. def __init__(self, inplace=False):
  542. super().__init__()
  543. self.inplace = inplace
  544. def _nested_map(condition, fn, condition_msg=None):
  545. def _map(obj):
  546. if condition(obj):
  547. return fn(obj)
  548. elif obj is None:
  549. return None
  550. elif isinstance(obj, (list, tuple)):
  551. mapped = (_map(x) for x in obj)
  552. if hasattr(obj, "_fields"):
  553. # obj is namedtuple
  554. return type(obj)(*mapped)
  555. return type(obj)(mapped)
  556. elif isinstance(obj, dict):
  557. return {x: _map(obj[x]) for x in obj}
  558. else:
  559. raise ValueError(
  560. "Auto nesting doesn't know how to process "
  561. "an input object of type "
  562. + torch.typename(obj)
  563. + (
  564. ". Accepted types: " + condition_msg + ", or lists/tuples of them"
  565. if condition_msg
  566. else ""
  567. )
  568. )
  569. return _map
  570. def _jit_unwrap_structured(obj):
  571. if hasattr(obj, "_jit_unwrap"):
  572. return obj._jit_unwrap()
  573. return obj
  574. def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None):
  575. def _iter(obj):
  576. if conversion is not None:
  577. obj = conversion(obj)
  578. if condition(obj):
  579. yield obj
  580. elif obj is None:
  581. return
  582. elif isinstance(obj, (list, tuple)):
  583. for o in obj:
  584. yield from _iter(o)
  585. elif isinstance(obj, dict):
  586. # We only accept primitive key types, so we needn't inspect them
  587. for o in obj.values():
  588. yield from _iter(o)
  589. elif allow_unknown:
  590. yield obj
  591. else:
  592. raise ValueError(
  593. "Auto nesting doesn't know how to process "
  594. "an input object of type "
  595. + torch.typename(obj)
  596. + (
  597. ". Accepted types: " + condition_msg + ", or lists/tuples of them"
  598. if condition_msg
  599. else ""
  600. )
  601. )
  602. return _iter
  603. def _unflatten(input, proto):
  604. # unflatten a list or tuple input into a nested list/tuple structure
  605. # specified by proto
  606. def unflatten_helper(input, proto):
  607. res: list[Optional[torch.Tensor]] = []
  608. if hasattr(proto, "_jit_wrap"):
  609. return proto._jit_wrap(input)
  610. if not isinstance(proto, (list, tuple)):
  611. return input[0], input[1:]
  612. for e in proto:
  613. if e is None:
  614. res.append(e)
  615. else:
  616. res_e, input = unflatten_helper(input, e)
  617. res.append(res_e)
  618. return type(proto)(res), input
  619. return unflatten_helper(input, proto)[0]
  620. _iter_jit_values = _iter_filter(
  621. lambda o: o is None or isinstance(o, torch._C.Value),
  622. condition_msg="jit's Values or None",
  623. )
  624. _iter_tensors = _iter_filter(
  625. lambda x: isinstance(x, torch.Tensor),
  626. condition_msg="Tensors",
  627. conversion=_jit_unwrap_structured,
  628. )
  629. _iter_tensors_permissive = _iter_filter(
  630. lambda x: isinstance(x, torch.Tensor),
  631. allow_unknown=True,
  632. condition_msg="Tensors (permissive)",
  633. )
  634. _iter_None_tensors = _iter_filter(
  635. lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None"
  636. )
  637. _map_tensor_data = _nested_map(
  638. lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors"
  639. )
  640. class NestedIOFunction(Function):
  641. r"""
  642. This class is here only for backward compatibility reasons.
  643. Use :class:`Function` instead of this for any new use case.
  644. """
  645. # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
  646. # superclass (Function) but are instance methods here, which mypy reports as incompatible.
  647. def _do_forward(self, *input):
  648. self._nested_input = input
  649. flat_input = tuple(_iter_tensors(input))
  650. flat_output = super()._do_forward(*flat_input) # type: ignore[misc]
  651. nested_tensors = _unflatten(flat_output, self._nested_output)
  652. return nested_tensors
  653. def _do_backward(self, gradients, retain_variables):
  654. self.retain_variables = retain_variables
  655. result = super()._do_backward(gradients, retain_variables) # type: ignore[misc]
  656. if not retain_variables:
  657. del self._nested_output
  658. del self._to_save_nested
  659. return result
  660. def backward(self, *gradients: Any) -> Any: # type: ignore[override]
  661. r"""
  662. Shared backward utility.
  663. """
  664. nested_gradients = _unflatten(gradients, self._nested_output)
  665. result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value]
  666. return tuple(_iter_None_tensors(result))
  667. __call__ = _do_forward
  668. def forward(self, *args: Any) -> Any: # type: ignore[override]
  669. r"""
  670. Shared forward utility.
  671. """
  672. nested_tensors = _map_tensor_data(self._nested_input)
  673. result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value]
  674. del self._nested_input
  675. self._nested_output = result
  676. return tuple(_iter_tensors(result))
  677. def save_for_backward(self, *args: Any) -> None:
  678. r"""
  679. See :meth:`Function.save_for_backward`.
  680. """
  681. self.to_save = tuple(_iter_tensors(args))
  682. self._to_save_nested = args
  683. @property
  684. def saved_tensors(self): # type: ignore[override]
  685. r"""
  686. See :meth:`Function.saved_tensors`.
  687. """
  688. flat_tensors = super().saved_tensors # type: ignore[misc]
  689. return _unflatten(flat_tensors, self._to_save_nested)
  690. def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
  691. r"""
  692. See :meth:`Function.mark_dirty`.
  693. """
  694. self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
  695. def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
  696. r"""
  697. See :meth:`Function.mark_non_differentiable`.
  698. """
  699. self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
  700. def forward_extended(self, *input: Any) -> None:
  701. r"""
  702. User defined forward.
  703. """
  704. raise NotImplementedError
  705. def backward_extended(self, *grad_output: Any) -> None:
  706. r"""
  707. User defined backward.
  708. """
  709. raise NotImplementedError