test_basic.py 93 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486
  1. import itertools
  2. import warnings
  3. import numpy as np
  4. from numpy import (arange, array, dot, zeros, identity, conjugate, transpose,
  5. float32)
  6. from numpy.testing import (assert_equal, assert_almost_equal, assert_,
  7. assert_array_almost_equal, assert_allclose,
  8. assert_array_equal)
  9. import pytest
  10. from pytest import raises as assert_raises
  11. from scipy.linalg import (solve, inv, det, lstsq, pinv, pinvh, norm,
  12. solve_banded, solveh_banded, solve_triangular,
  13. solve_circulant, circulant, LinAlgError, block_diag,
  14. matrix_balance, qr, LinAlgWarning)
  15. from scipy.linalg._testutils import assert_no_overwrite
  16. from scipy._lib._testutils import check_free_memory, IS_MUSL
  17. from scipy.linalg.blas import HAS_ILP64
  18. from scipy.conftest import skip_xp_invalid_arg
  19. REAL_DTYPES = (np.float32, np.float64, np.longdouble)
  20. COMPLEX_DTYPES = (np.complex64, np.complex128, np.clongdouble)
  21. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  22. parametrize_overwrite_arg = pytest.mark.parametrize(
  23. "overwrite_kw", [{"overwrite_a": True}, {"overwrite_a": False}, {}]
  24. )
  25. def _eps_cast(dtyp):
  26. """Get the epsilon for dtype, possibly downcast to BLAS types."""
  27. dt = dtyp
  28. if dt == np.longdouble:
  29. dt = np.float64
  30. elif dt == np.clongdouble:
  31. dt = np.complex128
  32. return np.finfo(dt).eps
  33. class TestSolveBanded:
  34. def test_real(self):
  35. a = array([[1.0, 20, 0, 0],
  36. [-30, 4, 6, 0],
  37. [2, 1, 20, 2],
  38. [0, -1, 7, 14]])
  39. ab = array([[0.0, 20, 6, 2],
  40. [1, 4, 20, 14],
  41. [-30, 1, 7, 0],
  42. [2, -1, 0, 0]])
  43. l, u = 2, 1
  44. b4 = array([10.0, 0.0, 2.0, 14.0])
  45. b4by1 = b4.reshape(-1, 1)
  46. b4by2 = array([[2, 1],
  47. [-30, 4],
  48. [2, 3],
  49. [1, 3]])
  50. b4by4 = array([[1, 0, 0, 0],
  51. [0, 0, 0, 1],
  52. [0, 1, 0, 0],
  53. [0, 1, 0, 0]])
  54. for b in [b4, b4by1, b4by2, b4by4]:
  55. x = solve_banded((l, u), ab, b)
  56. assert_array_almost_equal(dot(a, x), b)
  57. def test_complex(self):
  58. a = array([[1.0, 20, 0, 0],
  59. [-30, 4, 6, 0],
  60. [2j, 1, 20, 2j],
  61. [0, -1, 7, 14]])
  62. ab = array([[0.0, 20, 6, 2j],
  63. [1, 4, 20, 14],
  64. [-30, 1, 7, 0],
  65. [2j, -1, 0, 0]])
  66. l, u = 2, 1
  67. b4 = array([10.0, 0.0, 2.0, 14.0j])
  68. b4by1 = b4.reshape(-1, 1)
  69. b4by2 = array([[2, 1],
  70. [-30, 4],
  71. [2, 3],
  72. [1, 3]])
  73. b4by4 = array([[1, 0, 0, 0],
  74. [0, 0, 0, 1j],
  75. [0, 1, 0, 0],
  76. [0, 1, 0, 0]])
  77. for b in [b4, b4by1, b4by2, b4by4]:
  78. x = solve_banded((l, u), ab, b)
  79. assert_array_almost_equal(dot(a, x), b)
  80. def test_tridiag_real(self):
  81. ab = array([[0.0, 20, 6, 2],
  82. [1, 4, 20, 14],
  83. [-30, 1, 7, 0]])
  84. a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag(
  85. ab[2, :-1], -1)
  86. b4 = array([10.0, 0.0, 2.0, 14.0])
  87. b4by1 = b4.reshape(-1, 1)
  88. b4by2 = array([[2, 1],
  89. [-30, 4],
  90. [2, 3],
  91. [1, 3]])
  92. b4by4 = array([[1, 0, 0, 0],
  93. [0, 0, 0, 1],
  94. [0, 1, 0, 0],
  95. [0, 1, 0, 0]])
  96. for b in [b4, b4by1, b4by2, b4by4]:
  97. x = solve_banded((1, 1), ab, b)
  98. assert_array_almost_equal(dot(a, x), b)
  99. def test_tridiag_complex(self):
  100. ab = array([[0.0, 20, 6, 2j],
  101. [1, 4, 20, 14],
  102. [-30, 1, 7, 0]])
  103. a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag(
  104. ab[2, :-1], -1)
  105. b4 = array([10.0, 0.0, 2.0, 14.0j])
  106. b4by1 = b4.reshape(-1, 1)
  107. b4by2 = array([[2, 1],
  108. [-30, 4],
  109. [2, 3],
  110. [1, 3]])
  111. b4by4 = array([[1, 0, 0, 0],
  112. [0, 0, 0, 1],
  113. [0, 1, 0, 0],
  114. [0, 1, 0, 0]])
  115. for b in [b4, b4by1, b4by2, b4by4]:
  116. x = solve_banded((1, 1), ab, b)
  117. assert_array_almost_equal(dot(a, x), b)
  118. def test_check_finite(self):
  119. a = array([[1.0, 20, 0, 0],
  120. [-30, 4, 6, 0],
  121. [2, 1, 20, 2],
  122. [0, -1, 7, 14]])
  123. ab = array([[0.0, 20, 6, 2],
  124. [1, 4, 20, 14],
  125. [-30, 1, 7, 0],
  126. [2, -1, 0, 0]])
  127. l, u = 2, 1
  128. b4 = array([10.0, 0.0, 2.0, 14.0])
  129. x = solve_banded((l, u), ab, b4, check_finite=False)
  130. assert_array_almost_equal(dot(a, x), b4)
  131. def test_bad_shape(self):
  132. ab = array([[0.0, 20, 6, 2],
  133. [1, 4, 20, 14],
  134. [-30, 1, 7, 0],
  135. [2, -1, 0, 0]])
  136. l, u = 2, 1
  137. bad = array([1.0, 2.0, 3.0, 4.0]).reshape(-1, 4)
  138. assert_raises(ValueError, solve_banded, (l, u), ab, bad)
  139. assert_raises(ValueError, solve_banded, (l, u), ab, [1.0, 2.0])
  140. # Values of (l,u) are not compatible with ab.
  141. assert_raises(ValueError, solve_banded, (1, 1), ab, [1.0, 2.0])
  142. def test_1x1(self):
  143. # gh-8906 noted that the case of A@x = b with 1x1 A was handled
  144. # incorrectly; check that this is resolved. Typical case:
  145. # nupper == nlower == 0
  146. # A = [[2]]
  147. b = array([[1., 2., 3.]])
  148. ref = array([[0.5, 1.0, 1.5]])
  149. x = solve_banded((0, 0), [[2]], b)
  150. assert_allclose(x, ref, rtol=1e-15)
  151. # However, the user *can* represent the same system with garbage rows
  152. # in `ab`. Test the case with `nupper == 1, nlower == 1`.
  153. x = solve_banded((1, 1), [[0], [2], [0]], b)
  154. assert_allclose(x, ref, rtol=1e-15)
  155. assert_equal(x.dtype, np.dtype('f8'))
  156. assert_array_equal(b, [[1.0, 2.0, 3.0]])
  157. def test_native_list_arguments(self):
  158. a = [[1.0, 20, 0, 0],
  159. [-30, 4, 6, 0],
  160. [2, 1, 20, 2],
  161. [0, -1, 7, 14]]
  162. ab = [[0.0, 20, 6, 2],
  163. [1, 4, 20, 14],
  164. [-30, 1, 7, 0],
  165. [2, -1, 0, 0]]
  166. l, u = 2, 1
  167. b = [10.0, 0.0, 2.0, 14.0]
  168. x = solve_banded((l, u), ab, b)
  169. assert_array_almost_equal(dot(a, x), b)
  170. @pytest.mark.parametrize('dt_ab', [int, float, np.float32, complex, np.complex64])
  171. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  172. def test_empty(self, dt_ab, dt_b):
  173. # ab contains one empty row corresponding to the diagonal
  174. ab = np.array([[]], dtype=dt_ab)
  175. b = np.array([], dtype=dt_b)
  176. x = solve_banded((0, 0), ab, b)
  177. assert x.shape == (0,)
  178. assert x.dtype == solve(np.eye(1, dtype=dt_ab), np.ones(1, dtype=dt_b)).dtype
  179. b = np.empty((0, 0), dtype=dt_b)
  180. x = solve_banded((0, 0), ab, b)
  181. assert x.shape == (0, 0)
  182. assert x.dtype == solve(np.eye(1, dtype=dt_ab), np.ones(1, dtype=dt_b)).dtype
  183. class TestSolveHBanded:
  184. def test_01_upper(self):
  185. # Solve
  186. # [ 4 1 2 0] [1]
  187. # [ 1 4 1 2] X = [4]
  188. # [ 2 1 4 1] [1]
  189. # [ 0 2 1 4] [2]
  190. # with the RHS as a 1D array.
  191. ab = array([[0.0, 0.0, 2.0, 2.0],
  192. [-99, 1.0, 1.0, 1.0],
  193. [4.0, 4.0, 4.0, 4.0]])
  194. b = array([1.0, 4.0, 1.0, 2.0])
  195. x = solveh_banded(ab, b)
  196. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  197. def test_02_upper(self):
  198. # Solve
  199. # [ 4 1 2 0] [1 6]
  200. # [ 1 4 1 2] X = [4 2]
  201. # [ 2 1 4 1] [1 6]
  202. # [ 0 2 1 4] [2 1]
  203. #
  204. ab = array([[0.0, 0.0, 2.0, 2.0],
  205. [-99, 1.0, 1.0, 1.0],
  206. [4.0, 4.0, 4.0, 4.0]])
  207. b = array([[1.0, 6.0],
  208. [4.0, 2.0],
  209. [1.0, 6.0],
  210. [2.0, 1.0]])
  211. x = solveh_banded(ab, b)
  212. expected = array([[0.0, 1.0],
  213. [1.0, 0.0],
  214. [0.0, 1.0],
  215. [0.0, 0.0]])
  216. assert_array_almost_equal(x, expected)
  217. def test_03_upper(self):
  218. # Solve
  219. # [ 4 1 2 0] [1]
  220. # [ 1 4 1 2] X = [4]
  221. # [ 2 1 4 1] [1]
  222. # [ 0 2 1 4] [2]
  223. # with the RHS as a 2D array with shape (3,1).
  224. ab = array([[0.0, 0.0, 2.0, 2.0],
  225. [-99, 1.0, 1.0, 1.0],
  226. [4.0, 4.0, 4.0, 4.0]])
  227. b = array([1.0, 4.0, 1.0, 2.0]).reshape(-1, 1)
  228. x = solveh_banded(ab, b)
  229. assert_array_almost_equal(x, array([0., 1., 0., 0.]).reshape(-1, 1))
  230. def test_01_lower(self):
  231. # Solve
  232. # [ 4 1 2 0] [1]
  233. # [ 1 4 1 2] X = [4]
  234. # [ 2 1 4 1] [1]
  235. # [ 0 2 1 4] [2]
  236. #
  237. ab = array([[4.0, 4.0, 4.0, 4.0],
  238. [1.0, 1.0, 1.0, -99],
  239. [2.0, 2.0, 0.0, 0.0]])
  240. b = array([1.0, 4.0, 1.0, 2.0])
  241. x = solveh_banded(ab, b, lower=True)
  242. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  243. def test_02_lower(self):
  244. # Solve
  245. # [ 4 1 2 0] [1 6]
  246. # [ 1 4 1 2] X = [4 2]
  247. # [ 2 1 4 1] [1 6]
  248. # [ 0 2 1 4] [2 1]
  249. #
  250. ab = array([[4.0, 4.0, 4.0, 4.0],
  251. [1.0, 1.0, 1.0, -99],
  252. [2.0, 2.0, 0.0, 0.0]])
  253. b = array([[1.0, 6.0],
  254. [4.0, 2.0],
  255. [1.0, 6.0],
  256. [2.0, 1.0]])
  257. x = solveh_banded(ab, b, lower=True)
  258. expected = array([[0.0, 1.0],
  259. [1.0, 0.0],
  260. [0.0, 1.0],
  261. [0.0, 0.0]])
  262. assert_array_almost_equal(x, expected)
  263. def test_01_float32(self):
  264. # Solve
  265. # [ 4 1 2 0] [1]
  266. # [ 1 4 1 2] X = [4]
  267. # [ 2 1 4 1] [1]
  268. # [ 0 2 1 4] [2]
  269. #
  270. ab = array([[0.0, 0.0, 2.0, 2.0],
  271. [-99, 1.0, 1.0, 1.0],
  272. [4.0, 4.0, 4.0, 4.0]], dtype=float32)
  273. b = array([1.0, 4.0, 1.0, 2.0], dtype=float32)
  274. x = solveh_banded(ab, b)
  275. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  276. def test_02_float32(self):
  277. # Solve
  278. # [ 4 1 2 0] [1 6]
  279. # [ 1 4 1 2] X = [4 2]
  280. # [ 2 1 4 1] [1 6]
  281. # [ 0 2 1 4] [2 1]
  282. #
  283. ab = array([[0.0, 0.0, 2.0, 2.0],
  284. [-99, 1.0, 1.0, 1.0],
  285. [4.0, 4.0, 4.0, 4.0]], dtype=float32)
  286. b = array([[1.0, 6.0],
  287. [4.0, 2.0],
  288. [1.0, 6.0],
  289. [2.0, 1.0]], dtype=float32)
  290. x = solveh_banded(ab, b)
  291. expected = array([[0.0, 1.0],
  292. [1.0, 0.0],
  293. [0.0, 1.0],
  294. [0.0, 0.0]])
  295. assert_array_almost_equal(x, expected)
  296. def test_01_complex(self):
  297. # Solve
  298. # [ 4 -j 2 0] [2-j]
  299. # [ j 4 -j 2] X = [4-j]
  300. # [ 2 j 4 -j] [4+j]
  301. # [ 0 2 j 4] [2+j]
  302. #
  303. ab = array([[0.0, 0.0, 2.0, 2.0],
  304. [-99, -1.0j, -1.0j, -1.0j],
  305. [4.0, 4.0, 4.0, 4.0]])
  306. b = array([2-1.0j, 4.0-1j, 4+1j, 2+1j])
  307. x = solveh_banded(ab, b)
  308. assert_array_almost_equal(x, [0.0, 1.0, 1.0, 0.0])
  309. def test_02_complex(self):
  310. # Solve
  311. # [ 4 -j 2 0] [2-j 2+4j]
  312. # [ j 4 -j 2] X = [4-j -1-j]
  313. # [ 2 j 4 -j] [4+j 4+2j]
  314. # [ 0 2 j 4] [2+j j]
  315. #
  316. ab = array([[0.0, 0.0, 2.0, 2.0],
  317. [-99, -1.0j, -1.0j, -1.0j],
  318. [4.0, 4.0, 4.0, 4.0]])
  319. b = array([[2-1j, 2+4j],
  320. [4.0-1j, -1-1j],
  321. [4.0+1j, 4+2j],
  322. [2+1j, 1j]])
  323. x = solveh_banded(ab, b)
  324. expected = array([[0.0, 1.0j],
  325. [1.0, 0.0],
  326. [1.0, 1.0],
  327. [0.0, 0.0]])
  328. assert_array_almost_equal(x, expected)
  329. def test_tridiag_01_upper(self):
  330. # Solve
  331. # [ 4 1 0] [1]
  332. # [ 1 4 1] X = [4]
  333. # [ 0 1 4] [1]
  334. # with the RHS as a 1D array.
  335. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  336. b = array([1.0, 4.0, 1.0])
  337. x = solveh_banded(ab, b)
  338. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  339. def test_tridiag_02_upper(self):
  340. # Solve
  341. # [ 4 1 0] [1 4]
  342. # [ 1 4 1] X = [4 2]
  343. # [ 0 1 4] [1 4]
  344. #
  345. ab = array([[-99, 1.0, 1.0],
  346. [4.0, 4.0, 4.0]])
  347. b = array([[1.0, 4.0],
  348. [4.0, 2.0],
  349. [1.0, 4.0]])
  350. x = solveh_banded(ab, b)
  351. expected = array([[0.0, 1.0],
  352. [1.0, 0.0],
  353. [0.0, 1.0]])
  354. assert_array_almost_equal(x, expected)
  355. def test_tridiag_03_upper(self):
  356. # Solve
  357. # [ 4 1 0] [1]
  358. # [ 1 4 1] X = [4]
  359. # [ 0 1 4] [1]
  360. # with the RHS as a 2D array with shape (3,1).
  361. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  362. b = array([1.0, 4.0, 1.0]).reshape(-1, 1)
  363. x = solveh_banded(ab, b)
  364. assert_array_almost_equal(x, array([0.0, 1.0, 0.0]).reshape(-1, 1))
  365. def test_tridiag_01_lower(self):
  366. # Solve
  367. # [ 4 1 0] [1]
  368. # [ 1 4 1] X = [4]
  369. # [ 0 1 4] [1]
  370. #
  371. ab = array([[4.0, 4.0, 4.0],
  372. [1.0, 1.0, -99]])
  373. b = array([1.0, 4.0, 1.0])
  374. x = solveh_banded(ab, b, lower=True)
  375. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  376. def test_tridiag_02_lower(self):
  377. # Solve
  378. # [ 4 1 0] [1 4]
  379. # [ 1 4 1] X = [4 2]
  380. # [ 0 1 4] [1 4]
  381. #
  382. ab = array([[4.0, 4.0, 4.0],
  383. [1.0, 1.0, -99]])
  384. b = array([[1.0, 4.0],
  385. [4.0, 2.0],
  386. [1.0, 4.0]])
  387. x = solveh_banded(ab, b, lower=True)
  388. expected = array([[0.0, 1.0],
  389. [1.0, 0.0],
  390. [0.0, 1.0]])
  391. assert_array_almost_equal(x, expected)
  392. def test_tridiag_01_float32(self):
  393. # Solve
  394. # [ 4 1 0] [1]
  395. # [ 1 4 1] X = [4]
  396. # [ 0 1 4] [1]
  397. #
  398. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]], dtype=float32)
  399. b = array([1.0, 4.0, 1.0], dtype=float32)
  400. x = solveh_banded(ab, b)
  401. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  402. def test_tridiag_02_float32(self):
  403. # Solve
  404. # [ 4 1 0] [1 4]
  405. # [ 1 4 1] X = [4 2]
  406. # [ 0 1 4] [1 4]
  407. #
  408. ab = array([[-99, 1.0, 1.0],
  409. [4.0, 4.0, 4.0]], dtype=float32)
  410. b = array([[1.0, 4.0],
  411. [4.0, 2.0],
  412. [1.0, 4.0]], dtype=float32)
  413. x = solveh_banded(ab, b)
  414. expected = array([[0.0, 1.0],
  415. [1.0, 0.0],
  416. [0.0, 1.0]])
  417. assert_array_almost_equal(x, expected)
  418. def test_tridiag_01_complex(self):
  419. # Solve
  420. # [ 4 -j 0] [ -j]
  421. # [ j 4 -j] X = [4-j]
  422. # [ 0 j 4] [4+j]
  423. #
  424. ab = array([[-99, -1.0j, -1.0j], [4.0, 4.0, 4.0]])
  425. b = array([-1.0j, 4.0-1j, 4+1j])
  426. x = solveh_banded(ab, b)
  427. assert_array_almost_equal(x, [0.0, 1.0, 1.0])
  428. def test_tridiag_02_complex(self):
  429. # Solve
  430. # [ 4 -j 0] [ -j 4j]
  431. # [ j 4 -j] X = [4-j -1-j]
  432. # [ 0 j 4] [4+j 4 ]
  433. #
  434. ab = array([[-99, -1.0j, -1.0j],
  435. [4.0, 4.0, 4.0]])
  436. b = array([[-1j, 4.0j],
  437. [4.0-1j, -1.0-1j],
  438. [4.0+1j, 4.0]])
  439. x = solveh_banded(ab, b)
  440. expected = array([[0.0, 1.0j],
  441. [1.0, 0.0],
  442. [1.0, 1.0]])
  443. assert_array_almost_equal(x, expected)
  444. def test_check_finite(self):
  445. # Solve
  446. # [ 4 1 0] [1]
  447. # [ 1 4 1] X = [4]
  448. # [ 0 1 4] [1]
  449. # with the RHS as a 1D array.
  450. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  451. b = array([1.0, 4.0, 1.0])
  452. x = solveh_banded(ab, b, check_finite=False)
  453. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  454. def test_bad_shapes(self):
  455. ab = array([[-99, 1.0, 1.0],
  456. [4.0, 4.0, 4.0]])
  457. b = array([[1.0, 4.0],
  458. [4.0, 2.0]])
  459. assert_raises(ValueError, solveh_banded, ab, b)
  460. assert_raises(ValueError, solveh_banded, ab, [1.0, 2.0])
  461. assert_raises(ValueError, solveh_banded, ab, [1.0])
  462. def test_1x1(self):
  463. x = solveh_banded([[1]], [[1, 2, 3]])
  464. assert_array_equal(x, [[1.0, 2.0, 3.0]])
  465. assert_equal(x.dtype, np.dtype('f8'))
  466. def test_native_list_arguments(self):
  467. # Same as test_01_upper, using python's native list.
  468. ab = [[0.0, 0.0, 2.0, 2.0],
  469. [-99, 1.0, 1.0, 1.0],
  470. [4.0, 4.0, 4.0, 4.0]]
  471. b = [1.0, 4.0, 1.0, 2.0]
  472. x = solveh_banded(ab, b)
  473. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  474. @pytest.mark.parametrize('dt_ab', [int, float, np.float32, complex, np.complex64])
  475. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  476. def test_empty(self, dt_ab, dt_b):
  477. # ab contains one empty row corresponding to the diagonal
  478. ab = np.array([[]], dtype=dt_ab)
  479. b = np.array([], dtype=dt_b)
  480. x = solveh_banded(ab, b)
  481. assert x.shape == (0,)
  482. assert x.dtype == solve(np.eye(1, dtype=dt_ab), np.ones(1, dtype=dt_b)).dtype
  483. b = np.empty((0, 0), dtype=dt_b)
  484. x = solveh_banded(ab, b)
  485. assert x.shape == (0, 0)
  486. assert x.dtype == solve(np.eye(1, dtype=dt_ab), np.ones(1, dtype=dt_b)).dtype
  487. class TestSolve:
  488. def test_20Feb04_bug(self):
  489. a = [[1, 1], [1.0, 0]] # ok
  490. x0 = solve(a, [1, 0j])
  491. assert_array_almost_equal(dot(a, x0), [1, 0])
  492. # gives failure with clapack.zgesv(..,rowmajor=0)
  493. a = [[1, 1], [1.2, 0]]
  494. b = [1, 0j]
  495. x0 = solve(a, b)
  496. assert_array_almost_equal(dot(a, x0), [1, 0])
  497. def test_simple(self):
  498. a = [[1, 20], [-30, 4]]
  499. for b in ([[1, 0], [0, 1]],
  500. [1, 0],
  501. [[2, 1], [-30, 4]]
  502. ):
  503. x = solve(a, b)
  504. assert_array_almost_equal(dot(a, x), b)
  505. def test_simple_complex(self):
  506. a = array([[5, 2], [2j, 4]], 'D')
  507. for b in ([1j, 0],
  508. [[1j, 1j], [0, 2]],
  509. [1, 0j],
  510. array([1, 0], 'D'),
  511. ):
  512. x = solve(a, b)
  513. assert_array_almost_equal(dot(a, x), b)
  514. def test_simple_pos(self):
  515. a = [[2, 3], [3, 5]]
  516. for lower in [0, 1]:
  517. for b in ([[1, 0], [0, 1]],
  518. [1, 0]
  519. ):
  520. x = solve(a, b, assume_a='pos', lower=lower)
  521. assert_array_almost_equal(dot(a, x), b)
  522. def test_simple_pos_complexb(self):
  523. a = [[5, 2], [2, 4]]
  524. for b in ([1j, 0],
  525. [[1j, 1j], [0, 2]],
  526. ):
  527. x = solve(a, b, assume_a='pos')
  528. assert_array_almost_equal(dot(a, x), b)
  529. def test_simple_sym(self):
  530. a = [[2, 3], [3, -5]]
  531. for lower in [0, 1]:
  532. for b in ([[1, 0], [0, 1]],
  533. [1, 0]
  534. ):
  535. x = solve(a, b, assume_a='sym', lower=lower)
  536. assert_array_almost_equal(dot(a, x), b)
  537. def test_simple_sym_complexb(self):
  538. a = [[5, 2], [2, -4]]
  539. for b in ([1j, 0],
  540. [[1j, 1j], [0, 2]]
  541. ):
  542. x = solve(a, b, assume_a='sym')
  543. assert_array_almost_equal(dot(a, x), b)
  544. def test_simple_sym_complex(self):
  545. a = [[5, 2+1j], [2+1j, -4]]
  546. for b in ([1j, 0],
  547. [1, 0],
  548. [[1j, 1j], [0, 2]]
  549. ):
  550. x = solve(a, b, assume_a='sym')
  551. assert_array_almost_equal(dot(a, x), b)
  552. def test_simple_her_actuallysym(self):
  553. a = [[2, 3], [3, -5]]
  554. for lower in [0, 1]:
  555. for b in ([[1, 0], [0, 1]],
  556. [1, 0],
  557. [1j, 0],
  558. ):
  559. x = solve(a, b, assume_a='her', lower=lower)
  560. assert_array_almost_equal(dot(a, x), b)
  561. def test_simple_her(self):
  562. a = [[5, 2+1j], [2-1j, -4]]
  563. for b in ([1j, 0],
  564. [1, 0],
  565. [[1j, 1j], [0, 2]]
  566. ):
  567. x = solve(a, b, assume_a='her')
  568. assert_array_almost_equal(dot(a, x), b)
  569. def test_nils_20Feb04(self):
  570. rng = np.random.default_rng(1234)
  571. n = 2
  572. A = rng.random([n, n])+rng.random([n, n])*1j
  573. X = zeros((n, n), 'D')
  574. Ainv = inv(A)
  575. R = identity(n)+identity(n)*0j
  576. for i in arange(0, n):
  577. r = R[:, i]
  578. X[:, i] = solve(A, r)
  579. assert_array_almost_equal(X, Ainv)
  580. def test_random(self):
  581. rng = np.random.default_rng(1234)
  582. n = 20
  583. a = rng.random([n, n])
  584. for i in range(n):
  585. a[i, i] = 20*(.1+a[i, i])
  586. for i in range(4):
  587. b = rng.random([n, 3])
  588. x = solve(a, b)
  589. assert_array_almost_equal(dot(a, x), b)
  590. def test_random_complex(self):
  591. rng = np.random.default_rng(1234)
  592. n = 20
  593. a = rng.random([n, n]) + 1j * rng.random([n, n])
  594. for i in range(n):
  595. a[i, i] = 20*(.1+a[i, i])
  596. for i in range(2):
  597. b = rng.random([n, 3])
  598. x = solve(a, b)
  599. assert_array_almost_equal(dot(a, x), b)
  600. def test_random_sym(self):
  601. rng = np.random.default_rng(1234)
  602. n = 20
  603. a = rng.random([n, n])
  604. for i in range(n):
  605. a[i, i] = abs(20*(.1+a[i, i]))
  606. for j in range(i):
  607. a[i, j] = a[j, i]
  608. for i in range(4):
  609. b = rng.random([n])
  610. x = solve(a, b, assume_a="pos")
  611. assert_array_almost_equal(dot(a, x), b)
  612. def test_random_sym_complex(self):
  613. rng = np.random.default_rng(1234)
  614. n = 20
  615. a = rng.random([n, n])
  616. a = a + 1j*rng.random([n, n])
  617. for i in range(n):
  618. a[i, i] = abs(20*(.1+a[i, i]))
  619. for j in range(i):
  620. a[i, j] = conjugate(a[j, i])
  621. b = rng.random([n])+2j*rng.random([n])
  622. for i in range(2):
  623. x = solve(a, b, assume_a="pos")
  624. assert_array_almost_equal(dot(a, x), b)
  625. def test_check_finite(self):
  626. a = [[1, 20], [-30, 4]]
  627. for b in ([[1, 0], [0, 1]], [1, 0],
  628. [[2, 1], [-30, 4]]):
  629. x = solve(a, b, check_finite=False)
  630. assert_array_almost_equal(dot(a, x), b)
  631. def test_scalar_a_and_1D_b(self):
  632. a = 1
  633. b = [1, 2, 3]
  634. x = solve(a, b)
  635. assert_array_almost_equal(x.ravel(), b)
  636. assert_(x.shape == (3,), 'Scalar_a_1D_b test returned wrong shape')
  637. def test_simple2(self):
  638. a = np.array([[1.80, 2.88, 2.05, -0.89],
  639. [525.00, -295.00, -95.00, -380.00],
  640. [1.58, -2.69, -2.90, -1.04],
  641. [-1.11, -0.66, -0.59, 0.80]])
  642. b = np.array([[9.52, 18.47],
  643. [2435.00, 225.00],
  644. [0.77, -13.28],
  645. [-6.22, -6.21]])
  646. x = solve(a, b)
  647. assert_array_almost_equal(x, np.array([[1., -1, 3, -5],
  648. [3, 2, 4, 1]]).T)
  649. def test_simple_complex2(self):
  650. a = np.array([[-1.34+2.55j, 0.28+3.17j, -6.39-2.20j, 0.72-0.92j],
  651. [-1.70-14.10j, 33.10-1.50j, -1.50+13.40j, 12.90+13.80j],
  652. [-3.29-2.39j, -1.91+4.42j, -0.14-1.35j, 1.72+1.35j],
  653. [2.41+0.39j, -0.56+1.47j, -0.83-0.69j, -1.96+0.67j]])
  654. b = np.array([[26.26+51.78j, 31.32-6.70j],
  655. [64.30-86.80j, 158.60-14.20j],
  656. [-5.75+25.31j, -2.15+30.19j],
  657. [1.16+2.57j, -2.56+7.55j]])
  658. x = solve(a, b)
  659. assert_array_almost_equal(x, np. array([[1+1.j, -1-2.j],
  660. [2-3.j, 5+1.j],
  661. [-4-5.j, -3+4.j],
  662. [6.j, 2-3.j]]))
  663. @pytest.mark.parametrize("assume_a", ['her', 'sym'])
  664. def test_symmetric_hermitian(self, assume_a):
  665. # An upper triangular matrix will be used for symmetric/hermitian matrix a
  666. a = np.array([[-1.84, 0.11-0.11j, -1.78-1.18j, 3.91-1.50j],
  667. [0, -4.63, -1.84+0.03j, 2.21+0.21j],
  668. [0, 0, -8.87, 1.58-0.90j],
  669. [0, 0, 0, -1.36]])
  670. b = np.array([[2.98-10.18j, 28.68-39.89j],
  671. [-9.58+3.88j, -24.79-8.40j],
  672. [-0.77-16.05j, 4.23-70.02j],
  673. [7.79+5.48j, -35.39+18.01j]])
  674. a2 = a.T if assume_a == 'sym' else a.conj().T # for testing `lower`
  675. a3 = a + a2 # for reference solution
  676. a3[np.arange(4), np.arange(4)] = np.diag(a)
  677. ref = solve(a3, b, assume_a='general')
  678. x = solve(a, b, assume_a=assume_a)
  679. assert_array_almost_equal(x, ref)
  680. # Also transpose(/conjugate) `a` and test for lower triangular data
  681. # This also tests gh-22265 resolution; otherwise, a warning would be emitted
  682. x = solve(a2, b, assume_a=assume_a, lower=True)
  683. assert_array_almost_equal(x, ref)
  684. def test_pos_and_sym(self):
  685. A = np.arange(1, 10).reshape(3, 3)
  686. x = solve(np.tril(A)/9, np.ones(3), assume_a='pos')
  687. assert_array_almost_equal(x, [9., 1.8, 1.])
  688. x = solve(np.tril(A)/9, np.ones(3), assume_a='sym')
  689. assert_array_almost_equal(x, [9., 1.8, 1.])
  690. def test_singularity(self):
  691. a = np.array([[1, 0, 0, 0, 0, 0, 1, 0, 1],
  692. [1, 1, 1, 0, 0, 0, 1, 0, 1],
  693. [0, 1, 1, 0, 0, 0, 1, 0, 1],
  694. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  695. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  696. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  697. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  698. [1, 1, 1, 1, 1, 1, 1, 1, 1],
  699. [1, 1, 1, 1, 1, 1, 1, 1, 1]])
  700. b = np.arange(9)[:, None]
  701. assert_raises(LinAlgError, solve, a, b)
  702. @pytest.mark.parametrize('structure',
  703. ('diagonal', 'tridiagonal', 'lower triangular',
  704. 'upper triangular', 'symmetric', 'hermitian',
  705. 'positive definite', 'general', 'banded', None))
  706. def test_ill_condition_warning(self, structure):
  707. rng = np.random.default_rng(234859349452)
  708. n = 10
  709. d = np.logspace(0, 50, n)
  710. A = np.diag(d)
  711. b = rng.random(size=n)
  712. message = "(Ill-conditioned matrix|An ill-conditioned matrix)"
  713. with pytest.warns(LinAlgWarning, match=message):
  714. solve(A, b, assume_a=structure)
  715. @pytest.mark.parametrize('structure',
  716. ('diagonal', 'tridiagonal', 'lower triangular',
  717. 'upper triangular', 'symmetric', 'hermitian',
  718. 'positive definite', 'general', None))
  719. def test_exactly_singular_gh22263(self, structure):
  720. n = 10
  721. A = np.zeros((n, n))
  722. b = np.ones(n)
  723. with (pytest.raises(LinAlgError, match="singular"), np.errstate(all='ignore')):
  724. solve(A, b, assume_a=structure)
  725. def test_multiple_rhs(self):
  726. a = np.eye(2)
  727. rng = np.random.default_rng(1234)
  728. b = rng.random((2, 12))
  729. x = solve(a, b)
  730. assert_array_almost_equal(x, b)
  731. def test_transposed_keyword(self):
  732. A = np.arange(9).reshape(3, 3) + 1
  733. x = solve(np.tril(A)/9, np.ones(3), transposed=True)
  734. assert_array_almost_equal(x, [1.2, 0.2, 1])
  735. x = solve(np.tril(A)/9, np.ones(3), transposed=False)
  736. assert_array_almost_equal(x, [9, -5.4, -1.2])
  737. @pytest.mark.skip(reason="1. why? 2. deprecate the kwarg altogether?")
  738. def test_transposed_notimplemented(self):
  739. a = np.eye(3).astype(complex)
  740. with assert_raises(NotImplementedError):
  741. solve(a, a, transposed=True)
  742. def test_nonsquare_a(self):
  743. assert_raises(ValueError, solve, [1, 2], 1)
  744. def test_size_mismatch_with_1D_b(self):
  745. assert_array_almost_equal(solve(np.eye(3), np.ones(3)), np.ones(3))
  746. assert_raises(ValueError, solve, np.eye(3), np.ones(4))
  747. def test_assume_a_keyword(self):
  748. assert_raises(ValueError, solve, 1, 1, assume_a='zxcv')
  749. @pytest.mark.parametrize("size", [10, 100])
  750. @pytest.mark.parametrize("assume_a", ['gen', 'sym', 'pos', 'her', 'tridiagonal'])
  751. @pytest.mark.parametrize(
  752. "dtype", [np.float32, np.float64, np.complex64, np.complex128]
  753. )
  754. def test_all_type_size_routine_combinations(self, size, dtype, assume_a):
  755. rng = np.random.default_rng(1234)
  756. is_complex = dtype in (np.complex64, np.complex128)
  757. a = rng.standard_normal((size, size)).astype(dtype)
  758. b = rng.standard_normal(size).astype(dtype)
  759. if is_complex:
  760. a += (1j*rng.standard_normal((size, size))).astype(dtype)
  761. if assume_a == 'sym': # Can still be complex but only symmetric
  762. a = a + a.T
  763. elif assume_a == 'her': # Handle hermitian matrices here instead
  764. a = a + a.T.conj()
  765. elif assume_a == 'pos':
  766. a = a.T.conj() @ a + 0.1*np.eye(size)
  767. elif assume_a == 'tridiagonal':
  768. a = (np.diag(np.diag(a)) +
  769. np.diag(np.diag(a, 1), 1) +
  770. np.diag(np.diag(a, -1), -1)
  771. )
  772. tol = 1e-12 if dtype in (np.float64, np.complex128) else 1e-6
  773. if assume_a in ['gen', 'sym', 'her']:
  774. # We revert the tolerance from before
  775. # 4b4a6e7c34fa4060533db38f9a819b98fa81476c
  776. if dtype in (np.float32, np.complex64):
  777. tol *= 10
  778. x = solve(a, b, assume_a=assume_a)
  779. assert_allclose(a @ x, b, atol=tol * size, rtol=tol * size)
  780. if assume_a == 'sym' and not is_complex:
  781. x = solve(a, b, assume_a=assume_a, transposed=True)
  782. assert_allclose(a @ x, b, atol=tol * size, rtol=tol * size)
  783. @pytest.mark.parametrize('dt_a', [int, float, np.float32, complex, np.complex64])
  784. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  785. def test_empty(self, dt_a, dt_b):
  786. a = np.empty((0, 0), dtype=dt_a)
  787. b = np.empty(0, dtype=dt_b)
  788. x = solve(a, b)
  789. assert x.size == 0
  790. dt_nonempty = solve(np.eye(2, dtype=dt_a), np.ones(2, dtype=dt_b)).dtype
  791. assert x.dtype == dt_nonempty
  792. assert x.shape == np.linalg.solve(a, b).shape
  793. a = np.ones((3, 0, 2, 2), dtype=dt_a)
  794. b = np.ones((2, 4), dtype=dt_b)
  795. x = solve(a, b)
  796. assert x.shape == (3, 0, 2, 4)
  797. assert x.dtype == dt_nonempty
  798. def test_empty_rhs(self):
  799. a = np.eye(2)
  800. b = [[], []]
  801. x = solve(a, b)
  802. assert_(x.size == 0, 'Returned array is not empty')
  803. assert_(x.shape == (2, 0), 'Returned empty array shape is wrong')
  804. @pytest.mark.parametrize('dtype', [np.float64, np.complex128])
  805. @pytest.mark.parametrize('assume_a', ['diagonal', 'tridiagonal', 'banded',
  806. 'lower triangular', 'upper triangular',
  807. 'pos', 'positive definite',
  808. 'symmetric', 'hermitian', 'banded',
  809. 'general', 'sym', 'her', 'gen'])
  810. @pytest.mark.parametrize('nrhs', [(), (5,)])
  811. @pytest.mark.parametrize('transposed', [True, False])
  812. @pytest.mark.parametrize('overwrite', [True, False])
  813. @pytest.mark.parametrize('fortran', [True, False])
  814. def test_structure_detection(self, dtype, assume_a, nrhs, transposed,
  815. overwrite, fortran):
  816. rng = np.random.default_rng(982345982439826)
  817. n = 5 if not assume_a == 'banded' else 20
  818. b = rng.random(size=(n,) + nrhs)
  819. A = rng.random(size=(n, n))
  820. if np.issubdtype(dtype, np.complexfloating):
  821. b = b + rng.random(size=(n,) + nrhs) * 1j
  822. A = A + rng.random(size=(n, n)) * 1j
  823. if assume_a == 'diagonal':
  824. A = np.diag(np.diag(A))
  825. elif assume_a == 'lower triangular':
  826. A = np.tril(A)
  827. elif assume_a == 'upper triangular':
  828. A = np.triu(A)
  829. elif assume_a == 'tridiagonal':
  830. A = (np.diag(np.diag(A))
  831. + np.diag(np.diag(A, -1), -1)
  832. + np.diag(np.diag(A, 1), 1))
  833. elif assume_a == 'banded':
  834. A = np.triu(np.tril(A, 2), -1)
  835. elif assume_a in {'symmetric', 'sym'}:
  836. A = A + A.T
  837. elif assume_a in {'hermitian', 'her'}:
  838. A = A + A.conj().T
  839. elif assume_a in {'positive definite', 'pos'}:
  840. A = A @ A.T.conj()
  841. if fortran:
  842. A = np.asfortranarray(A)
  843. A_copy = A.copy(order='A')
  844. b_copy = b.copy()
  845. if np.issubdtype(dtype, np.complexfloating) and transposed:
  846. message = "scipy.linalg.solve can currently..."
  847. with pytest.raises(NotImplementedError, match=message):
  848. solve(A, b, overwrite_a=overwrite, overwrite_b=overwrite,
  849. transposed=transposed)
  850. return
  851. res = solve(A, b, overwrite_a=overwrite, overwrite_b=overwrite,
  852. transposed=transposed, assume_a=assume_a)
  853. # Check that solution this solution is *correct*
  854. ref = np.linalg.solve(A_copy.T if transposed else A_copy, b_copy)
  855. assert_allclose(res, ref)
  856. # Check that `solve` correctly identifies the structure and returns
  857. # *exactly* the same solution whether `assume_a` is specified or not
  858. if assume_a != 'banded': # structure detection removed for banded
  859. assert_allclose(
  860. solve(A_copy, b_copy, transposed=transposed), res, atol=1e-15
  861. )
  862. # Check that overwrite was respected
  863. if not overwrite:
  864. assert_equal(A, A_copy)
  865. assert_equal(b, b_copy)
  866. @pytest.mark.skipif(
  867. np.__version__ < '2', reason="solve chokes on b.ndim == 1 in numpy < 2"
  868. )
  869. @pytest.mark.parametrize(
  870. "assume_a",
  871. [
  872. None, "diagonal", "general", "upper triangular", "lower triangular", "pos",
  873. ]
  874. )
  875. def test_vs_np_solve(self, assume_a):
  876. e = np.eye(2)
  877. a = np.arange(1, 4*3*2 + 1).reshape((4, 3, 2, 1, 1)) * e
  878. b = np.ones(2)
  879. assert_allclose(solve(a, b, assume_a=assume_a), np.linalg.solve(a, b))
  880. b = np.ones((2, 1))
  881. assert_allclose(solve(a, b, assume_a=assume_a), np.linalg.solve(a, b))
  882. b = np.ones((2, 2)) * [1, 2]
  883. assert_allclose(solve(a, b, assume_a=assume_a), np.linalg.solve(a, b))
  884. def test_pos_lower(self):
  885. # regression test for
  886. # https://github.com/scipy/scipy/pull/23071#issuecomment-3085826112
  887. rng = np.random.default_rng(0)
  888. a = rng.normal(size=(4, 4))
  889. a = np.tril(np.matmul(a, np.conj(a.T))) # lower triangle of hermitian array
  890. b = rng.normal(size=(4, 2))
  891. out = solve(a, b, assume_a='pos', lower=True)
  892. aa = a + a.T - np.diag(np.diag(a)) # the full hermitian array
  893. result_np = np.linalg.solve(aa, b)
  894. assert_allclose(out, result_np, atol=1e-15)
  895. # repeat with uplo='U'
  896. out = solve(a.T, b, assume_a='pos', lower=False)
  897. assert_allclose(out, result_np, atol=1e-15)
  898. def test_readonly(self):
  899. a = np.eye(3)
  900. a.flags.writeable = False
  901. b = np.ones(3)
  902. x = solve(a, b)
  903. assert_allclose(x, b, atol=1e-14)
  904. @parametrize_overwrite_arg
  905. def test_batch_negative_stride(self, overwrite_kw):
  906. a = np.arange(3*8).reshape(2, 3, 2, 2)
  907. a = a[:, ::-1, :, :]
  908. b = np.ones(2)
  909. x = solve(a, b, **overwrite_kw)
  910. assert x.shape == a.shape[:-1]
  911. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  912. # use b with a negative stride now
  913. b = np.ones((2, 4))[:, ::-1]
  914. x = solve(a, b, **overwrite_kw)
  915. assert x.shape == a.shape[:-1] + (b.shape[-1],)
  916. assert_allclose(a @ x - b, 0, atol=1e-14)
  917. @parametrize_overwrite_arg
  918. def test_core_negative_stride(self, overwrite_kw):
  919. a = np.arange(3*8).reshape(2, 3, 2, 2)
  920. a = a[:, :, ::-1, :]
  921. b = np.ones(2)
  922. x = solve(a, b, **overwrite_kw)
  923. assert x.shape == a.shape[:-1]
  924. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  925. # use b with a negative stride now
  926. b = np.ones((2, 4))[::-1, :]
  927. x = solve(a, b, **overwrite_kw)
  928. assert x.shape == a.shape[:-1] + (b.shape[-1],)
  929. assert_allclose(a @ x - b, 0, atol=1e-14)
  930. @parametrize_overwrite_arg
  931. def test_core_non_contiguous(self, overwrite_kw):
  932. a = np.arange(3*8*2).reshape(2, 3, 2, 4)
  933. a = a[..., ::2]
  934. b = np.ones(2)
  935. x = solve(a, b, **overwrite_kw)
  936. assert x.shape == a.shape[:-1]
  937. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  938. # use strided b now
  939. b = np.ones(4)[::2]
  940. x = solve(a, b, **overwrite_kw)
  941. assert x.shape == a.shape[:-1]
  942. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  943. @parametrize_overwrite_arg
  944. def test_batch_non_contiguous(self, overwrite_kw):
  945. a = np.arange(3*8*2).reshape(2, 6, 2, 2)
  946. a = a[:, ::2, ...]
  947. b = np.ones(2)
  948. x = solve(a, b, **overwrite_kw)
  949. assert x.shape == a.shape[:-1]
  950. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  951. # use strided b now
  952. b = np.ones((2, 6))[:, ::2]
  953. x = solve(a, b, **overwrite_kw)
  954. assert x.shape == a.shape[:-1] + (b.shape[-1],)
  955. assert_allclose(a @ x - b, 0, atol=1e-14)
  956. @parametrize_overwrite_arg
  957. def test_batch_weird_strides(self, overwrite_kw):
  958. a = np.arange(3*8*2).reshape(2, 3, 2, 2, 2)
  959. a = a.transpose(1, 3, 4, 0, 2)
  960. b = np.ones(2)
  961. x = solve(a, b, **overwrite_kw)
  962. assert x.shape == a.shape[:-1]
  963. assert_allclose(a @ x[..., None] - b, 0, atol=1e-14)
  964. def test_posdef_not_posdef(self):
  965. # the `b` matrix is invertible but not positive definite
  966. a = np.arange(9).reshape(3, 3)
  967. A = a + a.T + np.eye(3)
  968. b = np.ones(3)
  969. # cholesky solver fails, and the routine falls back to the general inverse
  970. x0 = solve(A, b)
  971. assert_allclose(A @ x0, b, atol=1e-14)
  972. # but it does not fall back if `assume_a` is given
  973. with assert_raises(LinAlgError):
  974. solve(A, b, assume_a='pos')
  975. def test_diagonal(self):
  976. a = np.stack([np.triu(np.ones((3, 3))), np.diag(np.arange(1, 4))])
  977. b = np.ones(3)
  978. x = solve(a, b)
  979. # basic diagonal solve
  980. assert_allclose(x[1, ...], 1 / np.arange(1, 4), atol=1e-14)
  981. # ill-conditioned inputs warn
  982. a = np.asarray([[1e30, 0], [0, 1]])
  983. b = np.ones(2)
  984. with pytest.warns(LinAlgWarning):
  985. solve(a, b, assume_a="diagonal")
  986. # singular input raises
  987. a = np.asarray([[0, 0], [0, 1]])
  988. b = np.ones(2)
  989. with pytest.raises(LinAlgError):
  990. solve(a, b, assume_a="diagonal")
  991. def test_tridiagonal(self):
  992. n = 4
  993. a = -2*np.diag(np.ones(n)) + np.diag(np.ones(3), 1) + np.diag(np.ones(3), -1)
  994. a = np.stack([np.triu(np.ones((n, n))), a])
  995. b = np.ones(4)
  996. x = solve(a, b)
  997. # basic tridiag solve
  998. assert_allclose(x[1, ...], np.asarray([-2., -3., -3., -2.]), atol=1e-15)
  999. # ill-conditioned inputs warn
  1000. a[1, 0, 0] = 1e20
  1001. with pytest.warns(LinAlgWarning):
  1002. solve(a, b, assume_a="tridiagonal")
  1003. # singular inputss raise
  1004. a[1, 0, 0] = a[1, 0, 1] = 0
  1005. with pytest.raises(LinAlgError):
  1006. solve(a, b, assume_a="tridiagonal")
  1007. class TestSolveTriangular:
  1008. def test_simple(self):
  1009. """
  1010. solve_triangular on a simple 2x2 matrix.
  1011. """
  1012. A = array([[1, 0], [1, 2]])
  1013. b = [1, 1]
  1014. sol = solve_triangular(A, b, lower=True)
  1015. assert_array_almost_equal(sol, [1, 0])
  1016. # check that it works also for non-contiguous matrices
  1017. sol = solve_triangular(A.T, b, lower=False)
  1018. assert_array_almost_equal(sol, [.5, .5])
  1019. # and that it gives the same result as trans=1
  1020. sol = solve_triangular(A, b, lower=True, trans=1)
  1021. assert_array_almost_equal(sol, [.5, .5])
  1022. b = identity(2)
  1023. sol = solve_triangular(A, b, lower=True, trans=1)
  1024. assert_array_almost_equal(sol, [[1., -.5], [0, 0.5]])
  1025. def test_simple_complex(self):
  1026. """
  1027. solve_triangular on a simple 2x2 complex matrix
  1028. """
  1029. A = array([[1+1j, 0], [1j, 2]])
  1030. b = identity(2)
  1031. sol = solve_triangular(A, b, lower=True, trans=1)
  1032. assert_array_almost_equal(sol, [[.5-.5j, -.25-.25j], [0, 0.5]])
  1033. # check other option combinations with complex rhs
  1034. b = np.diag([1+1j, 1+2j])
  1035. sol = solve_triangular(A, b, lower=True, trans=0)
  1036. assert_array_almost_equal(sol, [[1, 0], [-0.5j, 0.5+1j]])
  1037. sol = solve_triangular(A, b, lower=True, trans=1)
  1038. assert_array_almost_equal(sol, [[1, 0.25-0.75j], [0, 0.5+1j]])
  1039. sol = solve_triangular(A, b, lower=True, trans=2)
  1040. assert_array_almost_equal(sol, [[1j, -0.75-0.25j], [0, 0.5+1j]])
  1041. sol = solve_triangular(A.T, b, lower=False, trans=0)
  1042. assert_array_almost_equal(sol, [[1, 0.25-0.75j], [0, 0.5+1j]])
  1043. sol = solve_triangular(A.T, b, lower=False, trans=1)
  1044. assert_array_almost_equal(sol, [[1, 0], [-0.5j, 0.5+1j]])
  1045. sol = solve_triangular(A.T, b, lower=False, trans=2)
  1046. assert_array_almost_equal(sol, [[1j, 0], [-0.5, 0.5+1j]])
  1047. def test_check_finite(self):
  1048. """
  1049. solve_triangular on a simple 2x2 matrix.
  1050. """
  1051. A = array([[1, 0], [1, 2]])
  1052. b = [1, 1]
  1053. sol = solve_triangular(A, b, lower=True, check_finite=False)
  1054. assert_array_almost_equal(sol, [1, 0])
  1055. @pytest.mark.parametrize('dt_a', [int, float, np.float32, complex, np.complex64])
  1056. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  1057. def test_empty(self, dt_a, dt_b):
  1058. a = np.empty((0, 0), dtype=dt_a)
  1059. b = np.empty(0, dtype=dt_b)
  1060. x = solve_triangular(a, b)
  1061. assert x.size == 0
  1062. dt_nonempty = solve_triangular(
  1063. np.eye(2, dtype=dt_a), np.ones(2, dtype=dt_b)
  1064. ).dtype
  1065. assert x.dtype == dt_nonempty
  1066. def test_empty_rhs(self):
  1067. a = np.eye(2)
  1068. b = [[], []]
  1069. x = solve_triangular(a, b)
  1070. assert_(x.size == 0, 'Returned array is not empty')
  1071. assert_(x.shape == (2, 0), 'Returned empty array shape is wrong')
  1072. class TestInv:
  1073. def test_simple(self):
  1074. a = [[1, 2], [3, 4]]
  1075. a_inv = inv(a)
  1076. assert_array_almost_equal(dot(a, a_inv), np.eye(2))
  1077. a = [[1, 2, 3], [4, 5, 6], [7, 8, 10]]
  1078. a_inv = inv(a)
  1079. assert_array_almost_equal(dot(a, a_inv), np.eye(3))
  1080. def test_random(self):
  1081. rng = np.random.default_rng(1234)
  1082. n = 20
  1083. for i in range(4):
  1084. a = rng.random([n, n])
  1085. for i in range(n):
  1086. a[i, i] = 20*(.1+a[i, i])
  1087. a_inv = inv(a)
  1088. assert_array_almost_equal(dot(a, a_inv),
  1089. identity(n))
  1090. def test_simple_complex(self):
  1091. a = [[1, 2], [3, 4j]]
  1092. a_inv = inv(a)
  1093. assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]])
  1094. def test_random_complex(self):
  1095. rng = np.random.default_rng(1234)
  1096. n = 20
  1097. for i in range(4):
  1098. a = rng.random([n, n])+2j*rng.random([n, n])
  1099. for i in range(n):
  1100. a[i, i] = 20*(.1+a[i, i])
  1101. a_inv = inv(a)
  1102. assert_array_almost_equal(dot(a, a_inv),
  1103. identity(n))
  1104. def test_check_finite(self):
  1105. a = [[1, 2], [3, 4]]
  1106. a_inv = inv(a, check_finite=False)
  1107. assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]])
  1108. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  1109. def test_empty(self, dt):
  1110. a = np.empty((0, 0), dtype=dt)
  1111. a_inv = inv(a)
  1112. assert a_inv.size == 0
  1113. assert a_inv.dtype == inv(np.eye(2, dtype=dt)).dtype
  1114. a = np.ones((3, 0, 2, 2), dtype=dt)
  1115. a_inv = inv(a)
  1116. assert a_inv.shape == (3, 0, 2, 2)
  1117. a = np.ones((3, 1, 0, 0), dtype=dt)
  1118. a_inv = inv(a)
  1119. assert a_inv.shape == (3, 1, 0, 0)
  1120. @pytest.mark.xfail(reason="TODO: re-enable overwrite_a")
  1121. def test_overwrite_a(self):
  1122. a = np.arange(1, 5).reshape(2, 2)
  1123. a_inv = inv(a, overwrite_a=True)
  1124. assert_allclose(a_inv @ a, np.eye(2), atol=1e-14)
  1125. assert not np.shares_memory(a, a_inv) # int arrays are copied internally
  1126. # 2D F-ordered arrays of LAPACK-compatible dtypes: works inplace
  1127. a = a.astype(float).copy(order='F')
  1128. a_inv = inv(a, overwrite_a=True)
  1129. assert np.shares_memory(a, a_inv)
  1130. def test_readonly(self):
  1131. a = np.eye(3)
  1132. a.flags.writeable = False
  1133. a_inv = inv(a)
  1134. assert_allclose(a_inv, a, atol=1e-14)
  1135. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  1136. def test_batch_core_1x1(self, dt):
  1137. a = np.arange(3*2, dtype=dt).reshape(3, 2, 1, 1) + 1
  1138. a_inv = inv(a)
  1139. assert a_inv.shape == a.shape
  1140. assert_allclose(a @ a_inv, 1.)
  1141. @parametrize_overwrite_arg
  1142. def test_batch_zero_stride(self, overwrite_kw):
  1143. a = np.arange(3*2*2, dtype=float).reshape(3, 2, 2)
  1144. aa = a[None, ...]
  1145. a_inv = inv(aa, **overwrite_kw)
  1146. assert a_inv.shape == aa.shape
  1147. assert_allclose(aa @ a_inv, np.broadcast_to(np.eye(2), aa.shape), atol=2e-14)
  1148. aa = a[:, None, ...]
  1149. a_inv = inv(aa, **overwrite_kw)
  1150. assert a_inv.shape == aa.shape
  1151. assert_allclose(aa @ a_inv, np.broadcast_to(np.eye(2), aa.shape), atol=2e-14)
  1152. @parametrize_overwrite_arg
  1153. def test_batch_negative_stride(self, overwrite_kw):
  1154. a = np.arange(3*8).reshape(2, 3, 2, 2)
  1155. a = a[:, ::-1, :, :]
  1156. a_inv = inv(a, **overwrite_kw)
  1157. assert a_inv.shape == a.shape
  1158. assert_allclose(a @ a_inv, np.broadcast_to(np.eye(2), a.shape), atol=5e-14)
  1159. @parametrize_overwrite_arg
  1160. def test_core_negative_stride(self, overwrite_kw):
  1161. a = np.arange(3*8).reshape(2, 3, 2, 2)
  1162. a = a[:, :, ::-1, :]
  1163. a_inv = inv(a, **overwrite_kw)
  1164. assert a_inv.shape == a.shape
  1165. assert_allclose(a @ a_inv, np.broadcast_to(np.eye(2), a.shape), atol=5e-14)
  1166. @parametrize_overwrite_arg
  1167. def test_core_non_contiguous(self, overwrite_kw):
  1168. a = np.arange(3*8*2).reshape(2, 3, 2, 4)
  1169. a = a[..., ::2]
  1170. a_inv = inv(a, **overwrite_kw)
  1171. assert a_inv.shape == (2, 3, 2, 2)
  1172. assert_allclose(a @ a_inv, np.broadcast_to(np.eye(2), a.shape), atol=5e-14)
  1173. @parametrize_overwrite_arg
  1174. def test_batch_non_contiguous(self, overwrite_kw):
  1175. a = np.arange(3*8*2).reshape(2, 6, 2, 2)
  1176. a = a[:, ::2, ...]
  1177. a_inv = inv(a, **overwrite_kw)
  1178. assert a_inv.shape == (2, 3, 2, 2)
  1179. assert_allclose(a @ a_inv, np.broadcast_to(np.eye(2), a.shape), atol=2e-13)
  1180. @parametrize_overwrite_arg
  1181. def test_singular(self, overwrite_kw):
  1182. # 2D case: A singular matrix: raise
  1183. with assert_raises(LinAlgError):
  1184. inv(np.ones((2, 2)))
  1185. # batched case: If all slices are singlar, raise
  1186. with assert_raises(LinAlgError):
  1187. inv(np.ones((3, 2, 2)))
  1188. # XXX: shall we make this behavior configurable somehow?
  1189. # A "keep-going" option would be this:
  1190. # if some of the slices are singular and some are not,
  1191. # - singular slices are filled with nans
  1192. # - non-singular slices are inverted
  1193. # - there is no error
  1194. a = np.stack((np.ones((2, 2), dtype=complex), np.arange(4).reshape(2, 2)))
  1195. with assert_raises(LinAlgError):
  1196. inv(a)
  1197. # this would be true for a "keep-going" option
  1198. # assert np.isnan(a_inv[0, ...]).all()
  1199. # assert_allclose(a_inv[1, ...] @ a[1, ...], np.eye(2), atol=1e-14)
  1200. def test_ill_cond(self):
  1201. a = np.diag([1., 1e-20])
  1202. with pytest.warns(LinAlgWarning):
  1203. inv(a)
  1204. a2 = np.stack([np.diag([1., 1e-20]), np.diag([1, 1]), np.diag([1, 1e-20])])
  1205. with pytest.warns(LinAlgWarning):
  1206. inv(a2)
  1207. def test_wrong_assume_a(self):
  1208. with assert_raises(KeyError):
  1209. inv(np.eye(2), assume_a="kaboom")
  1210. def test_posdef(self):
  1211. x = np.arange(25, dtype=float).reshape(5, 5)
  1212. y = x + x.T
  1213. y += 21*np.eye(5)
  1214. y_inv0 = inv(y)
  1215. y_inv1 = inv(y, assume_a="pos")
  1216. assert_allclose(y_inv1, y_inv0, atol=1e-15)
  1217. # check that the lower triangle is not referenced for `lower=False`
  1218. mask = np.where(1 - np.tri(*y.shape, -1) == 0, np.nan, 1)
  1219. y_inv2 = inv(y*mask, check_finite=False, assume_a="pos", lower=False)
  1220. assert_allclose(y_inv2, y_inv0, atol=1e-15)
  1221. # repeat with the upper triangle
  1222. y_inv3 = inv(y*mask.T, check_finite=False, assume_a="pos", lower=True)
  1223. assert_allclose(y_inv3, y_inv0, atol=1e-15)
  1224. @pytest.mark.parametrize('complex_', [False, True])
  1225. def test_posdef_not_posdef(self, complex_):
  1226. # the `b` matrix is invertible but not pos definite: test the "sym" fallback
  1227. a = np.arange(9).reshape(3, 3)
  1228. b = a + a.T + np.eye(3)
  1229. if complex_:
  1230. b = b + 1j*b
  1231. # cholesky solver fails, and the routine falls back to the symmetric inverse
  1232. b_inv0 = inv(b)
  1233. assert_allclose(b_inv0 @ b, np.eye(3), atol=3e-15)
  1234. # but it does not fall back if `assume_a` is given
  1235. with assert_raises(LinAlgError):
  1236. inv(b, assume_a='pos')
  1237. # test posdef fallback to the hermitian solver, too
  1238. if complex_:
  1239. a = np.arange(9).reshape(3, 3)
  1240. a = a + 1j*a
  1241. b = a + a.T.conj() + np.eye(3)
  1242. assert_allclose(inv(b) @ b, np.eye(3), atol=3e-15)
  1243. @pytest.mark.parametrize('complex_', [False, True])
  1244. @pytest.mark.parametrize('sym_herm', ['sym', 'her'])
  1245. def test_sym_her(self, complex_, sym_herm):
  1246. # test "sym" and "her" modes
  1247. a = np.arange(9).reshape(3, 3)
  1248. if complex_:
  1249. a = a + 1j*a
  1250. if sym_herm == "sym":
  1251. b = a + a.T
  1252. else: # sym_herm == "herm":
  1253. b = a + a.T.conj()
  1254. b = b + np.eye(3)
  1255. b_inv0 = np.linalg.inv(b)
  1256. assert_allclose(b_inv0 @ b, np.eye(3), atol=1e-14)
  1257. b_inv1 = inv(b, assume_a=sym_herm)
  1258. assert_allclose(b_inv0, b_inv1, atol=1e-15)
  1259. # check that the "other" triangle is not referenced
  1260. mask = np.where(1 - np.tri(*a.shape, -1) == 0, np.nan, 1)
  1261. b_inv2 = inv(b*mask, check_finite=False, assume_a=sym_herm, lower=False)
  1262. assert_allclose(b_inv2, b_inv0, atol=1e-15)
  1263. # repeat with the upper triangle
  1264. b_inv3 = inv(b*mask.T, check_finite=False, assume_a=sym_herm, lower=True)
  1265. assert_allclose(b_inv3, b_inv0, atol=1e-15)
  1266. def test_triangular_1(self):
  1267. x = np.arange(25, dtype=float).reshape(5, 5)
  1268. y = x + x.T
  1269. y += 21*np.eye(5)
  1270. y_inv0 = inv(y, assume_a='upper triangular')
  1271. # check that upper triangular differs from posdef
  1272. y_inv_posdef = inv(y, assume_a='pos')
  1273. assert not np.allclose(y_inv0, y_inv_posdef)
  1274. def test_triangular_2(self):
  1275. y = np.ones(25, dtype=float).reshape(5, 5)
  1276. y_inv_0_u = inv(np.triu(y))
  1277. assert_allclose(y_inv_0_u @ np.triu(y), np.eye(5), atol=1e-15)
  1278. y_inv_1_u = inv(y, assume_a='upper triangular')
  1279. assert_allclose(y_inv_1_u @ np.triu(y), np.eye(5), atol=1e-15)
  1280. # check that the lower triangle is not referenced for "upper triangular"
  1281. mask = np.where(1 - np.tri(*y.shape, -1) == 0, np.nan, 1)
  1282. y_inv_2_u = inv(y*mask, check_finite=False, assume_a='upper triangular')
  1283. assert_allclose(y_inv_2_u @ np.triu(y), np.eye(5), atol=1e-15)
  1284. # repeat for the lower traingular matrix
  1285. y_inv_0_l = inv(np.tril(y))
  1286. assert_allclose(y_inv_0_l @ np.tril(y), np.eye(5), atol=1e-15)
  1287. y_inv_1_l = inv(y, assume_a='lower triangular')
  1288. assert_allclose(y_inv_1_l @ np.tril(y), np.eye(5), atol=1e-15)
  1289. # check that the lower triangle is not referenced for "lower triangular"
  1290. mask = np.where(1 - np.tri(*y.shape, -1) == 0, np.nan, 1)
  1291. y_inv_2_l = inv(y*mask.T, check_finite=False, assume_a='lower triangular')
  1292. assert_allclose(y_inv_2_l @ np.tril(y), np.eye(5), atol=1e-15)
  1293. def test_diagonal(self):
  1294. a = np.stack([np.triu(np.ones((3, 3))), np.diag(np.arange(1, 4))])
  1295. inv_a = inv(a)
  1296. # basic diagonal invert
  1297. assert_allclose(inv_a[1], np.diag(1 / np.arange(1, 4)), atol=1e-14)
  1298. # ill-conditioned inputs warn
  1299. a = np.asarray([[1e30, 0], [0, 1]])
  1300. with pytest.warns(LinAlgWarning):
  1301. inv(a, assume_a="diagonal")
  1302. # singular input raises
  1303. a = np.asarray([[0, 0], [0, 1]])
  1304. with pytest.raises(LinAlgError):
  1305. inv(a, assume_a="diagonal")
  1306. class TestDet:
  1307. def test_1x1_all_singleton_dims(self):
  1308. a = np.array([[1]])
  1309. deta = det(a)
  1310. assert deta.dtype.char == 'd'
  1311. assert np.isscalar(deta)
  1312. assert deta == 1.
  1313. a = np.array([[[[1]]]], dtype='f')
  1314. deta = det(a)
  1315. assert deta.dtype.char == 'd'
  1316. assert deta.shape == (1, 1)
  1317. assert_equal(deta, [[1.0]])
  1318. a = np.array([[[1 + 3.j]]], dtype=np.complex64)
  1319. deta = det(a)
  1320. assert deta.dtype.char == 'D'
  1321. assert deta.shape == (1,)
  1322. assert_equal(deta, [1.+3.j])
  1323. def test_1by1_stacked_input_output(self):
  1324. rng = np.random.default_rng(1680305949878959)
  1325. a = rng.random([4, 5, 1, 1], dtype=np.float32)
  1326. deta = det(a)
  1327. assert deta.dtype.char == 'd'
  1328. assert deta.shape == (4, 5)
  1329. assert_allclose(deta, np.squeeze(a))
  1330. a = rng.random([4, 5, 1, 1], dtype=np.float32)*np.complex64(1.j)
  1331. deta = det(a)
  1332. assert deta.dtype.char == 'D'
  1333. assert deta.shape == (4, 5)
  1334. assert_allclose(deta, np.squeeze(a))
  1335. @pytest.mark.parametrize('shape', [[2, 2], [20, 20], [3, 2, 20, 20]])
  1336. def test_simple_det_shapes_real_complex(self, shape):
  1337. rng = np.random.default_rng(1680305949878959)
  1338. a = rng.uniform(-1., 1., size=shape)
  1339. d1, d2 = det(a), np.linalg.det(a)
  1340. assert_allclose(d1, d2)
  1341. b = rng.uniform(-1., 1., size=shape)*1j
  1342. b += rng.uniform(-0.5, 0.5, size=shape)
  1343. d3, d4 = det(b), np.linalg.det(b)
  1344. assert_allclose(d3, d4)
  1345. def test_for_known_det_values(self):
  1346. # Hadamard8
  1347. a = np.array([[1, 1, 1, 1, 1, 1, 1, 1],
  1348. [1, -1, 1, -1, 1, -1, 1, -1],
  1349. [1, 1, -1, -1, 1, 1, -1, -1],
  1350. [1, -1, -1, 1, 1, -1, -1, 1],
  1351. [1, 1, 1, 1, -1, -1, -1, -1],
  1352. [1, -1, 1, -1, -1, 1, -1, 1],
  1353. [1, 1, -1, -1, -1, -1, 1, 1],
  1354. [1, -1, -1, 1, -1, 1, 1, -1]])
  1355. assert_allclose(det(a), 4096.)
  1356. # consecutive number array always singular
  1357. assert_allclose(det(np.arange(25).reshape(5, 5)), 0.)
  1358. # simple anti-diagonal block array
  1359. # Upper right has det (-2+1j) and lower right has (-2-1j)
  1360. # det(a) = - (-2+1j) (-2-1j) = 5.
  1361. a = np.array([[0.+0.j, 0.+0.j, 0.-1.j, 1.-1.j],
  1362. [0.+0.j, 0.+0.j, 1.+0.j, 0.-1.j],
  1363. [0.+1.j, 1.+1.j, 0.+0.j, 0.+0.j],
  1364. [1.+0.j, 0.+1.j, 0.+0.j, 0.+0.j]], dtype=np.complex64)
  1365. assert_allclose(det(a), 5.+0.j)
  1366. # Fiedler companion complexified
  1367. # >>> a = scipy.linalg.fiedler_companion(np.arange(1, 10))
  1368. a = np.array([[-2., -3., 1., 0., 0., 0., 0., 0.],
  1369. [1., 0., 0., 0., 0., 0., 0., 0.],
  1370. [0., -4., 0., -5., 1., 0., 0., 0.],
  1371. [0., 1., 0., 0., 0., 0., 0., 0.],
  1372. [0., 0., 0., -6., 0., -7., 1., 0.],
  1373. [0., 0., 0., 1., 0., 0., 0., 0.],
  1374. [0., 0., 0., 0., 0., -8., 0., -9.],
  1375. [0., 0., 0., 0., 0., 1., 0., 0.]])*1.j
  1376. assert_allclose(det(a), 9.)
  1377. # g and G dtypes are handled differently in windows and other platforms
  1378. @pytest.mark.parametrize('typ', [x for x in np.typecodes['All'][:20]
  1379. if x not in 'gG'])
  1380. def test_sample_compatible_dtype_input(self, typ):
  1381. rng = np.random.default_rng(1680305949878959)
  1382. n = 4
  1383. a = rng.random([n, n]).astype(typ) # value is not important
  1384. assert isinstance(det(a), (np.float64 | np.complex128))
  1385. def test_incompatible_dtype_input(self):
  1386. # Double backslashes needed for escaping pytest regex.
  1387. msg = 'cannot be cast to float\\(32, 64\\)'
  1388. for c, t in zip('SUO', ['bytes8', 'str32', 'object']):
  1389. with assert_raises(TypeError, match=msg):
  1390. det(np.array([['a', 'b']]*2, dtype=c))
  1391. with assert_raises(TypeError, match=msg):
  1392. det(np.array([[b'a', b'b']]*2, dtype='V'))
  1393. with assert_raises(TypeError, match=msg):
  1394. det(np.array([[100, 200]]*2, dtype='datetime64[s]'))
  1395. with assert_raises(TypeError, match=msg):
  1396. det(np.array([[100, 200]]*2, dtype='timedelta64[s]'))
  1397. def test_empty_edge_cases(self):
  1398. assert_allclose(det(np.empty([0, 0])), 1.)
  1399. assert_allclose(det(np.empty([0, 0, 0])), np.array([]))
  1400. assert_allclose(det(np.empty([3, 0, 0])), np.array([1., 1., 1.]))
  1401. with assert_raises(ValueError, match='Last 2 dimensions'):
  1402. det(np.empty([0, 0, 3]))
  1403. with assert_raises(ValueError, match='at least two-dimensional'):
  1404. det(np.array([]))
  1405. with assert_raises(ValueError, match='Last 2 dimensions'):
  1406. det(np.array([[]]))
  1407. with assert_raises(ValueError, match='Last 2 dimensions'):
  1408. det(np.array([[[]]]))
  1409. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  1410. def test_empty_dtype(self, dt):
  1411. a = np.empty((0, 0), dtype=dt)
  1412. d = det(a)
  1413. assert d.shape == ()
  1414. assert d.dtype == det(np.eye(2, dtype=dt)).dtype
  1415. a = np.empty((3, 0, 0), dtype=dt)
  1416. d = det(a)
  1417. assert d.shape == (3,)
  1418. assert d.dtype == det(np.zeros((3, 1, 1), dtype=dt)).dtype
  1419. def test_overwrite_a(self):
  1420. # If all conditions are met then input should be overwritten;
  1421. # - dtype is one of 'fdFD'
  1422. # - C-contiguous
  1423. # - writeable
  1424. a = np.arange(9).reshape(3, 3).astype(np.float32)
  1425. ac = a.copy()
  1426. deta = det(ac, overwrite_a=True)
  1427. assert_allclose(deta, 0.)
  1428. assert not (a == ac).all()
  1429. def test_readonly_array(self):
  1430. a = np.array([[2., 0., 1.], [5., 3., -1.], [1., 1., 1.]])
  1431. a.setflags(write=False)
  1432. # overwrite_a will be overridden
  1433. assert_allclose(det(a, overwrite_a=True), 10.)
  1434. def test_simple_check_finite(self):
  1435. a = [[1, 2], [3, np.inf]]
  1436. with assert_raises(ValueError, match='array must not contain'):
  1437. det(a)
  1438. def direct_lstsq(a, b, cmplx=0):
  1439. at = transpose(a)
  1440. if cmplx:
  1441. at = conjugate(at)
  1442. a1 = dot(at, a)
  1443. b1 = dot(at, b)
  1444. return solve(a1, b1)
  1445. class TestLstsq:
  1446. lapack_drivers = ('gelsd', 'gelss', 'gelsy', None)
  1447. def test_simple_exact(self):
  1448. for dtype in REAL_DTYPES:
  1449. a = np.array([[1, 20], [-30, 4]], dtype=dtype)
  1450. for lapack_driver in TestLstsq.lapack_drivers:
  1451. for overwrite in (True, False):
  1452. for bt in (((1, 0), (0, 1)), (1, 0),
  1453. ((2, 1), (-30, 4))):
  1454. # Store values in case they are overwritten
  1455. # later
  1456. a1 = a.copy()
  1457. b = np.array(bt, dtype=dtype)
  1458. b1 = b.copy()
  1459. out = lstsq(a1, b1,
  1460. lapack_driver=lapack_driver,
  1461. overwrite_a=overwrite,
  1462. overwrite_b=overwrite)
  1463. x = out[0]
  1464. r = out[2]
  1465. assert_(r == 2,
  1466. f'expected efficient rank 2, got {r}')
  1467. assert_allclose(dot(a, x), b,
  1468. atol=25 * _eps_cast(a1.dtype),
  1469. rtol=25 * _eps_cast(a1.dtype),
  1470. err_msg=f"driver: {lapack_driver}")
  1471. def test_simple_overdet(self):
  1472. for dtype in REAL_DTYPES:
  1473. a = np.array([[1, 2], [4, 5], [3, 4]], dtype=dtype)
  1474. b = np.array([1, 2, 3], dtype=dtype)
  1475. for lapack_driver in TestLstsq.lapack_drivers:
  1476. for overwrite in (True, False):
  1477. # Store values in case they are overwritten later
  1478. a1 = a.copy()
  1479. b1 = b.copy()
  1480. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1481. overwrite_a=overwrite,
  1482. overwrite_b=overwrite)
  1483. x = out[0]
  1484. if lapack_driver == 'gelsy':
  1485. residuals = np.sum((b - a.dot(x))**2)
  1486. else:
  1487. residuals = out[1]
  1488. r = out[2]
  1489. assert_(r == 2, f'expected efficient rank 2, got {r}')
  1490. assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0),
  1491. residuals,
  1492. rtol=25 * _eps_cast(a1.dtype),
  1493. atol=25 * _eps_cast(a1.dtype),
  1494. err_msg=f"driver: {lapack_driver}")
  1495. assert_allclose(x, (-0.428571428571429, 0.85714285714285),
  1496. rtol=25 * _eps_cast(a1.dtype),
  1497. atol=25 * _eps_cast(a1.dtype),
  1498. err_msg=f"driver: {lapack_driver}")
  1499. def test_simple_overdet_complex(self):
  1500. for dtype in COMPLEX_DTYPES:
  1501. a = np.array([[1+2j, 2], [4, 5], [3, 4]], dtype=dtype)
  1502. b = np.array([1, 2+4j, 3], dtype=dtype)
  1503. for lapack_driver in TestLstsq.lapack_drivers:
  1504. for overwrite in (True, False):
  1505. # Store values in case they are overwritten later
  1506. a1 = a.copy()
  1507. b1 = b.copy()
  1508. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1509. overwrite_a=overwrite,
  1510. overwrite_b=overwrite)
  1511. x = out[0]
  1512. if lapack_driver == 'gelsy':
  1513. res = b - a.dot(x)
  1514. residuals = np.sum(res * res.conj())
  1515. else:
  1516. residuals = out[1]
  1517. r = out[2]
  1518. assert_(r == 2, f'expected efficient rank 2, got {r}')
  1519. assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0),
  1520. residuals,
  1521. rtol=25 * _eps_cast(a1.dtype),
  1522. atol=25 * _eps_cast(a1.dtype),
  1523. err_msg=f"driver: {lapack_driver}")
  1524. assert_allclose(
  1525. x, (-0.4831460674157303 + 0.258426966292135j,
  1526. 0.921348314606741 + 0.292134831460674j),
  1527. rtol=25 * _eps_cast(a1.dtype),
  1528. atol=25 * _eps_cast(a1.dtype),
  1529. err_msg=f"driver: {lapack_driver}")
  1530. def test_simple_underdet(self):
  1531. for dtype in REAL_DTYPES:
  1532. a = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
  1533. b = np.array([1, 2], dtype=dtype)
  1534. for lapack_driver in TestLstsq.lapack_drivers:
  1535. for overwrite in (True, False):
  1536. # Store values in case they are overwritten later
  1537. a1 = a.copy()
  1538. b1 = b.copy()
  1539. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1540. overwrite_a=overwrite,
  1541. overwrite_b=overwrite)
  1542. x = out[0]
  1543. r = out[2]
  1544. assert_(r == 2, f'expected efficient rank 2, got {r}')
  1545. assert_allclose(x, (-0.055555555555555, 0.111111111111111,
  1546. 0.277777777777777),
  1547. rtol=25 * _eps_cast(a1.dtype),
  1548. atol=25 * _eps_cast(a1.dtype),
  1549. err_msg=f"driver: {lapack_driver}")
  1550. @pytest.mark.parametrize("dtype", REAL_DTYPES)
  1551. @pytest.mark.parametrize("n", (20, 200))
  1552. @pytest.mark.parametrize("lapack_driver", lapack_drivers)
  1553. @pytest.mark.parametrize("overwrite", (True, False))
  1554. def test_random_exact(self, dtype, n, lapack_driver, overwrite):
  1555. rng = np.random.RandomState(1234)
  1556. a = np.asarray(rng.random([n, n]), dtype=dtype)
  1557. for i in range(n):
  1558. a[i, i] = 20 * (0.1 + a[i, i])
  1559. for i in range(4):
  1560. b = np.asarray(rng.random([n, 3]), dtype=dtype)
  1561. # Store values in case they are overwritten later
  1562. a1 = a.copy()
  1563. b1 = b.copy()
  1564. out = lstsq(a1, b1,
  1565. lapack_driver=lapack_driver,
  1566. overwrite_a=overwrite,
  1567. overwrite_b=overwrite)
  1568. x = out[0]
  1569. r = out[2]
  1570. assert_(r == n, f'expected efficient rank {n}, '
  1571. f'got {r}')
  1572. if dtype is np.float32:
  1573. assert_allclose(
  1574. dot(a, x), b,
  1575. rtol=500 * _eps_cast(a1.dtype),
  1576. atol=500 * _eps_cast(a1.dtype),
  1577. err_msg=f"driver: {lapack_driver}")
  1578. else:
  1579. assert_allclose(
  1580. dot(a, x), b,
  1581. rtol=1000 * _eps_cast(a1.dtype),
  1582. atol=1000 * _eps_cast(a1.dtype),
  1583. err_msg=f"driver: {lapack_driver}")
  1584. @pytest.mark.skipif(IS_MUSL, reason="may segfault on Alpine, see gh-17630")
  1585. @pytest.mark.parametrize("dtype", COMPLEX_DTYPES)
  1586. @pytest.mark.parametrize("n", (20, 200))
  1587. @pytest.mark.parametrize("lapack_driver", lapack_drivers)
  1588. @pytest.mark.parametrize("overwrite", (True, False))
  1589. def test_random_complex_exact(self, dtype, n, lapack_driver, overwrite):
  1590. rng = np.random.RandomState(1234)
  1591. a = np.asarray(rng.random([n, n]) + 1j*rng.random([n, n]),
  1592. dtype=dtype)
  1593. for i in range(n):
  1594. a[i, i] = 20 * (0.1 + a[i, i])
  1595. for i in range(2):
  1596. b = np.asarray(rng.random([n, 3]), dtype=dtype)
  1597. # Store values in case they are overwritten later
  1598. a1 = a.copy()
  1599. b1 = b.copy()
  1600. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1601. overwrite_a=overwrite,
  1602. overwrite_b=overwrite)
  1603. x = out[0]
  1604. r = out[2]
  1605. assert_(r == n, f'expected efficient rank {n}, '
  1606. f'got {r}')
  1607. if dtype is np.complex64:
  1608. assert_allclose(
  1609. dot(a, x), b,
  1610. rtol=400 * _eps_cast(a1.dtype),
  1611. atol=400 * _eps_cast(a1.dtype),
  1612. err_msg=f"driver: {lapack_driver}")
  1613. else:
  1614. assert_allclose(
  1615. dot(a, x), b,
  1616. rtol=1000 * _eps_cast(a1.dtype),
  1617. atol=1000 * _eps_cast(a1.dtype),
  1618. err_msg=f"driver: {lapack_driver}")
  1619. def test_random_overdet(self):
  1620. rng = np.random.RandomState(1234)
  1621. for dtype in REAL_DTYPES:
  1622. for (n, m) in ((20, 15), (200, 2)):
  1623. for lapack_driver in TestLstsq.lapack_drivers:
  1624. for overwrite in (True, False):
  1625. a = np.asarray(rng.random([n, m]), dtype=dtype)
  1626. for i in range(m):
  1627. a[i, i] = 20 * (0.1 + a[i, i])
  1628. for i in range(4):
  1629. b = np.asarray(rng.random([n, 3]), dtype=dtype)
  1630. # Store values in case they are overwritten later
  1631. a1 = a.copy()
  1632. b1 = b.copy()
  1633. out = lstsq(a1, b1,
  1634. lapack_driver=lapack_driver,
  1635. overwrite_a=overwrite,
  1636. overwrite_b=overwrite)
  1637. x = out[0]
  1638. r = out[2]
  1639. assert_(r == m, f'expected efficient rank {m}, '
  1640. f'got {r}')
  1641. assert_allclose(
  1642. x, direct_lstsq(a, b, cmplx=0),
  1643. rtol=25 * _eps_cast(a1.dtype),
  1644. atol=25 * _eps_cast(a1.dtype),
  1645. err_msg=f"driver: {lapack_driver}")
  1646. def test_random_complex_overdet(self):
  1647. rng = np.random.RandomState(1234)
  1648. for dtype in COMPLEX_DTYPES:
  1649. for (n, m) in ((20, 15), (200, 2)):
  1650. for lapack_driver in TestLstsq.lapack_drivers:
  1651. for overwrite in (True, False):
  1652. a = np.asarray(rng.random([n, m]) + 1j*rng.random([n, m]),
  1653. dtype=dtype)
  1654. for i in range(m):
  1655. a[i, i] = 20 * (0.1 + a[i, i])
  1656. for i in range(2):
  1657. b = np.asarray(rng.random([n, 3]), dtype=dtype)
  1658. # Store values in case they are overwritten
  1659. # later
  1660. a1 = a.copy()
  1661. b1 = b.copy()
  1662. out = lstsq(a1, b1,
  1663. lapack_driver=lapack_driver,
  1664. overwrite_a=overwrite,
  1665. overwrite_b=overwrite)
  1666. x = out[0]
  1667. r = out[2]
  1668. assert_(r == m, f'expected efficient rank {m}, '
  1669. f'got {r}')
  1670. assert_allclose(
  1671. x, direct_lstsq(a, b, cmplx=1),
  1672. rtol=25 * _eps_cast(a1.dtype),
  1673. atol=25 * _eps_cast(a1.dtype),
  1674. err_msg=f"driver: {lapack_driver}")
  1675. def test_check_finite(self):
  1676. with warnings.catch_warnings():
  1677. # On (some) OSX this tests triggers a warning (gh-7538)
  1678. warnings.filterwarnings("ignore",
  1679. "internal gelsd driver lwork query error,.*"
  1680. "Falling back to 'gelss' driver.", RuntimeWarning)
  1681. at = np.array(((1, 20), (-30, 4)))
  1682. for dtype, bt, lapack_driver, overwrite, check_finite in \
  1683. itertools.product(REAL_DTYPES,
  1684. (((1, 0), (0, 1)), (1, 0), ((2, 1), (-30, 4))),
  1685. TestLstsq.lapack_drivers,
  1686. (True, False),
  1687. (True, False)):
  1688. a = at.astype(dtype)
  1689. b = np.array(bt, dtype=dtype)
  1690. # Store values in case they are overwritten
  1691. # later
  1692. a1 = a.copy()
  1693. b1 = b.copy()
  1694. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1695. check_finite=check_finite, overwrite_a=overwrite,
  1696. overwrite_b=overwrite)
  1697. x = out[0]
  1698. r = out[2]
  1699. assert_(r == 2, f'expected efficient rank 2, got {r}')
  1700. assert_allclose(dot(a, x), b,
  1701. rtol=25 * _eps_cast(a.dtype),
  1702. atol=25 * _eps_cast(a.dtype),
  1703. err_msg=f"driver: {lapack_driver}")
  1704. def test_empty(self):
  1705. for a_shape, b_shape in (((0, 2), (0,)),
  1706. ((0, 4), (0, 2)),
  1707. ((4, 0), (4,)),
  1708. ((4, 0), (4, 2))):
  1709. b = np.ones(b_shape)
  1710. x, residues, rank, s = lstsq(np.zeros(a_shape), b)
  1711. assert_equal(x, np.zeros((a_shape[1],) + b_shape[1:]))
  1712. residues_should_be = (np.empty((0,)) if a_shape[1]
  1713. else np.linalg.norm(b, axis=0)**2)
  1714. assert_equal(residues, residues_should_be)
  1715. assert_(rank == 0, 'expected rank 0')
  1716. assert_equal(s, np.empty((0,)))
  1717. @pytest.mark.parametrize('dt_a', [int, float, np.float32, complex, np.complex64])
  1718. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  1719. def test_empty_dtype(self, dt_a, dt_b):
  1720. a = np.empty((0, 0), dtype=dt_a)
  1721. b = np.empty(0, dtype=dt_b)
  1722. x, residues, rank, s = lstsq(a, b)
  1723. assert x.size == 0
  1724. dt_nonempty = lstsq(np.eye(2, dtype=dt_a), np.ones(2, dtype=dt_b))[0].dtype
  1725. assert x.dtype == dt_nonempty
  1726. class TestPinv:
  1727. def test_simple_real(self):
  1728. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1729. a_pinv = pinv(a)
  1730. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1731. def test_simple_complex(self):
  1732. a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]],
  1733. dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]],
  1734. dtype=float))
  1735. a_pinv = pinv(a)
  1736. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1737. def test_simple_singular(self):
  1738. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
  1739. a_pinv = pinv(a)
  1740. expected = array([[-6.38888889e-01, -1.66666667e-01, 3.05555556e-01],
  1741. [-5.55555556e-02, 1.30136518e-16, 5.55555556e-02],
  1742. [5.27777778e-01, 1.66666667e-01, -1.94444444e-01]])
  1743. assert_array_almost_equal(a_pinv, expected)
  1744. def test_simple_cols(self):
  1745. a = array([[1, 2, 3], [4, 5, 6]], dtype=float)
  1746. a_pinv = pinv(a)
  1747. expected = array([[-0.94444444, 0.44444444],
  1748. [-0.11111111, 0.11111111],
  1749. [0.72222222, -0.22222222]])
  1750. assert_array_almost_equal(a_pinv, expected)
  1751. def test_simple_rows(self):
  1752. a = array([[1, 2], [3, 4], [5, 6]], dtype=float)
  1753. a_pinv = pinv(a)
  1754. expected = array([[-1.33333333, -0.33333333, 0.66666667],
  1755. [1.08333333, 0.33333333, -0.41666667]])
  1756. assert_array_almost_equal(a_pinv, expected)
  1757. def test_check_finite(self):
  1758. a = array([[1, 2, 3], [4, 5, 6.], [7, 8, 10]])
  1759. a_pinv = pinv(a, check_finite=False)
  1760. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1761. def test_native_list_argument(self):
  1762. a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  1763. a_pinv = pinv(a)
  1764. expected = array([[-6.38888889e-01, -1.66666667e-01, 3.05555556e-01],
  1765. [-5.55555556e-02, 1.30136518e-16, 5.55555556e-02],
  1766. [5.27777778e-01, 1.66666667e-01, -1.94444444e-01]])
  1767. assert_array_almost_equal(a_pinv, expected)
  1768. def test_atol_rtol(self):
  1769. rng = np.random.default_rng(1234)
  1770. n = 12
  1771. # get a random ortho matrix for shuffling
  1772. q, _ = qr(rng.random((n, n)))
  1773. a_m = np.arange(35.0).reshape(7, 5)
  1774. a = a_m.copy()
  1775. a[0, 0] = 0.001
  1776. atol = 1e-5
  1777. rtol = 0.05
  1778. # svds of a_m is ~ [116.906, 4.234, tiny, tiny, tiny]
  1779. # svds of a is ~ [116.906, 4.234, 4.62959e-04, tiny, tiny]
  1780. # Just abs cutoff such that we arrive at a_modified
  1781. a_p = pinv(a_m, atol=atol, rtol=0.)
  1782. adiff1 = a @ a_p @ a - a
  1783. adiff2 = a_m @ a_p @ a_m - a_m
  1784. # Now adiff1 should be around atol value while adiff2 should be
  1785. # relatively tiny
  1786. assert_allclose(np.linalg.norm(adiff1), 5e-4, atol=5.e-4)
  1787. assert_allclose(np.linalg.norm(adiff2), 5e-14, atol=5.e-14)
  1788. # Now do the same but remove another sv ~4.234 via rtol
  1789. a_p = pinv(a_m, atol=atol, rtol=rtol)
  1790. adiff1 = a @ a_p @ a - a
  1791. adiff2 = a_m @ a_p @ a_m - a_m
  1792. assert_allclose(np.linalg.norm(adiff1), 4.233, rtol=0.01)
  1793. assert_allclose(np.linalg.norm(adiff2), 4.233, rtol=0.01)
  1794. @pytest.mark.parametrize('dt', [float, np.float32, complex, np.complex64])
  1795. def test_empty(self, dt):
  1796. a = np.empty((0, 0), dtype=dt)
  1797. a_pinv = pinv(a)
  1798. assert a_pinv.size == 0
  1799. assert a_pinv.dtype == pinv(np.eye(2, dtype=dt)).dtype
  1800. class TestPinvSymmetric:
  1801. def test_simple_real(self):
  1802. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1803. a = np.dot(a, a.T)
  1804. a_pinv = pinvh(a)
  1805. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1806. def test_nonpositive(self):
  1807. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
  1808. a = np.dot(a, a.T)
  1809. u, s, vt = np.linalg.svd(a)
  1810. s[0] *= -1
  1811. a = np.dot(u * s, vt) # a is now symmetric non-positive and singular
  1812. a_pinv = pinv(a)
  1813. a_pinvh = pinvh(a)
  1814. assert_array_almost_equal(a_pinv, a_pinvh)
  1815. def test_simple_complex(self):
  1816. a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]],
  1817. dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]],
  1818. dtype=float))
  1819. a = np.dot(a, a.conj().T)
  1820. a_pinv = pinvh(a)
  1821. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1822. def test_native_list_argument(self):
  1823. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1824. a = np.dot(a, a.T)
  1825. a_pinv = pinvh(a.tolist())
  1826. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1827. def test_zero_eigenvalue(self):
  1828. # https://github.com/scipy/scipy/issues/12515
  1829. # the SYEVR eigh driver may give the zero eigenvalue > eps
  1830. a = np.array([[1, -1, 0], [-1, 2, -1], [0, -1, 1]])
  1831. p = pinvh(a)
  1832. assert_allclose(p @ a @ p, p, atol=1e-15)
  1833. assert_allclose(a @ p @ a, a, atol=1e-15)
  1834. def test_atol_rtol(self):
  1835. rng = np.random.default_rng(1234)
  1836. n = 12
  1837. # get a random ortho matrix for shuffling
  1838. q, _ = qr(rng.random((n, n)))
  1839. a = np.diag([4, 3, 2, 1, 0.99e-4, 0.99e-5] + [0.99e-6]*(n-6))
  1840. a = q.T @ a @ q
  1841. a_m = np.diag([4, 3, 2, 1, 0.99e-4, 0.] + [0.]*(n-6))
  1842. a_m = q.T @ a_m @ q
  1843. atol = 1e-5
  1844. rtol = (4.01e-4 - 4e-5)/4
  1845. # Just abs cutoff such that we arrive at a_modified
  1846. a_p = pinvh(a, atol=atol, rtol=0.)
  1847. adiff1 = a @ a_p @ a - a
  1848. adiff2 = a_m @ a_p @ a_m - a_m
  1849. # Now adiff1 should dance around atol value since truncation
  1850. # while adiff2 should be relatively tiny
  1851. assert_allclose(norm(adiff1), atol, rtol=0.1)
  1852. assert_allclose(norm(adiff2), 1e-12, atol=1e-11)
  1853. # Now do the same but through rtol cancelling atol value
  1854. a_p = pinvh(a, atol=atol, rtol=rtol)
  1855. adiff1 = a @ a_p @ a - a
  1856. adiff2 = a_m @ a_p @ a_m - a_m
  1857. # adiff1 and adiff2 should be elevated to ~1e-4 due to mismatch
  1858. assert_allclose(norm(adiff1), 1e-4, rtol=0.1)
  1859. assert_allclose(norm(adiff2), 1e-4, rtol=0.1)
  1860. @pytest.mark.parametrize('dt', [float, np.float32, complex, np.complex64])
  1861. def test_empty(self, dt):
  1862. a = np.empty((0, 0), dtype=dt)
  1863. a_pinv = pinvh(a)
  1864. assert a_pinv.size == 0
  1865. assert a_pinv.dtype == pinv(np.eye(2, dtype=dt)).dtype
  1866. @pytest.mark.parametrize('scale', (1e-20, 1., 1e20))
  1867. @pytest.mark.parametrize('pinv_', (pinv, pinvh))
  1868. def test_auto_rcond(scale, pinv_):
  1869. x = np.array([[1, 0], [0, 1e-10]]) * scale
  1870. expected = np.diag(1. / np.diag(x))
  1871. x_inv = pinv_(x)
  1872. assert_allclose(x_inv, expected)
  1873. class TestVectorNorms:
  1874. def test_types(self):
  1875. for dtype in np.typecodes['AllFloat']:
  1876. x = np.array([1, 2, 3], dtype=dtype)
  1877. tol = max(1e-15, np.finfo(dtype).eps.real * 20)
  1878. assert_allclose(norm(x), np.sqrt(14), rtol=tol)
  1879. assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol)
  1880. for dtype in np.typecodes['Complex']:
  1881. x = np.array([1j, 2j, 3j], dtype=dtype)
  1882. tol = max(1e-15, np.finfo(dtype).eps.real * 20)
  1883. assert_allclose(norm(x), np.sqrt(14), rtol=tol)
  1884. assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol)
  1885. def test_overflow(self):
  1886. # unlike numpy's norm, this one is
  1887. # safer on overflow
  1888. a = array([1e20], dtype=float32)
  1889. assert_almost_equal(norm(a), a)
  1890. def test_stable(self):
  1891. # more stable than numpy's norm
  1892. a = array([1e4] + [1]*10000, dtype=float32)
  1893. try:
  1894. # snrm in double precision; we obtain the same as for float64
  1895. # -- large atol needed due to varying blas implementations
  1896. assert_allclose(norm(a) - 1e4, 0.5, atol=1e-2)
  1897. except AssertionError:
  1898. # snrm implemented in single precision, == np.linalg.norm result
  1899. msg = ": Result should equal either 0.0 or 0.5 (depending on " \
  1900. "implementation of snrm2)."
  1901. assert_almost_equal(norm(a) - 1e4, 0.0, err_msg=msg)
  1902. def test_zero_norm(self):
  1903. assert_equal(norm([1, 0, 3], 0), 2)
  1904. assert_equal(norm([1, 2, 3], 0), 3)
  1905. def test_axis_kwd(self):
  1906. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  1907. assert_allclose(norm(a, axis=1), [[3.60555128, 4.12310563]] * 2)
  1908. assert_allclose(norm(a, 1, axis=1), [[5.] * 2] * 2)
  1909. def test_keepdims_kwd(self):
  1910. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  1911. b = norm(a, axis=1, keepdims=True)
  1912. assert_allclose(b, [[[3.60555128, 4.12310563]]] * 2)
  1913. assert_(b.shape == (2, 1, 2))
  1914. assert_allclose(norm(a, 1, axis=2, keepdims=True), [[[3.], [7.]]] * 2)
  1915. @pytest.mark.skipif(not HAS_ILP64, reason="64-bit BLAS required")
  1916. def test_large_vector(self):
  1917. check_free_memory(free_mb=17000)
  1918. x = np.zeros([2**31], dtype=np.float64)
  1919. x[-1] = 1
  1920. res = norm(x)
  1921. del x
  1922. assert_allclose(res, 1.0)
  1923. class TestMatrixNorms:
  1924. def test_matrix_norms(self):
  1925. # Not all of these are matrix norms in the most technical sense.
  1926. rng = np.random.default_rng(1234)
  1927. for n, m in (1, 1), (1, 3), (3, 1), (4, 4), (4, 5), (5, 4):
  1928. for t in np.float32, np.float64, np.complex64, np.complex128, np.int64:
  1929. A = 10 * rng.standard_normal((n, m)).astype(t)
  1930. if np.issubdtype(A.dtype, np.complexfloating):
  1931. A += 10j * rng.standard_normal((n, m))
  1932. t_high = np.complex128
  1933. else:
  1934. t_high = np.float64
  1935. for order in (None, 'fro', 1, -1, 2, -2, np.inf, -np.inf):
  1936. actual = norm(A, ord=order)
  1937. desired = np.linalg.norm(A, ord=order)
  1938. # SciPy may return higher precision matrix norms.
  1939. # This is a consequence of using LAPACK.
  1940. if not np.allclose(actual, desired):
  1941. desired = np.linalg.norm(A.astype(t_high), ord=order)
  1942. assert_allclose(actual, desired)
  1943. def test_axis_kwd(self):
  1944. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  1945. b = norm(a, ord=np.inf, axis=(1, 0))
  1946. c = norm(np.swapaxes(a, 0, 1), ord=np.inf, axis=(0, 1))
  1947. d = norm(a, ord=1, axis=(0, 1))
  1948. assert_allclose(b, c)
  1949. assert_allclose(c, d)
  1950. assert_allclose(b, d)
  1951. assert_(b.shape == c.shape == d.shape)
  1952. b = norm(a, ord=1, axis=(1, 0))
  1953. c = norm(np.swapaxes(a, 0, 1), ord=1, axis=(0, 1))
  1954. d = norm(a, ord=np.inf, axis=(0, 1))
  1955. assert_allclose(b, c)
  1956. assert_allclose(c, d)
  1957. assert_allclose(b, d)
  1958. assert_(b.shape == c.shape == d.shape)
  1959. def test_keepdims_kwd(self):
  1960. a = np.arange(120, dtype='d').reshape(2, 3, 4, 5)
  1961. b = norm(a, ord=np.inf, axis=(1, 0), keepdims=True)
  1962. c = norm(a, ord=1, axis=(0, 1), keepdims=True)
  1963. assert_allclose(b, c)
  1964. assert_(b.shape == c.shape)
  1965. def test_empty(self):
  1966. a = np.empty((0, 0))
  1967. assert_allclose(norm(a), 0.)
  1968. assert_allclose(norm(a, axis=0), np.zeros((0,)))
  1969. assert_allclose(norm(a, keepdims=True), np.zeros((1, 1)))
  1970. a = np.empty((0, 3))
  1971. assert_allclose(norm(a), 0.)
  1972. assert_allclose(norm(a, axis=0), np.zeros((3,)))
  1973. assert_allclose(norm(a, keepdims=True), np.zeros((1, 1)))
  1974. class TestOverwrite:
  1975. def test_solve(self):
  1976. assert_no_overwrite(solve, [(3, 3), (3,)])
  1977. def test_solve_triangular(self):
  1978. assert_no_overwrite(solve_triangular, [(3, 3), (3,)])
  1979. def test_solve_banded(self):
  1980. assert_no_overwrite(lambda ab, b: solve_banded((2, 1), ab, b),
  1981. [(4, 6), (6,)])
  1982. def test_solveh_banded(self):
  1983. assert_no_overwrite(solveh_banded, [(2, 6), (6,)])
  1984. def test_inv(self):
  1985. assert_no_overwrite(inv, [(3, 3)])
  1986. def test_det(self):
  1987. assert_no_overwrite(det, [(3, 3)])
  1988. def test_lstsq(self):
  1989. assert_no_overwrite(lstsq, [(3, 2), (3,)])
  1990. def test_pinv(self):
  1991. assert_no_overwrite(pinv, [(3, 3)])
  1992. def test_pinvh(self):
  1993. assert_no_overwrite(pinvh, [(3, 3)])
  1994. class TestSolveCirculant:
  1995. def test_basic1(self):
  1996. c = np.array([1, 2, 3, 5])
  1997. b = np.array([1, -1, 1, 0])
  1998. x = solve_circulant(c, b)
  1999. y = solve(circulant(c), b)
  2000. assert_allclose(x, y)
  2001. def test_basic2(self):
  2002. # b is a 2-d matrix.
  2003. c = np.array([1, 2, -3, -5])
  2004. b = np.arange(12).reshape(4, 3)
  2005. x = solve_circulant(c, b)
  2006. y = solve(circulant(c), b)
  2007. assert_allclose(x, y)
  2008. def test_basic3(self):
  2009. # b is a 3-d matrix.
  2010. c = np.array([1, 2, -3, -5])
  2011. b = np.arange(24).reshape(4, 3, 2)
  2012. x = solve_circulant(c, b)
  2013. y = solve(circulant(c), b.reshape(4, -1)).reshape(b.shape)
  2014. assert_allclose(x, y)
  2015. def test_complex(self):
  2016. # Complex b and c
  2017. c = np.array([1+2j, -3, 4j, 5])
  2018. b = np.arange(8).reshape(4, 2) + 0.5j
  2019. x = solve_circulant(c, b)
  2020. y = solve(circulant(c), b)
  2021. assert_allclose(x, y)
  2022. def test_random_b_and_c(self):
  2023. # Random b and c
  2024. rng = np.random.RandomState(54321)
  2025. c = rng.standard_normal(50)
  2026. b = rng.standard_normal(50)
  2027. x = solve_circulant(c, b)
  2028. y = solve(circulant(c), b)
  2029. assert_allclose(x, y)
  2030. def test_singular(self):
  2031. # c gives a singular circulant matrix.
  2032. c = np.array([1, 1, 0, 0])
  2033. b = np.array([1, 2, 3, 4])
  2034. x = solve_circulant(c, b, singular='lstsq')
  2035. y, res, rnk, s = lstsq(circulant(c), b)
  2036. assert_allclose(x, y)
  2037. assert_raises(LinAlgError, solve_circulant, x, y)
  2038. def test_axis_args(self):
  2039. # Test use of caxis, baxis and outaxis.
  2040. # c has shape (2, 1, 4)
  2041. c = np.array([[[-1, 2.5, 3, 3.5]], [[1, 6, 6, 6.5]]])
  2042. # b has shape (3, 4)
  2043. b = np.array([[0, 0, 1, 1], [1, 1, 0, 0], [1, -1, 0, 0]])
  2044. x = solve_circulant(c, b, baxis=1)
  2045. assert_equal(x.shape, (4, 2, 3))
  2046. expected = np.empty_like(x)
  2047. expected[:, 0, :] = solve(circulant(c[0].ravel()), b.T)
  2048. expected[:, 1, :] = solve(circulant(c[1].ravel()), b.T)
  2049. assert_allclose(x, expected)
  2050. x = solve_circulant(c, b, baxis=1, outaxis=-1)
  2051. assert_equal(x.shape, (2, 3, 4))
  2052. assert_allclose(np.moveaxis(x, -1, 0), expected)
  2053. # np.swapaxes(c, 1, 2) has shape (2, 4, 1); b.T has shape (4, 3).
  2054. x = solve_circulant(np.swapaxes(c, 1, 2), b.T, caxis=1)
  2055. assert_equal(x.shape, (4, 2, 3))
  2056. assert_allclose(x, expected)
  2057. def test_native_list_arguments(self):
  2058. # Same as test_basic1 using python's native list.
  2059. c = [1, 2, 3, 5]
  2060. b = [1, -1, 1, 0]
  2061. x = solve_circulant(c, b)
  2062. y = solve(circulant(c), b)
  2063. assert_allclose(x, y)
  2064. @pytest.mark.parametrize('dt_c', [int, float, np.float32, complex, np.complex64])
  2065. @pytest.mark.parametrize('dt_b', [int, float, np.float32, complex, np.complex64])
  2066. def test_empty(self, dt_c, dt_b):
  2067. c = np.array([], dtype=dt_c)
  2068. b = np.array([], dtype=dt_b)
  2069. x = solve_circulant(c, b)
  2070. assert x.shape == (0,)
  2071. assert x.dtype == solve_circulant(np.arange(3, dtype=dt_c),
  2072. np.ones(3, dtype=dt_b)).dtype
  2073. b = np.empty((0, 0), dtype=dt_b)
  2074. x1 = solve_circulant(c, b)
  2075. assert x1.shape == (0, 0)
  2076. assert x1.dtype == x.dtype
  2077. class TestMatrix_Balance:
  2078. @skip_xp_invalid_arg
  2079. def test_string_arg(self):
  2080. assert_raises(ValueError, matrix_balance, 'Some string for fail')
  2081. def test_infnan_arg(self):
  2082. assert_raises(ValueError, matrix_balance,
  2083. np.array([[1, 2], [3, np.inf]]))
  2084. assert_raises(ValueError, matrix_balance,
  2085. np.array([[1, 2], [3, np.nan]]))
  2086. def test_scaling(self):
  2087. _, y = matrix_balance(np.array([[1000, 1], [1000, 0]]))
  2088. # Pre/post LAPACK 3.5.0 gives the same result up to an offset
  2089. # since in each case col norm is x1000 greater and
  2090. # 1000 / 32 ~= 1 * 32 hence balanced with 2 ** 5.
  2091. assert_allclose(np.diff(np.log2(np.diag(y))), [5])
  2092. def test_scaling_order(self):
  2093. A = np.array([[1, 0, 1e-4], [1, 1, 1e-2], [1e4, 1e2, 1]])
  2094. x, y = matrix_balance(A)
  2095. assert_allclose(solve(y, A).dot(y), x)
  2096. def test_separate(self):
  2097. _, (y, z) = matrix_balance(np.array([[1000, 1], [1000, 0]]),
  2098. separate=1)
  2099. assert_equal(np.diff(np.log2(y)), [5])
  2100. assert_allclose(z, np.arange(2))
  2101. def test_permutation(self):
  2102. A = block_diag(np.ones((2, 2)), np.tril(np.ones((2, 2))),
  2103. np.ones((3, 3)))
  2104. x, (y, z) = matrix_balance(A, separate=1)
  2105. assert_allclose(y, np.ones_like(y))
  2106. assert_allclose(z, np.array([0, 1, 6, 5, 4, 3, 2]))
  2107. def test_perm_and_scaling(self):
  2108. # Matrix with its diagonal removed
  2109. cases = ( # Case 0
  2110. np.array([[0., 0., 0., 0., 0.000002],
  2111. [0., 0., 0., 0., 0.],
  2112. [2., 2., 0., 0., 0.],
  2113. [2., 2., 0., 0., 0.],
  2114. [0., 0., 0.000002, 0., 0.]]),
  2115. # Case 1 user reported GH-7258
  2116. np.array([[-0.5, 0., 0., 0.],
  2117. [0., -1., 0., 0.],
  2118. [1., 0., -0.5, 0.],
  2119. [0., 1., 0., -1.]]),
  2120. # Case 2 user reported GH-7258
  2121. np.array([[-3., 0., 1., 0.],
  2122. [-1., -1., -0., 1.],
  2123. [-3., -0., -0., 0.],
  2124. [-1., -0., 1., -1.]])
  2125. )
  2126. for A in cases:
  2127. x, y = matrix_balance(A)
  2128. x, (s, p) = matrix_balance(A, separate=1)
  2129. ip = np.empty_like(p)
  2130. ip[p] = np.arange(A.shape[0])
  2131. assert_allclose(y, np.diag(s)[ip, :])
  2132. assert_allclose(solve(y, A).dot(y), x)
  2133. @pytest.mark.parametrize('dt', [int, float, np.float32, complex, np.complex64])
  2134. def test_empty(self, dt):
  2135. a = np.empty((0, 0), dtype=dt)
  2136. b, t = matrix_balance(a)
  2137. assert b.size == 0
  2138. assert t.size == 0
  2139. b_n, t_n = matrix_balance(np.eye(2, dtype=dt))
  2140. assert b.dtype == b_n.dtype
  2141. assert t.dtype == t_n.dtype
  2142. b, (scale, perm) = matrix_balance(a, separate=True)
  2143. assert b.size == 0
  2144. assert scale.size == 0
  2145. assert perm.size == 0
  2146. b_n, (scale_n, perm_n) = matrix_balance(a, separate=True)
  2147. assert b.dtype == b_n.dtype
  2148. assert scale.dtype == scale_n.dtype
  2149. assert perm.dtype == perm_n.dtype