_array_object.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133
  1. """
  2. Wrapper class around the ndarray object for the array API standard.
  3. The array API standard defines some behaviors differently than ndarray, in
  4. particular, type promotion rules are different (the standard has no
  5. value-based casting). The standard also specifies a more limited subset of
  6. array methods and functionalities than are implemented on ndarray. Since the
  7. goal of the array_api namespace is to be a minimal implementation of the array
  8. API standard, we need to define a separate wrapper class for the array_api
  9. namespace.
  10. The standard compliant class is only a wrapper class. It is *not* a subclass
  11. of ndarray.
  12. """
  13. from __future__ import annotations
  14. import operator
  15. from enum import IntEnum
  16. from ._creation_functions import asarray
  17. from ._dtypes import (
  18. _all_dtypes,
  19. _boolean_dtypes,
  20. _integer_dtypes,
  21. _integer_or_boolean_dtypes,
  22. _floating_dtypes,
  23. _complex_floating_dtypes,
  24. _numeric_dtypes,
  25. _result_type,
  26. _dtype_categories,
  27. )
  28. from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex
  29. import types
  30. if TYPE_CHECKING:
  31. from ._typing import Any, PyCapsule, Device, Dtype
  32. import numpy.typing as npt
  33. import numpy as np
  34. from numpy import array_api
  35. class Array:
  36. """
  37. n-d array object for the array API namespace.
  38. See the docstring of :py:obj:`np.ndarray <numpy.ndarray>` for more
  39. information.
  40. This is a wrapper around numpy.ndarray that restricts the usage to only
  41. those things that are required by the array API namespace. Note,
  42. attributes on this object that start with a single underscore are not part
  43. of the API specification and should only be used internally. This object
  44. should not be constructed directly. Rather, use one of the creation
  45. functions, such as asarray().
  46. """
  47. _array: np.ndarray[Any, Any]
  48. # Use a custom constructor instead of __init__, as manually initializing
  49. # this class is not supported API.
  50. @classmethod
  51. def _new(cls, x, /):
  52. """
  53. This is a private method for initializing the array API Array
  54. object.
  55. Functions outside of the array_api submodule should not use this
  56. method. Use one of the creation functions instead, such as
  57. ``asarray``.
  58. """
  59. obj = super().__new__(cls)
  60. # Note: The spec does not have array scalars, only 0-D arrays.
  61. if isinstance(x, np.generic):
  62. # Convert the array scalar to a 0-D array
  63. x = np.asarray(x)
  64. if x.dtype not in _all_dtypes:
  65. raise TypeError(
  66. f"The array_api namespace does not support the dtype '{x.dtype}'"
  67. )
  68. obj._array = x
  69. return obj
  70. # Prevent Array() from working
  71. def __new__(cls, *args, **kwargs):
  72. raise TypeError(
  73. "The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead."
  74. )
  75. # These functions are not required by the spec, but are implemented for
  76. # the sake of usability.
  77. def __str__(self: Array, /) -> str:
  78. """
  79. Performs the operation __str__.
  80. """
  81. return self._array.__str__().replace("array", "Array")
  82. def __repr__(self: Array, /) -> str:
  83. """
  84. Performs the operation __repr__.
  85. """
  86. suffix = f", dtype={self.dtype.name})"
  87. if 0 in self.shape:
  88. prefix = "empty("
  89. mid = str(self.shape)
  90. else:
  91. prefix = "Array("
  92. mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
  93. return prefix + mid + suffix
  94. # This function is not required by the spec, but we implement it here for
  95. # convenience so that np.asarray(np.array_api.Array) will work.
  96. def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
  97. """
  98. Warning: this method is NOT part of the array API spec. Implementers
  99. of other libraries need not include it, and users should not assume it
  100. will be present in other implementations.
  101. """
  102. return np.asarray(self._array, dtype=dtype)
  103. # These are various helper functions to make the array behavior match the
  104. # spec in places where it either deviates from or is more strict than
  105. # NumPy behavior
  106. def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
  107. """
  108. Helper function for operators to only allow specific input dtypes
  109. Use like
  110. other = self._check_allowed_dtypes(other, 'numeric', '__add__')
  111. if other is NotImplemented:
  112. return other
  113. """
  114. if self.dtype not in _dtype_categories[dtype_category]:
  115. raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
  116. if isinstance(other, (int, complex, float, bool)):
  117. other = self._promote_scalar(other)
  118. elif isinstance(other, Array):
  119. if other.dtype not in _dtype_categories[dtype_category]:
  120. raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
  121. else:
  122. return NotImplemented
  123. # This will raise TypeError for type combinations that are not allowed
  124. # to promote in the spec (even if the NumPy array operator would
  125. # promote them).
  126. res_dtype = _result_type(self.dtype, other.dtype)
  127. if op.startswith("__i"):
  128. # Note: NumPy will allow in-place operators in some cases where
  129. # the type promoted operator does not match the left-hand side
  130. # operand. For example,
  131. # >>> a = np.array(1, dtype=np.int8)
  132. # >>> a += np.array(1, dtype=np.int16)
  133. # The spec explicitly disallows this.
  134. if res_dtype != self.dtype:
  135. raise TypeError(
  136. f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}"
  137. )
  138. return other
  139. # Helper function to match the type promotion rules in the spec
  140. def _promote_scalar(self, scalar):
  141. """
  142. Returns a promoted version of a Python scalar appropriate for use with
  143. operations on self.
  144. This may raise an OverflowError in cases where the scalar is an
  145. integer that is too large to fit in a NumPy integer dtype, or
  146. TypeError when the scalar type is incompatible with the dtype of self.
  147. """
  148. # Note: Only Python scalar types that match the array dtype are
  149. # allowed.
  150. if isinstance(scalar, bool):
  151. if self.dtype not in _boolean_dtypes:
  152. raise TypeError(
  153. "Python bool scalars can only be promoted with bool arrays"
  154. )
  155. elif isinstance(scalar, int):
  156. if self.dtype in _boolean_dtypes:
  157. raise TypeError(
  158. "Python int scalars cannot be promoted with bool arrays"
  159. )
  160. if self.dtype in _integer_dtypes:
  161. info = np.iinfo(self.dtype)
  162. if not (info.min <= scalar <= info.max):
  163. raise OverflowError(
  164. "Python int scalars must be within the bounds of the dtype for integer arrays"
  165. )
  166. # int + array(floating) is allowed
  167. elif isinstance(scalar, float):
  168. if self.dtype not in _floating_dtypes:
  169. raise TypeError(
  170. "Python float scalars can only be promoted with floating-point arrays."
  171. )
  172. elif isinstance(scalar, complex):
  173. if self.dtype not in _complex_floating_dtypes:
  174. raise TypeError(
  175. "Python complex scalars can only be promoted with complex floating-point arrays."
  176. )
  177. else:
  178. raise TypeError("'scalar' must be a Python scalar")
  179. # Note: scalars are unconditionally cast to the same dtype as the
  180. # array.
  181. # Note: the spec only specifies integer-dtype/int promotion
  182. # behavior for integers within the bounds of the integer dtype.
  183. # Outside of those bounds we use the default NumPy behavior (either
  184. # cast or raise OverflowError).
  185. return Array._new(np.array(scalar, self.dtype))
  186. @staticmethod
  187. def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
  188. """
  189. Normalize inputs to two arg functions to fix type promotion rules
  190. NumPy deviates from the spec type promotion rules in cases where one
  191. argument is 0-dimensional and the other is not. For example:
  192. >>> import numpy as np
  193. >>> a = np.array([1.0], dtype=np.float32)
  194. >>> b = np.array(1.0, dtype=np.float64)
  195. >>> np.add(a, b) # The spec says this should be float64
  196. array([2.], dtype=float32)
  197. To fix this, we add a dimension to the 0-dimension array before passing it
  198. through. This works because a dimension would be added anyway from
  199. broadcasting, so the resulting shape is the same, but this prevents NumPy
  200. from not promoting the dtype.
  201. """
  202. # Another option would be to use signature=(x1.dtype, x2.dtype, None),
  203. # but that only works for ufuncs, so we would have to call the ufuncs
  204. # directly in the operator methods. One should also note that this
  205. # sort of trick wouldn't work for functions like searchsorted, which
  206. # don't do normal broadcasting, but there aren't any functions like
  207. # that in the array API namespace.
  208. if x1.ndim == 0 and x2.ndim != 0:
  209. # The _array[None] workaround was chosen because it is relatively
  210. # performant. broadcast_to(x1._array, x2.shape) is much slower. We
  211. # could also manually type promote x2, but that is more complicated
  212. # and about the same performance as this.
  213. x1 = Array._new(x1._array[None])
  214. elif x2.ndim == 0 and x1.ndim != 0:
  215. x2 = Array._new(x2._array[None])
  216. return (x1, x2)
  217. # Note: A large fraction of allowed indices are disallowed here (see the
  218. # docstring below)
  219. def _validate_index(self, key):
  220. """
  221. Validate an index according to the array API.
  222. The array API specification only requires a subset of indices that are
  223. supported by NumPy. This function will reject any index that is
  224. allowed by NumPy but not required by the array API specification. We
  225. always raise ``IndexError`` on such indices (the spec does not require
  226. any specific behavior on them, but this makes the NumPy array API
  227. namespace a minimal implementation of the spec). See
  228. https://data-apis.org/array-api/latest/API_specification/indexing.html
  229. for the full list of required indexing behavior
  230. This function raises IndexError if the index ``key`` is invalid. It
  231. only raises ``IndexError`` on indices that are not already rejected by
  232. NumPy, as NumPy will already raise the appropriate error on such
  233. indices. ``shape`` may be None, in which case, only cases that are
  234. independent of the array shape are checked.
  235. The following cases are allowed by NumPy, but not specified by the array
  236. API specification:
  237. - Indices to not include an implicit ellipsis at the end. That is,
  238. every axis of an array must be explicitly indexed or an ellipsis
  239. included. This behaviour is sometimes referred to as flat indexing.
  240. - The start and stop of a slice may not be out of bounds. In
  241. particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
  242. following are allowed:
  243. - ``i`` or ``j`` omitted (``None``).
  244. - ``-n <= i <= max(0, n - 1)``.
  245. - For ``k > 0`` or ``k`` omitted (``None``), ``-n <= j <= n``.
  246. - For ``k < 0``, ``-n - 1 <= j <= max(0, n - 1)``.
  247. - Boolean array indices are not allowed as part of a larger tuple
  248. index.
  249. - Integer array indices are not allowed (with the exception of 0-D
  250. arrays, which are treated the same as scalars).
  251. Additionally, it should be noted that indices that would return a
  252. scalar in NumPy will return a 0-D array. Array scalars are not allowed
  253. in the specification, only 0-D arrays. This is done in the
  254. ``Array._new`` constructor, not this function.
  255. """
  256. _key = key if isinstance(key, tuple) else (key,)
  257. for i in _key:
  258. if isinstance(i, bool) or not (
  259. isinstance(i, SupportsIndex) # i.e. ints
  260. or isinstance(i, slice)
  261. or i == Ellipsis
  262. or i is None
  263. or isinstance(i, Array)
  264. or isinstance(i, np.ndarray)
  265. ):
  266. raise IndexError(
  267. f"Single-axes index {i} has {type(i)=}, but only "
  268. "integers, slices (:), ellipsis (...), newaxis (None), "
  269. "zero-dimensional integer arrays and boolean arrays "
  270. "are specified in the Array API."
  271. )
  272. nonexpanding_key = []
  273. single_axes = []
  274. n_ellipsis = 0
  275. key_has_mask = False
  276. for i in _key:
  277. if i is not None:
  278. nonexpanding_key.append(i)
  279. if isinstance(i, Array) or isinstance(i, np.ndarray):
  280. if i.dtype in _boolean_dtypes:
  281. key_has_mask = True
  282. single_axes.append(i)
  283. else:
  284. # i must not be an array here, to avoid elementwise equals
  285. if i == Ellipsis:
  286. n_ellipsis += 1
  287. else:
  288. single_axes.append(i)
  289. n_single_axes = len(single_axes)
  290. if n_ellipsis > 1:
  291. return # handled by ndarray
  292. elif n_ellipsis == 0:
  293. # Note boolean masks must be the sole index, which we check for
  294. # later on.
  295. if not key_has_mask and n_single_axes < self.ndim:
  296. raise IndexError(
  297. f"{self.ndim=}, but the multi-axes index only specifies "
  298. f"{n_single_axes} dimensions. If this was intentional, "
  299. "add a trailing ellipsis (...) which expands into as many "
  300. "slices (:) as necessary - this is what np.ndarray arrays "
  301. "implicitly do, but such flat indexing behaviour is not "
  302. "specified in the Array API."
  303. )
  304. if n_ellipsis == 0:
  305. indexed_shape = self.shape
  306. else:
  307. ellipsis_start = None
  308. for pos, i in enumerate(nonexpanding_key):
  309. if not (isinstance(i, Array) or isinstance(i, np.ndarray)):
  310. if i == Ellipsis:
  311. ellipsis_start = pos
  312. break
  313. assert ellipsis_start is not None # sanity check
  314. ellipsis_end = self.ndim - (n_single_axes - ellipsis_start)
  315. indexed_shape = (
  316. self.shape[:ellipsis_start] + self.shape[ellipsis_end:]
  317. )
  318. for i, side in zip(single_axes, indexed_shape):
  319. if isinstance(i, slice):
  320. if side == 0:
  321. f_range = "0 (or None)"
  322. else:
  323. f_range = f"between -{side} and {side - 1} (or None)"
  324. if i.start is not None:
  325. try:
  326. start = operator.index(i.start)
  327. except TypeError:
  328. pass # handled by ndarray
  329. else:
  330. if not (-side <= start <= side):
  331. raise IndexError(
  332. f"Slice {i} contains {start=}, but should be "
  333. f"{f_range} for an axis of size {side} "
  334. "(out-of-bounds starts are not specified in "
  335. "the Array API)"
  336. )
  337. if i.stop is not None:
  338. try:
  339. stop = operator.index(i.stop)
  340. except TypeError:
  341. pass # handled by ndarray
  342. else:
  343. if not (-side <= stop <= side):
  344. raise IndexError(
  345. f"Slice {i} contains {stop=}, but should be "
  346. f"{f_range} for an axis of size {side} "
  347. "(out-of-bounds stops are not specified in "
  348. "the Array API)"
  349. )
  350. elif isinstance(i, Array):
  351. if i.dtype in _boolean_dtypes and len(_key) != 1:
  352. assert isinstance(key, tuple) # sanity check
  353. raise IndexError(
  354. f"Single-axes index {i} is a boolean array and "
  355. f"{len(key)=}, but masking is only specified in the "
  356. "Array API when the array is the sole index."
  357. )
  358. elif i.dtype in _integer_dtypes and i.ndim != 0:
  359. raise IndexError(
  360. f"Single-axes index {i} is a non-zero-dimensional "
  361. "integer array, but advanced integer indexing is not "
  362. "specified in the Array API."
  363. )
  364. elif isinstance(i, tuple):
  365. raise IndexError(
  366. f"Single-axes index {i} is a tuple, but nested tuple "
  367. "indices are not specified in the Array API."
  368. )
  369. # Everything below this line is required by the spec.
  370. def __abs__(self: Array, /) -> Array:
  371. """
  372. Performs the operation __abs__.
  373. """
  374. if self.dtype not in _numeric_dtypes:
  375. raise TypeError("Only numeric dtypes are allowed in __abs__")
  376. res = self._array.__abs__()
  377. return self.__class__._new(res)
  378. def __add__(self: Array, other: Union[int, float, Array], /) -> Array:
  379. """
  380. Performs the operation __add__.
  381. """
  382. other = self._check_allowed_dtypes(other, "numeric", "__add__")
  383. if other is NotImplemented:
  384. return other
  385. self, other = self._normalize_two_args(self, other)
  386. res = self._array.__add__(other._array)
  387. return self.__class__._new(res)
  388. def __and__(self: Array, other: Union[int, bool, Array], /) -> Array:
  389. """
  390. Performs the operation __and__.
  391. """
  392. other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
  393. if other is NotImplemented:
  394. return other
  395. self, other = self._normalize_two_args(self, other)
  396. res = self._array.__and__(other._array)
  397. return self.__class__._new(res)
  398. def __array_namespace__(
  399. self: Array, /, *, api_version: Optional[str] = None
  400. ) -> types.ModuleType:
  401. if api_version is not None and not api_version.startswith("2021."):
  402. raise ValueError(f"Unrecognized array API version: {api_version!r}")
  403. return array_api
  404. def __bool__(self: Array, /) -> bool:
  405. """
  406. Performs the operation __bool__.
  407. """
  408. # Note: This is an error here.
  409. if self._array.ndim != 0:
  410. raise TypeError("bool is only allowed on arrays with 0 dimensions")
  411. res = self._array.__bool__()
  412. return res
  413. def __complex__(self: Array, /) -> complex:
  414. """
  415. Performs the operation __complex__.
  416. """
  417. # Note: This is an error here.
  418. if self._array.ndim != 0:
  419. raise TypeError("complex is only allowed on arrays with 0 dimensions")
  420. res = self._array.__complex__()
  421. return res
  422. def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
  423. """
  424. Performs the operation __dlpack__.
  425. """
  426. return self._array.__dlpack__(stream=stream)
  427. def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
  428. """
  429. Performs the operation __dlpack_device__.
  430. """
  431. # Note: device support is required for this
  432. return self._array.__dlpack_device__()
  433. def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
  434. """
  435. Performs the operation __eq__.
  436. """
  437. # Even though "all" dtypes are allowed, we still require them to be
  438. # promotable with each other.
  439. other = self._check_allowed_dtypes(other, "all", "__eq__")
  440. if other is NotImplemented:
  441. return other
  442. self, other = self._normalize_two_args(self, other)
  443. res = self._array.__eq__(other._array)
  444. return self.__class__._new(res)
  445. def __float__(self: Array, /) -> float:
  446. """
  447. Performs the operation __float__.
  448. """
  449. # Note: This is an error here.
  450. if self._array.ndim != 0:
  451. raise TypeError("float is only allowed on arrays with 0 dimensions")
  452. if self.dtype in _complex_floating_dtypes:
  453. raise TypeError("float is not allowed on complex floating-point arrays")
  454. res = self._array.__float__()
  455. return res
  456. def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
  457. """
  458. Performs the operation __floordiv__.
  459. """
  460. other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
  461. if other is NotImplemented:
  462. return other
  463. self, other = self._normalize_two_args(self, other)
  464. res = self._array.__floordiv__(other._array)
  465. return self.__class__._new(res)
  466. def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
  467. """
  468. Performs the operation __ge__.
  469. """
  470. other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
  471. if other is NotImplemented:
  472. return other
  473. self, other = self._normalize_two_args(self, other)
  474. res = self._array.__ge__(other._array)
  475. return self.__class__._new(res)
  476. def __getitem__(
  477. self: Array,
  478. key: Union[
  479. int,
  480. slice,
  481. ellipsis,
  482. Tuple[Union[int, slice, ellipsis, None], ...],
  483. Array,
  484. ],
  485. /,
  486. ) -> Array:
  487. """
  488. Performs the operation __getitem__.
  489. """
  490. # Note: Only indices required by the spec are allowed. See the
  491. # docstring of _validate_index
  492. self._validate_index(key)
  493. if isinstance(key, Array):
  494. # Indexing self._array with array_api arrays can be erroneous
  495. key = key._array
  496. res = self._array.__getitem__(key)
  497. return self._new(res)
  498. def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
  499. """
  500. Performs the operation __gt__.
  501. """
  502. other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
  503. if other is NotImplemented:
  504. return other
  505. self, other = self._normalize_two_args(self, other)
  506. res = self._array.__gt__(other._array)
  507. return self.__class__._new(res)
  508. def __int__(self: Array, /) -> int:
  509. """
  510. Performs the operation __int__.
  511. """
  512. # Note: This is an error here.
  513. if self._array.ndim != 0:
  514. raise TypeError("int is only allowed on arrays with 0 dimensions")
  515. if self.dtype in _complex_floating_dtypes:
  516. raise TypeError("int is not allowed on complex floating-point arrays")
  517. res = self._array.__int__()
  518. return res
  519. def __index__(self: Array, /) -> int:
  520. """
  521. Performs the operation __index__.
  522. """
  523. res = self._array.__index__()
  524. return res
  525. def __invert__(self: Array, /) -> Array:
  526. """
  527. Performs the operation __invert__.
  528. """
  529. if self.dtype not in _integer_or_boolean_dtypes:
  530. raise TypeError("Only integer or boolean dtypes are allowed in __invert__")
  531. res = self._array.__invert__()
  532. return self.__class__._new(res)
  533. def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
  534. """
  535. Performs the operation __le__.
  536. """
  537. other = self._check_allowed_dtypes(other, "real numeric", "__le__")
  538. if other is NotImplemented:
  539. return other
  540. self, other = self._normalize_two_args(self, other)
  541. res = self._array.__le__(other._array)
  542. return self.__class__._new(res)
  543. def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
  544. """
  545. Performs the operation __lshift__.
  546. """
  547. other = self._check_allowed_dtypes(other, "integer", "__lshift__")
  548. if other is NotImplemented:
  549. return other
  550. self, other = self._normalize_two_args(self, other)
  551. res = self._array.__lshift__(other._array)
  552. return self.__class__._new(res)
  553. def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
  554. """
  555. Performs the operation __lt__.
  556. """
  557. other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
  558. if other is NotImplemented:
  559. return other
  560. self, other = self._normalize_two_args(self, other)
  561. res = self._array.__lt__(other._array)
  562. return self.__class__._new(res)
  563. def __matmul__(self: Array, other: Array, /) -> Array:
  564. """
  565. Performs the operation __matmul__.
  566. """
  567. # matmul is not defined for scalars, but without this, we may get
  568. # the wrong error message from asarray.
  569. other = self._check_allowed_dtypes(other, "numeric", "__matmul__")
  570. if other is NotImplemented:
  571. return other
  572. res = self._array.__matmul__(other._array)
  573. return self.__class__._new(res)
  574. def __mod__(self: Array, other: Union[int, float, Array], /) -> Array:
  575. """
  576. Performs the operation __mod__.
  577. """
  578. other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
  579. if other is NotImplemented:
  580. return other
  581. self, other = self._normalize_two_args(self, other)
  582. res = self._array.__mod__(other._array)
  583. return self.__class__._new(res)
  584. def __mul__(self: Array, other: Union[int, float, Array], /) -> Array:
  585. """
  586. Performs the operation __mul__.
  587. """
  588. other = self._check_allowed_dtypes(other, "numeric", "__mul__")
  589. if other is NotImplemented:
  590. return other
  591. self, other = self._normalize_two_args(self, other)
  592. res = self._array.__mul__(other._array)
  593. return self.__class__._new(res)
  594. def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
  595. """
  596. Performs the operation __ne__.
  597. """
  598. other = self._check_allowed_dtypes(other, "all", "__ne__")
  599. if other is NotImplemented:
  600. return other
  601. self, other = self._normalize_two_args(self, other)
  602. res = self._array.__ne__(other._array)
  603. return self.__class__._new(res)
  604. def __neg__(self: Array, /) -> Array:
  605. """
  606. Performs the operation __neg__.
  607. """
  608. if self.dtype not in _numeric_dtypes:
  609. raise TypeError("Only numeric dtypes are allowed in __neg__")
  610. res = self._array.__neg__()
  611. return self.__class__._new(res)
  612. def __or__(self: Array, other: Union[int, bool, Array], /) -> Array:
  613. """
  614. Performs the operation __or__.
  615. """
  616. other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
  617. if other is NotImplemented:
  618. return other
  619. self, other = self._normalize_two_args(self, other)
  620. res = self._array.__or__(other._array)
  621. return self.__class__._new(res)
  622. def __pos__(self: Array, /) -> Array:
  623. """
  624. Performs the operation __pos__.
  625. """
  626. if self.dtype not in _numeric_dtypes:
  627. raise TypeError("Only numeric dtypes are allowed in __pos__")
  628. res = self._array.__pos__()
  629. return self.__class__._new(res)
  630. def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
  631. """
  632. Performs the operation __pow__.
  633. """
  634. from ._elementwise_functions import pow
  635. other = self._check_allowed_dtypes(other, "numeric", "__pow__")
  636. if other is NotImplemented:
  637. return other
  638. # Note: NumPy's __pow__ does not follow type promotion rules for 0-d
  639. # arrays, so we use pow() here instead.
  640. return pow(self, other)
  641. def __rshift__(self: Array, other: Union[int, Array], /) -> Array:
  642. """
  643. Performs the operation __rshift__.
  644. """
  645. other = self._check_allowed_dtypes(other, "integer", "__rshift__")
  646. if other is NotImplemented:
  647. return other
  648. self, other = self._normalize_two_args(self, other)
  649. res = self._array.__rshift__(other._array)
  650. return self.__class__._new(res)
  651. def __setitem__(
  652. self,
  653. key: Union[
  654. int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array
  655. ],
  656. value: Union[int, float, bool, Array],
  657. /,
  658. ) -> None:
  659. """
  660. Performs the operation __setitem__.
  661. """
  662. # Note: Only indices required by the spec are allowed. See the
  663. # docstring of _validate_index
  664. self._validate_index(key)
  665. if isinstance(key, Array):
  666. # Indexing self._array with array_api arrays can be erroneous
  667. key = key._array
  668. self._array.__setitem__(key, asarray(value)._array)
  669. def __sub__(self: Array, other: Union[int, float, Array], /) -> Array:
  670. """
  671. Performs the operation __sub__.
  672. """
  673. other = self._check_allowed_dtypes(other, "numeric", "__sub__")
  674. if other is NotImplemented:
  675. return other
  676. self, other = self._normalize_two_args(self, other)
  677. res = self._array.__sub__(other._array)
  678. return self.__class__._new(res)
  679. # PEP 484 requires int to be a subtype of float, but __truediv__ should
  680. # not accept int.
  681. def __truediv__(self: Array, other: Union[float, Array], /) -> Array:
  682. """
  683. Performs the operation __truediv__.
  684. """
  685. other = self._check_allowed_dtypes(other, "floating-point", "__truediv__")
  686. if other is NotImplemented:
  687. return other
  688. self, other = self._normalize_two_args(self, other)
  689. res = self._array.__truediv__(other._array)
  690. return self.__class__._new(res)
  691. def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array:
  692. """
  693. Performs the operation __xor__.
  694. """
  695. other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__")
  696. if other is NotImplemented:
  697. return other
  698. self, other = self._normalize_two_args(self, other)
  699. res = self._array.__xor__(other._array)
  700. return self.__class__._new(res)
  701. def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array:
  702. """
  703. Performs the operation __iadd__.
  704. """
  705. other = self._check_allowed_dtypes(other, "numeric", "__iadd__")
  706. if other is NotImplemented:
  707. return other
  708. self._array.__iadd__(other._array)
  709. return self
  710. def __radd__(self: Array, other: Union[int, float, Array], /) -> Array:
  711. """
  712. Performs the operation __radd__.
  713. """
  714. other = self._check_allowed_dtypes(other, "numeric", "__radd__")
  715. if other is NotImplemented:
  716. return other
  717. self, other = self._normalize_two_args(self, other)
  718. res = self._array.__radd__(other._array)
  719. return self.__class__._new(res)
  720. def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array:
  721. """
  722. Performs the operation __iand__.
  723. """
  724. other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__")
  725. if other is NotImplemented:
  726. return other
  727. self._array.__iand__(other._array)
  728. return self
  729. def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array:
  730. """
  731. Performs the operation __rand__.
  732. """
  733. other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__")
  734. if other is NotImplemented:
  735. return other
  736. self, other = self._normalize_two_args(self, other)
  737. res = self._array.__rand__(other._array)
  738. return self.__class__._new(res)
  739. def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
  740. """
  741. Performs the operation __ifloordiv__.
  742. """
  743. other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__")
  744. if other is NotImplemented:
  745. return other
  746. self._array.__ifloordiv__(other._array)
  747. return self
  748. def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
  749. """
  750. Performs the operation __rfloordiv__.
  751. """
  752. other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__")
  753. if other is NotImplemented:
  754. return other
  755. self, other = self._normalize_two_args(self, other)
  756. res = self._array.__rfloordiv__(other._array)
  757. return self.__class__._new(res)
  758. def __ilshift__(self: Array, other: Union[int, Array], /) -> Array:
  759. """
  760. Performs the operation __ilshift__.
  761. """
  762. other = self._check_allowed_dtypes(other, "integer", "__ilshift__")
  763. if other is NotImplemented:
  764. return other
  765. self._array.__ilshift__(other._array)
  766. return self
  767. def __rlshift__(self: Array, other: Union[int, Array], /) -> Array:
  768. """
  769. Performs the operation __rlshift__.
  770. """
  771. other = self._check_allowed_dtypes(other, "integer", "__rlshift__")
  772. if other is NotImplemented:
  773. return other
  774. self, other = self._normalize_two_args(self, other)
  775. res = self._array.__rlshift__(other._array)
  776. return self.__class__._new(res)
  777. def __imatmul__(self: Array, other: Array, /) -> Array:
  778. """
  779. Performs the operation __imatmul__.
  780. """
  781. # matmul is not defined for scalars, but without this, we may get
  782. # the wrong error message from asarray.
  783. other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
  784. if other is NotImplemented:
  785. return other
  786. res = self._array.__imatmul__(other._array)
  787. return self.__class__._new(res)
  788. def __rmatmul__(self: Array, other: Array, /) -> Array:
  789. """
  790. Performs the operation __rmatmul__.
  791. """
  792. # matmul is not defined for scalars, but without this, we may get
  793. # the wrong error message from asarray.
  794. other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__")
  795. if other is NotImplemented:
  796. return other
  797. res = self._array.__rmatmul__(other._array)
  798. return self.__class__._new(res)
  799. def __imod__(self: Array, other: Union[int, float, Array], /) -> Array:
  800. """
  801. Performs the operation __imod__.
  802. """
  803. other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
  804. if other is NotImplemented:
  805. return other
  806. self._array.__imod__(other._array)
  807. return self
  808. def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array:
  809. """
  810. Performs the operation __rmod__.
  811. """
  812. other = self._check_allowed_dtypes(other, "real numeric", "__rmod__")
  813. if other is NotImplemented:
  814. return other
  815. self, other = self._normalize_two_args(self, other)
  816. res = self._array.__rmod__(other._array)
  817. return self.__class__._new(res)
  818. def __imul__(self: Array, other: Union[int, float, Array], /) -> Array:
  819. """
  820. Performs the operation __imul__.
  821. """
  822. other = self._check_allowed_dtypes(other, "numeric", "__imul__")
  823. if other is NotImplemented:
  824. return other
  825. self._array.__imul__(other._array)
  826. return self
  827. def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array:
  828. """
  829. Performs the operation __rmul__.
  830. """
  831. other = self._check_allowed_dtypes(other, "numeric", "__rmul__")
  832. if other is NotImplemented:
  833. return other
  834. self, other = self._normalize_two_args(self, other)
  835. res = self._array.__rmul__(other._array)
  836. return self.__class__._new(res)
  837. def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array:
  838. """
  839. Performs the operation __ior__.
  840. """
  841. other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__")
  842. if other is NotImplemented:
  843. return other
  844. self._array.__ior__(other._array)
  845. return self
  846. def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array:
  847. """
  848. Performs the operation __ror__.
  849. """
  850. other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__")
  851. if other is NotImplemented:
  852. return other
  853. self, other = self._normalize_two_args(self, other)
  854. res = self._array.__ror__(other._array)
  855. return self.__class__._new(res)
  856. def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
  857. """
  858. Performs the operation __ipow__.
  859. """
  860. other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
  861. if other is NotImplemented:
  862. return other
  863. self._array.__ipow__(other._array)
  864. return self
  865. def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
  866. """
  867. Performs the operation __rpow__.
  868. """
  869. from ._elementwise_functions import pow
  870. other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
  871. if other is NotImplemented:
  872. return other
  873. # Note: NumPy's __pow__ does not follow the spec type promotion rules
  874. # for 0-d arrays, so we use pow() here instead.
  875. return pow(other, self)
  876. def __irshift__(self: Array, other: Union[int, Array], /) -> Array:
  877. """
  878. Performs the operation __irshift__.
  879. """
  880. other = self._check_allowed_dtypes(other, "integer", "__irshift__")
  881. if other is NotImplemented:
  882. return other
  883. self._array.__irshift__(other._array)
  884. return self
  885. def __rrshift__(self: Array, other: Union[int, Array], /) -> Array:
  886. """
  887. Performs the operation __rrshift__.
  888. """
  889. other = self._check_allowed_dtypes(other, "integer", "__rrshift__")
  890. if other is NotImplemented:
  891. return other
  892. self, other = self._normalize_two_args(self, other)
  893. res = self._array.__rrshift__(other._array)
  894. return self.__class__._new(res)
  895. def __isub__(self: Array, other: Union[int, float, Array], /) -> Array:
  896. """
  897. Performs the operation __isub__.
  898. """
  899. other = self._check_allowed_dtypes(other, "numeric", "__isub__")
  900. if other is NotImplemented:
  901. return other
  902. self._array.__isub__(other._array)
  903. return self
  904. def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array:
  905. """
  906. Performs the operation __rsub__.
  907. """
  908. other = self._check_allowed_dtypes(other, "numeric", "__rsub__")
  909. if other is NotImplemented:
  910. return other
  911. self, other = self._normalize_two_args(self, other)
  912. res = self._array.__rsub__(other._array)
  913. return self.__class__._new(res)
  914. def __itruediv__(self: Array, other: Union[float, Array], /) -> Array:
  915. """
  916. Performs the operation __itruediv__.
  917. """
  918. other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__")
  919. if other is NotImplemented:
  920. return other
  921. self._array.__itruediv__(other._array)
  922. return self
  923. def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array:
  924. """
  925. Performs the operation __rtruediv__.
  926. """
  927. other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__")
  928. if other is NotImplemented:
  929. return other
  930. self, other = self._normalize_two_args(self, other)
  931. res = self._array.__rtruediv__(other._array)
  932. return self.__class__._new(res)
  933. def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array:
  934. """
  935. Performs the operation __ixor__.
  936. """
  937. other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__")
  938. if other is NotImplemented:
  939. return other
  940. self._array.__ixor__(other._array)
  941. return self
  942. def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
  943. """
  944. Performs the operation __rxor__.
  945. """
  946. other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__")
  947. if other is NotImplemented:
  948. return other
  949. self, other = self._normalize_two_args(self, other)
  950. res = self._array.__rxor__(other._array)
  951. return self.__class__._new(res)
  952. def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
  953. if stream is not None:
  954. raise ValueError("The stream argument to to_device() is not supported")
  955. if device == 'cpu':
  956. return self
  957. raise ValueError(f"Unsupported device {device!r}")
  958. @property
  959. def dtype(self) -> Dtype:
  960. """
  961. Array API compatible wrapper for :py:meth:`np.ndarray.dtype <numpy.ndarray.dtype>`.
  962. See its docstring for more information.
  963. """
  964. return self._array.dtype
  965. @property
  966. def device(self) -> Device:
  967. return "cpu"
  968. # Note: mT is new in array API spec (see matrix_transpose)
  969. @property
  970. def mT(self) -> Array:
  971. from .linalg import matrix_transpose
  972. return matrix_transpose(self)
  973. @property
  974. def ndim(self) -> int:
  975. """
  976. Array API compatible wrapper for :py:meth:`np.ndarray.ndim <numpy.ndarray.ndim>`.
  977. See its docstring for more information.
  978. """
  979. return self._array.ndim
  980. @property
  981. def shape(self) -> Tuple[int, ...]:
  982. """
  983. Array API compatible wrapper for :py:meth:`np.ndarray.shape <numpy.ndarray.shape>`.
  984. See its docstring for more information.
  985. """
  986. return self._array.shape
  987. @property
  988. def size(self) -> int:
  989. """
  990. Array API compatible wrapper for :py:meth:`np.ndarray.size <numpy.ndarray.size>`.
  991. See its docstring for more information.
  992. """
  993. return self._array.size
  994. @property
  995. def T(self) -> Array:
  996. """
  997. Array API compatible wrapper for :py:meth:`np.ndarray.T <numpy.ndarray.T>`.
  998. See its docstring for more information.
  999. """
  1000. # Note: T only works on 2-dimensional arrays. See the corresponding
  1001. # note in the specification:
  1002. # https://data-apis.org/array-api/latest/API_specification/array_object.html#t
  1003. if self.ndim != 2:
  1004. raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.")
  1005. return self.__class__._new(self._array.T)