autograd.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from collections import namedtuple
  4. import torch
  5. import torch.utils._pytree as pytree
  6. # NOTE [CustomOp autograd kernel indirection]
  7. # We register `inner` as the autograd kernel for this custom_op.
  8. # `inner` either calls the autograd formula registered by the user,
  9. # or goes into an `autograd_not_implemented` kernel.
  10. #
  11. # The reason why this indirection exists is
  12. # so that we can swap out the autograd kernel (the PyTorch dispatcher
  13. # doesn't actually allow us to do this). By default, we want
  14. # the `autograd_not_implemented` behavior, but then the user may come
  15. # and register something that is actually a backward formula
  16. def autograd_kernel_indirection(custom_op):
  17. autograd_fallback = autograd_not_implemented(custom_op)
  18. def inner(*args, **kwargs):
  19. if custom_op._has_impl("autograd"):
  20. kernel = custom_op._get_impl("autograd").func
  21. return kernel(*args, **kwargs)
  22. # As explained in NOTE ["backward", "save_for_backward", and "autograd"],
  23. # after the user gives us "backward" and "save_for_backward", we generate
  24. # the "autograd" impl. If the user only provided one, then we tell
  25. # the user they've done something wrong.
  26. if custom_op._has_impl("save_for_backward") or custom_op._has_impl("backward"):
  27. missing = (
  28. "save_for_backward" if custom_op._has_impl("backward") else "backward"
  29. )
  30. found = "save_for_backward" if missing == "backward" else "backward"
  31. loc = custom_op._get_impl(found).location
  32. raise RuntimeError(
  33. f"We found a '{found}' registration for {custom_op} at "
  34. f"{loc} but were unable to find a '{missing}' registration. "
  35. f"To use the CustomOp API to register a backward formula, "
  36. f"please provide us both a backward function and a "
  37. f"'save for backward' function via `impl_backward` and "
  38. f"`impl_save_for_backward` respectively."
  39. )
  40. return autograd_fallback(*args, **kwargs)
  41. return inner
  42. # TODO(#101191): Use the actual C++ autograd not implemented fallback,
  43. # or change the default autograd fallback to the autograd not implemented fallback.
  44. def autograd_not_implemented(custom_op):
  45. def kernel(*args, **kwargs):
  46. if torch.is_grad_enabled() and pytree.tree_any(
  47. lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
  48. ):
  49. raise RuntimeError("Autograd has not been implemented for operator")
  50. with torch._C._AutoDispatchBelowAutograd():
  51. return custom_op(*args, **kwargs)
  52. return kernel
  53. def mark_non_differentiable(ctx, output, output_differentiability):
  54. # Output types are restricted to be:
  55. # - Tensor
  56. # - Tensor[]
  57. # - int, bool, Scalar, float
  58. # See _check_can_register_backward
  59. if output_differentiability is not None:
  60. if not isinstance(output, tuple):
  61. tuple_output = (output,)
  62. else:
  63. tuple_output = output # type: ignore[assignment]
  64. assert len(output_differentiability) == len(tuple_output)
  65. non_differentiable_tensors = []
  66. for idx, (differentiable, out) in enumerate(
  67. zip(output_differentiability, tuple_output)
  68. ):
  69. if isinstance(out, torch.Tensor):
  70. if not differentiable:
  71. non_differentiable_tensors.append(out)
  72. continue
  73. if isinstance(out, list):
  74. if not differentiable:
  75. non_differentiable_tensors.extend(out)
  76. continue
  77. if differentiable:
  78. raise RuntimeError(
  79. f"With output_differentiability={output_differentiability}. "
  80. f"At idx {idx}, we received an object of type {type(out)} that "
  81. f"is not a Tensor, so it cannot have be marked as differentiable in "
  82. f"output_differentiability."
  83. )
  84. if non_differentiable_tensors:
  85. ctx.mark_non_differentiable(*non_differentiable_tensors)
  86. def construct_autograd_kernel(
  87. schema,
  88. output_differentiability,
  89. custom_op,
  90. op_overload,
  91. save_for_backward_fn,
  92. backward_fn,
  93. ):
  94. def apply(*args):
  95. flat_args, spec = pytree.tree_flatten(args)
  96. out_spec = None
  97. def forward(ctx, *flat_args):
  98. ctx.set_materialize_grads(True)
  99. args = pytree.tree_unflatten(list(flat_args), spec)
  100. with torch._C._AutoDispatchBelowAutograd():
  101. output = op_overload(*args)
  102. # We use the info about args to give better error messages in backward
  103. args_info = namedtuple_args(schema, pytree.tree_map(type, args))
  104. save_for_backward_fn_inputs = namedtuple_args(schema, args)
  105. to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
  106. save_pytree_for_backward(ctx, (to_save, args_info))
  107. mark_non_differentiable(ctx, output, output_differentiability)
  108. nonlocal out_spec
  109. flat_output, out_spec = pytree.tree_flatten(output)
  110. return tuple(flat_output)
  111. def backward(ctx, *flat_grad_output):
  112. assert out_spec is not None
  113. grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
  114. saved, args_info = unpack_saved(ctx)
  115. # There is nothing on the ctx object for now, it is just there so
  116. # that we can add additional things in the future.
  117. inner_ctx = object()
  118. if not isinstance(grads, tuple):
  119. grads = (grads,)
  120. grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
  121. # Massage the grad_inputs_dict to a form acceptable by
  122. # autograd.Function.
  123. validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
  124. return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
  125. generated_cls = gen_autograd_function(
  126. custom_op._opname + "_customop", forward, backward
  127. )
  128. flat_output = generated_cls.apply(*flat_args)
  129. assert out_spec is not None
  130. return pytree.tree_unflatten(list(flat_output), out_spec)
  131. return apply
  132. def gen_autograd_function(name, forward, backward):
  133. generated_cls = type(
  134. name,
  135. (torch.autograd.Function,),
  136. {
  137. "forward": staticmethod(forward),
  138. "backward": staticmethod(backward),
  139. },
  140. )
  141. return generated_cls
  142. @functools.lru_cache
  143. def namedtuple_args_cls(schema):
  144. attribs = [arg.name for arg in schema.arguments.flat_all]
  145. name = str(schema.name) + "_args"
  146. # mypy doesn't support dynamic namedtuple name
  147. tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
  148. return tuple_cls
  149. def namedtuple_args(schema, args):
  150. assert isinstance(args, tuple)
  151. tuple_cls = namedtuple_args_cls(schema)
  152. return tuple_cls(*args)
  153. def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
  154. def error(what):
  155. backward = forward_op._get_impl("backward")
  156. raise RuntimeError(
  157. f"In the backward function defined for {forward_op} at "
  158. f"{backward.location} using the CustomOp API, {what}"
  159. )
  160. if not isinstance(grad_inputs_dict, dict):
  161. error(
  162. f"expected the output of the backward function to be a dict but "
  163. f"got {type(grad_inputs_dict)}"
  164. )
  165. expected_keys = {
  166. arg.name
  167. for arg in forward_op._schema.arguments.flat_all
  168. if arg.type.is_tensor_like()
  169. }
  170. actual_keys = grad_inputs_dict.keys()
  171. if expected_keys != actual_keys:
  172. error(
  173. f"expected the returned grad_input dict to have keys "
  174. f"{expected_keys} but got {actual_keys}. The backward "
  175. f"function must return a gradient (can be None) for each arg "
  176. f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
  177. f"Args declared to be non-Tensor-like types should not appear "
  178. f"in the grad_input dict"
  179. )
  180. for name, grad in grad_inputs_dict.items():
  181. arg_info = getattr(args_info, name)
  182. if isinstance(arg_info, list):
  183. if not isinstance(grad, (tuple, list)):
  184. error(
  185. f"for input '{name}' expected the grad_input dict to "
  186. f"hold a list of gradients but got object of type "
  187. f"{type(grad)}."
  188. )
  189. if not len(grad) == len(arg_info):
  190. error(
  191. f"for input '{name}' expected the grad_input dict to "
  192. f"hold a list of {len(arg_info)} gradients but got "
  193. f"{len(grad)}"
  194. )
  195. for idx, (g, info) in enumerate(zip(grad, arg_info)):
  196. if g is None:
  197. continue
  198. if not isinstance(g, torch.Tensor):
  199. error(
  200. f"for input '{name}' expected the grad_input dict to "
  201. f"hold a list of None or Tensor gradients but got "
  202. f"object of {type(g)} at index {idx}"
  203. )
  204. if not issubclass(info, torch.Tensor):
  205. error(
  206. f"for input '{name}', got a Tensor as the gradient "
  207. f"for the {idx}-th value but expected None because "
  208. f"the {idx}-th value was not a Tensor (it was "
  209. f"type {arg_info}"
  210. )
  211. continue
  212. if grad is None:
  213. continue
  214. if not isinstance(grad, torch.Tensor):
  215. error(
  216. f"got object of type {type(grad)} as the gradient for input "
  217. f"'{name}', "
  218. f"but expected the gradient to be either None or a Tensor"
  219. )
  220. if not issubclass(arg_info, torch.Tensor):
  221. error(
  222. f"got a Tensor as the gradient for input '{name}' but "
  223. f"expected None as the gradient because input '{name}' "
  224. f"was not a Tensor (it was type {arg_info})."
  225. )
  226. def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
  227. result = []
  228. for name, arg_info in args_info._asdict().items():
  229. if name not in grad_inputs_dict:
  230. result.append(pytree.tree_map(lambda x: None, arg_info))
  231. continue
  232. result.append(grad_inputs_dict[name])
  233. return tuple(pytree.tree_leaves(result))
  234. # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
  235. # autograd.Function prefers that users use ctx.save_for_backward to
  236. # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
  237. # ctx object.
  238. def save_pytree_for_backward(ctx, stuff):
  239. flat_stuff, spec = pytree.tree_flatten(stuff)
  240. num_elts = len(flat_stuff)
  241. tensor_idxs = [
  242. idx for idx, thing in enumerate(flat_stuff) if isinstance(thing, torch.Tensor)
  243. ]
  244. non_tensor_idxs = [
  245. idx
  246. for idx, thing in enumerate(flat_stuff)
  247. if not isinstance(thing, torch.Tensor)
  248. ]
  249. tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
  250. non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
  251. ctx.spec = spec
  252. ctx.num_elts = num_elts
  253. ctx.save_for_backward(*tensors)
  254. ctx.tensor_idxs = tensor_idxs
  255. ctx.saved_non_tensors = non_tensors
  256. ctx.non_tensor_idxs = non_tensor_idxs
  257. # Inverse operation to save_pytree_for_backward
  258. def unpack_saved(ctx):
  259. flat_stuff = [None] * ctx.num_elts
  260. for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
  261. flat_stuff[idx] = tensor
  262. for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
  263. flat_stuff[idx] = non_tensor
  264. stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
  265. return stuff