custom_ops.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import inspect
  4. import logging
  5. import weakref
  6. from collections.abc import Iterable, Sequence
  7. from contextlib import contextmanager
  8. from typing import Any, Callable, Literal, Optional, overload, Union
  9. import torch
  10. from torch import _C, _ops, Tensor
  11. from torch.types import _dtype
  12. from torch.utils._exposed_in import exposed_in
  13. from . import autograd, utils
  14. device_types_t = Optional[Union[str, Sequence[str]]]
  15. log = logging.getLogger(__name__)
  16. @overload
  17. def custom_op(
  18. name: str,
  19. fn: Literal[None] = None,
  20. /,
  21. *,
  22. mutates_args: Union[str, Iterable[str]],
  23. device_types: device_types_t = None,
  24. schema: Optional[str] = None,
  25. ) -> Callable[[Callable[..., object]], "CustomOpDef"]: ...
  26. @overload
  27. def custom_op(
  28. name: str,
  29. fn: Callable[..., object],
  30. /,
  31. *,
  32. mutates_args: Union[str, Iterable[str]],
  33. device_types: device_types_t = None,
  34. schema: Optional[str] = None,
  35. ) -> "CustomOpDef": ...
  36. @exposed_in("torch.library")
  37. def custom_op(
  38. name: str,
  39. fn: Optional[Callable] = None,
  40. /,
  41. *,
  42. mutates_args: Union[str, Iterable[str]],
  43. device_types: device_types_t = None,
  44. schema: Optional[str] = None,
  45. tags: Optional[Sequence[_C.Tag]] = None,
  46. ) -> Union[Callable[[Callable[..., object]], "CustomOpDef"], "CustomOpDef"]:
  47. """Wraps a function into custom operator.
  48. Reasons why you may want to create a custom op include:
  49. - Wrapping a third-party library or custom kernel to work with PyTorch
  50. subsystems like Autograd.
  51. - Preventing torch.compile/export/FX tracing from peeking inside your function.
  52. This API is used as a decorator around a function (please see examples).
  53. The provided function must have type hints; these are needed to interface
  54. with PyTorch's various subsystems.
  55. Args:
  56. name (str): A name for the custom op that looks like "{namespace}::{name}",
  57. e.g. "mylib::my_linear". The name is used as the op's stable identifier
  58. in PyTorch subsystems (e.g. torch.export, FX graphs).
  59. To avoid name collisions, please use your project name as the namespace;
  60. e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
  61. mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
  62. This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
  63. it pessimistically assumes that all inputs to the operator are being mutated.
  64. device_types (None | str | Sequence[str]): The device type(s) the function
  65. is valid for. If no device type is provided, then the function
  66. is used as the default implementation for all device types.
  67. Examples: "cpu", "cuda".
  68. When registering a device-specific implementation for an operator that accepts no Tensors,
  69. we require the operator to have a "device: torch.device argument".
  70. schema (None | str): A schema string for the operator. If None
  71. (recommended) we'll infer a schema for the operator from its type
  72. annotations. We recommend letting us infer a schema unless you
  73. have a specific reason not to.
  74. Example: "(Tensor x, int y) -> (Tensor, Tensor)".
  75. .. note::
  76. We recommend not passing in a ``schema`` arg and instead letting us infer
  77. it from the type annotations. It is error-prone to write your own schema.
  78. You may wish to provide your own schema if our interpretation of
  79. the type annotation is not what you want.
  80. For more info on how to write a schema string, see
  81. `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func>`_
  82. Examples::
  83. >>> import torch
  84. >>> from torch import Tensor
  85. >>> from torch.library import custom_op
  86. >>> import numpy as np
  87. >>>
  88. >>> @custom_op("mylib::numpy_sin", mutates_args=())
  89. >>> def numpy_sin(x: Tensor) -> Tensor:
  90. >>> x_np = x.cpu().numpy()
  91. >>> y_np = np.sin(x_np)
  92. >>> return torch.from_numpy(y_np).to(device=x.device)
  93. >>>
  94. >>> x = torch.randn(3)
  95. >>> y = numpy_sin(x)
  96. >>> assert torch.allclose(y, x.sin())
  97. >>>
  98. >>> # Example of a custom op that only works for one device type.
  99. >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
  100. >>> def numpy_sin_cpu(x: Tensor) -> Tensor:
  101. >>> x_np = x.numpy()
  102. >>> y_np = np.sin(x_np)
  103. >>> return torch.from_numpy(y_np)
  104. >>>
  105. >>> x = torch.randn(3)
  106. >>> y = numpy_sin_cpu(x)
  107. >>> assert torch.allclose(y, x.sin())
  108. >>>
  109. >>> # Example of a custom op that mutates an input
  110. >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
  111. >>> def numpy_sin_inplace(x: Tensor) -> None:
  112. >>> x_np = x.numpy()
  113. >>> np.sin(x_np, out=x_np)
  114. >>>
  115. >>> x = torch.randn(3)
  116. >>> expected = x.sin()
  117. >>> numpy_sin_inplace(x)
  118. >>> assert torch.allclose(x, expected)
  119. >>>
  120. >>> # Example of a factory function
  121. >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
  122. >>> def bar(device: torch.device) -> Tensor:
  123. >>> return torch.ones(3)
  124. >>>
  125. >>> bar("cpu")
  126. """
  127. def inner(fn: Callable[..., object]) -> CustomOpDef:
  128. import torch
  129. if schema is None:
  130. schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
  131. else:
  132. schema_str = schema
  133. namespace, opname = name.split("::")
  134. result = CustomOpDef(namespace, opname, schema_str, fn, tags)
  135. if schema is not None:
  136. # Check that schema's alias annotations match those of `mutates_args`.
  137. expected = set()
  138. for arg in result._opoverload._schema.arguments:
  139. if arg.alias_info is not None and arg.alias_info.is_write:
  140. expected.add(arg.name)
  141. if expected != set(mutates_args):
  142. raise ValueError(
  143. f"Attempted to create a custom op with `mutates_args={mutates_args}` "
  144. f"and `schema={schema}. The schema suggests that the op mutates {expected}"
  145. f"which is different from what was provided to us in `mutates_args`. "
  146. f"Please make these consistent."
  147. )
  148. result.register_kernel(device_types)(fn)
  149. return result
  150. if fn is None:
  151. return inner
  152. return inner(fn)
  153. class CustomOpDef:
  154. """CustomOpDef is a wrapper around a function that turns it into a custom op.
  155. It has various methods for registering additional behavior for this
  156. custom op.
  157. You should not instantiate CustomOpDef directly; instead, use the
  158. :func:`torch.library.custom_op` API.
  159. """
  160. def __init__(
  161. self,
  162. namespace: str,
  163. name: str,
  164. schema: str,
  165. fn: Callable,
  166. tags: Optional[Sequence[_C.Tag]] = None,
  167. ) -> None:
  168. # Fields used to interface with the PyTorch dispatcher
  169. self._namespace = namespace
  170. self._name = name
  171. self._schema = schema
  172. self._tags = tags if tags is not None else []
  173. self._init_fn = fn
  174. self._backend_fns: dict[Union[str, None], Callable] = {}
  175. self._abstract_fn: Optional[Callable] = None
  176. self._setup_context_fn: Optional[Callable] = None
  177. self._backward_fn: Optional[Callable] = None
  178. self._torch_dispatch_fns: dict[type, Callable] = {}
  179. self._vmap_fn: Optional[Callable] = None
  180. self._autocast_cuda_dtype: Optional[_dtype] = None
  181. self._autocast_cpu_dtype: Optional[_dtype] = None
  182. self._lib = get_library_allowing_overwrite(self._namespace, self._name)
  183. self._register_to_dispatcher(self._tags)
  184. self._disabled_kernel: set = set()
  185. self._used_triton_kernels: list[Any] = list()
  186. OPDEFS[self._qualname] = self
  187. @property
  188. def _qualname(self) -> str:
  189. return f"{self._namespace}::{self._name}"
  190. def __repr__(self) -> str:
  191. return f"<CustomOpDef({self._qualname})>"
  192. @contextmanager
  193. def set_kernel_enabled(self, device_type: str, enabled: bool = True):
  194. """
  195. Disable or re-enable an already registered kernel for this custom operator.
  196. If the kernel is already disabled/enabled, this is a no-op.
  197. Note:
  198. If a kernel is first disabled and then registered, it is disabled until enabled again.
  199. Args:
  200. device_type (str): The device type to disable/enable the kernel for.
  201. disable (bool): Whether to disable or enable the kernel.
  202. Example:
  203. >>> inp = torch.randn(1)
  204. >>>
  205. >>> # define custom op `f`.
  206. >>> @custom_op("mylib::f", mutates_args=())
  207. >>> def f(x: Tensor) -> Tensor:
  208. >>> return torch.zeros(1)
  209. >>>
  210. >>> print(f(inp)) # tensor([0.]), default kernel
  211. >>>
  212. >>> @f.register_kernel("cpu")
  213. >>> def _(x):
  214. >>> return torch.ones(1)
  215. >>>
  216. >>> print(f(inp)) # tensor([1.]), CPU kernel
  217. >>>
  218. >>> # temporarily disable the CPU kernel
  219. >>> with f.set_kernel_enabled("cpu", enabled = False):
  220. >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
  221. """
  222. action = "enable" if enabled else "disable"
  223. originally_disabled = device_type in self._disabled_kernel
  224. if device_type not in self._backend_fns:
  225. log.warning(
  226. "Attempted to %s kernel for %s but no kernel was registered for this device type.",
  227. action,
  228. device_type,
  229. )
  230. if not enabled:
  231. if originally_disabled:
  232. log.warning(
  233. "Attempted to disable kernel for %s but it was already disabled.",
  234. device_type,
  235. )
  236. else:
  237. self._disabled_kernel.add(device_type)
  238. else: # enable the kernel
  239. if not originally_disabled:
  240. log.warning(
  241. "Attempted to enable kernel for %s but it was already enabled.",
  242. device_type,
  243. )
  244. else:
  245. self._disabled_kernel.remove(device_type)
  246. try:
  247. yield
  248. finally:
  249. # restore original state
  250. if originally_disabled:
  251. self._disabled_kernel.add(device_type)
  252. else:
  253. self._disabled_kernel.discard(device_type)
  254. def register_kernel(
  255. self, device_types: device_types_t, fn: Optional[Callable] = None, /
  256. ) -> Callable:
  257. """Register an implementation for a device type for this operator.
  258. Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
  259. This API may be used as a decorator.
  260. Args:
  261. fn (Callable): The function to register as the implementation for
  262. the given device types.
  263. device_types (str | Sequence[str]): The device device_types to register an impl to.
  264. Examples::
  265. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  266. >>> import torch
  267. >>> from torch import Tensor
  268. >>> from torch.library import custom_op
  269. >>> import numpy as np
  270. >>>
  271. >>> # Create a custom op that works on cpu
  272. >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
  273. >>> def numpy_sin(x: Tensor) -> Tensor:
  274. >>> x_np = x.numpy()
  275. >>> y_np = np.sin(x_np)
  276. >>> return torch.from_numpy(y_np)
  277. >>>
  278. >>> # Add implementations for the cuda device
  279. >>> @numpy_sin.register_kernel("cuda")
  280. >>> def _(x):
  281. >>> x_np = x.cpu().numpy()
  282. >>> y_np = np.sin(x_np)
  283. >>> return torch.from_numpy(y_np).to(device=x.device)
  284. >>>
  285. >>> x_cpu = torch.randn(3)
  286. >>> x_cuda = x_cpu.cuda()
  287. >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
  288. >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
  289. """
  290. def inner(fn):
  291. if device_types is None or isinstance(device_types, str):
  292. dtypes: list[Union[str, None]] = [device_types]
  293. else:
  294. dtypes = list(device_types)
  295. for device_type in dtypes:
  296. if device_type not in self._backend_fns:
  297. def backend_impl(*args, **kwargs):
  298. result = self._backend_fns[device_type](*args, **kwargs)
  299. def get_module():
  300. fn = self._backend_fns[device_type]
  301. return inspect.getmodule(fn)
  302. utils._c_check_aliasing_constraint(
  303. self._name,
  304. args,
  305. kwargs,
  306. result,
  307. get_module,
  308. )
  309. return result
  310. if device_type is None:
  311. self._lib.impl(
  312. self._name, backend_impl, "CompositeExplicitAutograd"
  313. )
  314. else:
  315. self._lib.impl(
  316. self._name,
  317. backend_impl,
  318. _C._dispatch_key_for_device(device_type),
  319. )
  320. # Wrap function to choose between the default implementation or the device-specific
  321. # implementation depending on if the kernel is disabled.
  322. @torch._disable_dynamo
  323. def wrapped_fn(*args, **kwargs):
  324. if device_type in self._disabled_kernel:
  325. return self._init_fn(*args, **kwargs)
  326. else:
  327. return fn(*args, **kwargs)
  328. self._backend_fns[device_type] = wrapped_fn
  329. return fn
  330. if device_types is not None and not utils.has_tensor_arg(
  331. self._opoverload._schema
  332. ):
  333. device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
  334. if device_arg_index is None:
  335. raise ValueError(
  336. "Functions without tensor inputs are required to have a `device: torch.device` argument"
  337. )
  338. self._register_backend_select_dispatcher(device_arg_index)
  339. # See NOTE: [Supporting decorator and non-decorator usage]
  340. if fn is None:
  341. return inner
  342. return inner(fn)
  343. def register_fake(self, fn: Callable, /) -> Callable:
  344. r"""Register a FakeTensor implementation for this custom op.
  345. This is necessary to get the operator to work efficiently with torch.compile.
  346. The Fake impl (sometimes also known as a meta kernel or abstract impl)
  347. specifies the behavior of this operator on Tensors that carry no data.
  348. Given some input Tensors with certain properties
  349. (sizes/strides/storage_offset/device), it specifies what the properties of
  350. the output Tensors are.
  351. Please see :func:`torch.library.register_fake` for more details.
  352. Args:
  353. fn (Callable): The function to register as the FakeTensor
  354. implementation.
  355. Examples:
  356. >>> import torch
  357. >>> import numpy as np
  358. >>> from torch import Tensor
  359. >>>
  360. >>> # Example 1: an operator without data-dependent output shape
  361. >>> @torch.library.custom_op("mylib::linear", mutates_args=())
  362. >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
  363. >>> return (x @ weight.t()) + bias
  364. >>>
  365. >>> @linear.register_fake
  366. >>> def _(x, weight, bias):
  367. >>> assert x.dim() == 2
  368. >>> assert weight.dim() == 2
  369. >>> assert bias.dim() == 1
  370. >>> assert x.shape[1] == weight.shape[1]
  371. >>> assert weight.shape[0] == bias.shape[0]
  372. >>> assert x.device == weight.device
  373. >>> return x.new_empty(x.size(0), weight.size(0))
  374. >>>
  375. >>> x = torch.randn(2, 2)
  376. >>> weight = torch.randn(2, 2)
  377. >>> bias = torch.randn(2)
  378. >>> # xdoctest: +SKIP("Requires Python <= 3.11")
  379. >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias)
  380. >>> # xdoctest: +SKIP("Requires Python <= 3.11")
  381. >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias))
  382. >>>
  383. >>> # Example 2: an operator with data-dependent output shape
  384. >>> @torch.library.custom_op("mylib::nonzero", mutates_args=())
  385. >>> def nonzero(x: Tensor) -> Tensor:
  386. >>> x_np = x.cpu().numpy()
  387. >>> res = np.stack(np.nonzero(x_np), axis=1)
  388. >>> return torch.tensor(res, device=x.device)
  389. >>>
  390. >>> @nonzero.register_fake
  391. >>> def _(x):
  392. >>> # Number of nonzero-elements is data-dependent.
  393. >>> # Since we cannot peek at the data in an abstract impl,
  394. >>> # we use the ctx object to construct a new symint that
  395. >>> # represents the data-dependent size.
  396. >>> ctx = torch.library.get_ctx()
  397. >>> nnz = ctx.new_dynamic_size()
  398. >>> shape = [nnz, x.dim()]
  399. >>> result = x.new_empty(shape, dtype=torch.int64)
  400. >>> return result
  401. >>>
  402. >>> x = torch.tensor([0, 1, 2, 0, 0, 1])
  403. >>> # xdoctest: +SKIP("Requires Python <= 3.11")
  404. >>> out = torch.compile(nonzero, fullgraph=True)(x)
  405. >>> # xdoctest: +SKIP("Requires Python <= 3.11")
  406. >>> assert torch.allclose(out, x.nonzero())
  407. """
  408. self._abstract_fn = fn
  409. return fn
  410. def register_torch_dispatch(
  411. self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
  412. ) -> Callable:
  413. r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
  414. This allows for open registration to specify the behavior between the operator
  415. and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
  416. or the operator directly.
  417. Please see :func:`torch.library.register_torch_dispatch` for examples and more details.
  418. """
  419. def register(fn):
  420. if torch_dispatch_class not in self._torch_dispatch_fns:
  421. def inner(*args, **kwargs):
  422. return self._torch_dispatch_fns[torch_dispatch_class](
  423. *args, **kwargs
  424. )
  425. self._lib._register_torch_dispatch_rule(
  426. self._name, torch_dispatch_class, inner
  427. )
  428. self._torch_dispatch_fns[torch_dispatch_class] = fn
  429. return fn
  430. if fn is None:
  431. return register
  432. else:
  433. return register(fn)
  434. def register_autograd(
  435. self,
  436. backward: Callable,
  437. /,
  438. *,
  439. setup_context: Optional[Callable] = None,
  440. ) -> None:
  441. r"""Register a backward formula for this custom op.
  442. In order for an operator to work with autograd, you need to register
  443. a backward formula:
  444. 1. You must tell us how to compute gradients during the backward pass
  445. by providing us a "backward" function.
  446. 2. If you need any values from the forward to compute gradients, you can
  447. use `setup_context` to save values for backward.
  448. ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``:
  449. - ``grads`` is one or more gradients. The number of gradients matches
  450. the number of outputs of the operator.
  451. The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
  452. :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
  453. same as :meth:`torch.autograd.Function.backward`.
  454. ``setup_context(ctx, inputs, output)`` runs during the forward pass.
  455. Please save quantities needed for backward onto the ``ctx`` object via
  456. either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
  457. or assigning them as attributes of ``ctx``. If your custom op has
  458. kwarg-only arguments, we expect the signature of ``setup_context``
  459. to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
  460. Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
  461. they may not directly access :meth:`torch.Tensor.data_ptr` and they must
  462. not depend on or mutate global state. If you need a non-traceable backward,
  463. you can make it a separate custom_op that you call inside ``backward_fn``.
  464. If you need different autograd behavior on different devices, then we
  465. recommend creating two different custom operators, one for each device
  466. that needs different behavior, and switching between them at runtime.
  467. Examples:
  468. >>> import torch
  469. >>> import numpy as np
  470. >>> from torch import Tensor
  471. >>>
  472. >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
  473. >>> def numpy_sin(x: Tensor) -> Tensor:
  474. >>> x_np = x.cpu().numpy()
  475. >>> y_np = np.sin(x_np)
  476. >>> return torch.from_numpy(y_np).to(device=x.device)
  477. >>>
  478. >>> def setup_context(ctx, inputs, output) -> Tensor:
  479. >>> x, = inputs
  480. >>> ctx.save_for_backward(x)
  481. >>>
  482. >>> def backward(ctx, grad):
  483. >>> x, = ctx.saved_tensors
  484. >>> return grad * x.cos()
  485. >>>
  486. >>> numpy_sin.register_autograd(backward, setup_context=setup_context)
  487. >>>
  488. >>> x = torch.randn(3, requires_grad=True)
  489. >>> y = numpy_sin(x)
  490. >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
  491. >>> assert torch.allclose(grad_x, x.cos())
  492. >>>
  493. >>> # Example with a keyword-only arg
  494. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  495. >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
  496. >>> x_np = x.cpu().numpy()
  497. >>> y_np = x_np * val
  498. >>> return torch.from_numpy(y_np).to(device=x.device)
  499. >>>
  500. >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
  501. >>> ctx.val = keyword_only_inputs["val"]
  502. >>>
  503. >>> def backward(ctx, grad):
  504. >>> return grad * ctx.val
  505. >>>
  506. >>> numpy_mul.register_autograd(backward, setup_context=setup_context)
  507. >>>
  508. >>> x = torch.randn(3, requires_grad=True)
  509. >>> y = numpy_mul(x, val=3.14)
  510. >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
  511. >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
  512. """
  513. schema = self._opoverload._schema
  514. if not utils.is_functional_schema(schema):
  515. raise RuntimeError(
  516. f"Cannot register autograd formula for non-functional operator "
  517. f"{self} with schema {schema}. Please create "
  518. f"a functional operator and register an autograd formula for that."
  519. )
  520. self._backward_fn = backward
  521. self._setup_context_fn = setup_context
  522. def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None:
  523. lib = self._lib
  524. schema_str = self._name + self._schema
  525. cpp_schema = _C.parse_schema(schema_str)
  526. if utils.has_kwarg_only_tensors(cpp_schema):
  527. # If you want to support this, the progression is:
  528. # - supporting kwarg-only Tensors that are non-differentiable
  529. # - supporting kwarg-only Tensors (regardless of differentiability)
  530. raise NotImplementedError(
  531. f"custom_op with kwarg-only Tensor args. Please make your "
  532. f"tensors not kwarg-only. Got: {schema_str}"
  533. )
  534. lib.define(
  535. schema_str,
  536. tags=[_C.Tag.pt2_compliant_tag, *tags],
  537. )
  538. self._opoverload = utils.lookup_op(self._qualname)
  539. def fake_impl(*args, **kwargs):
  540. if self._abstract_fn is None:
  541. if utils.can_generate_trivial_fake_impl(self._opoverload):
  542. return None
  543. raise RuntimeError(
  544. f"There was no fake impl registered for {self}. "
  545. f"This is necessary for torch.compile/export/fx tracing to work. "
  546. f"Please use `{self._init_fn.__name__}.register_fake` to add an "
  547. f"fake impl."
  548. )
  549. return self._abstract_fn(*args, **kwargs)
  550. lib._register_fake(self._name, fake_impl, _stacklevel=4)
  551. autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
  552. lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
  553. schema = self._opoverload._schema
  554. if schema.is_mutable:
  555. mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
  556. def adinplaceorview_impl(keyset, *args, **kwargs):
  557. for idx in mutated_idxs:
  558. increment_version(args[idx])
  559. for key in mutated_keys:
  560. increment_version(kwargs[key])
  561. with _C._AutoDispatchBelowADInplaceOrView():
  562. return self._opoverload.redispatch(
  563. keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
  564. )
  565. lib.impl(
  566. self._name,
  567. adinplaceorview_impl,
  568. "ADInplaceOrView",
  569. with_keyset=True,
  570. )
  571. def _register_backend_select_dispatcher(self, device_arg_index: int):
  572. """
  573. Switch on the device argument to select the correct backend to dispatch to.
  574. """
  575. def backend_select(keyset, *args, **kwargs):
  576. device = args[device_arg_index].type
  577. if device not in self._backend_fns:
  578. raise RuntimeError(
  579. f"{self._name} does not have a kernel registered for {device}. "
  580. "Please use register_kernel to do so."
  581. )
  582. dispatch_key = _C._dispatch_key_for_device(device)
  583. dispatch_key = getattr(_C.DispatchKey, dispatch_key)
  584. return self._opoverload.redispatch(
  585. _C.DispatchKeySet(dispatch_key), *args, **kwargs
  586. )
  587. self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True)
  588. def __call__(self, *args, **kwargs):
  589. return self._opoverload(*args, **kwargs)
  590. def register_vmap(
  591. self,
  592. func: Optional[Callable] = None,
  593. ):
  594. r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
  595. This API may be used as a decorator.
  596. In order for an operator to work with :func:`torch.vmap`, you may need to register a
  597. vmap implementation in the following signature:
  598. ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
  599. where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
  600. It specifies how do we compute the batched version of ``op`` given inputs with an additional
  601. dimension (specified by ``in_dims``).
  602. For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
  603. if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
  604. specifying what dimension of the Tensor is being vmapped over.
  605. ``info`` is a collection of additional metadata that may be helpful:
  606. ``info.batch_size`` specifies the size of the dimension being vmapped over, while
  607. ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
  608. The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
  609. ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
  610. per output that specifies if the output has the vmapped dimension and what index it is in.
  611. Examples:
  612. >>> import torch
  613. >>> import numpy as np
  614. >>> from torch import Tensor
  615. >>> from typing import Tuple
  616. >>>
  617. >>> def to_numpy(tensor):
  618. >>> return tensor.cpu().numpy()
  619. >>>
  620. >>> lib = torch.library.Library("mylib", "FRAGMENT")
  621. >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
  622. >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
  623. >>> x_np = to_numpy(x)
  624. >>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
  625. >>> return torch.tensor(x_np ** 3, device=x.device), dx
  626. >>>
  627. >>> def numpy_cube_vmap(info, in_dims, x):
  628. >>> result = numpy_cube(x)
  629. >>> return result, (in_dims[0], in_dims[0])
  630. >>>
  631. >>> numpy_cube.register_vmap(numpy_cube_vmap)
  632. >>>
  633. >>> x = torch.randn(3)
  634. >>> torch.vmap(numpy_cube)(x)
  635. >>>
  636. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  637. >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
  638. >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
  639. >>>
  640. >>> @numpy_mul.register_vmap
  641. >>> def numpy_mul_vmap(info, in_dims, x, y):
  642. >>> x_bdim, y_bdim = in_dims
  643. >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
  644. >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
  645. >>> result = x * y
  646. >>> result = result.movedim(-1, 0)
  647. >>> return result, 0
  648. >>>
  649. >>>
  650. >>> x = torch.randn(3)
  651. >>> y = torch.randn(3)
  652. >>> torch.vmap(numpy_mul)(x, y)
  653. """
  654. from torch._functorch.autograd_function import custom_function_call_vmap_helper
  655. from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
  656. def register(func):
  657. need_register = self._vmap_fn is None
  658. self._vmap_fn = func
  659. if need_register:
  660. def wrapped_func(keyset, *args, **kwargs):
  661. interpreter = retrieve_current_functorch_interpreter()
  662. return custom_function_call_vmap_helper(
  663. interpreter, self._vmap_fn, self._opoverload, *args, **kwargs
  664. )
  665. self._lib.impl(
  666. self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
  667. )
  668. if func is None:
  669. return register
  670. else:
  671. return register(func)
  672. def register_autocast(
  673. self,
  674. device_type: str,
  675. cast_inputs: _dtype,
  676. ):
  677. r"""Register an autocast dispatch rule for this custom op.
  678. Valid `device_type` include: "cpu" and "cuda".
  679. Args:
  680. op (str | OpOverload): The operator to register an autocast dispatch rule to.
  681. device_type(str): Device type to use. 'cuda' or 'cpu'.
  682. The type is the same as the `type` attribute of a :class:`torch.device`.
  683. Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
  684. cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
  685. casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
  686. are not affected), then executes custom op with autocast disabled.
  687. lib (Optional[Library]): If provided, the lifetime of this registration
  688. Examples::
  689. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  690. >>> import torch
  691. >>> from torch import Tensor
  692. >>> from torch.library import custom_op
  693. >>>
  694. >>> # Create a custom op that works on cuda
  695. >>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
  696. >>> def my_sin(x: Tensor) -> Tensor:
  697. >>> return torch.sin(x)
  698. >>>
  699. >>> # Register autocast dispatch rule for the cuda device
  700. >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
  701. >>>
  702. >>> x = torch.randn(3, dtype=torch.float32, device="cuda")
  703. >>> with torch.autocast("cuda", dtype=torch.float16):
  704. >>> y = torch.ops.mylib.my_sin(x)
  705. >>> assert y.dtype == torch.float16
  706. """
  707. if not isinstance(device_type, str):
  708. raise ValueError(
  709. f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
  710. )
  711. if device_type not in ["cpu", "cuda"]:
  712. raise ValueError(f"Unknown device type: {device_type}")
  713. need_register_cuda = self._autocast_cuda_dtype is None
  714. need_register_cpu = self._autocast_cpu_dtype is None
  715. if device_type == "cuda":
  716. self._autocast_cuda_dtype = cast_inputs
  717. else:
  718. self._autocast_cpu_dtype = cast_inputs
  719. def kernel(_, *args, **kwargs):
  720. assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
  721. autocast_keyset = torch._C.DispatchKeySet(
  722. torch._C.DispatchKey.AutocastCPU
  723. ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
  724. with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
  725. return self._opoverload(*_cast(args, device_type, cast_inputs))
  726. if need_register_cuda and self._autocast_cuda_dtype:
  727. self._lib.impl(self._name, kernel, "AutocastCUDA", with_keyset=True)
  728. elif need_register_cpu and self._autocast_cpu_dtype:
  729. self._lib.impl(self._name, kernel, "AutocastCPU", with_keyset=True)
  730. return kernel
  731. # TODO: Merge this function with torch.amp.autocast_mode._cast, and refactor it
  732. # into a utility function once custom ops support arbitrary input types.
  733. def _cast(value, device_type: str, dtype: _dtype):
  734. if isinstance(value, torch.Tensor):
  735. is_eligible = (
  736. value.is_floating_point()
  737. and value.device.type == device_type
  738. and (value.dtype is not torch.float64)
  739. )
  740. return value.to(dtype) if is_eligible else value
  741. elif isinstance(value, (str, bytes)):
  742. return value
  743. elif isinstance(value, collections.abc.Iterable):
  744. iterable = (_cast(v, device_type, dtype) for v in value)
  745. if isinstance(value, (list, tuple)):
  746. return type(value)(iterable)
  747. else:
  748. return iterable
  749. else:
  750. return value
  751. def increment_version(val: Any) -> None:
  752. if isinstance(val, Tensor):
  753. torch.autograd.graph.increment_version(val)
  754. elif isinstance(val, (tuple, list)):
  755. for v in val:
  756. if isinstance(v, Tensor):
  757. torch.autograd.graph.increment_version(v)
  758. # NOTE: [Supporting decorator and non-decorator usage]
  759. #
  760. # Some APIs may be both used as a decorator and not as a decorator.
  761. # For example:
  762. #
  763. # >>> def fn(x):
  764. # >>> return x.sin()
  765. # >>>
  766. # >>> # Usage 1: not as a decorator
  767. # >>> numpy_sin.register_kernel("cuda", fn)
  768. # >>>
  769. # >>> # Usage 2: as a decorator
  770. # >>> @numpy_sin.register_kernel("cuda")
  771. # >>> def fn2(x):
  772. # >>> return x.sin
  773. #
  774. # The way we support this is that `register_kernel` accepts an optional `fn`.
  775. # If `fn` is provided (Usage 1), then we know that the user is using it not
  776. # as a decorator.
  777. # If `fn` is not provided (Usage 2), then `register_kernel` needs to return a
  778. # decorator.
  779. OPDEF_TO_LIB: dict[str, "torch.library.Library"] = {}
  780. OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
  781. def get_library_allowing_overwrite(
  782. namespace: str, name: str
  783. ) -> "torch.library.Library":
  784. qualname = f"{namespace}::{name}"
  785. if qualname in OPDEF_TO_LIB:
  786. OPDEF_TO_LIB[qualname]._destroy()
  787. del OPDEF_TO_LIB[qualname]
  788. lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
  789. OPDEF_TO_LIB[qualname] = lib
  790. return lib
  791. def _maybe_get_opdef(
  792. op: Union[CustomOpDef, _ops.OpOverload, str],
  793. ) -> Optional[CustomOpDef]:
  794. if isinstance(op, CustomOpDef):
  795. return op
  796. if isinstance(op, _ops.OpOverload):
  797. op = op._name
  798. assert isinstance(op, str)
  799. if op in OPDEFS:
  800. return OPDEFS[op]
  801. return None