_basic.py 86 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408
  1. #
  2. # Author: Pearu Peterson, March 2002
  3. #
  4. # w/ additions by Travis Oliphant, March 2002
  5. # and Jake Vanderplas, August 2012
  6. import warnings
  7. from warnings import warn
  8. from itertools import product
  9. import numpy as np
  10. from numpy import atleast_1d, atleast_2d
  11. from scipy._lib._util import _apply_over_batch
  12. from .lapack import get_lapack_funcs, _compute_lwork, _normalize_lapack_dtype
  13. from ._misc import LinAlgError, _datacopied, LinAlgWarning
  14. from ._decomp import _asarray_validated
  15. from . import _decomp, _decomp_svd
  16. from ._solve_toeplitz import levinson
  17. from ._cythonized_array_utils import (find_det_from_lu, bandwidth, issymmetric,
  18. ishermitian)
  19. from . import _batched_linalg
  20. __all__ = ['solve', 'solve_triangular', 'solveh_banded', 'solve_banded',
  21. 'solve_toeplitz', 'solve_circulant', 'inv', 'det', 'lstsq',
  22. 'pinv', 'pinvh', 'matrix_balance', 'matmul_toeplitz']
  23. # Linear equations
  24. def _solve_check(n, info, lamch=None, rcond=None):
  25. """ Check arguments during the different steps of the solution phase """
  26. if info < 0:
  27. raise ValueError(f'LAPACK reported an illegal value in {-info}-th argument.')
  28. elif 0 < info or rcond == 0:
  29. raise LinAlgError('Matrix is singular.')
  30. if lamch is None:
  31. return
  32. E = lamch('E')
  33. if not (rcond >= E): # `rcond < E` doesn't handle NaN
  34. warn(f'Ill-conditioned matrix (rcond={rcond:.6g}): '
  35. 'result may not be accurate.',
  36. LinAlgWarning, stacklevel=3)
  37. def _find_matrix_structure(a):
  38. n = a.shape[0]
  39. n_below, n_above = bandwidth(a)
  40. if n_below == n_above == 0:
  41. kind = 'diagonal'
  42. elif n_above == 0:
  43. kind = 'lower triangular'
  44. elif n_below == 0:
  45. kind = 'upper triangular'
  46. elif n_above <= 1 and n_below <= 1 and n > 3:
  47. kind = 'tridiagonal'
  48. elif np.issubdtype(a.dtype, np.complexfloating) and ishermitian(a):
  49. kind = 'hermitian'
  50. elif issymmetric(a):
  51. kind = 'symmetric'
  52. else:
  53. kind = 'general'
  54. return kind, n_below, n_above
  55. def _format_emit_errors_warnings(err_lst):
  56. """Format/emit errors/warnings from a lowlevel batched routine.
  57. See inv, solve.
  58. """
  59. singular, lapack_err, ill_cond = [], [], []
  60. for i, dct in enumerate(err_lst):
  61. if dct["is_singular"]:
  62. singular.append(i)
  63. if dct["lapack_info"] < 0:
  64. lapack_err.append(f"slice {i} emits lapack info={dct['lapack_info']}")
  65. if dct["is_ill_conditioned"]:
  66. ill_cond.append(f"slice {i} has rcond = {dct['rcond']}")
  67. if singular:
  68. raise LinAlgError(
  69. f"A singular matrix detected: slice(s) {singular} are singular."
  70. )
  71. if lapack_err:
  72. raise ValueError(f"Internal LAPACK errors: {','.join(lapack_err)}.")
  73. if ill_cond:
  74. warnings.warn(
  75. f"An ill-conditioned matrix detected: {','.join(ill_cond)}.",
  76. LinAlgWarning,
  77. stacklevel=3
  78. )
  79. def solve(a, b, lower=False, overwrite_a=False,
  80. overwrite_b=False, check_finite=True, assume_a=None,
  81. transposed=False):
  82. """
  83. Solve the equation ``a @ x = b`` for ``x``,
  84. where `a` is a square matrix.
  85. If the data matrix is known to be a particular type then supplying the
  86. corresponding string to ``assume_a`` key chooses the dedicated solver.
  87. The available options are
  88. ============================= ================================
  89. diagonal 'diagonal'
  90. tridiagonal 'tridiagonal'
  91. banded 'banded'
  92. upper triangular 'upper triangular'
  93. lower triangular 'lower triangular'
  94. symmetric 'symmetric' (or 'sym')
  95. hermitian 'hermitian' (or 'her')
  96. symmetric positive definite 'positive definite' (or 'pos')
  97. general 'general' (or 'gen')
  98. ============================= ================================
  99. Array argument(s) of this function may have additional
  100. "batch" dimensions prepended to the core shape. In this case, the array is treated
  101. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  102. Parameters
  103. ----------
  104. a : array_like, shape (..., N, N)
  105. Square left-hand side matrix or a batch of matrices.
  106. b : (..., N, NRHS) array_like
  107. Input data for the right hand side or a batch of right-hand sides.
  108. lower : bool, default: False
  109. Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
  110. If True, the calculation uses only the data in the lower triangle of `a`;
  111. entries above the diagonal are ignored. If False (default), the
  112. calculation uses only the data in the upper triangle of `a`; entries
  113. below the diagonal are ignored.
  114. overwrite_a : bool, default: False
  115. Allow overwriting data in `a` (may enhance performance).
  116. overwrite_b : bool, default: False
  117. Allow overwriting data in `b` (may enhance performance).
  118. check_finite : bool, default: True
  119. Whether to check that the input matrices contain only finite numbers.
  120. Disabling may give a performance gain, but may result in problems
  121. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  122. assume_a : str, optional
  123. Valid entries are described above.
  124. If omitted or ``None``, checks are performed to identify structure so the
  125. appropriate solver can be called.
  126. transposed : bool, default: False
  127. If True, solve ``a.T @ x == b``. Raises `NotImplementedError`
  128. for complex `a`.
  129. Returns
  130. -------
  131. x : ndarray, shape (N, NRHS) or (..., N)
  132. The solution array.
  133. Raises
  134. ------
  135. ValueError
  136. If size mismatches detected or input a is not square.
  137. LinAlgError
  138. If the computation fails because of matrix singularity.
  139. LinAlgWarning
  140. If an ill-conditioned input a is detected.
  141. NotImplementedError
  142. If transposed is True and input a is a complex matrix.
  143. Notes
  144. -----
  145. If the input b matrix is a 1-D array with N elements, when supplied
  146. together with an NxN input a, it is assumed as a valid column vector
  147. despite the apparent size mismatch. This is compatible with the
  148. numpy.dot() behavior and the returned result is still 1-D array.
  149. The general, symmetric, Hermitian and positive definite solutions are
  150. obtained via calling ?GETRF/?GETRS, ?SYSV, ?HESV, and ?POTRF/?POTRS routines of
  151. LAPACK respectively.
  152. The datatype of the arrays define which solver is called regardless
  153. of the values. In other words, even when the complex array entries have
  154. precisely zero imaginary parts, the complex solver will be called based
  155. on the data type of the array.
  156. Examples
  157. --------
  158. Given `a` and `b`, solve for `x`:
  159. >>> import numpy as np
  160. >>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
  161. >>> b = np.array([2, 4, -1])
  162. >>> from scipy.linalg import solve
  163. >>> x = solve(a, b)
  164. >>> x
  165. array([ 2., -2., 9.])
  166. >>> a @ x == b
  167. array([ True, True, True], dtype=bool)
  168. Batches of matrices are supported, with and without structure detection:
  169. >>> a = np.arange(12).reshape(3, 2, 2) # a batch of 3 2x2 matrices
  170. >>> A = a.transpose(0, 2, 1) @ a # A is a batch of 3 positive definite matrices
  171. >>> b = np.ones(2)
  172. >>> solve(A, b) # this automatically detects that A is pos.def.
  173. array([[ 1. , -0.5],
  174. [ 3. , -2.5],
  175. [ 5. , -4.5]])
  176. >>> solve(A, b, assume_a='pos') # bypass structucture detection
  177. array([[ 1. , -0.5],
  178. [ 3. , -2.5],
  179. [ 5. , -4.5]])
  180. """
  181. if assume_a in ['banded']:
  182. # TODO: handle these structures in this function
  183. return solve0(
  184. a, b, lower=lower, overwrite_a=overwrite_a, overwrite_b=overwrite_b,
  185. check_finite=check_finite, assume_a=assume_a, transposed=transposed
  186. )
  187. # keep the numbers in sync with C
  188. structure = {
  189. None: -1,
  190. 'general': 0, 'gen': 0,
  191. 'diagonal': 11,
  192. 'tridiagonal': 31,
  193. 'upper triangular': 21,
  194. 'lower triangular': 22,
  195. 'pos' : 101, 'positive definite': 101,
  196. 'sym' : 201, 'symmetric': 201,
  197. 'her' : 211, 'hermitian': 211,
  198. }.get(assume_a, 'unknown')
  199. if structure == 'unknown':
  200. raise ValueError(f'{assume_a} is not a recognized matrix structure')
  201. a1 = np.atleast_2d(_asarray_validated(a, check_finite=check_finite))
  202. b1 = np.atleast_1d(_asarray_validated(b, check_finite=check_finite))
  203. a1, b1 = _ensure_dtype_cdsz(a1, b1) # XXX; b upcasts a?
  204. a1, overwrite_a = _normalize_lapack_dtype(a1, overwrite_a)
  205. if a1.ndim < 2:
  206. raise ValueError(f"Expected at least ndim=2, got {a1.ndim=}")
  207. if a1.shape[-1] != a1.shape[-2]:
  208. raise ValueError(f"Expected square matrix, got {a1.shape=}")
  209. # backwards compatibility
  210. if np.issubdtype(a1.dtype, np.complexfloating) and transposed:
  211. raise NotImplementedError('scipy.linalg.solve can currently '
  212. 'not solve a^T x = b or a^H x = b '
  213. 'for complex matrices.')
  214. if not (a1.flags['ALIGNED'] or a1.dtype.byteorder == '='):
  215. overwrite_a = True
  216. a1 = a1.copy()
  217. if not (b1.flags['ALIGNED'] or b1.dtype.byteorder == '='):
  218. overwrite_a = True
  219. b1 = b1.copy()
  220. # align the shape of b with a: 1. make b1 at least 2D
  221. b_is_1D = b1.ndim == 1
  222. if b_is_1D:
  223. b1 = b1[:, None]
  224. a_is_scalar = a1.size == 1
  225. if b1.shape[-2] != a1.shape[-1] and not a_is_scalar:
  226. raise ValueError(f"incompatible shapes: {a1.shape=} and {b1.shape=}")
  227. # 2. broadcast the batch dimensions of b1 and a1
  228. batch_shape = np.broadcast_shapes(a1.shape[:-2], b1.shape[:-2])
  229. a1 = np.broadcast_to(a1, batch_shape + a1.shape[-2:])
  230. b1 = np.broadcast_to(b1, batch_shape + b1.shape[-2:])
  231. # catch empty inputs
  232. if a1.size == 0 or b1.size == 0:
  233. x = np.empty_like(b1)
  234. if b_is_1D:
  235. x = x[..., 0]
  236. return x
  237. if a_is_scalar:
  238. out = b1 / a1
  239. return out[..., 0] if b_is_1D else out
  240. # heavy lifting
  241. x, err_lst = _batched_linalg._solve(a1, b1, structure, lower, transposed)
  242. if err_lst:
  243. _format_emit_errors_warnings(err_lst)
  244. if b_is_1D:
  245. x = x[..., 0]
  246. return x
  247. @_apply_over_batch(('a', 2), ('b', '1|2'))
  248. def solve0(a, b, lower=False, overwrite_a=False,
  249. overwrite_b=False, check_finite=True, assume_a=None,
  250. transposed=False):
  251. """
  252. Solve the equation ``a @ x = b`` for ``x``,
  253. where `a` is a square matrix.
  254. If the data matrix is known to be a particular type then supplying the
  255. corresponding string to ``assume_a`` key chooses the dedicated solver.
  256. The available options are
  257. ============================= ================================
  258. diagonal 'diagonal'
  259. tridiagonal 'tridiagonal'
  260. banded 'banded'
  261. upper triangular 'upper triangular'
  262. lower triangular 'lower triangular'
  263. symmetric 'symmetric' (or 'sym')
  264. hermitian 'hermitian' (or 'her')
  265. symmetric positive definite 'positive definite' (or 'pos')
  266. general 'general' (or 'gen')
  267. ============================= ================================
  268. Parameters
  269. ----------
  270. a : (N, N) array_like
  271. Square input data
  272. b : (N, NRHS) array_like
  273. Input data for the right hand side.
  274. lower : bool, default: False
  275. Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
  276. If True, the calculation uses only the data in the lower triangle of `a`;
  277. entries above the diagonal are ignored. If False (default), the
  278. calculation uses only the data in the upper triangle of `a`; entries
  279. below the diagonal are ignored.
  280. overwrite_a : bool, default: False
  281. Allow overwriting data in `a` (may enhance performance).
  282. overwrite_b : bool, default: False
  283. Allow overwriting data in `b` (may enhance performance).
  284. check_finite : bool, default: True
  285. Whether to check that the input matrices contain only finite numbers.
  286. Disabling may give a performance gain, but may result in problems
  287. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  288. assume_a : str, optional
  289. Valid entries are described above.
  290. If omitted or ``None``, checks are performed to identify structure so the
  291. appropriate solver can be called.
  292. transposed : bool, default: False
  293. If True, solve ``a.T @ x == b``. Raises `NotImplementedError`
  294. for complex `a`.
  295. Returns
  296. -------
  297. x : (N, NRHS) ndarray
  298. The solution array.
  299. Raises
  300. ------
  301. ValueError
  302. If size mismatches detected or input a is not square.
  303. LinAlgError
  304. If the computation fails because of matrix singularity.
  305. LinAlgWarning
  306. If an ill-conditioned input a is detected.
  307. NotImplementedError
  308. If transposed is True and input a is a complex matrix.
  309. Notes
  310. -----
  311. If the input b matrix is a 1-D array with N elements, when supplied
  312. together with an NxN input a, it is assumed as a valid column vector
  313. despite the apparent size mismatch. This is compatible with the
  314. numpy.dot() behavior and the returned result is still 1-D array.
  315. The general, symmetric, Hermitian and positive definite solutions are
  316. obtained via calling ?GESV, ?SYSV, ?HESV, and ?POSV routines of
  317. LAPACK respectively.
  318. The datatype of the arrays define which solver is called regardless
  319. of the values. In other words, even when the complex array entries have
  320. precisely zero imaginary parts, the complex solver will be called based
  321. on the data type of the array.
  322. Examples
  323. --------
  324. Given `a` and `b`, solve for `x`:
  325. >>> import numpy as np
  326. >>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
  327. >>> b = np.array([2, 4, -1])
  328. >>> from scipy import linalg
  329. >>> x = linalg.solve(a, b)
  330. >>> x
  331. array([ 2., -2., 9.])
  332. >>> np.dot(a, x) == b
  333. array([ True, True, True], dtype=bool)
  334. """
  335. # Flags for 1-D or N-D right-hand side
  336. b_is_1D = False
  337. # check finite after determining structure
  338. a1 = atleast_2d(_asarray_validated(a, check_finite=False))
  339. b1 = atleast_1d(_asarray_validated(b, check_finite=False))
  340. a1, b1 = _ensure_dtype_cdsz(a1, b1)
  341. n = a1.shape[0]
  342. overwrite_a = overwrite_a or _datacopied(a1, a)
  343. overwrite_b = overwrite_b or _datacopied(b1, b)
  344. if a1.shape[0] != a1.shape[1]:
  345. raise ValueError('Input a needs to be a square matrix.')
  346. if n != b1.shape[0]:
  347. # Last chance to catch 1x1 scalar a and 1-D b arrays
  348. if not (n == 1 and b1.size != 0):
  349. raise ValueError('Input b has to have same number of rows as '
  350. 'input a')
  351. # accommodate empty arrays
  352. if b1.size == 0:
  353. dt = solve(np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype)).dtype
  354. return np.empty_like(b1, dtype=dt)
  355. # regularize 1-D b arrays to 2D
  356. if b1.ndim == 1:
  357. if n == 1:
  358. b1 = b1[None, :]
  359. else:
  360. b1 = b1[:, None]
  361. b_is_1D = True
  362. if assume_a not in {None, 'diagonal', 'tridiagonal', 'banded', 'lower triangular',
  363. 'upper triangular', 'symmetric', 'hermitian',
  364. 'positive definite', 'general', 'sym', 'her', 'pos', 'gen'}:
  365. raise ValueError(f'{assume_a} is not a recognized matrix structure')
  366. # for a real matrix, describe it as "symmetric", not "hermitian"
  367. # (lapack doesn't know what to do with real hermitian matrices)
  368. if assume_a in {'hermitian', 'her'} and not np.iscomplexobj(a1):
  369. assume_a = 'symmetric'
  370. n_below, n_above = None, None
  371. if assume_a is None:
  372. assume_a, n_below, n_above = _find_matrix_structure(a1)
  373. # Get the correct lamch function.
  374. # The LAMCH functions only exists for S and D
  375. # So for complex values we have to convert to real/double.
  376. if a1.dtype.char in 'fF': # single precision
  377. lamch = get_lapack_funcs('lamch', dtype='f')
  378. else:
  379. lamch = get_lapack_funcs('lamch', dtype='d')
  380. # Since the I-norm and 1-norm are the same for symmetric matrices
  381. # we can collect them all in this one call
  382. # Note however, that when issuing 'gen' and form!='none', then
  383. # the I-norm should be used
  384. if transposed:
  385. trans = 1
  386. norm = 'I'
  387. if np.iscomplexobj(a1):
  388. raise NotImplementedError('scipy.linalg.solve can currently '
  389. 'not solve a^T x = b or a^H x = b '
  390. 'for complex matrices.')
  391. else:
  392. trans = 0
  393. norm = '1'
  394. # Currently we do not have the other forms of the norm calculators
  395. # lansy, lanpo, lanhe.
  396. # However, in any case they only reduce computations slightly...
  397. if assume_a == 'diagonal':
  398. anorm = _matrix_norm_diagonal(a1, check_finite)
  399. elif assume_a == 'tridiagonal':
  400. anorm = _matrix_norm_tridiagonal(norm, a1, check_finite)
  401. elif assume_a == 'banded':
  402. n_below, n_above = bandwidth(a1) if n_below is None else (n_below, n_above)
  403. a2, n_below, n_above = ((a1.T, n_above, n_below) if transposed
  404. else (a1, n_below, n_above))
  405. ab = _to_banded(n_below, n_above, a2)
  406. anorm = _matrix_norm_banded(n_below, n_above, norm, ab, check_finite)
  407. elif assume_a in {'lower triangular', 'upper triangular'}:
  408. anorm = _matrix_norm_triangular(assume_a, norm, a1, check_finite)
  409. else:
  410. anorm = _matrix_norm_general(norm, a1, check_finite)
  411. info, rcond = 0, np.inf
  412. # Generalized case 'gesv'
  413. if assume_a in {'general', 'gen'}:
  414. gecon, getrf, getrs = get_lapack_funcs(('gecon', 'getrf', 'getrs'),
  415. (a1, b1))
  416. lu, ipvt, info = getrf(a1, overwrite_a=overwrite_a)
  417. _solve_check(n, info)
  418. x, info = getrs(lu, ipvt, b1,
  419. trans=trans, overwrite_b=overwrite_b)
  420. _solve_check(n, info)
  421. rcond, info = gecon(lu, anorm, norm=norm)
  422. # Hermitian case 'hesv'
  423. elif assume_a in {'hermitian', 'her'}:
  424. hecon, hesv, hesv_lw = get_lapack_funcs(('hecon', 'hesv',
  425. 'hesv_lwork'), (a1, b1))
  426. lwork = _compute_lwork(hesv_lw, n, lower)
  427. lu, ipvt, x, info = hesv(a1, b1, lwork=lwork,
  428. lower=lower,
  429. overwrite_a=overwrite_a,
  430. overwrite_b=overwrite_b)
  431. _solve_check(n, info)
  432. rcond, info = hecon(lu, ipvt, anorm, lower=lower)
  433. # Symmetric case 'sysv'
  434. elif assume_a in {'symmetric', 'sym'}:
  435. sycon, sysv, sysv_lw = get_lapack_funcs(('sycon', 'sysv',
  436. 'sysv_lwork'), (a1, b1))
  437. lwork = _compute_lwork(sysv_lw, n, lower)
  438. lu, ipvt, x, info = sysv(a1, b1, lwork=lwork,
  439. lower=lower,
  440. overwrite_a=overwrite_a,
  441. overwrite_b=overwrite_b)
  442. _solve_check(n, info)
  443. rcond, info = sycon(lu, ipvt, anorm, lower=lower)
  444. # Diagonal case
  445. elif assume_a == 'diagonal':
  446. diag_a = np.diag(a1)
  447. x = (b1.T / diag_a).T
  448. abs_diag_a = np.abs(diag_a)
  449. diag_min = abs_diag_a.min()
  450. rcond = diag_min if diag_min == 0 else diag_min / abs_diag_a.max()
  451. # Tri-diagonal case
  452. elif assume_a == 'tridiagonal':
  453. a1 = a1.T if transposed else a1
  454. dl, d, du = np.diag(a1, -1), np.diag(a1, 0), np.diag(a1, 1)
  455. _gttrf, _gttrs, _gtcon = get_lapack_funcs(('gttrf', 'gttrs', 'gtcon'), (a1, b1))
  456. dl, d, du, du2, ipiv, info = _gttrf(dl, d, du)
  457. _solve_check(n, info)
  458. x, info = _gttrs(dl, d, du, du2, ipiv, b1, overwrite_b=overwrite_b)
  459. _solve_check(n, info)
  460. rcond, info = _gtcon(dl, d, du, du2, ipiv, anorm)
  461. # Banded case
  462. elif assume_a == 'banded':
  463. gbsv, gbcon = get_lapack_funcs(('gbsv', 'gbcon'), (a1, b1))
  464. # Next two lines copied from `solve_banded`
  465. a2 = np.zeros((2*n_below + n_above + 1, ab.shape[1]), dtype=gbsv.dtype)
  466. a2[n_below:, :] = ab
  467. lu, piv, x, info = gbsv(n_below, n_above, a2, b1,
  468. overwrite_ab=True, overwrite_b=overwrite_b)
  469. _solve_check(n, info)
  470. rcond, info = gbcon(n_below, n_above, lu, piv, anorm)
  471. # Triangular case
  472. elif assume_a in {'lower triangular', 'upper triangular'}:
  473. lower = assume_a == 'lower triangular'
  474. x, info = _solve_triangular(a1, b1, lower=lower, overwrite_b=overwrite_b,
  475. trans=transposed)
  476. _solve_check(n, info)
  477. _trcon = get_lapack_funcs(('trcon'), (a1, b1))
  478. rcond, info = _trcon(a1, uplo='L' if lower else 'U')
  479. # Positive definite case 'posv'
  480. else:
  481. pocon, posv = get_lapack_funcs(('pocon', 'posv'),
  482. (a1, b1))
  483. lu, x, info = posv(a1, b1, lower=lower,
  484. overwrite_a=overwrite_a,
  485. overwrite_b=overwrite_b)
  486. _solve_check(n, info)
  487. rcond, info = pocon(lu, anorm)
  488. _solve_check(n, info, lamch, rcond)
  489. if b_is_1D:
  490. x = x.ravel()
  491. return x
  492. def _matrix_norm_diagonal(a, check_finite):
  493. # Equivalent of dlange for diagonal matrix, assuming
  494. # norm is either 'I' or '1' (really just not the Frobenius norm)
  495. d = np.diag(a)
  496. d = np.asarray_chkfinite(d) if check_finite else d
  497. return np.abs(d).max()
  498. def _matrix_norm_tridiagonal(norm, a, check_finite):
  499. # Equivalent of dlange for tridiagonal matrix, assuming
  500. # norm is either 'I' or '1'
  501. if norm == 'I':
  502. a = a.T
  503. # Context to avoid warning before error in cases like -inf + inf
  504. with np.errstate(invalid='ignore'):
  505. d = np.abs(np.diag(a))
  506. d[1:] += np.abs(np.diag(a, 1))
  507. d[:-1] += np.abs(np.diag(a, -1))
  508. d = np.asarray_chkfinite(d) if check_finite else d
  509. return d.max()
  510. def _matrix_norm_triangular(structure, norm, a, check_finite):
  511. a = np.asarray_chkfinite(a) if check_finite else a
  512. lantr = get_lapack_funcs('lantr', (a,))
  513. return lantr(norm, a, 'L' if structure == 'lower triangular' else 'U' )
  514. def _matrix_norm_banded(kl, ku, norm, ab, check_finite):
  515. ab = np.asarray_chkfinite(ab) if check_finite else ab
  516. langb = get_lapack_funcs('langb', (ab,))
  517. return langb(norm, kl, ku, ab)
  518. def _matrix_norm_general(norm, a, check_finite):
  519. a = np.asarray_chkfinite(a) if check_finite else a
  520. lange = get_lapack_funcs('lange', (a,))
  521. return lange(norm, a)
  522. def _to_banded(n_below, n_above, a):
  523. n = a.shape[0]
  524. rows = n_above + n_below + 1
  525. ab = np.zeros((rows, n), dtype=a.dtype)
  526. ab[n_above] = np.diag(a)
  527. for i in range(1, n_above + 1):
  528. ab[n_above - i, i:] = np.diag(a, i)
  529. for i in range(1, n_below + 1):
  530. ab[n_above + i, :-i] = np.diag(a, -i)
  531. return ab
  532. def _ensure_dtype_cdsz(*arrays):
  533. # Ensure that the dtype of arrays is one of the standard types
  534. # compatible with LAPACK functions (single or double precision
  535. # real or complex).
  536. dtype = np.result_type(*arrays)
  537. if not np.issubdtype(dtype, np.inexact):
  538. return (array.astype(np.float64) for array in arrays)
  539. complex = np.issubdtype(dtype, np.complexfloating)
  540. if np.finfo(dtype).bits <= 32:
  541. dtype = np.complex64 if complex else np.float32
  542. elif np.finfo(dtype).bits >= 64:
  543. dtype = np.complex128 if complex else np.float64
  544. return (array.astype(dtype, copy=False) for array in arrays)
  545. @_apply_over_batch(('a', 2), ('b', '1|2'))
  546. def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
  547. overwrite_b=False, check_finite=True):
  548. """
  549. Solve the equation ``a @ x = b`` for ``x``, where `a` is a triangular matrix.
  550. Parameters
  551. ----------
  552. a : (M, M) array_like
  553. A triangular matrix
  554. b : (M,) or (M, N) array_like
  555. Right-hand side matrix in ``a x = b``
  556. lower : bool, optional
  557. Use only data contained in the lower triangle of `a`.
  558. Default is to use upper triangle.
  559. trans : {0, 1, 2, 'N', 'T', 'C'}, optional
  560. Type of system to solve:
  561. ======== =========
  562. trans system
  563. ======== =========
  564. 0 or 'N' a x = b
  565. 1 or 'T' a^T x = b
  566. 2 or 'C' a^H x = b
  567. ======== =========
  568. unit_diagonal : bool, optional
  569. If True, diagonal elements of `a` are assumed to be 1 and
  570. will not be referenced.
  571. overwrite_b : bool, optional
  572. Allow overwriting data in `b` (may enhance performance)
  573. check_finite : bool, optional
  574. Whether to check that the input matrices contain only finite numbers.
  575. Disabling may give a performance gain, but may result in problems
  576. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  577. Returns
  578. -------
  579. x : (M,) or (M, N) ndarray
  580. Solution to the system ``a x = b``. Shape of return matches `b`.
  581. Raises
  582. ------
  583. LinAlgError
  584. If `a` is singular
  585. Notes
  586. -----
  587. .. versionadded:: 0.9.0
  588. Examples
  589. --------
  590. Solve the lower triangular system a x = b, where::
  591. [3 0 0 0] [4]
  592. a = [2 1 0 0] b = [2]
  593. [1 0 1 0] [4]
  594. [1 1 1 1] [2]
  595. >>> import numpy as np
  596. >>> from scipy.linalg import solve_triangular
  597. >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
  598. >>> b = np.array([4, 2, 4, 2])
  599. >>> x = solve_triangular(a, b, lower=True)
  600. >>> x
  601. array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333])
  602. >>> a.dot(x) # Check the result
  603. array([ 4., 2., 4., 2.])
  604. """
  605. a1 = _asarray_validated(a, check_finite=check_finite)
  606. b1 = _asarray_validated(b, check_finite=check_finite)
  607. if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
  608. raise ValueError('expected square matrix')
  609. if a1.shape[0] != b1.shape[0]:
  610. raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible')
  611. # accommodate empty arrays
  612. if b1.size == 0:
  613. dt_nonempty = solve_triangular(
  614. np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype)
  615. ).dtype
  616. return np.empty_like(b1, dtype=dt_nonempty)
  617. overwrite_b = overwrite_b or _datacopied(b1, b)
  618. x, _ = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b)
  619. return x
  620. # solve_triangular without the input validation
  621. def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False,
  622. overwrite_b=False):
  623. trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans)
  624. trtrs, = get_lapack_funcs(('trtrs',), (a1, b1))
  625. if a1.flags.f_contiguous or trans == 2:
  626. x, info = trtrs(a1, b1, overwrite_b=overwrite_b, lower=lower,
  627. trans=trans, unitdiag=unit_diagonal)
  628. else:
  629. # transposed system is solved since trtrs expects Fortran ordering
  630. x, info = trtrs(a1.T, b1, overwrite_b=overwrite_b, lower=not lower,
  631. trans=not trans, unitdiag=unit_diagonal)
  632. if info == 0:
  633. return x, info
  634. if info > 0:
  635. raise LinAlgError(f"singular matrix: resolution failed at diagonal {info-1}")
  636. raise ValueError(f'illegal value in {-info}-th argument of internal trtrs')
  637. def solve_banded(l_and_u, ab, b, overwrite_ab=False, overwrite_b=False,
  638. check_finite=True):
  639. """
  640. Solve the equation ``a @ x = b`` for ``x``, where ``a`` is the banded matrix
  641. defined by `ab`.
  642. The matrix a is stored in `ab` using the matrix diagonal ordered form::
  643. ab[u + i - j, j] == a[i,j]
  644. Example of `ab` (shape of a is (6,6), `u` =1, `l` =2)::
  645. * a01 a12 a23 a34 a45
  646. a00 a11 a22 a33 a44 a55
  647. a10 a21 a32 a43 a54 *
  648. a20 a31 a42 a53 * *
  649. The documentation is written assuming array arguments are of specified
  650. "core" shapes. However, array argument(s) of this function may have additional
  651. "batch" dimensions prepended to the core shape. In this case, the array is treated
  652. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  653. Parameters
  654. ----------
  655. (l, u) : (integer, integer)
  656. Number of non-zero lower and upper diagonals
  657. ab : (`l` + `u` + 1, M) array_like
  658. Banded matrix
  659. b : (M,) or (M, K) array_like
  660. Right-hand side
  661. overwrite_ab : bool, optional
  662. Discard data in `ab` (may enhance performance)
  663. overwrite_b : bool, optional
  664. Discard data in `b` (may enhance performance)
  665. check_finite : bool, optional
  666. Whether to check that the input matrices contain only finite numbers.
  667. Disabling may give a performance gain, but may result in problems
  668. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  669. Returns
  670. -------
  671. x : (M,) or (M, K) ndarray
  672. The solution to the system a x = b. Returned shape depends on the
  673. shape of `b`.
  674. Examples
  675. --------
  676. Solve the banded system a x = b, where::
  677. [5 2 -1 0 0] [0]
  678. [1 4 2 -1 0] [1]
  679. a = [0 1 3 2 -1] b = [2]
  680. [0 0 1 2 2] [2]
  681. [0 0 0 1 1] [3]
  682. There is one nonzero diagonal below the main diagonal (l = 1), and
  683. two above (u = 2). The diagonal banded form of the matrix is::
  684. [* * -1 -1 -1]
  685. ab = [* 2 2 2 2]
  686. [5 4 3 2 1]
  687. [1 1 1 1 *]
  688. >>> import numpy as np
  689. >>> from scipy.linalg import solve_banded
  690. >>> ab = np.array([[0, 0, -1, -1, -1],
  691. ... [0, 2, 2, 2, 2],
  692. ... [5, 4, 3, 2, 1],
  693. ... [1, 1, 1, 1, 0]])
  694. >>> b = np.array([0, 1, 2, 2, 3])
  695. >>> x = solve_banded((1, 2), ab, b)
  696. >>> x
  697. array([-2.37288136, 3.93220339, -4. , 4.3559322 , -1.3559322 ])
  698. """
  699. (nlower, nupper) = l_and_u
  700. return _solve_banded(nlower, nupper, ab, b, overwrite_ab=overwrite_ab,
  701. overwrite_b=overwrite_b, check_finite=check_finite)
  702. @_apply_over_batch(('nlower', 0), ('nupper', 0), ('ab', 2), ('b', '1|2'))
  703. def _solve_banded(nlower, nupper, ab, b, overwrite_ab, overwrite_b, check_finite):
  704. a1 = _asarray_validated(ab, check_finite=check_finite, as_inexact=True)
  705. b1 = _asarray_validated(b, check_finite=check_finite, as_inexact=True)
  706. # Validate shapes.
  707. if a1.shape[-1] != b1.shape[0]:
  708. raise ValueError("shapes of ab and b are not compatible.")
  709. if nlower + nupper + 1 != a1.shape[0]:
  710. raise ValueError(
  711. f"invalid values for the number of lower and upper diagonals: l+u+1 "
  712. f"({nlower + nupper + 1}) does not equal ab.shape[0] ({ab.shape[0]})"
  713. )
  714. # accommodate empty arrays
  715. if b1.size == 0:
  716. dt = solve(np.eye(1, dtype=a1.dtype), np.ones(1, dtype=b1.dtype)).dtype
  717. return np.empty_like(b1, dtype=dt)
  718. overwrite_b = overwrite_b or _datacopied(b1, b)
  719. if a1.shape[-1] == 1:
  720. b2 = np.array(b1, copy=(not overwrite_b))
  721. # a1.shape[-1] == 1 -> original matrix is 1x1. Typically, the user
  722. # will pass u = l = 0 and `a1` will be 1x1. However, the rest of the
  723. # function works with unnecessary rows in `a1` as long as
  724. # `a1[u + i - j, j] == a[i,j]`. In the 1x1 case, we want i = j = 0,
  725. # so the diagonal is in row `u` of `a1`. See gh-8906.
  726. b2 /= a1[nupper, 0]
  727. return b2
  728. if nlower == nupper == 1:
  729. overwrite_ab = overwrite_ab or _datacopied(a1, ab)
  730. gtsv, = get_lapack_funcs(('gtsv',), (a1, b1))
  731. du = a1[0, 1:]
  732. d = a1[1, :]
  733. dl = a1[2, :-1]
  734. du2, d, du, x, info = gtsv(dl, d, du, b1, overwrite_ab, overwrite_ab,
  735. overwrite_ab, overwrite_b)
  736. else:
  737. gbsv, = get_lapack_funcs(('gbsv',), (a1, b1))
  738. a2 = np.zeros((2*nlower + nupper + 1, a1.shape[1]), dtype=gbsv.dtype)
  739. a2[nlower:, :] = a1
  740. lu, piv, x, info = gbsv(nlower, nupper, a2, b1, overwrite_ab=True,
  741. overwrite_b=overwrite_b)
  742. if info == 0:
  743. return x
  744. if info > 0:
  745. raise LinAlgError("singular matrix")
  746. raise ValueError(f'illegal value in {-info}-th argument of internal gbsv/gtsv')
  747. @_apply_over_batch(('a', 2), ('b', '1|2'))
  748. def solveh_banded(ab, b, overwrite_ab=False, overwrite_b=False, lower=False,
  749. check_finite=True):
  750. """
  751. Solve the equation ``a @ x = b`` for ``x``, where ``a`` is the
  752. Hermitian positive-definite banded matrix defined by `ab`.
  753. Uses Thomas' Algorithm, which is more efficient than standard LU
  754. factorization, but should only be used for Hermitian positive-definite
  755. matrices.
  756. The matrix ``a`` is stored in `ab` either in lower diagonal or upper
  757. diagonal ordered form:
  758. ab[u + i - j, j] == a[i,j] (if upper form; i <= j)
  759. ab[ i - j, j] == a[i,j] (if lower form; i >= j)
  760. Example of `ab` (shape of ``a`` is (6, 6), number of upper diagonals,
  761. ``u`` =2)::
  762. upper form:
  763. * * a02 a13 a24 a35
  764. * a01 a12 a23 a34 a45
  765. a00 a11 a22 a33 a44 a55
  766. lower form:
  767. a00 a11 a22 a33 a44 a55
  768. a10 a21 a32 a43 a54 *
  769. a20 a31 a42 a53 * *
  770. Cells marked with * are not used.
  771. Parameters
  772. ----------
  773. ab : (``u`` + 1, M) array_like
  774. Banded matrix
  775. b : (M,) or (M, K) array_like
  776. Right-hand side
  777. overwrite_ab : bool, optional
  778. Discard data in `ab` (may enhance performance)
  779. overwrite_b : bool, optional
  780. Discard data in `b` (may enhance performance)
  781. lower : bool, optional
  782. Is the matrix in the lower form. (Default is upper form)
  783. check_finite : bool, optional
  784. Whether to check that the input matrices contain only finite numbers.
  785. Disabling may give a performance gain, but may result in problems
  786. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  787. Returns
  788. -------
  789. x : (M,) or (M, K) ndarray
  790. The solution to the system ``a x = b``. Shape of return matches shape
  791. of `b`.
  792. Notes
  793. -----
  794. In the case of a non-positive definite matrix ``a``, the solver
  795. `solve_banded` may be used.
  796. Examples
  797. --------
  798. Solve the banded system ``A x = b``, where::
  799. [ 4 2 -1 0 0 0] [1]
  800. [ 2 5 2 -1 0 0] [2]
  801. A = [-1 2 6 2 -1 0] b = [2]
  802. [ 0 -1 2 7 2 -1] [3]
  803. [ 0 0 -1 2 8 2] [3]
  804. [ 0 0 0 -1 2 9] [3]
  805. >>> import numpy as np
  806. >>> from scipy.linalg import solveh_banded
  807. ``ab`` contains the main diagonal and the nonzero diagonals below the
  808. main diagonal. That is, we use the lower form:
  809. >>> ab = np.array([[ 4, 5, 6, 7, 8, 9],
  810. ... [ 2, 2, 2, 2, 2, 0],
  811. ... [-1, -1, -1, -1, 0, 0]])
  812. >>> b = np.array([1, 2, 2, 3, 3, 3])
  813. >>> x = solveh_banded(ab, b, lower=True)
  814. >>> x
  815. array([ 0.03431373, 0.45938375, 0.05602241, 0.47759104, 0.17577031,
  816. 0.34733894])
  817. Solve the Hermitian banded system ``H x = b``, where::
  818. [ 8 2-1j 0 0 ] [ 1 ]
  819. H = [2+1j 5 1j 0 ] b = [1+1j]
  820. [ 0 -1j 9 -2-1j] [1-2j]
  821. [ 0 0 -2+1j 6 ] [ 0 ]
  822. In this example, we put the upper diagonals in the array ``hb``:
  823. >>> hb = np.array([[0, 2-1j, 1j, -2-1j],
  824. ... [8, 5, 9, 6 ]])
  825. >>> b = np.array([1, 1+1j, 1-2j, 0])
  826. >>> x = solveh_banded(hb, b)
  827. >>> x
  828. array([ 0.07318536-0.02939412j, 0.11877624+0.17696461j,
  829. 0.10077984-0.23035393j, -0.00479904-0.09358128j])
  830. """
  831. a1 = _asarray_validated(ab, check_finite=check_finite)
  832. b1 = _asarray_validated(b, check_finite=check_finite)
  833. # Validate shapes.
  834. if a1.shape[-1] != b1.shape[0]:
  835. raise ValueError("shapes of ab and b are not compatible.")
  836. # accommodate empty arrays
  837. if b1.size == 0:
  838. dt = solve(np.eye(1, dtype=a1.dtype), np.ones(1, dtype=b1.dtype)).dtype
  839. return np.empty_like(b1, dtype=dt)
  840. overwrite_b = overwrite_b or _datacopied(b1, b)
  841. overwrite_ab = overwrite_ab or _datacopied(a1, ab)
  842. if a1.shape[0] == 2:
  843. ptsv, = get_lapack_funcs(('ptsv',), (a1, b1))
  844. if lower:
  845. d = a1[0, :].real
  846. e = a1[1, :-1]
  847. else:
  848. d = a1[1, :].real
  849. e = a1[0, 1:].conj()
  850. d, du, x, info = ptsv(d, e, b1, overwrite_ab, overwrite_ab,
  851. overwrite_b)
  852. else:
  853. pbsv, = get_lapack_funcs(('pbsv',), (a1, b1))
  854. c, x, info = pbsv(a1, b1, lower=lower, overwrite_ab=overwrite_ab,
  855. overwrite_b=overwrite_b)
  856. if info > 0:
  857. raise LinAlgError(f"{info}th leading minor not positive definite")
  858. if info < 0:
  859. raise ValueError(f'illegal value in {-info}th argument of internal pbsv')
  860. return x
  861. def solve_toeplitz(c_or_cr, b, check_finite=True):
  862. r"""Solve the equation ``T @ x = b`` for ``x``, where ``T`` is a Toeplitz
  863. matrix defined by `c_or_cr`.
  864. The Toeplitz matrix has constant diagonals, with ``c`` as its first column
  865. and ``r`` as its first row. If ``r`` is not given, ``r == conjugate(c)`` is
  866. assumed.
  867. The documentation is written assuming array arguments are of specified
  868. "core" shapes. However, array argument(s) of this function may have additional
  869. "batch" dimensions prepended to the core shape. In this case, the array is treated
  870. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  871. Parameters
  872. ----------
  873. c_or_cr : array_like or tuple of (array_like, array_like)
  874. The vector ``c``, or a tuple of arrays (``c``, ``r``). If not
  875. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  876. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  877. of the Toeplitz matrix is ``[c[0], r[1:]]``.
  878. b : (M,) or (M, K) array_like
  879. Right-hand side in ``T x = b``.
  880. check_finite : bool, optional
  881. Whether to check that the input matrices contain only finite numbers.
  882. Disabling may give a performance gain, but may result in problems
  883. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  884. Returns
  885. -------
  886. x : (M,) or (M, K) ndarray
  887. The solution to the system ``T @ x = b``. Shape of return matches shape
  888. of `b`.
  889. See Also
  890. --------
  891. toeplitz : Toeplitz matrix
  892. Notes
  893. -----
  894. The solution is computed using Levinson-Durbin recursion, which is faster
  895. than generic least-squares methods, but can be less numerically stable.
  896. Examples
  897. --------
  898. Solve the Toeplitz system ``T @ x = b``, where::
  899. [ 1 -1 -2 -3] [1]
  900. T = [ 3 1 -1 -2] b = [2]
  901. [ 6 3 1 -1] [2]
  902. [10 6 3 1] [5]
  903. To specify the Toeplitz matrix, only the first column and the first
  904. row are needed.
  905. >>> import numpy as np
  906. >>> c = np.array([1, 3, 6, 10]) # First column of T
  907. >>> r = np.array([1, -1, -2, -3]) # First row of T
  908. >>> b = np.array([1, 2, 2, 5])
  909. >>> from scipy.linalg import solve_toeplitz, toeplitz
  910. >>> x = solve_toeplitz((c, r), b)
  911. >>> x
  912. array([ 1.66666667, -1. , -2.66666667, 2.33333333])
  913. Check the result by creating the full Toeplitz matrix and
  914. multiplying it by ``x``. We should get `b`.
  915. >>> T = toeplitz(c, r)
  916. >>> T.dot(x)
  917. array([ 1., 2., 2., 5.])
  918. """
  919. # If numerical stability of this algorithm is a problem, a future
  920. # developer might consider implementing other O(N^2) Toeplitz solvers,
  921. # such as GKO (https://www.jstor.org/stable/2153371) or Bareiss.
  922. c, r = c_or_cr if isinstance(c_or_cr, tuple) else (c_or_cr, np.conjugate(c_or_cr))
  923. return _solve_toeplitz(c, r, b, check_finite)
  924. @_apply_over_batch(('c', 1), ('r', 1), ('b', '1|2'))
  925. def _solve_toeplitz(c, r, b, check_finite):
  926. r, c, b, dtype, b_shape = _validate_args_for_toeplitz_ops(
  927. (c, r), b, check_finite, keep_b_shape=True)
  928. # accommodate empty arrays
  929. if b.size == 0:
  930. return np.empty_like(b)
  931. # Form a 1-D array of values to be used in the matrix, containing a
  932. # reversed copy of r[1:], followed by c.
  933. vals = np.concatenate((r[-1:0:-1], c))
  934. if b is None:
  935. raise ValueError('illegal value, `b` is a required argument')
  936. if b.ndim == 1:
  937. x, _ = levinson(vals, np.ascontiguousarray(b))
  938. else:
  939. x = np.column_stack([levinson(vals, np.ascontiguousarray(b[:, i]))[0]
  940. for i in range(b.shape[1])])
  941. x = x.reshape(*b_shape)
  942. return x
  943. def _get_axis_len(aname, a, axis):
  944. ax = axis
  945. if ax < 0:
  946. ax += a.ndim
  947. if 0 <= ax < a.ndim:
  948. return a.shape[ax]
  949. raise ValueError(f"'{aname}axis' entry is out of bounds")
  950. def solve_circulant(c, b, singular='raise', tol=None,
  951. caxis=-1, baxis=0, outaxis=0):
  952. """Solve the equation ``C @ x = b`` for ``x``, where ``C`` is a
  953. circulant matrix defined by `c`.
  954. `C` is the circulant matrix associated with the vector `c`.
  955. The system is solved by doing division in Fourier space. The
  956. calculation is::
  957. x = ifft(fft(b) / fft(c))
  958. where `fft` and `ifft` are the fast Fourier transform and its inverse,
  959. respectively. For a large vector `c`, this is *much* faster than
  960. solving the system with the full circulant matrix.
  961. Parameters
  962. ----------
  963. c : array_like
  964. The coefficients of the circulant matrix.
  965. b : array_like
  966. Right-hand side matrix in ``a x = b``.
  967. singular : str, optional
  968. This argument controls how a near singular circulant matrix is
  969. handled. If `singular` is "raise" and the circulant matrix is
  970. near singular, a `LinAlgError` is raised. If `singular` is
  971. "lstsq", the least squares solution is returned. Default is "raise".
  972. tol : float, optional
  973. If any eigenvalue of the circulant matrix has an absolute value
  974. that is less than or equal to `tol`, the matrix is considered to be
  975. near singular. If not given, `tol` is set to::
  976. tol = abs_eigs.max() * abs_eigs.size * np.finfo(np.float64).eps
  977. where `abs_eigs` is the array of absolute values of the eigenvalues
  978. of the circulant matrix.
  979. caxis : int
  980. When `c` has dimension greater than 1, it is viewed as a collection
  981. of circulant vectors. In this case, `caxis` is the axis of `c` that
  982. holds the vectors of circulant coefficients.
  983. baxis : int
  984. When `b` has dimension greater than 1, it is viewed as a collection
  985. of vectors. In this case, `baxis` is the axis of `b` that holds the
  986. right-hand side vectors.
  987. outaxis : int
  988. When `c` or `b` are multidimensional, the value returned by
  989. `solve_circulant` is multidimensional. In this case, `outaxis` is
  990. the axis of the result that holds the solution vectors.
  991. Returns
  992. -------
  993. x : ndarray
  994. Solution to the system ``C x = b``.
  995. Raises
  996. ------
  997. LinAlgError
  998. If the circulant matrix associated with `c` is near singular.
  999. See Also
  1000. --------
  1001. circulant : circulant matrix
  1002. Notes
  1003. -----
  1004. For a 1-D vector `c` with length `m`, and an array `b`
  1005. with shape ``(m, ...)``,
  1006. solve_circulant(c, b)
  1007. returns the same result as
  1008. solve(circulant(c), b)
  1009. where `solve` and `circulant` are from `scipy.linalg`.
  1010. .. versionadded:: 0.16.0
  1011. Examples
  1012. --------
  1013. >>> import numpy as np
  1014. >>> from scipy.linalg import solve_circulant, solve, circulant, lstsq
  1015. >>> c = np.array([2, 2, 4])
  1016. >>> b = np.array([1, 2, 3])
  1017. >>> solve_circulant(c, b)
  1018. array([ 0.75, -0.25, 0.25])
  1019. Compare that result to solving the system with `scipy.linalg.solve`:
  1020. >>> solve(circulant(c), b)
  1021. array([ 0.75, -0.25, 0.25])
  1022. A singular example:
  1023. >>> c = np.array([1, 1, 0, 0])
  1024. >>> b = np.array([1, 2, 3, 4])
  1025. Calling ``solve_circulant(c, b)`` will raise a `LinAlgError`. For the
  1026. least square solution, use the option ``singular='lstsq'``:
  1027. >>> solve_circulant(c, b, singular='lstsq')
  1028. array([ 0.25, 1.25, 2.25, 1.25])
  1029. Compare to `scipy.linalg.lstsq`:
  1030. >>> x, resid, rnk, s = lstsq(circulant(c), b)
  1031. >>> x
  1032. array([ 0.25, 1.25, 2.25, 1.25])
  1033. A broadcasting example:
  1034. Suppose we have the vectors of two circulant matrices stored in an array
  1035. with shape (2, 5), and three `b` vectors stored in an array with shape
  1036. (3, 5). For example,
  1037. >>> c = np.array([[1.5, 2, 3, 0, 0], [1, 1, 4, 3, 2]])
  1038. >>> b = np.arange(15).reshape(-1, 5)
  1039. We want to solve all combinations of circulant matrices and `b` vectors,
  1040. with the result stored in an array with shape (2, 3, 5). When we
  1041. disregard the axes of `c` and `b` that hold the vectors of coefficients,
  1042. the shapes of the collections are (2,) and (3,), respectively, which are
  1043. not compatible for broadcasting. To have a broadcast result with shape
  1044. (2, 3), we add a trivial dimension to `c`: ``c[:, np.newaxis, :]`` has
  1045. shape (2, 1, 5). The last dimension holds the coefficients of the
  1046. circulant matrices, so when we call `solve_circulant`, we can use the
  1047. default ``caxis=-1``. The coefficients of the `b` vectors are in the last
  1048. dimension of the array `b`, so we use ``baxis=-1``. If we use the
  1049. default `outaxis`, the result will have shape (5, 2, 3), so we'll use
  1050. ``outaxis=-1`` to put the solution vectors in the last dimension.
  1051. >>> x = solve_circulant(c[:, np.newaxis, :], b, baxis=-1, outaxis=-1)
  1052. >>> x.shape
  1053. (2, 3, 5)
  1054. >>> np.set_printoptions(precision=3) # For compact output of numbers.
  1055. >>> x
  1056. array([[[-0.118, 0.22 , 1.277, -0.142, 0.302],
  1057. [ 0.651, 0.989, 2.046, 0.627, 1.072],
  1058. [ 1.42 , 1.758, 2.816, 1.396, 1.841]],
  1059. [[ 0.401, 0.304, 0.694, -0.867, 0.377],
  1060. [ 0.856, 0.758, 1.149, -0.412, 0.831],
  1061. [ 1.31 , 1.213, 1.603, 0.042, 1.286]]])
  1062. Check by solving one pair of `c` and `b` vectors (cf. ``x[1, 1, :]``):
  1063. >>> solve_circulant(c[1], b[1, :])
  1064. array([ 0.856, 0.758, 1.149, -0.412, 0.831])
  1065. """
  1066. c = np.atleast_1d(c)
  1067. nc = _get_axis_len("c", c, caxis)
  1068. b = np.atleast_1d(b)
  1069. nb = _get_axis_len("b", b, baxis)
  1070. if nc != nb:
  1071. raise ValueError(f'Shapes of c {c.shape} and b {b.shape} are incompatible')
  1072. # accommodate empty arrays
  1073. if b.size == 0:
  1074. dt = solve_circulant(np.arange(3, dtype=c.dtype),
  1075. np.ones(3, dtype=b.dtype)).dtype
  1076. return np.empty_like(b, dtype=dt)
  1077. fc = np.fft.fft(np.moveaxis(c, caxis, -1), axis=-1)
  1078. abs_fc = np.abs(fc)
  1079. if tol is None:
  1080. # This is the same tolerance as used in np.linalg.matrix_rank.
  1081. tol = abs_fc.max(axis=-1) * nc * np.finfo(np.float64).eps
  1082. if tol.shape != ():
  1083. tol = tol.reshape(tol.shape + (1,))
  1084. else:
  1085. tol = np.atleast_1d(tol)
  1086. near_zeros = abs_fc <= tol
  1087. is_near_singular = np.any(near_zeros)
  1088. if is_near_singular:
  1089. if singular == 'raise':
  1090. raise LinAlgError("near singular circulant matrix.")
  1091. else:
  1092. # Replace the small values with 1 to avoid errors in the
  1093. # division fb/fc below.
  1094. fc[near_zeros] = 1
  1095. fb = np.fft.fft(np.moveaxis(b, baxis, -1), axis=-1)
  1096. q = fb / fc
  1097. if is_near_singular:
  1098. # `near_zeros` is a boolean array, same shape as `c`, that is
  1099. # True where `fc` is (near) zero. `q` is the broadcasted result
  1100. # of fb / fc, so to set the values of `q` to 0 where `fc` is near
  1101. # zero, we use a mask that is the broadcast result of an array
  1102. # of True values shaped like `b` with `near_zeros`.
  1103. mask = np.ones_like(b, dtype=bool) & near_zeros
  1104. q[mask] = 0
  1105. x = np.fft.ifft(q, axis=-1)
  1106. if not (np.iscomplexobj(c) or np.iscomplexobj(b)):
  1107. x = x.real
  1108. if outaxis != -1:
  1109. x = np.moveaxis(x, -1, outaxis)
  1110. return x
  1111. # matrix inversion
  1112. def inv(a, overwrite_a=False, check_finite=True, *, assume_a=None, lower=False):
  1113. r"""
  1114. Compute the inverse of a matrix.
  1115. If the data matrix is known to be a particular type then supplying the
  1116. corresponding string to ``assume_a`` key chooses the dedicated solver.
  1117. The available options are
  1118. ============================= ================================
  1119. general 'general' (or 'gen')
  1120. diagonal 'diagonal'
  1121. upper triangular 'upper triangular'
  1122. lower triangular 'lower triangular'
  1123. symmetric positive definite 'pos'
  1124. symmetric 'sym'
  1125. Hermitian 'her'
  1126. ============================= ================================
  1127. For the 'pos' option, only the triangle of the input matrix specified in
  1128. the `lower` argument is used, and the other triangle is not referenced.
  1129. Likewise, an explicit `assume_a='diagonal'` means that off-diagonal elements
  1130. are not referenced.
  1131. Array argument(s) of this function may have additional
  1132. "batch" dimensions prepended to the core shape. In this case, the array is treated
  1133. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  1134. Parameters
  1135. ----------
  1136. a : array_like, shape (..., M, M)
  1137. Square matrix (or a batch of matrices) to be inverted.
  1138. overwrite_a : bool, optional
  1139. Discard data in `a` (may improve performance). Default is False.
  1140. check_finite : bool, optional
  1141. Whether to check that the input matrix contains only finite numbers.
  1142. Disabling may give a performance gain, but may result in problems
  1143. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1144. assume_a : str, optional
  1145. Valid entries are described above.
  1146. If omitted or ``None``, checks are performed to identify structure so the
  1147. appropriate solver can be called.
  1148. lower : bool, optional
  1149. Ignored unless `assume_a` is one of 'sym', 'her', or 'pos'. If True, the
  1150. calculation uses only the data in the lower triangle of `a`; entries above the
  1151. diagonal are ignored. If False (default), the calculation uses only the data in
  1152. the upper triangle of `a`; entries below the diagonal are ignored.
  1153. Returns
  1154. -------
  1155. ainv : ndarray
  1156. Inverse of the matrix `a`.
  1157. Raises
  1158. ------
  1159. LinAlgError
  1160. If `a` is singular.
  1161. ValueError
  1162. If `a` is not square, or not 2D.
  1163. Examples
  1164. --------
  1165. >>> import numpy as np
  1166. >>> from scipy import linalg
  1167. >>> a = np.array([[1., 2.], [3., 4.]])
  1168. >>> linalg.inv(a)
  1169. array([[-2. , 1. ],
  1170. [ 1.5, -0.5]])
  1171. >>> np.dot(a, linalg.inv(a))
  1172. array([[ 1., 0.],
  1173. [ 0., 1.]])
  1174. Notes
  1175. -----
  1176. The input array ``a`` may represent a single matrix or a collection (a.k.a.
  1177. a "batch") of square matrices. For example, if ``a.shape == (4, 3, 2, 2)``, it is
  1178. interpreted as a ``(4, 3)``-shaped batch of :math:`2\times 2` matrices.
  1179. This routine checks the condition number of the `a` matrix and emits a
  1180. `LinAlgWarning` for ill-conditioned inputs.
  1181. """
  1182. a1 = _asarray_validated(a, check_finite=check_finite)
  1183. if a1.ndim < 2:
  1184. raise ValueError(f"Expected at least ndim=2, got {a1.ndim=}")
  1185. if a1.shape[-1] != a1.shape[-2]:
  1186. raise ValueError(f"Expected square matrix, got {a1.shape=}")
  1187. # accommodate empty matrices
  1188. if a1.size == 0:
  1189. dt = inv(np.eye(2, dtype=a1.dtype)).dtype
  1190. return np.empty_like(a1, dtype=dt)
  1191. # Also check if dtype is LAPACK compatible
  1192. a1, overwrite_a = _normalize_lapack_dtype(a1, overwrite_a)
  1193. if not (a1.flags['ALIGNED'] or a1.dtype.byteorder == '='):
  1194. overwrite_a = True
  1195. a1 = a1.copy()
  1196. # keep the numbers in sync with C at `linalg/src/_common_array_utils.hh`
  1197. structure = {
  1198. None: -1,
  1199. 'general': 0, 'gen': 0,
  1200. 'diagonal': 11,
  1201. 'upper triangular': 21,
  1202. 'lower triangular': 22,
  1203. 'pos' : 101,
  1204. 'sym' : 201,
  1205. 'her' : 211,
  1206. }[assume_a]
  1207. # a1 is well behaved, invert it.
  1208. inv_a, err_lst = _batched_linalg._inv(a1, structure, overwrite_a, lower)
  1209. if err_lst:
  1210. _format_emit_errors_warnings(err_lst)
  1211. return inv_a
  1212. # Determinant
  1213. def det(a, overwrite_a=False, check_finite=True):
  1214. """
  1215. Compute the determinant of a matrix
  1216. The determinant is a scalar that is a function of the associated square
  1217. matrix coefficients. The determinant value is zero for singular matrices.
  1218. Array argument(s) of this function may have additional
  1219. "batch" dimensions prepended to the core shape. In this case, the array is treated
  1220. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  1221. Parameters
  1222. ----------
  1223. a : (..., M, M) array_like
  1224. Input array to compute determinants for.
  1225. overwrite_a : bool, optional
  1226. Allow overwriting data in a (may enhance performance).
  1227. check_finite : bool, optional
  1228. Whether to check that the input matrix contains only finite numbers.
  1229. Disabling may give a performance gain, but may result in problems
  1230. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1231. Returns
  1232. -------
  1233. det : (...) float or complex
  1234. Determinant of `a`. For stacked arrays, a scalar is returned for each
  1235. (m, m) slice in the last two dimensions of the input. For example, an
  1236. input of shape (p, q, m, m) will produce a result of shape (p, q). If
  1237. all dimensions are 1 a scalar is returned regardless of ndim.
  1238. Notes
  1239. -----
  1240. The determinant is computed by performing an LU factorization of the
  1241. input with LAPACK routine 'getrf', and then calculating the product of
  1242. diagonal entries of the U factor.
  1243. Even if the input array is single precision (float32 or complex64), the
  1244. result will be returned in double precision (float64 or complex128) to
  1245. prevent overflows.
  1246. Examples
  1247. --------
  1248. >>> import numpy as np
  1249. >>> from scipy import linalg
  1250. >>> a = np.array([[1,2,3], [4,5,6], [7,8,9]]) # A singular matrix
  1251. >>> linalg.det(a)
  1252. 0.0
  1253. >>> b = np.array([[0,2,3], [4,5,6], [7,8,9]])
  1254. >>> linalg.det(b)
  1255. 3.0
  1256. >>> # An array with the shape (3, 2, 2, 2)
  1257. >>> c = np.array([[[[1., 2.], [3., 4.]],
  1258. ... [[5., 6.], [7., 8.]]],
  1259. ... [[[9., 10.], [11., 12.]],
  1260. ... [[13., 14.], [15., 16.]]],
  1261. ... [[[17., 18.], [19., 20.]],
  1262. ... [[21., 22.], [23., 24.]]]])
  1263. >>> linalg.det(c) # The resulting shape is (3, 2)
  1264. array([[-2., -2.],
  1265. [-2., -2.],
  1266. [-2., -2.]])
  1267. >>> linalg.det(c[0, 0]) # Confirm the (0, 0) slice, [[1, 2], [3, 4]]
  1268. -2.0
  1269. """
  1270. # The goal is to end up with a writable contiguous array to pass to Cython
  1271. # First we check and make arrays.
  1272. a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
  1273. if a1.ndim < 2:
  1274. raise ValueError('The input array must be at least two-dimensional.')
  1275. if a1.shape[-1] != a1.shape[-2]:
  1276. raise ValueError('Last 2 dimensions of the array must be square'
  1277. f' but received shape {a1.shape}.')
  1278. # Also check if dtype is LAPACK compatible
  1279. a1, overwrite_a = _normalize_lapack_dtype(a1, overwrite_a)
  1280. # Empty array has determinant 1 because math.
  1281. if min(*a1.shape) == 0:
  1282. dtyp = np.float64 if a1.dtype.char not in 'FD' else np.complex128
  1283. if a1.ndim == 2:
  1284. return dtyp(1.0)
  1285. else:
  1286. return np.ones(shape=a1.shape[:-2], dtype=dtyp)
  1287. # Scalar case
  1288. if a1.shape[-2:] == (1, 1):
  1289. a1 = a1[..., 0, 0]
  1290. if a1.ndim == 0:
  1291. a1 = a1[()]
  1292. # Convert float32 to float64, and complex64 to complex128.
  1293. if a1.dtype.char in 'dD':
  1294. return a1
  1295. return a1.astype('d') if a1.dtype.char == 'f' else a1.astype('D')
  1296. # Then check overwrite permission
  1297. if not _datacopied(a1, a): # "a" still alive through "a1"
  1298. if not overwrite_a:
  1299. # Data belongs to "a" so make a copy
  1300. a1 = a1.copy(order='C')
  1301. # else: Do nothing we'll use "a" if possible
  1302. # else: a1 has its own data thus free to scratch
  1303. # Then layout checks, might happen that overwrite is allowed but original
  1304. # array was read-only or non-C-contiguous.
  1305. if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']):
  1306. a1 = a1.copy(order='C')
  1307. if a1.ndim == 2:
  1308. det = find_det_from_lu(a1)
  1309. # Convert float, complex to NumPy scalars
  1310. return (np.float64(det) if np.isrealobj(det) else np.complex128(det))
  1311. # loop over the stacked array, and avoid overflows for single precision
  1312. # Cf. np.linalg.det(np.diag([1e+38, 1e+38]).astype(np.float32))
  1313. dtype_char = a1.dtype.char
  1314. if dtype_char in 'fF':
  1315. dtype_char = 'd' if dtype_char.islower() else 'D'
  1316. det = np.empty(a1.shape[:-2], dtype=dtype_char)
  1317. for ind in product(*[range(x) for x in a1.shape[:-2]]):
  1318. det[ind] = find_det_from_lu(a1[ind])
  1319. return det
  1320. # Linear Least Squares
  1321. @_apply_over_batch(('a', 2), ('b', '1|2'))
  1322. def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
  1323. check_finite=True, lapack_driver=None):
  1324. """
  1325. Compute least-squares solution to the equation ``a @ x = b``.
  1326. Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
  1327. Parameters
  1328. ----------
  1329. a : (M, N) array_like
  1330. Left-hand side array
  1331. b : (M,) or (M, K) array_like
  1332. Right hand side array
  1333. cond : float, optional
  1334. Cutoff for 'small' singular values; used to determine effective
  1335. rank of a. Singular values smaller than
  1336. ``cond * largest_singular_value`` are considered zero.
  1337. overwrite_a : bool, optional
  1338. Discard data in `a` (may enhance performance). Default is False.
  1339. overwrite_b : bool, optional
  1340. Discard data in `b` (may enhance performance). Default is False.
  1341. check_finite : bool, optional
  1342. Whether to check that the input matrices contain only finite numbers.
  1343. Disabling may give a performance gain, but may result in problems
  1344. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1345. lapack_driver : str, optional
  1346. Which LAPACK driver is used to solve the least-squares problem.
  1347. Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default
  1348. (``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly
  1349. faster on many problems. ``'gelss'`` was used historically. It is
  1350. generally slow but uses less memory.
  1351. .. versionadded:: 0.17.0
  1352. Returns
  1353. -------
  1354. x : (N,) or (N, K) ndarray
  1355. Least-squares solution.
  1356. residues : (K,) ndarray or float
  1357. Square of the 2-norm for each column in ``b - a x``, if ``M > N`` and
  1358. ``rank(A) == n`` (returns a scalar if ``b`` is 1-D). Otherwise a
  1359. (0,)-shaped array is returned.
  1360. rank : int
  1361. Effective rank of `a`.
  1362. s : (min(M, N),) ndarray or None
  1363. Singular values of `a`. The condition number of ``a`` is
  1364. ``s[0] / s[-1]``.
  1365. Raises
  1366. ------
  1367. LinAlgError
  1368. If computation does not converge.
  1369. ValueError
  1370. When parameters are not compatible.
  1371. See Also
  1372. --------
  1373. scipy.optimize.nnls : linear least squares with non-negativity constraint
  1374. Notes
  1375. -----
  1376. When ``'gelsy'`` is used as a driver, `residues` is set to a (0,)-shaped
  1377. array and `s` is always ``None``.
  1378. Examples
  1379. --------
  1380. >>> import numpy as np
  1381. >>> from scipy.linalg import lstsq
  1382. >>> import matplotlib.pyplot as plt
  1383. Suppose we have the following data:
  1384. >>> x = np.array([1, 2.5, 3.5, 4, 5, 7, 8.5])
  1385. >>> y = np.array([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6])
  1386. We want to fit a quadratic polynomial of the form ``y = a + b*x**2``
  1387. to this data. We first form the "design matrix" M, with a constant
  1388. column of 1s and a column containing ``x**2``:
  1389. >>> M = x[:, np.newaxis]**[0, 2]
  1390. >>> M
  1391. array([[ 1. , 1. ],
  1392. [ 1. , 6.25],
  1393. [ 1. , 12.25],
  1394. [ 1. , 16. ],
  1395. [ 1. , 25. ],
  1396. [ 1. , 49. ],
  1397. [ 1. , 72.25]])
  1398. We want to find the least-squares solution to ``M.dot(p) = y``,
  1399. where ``p`` is a vector with length 2 that holds the parameters
  1400. ``a`` and ``b``.
  1401. >>> p, res, rnk, s = lstsq(M, y)
  1402. >>> p
  1403. array([ 0.20925829, 0.12013861])
  1404. Plot the data and the fitted curve.
  1405. >>> plt.plot(x, y, 'o', label='data')
  1406. >>> xx = np.linspace(0, 9, 101)
  1407. >>> yy = p[0] + p[1]*xx**2
  1408. >>> plt.plot(xx, yy, label='least squares fit, $y = a + bx^2$')
  1409. >>> plt.xlabel('x')
  1410. >>> plt.ylabel('y')
  1411. >>> plt.legend(framealpha=1, shadow=True)
  1412. >>> plt.grid(alpha=0.25)
  1413. >>> plt.show()
  1414. """
  1415. a1 = _asarray_validated(a, check_finite=check_finite)
  1416. b1 = _asarray_validated(b, check_finite=check_finite)
  1417. if len(a1.shape) != 2:
  1418. raise ValueError('Input array a should be 2D')
  1419. m, n = a1.shape
  1420. if len(b1.shape) == 2:
  1421. nrhs = b1.shape[1]
  1422. else:
  1423. nrhs = 1
  1424. if m != b1.shape[0]:
  1425. raise ValueError('Shape mismatch: a and b should have the same number'
  1426. f' of rows ({m} != {b1.shape[0]}).')
  1427. if m == 0 or n == 0: # Zero-sized problem, confuses LAPACK
  1428. x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1))
  1429. if n == 0:
  1430. residues = np.linalg.norm(b1, axis=0)**2
  1431. else:
  1432. residues = np.empty((0,))
  1433. return x, residues, 0, np.empty((0,))
  1434. driver = lapack_driver
  1435. if driver is None:
  1436. driver = lstsq.default_lapack_driver
  1437. if driver not in ('gelsd', 'gelsy', 'gelss'):
  1438. raise ValueError(f'LAPACK driver "{driver}" is not found')
  1439. lapack_func, lapack_lwork = get_lapack_funcs((driver,
  1440. f'{driver}_lwork'),
  1441. (a1, b1))
  1442. real_data = True if (lapack_func.dtype.kind == 'f') else False
  1443. if m < n:
  1444. # need to extend b matrix as it will be filled with
  1445. # a larger solution matrix
  1446. if len(b1.shape) == 2:
  1447. b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype)
  1448. b2[:m, :] = b1
  1449. else:
  1450. b2 = np.zeros(n, dtype=lapack_func.dtype)
  1451. b2[:m] = b1
  1452. b1 = b2
  1453. overwrite_a = overwrite_a or _datacopied(a1, a)
  1454. overwrite_b = overwrite_b or _datacopied(b1, b)
  1455. if cond is None:
  1456. cond = np.finfo(lapack_func.dtype).eps
  1457. if driver in ('gelss', 'gelsd'):
  1458. if driver == 'gelss':
  1459. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1460. v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
  1461. overwrite_a=overwrite_a,
  1462. overwrite_b=overwrite_b)
  1463. elif driver == 'gelsd':
  1464. if real_data:
  1465. lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1466. x, s, rank, info = lapack_func(a1, b1, lwork,
  1467. iwork, cond, False, False)
  1468. else: # complex data
  1469. lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
  1470. nrhs, cond)
  1471. x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
  1472. cond, False, False)
  1473. if info > 0:
  1474. raise LinAlgError("SVD did not converge in Linear Least Squares")
  1475. if info < 0:
  1476. raise ValueError(
  1477. f'illegal value in {-info}-th argument of internal {lapack_driver}'
  1478. )
  1479. resids = np.asarray([], dtype=x.dtype)
  1480. if m > n:
  1481. x1 = x[:n]
  1482. if rank == n:
  1483. resids = np.sum(np.abs(x[n:])**2, axis=0)
  1484. x = x1
  1485. return x, resids, rank, s
  1486. elif driver == 'gelsy':
  1487. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1488. jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
  1489. v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
  1490. lwork, False, False)
  1491. if info < 0:
  1492. raise ValueError(f'illegal value in {-info}-th argument of internal gelsy')
  1493. if m > n:
  1494. x1 = x[:n]
  1495. x = x1
  1496. return x, np.array([], x.dtype), rank, None
  1497. lstsq.default_lapack_driver = 'gelsd'
  1498. @_apply_over_batch(('a', 2))
  1499. def pinv(a, *, atol=None, rtol=None, return_rank=False, check_finite=True):
  1500. """
  1501. Compute the (Moore-Penrose) pseudo-inverse of a matrix.
  1502. Calculate a generalized inverse of a matrix using its
  1503. singular-value decomposition ``U @ S @ V`` in the economy mode and picking
  1504. up only the columns/rows that are associated with significant singular
  1505. values.
  1506. If ``s`` is the maximum singular value of ``a``, then the
  1507. significance cut-off value is determined by ``atol + rtol * s``. Any
  1508. singular value below this value is assumed insignificant.
  1509. Parameters
  1510. ----------
  1511. a : (M, N) array_like
  1512. Matrix to be pseudo-inverted.
  1513. atol : float, optional
  1514. Absolute threshold term, default value is 0.
  1515. .. versionadded:: 1.7.0
  1516. rtol : float, optional
  1517. Relative threshold term, default value is ``max(M, N) * eps`` where
  1518. ``eps`` is the machine precision value of the datatype of ``a``.
  1519. .. versionadded:: 1.7.0
  1520. return_rank : bool, optional
  1521. If True, return the effective rank of the matrix.
  1522. check_finite : bool, optional
  1523. Whether to check that the input matrix contains only finite numbers.
  1524. Disabling may give a performance gain, but may result in problems
  1525. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1526. Returns
  1527. -------
  1528. B : (N, M) ndarray
  1529. The pseudo-inverse of matrix `a`.
  1530. rank : int
  1531. The effective rank of the matrix. Returned if `return_rank` is True.
  1532. Raises
  1533. ------
  1534. LinAlgError
  1535. If SVD computation does not converge.
  1536. See Also
  1537. --------
  1538. pinvh : Moore-Penrose pseudoinverse of a hermitian matrix.
  1539. Notes
  1540. -----
  1541. If ``A`` is invertible then the Moore-Penrose pseudoinverse is exactly
  1542. the inverse of ``A`` [1]_. If ``A`` is not invertible then the
  1543. Moore-Penrose pseudoinverse computes the ``x`` solution to ``Ax = b`` such
  1544. that ``||Ax - b||`` is minimized [1]_.
  1545. References
  1546. ----------
  1547. .. [1] Penrose, R. (1956). On best approximate solutions of linear matrix
  1548. equations. Mathematical Proceedings of the Cambridge Philosophical
  1549. Society, 52(1), 17-19. doi:10.1017/S0305004100030929
  1550. Examples
  1551. --------
  1552. Given an ``m x n`` matrix ``A`` and an ``n x m`` matrix ``B`` the four
  1553. Moore-Penrose conditions are:
  1554. 1. ``ABA = A`` (``B`` is a generalized inverse of ``A``),
  1555. 2. ``BAB = B`` (``A`` is a generalized inverse of ``B``),
  1556. 3. ``(AB)* = AB`` (``AB`` is hermitian),
  1557. 4. ``(BA)* = BA`` (``BA`` is hermitian) [1]_.
  1558. Here, ``A*`` denotes the conjugate transpose. The Moore-Penrose
  1559. pseudoinverse is a unique ``B`` that satisfies all four of these
  1560. conditions and exists for any ``A``. Note that, unlike the standard
  1561. matrix inverse, ``A`` does not have to be a square matrix or have
  1562. linearly independent columns/rows.
  1563. As an example, we can calculate the Moore-Penrose pseudoinverse of a
  1564. random non-square matrix and verify it satisfies the four conditions.
  1565. >>> import numpy as np
  1566. >>> from scipy import linalg
  1567. >>> rng = np.random.default_rng()
  1568. >>> A = rng.standard_normal((9, 6))
  1569. >>> B = linalg.pinv(A)
  1570. >>> np.allclose(A @ B @ A, A) # Condition 1
  1571. True
  1572. >>> np.allclose(B @ A @ B, B) # Condition 2
  1573. True
  1574. >>> np.allclose((A @ B).conj().T, A @ B) # Condition 3
  1575. True
  1576. >>> np.allclose((B @ A).conj().T, B @ A) # Condition 4
  1577. True
  1578. """
  1579. a = _asarray_validated(a, check_finite=check_finite)
  1580. u, s, vh = _decomp_svd.svd(a, full_matrices=False, check_finite=False)
  1581. t = u.dtype.char.lower()
  1582. maxS = np.max(s, initial=0.)
  1583. atol = 0. if atol is None else atol
  1584. rtol = max(a.shape) * np.finfo(t).eps if (rtol is None) else rtol
  1585. if (atol < 0.) or (rtol < 0.):
  1586. raise ValueError("atol and rtol values must be positive.")
  1587. val = atol + maxS * rtol
  1588. rank = np.sum(s > val)
  1589. u = u[:, :rank]
  1590. u /= s[:rank]
  1591. B = (u @ vh[:rank]).conj().T
  1592. if return_rank:
  1593. return B, rank
  1594. else:
  1595. return B
  1596. @_apply_over_batch(('a', 2))
  1597. def pinvh(a, atol=None, rtol=None, lower=True, return_rank=False,
  1598. check_finite=True):
  1599. """
  1600. Compute the (Moore-Penrose) pseudo-inverse of a Hermitian matrix.
  1601. Calculate a generalized inverse of a complex Hermitian/real symmetric
  1602. matrix using its eigenvalue decomposition and including all eigenvalues
  1603. with 'large' absolute value.
  1604. Parameters
  1605. ----------
  1606. a : (N, N) array_like
  1607. Real symmetric or complex hermetian matrix to be pseudo-inverted
  1608. atol : float, optional
  1609. Absolute threshold term, default value is 0.
  1610. .. versionadded:: 1.7.0
  1611. rtol : float, optional
  1612. Relative threshold term, default value is ``N * eps`` where
  1613. ``eps`` is the machine precision value of the datatype of ``a``.
  1614. .. versionadded:: 1.7.0
  1615. lower : bool, optional
  1616. Whether the pertinent array data is taken from the lower or upper
  1617. triangle of `a`. (Default: lower)
  1618. return_rank : bool, optional
  1619. If True, return the effective rank of the matrix.
  1620. check_finite : bool, optional
  1621. Whether to check that the input matrix contains only finite numbers.
  1622. Disabling may give a performance gain, but may result in problems
  1623. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1624. Returns
  1625. -------
  1626. B : (N, N) ndarray
  1627. The pseudo-inverse of matrix `a`.
  1628. rank : int
  1629. The effective rank of the matrix. Returned if `return_rank` is True.
  1630. Raises
  1631. ------
  1632. LinAlgError
  1633. If eigenvalue algorithm does not converge.
  1634. See Also
  1635. --------
  1636. pinv : Moore-Penrose pseudoinverse of a matrix.
  1637. Examples
  1638. --------
  1639. For a more detailed example see `pinv`.
  1640. >>> import numpy as np
  1641. >>> from scipy.linalg import pinvh
  1642. >>> rng = np.random.default_rng()
  1643. >>> a = rng.standard_normal((9, 6))
  1644. >>> a = np.dot(a, a.T)
  1645. >>> B = pinvh(a)
  1646. >>> np.allclose(a, a @ B @ a)
  1647. True
  1648. >>> np.allclose(B, B @ a @ B)
  1649. True
  1650. """
  1651. a = _asarray_validated(a, check_finite=check_finite)
  1652. s, u = _decomp.eigh(a, lower=lower, check_finite=False, driver='ev')
  1653. t = u.dtype.char.lower()
  1654. maxS = np.max(np.abs(s), initial=0.)
  1655. atol = 0. if atol is None else atol
  1656. rtol = max(a.shape) * np.finfo(t).eps if (rtol is None) else rtol
  1657. if (atol < 0.) or (rtol < 0.):
  1658. raise ValueError("atol and rtol values must be positive.")
  1659. val = atol + maxS * rtol
  1660. above_cutoff = (abs(s) > val)
  1661. psigma_diag = 1.0 / s[above_cutoff]
  1662. u = u[:, above_cutoff]
  1663. B = (u * psigma_diag) @ u.conj().T
  1664. if return_rank:
  1665. return B, len(psigma_diag)
  1666. else:
  1667. return B
  1668. @_apply_over_batch(('A', 2))
  1669. def matrix_balance(A, permute=True, scale=True, separate=False,
  1670. overwrite_a=False):
  1671. """
  1672. Compute a diagonal similarity transformation for row/column balancing.
  1673. The balancing tries to equalize the row and column 1-norms by applying
  1674. a similarity transformation such that the magnitude variation of the
  1675. matrix entries is reflected to the scaling matrices.
  1676. Moreover, if enabled, the matrix is first permuted to isolate the upper
  1677. triangular parts of the matrix and, again if scaling is also enabled,
  1678. only the remaining subblocks are subjected to scaling.
  1679. Parameters
  1680. ----------
  1681. A : (n, n) array_like
  1682. Square data matrix for the balancing.
  1683. permute : bool, optional
  1684. The selector to define whether permutation of A is also performed
  1685. prior to scaling.
  1686. scale : bool, optional
  1687. The selector to turn on and off the scaling. If False, the matrix
  1688. will not be scaled.
  1689. separate : bool, optional
  1690. This switches from returning a full matrix of the transformation
  1691. to a tuple of two separate 1-D permutation and scaling arrays.
  1692. overwrite_a : bool, optional
  1693. This is passed to xGEBAL directly. Essentially, overwrites the result
  1694. to the data. It might increase the space efficiency. See LAPACK manual
  1695. for details. This is False by default.
  1696. Returns
  1697. -------
  1698. B : (n, n) ndarray
  1699. Balanced matrix
  1700. T : (n, n) ndarray
  1701. A possibly permuted diagonal matrix whose nonzero entries are
  1702. integer powers of 2 to avoid numerical truncation errors.
  1703. scale, perm : (n,) ndarray
  1704. If ``separate`` keyword is set to True then instead of the array
  1705. ``T`` above, the scaling and the permutation vectors are given
  1706. separately as a tuple without allocating the full array ``T``.
  1707. Notes
  1708. -----
  1709. The balanced matrix satisfies the following equality
  1710. .. math::
  1711. B = T^{-1} A T
  1712. The scaling coefficients are approximated to the nearest power of 2
  1713. to avoid round-off errors.
  1714. This algorithm is particularly useful for eigenvalue and matrix
  1715. decompositions and in many cases it is already called by various
  1716. LAPACK routines.
  1717. The algorithm is based on the well-known technique of [1]_ and has
  1718. been modified to account for special cases. See [2]_ for details
  1719. which have been implemented since LAPACK v3.5.0. Before this version
  1720. there are corner cases where balancing can actually worsen the
  1721. conditioning. See [3]_ for such examples.
  1722. The code is a wrapper around LAPACK's xGEBAL routine family for matrix
  1723. balancing.
  1724. .. versionadded:: 0.19.0
  1725. References
  1726. ----------
  1727. .. [1] B.N. Parlett and C. Reinsch, "Balancing a Matrix for
  1728. Calculation of Eigenvalues and Eigenvectors", Numerische Mathematik,
  1729. Vol.13(4), 1969, :doi:`10.1007/BF02165404`
  1730. .. [2] R. James, J. Langou, B.R. Lowery, "On matrix balancing and
  1731. eigenvector computation", 2014, :arxiv:`1401.5766`
  1732. .. [3] D.S. Watkins. A case where balancing is harmful.
  1733. Electron. Trans. Numer. Anal, Vol.23, 2006.
  1734. Examples
  1735. --------
  1736. >>> import numpy as np
  1737. >>> from scipy import linalg
  1738. >>> x = np.array([[1,2,0], [9,1,0.01], [1,2,10*np.pi]])
  1739. >>> y, permscale = linalg.matrix_balance(x)
  1740. >>> np.abs(x).sum(axis=0) / np.abs(x).sum(axis=1)
  1741. array([ 3.66666667, 0.4995005 , 0.91312162])
  1742. >>> np.abs(y).sum(axis=0) / np.abs(y).sum(axis=1)
  1743. array([ 1.2 , 1.27041742, 0.92658316]) # may vary
  1744. >>> permscale # only powers of 2 (0.5 == 2^(-1))
  1745. array([[ 0.5, 0. , 0. ], # may vary
  1746. [ 0. , 1. , 0. ],
  1747. [ 0. , 0. , 1. ]])
  1748. """
  1749. A = np.atleast_2d(_asarray_validated(A, check_finite=True))
  1750. if not np.equal(*A.shape):
  1751. raise ValueError('The data matrix for balancing should be square.')
  1752. # accommodate empty arrays
  1753. if A.size == 0:
  1754. b_n, t_n = matrix_balance(np.eye(2, dtype=A.dtype))
  1755. B = np.empty_like(A, dtype=b_n.dtype)
  1756. if separate:
  1757. scaling = np.ones_like(A, shape=len(A))
  1758. perm = np.arange(len(A))
  1759. return B, (scaling, perm)
  1760. return B, np.empty_like(A, dtype=t_n.dtype)
  1761. gebal = get_lapack_funcs(('gebal'), (A,))
  1762. B, lo, hi, ps, info = gebal(A, scale=scale, permute=permute,
  1763. overwrite_a=overwrite_a)
  1764. if info < 0:
  1765. raise ValueError('xGEBAL exited with the internal error '
  1766. f'"illegal value in argument number {-info}.". See '
  1767. 'LAPACK documentation for the xGEBAL error codes.')
  1768. # Separate the permutations from the scalings and then convert to int
  1769. scaling = np.ones_like(ps, dtype=float)
  1770. scaling[lo:hi+1] = ps[lo:hi+1]
  1771. # gebal uses 1-indexing
  1772. ps = ps.astype(int, copy=False) - 1
  1773. n = A.shape[0]
  1774. perm = np.arange(n)
  1775. # LAPACK permutes with the ordering n --> hi, then 0--> lo
  1776. if hi < n:
  1777. for ind, x in enumerate(ps[hi+1:][::-1], 1):
  1778. if n-ind == x:
  1779. continue
  1780. perm[[x, n-ind]] = perm[[n-ind, x]]
  1781. if lo > 0:
  1782. for ind, x in enumerate(ps[:lo]):
  1783. if ind == x:
  1784. continue
  1785. perm[[x, ind]] = perm[[ind, x]]
  1786. if separate:
  1787. return B, (scaling, perm)
  1788. # get the inverse permutation
  1789. iperm = np.empty_like(perm)
  1790. iperm[perm] = np.arange(n)
  1791. return B, np.diag(scaling)[iperm, :]
  1792. def _validate_args_for_toeplitz_ops(c_or_cr, b, check_finite, keep_b_shape,
  1793. enforce_square=True):
  1794. """Validate arguments and format inputs for toeplitz functions
  1795. Parameters
  1796. ----------
  1797. c_or_cr : array_like or tuple of (array_like, array_like)
  1798. The vector ``c``, or a tuple of arrays (``c``, ``r``). Whatever the
  1799. actual shape of ``c``, it will be converted to a 1-D array. If not
  1800. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  1801. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  1802. of the Toeplitz matrix is ``[c[0], r[1:]]``. Whatever the actual shape
  1803. of ``r``, it will be converted to a 1-D array.
  1804. b : (M,) or (M, K) array_like
  1805. Right-hand side in ``T x = b``.
  1806. check_finite : bool
  1807. Whether to check that the input matrices contain only finite numbers.
  1808. Disabling may give a performance gain, but may result in problems
  1809. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  1810. keep_b_shape : bool
  1811. Whether to convert a (M,) dimensional b into a (M, 1) dimensional
  1812. matrix.
  1813. enforce_square : bool, optional
  1814. If True (default), this verifies that the Toeplitz matrix is square.
  1815. Returns
  1816. -------
  1817. r : array
  1818. 1d array corresponding to the first row of the Toeplitz matrix.
  1819. c: array
  1820. 1d array corresponding to the first column of the Toeplitz matrix.
  1821. b: array
  1822. (M,), (M, 1) or (M, K) dimensional array, post validation,
  1823. corresponding to ``b``.
  1824. dtype: numpy datatype
  1825. ``dtype`` stores the datatype of ``r``, ``c`` and ``b``. If any of
  1826. ``r``, ``c`` or ``b`` are complex, ``dtype`` is ``np.complex128``,
  1827. otherwise, it is ``np.float``.
  1828. b_shape: tuple
  1829. Shape of ``b`` after passing it through ``_asarray_validated``.
  1830. """
  1831. if isinstance(c_or_cr, tuple):
  1832. c, r = c_or_cr
  1833. c = _asarray_validated(c, check_finite=check_finite)
  1834. r = _asarray_validated(r, check_finite=check_finite)
  1835. else:
  1836. c = _asarray_validated(c_or_cr, check_finite=check_finite)
  1837. r = c.conjugate()
  1838. if b is None:
  1839. raise ValueError('`b` must be an array, not None.')
  1840. b = _asarray_validated(b, check_finite=check_finite)
  1841. b_shape = b.shape
  1842. is_not_square = r.shape[0] != c.shape[0]
  1843. if (enforce_square and is_not_square) or b.shape[0] != r.shape[0]:
  1844. raise ValueError('Incompatible dimensions.')
  1845. is_cmplx = np.iscomplexobj(r) or np.iscomplexobj(c) or np.iscomplexobj(b)
  1846. dtype = np.complex128 if is_cmplx else np.float64
  1847. r, c, b = (np.asarray(i, dtype=dtype) for i in (r, c, b))
  1848. if b.ndim == 1 and not keep_b_shape:
  1849. b = b.reshape(-1, 1)
  1850. elif b.ndim != 1:
  1851. b = b.reshape(b.shape[0], -1 if b.size > 0 else 0)
  1852. return r, c, b, dtype, b_shape
  1853. def matmul_toeplitz(c_or_cr, x, check_finite=False, workers=None):
  1854. r"""Efficient Toeplitz Matrix-Matrix Multiplication using FFT
  1855. This function returns the matrix multiplication between a Toeplitz
  1856. matrix and a dense matrix.
  1857. The Toeplitz matrix has constant diagonals, with c as its first column
  1858. and r as its first row. If r is not given, ``r == conjugate(c)`` is
  1859. assumed.
  1860. The documentation is written assuming array arguments are of specified
  1861. "core" shapes. However, array argument(s) of this function may have additional
  1862. "batch" dimensions prepended to the core shape. In this case, the array is treated
  1863. as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
  1864. Parameters
  1865. ----------
  1866. c_or_cr : array_like or tuple of (array_like, array_like)
  1867. The vector ``c``, or a tuple of arrays (``c``, ``r``). If not
  1868. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  1869. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  1870. of the Toeplitz matrix is ``[c[0], r[1:]]``.
  1871. x : (M,) or (M, K) array_like
  1872. Matrix with which to multiply.
  1873. check_finite : bool, optional
  1874. Whether to check that the input matrices contain only finite numbers.
  1875. Disabling may give a performance gain, but may result in problems
  1876. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  1877. workers : int, optional
  1878. To pass to scipy.fft.fft and ifft. Maximum number of workers to use
  1879. for parallel computation. If negative, the value wraps around from
  1880. ``os.cpu_count()``. See scipy.fft.fft for more details.
  1881. Returns
  1882. -------
  1883. T @ x : (M,) or (M, K) ndarray
  1884. The result of the matrix multiplication ``T @ x``. Shape of return
  1885. matches shape of `x`.
  1886. See Also
  1887. --------
  1888. toeplitz : Toeplitz matrix
  1889. solve_toeplitz : Solve a Toeplitz system using Levinson Recursion
  1890. Notes
  1891. -----
  1892. The Toeplitz matrix is embedded in a circulant matrix and the FFT is used
  1893. to efficiently calculate the matrix-matrix product.
  1894. Because the computation is based on the FFT, integer inputs will
  1895. result in floating point outputs. This is unlike NumPy's `matmul`,
  1896. which preserves the data type of the input.
  1897. This is partly based on the implementation that can be found in [1]_,
  1898. licensed under the MIT license. More information about the method can be
  1899. found in reference [2]_. References [3]_ and [4]_ have more reference
  1900. implementations in Python.
  1901. .. versionadded:: 1.6.0
  1902. References
  1903. ----------
  1904. .. [1] Jacob R Gardner, Geoff Pleiss, David Bindel, Kilian
  1905. Q Weinberger, Andrew Gordon Wilson, "GPyTorch: Blackbox Matrix-Matrix
  1906. Gaussian Process Inference with GPU Acceleration" with contributions
  1907. from Max Balandat and Ruihan Wu. Available online:
  1908. https://github.com/cornellius-gp/gpytorch
  1909. .. [2] J. Demmel, P. Koev, and X. Li, "A Brief Survey of Direct Linear
  1910. Solvers". In Z. Bai, J. Demmel, J. Dongarra, A. Ruhe, and H. van der
  1911. Vorst, editors. Templates for the Solution of Algebraic Eigenvalue
  1912. Problems: A Practical Guide. SIAM, Philadelphia, 2000. Available at:
  1913. http://www.netlib.org/utk/people/JackDongarra/etemplates/node384.html
  1914. .. [3] R. Scheibler, E. Bezzam, I. Dokmanic, Pyroomacoustics: A Python
  1915. package for audio room simulations and array processing algorithms,
  1916. Proc. IEEE ICASSP, Calgary, CA, 2018.
  1917. https://github.com/LCAV/pyroomacoustics/blob/pypi-release/
  1918. pyroomacoustics/adaptive/util.py
  1919. .. [4] Marano S, Edwards B, Ferrari G and Fah D (2017), "Fitting
  1920. Earthquake Spectra: Colored Noise and Incomplete Data", Bulletin of
  1921. the Seismological Society of America., January, 2017. Vol. 107(1),
  1922. pp. 276-291.
  1923. Examples
  1924. --------
  1925. Multiply the Toeplitz matrix T with matrix x::
  1926. [ 1 -1 -2 -3] [1 10]
  1927. T = [ 3 1 -1 -2] x = [2 11]
  1928. [ 6 3 1 -1] [2 11]
  1929. [10 6 3 1] [5 19]
  1930. To specify the Toeplitz matrix, only the first column and the first
  1931. row are needed.
  1932. >>> import numpy as np
  1933. >>> c = np.array([1, 3, 6, 10]) # First column of T
  1934. >>> r = np.array([1, -1, -2, -3]) # First row of T
  1935. >>> x = np.array([[1, 10], [2, 11], [2, 11], [5, 19]])
  1936. >>> from scipy.linalg import toeplitz, matmul_toeplitz
  1937. >>> matmul_toeplitz((c, r), x)
  1938. array([[-20., -80.],
  1939. [ -7., -8.],
  1940. [ 9., 85.],
  1941. [ 33., 218.]])
  1942. Check the result by creating the full Toeplitz matrix and
  1943. multiplying it by ``x``.
  1944. >>> toeplitz(c, r) @ x
  1945. array([[-20, -80],
  1946. [ -7, -8],
  1947. [ 9, 85],
  1948. [ 33, 218]])
  1949. The full matrix is never formed explicitly, so this routine
  1950. is suitable for very large Toeplitz matrices.
  1951. >>> n = 1000000
  1952. >>> matmul_toeplitz([1] + [0]*(n-1), np.ones(n))
  1953. array([1., 1., 1., ..., 1., 1., 1.], shape=(1000000,))
  1954. """
  1955. from ..fft import fft, ifft, rfft, irfft
  1956. c, r = c_or_cr if isinstance(c_or_cr, tuple) else (c_or_cr, np.conjugate(c_or_cr))
  1957. return _matmul_toepltiz(r, c, x, workers, check_finite, fft, ifft, rfft, irfft)
  1958. @_apply_over_batch(('r', 1), ('c', 1), ('x', '1|2'))
  1959. def _matmul_toepltiz(r, c, x, workers, check_finite, fft, ifft, rfft, irfft):
  1960. r, c, x, dtype, x_shape = _validate_args_for_toeplitz_ops((c, r), x, check_finite,
  1961. keep_b_shape=False,
  1962. enforce_square=False)
  1963. n, m = x.shape
  1964. T_nrows = len(c)
  1965. T_ncols = len(r)
  1966. p = T_nrows + T_ncols - 1 # equivalent to len(embedded_col)
  1967. return_shape = (T_nrows,) if len(x_shape) == 1 else (T_nrows, m)
  1968. # accommodate empty arrays
  1969. if x.size == 0:
  1970. return np.empty_like(x, shape=return_shape)
  1971. embedded_col = np.concatenate((c, r[-1:0:-1]))
  1972. if np.iscomplexobj(embedded_col) or np.iscomplexobj(x):
  1973. fft_mat = fft(embedded_col, axis=0, workers=workers).reshape(-1, 1)
  1974. fft_x = fft(x, n=p, axis=0, workers=workers)
  1975. mat_times_x = ifft(fft_mat*fft_x, axis=0,
  1976. workers=workers)[:T_nrows, :]
  1977. else:
  1978. # Real inputs; using rfft is faster
  1979. fft_mat = rfft(embedded_col, axis=0, workers=workers).reshape(-1, 1)
  1980. fft_x = rfft(x, n=p, axis=0, workers=workers)
  1981. mat_times_x = irfft(fft_mat*fft_x, axis=0,
  1982. workers=workers, n=p)[:T_nrows, :]
  1983. return mat_times_x.reshape(*return_shape)