apis.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # mypy: allow-untyped-defs
  2. # NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can
  3. # trace through functorch transforms.
  4. # Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing
  5. # and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file
  6. # to Dynamo.
  7. import functools
  8. from torch._functorch.utils import argnums_t, exposed_in
  9. from torch._functorch.vmap import (
  10. _check_out_dims_is_int_or_int_pytree,
  11. _check_randomness_arg,
  12. _chunked_vmap,
  13. _process_batched_inputs,
  14. Callable,
  15. in_dims_t,
  16. out_dims_t,
  17. vmap_impl,
  18. )
  19. # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
  20. # sends those into func, and then unwraps the output BatchedTensors. Operations
  21. # on BatchedTensors perform the batched operations that the user is asking for.
  22. #
  23. # vmap's randomness behavior differs from JAX's, which would require a PRNG key
  24. # to be passed everywhere.
  25. @exposed_in("torch.func")
  26. def vmap(
  27. func: Callable,
  28. in_dims: in_dims_t = 0,
  29. out_dims: out_dims_t = 0,
  30. randomness: str = "error",
  31. *,
  32. chunk_size=None,
  33. ) -> Callable:
  34. """
  35. vmap is the vectorizing map; ``vmap(func)`` returns a new function that
  36. maps ``func`` over some dimension of the inputs. Semantically, vmap
  37. pushes the map into PyTorch operations called by ``func``, effectively
  38. vectorizing those operations.
  39. vmap is useful for handling batch dimensions: one can write a function
  40. ``func`` that runs on examples and then lift it to a function that can
  41. take batches of examples with ``vmap(func)``. vmap can also be used to
  42. compute batched gradients when composed with autograd.
  43. .. note::
  44. :func:`torch.vmap` is aliased to :func:`torch.func.vmap` for
  45. convenience. Use whichever one you'd like.
  46. Args:
  47. func (function): A Python function that takes one or more arguments.
  48. Must return one or more Tensors.
  49. in_dims (int or nested structure): Specifies which dimension of the
  50. inputs should be mapped over. ``in_dims`` should have a
  51. structure like the inputs. If the ``in_dim`` for a particular
  52. input is None, then that indicates there is no map dimension.
  53. Default: 0.
  54. out_dims (int or Tuple[int]): Specifies where the mapped dimension
  55. should appear in the outputs. If ``out_dims`` is a Tuple, then
  56. it should have one element per output. Default: 0.
  57. randomness (str): Specifies whether the randomness in this
  58. vmap should be the same or different across batches. If 'different',
  59. the randomness for each batch will be different. If 'same', the
  60. randomness will be the same across batches. If 'error', any calls to
  61. random functions will error. Default: 'error'. WARNING: this flag
  62. only applies to random PyTorch operations and does not apply to
  63. Python's random module or numpy randomness.
  64. chunk_size (None or int): If None (default), apply a single vmap over inputs.
  65. If not None, then compute the vmap :attr:`chunk_size` samples at a time.
  66. Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop.
  67. If you run into memory issues computing the vmap, please try a non-None chunk_size.
  68. Returns:
  69. Returns a new "batched" function. It takes the same inputs as
  70. ``func``, except each input has an extra dimension at the index
  71. specified by ``in_dims``. It takes returns the same outputs as
  72. ``func``, except each output has an extra dimension at the index
  73. specified by ``out_dims``.
  74. .. warning:
  75. :func:`vmap` works best with functional-style code. Please do not
  76. perform any side-effects in ``func``, with the exception of
  77. in-place PyTorch operations. Examples of side-effects include mutating
  78. Python data structures and assigning values to variables not captured
  79. in ``func``.
  80. One example of using :func:`vmap` is to compute batched dot products. PyTorch
  81. doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully
  82. rummaging through docs, use :func:`vmap` to construct a new function.
  83. >>> torch.dot # [D], [D] -> []
  84. >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N]
  85. >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
  86. >>> batched_dot(x, y)
  87. :func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler
  88. model authoring experience.
  89. >>> batch_size, feature_size = 3, 5
  90. >>> weights = torch.randn(feature_size, requires_grad=True)
  91. >>>
  92. >>> def model(feature_vec):
  93. >>> # Very simple linear model with activation
  94. >>> return feature_vec.dot(weights).relu()
  95. >>>
  96. >>> examples = torch.randn(batch_size, feature_size)
  97. >>> result = torch.vmap(model)(examples)
  98. :func:`vmap` can also help vectorize computations that were previously difficult
  99. or impossible to batch. One example is higher-order gradient computation.
  100. The PyTorch autograd engine computes vjps (vector-Jacobian products).
  101. Computing a full Jacobian matrix for some function f: R^N -> R^N usually
  102. requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`,
  103. we can vectorize the whole computation, computing the Jacobian in a single
  104. call to ``autograd.grad``.
  105. >>> # Setup
  106. >>> N = 5
  107. >>> f = lambda x: x**2
  108. >>> x = torch.randn(N, requires_grad=True)
  109. >>> y = f(x)
  110. >>> I_N = torch.eye(N)
  111. >>>
  112. >>> # Sequential approach
  113. >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
  114. >>> for v in I_N.unbind()]
  115. >>> jacobian = torch.stack(jacobian_rows)
  116. >>>
  117. >>> # vectorized gradient computation
  118. >>> def get_vjp(v):
  119. >>> return torch.autograd.grad(y, x, v)
  120. >>> jacobian = torch.vmap(get_vjp)(I_N)
  121. :func:`vmap` can also be nested, producing an output with multiple batched dimensions
  122. >>> torch.dot # [D], [D] -> []
  123. >>> batched_dot = torch.vmap(
  124. ... torch.vmap(torch.dot)
  125. ... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
  126. >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
  127. >>> batched_dot(x, y) # tensor of size [2, 3]
  128. If the inputs are not batched along the first dimension, ``in_dims`` specifies
  129. the dimension that each inputs are batched along as
  130. >>> torch.dot # [N], [N] -> []
  131. >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
  132. >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
  133. >>> batched_dot(
  134. ... x, y
  135. ... ) # output is [5] instead of [2] if batched along the 0th dimension
  136. If there are multiple inputs each of which is batched along different dimensions,
  137. ``in_dims`` must be a tuple with the batch dimension for each input as
  138. >>> torch.dot # [D], [D] -> []
  139. >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
  140. >>> x, y = torch.randn(2, 5), torch.randn(5)
  141. >>> batched_dot(
  142. ... x, y
  143. ... ) # second arg doesn't have a batch dim because in_dim[1] was None
  144. If the input is a Python struct, ``in_dims`` must be a tuple containing a struct
  145. matching the shape of the input:
  146. >>> f = lambda dict: torch.dot(dict["x"], dict["y"])
  147. >>> x, y = torch.randn(2, 5), torch.randn(5)
  148. >>> input = {"x": x, "y": y}
  149. >>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},))
  150. >>> batched_dot(input)
  151. By default, the output is batched along the first dimension. However, it can be batched
  152. along any dimension by using ``out_dims``
  153. >>> f = lambda x: x**2
  154. >>> x = torch.randn(2, 5)
  155. >>> batched_pow = torch.vmap(f, out_dims=1)
  156. >>> batched_pow(x) # [5, 2]
  157. For any function that uses kwargs, the returned function will not batch the kwargs but will
  158. accept kwargs
  159. >>> x = torch.randn([2, 5])
  160. >>> def fn(x, scale=4.):
  161. >>> return x * scale
  162. >>>
  163. >>> batched_pow = torch.vmap(fn)
  164. >>> assert torch.allclose(batched_pow(x), x * 4)
  165. >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
  166. .. note::
  167. vmap does not provide general autobatching or handle variable-length
  168. sequences out of the box.
  169. """
  170. from torch.compiler import is_compiling
  171. _check_randomness_arg(randomness)
  172. if not (chunk_size is None or chunk_size > 0):
  173. raise ValueError(
  174. f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})"
  175. )
  176. def wrapped(*args, **kwargs):
  177. return vmap_impl(
  178. func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
  179. )
  180. if not is_compiling():
  181. wrapped = functools.wraps(func)(wrapped)
  182. return wrapped
  183. def chunk_vmap(
  184. func: Callable,
  185. in_dims: in_dims_t = 0,
  186. out_dims: out_dims_t = 0,
  187. randomness: str = "error",
  188. chunks=2,
  189. ) -> Callable:
  190. """
  191. chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes
  192. everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of
  193. chunks at a time. For more details about vectorizing map, see :func:`vmap`.
  194. .. note::
  195. Please use :func:`vmap` with ``chunk_size`` argument instead of this API.
  196. Args:
  197. func (function): A Python function that takes one or more arguments.
  198. Must return one or more Tensors.
  199. in_dims (int or nested structure): Specifies which dimension of the
  200. inputs should be mapped over. ``in_dims`` should have a
  201. structure like the inputs. If the ``in_dim`` for a particular
  202. input is None, then that indicates there is no map dimension.
  203. Default: 0.
  204. out_dims (int or Tuple[int]): Specifies where the mapped dimension
  205. should appear in the outputs. If ``out_dims`` is a Tuple, then
  206. it should have one element per output. Default: 0.
  207. randomness (str): Specifies whether the randomness in this
  208. vmap should be the same or different across batches. If 'different',
  209. the randomness for each batch will be different. If 'same', the
  210. randomness will be the same across batches. If 'error', any calls to
  211. random functions will error. Default: 'error'. WARNING: this flag
  212. only applies to random PyTorch operations and does not apply to
  213. Python's random module or numpy randomness.
  214. chunks (int): Number of chunks to use to split the input data. Default is 2.
  215. If equals to 1 then :func:`vmap` is called.
  216. Returns:
  217. Returns a new "batched" function. It takes the same inputs as
  218. ``func``, except each input has an extra dimension at the index
  219. specified by ``in_dims``. It takes returns the same outputs as
  220. ``func``, except each output has an extra dimension at the index
  221. specified by ``out_dims``.
  222. """
  223. _check_randomness_arg(randomness)
  224. if chunks == 1:
  225. return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness)
  226. def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_):
  227. flat_args_chunks = tuple(
  228. t.chunk(chunks_, dim=in_dim)
  229. if in_dim is not None
  230. else [
  231. t,
  232. ]
  233. * chunks_
  234. for t, in_dim in zip(flat_args_, flat_in_dims_)
  235. )
  236. # transpose chunk dim and flatten structure
  237. # chunks_flat_args is a list of flatten args
  238. chunks_flat_args = zip(*flat_args_chunks)
  239. return chunks_flat_args
  240. @functools.wraps(func)
  241. def wrapped_with_chunks(*args, **kwargs):
  242. _check_out_dims_is_int_or_int_pytree(out_dims, func)
  243. _, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
  244. in_dims, args, func
  245. )
  246. # Chunk flat arguments
  247. chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks)
  248. # Apply vmap on chunks
  249. return _chunked_vmap(
  250. func,
  251. flat_in_dims,
  252. chunks_flat_args,
  253. args_spec,
  254. out_dims,
  255. randomness,
  256. **kwargs,
  257. )
  258. return wrapped_with_chunks
  259. @exposed_in("torch.func")
  260. def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
  261. """``grad`` operator helps computing gradients of ``func`` with respect to the
  262. input(s) specified by ``argnums``. This operator can be nested to
  263. compute higher-order gradients.
  264. Args:
  265. func (Callable): A Python function that takes one or more arguments.
  266. Must return a single-element Tensor. If specified ``has_aux`` equals ``True``,
  267. function can return a tuple of single-element Tensor and other auxiliary objects:
  268. ``(output, aux)``.
  269. argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
  270. ``argnums`` can be single integer or tuple of integers. Default: 0.
  271. has_aux (bool): Flag indicating that ``func`` returns a tensor and other
  272. auxiliary objects: ``(output, aux)``. Default: False.
  273. Returns:
  274. Function to compute gradients with respect to its inputs. By default, the output of
  275. the function is the gradient tensor(s) with respect to the first argument.
  276. If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects
  277. is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with
  278. respect to each ``argnums`` value is returned.
  279. Example of using ``grad``:
  280. >>> # xdoctest: +SKIP
  281. >>> from torch.func import grad
  282. >>> x = torch.randn([])
  283. >>> cos_x = grad(lambda x: torch.sin(x))(x)
  284. >>> assert torch.allclose(cos_x, x.cos())
  285. >>>
  286. >>> # Second-order gradients
  287. >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
  288. >>> assert torch.allclose(neg_sin_x, -x.sin())
  289. When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
  290. >>> # xdoctest: +SKIP
  291. >>> from torch.func import grad, vmap
  292. >>> batch_size, feature_size = 3, 5
  293. >>>
  294. >>> def model(weights, feature_vec):
  295. >>> # Very simple linear model with activation
  296. >>> assert feature_vec.dim() == 1
  297. >>> return feature_vec.dot(weights).relu()
  298. >>>
  299. >>> def compute_loss(weights, example, target):
  300. >>> y = model(weights, example)
  301. >>> return ((y - target) ** 2).mean() # MSELoss
  302. >>>
  303. >>> weights = torch.randn(feature_size, requires_grad=True)
  304. >>> examples = torch.randn(batch_size, feature_size)
  305. >>> targets = torch.randn(batch_size)
  306. >>> inputs = (weights, examples, targets)
  307. >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(
  308. ... *inputs
  309. ... )
  310. Example of using ``grad`` with ``has_aux`` and ``argnums``:
  311. >>> # xdoctest: +SKIP
  312. >>> from torch.func import grad
  313. >>> def my_loss_func(y, y_pred):
  314. >>> loss_per_sample = (0.5 * y_pred - y) ** 2
  315. >>> loss = loss_per_sample.mean()
  316. >>> return loss, (y_pred, loss_per_sample)
  317. >>>
  318. >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
  319. >>> y_true = torch.rand(4)
  320. >>> y_preds = torch.rand(4, requires_grad=True)
  321. >>> out = fn(y_true, y_preds)
  322. >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
  323. .. note::
  324. Using PyTorch ``torch.no_grad`` together with ``grad``.
  325. Case 1: Using ``torch.no_grad`` inside a function:
  326. >>> # xdoctest: +SKIP
  327. >>> def f(x):
  328. >>> with torch.no_grad():
  329. >>> c = x ** 2
  330. >>> return x - c
  331. In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
  332. Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
  333. >>> # xdoctest: +SKIP
  334. >>> with torch.no_grad():
  335. >>> grad(f)(x)
  336. In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
  337. outer one. This is because ``grad`` is a "function transform": its result
  338. should not depend on the result of a context manager outside of ``f``.
  339. """
  340. # To avoid cyclical dependency.
  341. import torch._functorch.eager_transforms as eager_transforms
  342. from torch.compiler import is_compiling
  343. def wrapper(*args, **kwargs):
  344. return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
  345. if not is_compiling():
  346. wrapper = functools.wraps(func)(wrapper)
  347. return wrapper
  348. @exposed_in("torch.func")
  349. def grad_and_value(
  350. func: Callable, argnums: argnums_t = 0, has_aux: bool = False
  351. ) -> Callable:
  352. """
  353. Returns a function to compute a tuple of the gradient and primal, or
  354. forward, computation.
  355. Args:
  356. func (Callable): A Python function that takes one or more arguments.
  357. Must return a single-element Tensor. If specified ``has_aux``
  358. equals ``True``, function can return a tuple of single-element
  359. Tensor and other auxiliary objects: ``(output, aux)``.
  360. argnums (int or Tuple[int]): Specifies arguments to compute gradients
  361. with respect to. ``argnums`` can be single integer or tuple of
  362. integers. Default: 0.
  363. has_aux (bool): Flag indicating that ``func`` returns a tensor and
  364. other auxiliary objects: ``(output, aux)``. Default: False.
  365. Returns:
  366. Function to compute a tuple of gradients with respect to its inputs
  367. and the forward computation. By default, the output of the function is
  368. a tuple of the gradient tensor(s) with respect to the first argument
  369. and the primal computation. If specified ``has_aux`` equals
  370. ``True``, tuple of gradients and tuple of the forward computation with
  371. output auxiliary objects is returned. If ``argnums`` is a tuple of
  372. integers, a tuple of a tuple of the output gradients with respect to
  373. each ``argnums`` value and the forward computation is returned.
  374. See :func:`grad` for examples
  375. """
  376. from torch._functorch import eager_transforms
  377. from torch.compiler import is_compiling
  378. def wrapper(*args, **kwargs):
  379. return eager_transforms.grad_and_value_impl(
  380. func, argnums, has_aux, args, kwargs
  381. )
  382. if not is_compiling():
  383. wrapper = functools.wraps(func)(wrapper)
  384. return wrapper