wrappers.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import types
  4. import warnings
  5. from collections.abc import Sequence
  6. from functools import wraps
  7. from types import GenericAlias
  8. from typing import Callable, NamedTuple, Optional, overload, TypeVar, Union
  9. from typing_extensions import ParamSpec
  10. import torch
  11. import torch._prims_common as utils
  12. from torch._prims_common import (
  13. CustomOutParamAnnotation,
  14. ELEMENTWISE_TYPE_PROMOTION_KIND,
  15. Number,
  16. NumberType,
  17. ShapeType,
  18. TensorLike,
  19. TensorLikeType,
  20. )
  21. from torch.utils import _pytree as pytree
  22. from torch.utils._pytree import tree_flatten, tree_unflatten
  23. _T = TypeVar("_T")
  24. _P = ParamSpec("_P")
  25. @overload
  26. def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
  27. pass
  28. @overload
  29. def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
  30. pass
  31. @overload
  32. def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
  33. pass
  34. @overload
  35. def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None:
  36. pass
  37. # TODO: implement ref.cast with an option to enforce safe casting
  38. def _maybe_convert_to_dtype(a, dtype):
  39. if isinstance(a, TensorLike):
  40. if a.dtype != dtype:
  41. return a.to(dtype)
  42. return a
  43. if isinstance(a, Number):
  44. return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type]
  45. if isinstance(a, Sequence):
  46. return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
  47. # Passthrough None because some functions wrapped with type promotion
  48. # wrapper might have optional args
  49. if a is None:
  50. return None
  51. raise ValueError(
  52. f"Received unsupported type {type(a)}. Expected TensorLike, Number, or Sequence."
  53. )
  54. def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
  55. if not isinstance(a, Number):
  56. msg = f"Found unknown type {type(a)} when trying to convert scalars!"
  57. raise ValueError(msg)
  58. if not utils.is_weakly_lesser_type(type(a), typ):
  59. msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!"
  60. raise ValueError(msg)
  61. return typ(a)
  62. def _annotation_has_type(*, typ, annotation):
  63. if hasattr(annotation, "__args__"):
  64. for a in annotation.__args__:
  65. if _annotation_has_type(typ=typ, annotation=a):
  66. return True
  67. return False
  68. return typ is annotation
  69. class elementwise_type_promotion_wrapper:
  70. """
  71. Adds elementwise type promotion to a Python reference implementation.
  72. Takes two kwargs, type_promoting_args and type_promotion_kind.
  73. type_promoting_args must be a string Sequence specifying the argument names of all
  74. arguments that participate in type promotion (and should be type promoted). If the
  75. arg specifies a Sequence-type then every element of the Sequence will participate in
  76. type promotion.
  77. type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
  78. See its documentation for details.
  79. The return_dtype will be coerced to the wrapped function's dtype arg if it is available and
  80. not None.
  81. Other type promotion behavior, like validating the Python type of scalar arguments, must
  82. be handled separately.
  83. """
  84. def __init__(
  85. self,
  86. *,
  87. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
  88. type_promoting_args: Optional[Sequence[str]] = None,
  89. ):
  90. self.type_promoting_arg_names = type_promoting_args
  91. self.type_promotion_kind = type_promotion_kind
  92. def __call__(self, fn: Callable) -> Callable:
  93. sig = inspect.signature(fn)
  94. # TorchDynamo tracing of inspect causes fake tensor dynamo_wrapped tests to fail
  95. # PYTORCH_TEST_WITH_DYNAMO=1 python test/test_fake_tensor.py FakeTensorTest.test_basic
  96. @torch._disable_dynamo
  97. @wraps(fn)
  98. def _fn(*args, **kwargs):
  99. bound = sig.bind(*args, **kwargs)
  100. type_promoting_args = tuple(
  101. bound.arguments[x]
  102. for x in self.type_promoting_arg_names # type: ignore[union-attr]
  103. if x in bound.arguments.keys()
  104. )
  105. flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args)
  106. compute_dtype, result_dtype = utils.elementwise_dtypes(
  107. *flattened_type_promoting_args,
  108. type_promotion_kind=self.type_promotion_kind,
  109. )
  110. promoted_args = {
  111. x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
  112. for x in self.type_promoting_arg_names # type: ignore[union-attr]
  113. if x in bound.arguments.keys()
  114. }
  115. bound.arguments.update(promoted_args)
  116. result = fn(**bound.arguments)
  117. # Override the return_dtype if a dtype arg is present and not None
  118. if "dtype" in bound.arguments:
  119. maybe_dtype = bound.arguments["dtype"]
  120. if maybe_dtype: # dtype cannot be None
  121. result_dtype = maybe_dtype
  122. if isinstance(result, TensorLike):
  123. return _maybe_convert_to_dtype(result, result_dtype)
  124. if isinstance(result, Sequence):
  125. return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
  126. raise AssertionError(f"Unhandled result type: {type(result)}")
  127. _fn.__signature__ = sig # type: ignore[attr-defined]
  128. return _fn
  129. # Returns True if resize is necessary
  130. def _resize_output_check(out: TensorLikeType, shape: ShapeType):
  131. # If the shapes are correct there's nothing to do
  132. if utils.same_shape(out.shape, shape):
  133. return False
  134. if out.numel() != 0:
  135. msg = (
  136. f"An output with one or more elements was resized since it had shape {str(out.shape)} "
  137. "which does not match the required output shape {str(shape)}. "
  138. "This behavior is deprecated, and in a future PyTorch release outputs will not "
  139. "be resized unless they have zero elements. "
  140. "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
  141. )
  142. warnings.warn(msg)
  143. return True
  144. # TODO: handle tuples of tensors
  145. def _maybe_resize_out(
  146. out: TensorLikeType,
  147. shape: ShapeType,
  148. memory_format: Optional[torch.memory_format] = None,
  149. ):
  150. if _resize_output_check(out, shape):
  151. return out.resize_(shape, memory_format=memory_format)
  152. else:
  153. return out
  154. def is_cpu_scalar(x: TensorLikeType) -> bool:
  155. return x.dim() == 0 and x.device.type == "cpu"
  156. def check_copy_devices(*, copy_from: TensorLikeType, copy_to: TensorLikeType) -> None:
  157. if copy_from.device != copy_to.device:
  158. msg = (
  159. f"Attempting to copy from device {copy_from.device} "
  160. f"to device {copy_to.device}, but cross-device copies are not allowed!"
  161. )
  162. raise RuntimeError(msg)
  163. def _safe_copy_out(
  164. *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
  165. ):
  166. # Checks same device
  167. if not is_cpu_scalar(copy_from):
  168. check_copy_devices(copy_from=copy_from, copy_to=copy_to)
  169. # Checks safe cast
  170. if exact_dtype:
  171. torch._check(
  172. copy_from.dtype == copy_to.dtype,
  173. lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
  174. f"but got {copy_to.dtype} instead",
  175. )
  176. else:
  177. torch._check(
  178. utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
  179. lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
  180. "but this can't be cast because it is not safe!",
  181. )
  182. return copy_to.copy_(copy_from)
  183. def out_wrapper(
  184. *out_names: str,
  185. exact_dtype: bool = False,
  186. pass_is_out: bool = False,
  187. preserve_memory_format: bool = False,
  188. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  189. # The wrapped function needs to convert the output parameters to ensure
  190. # compatibility between the Python API (which always uses "out" as the
  191. # parameter name and may be a tuple) and the Aten API (which may have
  192. # multiple output parameters and use different parameter names such as
  193. # "grad_input", "indices" or "values".)
  194. default_out_names = ("out",)
  195. if len(out_names) == 0:
  196. # Use default in out name
  197. out_names = default_out_names
  198. is_tensor = len(out_names) == 1
  199. def maybe_compute_memory_format(t):
  200. return utils.suggest_memory_format(t) if preserve_memory_format else None
  201. def _out_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  202. """
  203. Adds the out parameter to a Python reference.
  204. """
  205. out_type = (
  206. TensorLikeType
  207. if is_tensor
  208. else GenericAlias(
  209. tuple, tuple(TensorLikeType for _ in range(len(out_names)))
  210. )
  211. )
  212. # For backward compatibility - should be able to remove once PEP585
  213. # conversion is complete.
  214. bc_out_type = (
  215. TensorLikeType
  216. if is_tensor
  217. else types.GenericAlias(
  218. tuple, tuple(TensorLikeType for _ in range(len(out_names)))
  219. )
  220. )
  221. return_type = (
  222. TensorLikeType
  223. if is_tensor
  224. else NamedTuple(
  225. f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
  226. )
  227. )
  228. sig = inspect.signature(fn)
  229. factory_kwargs = ("device", "dtype")
  230. is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
  231. @wraps(fn)
  232. def _fn(*args: _P.args, **kwargs: _P.kwargs):
  233. out = kwargs.pop("out", None)
  234. if is_factory_fn and out is not None:
  235. for k in factory_kwargs:
  236. out_attr = getattr(out, k)
  237. if k not in kwargs:
  238. kwargs[k] = out_attr
  239. def maybe_check_copy_devices(out):
  240. if isinstance(out, TensorLike) and isinstance(args[0], TensorLike):
  241. check_copy_devices(copy_from=args[0], copy_to=out)
  242. if isinstance(out, (tuple, list)):
  243. for o in out:
  244. maybe_check_copy_devices(o)
  245. else:
  246. maybe_check_copy_devices(out)
  247. if pass_is_out:
  248. result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type]
  249. else:
  250. result = fn(*args, **kwargs)
  251. if result is NotImplemented:
  252. return NotImplemented
  253. assert (
  254. (isinstance(result, TensorLike) and is_tensor)
  255. or (
  256. isinstance(result, tuple) # type: ignore[arg-type]
  257. and len(result) == len(out_names) # type: ignore[arg-type]
  258. )
  259. or (
  260. fn.__name__ == "unbind" and isinstance(result, (list, tuple)) # type: ignore[arg-type]
  261. )
  262. )
  263. # unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829
  264. if out is not None:
  265. # Naively you might expect this assert to be true, but
  266. # it's not:
  267. #
  268. # assert type(out) == type(result)
  269. #
  270. # The reason is that functions under this wrapper can
  271. # get registered to the Meta dispatch key, and that
  272. # means they can be executed in a context where tensor
  273. # subclasses are disabled (with no_dispatch), which is a
  274. # handy way for an is-a tensor subclass (e.g.,
  275. # FakeTensor) to have the normal meta backend create a
  276. # meta tensor, to be wrapped once it gets returned.
  277. # In this situation, you will get a FakeTensor as
  278. # the output tensor, but not the result--which will
  279. # be a normal meta tensor, but this is perfectly
  280. # harmless.
  281. if is_tensor and fn.__name__ != "unbind":
  282. assert isinstance(out, TensorLike)
  283. # These two operations are done in-place
  284. _maybe_resize_out(
  285. out,
  286. result.shape, # type: ignore[union-attr]
  287. maybe_compute_memory_format(result),
  288. )
  289. _safe_copy_out(
  290. copy_from=result, # type: ignore[arg-type]
  291. copy_to=out,
  292. exact_dtype=exact_dtype,
  293. )
  294. else:
  295. if fn.__name__ != "unbind":
  296. assert isinstance(out, tuple) # type: ignore[arg-type]
  297. else:
  298. assert isinstance(out, (list, tuple)) # type: ignore[arg-type]
  299. torch._check_type(
  300. len(out) == len(result), # type: ignore[arg-type]
  301. lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type]
  302. )
  303. for r, o in zip(result, out): # type: ignore[arg-type]
  304. # These two operations are done in-place
  305. _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r))
  306. _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
  307. else:
  308. out = result
  309. # mypy does not see through the definition of out_type given that it's in a different scope
  310. return out if is_tensor else return_type(*out) # type: ignore[operator]
  311. out_param = inspect.Parameter(
  312. "out",
  313. kind=inspect.Parameter.KEYWORD_ONLY,
  314. default=None,
  315. annotation=out_type,
  316. )
  317. # Mark that the function now returns a tuple
  318. assert isinstance(
  319. sig.return_annotation, (str, TypeVar)
  320. ) or sig.return_annotation in (
  321. sig.empty,
  322. out_type,
  323. bc_out_type,
  324. )
  325. params = *sig.parameters.values(), out_param
  326. # If there's a Parameter.VAR_KEYWORD parameter (like **kwds), it must appear
  327. # after the out= parameter, which is Parameter.KEYWORD_ONLY. Sorting by
  328. # Parameter.kind guarantees that all the parameters are in legal order.
  329. params = sorted(params, key=lambda p: p.kind)
  330. _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
  331. parameters=params,
  332. return_annotation=return_type, # type: ignore[arg-type]
  333. )
  334. _fn.__annotations__ = dict(getattr(fn, "__annotations__", {}))
  335. _fn.__annotations__["out"] = out_type
  336. _fn.__annotations__["return"] = return_type
  337. # In the special case of having a single tensor out parameter with a
  338. # name other than out, add a special annotation to name the parameter
  339. if is_tensor and out_names != default_out_names:
  340. _fn.__annotations__[CustomOutParamAnnotation] = out_names[0]
  341. # Add an indicator attribute that can be used in special cases
  342. # where having a function wrapped by `out_wrapper` is not desirable e.g.
  343. # jit
  344. _fn._torch_decompositions_out_wrapper = ( # type: ignore[attr-defined]
  345. f"This function is wrapped by {out_wrapper.__module__}.out_wrapper"
  346. )
  347. return _fn
  348. return _out_wrapper
  349. def _maybe_remove_out_wrapper(fn: Callable):
  350. return inspect.unwrap(
  351. fn,
  352. stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"),
  353. )
  354. def backwards_not_supported(prim):
  355. def redispatch_prim(args, kwargs):
  356. with torch._C._AutoDispatchBelowAutograd():
  357. return prim(*args, **kwargs)
  358. class BackwardsNotSupported(torch.autograd.Function):
  359. @staticmethod
  360. def forward(ctx, args_spec, *flat_args):
  361. args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
  362. return redispatch_prim(args, kwargs)
  363. @staticmethod
  364. def backward(ctx, *args):
  365. raise RuntimeError("backwards not supported on prim")
  366. @wraps(prim)
  367. def _autograd_impl(*args, **kwargs):
  368. flat_args, args_spec = tree_flatten((args, kwargs))
  369. if torch.is_grad_enabled() and any(
  370. a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)
  371. ):
  372. # TODO: There is a subtle bug here: prims like copy_to
  373. # return their input argument after mutating it; and custom
  374. # autograd function will incorrectly turn the result into
  375. # a view which will fail test_python_ref_executor tests.
  376. # At the moment, we sidestep this by observing that the
  377. # unit tests don't ever try to run the executor with
  378. # autograd, so we don't exercise the buggy case, but if
  379. # you ever want to feed autograd through this, be aware
  380. # of it! We need a way of properly implementing autograd
  381. # for mutating operations in Python to do this.
  382. return BackwardsNotSupported.apply(args_spec, *flat_args)
  383. else:
  384. return redispatch_prim(args, kwargs)
  385. return _autograd_impl
  386. # TODO: when tracing this will add torch tensors and not TensorMeta objects
  387. # to the trace -- we should fix this by adding a tracing context and NumberMeta classes
  388. # TODO: this wrapper is currently untested
  389. def elementwise_unary_scalar_wrapper(
  390. fn: Callable[_P, _T],
  391. ) -> Callable[_P, Union[_T, NumberType]]:
  392. """
  393. Allows unary operators that accept tensors to work with Python numbers.
  394. """
  395. sig = inspect.signature(fn)
  396. @wraps(fn)
  397. def _fn(*args, **kwargs):
  398. if len(args) > 0 and isinstance(args[0], Number):
  399. dtype = utils.type_to_dtype(type(args[0]))
  400. args_ = list(args)
  401. args_[0] = torch.tensor(args[0], dtype=dtype)
  402. result = fn(*args_, **kwargs)
  403. assert isinstance(result, torch.Tensor)
  404. return result.item()
  405. return fn(*args, **kwargs)
  406. _fn.__signature__ = sig # type: ignore[attr-defined]
  407. return _fn