overrides.py 103 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120
  1. """
  2. Python implementation of ``__torch_function__``
  3. While most of the torch API and handling for ``__torch_function__`` happens
  4. at the C++ level, some of the torch API is written in Python so we need
  5. python-level handling for ``__torch_function__`` overrides as well. The main
  6. developer-facing functionality in this file are handle_torch_function and
  7. has_torch_function. See torch/functional.py and test/test_overrides.py
  8. for usage examples.
  9. Note
  10. ----
  11. heavily inspired by NumPy's ``__array_function__`` (see:
  12. https://github.com/pytorch/pytorch/issues/24015 and
  13. https://www.numpy.org/neps/nep-0018-array-function-protocol.html
  14. )
  15. If changing this file in a way that can affect ``__torch_function__`` overhead,
  16. please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
  17. instructions in the ``README.md`` in that directory.
  18. """
  19. import __future__ # noqa: F404
  20. import collections
  21. import contextlib
  22. import functools
  23. import types
  24. import warnings
  25. from collections.abc import Iterable
  26. from functools import wraps
  27. from typing import Any, Callable, Optional, TypeVar
  28. from typing_extensions import ParamSpec
  29. import torch
  30. from torch._C import (
  31. _add_docstr,
  32. _get_function_stack_at,
  33. _has_torch_function,
  34. _has_torch_function_unary,
  35. _has_torch_function_variadic,
  36. _is_torch_function_mode_enabled,
  37. _len_torch_function_stack,
  38. _pop_torch_function_stack,
  39. _push_on_torch_function_stack,
  40. )
  41. __all__ = [
  42. "get_ignored_functions",
  43. "get_overridable_functions",
  44. "get_testing_overrides",
  45. "handle_torch_function",
  46. "has_torch_function",
  47. "resolve_name",
  48. "is_tensor_like",
  49. "is_tensor_method_or_property",
  50. "wrap_torch_function",
  51. "enable_reentrant_dispatch",
  52. ]
  53. _P = ParamSpec("_P")
  54. _R = TypeVar("_R")
  55. def _disable_user_warnings(
  56. func: Callable[_P, _R],
  57. regex: str = ".*is deprecated, please use.*",
  58. module: str = "torch",
  59. ) -> Callable[_P, _R]:
  60. """
  61. Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
  62. given ``regex`` pattern.
  63. Arguments
  64. ---------
  65. func : function
  66. Function to disable the warnings for.
  67. regex : str
  68. A regex pattern compilable by ``re.compile``. This is used to match the ``UserWarning`` message.
  69. module : str
  70. The python module to which the filtering should be restricted.
  71. Returns
  72. -------
  73. function
  74. The wrapped function.
  75. """
  76. @wraps(func)
  77. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  78. with warnings.catch_warnings():
  79. warnings.filterwarnings(
  80. "ignore", category=UserWarning, message=regex, module=module
  81. )
  82. return func(*args, **kwargs)
  83. return wrapper
  84. @functools.cache
  85. @_disable_user_warnings
  86. def get_ignored_functions() -> set[Callable]:
  87. """
  88. Return public functions that cannot be overridden by ``__torch_function__``.
  89. Returns
  90. -------
  91. set[Callable]
  92. A tuple of functions that are publicly available in the torch API but cannot
  93. be overridden with ``__torch_function__``. Mostly this is because none of the
  94. arguments of these functions are tensors or tensor-likes.
  95. Examples
  96. --------
  97. >>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
  98. True
  99. >>> torch.add in torch.overrides.get_ignored_functions()
  100. False
  101. """
  102. Tensor = torch.Tensor
  103. return {
  104. torch.typename,
  105. torch.is_tensor,
  106. torch.is_storage,
  107. torch.set_default_tensor_type,
  108. torch.set_default_device,
  109. torch.get_default_device,
  110. torch.set_rng_state,
  111. torch.get_rng_state,
  112. torch.manual_seed,
  113. torch.initial_seed,
  114. torch.seed,
  115. torch.save,
  116. torch.load,
  117. torch.set_printoptions,
  118. torch.fork,
  119. torch.get_default_dtype,
  120. torch.get_num_interop_threads,
  121. torch.get_num_threads,
  122. torch.init_num_threads,
  123. torch.import_ir_module,
  124. torch.import_ir_module_from_buffer,
  125. torch.is_anomaly_enabled,
  126. torch.is_anomaly_check_nan_enabled,
  127. torch.is_grad_enabled,
  128. torch.merge_type_from_type_comment,
  129. torch.parse_ir,
  130. torch.parse_schema,
  131. torch.parse_type_comment,
  132. torch.set_anomaly_enabled,
  133. torch.set_flush_denormal,
  134. torch.set_num_interop_threads,
  135. torch.set_num_threads,
  136. torch.wait,
  137. torch.as_tensor,
  138. torch.from_numpy,
  139. torch.tensor,
  140. torch.default_generator,
  141. torch.has_cuda,
  142. torch.has_cudnn,
  143. torch.has_lapack,
  144. torch.device,
  145. torch.dtype,
  146. torch.finfo,
  147. torch.has_mkl,
  148. torch.has_mps,
  149. torch.has_mkldnn,
  150. torch.has_openmp,
  151. torch.iinfo,
  152. torch.memory_format,
  153. torch.qscheme,
  154. torch.set_grad_enabled,
  155. torch.no_grad,
  156. torch.enable_grad,
  157. torch.inference_mode,
  158. torch.is_inference_mode_enabled,
  159. torch.layout,
  160. torch.align_tensors,
  161. torch.arange,
  162. torch.as_strided,
  163. torch.bartlett_window,
  164. torch.blackman_window,
  165. torch.broadcast_shapes,
  166. torch.can_cast,
  167. torch.compile,
  168. torch.cudnn_affine_grid_generator,
  169. torch.cudnn_batch_norm,
  170. torch.cudnn_convolution,
  171. torch.cudnn_convolution_transpose,
  172. torch.cudnn_convolution_relu,
  173. torch.cudnn_convolution_add_relu,
  174. torch.cudnn_grid_sampler,
  175. torch.cudnn_is_acceptable,
  176. torch.empty,
  177. torch.empty_permuted,
  178. torch.empty_strided,
  179. torch.empty_quantized,
  180. torch.export.export,
  181. torch.export.load,
  182. torch.export.register_dataclass,
  183. torch.export.save,
  184. torch.eye,
  185. torch.fft.fftfreq,
  186. torch.fft.rfftfreq,
  187. torch.from_file,
  188. torch.full,
  189. torch.fill,
  190. torch.hamming_window,
  191. torch.hann_window,
  192. torch.kaiser_window,
  193. torch.linspace,
  194. torch.logspace,
  195. torch.mkldnn_adaptive_avg_pool2d,
  196. torch.mkldnn_convolution,
  197. torch.mkldnn_max_pool2d,
  198. torch.mkldnn_max_pool3d,
  199. torch.mkldnn_linear_backward_weights,
  200. torch.mkldnn_rnn_layer,
  201. torch.normal,
  202. torch.ones,
  203. torch.promote_types,
  204. torch.rand,
  205. torch.randn,
  206. torch.randint,
  207. torch.randperm,
  208. torch.range,
  209. torch.result_type,
  210. torch.scalar_tensor,
  211. torch.sparse_coo_tensor,
  212. torch.sparse_compressed_tensor,
  213. torch.sparse_csr_tensor,
  214. torch.sparse_csc_tensor,
  215. torch.sparse_bsr_tensor,
  216. torch.sparse_bsc_tensor,
  217. torch.sym_constrain_range,
  218. torch.sym_constrain_range_for_size,
  219. torch.sym_fresh_size,
  220. torch.tril_indices,
  221. torch.triu_indices,
  222. torch.vander,
  223. torch.zeros,
  224. torch._jit_internal.boolean_dispatch,
  225. torch.nn.functional.assert_int_or_pair,
  226. torch.nn.functional.upsample,
  227. torch.nn.functional.upsample_bilinear,
  228. torch.nn.functional.upsample_nearest,
  229. torch.nn.functional.has_torch_function,
  230. torch.nn.functional.has_torch_function_unary,
  231. torch.nn.functional.has_torch_function_variadic,
  232. torch.nn.functional.handle_torch_function,
  233. torch.nn.functional.sigmoid,
  234. torch.nn.functional.hardsigmoid,
  235. torch.nn.functional.tanh,
  236. torch.nn.functional._canonical_mask,
  237. torch.nn.functional._none_or_dtype,
  238. # Doesn't actually take or return tensor arguments
  239. torch.nn.init.calculate_gain,
  240. # These are deprecated; don't test them
  241. torch.nn.init.uniform,
  242. torch.nn.init.normal,
  243. torch.nn.init.constant,
  244. torch.nn.init.eye,
  245. torch.nn.init.dirac,
  246. torch.nn.init.xavier_uniform,
  247. torch.nn.init.xavier_normal,
  248. torch.nn.init.kaiming_uniform,
  249. torch.nn.init.kaiming_normal,
  250. torch.nn.init.orthogonal,
  251. torch.nn.init.sparse,
  252. torch.nested.to_padded_tensor,
  253. has_torch_function,
  254. handle_torch_function,
  255. torch.set_autocast_enabled,
  256. torch.is_autocast_enabled,
  257. torch.set_autocast_dtype,
  258. torch.get_autocast_dtype,
  259. torch.clear_autocast_cache,
  260. torch.set_autocast_cpu_enabled,
  261. torch.is_autocast_cpu_enabled,
  262. torch.set_autocast_xla_enabled,
  263. torch.is_autocast_xla_enabled,
  264. torch.set_autocast_ipu_enabled,
  265. torch.is_autocast_ipu_enabled,
  266. torch.set_autocast_cpu_dtype,
  267. torch.get_autocast_cpu_dtype,
  268. torch.set_autocast_ipu_dtype,
  269. torch.get_autocast_ipu_dtype,
  270. torch.get_autocast_gpu_dtype,
  271. torch.set_autocast_gpu_dtype,
  272. torch.get_autocast_xla_dtype,
  273. torch.set_autocast_xla_dtype,
  274. torch.autocast_increment_nesting,
  275. torch.autocast_decrement_nesting,
  276. torch.is_autocast_cache_enabled,
  277. torch.set_autocast_cache_enabled,
  278. torch.nn.functional.hardswish,
  279. torch.is_vulkan_available,
  280. torch.are_deterministic_algorithms_enabled,
  281. torch.use_deterministic_algorithms,
  282. torch.is_deterministic_algorithms_warn_only_enabled,
  283. torch.set_deterministic_debug_mode,
  284. torch.get_device_module,
  285. torch.get_deterministic_debug_mode,
  286. torch.set_float32_matmul_precision,
  287. torch.get_float32_matmul_precision,
  288. torch.unify_type_list,
  289. torch.is_warn_always_enabled,
  290. torch.set_warn_always,
  291. torch.vitals_enabled,
  292. torch.set_vital,
  293. torch.read_vitals,
  294. torch.vmap,
  295. torch.cond,
  296. torch.frombuffer,
  297. torch.asarray,
  298. torch._functional_sym_constrain_range,
  299. torch._make_dep_token,
  300. Tensor.__delitem__,
  301. Tensor.__dir__,
  302. Tensor.__getattribute__,
  303. Tensor.__init__,
  304. Tensor.__iter__,
  305. Tensor.__init_subclass__,
  306. Tensor.__delattr__,
  307. Tensor.__setattr__,
  308. Tensor.__torch_function__,
  309. Tensor.__torch_dispatch__,
  310. Tensor.__new__,
  311. Tensor.__class__,
  312. Tensor.__subclasshook__,
  313. Tensor.__hash__,
  314. Tensor.as_subclass,
  315. Tensor.eig,
  316. Tensor.lstsq,
  317. Tensor.reinforce,
  318. Tensor.new,
  319. Tensor.new_tensor,
  320. Tensor.new_empty,
  321. Tensor.new_empty_strided,
  322. Tensor.new_zeros,
  323. Tensor.new_ones,
  324. Tensor.new_full,
  325. Tensor._make_subclass,
  326. Tensor.solve,
  327. Tensor.symeig,
  328. Tensor.stride,
  329. Tensor.unflatten,
  330. Tensor.to_sparse_coo,
  331. Tensor.to_sparse_csr,
  332. Tensor.to_sparse_csc,
  333. Tensor.to_sparse_bsr,
  334. Tensor.to_sparse_bsc,
  335. Tensor._to_sparse,
  336. Tensor._to_sparse_csr,
  337. Tensor._to_sparse_csc,
  338. Tensor._to_sparse_bsr,
  339. Tensor._to_sparse_bsc,
  340. Tensor._typed_storage,
  341. Tensor._reduce_ex_internal,
  342. Tensor._fix_weakref,
  343. Tensor._view_func,
  344. Tensor._view_func_unsafe,
  345. Tensor._rev_view_func_unsafe,
  346. Tensor._make_dtensor,
  347. Tensor._make_wrapper_subclass,
  348. Tensor._python_dispatch.__get__,
  349. Tensor._has_symbolic_sizes_strides.__get__,
  350. Tensor._conj,
  351. Tensor._conj_physical,
  352. Tensor._lazy_clone,
  353. Tensor._neg_view,
  354. Tensor._is_zerotensor,
  355. Tensor._is_all_true,
  356. Tensor._is_any_true,
  357. Tensor._addmm_activation,
  358. Tensor.to_padded_tensor,
  359. Tensor._use_count,
  360. }
  361. @functools.cache
  362. def get_default_nowrap_functions() -> set[Callable]:
  363. """
  364. Return public functions that do not wrap in a subclass when invoked by
  365. the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
  366. these functions represent field accesses (i.e., retrieving a Tensor that
  367. is stored somewhere on the Tensor) as opposed to computation. Users of
  368. these functions expect object identity to be preserved over multiple accesses
  369. (e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
  370. the fly every time (furthermore, the tensor stored here might already be
  371. the subclass, in which case wrapping really ought not to happen).
  372. Not ALL property accessors have this property; for example ``Tensor.T`` actually
  373. just creates a new transposed tensor on the fly, and so we SHOULD interpose on
  374. these calls (you need to check the implementation of the function to see if
  375. this is the case or not). Additionally, if a property accessor doesn't return a Tensor,
  376. it doesn't have to be on this list (though it is harmless if it is).
  377. """
  378. Tensor = torch.Tensor
  379. return {
  380. Tensor._base.__get__,
  381. Tensor.grad.__get__,
  382. Tensor._grad.__get__,
  383. }
  384. @functools.cache
  385. @_disable_user_warnings
  386. def get_testing_overrides() -> dict[Callable, Callable]:
  387. """Return a dict containing dummy overrides for all overridable functions
  388. Returns
  389. -------
  390. Dict[Callable, Callable]
  391. A dictionary that maps overridable functions in the PyTorch API to
  392. lambda functions that have the same signature as the real function
  393. and unconditionally return -1. These lambda functions are useful
  394. for testing API coverage for a type that defines ``__torch_function__``.
  395. Examples
  396. --------
  397. >>> import inspect
  398. >>> my_add = torch.overrides.get_testing_overrides()[torch.add]
  399. >>> inspect.signature(my_add)
  400. <Signature (input, other, out=None)>
  401. """
  402. # Every function in the PyTorchAPI that can be overridden needs an entry
  403. # in this dict.
  404. #
  405. # Optimally we would use inspect to get the function signature and define
  406. # the lambda function procedurally but that is blocked by generating
  407. # function signatures for native kernels that can be consumed by inspect.
  408. # See Issue #28233.
  409. Tensor = torch.Tensor
  410. ret: dict[Callable, Callable] = {
  411. torch.abs: lambda input, out=None: -1,
  412. torch.absolute: lambda input, out=None: -1,
  413. torch.adaptive_avg_pool1d: lambda input, output_size: -1,
  414. torch.adaptive_max_pool1d: lambda inputs, output_size: -1,
  415. torch.acos: lambda input, out=None: -1,
  416. torch.adjoint: lambda input: -1,
  417. torch.arccos: lambda input, out=None: -1,
  418. torch.acosh: lambda input, out=None: -1,
  419. torch.arccosh: lambda input, out=None: -1,
  420. torch.add: lambda input, other, out=None: -1,
  421. torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
  422. torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
  423. torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1,
  424. torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
  425. torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1,
  426. torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1,
  427. torch.affine_grid_generator: lambda theta, size, align_corners: -1,
  428. torch.all: lambda input, dim=None: -1,
  429. torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1,
  430. torch.alpha_dropout: lambda input, p, train, inplace=False: -1,
  431. torch.amax: lambda input, dim=None: -1,
  432. torch.amin: lambda input, dim=None: -1,
  433. torch.aminmax: lambda input, dim=None, keepdim=False, out=None: -1,
  434. torch.angle: lambda input, out=None: -1,
  435. torch.any: lambda input, dim=None, keepdim=False, out=None: -1,
  436. torch.argmax: lambda input: -1,
  437. torch.argmin: lambda input: -1,
  438. torch.argsort: lambda input, dim=None: -1,
  439. torch.asin: lambda input, out=None: -1,
  440. torch._assert_async: lambda input, msg: -1,
  441. torch.arcsin: lambda input, out=None: -1,
  442. torch.asinh: lambda input, out=None: -1,
  443. torch.arcsinh: lambda input, out=None: -1,
  444. torch.atan: lambda input, out=None: -1,
  445. torch.arctan: lambda input, out=None: -1,
  446. torch.atan2: lambda input, other, out=None: -1,
  447. torch.arctan2: lambda input, other, out=None: -1,
  448. torch.atanh: lambda input, out=None: -1,
  449. torch.arctanh: lambda input, out=None: -1,
  450. torch.atleast_1d: lambda *tensors: -1,
  451. torch.atleast_2d: lambda *tensors: -1,
  452. torch.atleast_3d: lambda *tensors: -1,
  453. torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1,
  454. torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
  455. torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1,
  456. torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1,
  457. torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1,
  458. torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1,
  459. torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
  460. torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
  461. torch.batch_norm_stats: lambda input, eps: -1,
  462. torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
  463. torch.bernoulli: lambda input, generator=None, out=None: -1,
  464. torch.bilinear: lambda input1, input2, weight, bias: -1,
  465. torch.binary_cross_entropy_with_logits: (
  466. lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
  467. ),
  468. torch.bincount: lambda input, weights=None, minlength=0: -1,
  469. torch.binomial: lambda count, prob, generator=None: -1,
  470. torch.bitwise_and: lambda input, other, out=None: -1,
  471. torch.bitwise_not: lambda input, out=None: -1,
  472. torch.bitwise_or: lambda input, other, out=None: -1,
  473. torch.bitwise_xor: lambda input, other, out=None: -1,
  474. torch.bitwise_left_shift: lambda input, other, out=None: -1,
  475. torch.bitwise_right_shift: lambda input, other, out=None: -1,
  476. torch.block_diag: lambda *tensors: -1,
  477. torch.bmm: lambda input, mat2, out_dtype=None, out=None: -1,
  478. torch.broadcast_tensors: lambda *tensors: -1,
  479. torch.broadcast_to: lambda self, size: -1,
  480. torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1,
  481. torch.cartesian_prod: lambda *tensors: -1,
  482. torch.cat: lambda tensors, dim=0, out=None: -1,
  483. torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
  484. torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate
  485. torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1,
  486. torch.ceil: lambda input, out=None: -1,
  487. torch.celu: lambda input, alpha=1.0, inplace=False: -1,
  488. torch.chain_matmul: lambda *matrices, out=None: -1,
  489. torch.channel_shuffle: lambda input, groups: -1,
  490. torch.cholesky: lambda input, upper=False, out=None: -1,
  491. torch.linalg.cholesky: lambda input, out=None: -1,
  492. torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
  493. torch.cholesky_inverse: lambda input, upper=False, out=None: -1,
  494. torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1,
  495. torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1,
  496. torch.chunk: lambda input, chunks, dim=0: -1,
  497. torch.clamp: lambda input, min=None, max=None, out=None: -1,
  498. torch.clip: lambda input, min=None, max=None, out=None: -1,
  499. torch.clamp_min: lambda input, min, out=None: -1,
  500. torch.clamp_max: lambda input, max, out=None: -1,
  501. torch.column_stack: lambda tensors, out=None: -1,
  502. torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1,
  503. torch.clone: lambda input: -1,
  504. torch.combinations: lambda input, r=2, with_replacement=False: -1,
  505. torch.complex: lambda real, imag: -1,
  506. torch.copysign: lambda input, other, out=None: -1,
  507. torch.polar: lambda abs, ang: -1,
  508. torch.linalg.cond: lambda input, ord=None: -1,
  509. torch.conj: lambda input, out=None: -1,
  510. torch.conj_physical: lambda input, out=None: -1,
  511. torch.resolve_conj: lambda input, out=None: -1,
  512. torch.resolve_neg: lambda input, out=None: -1,
  513. torch.constant_pad_nd: lambda input, pad, value=0: -1,
  514. torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
  515. torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
  516. torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
  517. torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1,
  518. torch.conv_tbc: lambda input, weight, bias, pad=0: -1,
  519. torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
  520. torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
  521. torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
  522. torch.corrcoef: lambda input: -1,
  523. torch.cos: lambda input, out=None: -1,
  524. torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,
  525. torch.cosh: lambda input, out=None: -1,
  526. torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
  527. torch.count_nonzero: lambda input: -1,
  528. torch.cross: lambda input, other, dim=None, out=None: -1,
  529. torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
  530. torch.ctc_loss: (
  531. lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
  532. ),
  533. torch.cummax: lambda input, dim, out=None: -1,
  534. torch.cummin: lambda input, dim, out=None: -1,
  535. torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
  536. torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
  537. torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1,
  538. torch.logcumsumexp: lambda input, dim, out=None: -1,
  539. torch.deg2rad: lambda input, out=None: -1,
  540. torch.dequantize: lambda input: -1,
  541. torch.det: lambda input: -1,
  542. torch.linalg.det: lambda input: -1, # alias for torch.det # type: ignore[attr-defined]
  543. torch.detach: lambda input: -1,
  544. torch.diag: lambda input, diagonal=0, out=None: -1,
  545. torch.diag_embed: lambda input, diagonal=0, out=None: -1,
  546. torch.diagflat: lambda input, offset=0: -1,
  547. torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
  548. torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
  549. torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
  550. torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
  551. torch.as_strided_scatter: lambda self, src, size, stride, storage_offset=None: -1,
  552. torch.digamma: lambda input, out=None: -1,
  553. torch.dist: lambda input, other, p=2: -1,
  554. torch.div: lambda input, other, rounding_mode=None, out=None: -1,
  555. torch.divide: lambda input, other, rounding_mode=None, out=None: -1,
  556. torch.dot: lambda input, other, out=None: -1,
  557. torch.dropout: lambda input, p, train, inplace=False: -1,
  558. torch.dsmm: lambda input, mat2, out_dtype=None: -1,
  559. torch.hsmm: lambda mat1, mat2: -1,
  560. torch.dsplit: lambda input, indices_or_sections: -1,
  561. torch.dstack: lambda tensors, out=None: -1,
  562. torch.linalg.eig: lambda input, out=None: -1,
  563. torch.linalg.eigvals: lambda input, out=None: -1,
  564. torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
  565. torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
  566. torch.einsum: lambda equation, *operands: -1,
  567. torch.embedding: (
  568. lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
  569. ),
  570. torch.embedding_bag: (
  571. lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950
  572. ),
  573. torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  574. torch.eq: lambda input, other, out=None: -1,
  575. torch.equal: lambda input, other: -1,
  576. torch.erf: lambda input, out=None: -1,
  577. torch.erfc: lambda input, out=None: -1,
  578. torch.erfinv: lambda input, out=None: -1,
  579. torch.exp: lambda input, out=None: -1,
  580. torch.exp2: lambda input, out=None: -1,
  581. torch.expm1: lambda input, out=None: -1,
  582. torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
  583. torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
  584. torch.fused_moving_avg_obs_fake_quant: (
  585. lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950
  586. ),
  587. torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias, output: -1,
  588. torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias, output: -1,
  589. torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950
  590. torch.fbgemm_linear_int8_weight_fp32_activation: (
  591. lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1
  592. ),
  593. torch.fbgemm_linear_quantize_weight: lambda input: -1,
  594. torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
  595. torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
  596. torch.feature_alpha_dropout: lambda input, p, train: -1,
  597. torch.feature_dropout: lambda input, p, train: -1,
  598. torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
  599. torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1,
  600. torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1,
  601. torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1,
  602. torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1,
  603. torch.fft.hfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  604. torch.fft.ihfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  605. torch.fft.hfftn: lambda input, s=None, dim=-1, norm=None: -1,
  606. torch.fft.ihfftn: lambda input, s=None, dim=-1, norm=None: -1,
  607. torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1,
  608. torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1,
  609. torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1,
  610. torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1,
  611. torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  612. torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  613. torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  614. torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  615. torch.fft.fftshift: lambda input, dim=None: -1,
  616. torch.fft.ifftshift: lambda input, dim=None: -1,
  617. torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
  618. torch.fix: lambda input, out=None: -1,
  619. torch.flatten: lambda input, start_dim=0, end_dim=-1: -1,
  620. torch.flip: lambda input, dims: -1,
  621. torch.fliplr: lambda input: -1,
  622. torch.flipud: lambda input: -1,
  623. torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
  624. torch.floor: lambda input, out=None: -1,
  625. torch.floor_divide: lambda input, other: -1,
  626. torch.float_power: lambda input, exponent, out=None: -1,
  627. torch.fmod: lambda input, other, out=None: -1,
  628. torch.frac: lambda input, out=None: -1,
  629. torch.frexp: lambda input, out=None: -1,
  630. torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950
  631. torch._functional_assert_async: lambda input, msg, dep_token: -1,
  632. torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
  633. torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
  634. torch.gcd: lambda input, other, out=None: -1,
  635. torch.ge: lambda input, other, out=None: -1,
  636. torch.get_device: lambda input: -1,
  637. torch.greater_equal: lambda input, other, out=None: -1,
  638. torch.geqrf: lambda input, out=None: -1,
  639. torch.i0: lambda input, out=None: -1,
  640. torch.inner: lambda input, other, out=None: -1,
  641. torch.outer: lambda input, vec2, out=None: -1,
  642. torch.ger: lambda input, vec2, out=None: -1, # alias for torch.outer
  643. torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1,
  644. torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
  645. torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
  646. torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
  647. torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1,
  648. torch.gru: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
  649. torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  650. torch.gt: lambda input, other, out=None: -1,
  651. torch.greater: lambda input, other, out=None: -1,
  652. torch.hardshrink: lambda input, lambd=0.5: -1,
  653. torch.hash_tensor: lambda input, dim=None, keepdim=False, mode=0, out=None: -1,
  654. torch.heaviside: lambda input, values, out=None: -1,
  655. torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
  656. torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
  657. torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
  658. torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
  659. torch.linalg.householder_product: lambda input, tau: -1,
  660. torch.hspmm: lambda mat1, mat2, out=None: -1,
  661. torch.hsplit: lambda input, indices_or_sections: -1,
  662. torch.hstack: lambda tensors, out=None: -1,
  663. torch.hypot: lambda input, other, out=None: -1,
  664. torch.igamma: lambda input, other, out=None: -1,
  665. torch.igammac: lambda input, other, out=None: -1,
  666. torch.imag: lambda input, out=None: -1,
  667. torch.index_add: lambda input, dim, index, source: -1,
  668. torch.index_copy: lambda input, dim, index, source: -1,
  669. torch.index_put: lambda input, indices, values, accumulate=False: -1,
  670. torch.index_select: lambda input, dim, index, out=None: -1,
  671. torch.index_fill: lambda input, dim, index, value: -1,
  672. torch.index_reduce: lambda input, dim, index, source, reduce, include_input=True: -1,
  673. torch.isfinite: lambda tensor: -1,
  674. torch.isin: lambda e, te, assume_unique=False, invert=False: -1,
  675. torch.isinf: lambda tensor: -1,
  676. torch.isreal: lambda tensor: -1,
  677. torch.isposinf: lambda input, out=None: -1,
  678. torch.isneginf: lambda input, out=None: -1,
  679. torch.instance_norm: (
  680. lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1
  681. ),
  682. torch.int_repr: lambda input: -1,
  683. torch.inverse: lambda input, out=None: -1,
  684. torch.linalg.inv: lambda input, out=None: -1,
  685. torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
  686. torch.is_complex: lambda input: -1,
  687. torch.is_conj: lambda input: -1,
  688. torch.is_neg: lambda input: -1,
  689. torch.is_distributed: lambda input: -1,
  690. torch.is_inference: lambda input: -1,
  691. torch.is_floating_point: lambda input: -1,
  692. torch.is_nonzero: lambda input: -1,
  693. torch.is_same_size: lambda input, other: -1,
  694. torch.is_signed: lambda input: -1,
  695. torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
  696. torch.isnan: lambda input: -1,
  697. torch.istft: (
  698. lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950
  699. ),
  700. torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,
  701. torch.kron: lambda input, other: -1,
  702. torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
  703. torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
  704. torch.linalg.ldl_factor: lambda input, hermitian=False, out=None: -1,
  705. torch.linalg.ldl_solve: lambda LD, pivots, B, hermitian=False, out=None: -1,
  706. torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
  707. torch.lcm: lambda input, other, out=None: -1,
  708. torch.ldexp: lambda input, other, out=None: -1,
  709. torch.le: lambda input, other, out=None: -1,
  710. torch.less_equal: lambda input, other, out=None: -1,
  711. torch.lerp: lambda input, end, weight, out=None: -1,
  712. torch.lgamma: lambda input, out=None: -1,
  713. torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950
  714. torch.log: lambda input, out=None: -1,
  715. torch.log_softmax: lambda input, dim, dtype=None: -1,
  716. torch.log10: lambda input, out=None: -1,
  717. torch.log1p: lambda input, out=None: -1,
  718. torch.log2: lambda input, out=None: -1,
  719. torch.logaddexp: lambda input, other, out=None: -1,
  720. torch.logaddexp2: lambda input, other, out=None: -1,
  721. torch.logdet: lambda input: -1,
  722. torch.xlogy: lambda x, y, out=None: -1,
  723. torch.logical_and: lambda input, other, out=None: -1,
  724. torch.logical_not: lambda input, out=None: -1,
  725. torch.logical_or: lambda input, other, out=None: -1,
  726. torch.logical_xor: lambda input, other, out=None: -1,
  727. torch.logit: lambda input, eps=None: -1,
  728. torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
  729. torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1,
  730. torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  731. torch.lt: lambda input, other, out=None: -1,
  732. torch.less: lambda input, other, out=None: -1,
  733. torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
  734. torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
  735. torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950
  736. torch.masked_fill: lambda input, mask, value: -1,
  737. torch.masked_scatter: lambda input, mask, source: -1,
  738. torch.masked_select: lambda input, mask, out=None: -1,
  739. torch.matmul: lambda input, other, out=None: -1,
  740. torch.linalg.lu: lambda input, pivot=True, out=None: -1,
  741. torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1,
  742. torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1,
  743. torch.linalg.lu_solve: lambda LU, pivots, B, left=True, adjoint=False, out=None: -1,
  744. torch.linalg.matmul: lambda input, other, out=None: -1, # alias for torch.matmul
  745. torch.matrix_power: lambda input, n: -1,
  746. torch.linalg.matrix_power: lambda input, n, out=None: -1,
  747. torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
  748. torch.linalg.multi_dot: lambda tensors, out=None: -1,
  749. torch.matrix_exp: lambda input: -1,
  750. torch.linalg.matrix_exp: lambda input: -1,
  751. torch.max: lambda input, out=None: -1,
  752. torch.maximum: lambda input, other, out=None: -1,
  753. torch.fmax: lambda input, other, out=None: -1,
  754. torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
  755. torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
  756. torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
  757. torch.max_pool1d_with_indices: (
  758. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  759. ),
  760. torch.mean: lambda input, dim=None: -1,
  761. torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
  762. torch.median: lambda input, dim=None: -1,
  763. torch.nanmedian: lambda input, dim=None: -1,
  764. torch.meshgrid: lambda *tensors, **kwargs: -1,
  765. torch.min: lambda input, out=None: -1,
  766. torch.minimum: lambda input, other, out=None: -1,
  767. torch.fmin: lambda input, other, out=None: -1,
  768. torch.miopen_batch_norm: (
  769. lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1
  770. ),
  771. torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950
  772. torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
  773. torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
  774. torch.miopen_convolution_transpose: (
  775. lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1
  776. ),
  777. torch.miopen_depthwise_convolution: (
  778. lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1
  779. ),
  780. torch.miopen_rnn: (
  781. lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950
  782. ),
  783. torch.mm: lambda input, mat2, out_dtype=None, out=None: -1,
  784. torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
  785. torch.movedim: lambda input, source, destination: -1,
  786. torch.moveaxis: lambda input, source, destination: -1,
  787. torch.msort: lambda input, descending=False, out=None: -1,
  788. torch.mul: lambda input, other, out=None: -1,
  789. torch.multiply: lambda input, other, out=None: -1,
  790. torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
  791. torch.mv: lambda input, vec, out=None: -1,
  792. torch.mvlgamma: lambda input, p: -1,
  793. torch.narrow: lambda input, dim, start, length: -1,
  794. torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
  795. torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
  796. torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
  797. torch.native_dropout: lambda input, p, train: -1,
  798. torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
  799. torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1,
  800. torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
  801. torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
  802. torch.native_channel_shuffle: lambda input, groups: -1,
  803. torch.ne: lambda input, other, out=None: -1,
  804. torch.not_equal: lambda input, other, out=None: -1,
  805. torch.neg: lambda input, out=None: -1,
  806. torch.negative: lambda input, out=None: -1,
  807. torch.nextafter: lambda input, other, out=None: -1,
  808. torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1,
  809. torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1,
  810. torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1,
  811. torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1,
  812. torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1,
  813. torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1,
  814. torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1,
  815. torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
  816. torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
  817. torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
  818. torch.nn.functional.avg_pool2d: (
  819. lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
  820. ),
  821. torch.nn.functional.avg_pool3d: (
  822. lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
  823. ),
  824. torch.nn.functional.batch_norm: (
  825. lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1
  826. ),
  827. torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
  828. torch.nn.functional.binary_cross_entropy: (
  829. lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
  830. ),
  831. torch.nn.functional.binary_cross_entropy_with_logits: (
  832. lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
  833. ),
  834. torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
  835. torch.nn.functional.cosine_embedding_loss: (
  836. lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
  837. ),
  838. torch.nn.functional.cross_entropy: (
  839. lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950
  840. ),
  841. torch.nn.functional.ctc_loss: (
  842. lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
  843. ),
  844. torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
  845. torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
  846. torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
  847. torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
  848. torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
  849. torch.nn.functional.embedding: (
  850. lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
  851. ),
  852. torch.nn.functional.embedding_bag: (
  853. lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950
  854. ),
  855. torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
  856. torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
  857. torch.nn.functional.fractional_max_pool2d: (
  858. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
  859. ),
  860. torch.nn.functional.fractional_max_pool2d_with_indices: (
  861. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
  862. ),
  863. torch.nn.functional.fractional_max_pool3d: (
  864. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
  865. ),
  866. torch.nn.functional.fractional_max_pool3d_with_indices: (
  867. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
  868. ),
  869. torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1,
  870. torch.nn.functional.gelu: lambda input, approximate="none": -1,
  871. torch.nn.functional.glu: lambda input, dim=-1: -1,
  872. torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950
  873. torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
  874. torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
  875. torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
  876. torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1,
  877. torch.nn.functional.hinge_embedding_loss: (
  878. lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1
  879. ),
  880. torch.nn.functional.instance_norm: (
  881. lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950
  882. ),
  883. torch.nn.functional.interpolate: (
  884. lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950
  885. ),
  886. torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950
  887. torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1,
  888. torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
  889. torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
  890. torch.nn.functional.linear: lambda input, weight, bias=None: -1,
  891. torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1,
  892. torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
  893. torch.nn.functional.logsigmoid: lambda input: -1,
  894. torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
  895. torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
  896. torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
  897. torch.nn.functional.margin_ranking_loss: (
  898. lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
  899. ),
  900. torch.nn.functional.max_pool1d: (
  901. lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
  902. ),
  903. torch.nn.functional.max_pool1d_with_indices: (
  904. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  905. ),
  906. torch.nn.functional.max_pool2d: (
  907. lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
  908. ),
  909. torch.nn.functional.max_pool2d_with_indices: (
  910. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  911. ),
  912. torch.nn.functional.max_pool3d: (
  913. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  914. ),
  915. torch.nn.functional.max_pool3d_with_indices: (
  916. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  917. ),
  918. torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
  919. torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
  920. torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
  921. torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1,
  922. torch.nn.functional.multi_head_attention_forward: (
  923. lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950
  924. ),
  925. torch.nn.functional.multi_margin_loss: (
  926. lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1
  927. ),
  928. torch.nn.functional.multilabel_margin_loss: (
  929. lambda input, target, size_average=None, reduce=None, reduction="mean": -1
  930. ),
  931. torch.nn.functional.multilabel_soft_margin_loss: (
  932. lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
  933. ),
  934. torch.nn.functional.nll_loss: (
  935. lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1
  936. ),
  937. torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
  938. torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
  939. torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1,
  940. torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
  941. torch.nn.functional.poisson_nll_loss: (
  942. lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950
  943. ),
  944. torch.nn.functional.prelu: lambda input, weight: -1,
  945. torch.nn.functional.relu: lambda input, inplace=False: -1,
  946. torch.nn.functional.relu6: lambda input, inplace=False: -1,
  947. torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
  948. torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950
  949. torch.nn.functional.selu: lambda input, inplace=False: -1,
  950. torch.nn.functional.silu: lambda input, inplace=False: -1,
  951. torch.nn.functional.mish: lambda input, inplace=False: -1,
  952. torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
  953. torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950
  954. torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0, weight=None: -1,
  955. torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
  956. torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
  957. torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
  958. torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
  959. torch.nn.functional.softshrink: lambda input, lambd=0.5: -1,
  960. torch.nn.functional.softsign: lambda input: -1,
  961. torch.nn.functional.tanhshrink: lambda input: -1,
  962. torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
  963. torch.nn.functional.triplet_margin_loss: (
  964. lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
  965. ),
  966. torch.nn.functional.triplet_margin_with_distance_loss: (
  967. lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1
  968. ),
  969. torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
  970. torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1,
  971. torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1,
  972. torch.nn.init.constant_: lambda tensor, val: -1,
  973. torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950
  974. torch.nonzero: lambda input, as_tuple=False: -1,
  975. torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
  976. torch.argwhere: lambda input: -1,
  977. torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
  978. torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
  979. torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
  980. torch.linalg.matrix_norm: lambda input, ord="fro", dim=(
  981. -2,
  982. -1,
  983. ), keepdim=False, out=None, dtype=None: -1,
  984. torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
  985. torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
  986. torch.numel: lambda input: -1,
  987. torch.orgqr: lambda input, tau: -1,
  988. torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
  989. torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
  990. torch.permute: lambda self, dim: -1,
  991. torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1,
  992. torch.pdist: lambda input, p=2: -1,
  993. torch.pinverse: lambda input, rcond=1e-15: -1,
  994. torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1,
  995. torch.pixel_shuffle: lambda input, upscale_factor: -1,
  996. torch.pixel_unshuffle: lambda input, downscale_factor: -1,
  997. torch.poisson: lambda input, generator=None: -1,
  998. torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
  999. torch.polygamma: lambda input, n, out=None: -1,
  1000. torch.positive: lambda input, out=None: -1,
  1001. torch.prelu: lambda input, weight: -1,
  1002. torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  1003. torch.pow: lambda input, exponent, out=None: -1,
  1004. torch.prod: lambda input, dtype=None: -1,
  1005. torch.put: lambda input, index, source, accumulate=False: -1,
  1006. torch.q_per_channel_axis: lambda input: -1,
  1007. torch.q_per_channel_scales: lambda input: -1,
  1008. torch.q_per_channel_zero_points: lambda input: -1,
  1009. torch.q_scale: lambda input: -1,
  1010. torch.q_zero_point: lambda input: -1,
  1011. torch.qr: lambda input, some=True, out=None: -1,
  1012. torch.linalg.qr: lambda input, mode="reduced", out=None: -1,
  1013. torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
  1014. torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
  1015. torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
  1016. torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
  1017. torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
  1018. torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
  1019. torch.quantized_gru_cell: (
  1020. lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
  1021. ),
  1022. torch.quantized_lstm_cell: (
  1023. lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
  1024. ),
  1025. torch.quantized_max_pool1d: (
  1026. lambda input, kernel_size, stride=(), padding=(0,), dilation=(
  1027. 1,
  1028. ), ceil_mode=False: -1
  1029. ),
  1030. torch.quantized_max_pool2d: (
  1031. lambda input, kernel_size, stride=(), padding=(0, 0), dilation=(
  1032. 1,
  1033. 1,
  1034. ), ceil_mode=False: -1
  1035. ),
  1036. torch.quantized_max_pool3d: (
  1037. lambda input, kernel_size, stride=(), padding=(0, 0, 0), dilation=(
  1038. 1,
  1039. 1,
  1040. 1,
  1041. ), ceil_mode=False: -1
  1042. ),
  1043. torch.quantized_rnn_relu_cell: (
  1044. lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
  1045. ),
  1046. torch.quantized_rnn_tanh_cell: (
  1047. lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
  1048. ),
  1049. torch.rad2deg: lambda input, out=None: -1,
  1050. torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  1051. torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
  1052. torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  1053. torch.ravel: lambda input: -1,
  1054. torch.real: lambda input, out=None: -1,
  1055. torch.vdot: lambda input, other, out=None: -1,
  1056. torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
  1057. torch.view_as_real: lambda input: -1,
  1058. torch.view_as_complex: lambda input: -1,
  1059. torch.reciprocal: lambda input, out=None: -1,
  1060. torch.relu: lambda input, inplace=False: -1,
  1061. torch.remainder: lambda input, other, out=None: -1,
  1062. torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
  1063. torch.repeat_interleave: lambda input, dim=None: -1,
  1064. torch.reshape: lambda input, shape: -1,
  1065. torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
  1066. torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
  1067. torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  1068. torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
  1069. torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  1070. torch.roll: lambda input, shifts, dims=None: -1,
  1071. torch.rot90: lambda input, k=1, dims=(0, 1): -1,
  1072. torch.round: lambda input, out=None: -1,
  1073. torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
  1074. torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
  1075. torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1,
  1076. torch.rsqrt: lambda input, out=None: -1,
  1077. torch.rsub: lambda input, other, alpha=1: -1,
  1078. torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
  1079. torch.scatter: lambda input, dim, index, src: -1,
  1080. torch.scatter_add: lambda input, dim, index, src: -1,
  1081. torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
  1082. torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
  1083. torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950
  1084. torch.select: lambda input, dim, index: -1,
  1085. torch.select_scatter: lambda input, src, dim, index: -1,
  1086. torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
  1087. torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
  1088. torch.selu: lambda input, inplace=False: -1,
  1089. torch.sigmoid: lambda input, out=None: -1,
  1090. torch.sign: lambda input, out=None: -1,
  1091. torch.signbit: lambda input, out=None: -1,
  1092. torch.sgn: lambda input, out=None: -1,
  1093. torch.sin: lambda input, out=None: -1,
  1094. torch.sinc: lambda input, out=None: -1,
  1095. torch.sinh: lambda input, out=None: -1,
  1096. torch.slogdet: lambda input: -1,
  1097. torch.linalg.slogdet: lambda input: -1,
  1098. torch.smm: lambda input, mat2, out_dtype=None: -1,
  1099. torch.spmm: lambda input, mat2, out_dtype=None: -1,
  1100. torch.softmax: lambda input, dim, dtype=None: -1,
  1101. torch.linalg.solve: lambda A, B, left=True, out=None: -1,
  1102. torch.linalg.solve_ex: lambda A, B, left=True, check_errors=False, out=None: -1,
  1103. torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
  1104. torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
  1105. torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
  1106. torch.sqrt: lambda input, out=None: -1,
  1107. torch.square: lambda input, out=None: -1,
  1108. torch.squeeze: lambda input, dim=None, out=None: -1,
  1109. torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
  1110. torch.stack: lambda tensors, dim=0, out=None: -1,
  1111. torch.std: lambda input, dim=None: -1,
  1112. torch.std_mean: lambda input, dim=None: -1,
  1113. torch.stft: (
  1114. lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None, align_to_window=None: -1 # noqa: B950
  1115. ),
  1116. torch.sub: lambda input, other, out=None: -1,
  1117. torch.subtract: lambda input, other, out=None: -1,
  1118. torch.sum: lambda input, dim=None: -1,
  1119. torch.sym_float: lambda input: -1,
  1120. torch.sym_int: lambda input: -1,
  1121. torch.sym_max: lambda a, b: -1,
  1122. torch.sym_min: lambda a, b: -1,
  1123. torch.sym_not: lambda input: -1,
  1124. torch.sym_ite: lambda a, b, c: -1,
  1125. torch.sym_sum: lambda args: -1,
  1126. torch._sym_sqrt: lambda input: -1,
  1127. torch._sym_cos: lambda input: -1,
  1128. torch._sym_cosh: lambda input: -1,
  1129. torch._sym_sin: lambda input: -1,
  1130. torch._sym_sinh: lambda input: -1,
  1131. torch._sym_tan: lambda input: -1,
  1132. torch._sym_tanh: lambda input: -1,
  1133. torch._sym_asin: lambda input: -1,
  1134. torch._sym_acos: lambda input: -1,
  1135. torch._sym_atan: lambda input: -1,
  1136. torch.nansum: lambda input, dim=None: -1,
  1137. torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
  1138. torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
  1139. torch.linalg.svd: lambda input, full_matrices=True, out=None: -1,
  1140. torch.linalg.svdvals: lambda input, out=None: -1,
  1141. torch.swapaxes: lambda input, dim0, dim1: -1,
  1142. torch.swapdims: lambda input, axis0, axis1: -1,
  1143. torch.special.airy_ai: lambda input: -1,
  1144. torch.special.bessel_j0: lambda input: -1,
  1145. torch.special.bessel_j1: lambda input: -1,
  1146. torch.special.bessel_y0: lambda input: -1,
  1147. torch.special.bessel_y1: lambda input: -1,
  1148. torch.special.chebyshev_polynomial_t: lambda input, n, out=None: -1,
  1149. torch.special.chebyshev_polynomial_u: lambda input, n, out=None: -1,
  1150. torch.special.chebyshev_polynomial_v: lambda input, n, out=None: -1,
  1151. torch.special.chebyshev_polynomial_w: lambda input, n, out=None: -1,
  1152. torch.special.digamma: lambda input: -1,
  1153. torch.special.entr: lambda input: -1,
  1154. torch.special.erf: lambda input: -1,
  1155. torch.special.erfc: lambda input: -1,
  1156. torch.special.erfcx: lambda input: -1,
  1157. torch.special.erfinv: lambda input: -1,
  1158. torch.special.exp2: lambda input: -1,
  1159. torch.special.expit: lambda input: -1,
  1160. torch.special.expm1: lambda input: -1,
  1161. torch.special.gammainc: lambda input, other, out=None: -1,
  1162. torch.special.gammaincc: lambda input, other, out=None: -1,
  1163. torch.special.gammaln: lambda input: -1,
  1164. torch.special.hermite_polynomial_h: lambda input, n, out=None: -1,
  1165. torch.special.hermite_polynomial_he: lambda input, n, out=None: -1,
  1166. torch.special.i0: lambda input: -1,
  1167. torch.special.i0e: lambda input: -1,
  1168. torch.special.i1: lambda input: -1,
  1169. torch.special.i1e: lambda input: -1,
  1170. torch.special.laguerre_polynomial_l: lambda input, n, out=None: -1,
  1171. torch.special.legendre_polynomial_p: lambda input, n, out=None: -1,
  1172. torch.special.log1p: lambda input: -1,
  1173. torch.special.log_ndtr: lambda input: -1,
  1174. torch.special.log_softmax: lambda input, dim, dtype=None: -1,
  1175. torch.special.logit: lambda input: -1,
  1176. torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1,
  1177. torch.special.modified_bessel_i0: lambda input: -1,
  1178. torch.special.modified_bessel_i1: lambda input: -1,
  1179. torch.special.modified_bessel_k0: lambda input: -1,
  1180. torch.special.modified_bessel_k1: lambda input: -1,
  1181. torch.special.multigammaln: lambda input, p: -1,
  1182. torch.special.ndtr: lambda input: -1,
  1183. torch.special.ndtri: lambda input: -1,
  1184. torch.special.polygamma: lambda input, n, out=None: -1,
  1185. torch.special.psi: lambda input: -1,
  1186. torch.special.round: lambda input: -1,
  1187. torch.special.scaled_modified_bessel_k0: lambda input: -1,
  1188. torch.special.scaled_modified_bessel_k1: lambda input: -1,
  1189. torch.special.shifted_chebyshev_polynomial_t: lambda input, n, out=None: -1,
  1190. torch.special.shifted_chebyshev_polynomial_u: lambda input, n, out=None: -1,
  1191. torch.special.shifted_chebyshev_polynomial_v: lambda input, n, out=None: -1,
  1192. torch.special.shifted_chebyshev_polynomial_w: lambda input, n, out=None: -1,
  1193. torch.special.sinc: lambda input: -1,
  1194. torch.special.softmax: lambda input, dim, dtype=None: -1,
  1195. torch.special.spherical_bessel_j0: lambda input: -1,
  1196. torch.special.xlog1py: lambda input, other, out=None: -1,
  1197. torch.special.xlogy: lambda input, other, out=None: -1,
  1198. torch.special.zeta: lambda self, other, out=None: -1,
  1199. torch.t: lambda input: -1,
  1200. torch.take: lambda input, index: -1,
  1201. torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
  1202. torch.tan: lambda input, out=None: -1,
  1203. torch.tanh: lambda input, out=None: -1,
  1204. torch.linalg.tensorinv: lambda a, ind=2: -1,
  1205. torch.linalg.tensorsolve: lambda a, b, dims=None: -1,
  1206. torch.tensordot: lambda a, b, dims=2, out=None: -1,
  1207. torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
  1208. torch.threshold: lambda input, threshold, value, inplace=False: -1,
  1209. torch.tile: lambda input, dims: -1,
  1210. torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,
  1211. torch.trace: lambda input: -1,
  1212. torch.transpose: lambda input, dim0, dim1: -1,
  1213. torch.trapz: lambda y, x=None, dim=-1: -1,
  1214. torch.trapezoid: lambda y, x=None, dim=-1: -1,
  1215. torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
  1216. torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
  1217. torch.tril: lambda input, diagonal=0, out=None: -1,
  1218. torch.triplet_margin_loss: (
  1219. lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
  1220. ),
  1221. torch.triu: lambda input, diagonal=0, out=None: -1,
  1222. torch.true_divide: lambda input, other: -1,
  1223. torch.trunc: lambda input, out=None: -1,
  1224. torch.unbind: lambda input, dim=0: -1,
  1225. torch.unflatten: lambda input, dim, sizes, names: -1,
  1226. torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
  1227. torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
  1228. torch.unravel_index: lambda indices, shape: -1,
  1229. torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
  1230. torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1,
  1231. torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
  1232. torch.unsqueeze: lambda input, dim, out=None: -1,
  1233. torch.linalg.vander: lambda x, N=None: -1,
  1234. torch.var: lambda input, dim=None: -1,
  1235. torch.var_mean: lambda input, dim=None: -1,
  1236. torch.vsplit: lambda input, indices_or_sections: -1,
  1237. torch.vstack: lambda tensors, out=None: -1,
  1238. torch.where: lambda condition, x=None, y=None: -1,
  1239. torch._wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1,
  1240. torch._wrapped_quantized_linear_prepacked: (
  1241. lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 # noqa: B950
  1242. ),
  1243. torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  1244. torch._fw_primal_copy: lambda self, level: -1,
  1245. torch._make_dual_copy: lambda primal, tangent, level: -1,
  1246. torch.view_as_real_copy: lambda self: -1,
  1247. torch.view_as_complex_copy: lambda self: -1,
  1248. torch._conj_copy: lambda self: -1,
  1249. torch._neg_view_copy: lambda self: -1,
  1250. torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1,
  1251. torch._sparse_broadcast_to_copy: lambda self, size: -1,
  1252. torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1,
  1253. torch.expand_copy: lambda self, size, *, implicit=False: -1,
  1254. torch.narrow_copy: lambda self, dim, start, length: -1,
  1255. torch.permute_copy: lambda self, dims: -1,
  1256. torch._reshape_alias_copy: lambda self, size, stride: -1,
  1257. torch.select_copy: lambda self, dim, index: -1,
  1258. torch.detach_copy: lambda self: -1,
  1259. torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1,
  1260. torch.split_copy: lambda self, split_size, dim=0: -1,
  1261. torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1,
  1262. torch.squeeze_copy: lambda self, dim: -1,
  1263. torch.t_copy: lambda self: -1,
  1264. torch.transpose_copy: lambda self, dim0, dim1: -1,
  1265. torch.unsqueeze_copy: lambda self, dim: -1,
  1266. torch._indices_copy: lambda self: -1,
  1267. torch._values_copy: lambda self: -1,
  1268. torch.indices_copy: lambda self: -1,
  1269. torch.values_copy: lambda self: -1,
  1270. torch.crow_indices_copy: lambda self: -1,
  1271. torch.col_indices_copy: lambda self: -1,
  1272. torch.ccol_indices_copy: lambda self: -1,
  1273. torch.row_indices_copy: lambda self: -1,
  1274. torch.unbind_copy: lambda self, dim=0: -1,
  1275. torch.view_copy: lambda self, dtype: -1,
  1276. torch.unfold_copy: lambda self, dimension, size, step: -1,
  1277. torch.alias_copy: lambda self: -1,
  1278. Tensor.__floordiv__: lambda self, other: -1,
  1279. Tensor.__rfloordiv__: lambda self, other: -1,
  1280. Tensor.__ifloordiv__: lambda self, other: -1,
  1281. Tensor.__truediv__: lambda self, other: -1,
  1282. Tensor.__rtruediv__: lambda self, other: -1,
  1283. Tensor.__itruediv__: lambda self, other: -1,
  1284. Tensor.__lshift__: lambda self, other: -1,
  1285. Tensor.__rlshift__: lambda self, other: -1,
  1286. Tensor.__ilshift__: lambda self, other: -1,
  1287. Tensor.__rshift__: lambda self, other: -1,
  1288. Tensor.__rrshift__: lambda self, other: -1,
  1289. Tensor.__irshift__: lambda self, other: -1,
  1290. Tensor.__and__: lambda self, other: -1,
  1291. Tensor.__or__: lambda self, other: -1,
  1292. Tensor.__xor__: lambda self, other: -1,
  1293. Tensor.__float__: lambda self: -1,
  1294. Tensor.__complex__: lambda self: -1,
  1295. Tensor.__array__: lambda self, dtype: -1,
  1296. Tensor.__bool__: lambda self: -1,
  1297. Tensor.__contains__: lambda self, other: -1,
  1298. Tensor.__neg__: lambda self: -1,
  1299. Tensor.__invert__: lambda self: -1,
  1300. Tensor.__mod__: lambda self, other: -1,
  1301. Tensor.__rmod__: lambda self, other: -1,
  1302. Tensor.__imod__: lambda self, other: -1,
  1303. Tensor.__array_wrap__: lambda self, array: -1,
  1304. Tensor.__getitem__: lambda self, idx: -1,
  1305. Tensor.__deepcopy__: lambda self, memo: -1,
  1306. Tensor.__int__: lambda self: -1,
  1307. Tensor.__long__: lambda self: -1,
  1308. Tensor.__index__: lambda self: -1,
  1309. Tensor.__len__: lambda self: -1,
  1310. Tensor.__format__: lambda self, format_spec: -1,
  1311. Tensor.__reduce_ex__: lambda self, proto: -1,
  1312. Tensor.__reversed__: lambda self: -1,
  1313. Tensor.__repr__: lambda self, *, tensor_contents=None: -1,
  1314. Tensor.__setitem__: lambda self, k, v: -1,
  1315. Tensor.__setstate__: lambda self, d: -1,
  1316. Tensor.T.__get__: lambda self: -1,
  1317. Tensor.H.__get__: lambda self: -1,
  1318. Tensor.mT.__get__: lambda self: -1,
  1319. Tensor.mH.__get__: lambda self: -1,
  1320. Tensor._backward_hooks.__get__: lambda self: -1,
  1321. Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1,
  1322. Tensor._base.__get__: lambda self: -1,
  1323. Tensor._cdata.__get__: lambda self: -1,
  1324. Tensor.grad.__get__: lambda self: -1,
  1325. Tensor._grad.__get__: lambda self: -1,
  1326. Tensor._grad_fn.__get__: lambda self: -1,
  1327. Tensor.grad_fn.__get__: lambda self: -1,
  1328. Tensor._version.__get__: lambda self: -1,
  1329. Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1,
  1330. Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,
  1331. Tensor._clear_non_serializable_cached_data: lambda self: -1,
  1332. Tensor.data.__get__: lambda self: -1,
  1333. Tensor.device.__get__: lambda self: -1,
  1334. Tensor.dtype.__get__: lambda self: -1,
  1335. Tensor.is_cuda.__get__: lambda self: -1,
  1336. Tensor.is_cpu.__get__: lambda self: -1,
  1337. Tensor.is_xla.__get__: lambda self: -1,
  1338. Tensor.is_xpu.__get__: lambda self: -1,
  1339. Tensor.is_ipu.__get__: lambda self: -1,
  1340. Tensor.is_leaf.__get__: lambda self: -1,
  1341. Tensor.retains_grad.__get__: lambda self: -1,
  1342. Tensor.is_meta.__get__: lambda self: -1,
  1343. Tensor.is_mps.__get__: lambda self: -1,
  1344. Tensor.is_mtia.__get__: lambda self: -1,
  1345. Tensor.is_nested.__get__: lambda self: -1,
  1346. Tensor.is_maia.__get__: lambda self: -1,
  1347. Tensor.is_mkldnn.__get__: lambda self: -1,
  1348. Tensor.is_quantized.__get__: lambda self: -1,
  1349. Tensor.is_sparse.__get__: lambda self: -1,
  1350. Tensor.is_sparse_csr.__get__: lambda self: -1,
  1351. Tensor.is_vulkan.__get__: lambda self: -1,
  1352. Tensor.itemsize.__get__: lambda self: -1,
  1353. Tensor.layout.__get__: lambda self: -1,
  1354. Tensor.name.__get__: lambda self: -1,
  1355. Tensor.names.__get__: lambda self: -1,
  1356. Tensor.nbytes.__get__: lambda self: -1,
  1357. Tensor.ndim.__get__: lambda self: -1,
  1358. Tensor.output_nr.__get__: lambda self: -1,
  1359. Tensor.requires_grad.__get__: lambda self: -1,
  1360. Tensor.shape.__get__: lambda self: -1,
  1361. Tensor.volatile.__get__: lambda self: -1,
  1362. Tensor.real.__get__: lambda self: -1,
  1363. Tensor.imag.__get__: lambda self: -1,
  1364. Tensor.__cuda_array_interface__.__get__: lambda self: -1,
  1365. Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1,
  1366. Tensor._dimI: lambda self: -1,
  1367. Tensor._dimV: lambda self: -1,
  1368. Tensor._indices: lambda self: -1,
  1369. Tensor._is_view: lambda self: -1,
  1370. Tensor._nnz: lambda self: -1,
  1371. Tensor.crow_indices: lambda self: -1,
  1372. Tensor.col_indices: lambda self: -1,
  1373. Tensor.ccol_indices: lambda self: -1,
  1374. Tensor.row_indices: lambda self: -1,
  1375. Tensor._update_names: lambda self, names, inplace: -1,
  1376. Tensor._values: lambda self: -1,
  1377. Tensor.adjoint: lambda self: -1,
  1378. Tensor.align_as: lambda self, other: -1,
  1379. Tensor.align_to: lambda self, order, ellipsis_idx: -1,
  1380. Tensor.apply_: lambda self, callable: -1,
  1381. Tensor.as_strided: lambda self, size, stride: -1,
  1382. Tensor.as_strided_: lambda self, size, stride: -1,
  1383. Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1,
  1384. Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1,
  1385. Tensor.bool: lambda self, memory_format=torch.preserve_format: -1,
  1386. Tensor.byte: lambda self, memory_format=torch.preserve_format: -1,
  1387. Tensor.char: lambda self, memory_format=torch.preserve_format: -1,
  1388. Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1,
  1389. Tensor.coalesce: lambda self: -1,
  1390. Tensor._coalesced_: lambda self, coalesced: -1,
  1391. Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1,
  1392. Tensor.copy_: lambda self, src, non_blocking=False: -1,
  1393. Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1,
  1394. Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1,
  1395. Tensor.mtia: lambda self, memory_format=torch.preserve_format: -1,
  1396. Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
  1397. Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1,
  1398. Tensor.data_ptr: lambda self: -1,
  1399. Tensor.dense_dim: lambda self: -1,
  1400. Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
  1401. Tensor.dim: lambda self: -1,
  1402. Tensor.dim_order: lambda self, ambiguity_check=False: -1,
  1403. Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
  1404. Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
  1405. Tensor.element_size: lambda self: -1,
  1406. Tensor.expand: lambda self, size: -1,
  1407. Tensor.expand_as: lambda self, other: -1,
  1408. Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1,
  1409. Tensor.fill_: lambda self, value: -1,
  1410. Tensor.fill_diagonal_: lambda self, value: -1,
  1411. Tensor.float: lambda self, memory_format=torch.preserve_format: -1,
  1412. Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1,
  1413. Tensor.geometric_: lambda self, p, *, generator=None: -1,
  1414. Tensor.get_device: lambda self: -1,
  1415. Tensor.half: lambda self, memory_format=torch.preserve_format: -1,
  1416. Tensor.chalf: lambda self, memory_format=torch.preserve_format: -1,
  1417. Tensor.has_names: lambda self: -1,
  1418. Tensor.indices: lambda self: -1,
  1419. Tensor.int: lambda self, memory_format=torch.preserve_format: -1,
  1420. Tensor.is_coalesced: lambda self: -1,
  1421. Tensor.is_contiguous: lambda self: -1,
  1422. Tensor.is_inference: lambda self: -1,
  1423. Tensor.is_pinned: lambda self: -1,
  1424. Tensor.is_set_to: lambda self, tensor: -1,
  1425. Tensor.is_shared: lambda self: -1,
  1426. Tensor.item: lambda self: -1,
  1427. Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1,
  1428. Tensor.log_softmax: lambda self, dim: -1,
  1429. Tensor.long: lambda self, memory_format=torch.preserve_format: -1,
  1430. Tensor.map_: lambda self, tensor, callable: -1,
  1431. Tensor.map2_: lambda self, x, y, callable: -1,
  1432. Tensor.mm: lambda self, mat2, out_dtype=None: -1,
  1433. Tensor.module_load: lambda self, other, assign=False: -1,
  1434. Tensor.narrow_copy: lambda self, dimension, start, length: -1,
  1435. Tensor.ndimension: lambda self: -1,
  1436. Tensor.nelement: lambda self: -1,
  1437. Tensor._nested_tensor_size: lambda self: -1,
  1438. Tensor._nested_tensor_storage_offsets: lambda self: -1,
  1439. Tensor._nested_tensor_strides: lambda self: -1,
  1440. Tensor.normal_: lambda self: -1,
  1441. Tensor.numpy: lambda self: -1,
  1442. Tensor.permute: lambda self, dim: -1,
  1443. Tensor.pin_memory: lambda self: -1,
  1444. Tensor.put_: lambda self, indices, tensor, accumulate=False: -1,
  1445. Tensor.qscheme: lambda self: -1,
  1446. Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1,
  1447. Tensor.record_stream: lambda self, stream: -1,
  1448. Tensor.refine_names: lambda self, names: -1,
  1449. Tensor.register_hook: lambda self, hook: -1,
  1450. Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1,
  1451. Tensor.rename: lambda self, name: -1,
  1452. Tensor.repeat: lambda self, *size: -1,
  1453. Tensor.requires_grad_: lambda self, requires_grad=True: -1,
  1454. Tensor.reshape_as: lambda self, other: -1,
  1455. Tensor.resize: lambda self, *size: -1,
  1456. Tensor.resize_: lambda self, size: -1,
  1457. Tensor.resize_as: lambda self, other: -1,
  1458. Tensor.resize_as_sparse_: lambda self, other: -1,
  1459. Tensor.retain_grad: lambda self: -1,
  1460. Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
  1461. Tensor.select_scatter: lambda self, src, dim, index: -1,
  1462. Tensor.share_memory_: lambda self: -1,
  1463. Tensor.short: lambda self, memory_format=torch.preserve_format: -1,
  1464. Tensor.size: lambda self: -1,
  1465. Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1,
  1466. Tensor.sparse_dim: lambda self: -1,
  1467. Tensor.sparse_mask: lambda self, mask: -1,
  1468. Tensor._sparse_mask_projection: lambda self, mask, accumulate_matches=False: -1,
  1469. Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
  1470. Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
  1471. Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
  1472. Tensor.storage: lambda self: -1,
  1473. Tensor.untyped_storage: lambda self: -1,
  1474. Tensor.storage_offset: lambda self: -1,
  1475. Tensor.storage_type: lambda self: -1,
  1476. Tensor.sum_to_size: lambda self, size: -1,
  1477. Tensor.tile: lambda self, *reps: -1,
  1478. Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
  1479. Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1,
  1480. Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1,
  1481. Tensor.to_sparse: lambda self: -1,
  1482. Tensor.tolist: lambda self: -1,
  1483. Tensor.to_mkldnn: lambda self: -1,
  1484. Tensor.type_as: lambda self, other: -1,
  1485. Tensor.unfold: lambda self, dimension, size, step: -1,
  1486. Tensor.uniform_: lambda self, from_=0, to=1: -1,
  1487. Tensor.values: lambda self: -1,
  1488. Tensor.view: lambda self, shape: -1,
  1489. Tensor.view_as: lambda self, other: -1,
  1490. Tensor.zero_: lambda self: -1,
  1491. Tensor.__dlpack__: lambda self, stream=None, max_version=None, dl_device=None, copy=None: -1,
  1492. Tensor.__dlpack_device__: lambda self: -1,
  1493. torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
  1494. } # fmt: skip
  1495. privateuse1_backend_name = (
  1496. torch.utils.backend_registration._privateuse1_backend_name
  1497. )
  1498. if hasattr(Tensor, privateuse1_backend_name):
  1499. ret[getattr(Tensor, privateuse1_backend_name)] = (
  1500. lambda self, device=None, non_blocking=False, **kwargs: -1
  1501. )
  1502. ret[getattr(Tensor, f"is_{privateuse1_backend_name}").__get__] = lambda self: -1
  1503. ret2 = {}
  1504. ignored = get_ignored_functions()
  1505. for k, v in ret.items():
  1506. # Generate methods like __add__ and add_ by default from add
  1507. names = [
  1508. k.__name__, # Default method
  1509. k.__name__ + "_", # Inplace variant
  1510. "__" + k.__name__ + "__", # Dunder method
  1511. "__i" + k.__name__ + "__", # Inplace dunder method
  1512. "__r" + k.__name__ + "__", # Reverse dunder method
  1513. ]
  1514. if k.__name__.startswith("bitwise_"):
  1515. # bitwise_<op> have dunder methods of the form __<op>__
  1516. # And so on.
  1517. subname = k.__name__[len("bitwise_") :]
  1518. names.extend(
  1519. ["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"]
  1520. )
  1521. for name in names:
  1522. func = getattr(Tensor, name, None)
  1523. if callable(func) and func not in ret and func not in ignored:
  1524. ret2[func] = v
  1525. ret.update(ret2)
  1526. return ret
  1527. def wrap_torch_function(dispatcher: Callable):
  1528. """Wraps a given function with ``__torch_function__`` -related functionality.
  1529. Parameters
  1530. ----------
  1531. dispatcher: Callable
  1532. A callable that returns an iterable of Tensor-likes passed into the function.
  1533. Note
  1534. ----
  1535. This decorator may reduce the performance of your code. Generally, it's enough to express
  1536. your code as a series of functions that, themselves, support __torch_function__. If you
  1537. find yourself in the rare situation where this is not the case, e.g. if you're wrapping a
  1538. low-level library and you also need it to work for Tensor-likes, then this function is available.
  1539. Examples
  1540. --------
  1541. >>> def dispatcher(a): # Must have the same signature as func
  1542. ... return (a,)
  1543. >>> @torch.overrides.wrap_torch_function(dispatcher)
  1544. >>> def func(a): # This will make func dispatchable by __torch_function__
  1545. ... return a + 0
  1546. """
  1547. def inner(func):
  1548. @functools.wraps(func)
  1549. def wrapped(*args, **kwargs):
  1550. relevant_args = dispatcher(*args, **kwargs)
  1551. if has_torch_function(relevant_args):
  1552. return handle_torch_function(wrapped, relevant_args, *args, **kwargs)
  1553. return func(*args, **kwargs)
  1554. return wrapped
  1555. return inner
  1556. def _get_overloaded_args(
  1557. relevant_args: Iterable[Any],
  1558. get_type_fn: Optional[Callable[[Any], type]] = None,
  1559. ) -> list[Any]:
  1560. """Returns a list of arguments on which to call __torch_function__.
  1561. Checks arguments in relevant_args for __torch_function__ implementations,
  1562. storing references to the arguments and their types in overloaded_args and
  1563. overloaded_types in order of calling precedence. Only distinct types are
  1564. considered. If a type is a subclass of another type it will have higher
  1565. precedence, otherwise the precedence order is the same as the order of
  1566. arguments in relevant_args, that is, from left-to-right in the argument list.
  1567. The precedence-determining algorithm implemented in this function is
  1568. described in `NEP-0018`_.
  1569. See torch::append_overloaded_arg for the equivalent function in the C++
  1570. implementation.
  1571. Parameters
  1572. ----------
  1573. relevant_args : iterable of array-like
  1574. Iterable of array-like arguments to check for __torch_function__
  1575. methods.
  1576. get_type_fn : callable, optional
  1577. Function to call on each argument in relevant_args to get its type.
  1578. Returns
  1579. -------
  1580. overloaded_args : list
  1581. Arguments from relevant_args on which to call __torch_function__
  1582. methods, in the order in which they should be called.
  1583. .. _NEP-0018:
  1584. https://numpy.org/neps/nep-0018-array-function-protocol.html
  1585. """
  1586. if get_type_fn is None:
  1587. get_type_fn = type
  1588. # If torch function is not enabled, there are no overloaded types
  1589. if not torch._C._is_torch_function_enabled():
  1590. return []
  1591. # Runtime is O(num_arguments * num_unique_types)
  1592. overloaded_types: set[type] = set()
  1593. overloaded_args: list[Any] = []
  1594. for arg in relevant_args:
  1595. arg_type = get_type_fn(arg)
  1596. # We only collect arguments if they have a unique type, which ensures
  1597. # reasonable performance even with a long list of possibly overloaded
  1598. # arguments.
  1599. #
  1600. # NB: Important to exclude _disabled_torch_function_impl, otherwise
  1601. # https://github.com/pytorch/pytorch/issues/64687
  1602. if (
  1603. arg_type not in overloaded_types
  1604. and hasattr(arg_type, "__torch_function__")
  1605. and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl
  1606. ):
  1607. # Create lists explicitly for the first type (usually the only one
  1608. # done) to avoid setting up the iterator for overloaded_args.
  1609. if overloaded_types:
  1610. overloaded_types.add(arg_type)
  1611. # By default, insert argument at the end, but if it is
  1612. # subclass of another argument, insert it before that argument.
  1613. # This ensures "subclasses before superclasses".
  1614. index = len(overloaded_args)
  1615. for i, old_arg in enumerate(overloaded_args):
  1616. if issubclass(arg_type, get_type_fn(old_arg)):
  1617. index = i
  1618. break
  1619. overloaded_args.insert(index, arg)
  1620. else:
  1621. overloaded_types = {arg_type}
  1622. overloaded_args = [arg]
  1623. return overloaded_args
  1624. def handle_torch_function(
  1625. public_api: Callable,
  1626. relevant_args: Iterable[Any],
  1627. *args,
  1628. **kwargs,
  1629. ) -> Any:
  1630. """Implement a function with checks for ``__torch_function__`` overrides.
  1631. See torch::autograd::handle_torch_function for the equivalent of this
  1632. function in the C++ implementation.
  1633. Arguments
  1634. ---------
  1635. public_api : function
  1636. Function exposed by the public torch API originally called like
  1637. ``public_api(*args, **kwargs)`` on which arguments are now being
  1638. checked.
  1639. relevant_args : iterable
  1640. Iterable of arguments to check for __torch_function__ methods.
  1641. args : tuple
  1642. Arbitrary positional arguments originally passed into ``public_api``.
  1643. kwargs : tuple
  1644. Arbitrary keyword arguments originally passed into ``public_api``.
  1645. Returns
  1646. -------
  1647. object
  1648. Result from calling ``implementation`` or an ``__torch_function__``
  1649. method, as appropriate.
  1650. Raises
  1651. ------
  1652. TypeError : if no implementation is found.
  1653. Example
  1654. -------
  1655. >>> def func(a):
  1656. ... if has_torch_function_unary(a):
  1657. ... return handle_torch_function(func, (a,), a)
  1658. ... return a + 0
  1659. """
  1660. # Check for __torch_function__ methods.
  1661. overloaded_args = _get_overloaded_args(relevant_args)
  1662. # overloaded_args already have unique types.
  1663. types = tuple(map(type, overloaded_args))
  1664. # Check for __torch_function__ mode.
  1665. if _is_torch_function_mode_enabled():
  1666. # if we're here, the mode must be set to a TorchFunctionStackMode
  1667. # this unsets it and calls directly into TorchFunctionStackMode's torch function
  1668. with _pop_mode_temporarily() as mode:
  1669. result = mode.__torch_function__(public_api, types, args, kwargs)
  1670. if result is not NotImplemented:
  1671. return result
  1672. # Call overrides
  1673. for overloaded_arg in overloaded_args:
  1674. # This call needs to become a classmethod call in the future.
  1675. # See https://github.com/pytorch/pytorch/issues/63767
  1676. torch_func_method = overloaded_arg.__torch_function__
  1677. if (
  1678. hasattr(torch_func_method, "__self__")
  1679. and torch_func_method.__self__ is overloaded_arg
  1680. and torch_func_method is not torch._C._disabled_torch_function_impl
  1681. ):
  1682. warnings.warn(
  1683. "Defining your `__torch_function__ as a plain method is deprecated and "
  1684. "will be an error in future, please define it as a classmethod.",
  1685. DeprecationWarning,
  1686. )
  1687. # Use `public_api` instead of `implementation` so __torch_function__
  1688. # implementations can do equality/identity comparisons.
  1689. result = torch_func_method(public_api, types, args, kwargs)
  1690. if result is not NotImplemented:
  1691. return result
  1692. func_name = f"{public_api.__module__}.{public_api.__name__}"
  1693. msg = (
  1694. f"no implementation found for '{func_name}' on types that implement "
  1695. f"__torch_function__: {[type(arg) for arg in overloaded_args]}"
  1696. )
  1697. if _is_torch_function_mode_enabled():
  1698. msg += f" nor in mode {_get_current_function_mode()}"
  1699. raise TypeError(msg)
  1700. has_torch_function = _add_docstr(
  1701. _has_torch_function,
  1702. r"""Check for __torch_function__ implementations in the elements of an iterable
  1703. or if a __torch_function__ mode is enabled. Considers exact ``Tensor`` s
  1704. and ``Parameter`` s non-dispatchable. Use this to guard a call to
  1705. :func:`handle_torch_function`; don't use it to test if something
  1706. is Tensor-like, use :func:`is_tensor_like` instead.
  1707. Arguments
  1708. ---------
  1709. relevant_args : iterable
  1710. Iterable or arguments to check for __torch_function__ methods.
  1711. Returns
  1712. -------
  1713. bool
  1714. True if any of the elements of relevant_args have __torch_function__
  1715. implementations, False otherwise.
  1716. See Also
  1717. ________
  1718. torch.is_tensor_like
  1719. Checks if something is a Tensor-like, including an exact ``Tensor``.
  1720. """,
  1721. )
  1722. has_torch_function_unary = _add_docstr(
  1723. _has_torch_function_unary,
  1724. r"""Special case of `has_torch_function` for single inputs.
  1725. Instead of:
  1726. `has_torch_function((t,))`
  1727. call:
  1728. `has_torch_function_unary(t)`
  1729. which skips unnecessary packing and unpacking work.
  1730. """,
  1731. )
  1732. has_torch_function_variadic = _add_docstr(
  1733. _has_torch_function_variadic,
  1734. r"""Special case of `has_torch_function` that skips tuple creation.
  1735. This uses the METH_FASTCALL protocol introduced in Python 3.7
  1736. Instead of:
  1737. `has_torch_function((a, b))`
  1738. call:
  1739. `has_torch_function_variadic(a, b)`
  1740. which skips unnecessary packing and unpacking work.
  1741. """,
  1742. )
  1743. @functools.cache
  1744. def _get_overridable_functions() -> tuple[
  1745. dict[Any, list[Callable]], dict[Callable, str]
  1746. ]:
  1747. overridable_funcs = collections.defaultdict(list)
  1748. index = {}
  1749. tested_namespaces = [
  1750. ("torch", torch, torch.__all__),
  1751. ("torch.functional", torch.functional, torch.functional.__all__),
  1752. ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
  1753. ("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
  1754. ("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
  1755. ("torch.linalg", torch.linalg, dir(torch.linalg)),
  1756. ("torch.fft", torch.fft, dir(torch.fft)),
  1757. ("torch.special", torch.special, dir(torch.special)),
  1758. ]
  1759. for namespace_str, namespace, ns_funcs in tested_namespaces:
  1760. for func_name in ns_funcs:
  1761. ignore = False
  1762. # ignore private functions or functions that are deleted in torch.__init__
  1763. if namespace is not torch.Tensor:
  1764. if func_name.startswith("__"):
  1765. continue
  1766. elif func_name.startswith("_"):
  1767. ignore = True
  1768. elif func_name.endswith("_"):
  1769. ignore = True
  1770. elif not func_name[0].islower():
  1771. ignore = True
  1772. elif func_name == "unique_dim":
  1773. continue
  1774. else:
  1775. func = getattr(namespace, func_name)
  1776. if getattr(object, func_name, None) == func:
  1777. continue
  1778. if func_name == "__weakref__":
  1779. continue
  1780. func = getattr(namespace, func_name)
  1781. if namespace is torch.Tensor and getattr(object, func_name, None) == func:
  1782. continue
  1783. # ignore re-exported modules
  1784. if isinstance(func, types.ModuleType):
  1785. continue
  1786. # ignore __future__ imports
  1787. if isinstance(func, __future__._Feature):
  1788. continue
  1789. if not callable(func) and hasattr(func, "__get__"):
  1790. index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
  1791. index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
  1792. if ignore:
  1793. continue
  1794. if func.__get__ in get_ignored_functions():
  1795. msg = (
  1796. "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
  1797. "but still has an explicit override"
  1798. )
  1799. assert func.__get__ not in get_testing_overrides(), msg.format(
  1800. namespace, func.__name__
  1801. )
  1802. continue
  1803. else:
  1804. overridable_funcs[func].append(func.__get__)
  1805. continue
  1806. if not callable(func):
  1807. continue
  1808. index[func] = f"{namespace_str}.{func_name}"
  1809. if ignore:
  1810. continue
  1811. # cannot be overridden by __torch_function__
  1812. if func in get_ignored_functions():
  1813. msg = (
  1814. "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
  1815. "but still has an explicit override"
  1816. )
  1817. assert func not in get_testing_overrides(), msg.format(
  1818. namespace, func.__name__
  1819. )
  1820. continue
  1821. overridable_funcs[namespace].append(func)
  1822. return overridable_funcs, index
  1823. @_disable_user_warnings
  1824. def get_overridable_functions() -> dict[Any, list[Callable]]:
  1825. """List functions that are overridable via __torch_function__
  1826. Returns
  1827. -------
  1828. Dict[Any, List[Callable]]
  1829. A dictionary that maps namespaces that contain overridable functions
  1830. to functions in that namespace that can be overridden.
  1831. """
  1832. return _get_overridable_functions()[0]
  1833. @_disable_user_warnings
  1834. def resolve_name(f):
  1835. """Get a human readable string name for a function passed to
  1836. __torch_function__
  1837. Arguments
  1838. ---------
  1839. f : Callable
  1840. Function to resolve the name of.
  1841. Returns
  1842. -------
  1843. str
  1844. Name of the function; if eval'ed it should give back the input
  1845. function.
  1846. """
  1847. if isinstance(f, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
  1848. return str(f)
  1849. return _get_overridable_functions()[1].get(f)
  1850. @functools.cache
  1851. def _get_tensor_methods() -> set[Callable]:
  1852. """Returns a set of the overridable methods on ``torch.Tensor``"""
  1853. overridable_funcs = get_overridable_functions()
  1854. methods = set(overridable_funcs[torch.Tensor])
  1855. return methods
  1856. @_disable_user_warnings
  1857. def is_tensor_method_or_property(func: Callable) -> bool:
  1858. """
  1859. Returns True if the function passed in is a handler for a
  1860. method or property belonging to ``torch.Tensor``, as passed
  1861. into ``__torch_function__``.
  1862. .. note::
  1863. For properties, their ``__get__`` method must be passed in.
  1864. This may be needed, in particular, for the following reasons:
  1865. 1. Methods/properties sometimes don't contain a `__module__` slot.
  1866. 2. They require that the first passed-in argument is an instance
  1867. of ``torch.Tensor``.
  1868. Examples
  1869. --------
  1870. >>> is_tensor_method_or_property(torch.Tensor.add)
  1871. True
  1872. >>> is_tensor_method_or_property(torch.add)
  1873. False
  1874. """
  1875. return func in _get_tensor_methods() or func.__name__ == "__get__"
  1876. def is_tensor_like(inp):
  1877. """
  1878. Returns ``True`` if the passed-in input is a Tensor-like.
  1879. Currently, this occurs whenever there's a ``__torch_function__``
  1880. attribute on the type of the input.
  1881. Examples
  1882. --------
  1883. A subclass of tensor is generally a Tensor-like.
  1884. >>> class SubTensor(torch.Tensor): ...
  1885. >>> is_tensor_like(SubTensor([0]))
  1886. True
  1887. Built-in or user types aren't usually Tensor-like.
  1888. >>> is_tensor_like(6)
  1889. False
  1890. >>> is_tensor_like(None)
  1891. False
  1892. >>> class NotATensor: ...
  1893. >>> is_tensor_like(NotATensor())
  1894. False
  1895. But, they can be made Tensor-like by implementing __torch_function__.
  1896. >>> class TensorLike:
  1897. ... @classmethod
  1898. ... def __torch_function__(cls, func, types, args, kwargs):
  1899. ... return -1
  1900. >>> is_tensor_like(TensorLike())
  1901. True
  1902. """
  1903. return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
  1904. class TorchFunctionMode:
  1905. """
  1906. A ``TorchFunctionMode`` allows you to override the meaning of all
  1907. ``__torch_function__`` overridable functions within a dynamic scope,
  1908. without having to actually create a tensor subclass or manually
  1909. monkey-patch functions in the PyTorch API. Some common situations
  1910. where you should use a mode:
  1911. * You want to override the meaning of factory functions, or other
  1912. functions that do not otherwise take a tensor as an argument
  1913. (these cannot be overridden with tensor subclasses).
  1914. * You want to override the behavior of all functions without needing
  1915. to wrap your inputs in tensor subclasses; e.g., if you are just
  1916. interested in logging intermediate computations.
  1917. * You want to control the order of execution of various tensor
  1918. subclasses explicitly, rather than implicitly via the return of
  1919. ``NotImplemented``.
  1920. Independent subclasses of :class:`TorchFunctionMode` are compositional:
  1921. modes can be pushed onto a stack using ``with MyMode():``.
  1922. When you call functions in the PyTorch API inside your
  1923. ``__torch_function__`` implementation, by default, they will forward on to
  1924. the next mode on the mode stack. If you want recursively call back into
  1925. your current ``__torch_function__`` implementation, either explicitly
  1926. invoke ``self.__torch_function__(...)``, or use the context manager
  1927. ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
  1928. API self-referential (beware of infinite loops, in this case!)
  1929. """
  1930. inner: "TorchFunctionMode"
  1931. # Force metaclass to generate constructor at the base of the hierarchy
  1932. def __init__(self) -> None:
  1933. pass
  1934. def __torch_function__(self, func, types, args=(), kwargs=None):
  1935. raise NotImplementedError
  1936. def __enter__(self):
  1937. _push_mode(self)
  1938. return self
  1939. def __exit__(self, exc_type, exc_val, exc_tb):
  1940. _pop_mode()
  1941. @classmethod
  1942. def push(cls, *args, **kwargs):
  1943. warnings.warn(
  1944. "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
  1945. )
  1946. instance = cls(*args, **kwargs)
  1947. return instance
  1948. def _get_current_function_mode():
  1949. stack_len = _len_torch_function_stack()
  1950. return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None
  1951. def _get_current_function_mode_stack():
  1952. stack_len = _len_torch_function_stack()
  1953. return [_get_function_stack_at(i) for i in range(stack_len)]
  1954. def _push_mode(mode):
  1955. _push_on_torch_function_stack(mode)
  1956. def _pop_mode():
  1957. old = _pop_torch_function_stack()
  1958. return old
  1959. @contextlib.contextmanager
  1960. def _pop_mode_temporarily():
  1961. old = _pop_mode()
  1962. try:
  1963. yield old
  1964. finally:
  1965. _push_mode(old)
  1966. class BaseTorchFunctionMode(TorchFunctionMode):
  1967. def __torch_function__(self, func, types, args=(), kwargs=None):
  1968. if kwargs is None:
  1969. kwargs = {}
  1970. return func(*args, **kwargs)
  1971. @contextlib.contextmanager
  1972. def _enable_torch_function():
  1973. old_state = torch._C._get_torch_function_state()
  1974. try:
  1975. torch._C._set_torch_function_state(torch._C._TorchFunctionState.ENABLED)
  1976. yield
  1977. finally:
  1978. torch._C._set_torch_function_state(old_state)
  1979. @contextlib.contextmanager
  1980. def enable_reentrant_dispatch():
  1981. # NB: this can't simply be
  1982. # `enable_reentrant_dispatch = torch._C._RestorePythonTLSSnapshot`
  1983. # because:
  1984. # 1. torch._C._RestorePythonTLSSnapshot is unavailable when this file
  1985. # initially gets imported. Probably an import order thing.
  1986. # 2. enable_reentrant_dispatch is technically public API; assigning
  1987. # it the object would change the __module__ to look private.
  1988. with torch._C._RestorePythonTLSSnapshot():
  1989. try:
  1990. yield
  1991. finally:
  1992. pass