test_utils.py 54 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626
  1. import warnings
  2. import sys
  3. import os
  4. import itertools
  5. import pytest
  6. import weakref
  7. import numpy as np
  8. from numpy.testing import (
  9. assert_equal, assert_array_equal, assert_almost_equal,
  10. assert_array_almost_equal, assert_array_less, build_err_msg,
  11. assert_raises, assert_warns, assert_no_warnings, assert_allclose,
  12. assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp,
  13. clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
  14. tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
  15. )
  16. class _GenericTest:
  17. def _test_equal(self, a, b):
  18. self._assert_func(a, b)
  19. def _test_not_equal(self, a, b):
  20. with assert_raises(AssertionError):
  21. self._assert_func(a, b)
  22. def test_array_rank1_eq(self):
  23. """Test two equal array of rank 1 are found equal."""
  24. a = np.array([1, 2])
  25. b = np.array([1, 2])
  26. self._test_equal(a, b)
  27. def test_array_rank1_noteq(self):
  28. """Test two different array of rank 1 are found not equal."""
  29. a = np.array([1, 2])
  30. b = np.array([2, 2])
  31. self._test_not_equal(a, b)
  32. def test_array_rank2_eq(self):
  33. """Test two equal array of rank 2 are found equal."""
  34. a = np.array([[1, 2], [3, 4]])
  35. b = np.array([[1, 2], [3, 4]])
  36. self._test_equal(a, b)
  37. def test_array_diffshape(self):
  38. """Test two arrays with different shapes are found not equal."""
  39. a = np.array([1, 2])
  40. b = np.array([[1, 2], [1, 2]])
  41. self._test_not_equal(a, b)
  42. def test_objarray(self):
  43. """Test object arrays."""
  44. a = np.array([1, 1], dtype=object)
  45. self._test_equal(a, 1)
  46. def test_array_likes(self):
  47. self._test_equal([1, 2, 3], (1, 2, 3))
  48. class TestArrayEqual(_GenericTest):
  49. def setup_method(self):
  50. self._assert_func = assert_array_equal
  51. def test_generic_rank1(self):
  52. """Test rank 1 array for all dtypes."""
  53. def foo(t):
  54. a = np.empty(2, t)
  55. a.fill(1)
  56. b = a.copy()
  57. c = a.copy()
  58. c.fill(0)
  59. self._test_equal(a, b)
  60. self._test_not_equal(c, b)
  61. # Test numeric types and object
  62. for t in '?bhilqpBHILQPfdgFDG':
  63. foo(t)
  64. # Test strings
  65. for t in ['S1', 'U1']:
  66. foo(t)
  67. def test_0_ndim_array(self):
  68. x = np.array(473963742225900817127911193656584771)
  69. y = np.array(18535119325151578301457182298393896)
  70. assert_raises(AssertionError, self._assert_func, x, y)
  71. y = x
  72. self._assert_func(x, y)
  73. x = np.array(43)
  74. y = np.array(10)
  75. assert_raises(AssertionError, self._assert_func, x, y)
  76. y = x
  77. self._assert_func(x, y)
  78. def test_generic_rank3(self):
  79. """Test rank 3 array for all dtypes."""
  80. def foo(t):
  81. a = np.empty((4, 2, 3), t)
  82. a.fill(1)
  83. b = a.copy()
  84. c = a.copy()
  85. c.fill(0)
  86. self._test_equal(a, b)
  87. self._test_not_equal(c, b)
  88. # Test numeric types and object
  89. for t in '?bhilqpBHILQPfdgFDG':
  90. foo(t)
  91. # Test strings
  92. for t in ['S1', 'U1']:
  93. foo(t)
  94. def test_nan_array(self):
  95. """Test arrays with nan values in them."""
  96. a = np.array([1, 2, np.nan])
  97. b = np.array([1, 2, np.nan])
  98. self._test_equal(a, b)
  99. c = np.array([1, 2, 3])
  100. self._test_not_equal(c, b)
  101. def test_string_arrays(self):
  102. """Test two arrays with different shapes are found not equal."""
  103. a = np.array(['floupi', 'floupa'])
  104. b = np.array(['floupi', 'floupa'])
  105. self._test_equal(a, b)
  106. c = np.array(['floupipi', 'floupa'])
  107. self._test_not_equal(c, b)
  108. def test_recarrays(self):
  109. """Test record arrays."""
  110. a = np.empty(2, [('floupi', float), ('floupa', float)])
  111. a['floupi'] = [1, 2]
  112. a['floupa'] = [1, 2]
  113. b = a.copy()
  114. self._test_equal(a, b)
  115. c = np.empty(2, [('floupipi', float),
  116. ('floupi', float), ('floupa', float)])
  117. c['floupipi'] = a['floupi'].copy()
  118. c['floupa'] = a['floupa'].copy()
  119. with pytest.raises(TypeError):
  120. self._test_not_equal(c, b)
  121. def test_masked_nan_inf(self):
  122. # Regression test for gh-11121
  123. a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
  124. b = np.array([3., np.nan, 6.5])
  125. self._test_equal(a, b)
  126. self._test_equal(b, a)
  127. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
  128. b = np.array([np.inf, 4., 6.5])
  129. self._test_equal(a, b)
  130. self._test_equal(b, a)
  131. def test_subclass_that_overrides_eq(self):
  132. # While we cannot guarantee testing functions will always work for
  133. # subclasses, the tests should ideally rely only on subclasses having
  134. # comparison operators, not on them being able to store booleans
  135. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  136. class MyArray(np.ndarray):
  137. def __eq__(self, other):
  138. return bool(np.equal(self, other).all())
  139. def __ne__(self, other):
  140. return not self == other
  141. a = np.array([1., 2.]).view(MyArray)
  142. b = np.array([2., 3.]).view(MyArray)
  143. assert_(type(a == a), bool)
  144. assert_(a == a)
  145. assert_(a != b)
  146. self._test_equal(a, a)
  147. self._test_not_equal(a, b)
  148. self._test_not_equal(b, a)
  149. def test_subclass_that_does_not_implement_npall(self):
  150. class MyArray(np.ndarray):
  151. def __array_function__(self, *args, **kwargs):
  152. return NotImplemented
  153. a = np.array([1., 2.]).view(MyArray)
  154. b = np.array([2., 3.]).view(MyArray)
  155. with assert_raises(TypeError):
  156. np.all(a)
  157. self._test_equal(a, a)
  158. self._test_not_equal(a, b)
  159. self._test_not_equal(b, a)
  160. def test_suppress_overflow_warnings(self):
  161. # Based on issue #18992
  162. with pytest.raises(AssertionError):
  163. with np.errstate(all="raise"):
  164. np.testing.assert_array_equal(
  165. np.array([1, 2, 3], np.float32),
  166. np.array([1, 1e-40, 3], np.float32))
  167. def test_array_vs_scalar_is_equal(self):
  168. """Test comparing an array with a scalar when all values are equal."""
  169. a = np.array([1., 1., 1.])
  170. b = 1.
  171. self._test_equal(a, b)
  172. def test_array_vs_scalar_not_equal(self):
  173. """Test comparing an array with a scalar when not all values equal."""
  174. a = np.array([1., 2., 3.])
  175. b = 1.
  176. self._test_not_equal(a, b)
  177. def test_array_vs_scalar_strict(self):
  178. """Test comparing an array with a scalar with strict option."""
  179. a = np.array([1., 1., 1.])
  180. b = 1.
  181. with pytest.raises(AssertionError):
  182. assert_array_equal(a, b, strict=True)
  183. def test_array_vs_array_strict(self):
  184. """Test comparing two arrays with strict option."""
  185. a = np.array([1., 1., 1.])
  186. b = np.array([1., 1., 1.])
  187. assert_array_equal(a, b, strict=True)
  188. def test_array_vs_float_array_strict(self):
  189. """Test comparing two arrays with strict option."""
  190. a = np.array([1, 1, 1])
  191. b = np.array([1., 1., 1.])
  192. with pytest.raises(AssertionError):
  193. assert_array_equal(a, b, strict=True)
  194. class TestBuildErrorMessage:
  195. def test_build_err_msg_defaults(self):
  196. x = np.array([1.00001, 2.00002, 3.00003])
  197. y = np.array([1.00002, 2.00003, 3.00004])
  198. err_msg = 'There is a mismatch'
  199. a = build_err_msg([x, y], err_msg)
  200. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  201. '1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, '
  202. '2.00003, 3.00004])')
  203. assert_equal(a, b)
  204. def test_build_err_msg_no_verbose(self):
  205. x = np.array([1.00001, 2.00002, 3.00003])
  206. y = np.array([1.00002, 2.00003, 3.00004])
  207. err_msg = 'There is a mismatch'
  208. a = build_err_msg([x, y], err_msg, verbose=False)
  209. b = '\nItems are not equal: There is a mismatch'
  210. assert_equal(a, b)
  211. def test_build_err_msg_custom_names(self):
  212. x = np.array([1.00001, 2.00002, 3.00003])
  213. y = np.array([1.00002, 2.00003, 3.00004])
  214. err_msg = 'There is a mismatch'
  215. a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
  216. b = ('\nItems are not equal: There is a mismatch\n FOO: array(['
  217. '1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, '
  218. '3.00004])')
  219. assert_equal(a, b)
  220. def test_build_err_msg_custom_precision(self):
  221. x = np.array([1.000000001, 2.00002, 3.00003])
  222. y = np.array([1.000000002, 2.00003, 3.00004])
  223. err_msg = 'There is a mismatch'
  224. a = build_err_msg([x, y], err_msg, precision=10)
  225. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  226. '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array(['
  227. '1.000000002, 2.00003 , 3.00004 ])')
  228. assert_equal(a, b)
  229. class TestEqual(TestArrayEqual):
  230. def setup_method(self):
  231. self._assert_func = assert_equal
  232. def test_nan_items(self):
  233. self._assert_func(np.nan, np.nan)
  234. self._assert_func([np.nan], [np.nan])
  235. self._test_not_equal(np.nan, [np.nan])
  236. self._test_not_equal(np.nan, 1)
  237. def test_inf_items(self):
  238. self._assert_func(np.inf, np.inf)
  239. self._assert_func([np.inf], [np.inf])
  240. self._test_not_equal(np.inf, [np.inf])
  241. def test_datetime(self):
  242. self._test_equal(
  243. np.datetime64("2017-01-01", "s"),
  244. np.datetime64("2017-01-01", "s")
  245. )
  246. self._test_equal(
  247. np.datetime64("2017-01-01", "s"),
  248. np.datetime64("2017-01-01", "m")
  249. )
  250. # gh-10081
  251. self._test_not_equal(
  252. np.datetime64("2017-01-01", "s"),
  253. np.datetime64("2017-01-02", "s")
  254. )
  255. self._test_not_equal(
  256. np.datetime64("2017-01-01", "s"),
  257. np.datetime64("2017-01-02", "m")
  258. )
  259. def test_nat_items(self):
  260. # not a datetime
  261. nadt_no_unit = np.datetime64("NaT")
  262. nadt_s = np.datetime64("NaT", "s")
  263. nadt_d = np.datetime64("NaT", "ns")
  264. # not a timedelta
  265. natd_no_unit = np.timedelta64("NaT")
  266. natd_s = np.timedelta64("NaT", "s")
  267. natd_d = np.timedelta64("NaT", "ns")
  268. dts = [nadt_no_unit, nadt_s, nadt_d]
  269. tds = [natd_no_unit, natd_s, natd_d]
  270. for a, b in itertools.product(dts, dts):
  271. self._assert_func(a, b)
  272. self._assert_func([a], [b])
  273. self._test_not_equal([a], b)
  274. for a, b in itertools.product(tds, tds):
  275. self._assert_func(a, b)
  276. self._assert_func([a], [b])
  277. self._test_not_equal([a], b)
  278. for a, b in itertools.product(tds, dts):
  279. self._test_not_equal(a, b)
  280. self._test_not_equal(a, [b])
  281. self._test_not_equal([a], [b])
  282. self._test_not_equal([a], np.datetime64("2017-01-01", "s"))
  283. self._test_not_equal([b], np.datetime64("2017-01-01", "s"))
  284. self._test_not_equal([a], np.timedelta64(123, "s"))
  285. self._test_not_equal([b], np.timedelta64(123, "s"))
  286. def test_non_numeric(self):
  287. self._assert_func('ab', 'ab')
  288. self._test_not_equal('ab', 'abb')
  289. def test_complex_item(self):
  290. self._assert_func(complex(1, 2), complex(1, 2))
  291. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  292. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  293. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  294. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  295. def test_negative_zero(self):
  296. self._test_not_equal(np.PZERO, np.NZERO)
  297. def test_complex(self):
  298. x = np.array([complex(1, 2), complex(1, np.nan)])
  299. y = np.array([complex(1, 2), complex(1, 2)])
  300. self._assert_func(x, x)
  301. self._test_not_equal(x, y)
  302. def test_object(self):
  303. #gh-12942
  304. import datetime
  305. a = np.array([datetime.datetime(2000, 1, 1),
  306. datetime.datetime(2000, 1, 2)])
  307. self._test_not_equal(a, a[::-1])
  308. class TestArrayAlmostEqual(_GenericTest):
  309. def setup_method(self):
  310. self._assert_func = assert_array_almost_equal
  311. def test_closeness(self):
  312. # Note that in the course of time we ended up with
  313. # `abs(x - y) < 1.5 * 10**(-decimal)`
  314. # instead of the previously documented
  315. # `abs(x - y) < 0.5 * 10**(-decimal)`
  316. # so this check serves to preserve the wrongness.
  317. # test scalars
  318. self._assert_func(1.499999, 0.0, decimal=0)
  319. assert_raises(AssertionError,
  320. lambda: self._assert_func(1.5, 0.0, decimal=0))
  321. # test arrays
  322. self._assert_func([1.499999], [0.0], decimal=0)
  323. assert_raises(AssertionError,
  324. lambda: self._assert_func([1.5], [0.0], decimal=0))
  325. def test_simple(self):
  326. x = np.array([1234.2222])
  327. y = np.array([1234.2223])
  328. self._assert_func(x, y, decimal=3)
  329. self._assert_func(x, y, decimal=4)
  330. assert_raises(AssertionError,
  331. lambda: self._assert_func(x, y, decimal=5))
  332. def test_nan(self):
  333. anan = np.array([np.nan])
  334. aone = np.array([1])
  335. ainf = np.array([np.inf])
  336. self._assert_func(anan, anan)
  337. assert_raises(AssertionError,
  338. lambda: self._assert_func(anan, aone))
  339. assert_raises(AssertionError,
  340. lambda: self._assert_func(anan, ainf))
  341. assert_raises(AssertionError,
  342. lambda: self._assert_func(ainf, anan))
  343. def test_inf(self):
  344. a = np.array([[1., 2.], [3., 4.]])
  345. b = a.copy()
  346. a[0, 0] = np.inf
  347. assert_raises(AssertionError,
  348. lambda: self._assert_func(a, b))
  349. b[0, 0] = -np.inf
  350. assert_raises(AssertionError,
  351. lambda: self._assert_func(a, b))
  352. def test_subclass(self):
  353. a = np.array([[1., 2.], [3., 4.]])
  354. b = np.ma.masked_array([[1., 2.], [0., 4.]],
  355. [[False, False], [True, False]])
  356. self._assert_func(a, b)
  357. self._assert_func(b, a)
  358. self._assert_func(b, b)
  359. # Test fully masked as well (see gh-11123).
  360. a = np.ma.MaskedArray(3.5, mask=True)
  361. b = np.array([3., 4., 6.5])
  362. self._test_equal(a, b)
  363. self._test_equal(b, a)
  364. a = np.ma.masked
  365. b = np.array([3., 4., 6.5])
  366. self._test_equal(a, b)
  367. self._test_equal(b, a)
  368. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  369. b = np.array([1., 2., 3.])
  370. self._test_equal(a, b)
  371. self._test_equal(b, a)
  372. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  373. b = np.array(1.)
  374. self._test_equal(a, b)
  375. self._test_equal(b, a)
  376. def test_subclass_that_cannot_be_bool(self):
  377. # While we cannot guarantee testing functions will always work for
  378. # subclasses, the tests should ideally rely only on subclasses having
  379. # comparison operators, not on them being able to store booleans
  380. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  381. class MyArray(np.ndarray):
  382. def __eq__(self, other):
  383. return super().__eq__(other).view(np.ndarray)
  384. def __lt__(self, other):
  385. return super().__lt__(other).view(np.ndarray)
  386. def all(self, *args, **kwargs):
  387. raise NotImplementedError
  388. a = np.array([1., 2.]).view(MyArray)
  389. self._assert_func(a, a)
  390. class TestAlmostEqual(_GenericTest):
  391. def setup_method(self):
  392. self._assert_func = assert_almost_equal
  393. def test_closeness(self):
  394. # Note that in the course of time we ended up with
  395. # `abs(x - y) < 1.5 * 10**(-decimal)`
  396. # instead of the previously documented
  397. # `abs(x - y) < 0.5 * 10**(-decimal)`
  398. # so this check serves to preserve the wrongness.
  399. # test scalars
  400. self._assert_func(1.499999, 0.0, decimal=0)
  401. assert_raises(AssertionError,
  402. lambda: self._assert_func(1.5, 0.0, decimal=0))
  403. # test arrays
  404. self._assert_func([1.499999], [0.0], decimal=0)
  405. assert_raises(AssertionError,
  406. lambda: self._assert_func([1.5], [0.0], decimal=0))
  407. def test_nan_item(self):
  408. self._assert_func(np.nan, np.nan)
  409. assert_raises(AssertionError,
  410. lambda: self._assert_func(np.nan, 1))
  411. assert_raises(AssertionError,
  412. lambda: self._assert_func(np.nan, np.inf))
  413. assert_raises(AssertionError,
  414. lambda: self._assert_func(np.inf, np.nan))
  415. def test_inf_item(self):
  416. self._assert_func(np.inf, np.inf)
  417. self._assert_func(-np.inf, -np.inf)
  418. assert_raises(AssertionError,
  419. lambda: self._assert_func(np.inf, 1))
  420. assert_raises(AssertionError,
  421. lambda: self._assert_func(-np.inf, np.inf))
  422. def test_simple_item(self):
  423. self._test_not_equal(1, 2)
  424. def test_complex_item(self):
  425. self._assert_func(complex(1, 2), complex(1, 2))
  426. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  427. self._assert_func(complex(np.inf, np.nan), complex(np.inf, np.nan))
  428. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  429. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  430. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  431. def test_complex(self):
  432. x = np.array([complex(1, 2), complex(1, np.nan)])
  433. z = np.array([complex(1, 2), complex(np.nan, 1)])
  434. y = np.array([complex(1, 2), complex(1, 2)])
  435. self._assert_func(x, x)
  436. self._test_not_equal(x, y)
  437. self._test_not_equal(x, z)
  438. def test_error_message(self):
  439. """Check the message is formatted correctly for the decimal value.
  440. Also check the message when input includes inf or nan (gh12200)"""
  441. x = np.array([1.00000000001, 2.00000000002, 3.00003])
  442. y = np.array([1.00000000002, 2.00000000003, 3.00004])
  443. # Test with a different amount of decimal digits
  444. with pytest.raises(AssertionError) as exc_info:
  445. self._assert_func(x, y, decimal=12)
  446. msgs = str(exc_info.value).split('\n')
  447. assert_equal(msgs[3], 'Mismatched elements: 3 / 3 (100%)')
  448. assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
  449. assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
  450. assert_equal(
  451. msgs[6],
  452. ' x: array([1.00000000001, 2.00000000002, 3.00003 ])')
  453. assert_equal(
  454. msgs[7],
  455. ' y: array([1.00000000002, 2.00000000003, 3.00004 ])')
  456. # With the default value of decimal digits, only the 3rd element
  457. # differs. Note that we only check for the formatting of the arrays
  458. # themselves.
  459. with pytest.raises(AssertionError) as exc_info:
  460. self._assert_func(x, y)
  461. msgs = str(exc_info.value).split('\n')
  462. assert_equal(msgs[3], 'Mismatched elements: 1 / 3 (33.3%)')
  463. assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
  464. assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
  465. assert_equal(msgs[6], ' x: array([1. , 2. , 3.00003])')
  466. assert_equal(msgs[7], ' y: array([1. , 2. , 3.00004])')
  467. # Check the error message when input includes inf
  468. x = np.array([np.inf, 0])
  469. y = np.array([np.inf, 1])
  470. with pytest.raises(AssertionError) as exc_info:
  471. self._assert_func(x, y)
  472. msgs = str(exc_info.value).split('\n')
  473. assert_equal(msgs[3], 'Mismatched elements: 1 / 2 (50%)')
  474. assert_equal(msgs[4], 'Max absolute difference: 1.')
  475. assert_equal(msgs[5], 'Max relative difference: 1.')
  476. assert_equal(msgs[6], ' x: array([inf, 0.])')
  477. assert_equal(msgs[7], ' y: array([inf, 1.])')
  478. # Check the error message when dividing by zero
  479. x = np.array([1, 2])
  480. y = np.array([0, 0])
  481. with pytest.raises(AssertionError) as exc_info:
  482. self._assert_func(x, y)
  483. msgs = str(exc_info.value).split('\n')
  484. assert_equal(msgs[3], 'Mismatched elements: 2 / 2 (100%)')
  485. assert_equal(msgs[4], 'Max absolute difference: 2')
  486. assert_equal(msgs[5], 'Max relative difference: inf')
  487. def test_error_message_2(self):
  488. """Check the message is formatted correctly when either x or y is a scalar."""
  489. x = 2
  490. y = np.ones(20)
  491. with pytest.raises(AssertionError) as exc_info:
  492. self._assert_func(x, y)
  493. msgs = str(exc_info.value).split('\n')
  494. assert_equal(msgs[3], 'Mismatched elements: 20 / 20 (100%)')
  495. assert_equal(msgs[4], 'Max absolute difference: 1.')
  496. assert_equal(msgs[5], 'Max relative difference: 1.')
  497. y = 2
  498. x = np.ones(20)
  499. with pytest.raises(AssertionError) as exc_info:
  500. self._assert_func(x, y)
  501. msgs = str(exc_info.value).split('\n')
  502. assert_equal(msgs[3], 'Mismatched elements: 20 / 20 (100%)')
  503. assert_equal(msgs[4], 'Max absolute difference: 1.')
  504. assert_equal(msgs[5], 'Max relative difference: 0.5')
  505. def test_subclass_that_cannot_be_bool(self):
  506. # While we cannot guarantee testing functions will always work for
  507. # subclasses, the tests should ideally rely only on subclasses having
  508. # comparison operators, not on them being able to store booleans
  509. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  510. class MyArray(np.ndarray):
  511. def __eq__(self, other):
  512. return super().__eq__(other).view(np.ndarray)
  513. def __lt__(self, other):
  514. return super().__lt__(other).view(np.ndarray)
  515. def all(self, *args, **kwargs):
  516. raise NotImplementedError
  517. a = np.array([1., 2.]).view(MyArray)
  518. self._assert_func(a, a)
  519. class TestApproxEqual:
  520. def setup_method(self):
  521. self._assert_func = assert_approx_equal
  522. def test_simple_0d_arrays(self):
  523. x = np.array(1234.22)
  524. y = np.array(1234.23)
  525. self._assert_func(x, y, significant=5)
  526. self._assert_func(x, y, significant=6)
  527. assert_raises(AssertionError,
  528. lambda: self._assert_func(x, y, significant=7))
  529. def test_simple_items(self):
  530. x = 1234.22
  531. y = 1234.23
  532. self._assert_func(x, y, significant=4)
  533. self._assert_func(x, y, significant=5)
  534. self._assert_func(x, y, significant=6)
  535. assert_raises(AssertionError,
  536. lambda: self._assert_func(x, y, significant=7))
  537. def test_nan_array(self):
  538. anan = np.array(np.nan)
  539. aone = np.array(1)
  540. ainf = np.array(np.inf)
  541. self._assert_func(anan, anan)
  542. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  543. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  544. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  545. def test_nan_items(self):
  546. anan = np.array(np.nan)
  547. aone = np.array(1)
  548. ainf = np.array(np.inf)
  549. self._assert_func(anan, anan)
  550. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  551. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  552. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  553. class TestArrayAssertLess:
  554. def setup_method(self):
  555. self._assert_func = assert_array_less
  556. def test_simple_arrays(self):
  557. x = np.array([1.1, 2.2])
  558. y = np.array([1.2, 2.3])
  559. self._assert_func(x, y)
  560. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  561. y = np.array([1.0, 2.3])
  562. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  563. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  564. def test_rank2(self):
  565. x = np.array([[1.1, 2.2], [3.3, 4.4]])
  566. y = np.array([[1.2, 2.3], [3.4, 4.5]])
  567. self._assert_func(x, y)
  568. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  569. y = np.array([[1.0, 2.3], [3.4, 4.5]])
  570. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  571. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  572. def test_rank3(self):
  573. x = np.ones(shape=(2, 2, 2))
  574. y = np.ones(shape=(2, 2, 2))+1
  575. self._assert_func(x, y)
  576. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  577. y[0, 0, 0] = 0
  578. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  579. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  580. def test_simple_items(self):
  581. x = 1.1
  582. y = 2.2
  583. self._assert_func(x, y)
  584. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  585. y = np.array([2.2, 3.3])
  586. self._assert_func(x, y)
  587. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  588. y = np.array([1.0, 3.3])
  589. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  590. def test_nan_noncompare(self):
  591. anan = np.array(np.nan)
  592. aone = np.array(1)
  593. ainf = np.array(np.inf)
  594. self._assert_func(anan, anan)
  595. assert_raises(AssertionError, lambda: self._assert_func(aone, anan))
  596. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  597. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  598. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  599. def test_nan_noncompare_array(self):
  600. x = np.array([1.1, 2.2, 3.3])
  601. anan = np.array(np.nan)
  602. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  603. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  604. x = np.array([1.1, 2.2, np.nan])
  605. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  606. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  607. y = np.array([1.0, 2.0, np.nan])
  608. self._assert_func(y, x)
  609. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  610. def test_inf_compare(self):
  611. aone = np.array(1)
  612. ainf = np.array(np.inf)
  613. self._assert_func(aone, ainf)
  614. self._assert_func(-ainf, aone)
  615. self._assert_func(-ainf, ainf)
  616. assert_raises(AssertionError, lambda: self._assert_func(ainf, aone))
  617. assert_raises(AssertionError, lambda: self._assert_func(aone, -ainf))
  618. assert_raises(AssertionError, lambda: self._assert_func(ainf, ainf))
  619. assert_raises(AssertionError, lambda: self._assert_func(ainf, -ainf))
  620. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -ainf))
  621. def test_inf_compare_array(self):
  622. x = np.array([1.1, 2.2, np.inf])
  623. ainf = np.array(np.inf)
  624. assert_raises(AssertionError, lambda: self._assert_func(x, ainf))
  625. assert_raises(AssertionError, lambda: self._assert_func(ainf, x))
  626. assert_raises(AssertionError, lambda: self._assert_func(x, -ainf))
  627. assert_raises(AssertionError, lambda: self._assert_func(-x, -ainf))
  628. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
  629. self._assert_func(-ainf, x)
  630. class TestWarns:
  631. def test_warn(self):
  632. def f():
  633. warnings.warn("yo")
  634. return 3
  635. before_filters = sys.modules['warnings'].filters[:]
  636. assert_equal(assert_warns(UserWarning, f), 3)
  637. after_filters = sys.modules['warnings'].filters
  638. assert_raises(AssertionError, assert_no_warnings, f)
  639. assert_equal(assert_no_warnings(lambda x: x, 1), 1)
  640. # Check that the warnings state is unchanged
  641. assert_equal(before_filters, after_filters,
  642. "assert_warns does not preserver warnings state")
  643. def test_context_manager(self):
  644. before_filters = sys.modules['warnings'].filters[:]
  645. with assert_warns(UserWarning):
  646. warnings.warn("yo")
  647. after_filters = sys.modules['warnings'].filters
  648. def no_warnings():
  649. with assert_no_warnings():
  650. warnings.warn("yo")
  651. assert_raises(AssertionError, no_warnings)
  652. assert_equal(before_filters, after_filters,
  653. "assert_warns does not preserver warnings state")
  654. def test_warn_wrong_warning(self):
  655. def f():
  656. warnings.warn("yo", DeprecationWarning)
  657. failed = False
  658. with warnings.catch_warnings():
  659. warnings.simplefilter("error", DeprecationWarning)
  660. try:
  661. # Should raise a DeprecationWarning
  662. assert_warns(UserWarning, f)
  663. failed = True
  664. except DeprecationWarning:
  665. pass
  666. if failed:
  667. raise AssertionError("wrong warning caught by assert_warn")
  668. class TestAssertAllclose:
  669. def test_simple(self):
  670. x = 1e-3
  671. y = 1e-9
  672. assert_allclose(x, y, atol=1)
  673. assert_raises(AssertionError, assert_allclose, x, y)
  674. a = np.array([x, y, x, y])
  675. b = np.array([x, y, x, x])
  676. assert_allclose(a, b, atol=1)
  677. assert_raises(AssertionError, assert_allclose, a, b)
  678. b[-1] = y * (1 + 1e-8)
  679. assert_allclose(a, b)
  680. assert_raises(AssertionError, assert_allclose, a, b, rtol=1e-9)
  681. assert_allclose(6, 10, rtol=0.5)
  682. assert_raises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
  683. def test_min_int(self):
  684. a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
  685. # Should not raise:
  686. assert_allclose(a, a)
  687. def test_report_fail_percentage(self):
  688. a = np.array([1, 1, 1, 1])
  689. b = np.array([1, 1, 1, 2])
  690. with pytest.raises(AssertionError) as exc_info:
  691. assert_allclose(a, b)
  692. msg = str(exc_info.value)
  693. assert_('Mismatched elements: 1 / 4 (25%)\n'
  694. 'Max absolute difference: 1\n'
  695. 'Max relative difference: 0.5' in msg)
  696. def test_equal_nan(self):
  697. a = np.array([np.nan])
  698. b = np.array([np.nan])
  699. # Should not raise:
  700. assert_allclose(a, b, equal_nan=True)
  701. def test_not_equal_nan(self):
  702. a = np.array([np.nan])
  703. b = np.array([np.nan])
  704. assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
  705. def test_equal_nan_default(self):
  706. # Make sure equal_nan default behavior remains unchanged. (All
  707. # of these functions use assert_array_compare under the hood.)
  708. # None of these should raise.
  709. a = np.array([np.nan])
  710. b = np.array([np.nan])
  711. assert_array_equal(a, b)
  712. assert_array_almost_equal(a, b)
  713. assert_array_less(a, b)
  714. assert_allclose(a, b)
  715. def test_report_max_relative_error(self):
  716. a = np.array([0, 1])
  717. b = np.array([0, 2])
  718. with pytest.raises(AssertionError) as exc_info:
  719. assert_allclose(a, b)
  720. msg = str(exc_info.value)
  721. assert_('Max relative difference: 0.5' in msg)
  722. def test_timedelta(self):
  723. # see gh-18286
  724. a = np.array([[1, 2, 3, "NaT"]], dtype="m8[ns]")
  725. assert_allclose(a, a)
  726. def test_error_message_unsigned(self):
  727. """Check the the message is formatted correctly when overflow can occur
  728. (gh21768)"""
  729. # Ensure to test for potential overflow in the case of:
  730. # x - y
  731. # and
  732. # y - x
  733. x = np.asarray([0, 1, 8], dtype='uint8')
  734. y = np.asarray([4, 4, 4], dtype='uint8')
  735. with pytest.raises(AssertionError) as exc_info:
  736. assert_allclose(x, y, atol=3)
  737. msgs = str(exc_info.value).split('\n')
  738. assert_equal(msgs[4], 'Max absolute difference: 4')
  739. class TestArrayAlmostEqualNulp:
  740. def test_float64_pass(self):
  741. # The number of units of least precision
  742. # In this case, use a few places above the lowest level (ie nulp=1)
  743. nulp = 5
  744. x = np.linspace(-20, 20, 50, dtype=np.float64)
  745. x = 10**x
  746. x = np.r_[-x, x]
  747. # Addition
  748. eps = np.finfo(x.dtype).eps
  749. y = x + x*eps*nulp/2.
  750. assert_array_almost_equal_nulp(x, y, nulp)
  751. # Subtraction
  752. epsneg = np.finfo(x.dtype).epsneg
  753. y = x - x*epsneg*nulp/2.
  754. assert_array_almost_equal_nulp(x, y, nulp)
  755. def test_float64_fail(self):
  756. nulp = 5
  757. x = np.linspace(-20, 20, 50, dtype=np.float64)
  758. x = 10**x
  759. x = np.r_[-x, x]
  760. eps = np.finfo(x.dtype).eps
  761. y = x + x*eps*nulp*2.
  762. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  763. x, y, nulp)
  764. epsneg = np.finfo(x.dtype).epsneg
  765. y = x - x*epsneg*nulp*2.
  766. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  767. x, y, nulp)
  768. def test_float64_ignore_nan(self):
  769. # Ignore ULP differences between various NAN's
  770. # Note that MIPS may reverse quiet and signaling nans
  771. # so we use the builtin version as a base.
  772. offset = np.uint64(0xffffffff)
  773. nan1_i64 = np.array(np.nan, dtype=np.float64).view(np.uint64)
  774. nan2_i64 = nan1_i64 ^ offset # nan payload on MIPS is all ones.
  775. nan1_f64 = nan1_i64.view(np.float64)
  776. nan2_f64 = nan2_i64.view(np.float64)
  777. assert_array_max_ulp(nan1_f64, nan2_f64, 0)
  778. def test_float32_pass(self):
  779. nulp = 5
  780. x = np.linspace(-20, 20, 50, dtype=np.float32)
  781. x = 10**x
  782. x = np.r_[-x, x]
  783. eps = np.finfo(x.dtype).eps
  784. y = x + x*eps*nulp/2.
  785. assert_array_almost_equal_nulp(x, y, nulp)
  786. epsneg = np.finfo(x.dtype).epsneg
  787. y = x - x*epsneg*nulp/2.
  788. assert_array_almost_equal_nulp(x, y, nulp)
  789. def test_float32_fail(self):
  790. nulp = 5
  791. x = np.linspace(-20, 20, 50, dtype=np.float32)
  792. x = 10**x
  793. x = np.r_[-x, x]
  794. eps = np.finfo(x.dtype).eps
  795. y = x + x*eps*nulp*2.
  796. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  797. x, y, nulp)
  798. epsneg = np.finfo(x.dtype).epsneg
  799. y = x - x*epsneg*nulp*2.
  800. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  801. x, y, nulp)
  802. def test_float32_ignore_nan(self):
  803. # Ignore ULP differences between various NAN's
  804. # Note that MIPS may reverse quiet and signaling nans
  805. # so we use the builtin version as a base.
  806. offset = np.uint32(0xffff)
  807. nan1_i32 = np.array(np.nan, dtype=np.float32).view(np.uint32)
  808. nan2_i32 = nan1_i32 ^ offset # nan payload on MIPS is all ones.
  809. nan1_f32 = nan1_i32.view(np.float32)
  810. nan2_f32 = nan2_i32.view(np.float32)
  811. assert_array_max_ulp(nan1_f32, nan2_f32, 0)
  812. def test_float16_pass(self):
  813. nulp = 5
  814. x = np.linspace(-4, 4, 10, dtype=np.float16)
  815. x = 10**x
  816. x = np.r_[-x, x]
  817. eps = np.finfo(x.dtype).eps
  818. y = x + x*eps*nulp/2.
  819. assert_array_almost_equal_nulp(x, y, nulp)
  820. epsneg = np.finfo(x.dtype).epsneg
  821. y = x - x*epsneg*nulp/2.
  822. assert_array_almost_equal_nulp(x, y, nulp)
  823. def test_float16_fail(self):
  824. nulp = 5
  825. x = np.linspace(-4, 4, 10, dtype=np.float16)
  826. x = 10**x
  827. x = np.r_[-x, x]
  828. eps = np.finfo(x.dtype).eps
  829. y = x + x*eps*nulp*2.
  830. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  831. x, y, nulp)
  832. epsneg = np.finfo(x.dtype).epsneg
  833. y = x - x*epsneg*nulp*2.
  834. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  835. x, y, nulp)
  836. def test_float16_ignore_nan(self):
  837. # Ignore ULP differences between various NAN's
  838. # Note that MIPS may reverse quiet and signaling nans
  839. # so we use the builtin version as a base.
  840. offset = np.uint16(0xff)
  841. nan1_i16 = np.array(np.nan, dtype=np.float16).view(np.uint16)
  842. nan2_i16 = nan1_i16 ^ offset # nan payload on MIPS is all ones.
  843. nan1_f16 = nan1_i16.view(np.float16)
  844. nan2_f16 = nan2_i16.view(np.float16)
  845. assert_array_max_ulp(nan1_f16, nan2_f16, 0)
  846. def test_complex128_pass(self):
  847. nulp = 5
  848. x = np.linspace(-20, 20, 50, dtype=np.float64)
  849. x = 10**x
  850. x = np.r_[-x, x]
  851. xi = x + x*1j
  852. eps = np.finfo(x.dtype).eps
  853. y = x + x*eps*nulp/2.
  854. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  855. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  856. # The test condition needs to be at least a factor of sqrt(2) smaller
  857. # because the real and imaginary parts both change
  858. y = x + x*eps*nulp/4.
  859. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  860. epsneg = np.finfo(x.dtype).epsneg
  861. y = x - x*epsneg*nulp/2.
  862. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  863. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  864. y = x - x*epsneg*nulp/4.
  865. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  866. def test_complex128_fail(self):
  867. nulp = 5
  868. x = np.linspace(-20, 20, 50, dtype=np.float64)
  869. x = 10**x
  870. x = np.r_[-x, x]
  871. xi = x + x*1j
  872. eps = np.finfo(x.dtype).eps
  873. y = x + x*eps*nulp*2.
  874. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  875. xi, x + y*1j, nulp)
  876. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  877. xi, y + x*1j, nulp)
  878. # The test condition needs to be at least a factor of sqrt(2) smaller
  879. # because the real and imaginary parts both change
  880. y = x + x*eps*nulp
  881. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  882. xi, y + y*1j, nulp)
  883. epsneg = np.finfo(x.dtype).epsneg
  884. y = x - x*epsneg*nulp*2.
  885. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  886. xi, x + y*1j, nulp)
  887. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  888. xi, y + x*1j, nulp)
  889. y = x - x*epsneg*nulp
  890. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  891. xi, y + y*1j, nulp)
  892. def test_complex64_pass(self):
  893. nulp = 5
  894. x = np.linspace(-20, 20, 50, dtype=np.float32)
  895. x = 10**x
  896. x = np.r_[-x, x]
  897. xi = x + x*1j
  898. eps = np.finfo(x.dtype).eps
  899. y = x + x*eps*nulp/2.
  900. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  901. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  902. y = x + x*eps*nulp/4.
  903. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  904. epsneg = np.finfo(x.dtype).epsneg
  905. y = x - x*epsneg*nulp/2.
  906. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  907. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  908. y = x - x*epsneg*nulp/4.
  909. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  910. def test_complex64_fail(self):
  911. nulp = 5
  912. x = np.linspace(-20, 20, 50, dtype=np.float32)
  913. x = 10**x
  914. x = np.r_[-x, x]
  915. xi = x + x*1j
  916. eps = np.finfo(x.dtype).eps
  917. y = x + x*eps*nulp*2.
  918. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  919. xi, x + y*1j, nulp)
  920. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  921. xi, y + x*1j, nulp)
  922. y = x + x*eps*nulp
  923. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  924. xi, y + y*1j, nulp)
  925. epsneg = np.finfo(x.dtype).epsneg
  926. y = x - x*epsneg*nulp*2.
  927. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  928. xi, x + y*1j, nulp)
  929. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  930. xi, y + x*1j, nulp)
  931. y = x - x*epsneg*nulp
  932. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  933. xi, y + y*1j, nulp)
  934. class TestULP:
  935. def test_equal(self):
  936. x = np.random.randn(10)
  937. assert_array_max_ulp(x, x, maxulp=0)
  938. def test_single(self):
  939. # Generate 1 + small deviation, check that adding eps gives a few UNL
  940. x = np.ones(10).astype(np.float32)
  941. x += 0.01 * np.random.randn(10).astype(np.float32)
  942. eps = np.finfo(np.float32).eps
  943. assert_array_max_ulp(x, x+eps, maxulp=20)
  944. def test_double(self):
  945. # Generate 1 + small deviation, check that adding eps gives a few UNL
  946. x = np.ones(10).astype(np.float64)
  947. x += 0.01 * np.random.randn(10).astype(np.float64)
  948. eps = np.finfo(np.float64).eps
  949. assert_array_max_ulp(x, x+eps, maxulp=200)
  950. def test_inf(self):
  951. for dt in [np.float32, np.float64]:
  952. inf = np.array([np.inf]).astype(dt)
  953. big = np.array([np.finfo(dt).max])
  954. assert_array_max_ulp(inf, big, maxulp=200)
  955. def test_nan(self):
  956. # Test that nan is 'far' from small, tiny, inf, max and min
  957. for dt in [np.float32, np.float64]:
  958. if dt == np.float32:
  959. maxulp = 1e6
  960. else:
  961. maxulp = 1e12
  962. inf = np.array([np.inf]).astype(dt)
  963. nan = np.array([np.nan]).astype(dt)
  964. big = np.array([np.finfo(dt).max])
  965. tiny = np.array([np.finfo(dt).tiny])
  966. zero = np.array([np.PZERO]).astype(dt)
  967. nzero = np.array([np.NZERO]).astype(dt)
  968. assert_raises(AssertionError,
  969. lambda: assert_array_max_ulp(nan, inf,
  970. maxulp=maxulp))
  971. assert_raises(AssertionError,
  972. lambda: assert_array_max_ulp(nan, big,
  973. maxulp=maxulp))
  974. assert_raises(AssertionError,
  975. lambda: assert_array_max_ulp(nan, tiny,
  976. maxulp=maxulp))
  977. assert_raises(AssertionError,
  978. lambda: assert_array_max_ulp(nan, zero,
  979. maxulp=maxulp))
  980. assert_raises(AssertionError,
  981. lambda: assert_array_max_ulp(nan, nzero,
  982. maxulp=maxulp))
  983. class TestStringEqual:
  984. def test_simple(self):
  985. assert_string_equal("hello", "hello")
  986. assert_string_equal("hello\nmultiline", "hello\nmultiline")
  987. with pytest.raises(AssertionError) as exc_info:
  988. assert_string_equal("foo\nbar", "hello\nbar")
  989. msg = str(exc_info.value)
  990. assert_equal(msg, "Differences in strings:\n- foo\n+ hello")
  991. assert_raises(AssertionError,
  992. lambda: assert_string_equal("foo", "hello"))
  993. def test_regex(self):
  994. assert_string_equal("a+*b", "a+*b")
  995. assert_raises(AssertionError,
  996. lambda: assert_string_equal("aaa", "a+b"))
  997. def assert_warn_len_equal(mod, n_in_context):
  998. try:
  999. mod_warns = mod.__warningregistry__
  1000. except AttributeError:
  1001. # the lack of a __warningregistry__
  1002. # attribute means that no warning has
  1003. # occurred; this can be triggered in
  1004. # a parallel test scenario, while in
  1005. # a serial test scenario an initial
  1006. # warning (and therefore the attribute)
  1007. # are always created first
  1008. mod_warns = {}
  1009. num_warns = len(mod_warns)
  1010. if 'version' in mod_warns:
  1011. # Python 3 adds a 'version' entry to the registry,
  1012. # do not count it.
  1013. num_warns -= 1
  1014. assert_equal(num_warns, n_in_context)
  1015. def test_warn_len_equal_call_scenarios():
  1016. # assert_warn_len_equal is called under
  1017. # varying circumstances depending on serial
  1018. # vs. parallel test scenarios; this test
  1019. # simply aims to probe both code paths and
  1020. # check that no assertion is uncaught
  1021. # parallel scenario -- no warning issued yet
  1022. class mod:
  1023. pass
  1024. mod_inst = mod()
  1025. assert_warn_len_equal(mod=mod_inst,
  1026. n_in_context=0)
  1027. # serial test scenario -- the __warningregistry__
  1028. # attribute should be present
  1029. class mod:
  1030. def __init__(self):
  1031. self.__warningregistry__ = {'warning1':1,
  1032. 'warning2':2}
  1033. mod_inst = mod()
  1034. assert_warn_len_equal(mod=mod_inst,
  1035. n_in_context=2)
  1036. def _get_fresh_mod():
  1037. # Get this module, with warning registry empty
  1038. my_mod = sys.modules[__name__]
  1039. try:
  1040. my_mod.__warningregistry__.clear()
  1041. except AttributeError:
  1042. # will not have a __warningregistry__ unless warning has been
  1043. # raised in the module at some point
  1044. pass
  1045. return my_mod
  1046. def test_clear_and_catch_warnings():
  1047. # Initial state of module, no warnings
  1048. my_mod = _get_fresh_mod()
  1049. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1050. with clear_and_catch_warnings(modules=[my_mod]):
  1051. warnings.simplefilter('ignore')
  1052. warnings.warn('Some warning')
  1053. assert_equal(my_mod.__warningregistry__, {})
  1054. # Without specified modules, don't clear warnings during context.
  1055. # catch_warnings doesn't make an entry for 'ignore'.
  1056. with clear_and_catch_warnings():
  1057. warnings.simplefilter('ignore')
  1058. warnings.warn('Some warning')
  1059. assert_warn_len_equal(my_mod, 0)
  1060. # Manually adding two warnings to the registry:
  1061. my_mod.__warningregistry__ = {'warning1': 1,
  1062. 'warning2': 2}
  1063. # Confirm that specifying module keeps old warning, does not add new
  1064. with clear_and_catch_warnings(modules=[my_mod]):
  1065. warnings.simplefilter('ignore')
  1066. warnings.warn('Another warning')
  1067. assert_warn_len_equal(my_mod, 2)
  1068. # Another warning, no module spec it clears up registry
  1069. with clear_and_catch_warnings():
  1070. warnings.simplefilter('ignore')
  1071. warnings.warn('Another warning')
  1072. assert_warn_len_equal(my_mod, 0)
  1073. def test_suppress_warnings_module():
  1074. # Initial state of module, no warnings
  1075. my_mod = _get_fresh_mod()
  1076. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1077. def warn_other_module():
  1078. # Apply along axis is implemented in python; stacklevel=2 means
  1079. # we end up inside its module, not ours.
  1080. def warn(arr):
  1081. warnings.warn("Some warning 2", stacklevel=2)
  1082. return arr
  1083. np.apply_along_axis(warn, 0, [0])
  1084. # Test module based warning suppression:
  1085. assert_warn_len_equal(my_mod, 0)
  1086. with suppress_warnings() as sup:
  1087. sup.record(UserWarning)
  1088. # suppress warning from other module (may have .pyc ending),
  1089. # if apply_along_axis is moved, had to be changed.
  1090. sup.filter(module=np.lib.shape_base)
  1091. warnings.warn("Some warning")
  1092. warn_other_module()
  1093. # Check that the suppression did test the file correctly (this module
  1094. # got filtered)
  1095. assert_equal(len(sup.log), 1)
  1096. assert_equal(sup.log[0].message.args[0], "Some warning")
  1097. assert_warn_len_equal(my_mod, 0)
  1098. sup = suppress_warnings()
  1099. # Will have to be changed if apply_along_axis is moved:
  1100. sup.filter(module=my_mod)
  1101. with sup:
  1102. warnings.warn('Some warning')
  1103. assert_warn_len_equal(my_mod, 0)
  1104. # And test repeat works:
  1105. sup.filter(module=my_mod)
  1106. with sup:
  1107. warnings.warn('Some warning')
  1108. assert_warn_len_equal(my_mod, 0)
  1109. # Without specified modules
  1110. with suppress_warnings():
  1111. warnings.simplefilter('ignore')
  1112. warnings.warn('Some warning')
  1113. assert_warn_len_equal(my_mod, 0)
  1114. def test_suppress_warnings_type():
  1115. # Initial state of module, no warnings
  1116. my_mod = _get_fresh_mod()
  1117. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1118. # Test module based warning suppression:
  1119. with suppress_warnings() as sup:
  1120. sup.filter(UserWarning)
  1121. warnings.warn('Some warning')
  1122. assert_warn_len_equal(my_mod, 0)
  1123. sup = suppress_warnings()
  1124. sup.filter(UserWarning)
  1125. with sup:
  1126. warnings.warn('Some warning')
  1127. assert_warn_len_equal(my_mod, 0)
  1128. # And test repeat works:
  1129. sup.filter(module=my_mod)
  1130. with sup:
  1131. warnings.warn('Some warning')
  1132. assert_warn_len_equal(my_mod, 0)
  1133. # Without specified modules
  1134. with suppress_warnings():
  1135. warnings.simplefilter('ignore')
  1136. warnings.warn('Some warning')
  1137. assert_warn_len_equal(my_mod, 0)
  1138. def test_suppress_warnings_decorate_no_record():
  1139. sup = suppress_warnings()
  1140. sup.filter(UserWarning)
  1141. @sup
  1142. def warn(category):
  1143. warnings.warn('Some warning', category)
  1144. with warnings.catch_warnings(record=True) as w:
  1145. warnings.simplefilter("always")
  1146. warn(UserWarning) # should be supppressed
  1147. warn(RuntimeWarning)
  1148. assert_equal(len(w), 1)
  1149. def test_suppress_warnings_record():
  1150. sup = suppress_warnings()
  1151. log1 = sup.record()
  1152. with sup:
  1153. log2 = sup.record(message='Some other warning 2')
  1154. sup.filter(message='Some warning')
  1155. warnings.warn('Some warning')
  1156. warnings.warn('Some other warning')
  1157. warnings.warn('Some other warning 2')
  1158. assert_equal(len(sup.log), 2)
  1159. assert_equal(len(log1), 1)
  1160. assert_equal(len(log2),1)
  1161. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1162. # Do it again, with the same context to see if some warnings survived:
  1163. with sup:
  1164. log2 = sup.record(message='Some other warning 2')
  1165. sup.filter(message='Some warning')
  1166. warnings.warn('Some warning')
  1167. warnings.warn('Some other warning')
  1168. warnings.warn('Some other warning 2')
  1169. assert_equal(len(sup.log), 2)
  1170. assert_equal(len(log1), 1)
  1171. assert_equal(len(log2), 1)
  1172. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1173. # Test nested:
  1174. with suppress_warnings() as sup:
  1175. sup.record()
  1176. with suppress_warnings() as sup2:
  1177. sup2.record(message='Some warning')
  1178. warnings.warn('Some warning')
  1179. warnings.warn('Some other warning')
  1180. assert_equal(len(sup2.log), 1)
  1181. assert_equal(len(sup.log), 1)
  1182. def test_suppress_warnings_forwarding():
  1183. def warn_other_module():
  1184. # Apply along axis is implemented in python; stacklevel=2 means
  1185. # we end up inside its module, not ours.
  1186. def warn(arr):
  1187. warnings.warn("Some warning", stacklevel=2)
  1188. return arr
  1189. np.apply_along_axis(warn, 0, [0])
  1190. with suppress_warnings() as sup:
  1191. sup.record()
  1192. with suppress_warnings("always"):
  1193. for i in range(2):
  1194. warnings.warn("Some warning")
  1195. assert_equal(len(sup.log), 2)
  1196. with suppress_warnings() as sup:
  1197. sup.record()
  1198. with suppress_warnings("location"):
  1199. for i in range(2):
  1200. warnings.warn("Some warning")
  1201. warnings.warn("Some warning")
  1202. assert_equal(len(sup.log), 2)
  1203. with suppress_warnings() as sup:
  1204. sup.record()
  1205. with suppress_warnings("module"):
  1206. for i in range(2):
  1207. warnings.warn("Some warning")
  1208. warnings.warn("Some warning")
  1209. warn_other_module()
  1210. assert_equal(len(sup.log), 2)
  1211. with suppress_warnings() as sup:
  1212. sup.record()
  1213. with suppress_warnings("once"):
  1214. for i in range(2):
  1215. warnings.warn("Some warning")
  1216. warnings.warn("Some other warning")
  1217. warn_other_module()
  1218. assert_equal(len(sup.log), 2)
  1219. def test_tempdir():
  1220. with tempdir() as tdir:
  1221. fpath = os.path.join(tdir, 'tmp')
  1222. with open(fpath, 'w'):
  1223. pass
  1224. assert_(not os.path.isdir(tdir))
  1225. raised = False
  1226. try:
  1227. with tempdir() as tdir:
  1228. raise ValueError()
  1229. except ValueError:
  1230. raised = True
  1231. assert_(raised)
  1232. assert_(not os.path.isdir(tdir))
  1233. def test_temppath():
  1234. with temppath() as fpath:
  1235. with open(fpath, 'w'):
  1236. pass
  1237. assert_(not os.path.isfile(fpath))
  1238. raised = False
  1239. try:
  1240. with temppath() as fpath:
  1241. raise ValueError()
  1242. except ValueError:
  1243. raised = True
  1244. assert_(raised)
  1245. assert_(not os.path.isfile(fpath))
  1246. class my_cacw(clear_and_catch_warnings):
  1247. class_modules = (sys.modules[__name__],)
  1248. def test_clear_and_catch_warnings_inherit():
  1249. # Test can subclass and add default modules
  1250. my_mod = _get_fresh_mod()
  1251. with my_cacw():
  1252. warnings.simplefilter('ignore')
  1253. warnings.warn('Some warning')
  1254. assert_equal(my_mod.__warningregistry__, {})
  1255. @pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
  1256. class TestAssertNoGcCycles:
  1257. """ Test assert_no_gc_cycles """
  1258. def test_passes(self):
  1259. def no_cycle():
  1260. b = []
  1261. b.append([])
  1262. return b
  1263. with assert_no_gc_cycles():
  1264. no_cycle()
  1265. assert_no_gc_cycles(no_cycle)
  1266. def test_asserts(self):
  1267. def make_cycle():
  1268. a = []
  1269. a.append(a)
  1270. a.append(a)
  1271. return a
  1272. with assert_raises(AssertionError):
  1273. with assert_no_gc_cycles():
  1274. make_cycle()
  1275. with assert_raises(AssertionError):
  1276. assert_no_gc_cycles(make_cycle)
  1277. @pytest.mark.slow
  1278. def test_fails(self):
  1279. """
  1280. Test that in cases where the garbage cannot be collected, we raise an
  1281. error, instead of hanging forever trying to clear it.
  1282. """
  1283. class ReferenceCycleInDel:
  1284. """
  1285. An object that not only contains a reference cycle, but creates new
  1286. cycles whenever it's garbage-collected and its __del__ runs
  1287. """
  1288. make_cycle = True
  1289. def __init__(self):
  1290. self.cycle = self
  1291. def __del__(self):
  1292. # break the current cycle so that `self` can be freed
  1293. self.cycle = None
  1294. if ReferenceCycleInDel.make_cycle:
  1295. # but create a new one so that the garbage collector has more
  1296. # work to do.
  1297. ReferenceCycleInDel()
  1298. try:
  1299. w = weakref.ref(ReferenceCycleInDel())
  1300. try:
  1301. with assert_raises(RuntimeError):
  1302. # this will be unable to get a baseline empty garbage
  1303. assert_no_gc_cycles(lambda: None)
  1304. except AssertionError:
  1305. # the above test is only necessary if the GC actually tried to free
  1306. # our object anyway, which python 2.7 does not.
  1307. if w() is not None:
  1308. pytest.skip("GC does not call __del__ on cyclic objects")
  1309. raise
  1310. finally:
  1311. # make sure that we stop creating reference cycles
  1312. ReferenceCycleInDel.make_cycle = False