library.py 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. import inspect
  5. import re
  6. import sys
  7. import traceback
  8. import weakref
  9. from collections.abc import Sequence
  10. from typing import (
  11. Any,
  12. Callable,
  13. Literal,
  14. Optional,
  15. overload,
  16. TYPE_CHECKING,
  17. TypeVar,
  18. Union,
  19. )
  20. from typing_extensions import deprecated, ParamSpec
  21. import torch
  22. import torch._library as _library
  23. from torch._library.custom_ops import (
  24. _cast,
  25. _maybe_get_opdef,
  26. custom_op,
  27. CustomOpDef,
  28. device_types_t,
  29. )
  30. from torch._library.infer_schema import infer_schema # noqa: F401
  31. from torch._library.triton import triton_op, wrap_triton
  32. from torch._ops import OpOverload
  33. from torch.types import _dtype
  34. __all__ = [
  35. "Library",
  36. "impl",
  37. "define",
  38. "fallthrough_kernel",
  39. "impl_abstract",
  40. "register_autocast",
  41. "register_fake",
  42. "register_torch_dispatch",
  43. "register_vmap",
  44. "get_ctx",
  45. "get_kernel",
  46. "custom_op",
  47. "triton_op",
  48. "wrap_triton",
  49. "infer_schema",
  50. ]
  51. _T = TypeVar("_T")
  52. _P = ParamSpec("_P")
  53. # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
  54. # The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
  55. # This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
  56. # libraries calling into kernels not intended to be called.
  57. _impls: set[str] = set()
  58. _defs: set[str] = set()
  59. # prim is reserved by TorchScript interpreter
  60. _reserved_namespaces = ["prim"]
  61. def fallthrough_kernel():
  62. """
  63. A dummy function to pass to ``Library.impl`` in order to register a fallthrough.
  64. """
  65. raise NotImplementedError("fallthrough_kernel() should never be called.")
  66. class Library:
  67. """
  68. A class to create libraries that can be used to register new operators or
  69. override operators in existing libraries from Python.
  70. A user can optionally pass in a dispatch keyname if they only want to register
  71. kernels corresponding to only one specific dispatch key.
  72. To create a library to override operators in an existing library (with name ns), set the kind to "IMPL".
  73. To create a new library (with name ns) to register new operators, set the kind to "DEF".
  74. To create a fragment of a possibly existing library to register operators (and bypass
  75. the limitation that there is only one library for a given namespace), set the kind to
  76. "FRAGMENT".
  77. Args:
  78. ns: library name
  79. kind: "DEF", "IMPL", "FRAGMENT"
  80. dispatch_key: PyTorch dispatch key (default: "")
  81. """
  82. def __init__(self, ns, kind, dispatch_key=""):
  83. from torch.fx.operator_schemas import _SCHEMA_TO_SIGNATURE_CACHE
  84. if kind not in ("IMPL", "DEF", "FRAGMENT"):
  85. raise ValueError("Unsupported kind: ", kind)
  86. if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
  87. raise ValueError(
  88. ns,
  89. " is a reserved namespace. Please try creating a library with another name.",
  90. )
  91. frame = traceback.extract_stack(limit=2)[0]
  92. filename, lineno = frame.filename, frame.lineno
  93. self.m: Optional[Any] = torch._C._dispatch_library(
  94. kind, ns, dispatch_key, filename, lineno
  95. )
  96. self.ns = ns
  97. self._op_defs: set[str] = set()
  98. self._op_impls: set[str] = set()
  99. self._registration_handles: list[torch._library.utils.RegistrationHandle] = []
  100. self.kind = kind
  101. self.dispatch_key = dispatch_key
  102. # Use a finalizer to setup the "destructor" instead of __del__.
  103. # Python __del__ can lead to weird things (globals and locals may already
  104. # be gone when __del__ actually gets called!). finalizers help the
  105. # situation because it lets us capture references and keeps them alive
  106. weakref.finalize(
  107. self,
  108. _del_library,
  109. _impls,
  110. self._op_impls,
  111. _defs,
  112. self._op_defs,
  113. self._registration_handles,
  114. self.m,
  115. _SCHEMA_TO_SIGNATURE_CACHE,
  116. )
  117. def __repr__(self):
  118. return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
  119. def define(self, schema, alias_analysis="", *, tags=()):
  120. r"""Defines a new operator and its semantics in the ns namespace.
  121. Args:
  122. schema: function schema to define a new operator.
  123. alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be
  124. inferred from the schema (default behavior) or not ("CONSERVATIVE").
  125. tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
  126. operator. Tagging an operator changes the operator's behavior
  127. under various PyTorch subsystems; please read the docs for the
  128. torch.Tag carefully before applying it.
  129. Returns:
  130. name of the operator as inferred from the schema.
  131. Example::
  132. >>> my_lib = Library("mylib", "DEF")
  133. >>> my_lib.define("sum(Tensor self) -> Tensor")
  134. """
  135. # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
  136. # AliasAnalysis type in C++
  137. if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
  138. raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}")
  139. assert self.m is not None
  140. if isinstance(tags, torch.Tag):
  141. tags = (tags,)
  142. name = schema.split("(")[0]
  143. packet_name = name.split(".")[0] if "." in name else name
  144. has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
  145. getattr(torch.ops, self.ns), packet_name
  146. )
  147. result = self.m.define(schema, alias_analysis, tuple(tags))
  148. name = schema.split("(")[0]
  149. qualname = self.ns + "::" + name
  150. # If the OpOverloadPacket exists already, then this means we're adding a
  151. # new OpOverload for it. Refresh the packet to include the new OpOverload.
  152. if has_preexisting_packet:
  153. ns = getattr(torch.ops, self.ns)
  154. packet = getattr(ns, packet_name)
  155. torch._ops._refresh_packet(packet)
  156. self._op_defs.add(qualname)
  157. _defs.add(qualname)
  158. return result
  159. def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False):
  160. r"""Registers the fake impl for an operator defined in the library."""
  161. source = torch._library.utils.get_source(_stacklevel + 1)
  162. frame = sys._getframe(_stacklevel)
  163. caller_module = inspect.getmodule(frame)
  164. # Can be none if you call register_fake from somewhere there isn't a module
  165. # (e.g. __main__)
  166. caller_module_name = None if caller_module is None else caller_module.__name__
  167. # TODO(rzou): We're gonna need to stage this change with torchvision,
  168. # since torchvision is github first.
  169. if caller_module_name is not None and caller_module_name.startswith(
  170. "torchvision."
  171. ):
  172. caller_module_name = None
  173. qualname = f"{self.ns}::{op_name}"
  174. entry = torch._library.simple_registry.singleton.find(qualname)
  175. if caller_module_name is not None:
  176. func_to_register = _check_pystubs_once(fn, qualname, caller_module_name)
  177. else:
  178. func_to_register = fn
  179. handle = entry.fake_impl.register(
  180. func_to_register, source, lib=self, allow_override=allow_override
  181. )
  182. self._registration_handles.append(handle)
  183. def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn):
  184. r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class.
  185. This allows for open registration to specify the behavior between the operator
  186. and the torch_dispatch_class without needing to modify the torch_dispatch_class
  187. or the operator directly.
  188. The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a
  189. TorchDispatchMode.
  190. If it is a Tensor subclass, we expect fn to have the following signature:
  191. (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
  192. If it is a TorchDispatchMode, we expect fn to have the following signature:
  193. (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
  194. """
  195. qualname = f"{self.ns}::{op_name}"
  196. entry = torch._library.simple_registry.singleton.find(qualname)
  197. handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn)
  198. self._registration_handles.append(handle)
  199. def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
  200. r"""Register the operator to use the AOTI-compiled implementation.
  201. Args:
  202. op_name: operator name (along with the overload) or OpOverload object.
  203. dispatch_key: dispatch key that the input function should be registered for. By default, it uses
  204. the dispatch key that the library was created with.
  205. Example::
  206. >>> my_lib = Library("aten", "IMPL")
  207. >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
  208. """
  209. if dispatch_key == "":
  210. dispatch_key = self.dispatch_key
  211. assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
  212. if isinstance(op_name, str):
  213. name = op_name
  214. elif isinstance(op_name, OpOverload):
  215. name = op_name._schema.name
  216. overload_name = op_name._schema.overload_name
  217. if overload_name != "":
  218. name = name + "." + overload_name
  219. else:
  220. raise RuntimeError(
  221. "_impl_with_aoti_compile should be passed either a name or an OpOverload object "
  222. "as the first argument"
  223. )
  224. key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
  225. if key in _impls:
  226. # TODO: in future, add more info about where the existing function is registered (this info is
  227. # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
  228. raise RuntimeError(
  229. "This is not allowed since there's already a kernel registered from python overriding {}"
  230. "'s behavior for {} dispatch key and {} namespace.".format(
  231. name.split("::")[-1], dispatch_key, self.ns
  232. )
  233. )
  234. assert self.m is not None
  235. impl_fn: Callable = self.m.impl_with_aoti_compile
  236. impl_fn(self.ns, name.split("::")[-1], dispatch_key)
  237. _impls.add(key)
  238. self._op_impls.add(key)
  239. def impl(
  240. self, op_name, fn, dispatch_key="", *, with_keyset=False, allow_override=False
  241. ):
  242. r"""Registers the function implementation for an operator defined in the library.
  243. Args:
  244. op_name: operator name (along with the overload) or OpOverload object.
  245. fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel`
  246. to register a fallthrough.
  247. dispatch_key: dispatch key that the input function should be registered for. By default, it uses
  248. the dispatch key that the library was created with.
  249. with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
  250. to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
  251. allow_override: Flag controlling if we want to override an
  252. existing registered kernel implementation. This is by
  253. default off, and will error you're trying to register a
  254. kernel to a dispatch key with a kernel already
  255. registered.
  256. Example::
  257. >>> my_lib = Library("aten", "IMPL")
  258. >>> def div_cpu(self, other):
  259. >>> return self * (1 / other)
  260. >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
  261. """
  262. if not callable(fn):
  263. raise TypeError(
  264. f"Input function is required to be a callable but found type {type(fn)}"
  265. )
  266. if dispatch_key == "":
  267. dispatch_key = self.dispatch_key
  268. if isinstance(op_name, str):
  269. name = op_name
  270. elif isinstance(op_name, OpOverload):
  271. name = op_name._schema.name
  272. overload_name = op_name._schema.overload_name
  273. if overload_name != "":
  274. name = name + "." + overload_name
  275. else:
  276. raise RuntimeError(
  277. "impl should be passed either a name or an OpOverload object as the first argument"
  278. )
  279. key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
  280. if (not allow_override) and key in _impls:
  281. # TODO: in future, add more info about where the existing function is registered (this info is
  282. # today already returned by the C++ warning when impl is called but we error out before that)
  283. raise RuntimeError(
  284. "This is not allowed since there's already a kernel registered from python overriding {}"
  285. "'s behavior for {} dispatch key and {} namespace.".format(
  286. name.split("::")[-1], dispatch_key, self.ns
  287. )
  288. )
  289. if dispatch_key == "Meta":
  290. dispatcher_op_name = name
  291. if "::" not in dispatcher_op_name:
  292. dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
  293. # Internally, we shouldn't be registering meta kernels for any operators that
  294. # have CompositeImplicitAutograd kernels.
  295. # Instead, we should be letting those decompositions run, and writing meta kernels
  296. # only for the base operators.
  297. if torch._C._dispatch_has_kernel_for_dispatch_key(
  298. dispatcher_op_name, "CompositeImplicitAutograd"
  299. ):
  300. raise RuntimeError(
  301. f"We should not register a meta kernel directly to the operator '{name}',"
  302. " because it has a CompositeImplicitAutograd kernel in core."
  303. " Instead we should let the operator decompose, and ensure that we have meta kernels"
  304. " for the base ops that it decomposes into."
  305. )
  306. assert self.m is not None
  307. self.m.impl(
  308. name,
  309. dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
  310. fn,
  311. with_keyset,
  312. )
  313. _impls.add(key)
  314. self._op_impls.add(key)
  315. def fallback(self, fn, dispatch_key="", *, with_keyset=False):
  316. r"""Registers the function implementation as the fallback for the given key.
  317. This function only works for a library with global namespace ("_").
  318. Args:
  319. fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel`
  320. to register a fallthrough.
  321. dispatch_key: dispatch key that the input function should be registered for. By default, it uses
  322. the dispatch key that the library was created with.
  323. with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
  324. to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
  325. Example::
  326. >>> my_lib = Library("_", "IMPL")
  327. >>> def fallback_kernel(op, *args, **kwargs):
  328. >>> # Handle all autocast ops generically
  329. >>> # ...
  330. >>> my_lib.fallback(fallback_kernel, "Autocast")
  331. """
  332. if dispatch_key == "":
  333. dispatch_key = self.dispatch_key
  334. if self.ns != "_":
  335. raise RuntimeError(
  336. f"""Fallback can only be registered using library fragment on the global namespace "_" but it is {self.ns}"""
  337. )
  338. assert dispatch_key != ""
  339. assert self.m is not None
  340. self.m.fallback(dispatch_key, fn, with_keyset)
  341. def _destroy(self):
  342. if self.m is not None:
  343. self.m.reset()
  344. self.m = None
  345. for handle in self._registration_handles:
  346. handle.destroy()
  347. self._registration_handles.clear()
  348. global _impls
  349. _impls -= self._op_impls
  350. for name in self._op_defs:
  351. # Delete the cached torch.ops.ns.foo if it was registered.
  352. # Otherwise, accessing it leads to a segfault.
  353. # It's possible that we only registered an overload in this Library
  354. # and another library owns an alive overload.
  355. # That's OK - the next time torch.ops.ns.foo gets called, it'll be
  356. # recomputed to point at the right collection of overloads.
  357. ns, name_with_overload = name.split("::")
  358. name = name_with_overload.split(".")[0]
  359. if not hasattr(torch.ops, ns):
  360. continue
  361. namespace = getattr(torch.ops, ns)
  362. if not hasattr(namespace, name):
  363. continue
  364. delattr(namespace, name)
  365. namespace._dir.remove(name)
  366. def _del_library(
  367. captured_impls,
  368. op_impls,
  369. captured_defs,
  370. op_defs,
  371. registration_handles,
  372. m,
  373. schema_to_signature_cache,
  374. ):
  375. for op_def in op_defs:
  376. name = op_def
  377. overload_name = ""
  378. if "." in op_def:
  379. name, overload_name = op_def.split(".")
  380. if (
  381. name,
  382. overload_name,
  383. ) in schema_to_signature_cache:
  384. del schema_to_signature_cache[(name, overload_name)]
  385. captured_impls -= op_impls
  386. captured_defs -= op_defs
  387. for handle in registration_handles:
  388. handle.destroy()
  389. if m is not None:
  390. m.reset()
  391. @contextlib.contextmanager
  392. def _scoped_library(*args, **kwargs):
  393. try:
  394. lib = Library(*args, **kwargs)
  395. yield lib
  396. finally:
  397. lib._destroy()
  398. _keep_alive: list[Library] = []
  399. NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*")
  400. @functools.singledispatch
  401. def define(qualname, schema, *, lib=None, tags=()):
  402. r"""Defines a new operator.
  403. In PyTorch, defining an op (short for "operator") is a two step-process:
  404. - we need to define the op (by providing an operator name and schema)
  405. - we need to implement behavior for how the operator interacts with
  406. various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
  407. This entrypoint defines the custom operator (the first step)
  408. you must then perform the second step by calling various
  409. ``impl_*`` APIs, like :func:`torch.library.impl` or
  410. :func:`torch.library.register_fake`.
  411. Args:
  412. qualname (str): The qualified name for the operator. Should be
  413. a string that looks like "namespace::name", e.g. "aten::sin".
  414. Operators in PyTorch need a namespace to
  415. avoid name collisions; a given operator may only be created once.
  416. If you are writing a Python library, we recommend the namespace to
  417. be the name of your top-level module.
  418. schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor"
  419. for an op that accepts one Tensor and returns one Tensor. It does
  420. not contain the operator name (that is passed in ``qualname``).
  421. lib (Optional[Library]): If provided, the lifetime of this operator
  422. will be tied to the lifetime of the Library object.
  423. tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
  424. operator. Tagging an operator changes the operator's behavior
  425. under various PyTorch subsystems; please read the docs for the
  426. torch.Tag carefully before applying it.
  427. Example::
  428. >>> import torch
  429. >>> import numpy as np
  430. >>>
  431. >>> # Define the operator
  432. >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
  433. >>>
  434. >>> # Add implementations for the operator
  435. >>> @torch.library.impl("mylib::sin", "cpu")
  436. >>> def f(x):
  437. >>> return torch.from_numpy(np.sin(x.numpy()))
  438. >>>
  439. >>> # Call the new operator from torch.ops.
  440. >>> x = torch.randn(3)
  441. >>> y = torch.ops.mylib.sin(x)
  442. >>> assert torch.allclose(y, x.sin())
  443. """
  444. if not isinstance(qualname, str):
  445. raise ValueError(
  446. f"define(qualname, schema): expected qualname "
  447. f"to be instance of str, got {type(qualname)}"
  448. )
  449. namespace, name = torch._library.utils.parse_namespace(qualname)
  450. if lib is None:
  451. lib = Library(namespace, "FRAGMENT")
  452. _keep_alive.append(lib)
  453. if not NAMELESS_SCHEMA.fullmatch(schema):
  454. raise ValueError(
  455. f"define(qualname, schema, ...): expected schema "
  456. f'to look like e.g. "(Tensor x) -> Tensor" but '
  457. f'got "{schema}"'
  458. )
  459. lib.define(name + schema, alias_analysis="", tags=tags)
  460. @define.register
  461. def _(lib: Library, schema, alias_analysis=""):
  462. """The old torch.library.define.
  463. We're keeping this around for BC reasons
  464. """
  465. def wrap(f):
  466. name = lib.define(schema, alias_analysis)
  467. lib.impl(name, f)
  468. return f
  469. return wrap
  470. @overload
  471. def impl(
  472. qualname: str,
  473. types: Union[str, Sequence[str]],
  474. func: Literal[None] = None,
  475. *,
  476. lib: Optional[Library] = None,
  477. ) -> Callable[[Callable[..., object]], None]: ...
  478. @overload
  479. def impl(
  480. qualname: str,
  481. types: Union[str, Sequence[str]],
  482. func: Callable[..., object],
  483. *,
  484. lib: Optional[Library] = None,
  485. ) -> None: ...
  486. # Deprecated BC API
  487. @overload
  488. def impl(
  489. lib: Library,
  490. name: str,
  491. dispatch_key: str = "",
  492. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
  493. @functools.singledispatch
  494. def impl(
  495. qualname: str,
  496. types: Union[str, Sequence[str]],
  497. func: Optional[Callable[_P, _T]] = None,
  498. *,
  499. lib: Optional[Library] = None,
  500. ) -> object:
  501. """Register an implementation for a device type for this operator.
  502. You may pass "default" for ``types`` to register this implementation as the
  503. default implementation for ALL device types.
  504. Please only use this if the implementation truly supports all device types;
  505. for example, this is true if it is a composition of built-in PyTorch operators.
  506. This API may be used as a decorator. You can use nested decorators
  507. with this API provided they return a function and are placed inside
  508. this API (see Example 2).
  509. Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
  510. Args:
  511. qualname (str): Should be a string that looks like "namespace::operator_name".
  512. types (str | Sequence[str]): The device types to register an impl to.
  513. lib (Optional[Library]): If provided, the lifetime of this registration
  514. will be tied to the lifetime of the Library object.
  515. Examples:
  516. >>> import torch
  517. >>> import numpy as np
  518. >>> # Example 1: Register function.
  519. >>> # Define the operator
  520. >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
  521. >>>
  522. >>> # Add implementations for the cpu device
  523. >>> @torch.library.impl("mylib::mysin", "cpu")
  524. >>> def f(x):
  525. >>> return torch.from_numpy(np.sin(x.numpy()))
  526. >>>
  527. >>> x = torch.randn(3)
  528. >>> y = torch.ops.mylib.mysin(x)
  529. >>> assert torch.allclose(y, x.sin())
  530. >>>
  531. >>> # Example 2: Register function with decorator.
  532. >>> def custom_decorator(func):
  533. >>> def wrapper(*args, **kwargs):
  534. >>> return func(*args, **kwargs) + 1
  535. >>> return wrapper
  536. >>>
  537. >>> # Define the operator
  538. >>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor")
  539. >>>
  540. >>> # Add implementations for the operator
  541. >>> @torch.library.impl("mylib::sin_plus_one", "cpu")
  542. >>> @custom_decorator
  543. >>> def f(x):
  544. >>> return torch.from_numpy(np.sin(x.numpy()))
  545. >>>
  546. >>> # Call the new operator from torch.ops.
  547. >>> x = torch.randn(3)
  548. >>>
  549. >>> y1 = torch.ops.mylib.sin_plus_one(x)
  550. >>> y2 = torch.sin(x) + 1
  551. >>> assert torch.allclose(y1, y2)
  552. """
  553. return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
  554. if not TYPE_CHECKING:
  555. @impl.register
  556. def _(
  557. lib: Library, name: str, dispatch_key: str = ""
  558. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  559. """Legacy torch.library.impl API. Kept around for BC"""
  560. def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]:
  561. lib.impl(name, f, dispatch_key)
  562. return f
  563. return wrap
  564. @overload
  565. def _impl(
  566. qualname: str,
  567. types: Union[str, Sequence[str]],
  568. func: Literal[None] = None,
  569. *,
  570. lib: Optional[Library] = None,
  571. disable_dynamo: bool = False,
  572. ) -> Callable[[Callable[..., object]], None]: ...
  573. @overload
  574. def _impl(
  575. qualname: str,
  576. types: Union[str, Sequence[str]],
  577. func: Callable[..., object],
  578. *,
  579. lib: Optional[Library] = None,
  580. disable_dynamo: bool = False,
  581. ) -> None: ...
  582. def _impl(
  583. qualname: str,
  584. types: Union[str, Sequence[str]],
  585. func: Optional[Callable[..., object]] = None,
  586. *,
  587. lib: Optional[Library] = None,
  588. disable_dynamo: bool = False,
  589. ) -> Optional[Callable[[Callable[..., object]], None]]:
  590. # See impl()
  591. if isinstance(types, str):
  592. types = (types,)
  593. keys = set({})
  594. for typ in types:
  595. is_dispatch_key = torch._C._parse_dispatch_key(typ)
  596. if is_dispatch_key:
  597. # We also support passing a DispatchKey to impl. Please prefer using
  598. # the higher-level torch.library APIs and only pass DispatchKey to
  599. # torch.library.impl with caution (or even better, don't use this
  600. # option and file an issue on GitHub for what you need).
  601. # We don't advertise this to users because
  602. # it is very easy to shoot yourself in the foot.
  603. keys.add(typ)
  604. else:
  605. keys.add(_device_type_to_key(typ))
  606. def register_(func: Callable[..., object]) -> None:
  607. namespace, _ = torch._library.utils.parse_namespace(qualname)
  608. if lib is None:
  609. use_lib = Library(namespace, "FRAGMENT")
  610. _keep_alive.append(use_lib)
  611. else:
  612. use_lib = lib
  613. if disable_dynamo:
  614. @torch._disable_dynamo
  615. def func_no_dynamo(*args, **kwargs):
  616. return func(*args, **kwargs)
  617. for key in keys:
  618. use_lib.impl(qualname, func_no_dynamo, key)
  619. else:
  620. for key in keys:
  621. use_lib.impl(qualname, func, key)
  622. if func is None:
  623. return register_
  624. else:
  625. register_(func)
  626. return None
  627. def _device_type_to_key(device_type: str) -> str:
  628. if device_type == "default":
  629. # This is technically not correct, because although all device_type
  630. # DispatchKeys are included in CompositeExplicitAutograd,
  631. # not everything in CompositeExplicitAutograd is associated with a
  632. # device_type. I don't really care that much about the difference.
  633. return "CompositeExplicitAutograd"
  634. return torch._C._dispatch_key_for_device(device_type)
  635. @deprecated(
  636. "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that "
  637. "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.",
  638. category=FutureWarning,
  639. )
  640. def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
  641. r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4.
  642. Please use that instead.
  643. """
  644. if func is not None:
  645. _stacklevel = _stacklevel + 1
  646. return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
  647. _op_identifier = Union[
  648. str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
  649. ]
  650. def register_kernel(
  651. op: _op_identifier,
  652. device_types: device_types_t,
  653. func: Optional[Callable] = None,
  654. /,
  655. *,
  656. lib: Optional[Library] = None,
  657. ):
  658. """Register an implementation for a device type for this operator.
  659. Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
  660. This API may be used as a decorator.
  661. Args:
  662. op (str | OpOverload): The operator to register an impl to.
  663. device_types (None | str | Sequence[str]): The device_types to register an impl to.
  664. If None, we will register to all device types -- please only use
  665. this option if your implementation is truly device-type-agnostic.
  666. func (Callable): The function to register as the implementation for
  667. the given device types.
  668. lib (Optional[Library]): If provided, the lifetime of this registration
  669. Examples::
  670. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  671. >>> import torch
  672. >>> from torch import Tensor
  673. >>> from torch.library import custom_op
  674. >>> import numpy as np
  675. >>>
  676. >>> # Create a custom op that works on cpu
  677. >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
  678. >>> def numpy_sin(x: Tensor) -> Tensor:
  679. >>> x_np = x.numpy()
  680. >>> y_np = np.sin(x_np)
  681. >>> return torch.from_numpy(y_np)
  682. >>>
  683. >>> # Add implementations for the cuda device
  684. >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
  685. >>> def _(x):
  686. >>> x_np = x.cpu().numpy()
  687. >>> y_np = np.sin(x_np)
  688. >>> return torch.from_numpy(y_np).to(device=x.device)
  689. >>>
  690. >>> x_cpu = torch.randn(3)
  691. >>> x_cuda = x_cpu.cuda()
  692. >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
  693. >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
  694. """
  695. if not isinstance(
  696. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  697. ):
  698. raise ValueError(
  699. f"register_kernel({op}): got unexpected type for op: {type(op)}"
  700. )
  701. if isinstance(op, torch._ops.OpOverload):
  702. op = op._name
  703. opdef = _maybe_get_opdef(op)
  704. if opdef is not None:
  705. return opdef.register_kernel(device_types, func)
  706. assert isinstance(op, str)
  707. if device_types is None:
  708. device_types = "CompositeExplicitAutograd"
  709. return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
  710. def register_autocast(
  711. op: _op_identifier,
  712. device_type: str,
  713. cast_inputs: _dtype,
  714. /,
  715. *,
  716. lib: Optional[Library] = None,
  717. ):
  718. r"""Register an autocast dispatch rule for this custom op.
  719. Valid `device_type` include: "cpu" and "cuda".
  720. Args:
  721. op (str | OpOverload): The operator to register an autocast dispatch rule to.
  722. device_type(str): Device type to use. 'cuda' or 'cpu'.
  723. The type is the same as the `type` attribute of a :class:`torch.device`.
  724. Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
  725. cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
  726. casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
  727. are not affected), then executes custom op with autocast disabled.
  728. lib (Optional[Library]): If provided, the lifetime of this registration
  729. Examples::
  730. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  731. >>> import torch
  732. >>> from torch import Tensor
  733. >>> from torch.library import custom_op
  734. >>>
  735. >>> # Create a custom op that works on cuda
  736. >>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
  737. >>> def my_sin(x: Tensor) -> Tensor:
  738. >>> return torch.sin(x)
  739. >>>
  740. >>> # Register autocast dispatch rule for the cuda device
  741. >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
  742. >>>
  743. >>> x = torch.randn(3, dtype=torch.float32, device="cuda")
  744. >>> with torch.autocast("cuda", dtype=torch.float16):
  745. >>> y = torch.ops.mylib.my_sin(x)
  746. >>> assert y.dtype == torch.float16
  747. """
  748. if not isinstance(
  749. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  750. ):
  751. raise ValueError(
  752. f"register_autocast({op}): got unexpected type for op: {type(op)}"
  753. )
  754. if device_type not in ["cpu", "cuda"]:
  755. raise ValueError(f"Unknown device type: {device_type}")
  756. if isinstance(op, torch._ops.OpOverload):
  757. op = op._name
  758. opdef = _maybe_get_opdef(op)
  759. if opdef is not None:
  760. return opdef.register_autocast(device_type, cast_inputs)
  761. assert isinstance(op, str)
  762. qualname = op
  763. _op = torch._library.utils.lookup_op(qualname)
  764. namespace, opname = torch._library.utils.parse_namespace(qualname)
  765. if lib is None:
  766. lib = Library(namespace, "FRAGMENT")
  767. _keep_alive.append(lib)
  768. def _maybe_override_py_impl(op: torch._ops.OpOverload, dispatch_key):
  769. def inner(kernel):
  770. if op.has_kernel_for_dispatch_key(dispatch_key):
  771. op.py_kernels.pop(dispatch_key)
  772. return op.py_impl(dispatch_key)(kernel)
  773. return inner
  774. @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCPU)
  775. @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCUDA)
  776. def _autocast_py_impl(*args, **kwargs):
  777. assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
  778. autocast_keyset = torch._C.DispatchKeySet(
  779. torch._C.DispatchKey.AutocastCPU
  780. ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
  781. with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
  782. return _op(*_cast(args, device_type, cast_inputs))
  783. def kernel(_, *args, **kwargs):
  784. assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
  785. return _autocast_py_impl(*args, **kwargs)
  786. if device_type == "cuda":
  787. return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True)
  788. else:
  789. # device_type is "cpu"
  790. return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True)
  791. def register_fake(
  792. op: _op_identifier,
  793. func: Optional[Callable] = None,
  794. /,
  795. *,
  796. lib: Optional[Library] = None,
  797. _stacklevel: int = 1,
  798. allow_override: bool = False,
  799. ):
  800. r"""Register a FakeTensor implementation ("fake impl") for this operator.
  801. Also sometimes known as a "meta kernel", "abstract impl".
  802. An "FakeTensor implementation" specifies the behavior of this operator on
  803. Tensors that carry no data ("FakeTensor"). Given some input Tensors with
  804. certain properties (sizes/strides/storage_offset/device), it specifies
  805. what the properties of the output Tensors are.
  806. The FakeTensor implementation has the same signature as the operator.
  807. It is run for both FakeTensors and meta tensors. To write a FakeTensor
  808. implementation, assume that all Tensor inputs to the operator are
  809. regular CPU/CUDA/Meta tensors, but they do not have storage, and
  810. you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
  811. The FakeTensor implementation must consist of only PyTorch operations
  812. (and may not directly access the storage or data of any input or
  813. intermediate Tensors).
  814. This API may be used as a decorator (see examples).
  815. For a detailed guide on custom ops, please see
  816. https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
  817. Args:
  818. op_name: Operator name (along with the overload) or OpOverload object.
  819. func: Fake tensor implementation.
  820. lib (Optional[Library]): Library to register the fake tensor to.
  821. allow_override: Flag controlling if we want to override an
  822. existing registered fake impl. This is by default off,
  823. and will error you're trying to register a fake impl to
  824. an operator that already has a fake impl. This also only
  825. applies if the custom operator was not created via
  826. torch.library.custom_op, as overriding and existing fake
  827. impl is already allowed.
  828. Examples:
  829. >>> import torch
  830. >>> import numpy as np
  831. >>> from torch import Tensor
  832. >>>
  833. >>> # Example 1: an operator without data-dependent output shape
  834. >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
  835. >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
  836. >>> raise NotImplementedError("Implementation goes here")
  837. >>>
  838. >>> @torch.library.register_fake("mylib::custom_linear")
  839. >>> def _(x, weight, bias):
  840. >>> assert x.dim() == 2
  841. >>> assert weight.dim() == 2
  842. >>> assert bias.dim() == 1
  843. >>> assert x.shape[1] == weight.shape[1]
  844. >>> assert weight.shape[0] == bias.shape[0]
  845. >>> assert x.device == weight.device
  846. >>>
  847. >>> return (x @ weight.t()) + bias
  848. >>>
  849. >>> with torch._subclasses.fake_tensor.FakeTensorMode():
  850. >>> x = torch.randn(2, 3)
  851. >>> w = torch.randn(3, 3)
  852. >>> b = torch.randn(3)
  853. >>> y = torch.ops.mylib.custom_linear(x, w, b)
  854. >>>
  855. >>> assert y.shape == (2, 3)
  856. >>>
  857. >>> # Example 2: an operator with data-dependent output shape
  858. >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
  859. >>> def custom_nonzero(x: Tensor) -> Tensor:
  860. >>> x_np = x.numpy(force=True)
  861. >>> res = np.stack(np.nonzero(x_np), axis=1)
  862. >>> return torch.tensor(res, device=x.device)
  863. >>>
  864. >>> @torch.library.register_fake("mylib::custom_nonzero")
  865. >>> def _(x):
  866. >>> # Number of nonzero-elements is data-dependent.
  867. >>> # Since we cannot peek at the data in an fake impl,
  868. >>> # we use the ctx object to construct a new symint that
  869. >>> # represents the data-dependent size.
  870. >>> ctx = torch.library.get_ctx()
  871. >>> nnz = ctx.new_dynamic_size()
  872. >>> shape = [nnz, x.dim()]
  873. >>> result = x.new_empty(shape, dtype=torch.int64)
  874. >>> return result
  875. >>>
  876. >>> from torch.fx.experimental.proxy_tensor import make_fx
  877. >>>
  878. >>> x = torch.tensor([0, 1, 2, 3, 4, 0])
  879. >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
  880. >>> trace.print_readable()
  881. >>>
  882. >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
  883. """
  884. if not isinstance(
  885. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  886. ):
  887. raise ValueError(f"register_fake({op}): got unexpected type for op: {type(op)}")
  888. if isinstance(op, torch._ops.OpOverload):
  889. op = op._name
  890. opdef = _maybe_get_opdef(op)
  891. if opdef is not None:
  892. if func is None:
  893. return opdef.register_fake
  894. else:
  895. return opdef.register_fake(func)
  896. assert isinstance(op, str)
  897. stacklevel = _stacklevel
  898. def register(func):
  899. namespace, op_name = torch._library.utils.parse_namespace(op)
  900. if lib is None:
  901. use_lib = Library(namespace, "FRAGMENT")
  902. _keep_alive.append(use_lib)
  903. else:
  904. use_lib = lib
  905. use_lib._register_fake(
  906. op_name, func, _stacklevel=stacklevel + 1, allow_override=allow_override
  907. )
  908. return func
  909. if func is None:
  910. return register
  911. else:
  912. stacklevel += 1
  913. return register(func)
  914. def register_autograd(
  915. op: _op_identifier,
  916. backward: Callable,
  917. /,
  918. *,
  919. setup_context: Optional[Callable] = None,
  920. lib=None,
  921. ) -> None:
  922. r"""Register a backward formula for this custom op.
  923. In order for an operator to work with autograd, you need to register
  924. a backward formula:
  925. 1. You must tell us how to compute gradients during the backward pass
  926. by providing us a "backward" function.
  927. 2. If you need any values from the forward to compute gradients, you can
  928. use `setup_context` to save values for backward.
  929. ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``:
  930. - ``grads`` is one or more gradients. The number of gradients matches
  931. the number of outputs of the operator.
  932. The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
  933. :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
  934. same as :meth:`torch.autograd.Function.backward`.
  935. ``setup_context(ctx, inputs, output)`` runs during the forward pass.
  936. Please save quantities needed for backward onto the ``ctx`` object via
  937. either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
  938. or assigning them as attributes of ``ctx``. If your custom op has
  939. kwarg-only arguments, we expect the signature of ``setup_context``
  940. to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
  941. Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
  942. they may not directly access :meth:`torch.Tensor.data_ptr` and they must
  943. not depend on or mutate global state. If you need a non-traceable backward,
  944. you can make it a separate custom_op that you call inside ``backward_fn``.
  945. If you need different autograd behavior on different devices, then we
  946. recommend creating two different custom operators, one for each device
  947. that needs different behavior, and switching between them at runtime.
  948. Examples:
  949. >>> import torch
  950. >>> import numpy as np
  951. >>> from torch import Tensor
  952. >>>
  953. >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
  954. >>> def numpy_sin(x: Tensor) -> Tensor:
  955. >>> x_np = x.cpu().numpy()
  956. >>> y_np = np.sin(x_np)
  957. >>> return torch.from_numpy(y_np).to(device=x.device)
  958. >>>
  959. >>> def setup_context(ctx, inputs, output) -> Tensor:
  960. >>> x, = inputs
  961. >>> ctx.save_for_backward(x)
  962. >>>
  963. >>> def backward(ctx, grad):
  964. >>> x, = ctx.saved_tensors
  965. >>> return grad * x.cos()
  966. >>>
  967. >>> torch.library.register_autograd(
  968. ... "mylib::numpy_sin", backward, setup_context=setup_context
  969. ... )
  970. >>>
  971. >>> x = torch.randn(3, requires_grad=True)
  972. >>> y = numpy_sin(x)
  973. >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
  974. >>> assert torch.allclose(grad_x, x.cos())
  975. >>>
  976. >>> # Example with a keyword-only arg
  977. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  978. >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
  979. >>> x_np = x.cpu().numpy()
  980. >>> y_np = x_np * val
  981. >>> return torch.from_numpy(y_np).to(device=x.device)
  982. >>>
  983. >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
  984. >>> ctx.val = keyword_only_inputs["val"]
  985. >>>
  986. >>> def backward(ctx, grad):
  987. >>> return grad * ctx.val
  988. >>>
  989. >>> torch.library.register_autograd(
  990. ... "mylib::numpy_mul", backward, setup_context=setup_context
  991. ... )
  992. >>>
  993. >>> x = torch.randn(3, requires_grad=True)
  994. >>> y = numpy_mul(x, val=3.14)
  995. >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
  996. >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
  997. """
  998. if not isinstance(
  999. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  1000. ):
  1001. raise ValueError(
  1002. f"register_autograd({op}): got unexpected type for op: {type(op)}"
  1003. )
  1004. if isinstance(op, torch._ops.OpOverload):
  1005. op = op._name
  1006. opdef = _maybe_get_opdef(op)
  1007. if opdef is not None:
  1008. opdef.register_autograd(backward, setup_context=setup_context)
  1009. return
  1010. assert isinstance(op, str)
  1011. qualname = op
  1012. op = torch._library.utils.lookup_op(qualname)
  1013. schema = op._schema
  1014. if not _library.utils.is_functional_schema(schema):
  1015. raise RuntimeError(
  1016. f"Cannot register autograd formula for non-functional operator "
  1017. f"{op} with schema {schema}. Please create "
  1018. f"a functional operator and register an autograd formula for that."
  1019. )
  1020. if _library.utils.has_kwarg_only_tensors(schema):
  1021. raise NotImplementedError(
  1022. f"register_autograd with kwarg-only Tensor args. In the original "
  1023. f"definition of the op, please make your tensors not kwarg-only. "
  1024. f"Got: {schema}"
  1025. )
  1026. info = _library.autograd.Info(backward, setup_context)
  1027. autograd_kernel = _library.autograd.make_autograd_impl(op, info)
  1028. namespace, opname = torch._library.utils.parse_namespace(qualname)
  1029. if lib is None:
  1030. lib = Library(namespace, "FRAGMENT")
  1031. _keep_alive.append(lib)
  1032. lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True)
  1033. def register_torch_dispatch(
  1034. op: _op_identifier,
  1035. torch_dispatch_class: Any,
  1036. func: Optional[Callable] = None,
  1037. /,
  1038. *,
  1039. lib: Optional[Library] = None,
  1040. ):
  1041. r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
  1042. This allows for open registration to specify the behavior between the operator
  1043. and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
  1044. or the operator directly.
  1045. The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a
  1046. TorchDispatchMode.
  1047. If it is a Tensor subclass, we expect ``func`` to have the following signature:
  1048. ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
  1049. If it is a TorchDispatchMode, we expect ``func`` to have the following signature:
  1050. ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
  1051. ``args`` and ``kwargs`` will have been normalized the same way they are
  1052. in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`).
  1053. Examples:
  1054. >>> import torch
  1055. >>>
  1056. >>> @torch.library.custom_op("mylib::foo", mutates_args={})
  1057. >>> def foo(x: torch.Tensor) -> torch.Tensor:
  1058. >>> return x.clone()
  1059. >>>
  1060. >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
  1061. >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  1062. >>> return func(*args, **kwargs)
  1063. >>>
  1064. >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
  1065. >>> def _(mode, func, types, args, kwargs):
  1066. >>> x, = args
  1067. >>> return x + 1
  1068. >>>
  1069. >>> x = torch.randn(3)
  1070. >>> y = foo(x)
  1071. >>> assert torch.allclose(y, x)
  1072. >>>
  1073. >>> with MyMode():
  1074. >>> y = foo(x)
  1075. >>> assert torch.allclose(y, x + 1)
  1076. """
  1077. if not isinstance(
  1078. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  1079. ):
  1080. raise ValueError(
  1081. f"register_torch_dispatch({op}): got unexpected type for op: {type(op)}"
  1082. )
  1083. if isinstance(op, torch._ops.OpOverload):
  1084. op = op._name
  1085. opdef = _maybe_get_opdef(op)
  1086. if opdef is not None:
  1087. return opdef.register_torch_dispatch(torch_dispatch_class, func)
  1088. assert isinstance(op, str)
  1089. def register(func):
  1090. namespace, op_name = torch._library.utils.parse_namespace(op)
  1091. if lib is None:
  1092. use_lib = Library(namespace, "FRAGMENT")
  1093. _keep_alive.append(use_lib)
  1094. else:
  1095. use_lib = lib
  1096. use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func)
  1097. return func
  1098. if func is None:
  1099. return register
  1100. else:
  1101. return register(func)
  1102. def register_vmap(
  1103. op: _op_identifier,
  1104. func: Optional[Callable] = None,
  1105. /,
  1106. *,
  1107. lib=None,
  1108. ):
  1109. r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
  1110. This API may be used as a decorator (see examples).
  1111. In order for an operator to work with :func:`torch.vmap`, you may need to register a
  1112. vmap implementation in the following signature:
  1113. ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
  1114. where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
  1115. We do not support kwarg-only Tensor args.
  1116. It specifies how do we compute the batched version of ``op`` given inputs with an additional
  1117. dimension (specified by ``in_dims``).
  1118. For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
  1119. if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
  1120. specifying what dimension of the Tensor is being vmapped over.
  1121. ``info`` is a collection of additional metadata that may be helpful:
  1122. ``info.batch_size`` specifies the size of the dimension being vmapped over, while
  1123. ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
  1124. The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
  1125. ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
  1126. per output that specifies if the output has the vmapped dimension and what index it is in.
  1127. Examples:
  1128. >>> import torch
  1129. >>> import numpy as np
  1130. >>> from torch import Tensor
  1131. >>> from typing import Tuple
  1132. >>>
  1133. >>> def to_numpy(tensor):
  1134. >>> return tensor.cpu().numpy()
  1135. >>>
  1136. >>> lib = torch.library.Library("mylib", "FRAGMENT")
  1137. >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
  1138. >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
  1139. >>> x_np = to_numpy(x)
  1140. >>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
  1141. >>> return torch.tensor(x_np ** 3, device=x.device), dx
  1142. >>>
  1143. >>> def numpy_cube_vmap(info, in_dims, x):
  1144. >>> result = numpy_cube(x)
  1145. >>> return result, (in_dims[0], in_dims[0])
  1146. >>>
  1147. >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
  1148. >>>
  1149. >>> x = torch.randn(3)
  1150. >>> torch.vmap(numpy_cube)(x)
  1151. >>>
  1152. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  1153. >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
  1154. >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
  1155. >>>
  1156. >>> @torch.library.register_vmap("mylib::numpy_mul")
  1157. >>> def numpy_mul_vmap(info, in_dims, x, y):
  1158. >>> x_bdim, y_bdim = in_dims
  1159. >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
  1160. >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
  1161. >>> result = x * y
  1162. >>> result = result.movedim(-1, 0)
  1163. >>> return result, 0
  1164. >>>
  1165. >>>
  1166. >>> x = torch.randn(3)
  1167. >>> y = torch.randn(3)
  1168. >>> torch.vmap(numpy_mul)(x, y)
  1169. .. note::
  1170. The vmap function should aim to preserve the semantics of the entire custom operator.
  1171. That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``.
  1172. If your custom operator has any custom behavior in the backward pass, please
  1173. keep this in mind.
  1174. """
  1175. if not isinstance(
  1176. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  1177. ):
  1178. raise ValueError(f"register_vmap({op}): got unexpected type for op: {type(op)}")
  1179. if isinstance(op, torch._ops.OpOverload):
  1180. op = op._name
  1181. opdef = _maybe_get_opdef(op)
  1182. if opdef is not None:
  1183. return opdef.register_vmap(func)
  1184. assert isinstance(op, str)
  1185. qualname = op
  1186. op = torch._library.utils.lookup_op(qualname)
  1187. schema = op._schema
  1188. if _library.utils.has_kwarg_only_tensors(schema):
  1189. raise NotImplementedError(
  1190. f"register_vmap with kwarg-only Tensor args. In the original "
  1191. f"definition of the op, please make your tensors not kwarg-only. "
  1192. f"Got: {schema}"
  1193. )
  1194. def register(func):
  1195. nonlocal op, lib
  1196. namespace, opname = torch._library.utils.parse_namespace(qualname)
  1197. if lib is None:
  1198. lib = Library(namespace, "FRAGMENT")
  1199. _keep_alive.append(lib)
  1200. from torch._functorch.autograd_function import custom_function_call_vmap_helper
  1201. from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
  1202. def wrapped_func(keyset, *args, **kwargs):
  1203. interpreter = retrieve_current_functorch_interpreter()
  1204. return custom_function_call_vmap_helper(
  1205. interpreter, func, op, *args, **kwargs
  1206. )
  1207. lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True)
  1208. if func is None:
  1209. return register
  1210. else:
  1211. return register(func)
  1212. # If the op was defined in C++, then we want to make sure there was an
  1213. # m.set_python_module(module, ...) call and that the module is the
  1214. # same as the module that called torch.library.register_fake.
  1215. def _check_pystubs_once(func, qualname, actual_module_name):
  1216. checked = False
  1217. def inner(*args, **kwargs):
  1218. nonlocal checked
  1219. if checked:
  1220. return func(*args, **kwargs)
  1221. op = torch._library.utils.lookup_op(qualname)
  1222. if op._defined_in_python:
  1223. checked = True
  1224. return func(*args, **kwargs)
  1225. maybe_pystub = torch._C._dispatch_pystub(
  1226. op._schema.name, op._schema.overload_name
  1227. )
  1228. if maybe_pystub is None:
  1229. if torch._library.utils.requires_set_python_module():
  1230. namespace = op.namespace
  1231. cpp_filename = op._handle.debug()
  1232. raise RuntimeError(
  1233. f"Operator '{qualname}' was defined in C++ and has a Python "
  1234. f"fake impl. In this situation, we require there to also be a "
  1235. f'companion C++ `m.set_python_module("{actual_module_name}")` '
  1236. f"call, but we could not find one. Please add that to "
  1237. f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
  1238. f"operator was registered in ({cpp_filename})"
  1239. )
  1240. else:
  1241. pystub_module = maybe_pystub[0]
  1242. if actual_module_name != pystub_module:
  1243. cpp_filename = op._handle.debug()
  1244. raise RuntimeError(
  1245. f"Operator '{qualname}' specified that its python fake impl "
  1246. f"is in the Python module '{pystub_module}' but it was actually found "
  1247. f"in '{actual_module_name}'. Please either move the fake impl "
  1248. f"or correct the m.set_python_module call ({cpp_filename})"
  1249. )
  1250. checked = True
  1251. return func(*args, **kwargs)
  1252. return inner
  1253. # NOTE [ctx inside the fake implementation]
  1254. # If a user has an operator with data-dependent output shape, then when writing
  1255. # a fake implementation they must query the current ctx and use methods on the
  1256. # ctx to construct a new unbacked symint.
  1257. #
  1258. # This is done via us setting the global_ctx_getter function every time a fake
  1259. # implementation is invoked.
  1260. def get_ctx() -> "torch._library.fake_impl.FakeImplCtx":
  1261. """get_ctx() returns the current AbstractImplCtx object.
  1262. Calling ``get_ctx()`` is only valid inside of an fake impl
  1263. (see :func:`torch.library.register_fake` for more usage details.
  1264. """
  1265. return torch._library.fake_impl.global_ctx_getter()
  1266. def get_kernel(
  1267. op: _op_identifier, dispatch_key: Union[str, torch.DispatchKey]
  1268. ) -> torch._C._SafeKernelFunction:
  1269. """Returns the computed kernel for a given operator and dispatch key.
  1270. This function retrieves the kernel that would be executed for a given
  1271. operator and dispatch key combination. The returned SafeKernelFunction
  1272. can be used to call the kernel in a boxed fashion. The intended use
  1273. case for this function is to retrieve the original kernel for a given
  1274. dispatch key and then register another kernel to the same dispatch key
  1275. that calls into the original kernel for certain cases.
  1276. Args:
  1277. op: Operator name (along with the overload) or OpOverload object
  1278. Can be a string (e.g., "aten::add.Tensor"), an OpOverload, or a CustomOpDef.
  1279. dispatch_key (str | torch.DispatchKey): The dispatch key to get the kernel for.
  1280. Can be a string (e.g., "CPU", "CUDA") or a DispatchKey enum value.
  1281. Returns:
  1282. torch._C._SafeKernelFunction: A safe kernel function that can be used to
  1283. call the kernel.
  1284. Raises:
  1285. RuntimeError: If the operator does not exist.
  1286. Example:
  1287. >>> # Get the CPU kernel for torch.add
  1288. >>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU")
  1289. >>>
  1290. >>> # You can also use DispatchKey enum
  1291. >>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU)
  1292. >>>
  1293. >>> # Or use an OpOverload directly
  1294. >>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU")
  1295. >>>
  1296. >>> # Example: Using get_kernel in a custom op with conditional dispatch
  1297. >>> # Get the original kernel for torch.sin
  1298. >>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU")
  1299. >>>
  1300. >>> # If input has negative values, use original sin, otherwise return zeros
  1301. >>> def conditional_sin_impl(dispatch_keys, x):
  1302. >>> if (x < 0).any():
  1303. >>> return original_sin_kernel.call_boxed(dispatch_keys, x)
  1304. >>> else:
  1305. >>> return torch.zeros_like(x)
  1306. >>>
  1307. >>> lib = torch.library.Library("aten", "IMPL")
  1308. >>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet
  1309. >>> which needs to be the first argument to ``kernel.call_boxed``
  1310. >>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True)
  1311. >>>
  1312. >>> # Test the conditional behavior
  1313. >>> x_positive = torch.tensor([1.0, 2.0])
  1314. >>> x_mixed = torch.tensor([-1.0, 2.0])
  1315. >>> torch.sin(x_positive)
  1316. tensor([0., 0.])
  1317. >>> torch.sin(x_mixed)
  1318. tensor([-0.8415, 0.9093])
  1319. """
  1320. if not isinstance(op, (str, torch._ops.OpOverload)):
  1321. raise ValueError(f"get_kernel({op}): got unexpected type for op: {type(op)}")
  1322. if isinstance(op, torch._ops.OpOverload):
  1323. op = op._name
  1324. if isinstance(dispatch_key, str):
  1325. try:
  1326. dispatch_key = torch._C.DispatchKey.__members__[dispatch_key]
  1327. except KeyError:
  1328. raise ValueError(f"Invalid dispatch key: {dispatch_key}") from None
  1329. return torch._C._dispatch_get_computed_kernel_for_dispatch_key(op, dispatch_key)
  1330. _OPCHECK_DEFAULT_UTILS = (
  1331. "test_schema",
  1332. "test_autograd_registration",
  1333. "test_faketensor",
  1334. "test_aot_dispatch_dynamic",
  1335. )
  1336. def opcheck(
  1337. op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
  1338. args: tuple[Any, ...],
  1339. kwargs: Optional[dict[str, Any]] = None,
  1340. *,
  1341. test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS,
  1342. raise_exception: bool = True,
  1343. atol=None,
  1344. rtol=None,
  1345. ) -> dict[str, str]:
  1346. """Given an operator and some sample arguments, tests if the operator is
  1347. registered correctly.
  1348. That is, when you use the torch.library/TORCH_LIBRARY APIs to create a
  1349. custom op, you specified metadata (e.g. mutability info) about the custom op
  1350. and these APIs require that the functions you pass them satisfy certain
  1351. properties (e.g. no data pointer access in the fake/meta/abstract kernel)
  1352. ``opcheck`` tests these metadata and properties.
  1353. Concretely, we test the following:
  1354. - test_schema: If the schema matches the implementation of
  1355. the operator. For example: if the schema specifies a Tensor is mutated,
  1356. then we check the implementation mutates the Tensor. If the schema
  1357. specifies that we return a new Tensor, then we check that the
  1358. implementation returns a new Tensor (instead of an existing one or
  1359. a view of an existing one).
  1360. - test_autograd_registration: If the operator supports training
  1361. (autograd): we check that its autograd formula is registered via
  1362. torch.library.register_autograd or a manual registration to one
  1363. or more DispatchKey::Autograd keys. Any other DispatchKey-based
  1364. registrations may lead to undefined behavior.
  1365. - test_faketensor: If the operator has a FakeTensor kernel
  1366. (and if it is correct). The FakeTensor kernel is necessary (
  1367. but not sufficient) for the operator to work with PyTorch compilation
  1368. APIs (torch.compile/export/FX). We check that a FakeTensor kernel
  1369. (also sometimes known as a meta kernel) was registered for the
  1370. operator and that it is correct. This test takes the result of
  1371. running the operator on real tensors and the result of running
  1372. the operator on FakeTensors and checks that they have the same
  1373. Tensor metadata (sizes/strides/dtype/device/etc).
  1374. - test_aot_dispatch_dynamic: If the operator has correct behavior
  1375. with PyTorch compilation APIs (torch.compile/export/FX).
  1376. This checks that the outputs (and gradients, if applicable) are the
  1377. same under eager-mode PyTorch and torch.compile.
  1378. This test is a superset of ``test_faketensor`` and is an e2e test;
  1379. other things it tests are that the operator supports
  1380. functionalization and that the backward pass (if it exists) also
  1381. supports FakeTensor and functionalization.
  1382. For best results, please call ``opcheck`` multiple times with a
  1383. representative set of inputs. If your operator supports
  1384. autograd, please use ``opcheck`` with inputs with ``requires_grad = True``;
  1385. if your operator supports multiple devices (e.g. CPU and CUDA), please
  1386. use ``opcheck`` with inputs on all supported devices.
  1387. Args:
  1388. op: The operator. Must either be a function decorated with
  1389. :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket
  1390. found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo)
  1391. args: The args to the operator
  1392. kwargs: The kwargs to the operator
  1393. test_utils: Tests that we should run. Default: all of them.
  1394. Example: ("test_schema", "test_faketensor")
  1395. raise_exception: If we should raise an exception on the first
  1396. error. If False, we will return a dict with information
  1397. on if each test passed or not.
  1398. rtol (Optional[float]): Relative tolerance for floating point comparisons.
  1399. If specified ``atol`` must also be specified.
  1400. If omitted, default values based on the ``dtype`` are selected
  1401. (see the table in :func:`torch.testing.assert_close`).
  1402. atol (Optional[float]): Absolute tolerance for floating point comparisons.
  1403. If specified ``rtol`` must also be specified.
  1404. If omitted, default values based on the ``dtype`` are selected
  1405. (see the table in :func:`torch.testing.assert_close`).
  1406. .. warning::
  1407. opcheck and :func:`torch.autograd.gradcheck` test different things;
  1408. opcheck tests if your usage of torch.library APIs is correct while
  1409. :func:`torch.autograd.gradcheck` tests if your autograd formula is
  1410. mathematically correct. Use both to test custom ops that support
  1411. gradient computation.
  1412. Example:
  1413. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  1414. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  1415. >>> def numpy_mul(x: Tensor, y: float) -> Tensor:
  1416. >>> x_np = x.numpy(force=True)
  1417. >>> z_np = x_np * y
  1418. >>> return torch.from_numpy(z_np).to(x.device)
  1419. >>>
  1420. >>> @numpy_mul.register_fake
  1421. >>> def _(x, y):
  1422. >>> return torch.empty_like(x)
  1423. >>>
  1424. >>> def setup_context(ctx, inputs, output):
  1425. >>> y, = inputs
  1426. >>> ctx.y = y
  1427. >>>
  1428. >>> def backward(ctx, grad):
  1429. >>> return grad * ctx.y, None
  1430. >>>
  1431. >>> numpy_mul.register_autograd(backward, setup_context=setup_context)
  1432. >>>
  1433. >>> sample_inputs = [
  1434. >>> (torch.randn(3), 3.14),
  1435. >>> (torch.randn(2, 3, device='cuda'), 2.718),
  1436. >>> (torch.randn(1, 10, requires_grad=True), 1.234),
  1437. >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
  1438. >>> ]
  1439. >>>
  1440. >>> for args in sample_inputs:
  1441. >>> torch.library.opcheck(numpy_mul, args)
  1442. """
  1443. import torch.testing._internal.optests as optests
  1444. return optests.opcheck(
  1445. op,
  1446. args,
  1447. kwargs,
  1448. test_utils=test_utils,
  1449. raise_exception=raise_exception,
  1450. rtol=rtol,
  1451. atol=atol,
  1452. )