library.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. import inspect
  5. import re
  6. import sys
  7. import traceback
  8. import weakref
  9. from collections.abc import Callable, Sequence
  10. from typing import Any, overload, TYPE_CHECKING, TypeVar, Union
  11. from typing_extensions import deprecated, ParamSpec
  12. import torch
  13. import torch._library as _library
  14. from torch._library.custom_ops import (
  15. _cast,
  16. _maybe_get_opdef,
  17. custom_op,
  18. CustomOpDef,
  19. device_types_t,
  20. )
  21. from torch._library.effects import EffectType
  22. from torch._library.infer_schema import infer_schema # noqa: F401
  23. from torch._library.triton import triton_op, wrap_triton
  24. from torch._ops import OpOverload
  25. from torch.types import _dtype
  26. __all__ = [
  27. "Library",
  28. "impl",
  29. "define",
  30. "fallthrough_kernel",
  31. "impl_abstract",
  32. "register_autocast",
  33. "register_fake",
  34. "register_torch_dispatch",
  35. "register_vmap",
  36. "get_ctx",
  37. "get_kernel",
  38. "custom_op",
  39. "triton_op",
  40. "wrap_triton",
  41. "infer_schema",
  42. ]
  43. _T = TypeVar("_T")
  44. _P = ParamSpec("_P")
  45. # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
  46. # The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
  47. # This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
  48. # libraries calling into kernels not intended to be called.
  49. _impls: set[str] = set()
  50. _defs: set[str] = set()
  51. # prim is reserved by TorchScript interpreter
  52. _reserved_namespaces = ["prim"]
  53. def fallthrough_kernel():
  54. """
  55. A dummy function to pass to ``Library.impl`` in order to register a fallthrough.
  56. """
  57. raise NotImplementedError("fallthrough_kernel() should never be called.")
  58. class Library:
  59. """
  60. A class to create libraries that can be used to register new operators or
  61. override operators in existing libraries from Python.
  62. A user can optionally pass in a dispatch keyname if they only want to register
  63. kernels corresponding to only one specific dispatch key.
  64. To create a library to override operators in an existing library (with name ns), set the kind to "IMPL".
  65. To create a new library (with name ns) to register new operators, set the kind to "DEF".
  66. To create a fragment of a possibly existing library to register operators (and bypass
  67. the limitation that there is only one library for a given namespace), set the kind to
  68. "FRAGMENT".
  69. Args:
  70. ns: library name
  71. kind: "DEF", "IMPL", "FRAGMENT"
  72. dispatch_key: PyTorch dispatch key (default: "")
  73. """
  74. def __init__(self, ns, kind, dispatch_key=""):
  75. from torch.fx.operator_schemas import _SCHEMA_TO_SIGNATURE_CACHE
  76. if kind not in ("IMPL", "DEF", "FRAGMENT"):
  77. raise ValueError("Unsupported kind: ", kind)
  78. if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
  79. raise ValueError(
  80. ns,
  81. " is a reserved namespace. Please try creating a library with another name.",
  82. )
  83. frame = traceback.extract_stack(limit=2)[0]
  84. filename, lineno = frame.filename, frame.lineno
  85. self.m: Any | None = torch._C._dispatch_library(
  86. kind, ns, dispatch_key, filename, lineno
  87. )
  88. self.ns = ns
  89. self._op_defs: set[str] = set()
  90. self._op_impls: set[str] = set()
  91. self._registration_handles: list[torch._library.utils.RegistrationHandle] = []
  92. self.kind = kind
  93. self.dispatch_key = dispatch_key
  94. # Use a finalizer to setup the "destructor" instead of __del__.
  95. # Python __del__ can lead to weird things (globals and locals may already
  96. # be gone when __del__ actually gets called!). finalizers help the
  97. # situation because it lets us capture references and keeps them alive
  98. weakref.finalize(
  99. self,
  100. _del_library,
  101. _impls,
  102. self._op_impls,
  103. _defs,
  104. self._op_defs,
  105. self._registration_handles,
  106. self.m,
  107. _SCHEMA_TO_SIGNATURE_CACHE,
  108. )
  109. def __repr__(self):
  110. return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
  111. def define(self, schema, alias_analysis="", *, tags=()):
  112. r"""Defines a new operator and its semantics in the ns namespace.
  113. Args:
  114. schema: function schema to define a new operator.
  115. alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be
  116. inferred from the schema (default behavior) or not ("CONSERVATIVE").
  117. tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
  118. operator. Tagging an operator changes the operator's behavior
  119. under various PyTorch subsystems; please read the docs for the
  120. torch.Tag carefully before applying it.
  121. Returns:
  122. name of the operator as inferred from the schema.
  123. Example::
  124. >>> my_lib = Library("mylib", "DEF")
  125. >>> my_lib.define("sum(Tensor self) -> Tensor")
  126. """
  127. # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
  128. # AliasAnalysis type in C++
  129. if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
  130. raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}")
  131. assert self.m is not None
  132. if isinstance(tags, torch.Tag):
  133. tags = (tags,)
  134. name = schema.split("(")[0]
  135. packet_name = name.split(".")[0] if "." in name else name
  136. has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
  137. getattr(torch.ops, self.ns), packet_name
  138. )
  139. result = self.m.define(schema, alias_analysis, tuple(tags))
  140. name = schema.split("(")[0]
  141. qualname = self.ns + "::" + name
  142. # If the OpOverloadPacket exists already, then this means we're adding a
  143. # new OpOverload for it. Refresh the packet to include the new OpOverload.
  144. if has_preexisting_packet:
  145. ns = getattr(torch.ops, self.ns)
  146. packet = getattr(ns, packet_name)
  147. torch._ops._refresh_packet(packet)
  148. self._op_defs.add(qualname)
  149. _defs.add(qualname)
  150. return result
  151. def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False):
  152. r"""Registers the fake impl for an operator defined in the library."""
  153. source = torch._library.utils.get_source(_stacklevel + 1)
  154. frame = sys._getframe(_stacklevel)
  155. caller_module = inspect.getmodule(frame)
  156. # Can be none if you call register_fake from somewhere there isn't a module
  157. # (e.g. __main__)
  158. caller_module_name = None if caller_module is None else caller_module.__name__
  159. # TODO(rzou): We're gonna need to stage this change with torchvision,
  160. # since torchvision is github first.
  161. if caller_module_name is not None and caller_module_name.startswith(
  162. "torchvision."
  163. ):
  164. caller_module_name = None
  165. qualname = f"{self.ns}::{op_name}"
  166. entry = torch._library.simple_registry.singleton.find(qualname)
  167. if caller_module_name is not None:
  168. func_to_register = _check_pystubs_once(fn, qualname, caller_module_name)
  169. else:
  170. func_to_register = fn
  171. handle = entry.fake_impl.register(
  172. func_to_register, source, lib=self, allow_override=allow_override
  173. )
  174. self._registration_handles.append(handle)
  175. def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn):
  176. r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class.
  177. This allows for open registration to specify the behavior between the operator
  178. and the torch_dispatch_class without needing to modify the torch_dispatch_class
  179. or the operator directly.
  180. The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a
  181. TorchDispatchMode.
  182. If it is a Tensor subclass, we expect fn to have the following signature:
  183. (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
  184. If it is a TorchDispatchMode, we expect fn to have the following signature:
  185. (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
  186. """
  187. qualname = f"{self.ns}::{op_name}"
  188. entry = torch._library.simple_registry.singleton.find(qualname)
  189. handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn)
  190. self._registration_handles.append(handle)
  191. def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
  192. r"""Register the operator to use the AOTI-compiled implementation.
  193. Args:
  194. op_name: operator name (along with the overload) or OpOverload object.
  195. dispatch_key: dispatch key that the input function should be registered for. By default, it uses
  196. the dispatch key that the library was created with.
  197. Example::
  198. >>> my_lib = Library("aten", "IMPL")
  199. >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
  200. """
  201. if dispatch_key == "":
  202. dispatch_key = self.dispatch_key
  203. # pyrefly: ignore [bad-argument-type]
  204. assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
  205. if isinstance(op_name, str):
  206. name = op_name
  207. elif isinstance(op_name, OpOverload):
  208. name = op_name._schema.name
  209. overload_name = op_name._schema.overload_name
  210. if overload_name != "":
  211. name = name + "." + overload_name
  212. else:
  213. raise RuntimeError(
  214. "_impl_with_aoti_compile should be passed either a name or an OpOverload object "
  215. "as the first argument"
  216. )
  217. key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
  218. if key in _impls:
  219. # TODO: in future, add more info about where the existing function is registered (this info is
  220. # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
  221. raise RuntimeError(
  222. "This is not allowed since there's already a kernel registered from python overriding {}"
  223. "'s behavior for {} dispatch key and {} namespace.".format(
  224. name.split("::")[-1], dispatch_key, self.ns
  225. )
  226. )
  227. assert self.m is not None
  228. impl_fn: Callable = self.m.impl_with_aoti_compile
  229. impl_fn(self.ns, name.split("::")[-1], dispatch_key)
  230. _impls.add(key)
  231. self._op_impls.add(key)
  232. def impl(
  233. self, op_name, fn, dispatch_key="", *, with_keyset=False, allow_override=False
  234. ):
  235. r"""Registers the function implementation for an operator defined in the library.
  236. Args:
  237. op_name: operator name (along with the overload) or OpOverload object.
  238. fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel`
  239. to register a fallthrough.
  240. dispatch_key: dispatch key that the input function should be registered for. By default, it uses
  241. the dispatch key that the library was created with.
  242. with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
  243. to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
  244. allow_override: Flag controlling if we want to override an
  245. existing registered kernel implementation. This is by
  246. default off, and will error you're trying to register a
  247. kernel to a dispatch key with a kernel already
  248. registered.
  249. Example::
  250. >>> my_lib = Library("aten", "IMPL")
  251. >>> def div_cpu(self, other):
  252. >>> return self * (1 / other)
  253. >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
  254. """
  255. if not callable(fn):
  256. raise TypeError(
  257. f"Input function is required to be a callable but found type {type(fn)}"
  258. )
  259. if dispatch_key == "":
  260. dispatch_key = self.dispatch_key
  261. if isinstance(op_name, str):
  262. name = op_name
  263. elif isinstance(op_name, OpOverload):
  264. name = op_name._schema.name
  265. overload_name = op_name._schema.overload_name
  266. if overload_name != "":
  267. name = name + "." + overload_name
  268. else:
  269. raise RuntimeError(
  270. "impl should be passed either a name or an OpOverload object as the first argument"
  271. )
  272. key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
  273. if (not allow_override) and key in _impls:
  274. # TODO: in future, add more info about where the existing function is registered (this info is
  275. # today already returned by the C++ warning when impl is called but we error out before that)
  276. raise RuntimeError(
  277. "This is not allowed since there's already a kernel registered from python overriding {}"
  278. "'s behavior for {} dispatch key and {} namespace.".format(
  279. name.split("::")[-1], dispatch_key, self.ns
  280. )
  281. )
  282. if dispatch_key == "Meta":
  283. dispatcher_op_name = name
  284. if "::" not in dispatcher_op_name:
  285. dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
  286. # Internally, we shouldn't be registering meta kernels for any operators that
  287. # have CompositeImplicitAutograd kernels.
  288. # Instead, we should be letting those decompositions run, and writing meta kernels
  289. # only for the base operators.
  290. if torch._C._dispatch_has_kernel_for_dispatch_key(
  291. dispatcher_op_name, "CompositeImplicitAutograd"
  292. ):
  293. raise RuntimeError(
  294. f"We should not register a meta kernel directly to the operator '{name}',"
  295. " because it has a CompositeImplicitAutograd kernel in core."
  296. " Instead we should let the operator decompose, and ensure that we have meta kernels"
  297. " for the base ops that it decomposes into."
  298. )
  299. assert self.m is not None
  300. self.m.impl(
  301. name,
  302. dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
  303. fn,
  304. with_keyset,
  305. )
  306. _impls.add(key)
  307. self._op_impls.add(key)
  308. def fallback(self, fn, dispatch_key="", *, with_keyset=False):
  309. r"""Registers the function implementation as the fallback for the given key.
  310. This function only works for a library with global namespace ("_").
  311. Args:
  312. fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel`
  313. to register a fallthrough.
  314. dispatch_key: dispatch key that the input function should be registered for. By default, it uses
  315. the dispatch key that the library was created with.
  316. with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
  317. to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
  318. Example::
  319. >>> my_lib = Library("_", "IMPL")
  320. >>> def fallback_kernel(op, *args, **kwargs):
  321. >>> # Handle all autocast ops generically
  322. >>> # ...
  323. >>> my_lib.fallback(fallback_kernel, "Autocast")
  324. """
  325. if dispatch_key == "":
  326. dispatch_key = self.dispatch_key
  327. if self.ns != "_":
  328. raise RuntimeError(
  329. f"""Fallback can only be registered using library fragment on the global namespace "_" but it is {self.ns}"""
  330. )
  331. assert dispatch_key != ""
  332. assert self.m is not None
  333. self.m.fallback(dispatch_key, fn, with_keyset)
  334. def _register_effectful_op(self, op_name: str, effect: EffectType | None):
  335. """
  336. Registers an effect to an operator. This is used to register an op that
  337. has side effects that is not capturable by the schema.
  338. Args:
  339. op_name: operator name (along with the overload) or OpOverload object.
  340. effect: The effect of the op.
  341. """
  342. from torch._higher_order_ops.effects import (
  343. _register_effectful_op as hoo_register_effect,
  344. )
  345. handle = hoo_register_effect(op_name, effect)
  346. self._registration_handles.append(handle)
  347. def _destroy(self):
  348. if self.m is not None:
  349. self.m.reset()
  350. self.m = None
  351. for handle in self._registration_handles:
  352. handle.destroy()
  353. self._registration_handles.clear()
  354. global _impls
  355. _impls -= self._op_impls
  356. for name in self._op_defs:
  357. # Delete the cached torch.ops.ns.foo if it was registered.
  358. # Otherwise, accessing it leads to a segfault.
  359. # It's possible that we only registered an overload in this Library
  360. # and another library owns an alive overload.
  361. # That's OK - the next time torch.ops.ns.foo gets called, it'll be
  362. # recomputed to point at the right collection of overloads.
  363. ns, name_with_overload = name.split("::")
  364. name = name_with_overload.split(".")[0]
  365. if not hasattr(torch.ops, ns):
  366. continue
  367. namespace = getattr(torch.ops, ns)
  368. if not hasattr(namespace, name):
  369. continue
  370. delattr(namespace, name)
  371. namespace._dir.remove(name)
  372. def _del_library(
  373. captured_impls,
  374. op_impls,
  375. captured_defs,
  376. op_defs,
  377. registration_handles,
  378. m,
  379. schema_to_signature_cache,
  380. ):
  381. for op_def in op_defs:
  382. name = op_def
  383. overload_name = ""
  384. if "." in op_def:
  385. name, overload_name = op_def.split(".")
  386. if (
  387. name,
  388. overload_name,
  389. ) in schema_to_signature_cache:
  390. del schema_to_signature_cache[(name, overload_name)]
  391. captured_impls -= op_impls
  392. captured_defs -= op_defs
  393. for handle in registration_handles:
  394. handle.destroy()
  395. if m is not None:
  396. m.reset()
  397. @contextlib.contextmanager
  398. def _scoped_library(*args, **kwargs):
  399. try:
  400. lib = Library(*args, **kwargs)
  401. yield lib
  402. finally:
  403. lib._destroy()
  404. _keep_alive: list[Library] = []
  405. NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*")
  406. @functools.singledispatch
  407. def define(qualname, schema, *, lib=None, tags=()):
  408. r"""Defines a new operator.
  409. In PyTorch, defining an op (short for "operator") is a two step-process:
  410. - we need to define the op (by providing an operator name and schema)
  411. - we need to implement behavior for how the operator interacts with
  412. various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
  413. This entrypoint defines the custom operator (the first step)
  414. you must then perform the second step by calling various
  415. ``impl_*`` APIs, like :func:`torch.library.impl` or
  416. :func:`torch.library.register_fake`.
  417. Args:
  418. qualname (str): The qualified name for the operator. Should be
  419. a string that looks like "namespace::name", e.g. "aten::sin".
  420. Operators in PyTorch need a namespace to
  421. avoid name collisions; a given operator may only be created once.
  422. If you are writing a Python library, we recommend the namespace to
  423. be the name of your top-level module.
  424. schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor"
  425. for an op that accepts one Tensor and returns one Tensor. It does
  426. not contain the operator name (that is passed in ``qualname``).
  427. lib (Optional[Library]): If provided, the lifetime of this operator
  428. will be tied to the lifetime of the Library object.
  429. tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
  430. operator. Tagging an operator changes the operator's behavior
  431. under various PyTorch subsystems; please read the docs for the
  432. torch.Tag carefully before applying it.
  433. Example::
  434. >>> import torch
  435. >>> import numpy as np
  436. >>>
  437. >>> # Define the operator
  438. >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
  439. >>>
  440. >>> # Add implementations for the operator
  441. >>> @torch.library.impl("mylib::sin", "cpu")
  442. >>> def f(x):
  443. >>> return torch.from_numpy(np.sin(x.numpy()))
  444. >>>
  445. >>> # Call the new operator from torch.ops.
  446. >>> x = torch.randn(3)
  447. >>> y = torch.ops.mylib.sin(x)
  448. >>> assert torch.allclose(y, x.sin())
  449. """
  450. if not isinstance(qualname, str):
  451. raise ValueError(
  452. f"define(qualname, schema): expected qualname "
  453. f"to be instance of str, got {type(qualname)}"
  454. )
  455. namespace, name = torch._library.utils.parse_namespace(qualname)
  456. if lib is None:
  457. lib = Library(namespace, "FRAGMENT")
  458. _keep_alive.append(lib)
  459. if not NAMELESS_SCHEMA.fullmatch(schema):
  460. raise ValueError(
  461. f"define(qualname, schema, ...): expected schema "
  462. f'to look like e.g. "(Tensor x) -> Tensor" but '
  463. f'got "{schema}"'
  464. )
  465. lib.define(name + schema, alias_analysis="", tags=tags)
  466. @define.register
  467. def _(lib: Library, schema, alias_analysis=""):
  468. """The old torch.library.define.
  469. We're keeping this around for BC reasons
  470. """
  471. def wrap(f):
  472. name = lib.define(schema, alias_analysis)
  473. lib.impl(name, f)
  474. return f
  475. return wrap
  476. @overload
  477. def impl(
  478. qualname: str,
  479. types: str | Sequence[str],
  480. func: None = None,
  481. *,
  482. lib: Library | None = None,
  483. ) -> Callable[[Callable[..., object]], None]: ...
  484. @overload
  485. def impl(
  486. qualname: str,
  487. types: str | Sequence[str],
  488. func: Callable[..., object],
  489. *,
  490. lib: Library | None = None,
  491. ) -> None: ...
  492. # Deprecated BC API
  493. @overload
  494. def impl(
  495. lib: Library,
  496. name: str,
  497. dispatch_key: str = "",
  498. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
  499. @functools.singledispatch
  500. def impl(
  501. qualname: str,
  502. types: str | Sequence[str],
  503. func: Callable[_P, _T] | None = None,
  504. *,
  505. lib: Library | None = None,
  506. ) -> object:
  507. """Register an implementation for a device type for this operator.
  508. You may pass "default" for ``types`` to register this implementation as the
  509. default implementation for ALL device types.
  510. Please only use this if the implementation truly supports all device types;
  511. for example, this is true if it is a composition of built-in PyTorch operators.
  512. This API may be used as a decorator. You can use nested decorators
  513. with this API provided they return a function and are placed inside
  514. this API (see Example 2).
  515. Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
  516. Args:
  517. qualname (str): Should be a string that looks like "namespace::operator_name".
  518. types (str | Sequence[str]): The device types to register an impl to.
  519. lib (Optional[Library]): If provided, the lifetime of this registration
  520. will be tied to the lifetime of the Library object.
  521. Examples:
  522. >>> import torch
  523. >>> import numpy as np
  524. >>> # Example 1: Register function.
  525. >>> # Define the operator
  526. >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
  527. >>>
  528. >>> # Add implementations for the cpu device
  529. >>> @torch.library.impl("mylib::mysin", "cpu")
  530. >>> def f(x):
  531. >>> return torch.from_numpy(np.sin(x.numpy()))
  532. >>>
  533. >>> x = torch.randn(3)
  534. >>> y = torch.ops.mylib.mysin(x)
  535. >>> assert torch.allclose(y, x.sin())
  536. >>>
  537. >>> # Example 2: Register function with decorator.
  538. >>> def custom_decorator(func):
  539. >>> def wrapper(*args, **kwargs):
  540. >>> return func(*args, **kwargs) + 1
  541. >>> return wrapper
  542. >>>
  543. >>> # Define the operator
  544. >>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor")
  545. >>>
  546. >>> # Add implementations for the operator
  547. >>> @torch.library.impl("mylib::sin_plus_one", "cpu")
  548. >>> @custom_decorator
  549. >>> def f(x):
  550. >>> return torch.from_numpy(np.sin(x.numpy()))
  551. >>>
  552. >>> # Call the new operator from torch.ops.
  553. >>> x = torch.randn(3)
  554. >>>
  555. >>> y1 = torch.ops.mylib.sin_plus_one(x)
  556. >>> y2 = torch.sin(x) + 1
  557. >>> assert torch.allclose(y1, y2)
  558. """
  559. return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
  560. if not TYPE_CHECKING:
  561. @impl.register
  562. def _(
  563. lib: Library, name: str, dispatch_key: str = ""
  564. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  565. """Legacy torch.library.impl API. Kept around for BC"""
  566. def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]:
  567. lib.impl(name, f, dispatch_key)
  568. return f
  569. return wrap
  570. @overload
  571. def _impl(
  572. qualname: str,
  573. types: str | Sequence[str],
  574. func: None = None,
  575. *,
  576. lib: Library | None = None,
  577. disable_dynamo: bool = False,
  578. ) -> Callable[[Callable[..., object]], None]: ...
  579. @overload
  580. def _impl(
  581. qualname: str,
  582. types: str | Sequence[str],
  583. func: Callable[..., object],
  584. *,
  585. lib: Library | None = None,
  586. disable_dynamo: bool = False,
  587. ) -> None: ...
  588. def _impl(
  589. qualname: str,
  590. types: str | Sequence[str],
  591. func: Callable[..., object] | None = None,
  592. *,
  593. lib: Library | None = None,
  594. disable_dynamo: bool = False,
  595. ) -> Callable[[Callable[..., object]], None] | None:
  596. # See impl()
  597. if isinstance(types, str):
  598. types = (types,)
  599. keys = set({})
  600. for typ in types:
  601. is_dispatch_key = torch._C._parse_dispatch_key(typ)
  602. if is_dispatch_key:
  603. # We also support passing a DispatchKey to impl. Please prefer using
  604. # the higher-level torch.library APIs and only pass DispatchKey to
  605. # torch.library.impl with caution (or even better, don't use this
  606. # option and file an issue on GitHub for what you need).
  607. # We don't advertise this to users because
  608. # it is very easy to shoot yourself in the foot.
  609. keys.add(typ)
  610. else:
  611. keys.add(_device_type_to_key(typ))
  612. def register_(func: Callable[..., object]) -> None:
  613. namespace, _ = torch._library.utils.parse_namespace(qualname)
  614. if lib is None:
  615. use_lib = Library(namespace, "FRAGMENT")
  616. _keep_alive.append(use_lib)
  617. else:
  618. use_lib = lib
  619. if disable_dynamo:
  620. @torch._disable_dynamo
  621. def func_no_dynamo(*args, **kwargs):
  622. return func(*args, **kwargs)
  623. for key in keys:
  624. use_lib.impl(qualname, func_no_dynamo, key)
  625. else:
  626. for key in keys:
  627. use_lib.impl(qualname, func, key)
  628. if func is None:
  629. return register_
  630. else:
  631. register_(func)
  632. return None
  633. def _device_type_to_key(device_type: str) -> str:
  634. if device_type == "default":
  635. # This is technically not correct, because although all device_type
  636. # DispatchKeys are included in CompositeExplicitAutograd,
  637. # not everything in CompositeExplicitAutograd is associated with a
  638. # device_type. I don't really care that much about the difference.
  639. return "CompositeExplicitAutograd"
  640. return torch._C._dispatch_key_for_device(device_type)
  641. @deprecated(
  642. "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that "
  643. "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.",
  644. category=FutureWarning,
  645. )
  646. def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
  647. r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4.
  648. Please use that instead.
  649. """
  650. if func is not None:
  651. _stacklevel = _stacklevel + 1
  652. return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
  653. _op_identifier = Union[
  654. str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
  655. ]
  656. def register_kernel(
  657. op: _op_identifier,
  658. device_types: device_types_t,
  659. func: Callable | None = None,
  660. /,
  661. *,
  662. lib: Library | None = None,
  663. ):
  664. """Register an implementation for a device type for this operator.
  665. Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
  666. This API may be used as a decorator.
  667. Args:
  668. op (str | OpOverload): The operator to register an impl to.
  669. device_types (None | str | Sequence[str]): The device_types to register an impl to.
  670. If None, we will register to all device types -- please only use
  671. this option if your implementation is truly device-type-agnostic.
  672. func (Callable): The function to register as the implementation for
  673. the given device types.
  674. lib (Optional[Library]): If provided, the lifetime of this registration
  675. Examples::
  676. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  677. >>> import torch
  678. >>> from torch import Tensor
  679. >>> from torch.library import custom_op
  680. >>> import numpy as np
  681. >>>
  682. >>> # Create a custom op that works on cpu
  683. >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
  684. >>> def numpy_sin(x: Tensor) -> Tensor:
  685. >>> x_np = x.numpy()
  686. >>> y_np = np.sin(x_np)
  687. >>> return torch.from_numpy(y_np)
  688. >>>
  689. >>> # Add implementations for the cuda device
  690. >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
  691. >>> def _(x):
  692. >>> x_np = x.cpu().numpy()
  693. >>> y_np = np.sin(x_np)
  694. >>> return torch.from_numpy(y_np).to(device=x.device)
  695. >>>
  696. >>> x_cpu = torch.randn(3)
  697. >>> x_cuda = x_cpu.cuda()
  698. >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
  699. >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
  700. """
  701. if not isinstance(
  702. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  703. ):
  704. raise ValueError(
  705. f"register_kernel({op}): got unexpected type for op: {type(op)}"
  706. )
  707. if isinstance(op, torch._ops.OpOverload):
  708. op = op._name
  709. opdef = _maybe_get_opdef(op)
  710. if opdef is not None:
  711. return opdef.register_kernel(device_types, func)
  712. assert isinstance(op, str)
  713. if device_types is None:
  714. device_types = "CompositeExplicitAutograd"
  715. return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
  716. def register_autocast(
  717. op: _op_identifier,
  718. device_type: str,
  719. cast_inputs: _dtype,
  720. /,
  721. *,
  722. lib: Library | None = None,
  723. ):
  724. r"""Register an autocast dispatch rule for this custom op.
  725. Valid `device_type` include: "cpu" and "cuda".
  726. Args:
  727. op (str | OpOverload): The operator to register an autocast dispatch rule to.
  728. device_type(str): Device type to use. 'cuda' or 'cpu'.
  729. The type is the same as the `type` attribute of a :class:`torch.device`.
  730. Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
  731. cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
  732. casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
  733. are not affected), then executes custom op with autocast disabled.
  734. lib (Optional[Library]): If provided, the lifetime of this registration
  735. Examples::
  736. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  737. >>> import torch
  738. >>> from torch import Tensor
  739. >>> from torch.library import custom_op
  740. >>>
  741. >>> # Create a custom op that works on cuda
  742. >>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
  743. >>> def my_sin(x: Tensor) -> Tensor:
  744. >>> return torch.sin(x)
  745. >>>
  746. >>> # Register autocast dispatch rule for the cuda device
  747. >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
  748. >>>
  749. >>> x = torch.randn(3, dtype=torch.float32, device="cuda")
  750. >>> with torch.autocast("cuda", dtype=torch.float16):
  751. >>> y = torch.ops.mylib.my_sin(x)
  752. >>> assert y.dtype == torch.float16
  753. """
  754. if not isinstance(
  755. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  756. ):
  757. raise ValueError(
  758. f"register_autocast({op}): got unexpected type for op: {type(op)}"
  759. )
  760. if device_type not in ["cpu", "cuda"]:
  761. raise ValueError(f"Unknown device type: {device_type}")
  762. if isinstance(op, torch._ops.OpOverload):
  763. op = op._name
  764. opdef = _maybe_get_opdef(op)
  765. if opdef is not None:
  766. return opdef.register_autocast(device_type, cast_inputs)
  767. assert isinstance(op, str)
  768. qualname = op
  769. _op = torch._library.utils.lookup_op(qualname)
  770. namespace, opname = torch._library.utils.parse_namespace(qualname)
  771. if lib is None:
  772. lib = Library(namespace, "FRAGMENT")
  773. _keep_alive.append(lib)
  774. def _maybe_override_py_impl(op: torch._ops.OpOverload, dispatch_key):
  775. def inner(kernel):
  776. if op.has_kernel_for_dispatch_key(dispatch_key):
  777. op.py_kernels.pop(dispatch_key)
  778. return op.py_impl(dispatch_key)(kernel)
  779. return inner
  780. @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCPU)
  781. @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCUDA)
  782. def _autocast_py_impl(*args, **kwargs):
  783. assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
  784. autocast_keyset = torch._C.DispatchKeySet(
  785. torch._C.DispatchKey.AutocastCPU
  786. ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
  787. with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
  788. return _op(*_cast(args, device_type, cast_inputs))
  789. def kernel(_, *args, **kwargs):
  790. assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
  791. return _autocast_py_impl(*args, **kwargs)
  792. if device_type == "cuda":
  793. return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True)
  794. else:
  795. # device_type is "cpu"
  796. return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True)
  797. def register_fake(
  798. op: _op_identifier,
  799. func: Callable | None = None,
  800. /,
  801. *,
  802. lib: Library | None = None,
  803. _stacklevel: int = 1,
  804. allow_override: bool = False,
  805. ):
  806. r"""Register a FakeTensor implementation ("fake impl") for this operator.
  807. Also sometimes known as a "meta kernel", "abstract impl".
  808. An "FakeTensor implementation" specifies the behavior of this operator on
  809. Tensors that carry no data ("FakeTensor"). Given some input Tensors with
  810. certain properties (sizes/strides/storage_offset/device), it specifies
  811. what the properties of the output Tensors are.
  812. The FakeTensor implementation has the same signature as the operator.
  813. It is run for both FakeTensors and meta tensors. To write a FakeTensor
  814. implementation, assume that all Tensor inputs to the operator are
  815. regular CPU/CUDA/Meta tensors, but they do not have storage, and
  816. you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
  817. The FakeTensor implementation must consist of only PyTorch operations
  818. (and may not directly access the storage or data of any input or
  819. intermediate Tensors).
  820. This API may be used as a decorator (see examples).
  821. For a detailed guide on custom ops, please see
  822. https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
  823. Args:
  824. op_name: Operator name (along with the overload) or OpOverload object.
  825. func: Fake tensor implementation.
  826. lib (Optional[Library]): Library to register the fake tensor to.
  827. allow_override: Flag controlling if we want to override an
  828. existing registered fake impl. This is by default off,
  829. and will error you're trying to register a fake impl to
  830. an operator that already has a fake impl. This also only
  831. applies if the custom operator was not created via
  832. torch.library.custom_op, as overriding and existing fake
  833. impl is already allowed.
  834. Examples:
  835. >>> import torch
  836. >>> import numpy as np
  837. >>> from torch import Tensor
  838. >>>
  839. >>> # Example 1: an operator without data-dependent output shape
  840. >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
  841. >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
  842. >>> raise NotImplementedError("Implementation goes here")
  843. >>>
  844. >>> @torch.library.register_fake("mylib::custom_linear")
  845. >>> def _(x, weight, bias):
  846. >>> assert x.dim() == 2
  847. >>> assert weight.dim() == 2
  848. >>> assert bias.dim() == 1
  849. >>> assert x.shape[1] == weight.shape[1]
  850. >>> assert weight.shape[0] == bias.shape[0]
  851. >>> assert x.device == weight.device
  852. >>>
  853. >>> return (x @ weight.t()) + bias
  854. >>>
  855. >>> with torch._subclasses.fake_tensor.FakeTensorMode():
  856. >>> x = torch.randn(2, 3)
  857. >>> w = torch.randn(3, 3)
  858. >>> b = torch.randn(3)
  859. >>> y = torch.ops.mylib.custom_linear(x, w, b)
  860. >>>
  861. >>> assert y.shape == (2, 3)
  862. >>>
  863. >>> # Example 2: an operator with data-dependent output shape
  864. >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
  865. >>> def custom_nonzero(x: Tensor) -> Tensor:
  866. >>> x_np = x.numpy(force=True)
  867. >>> res = np.stack(np.nonzero(x_np), axis=1)
  868. >>> return torch.tensor(res, device=x.device)
  869. >>>
  870. >>> @torch.library.register_fake("mylib::custom_nonzero")
  871. >>> def _(x):
  872. >>> # Number of nonzero-elements is data-dependent.
  873. >>> # Since we cannot peek at the data in an fake impl,
  874. >>> # we use the ctx object to construct a new symint that
  875. >>> # represents the data-dependent size.
  876. >>> ctx = torch.library.get_ctx()
  877. >>> nnz = ctx.new_dynamic_size()
  878. >>> shape = [nnz, x.dim()]
  879. >>> result = x.new_empty(shape, dtype=torch.int64)
  880. >>> return result
  881. >>>
  882. >>> from torch.fx.experimental.proxy_tensor import make_fx
  883. >>>
  884. >>> x = torch.tensor([0, 1, 2, 3, 4, 0])
  885. >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
  886. >>> trace.print_readable()
  887. >>>
  888. >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
  889. """
  890. if not isinstance(
  891. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  892. ):
  893. raise ValueError(f"register_fake({op}): got unexpected type for op: {type(op)}")
  894. if isinstance(op, torch._ops.OpOverload):
  895. op = op._name
  896. opdef = _maybe_get_opdef(op)
  897. if opdef is not None:
  898. if func is None:
  899. return opdef.register_fake
  900. else:
  901. return opdef.register_fake(func)
  902. assert isinstance(op, str)
  903. stacklevel = _stacklevel
  904. def register(func):
  905. namespace, op_name = torch._library.utils.parse_namespace(op)
  906. if lib is None:
  907. use_lib = Library(namespace, "FRAGMENT")
  908. _keep_alive.append(use_lib)
  909. else:
  910. use_lib = lib
  911. use_lib._register_fake(
  912. op_name, func, _stacklevel=stacklevel + 1, allow_override=allow_override
  913. )
  914. return func
  915. if func is None:
  916. return register
  917. else:
  918. stacklevel += 1
  919. return register(func)
  920. def _register_effectful_op(
  921. op: _op_identifier,
  922. effect: EffectType | None,
  923. *,
  924. lib: Library | None = None,
  925. ) -> None:
  926. r"""
  927. To specify that an operator has side-effects, we must register an effect
  928. type for the operator. This will prevent graph passes in torch.compile from
  929. reordering operations with the same effect type.
  930. Args:
  931. op_name: Operator name (along with the overload) or OpOverload object.
  932. effect: Effect type to register. None means the operator is not effectful.
  933. """
  934. if not isinstance(
  935. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  936. ):
  937. raise ValueError(
  938. f"register_effectful_op({op}): got unexpected type for op: {type(op)}"
  939. )
  940. if isinstance(op, torch._ops.OpOverload):
  941. op = op._name
  942. opdef = _maybe_get_opdef(op)
  943. if opdef is not None:
  944. opdef.register_effect(effect)
  945. assert isinstance(op, str)
  946. namespace, _ = torch._library.utils.parse_namespace(op)
  947. if lib is None:
  948. use_lib = Library(namespace, "FRAGMENT")
  949. _keep_alive.append(use_lib)
  950. else:
  951. use_lib = lib
  952. use_lib._register_effectful_op(op, effect)
  953. def register_autograd(
  954. op: _op_identifier,
  955. backward: Callable,
  956. /,
  957. *,
  958. setup_context: Callable | None = None,
  959. lib=None,
  960. ) -> None:
  961. r"""Register a backward formula for this custom op.
  962. In order for an operator to work with autograd, you need to register
  963. a backward formula:
  964. 1. You must tell us how to compute gradients during the backward pass
  965. by providing us a "backward" function.
  966. 2. If you need any values from the forward to compute gradients, you can
  967. use `setup_context` to save values for backward.
  968. ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``:
  969. - ``grads`` is one or more gradients. The number of gradients matches
  970. the number of outputs of the operator.
  971. The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
  972. :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
  973. same as :meth:`torch.autograd.Function.backward`.
  974. ``setup_context(ctx, inputs, output)`` runs during the forward pass.
  975. Please save quantities needed for backward onto the ``ctx`` object via
  976. either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
  977. or assigning them as attributes of ``ctx``. If your custom op has
  978. kwarg-only arguments, we expect the signature of ``setup_context``
  979. to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
  980. Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
  981. they may not directly access :meth:`torch.Tensor.data_ptr` and they must
  982. not depend on or mutate global state. If you need a non-traceable backward,
  983. you can make it a separate custom_op that you call inside ``backward_fn``.
  984. If you need different autograd behavior on different devices, then we
  985. recommend creating two different custom operators, one for each device
  986. that needs different behavior, and switching between them at runtime.
  987. Examples:
  988. >>> import torch
  989. >>> import numpy as np
  990. >>> from torch import Tensor
  991. >>>
  992. >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
  993. >>> def numpy_sin(x: Tensor) -> Tensor:
  994. >>> x_np = x.cpu().numpy()
  995. >>> y_np = np.sin(x_np)
  996. >>> return torch.from_numpy(y_np).to(device=x.device)
  997. >>>
  998. >>> def setup_context(ctx, inputs, output) -> Tensor:
  999. >>> x, = inputs
  1000. >>> ctx.save_for_backward(x)
  1001. >>>
  1002. >>> def backward(ctx, grad):
  1003. >>> x, = ctx.saved_tensors
  1004. >>> return grad * x.cos()
  1005. >>>
  1006. >>> torch.library.register_autograd(
  1007. ... "mylib::numpy_sin", backward, setup_context=setup_context
  1008. ... )
  1009. >>>
  1010. >>> x = torch.randn(3, requires_grad=True)
  1011. >>> y = numpy_sin(x)
  1012. >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
  1013. >>> assert torch.allclose(grad_x, x.cos())
  1014. >>>
  1015. >>> # Example with a keyword-only arg
  1016. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  1017. >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
  1018. >>> x_np = x.cpu().numpy()
  1019. >>> y_np = x_np * val
  1020. >>> return torch.from_numpy(y_np).to(device=x.device)
  1021. >>>
  1022. >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
  1023. >>> ctx.val = keyword_only_inputs["val"]
  1024. >>>
  1025. >>> def backward(ctx, grad):
  1026. >>> return grad * ctx.val
  1027. >>>
  1028. >>> torch.library.register_autograd(
  1029. ... "mylib::numpy_mul", backward, setup_context=setup_context
  1030. ... )
  1031. >>>
  1032. >>> x = torch.randn(3, requires_grad=True)
  1033. >>> y = numpy_mul(x, val=3.14)
  1034. >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
  1035. >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
  1036. """
  1037. if not isinstance(
  1038. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  1039. ):
  1040. raise ValueError(
  1041. f"register_autograd({op}): got unexpected type for op: {type(op)}"
  1042. )
  1043. if isinstance(op, torch._ops.OpOverload):
  1044. op = op._name
  1045. opdef = _maybe_get_opdef(op)
  1046. if opdef is not None:
  1047. opdef.register_autograd(backward, setup_context=setup_context)
  1048. return
  1049. assert isinstance(op, str)
  1050. qualname = op
  1051. op = torch._library.utils.lookup_op(qualname)
  1052. schema = op._schema
  1053. if not _library.utils.is_functional_schema(schema):
  1054. raise RuntimeError(
  1055. f"Cannot register autograd formula for non-functional operator "
  1056. f"{op} with schema {schema}. Please create "
  1057. f"a functional operator and register an autograd formula for that."
  1058. )
  1059. if _library.utils.has_kwarg_only_tensors(schema):
  1060. raise NotImplementedError(
  1061. f"register_autograd with kwarg-only Tensor args. In the original "
  1062. f"definition of the op, please make your tensors not kwarg-only. "
  1063. f"Got: {schema}"
  1064. )
  1065. info = _library.autograd.Info(backward, setup_context)
  1066. autograd_kernel = _library.autograd.make_autograd_impl(op, info)
  1067. namespace, opname = torch._library.utils.parse_namespace(qualname)
  1068. if lib is None:
  1069. lib = Library(namespace, "FRAGMENT")
  1070. _keep_alive.append(lib)
  1071. lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True)
  1072. def register_torch_dispatch(
  1073. op: _op_identifier,
  1074. torch_dispatch_class: Any,
  1075. func: Callable | None = None,
  1076. /,
  1077. *,
  1078. lib: Library | None = None,
  1079. ):
  1080. r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
  1081. This allows for open registration to specify the behavior between the operator
  1082. and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
  1083. or the operator directly.
  1084. The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a
  1085. TorchDispatchMode.
  1086. If it is a Tensor subclass, we expect ``func`` to have the following signature:
  1087. ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
  1088. If it is a TorchDispatchMode, we expect ``func`` to have the following signature:
  1089. ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
  1090. ``args`` and ``kwargs`` will have been normalized the same way they are
  1091. in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`).
  1092. Examples:
  1093. >>> import torch
  1094. >>>
  1095. >>> @torch.library.custom_op("mylib::foo", mutates_args={})
  1096. >>> def foo(x: torch.Tensor) -> torch.Tensor:
  1097. >>> return x.clone()
  1098. >>>
  1099. >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
  1100. >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  1101. >>> return func(*args, **kwargs)
  1102. >>>
  1103. >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
  1104. >>> def _(mode, func, types, args, kwargs):
  1105. >>> x, = args
  1106. >>> return x + 1
  1107. >>>
  1108. >>> x = torch.randn(3)
  1109. >>> y = foo(x)
  1110. >>> assert torch.allclose(y, x)
  1111. >>>
  1112. >>> with MyMode():
  1113. >>> y = foo(x)
  1114. >>> assert torch.allclose(y, x + 1)
  1115. """
  1116. if not isinstance(
  1117. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  1118. ):
  1119. raise ValueError(
  1120. f"register_torch_dispatch({op}): got unexpected type for op: {type(op)}"
  1121. )
  1122. if isinstance(op, torch._ops.OpOverload):
  1123. op = op._name
  1124. opdef = _maybe_get_opdef(op)
  1125. if opdef is not None:
  1126. return opdef.register_torch_dispatch(torch_dispatch_class, func)
  1127. assert isinstance(op, str)
  1128. def register(func):
  1129. namespace, op_name = torch._library.utils.parse_namespace(op)
  1130. if lib is None:
  1131. use_lib = Library(namespace, "FRAGMENT")
  1132. _keep_alive.append(use_lib)
  1133. else:
  1134. use_lib = lib
  1135. use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func)
  1136. return func
  1137. if func is None:
  1138. return register
  1139. else:
  1140. return register(func)
  1141. def register_vmap(
  1142. op: _op_identifier,
  1143. func: Callable | None = None,
  1144. /,
  1145. *,
  1146. lib=None,
  1147. ):
  1148. r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
  1149. This API may be used as a decorator (see examples).
  1150. In order for an operator to work with :func:`torch.vmap`, you may need to register a
  1151. vmap implementation in the following signature:
  1152. ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
  1153. where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
  1154. We do not support kwarg-only Tensor args.
  1155. It specifies how do we compute the batched version of ``op`` given inputs with an additional
  1156. dimension (specified by ``in_dims``).
  1157. For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
  1158. if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
  1159. specifying what dimension of the Tensor is being vmapped over.
  1160. ``info`` is a collection of additional metadata that may be helpful:
  1161. ``info.batch_size`` specifies the size of the dimension being vmapped over, while
  1162. ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
  1163. The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
  1164. ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
  1165. per output that specifies if the output has the vmapped dimension and what index it is in.
  1166. Examples:
  1167. >>> import torch
  1168. >>> import numpy as np
  1169. >>> from torch import Tensor
  1170. >>> from typing import Tuple
  1171. >>>
  1172. >>> def to_numpy(tensor):
  1173. >>> return tensor.cpu().numpy()
  1174. >>>
  1175. >>> lib = torch.library.Library("mylib", "FRAGMENT")
  1176. >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
  1177. >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
  1178. >>> x_np = to_numpy(x)
  1179. >>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
  1180. >>> return torch.tensor(x_np ** 3, device=x.device), dx
  1181. >>>
  1182. >>> def numpy_cube_vmap(info, in_dims, x):
  1183. >>> result = numpy_cube(x)
  1184. >>> return result, (in_dims[0], in_dims[0])
  1185. >>>
  1186. >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
  1187. >>>
  1188. >>> x = torch.randn(3)
  1189. >>> torch.vmap(numpy_cube)(x)
  1190. >>>
  1191. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  1192. >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
  1193. >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
  1194. >>>
  1195. >>> @torch.library.register_vmap("mylib::numpy_mul")
  1196. >>> def numpy_mul_vmap(info, in_dims, x, y):
  1197. >>> x_bdim, y_bdim = in_dims
  1198. >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
  1199. >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
  1200. >>> result = x * y
  1201. >>> result = result.movedim(-1, 0)
  1202. >>> return result, 0
  1203. >>>
  1204. >>>
  1205. >>> x = torch.randn(3)
  1206. >>> y = torch.randn(3)
  1207. >>> torch.vmap(numpy_mul)(x, y)
  1208. .. note::
  1209. The vmap function should aim to preserve the semantics of the entire custom operator.
  1210. That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``.
  1211. If your custom operator has any custom behavior in the backward pass, please
  1212. keep this in mind.
  1213. """
  1214. if not isinstance(
  1215. op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
  1216. ):
  1217. raise ValueError(f"register_vmap({op}): got unexpected type for op: {type(op)}")
  1218. if isinstance(op, torch._ops.OpOverload):
  1219. op = op._name
  1220. opdef = _maybe_get_opdef(op)
  1221. if opdef is not None:
  1222. return opdef.register_vmap(func)
  1223. assert isinstance(op, str)
  1224. qualname = op
  1225. op = torch._library.utils.lookup_op(qualname)
  1226. schema = op._schema
  1227. if _library.utils.has_kwarg_only_tensors(schema):
  1228. raise NotImplementedError(
  1229. f"register_vmap with kwarg-only Tensor args. In the original "
  1230. f"definition of the op, please make your tensors not kwarg-only. "
  1231. f"Got: {schema}"
  1232. )
  1233. def register(func):
  1234. nonlocal op, lib
  1235. namespace, opname = torch._library.utils.parse_namespace(qualname)
  1236. if lib is None:
  1237. lib = Library(namespace, "FRAGMENT")
  1238. _keep_alive.append(lib)
  1239. from torch._functorch.autograd_function import custom_function_call_vmap_helper
  1240. from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
  1241. def wrapped_func(keyset, *args, **kwargs):
  1242. interpreter = retrieve_current_functorch_interpreter()
  1243. return custom_function_call_vmap_helper(
  1244. interpreter, func, op, *args, **kwargs
  1245. )
  1246. lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True)
  1247. if func is None:
  1248. return register
  1249. else:
  1250. return register(func)
  1251. # If the op was defined in C++, then we want to make sure there was an
  1252. # m.set_python_module(module, ...) call and that the module is the
  1253. # same as the module that called torch.library.register_fake.
  1254. def _check_pystubs_once(func, qualname, actual_module_name):
  1255. checked = False
  1256. def inner(*args, **kwargs):
  1257. nonlocal checked
  1258. if checked:
  1259. return func(*args, **kwargs)
  1260. op = torch._library.utils.lookup_op(qualname)
  1261. if op._defined_in_python:
  1262. checked = True
  1263. return func(*args, **kwargs)
  1264. maybe_pystub = torch._C._dispatch_pystub(
  1265. op._schema.name, op._schema.overload_name
  1266. )
  1267. if maybe_pystub is None:
  1268. if torch._library.utils.requires_set_python_module():
  1269. namespace = op.namespace
  1270. cpp_filename = op._handle.debug()
  1271. raise RuntimeError(
  1272. f"Operator '{qualname}' was defined in C++ and has a Python "
  1273. f"fake impl. In this situation, we require there to also be a "
  1274. f'companion C++ `m.set_python_module("{actual_module_name}")` '
  1275. f"call, but we could not find one. Please add that to "
  1276. f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
  1277. f"operator was registered in ({cpp_filename})"
  1278. )
  1279. else:
  1280. pystub_module = maybe_pystub[0]
  1281. if actual_module_name != pystub_module:
  1282. cpp_filename = op._handle.debug()
  1283. raise RuntimeError(
  1284. f"Operator '{qualname}' specified that its python fake impl "
  1285. f"is in the Python module '{pystub_module}' but it was actually found "
  1286. f"in '{actual_module_name}'. Please either move the fake impl "
  1287. f"or correct the m.set_python_module call ({cpp_filename})"
  1288. )
  1289. checked = True
  1290. return func(*args, **kwargs)
  1291. return inner
  1292. # NOTE [ctx inside the fake implementation]
  1293. # If a user has an operator with data-dependent output shape, then when writing
  1294. # a fake implementation they must query the current ctx and use methods on the
  1295. # ctx to construct a new unbacked symint.
  1296. #
  1297. # This is done via us setting the global_ctx_getter function every time a fake
  1298. # implementation is invoked.
  1299. def get_ctx() -> "torch._library.fake_impl.FakeImplCtx":
  1300. """get_ctx() returns the current AbstractImplCtx object.
  1301. Calling ``get_ctx()`` is only valid inside of an fake impl
  1302. (see :func:`torch.library.register_fake` for more usage details.
  1303. """
  1304. return torch._library.fake_impl.global_ctx_getter()
  1305. def get_kernel(
  1306. op: _op_identifier, dispatch_key: str | torch.DispatchKey
  1307. ) -> torch._C._SafeKernelFunction:
  1308. """Returns the computed kernel for a given operator and dispatch key.
  1309. This function retrieves the kernel that would be executed for a given
  1310. operator and dispatch key combination. The returned SafeKernelFunction
  1311. can be used to call the kernel in a boxed fashion. The intended use
  1312. case for this function is to retrieve the original kernel for a given
  1313. dispatch key and then register another kernel to the same dispatch key
  1314. that calls into the original kernel for certain cases.
  1315. Args:
  1316. op: Operator name (along with the overload) or OpOverload object
  1317. Can be a string (e.g., "aten::add.Tensor"), an OpOverload, or a CustomOpDef.
  1318. dispatch_key (str | torch.DispatchKey): The dispatch key to get the kernel for.
  1319. Can be a string (e.g., "CPU", "CUDA") or a DispatchKey enum value.
  1320. Returns:
  1321. torch._C._SafeKernelFunction: A safe kernel function that can be used to
  1322. call the kernel.
  1323. Raises:
  1324. RuntimeError: If the operator does not exist.
  1325. Example:
  1326. >>> # Get the CPU kernel for torch.add
  1327. >>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU")
  1328. >>>
  1329. >>> # You can also use DispatchKey enum
  1330. >>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU)
  1331. >>>
  1332. >>> # Or use an OpOverload directly
  1333. >>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU")
  1334. >>>
  1335. >>> # Example: Using get_kernel in a custom op with conditional dispatch
  1336. >>> # Get the original kernel for torch.sin
  1337. >>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU")
  1338. >>>
  1339. >>> # If input has negative values, use original sin, otherwise return zeros
  1340. >>> def conditional_sin_impl(dispatch_keys, x):
  1341. >>> if (x < 0).any():
  1342. >>> return original_sin_kernel.call_boxed(dispatch_keys, x)
  1343. >>> else:
  1344. >>> return torch.zeros_like(x)
  1345. >>>
  1346. >>> lib = torch.library.Library("aten", "IMPL")
  1347. >>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet
  1348. >>> which needs to be the first argument to ``kernel.call_boxed``
  1349. >>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True)
  1350. >>>
  1351. >>> # Test the conditional behavior
  1352. >>> x_positive = torch.tensor([1.0, 2.0])
  1353. >>> x_mixed = torch.tensor([-1.0, 2.0])
  1354. >>> torch.sin(x_positive)
  1355. tensor([0., 0.])
  1356. >>> torch.sin(x_mixed)
  1357. tensor([-0.8415, 0.9093])
  1358. """
  1359. if not isinstance(op, (str, torch._ops.OpOverload)):
  1360. raise ValueError(f"get_kernel({op}): got unexpected type for op: {type(op)}")
  1361. if isinstance(op, torch._ops.OpOverload):
  1362. op = op._name
  1363. if isinstance(dispatch_key, str):
  1364. try:
  1365. dispatch_key = torch._C.DispatchKey.__members__[dispatch_key]
  1366. except KeyError:
  1367. raise ValueError(f"Invalid dispatch key: {dispatch_key}") from None
  1368. return torch._C._dispatch_get_computed_kernel_for_dispatch_key(op, dispatch_key)
  1369. _OPCHECK_DEFAULT_UTILS = (
  1370. "test_schema",
  1371. "test_autograd_registration",
  1372. "test_faketensor",
  1373. "test_aot_dispatch_dynamic",
  1374. )
  1375. def opcheck(
  1376. op: torch._ops.OpOverload | torch._ops.OpOverloadPacket | CustomOpDef,
  1377. args: tuple[Any, ...],
  1378. kwargs: dict[str, Any] | None = None,
  1379. *,
  1380. test_utils: str | Sequence[str] = _OPCHECK_DEFAULT_UTILS,
  1381. raise_exception: bool = True,
  1382. atol=None,
  1383. rtol=None,
  1384. ) -> dict[str, str]:
  1385. """Given an operator and some sample arguments, tests if the operator is
  1386. registered correctly.
  1387. That is, when you use the torch.library/TORCH_LIBRARY APIs to create a
  1388. custom op, you specified metadata (e.g. mutability info) about the custom op
  1389. and these APIs require that the functions you pass them satisfy certain
  1390. properties (e.g. no data pointer access in the fake/meta/abstract kernel)
  1391. ``opcheck`` tests these metadata and properties.
  1392. Concretely, we test the following:
  1393. - test_schema: If the schema matches the implementation of
  1394. the operator. For example: if the schema specifies a Tensor is mutated,
  1395. then we check the implementation mutates the Tensor. If the schema
  1396. specifies that we return a new Tensor, then we check that the
  1397. implementation returns a new Tensor (instead of an existing one or
  1398. a view of an existing one).
  1399. - test_autograd_registration: If the operator supports training
  1400. (autograd): we check that its autograd formula is registered via
  1401. torch.library.register_autograd or a manual registration to one
  1402. or more DispatchKey::Autograd keys. Any other DispatchKey-based
  1403. registrations may lead to undefined behavior.
  1404. - test_faketensor: If the operator has a FakeTensor kernel
  1405. (and if it is correct). The FakeTensor kernel is necessary (
  1406. but not sufficient) for the operator to work with PyTorch compilation
  1407. APIs (torch.compile/export/FX). We check that a FakeTensor kernel
  1408. (also sometimes known as a meta kernel) was registered for the
  1409. operator and that it is correct. This test takes the result of
  1410. running the operator on real tensors and the result of running
  1411. the operator on FakeTensors and checks that they have the same
  1412. Tensor metadata (sizes/strides/dtype/device/etc).
  1413. - test_aot_dispatch_dynamic: If the operator has correct behavior
  1414. with PyTorch compilation APIs (torch.compile/export/FX).
  1415. This checks that the outputs (and gradients, if applicable) are the
  1416. same under eager-mode PyTorch and torch.compile.
  1417. This test is a superset of ``test_faketensor`` and is an e2e test;
  1418. other things it tests are that the operator supports
  1419. functionalization and that the backward pass (if it exists) also
  1420. supports FakeTensor and functionalization.
  1421. For best results, please call ``opcheck`` multiple times with a
  1422. representative set of inputs. If your operator supports
  1423. autograd, please use ``opcheck`` with inputs with ``requires_grad = True``;
  1424. if your operator supports multiple devices (e.g. CPU and CUDA), please
  1425. use ``opcheck`` with inputs on all supported devices.
  1426. Args:
  1427. op: The operator. Must either be a function decorated with
  1428. :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket
  1429. found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo)
  1430. args: The args to the operator
  1431. kwargs: The kwargs to the operator
  1432. test_utils: Tests that we should run. Default: all of them.
  1433. Example: ("test_schema", "test_faketensor")
  1434. raise_exception: If we should raise an exception on the first
  1435. error. If False, we will return a dict with information
  1436. on if each test passed or not.
  1437. rtol (Optional[float]): Relative tolerance for floating point comparisons.
  1438. If specified ``atol`` must also be specified.
  1439. If omitted, default values based on the ``dtype`` are selected
  1440. (see the table in :func:`torch.testing.assert_close`).
  1441. atol (Optional[float]): Absolute tolerance for floating point comparisons.
  1442. If specified ``rtol`` must also be specified.
  1443. If omitted, default values based on the ``dtype`` are selected
  1444. (see the table in :func:`torch.testing.assert_close`).
  1445. .. warning::
  1446. opcheck and :func:`torch.autograd.gradcheck` test different things;
  1447. opcheck tests if your usage of torch.library APIs is correct while
  1448. :func:`torch.autograd.gradcheck` tests if your autograd formula is
  1449. mathematically correct. Use both to test custom ops that support
  1450. gradient computation.
  1451. Example:
  1452. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  1453. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  1454. >>> def numpy_mul(x: Tensor, y: float) -> Tensor:
  1455. >>> x_np = x.numpy(force=True)
  1456. >>> z_np = x_np * y
  1457. >>> return torch.from_numpy(z_np).to(x.device)
  1458. >>>
  1459. >>> @numpy_mul.register_fake
  1460. >>> def _(x, y):
  1461. >>> return torch.empty_like(x)
  1462. >>>
  1463. >>> def setup_context(ctx, inputs, output):
  1464. >>> y, = inputs
  1465. >>> ctx.y = y
  1466. >>>
  1467. >>> def backward(ctx, grad):
  1468. >>> return grad * ctx.y, None
  1469. >>>
  1470. >>> numpy_mul.register_autograd(backward, setup_context=setup_context)
  1471. >>>
  1472. >>> sample_inputs = [
  1473. >>> (torch.randn(3), 3.14),
  1474. >>> (torch.randn(2, 3, device='cuda'), 2.718),
  1475. >>> (torch.randn(1, 10, requires_grad=True), 1.234),
  1476. >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
  1477. >>> ]
  1478. >>>
  1479. >>> for args in sample_inputs:
  1480. >>> torch.library.opcheck(numpy_mul, args)
  1481. """
  1482. import torch.testing._internal.optests as optests
  1483. return optests.opcheck(
  1484. op,
  1485. args,
  1486. kwargs,
  1487. test_utils=test_utils,
  1488. raise_exception=raise_exception,
  1489. rtol=rtol,
  1490. atol=atol,
  1491. )