linalg.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. from __future__ import annotations
  2. from ._dtypes import (
  3. _floating_dtypes,
  4. _numeric_dtypes,
  5. float32,
  6. float64,
  7. complex64,
  8. complex128
  9. )
  10. from ._manipulation_functions import reshape
  11. from ._elementwise_functions import conj
  12. from ._array_object import Array
  13. from ..core.numeric import normalize_axis_tuple
  14. from typing import TYPE_CHECKING
  15. if TYPE_CHECKING:
  16. from ._typing import Literal, Optional, Sequence, Tuple, Union, Dtype
  17. from typing import NamedTuple
  18. import numpy.linalg
  19. import numpy as np
  20. class EighResult(NamedTuple):
  21. eigenvalues: Array
  22. eigenvectors: Array
  23. class QRResult(NamedTuple):
  24. Q: Array
  25. R: Array
  26. class SlogdetResult(NamedTuple):
  27. sign: Array
  28. logabsdet: Array
  29. class SVDResult(NamedTuple):
  30. U: Array
  31. S: Array
  32. Vh: Array
  33. # Note: the inclusion of the upper keyword is different from
  34. # np.linalg.cholesky, which does not have it.
  35. def cholesky(x: Array, /, *, upper: bool = False) -> Array:
  36. """
  37. Array API compatible wrapper for :py:func:`np.linalg.cholesky <numpy.linalg.cholesky>`.
  38. See its docstring for more information.
  39. """
  40. # Note: the restriction to floating-point dtypes only is different from
  41. # np.linalg.cholesky.
  42. if x.dtype not in _floating_dtypes:
  43. raise TypeError('Only floating-point dtypes are allowed in cholesky')
  44. L = np.linalg.cholesky(x._array)
  45. if upper:
  46. U = Array._new(L).mT
  47. if U.dtype in [complex64, complex128]:
  48. U = conj(U)
  49. return U
  50. return Array._new(L)
  51. # Note: cross is the numpy top-level namespace, not np.linalg
  52. def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
  53. """
  54. Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
  55. See its docstring for more information.
  56. """
  57. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  58. raise TypeError('Only numeric dtypes are allowed in cross')
  59. # Note: this is different from np.cross(), which broadcasts
  60. if x1.shape != x2.shape:
  61. raise ValueError('x1 and x2 must have the same shape')
  62. if x1.ndim == 0:
  63. raise ValueError('cross() requires arrays of dimension at least 1')
  64. # Note: this is different from np.cross(), which allows dimension 2
  65. if x1.shape[axis] != 3:
  66. raise ValueError('cross() dimension must equal 3')
  67. return Array._new(np.cross(x1._array, x2._array, axis=axis))
  68. def det(x: Array, /) -> Array:
  69. """
  70. Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
  71. See its docstring for more information.
  72. """
  73. # Note: the restriction to floating-point dtypes only is different from
  74. # np.linalg.det.
  75. if x.dtype not in _floating_dtypes:
  76. raise TypeError('Only floating-point dtypes are allowed in det')
  77. return Array._new(np.linalg.det(x._array))
  78. # Note: diagonal is the numpy top-level namespace, not np.linalg
  79. def diagonal(x: Array, /, *, offset: int = 0) -> Array:
  80. """
  81. Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
  82. See its docstring for more information.
  83. """
  84. # Note: diagonal always operates on the last two axes, whereas np.diagonal
  85. # operates on the first two axes by default
  86. return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
  87. def eigh(x: Array, /) -> EighResult:
  88. """
  89. Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`.
  90. See its docstring for more information.
  91. """
  92. # Note: the restriction to floating-point dtypes only is different from
  93. # np.linalg.eigh.
  94. if x.dtype not in _floating_dtypes:
  95. raise TypeError('Only floating-point dtypes are allowed in eigh')
  96. # Note: the return type here is a namedtuple, which is different from
  97. # np.eigh, which only returns a tuple.
  98. return EighResult(*map(Array._new, np.linalg.eigh(x._array)))
  99. def eigvalsh(x: Array, /) -> Array:
  100. """
  101. Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`.
  102. See its docstring for more information.
  103. """
  104. # Note: the restriction to floating-point dtypes only is different from
  105. # np.linalg.eigvalsh.
  106. if x.dtype not in _floating_dtypes:
  107. raise TypeError('Only floating-point dtypes are allowed in eigvalsh')
  108. return Array._new(np.linalg.eigvalsh(x._array))
  109. def inv(x: Array, /) -> Array:
  110. """
  111. Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
  112. See its docstring for more information.
  113. """
  114. # Note: the restriction to floating-point dtypes only is different from
  115. # np.linalg.inv.
  116. if x.dtype not in _floating_dtypes:
  117. raise TypeError('Only floating-point dtypes are allowed in inv')
  118. return Array._new(np.linalg.inv(x._array))
  119. # Note: matmul is the numpy top-level namespace but not in np.linalg
  120. def matmul(x1: Array, x2: Array, /) -> Array:
  121. """
  122. Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
  123. See its docstring for more information.
  124. """
  125. # Note: the restriction to numeric dtypes only is different from
  126. # np.matmul.
  127. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  128. raise TypeError('Only numeric dtypes are allowed in matmul')
  129. return Array._new(np.matmul(x1._array, x2._array))
  130. # Note: the name here is different from norm(). The array API norm is split
  131. # into matrix_norm and vector_norm().
  132. # The type for ord should be Optional[Union[int, float, Literal[np.inf,
  133. # -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point
  134. # literals.
  135. def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array:
  136. """
  137. Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
  138. See its docstring for more information.
  139. """
  140. # Note: the restriction to floating-point dtypes only is different from
  141. # np.linalg.norm.
  142. if x.dtype not in _floating_dtypes:
  143. raise TypeError('Only floating-point dtypes are allowed in matrix_norm')
  144. return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord))
  145. def matrix_power(x: Array, n: int, /) -> Array:
  146. """
  147. Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`.
  148. See its docstring for more information.
  149. """
  150. # Note: the restriction to floating-point dtypes only is different from
  151. # np.linalg.matrix_power.
  152. if x.dtype not in _floating_dtypes:
  153. raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power')
  154. # np.matrix_power already checks if n is an integer
  155. return Array._new(np.linalg.matrix_power(x._array, n))
  156. # Note: the keyword argument name rtol is different from np.linalg.matrix_rank
  157. def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
  158. """
  159. Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
  160. See its docstring for more information.
  161. """
  162. # Note: this is different from np.linalg.matrix_rank, which supports 1
  163. # dimensional arrays.
  164. if x.ndim < 2:
  165. raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
  166. S = np.linalg.svd(x._array, compute_uv=False)
  167. if rtol is None:
  168. tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps
  169. else:
  170. if isinstance(rtol, Array):
  171. rtol = rtol._array
  172. # Note: this is different from np.linalg.matrix_rank, which does not multiply
  173. # the tolerance by the largest singular value.
  174. tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis]
  175. return Array._new(np.count_nonzero(S > tol, axis=-1))
  176. # Note: this function is new in the array API spec. Unlike transpose, it only
  177. # transposes the last two axes.
  178. def matrix_transpose(x: Array, /) -> Array:
  179. if x.ndim < 2:
  180. raise ValueError("x must be at least 2-dimensional for matrix_transpose")
  181. return Array._new(np.swapaxes(x._array, -1, -2))
  182. # Note: outer is the numpy top-level namespace, not np.linalg
  183. def outer(x1: Array, x2: Array, /) -> Array:
  184. """
  185. Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
  186. See its docstring for more information.
  187. """
  188. # Note: the restriction to numeric dtypes only is different from
  189. # np.outer.
  190. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  191. raise TypeError('Only numeric dtypes are allowed in outer')
  192. # Note: the restriction to only 1-dim arrays is different from np.outer
  193. if x1.ndim != 1 or x2.ndim != 1:
  194. raise ValueError('The input arrays to outer must be 1-dimensional')
  195. return Array._new(np.outer(x1._array, x2._array))
  196. # Note: the keyword argument name rtol is different from np.linalg.pinv
  197. def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
  198. """
  199. Array API compatible wrapper for :py:func:`np.linalg.pinv <numpy.linalg.pinv>`.
  200. See its docstring for more information.
  201. """
  202. # Note: the restriction to floating-point dtypes only is different from
  203. # np.linalg.pinv.
  204. if x.dtype not in _floating_dtypes:
  205. raise TypeError('Only floating-point dtypes are allowed in pinv')
  206. # Note: this is different from np.linalg.pinv, which does not multiply the
  207. # default tolerance by max(M, N).
  208. if rtol is None:
  209. rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps
  210. return Array._new(np.linalg.pinv(x._array, rcond=rtol))
  211. def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult:
  212. """
  213. Array API compatible wrapper for :py:func:`np.linalg.qr <numpy.linalg.qr>`.
  214. See its docstring for more information.
  215. """
  216. # Note: the restriction to floating-point dtypes only is different from
  217. # np.linalg.qr.
  218. if x.dtype not in _floating_dtypes:
  219. raise TypeError('Only floating-point dtypes are allowed in qr')
  220. # Note: the return type here is a namedtuple, which is different from
  221. # np.linalg.qr, which only returns a tuple.
  222. return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode)))
  223. def slogdet(x: Array, /) -> SlogdetResult:
  224. """
  225. Array API compatible wrapper for :py:func:`np.linalg.slogdet <numpy.linalg.slogdet>`.
  226. See its docstring for more information.
  227. """
  228. # Note: the restriction to floating-point dtypes only is different from
  229. # np.linalg.slogdet.
  230. if x.dtype not in _floating_dtypes:
  231. raise TypeError('Only floating-point dtypes are allowed in slogdet')
  232. # Note: the return type here is a namedtuple, which is different from
  233. # np.linalg.slogdet, which only returns a tuple.
  234. return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array)))
  235. # Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
  236. # vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
  237. # of matrices. The np.linalg.solve behavior of allowing stacks of both
  238. # matrices and vectors is ambiguous c.f.
  239. # https://github.com/numpy/numpy/issues/15349 and
  240. # https://github.com/data-apis/array-api/issues/285.
  241. # To workaround this, the below is the code from np.linalg.solve except
  242. # only calling solve1 in the exactly 1D case.
  243. def _solve(a, b):
  244. from ..linalg.linalg import (_makearray, _assert_stacked_2d,
  245. _assert_stacked_square, _commonType,
  246. isComplexType, get_linalg_error_extobj,
  247. _raise_linalgerror_singular)
  248. from ..linalg import _umath_linalg
  249. a, _ = _makearray(a)
  250. _assert_stacked_2d(a)
  251. _assert_stacked_square(a)
  252. b, wrap = _makearray(b)
  253. t, result_t = _commonType(a, b)
  254. # This part is different from np.linalg.solve
  255. if b.ndim == 1:
  256. gufunc = _umath_linalg.solve1
  257. else:
  258. gufunc = _umath_linalg.solve
  259. # This does nothing currently but is left in because it will be relevant
  260. # when complex dtype support is added to the spec in 2022.
  261. signature = 'DD->D' if isComplexType(t) else 'dd->d'
  262. with np.errstate(call=_raise_linalgerror_singular, invalid='call',
  263. over='ignore', divide='ignore', under='ignore'):
  264. r = gufunc(a, b, signature=signature)
  265. return wrap(r.astype(result_t, copy=False))
  266. def solve(x1: Array, x2: Array, /) -> Array:
  267. """
  268. Array API compatible wrapper for :py:func:`np.linalg.solve <numpy.linalg.solve>`.
  269. See its docstring for more information.
  270. """
  271. # Note: the restriction to floating-point dtypes only is different from
  272. # np.linalg.solve.
  273. if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
  274. raise TypeError('Only floating-point dtypes are allowed in solve')
  275. return Array._new(_solve(x1._array, x2._array))
  276. def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
  277. """
  278. Array API compatible wrapper for :py:func:`np.linalg.svd <numpy.linalg.svd>`.
  279. See its docstring for more information.
  280. """
  281. # Note: the restriction to floating-point dtypes only is different from
  282. # np.linalg.svd.
  283. if x.dtype not in _floating_dtypes:
  284. raise TypeError('Only floating-point dtypes are allowed in svd')
  285. # Note: the return type here is a namedtuple, which is different from
  286. # np.svd, which only returns a tuple.
  287. return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices)))
  288. # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
  289. # np.linalg.svd(compute_uv=False).
  290. def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
  291. if x.dtype not in _floating_dtypes:
  292. raise TypeError('Only floating-point dtypes are allowed in svdvals')
  293. return Array._new(np.linalg.svd(x._array, compute_uv=False))
  294. # Note: tensordot is the numpy top-level namespace but not in np.linalg
  295. # Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
  296. def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
  297. # Note: the restriction to numeric dtypes only is different from
  298. # np.tensordot.
  299. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  300. raise TypeError('Only numeric dtypes are allowed in tensordot')
  301. return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
  302. # Note: trace is the numpy top-level namespace, not np.linalg
  303. def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array:
  304. """
  305. Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
  306. See its docstring for more information.
  307. """
  308. if x.dtype not in _numeric_dtypes:
  309. raise TypeError('Only numeric dtypes are allowed in trace')
  310. # Note: trace() works the same as sum() and prod() (see
  311. # _statistical_functions.py)
  312. if dtype is None:
  313. if x.dtype == float32:
  314. dtype = float64
  315. elif x.dtype == complex64:
  316. dtype = complex128
  317. # Note: trace always operates on the last two axes, whereas np.trace
  318. # operates on the first two axes by default
  319. return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype)))
  320. # Note: vecdot is not in NumPy
  321. def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
  322. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  323. raise TypeError('Only numeric dtypes are allowed in vecdot')
  324. ndim = max(x1.ndim, x2.ndim)
  325. x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
  326. x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
  327. if x1_shape[axis] != x2_shape[axis]:
  328. raise ValueError("x1 and x2 must have the same size along the given axis")
  329. x1_, x2_ = np.broadcast_arrays(x1._array, x2._array)
  330. x1_ = np.moveaxis(x1_, axis, -1)
  331. x2_ = np.moveaxis(x2_, axis, -1)
  332. res = x1_[..., None, :] @ x2_[..., None]
  333. return Array._new(res[..., 0, 0])
  334. # Note: the name here is different from norm(). The array API norm is split
  335. # into matrix_norm and vector_norm().
  336. # The type for ord should be Optional[Union[int, float, Literal[np.inf,
  337. # -np.inf]]] but Literal does not support floating-point literals.
  338. def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
  339. """
  340. Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
  341. See its docstring for more information.
  342. """
  343. # Note: the restriction to floating-point dtypes only is different from
  344. # np.linalg.norm.
  345. if x.dtype not in _floating_dtypes:
  346. raise TypeError('Only floating-point dtypes are allowed in norm')
  347. # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
  348. # when axis=None and the input is 2-D, so to force a vector norm, we make
  349. # it so the input is 1-D (for axis=None), or reshape so that norm is done
  350. # on a single dimension.
  351. a = x._array
  352. if axis is None:
  353. # Note: np.linalg.norm() doesn't handle 0-D arrays
  354. a = a.ravel()
  355. _axis = 0
  356. elif isinstance(axis, tuple):
  357. # Note: The axis argument supports any number of axes, whereas
  358. # np.linalg.norm() only supports a single axis for vector norm.
  359. normalized_axis = normalize_axis_tuple(axis, x.ndim)
  360. rest = tuple(i for i in range(a.ndim) if i not in normalized_axis)
  361. newshape = axis + rest
  362. a = np.transpose(a, newshape).reshape(
  363. (np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest]))
  364. _axis = 0
  365. else:
  366. _axis = axis
  367. res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord))
  368. if keepdims:
  369. # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
  370. # above to avoid matrix norm logic.
  371. shape = list(x.shape)
  372. _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
  373. for i in _axis:
  374. shape[i] = 1
  375. res = reshape(res, tuple(shape))
  376. return res
  377. __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']