_ndarray.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. # mypy: ignore-errors
  2. from __future__ import annotations
  3. import builtins
  4. import math
  5. import operator
  6. from collections.abc import Sequence
  7. import torch
  8. from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util
  9. from ._normalizations import (
  10. ArrayLike,
  11. normalize_array_like,
  12. normalizer,
  13. NotImplementedType,
  14. )
  15. newaxis = None
  16. FLAGS = [
  17. "C_CONTIGUOUS",
  18. "F_CONTIGUOUS",
  19. "OWNDATA",
  20. "WRITEABLE",
  21. "ALIGNED",
  22. "WRITEBACKIFCOPY",
  23. "FNC",
  24. "FORC",
  25. "BEHAVED",
  26. "CARRAY",
  27. "FARRAY",
  28. ]
  29. SHORTHAND_TO_FLAGS = {
  30. "C": "C_CONTIGUOUS",
  31. "F": "F_CONTIGUOUS",
  32. "O": "OWNDATA",
  33. "W": "WRITEABLE",
  34. "A": "ALIGNED",
  35. "X": "WRITEBACKIFCOPY",
  36. "B": "BEHAVED",
  37. "CA": "CARRAY",
  38. "FA": "FARRAY",
  39. }
  40. class Flags:
  41. def __init__(self, flag_to_value: dict):
  42. assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
  43. self._flag_to_value = flag_to_value
  44. def __getattr__(self, attr: str):
  45. if attr.islower() and attr.upper() in FLAGS:
  46. return self[attr.upper()]
  47. else:
  48. raise AttributeError(f"No flag attribute '{attr}'")
  49. def __getitem__(self, key):
  50. if key in SHORTHAND_TO_FLAGS.keys():
  51. key = SHORTHAND_TO_FLAGS[key]
  52. if key in FLAGS:
  53. try:
  54. return self._flag_to_value[key]
  55. except KeyError as e:
  56. raise NotImplementedError(f"{key=}") from e
  57. else:
  58. raise KeyError(f"No flag key '{key}'")
  59. def __setattr__(self, attr, value):
  60. if attr.islower() and attr.upper() in FLAGS:
  61. self[attr.upper()] = value
  62. else:
  63. super().__setattr__(attr, value)
  64. def __setitem__(self, key, value):
  65. if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
  66. raise NotImplementedError("Modifying flags is not implemented")
  67. else:
  68. raise KeyError(f"No flag key '{key}'")
  69. def create_method(fn, name=None):
  70. name = name or fn.__name__
  71. def f(*args, **kwargs):
  72. return fn(*args, **kwargs)
  73. f.__name__ = name
  74. f.__qualname__ = f"ndarray.{name}"
  75. return f
  76. # Map ndarray.name_method -> np.name_func
  77. # If name_func == None, it means that name_method == name_func
  78. methods = {
  79. "clip": None,
  80. "nonzero": None,
  81. "repeat": None,
  82. "round": None,
  83. "squeeze": None,
  84. "swapaxes": None,
  85. "ravel": None,
  86. # linalg
  87. "diagonal": None,
  88. "dot": None,
  89. "trace": None,
  90. # sorting
  91. "argsort": None,
  92. "searchsorted": None,
  93. # reductions
  94. "argmax": None,
  95. "argmin": None,
  96. "any": None,
  97. "all": None,
  98. "max": None,
  99. "min": None,
  100. "ptp": None,
  101. "sum": None,
  102. "prod": None,
  103. "mean": None,
  104. "var": None,
  105. "std": None,
  106. # scans
  107. "cumsum": None,
  108. "cumprod": None,
  109. # advanced indexing
  110. "take": None,
  111. "choose": None,
  112. }
  113. dunder = {
  114. "abs": "absolute",
  115. "invert": None,
  116. "pos": "positive",
  117. "neg": "negative",
  118. "gt": "greater",
  119. "lt": "less",
  120. "ge": "greater_equal",
  121. "le": "less_equal",
  122. }
  123. # dunder methods with right-looking and in-place variants
  124. ri_dunder = {
  125. "add": None,
  126. "sub": "subtract",
  127. "mul": "multiply",
  128. "truediv": "divide",
  129. "floordiv": "floor_divide",
  130. "pow": "power",
  131. "mod": "remainder",
  132. "and": "bitwise_and",
  133. "or": "bitwise_or",
  134. "xor": "bitwise_xor",
  135. "lshift": "left_shift",
  136. "rshift": "right_shift",
  137. "matmul": None,
  138. }
  139. def _upcast_int_indices(index):
  140. if isinstance(index, torch.Tensor):
  141. if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
  142. return index.to(torch.int64)
  143. elif isinstance(index, tuple):
  144. return tuple(_upcast_int_indices(i) for i in index)
  145. return index
  146. def _has_advanced_indexing(index):
  147. """Check if there's any advanced indexing"""
  148. return any(
  149. isinstance(idx, (Sequence, bool))
  150. or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))
  151. for idx in index
  152. )
  153. def _numpy_compatible_indexing(index):
  154. """Convert scalar indices to lists when advanced indexing is present for NumPy compatibility."""
  155. if not isinstance(index, tuple):
  156. index = (index,)
  157. # Check if there's any advanced indexing (sequences, booleans, or tensors)
  158. has_advanced = _has_advanced_indexing(index)
  159. if not has_advanced:
  160. return index
  161. # Convert integer scalar indices to single-element lists when advanced indexing is present
  162. # Note: Do NOT convert boolean scalars (True/False) as they have special meaning in NumPy
  163. converted = []
  164. for idx in index:
  165. if isinstance(idx, int) and not isinstance(idx, bool):
  166. # Integer scalars should be converted to lists
  167. converted.append([idx])
  168. elif (
  169. isinstance(idx, torch.Tensor)
  170. and idx.ndim == 0
  171. and not torch.is_floating_point(idx)
  172. and idx.dtype != torch.bool
  173. ):
  174. # Zero-dimensional tensors holding integers should be treated the same as integer scalars
  175. converted.append([idx])
  176. else:
  177. # Everything else (booleans, lists, slices, etc.) stays as is
  178. converted.append(idx)
  179. return tuple(converted)
  180. def _get_bool_depth(s):
  181. """Returns the depth of a boolean sequence/tensor"""
  182. if isinstance(s, bool):
  183. return True, 0
  184. if isinstance(s, torch.Tensor) and s.dtype == torch.bool:
  185. return True, s.ndim
  186. if not (isinstance(s, Sequence) and s and s[0] != s):
  187. return False, 0
  188. is_bool, depth = _get_bool_depth(s[0])
  189. return is_bool, depth + 1
  190. def _numpy_empty_ellipsis_patch(index, tensor_ndim):
  191. """
  192. Patch for NumPy-compatible ellipsis behavior when ellipsis doesn't match any dimensions.
  193. In NumPy, when an ellipsis (...) doesn't actually match any dimensions of the input array,
  194. it still acts as a separator between advanced indices. PyTorch doesn't have this behavior.
  195. This function detects when we have:
  196. 1. Advanced indexing on both sides of an ellipsis
  197. 2. The ellipsis doesn't actually match any dimensions
  198. """
  199. if not isinstance(index, tuple):
  200. index = (index,)
  201. # Find ellipsis position
  202. ellipsis_pos = None
  203. for i, idx in enumerate(index):
  204. if idx is Ellipsis:
  205. ellipsis_pos = i
  206. break
  207. # If no ellipsis, no patch needed
  208. if ellipsis_pos is None:
  209. return index, lambda x: x, lambda x: x
  210. # Count non-ellipsis dimensions consumed by the index
  211. consumed_dims = 0
  212. for idx in index:
  213. is_bool, depth = _get_bool_depth(idx)
  214. if is_bool:
  215. consumed_dims += depth
  216. elif idx is Ellipsis or idx is None:
  217. continue
  218. else:
  219. consumed_dims += 1
  220. # Calculate how many dimensions the ellipsis should match
  221. ellipsis_dims = tensor_ndim - consumed_dims
  222. # Check if ellipsis doesn't match any dimensions
  223. if ellipsis_dims == 0:
  224. # Check if we have advanced indexing on both sides of ellipsis
  225. left_advanced = _has_advanced_indexing(index[:ellipsis_pos])
  226. right_advanced = _has_advanced_indexing(index[ellipsis_pos + 1 :])
  227. if left_advanced and right_advanced:
  228. # This is the case where NumPy and PyTorch differ
  229. # We need to ensure the advanced indices are treated as separated
  230. new_index = index[:ellipsis_pos] + (None,) + index[ellipsis_pos + 1 :]
  231. end_ndims = 1 + sum(
  232. 1 for idx in index[ellipsis_pos + 1 :] if isinstance(idx, slice)
  233. )
  234. def squeeze_fn(x):
  235. return x.squeeze(-end_ndims)
  236. def unsqueeze_fn(x):
  237. if isinstance(x, torch.Tensor) and x.ndim >= end_ndims:
  238. return x.unsqueeze(-end_ndims)
  239. return x
  240. return new_index, squeeze_fn, unsqueeze_fn
  241. return index, lambda x: x, lambda x: x
  242. # Used to indicate that a parameter is unspecified (as opposed to explicitly
  243. # `None`)
  244. class _Unspecified:
  245. pass
  246. _Unspecified.unspecified = _Unspecified()
  247. ###############################################################
  248. # ndarray class #
  249. ###############################################################
  250. class ndarray:
  251. def __init__(self, t=None):
  252. if t is None:
  253. self.tensor = torch.Tensor()
  254. elif isinstance(t, torch.Tensor):
  255. self.tensor = t
  256. else:
  257. raise ValueError(
  258. "ndarray constructor is not recommended; prefer"
  259. "either array(...) or zeros/empty(...)"
  260. )
  261. # Register NumPy functions as methods
  262. for method, name in methods.items():
  263. fn = getattr(_funcs, name or method)
  264. vars()[method] = create_method(fn, method)
  265. # Regular methods but coming from ufuncs
  266. conj = create_method(_ufuncs.conjugate, "conj")
  267. conjugate = create_method(_ufuncs.conjugate)
  268. for method, name in dunder.items():
  269. fn = getattr(_ufuncs, name or method)
  270. method = f"__{method}__"
  271. vars()[method] = create_method(fn, method)
  272. for method, name in ri_dunder.items():
  273. fn = getattr(_ufuncs, name or method)
  274. plain = f"__{method}__"
  275. vars()[plain] = create_method(fn, plain)
  276. rvar = f"__r{method}__"
  277. vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
  278. ivar = f"__i{method}__"
  279. vars()[ivar] = create_method(
  280. lambda self, other, fn=fn: fn(self, other, out=self), ivar
  281. )
  282. # There's no __idivmod__
  283. __divmod__ = create_method(_ufuncs.divmod, "__divmod__")
  284. __rdivmod__ = create_method(
  285. lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
  286. )
  287. # prevent loop variables leaking into the ndarray class namespace
  288. del ivar, rvar, name, plain, fn, method
  289. @property
  290. def shape(self):
  291. return tuple(self.tensor.shape)
  292. @property
  293. def size(self):
  294. return self.tensor.numel()
  295. @property
  296. def ndim(self):
  297. return self.tensor.ndim
  298. @property
  299. def dtype(self):
  300. return _dtypes.dtype(self.tensor.dtype)
  301. @property
  302. def strides(self):
  303. elsize = self.tensor.element_size()
  304. return tuple(stride * elsize for stride in self.tensor.stride())
  305. @property
  306. def itemsize(self):
  307. return self.tensor.element_size()
  308. @property
  309. def flags(self):
  310. # Note contiguous in torch is assumed C-style
  311. return Flags(
  312. {
  313. "C_CONTIGUOUS": self.tensor.is_contiguous(),
  314. "F_CONTIGUOUS": self.T.tensor.is_contiguous(),
  315. "OWNDATA": self.tensor._base is None,
  316. "WRITEABLE": True, # pytorch does not have readonly tensors
  317. }
  318. )
  319. @property
  320. def data(self):
  321. return self.tensor.data_ptr()
  322. @property
  323. def nbytes(self):
  324. return self.tensor.storage().nbytes()
  325. @property
  326. def T(self):
  327. return self.transpose()
  328. @property
  329. def real(self):
  330. return _funcs.real(self)
  331. @real.setter
  332. def real(self, value):
  333. self.tensor.real = asarray(value).tensor
  334. @property
  335. def imag(self):
  336. return _funcs.imag(self)
  337. @imag.setter
  338. def imag(self, value):
  339. self.tensor.imag = asarray(value).tensor
  340. # ctors
  341. def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
  342. if order != "K":
  343. raise NotImplementedError(f"astype(..., order={order} is not implemented.")
  344. if casting != "unsafe":
  345. raise NotImplementedError(
  346. f"astype(..., casting={casting} is not implemented."
  347. )
  348. if not subok:
  349. raise NotImplementedError(f"astype(..., subok={subok} is not implemented.")
  350. if not copy:
  351. raise NotImplementedError(f"astype(..., copy={copy} is not implemented.")
  352. torch_dtype = _dtypes.dtype(dtype).torch_dtype
  353. t = self.tensor.to(torch_dtype)
  354. return ndarray(t)
  355. @normalizer
  356. def copy(self: ArrayLike, order: NotImplementedType = "C"):
  357. return self.clone()
  358. @normalizer
  359. def flatten(self: ArrayLike, order: NotImplementedType = "C"):
  360. return torch.flatten(self)
  361. def resize(self, *new_shape, refcheck=False):
  362. # NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
  363. if refcheck:
  364. raise NotImplementedError(
  365. f"resize(..., refcheck={refcheck} is not implemented."
  366. )
  367. if new_shape in [(), (None,)]:
  368. return
  369. # support both x.resize((2, 2)) and x.resize(2, 2)
  370. if len(new_shape) == 1:
  371. new_shape = new_shape[0]
  372. if isinstance(new_shape, int):
  373. new_shape = (new_shape,)
  374. if builtins.any(x < 0 for x in new_shape):
  375. raise ValueError("all elements of `new_shape` must be non-negative")
  376. new_numel, old_numel = math.prod(new_shape), self.tensor.numel()
  377. self.tensor.resize_(new_shape)
  378. if new_numel >= old_numel:
  379. # zero-fill new elements
  380. assert self.tensor.is_contiguous()
  381. b = self.tensor.flatten() # does not copy
  382. b[old_numel:].zero_()
  383. def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified):
  384. if dtype is _Unspecified.unspecified:
  385. dtype = self.dtype
  386. if type is not _Unspecified.unspecified:
  387. raise NotImplementedError(f"view(..., type={type} is not implemented.")
  388. torch_dtype = _dtypes.dtype(dtype).torch_dtype
  389. tview = self.tensor.view(torch_dtype)
  390. return ndarray(tview)
  391. @normalizer
  392. def fill(self, value: ArrayLike):
  393. # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
  394. # error out on D > 0 arrays
  395. self.tensor.fill_(value)
  396. def tolist(self):
  397. return self.tensor.tolist()
  398. def __iter__(self):
  399. return (ndarray(x) for x in self.tensor.__iter__())
  400. def __str__(self):
  401. return (
  402. str(self.tensor)
  403. .replace("tensor", "torch.ndarray")
  404. .replace("dtype=torch.", "dtype=")
  405. )
  406. __repr__ = create_method(__str__)
  407. def __eq__(self, other):
  408. try:
  409. return _ufuncs.equal(self, other)
  410. except (RuntimeError, TypeError):
  411. # Failed to convert other to array: definitely not equal.
  412. falsy = torch.full(self.shape, fill_value=False, dtype=bool)
  413. return asarray(falsy)
  414. def __ne__(self, other):
  415. return ~(self == other)
  416. def __index__(self):
  417. try:
  418. return operator.index(self.tensor.item())
  419. except Exception as exc:
  420. raise TypeError(
  421. "only integer scalar arrays can be converted to a scalar index"
  422. ) from exc
  423. def __bool__(self):
  424. return bool(self.tensor)
  425. def __int__(self):
  426. return int(self.tensor)
  427. def __float__(self):
  428. return float(self.tensor)
  429. def __complex__(self):
  430. return complex(self.tensor)
  431. def is_integer(self):
  432. try:
  433. v = self.tensor.item()
  434. result = int(v) == v
  435. except Exception:
  436. result = False
  437. return result
  438. def __len__(self):
  439. return self.tensor.shape[0]
  440. def __contains__(self, x):
  441. return self.tensor.__contains__(x)
  442. def transpose(self, *axes):
  443. # np.transpose(arr, axis=None) but arr.transpose(*axes)
  444. return _funcs.transpose(self, axes)
  445. def reshape(self, *shape, order="C"):
  446. # arr.reshape(shape) and arr.reshape(*shape)
  447. return _funcs.reshape(self, shape, order=order)
  448. def sort(self, axis=-1, kind=None, order=None):
  449. # ndarray.sort works in-place
  450. _funcs.copyto(self, _funcs.sort(self, axis, kind, order))
  451. def item(self, *args):
  452. # Mimic NumPy's implementation with three special cases (no arguments,
  453. # a flat index and a multi-index):
  454. # https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/methods.c#L702
  455. if args == ():
  456. return self.tensor.item()
  457. elif len(args) == 1:
  458. # int argument
  459. return self.ravel()[args[0]]
  460. else:
  461. return self.__getitem__(args)
  462. def __getitem__(self, index):
  463. tensor = self.tensor
  464. def neg_step(i, s):
  465. if not (isinstance(s, slice) and s.step is not None and s.step < 0):
  466. return s
  467. nonlocal tensor
  468. tensor = torch.flip(tensor, (i,))
  469. # Account for the fact that a slice includes the start but not the end
  470. assert isinstance(s.start, int) or s.start is None
  471. assert isinstance(s.stop, int) or s.stop is None
  472. start = s.stop + 1 if s.stop else None
  473. stop = s.start + 1 if s.start else None
  474. return slice(start, stop, -s.step)
  475. if isinstance(index, Sequence):
  476. index = type(index)(neg_step(i, s) for i, s in enumerate(index))
  477. else:
  478. index = neg_step(0, index)
  479. index = _util.ndarrays_to_tensors(index)
  480. index = _upcast_int_indices(index)
  481. # Apply NumPy-compatible indexing conversion
  482. index = _numpy_compatible_indexing(index)
  483. # Apply NumPy-compatible empty ellipsis behavior
  484. index, maybe_squeeze, _ = _numpy_empty_ellipsis_patch(index, tensor.ndim)
  485. return maybe_squeeze(ndarray(tensor.__getitem__(index)))
  486. def __setitem__(self, index, value):
  487. index = _util.ndarrays_to_tensors(index)
  488. index = _upcast_int_indices(index)
  489. # Apply NumPy-compatible indexing conversion
  490. index = _numpy_compatible_indexing(index)
  491. # Apply NumPy-compatible empty ellipsis behavior
  492. index, _, maybe_unsqueeze = _numpy_empty_ellipsis_patch(index, self.tensor.ndim)
  493. if not _dtypes_impl.is_scalar(value):
  494. value = normalize_array_like(value)
  495. value = _util.cast_if_needed(value, self.tensor.dtype)
  496. return self.tensor.__setitem__(index, maybe_unsqueeze(value))
  497. take = _funcs.take
  498. put = _funcs.put
  499. def __dlpack__(self, *, stream=None):
  500. return self.tensor.__dlpack__(stream=stream)
  501. def __dlpack_device__(self):
  502. return self.tensor.__dlpack_device__()
  503. def _tolist(obj):
  504. """Recursively convert tensors into lists."""
  505. a1 = []
  506. for elem in obj:
  507. if isinstance(elem, (list, tuple)):
  508. elem = _tolist(elem)
  509. if isinstance(elem, ndarray):
  510. a1.append(elem.tensor.tolist())
  511. else:
  512. a1.append(elem)
  513. return a1
  514. # This is the ideally the only place which talks to ndarray directly.
  515. # The rest goes through asarray (preferred) or array.
  516. def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
  517. if subok is not False:
  518. raise NotImplementedError("'subok' parameter is not supported.")
  519. if like is not None:
  520. raise NotImplementedError("'like' parameter is not supported.")
  521. if order != "K":
  522. raise NotImplementedError
  523. # a happy path
  524. if (
  525. isinstance(obj, ndarray)
  526. and copy is False
  527. and dtype is None
  528. and ndmin <= obj.ndim
  529. ):
  530. return obj
  531. if isinstance(obj, (list, tuple)):
  532. # FIXME and they have the same dtype, device, etc
  533. if obj and all(isinstance(x, torch.Tensor) for x in obj):
  534. # list of arrays: *under torch.Dynamo* these are FakeTensors
  535. obj = torch.stack(obj)
  536. else:
  537. # XXX: remove tolist
  538. # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
  539. obj = _tolist(obj)
  540. # is obj an ndarray already?
  541. if isinstance(obj, ndarray):
  542. obj = obj.tensor
  543. # is a specific dtype requested?
  544. torch_dtype = None
  545. if dtype is not None:
  546. torch_dtype = _dtypes.dtype(dtype).torch_dtype
  547. tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
  548. return ndarray(tensor)
  549. def asarray(a, dtype=None, order="K", *, like=None):
  550. return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
  551. def ascontiguousarray(a, dtype=None, *, like=None):
  552. arr = asarray(a, dtype=dtype, like=like)
  553. if not arr.tensor.is_contiguous():
  554. arr.tensor = arr.tensor.contiguous()
  555. return arr
  556. def from_dlpack(x, /):
  557. t = torch.from_dlpack(x)
  558. return ndarray(t)
  559. def _extract_dtype(entry):
  560. try:
  561. dty = _dtypes.dtype(entry)
  562. except Exception:
  563. dty = asarray(entry).dtype
  564. return dty
  565. def can_cast(from_, to, casting="safe"):
  566. from_ = _extract_dtype(from_)
  567. to_ = _extract_dtype(to)
  568. return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
  569. def result_type(*arrays_and_dtypes):
  570. tensors = []
  571. for entry in arrays_and_dtypes:
  572. try:
  573. t = asarray(entry).tensor
  574. except (RuntimeError, ValueError, TypeError):
  575. dty = _dtypes.dtype(entry)
  576. t = torch.empty(1, dtype=dty.torch_dtype)
  577. tensors.append(t)
  578. torch_dtype = _dtypes_impl.result_type_impl(*tensors)
  579. return _dtypes.dtype(torch_dtype)