_ops.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449
  1. # mypy: allow-untyped-defs
  2. import abc
  3. import contextlib
  4. import ctypes
  5. import importlib
  6. import inspect
  7. import sys
  8. import types
  9. from collections.abc import Callable, Iterator
  10. from functools import cached_property
  11. from typing import Any, ClassVar, Concatenate, final, Generic, TYPE_CHECKING
  12. from typing_extensions import ParamSpec, TypeVar
  13. import torch
  14. import torch.utils._pytree as pytree
  15. from torch import _utils_internal
  16. from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey
  17. from torch._functorch.pyfunctorch import dispatch_functorch, TransformType
  18. from torch.utils._python_dispatch import TorchDispatchMode
  19. if TYPE_CHECKING:
  20. from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
  21. _T = TypeVar("_T", default=Any)
  22. _P = ParamSpec("_P", default=...)
  23. # Query `hasattr` only once.
  24. _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
  25. @contextlib.contextmanager
  26. def dl_open_guard():
  27. """
  28. Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
  29. shared library to load custom operators.
  30. """
  31. if not _SET_GLOBAL_FLAGS:
  32. yield
  33. return
  34. old_flags = sys.getdlopenflags()
  35. sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
  36. try:
  37. yield
  38. finally:
  39. sys.setdlopenflags(old_flags)
  40. class OperatorBase:
  41. """
  42. Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator
  43. (which represents Python-only operators that are unrepresentable in TorchScript).
  44. """
  45. def __init__(self):
  46. # The dispatch cache precomputes a mapping of dispatch key that the
  47. # dispatcher wants to dispatch to, to an actual implementation of the
  48. # dispatch key. Confusingly, the actual implementation could *also* be a
  49. # dispatch key, but in this case, this refers to the C++ kernel that
  50. # was registered to some dispatch key. Aliases are permitted in the
  51. # latter but not the former; for example, you might lookup the
  52. # entry for AutogradCPU, and this maps you to the Autograd key for
  53. # the generic autograd kernel that works for all devices. Since this
  54. # is the Python dispatcher, you can also put an arbitrary Python
  55. # callable to call instead. This handler gets precisely the
  56. # args/kwargs that the operator was __call__'ed with.
  57. # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
  58. # for use with OpOverload; cache lookup is done entirely from C++
  59. # for speed.
  60. # TODO: The cache is NOT currently used by HigherOrderOperator, but it should!
  61. self._dispatch_cache: dict[DispatchKey, DispatchKey | Callable[..., Any]] = {}
  62. # This table allows you to override the behavior of a particular
  63. # dispatch key to call a custom Python function, rather than the
  64. # ordinary C++ configured behavior. This is the raison d'etre of # codespell:ignore
  65. # Python dispatcher: to let you program the dispatcher from Python
  66. # in case you need something unusual, and don't want to clobber
  67. # the existing registrations using the Python operator registration
  68. # API.
  69. self.py_kernels: dict[DispatchKey, Callable[..., Any]] = {}
  70. # This table allows you to override the behavior of a particular
  71. # operator for a particular TorchDispatchMode. In practice,
  72. # we are using this mostly for ProxyTensorMode. Modes can be
  73. # thought of as an open world extension of dispatch keys, so it
  74. # makes sense that you should be able to register them, the same
  75. # way you can register dispatch keys.
  76. self.python_key_table: dict[
  77. type[TorchDispatchMode | torch.Tensor], Callable[..., Any]
  78. ] = {}
  79. # This table allows you to override the behavior of functorch
  80. # transformations. NB: this currently only does something for
  81. # HigherOrderOperator
  82. self.functorch_table = {}
  83. def __call__(self, *args, **kwargs):
  84. raise NotImplementedError
  85. def has_kernel_for_dispatch_key(self, k):
  86. return k in self.py_kernels
  87. def has_kernel_for_any_dispatch_key(self, ks):
  88. for k in self.py_kernels:
  89. if not torch._C._dispatch_is_alias_key(k) and ks.has(k):
  90. return True
  91. return False
  92. def py_impl(
  93. self,
  94. k: type[TorchDispatchMode] | type[torch.Tensor] | TransformType | DispatchKey,
  95. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  96. def inner(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  97. if inspect.isclass(k) and (
  98. issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
  99. ):
  100. assert k not in self.python_key_table
  101. # TODO(voz): Should we replace setting DispatchKey.Python entirely with setting mode keys?
  102. self.python_key_table[k] = fn
  103. self._dispatch_cache.clear()
  104. return fn
  105. if isinstance(k, TransformType):
  106. assert k not in self.functorch_table
  107. self.functorch_table[k] = fn
  108. return fn
  109. assert isinstance(k, DispatchKey)
  110. assert k != DispatchKey.Python, (
  111. "Please register a mode for the DispatchKey.Python key instead."
  112. )
  113. if k in self.py_kernels:
  114. raise RuntimeError(
  115. f"Trying to override a python impl for {k} on operator {self.name()}"
  116. )
  117. self.py_kernels[k] = fn
  118. self._dispatch_cache.clear()
  119. return fn
  120. return inner
  121. # Registers an implementation to all **3** variants of functionalization that we have:
  122. # - DispatchKey.Functionalize
  123. # - functorch.TransformType.Functionalize
  124. # - FunctionalTensorMode
  125. # Example:
  126. # @py_functionalize_impl
  127. # def functionalize_rule(ctx, inner_f, *args):
  128. # args_unwrapped = ctx.unwrap_tensors(args)
  129. # with ctx.redispatch_to_next():
  130. # out = ctx.functionalize(inner_f)(*args_unwrapped)
  131. # return ctx.wrap_tensors(out)
  132. def py_functionalize_impl(
  133. self, fn: Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]
  134. ) -> Callable[Concatenate["BaseFunctionalizeAPI", _P], _T]:
  135. from torch._subclasses.functional_tensor import (
  136. CppFunctionalizeAPI,
  137. FunctionalTensorMode,
  138. FunctorchFunctionalizeAPI,
  139. PythonFunctionalizeAPI,
  140. )
  141. # Construct our three flavors of functionalization,
  142. # each of which have slightly different wrap/unwrap/redispatch policies
  143. def functionalize_dk_fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  144. return fn(CppFunctionalizeAPI(), *args, **kwargs)
  145. def functionalize_dispatch_mode_fn(
  146. mode: FunctionalTensorMode | None, *args: _P.args, **kwargs: _P.kwargs
  147. ) -> _T:
  148. return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
  149. def functionalize_functorch_fn(
  150. interpreter, *args: _P.args, **kwargs: _P.kwargs
  151. ) -> _T:
  152. return fn(FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
  153. self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn)
  154. self.py_impl(FunctionalTensorMode)(functionalize_dispatch_mode_fn)
  155. self.py_impl(TransformType.Functionalize)(functionalize_functorch_fn)
  156. return fn
  157. def name(self):
  158. raise NotImplementedError
  159. # Equivalent to computeDispatchTableEntryWithDebug
  160. def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
  161. # 1. (Direct) operator registration
  162. if op.has_kernel_for_dispatch_key(k):
  163. return k
  164. # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
  165. cand = DispatchKey.CompositeExplicitAutogradNonFunctional
  166. if (
  167. k == DispatchKey.Undefined or is_included_in_alias(k, cand)
  168. ) and op.has_kernel_for_dispatch_key(cand):
  169. return cand
  170. # 2.2 Use CompositeExplicitAutograd kernel if available
  171. cand = DispatchKey.CompositeExplicitAutograd
  172. if (
  173. k == DispatchKey.Undefined or is_included_in_alias(k, cand)
  174. ) and op.has_kernel_for_dispatch_key(cand):
  175. return cand
  176. has_backend_kernel = op.has_kernel_for_any_dispatch_key(
  177. torch._C._dispatch_get_backend_keyset_from_autograd(k)
  178. ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd)
  179. # 2.3. Use CompositeImplicitAutograd kernel if available
  180. cand = DispatchKey.CompositeImplicitAutogradNestedTensor
  181. if (
  182. (k != DispatchKey.Undefined and is_included_in_alias(k, cand))
  183. and op.has_kernel_for_dispatch_key(cand)
  184. and not has_backend_kernel
  185. ):
  186. return cand
  187. cand = DispatchKey.CompositeImplicitAutograd
  188. if (
  189. k == DispatchKey.Undefined or is_included_in_alias(k, cand)
  190. ) and op.has_kernel_for_dispatch_key(cand):
  191. if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key(
  192. torch._C._dispatch_autogradother_backends
  193. ):
  194. raise RuntimeError("ambiguous autogradother kernel")
  195. elif not has_backend_kernel:
  196. return cand
  197. # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
  198. cand = DispatchKey.Autograd
  199. if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
  200. return cand
  201. # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
  202. cand = DispatchKey.FuncTorchBatchedDecomposition
  203. if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
  204. return cand
  205. # Backend fallback
  206. if torch._C._dispatch_has_backend_fallback(k):
  207. # The dispatch key itself will implicitly route to backend fallback.
  208. # This is probably not great for the pure Python implementation.
  209. return k
  210. raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
  211. _higher_order_ops: dict[str, "HigherOrderOperator"] = {}
  212. _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
  213. DispatchKey.PythonDispatcher, # type: ignore[attr-defined]
  214. DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined]
  215. DispatchKey.ADInplaceOrView,
  216. DispatchKey.BackendSelect,
  217. DispatchKey.AutocastCPU, # type: ignore[attr-defined]
  218. DispatchKey.AutocastCUDA, # type: ignore[attr-defined]
  219. DispatchKey.AutocastXPU, # type: ignore[attr-defined]
  220. ]
  221. class HigherOrderOperator(OperatorBase, abc.ABC):
  222. # The HigherOrderOperator will appear as torch.ops.higher_order.{name}
  223. #
  224. # If you're creating a new HigherOrderOperator, please do not change the
  225. # default. Adding operators to the global torch.ops namespace is a bad
  226. # practice due to name collisions.
  227. def __init__(self, name, *, cacheable=False):
  228. super().__init__()
  229. if type(self) is HigherOrderOperator:
  230. raise RuntimeError(
  231. "Direct instantiation of HigherOrderOperator is not allowed. Please subclass it."
  232. )
  233. self._name = name
  234. # Make _OPNamespace not scream, this whole name based association needs a good hard look
  235. self.__name__ = name
  236. _higher_order_ops[name] = self
  237. self._ns = "higher_order"
  238. self.__module__ = "torch.ops.higher_order"
  239. self._cacheable = cacheable
  240. self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
  241. for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS:
  242. self.fallthrough(dispatch_key)
  243. # [NOTE] We have to register pre-dispatch key implementation
  244. # because sometimes HOP use aot-dispatch tracing to detect certain
  245. # mutations. This is problematic when we are functionalizing HOP
  246. # during pre-dispatch because when the inner tracer starts, it will see
  247. # that PreDispatch key is still active. In that case, we just redispatch
  248. # it to next key. This is only safe to do when PreDispatch key stack has no
  249. # active modes.
  250. def py_impl(
  251. self,
  252. k: type[TorchDispatchMode] | type[torch.Tensor] | TransformType | DispatchKey,
  253. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  254. if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
  255. self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
  256. return super().py_impl(k)
  257. def py_autograd_impl(
  258. self,
  259. fn: Callable[_P, _T],
  260. ) -> Callable[_P, _T]:
  261. def maybe_run_autograd(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  262. if not torch.is_grad_enabled() or pytree.tree_all_only(
  263. torch.Tensor,
  264. lambda t: not t.requires_grad, # type: ignore[union-attr]
  265. (*args, kwargs),
  266. ):
  267. with torch._C._AutoDispatchBelowAutograd():
  268. return self(*args, **kwargs)
  269. from torch._higher_order_ops.utils import _has_gen_schema
  270. if _has_gen_schema(self):
  271. schema = self.gen_schema(*args, **kwargs)
  272. if any(arg.is_write for arg in schema.arguments):
  273. raise RuntimeError(
  274. f"The {self.name()} HigherOrderOperator does not currently support training "
  275. "with in-place input or buffer mutations "
  276. "If you require this feature, please submit an issue to PyTorch. "
  277. "Alternatively, consider creating your own custom autograd.Function. "
  278. )
  279. return fn(*args, **kwargs)
  280. self.py_impl(DispatchKey.Autograd)(maybe_run_autograd)
  281. return fn
  282. @property
  283. def namespace(self):
  284. return self._ns
  285. @final
  286. def cacheable(self) -> bool:
  287. from torch._functorch.autograd_function import AutogradFunctionApply
  288. return (
  289. self._cacheable
  290. or f"{self.__module__}.{self.__name__}"
  291. in torch._inductor.config.unsafe_marked_cacheable_functions
  292. or (
  293. isinstance(self, AutogradFunctionApply)
  294. and torch._functorch.config.autograd_cache_allow_custom_autograd_functions
  295. )
  296. )
  297. def fallthrough(self, dispatch_key):
  298. self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
  299. # Use positional-only argument to avoid naming collide with custom ops arguments
  300. # that are named "self".
  301. def dispatch(self, /, dispatch_key, *args, **kwargs):
  302. from torch.utils._python_dispatch import _get_current_dispatch_mode
  303. if dispatch_key in self._dispatch_cache:
  304. kernel = self._dispatch_cache[dispatch_key]
  305. assert not isinstance(kernel, DispatchKey)
  306. return kernel(*args, **kwargs)
  307. if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode:
  308. return dispatch_functorch(self, args, kwargs)
  309. if dispatch_key == DispatchKey.Python:
  310. # Keep the following 1:1 with handle_torch_function_no_python_arg_parser
  311. # in torch/csrc/utils/python_arg_parser.cpp
  312. overloaded_args_list = []
  313. def has_python_key(tensor):
  314. return torch._C._dispatch_keys(tensor).has("Python")
  315. def check_overloaded(arg):
  316. if isinstance(arg, torch.Tensor) and has_python_key(arg):
  317. overloaded_args_list.append(arg)
  318. for arg in (*args, *kwargs.values()):
  319. check_overloaded(arg)
  320. if isinstance(arg, (list, tuple)):
  321. for a in arg:
  322. check_overloaded(a)
  323. overloaded_args = tuple(overloaded_args_list)
  324. # Step 1: dispatch on any user TorchDispatchModes
  325. from torch.utils._python_dispatch import _pop_mode_temporarily
  326. curr_mode = _get_current_dispatch_mode()
  327. if curr_mode is not None:
  328. if type(curr_mode) in self.python_key_table:
  329. handler = self.python_key_table[type(curr_mode)]
  330. with _pop_mode_temporarily() as mode:
  331. # "natural" calling convention: (mode, *args, **kwargs)
  332. # TODO(rzou): we should support torch_dispatch calling convention too.
  333. result = handler(mode, *args, **kwargs)
  334. else:
  335. if curr_mode.supports_higher_order_operators:
  336. with _pop_mode_temporarily() as mode:
  337. return curr_mode.__torch_dispatch__(self, [], args, kwargs)
  338. else:
  339. raise NotImplementedError(
  340. f"There was no rule registered for HigherOrderOperator {self._name} and mode {curr_mode}."
  341. f"Hint: set {curr_mode}'s supports_higher_order_operators to True."
  342. f" This causes all higher order operators to pass through {curr_mode}'s __torch_dispatch__,"
  343. f" so handle them accordingly by"
  344. f" adding support for HigerOrderOperators (in this case, {self._name}) in"
  345. f" {curr_mode}.__torch_dispatch__ or"
  346. f" returning NotImplemented when not supported."
  347. )
  348. if result is not NotImplemented:
  349. return result
  350. # Step 2: dispatch on any subclasses
  351. for arg in overloaded_args:
  352. subclass_type = type(arg)
  353. if (
  354. subclass_type.__torch_dispatch__
  355. is torch._C._disabled_torch_dispatch_impl
  356. ):
  357. continue
  358. # In some case, people are using FakeTensor without a FakeTensorMode.
  359. # For example, some sparse arch model has a mix of FakeTensor and real
  360. # tensor for weights during lowering, and ppl tends to run eager evaluation
  361. # on the model without setting up the FakeTensorMode.
  362. # In this case, we pull FakeTensorMode impl.
  363. if subclass_type is torch._subclasses.fake_tensor.FakeTensor:
  364. subclass_type = torch._subclasses.fake_tensor.FakeTensorMode # type: ignore[assignment]
  365. handler = self.python_key_table[subclass_type]
  366. result = handler(arg.fake_mode, *args, **kwargs) # type: ignore[attr-defined]
  367. return result
  368. if subclass_type in self.python_key_table:
  369. handler = self.python_key_table[subclass_type]
  370. # "natural" calling convention: (*args, **kwargs)
  371. # TODO(rzou): we should support torch_dispatch calling convention too.
  372. result = handler(*args, **kwargs)
  373. else:
  374. raise NotImplementedError(
  375. f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. "
  376. f"We recommend filing an issue."
  377. )
  378. if result is not NotImplemented:
  379. return result
  380. # All handlers returned NotImplemented
  381. raise TypeError(
  382. f"HigherOrderOperator '{self._name}' is not supported for the given input types. "
  383. f"This typically happens when using custom tensor types or dispatch modes that don't "
  384. f"have implementations for this operation.\n\n"
  385. f"Current mode: {curr_mode}\n"
  386. f"Input types: {[type(a).__name__ for a in overloaded_args]}\n\n"
  387. f"To fix this, can add support for '{self._name}' in {curr_mode}'s __torch_dispatch__\n"
  388. )
  389. functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined]
  390. if functionality_key == DispatchKey.PreDispatch:
  391. from torch.utils._python_dispatch import _pop_mode_temporarily
  392. # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
  393. # calls inside of a mode.
  394. if (
  395. _len_torch_dispatch_stack_pre_dispatch() > 0
  396. ) and not torch._C._dispatch_tls_is_dispatch_key_excluded(
  397. DispatchKey.Python
  398. ):
  399. curr_mode = _get_current_dispatch_mode_pre_dispatch()
  400. assert curr_mode is not None, (
  401. "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode."
  402. )
  403. assert type(curr_mode) in self.python_key_table, (
  404. f"Current active mode {curr_mode} not registered"
  405. )
  406. handler = self.python_key_table[type(curr_mode)]
  407. with _pop_mode_temporarily(functionality_key) as mode:
  408. return handler(mode, *args, **kwargs)
  409. final_key = resolve_key(self, dispatch_key)
  410. # This can current fail due to backend fallbacks. You just have to
  411. # register them by hand for HigherOrderOperator.
  412. if final_key not in self.py_kernels:
  413. raise NotImplementedError(
  414. f"could not find kernel for HigherOrderOperator {self._name} "
  415. f"at dispatch key {final_key} (resolved from {dispatch_key})"
  416. )
  417. # [NOTE] We shouldn't cache PreDispatch kernel here because depending
  418. # on what modes are active, predispatch behaviour is different.
  419. # Also we do same thing for normal ops:
  420. # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
  421. if dispatch_key != DispatchKey.PreDispatch:
  422. self._dispatch_cache[dispatch_key] = self.py_kernels[final_key]
  423. kernel = self.py_kernels[final_key]
  424. # It's illegal to register DispatchKey to py_kernels, since there's no
  425. # C++ kernel to call into
  426. assert not isinstance(kernel, DispatchKey)
  427. return kernel(*args, **kwargs)
  428. @abc.abstractmethod
  429. def __call__(self, /, *args, **kwargs):
  430. flat_args = _to_flat_tuple(args, kwargs)
  431. if torch.overrides.has_torch_function(flat_args):
  432. return torch.overrides.handle_torch_function(
  433. self, flat_args, *args, **kwargs
  434. )
  435. dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
  436. return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
  437. # NOTE [HigherOrderOperator Schema]
  438. # Each invocation of a HigherOrderOperator (hop) should have its own schema because
  439. # the subgraphs and the arguments can be different even for the same hop.
  440. #
  441. # Each hop should implement its own gen_schema method, which should
  442. # take the same input as the __call__ method and returns a FunctionSchema.
  443. # The schema provides a unified way to check if the hop mutates its inputs,
  444. # which can be useful in implementing optimizations.
  445. #
  446. # If the hop doesn't implement the gen_schema method,
  447. # we expect it to be functional. It should not mutate its inputs and there
  448. # are no input, output aliasing via views or direct referencing.
  449. def gen_schema(self, *args, **kwargs):
  450. raise NotImplementedError(
  451. f"HigherOrderOperator {self._name} does not implement a gen_schema. "
  452. f"This is OK as long as the hop is functional. "
  453. f"e.g. it should not mutate its inputs and there are no input, output aliasing "
  454. f"via views or direct referencing."
  455. )
  456. def __str__(self):
  457. return f"{self.name()}"
  458. def name(self):
  459. return self._name
  460. # it's a no-op since HigherOrderOperator is immutable and must be unique for a given op.
  461. def __deepcopy__(self, memo=None):
  462. return self
  463. def _to_flat_tuple(args, kwargs):
  464. return pytree.arg_tree_leaves(*args, **kwargs)
  465. def _compute_keyset(args, kwargs, non_fallthrough_keys):
  466. tensors = _get_tensors(args, kwargs)
  467. return key_extractor(tensors, non_fallthrough_keys)
  468. def _get_tensors(args, kwargs):
  469. flat_all = _to_flat_tuple(args, kwargs)
  470. tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
  471. return tuple(tensor_args)
  472. # Note - this should maintain identical impl to the C++ dispatcher key extraction logic
  473. # at ATen/core/dispatch/DispatchKeyExtractor.h
  474. def key_extractor(tensors, key_mask):
  475. key_set = torch._C._dispatch_tls_local_include_set()
  476. for tensor in tensors:
  477. key_set = key_set | torch._C._dispatch_keys(tensor)
  478. key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
  479. key_set = key_set & key_mask
  480. return key_set
  481. # Mode stack for PreDispatchKey
  482. # it should always have three keys with
  483. # priority given to FunctionalTensorMode and
  484. # then ProxyTorchDispatchMode. It means that
  485. # slot 0 belongs to ProxyTorchDispatchMode and
  486. # slot 1 belongs to FunctionalTensorMode.
  487. #
  488. # SchemaCheckMode is separate from the other 2,
  489. # and is only valid when the stack is empty.
  490. # SchemaCheckMode is for testing purposes, and
  491. # is meant to run in eager mode on concrete inputs,
  492. # checking for incorrect schemas in regards to
  493. # aliasing or mutating ops.
  494. class _ModeStackStateForPreDispatch:
  495. def __init__(self):
  496. self.__infra_modes = [None, None]
  497. self._schema_check_mode = None
  498. def set(self, index, mode):
  499. assert index < len(self.__infra_modes)
  500. self.__infra_modes[index] = mode
  501. def get(self, index):
  502. assert index < len(self.__infra_modes)
  503. return self.__infra_modes[index]
  504. def count(self):
  505. return len([i for i in self.__infra_modes if i is not None]) + int(
  506. self._schema_check_mode is not None
  507. )
  508. _mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch()
  509. def unset_mode_pre_dispatch(mode_key, schema_check=False):
  510. current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch()
  511. assert mode_key is None or mode_key in (
  512. torch._C._TorchDispatchModeKey.PROXY,
  513. torch._C._TorchDispatchModeKey.FUNCTIONAL,
  514. )
  515. if schema_check:
  516. assert mode_key is None
  517. def _unset_mode():
  518. # NOTE: Using `is` rather than `==` to work around slow enum comparison in
  519. # pybind11.
  520. if mode_key is torch._C._TorchDispatchModeKey.PROXY:
  521. current_mode = current_mode_stack_pre_dispatch.get(0)
  522. mode_stack_state_for_pre_dispatch().set(0, None)
  523. return current_mode
  524. elif mode_key is torch._C._TorchDispatchModeKey.FUNCTIONAL:
  525. current_mode = current_mode_stack_pre_dispatch.get(1)
  526. mode_stack_state_for_pre_dispatch().set(1, None)
  527. return current_mode
  528. else:
  529. current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
  530. mode_stack_state_for_pre_dispatch()._schema_check_mode = None
  531. return current_mode
  532. current_mode = _unset_mode()
  533. new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
  534. # When we are unsetting a mode, we need to check if there is
  535. # active mode left on the PreDispatch key. If there is nothing
  536. # active, we need to remove PreDispatch key from local dispatch include
  537. # set.
  538. if new_pre_dispatch_len == 0:
  539. torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False)
  540. return current_mode
  541. def _set_mode_pre_dispatch(mode):
  542. from torch._subclasses.functional_tensor import FunctionalTensorMode
  543. from torch._subclasses.schema_check_mode import SchemaCheckMode
  544. from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
  545. assert isinstance(
  546. mode,
  547. (
  548. FunctionalTensorMode,
  549. ProxyTorchDispatchMode,
  550. SchemaCheckMode,
  551. ),
  552. )
  553. previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch()
  554. if isinstance(mode, SchemaCheckMode):
  555. current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
  556. if previous_mode_stack_len > 0:
  557. raise AssertionError(
  558. "SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack"
  559. )
  560. mode_stack_state_for_pre_dispatch()._schema_check_mode = mode
  561. elif isinstance(mode, FunctionalTensorMode):
  562. current_mode = mode_stack_state_for_pre_dispatch().get(1)
  563. assert current_mode is None
  564. mode_stack_state_for_pre_dispatch().set(1, mode)
  565. else:
  566. current_mode = mode_stack_state_for_pre_dispatch().get(0)
  567. assert current_mode is None
  568. mode_stack_state_for_pre_dispatch().set(0, mode)
  569. # When we are setting a mode, we need to check if there is
  570. # active mode left on the PreDispatch key. If there was nothing
  571. # active before setting this mode, it means that PreDispatch key
  572. # was turned off. So we need to turn it on again.
  573. if previous_mode_stack_len == 0:
  574. torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True)
  575. def _pop_mode_from_pre_dispatch():
  576. mode_stack = mode_stack_state_for_pre_dispatch()
  577. pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
  578. if pre_dispatch_len == 0:
  579. raise AssertionError("Trying to pop empty mode stack")
  580. if mode_stack._schema_check_mode is not None:
  581. return unset_mode_pre_dispatch(None, schema_check=True)
  582. if mode_stack.get(1) is not None:
  583. return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL)
  584. if mode_stack.get(0) is not None:
  585. return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
  586. def _len_torch_dispatch_stack_pre_dispatch():
  587. return mode_stack_state_for_pre_dispatch().count()
  588. def _get_dispatch_mode_pre_dispatch(mode_key):
  589. # NOTE: Using `is` rather than `==` to work around slow enum comparison in pybind11.
  590. if mode_key is torch._C._TorchDispatchModeKey.PROXY:
  591. return mode_stack_state_for_pre_dispatch().get(0)
  592. else:
  593. assert mode_key is torch._C._TorchDispatchModeKey.FUNCTIONAL
  594. return mode_stack_state_for_pre_dispatch().get(1)
  595. def _get_current_dispatch_mode_pre_dispatch():
  596. if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None:
  597. return mode_stack_state_for_pre_dispatch()._schema_check_mode
  598. else:
  599. stack_len = mode_stack_state_for_pre_dispatch().count()
  600. if stack_len == 2:
  601. return mode_stack_state_for_pre_dispatch().get(1)
  602. if stack_len == 1:
  603. return (
  604. mode_stack_state_for_pre_dispatch().get(1)
  605. if mode_stack_state_for_pre_dispatch().get(1) is not None
  606. else mode_stack_state_for_pre_dispatch().get(0)
  607. )
  608. return None
  609. def mode_stack_state_for_pre_dispatch():
  610. global _mode_stack_state_for_pre_dispatch
  611. return _mode_stack_state_for_pre_dispatch
  612. cached_ops: set["OpOverload"] = set()
  613. def add_cached_op(op_overload):
  614. global cached_ops
  615. cached_ops.add(op_overload)
  616. def reset_cached_ops():
  617. global cached_ops
  618. cached_ops.clear()
  619. def get_cached_ops():
  620. global cached_ops
  621. return cached_ops
  622. # Each OpOverload object contains pointer to a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
  623. # You can obtain an OpOverload object through attribute query on OpOverloadPacket.
  624. class OpOverload(OperatorBase, Generic[_P, _T]):
  625. def __init__(
  626. self,
  627. overloadpacket: "OpOverloadPacket",
  628. op: Callable[_P, _T],
  629. op_dk: Callable[Concatenate[DispatchKey, _P], _T],
  630. schema: torch._C.FunctionSchema,
  631. tags: list[Any],
  632. ) -> None:
  633. super().__init__()
  634. self._op = op
  635. self._op_dk = op_dk
  636. self._schema = schema
  637. self._overloadpacket = overloadpacket
  638. self._tags = tags
  639. self._overloadname = (
  640. "default" if schema.overload_name == "" else schema.overload_name
  641. )
  642. if tags:
  643. self._nondeterministic_seeded = torch.Tag.nondeterministic_seeded in tags
  644. self._name = self._schema.name
  645. if schema.overload_name:
  646. self._name += "." + schema.overload_name
  647. self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}"
  648. self.__module__ = overloadpacket.__module__
  649. op.__module__ = overloadpacket.__module__
  650. self.__qualname__ = self._name
  651. self.__annotations__ = {}
  652. # If the OpOverload was constructed from a Library.def in Python.
  653. self._defined_in_python = self.__qualname__ in torch.library._defs
  654. # Logic replicated from aten/src/ATen/native/MathBitsFallback.h
  655. is_write = None
  656. for a in self._schema.arguments: # pyrefly: ignore # bad-assignment
  657. if a.alias_info is None:
  658. continue
  659. if is_write is None:
  660. is_write = a.alias_info.is_write
  661. else:
  662. # We will conservatively call mixed mutable/non-mutable
  663. # aliased inputs as NOT a view
  664. is_write = a.alias_info.is_write or is_write
  665. self.is_view = is_write is not None and not is_write
  666. @cached_property
  667. def _namespace(self) -> str:
  668. return self._schema.name.split("::", maxsplit=1)[0]
  669. @cached_property
  670. def _opname(self) -> str:
  671. return self._schema.name.split("::", maxsplit=1)[1]
  672. @cached_property
  673. def _handle(self) -> torch._C._DispatchOperatorHandle:
  674. return torch._C._dispatch_find_schema_or_throw(
  675. self._schema.name, self._schema.overload_name
  676. )
  677. # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
  678. def __deepcopy__(self, memo=None):
  679. return self
  680. def __repr__(self):
  681. return f"<OpOverload(op='{self._namespace}.{self._opname}', overload='{self._overloadname}')>"
  682. # Use positional-only argument to avoid naming collision with aten ops arguments
  683. # that are named "self". This way, all the aten ops can be called by kwargs.
  684. def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
  685. return self._op(*args, **kwargs)
  686. # Use positional-only argument to avoid naming collision with aten ops arguments
  687. # that are named "self". This way, all the aten ops can be called by kwargs.
  688. def redispatch(
  689. self, /, keyset: torch._C.DispatchKeySet, *args: _P.args, **kwargs: _P.kwargs
  690. ) -> _T:
  691. return self._handle.redispatch_boxed(keyset, *args, **kwargs) # type: ignore[return-value]
  692. def __hash__(self):
  693. return hash(self._op)
  694. # `my_namespace.my_op_name.overload_name`
  695. def __str__(self):
  696. return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
  697. def has_kernel_for_dispatch_key(self, k: DispatchKey) -> bool:
  698. return super().has_kernel_for_dispatch_key(
  699. k
  700. ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k)
  701. def has_kernel_for_any_dispatch_key(self, ks: torch._C.DispatchKeySet) -> bool:
  702. return torch._C._dispatch_has_kernel_for_any_dispatch_key(
  703. self.name(), ks
  704. ) or super().has_kernel_for_any_dispatch_key(ks)
  705. @property
  706. def namespace(self) -> str:
  707. return self._namespace
  708. def _can_decompose(self) -> bool:
  709. dk = DispatchKey.CompositeImplicitAutograd
  710. return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key(
  711. self.name(), dk
  712. )
  713. def decompose(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
  714. dk = DispatchKey.CompositeImplicitAutograd
  715. if dk in self.py_kernels:
  716. # NB: This branch is not too necessary anymore, because we can
  717. # apply Python CompositeImplicitAutograd *before* tracing
  718. # using Python dispatcher (also taking advantage of the autograd
  719. # formula). But it's included for completeness
  720. return self.py_kernels[dk](*args, **kwargs)
  721. elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
  722. return self._op_dk(dk, *args, **kwargs)
  723. else:
  724. return NotImplemented # pyrefly: ignore [bad-return]
  725. # Remove a dispatch key from the dispatch cache. This will force it to get
  726. # recomputed the next time. Does nothing
  727. # WARNING: if you register a dispatch key to py_kernels of an OpOverload,
  728. # calling _del_dispatch on that key is NOT sufficient to apply your change,
  729. # because a single registration may affect MULTIPLE dispatch keys (e.g.,
  730. # registering Autograd affects AutogradCPU). del_dispatch is to be used
  731. # only if you are specifically modifying how get_dispatch handles a
  732. # particular input 'key'.
  733. def _uncache_dispatch(self, key: DispatchKey) -> None:
  734. self._dispatch_cache.pop(key, None)
  735. # This implements the pre-computation logic for the Python dispatcher.
  736. def _get_dispatch(self, key: DispatchKey) -> DispatchKey | Callable[_P, _T]:
  737. # This is only called upon a cache miss
  738. assert key not in self._dispatch_cache, f"{self} {key}"
  739. if key == DispatchKey.Python:
  740. if not isinstance(self, TorchBindOpOverload) and not self.python_key_table:
  741. self._dispatch_cache[key] = key
  742. add_cached_op(self)
  743. return key
  744. def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  745. from torch.utils._python_dispatch import _get_current_dispatch_mode
  746. # TODO: We also need to handle tensor subclasses here
  747. # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
  748. curr_mode = type(_get_current_dispatch_mode())
  749. assert curr_mode is not None, (
  750. "Illegal invocation of dispatch on DispatchKey.Python without a mode."
  751. )
  752. if curr_mode not in self.python_key_table:
  753. if isinstance(self, TorchBindOpOverload):
  754. with (
  755. torch.utils._python_dispatch._pop_mode_temporarily() as mode
  756. ):
  757. return torch._library.utils.handle_dispatch_mode(
  758. mode, self, *args, **kwargs
  759. )
  760. else:
  761. return self._op_dk(key, *args, **kwargs)
  762. with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
  763. return self.python_key_table[curr_mode](mode, *args, **kwargs) # type: ignore[index]
  764. self._dispatch_cache[key] = handler
  765. add_cached_op(self)
  766. return handler
  767. functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined]
  768. if functionality_key == DispatchKey.PreDispatch:
  769. curr_stack_len = _len_torch_dispatch_stack_pre_dispatch()
  770. # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
  771. # calls inside of a mode.
  772. if (
  773. curr_stack_len > 0
  774. and not torch._C._dispatch_tls_is_dispatch_key_excluded(
  775. DispatchKey.Python
  776. )
  777. ):
  778. def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  779. @contextlib.contextmanager
  780. def _temporarily_pop_modes_from_pre_dispatch():
  781. top_mode = _pop_mode_from_pre_dispatch()
  782. try:
  783. yield top_mode
  784. finally:
  785. _set_mode_pre_dispatch(top_mode)
  786. with _temporarily_pop_modes_from_pre_dispatch() as curr_mode:
  787. return torch._library.utils.handle_dispatch_mode(
  788. curr_mode, self, *args, **kwargs
  789. )
  790. # Note [Not Caching Per-Dispatch-Key Mode Handlers]
  791. # Note that we're not caching this handler. There isn't really a point, since the slow bit
  792. # is the handler itself (in python).
  793. # Also, not caching means that we don't have to reset the cache when any existing
  794. # modes go out of scope (which in of itself takes time to loop through all operators).
  795. return handler
  796. final_key = resolve_key(self, key)
  797. # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
  798. cache_result = key != DispatchKey.PreDispatch
  799. # TODO: We could potentially have lots of debugging wrappers against
  800. # dispatch keys; design some general registration mechanism instead of
  801. # having if statement for each of them
  802. if key == DispatchKey.Functionalize:
  803. import torch._dispatch.python as pydispatch
  804. if pydispatch.CROSSREF_FUNCTIONALIZE:
  805. handler = pydispatch.make_crossref_functionalize(self, final_key) # type: ignore[assignment]
  806. if cache_result:
  807. self._dispatch_cache[key] = handler
  808. add_cached_op(self)
  809. return handler
  810. r = self.py_kernels.get(final_key, final_key)
  811. if cache_result:
  812. self._dispatch_cache[key] = r # pyrefly: ignore [unsupported-operation]
  813. add_cached_op(self)
  814. return r # pyrefly: ignore [bad-return]
  815. def name(self):
  816. return self._name
  817. @property
  818. def overloadpacket(self):
  819. return self._overloadpacket
  820. @property
  821. def op(self):
  822. return self._op
  823. @property
  824. def tags(self):
  825. return self._tags
  826. # TODO: add more methods to expose information about input and output arguments
  827. # TorchBindOpOverload are those custom ops which have at least one overload's
  828. # schema consists of torch.ScriptObject (i.e. custom class) input.
  829. # TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
  830. # when its inputs contain FakeScriptObject in a similar way as higher order ops.
  831. class TorchBindOpOverload(OpOverload[_P, _T]):
  832. def _fallthrough_keys(self) -> list[DispatchKey]:
  833. # TODO: we should be calling the fallback for these, but a fallthrough is almost close
  834. # enough to the fallback in most cases that we care about.
  835. _DEFAULT_FALLTHROUGH_KEYS = [
  836. DispatchKey.Autograd,
  837. DispatchKey.AutogradCPU,
  838. DispatchKey.AutogradCUDA,
  839. DispatchKey.ADInplaceOrView,
  840. DispatchKey.BackendSelect,
  841. DispatchKey.PythonTLSSnapshot,
  842. DispatchKey.PythonDispatcher,
  843. DispatchKey.Functionalize,
  844. ]
  845. def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
  846. if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
  847. return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
  848. self.name(), key
  849. )
  850. return (
  851. key not in self.py_kernels
  852. or self.py_kernels[key] is torch.library.fallthrough_kernel
  853. )
  854. return [
  855. key
  856. for key in _DEFAULT_FALLTHROUGH_KEYS
  857. if _may_use_fallthrough_instead_of_fallback(key)
  858. ]
  859. # Use positional-only argument to avoid naming collision with aten ops arguments
  860. # that are named "self". This way, all the aten ops can be called by kwargs.
  861. def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
  862. if _must_dispatch_in_python(args, kwargs):
  863. # When any inputs are FakeScriptObject, we need to
  864. # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
  865. # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject.
  866. return self._dispatch_in_python(self._fallthrough_keys(), *args, **kwargs)
  867. return self._op(*args, **kwargs)
  868. def _dispatch_in_python(
  869. self, fallthrough_keys: list[DispatchKey], *args: _P.args, **kwargs: _P.kwargs
  870. ) -> _T:
  871. non_fallthrough_keys = torch._C._dispatch_keyset_full()
  872. for key in fallthrough_keys:
  873. non_fallthrough_keys = non_fallthrough_keys.remove(key)
  874. dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys)
  875. dispatch_key = dispatch_key_set.highestPriorityTypeId()
  876. handler = (
  877. self._get_dispatch(dispatch_key)
  878. if dispatch_key not in self._dispatch_cache
  879. else self._dispatch_cache[dispatch_key]
  880. )
  881. if isinstance(handler, DispatchKey):
  882. # fallthrough keys can be registered at runtime via torch.library.impl
  883. # so need to add it to fallthrough_keys and re-dispatch.
  884. if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
  885. self.name(), dispatch_key
  886. ):
  887. return self._dispatch_in_python(
  888. fallthrough_keys + [dispatch_key],
  889. *args,
  890. **kwargs,
  891. )
  892. raise RuntimeError(
  893. f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}."
  894. f" but no python implementation is found."
  895. f" Please file an issue on this when you encounter this error."
  896. f" This error can happen when you export or compile the model."
  897. f" It can still happen even if a C++ implementation for {dispatch_key}. "
  898. f" has been registered. That's because FakeScriptObject purely lives in python and cannot work "
  899. f" with a C++ implementation."
  900. )
  901. assert isinstance(handler, Callable) # type: ignore[arg-type]
  902. return handler(*args, **kwargs) # pyrefly: ignore [bad-return]
  903. def _must_dispatch_in_python(args, kwargs):
  904. return pytree.tree_any(
  905. lambda obj: isinstance(
  906. obj, torch._library.fake_class_registry.FakeScriptObject
  907. ),
  908. (args, kwargs),
  909. )
  910. def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
  911. return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments)
  912. # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
  913. # You can obtain an OpOverload object through attribute query.
  914. class OpOverloadPacket(Generic[_P, _T]):
  915. __file__: ClassVar[str] = "torch.ops"
  916. def __init__(
  917. self,
  918. qualified_op_name: str,
  919. op_name: str,
  920. op: Callable[_P, _T],
  921. overload_names: list[str],
  922. ) -> None:
  923. # These attributes are accessible on the object through the properties
  924. # defined below but are immutable
  925. self._qualified_op_name = qualified_op_name
  926. self.__name__ = op_name
  927. self._op = op
  928. self._overload_names = overload_names
  929. self._dir: list[str] = []
  930. self._has_torchbind_op_overload = any(
  931. _has_script_object_arg(schema) for schema in self._schemas.values()
  932. )
  933. # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
  934. def __deepcopy__(self, memo=None):
  935. return self
  936. def __repr__(self):
  937. return "<OpOverloadPacket(op='{}.{}')>".format(
  938. *self._qualified_op_name.split("::")
  939. )
  940. def __hash__(self):
  941. return hash(self._op)
  942. def __str__(self):
  943. return "{}.{}".format(*self._qualified_op_name.split("::"))
  944. @property
  945. def op(self):
  946. return self._op
  947. @property
  948. def _schemas(self):
  949. return {
  950. overload_name: torch._C._get_schema(self._qualified_op_name, overload_name)
  951. for overload_name in self._overload_names
  952. }
  953. def __getattr__(self, key: str) -> OpOverload[_P, _T]:
  954. # ensure that query for dunder attributes that does not exist on
  955. # opoverloadpacket but instead exists on the self._op object does not unnecessarily call
  956. # `_get_operation_overload` (which is an expensive operation).
  957. # This is done to prevent any potential slowdown. This list can be extended
  958. # if there exists other attributes like `__name__` that only exist on self._op and not on the
  959. # opoverloadpacket.
  960. # This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
  961. try:
  962. if key.startswith("__"):
  963. return getattr(self._op, key)
  964. except AttributeError:
  965. # for consistency because it seems weird to
  966. # throw an attribute error with a message containing
  967. # an object name different from the one the attribute
  968. # query was performed on.
  969. raise AttributeError(
  970. f"'{str(self)}' can't have an overload name beginning with '__' and the "
  971. f"underlying op {str(self._op)} has no attribute {key} either."
  972. ) from None
  973. try:
  974. # This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
  975. use_key = "" if key == "default" else key
  976. # TODO: disallow access to overloads registered by JIT
  977. op_dk_tags = torch._C._get_operation_overload(
  978. self._qualified_op_name, use_key
  979. )
  980. if op_dk_tags is None:
  981. raise AttributeError(
  982. f"The underlying op of '{str(self)}' has no overload name '{key}'"
  983. )
  984. op_, op_dk_, tags = op_dk_tags
  985. schema = torch._C._get_schema(self._qualified_op_name, use_key)
  986. overload: OpOverload[_P, _T] = (
  987. OpOverload(self, op_, op_dk_, schema, tags)
  988. if not _has_script_object_arg(schema)
  989. else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
  990. )
  991. # cache the overload object
  992. setattr(self, key, overload)
  993. self._dir.append(key)
  994. return overload
  995. except RuntimeError:
  996. raise AttributeError(
  997. f"The underlying op of '{str(self)}' has no overload name '{key}'"
  998. ) from None
  999. def __iter__(self) -> Iterator[str]:
  1000. return iter(self._dir)
  1001. # Use positional-only argument to avoid naming collision with aten ops arguments
  1002. # that are named "self". This way, all the aten ops can be called by kwargs.
  1003. def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
  1004. # overloading __call__ to ensure torch.ops.foo.bar()
  1005. # is still callable from JIT
  1006. # We save the function ptr as the `op` attribute on
  1007. # OpOverloadPacket to access it here.
  1008. # Directly calling OverloadPacket goes into C++, which will check
  1009. # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
  1010. # intercept it here and call TorchBindOpverload instead.
  1011. if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
  1012. # pyrefly: ignore [bad-argument-type]
  1013. return _call_overload_packet_from_python(self, *args, **kwargs)
  1014. return self._op(*args, **kwargs)
  1015. # TODO: use this to make a __dir__
  1016. def overloads(self):
  1017. return [n if n else "default" for n in self._overload_names]
  1018. # Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
  1019. # _jit_get_operations, which calls _get_operation_for_overload_or_packet.
  1020. def _call_overload_packet_from_python(
  1021. op: OpOverloadPacket[_P, _T], *args: _P.args, **kwargs: _P.kwargs
  1022. ) -> _T:
  1023. # Reuse the torch function handling logic in cpp
  1024. torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
  1025. op, *args, **kwargs
  1026. )
  1027. if torch_function_called:
  1028. return ret
  1029. # The following mirrors getOpWithStack.
  1030. # In cpp, we do a schema matching for the arguments, and call ToIValue to
  1031. # to check whether the arguments are valid. But need to do similar things here
  1032. # and check the schema whether the FakeScriptObject is the corresponding fake class
  1033. # of the actual class used in schema.
  1034. exceptions = {}
  1035. found_op = None
  1036. for overload_name in op.overloads():
  1037. op_overload = getattr(op, overload_name)
  1038. try:
  1039. _ = torch._C._check_schema_allow_fake_script_object(
  1040. op_overload._schema, *args, **kwargs
  1041. )
  1042. found_op = op_overload
  1043. break
  1044. except RuntimeError as e:
  1045. exceptions[overload_name] = e
  1046. if found_op:
  1047. return found_op(*args, **kwargs)
  1048. err_msg = (
  1049. f"Fail to match any TorchBindOverload of {op} with following exceptions:\n"
  1050. )
  1051. for key, msg in exceptions.items():
  1052. err_msg += f"Overload name {key}:\n {msg}\n"
  1053. raise RuntimeError(err_msg)
  1054. # Resolution of torch.fn is different from torch.ops.aten.fn
  1055. # torch.fn uses the Python argparser, matches with the
  1056. # appropriate schema, and calls into the unboxed version of the method
  1057. # torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
  1058. # JIT creates a stack of all the overloads and then tries to match the
  1059. # correct one at runtime and always calls into the boxed version of the method
  1060. # Autograd codegen creates VariableType, TracerType,
  1061. # inplace or view type and python bindings.
  1062. # Aten codegen generates tensor methods for the tensor class.
  1063. # _OpNamespace is a subclass of ModuleType because the torch script
  1064. # allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
  1065. # to work from script, we need to ensure ops and foo are modules
  1066. class _OpNamespace(types.ModuleType):
  1067. """
  1068. An op namespace to dynamically bind Operators into Python.
  1069. Say a user has created a custom Operator called "my_namespace::my_op". To
  1070. call this op, the user will write torch.ops.my_namespace.my_op(...).
  1071. At startup, this operation will not yet be bound into Python. Instead, the
  1072. following sequence of magic tricks will occur:
  1073. 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
  1074. on the `torch.ops` object, which will create a new `_OpNamespace`
  1075. object called `my_namespace` and set it as an attribute on the `ops`
  1076. object.
  1077. 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
  1078. the `my_namespace` object, which will retrieve the operation via
  1079. `torch.get_operation`, a function bound from C++, and then in a similar
  1080. fashion bind this new object onto the `my_namespace` object.
  1081. 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
  1082. and subsequent accesses will incur no further lookup (the namespace and
  1083. operation will already exist).
  1084. """
  1085. __file__ = "torch.ops"
  1086. def __init__(self, name: str) -> None:
  1087. super().__init__("torch.ops." + name)
  1088. self.name = name
  1089. self._dir: list[str] = []
  1090. def __iter__(self) -> Iterator[str]:
  1091. return iter(self._dir)
  1092. def __getattr__(self, op_name: str) -> OpOverloadPacket:
  1093. if op_name in ("__origin__", "__self__"):
  1094. raise AttributeError(
  1095. f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'"
  1096. )
  1097. # Get the op `my_namespace::my_op` if available. This will also check
  1098. # for overloads and raise an exception if there are more than one.
  1099. namespace_name = self.name
  1100. qualified_op_name = f"{namespace_name}::{op_name}"
  1101. module_name = self.__module__ + "." + namespace_name
  1102. try:
  1103. op, overload_names = _get_packet(qualified_op_name, module_name)
  1104. if op is None:
  1105. raise AttributeError(
  1106. f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
  1107. )
  1108. except RuntimeError as e:
  1109. # Turn this into AttributeError so getattr(obj, key, default)
  1110. # works (this is called by TorchScript with __origin__)
  1111. raise AttributeError(
  1112. f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
  1113. ) from e
  1114. op.__module__ = module_name
  1115. opoverloadpacket = OpOverloadPacket(
  1116. qualified_op_name, op_name, op, overload_names
  1117. )
  1118. opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
  1119. # cache the opoverloadpacket to ensure that each op corresponds to
  1120. # a unique OpOverloadPacket object
  1121. setattr(self, op_name, opoverloadpacket)
  1122. self._dir.append(op_name)
  1123. return opoverloadpacket
  1124. def _get_packet(qualname, op_module):
  1125. op, overload_names = torch._C._jit_get_operation(qualname)
  1126. if op is not None:
  1127. # let the script frontend know that op is identical to the builtin op
  1128. # with qualified_op_name
  1129. torch.jit._builtins._register_builtin(op, qualname)
  1130. op.__module__ = op_module
  1131. return op, overload_names
  1132. def _refresh_packet(packet):
  1133. op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__)
  1134. assert op is not None
  1135. packet._op = op
  1136. packet._overload_names = overload_names
  1137. class _HigherOrderNamespace(types.ModuleType):
  1138. __file__ = "torch.ops"
  1139. def __init__(self) -> None:
  1140. super().__init__("torch.ops.higher_order")
  1141. self._dir: list[str] = []
  1142. def __iter__(self) -> Iterator[str]:
  1143. return iter(self._dir)
  1144. def __getattr__(self, name: str) -> HigherOrderOperator:
  1145. # Following _OpNamespace.__getattr__, we cache the op on this object.
  1146. op = _higher_order_ops.get(name)
  1147. if op is None:
  1148. raise AttributeError(
  1149. f"'_HigherOrderNamespace' 'torch.ops.higher_order' object has no attribute '{name}'"
  1150. )
  1151. setattr(self, name, op)
  1152. self._dir.append(name)
  1153. return op
  1154. class _Ops(types.ModuleType):
  1155. __file__ = "_ops.py"
  1156. def __init__(self):
  1157. super().__init__("torch.ops")
  1158. self.loaded_libraries = set()
  1159. self.higher_order = _HigherOrderNamespace()
  1160. self._dir = []
  1161. def __getattr__(self, name: str) -> _OpNamespace:
  1162. # Here we are creating `torch.ops.my_namespace`
  1163. namespace = _OpNamespace(name)
  1164. setattr(self, name, namespace)
  1165. self._dir.append(name)
  1166. return namespace
  1167. def __iter__(self) -> Iterator[str]:
  1168. return iter(self._dir)
  1169. def import_module(self, module):
  1170. """
  1171. Imports a Python module that has torch.library registrations.
  1172. Generally, to extend PyTorch with custom operators, a user will
  1173. create a Python module whose import triggers registration of
  1174. the custom operators via a torch.ops.load_library call or a call
  1175. to one or more torch.library.* APIs.
  1176. It is unexpected for Python modules to have side effects, so some
  1177. linters and formatters will complain. Use this API to import Python
  1178. modules that contain these torch.library side effects.
  1179. Args:
  1180. module (str): The name of the Python module to import
  1181. """
  1182. importlib.import_module(module)
  1183. def load_library(self, path):
  1184. """
  1185. Loads a shared library from the given path into the current process.
  1186. The library being loaded may run global initialization code to register
  1187. custom operators with the PyTorch JIT runtime. This allows dynamically
  1188. loading custom operators. For this, you should compile your operator
  1189. and the static registration code into a shared library object, and then
  1190. call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
  1191. shared object.
  1192. After the library is loaded, it is added to the
  1193. ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
  1194. for the paths of all libraries loaded using this function.
  1195. Args:
  1196. path (str): A path to a shared library to load.
  1197. """
  1198. path = _utils_internal.resolve_library_path(path)
  1199. with dl_open_guard():
  1200. # Import the shared library into the process, thus running its
  1201. # static (global) initialization code in order to register custom
  1202. # operators with the JIT.
  1203. try:
  1204. ctypes.CDLL(path)
  1205. except Exception as e:
  1206. raise OSError(f"Could not load this library: {path}") from e
  1207. self.loaded_libraries.add(path)
  1208. # The ops "namespace"
  1209. ops: _Ops = _Ops()