autograd.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. from dataclasses import dataclass
  4. from typing import Any, Callable, Optional, Protocol
  5. from torch import _C, _ops, autograd, Tensor
  6. from torch.utils import _pytree
  7. from . import utils
  8. class InfoProtocol(Protocol):
  9. _backward_fn: Optional[Callable]
  10. _setup_context_fn: Optional[Callable]
  11. @dataclasses.dataclass
  12. class Info:
  13. _backward_fn: Optional[Callable]
  14. _setup_context_fn: Optional[Callable]
  15. def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
  16. name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}"
  17. has_kwarg_only_args = utils.has_kwarg_only_args(op._schema)
  18. @dataclass
  19. class Metadata:
  20. keyset: _C.DispatchKeySet
  21. keyword_only_args: dict[str, Any]
  22. def forward_no_grad(*args):
  23. metadata = args[-1]
  24. args = args[:-1]
  25. with _C._AutoDispatchBelowAutograd():
  26. keyset = metadata.keyset
  27. kwargs = metadata.keyword_only_args
  28. result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
  29. return result
  30. def forward(ctx, *args):
  31. metadata = args[-1]
  32. args = args[:-1]
  33. with _C._AutoDispatchBelowAutograd():
  34. keyset = metadata.keyset
  35. kwargs = metadata.keyword_only_args
  36. result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
  37. if info._setup_context_fn:
  38. # The Dispatcher will remove args that are equal to their default
  39. # values from (args, kwargs). We're going to add it back so that
  40. # the user can access them.
  41. #
  42. # This is OK to do: The Dispatcher removed the args for serialization
  43. # FC/BC reasons (that is, a graph will not store args that are equal
  44. # to their default values), but that doesn't matter here. If the user
  45. # adds a new default arg, then they must update
  46. # their setup_context (along with the rest of their operator
  47. # registrations)
  48. args, kwargs = utils.fill_defaults(op._schema, args, kwargs)
  49. if has_kwarg_only_args:
  50. info._setup_context_fn(
  51. ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result
  52. )
  53. else:
  54. info._setup_context_fn(ctx=ctx, inputs=args, output=result)
  55. return result
  56. def backward(ctx, *grads):
  57. if info._backward_fn:
  58. try:
  59. prev_needs_input_grad = ctx.needs_input_grad
  60. ctx.needs_input_grad = ctx.needs_input_grad[:-1]
  61. result = info._backward_fn(ctx, *grads)
  62. finally:
  63. ctx.needs_input_grad = prev_needs_input_grad
  64. if isinstance(result, tuple):
  65. return (*result, None)
  66. return result, None
  67. raise RuntimeError(
  68. f"Trying to backward through {op} but no autograd "
  69. f"formula was registered. "
  70. f"Please use register_autograd to add one."
  71. )
  72. Generated = type(
  73. name,
  74. (autograd.Function,),
  75. {
  76. "forward": staticmethod(forward),
  77. "backward": staticmethod(backward),
  78. },
  79. )
  80. schema = op._schema
  81. if any(
  82. utils.is_tensorlist_like_type(a.type)
  83. for a in (*schema.arguments, *schema.returns)
  84. ):
  85. Generated = supports_tensorlist(Generated)
  86. # The dispatcher passes any keyword-only-args as kwargs and the
  87. # rest of the args (even if specified as kwargs) as args.
  88. def autograd_impl(keyset, *args, **keyword_only_args):
  89. if _C.is_grad_enabled() and _C._any_requires_grad(*args):
  90. result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
  91. else:
  92. result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
  93. return result
  94. return autograd_impl
  95. def supports_tensorlist(cls: Any) -> Any:
  96. """Allows a given autograd.Function class to support List[Tensor] inputs/outputs.
  97. Regular autograd.Function has a constraint that it only directly supports autograd for
  98. Tensors. Applying @supports_tensorlist enables an autograd.Function to support
  99. autograd for List[Tensor] inputs and outputs.
  100. """
  101. orig_forward = cls.forward
  102. orig_backward = cls.backward
  103. orig_apply = cls.apply
  104. @dataclass
  105. class Metadata:
  106. input_spec: spec_t
  107. output_spec: Optional[spec_t] = None
  108. result_is_tuple: Optional[bool] = None
  109. def new_forward(ctx, *args):
  110. metadata = args[-1]
  111. args = args[:-1]
  112. if not isinstance(metadata, Metadata):
  113. raise NotImplementedError(
  114. "NYI: calling supports_tensorlist autograd.Function.forward directly. "
  115. "You should probably be calling .apply instead. "
  116. "Please file an issue if not."
  117. )
  118. args = unflatten(list(args), metadata.input_spec)
  119. result = orig_forward(ctx, *args)
  120. metadata.result_is_tuple = isinstance(result, tuple)
  121. if not metadata.result_is_tuple:
  122. result = (result,)
  123. flat_result, output_spec = flatten(result, not_list_of_tensor)
  124. metadata.output_spec = output_spec
  125. if hasattr(ctx, "_pt_metadata"):
  126. raise RuntimeError(
  127. "Please don't set ctx._pt_metadata; PyTorch uses it to store info"
  128. )
  129. ctx._pt_metadata = metadata
  130. return tuple(flat_result)
  131. def new_backward(ctx, *grads):
  132. if not hasattr(ctx, "_pt_metadata"):
  133. raise NotImplementedError(
  134. "NYI: calling supports_tensorlist autograd.Function.backward directly. "
  135. "This will automatically get called by PyTorch autograd. "
  136. "Please file an issue if you need this."
  137. )
  138. metadata = ctx._pt_metadata
  139. grads = unflatten(list(grads), metadata.output_spec)
  140. # If the user's input is ([x, y, z], w),
  141. # then needs_input_grad is (bool, bool, bool, bool, bool).
  142. # We need to
  143. # 1. get rid of the additional bool (which comes from the extra
  144. # `metadata input`)
  145. # 2. unflatten to get the right structure.
  146. prev_needs_input_grad = ctx.needs_input_grad
  147. try:
  148. ctx.needs_input_grad = unflatten(
  149. list(ctx.needs_input_grad[:-1]), metadata.input_spec
  150. )
  151. grad_inputs = orig_backward(ctx, *grads)
  152. finally:
  153. ctx.needs_input_grad = prev_needs_input_grad
  154. if not isinstance(grad_inputs, tuple):
  155. grad_inputs = (grad_inputs,)
  156. # Assume that any Nones in the backward are Tensors.
  157. # If the forward has an arg that is [1, 2, 3], the backward should
  158. # return None as the grad.
  159. # If the forward has an arg that is [tensor, tensor], the backward
  160. # may return [None, None], [grad, None], [None, grad], or [grad, grad].
  161. flat_grad_inputs, grad_inputs_spec = flatten(
  162. grad_inputs, not_list_of_optional_tensor
  163. )
  164. if grad_inputs_spec != metadata.input_spec:
  165. raise RuntimeError(
  166. f"Expected the return from backward to be of the same structure "
  167. f"as the inputs. Got: {grad_inputs_spec} (return from backward), "
  168. f"{metadata.input_spec} (inputs)"
  169. )
  170. return tuple(flat_grad_inputs + [None])
  171. def new_apply(*args):
  172. flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor)
  173. metadata = Metadata(input_spec)
  174. result = orig_apply(*flat_args, metadata) # type: ignore[misc]
  175. assert metadata.output_spec is not None
  176. result = unflatten(list(result), metadata.output_spec)
  177. if not metadata.result_is_tuple:
  178. assert isinstance(result, tuple)
  179. assert len(result) == 1
  180. return result[0]
  181. return result
  182. cls.forward = new_forward
  183. cls.backward = new_backward
  184. cls.apply = new_apply
  185. return cls
  186. def not_list_of_tensor(tree):
  187. if isinstance(tree, tuple):
  188. return False
  189. if isinstance(tree, list):
  190. return any(not isinstance(l, Tensor) for l in tree)
  191. return True
  192. def not_list_of_optional_tensor(tree):
  193. if isinstance(tree, tuple):
  194. return False
  195. if isinstance(tree, list):
  196. return any(l is not None and not isinstance(l, Tensor) for l in tree)
  197. return True
  198. flatten = _pytree.tree_flatten
  199. unflatten = _pytree.tree_unflatten
  200. spec_t = _pytree.TreeSpec