_vmap_internals.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from typing import Any, Callable, Optional, Union
  4. from typing_extensions import deprecated
  5. import torch
  6. from torch import Tensor
  7. from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
  8. in_dims_t = Union[int, tuple]
  9. out_dims_t = Union[int, tuple[int, ...]]
  10. # Checks that all args-to-be-batched have the same batch dim size
  11. def _validate_and_get_batch_size(
  12. flat_in_dims: list[Optional[int]],
  13. flat_args: list,
  14. ) -> int:
  15. batch_sizes = [
  16. arg.size(in_dim)
  17. for in_dim, arg in zip(flat_in_dims, flat_args)
  18. if in_dim is not None
  19. ]
  20. if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
  21. raise ValueError(
  22. f"vmap: Expected all tensors to have the same size in the mapped "
  23. f"dimension, got sizes {batch_sizes} for the mapped dimension"
  24. )
  25. return batch_sizes[0]
  26. def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int:
  27. if isinstance(batched_outputs, tuple):
  28. return len(batched_outputs)
  29. return 1
  30. # If value is a tuple, check it has length `num_elements`.
  31. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times
  32. def _as_tuple(
  33. value: Any,
  34. num_elements: int,
  35. error_message_lambda: Callable[[], str],
  36. ) -> tuple:
  37. if not isinstance(value, tuple):
  38. return (value,) * num_elements
  39. if len(value) != num_elements:
  40. raise ValueError(error_message_lambda())
  41. return value
  42. # Creates BatchedTensors for every Tensor in arg that should be batched.
  43. # Returns the (potentially) batched arguments and the batch_size.
  44. def _create_batched_inputs(
  45. in_dims: in_dims_t,
  46. args: tuple,
  47. vmap_level: int,
  48. func: Callable,
  49. ) -> tuple[tuple, int]:
  50. if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
  51. raise ValueError(
  52. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  53. f"expected `in_dims` to be int or a (potentially nested) tuple "
  54. f"matching the structure of inputs, got: {type(in_dims)}."
  55. )
  56. if len(args) == 0:
  57. raise ValueError(
  58. f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
  59. f"inputs, or you are trying to vmap over a function with no inputs. "
  60. f"The latter is unsupported."
  61. )
  62. flat_args, args_spec = tree_flatten(args)
  63. flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
  64. if flat_in_dims is None:
  65. raise ValueError(
  66. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  67. f"in_dims is not compatible with the structure of `inputs`. "
  68. f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
  69. f"has structure {args_spec}."
  70. )
  71. for arg, in_dim in zip(flat_args, flat_in_dims):
  72. if not isinstance(in_dim, int) and in_dim is not None:
  73. raise ValueError(
  74. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  75. f"Got in_dim={in_dim} for an input but in_dim must be either "
  76. f"an integer dimension or None."
  77. )
  78. if isinstance(in_dim, int) and not isinstance(arg, Tensor):
  79. raise ValueError(
  80. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  81. f"Got in_dim={in_dim} for an input but the input is of type "
  82. f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
  83. f"please use None as the respective in_dim"
  84. )
  85. if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
  86. raise ValueError(
  87. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  88. f"Got in_dim={in_dim} for some input, but that input is a Tensor "
  89. f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
  90. f"0 <= in_dim < {arg.dim()}."
  91. )
  92. batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
  93. # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  94. batched_inputs = [
  95. arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level)
  96. for in_dim, arg in zip(flat_in_dims, flat_args)
  97. ]
  98. return tree_unflatten(batched_inputs, args_spec), batch_size
  99. # Undos the batching (and any batch dimensions) associated with the `vmap_level`.
  100. def _unwrap_batched(
  101. batched_outputs: Union[Tensor, tuple[Tensor, ...]],
  102. out_dims: out_dims_t,
  103. vmap_level: int,
  104. batch_size: int,
  105. func: Callable,
  106. allow_none_pass_through: bool = False,
  107. ) -> tuple:
  108. num_outputs = _num_outputs(batched_outputs)
  109. out_dims_as_tuple = _as_tuple(
  110. out_dims,
  111. num_outputs,
  112. lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must "
  113. f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.",
  114. )
  115. # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  116. # There is something wrong with our type bindings for functions that begin
  117. # with '_', see #40397.
  118. if isinstance(batched_outputs, Tensor):
  119. out_dim = out_dims_as_tuple[0]
  120. return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value]
  121. if allow_none_pass_through:
  122. return tuple(
  123. (
  124. torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
  125. if out is not None
  126. else None
  127. )
  128. for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
  129. )
  130. else:
  131. return tuple(
  132. torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
  133. for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
  134. )
  135. # Checks that `fn` returned one or more Tensors and nothing else.
  136. # NB: A python function that return multiple arguments returns a single tuple,
  137. # so we are effectively checking that `outputs` is a single Tensor or a tuple of
  138. # Tensors.
  139. def _validate_outputs(outputs: Any, func: Callable) -> None:
  140. if isinstance(outputs, Tensor):
  141. return
  142. if not isinstance(outputs, tuple):
  143. raise ValueError(
  144. f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
  145. f"Tensors, got type {type(outputs)} as the return."
  146. )
  147. for idx, output in enumerate(outputs):
  148. if isinstance(output, Tensor):
  149. continue
  150. raise ValueError(
  151. f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
  152. f"Tensors, got type {type(output)} for return {idx}."
  153. )
  154. def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
  155. if isinstance(out_dims, int):
  156. return
  157. if not isinstance(out_dims, tuple) or not all(
  158. isinstance(out_dim, int) for out_dim in out_dims
  159. ):
  160. raise ValueError(
  161. f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
  162. f"an int or a tuple of int representing where in the outputs the "
  163. f"vmapped dimension should appear."
  164. )
  165. def _get_name(func: Callable):
  166. if hasattr(func, "__name__"):
  167. return func.__name__
  168. # Not all callables have __name__, in fact, only static functions/methods do.
  169. # A callable created via functools.partial or an nn.Module, to name some
  170. # examples, don't have a __name__.
  171. return repr(func)
  172. # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
  173. # sends those into func, and then unwraps the output BatchedTensors. Operations
  174. # on BatchedTensors perform the batched operations that the user is asking for.
  175. @deprecated(
  176. "Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.",
  177. category=FutureWarning,
  178. )
  179. def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
  180. """
  181. Please use torch.vmap instead of this API.
  182. """
  183. return _vmap(func, in_dims, out_dims)
  184. # A version of vmap but without the initial "experimental prototype" warning
  185. def _vmap(
  186. func: Callable,
  187. in_dims: in_dims_t = 0,
  188. out_dims: out_dims_t = 0,
  189. allow_none_pass_through: bool = False,
  190. ) -> Callable:
  191. # The `allow_none_pass_through` argument is a temporary workaround may be removed.
  192. # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
  193. # which may return None if any of the inputs are unused. See the issue discussing this:
  194. # https://github.com/pytorch/functorch/issues/159.
  195. @functools.wraps(func)
  196. def wrapped(*args):
  197. _check_out_dims_is_int_or_int_tuple(out_dims, func)
  198. vmap_level = torch._C._vmapmode_increment_nesting()
  199. try:
  200. batched_inputs, batch_size = _create_batched_inputs(
  201. in_dims, args, vmap_level, func
  202. )
  203. batched_outputs = func(*batched_inputs)
  204. if not allow_none_pass_through:
  205. _validate_outputs(batched_outputs, func)
  206. return _unwrap_batched(
  207. batched_outputs,
  208. out_dims,
  209. vmap_level,
  210. batch_size,
  211. func,
  212. allow_none_pass_through=allow_none_pass_through,
  213. )
  214. finally:
  215. torch._C._vmapmode_decrement_nesting()
  216. return wrapped