impl.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import functools
  4. import inspect
  5. import sys
  6. import typing
  7. import warnings
  8. import weakref
  9. import torch
  10. import torch._C as _C
  11. import torch._library.infer_schema
  12. import torch.library as library
  13. from torch._library.infer_schema import infer_schema
  14. from torch.library import get_ctx
  15. from torchgen.model import (
  16. BaseTy,
  17. BaseType,
  18. FunctionSchema,
  19. ListType,
  20. OperatorName,
  21. SchemaKind,
  22. )
  23. from .autograd import autograd_kernel_indirection, construct_autograd_kernel
  24. """
  25. torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library.
  26. Please use those APIs instead.
  27. """
  28. __all__ = ["custom_op", "CustomOp", "get_ctx"]
  29. SUPPORTED_DEVICE_TYPE_TO_KEY = {
  30. "cpu": "CPU",
  31. "cuda": "CUDA",
  32. }
  33. # We will not let users register CustomOps with anything that could look like
  34. # PyTorch internals to avoid confusion.
  35. RESERVED_NS = {
  36. "prim",
  37. "prims",
  38. "aten",
  39. "at",
  40. "torch",
  41. "pytorch",
  42. }
  43. def warn_deprecated():
  44. warnings.warn(
  45. "torch._custom_op is deprecated and will be removed in PyTorch 2.6, please "
  46. "use the equivalent torch.library API instead.",
  47. DeprecationWarning,
  48. )
  49. def custom_op(
  50. qualname: str, manual_schema: typing.Optional[str] = None
  51. ) -> typing.Callable:
  52. r"""
  53. This API is deprecated, please use torch.library.custom_op instead
  54. """
  55. warn_deprecated()
  56. def inner(func):
  57. if not inspect.isfunction(func):
  58. raise ValueError(
  59. f"custom_op(...)(func): Expected `func` to be a Python "
  60. f"function, got: {type(func)}"
  61. )
  62. ns, name = parse_qualname(qualname)
  63. validate_namespace(ns)
  64. if func.__name__ != name:
  65. raise ValueError(
  66. f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
  67. f"to have name '{name}' but got '{func.__name__}'. "
  68. f"Please either change the name of `func` or the qualname that "
  69. f"is passed to `custom_op`"
  70. )
  71. schema = (
  72. infer_schema(func, mutates_args=())
  73. if manual_schema is None
  74. else manual_schema
  75. )
  76. schema_str = f"{name}{schema}"
  77. function_schema = FunctionSchema.parse(schema_str)
  78. validate_schema(function_schema)
  79. if manual_schema is not None:
  80. validate_function_matches_schema(function_schema, func)
  81. lib = library.Library(ns, "FRAGMENT")
  82. lib.define(schema_str)
  83. ophandle = find_ophandle_or_throw(ns, function_schema.name)
  84. result = CustomOp(
  85. lib, ns, function_schema, name, ophandle, _private_access=True
  86. )
  87. result.__name__ = func.__name__
  88. result.__module__ = func.__module__
  89. result.__doc__ = func.__doc__
  90. library.impl(lib, result._opname, "Autograd")(
  91. autograd_kernel_indirection(weakref.proxy(result))
  92. )
  93. torch._C._dispatch_set_report_error_callback(
  94. ophandle, functools.partial(report_error_callback, weakref.proxy(result))
  95. )
  96. return result
  97. return inner
  98. # Global dictionary holding references to all CustomOp objects
  99. # Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
  100. # Used to query the CustomOp associated with a specific C++ dispatcher operator.
  101. # An example usage is FakeTensor: FakeTensor checks if a specific operator
  102. # has an implementation registered via the CustomOp API.
  103. # Indexed by qualname (e.g. aten::foo)
  104. global_registry: dict[str, "CustomOp"] = {}
  105. class CustomOp:
  106. r"""
  107. This API is deprecated, please use torch.library.custom_op instead
  108. """
  109. def __init__(
  110. self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False
  111. ):
  112. super().__init__()
  113. warn_deprecated()
  114. if not _private_access:
  115. raise RuntimeError(
  116. "The CustomOp constructor is private and we do not guarantee "
  117. "BC for it. Please use custom_op(...) to create a CustomOp object"
  118. )
  119. name = f"{cpp_ns}::{operator_name}"
  120. self._schema = schema
  121. self._cpp_ns = cpp_ns
  122. self._lib: library.Library = lib
  123. self._ophandle: _C._DispatchOperatorHandle = ophandle
  124. # Has the name of the op, e.g. "foo". We cache here for convenience.
  125. self._opname: str = operator_name
  126. # this is _opname but with namespace. e.g. "custom::foo"
  127. self._qualname: str = name
  128. self.__name__ = None # mypy requires this
  129. # NB: Some of these impls are registered as kernels to DispatchKeys.
  130. # Modifying the _impls dict directly won't do anything in that case.
  131. self._impls: dict[str, typing.Optional[FuncAndLocation]] = {}
  132. # See NOTE [CustomOp autograd kernel indirection]
  133. self._registered_autograd_kernel_indirection = False
  134. global_registry[self._qualname] = self
  135. def _register_autograd_kernel_indirection(self):
  136. assert not self._registered_autograd_kernel_indirection
  137. self._lib.impl(
  138. self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd"
  139. )
  140. self._registered_autograd_kernel_indirection = True
  141. # Records the impl and the source location in self._impls
  142. # Note that this doesn't cause torch.library to use the impl, that
  143. # needs to be done in a separate self._lib.impl call.
  144. def _register_impl(self, kind, func, stacklevel=2):
  145. if self._has_impl(kind):
  146. func_and_location = self._impls[kind]
  147. assert func_and_location is not None # Pacify mypy
  148. location = func_and_location.location
  149. raise RuntimeError(
  150. f"Attempting to register a {kind} impl for operator {self._qualname} "
  151. f"that already has a {kind} impl registered from Python at "
  152. f"{location}. This is not supported."
  153. )
  154. frame = inspect.getframeinfo(sys._getframe(stacklevel))
  155. location = f"{frame.filename}:{frame.lineno}"
  156. self._impls[kind] = FuncAndLocation(func, location)
  157. def _get_impl(self, kind):
  158. return self._impls[kind]
  159. def _has_impl(self, kind):
  160. return kind in self._impls
  161. def _destroy(self):
  162. # NOTE: [CustomOp lifetime]
  163. # A CustomOp, once created, lives forever. The mechanism is that the
  164. # global registry holds a reference to it. However, to make testing
  165. # easier, we want to be able to destroy CustomOp objects.
  166. # CustomOp._destroy does the job, though it leaves the CustomOp
  167. # in a garbage state.
  168. del self._lib
  169. opnamespace = getattr(torch.ops, self._cpp_ns)
  170. if hasattr(opnamespace, self._opname):
  171. delattr(opnamespace, self._opname)
  172. del global_registry[self._qualname]
  173. def __repr__(self):
  174. return f'<CustomOp(op="{self._qualname}")>'
  175. def __call__(self, *args, **kwargs):
  176. # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
  177. # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
  178. # issues from caching operators that make testing CustomOp difficult).
  179. result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
  180. return result
  181. def impl(
  182. self,
  183. device_types: typing.Union[str, typing.Iterable[str]],
  184. _stacklevel=2,
  185. ) -> typing.Callable:
  186. r"""
  187. This API is deprecated, please use torch.library.custom_op instead
  188. """
  189. if isinstance(device_types, str):
  190. device_types = [device_types]
  191. for device_type in device_types:
  192. validate_device_type(device_type)
  193. def inner(f):
  194. for device_type in set(device_types):
  195. self._check_doesnt_have_library_impl(device_type)
  196. self._register_impl(device_type, f, stacklevel=_stacklevel)
  197. dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
  198. library.impl(self._lib, self._opname, dispatch_key)(f)
  199. return f
  200. return inner
  201. def _check_doesnt_have_library_impl(self, device_type):
  202. if self._has_impl(device_type):
  203. return
  204. key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
  205. if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
  206. raise RuntimeError(
  207. f"impl(..., device_types={device_type}): the operator {self._qualname} "
  208. f"already has an implementation for this device type via a "
  209. f"pre-existing torch.library or TORCH_LIBRARY registration."
  210. )
  211. def impl_factory(self) -> typing.Callable:
  212. r"""Register an implementation for a factory function."""
  213. def inner(f):
  214. self._register_impl("factory", f)
  215. library.impl(self._lib, self._opname, "BackendSelect")(f)
  216. return f
  217. return inner
  218. def impl_abstract(self, _stacklevel=2) -> typing.Callable:
  219. r"""
  220. This API is deprecated, please use torch.library.custom_op instead
  221. """
  222. def inner(f):
  223. self._check_doesnt_have_library_meta_impl()
  224. self._register_impl("abstract", f, stacklevel=_stacklevel)
  225. location = self._get_impl("abstract").location
  226. qualname = self._qualname
  227. # Handle DispatchKey.Meta registration
  228. @functools.wraps(f)
  229. def f_with_ctx(*args, **kwargs):
  230. def error_on_ctx():
  231. raise RuntimeError(
  232. f"Attempted to call get_ctx() for the meta implementation "
  233. f"for {qualname}."
  234. f"You have presumably called get_ctx() because the operator "
  235. f"has a data-dependent output shape; if so, there is no "
  236. f"such meta implementation and this error is the correct "
  237. f"behavior. Otherwise, please remove the call to get_ctx() "
  238. f"in the implementation registered with impl_abstract "
  239. f"at {location}"
  240. )
  241. with torch._library.fake_impl.set_ctx_getter(error_on_ctx):
  242. return f(*args, **kwargs)
  243. self._lib.impl(self._opname, f_with_ctx, "Meta")
  244. return f
  245. return inner
  246. def _check_can_register_backward(self):
  247. def error(detail):
  248. raise RuntimeError(
  249. f"Cannot use torch._custom_ops APIs to register backward "
  250. f"formula for {detail}. Got operator "
  251. f"{self._qualname} with schema: {schema}"
  252. )
  253. schema = self._schema
  254. if schema.kind() != SchemaKind.functional:
  255. error("non-functional operator")
  256. rets = schema.returns
  257. if not schema.returns:
  258. error("operator with no returns")
  259. assert len(rets) > 0
  260. is_non_mutating_view = any(
  261. r.annotation is not None and not r.annotation.is_write for r in rets
  262. )
  263. if is_non_mutating_view:
  264. error("operator that returns views")
  265. # We make assumptions about the schema's return types.
  266. allowed_return_types = {
  267. BaseType(BaseTy.int): "int",
  268. BaseType(BaseTy.SymInt): "SymInt",
  269. BaseType(BaseTy.bool): "bool",
  270. BaseType(BaseTy.float): "float",
  271. BaseType(BaseTy.Tensor): "Tensor",
  272. ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
  273. }
  274. for ret in schema.returns:
  275. if ret.type in allowed_return_types:
  276. continue
  277. error(
  278. f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})"
  279. )
  280. def _check_doesnt_have_library_autograd_impl(self):
  281. if self._registered_autograd_kernel_indirection:
  282. return
  283. if _C._dispatch_has_kernel_for_dispatch_key(
  284. self._qualname, "CompositeImplicitAutograd"
  285. ):
  286. raise RuntimeError(
  287. f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
  288. f"already has an implementation for this device type via a "
  289. f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
  290. f"CompositeImplicitAutograd operators do not need an autograd formula; "
  291. f"instead, the operator will decompose into its constituents and those "
  292. f"can have autograd formulas defined on them."
  293. )
  294. # We can improve this by adding "all Autograd<BACKEND> keys", but
  295. # realistically people will just be using this API for CPU/CUDA for now.
  296. for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
  297. if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
  298. raise RuntimeError(
  299. f"impl_backward/impl_save_for_backward: "
  300. f"the operator {self._qualname} already has an Autograd kernel "
  301. f"registered to DispatchKey::{key} vi a pre-existing "
  302. f"torch.library or TORCH_LIBRARY registration. Please either "
  303. f"remove those registrations or don't use the torch._custom_ops APIs"
  304. )
  305. def _check_doesnt_have_library_meta_impl(self):
  306. if self._has_impl("abstract"):
  307. return
  308. # If the user's operator is CompositeExplicitAutograd,
  309. # allow them to impl_abstract. This is being pragmatic
  310. # (existing custom ops may have CompositeExplicitAutograd
  311. # registration that don't work with Meta kernels, so this
  312. # gives them an escape hatch).
  313. if _C._dispatch_has_kernel_for_dispatch_key(
  314. self._qualname, "CompositeExplicitAutograd"
  315. ) and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
  316. return
  317. # Otherwise, if the user's already has a Meta kernel or their
  318. # op is CompositeImplicitAutograd or some other alias dispatch key,
  319. # raise.
  320. # Special case for CompositeImplicitAutograd
  321. if _C._dispatch_has_kernel_for_dispatch_key(
  322. self._qualname, "CompositeImplicitAutograd"
  323. ):
  324. raise RuntimeError(
  325. f"impl_abstract(...): the operator {self._qualname} "
  326. f"already has an implementation for this device type via a "
  327. f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
  328. f"CompositeImplicitAutograd operators do not need an abstract impl; "
  329. f"instead, the operator will decompose into its constituents and those "
  330. f"can have abstract impls defined on them."
  331. )
  332. if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
  333. raise RuntimeError(
  334. f"impl_abstract(...): the operator {self._qualname} "
  335. f"already has an DispatchKey::Meta implementation via a "
  336. f"pre-existing torch.library or TORCH_LIBRARY registration. "
  337. f"Please either remove that registration or don't call impl_abstract."
  338. )
  339. # NOTE ["backward", "save_for_backward", and "autograd"]
  340. # As a part of the explicit autograd API, a user must provide us
  341. # a "save_for_backward" function and a "backward" function.
  342. # When both of these have been provided, then we automatically
  343. # construct the "autograd" kernel.
  344. def _register_autograd_kernel(self):
  345. assert self._has_impl("backward")
  346. assert self._has_impl("save_for_backward")
  347. kernel = construct_autograd_kernel(
  348. self._schema,
  349. self._output_differentiability,
  350. self,
  351. get_op(self._qualname),
  352. self._get_impl("save_for_backward").func,
  353. self._get_impl("backward").func,
  354. )
  355. self._register_impl("autograd", kernel)
  356. def impl_save_for_backward(self, _stacklevel=2):
  357. r"""Register a function that tells us what to save for backward.
  358. Please see impl_backward for more details.
  359. """
  360. def inner(f):
  361. self._check_can_register_backward()
  362. self._check_doesnt_have_library_autograd_impl()
  363. if not self._registered_autograd_kernel_indirection:
  364. self._register_autograd_kernel_indirection()
  365. self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
  366. if self._has_impl("backward"):
  367. self._register_autograd_kernel()
  368. return inner
  369. def impl_backward(self, output_differentiability=None, _stacklevel=2):
  370. r"""
  371. This API is deprecated, please use torch.library.custom_op instead
  372. """
  373. if output_differentiability is not None:
  374. def yell():
  375. raise RuntimeError(
  376. f"impl_backward(output_differentiability): expected "
  377. f"output_differentiability to be a list of bools with "
  378. f"length equal to the number of outputs of this CustomOp "
  379. f"got: {output_differentiability}"
  380. )
  381. if not isinstance(output_differentiability, list):
  382. yell()
  383. for diff in output_differentiability:
  384. if not isinstance(diff, bool):
  385. yell()
  386. if len(self._schema.returns) != len(output_differentiability):
  387. yell()
  388. def inner(f):
  389. self._check_can_register_backward()
  390. self._check_doesnt_have_library_autograd_impl()
  391. if not self._registered_autograd_kernel_indirection:
  392. self._register_autograd_kernel_indirection()
  393. self._register_impl("backward", f, stacklevel=_stacklevel)
  394. self._output_differentiability = output_differentiability
  395. if self._has_impl("save_for_backward"):
  396. self._register_autograd_kernel()
  397. return inner
  398. @dataclasses.dataclass
  399. class FuncAndLocation:
  400. func: typing.Callable
  401. location: str
  402. def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
  403. overload_name = (
  404. "" if operator_name.overload_name is None else operator_name.overload_name
  405. )
  406. return _C._dispatch_find_schema_or_throw(
  407. f"{cpp_ns}::{str(operator_name.name)}", overload_name
  408. )
  409. def validate_namespace(ns: str) -> None:
  410. if "." in ns:
  411. raise ValueError(
  412. f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
  413. f"valid variable name)"
  414. )
  415. if ns in RESERVED_NS:
  416. raise ValueError(
  417. f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
  418. f"please choose something else. "
  419. )
  420. def validate_schema(schema: FunctionSchema) -> None:
  421. if not torch._library.utils.is_functional_schema(schema):
  422. raise ValueError(
  423. f"custom_op only supports functional operators "
  424. f"(ops that do not mutate any inputs, do not return "
  425. f"views of the inputs, and has at least one return). "
  426. f"Got the following non-functional schema: {schema}"
  427. )
  428. # For simplicity: don't allow self arguments
  429. if schema.arguments.self_arg is not None:
  430. raise ValueError(
  431. f"custom_op does not support arguments named 'self'. Please "
  432. f"rename your argument. Got: {schema}"
  433. )
  434. def parse_qualname(qualname: str) -> tuple[str, str]:
  435. names = qualname.split("::", 1)
  436. if len(names) != 2:
  437. raise ValueError(
  438. f"Expected there to be a namespace in {qualname}, i.e. The "
  439. f"operator name should look something like ns::foo"
  440. )
  441. if "." in names[1]:
  442. raise ValueError(
  443. f"The torch.custom_ops APIs do not handle overloads, "
  444. f"i.e. operator names with '.' in them. "
  445. f"Please name your operator something like ns::foo. "
  446. f"Got: {qualname}"
  447. )
  448. return names[0], names[1]
  449. def validate_device_type(device_type: str) -> None:
  450. if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
  451. raise ValueError(
  452. f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
  453. f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
  454. )
  455. def supported_param(param: inspect.Parameter) -> bool:
  456. return param.kind in (
  457. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  458. inspect.Parameter.KEYWORD_ONLY,
  459. )
  460. def validate_function_matches_schema(
  461. schema: FunctionSchema, func: typing.Callable
  462. ) -> None:
  463. sig = inspect.signature(func)
  464. if not all(supported_param(p) for _, p in sig.parameters.items()):
  465. raise ValueError(
  466. f"custom_op(..., manual_schema)(func): positional-only args, "
  467. f"varargs, and kwargs are not supported. Please rewrite `func` "
  468. f"to not have them. Got `func` with signature: {sig}"
  469. )
  470. if (
  471. any(
  472. p.annotation is not inspect.Parameter.empty
  473. for _, p in sig.parameters.items()
  474. )
  475. or sig.return_annotation is not inspect.Signature.empty
  476. ):
  477. raise ValueError(
  478. f"custom_op(..., manual_schema)(func): When passing in a manual "
  479. f"schema, we expect `func` to have no type annotations to avoid "
  480. f"ambiguity. Got `func` with signature: {sig}"
  481. )
  482. positional = [
  483. (name, param)
  484. for name, param in sig.parameters.items()
  485. if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  486. ]
  487. kwargonly = [
  488. (name, param)
  489. for name, param in sig.parameters.items()
  490. if param.kind == inspect.Parameter.KEYWORD_ONLY
  491. ]
  492. def error():
  493. raise ValueError(
  494. f"custom_op(..., manual_schema)(func): When passing in a manual "
  495. f"schema, we expect `func`'s signature to match `manual_schema` "
  496. f"(aside from type annotations). "
  497. f"func's signature: {sig}, manual_schema: {schema}"
  498. )
  499. def error_default_args():
  500. raise ValueError(
  501. f"custom_op(..., manual_schema)(func): "
  502. f"neither func nor manual_schema should have default "
  503. f"arguments. Got "
  504. f"func's signature: {sig}, manual_schema: {schema}"
  505. )
  506. def compare(sig_args, schema_args):
  507. if len(sig_args) != len(schema_args):
  508. error()
  509. for (name, param), arg in zip(sig_args, schema_args):
  510. if name != arg.name:
  511. error()
  512. if param.default is not inspect.Parameter.empty or arg.default is not None:
  513. error_default_args()
  514. compare(positional, schema.arguments.flat_positional)
  515. compare(kwargonly, schema.arguments.flat_kwarg_only)
  516. def report_error_callback(custom_op: typing.Any, key: str) -> None:
  517. if key == "Undefined":
  518. raise NotImplementedError(
  519. f"{custom_op}: There were no Tensor inputs to this operator "
  520. f"(e.g. you passed an empty list of Tensors). If your operator is a "
  521. f"factory function (that is, it takes no Tensors and constructs "
  522. f"a new one), then please use CustomOp.impl_factory to register "
  523. f"an implementation for it"
  524. )
  525. if key == "Meta":
  526. raise NotImplementedError(
  527. f"{custom_op}: when running with device='Meta' tensors: there is no "
  528. f"abstract impl registered for this CustomOp. Please register one via "
  529. f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
  530. )
  531. if key in ("CPU", "CUDA"):
  532. device = key.lower()
  533. raise NotImplementedError(
  534. f"{custom_op}: when running with device='{device}' tensors: there is no "
  535. f"{device} impl registered for this CustomOp. Please register one via "
  536. f"CustomOp.impl(device_type='{device}')"
  537. )
  538. raise NotImplementedError(
  539. f"{custom_op}: No implementation for dispatch key {key}. It is likely "
  540. f"that we have not added this functionality yet, please either open an "
  541. f"issue or if you're feeling adventurous, use the low-level "
  542. f"torch.library API"
  543. )
  544. def custom_op_from_existing(op):
  545. ns = op.namespace
  546. lib = torch.library.Library(ns, "FRAGMENT")
  547. name = op.name().split("::")[-1]
  548. schema_str = str(op._schema)
  549. # CustomOp expects the schema string without the namespace
  550. schema_str = schema_str.rsplit("::", maxsplit=1)[-1]
  551. schema = FunctionSchema.parse(schema_str)
  552. return CustomOp(lib, ns, schema, name, op, _private_access=True)
  553. def get_op(qualname):
  554. def error_not_found():
  555. raise ValueError(
  556. f"Could not find the operator {qualname}. Please make sure you have "
  557. f"already registered the operator and (if registered from C++) "
  558. f"loaded it via torch.ops.load_library."
  559. )
  560. ns, name = parse_qualname(qualname)
  561. if not hasattr(torch.ops, ns):
  562. error_not_found()
  563. opnamespace = getattr(torch.ops, ns)
  564. if not hasattr(opnamespace, name):
  565. error_not_found()
  566. packet = getattr(opnamespace, name)
  567. if not hasattr(packet, "default"):
  568. error_not_found()
  569. return packet.default
  570. def _find_custom_op(qualname, also_check_torch_library=False):
  571. if qualname in global_registry:
  572. return global_registry[qualname]
  573. if not also_check_torch_library:
  574. raise RuntimeError(
  575. f'Could not find custom op "{qualname}". Did you register it via '
  576. f"the torch._custom_ops API?"
  577. )
  578. overload = get_op(qualname)
  579. result = custom_op_from_existing(overload)
  580. return result
  581. def get_abstract_impl(qualname):
  582. if qualname not in torch._custom_op.impl.global_registry:
  583. return None
  584. custom_op = torch._custom_op.impl.global_registry[qualname]
  585. if custom_op is None:
  586. return None
  587. if not custom_op._has_impl("abstract"):
  588. return None
  589. return custom_op._get_impl("abstract").func
  590. def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
  591. ns, name = qualname.split("::")
  592. schema_str = f"{name}{schema}"
  593. function_schema = FunctionSchema.parse(schema_str)
  594. validate_schema(function_schema)
  595. tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
  596. lib = library.Library(ns, "FRAGMENT")
  597. lib.define(schema_str, tags=tags)
  598. ophandle = find_ophandle_or_throw(ns, function_schema.name)
  599. result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
  600. result._register_autograd_kernel_indirection()
  601. torch._C._dispatch_set_report_error_callback(
  602. ophandle, functools.partial(report_error_callback, weakref.proxy(result))
  603. )
  604. return get_op(qualname)