_tensor.py 73 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837
  1. # mypy: allow-untyped-defs
  2. import copyreg
  3. import enum
  4. import functools
  5. import warnings
  6. from collections import OrderedDict
  7. from copy import deepcopy
  8. from numbers import Number
  9. from typing import Any, Callable, cast, Optional, TypeVar, Union
  10. from typing_extensions import Concatenate, ParamSpec
  11. import torch
  12. import torch._C as _C
  13. from torch._namedtensor_internals import (
  14. check_serializing_named_tensor,
  15. is_ellipsis,
  16. resolve_ellipsis,
  17. single_ellipsis_index,
  18. unzip_namedshape,
  19. update_names,
  20. )
  21. from torch.overrides import (
  22. get_default_nowrap_functions,
  23. handle_torch_function,
  24. has_torch_function,
  25. has_torch_function_unary,
  26. has_torch_function_variadic,
  27. )
  28. _P = ParamSpec("_P")
  29. _TensorLike = TypeVar("_TensorLike", bound=_C.TensorBase)
  30. def _handle_torch_function_and_wrap_type_error_to_not_implemented(
  31. f: Callable[Concatenate[_TensorLike, _P], "Tensor"],
  32. ) -> Callable[Concatenate[_TensorLike, _P], "Tensor"]:
  33. @functools.wraps(f)
  34. def wrapped(self: _TensorLike, *args: _P.args, **kwargs: _P.kwargs) -> "Tensor":
  35. try:
  36. # See https://github.com/pytorch/pytorch/issues/75462
  37. sargs = self, *args
  38. if has_torch_function(sargs):
  39. return handle_torch_function(wrapped, sargs, *sargs, **kwargs)
  40. return f(self, *args, **kwargs)
  41. except TypeError:
  42. return NotImplemented
  43. return wrapped
  44. # Should not be used, this is kept only for BC of loading old serialized Tensor subclasses
  45. def _rebuild_from_type(func, type, args, dict):
  46. if type is Tensor:
  47. return func(*args)
  48. ret = func(*args).as_subclass(type)
  49. ret.__dict__ = dict
  50. return ret
  51. def _rebuild_from_type_v2(func, new_type, args, state):
  52. ret = func(*args)
  53. if type(ret) is not new_type:
  54. ret = ret.as_subclass(new_type)
  55. # Tensor does define __setstate__ even though it doesn't define
  56. # __getstate__. So only use __setstate__ if it is NOT the one defined
  57. # on Tensor
  58. if (
  59. getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
  60. is not Tensor.__setstate__
  61. ):
  62. ret.__setstate__(state)
  63. else:
  64. ret = torch._utils._set_obj_state(ret, state)
  65. return ret
  66. def _dtype_to_typestr(dtype):
  67. # CUDA devices are little-endian and tensors are stored in native byte
  68. # order. 1-byte entries are endian-agnostic.
  69. return {
  70. torch.complex64: "<c8",
  71. torch.complex128: "<c16",
  72. torch.bfloat16: "<V2", # Same as ml_dtypes.bfloat16.dtype.str.
  73. torch.float16: "<f2",
  74. torch.float32: "<f4",
  75. torch.float64: "<f8",
  76. torch.uint8: "|u1",
  77. torch.int8: "|i1",
  78. torch.uint16: "<u2",
  79. torch.int16: "<i2",
  80. torch.uint32: "<u4",
  81. torch.int32: "<i4",
  82. torch.uint64: "<u8",
  83. torch.int64: "<i8",
  84. torch.bool: "|b1",
  85. }[dtype]
  86. # NB: If you subclass Tensor, and want to share the subclassed class
  87. # across processes, you must also update torch/multiprocessing/reductions.py
  88. # to define a ForkingPickler serialization mode for the class.
  89. #
  90. # NB: If you add a new method to Tensor, you must update
  91. # torch/_C/__init__.pyi.in to add a type annotation for your method;
  92. # otherwise, it will not show up in autocomplete.
  93. class Tensor(torch._C.TensorBase):
  94. _is_param: bool
  95. def _clear_non_serializable_cached_data(self):
  96. r"""Clears any data cached in the tensor's ``__dict__`` that would prevent the tensor
  97. from being serialized.
  98. For example, subclasses with custom dispatched sizes / strides cache this info in
  99. non-serializable PyCapsules within the ``__dict__``, and this must be cleared out for
  100. serialization to function.
  101. Any subclass that overrides this MUST call ``super()._clear_non_serializable_cached_data().``
  102. Additional data cleared within the override must be able to be re-cached transparently
  103. to avoid breaking subclass functionality.
  104. """
  105. if has_torch_function_unary(self):
  106. return handle_torch_function(
  107. Tensor._clear_non_serializable_cached_data, (self,), self
  108. )
  109. # NB: Wrapper subclasses that implement custom-dispatched sizes / strides cache
  110. # this info via non-serializable PyCapsules.
  111. CACHED_SIZES_STRIDES_KEYS = [
  112. "_sym_sizes_capsule",
  113. "_sym_sizes_capsule_len",
  114. "_sym_strides_capsule",
  115. "_sym_strides_capsule_len",
  116. ]
  117. for key in CACHED_SIZES_STRIDES_KEYS:
  118. self.__dict__.pop(key, None)
  119. def __deepcopy__(self, memo):
  120. if has_torch_function_unary(self):
  121. return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
  122. if not self.is_leaf:
  123. raise RuntimeError(
  124. "Only Tensors created explicitly by the user "
  125. "(graph leaves) support the deepcopy protocol at the moment. "
  126. "If you were attempting to deepcopy a module, this may be because "
  127. "of a torch.nn.utils.weight_norm usage, "
  128. "see https://github.com/pytorch/pytorch/pull/103001"
  129. )
  130. if id(self) in memo:
  131. return memo[id(self)]
  132. with torch.no_grad():
  133. # TODO: skipping storage copy is wrong for meta, as meta
  134. # does accurate alias tracking; however, the code below
  135. # doesn't work because of
  136. # https://github.com/pytorch/pytorch/issues/47442
  137. # Update the test in test_serialization if you remove 'meta' from here
  138. if (
  139. self.is_sparse
  140. or self.device.type
  141. in ["lazy", "xla", "mtia", "mps", "maia", "meta", "ipu"]
  142. or (
  143. not torch._C._has_storage(self)
  144. and self.device.type == torch._C._get_privateuse1_backend_name()
  145. )
  146. or (type(self) is not Tensor and self.data_ptr() == 0)
  147. ):
  148. new_tensor = self.clone()
  149. if type(new_tensor) is not type(self):
  150. raise RuntimeError(
  151. "The default implementation of __deepcopy__() for wrapper subclasses "
  152. "only works for subclass types that implement clone() and for which "
  153. "cloning returns another instance of the same subclass. You should either "
  154. "properly implement clone() for your subclass or override __deepcopy__() "
  155. "if it is intended behavior for clone() to return an instance of a "
  156. "different type."
  157. )
  158. else:
  159. new_storage = self._typed_storage()._deepcopy(memo)
  160. if self.is_quantized:
  161. # quantizer_params can be different type based on torch attribute
  162. quantizer_params: Union[
  163. tuple[torch.qscheme, float, int],
  164. tuple[torch.qscheme, Tensor, Tensor, int],
  165. ]
  166. if self.qscheme() == torch.per_tensor_affine:
  167. quantizer_params = (
  168. self.qscheme(),
  169. self.q_scale(),
  170. self.q_zero_point(),
  171. )
  172. elif self.qscheme() in (
  173. torch.per_channel_affine,
  174. torch.per_channel_affine_float_qparams,
  175. ):
  176. quantizer_params = (
  177. self.qscheme(),
  178. self.q_per_channel_scales(),
  179. self.q_per_channel_zero_points(),
  180. self.q_per_channel_axis(),
  181. )
  182. else:
  183. raise RuntimeError(
  184. f"Unsupported qscheme {self.qscheme()} in deepcopy"
  185. )
  186. # TODO: Once we decide to break serialization FC, no longer
  187. # need to wrap with TypedStorage
  188. new_tensor = torch._utils._rebuild_qtensor(
  189. torch.storage.TypedStorage(
  190. wrap_storage=new_storage._untyped_storage,
  191. dtype=self.dtype,
  192. _internal=True,
  193. ),
  194. self.storage_offset(),
  195. self.size(),
  196. self.stride(),
  197. quantizer_params,
  198. self.requires_grad,
  199. self._backward_hooks,
  200. )
  201. if type(new_tensor) is not type(self):
  202. raise RuntimeError(
  203. "The default implementation of __deepcopy__() for quantized tensors "
  204. "expects the tensor returned by torch._utils._rebuild_qtensor() to "
  205. "match the type of the instance being copied. If you encounter this, "
  206. "please open an issue on PyTorch's GitHub."
  207. )
  208. else:
  209. new_tensor = self.new_empty([])
  210. if type(new_tensor) is not type(self):
  211. raise RuntimeError(
  212. "The default implementation of __deepcopy__() for non-wrapper subclasses "
  213. "only works for subclass types that implement new_empty() and for which "
  214. "that function returns another instance of the same subclass. You should "
  215. "either properly implement new_empty() for your subclass or override "
  216. "__deepcopy__() if it is intended behavior for new_empty() to return "
  217. "an instance of a different type."
  218. )
  219. new_tensor.set_(
  220. new_storage, self.storage_offset(), self.size(), self.stride()
  221. )
  222. if self.is_conj():
  223. new_tensor = new_tensor.conj_physical()
  224. if self.is_neg():
  225. new_tensor = new_tensor.neg()
  226. if self.requires_grad:
  227. new_tensor.requires_grad_()
  228. if self.grad is not None:
  229. new_tensor.grad = self.grad.__deepcopy__(memo)
  230. if type(self) is not Tensor:
  231. if type(new_tensor) is not type(self):
  232. raise RuntimeError(
  233. "Type of deepcopy result does not match the type of the source tensor. "
  234. "If you encounter this, please open an issue on PyTorch's GitHub."
  235. )
  236. # Plain Tensors don't have slots
  237. slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined]
  238. for slot in slots_to_save:
  239. if hasattr(self, slot):
  240. setattr(new_tensor, slot, deepcopy(getattr(self, slot), memo))
  241. # don't try to deepcopy non-serializable cached data
  242. self._clear_non_serializable_cached_data()
  243. new_tensor.__dict__ = deepcopy(self.__dict__, memo)
  244. memo[id(self)] = new_tensor
  245. return new_tensor
  246. def __reduce_ex__(self, proto):
  247. materialize_fake_tensors = (
  248. torch.serialization._serialization_tls.materialize_fake_tensors
  249. )
  250. state = torch._utils._get_obj_state(self)
  251. # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has
  252. # some state that cannot be pickled
  253. if (
  254. # TODO: remove hasattr, it's a hack to support versions of torch that
  255. # don't have _subclasses
  256. hasattr(torch, "_subclasses")
  257. and type(self) is torch._subclasses.fake_tensor.FakeTensor
  258. and materialize_fake_tensors
  259. ) or (type(self) is Tensor and not state):
  260. # Fast path for regular tensor without Python state.
  261. return self._reduce_ex_internal(proto)
  262. if has_torch_function_unary(self):
  263. return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto)
  264. func, args = self._reduce_ex_internal(proto)
  265. # sizes / strides cache needs to be cleared here because it'll just be re-cached
  266. # if cleared earlier. Note that state references the -actual- tensor dict.
  267. self._clear_non_serializable_cached_data()
  268. return (_rebuild_from_type_v2, (func, type(self), args, state))
  269. def storage(self):
  270. r"""
  271. storage() -> torch.TypedStorage
  272. Returns the underlying :class:`TypedStorage`.
  273. .. warning::
  274. :class:`TypedStorage` is deprecated. It will be removed in the future, and
  275. :class:`UntypedStorage` will be the only storage class. To access the
  276. :class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`.
  277. """
  278. if has_torch_function_unary(self):
  279. return handle_torch_function(Tensor.storage, (self,), self)
  280. torch.storage._warn_typed_storage_removal(stacklevel=2)
  281. return self._typed_storage()
  282. # For internal use only, to avoid raising deprecation warning
  283. def _typed_storage(self):
  284. untyped_storage = self.untyped_storage()
  285. return torch.TypedStorage(
  286. wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
  287. )
  288. def _reduce_ex_internal(self, proto):
  289. check_serializing_named_tensor(self)
  290. from torch.utils.hooks import warn_if_has_hooks
  291. # See Note [Don't serialize hooks]
  292. warn_if_has_hooks(self)
  293. backward_hooks: dict[Any, Any] = OrderedDict()
  294. skip_data = torch.serialization._serialization_tls.skip_data
  295. materialize_fake_tensors = (
  296. torch.serialization._serialization_tls.materialize_fake_tensors
  297. )
  298. if self.device.type in ["xla", "maia", "mtia"] or (
  299. not torch._C._has_storage(self)
  300. and self.device.type == torch._C._get_privateuse1_backend_name()
  301. ):
  302. if skip_data:
  303. raise RuntimeError(
  304. "Cannot serialize tensors on backends with no storage under skip_data context manager"
  305. )
  306. cpu_tensor = self.cpu()
  307. return (
  308. torch._utils._rebuild_device_tensor_from_cpu_tensor,
  309. (cpu_tensor, self.dtype, str(self.device), self.requires_grad),
  310. )
  311. if self.device.type == "meta":
  312. # NB: This implementation BREAKS storage sharing. Current
  313. # hypothesis is that no one cares for meta tensors.
  314. if skip_data:
  315. warnings.warn(
  316. "Serializing tensors on the meta device under skip_data context manager is a no-op"
  317. )
  318. arg_meta = (
  319. self.dtype,
  320. tuple(self.size()),
  321. self.stride(),
  322. self.requires_grad,
  323. )
  324. return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
  325. if self.is_quantized:
  326. if skip_data:
  327. raise RuntimeError(
  328. "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature"
  329. )
  330. # quantizer_params can be different type based on torch attribute
  331. quantizer_params: Union[
  332. tuple[torch.qscheme, float, int], tuple[Any, Tensor, Tensor, int]
  333. ]
  334. if self.qscheme() == torch.per_tensor_affine:
  335. quantizer_params = (
  336. torch.per_tensor_affine,
  337. self.q_scale(),
  338. self.q_zero_point(),
  339. )
  340. elif self.qscheme() in (
  341. torch.per_channel_affine,
  342. torch.per_channel_affine_float_qparams,
  343. ):
  344. # convert scales and zero points to tuple to avoid recursive calls
  345. # when/if we get multi-axis quantized tensors in the future, the shape
  346. # is recoverable from the main tensor shape
  347. quantizer_params = (
  348. torch.per_channel_affine,
  349. self.q_per_channel_scales(),
  350. self.q_per_channel_zero_points(),
  351. self.q_per_channel_axis(),
  352. )
  353. else:
  354. raise RuntimeError(
  355. f"Serialization is not supported for tensors of type {self.qscheme()}"
  356. )
  357. # TODO: Once we decide to break serialization FC, no longer
  358. # need to wrap with TypedStorage
  359. args_qtensor = (
  360. torch.storage.TypedStorage(
  361. wrap_storage=self._typed_storage()._untyped_storage,
  362. dtype=self.dtype,
  363. _internal=True,
  364. ),
  365. self.storage_offset(),
  366. tuple(self.size()),
  367. self.stride(),
  368. quantizer_params,
  369. self.requires_grad,
  370. backward_hooks,
  371. )
  372. return (torch._utils._rebuild_qtensor, args_qtensor)
  373. elif self.is_sparse:
  374. if self.layout == torch.sparse_coo:
  375. args_sparse = (
  376. self.layout,
  377. (self._indices(), self._values(), self.size(), self.is_coalesced()),
  378. )
  379. else:
  380. raise NotImplementedError(
  381. f"sparse tensor __reduce_ex__ for layout `{self.layout}`"
  382. )
  383. return (torch._utils._rebuild_sparse_tensor, args_sparse)
  384. elif self.layout in {
  385. torch.sparse_csr,
  386. torch.sparse_csc,
  387. torch.sparse_bsr,
  388. torch.sparse_bsc,
  389. }:
  390. if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
  391. compressed_indices, plain_indices = (
  392. self.crow_indices(),
  393. self.col_indices(),
  394. )
  395. else:
  396. compressed_indices, plain_indices = (
  397. self.ccol_indices(),
  398. self.row_indices(),
  399. )
  400. args_sparse_compressed = (
  401. self.layout,
  402. (
  403. compressed_indices,
  404. plain_indices,
  405. self.values(),
  406. self.size(),
  407. ),
  408. )
  409. return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
  410. elif self.is_nested:
  411. if skip_data:
  412. raise RuntimeError(
  413. "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature"
  414. )
  415. args_nested = (
  416. # NB: values() currently returns the storage as a buffer in an unsafe way.
  417. # Ideally, we'd use a private API for this instead. TODO: Switch to this if
  418. # we ever get around to adding it.
  419. self.values(),
  420. self._nested_tensor_size(),
  421. self._nested_tensor_strides(),
  422. self._nested_tensor_storage_offsets(),
  423. )
  424. return (torch._utils._rebuild_nested_tensor, args_nested)
  425. elif (
  426. type(self) is not torch.Tensor
  427. and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
  428. and (
  429. isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor)
  430. or (
  431. not isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
  432. and self.data_ptr() == 0
  433. )
  434. )
  435. ):
  436. arg_wrapper_subclass = (
  437. type(self),
  438. self.dtype,
  439. tuple(self.size()),
  440. self.stride(),
  441. self.storage_offset(),
  442. self.layout,
  443. self.device,
  444. self.requires_grad,
  445. )
  446. return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
  447. elif (
  448. type(self) is not torch.Tensor
  449. and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
  450. and (
  451. isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
  452. and not (skip_data and materialize_fake_tensors)
  453. )
  454. ):
  455. arg_wrapper_subclass = (
  456. type(self),
  457. self.dtype,
  458. tuple(self.size()),
  459. self.stride(),
  460. self.storage_offset(),
  461. self.layout,
  462. self.device,
  463. self.requires_grad,
  464. )
  465. return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
  466. else:
  467. v3_dtypes = torch.storage._new_dtypes()
  468. if self.dtype in v3_dtypes:
  469. rebuild_func = torch._utils._rebuild_tensor_v3
  470. storage = self.untyped_storage()
  471. else:
  472. # TODO: Once we decide to break serialization FC, no longer
  473. # need to wrap with TypedStorage
  474. rebuild_func = torch._utils._rebuild_tensor_v2 # type: ignore[assignment]
  475. storage = torch.storage.TypedStorage(
  476. wrap_storage=self._typed_storage()._untyped_storage,
  477. dtype=self.dtype,
  478. _internal=True,
  479. ) # type: ignore[assignment]
  480. # TODO: remove hasattr, it's a hack to support versions of torch that
  481. # don't have _subclasses
  482. if (
  483. hasattr(torch, "_subclasses")
  484. and isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
  485. and skip_data
  486. ):
  487. storage._fake_device = self.device
  488. args = (
  489. storage,
  490. self.storage_offset(),
  491. tuple(self.size()),
  492. self.stride(),
  493. self.requires_grad,
  494. backward_hooks,
  495. ) # previously was self._backward_hooks
  496. if isinstance(storage, torch.storage.UntypedStorage):
  497. args = args + (self.dtype,) # type: ignore[assignment]
  498. metadata = torch._utils.get_tensor_metadata(self)
  499. if metadata:
  500. args = args + (metadata,) # type: ignore[assignment]
  501. return (rebuild_func, args)
  502. def __setstate__(self, state):
  503. if has_torch_function_unary(self):
  504. return handle_torch_function(Tensor.__setstate__, (self,), self, state)
  505. # Warning: this method is NOT called when you torch.load() a tensor;
  506. # that is managed by _rebuild_tensor_v2
  507. if not self.is_leaf:
  508. raise RuntimeError("__setstate__ can be only called on leaf Tensors")
  509. if len(state) == 4:
  510. # legacy serialization of Tensor
  511. self.set_(*state)
  512. return
  513. elif len(state) == 5:
  514. # legacy serialization of Variable
  515. self.data = state[0]
  516. state = (state[3], state[4], state[2])
  517. # The setting of _backward_hooks is expected to be a no-op.
  518. # See Note [Don't serialize hooks]
  519. self.requires_grad, _, self._backward_hooks = state
  520. def __repr__(self, *, tensor_contents=None):
  521. if has_torch_function_unary(self):
  522. return handle_torch_function(
  523. Tensor.__repr__, (self,), self, tensor_contents=tensor_contents
  524. )
  525. # All strings are unicode in Python 3.
  526. return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  527. def backward(
  528. self, gradient=None, retain_graph=None, create_graph=False, inputs=None
  529. ):
  530. r"""Computes the gradient of current tensor wrt graph leaves.
  531. The graph is differentiated using the chain rule. If the tensor is
  532. non-scalar (i.e. its data has more than one element) and requires
  533. gradient, the function additionally requires specifying a ``gradient``.
  534. It should be a tensor of matching type and shape, that represents
  535. the gradient of the differentiated function w.r.t. ``self``.
  536. This function accumulates gradients in the leaves - you might need to zero
  537. ``.grad`` attributes or set them to ``None`` before calling it.
  538. See :ref:`Default gradient layouts<default-grad-layouts>`
  539. for details on the memory layout of accumulated gradients.
  540. .. note::
  541. If you run any forward ops, create ``gradient``, and/or call ``backward``
  542. in a user-specified CUDA stream context, see
  543. :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
  544. .. note::
  545. When ``inputs`` are provided and a given input is not a leaf,
  546. the current implementation will call its grad_fn (though it is not strictly needed to get this gradients).
  547. It is an implementation detail on which the user should not rely.
  548. See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
  549. Args:
  550. gradient (Tensor, optional): The gradient of the function
  551. being differentiated w.r.t. ``self``.
  552. This argument can be omitted if ``self`` is a scalar. Defaults to ``None``.
  553. retain_graph (bool, optional): If ``False``, the graph used to compute the grads will be freed;
  554. If ``True``, it will be retained. The default is ``None``, in which case the value is inferred from ``create_graph``
  555. (i.e., the graph is retained only when higher-order derivative tracking is requested). Note that in nearly all cases
  556. setting this option to True is not needed and often can be worked around in a much more efficient way.
  557. create_graph (bool, optional): If ``True``, graph of the derivative will
  558. be constructed, allowing to compute higher order derivative
  559. products. Defaults to ``False``.
  560. inputs (Sequence[Tensor], optional): Inputs w.r.t. which the gradient will be
  561. accumulated into ``.grad``. All other tensors will be ignored. If not
  562. provided, the gradient is accumulated into all the leaf Tensors that were
  563. used to compute the :attr:`tensors`. Defaults to ``None``.
  564. """
  565. if has_torch_function_unary(self):
  566. return handle_torch_function(
  567. Tensor.backward,
  568. (self,),
  569. self,
  570. gradient=gradient,
  571. retain_graph=retain_graph,
  572. create_graph=create_graph,
  573. inputs=inputs,
  574. )
  575. torch.autograd.backward(
  576. self, gradient, retain_graph, create_graph, inputs=inputs
  577. )
  578. def register_hook(self, hook):
  579. r"""Registers a backward hook.
  580. The hook will be called every time a gradient with respect to the
  581. Tensor is computed. The hook should have the following signature::
  582. hook(grad) -> Tensor or None
  583. The hook should not modify its argument, but it can optionally return
  584. a new gradient which will be used in place of :attr:`grad`.
  585. This function returns a handle with a method ``handle.remove()``
  586. that removes the hook from the module.
  587. .. note::
  588. See :ref:`backward-hooks-execution` for more information on how when this hook
  589. is executed, and how its execution is ordered relative to other hooks.
  590. Example::
  591. >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
  592. >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
  593. >>> v.backward(torch.tensor([1., 2., 3.]))
  594. >>> v.grad
  595. 2
  596. 4
  597. 6
  598. [torch.FloatTensor of size (3,)]
  599. >>> h.remove() # removes the hook
  600. """
  601. if has_torch_function_unary(self):
  602. return handle_torch_function(Tensor.register_hook, (self,), self, hook)
  603. if not self.requires_grad:
  604. raise RuntimeError(
  605. "cannot register a hook on a tensor that doesn't require gradient"
  606. )
  607. if self._backward_hooks is None:
  608. self._backward_hooks = OrderedDict()
  609. if self.grad_fn is not None:
  610. self.grad_fn._register_hook_dict(self)
  611. from torch.utils.hooks import RemovableHandle
  612. handle = RemovableHandle(self._backward_hooks)
  613. self._backward_hooks[handle.id] = hook
  614. return handle
  615. def register_post_accumulate_grad_hook(self, hook):
  616. r"""Registers a backward hook that runs after grad accumulation.
  617. The hook will be called after all gradients for a tensor have been accumulated,
  618. meaning that the .grad field has been updated on that tensor. The post
  619. accumulate grad hook is ONLY applicable for leaf tensors (tensors without a
  620. .grad_fn field). Registering this hook on a non-leaf tensor will error!
  621. The hook should have the following signature::
  622. hook(param: Tensor) -> None
  623. Note that, unlike other autograd hooks, this hook operates on the tensor
  624. that requires grad and not the grad itself. The hook can in-place modify
  625. and access its Tensor argument, including its .grad field.
  626. This function returns a handle with a method ``handle.remove()``
  627. that removes the hook from the module.
  628. .. note::
  629. See :ref:`backward-hooks-execution` for more information on how when this hook
  630. is executed, and how its execution is ordered relative to other hooks. Since
  631. this hook runs during the backward pass, it will run in no_grad mode (unless
  632. create_graph is True). You can use torch.enable_grad() to re-enable autograd
  633. within the hook if you need it.
  634. Example::
  635. >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
  636. >>> lr = 0.01
  637. >>> # simulate a simple SGD update
  638. >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
  639. >>> v.backward(torch.tensor([1., 2., 3.]))
  640. >>> v
  641. tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)
  642. >>> h.remove() # removes the hook
  643. """
  644. if has_torch_function_unary(self):
  645. return handle_torch_function(
  646. Tensor.register_post_accumulate_grad_hook, (self,), self, hook
  647. )
  648. if not self.requires_grad:
  649. raise RuntimeError(
  650. "cannot register a hook on a tensor that doesn't require gradient"
  651. )
  652. if self.grad_fn is not None:
  653. raise RuntimeError(
  654. "post accumulate grad hooks cannot be registered on non-leaf tensors"
  655. )
  656. if self._post_accumulate_grad_hooks is None:
  657. self._post_accumulate_grad_hooks: dict[Any, Any] = OrderedDict()
  658. from torch.utils.hooks import RemovableHandle
  659. handle = RemovableHandle(self._post_accumulate_grad_hooks)
  660. self._post_accumulate_grad_hooks[handle.id] = hook
  661. return handle
  662. def reinforce(self, reward):
  663. def trim(str):
  664. return "\n".join([line.strip() for line in str.split("\n")])
  665. raise RuntimeError(
  666. trim(
  667. r"""reinforce() was removed.
  668. Use torch.distributions instead.
  669. See https://pytorch.org/docs/main/distributions.html
  670. Instead of:
  671. probs = policy_network(state)
  672. action = probs.multinomial()
  673. next_state, reward = env.step(action)
  674. action.reinforce(reward)
  675. action.backward()
  676. Use:
  677. probs = policy_network(state)
  678. # NOTE: categorical is equivalent to what used to be called multinomial
  679. m = torch.distributions.Categorical(probs)
  680. action = m.sample()
  681. next_state, reward = env.step(action)
  682. loss = -m.log_prob(action) * reward
  683. loss.backward()
  684. """
  685. )
  686. )
  687. detach = _C._add_docstr(
  688. _C.TensorBase.detach,
  689. r"""
  690. Returns a new Tensor, detached from the current graph.
  691. The result will never require gradient.
  692. This method also affects forward mode AD gradients and the result will never
  693. have forward mode AD gradients.
  694. .. note::
  695. Returned Tensor shares the same storage with the original one.
  696. In-place modifications on either of them will be seen, and may trigger
  697. errors in correctness checks.
  698. """,
  699. )
  700. detach_ = _C._add_docstr(
  701. _C.TensorBase.detach_,
  702. r"""
  703. Detaches the Tensor from the graph that created it, making it a leaf.
  704. Views cannot be detached in-place.
  705. This method also affects forward mode AD gradients and the result will never
  706. have forward mode AD gradients.
  707. """,
  708. )
  709. def is_shared(self):
  710. r"""Checks if tensor is in shared memory.
  711. This is always ``True`` for CUDA tensors.
  712. """
  713. if has_torch_function_unary(self):
  714. return handle_torch_function(Tensor.is_shared, (self,), self)
  715. return self._typed_storage()._is_shared()
  716. def share_memory_(self):
  717. r"""Moves the underlying storage to shared memory.
  718. This is a no-op if the underlying storage is already in shared memory
  719. and for CUDA tensors. Tensors in shared memory cannot be resized.
  720. See :meth:`torch.UntypedStorage.share_memory_` for more details.
  721. """
  722. if has_torch_function_unary(self):
  723. return handle_torch_function(Tensor.share_memory_, (self,), self)
  724. self._typed_storage()._share_memory_()
  725. return self
  726. def module_load(self, other, assign=False):
  727. r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`.
  728. Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
  729. It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the
  730. value in the state dictionary with the corresponding key, this method defines
  731. how ``other`` is remapped before being swapped with ``self`` via
  732. :func:`~torch.utils.swap_tensors` in :meth:`~nn.Module.load_state_dict`.
  733. .. note::
  734. This method should always return a new object that is not ``self`` or ``other``.
  735. For example, the default implementation returns ``self.copy_(other).detach()``
  736. if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``.
  737. Args:
  738. other (Tensor): value in state dict with key corresponding to ``self``
  739. assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict`
  740. """
  741. if has_torch_function_variadic(self, other):
  742. return handle_torch_function(
  743. Tensor.module_load, (self, other), self, other, assign=assign
  744. )
  745. if assign:
  746. return other.detach()
  747. else:
  748. return self.copy_(other).detach()
  749. def __reversed__(self):
  750. r"""Reverses the tensor along dimension 0."""
  751. if has_torch_function_unary(self):
  752. return handle_torch_function(Tensor.__reversed__, (self,), self)
  753. if self.dim() == 0:
  754. return self
  755. else:
  756. return self.flip(0)
  757. def norm(
  758. self,
  759. p: Optional[Union[float, str]] = "fro",
  760. dim=None,
  761. keepdim=False,
  762. dtype=None,
  763. ):
  764. r"""See :func:`torch.norm`"""
  765. if has_torch_function_unary(self):
  766. return handle_torch_function(
  767. Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype
  768. )
  769. return torch.norm(self, p, dim, keepdim, dtype=dtype)
  770. def solve(self, other):
  771. from torch._linalg_utils import solve
  772. return solve(self, other)
  773. def lstsq(self, other):
  774. from torch._linalg_utils import lstsq
  775. return lstsq(self, other)
  776. def eig(self, eigenvectors=False):
  777. from torch._linalg_utils import eig
  778. return eig(self, eigenvectors=eigenvectors)
  779. def symeig(self, eigenvectors=False):
  780. from torch._linalg_utils import _symeig
  781. return _symeig(self, eigenvectors=eigenvectors)
  782. def lu(self, pivot=True, get_infos=False):
  783. r"""See :func:`torch.lu`"""
  784. # If get_infos is True, then we don't need to check for errors and vice versa
  785. if has_torch_function_unary(self):
  786. return handle_torch_function(
  787. Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos
  788. )
  789. LU, pivots, infos = torch._lu_with_info(
  790. self, pivot=pivot, check_errors=(not get_infos)
  791. )
  792. if get_infos:
  793. return LU, pivots, infos
  794. else:
  795. return LU, pivots
  796. def stft(
  797. self,
  798. n_fft: int,
  799. hop_length: Optional[int] = None,
  800. win_length: Optional[int] = None,
  801. window: "Optional[Tensor]" = None,
  802. center: bool = True,
  803. pad_mode: str = "reflect",
  804. normalized: bool = False,
  805. onesided: Optional[bool] = None,
  806. return_complex: Optional[bool] = None,
  807. align_to_window: Optional[bool] = None,
  808. ):
  809. r"""See :func:`torch.stft`
  810. .. warning::
  811. This function changed signature at version 0.4.1. Calling with
  812. the previous signature may cause error or return incorrect result.
  813. """
  814. if has_torch_function_unary(self):
  815. return handle_torch_function(
  816. Tensor.stft,
  817. (self,),
  818. self,
  819. n_fft,
  820. hop_length=hop_length,
  821. win_length=win_length,
  822. window=window,
  823. center=center,
  824. pad_mode=pad_mode,
  825. normalized=normalized,
  826. onesided=onesided,
  827. return_complex=return_complex,
  828. align_to_window=align_to_window,
  829. )
  830. return torch.stft(
  831. self,
  832. n_fft,
  833. hop_length,
  834. win_length,
  835. window,
  836. center,
  837. pad_mode,
  838. normalized,
  839. onesided,
  840. return_complex=return_complex,
  841. align_to_window=align_to_window,
  842. )
  843. def istft(
  844. self,
  845. n_fft: int,
  846. hop_length: Optional[int] = None,
  847. win_length: Optional[int] = None,
  848. window: "Optional[Tensor]" = None,
  849. center: bool = True,
  850. normalized: bool = False,
  851. onesided: Optional[bool] = None,
  852. length: Optional[int] = None,
  853. return_complex: bool = False,
  854. ):
  855. r"""See :func:`torch.istft`"""
  856. if has_torch_function_unary(self):
  857. return handle_torch_function(
  858. Tensor.istft,
  859. (self,),
  860. self,
  861. n_fft,
  862. hop_length=hop_length,
  863. win_length=win_length,
  864. window=window,
  865. center=center,
  866. normalized=normalized,
  867. onesided=onesided,
  868. length=length,
  869. return_complex=return_complex,
  870. )
  871. return torch.istft(
  872. self,
  873. n_fft,
  874. hop_length,
  875. win_length,
  876. window,
  877. center,
  878. normalized,
  879. onesided,
  880. length,
  881. return_complex=return_complex,
  882. )
  883. def resize(self, *sizes):
  884. if has_torch_function_unary(self):
  885. return handle_torch_function(Tensor.resize, (self,), self, *sizes)
  886. warnings.warn("non-inplace resize is deprecated")
  887. from torch.autograd._functions import Resize
  888. return Resize.apply(self, sizes)
  889. def resize_as(self, tensor):
  890. if has_torch_function_variadic(self, tensor):
  891. return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor)
  892. warnings.warn("non-inplace resize_as is deprecated")
  893. from torch.autograd._functions import Resize
  894. return Resize.apply(self, tensor.size())
  895. def split(self, split_size, dim=0):
  896. r"""See :func:`torch.split`"""
  897. if has_torch_function_unary(self):
  898. return handle_torch_function(
  899. Tensor.split, (self,), self, split_size, dim=dim
  900. )
  901. if isinstance(split_size, Tensor):
  902. try:
  903. split_size = int(split_size)
  904. except ValueError:
  905. pass
  906. if isinstance(split_size, (int, torch.SymInt)):
  907. return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined]
  908. else:
  909. return torch._VF.split_with_sizes(self, split_size, dim)
  910. def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
  911. r"""Returns the unique elements of the input tensor.
  912. See :func:`torch.unique`
  913. """
  914. if has_torch_function_unary(self):
  915. return handle_torch_function(
  916. Tensor.unique,
  917. (self,),
  918. self,
  919. sorted=sorted,
  920. return_inverse=return_inverse,
  921. return_counts=return_counts,
  922. dim=dim,
  923. )
  924. return torch.unique(
  925. self,
  926. sorted=sorted,
  927. return_inverse=return_inverse,
  928. return_counts=return_counts,
  929. dim=dim,
  930. )
  931. def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
  932. r"""Eliminates all but the first element from every consecutive group of equivalent elements.
  933. See :func:`torch.unique_consecutive`
  934. """
  935. if has_torch_function_unary(self):
  936. return handle_torch_function(
  937. Tensor.unique_consecutive,
  938. (self,),
  939. self,
  940. return_inverse=return_inverse,
  941. return_counts=return_counts,
  942. dim=dim,
  943. )
  944. return torch.unique_consecutive(
  945. self, return_inverse=return_inverse, return_counts=return_counts, dim=dim
  946. )
  947. @_handle_torch_function_and_wrap_type_error_to_not_implemented
  948. def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
  949. return _C._VariableFunctions.rsub(self, other)
  950. @_handle_torch_function_and_wrap_type_error_to_not_implemented
  951. def __rdiv__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
  952. return self.reciprocal() * other
  953. __rtruediv__ = __rdiv__
  954. __itruediv__ = _C.TensorBase.__idiv__
  955. __pow__ = cast(
  956. Callable[
  957. ["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],
  958. "Tensor",
  959. ],
  960. _handle_torch_function_and_wrap_type_error_to_not_implemented(
  961. _C.TensorBase.pow
  962. ),
  963. )
  964. __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
  965. _C.TensorBase.pow_
  966. )
  967. @_handle_torch_function_and_wrap_type_error_to_not_implemented
  968. def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
  969. return torch.remainder(other, self)
  970. def __format__(self, format_spec):
  971. if has_torch_function_unary(self):
  972. return handle_torch_function(Tensor.__format__, (self,), self, format_spec)
  973. if self.dim() == 0 and not self.is_meta and type(self) is Tensor:
  974. # Use detach() here to avoid the warning when converting a scalar Tensor that
  975. # requires gradients to a python number. It is ok for formatting.
  976. return self.detach().item().__format__(format_spec)
  977. return object.__format__(self, format_spec)
  978. @_handle_torch_function_and_wrap_type_error_to_not_implemented
  979. def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
  980. return torch.pow(other, self)
  981. @_handle_torch_function_and_wrap_type_error_to_not_implemented
  982. def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
  983. # TODO(rec): the superclass says it accepts complex here,
  984. # but torch.floor_divide says it doesn't.
  985. return torch.floor_divide(self, other)
  986. @_handle_torch_function_and_wrap_type_error_to_not_implemented
  987. def __rfloordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
  988. return torch.floor_divide(other, self)
  989. @_handle_torch_function_and_wrap_type_error_to_not_implemented
  990. def __rlshift__(
  991. self, other: Union["Tensor", int, float, bool, complex]
  992. ) -> "Tensor":
  993. return torch.bitwise_left_shift(other, self)
  994. @_handle_torch_function_and_wrap_type_error_to_not_implemented
  995. def __rrshift__(
  996. self, other: Union["Tensor", int, float, bool, complex]
  997. ) -> "Tensor":
  998. return torch.bitwise_right_shift(other, self)
  999. @_handle_torch_function_and_wrap_type_error_to_not_implemented
  1000. def __rmatmul__(self, other: "Tensor") -> "Tensor":
  1001. return torch.matmul(other, self)
  1002. __pos__ = _C.TensorBase.positive
  1003. __neg__ = _C.TensorBase.neg
  1004. __abs__ = _C.TensorBase.abs
  1005. def __len__(self):
  1006. if has_torch_function_unary(self):
  1007. return handle_torch_function(Tensor.__len__, (self,), self)
  1008. if self.dim() == 0:
  1009. raise TypeError("len() of a 0-d tensor")
  1010. if torch._C._get_tracing_state():
  1011. warnings.warn(
  1012. "Using len to get tensor shape might cause the trace to be incorrect. "
  1013. "Recommended usage would be tensor.shape[0]. "
  1014. "Passing a tensor of different shape might lead to errors or silently give "
  1015. "incorrect results.",
  1016. category=torch.jit.TracerWarning,
  1017. stacklevel=2,
  1018. )
  1019. return self.shape[0]
  1020. def __iter__(self):
  1021. # NB: we use 'imap' and not 'map' here, so that in Python 2 we get a
  1022. # generator and don't eagerly perform all the indexes. This could
  1023. # save us work, and also helps keep trace ordering deterministic
  1024. # (e.g., if you zip(*hiddens), the eager map will force all the
  1025. # indexes of hiddens[0] before hiddens[1], while the generator
  1026. # map will interleave them.)
  1027. # NB: We have intentionally skipped __torch_function__ dispatch here.
  1028. # See gh-54457
  1029. if self.dim() == 0:
  1030. raise TypeError("iteration over a 0-d tensor")
  1031. if torch._C._get_tracing_state():
  1032. warnings.warn(
  1033. "Iterating over a tensor might cause the trace to be incorrect. "
  1034. "Passing a tensor of different shape won't change the number of "
  1035. "iterations executed (and might lead to errors or silently give "
  1036. "incorrect results).",
  1037. category=torch.jit.TracerWarning,
  1038. stacklevel=2,
  1039. )
  1040. return iter(self.unbind(0))
  1041. def __hash__(self):
  1042. # Do NOT handle __torch_function__ here as user's default
  1043. # implementation that handle most functions will most likely do it wrong.
  1044. # It can be easily overridden by defining this method on the user
  1045. # subclass if needed.
  1046. return id(self)
  1047. def __dir__(self):
  1048. if has_torch_function_unary(self):
  1049. return handle_torch_function(Tensor.__dir__, (self,), self)
  1050. tensor_methods = dir(self.__class__)
  1051. tensor_methods.remove("volatile") # deprecated
  1052. attrs = list(self.__dict__.keys())
  1053. keys = tensor_methods + attrs
  1054. # property only available dense, cuda tensors
  1055. if (not self.is_cuda) or self.is_sparse:
  1056. keys.remove("__cuda_array_interface__")
  1057. return sorted(keys)
  1058. # Numpy array interface, to support `numpy.asarray(tensor) -> ndarray`
  1059. __array_priority__ = 1000 # prefer Tensor ops over numpy ones
  1060. def __array__(self, dtype=None):
  1061. if has_torch_function_unary(self):
  1062. return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
  1063. if dtype is None:
  1064. return self.numpy()
  1065. else:
  1066. return self.numpy().astype(dtype, copy=False)
  1067. # Wrap Numpy array again in a suitable tensor when done, to support e.g.
  1068. # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor`
  1069. def __array_wrap__(self, array):
  1070. if has_torch_function_unary(self):
  1071. return handle_torch_function(
  1072. Tensor.__array_wrap__, (self,), self, array=array
  1073. )
  1074. if array.dtype == bool:
  1075. # Workaround, torch has no built-in bool tensor
  1076. array = array.astype("uint8")
  1077. return torch.from_numpy(array)
  1078. def __contains__(self, element: Any, /) -> bool:
  1079. r"""Check if `element` is present in tensor
  1080. Args:
  1081. element (Tensor or scalar): element to be checked
  1082. for presence in current tensor"
  1083. """
  1084. if has_torch_function_unary(self):
  1085. return handle_torch_function(Tensor.__contains__, (self,), self, element)
  1086. if isinstance(
  1087. element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool)
  1088. ):
  1089. # type hint doesn't understand the __contains__ result array
  1090. return bool((element == self).any().item()) # type: ignore[union-attr]
  1091. raise RuntimeError(
  1092. f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}."
  1093. )
  1094. @property
  1095. def __cuda_array_interface__(self):
  1096. """Array view description for cuda tensors.
  1097. See:
  1098. https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
  1099. """
  1100. if has_torch_function_unary(self):
  1101. # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
  1102. return handle_torch_function(
  1103. Tensor.__cuda_array_interface__.__get__, # type: ignore[attr-defined]
  1104. (self,),
  1105. self,
  1106. )
  1107. # raise AttributeError for unsupported tensors, so that
  1108. # hasattr(cpu_tensor, "__cuda_array_interface__") is False.
  1109. if not self.is_cuda:
  1110. raise AttributeError(
  1111. f"Can't get __cuda_array_interface__ on non-CUDA tensor type: {self.type()} "
  1112. "If CUDA data is required use tensor.cuda() to copy tensor to device memory."
  1113. )
  1114. if self.is_sparse:
  1115. raise AttributeError(
  1116. f"Can't get __cuda_array_interface__ on sparse type: {self.type()} "
  1117. "Use Tensor.to_dense() to convert to a dense tensor first."
  1118. )
  1119. # RuntimeError, matching tensor.__array__() behavior.
  1120. if self.requires_grad:
  1121. raise RuntimeError(
  1122. "Can't get __cuda_array_interface__ on Variable that requires grad. "
  1123. "If gradients aren't required, use var.detach() to get Variable that doesn't require grad."
  1124. )
  1125. typestr = _dtype_to_typestr(self.dtype)
  1126. itemsize = self.element_size()
  1127. shape = tuple(self.shape)
  1128. if self.is_contiguous():
  1129. # __cuda_array_interface__ v2 requires the strides to be omitted
  1130. # (either not set or set to None) for C-contiguous arrays.
  1131. strides = None
  1132. else:
  1133. strides = tuple(s * itemsize for s in self.stride())
  1134. data_ptr = self.data_ptr() if self.numel() > 0 else 0
  1135. data = (data_ptr, False) # read-only is false
  1136. return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=2)
  1137. def storage_type(self):
  1138. r"""storage_type() -> type
  1139. Returns the type of the underlying storage.
  1140. """
  1141. if has_torch_function_unary(self):
  1142. return handle_torch_function(Tensor.storage_type, (self,), self)
  1143. torch.storage._warn_typed_storage_removal()
  1144. return self._typed_storage()._get_legacy_storage_class()
  1145. def refine_names(self, *names):
  1146. r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
  1147. Refining is a special case of renaming that "lifts" unnamed dimensions.
  1148. A ``None`` dim can be refined to have any name; a named dim can only be
  1149. refined to have the same name.
  1150. Because named tensors can coexist with unnamed tensors, refining names
  1151. gives a nice way to write named-tensor-aware code that works with both
  1152. named and unnamed tensors.
  1153. :attr:`names` may contain up to one Ellipsis (``...``).
  1154. The Ellipsis is expanded greedily; it is expanded in-place to fill
  1155. :attr:`names` to the same length as ``self.dim()`` using names from the
  1156. corresponding indices of ``self.names``.
  1157. Python 2 does not support Ellipsis but one may use a string literal
  1158. instead (``'...'``).
  1159. Args:
  1160. names (iterable of str): The desired names of the output tensor. May
  1161. contain up to one Ellipsis.
  1162. Examples::
  1163. >>> imgs = torch.randn(32, 3, 128, 128)
  1164. >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
  1165. >>> named_imgs.names
  1166. ('N', 'C', 'H', 'W')
  1167. >>> tensor = torch.randn(2, 3, 5, 7, 11)
  1168. >>> tensor = tensor.refine_names('A', ..., 'B', 'C')
  1169. >>> tensor.names
  1170. ('A', None, None, 'B', 'C')
  1171. .. warning::
  1172. The named tensor API is experimental and subject to change.
  1173. """
  1174. if has_torch_function_unary(self):
  1175. return handle_torch_function(Tensor.refine_names, (self,), self, *names)
  1176. names = resolve_ellipsis(names, self.names, "refine_names")
  1177. return super().refine_names(names)
  1178. def align_to(self, *names):
  1179. r"""Permutes the dimensions of the :attr:`self` tensor to match the order
  1180. specified in :attr:`names`, adding size-one dims for any new names.
  1181. All of the dims of :attr:`self` must be named in order to use this method.
  1182. The resulting tensor is a view on the original tensor.
  1183. All dimension names of :attr:`self` must be present in :attr:`names`.
  1184. :attr:`names` may contain additional names that are not in ``self.names``;
  1185. the output tensor has a size-one dimension for each of those new names.
  1186. :attr:`names` may contain up to one Ellipsis (``...``).
  1187. The Ellipsis is expanded to be equal to all dimension names of :attr:`self`
  1188. that are not mentioned in :attr:`names`, in the order that they appear
  1189. in :attr:`self`.
  1190. Python 2 does not support Ellipsis but one may use a string literal
  1191. instead (``'...'``).
  1192. Args:
  1193. names (iterable of str): The desired dimension ordering of the
  1194. output tensor. May contain up to one Ellipsis that is expanded
  1195. to all unmentioned dim names of :attr:`self`.
  1196. Examples::
  1197. >>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
  1198. >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
  1199. # Move the F and E dims to the front while keeping the rest in order
  1200. >>> named_tensor.align_to('F', 'E', ...)
  1201. .. warning::
  1202. The named tensor API is experimental and subject to change.
  1203. """
  1204. if has_torch_function_unary(self):
  1205. return handle_torch_function(Tensor.align_to, (self,), self, *names)
  1206. ellipsis_idx = single_ellipsis_index(names, "align_to")
  1207. if ellipsis_idx is None:
  1208. return super().align_to(names)
  1209. return super().align_to(
  1210. [name for name in names if not is_ellipsis(name)], ellipsis_idx
  1211. )
  1212. def unflatten(self, dim, sizes): # type: ignore[override]
  1213. r"""
  1214. unflatten(dim, sizes) -> Tensor
  1215. See :func:`torch.unflatten`.
  1216. """
  1217. if has_torch_function_unary(self):
  1218. return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes)
  1219. if not sizes:
  1220. raise RuntimeError("unflatten: sizes must be non-empty")
  1221. names = None
  1222. if isinstance(sizes, OrderedDict) or (
  1223. isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list))
  1224. ):
  1225. names, sizes = unzip_namedshape(sizes)
  1226. return super().unflatten(dim, sizes, names)
  1227. else:
  1228. return super().unflatten(dim, sizes)
  1229. def rename_(self, *names, **rename_map):
  1230. """In-place version of :meth:`~Tensor.rename`."""
  1231. if has_torch_function_unary(self):
  1232. return handle_torch_function(
  1233. Tensor.rename_, (self,), self, *names, **rename_map
  1234. )
  1235. # Note [rename_ / rename API]
  1236. # The Python API for these is different from the C++ API. In Python:
  1237. # 1) tensor.rename(*names) takes a vararglist of names
  1238. # 2) tensor.rename(**rename_map) takes a map of names to rename.
  1239. # C++ is static, making it difficult to implement similar behavior.
  1240. return update_names(self, names, rename_map, inplace=True)
  1241. def rename(self, *names, **rename_map):
  1242. """Renames dimension names of :attr:`self`.
  1243. There are two main usages:
  1244. ``self.rename(**rename_map)`` returns a view on tensor that has dims
  1245. renamed as specified in the mapping :attr:`rename_map`.
  1246. ``self.rename(*names)`` returns a view on tensor, renaming all
  1247. dimensions positionally using :attr:`names`.
  1248. Use ``self.rename(None)`` to drop names on a tensor.
  1249. One cannot specify both positional args :attr:`names` and keyword args
  1250. :attr:`rename_map`.
  1251. Examples::
  1252. >>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
  1253. >>> renamed_imgs = imgs.rename(N='batch', C='channels')
  1254. >>> renamed_imgs.names
  1255. ('batch', 'channels', 'H', 'W')
  1256. >>> renamed_imgs = imgs.rename(None)
  1257. >>> renamed_imgs.names
  1258. (None, None, None, None)
  1259. >>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width')
  1260. >>> renamed_imgs.names
  1261. ('batch', 'channel', 'height', 'width')
  1262. .. warning::
  1263. The named tensor API is experimental and subject to change.
  1264. """
  1265. if has_torch_function_unary(self):
  1266. return handle_torch_function(
  1267. Tensor.rename, (self,), self, *names, **rename_map
  1268. )
  1269. # See Note [rename_ / rename API]
  1270. return update_names(self, names, rename_map, inplace=False)
  1271. def to_sparse_coo(self):
  1272. """Convert a tensor to :ref:`coordinate format <sparse-coo-docs>`.
  1273. Examples::
  1274. >>> dense = torch.randn(5, 5)
  1275. >>> sparse = dense.to_sparse_coo()
  1276. >>> sparse._nnz()
  1277. 25
  1278. """
  1279. return self.to_sparse()
  1280. def dim_order(
  1281. self, *, ambiguity_check: Union[bool, list[torch.memory_format]] = False
  1282. ):
  1283. """
  1284. dim_order(ambiguity_check=False) -> tuple
  1285. Returns the uniquely determined tuple of int describing the dim order or
  1286. physical layout of :attr:`self`.
  1287. The dim order represents how dimensions are laid out in memory of dense tensors,
  1288. starting from the outermost to the innermost dimension.
  1289. Note that the dim order may not always be uniquely determined.
  1290. If `ambiguity_check` is True, this function raises a RuntimeError when the dim order cannot be uniquely determined;
  1291. If `ambiguity_check` is a list of memory formats, this function raises a RuntimeError when tensor can not be interpreted
  1292. into exactly one of the given memory formats, or it cannot be uniquely determined.
  1293. If `ambiguity_check` is False, it will return one of legal dim order(s) without checking its uniqueness.
  1294. Otherwise, it will raise TypeError.
  1295. Args:
  1296. ambiguity_check (bool or List[torch.memory_format]): The check method for ambiguity of dim order.
  1297. Examples::
  1298. >>> torch.empty((2, 3, 5, 7)).dim_order()
  1299. (0, 1, 2, 3)
  1300. >>> torch.empty((2, 3, 5, 7)).transpose(1, 2).dim_order()
  1301. (0, 2, 1, 3)
  1302. >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
  1303. (0, 2, 3, 1)
  1304. >>> torch.empty((1, 2, 3, 4)).dim_order()
  1305. (0, 1, 2, 3)
  1306. >>> try:
  1307. ... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check=True)
  1308. ... except RuntimeError as e:
  1309. ... print(e)
  1310. The tensor does not have unique dim order, or cannot map to exact one of the given memory formats.
  1311. >>> torch.empty((1, 2, 3, 4)).dim_order(
  1312. ... ambiguity_check=[torch.contiguous_format, torch.channels_last]
  1313. ... ) # It can be mapped to contiguous format
  1314. (0, 1, 2, 3)
  1315. >>> try:
  1316. ... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check="ILLEGAL")
  1317. ... except TypeError as e:
  1318. ... print(e)
  1319. The ambiguity_check argument must be a bool or a list of memory formats.
  1320. .. warning::
  1321. The dim_order tensor API is experimental and subject to change.
  1322. """
  1323. if has_torch_function_unary(self):
  1324. return handle_torch_function(Tensor.dim_order, (self,), self)
  1325. if self.is_sparse:
  1326. raise AttributeError(
  1327. f"Can't get dim order on sparse type: {self.type()} "
  1328. "Use Tensor.to_dense() to convert to a dense tensor first."
  1329. )
  1330. # Sanity check ambiguity_check data types
  1331. if not isinstance(ambiguity_check, bool):
  1332. if not isinstance(ambiguity_check, list):
  1333. raise TypeError(
  1334. "The ambiguity_check argument must be a bool or a list of memory formats."
  1335. )
  1336. for memory_format in ambiguity_check:
  1337. if not isinstance(memory_format, torch.memory_format):
  1338. raise TypeError(
  1339. "The ambiguity_check argument must be a bool or a list of memory formats."
  1340. )
  1341. def invalid_unique_memory_format(tensor, valid_memory_formats):
  1342. """
  1343. Returns True if the tensor cannot be uniquely mapped to any of the given memory formats, False otherwise.
  1344. """
  1345. n_legality = 0
  1346. for memory_format in valid_memory_formats:
  1347. if tensor.is_contiguous(memory_format=memory_format):
  1348. n_legality += 1
  1349. return n_legality != 1
  1350. def has_multiple_dim_order(tensor):
  1351. """
  1352. Returns True if there're multiple legal dim orders for given tensor, False otherwise.
  1353. The tensor is considered to have multiple legal dim orders if either of the following conditions is met:
  1354. * Singleton Dimensions: There's at least one singleteon dimension in the tensor.
  1355. Since their size is 1, they don't affect the memory offset (stride * index
  1356. is zero because index is always zero). Therefore, they can be placed anywhere
  1357. in the dimension order without changing how data is accessed.
  1358. * Same strides: Strides reflect how the tensor is stored in memory.
  1359. If any two dimensions have the same stride, swapping these dimensions won't
  1360. change how data is accessed, leading to multiple correct dimension orders.
  1361. """
  1362. sizes = tensor.size()
  1363. strides = tensor.stride()
  1364. # Check if there are any duplicate strides
  1365. has_duplicate_strides = any(
  1366. earlier == later for earlier, later in zip(strides, strides[1:])
  1367. )
  1368. # Check if there are any singleton dimensions
  1369. has_singleton_dims = any(size == 1 for size in sizes)
  1370. return has_duplicate_strides or has_singleton_dims
  1371. valid_memory_formats = (
  1372. ambiguity_check if isinstance(ambiguity_check, list) else []
  1373. )
  1374. check_multiple_dim_order = (
  1375. ambiguity_check if isinstance(ambiguity_check, bool) else True
  1376. )
  1377. if (
  1378. check_multiple_dim_order and has_multiple_dim_order(self)
  1379. ) and invalid_unique_memory_format(self, valid_memory_formats):
  1380. raise RuntimeError(
  1381. "The tensor does not have unique dim order, or cannot map to exact one of the given memory formats."
  1382. )
  1383. import torch._prims_common as utils
  1384. return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))
  1385. def _update_names(self, names, inplace):
  1386. if has_torch_function_unary(self):
  1387. return handle_torch_function(
  1388. Tensor._update_names, (self,), self, names, inplace
  1389. )
  1390. # See Note [rename_ / rename API]
  1391. if inplace:
  1392. return super().rename_(names)
  1393. else:
  1394. return super().rename(names)
  1395. @classmethod
  1396. def __torch_function__(cls, func, types, args=(), kwargs=None):
  1397. """
  1398. This __torch_function__ implementation wraps subclasses such that
  1399. methods called on subclasses return a subclass instance instead of
  1400. a ``torch.Tensor`` instance.
  1401. One corollary to this is that you need coverage for torch.Tensor
  1402. methods if implementing __torch_function__ for subclasses.
  1403. We recommend always calling ``super().__torch_function__`` as the base
  1404. case when doing the above.
  1405. While not mandatory, we recommend making `__torch_function__` a classmethod.
  1406. """
  1407. if kwargs is None:
  1408. kwargs = {}
  1409. if not all(issubclass(cls, t) for t in types):
  1410. return NotImplemented
  1411. with _C.DisableTorchFunctionSubclass():
  1412. ret = func(*args, **kwargs)
  1413. if func in get_default_nowrap_functions():
  1414. return ret
  1415. else:
  1416. return _convert(ret, cls)
  1417. __torch_dispatch__ = _C._disabled_torch_dispatch_impl
  1418. def __dlpack__(
  1419. self,
  1420. *,
  1421. stream: Optional[Any] = None,
  1422. max_version: Optional[tuple[int, int]] = None,
  1423. dl_device: Optional[tuple[enum.IntEnum, int]] = None,
  1424. copy: Optional[bool] = None,
  1425. ):
  1426. """
  1427. Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_
  1428. of the current tensor to be exported to other libraries.
  1429. This function will be called from the `from_dlpack` method
  1430. of the library that will consume the capsule. `from_dlpack` passes the current
  1431. stream to this method as part of the specification.
  1432. Args:
  1433. stream (integer or None): An optional Python integer representing a
  1434. pointer to a CUDA stream. The current stream is synchronized with
  1435. this stream before the capsule is created, and since the capsule
  1436. shares its storage with the tensor this make it safe to access from
  1437. both streams. If None or -1 is passed then no synchronization is performed.
  1438. If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
  1439. synchronization.
  1440. max_version (tuple[int, int] or None): An optional Python tuple with
  1441. 2 integers, representing the maximum version the caller supports. If
  1442. None (default), PyTorch will fallback to DLPack 0.8.
  1443. dl_device (tuple[DLDeviceType, int] or None): An optional tuple specifying
  1444. in which device the exported DLPack capsule should be on. If None (default),
  1445. the exported DLPack capsule will be on the same device as ``self``.
  1446. copy (bool or None): An optional boolean indicating whether or not to copy
  1447. ``self``. If None, PyTorch will copy only if necessary.
  1448. """
  1449. if has_torch_function_unary(self):
  1450. args = (self,)
  1451. kwargs = {
  1452. "stream": stream,
  1453. "max_version": max_version,
  1454. "dl_device": dl_device,
  1455. "copy": copy,
  1456. }
  1457. return handle_torch_function(Tensor.__dlpack__, (self,), *args, **kwargs)
  1458. # DLPack capsules can't capture all of PyTorch's semantics,
  1459. # so we prohibit exporting tensors that would lose their properties like
  1460. # requires_grad and having the conjugate bit set.
  1461. if self.requires_grad:
  1462. raise BufferError(
  1463. "Can't export tensors that require gradient, use tensor.detach()"
  1464. )
  1465. if self.is_conj():
  1466. raise BufferError("Can't export tensors with the conjugate bit set")
  1467. if self.layout != torch.strided:
  1468. raise BufferError(
  1469. "Can't export tensors with layout other than torch.strided"
  1470. )
  1471. if (
  1472. self.device.type == "cuda"
  1473. and self.device.index != torch.cuda.current_device()
  1474. ):
  1475. raise BufferError(
  1476. "Can't export tensors on a different CUDA device index. "
  1477. f"Expected: {self.device.index}. "
  1478. f"Current device: {torch.cuda.current_device()}."
  1479. )
  1480. if stream is not None and type(stream) is not int:
  1481. # Stream pointers in CUDA/ROCm are uniquely numbered and can
  1482. # be retrieved from their integer value.
  1483. raise TypeError("stream must be ``int`` or ``none``")
  1484. elif self.device.type == "cuda" and stream != -1:
  1485. # NB: This logic handles the special case values for default
  1486. # streams and must be kept in sync with from_dlpack in
  1487. # torch/utils/dlpack.py
  1488. is_rocm = torch.version.hip is not None
  1489. is_cuda = not is_rocm
  1490. if stream is None or (is_rocm and stream == 0) or (is_cuda and stream == 1):
  1491. stream = torch.cuda.default_stream()
  1492. else:
  1493. if is_cuda and stream == 2:
  1494. raise BufferError("per-thread default stream is not supported.")
  1495. device_str = "CUDA" if is_cuda else "ROCm"
  1496. assert (is_cuda and stream != 0) or (
  1497. is_rocm and stream not in (1, 2)
  1498. ), f"unsupported stream on {device_str}: {stream}."
  1499. stream = torch.cuda.ExternalStream(stream)
  1500. # Only synchronize on different streams
  1501. current_stream = torch.cuda.current_stream()
  1502. if stream != current_stream:
  1503. event = torch.cuda.Event()
  1504. event.record(current_stream)
  1505. stream.wait_event(event)
  1506. elif self.device.type == "cpu":
  1507. assert stream is None, "stream should be None on cpu."
  1508. if self.device.type == "xla":
  1509. import torch_xla
  1510. import torch_xla.utils.dlpack as xla_dlpack
  1511. if (
  1512. len(torch_xla.real_devices()) <= 0
  1513. or "cuda" not in torch_xla.real_devices()[0].lower()
  1514. ):
  1515. raise RuntimeError(
  1516. "Can't export to dlpack an XLA tensor that is not on CUDA."
  1517. )
  1518. # Does not support DLPack 1.0, yet.
  1519. return xla_dlpack.to_dlpack(self)
  1520. if max_version is None or max_version[0] < 1:
  1521. # Fallback to the old, unversioned variant.
  1522. return _C._to_dlpack(self, dl_device=dl_device, copy=copy)
  1523. return _C._to_dlpack_versioned(self, dl_device=dl_device, copy=copy)
  1524. def __dlpack_device__(self) -> tuple[enum.IntEnum, int]:
  1525. if has_torch_function_unary(self):
  1526. return handle_torch_function(Tensor.__dlpack_device__, (self,), self)
  1527. from torch.utils.dlpack import DLDeviceType
  1528. device = self.device
  1529. idx = device.index if device.index is not None else 0
  1530. torch_device_type = device.type
  1531. if torch_device_type == "cuda" and torch.version.hip is not None:
  1532. device_type = DLDeviceType.kDLROCM
  1533. elif torch_device_type == "cpu" and self.is_pinned():
  1534. device_type = DLDeviceType.kDLCUDAHost
  1535. elif torch_device_type == "cuda":
  1536. device_type = DLDeviceType.kDLCUDA
  1537. elif torch_device_type == "cpu":
  1538. device_type = DLDeviceType.kDLCPU
  1539. elif torch_device_type == "xpu":
  1540. device_type = DLDeviceType.kDLOneAPI
  1541. elif self.device.type == "privateuse1":
  1542. device_type = DLDeviceType.kDLExtDev
  1543. elif torch_device_type == "xla":
  1544. import torch_xla
  1545. if (
  1546. len(torch_xla.real_devices()) <= 0
  1547. or "cuda" not in torch_xla.real_devices()[0].lower()
  1548. ):
  1549. raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
  1550. device_type = DLDeviceType.kDLCUDA
  1551. elif torch_device_type == "mps":
  1552. device_type = DLDeviceType.kDLMetal
  1553. else:
  1554. raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
  1555. return (device_type, idx)
  1556. __module__ = "torch"
  1557. def _convert(ret, cls):
  1558. if cls is Tensor:
  1559. return ret
  1560. if isinstance(ret, Tensor) and not isinstance(ret, cls):
  1561. ret = ret.as_subclass(cls)
  1562. if isinstance(ret, (tuple, list)):
  1563. # Also handles things like namedtuples
  1564. ret = type(ret)(_convert(r, cls) for r in ret)
  1565. return ret