vmap.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # mypy: ignore-errors
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import contextlib
  8. import functools
  9. import itertools
  10. from functools import partial
  11. from typing import Any, Callable, Optional, Union
  12. import torch
  13. from torch import Tensor
  14. from torch._C._functorch import is_batchedtensor
  15. from torch._functorch.predispatch import (
  16. _add_batch_dim,
  17. _remove_batch_dim,
  18. _vmap_decrement_nesting,
  19. _vmap_increment_nesting,
  20. lazy_load_decompositions,
  21. )
  22. from torch.utils._pytree import (
  23. _broadcast_to_and_flatten,
  24. tree_flatten,
  25. tree_map_,
  26. tree_unflatten,
  27. TreeSpec,
  28. )
  29. in_dims_t = Union[int, tuple]
  30. out_dims_t = Union[int, tuple[int, ...]]
  31. def doesnt_support_saved_tensors_hooks(f):
  32. message = (
  33. "torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. "
  34. "Please open an issue with your use case."
  35. )
  36. @functools.wraps(f)
  37. def fn(*args, **kwargs):
  38. with torch.autograd.graph.disable_saved_tensors_hooks(message):
  39. return f(*args, **kwargs)
  40. return fn
  41. # Checks that all args-to-be-batched have the same batch dim size
  42. def _validate_and_get_batch_size(
  43. flat_in_dims: list[Optional[int]], flat_args: list
  44. ) -> int:
  45. batch_sizes = [
  46. arg.size(in_dim)
  47. for in_dim, arg in zip(flat_in_dims, flat_args)
  48. if in_dim is not None
  49. ]
  50. if len(batch_sizes) == 0:
  51. raise ValueError("vmap: Expected at least one Tensor to vmap over")
  52. if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
  53. raise ValueError(
  54. f"vmap: Expected all tensors to have the same size in the mapped "
  55. f"dimension, got sizes {batch_sizes} for the mapped dimension"
  56. )
  57. return batch_sizes[0]
  58. def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int:
  59. if isinstance(batched_outputs, tuple):
  60. return len(batched_outputs)
  61. return 1
  62. # If value is a tuple, check it has length `num_elements`.
  63. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times
  64. def _as_tuple(
  65. value: Any, num_elements: int, error_message_lambda: Callable[[], str]
  66. ) -> tuple:
  67. if not isinstance(value, tuple):
  68. return (value,) * num_elements
  69. if len(value) != num_elements:
  70. raise ValueError(error_message_lambda())
  71. return value
  72. def _process_batched_inputs(
  73. in_dims: in_dims_t, args: tuple, func: Callable
  74. ) -> tuple[int, list[Any], list[Any], TreeSpec]:
  75. if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
  76. raise ValueError(
  77. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  78. f"expected `in_dims` to be int or a (potentially nested) tuple "
  79. f"matching the structure of inputs, got: {type(in_dims)}."
  80. )
  81. if len(args) == 0:
  82. raise ValueError(
  83. f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
  84. f"inputs, or you are trying to vmap over a function with no inputs. "
  85. f"The latter is unsupported."
  86. )
  87. flat_args, args_spec = tree_flatten(args)
  88. flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
  89. if flat_in_dims is None:
  90. raise ValueError(
  91. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  92. f"in_dims is not compatible with the structure of `inputs`. "
  93. f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
  94. f"has structure {args_spec}."
  95. )
  96. for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):
  97. if not isinstance(in_dim, int) and in_dim is not None:
  98. raise ValueError(
  99. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  100. f"Got in_dim={in_dim} for an input but in_dim must be either "
  101. f"an integer dimension or None."
  102. )
  103. if isinstance(in_dim, int) and not isinstance(arg, Tensor):
  104. raise ValueError(
  105. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  106. f"Got in_dim={in_dim} for an input but the input is of type "
  107. f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
  108. f"please use None as the respective in_dim"
  109. )
  110. if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()):
  111. raise ValueError(
  112. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  113. f"Got in_dim={in_dim} for some input, but that input is a Tensor "
  114. f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
  115. f"-{arg.dim()} <= in_dim < {arg.dim()}."
  116. )
  117. if in_dim is not None and in_dim < 0:
  118. flat_in_dims[i] = in_dim % arg.dim()
  119. return (
  120. _validate_and_get_batch_size(flat_in_dims, flat_args),
  121. flat_in_dims,
  122. flat_args,
  123. args_spec,
  124. )
  125. # Creates BatchedTensors for every Tensor in arg that should be batched.
  126. # Returns the (potentially) batched arguments and the batch_size.
  127. def _create_batched_inputs(
  128. flat_in_dims: list[Any], flat_args: list[Any], vmap_level: int, args_spec
  129. ) -> tuple:
  130. # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  131. batched_inputs = [
  132. arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level)
  133. for in_dim, arg in zip(flat_in_dims, flat_args)
  134. ]
  135. return tree_unflatten(batched_inputs, args_spec)
  136. def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):
  137. if out_dim is None:
  138. if isinstance(batched_output, torch.Tensor) and is_batchedtensor(
  139. batched_output
  140. ):
  141. raise ValueError(
  142. f"vmap({name}, ...): `{name}` can not return a "
  143. f"BatchedTensor when out_dim is None"
  144. )
  145. return batched_output
  146. # out_dim is non None
  147. if not isinstance(batched_output, torch.Tensor):
  148. raise ValueError(
  149. f"vmap({name}, ...): `{name}` must only return "
  150. f"Tensors, got type {type(batched_output)}. "
  151. "Did you mean to set out_dims= to None for output?"
  152. )
  153. return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
  154. # Undos the batching (and any batch dimensions) associated with the `vmap_level`.
  155. def _unwrap_batched(
  156. batched_outputs: Union[Tensor, tuple[Tensor, ...]],
  157. out_dims: out_dims_t,
  158. vmap_level: int,
  159. batch_size: int,
  160. func: Callable,
  161. ) -> tuple:
  162. flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
  163. def incompatible_error():
  164. raise ValueError(
  165. f"vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): "
  166. f"out_dims is not compatible with the structure of `outputs`. "
  167. f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs "
  168. f"has structure {output_spec}."
  169. )
  170. if isinstance(batched_outputs, torch.Tensor):
  171. # Some weird edge case requires us to spell out the following
  172. # see test_out_dims_edge_case
  173. if isinstance(out_dims, int):
  174. flat_out_dims = [out_dims]
  175. elif isinstance(out_dims, tuple) and len(out_dims) == 1:
  176. flat_out_dims = out_dims
  177. elif out_dims is None:
  178. flat_out_dims = [out_dims]
  179. else:
  180. incompatible_error()
  181. else:
  182. flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
  183. if flat_out_dims is None:
  184. incompatible_error()
  185. flat_outputs = [
  186. _maybe_remove_batch_dim(
  187. _get_name(func), batched_output, vmap_level, batch_size, out_dim
  188. )
  189. for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
  190. ]
  191. return tree_unflatten(flat_outputs, output_spec)
  192. def _check_int_or_none(x, func, out_dims):
  193. if isinstance(x, int):
  194. return
  195. if x is None:
  196. return
  197. raise ValueError(
  198. f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
  199. f"an int, None or a python collection of ints representing where in the outputs the "
  200. f"vmapped dimension should appear."
  201. )
  202. def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:
  203. if isinstance(out_dims, int):
  204. return
  205. tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims)
  206. def _get_name(func: Callable):
  207. if hasattr(func, "__name__"):
  208. return func.__name__
  209. if isinstance(func, functools.partial):
  210. return f"functools.partial({_get_name(func.func)}, ...)"
  211. # Not all callables have __name__, in fact, only static functions/methods
  212. # do. A callable created via nn.Module, to name one example, doesn't have a
  213. # __name__.
  214. return repr(func)
  215. def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
  216. lazy_load_decompositions()
  217. _check_out_dims_is_int_or_int_pytree(out_dims, func)
  218. batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
  219. in_dims, args, func
  220. )
  221. if chunk_size is not None:
  222. chunks_flat_args = _get_chunked_inputs(
  223. flat_args, flat_in_dims, batch_size, chunk_size
  224. )
  225. return _chunked_vmap(
  226. func,
  227. flat_in_dims,
  228. chunks_flat_args,
  229. args_spec,
  230. out_dims,
  231. randomness,
  232. **kwargs,
  233. )
  234. # If chunk_size is not specified.
  235. return _flat_vmap(
  236. func,
  237. batch_size,
  238. flat_in_dims,
  239. flat_args,
  240. args_spec,
  241. out_dims,
  242. randomness,
  243. **kwargs,
  244. )
  245. def get_chunk_sizes(total_elems, chunk_size):
  246. n_chunks = n_chunks = total_elems // chunk_size
  247. chunk_sizes = [chunk_size] * n_chunks
  248. # remainder chunk
  249. remainder = total_elems % chunk_size
  250. if remainder != 0:
  251. chunk_sizes.append(remainder)
  252. return chunk_sizes
  253. def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size):
  254. split_idxs = (batch_size,)
  255. if chunk_size is not None:
  256. chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
  257. split_idxs = tuple(itertools.accumulate(chunk_sizes))
  258. flat_args_chunks = tuple(
  259. (
  260. t.tensor_split(split_idxs, dim=in_dim)
  261. if in_dim is not None
  262. else [
  263. t,
  264. ]
  265. * len(split_idxs)
  266. )
  267. for t, in_dim in zip(flat_args, flat_in_dims)
  268. )
  269. # transpose chunk dim and flatten structure
  270. # chunks_flat_args is a list of flatten args
  271. chunks_flat_args = zip(*flat_args_chunks)
  272. return chunks_flat_args
  273. def _flatten_chunks_output(chunks_output_):
  274. # chunks_output is a list of chunked outputs
  275. # flatten chunked outputs:
  276. flat_chunks_output = []
  277. arg_spec = None
  278. for output in chunks_output_:
  279. flat_output, arg_specs = tree_flatten(output)
  280. flat_chunks_output.append(flat_output)
  281. if arg_spec is None:
  282. arg_spec = arg_specs
  283. # transpose chunk dim and flatten structure
  284. # flat_output_chunks is flat list of chunks
  285. flat_output_chunks = list(zip(*flat_chunks_output))
  286. return flat_output_chunks, arg_spec
  287. def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks):
  288. # concat chunks on out_dim
  289. flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)
  290. assert len(flat_out_dims) == len(flat_output_chunks)
  291. flat_output = []
  292. for idx, out_dim in enumerate(flat_out_dims):
  293. flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim))
  294. # release tensors
  295. flat_output_chunks[idx] = None
  296. return flat_output
  297. # Applies vmap on chunked_input and returns concatenated output over the chunks.
  298. def _chunked_vmap(
  299. func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs
  300. ):
  301. chunks_output = []
  302. rs = torch.get_rng_state() if randomness == "same" else None
  303. for flat_args in chunks_flat_args:
  304. batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
  305. # The way we compute split the input in `_get_chunked_inputs`,
  306. # we may get a tensor with `0` batch-size. We skip any computation
  307. # in that case.
  308. # Eg.
  309. # >>> chunk_size = 1
  310. # >>> batch_size = 6
  311. # >>> t = torch.zeros(batch_size, 1)
  312. # >>> t.tensor_split([1, 2, 3, 4, 5, 6])
  313. # (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]),
  314. # tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1)))
  315. if batch_size == 0:
  316. continue
  317. if rs is not None:
  318. torch.set_rng_state(rs)
  319. chunks_output.append(
  320. _flat_vmap(
  321. func,
  322. batch_size,
  323. flat_in_dims,
  324. flat_args,
  325. args_spec,
  326. out_dims,
  327. randomness,
  328. **kwargs,
  329. )
  330. )
  331. flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)
  332. # chunked output tensors are held by both `flat_output_chunks` and `chunks_output`.
  333. # eagerly remove the reference from `chunks_output`.
  334. del chunks_output
  335. # concat chunks on out_dim
  336. flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks)
  337. # finally unflatten the output
  338. return tree_unflatten(flat_output, arg_spec)
  339. # Vmap refactored helper functions:
  340. def _check_randomness_arg(randomness):
  341. if randomness not in ["error", "different", "same"]:
  342. raise RuntimeError(
  343. f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}"
  344. )
  345. @contextlib.contextmanager
  346. def vmap_increment_nesting(batch_size, randomness):
  347. try:
  348. vmap_level = _vmap_increment_nesting(batch_size, randomness)
  349. yield vmap_level
  350. finally:
  351. _vmap_decrement_nesting()
  352. def _flat_vmap(
  353. func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
  354. ):
  355. with vmap_increment_nesting(batch_size, randomness) as vmap_level:
  356. batched_inputs = _create_batched_inputs(
  357. flat_in_dims, flat_args, vmap_level, args_spec
  358. )
  359. batched_outputs = func(*batched_inputs, **kwargs)
  360. return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
  361. # `restore_vmap` is a private helper function. It is vmap but has the following
  362. # differences:
  363. # - instead of returning outputs, it returns an (outputs, out_dims) tuple.
  364. # out_dims is a pytree of same shape as outputs and contains Optional[int]
  365. # specifying where the vmapped dimension, if it exists, is in the corresponding output.
  366. # - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped).
  367. # restore_vmap allows for no inputs to have the vmap dimension
  368. # - does no validation on outputs (vmap expects only Tensor outputs)
  369. # restore_vmap allows for return of arbitrary outputs (not just Tensors)
  370. #
  371. # The TL;DR is that restore_vmap is more general than vmap and has a slightly
  372. # different API. The relaxations are so that we can "pause" vmap in the middle
  373. # of its execution and then "restore" it later (this is what we do in
  374. # the generate_vmap_rule=True implementation of autograd.Function).
  375. #
  376. # restore_vmap can be technically used in the implementation of vmap, but doing
  377. # that refactor is a bit technically challenging because:
  378. # - vmap couples the tensor-wrapping code with error checking
  379. # - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it
  380. # in python because it overlaps with unwrap_batched
  381. def restore_vmap(func, in_dims, batch_size, randomness):
  382. def inner(*args, **kwargs):
  383. with vmap_increment_nesting(batch_size, randomness) as vmap_level:
  384. batched_inputs = wrap_batched(args, in_dims, vmap_level)
  385. batched_outputs = func(*batched_inputs, **kwargs)
  386. return unwrap_batched(batched_outputs, vmap_level)
  387. return inner
  388. def wrap_batched(args, bdims, level):
  389. flat_args, spec = tree_flatten(args)
  390. flat_bdims = _broadcast_to_and_flatten(bdims, spec)
  391. assert flat_bdims is not None
  392. result = _create_batched_inputs(flat_bdims, flat_args, level, spec)
  393. return result
  394. def unwrap_batched(args, level):
  395. flat_args, spec = tree_flatten(args)
  396. if len(flat_args) == 0:
  397. return args, ()
  398. result = [
  399. (
  400. torch._C._functorch._unwrap_batched(arg, level)
  401. if isinstance(arg, torch.Tensor)
  402. else (arg, None)
  403. )
  404. for arg in flat_args
  405. ]
  406. output, bdims = zip(*result)
  407. return tree_unflatten(output, spec), tree_unflatten(bdims, spec)