test_continuous.py 91 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207
  1. import itertools as it
  2. import os
  3. import pickle
  4. from copy import deepcopy
  5. import numpy as np
  6. from numpy import inf
  7. import pytest
  8. from numpy.testing import assert_allclose, assert_equal
  9. from hypothesis import strategies, given, reproduce_failure, settings # noqa: F401
  10. import hypothesis.extra.numpy as npst
  11. from scipy import special
  12. from scipy import stats
  13. from scipy.stats._fit import _kolmogorov_smirnov
  14. from scipy.stats._ksstats import kolmogn
  15. from scipy.stats import qmc
  16. from scipy.stats._distr_params import distcont, distdiscrete
  17. from scipy.stats._distribution_infrastructure import (
  18. _Domain, _RealInterval, _Parameter, _Parameterization, _RealParameter,
  19. ContinuousDistribution, ShiftedScaledDistribution, _fiinfo,
  20. _generate_domain_support, Mixture)
  21. from scipy.stats._new_distributions import StandardNormal, _LogUniform, _Gamma
  22. from scipy.stats._new_distributions import DiscreteDistribution
  23. from scipy.stats import Normal, Logistic, Uniform, Binomial
  24. class Test_RealInterval:
  25. rng = np.random.default_rng(349849812549824)
  26. def test_iv(self):
  27. domain = _RealInterval(endpoints=('a', 'b'))
  28. message = "The endpoints of the distribution are defined..."
  29. with pytest.raises(TypeError, match=message):
  30. domain.get_numerical_endpoints(dict)
  31. @pytest.mark.parametrize('x', [rng.uniform(10, 10, size=(2, 3, 4)),
  32. -np.inf, np.pi])
  33. def test_contains_simple(self, x):
  34. # Test `contains` when endpoints are defined by constants
  35. a, b = -np.inf, np.pi
  36. domain = _RealInterval(endpoints=(a, b), inclusive=(False, True))
  37. assert_equal(domain.contains(x), (a < x) & (x <= b))
  38. @pytest.mark.slow
  39. @given(shapes=npst.mutually_broadcastable_shapes(num_shapes=3, min_side=0),
  40. inclusive_a=strategies.booleans(),
  41. inclusive_b=strategies.booleans(),
  42. data=strategies.data())
  43. def test_contains(self, shapes, inclusive_a, inclusive_b, data):
  44. # Test `contains` when endpoints are defined by parameters
  45. input_shapes, result_shape = shapes
  46. shape_a, shape_b, shape_x = input_shapes
  47. # Without defining min and max values, I spent forever trying to set
  48. # up a valid test without overflows or similar just drawing arrays.
  49. a_elements = dict(allow_nan=False, allow_infinity=False,
  50. min_value=-1e3, max_value=1)
  51. b_elements = dict(allow_nan=False, allow_infinity=False,
  52. min_value=2, max_value=1e3)
  53. a = data.draw(npst.arrays(npst.floating_dtypes(),
  54. shape_a, elements=a_elements))
  55. b = data.draw(npst.arrays(npst.floating_dtypes(),
  56. shape_b, elements=b_elements))
  57. # ensure some points are to the left, some to the right, and some
  58. # are exactly on the boundary
  59. d = b - a
  60. x = np.concatenate([np.linspace(a-d, a, 10),
  61. np.linspace(a, b, 10),
  62. np.linspace(b, b+d, 10)])
  63. # Domain is defined by two parameters, 'a' and 'b'
  64. domain = _RealInterval(endpoints=('a', 'b'),
  65. inclusive=(inclusive_a, inclusive_b))
  66. domain.define_parameters(_RealParameter('a', domain=_RealInterval()),
  67. _RealParameter('b', domain=_RealInterval()))
  68. # Check that domain and string evaluation give the same result
  69. res = domain.contains(x, dict(a=a, b=b))
  70. # Apparently, `np.float16([2]) < np.float32(2.0009766)` is False
  71. # but `np.float16([2]) < np.float32([2.0009766])` is True
  72. # dtype = np.result_type(a.dtype, b.dtype, x.dtype)
  73. # a, b, x = a.astype(dtype), b.astype(dtype), x.astype(dtype)
  74. # unclear whether we should be careful about this, since it will be
  75. # fixed with NEP50. Just do what makes the test pass.
  76. left_comparison = '<=' if inclusive_a else '<'
  77. right_comparison = '<=' if inclusive_b else '<'
  78. ref = eval(f'(a {left_comparison} x) & (x {right_comparison} b)')
  79. assert_equal(res, ref)
  80. @pytest.mark.parametrize("inclusive", list(it.product([True, False], repeat=2)))
  81. @pytest.mark.parametrize("a,b", [(0, 1), (3, 1)])
  82. def test_contains_function_endpoints(self, inclusive, a, b):
  83. # Test `contains` when endpoints are defined by functions.
  84. endpoints = (lambda a, b: (a - b) / 2, lambda a, b: (a + b) / 2)
  85. domain = _RealInterval(endpoints=endpoints, inclusive=inclusive)
  86. x = np.asarray([(a - 2*b)/2, (a - b)/2, a/2, (a + b)/2, (a + 2*b)/2])
  87. res = domain.contains(x, dict(a=a, b=b))
  88. numerical_endpoints = ((a - b) / 2, (a + b) / 2)
  89. assert numerical_endpoints == domain.get_numerical_endpoints(dict(a=a, b=b))
  90. alpha, beta = numerical_endpoints
  91. above_left = alpha <= x if inclusive[0] else alpha < x
  92. below_right = x <= beta if inclusive[1] else x < beta
  93. ref = above_left & below_right
  94. assert_equal(res, ref)
  95. @pytest.mark.parametrize('case', [
  96. (-np.inf, np.pi, False, True, r"(-\infty, \pi]"),
  97. ('a', 5, True, False, "[a, 5)")
  98. ])
  99. def test_str(self, case):
  100. domain = _RealInterval(endpoints=case[:2], inclusive=case[2:4])
  101. assert str(domain) == case[4]
  102. @pytest.mark.slow
  103. @given(a=strategies.one_of(
  104. strategies.decimals(allow_nan=False),
  105. strategies.characters(whitelist_categories="L"), # type: ignore[arg-type]
  106. strategies.sampled_from(list(_Domain.symbols))),
  107. b=strategies.one_of(
  108. strategies.decimals(allow_nan=False),
  109. strategies.characters(whitelist_categories="L"), # type: ignore[arg-type]
  110. strategies.sampled_from(list(_Domain.symbols))),
  111. inclusive_a=strategies.booleans(),
  112. inclusive_b=strategies.booleans(),
  113. )
  114. def test_str2(self, a, b, inclusive_a, inclusive_b):
  115. # I wrote this independently from the implementation of __str__, but
  116. # I imagine it looks pretty similar to __str__.
  117. a = _Domain.symbols.get(a, a)
  118. b = _Domain.symbols.get(b, b)
  119. left_bracket = '[' if inclusive_a else '('
  120. right_bracket = ']' if inclusive_b else ')'
  121. domain = _RealInterval(endpoints=(a, b),
  122. inclusive=(inclusive_a, inclusive_b))
  123. ref = f"{left_bracket}{a}, {b}{right_bracket}"
  124. assert str(domain) == ref
  125. def test_symbols_gh22137(self):
  126. # `symbols` was accidentally shared between instances originally
  127. # Check that this is no longer the case
  128. domain1 = _RealInterval(endpoints=(0, 1))
  129. domain2 = _RealInterval(endpoints=(0, 1))
  130. assert domain1.symbols is not domain2.symbols
  131. def draw_distribution_from_family(family, data, rng, proportions, min_side=0):
  132. # If the distribution has parameters, choose a parameterization and
  133. # draw broadcastable shapes for the parameter arrays.
  134. n_parameterizations = family._num_parameterizations()
  135. if n_parameterizations > 0:
  136. i = data.draw(strategies.integers(0, max_value=n_parameterizations-1))
  137. n_parameters = family._num_parameters(i)
  138. shapes, result_shape = data.draw(
  139. npst.mutually_broadcastable_shapes(num_shapes=n_parameters,
  140. min_side=min_side))
  141. dist = family._draw(shapes, rng=rng, proportions=proportions,
  142. i_parameterization=i)
  143. else:
  144. dist = family._draw(rng=rng)
  145. result_shape = tuple()
  146. # Draw a broadcastable shape for the arguments, and draw values for the
  147. # arguments.
  148. x_shape = data.draw(npst.broadcastable_shapes(result_shape,
  149. min_side=min_side))
  150. x = dist._variable.draw(x_shape, parameter_values=dist._parameters,
  151. proportions=proportions, rng=rng, region='typical')
  152. x_result_shape = np.broadcast_shapes(x_shape, result_shape)
  153. y_shape = data.draw(npst.broadcastable_shapes(x_result_shape,
  154. min_side=min_side))
  155. y = dist._variable.draw(y_shape, parameter_values=dist._parameters,
  156. proportions=proportions, rng=rng, region='typical')
  157. xy_result_shape = np.broadcast_shapes(y_shape, x_result_shape)
  158. p_domain = _RealInterval((0, 1), (True, True))
  159. p_var = _RealParameter('p', domain=p_domain)
  160. p = p_var.draw(x_shape, proportions=proportions, rng=rng)
  161. with np.errstate(divide='ignore', invalid='ignore'):
  162. logp = np.log(p)
  163. return dist, x, y, p, logp, result_shape, x_result_shape, xy_result_shape
  164. continuous_families = [
  165. StandardNormal,
  166. Normal,
  167. Logistic,
  168. Uniform,
  169. _LogUniform
  170. ]
  171. discrete_families = [
  172. Binomial,
  173. ]
  174. families = continuous_families + discrete_families
  175. class TestDistributions:
  176. @pytest.mark.fail_slow(60) # need to break up check_moment_funcs
  177. @settings(max_examples=20)
  178. @pytest.mark.parametrize('family', families)
  179. @given(data=strategies.data(), seed=strategies.integers(min_value=0))
  180. def test_support_moments_sample(self, family, data, seed):
  181. rng = np.random.default_rng(seed)
  182. # relative proportions of valid, endpoint, out of bounds, and NaN params
  183. proportions = (0.7, 0.1, 0.1, 0.1)
  184. tmp = draw_distribution_from_family(family, data, rng, proportions)
  185. dist, x, y, p, logp, result_shape, x_result_shape, xy_result_shape = tmp
  186. sample_shape = data.draw(npst.array_shapes(min_dims=0, min_side=0,
  187. max_side=20))
  188. with np.errstate(invalid='ignore', divide='ignore'):
  189. check_support(dist)
  190. check_moment_funcs(dist, result_shape) # this needs to get split up
  191. check_sample_shape_NaNs(dist, 'sample', sample_shape, result_shape, rng)
  192. qrng = qmc.Halton(d=1, seed=rng)
  193. check_sample_shape_NaNs(dist, 'sample', sample_shape, result_shape, qrng)
  194. @pytest.mark.fail_slow(10)
  195. @pytest.mark.parametrize('family', families)
  196. @pytest.mark.parametrize('func, methods, arg',
  197. [('entropy', {'log/exp', 'quadrature'}, None),
  198. ('logentropy', {'log/exp', 'quadrature'}, None),
  199. ('median', {'icdf'}, None),
  200. ('mode', {'optimization'}, None),
  201. ('mean', {'cache'}, None),
  202. ('variance', {'cache'}, None),
  203. ('skewness', {'cache'}, None),
  204. ('kurtosis', {'cache'}, None),
  205. ('pdf', {'log/exp'}, 'x'),
  206. ('logpdf', {'log/exp'}, 'x'),
  207. ('logcdf', {'log/exp', 'complement', 'quadrature'}, 'x'),
  208. ('cdf', {'log/exp', 'complement', 'quadrature'}, 'x'),
  209. ('logccdf', {'log/exp', 'complement', 'quadrature'}, 'x'),
  210. ('ccdf', {'log/exp', 'complement', 'quadrature'}, 'x'),
  211. ('ilogccdf', {'complement', 'inversion'}, 'logp'),
  212. ('iccdf', {'complement', 'inversion'}, 'p'),
  213. ])
  214. @settings(max_examples=20)
  215. @given(data=strategies.data(), seed=strategies.integers(min_value=0))
  216. def test_funcs(self, family, data, seed, func, methods, arg):
  217. if family == Uniform and func == 'mode':
  218. pytest.skip("Mode is not unique; `method`s disagree.")
  219. rng = np.random.default_rng(seed)
  220. # relative proportions of valid, endpoint, out of bounds, and NaN params
  221. proportions = (0.7, 0.1, 0.1, 0.1)
  222. tmp = draw_distribution_from_family(family, data, rng, proportions)
  223. dist, x, y, p, logp, result_shape, x_result_shape, xy_result_shape = tmp
  224. args = {'x': x, 'p': p, 'logp': p}
  225. with np.errstate(invalid='ignore', divide='ignore', over='ignore'):
  226. if arg is None:
  227. check_dist_func(dist, func, None, result_shape, methods)
  228. elif arg in args:
  229. check_dist_func(dist, func, args[arg], x_result_shape, methods)
  230. if func == 'variance':
  231. assert_allclose(dist.standard_deviation()**2, dist.variance())
  232. # invalid and divide are to be expected; maybe look into over
  233. with np.errstate(invalid='ignore', divide='ignore', over='ignore'):
  234. if not isinstance(dist, ShiftedScaledDistribution):
  235. if func == 'cdf':
  236. methods = {'quadrature'}
  237. check_cdf2(dist, False, x, y, xy_result_shape, methods)
  238. check_cdf2(dist, True, x, y, xy_result_shape, methods)
  239. elif func == 'ccdf':
  240. methods = {'addition'}
  241. check_ccdf2(dist, False, x, y, xy_result_shape, methods)
  242. check_ccdf2(dist, True, x, y, xy_result_shape, methods)
  243. def test_plot(self):
  244. try:
  245. import matplotlib.pyplot as plt
  246. except ImportError:
  247. return
  248. X = Uniform(a=0., b=1.)
  249. ax = X.plot()
  250. assert ax == plt.gca()
  251. @pytest.mark.parametrize('method_name', ['cdf', 'ccdf'])
  252. def test_complement_safe(self, method_name):
  253. X = stats.Normal()
  254. X.tol = 1e-12
  255. p = np.asarray([1e-4, 1e-3])
  256. func = getattr(X, method_name)
  257. ifunc = getattr(X, 'i'+method_name)
  258. x = ifunc(p, method='formula')
  259. p1 = func(x, method='complement_safe')
  260. p2 = func(x, method='complement')
  261. assert_equal(p1[1], p2[1])
  262. assert p1[0] != p2[0]
  263. assert_allclose(p1[0], p[0], rtol=X.tol)
  264. @pytest.mark.parametrize('method_name', ['cdf', 'ccdf'])
  265. def test_icomplement_safe(self, method_name):
  266. X = stats.Normal()
  267. X.tol = 1e-12
  268. p = np.asarray([1e-4, 1e-3])
  269. func = getattr(X, method_name)
  270. ifunc = getattr(X, 'i'+method_name)
  271. x1 = ifunc(p, method='complement_safe')
  272. x2 = ifunc(p, method='complement')
  273. assert_equal(x1[1], x2[1])
  274. assert x1[0] != x2[0]
  275. assert_allclose(func(x1[0]), p[0], rtol=X.tol)
  276. def test_subtraction_safe(self):
  277. X = stats.Normal()
  278. X.tol = 1e-12
  279. # Regular subtraction is fine in either tail (and of course, across tails)
  280. x = [-11, -10, 10, 11]
  281. y = [-10, -11, 11, 10]
  282. p0 = X.cdf(x, y, method='quadrature')
  283. p1 = X.cdf(x, y, method='subtraction_safe')
  284. p2 = X.cdf(x, y, method='subtraction')
  285. assert_equal(p2, p1)
  286. assert_allclose(p1, p0, rtol=X.tol)
  287. # Safe subtraction is needed in special cases
  288. x = np.asarray([-1e-20, -1e-21, 1e-20, 1e-21, -1e-20])
  289. y = np.asarray([-1e-21, -1e-20, 1e-21, 1e-20, 1e-20])
  290. p0 = X.pdf(0)*(y-x)
  291. p1 = X.cdf(x, y, method='subtraction_safe')
  292. p2 = X.cdf(x, y, method='subtraction')
  293. assert_equal(p2, 0)
  294. assert_allclose(p1, p0, rtol=X.tol)
  295. def test_logentropy_safe(self):
  296. # simulate an `entropy` calculation over/underflowing with extreme parameters
  297. class _Normal(stats.Normal):
  298. def _entropy_formula(self, **params):
  299. out = np.asarray(super()._entropy_formula(**params))
  300. out[0] = 0
  301. out[-1] = np.inf
  302. return out
  303. X = _Normal(sigma=[1, 2, 3])
  304. with np.errstate(divide='ignore'):
  305. res1 = X.logentropy(method='logexp_safe')
  306. res2 = X.logentropy(method='logexp')
  307. ref = X.logentropy(method='quadrature')
  308. i_fl = [0, -1] # first and last
  309. assert np.isinf(res2[i_fl]).all()
  310. assert res1[1] == res2[1]
  311. # quadrature happens to be perfectly accurate on some platforms
  312. # assert res1[1] != ref[1]
  313. assert_equal(res1[i_fl], ref[i_fl])
  314. def test_logcdf2_safe(self):
  315. # test what happens when 2-arg `cdf` underflows
  316. X = stats.Normal(sigma=[1, 2, 3])
  317. x = [-301, 1, 300]
  318. y = [-300, 2, 301]
  319. with np.errstate(divide='ignore'):
  320. res1 = X.logcdf(x, y, method='logexp_safe')
  321. res2 = X.logcdf(x, y, method='logexp')
  322. ref = X.logcdf(x, y, method='quadrature')
  323. i_fl = [0, -1] # first and last
  324. assert np.isinf(res2[i_fl]).all()
  325. assert res1[1] == res2[1]
  326. # quadrature happens to be perfectly accurate on some platforms
  327. # assert res1[1] != ref[1]
  328. assert_equal(res1[i_fl], ref[i_fl])
  329. @pytest.mark.parametrize('method_name', ['logcdf', 'logccdf'])
  330. def test_logexp_safe(self, method_name):
  331. # test what happens when `cdf`/`ccdf` underflows
  332. X = stats.Normal(sigma=2)
  333. x = [-301, 1] if method_name == 'logcdf' else [301, 1]
  334. func = getattr(X, method_name)
  335. with np.errstate(divide='ignore'):
  336. res1 = func(x, method='logexp_safe')
  337. res2 = func(x, method='logexp')
  338. ref = func(x, method='quadrature')
  339. assert res1[0] == ref[0]
  340. assert res1[0] != res2[0]
  341. assert res1[1] == res2[1]
  342. assert res1[1] != ref[1]
  343. def check_sample_shape_NaNs(dist, fname, sample_shape, result_shape, rng):
  344. full_shape = sample_shape + result_shape
  345. if fname == 'sample':
  346. sample_method = dist.sample
  347. methods = {'inverse_transform'}
  348. if dist._overrides(f'_{fname}_formula') and not isinstance(rng, qmc.QMCEngine):
  349. methods.add('formula')
  350. for method in methods:
  351. res = sample_method(sample_shape, method=method, rng=rng)
  352. valid_parameters = np.broadcast_to(get_valid_parameters(dist),
  353. res.shape)
  354. assert_equal(res.shape, full_shape)
  355. np.testing.assert_equal(res.dtype, dist._dtype)
  356. if full_shape == ():
  357. # NumPy random makes a distinction between a 0d array and a scalar.
  358. # In stats, we consistently turn 0d arrays into scalars, so
  359. # maintain that behavior here. (With Array API arrays, this will
  360. # change.)
  361. assert np.isscalar(res)
  362. assert np.all(np.isfinite(res[valid_parameters]))
  363. assert_equal(res[~valid_parameters], np.nan)
  364. sample1 = sample_method(sample_shape, method=method, rng=42)
  365. sample2 = sample_method(sample_shape, method=method, rng=42)
  366. if not isinstance(dist, DiscreteDistribution):
  367. # The idea is that it's very unlikely that the random sample
  368. # for a randomly chosen seed will match that for seed 42,
  369. # but it is not so unlikely if `dist` is a discrete distribution.
  370. assert not np.any(np.equal(res, sample1))
  371. assert_equal(sample1, sample2)
  372. def check_support(dist):
  373. a, b = dist.support()
  374. check_nans_and_edges(dist, 'support', None, a)
  375. check_nans_and_edges(dist, 'support', None, b)
  376. assert a.shape == dist._shape
  377. assert b.shape == dist._shape
  378. assert a.dtype == dist._dtype
  379. assert b.dtype == dist._dtype
  380. def check_dist_func(dist, fname, arg, result_shape, methods):
  381. # Check that all computation methods of all distribution functions agree
  382. # with one another, effectively testing the correctness of the generic
  383. # computation methods and confirming the consistency of specific
  384. # distributions with their pdf/logpdf.
  385. args = tuple() if arg is None else (arg,)
  386. methods = methods.copy()
  387. if "cache" in methods:
  388. # If "cache" is specified before the value has been evaluated, it
  389. # raises an error. After the value is evaluated, it will succeed.
  390. with pytest.raises(NotImplementedError):
  391. getattr(dist, fname)(*args, method="cache")
  392. ref = getattr(dist, fname)(*args)
  393. check_nans_and_edges(dist, fname, arg, ref)
  394. # Remove this after fixing `draw`
  395. tol_override = {'atol': 1e-15}
  396. # Mean can be 0, which makes logmean -inf.
  397. if fname in {'logmean', 'mean', 'logskewness', 'skewness'}:
  398. tol_override = {'atol': 1e-15}
  399. elif fname in {'mode'}:
  400. # can only expect about half of machine precision for optimization
  401. # because math
  402. tol_override = {'atol': 1e-6}
  403. elif fname in {'logcdf'}: # gh-22276
  404. tol_override = {'rtol': 2e-7}
  405. if dist._overrides(f'_{fname}_formula'):
  406. methods.add('formula')
  407. np.testing.assert_equal(ref.shape, result_shape)
  408. # Until we convert to array API, let's do the familiar thing:
  409. # 0d things are scalars, not arrays
  410. if result_shape == tuple():
  411. assert np.isscalar(ref)
  412. for method in methods:
  413. res = getattr(dist, fname)(*args, method=method)
  414. if 'log' in fname:
  415. np.testing.assert_allclose(np.exp(res), np.exp(ref),
  416. **tol_override)
  417. else:
  418. np.testing.assert_allclose(res, ref, **tol_override)
  419. # for now, make sure dtypes are consistent; later, we can check whether
  420. # they are correct.
  421. np.testing.assert_equal(res.dtype, ref.dtype)
  422. np.testing.assert_equal(res.shape, result_shape)
  423. if result_shape == tuple():
  424. assert np.isscalar(res)
  425. def check_cdf2(dist, log, x, y, result_shape, methods):
  426. # Specialized test for 2-arg cdf since the interface is a bit different
  427. # from the other methods. Here, we'll use 1-arg cdf as a reference, and
  428. # since we have already checked 1-arg cdf in `check_nans_and_edges`, this
  429. # checks the equivalent of both `check_dist_func` and
  430. # `check_nans_and_edges`.
  431. methods = methods.copy()
  432. if log:
  433. if dist._overrides('_logcdf2_formula'):
  434. methods.add('formula')
  435. if dist._overrides('_logcdf_formula') or dist._overrides('_logccdf_formula'):
  436. methods.add('subtraction')
  437. if (dist._overrides('_cdf_formula')
  438. or dist._overrides('_ccdf_formula')):
  439. methods.add('log/exp')
  440. else:
  441. if dist._overrides('_cdf2_formula'):
  442. methods.add('formula')
  443. if dist._overrides('_cdf_formula') or dist._overrides('_ccdf_formula'):
  444. methods.add('subtraction')
  445. if (dist._overrides('_logcdf_formula')
  446. or dist._overrides('_logccdf_formula')):
  447. methods.add('log/exp')
  448. ref = dist.cdf(y) - dist.cdf(x)
  449. np.testing.assert_equal(ref.shape, result_shape)
  450. if result_shape == tuple():
  451. assert np.isscalar(ref)
  452. for method in methods:
  453. if isinstance(dist, DiscreteDistribution):
  454. message = ("Two argument cdf functions are currently only supported for "
  455. "continuous distributions.")
  456. with pytest.raises(NotImplementedError, match=message):
  457. res = (np.exp(dist.logcdf(x, y, method=method)) if log
  458. else dist.cdf(x, y, method=method))
  459. continue
  460. res = (np.exp(dist.logcdf(x, y, method=method)) if log
  461. else dist.cdf(x, y, method=method))
  462. np.testing.assert_allclose(res, ref, atol=1e-14)
  463. if log:
  464. np.testing.assert_equal(res.dtype, (ref + 0j).dtype)
  465. else:
  466. np.testing.assert_equal(res.dtype, ref.dtype)
  467. np.testing.assert_equal(res.shape, result_shape)
  468. if result_shape == tuple():
  469. assert np.isscalar(res)
  470. def check_ccdf2(dist, log, x, y, result_shape, methods):
  471. # Specialized test for 2-arg ccdf since the interface is a bit different
  472. # from the other methods. Could be combined with check_cdf2 above, but
  473. # writing it separately is simpler.
  474. methods = methods.copy()
  475. if dist._overrides(f'_{"log" if log else ""}ccdf2_formula'):
  476. methods.add('formula')
  477. ref = dist.cdf(x) + dist.ccdf(y)
  478. np.testing.assert_equal(ref.shape, result_shape)
  479. if result_shape == tuple():
  480. assert np.isscalar(ref)
  481. for method in methods:
  482. message = ("Two argument cdf functions are currently only supported for "
  483. "continuous distributions.")
  484. if isinstance(dist, DiscreteDistribution):
  485. with pytest.raises(NotImplementedError, match=message):
  486. res = (np.exp(dist.logccdf(x, y, method=method)) if log
  487. else dist.ccdf(x, y, method=method))
  488. continue
  489. res = (np.exp(dist.logccdf(x, y, method=method)) if log
  490. else dist.ccdf(x, y, method=method))
  491. np.testing.assert_allclose(res, ref, atol=1e-14)
  492. np.testing.assert_equal(res.dtype, ref.dtype)
  493. np.testing.assert_equal(res.shape, result_shape)
  494. if result_shape == tuple():
  495. assert np.isscalar(res)
  496. def check_nans_and_edges(dist, fname, arg, res):
  497. valid_parameters = get_valid_parameters(dist)
  498. if fname in {'icdf', 'iccdf'}:
  499. arg_domain = _RealInterval(endpoints=(0, 1), inclusive=(True, True))
  500. elif fname in {'ilogcdf', 'ilogccdf'}:
  501. arg_domain = _RealInterval(endpoints=(-inf, 0), inclusive=(True, True))
  502. else:
  503. arg_domain = dist._variable.domain
  504. classified_args = classify_arg(dist, arg, arg_domain)
  505. valid_parameters, *classified_args = np.broadcast_arrays(valid_parameters,
  506. *classified_args)
  507. valid_arg, endpoint_arg, outside_arg, nan_arg = classified_args
  508. all_valid = valid_arg & valid_parameters
  509. # Check NaN pattern and edge cases
  510. assert_equal(res[~valid_parameters], np.nan)
  511. assert_equal(res[nan_arg], np.nan)
  512. a, b = dist.support()
  513. a = np.broadcast_to(a, res.shape)
  514. b = np.broadcast_to(b, res.shape)
  515. outside_arg_minus = (outside_arg == -1) & valid_parameters
  516. outside_arg_plus = (outside_arg == 1) & valid_parameters
  517. endpoint_arg_minus = (endpoint_arg == -1) & valid_parameters
  518. endpoint_arg_plus = (endpoint_arg == 1) & valid_parameters
  519. is_discrete = isinstance(dist, DiscreteDistribution)
  520. # Writing this independently of how the are set in the distribution
  521. # infrastructure. That is very compact; this is very verbose.
  522. if fname in {'logpdf'}:
  523. assert_equal(res[outside_arg_minus], -np.inf)
  524. assert_equal(res[outside_arg_plus], -np.inf)
  525. ref = -np.inf if not is_discrete else np.inf
  526. assert_equal(res[endpoint_arg_minus & ~valid_arg], ref)
  527. assert_equal(res[endpoint_arg_plus & ~valid_arg], ref)
  528. elif fname in {'pdf'}:
  529. assert_equal(res[outside_arg_minus], 0)
  530. assert_equal(res[outside_arg_plus], 0)
  531. ref = 0 if not is_discrete else np.inf
  532. assert_equal(res[endpoint_arg_minus & ~valid_arg], ref)
  533. assert_equal(res[endpoint_arg_plus & ~valid_arg], ref)
  534. elif fname in {'logcdf'} and not is_discrete:
  535. assert_equal(res[outside_arg_minus], -inf)
  536. assert_equal(res[outside_arg_plus], 0)
  537. assert_equal(res[endpoint_arg_minus], -inf)
  538. assert_equal(res[endpoint_arg_plus], 0)
  539. elif fname in {'cdf'} and not is_discrete:
  540. assert_equal(res[outside_arg_minus], 0)
  541. assert_equal(res[outside_arg_plus], 1)
  542. assert_equal(res[endpoint_arg_minus], 0)
  543. assert_equal(res[endpoint_arg_plus], 1)
  544. elif fname in {'logccdf'} and not is_discrete:
  545. assert_equal(res[outside_arg_minus], 0)
  546. assert_equal(res[outside_arg_plus], -inf)
  547. assert_equal(res[endpoint_arg_minus], 0)
  548. assert_equal(res[endpoint_arg_plus], -inf)
  549. elif fname in {'ccdf'} and not is_discrete:
  550. assert_equal(res[outside_arg_minus], 1)
  551. assert_equal(res[outside_arg_plus], 0)
  552. assert_equal(res[endpoint_arg_minus], 1)
  553. assert_equal(res[endpoint_arg_plus], 0)
  554. elif fname in {'ilogcdf', 'icdf'} and not is_discrete:
  555. assert_equal(res[outside_arg == -1], np.nan)
  556. assert_equal(res[outside_arg == 1], np.nan)
  557. assert_equal(res[endpoint_arg == -1], a[endpoint_arg == -1])
  558. assert_equal(res[endpoint_arg == 1], b[endpoint_arg == 1])
  559. elif fname in {'ilogccdf', 'iccdf'} and not is_discrete:
  560. assert_equal(res[outside_arg == -1], np.nan)
  561. assert_equal(res[outside_arg == 1], np.nan)
  562. assert_equal(res[endpoint_arg == -1], b[endpoint_arg == -1])
  563. assert_equal(res[endpoint_arg == 1], a[endpoint_arg == 1])
  564. exclude = {'logmean', 'mean', 'logskewness', 'skewness', 'support'}
  565. if isinstance(dist, DiscreteDistribution):
  566. exclude.update({'pdf', 'logpdf'})
  567. if (
  568. fname not in exclude
  569. and not (isinstance(dist, Binomial)
  570. and np.any((dist.n == 0) | (dist.p == 0) | (dist.p == 1)))):
  571. # This can fail in degenerate case where Binomial distribution is a point
  572. # distribution. Further on, we could factor out an is_degenerate function
  573. # for the tests, or think about storing info about degeneracy in the
  574. # instances.
  575. assert np.isfinite(res[all_valid & (endpoint_arg == 0)]).all()
  576. def check_moment_funcs(dist, result_shape):
  577. # Check that all computation methods of all distribution functions agree
  578. # with one another, effectively testing the correctness of the generic
  579. # computation methods and confirming the consistency of specific
  580. # distributions with their pdf/logpdf.
  581. atol = 1e-9 # make this tighter (e.g. 1e-13) after fixing `draw`
  582. def check(order, kind, method=None, ref=None, success=True):
  583. if success:
  584. res = dist.moment(order, kind, method=method)
  585. assert_allclose(res, ref, atol=atol*10**order)
  586. assert res.shape == ref.shape
  587. else:
  588. with pytest.raises(NotImplementedError):
  589. dist.moment(order, kind, method=method)
  590. def has_formula(order, kind):
  591. formula_name = f'_moment_{kind}_formula'
  592. overrides = dist._overrides(formula_name)
  593. if not overrides:
  594. return False
  595. formula = getattr(dist, formula_name)
  596. orders = getattr(formula, 'orders', set(range(6)))
  597. return order in orders
  598. dist.reset_cache()
  599. ### Check Raw Moments ###
  600. for i in range(6):
  601. check(i, 'raw', 'cache', success=False) # not cached yet
  602. ref = dist.moment(i, 'raw', method='quadrature')
  603. check_nans_and_edges(dist, 'moment', None, ref)
  604. assert ref.shape == result_shape
  605. check(i, 'raw','cache', ref, success=True) # cached now
  606. check(i, 'raw', 'formula', ref, success=has_formula(i, 'raw'))
  607. check(i, 'raw', 'general', ref, success=(i == 0))
  608. if dist.__class__ == stats.Normal:
  609. check(i, 'raw', 'quadrature_icdf', ref, success=True)
  610. # Clearing caches to better check their behavior
  611. dist.reset_cache()
  612. # If we have central or standard moment formulas, or if there are
  613. # values in their cache, we can use method='transform'
  614. dist.moment(0, 'central') # build up the cache
  615. dist.moment(1, 'central')
  616. for i in range(2, 6):
  617. ref = dist.moment(i, 'raw', method='quadrature')
  618. check(i, 'raw', 'transform', ref,
  619. success=has_formula(i, 'central') or has_formula(i, 'standardized'))
  620. dist.moment(i, 'central') # build up the cache
  621. check(i, 'raw', 'transform', ref)
  622. dist.reset_cache()
  623. ### Check Central Moments ###
  624. for i in range(6):
  625. check(i, 'central', 'cache', success=False)
  626. ref = dist.moment(i, 'central', method='quadrature')
  627. assert ref.shape == result_shape
  628. check(i, 'central', 'cache', ref, success=True)
  629. check(i, 'central', 'formula', ref, success=has_formula(i, 'central'))
  630. check(i, 'central', 'general', ref, success=i <= 1)
  631. if dist.__class__ == stats.Normal:
  632. check(i, 'central', 'quadrature_icdf', ref, success=True)
  633. if not (dist.__class__ == stats.Uniform and i == 5):
  634. # Quadrature is not super accurate for 5th central moment when the
  635. # support is really big. Skip this one failing test. We need to come
  636. # up with a better system of skipping individual failures w/ hypothesis.
  637. check(i, 'central', 'transform', ref,
  638. success=has_formula(i, 'raw') or (i <= 1))
  639. if not has_formula(i, 'raw'):
  640. dist.moment(i, 'raw')
  641. check(i, 'central', 'transform', ref)
  642. variance = dist.variance()
  643. dist.reset_cache()
  644. # If we have standard moment formulas, or if there are
  645. # values in their cache, we can use method='normalize'
  646. dist.moment(0, 'standardized') # build up the cache
  647. dist.moment(1, 'standardized')
  648. dist.moment(2, 'standardized')
  649. for i in range(3, 6):
  650. ref = dist.moment(i, 'central', method='quadrature')
  651. check(i, 'central', 'normalize', ref,
  652. success=has_formula(i, 'standardized') and not np.any(variance == 0))
  653. dist.moment(i, 'standardized') # build up the cache
  654. check(i, 'central', 'normalize', ref, success=not np.any(variance == 0))
  655. ### Check Standardized Moments ###
  656. var = dist.moment(2, 'central', method='quadrature')
  657. dist.reset_cache()
  658. for i in range(6):
  659. check(i, 'standardized', 'cache', success=False)
  660. ref = dist.moment(i, 'central', method='quadrature') / var ** (i / 2)
  661. assert ref.shape == result_shape
  662. check(i, 'standardized', 'formula', ref,
  663. success=has_formula(i, 'standardized'))
  664. if not (
  665. isinstance(dist, Binomial)
  666. and np.any((dist.n == 0) | (dist.p == 0) | (dist.p == 1))
  667. ):
  668. # This test will fail for degenerate case where binomial distribution
  669. # is a point distribution.
  670. check(i, 'standardized', 'general', ref, success=i <= 2)
  671. check(i, 'standardized', 'normalize', ref)
  672. if isinstance(dist, ShiftedScaledDistribution):
  673. # logmoment is not fully fleshed out; no need to test
  674. # ShiftedScaledDistribution here
  675. return
  676. # logmoment is not very accuate, and it's not public, so skip for now
  677. # ### Check Against _logmoment ###
  678. # logmean = dist._logmoment(1, logcenter=-np.inf)
  679. # for i in range(6):
  680. # ref = np.exp(dist._logmoment(i, logcenter=-np.inf))
  681. # assert_allclose(dist.moment(i, 'raw'), ref, atol=atol*10**i)
  682. #
  683. # ref = np.exp(dist._logmoment(i, logcenter=logmean))
  684. # assert_allclose(dist.moment(i, 'central'), ref, atol=atol*10**i)
  685. #
  686. # ref = np.exp(dist._logmoment(i, logcenter=logmean, standardized=True))
  687. # assert_allclose(dist.moment(i, 'standardized'), ref, atol=atol*10**i)
  688. @pytest.mark.parametrize('family', (Normal,))
  689. @pytest.mark.parametrize('x_shape', [tuple(), (2, 3)])
  690. @pytest.mark.parametrize('dist_shape', [tuple(), (4, 1)])
  691. @pytest.mark.parametrize('fname', ['sample'])
  692. @pytest.mark.parametrize('rng_type', [np.random.Generator, qmc.Halton, qmc.Sobol])
  693. def test_sample_against_cdf(family, dist_shape, x_shape, fname, rng_type):
  694. rng = np.random.default_rng(842582438235635)
  695. num_parameters = family._num_parameters()
  696. if dist_shape and num_parameters == 0:
  697. pytest.skip("Distribution can't have a shape without parameters.")
  698. dist = family._draw(dist_shape, rng)
  699. n = 1024
  700. sample_size = (n,) + x_shape
  701. sample_array_shape = sample_size + dist_shape
  702. if fname == 'sample':
  703. sample_method = dist.sample
  704. if rng_type != np.random.Generator:
  705. rng = rng_type(d=1, seed=rng)
  706. x = sample_method(sample_size, rng=rng)
  707. assert x.shape == sample_array_shape
  708. # probably should give `axis` argument to ks_1samp, review that separately
  709. statistic = _kolmogorov_smirnov(dist, x, axis=0)
  710. pvalue = kolmogn(x.shape[0], statistic, cdf=False)
  711. p_threshold = 0.01
  712. num_pvalues = pvalue.size
  713. num_small_pvalues = np.sum(pvalue < p_threshold)
  714. assert num_small_pvalues < p_threshold * num_pvalues
  715. def get_valid_parameters(dist):
  716. # Given a distribution, return a logical array that is true where all
  717. # distribution parameters are within their respective domains. The code
  718. # here is probably quite similar to that used to form the `_invalid`
  719. # attribute of the distribution, but this was written about a week later
  720. # without referring to that code, so it is a somewhat independent check.
  721. # Get all parameter values and `_Parameter` objects
  722. parameter_values = dist._parameters
  723. parameters = {}
  724. for parameterization in dist._parameterizations:
  725. parameters.update(parameterization.parameters)
  726. all_valid = np.ones(dist._shape, dtype=bool)
  727. for name, value in parameter_values.items():
  728. if name not in parameters: # cached value not part of parameterization
  729. continue
  730. parameter = parameters[name]
  731. # Check that the numerical endpoints and inclusivity attribute
  732. # agree with the `contains` method about which parameter values are
  733. # within the domain.
  734. a, b = parameter.domain.get_numerical_endpoints(
  735. parameter_values=parameter_values)
  736. a_included, b_included = parameter.domain.inclusive
  737. valid = (a <= value) if a_included else a < value
  738. valid &= (value <= b) if b_included else value < b
  739. assert_equal(valid, parameter.domain.contains(
  740. value, parameter_values=parameter_values))
  741. # Form `all_valid` mask that is True where *all* parameters are valid
  742. all_valid &= valid
  743. # Check that the `all_valid` mask formed here is the complement of the
  744. # `dist._invalid` mask stored by the infrastructure
  745. assert_equal(~all_valid, dist._invalid)
  746. return all_valid
  747. def classify_arg(dist, arg, arg_domain):
  748. if arg is None:
  749. valid_args = np.ones(dist._shape, dtype=bool)
  750. endpoint_args = np.zeros(dist._shape, dtype=bool)
  751. outside_args = np.zeros(dist._shape, dtype=bool)
  752. nan_args = np.zeros(dist._shape, dtype=bool)
  753. return valid_args, endpoint_args, outside_args, nan_args
  754. a, b = arg_domain.get_numerical_endpoints(
  755. parameter_values=dist._parameters)
  756. a, b, arg = np.broadcast_arrays(a, b, arg)
  757. a_included, b_included = arg_domain.inclusive
  758. inside = (a <= arg) if a_included else a < arg
  759. inside &= (arg <= b) if b_included else arg < b
  760. # TODO: add `supported` method and check here
  761. on = np.zeros(a.shape, dtype=int)
  762. on[a == arg] = -1
  763. on[b == arg] = 1
  764. outside = np.zeros(a.shape, dtype=int)
  765. outside[(arg < a) if a_included else arg <= a] = -1
  766. outside[(b < arg) if b_included else b <= arg] = 1
  767. nan = np.isnan(arg)
  768. return inside, on, outside, nan
  769. def test_input_validation():
  770. class Test(ContinuousDistribution):
  771. _variable = _RealParameter('x', domain=_RealInterval())
  772. message = ("The `Test` distribution family does not accept parameters, "
  773. "but parameters `{'a'}` were provided.")
  774. with pytest.raises(ValueError, match=message):
  775. Test(a=1, )
  776. message = "Attribute `tol` of `Test` must be a positive float, if specified."
  777. with pytest.raises(ValueError, match=message):
  778. Test(tol=np.asarray([]))
  779. with pytest.raises(ValueError, match=message):
  780. Test(tol=[1, 2, 3])
  781. with pytest.raises(ValueError, match=message):
  782. Test(tol=np.nan)
  783. with pytest.raises(ValueError, match=message):
  784. Test(tol=-1)
  785. message = ("Argument `order` of `Test.moment` must be a "
  786. "finite, positive integer.")
  787. with pytest.raises(ValueError, match=message):
  788. Test().moment(-1)
  789. with pytest.raises(ValueError, match=message):
  790. Test().moment(np.inf)
  791. message = "Argument `kind` of `Test.moment` must be one of..."
  792. with pytest.raises(ValueError, match=message):
  793. Test().moment(2, kind='coconut')
  794. class Test2(ContinuousDistribution):
  795. _p1 = _RealParameter('c', domain=_RealInterval())
  796. _p2 = _RealParameter('d', domain=_RealInterval())
  797. _parameterizations = [_Parameterization(_p1, _p2)]
  798. _variable = _RealParameter('x', domain=_RealInterval())
  799. message = ("The provided parameters `{a}` do not match a supported "
  800. "parameterization of the `Test2` distribution family.")
  801. with pytest.raises(ValueError, match=message):
  802. Test2(a=1)
  803. message = ("The `Test2` distribution family requires parameters, but none "
  804. "were provided.")
  805. with pytest.raises(ValueError, match=message):
  806. Test2()
  807. message = ("The parameters `{c, d}` provided to the `Test2` "
  808. "distribution family cannot be broadcast to the same shape.")
  809. with pytest.raises(ValueError, match=message):
  810. Test2(c=[1, 2], d=[1, 2, 3])
  811. message = ("The argument provided to `Test2.pdf` cannot be be broadcast to "
  812. "the same shape as the distribution parameters.")
  813. with pytest.raises(ValueError, match=message):
  814. dist = Test2(c=[1, 2, 3], d=[1, 2, 3])
  815. dist.pdf([1, 2])
  816. message = "Parameter `c` must be of real dtype."
  817. with pytest.raises(TypeError, match=message):
  818. Test2(c=[1, object()], d=[1, 2])
  819. message = "Parameter `convention` of `Test2.kurtosis` must be one of..."
  820. with pytest.raises(ValueError, match=message):
  821. dist = Test2(c=[1, 2, 3], d=[1, 2, 3])
  822. dist.kurtosis(convention='coconut')
  823. def test_rng_deepcopy_pickle():
  824. # test behavior of `rng` attribute and copy behavior
  825. kwargs = dict(a=[-1, 2], b=10)
  826. dist1 = Uniform(**kwargs)
  827. dist2 = deepcopy(dist1)
  828. dist3 = pickle.loads(pickle.dumps(dist1))
  829. res1, res2, res3 = dist1.sample(), dist2.sample(), dist3.sample()
  830. assert np.all(res2 != res1)
  831. assert np.all(res3 != res1)
  832. res1, res2, res3 = dist1.sample(rng=42), dist2.sample(rng=42), dist3.sample(rng=42)
  833. assert np.all(res2 == res1)
  834. assert np.all(res3 == res1)
  835. class TestAttributes:
  836. def test_cache_policy(self):
  837. dist = StandardNormal(cache_policy="no_cache")
  838. # make error message more appropriate
  839. message = "`StandardNormal` does not provide an accurate implementation of the "
  840. with pytest.raises(NotImplementedError, match=message):
  841. dist.mean(method='cache')
  842. mean = dist.mean()
  843. with pytest.raises(NotImplementedError, match=message):
  844. dist.mean(method='cache')
  845. # add to enum
  846. dist.cache_policy = None
  847. with pytest.raises(NotImplementedError, match=message):
  848. dist.mean(method='cache')
  849. mean = dist.mean() # method is 'formula' by default
  850. cached_mean = dist.mean(method='cache')
  851. assert_equal(cached_mean, mean)
  852. # cache is overridden by latest evaluation
  853. quadrature_mean = dist.mean(method='quadrature')
  854. cached_mean = dist.mean(method='cache')
  855. assert_equal(cached_mean, quadrature_mean)
  856. assert not np.all(mean == quadrature_mean)
  857. # We can turn the cache off, and it won't change, but the old cache is
  858. # still available
  859. dist.cache_policy = "no_cache"
  860. mean = dist.mean(method='formula')
  861. cached_mean = dist.mean(method='cache')
  862. assert_equal(cached_mean, quadrature_mean)
  863. assert not np.all(mean == quadrature_mean)
  864. dist.reset_cache()
  865. with pytest.raises(NotImplementedError, match=message):
  866. dist.mean(method='cache')
  867. message = "Attribute `cache_policy` of `StandardNormal`..."
  868. with pytest.raises(ValueError, match=message):
  869. dist.cache_policy = "invalid"
  870. def test_tol(self):
  871. x = 3.
  872. X = stats.Normal()
  873. message = "Attribute `tol` of `StandardNormal` must..."
  874. with pytest.raises(ValueError, match=message):
  875. X.tol = -1.
  876. with pytest.raises(ValueError, match=message):
  877. X.tol = (0.1,)
  878. with pytest.raises(ValueError, match=message):
  879. X.tol = np.nan
  880. X1 = stats.Normal(tol=1e-1)
  881. X2 = stats.Normal(tol=1e-12)
  882. ref = X.cdf(x)
  883. res1 = X1.cdf(x, method='quadrature')
  884. res2 = X2.cdf(x, method='quadrature')
  885. assert_allclose(res1, ref, rtol=X1.tol)
  886. assert_allclose(res2, ref, rtol=X2.tol)
  887. assert abs(res1 - ref) > abs(res2 - ref)
  888. p = 0.99
  889. X1.tol, X2.tol = X2.tol, X1.tol
  890. ref = X.icdf(p)
  891. res1 = X1.icdf(p, method='inversion')
  892. res2 = X2.icdf(p, method='inversion')
  893. assert_allclose(res1, ref, rtol=X1.tol)
  894. assert_allclose(res2, ref, rtol=X2.tol)
  895. assert abs(res2 - ref) > abs(res1 - ref)
  896. def test_iv_policy(self):
  897. X = Uniform(a=0, b=1)
  898. assert X.pdf(2) == 0
  899. X.validation_policy = 'skip_all'
  900. assert X.pdf(np.asarray(2.)) == 1
  901. # Tests _set_invalid_nan
  902. a, b = np.asarray(1.), np.asarray(0.) # invalid parameters
  903. X = Uniform(a=a, b=b, validation_policy='skip_all')
  904. assert X.pdf(np.asarray(2.)) == -1
  905. # Tests _set_invalid_nan_property
  906. class MyUniform(Uniform):
  907. def _entropy_formula(self, *args, **kwargs):
  908. return 'incorrect'
  909. def _moment_raw_formula(self, order, **params):
  910. return 'incorrect'
  911. X = MyUniform(a=a, b=b, validation_policy='skip_all')
  912. assert X.entropy() == 'incorrect'
  913. # Tests _validate_order_kind
  914. assert X.moment(kind='raw', order=-1) == 'incorrect'
  915. # Test input validation
  916. message = "Attribute `validation_policy` of `MyUniform`..."
  917. with pytest.raises(ValueError, match=message):
  918. X.validation_policy = "invalid"
  919. def test_shapes(self):
  920. X = stats.Normal(mu=1, sigma=2)
  921. Y = stats.Normal(mu=[2], sigma=3)
  922. # Check that attributes are available as expected
  923. assert X.mu == 1
  924. assert X.sigma == 2
  925. assert Y.mu[0] == 2
  926. assert Y.sigma[0] == 3
  927. # Trying to set an attribute raises
  928. # message depends on Python version
  929. with pytest.raises(AttributeError):
  930. X.mu = 2
  931. # Trying to mutate an attribute really mutates a copy
  932. Y.mu[0] = 10
  933. assert Y.mu[0] == 2
  934. class TestMakeDistribution:
  935. @pytest.mark.parametrize('i, distdata', enumerate(distcont + distdiscrete))
  936. def test_rv_generic(self, i, distdata):
  937. distname = distdata[0]
  938. slow = {'argus', 'exponpow', 'exponweib', 'genexpon', 'gompertz', 'halfgennorm',
  939. 'johnsonsb', 'kappa4', 'ksone', 'kstwo', 'kstwobign', 'norminvgauss',
  940. 'powerlognorm', 'powernorm', 'recipinvgauss', 'studentized_range',
  941. 'vonmises_line', # continuous
  942. 'betanbinom', 'logser', 'skellam', 'zipf'} # discrete
  943. if not int(os.environ.get('SCIPY_XSLOW', '0')) and distname in slow:
  944. pytest.skip('Skipping as XSLOW')
  945. if distname in { # skip these distributions
  946. 'levy_stable', # private methods seem to require >= 1d args
  947. 'vonmises', # circular distribution; shouldn't work
  948. 'poisson_binom', # vector shape parameter
  949. 'hypergeom', # distribution functions need interpolation
  950. 'nchypergeom_fisher', # distribution functions need interpolation
  951. 'nchypergeom_wallenius', # distribution functions need interpolation
  952. }:
  953. return
  954. # skip single test, mostly due to slight disagreement
  955. custom_tolerances = {'ksone': 1e-5, 'kstwo': 1e-5} # discontinuous PDF
  956. skip_entropy = {'kstwobign', 'pearson3'} # tolerance issue
  957. skip_skewness = {'exponpow', 'ksone', 'nchypergeom_wallenius'} # tolerance
  958. skip_kurtosis = {'chi', 'exponpow', 'invgamma', # tolerance
  959. 'johnsonsb', 'ksone', 'kstwo', # tolerance
  960. 'nchypergeom_wallenius'} # tolerance
  961. skip_logccdf = {'arcsine', 'skewcauchy', 'trapezoid', 'triang'} # tolerance
  962. skip_raw = {2: {'alpha', 'foldcauchy', 'halfcauchy', 'levy', 'levy_l'},
  963. 3: {'pareto'}, # stats.pareto is just wrong
  964. 4: {'invgamma'}} # tolerance issue
  965. skip_standardized = {'exponpow', 'ksone'} # tolerances
  966. dist = getattr(stats, distname)
  967. params = dict(zip(dist.shapes.split(', '), distdata[1])) if dist.shapes else {}
  968. rng = np.random.default_rng(7548723590230982)
  969. CustomDistribution = stats.make_distribution(dist)
  970. X = CustomDistribution(**params)
  971. Y = dist(**params)
  972. x = X.sample(shape=10, rng=rng)
  973. p = X.cdf(x)
  974. rtol = custom_tolerances.get(distname, 1e-7)
  975. atol = 1e-12
  976. with np.errstate(divide='ignore', invalid='ignore'):
  977. m, v, s, k = Y.stats('mvsk')
  978. assert_allclose(X.support(), Y.support())
  979. if distname not in skip_entropy:
  980. assert_allclose(X.entropy(), Y.entropy(), rtol=rtol)
  981. if isinstance(Y, stats.rv_discrete):
  982. # some continuous distributions have trouble with `logentropy` because
  983. # it uses complex numbers
  984. assert_allclose(np.exp(X.logentropy()), Y.entropy(), rtol=rtol)
  985. assert_allclose(X.median(), Y.median(), rtol=rtol)
  986. assert_allclose(X.mean(), m, rtol=rtol, atol=atol)
  987. assert_allclose(X.variance(), v, rtol=rtol, atol=atol)
  988. if distname not in skip_skewness:
  989. assert_allclose(X.skewness(), s, rtol=rtol, atol=atol)
  990. if distname not in skip_kurtosis:
  991. assert_allclose(X.kurtosis(convention='excess'), k,
  992. rtol=rtol, atol=atol)
  993. if isinstance(dist, stats.rv_continuous):
  994. assert_allclose(X.logpdf(x), Y.logpdf(x), rtol=rtol)
  995. assert_allclose(X.pdf(x), Y.pdf(x), rtol=rtol)
  996. else:
  997. assert_allclose(X.logpmf(x), Y.logpmf(x), rtol=rtol)
  998. assert_allclose(X.pmf(x), Y.pmf(x), rtol=rtol)
  999. assert_allclose(X.logcdf(x), Y.logcdf(x), rtol=rtol)
  1000. assert_allclose(X.cdf(x), Y.cdf(x), rtol=rtol)
  1001. if distname not in skip_logccdf:
  1002. assert_allclose(X.logccdf(x), Y.logsf(x), rtol=rtol)
  1003. assert_allclose(X.ccdf(x), Y.sf(x), rtol=rtol)
  1004. # old infrastructure convention for ppf(p=0) and isf(p=1) is different than
  1005. # new infrastructure. Adjust reference values accordingly.
  1006. a, _ = Y.support()
  1007. ref_ppf = Y.ppf(p)
  1008. ref_ppf[p == 0] = a
  1009. ref_isf = Y.isf(p)
  1010. ref_isf[p == 1] = a
  1011. assert_allclose(X.icdf(p), ref_ppf, rtol=rtol)
  1012. assert_allclose(X.iccdf(p), ref_isf, rtol=rtol)
  1013. for order in range(5):
  1014. if distname not in skip_raw.get(order, {}):
  1015. assert_allclose(X.moment(order, kind='raw'),
  1016. Y.moment(order), rtol=rtol, atol=atol)
  1017. for order in range(3, 4):
  1018. if distname not in skip_standardized:
  1019. assert_allclose(X.moment(order, kind='standardized'),
  1020. Y.stats('mvsk'[order-1]), rtol=rtol, atol=atol)
  1021. if isinstance(dist, stats.rv_continuous):
  1022. # For discrete distributions, these won't agree at the far left end
  1023. # of the support, and the new infrastructure is slow there (for now).
  1024. seed = 845298245687345
  1025. assert_allclose(X.sample(shape=10, rng=seed),
  1026. Y.rvs(size=10,
  1027. random_state=np.random.default_rng(seed)),
  1028. rtol=rtol)
  1029. def test_custom(self):
  1030. rng = np.random.default_rng(7548723590230982)
  1031. class MyLogUniform:
  1032. @property
  1033. def __make_distribution_version__(self):
  1034. return "1.16.0"
  1035. @property
  1036. def parameters(self):
  1037. return {'a': {'endpoints': (0, np.inf), 'inclusive': (False, False)},
  1038. 'b': {'endpoints': ('a', np.inf), 'inclusive': (False, False)}}
  1039. @property
  1040. def support(self):
  1041. return {'endpoints': ('a', 'b')}
  1042. def pdf(self, x, a, b):
  1043. return 1 / (x * (np.log(b) - np.log(a)))
  1044. def sample(self, shape, *, a, b, rng=None):
  1045. p = rng.uniform(size=shape)
  1046. return np.exp(np.log(a) + p * (np.log(b) - np.log(a)))
  1047. def moment(self, order, kind='raw', *, a, b):
  1048. if order == 1 and kind == 'raw':
  1049. # quadrature is perfectly accurate here; add 1e-10 error so we
  1050. # can tell the difference between the two
  1051. return (b - a) / np.log(b/a) + 1e-10
  1052. LogUniform = stats.make_distribution(MyLogUniform())
  1053. X = LogUniform(a=1., b=np.e)
  1054. Y = stats.exp(Uniform(a=0., b=1.))
  1055. # pre-2.0 support is not needed for much longer, so let's just test with 2.0+
  1056. if np.__version__ >= "2.0":
  1057. assert str(X) == f"MyLogUniform(a=1.0, b={np.e})"
  1058. assert repr(X) == f"MyLogUniform(a=np.float64(1.0), b=np.float64({np.e}))"
  1059. x = X.sample(shape=10, rng=rng)
  1060. p = X.cdf(x)
  1061. assert_allclose(X.support(), Y.support())
  1062. assert_allclose(X.entropy(), Y.entropy())
  1063. assert_allclose(X.median(), Y.median())
  1064. assert_allclose(X.logpdf(x), Y.logpdf(x))
  1065. assert_allclose(X.pdf(x), Y.pdf(x))
  1066. assert_allclose(X.logcdf(x), Y.logcdf(x))
  1067. assert_allclose(X.cdf(x), Y.cdf(x))
  1068. assert_allclose(X.logccdf(x), Y.logccdf(x))
  1069. assert_allclose(X.ccdf(x), Y.ccdf(x))
  1070. assert_allclose(X.icdf(p), Y.icdf(p))
  1071. assert_allclose(X.iccdf(p), Y.iccdf(p))
  1072. for kind in ['raw', 'central', 'standardized']:
  1073. for order in range(5):
  1074. assert_allclose(X.moment(order, kind=kind),
  1075. Y.moment(order, kind=kind))
  1076. # Confirm that the `sample` and `moment` methods are overriden as expected
  1077. sample_formula = X.sample(shape=10, rng=0, method='formula')
  1078. sample_inverse = X.sample(shape=10, rng=0, method='inverse_transform')
  1079. assert_allclose(sample_formula, sample_inverse)
  1080. assert not np.all(sample_formula == sample_inverse)
  1081. assert_allclose(X.mean(method='formula'), X.mean(method='quadrature'))
  1082. assert not X.mean(method='formula') == X.mean(method='quadrature')
  1083. # pdf and cdf formulas below can warn on boundary of support in some cases.
  1084. # See https://github.com/scipy/scipy/pull/22560#discussion_r1962763840.
  1085. @pytest.mark.slow
  1086. @pytest.mark.filterwarnings("ignore::RuntimeWarning")
  1087. @pytest.mark.parametrize("c", [-1, 0, 1, np.asarray([-2.1, -1., 0., 1., 2.1])])
  1088. def test_custom_variable_support(self, c):
  1089. rng = np.random.default_rng(7548723590230982)
  1090. class MyGenExtreme:
  1091. @property
  1092. def __make_distribution_version__(self):
  1093. return "1.16.0"
  1094. @property
  1095. def parameters(self):
  1096. return {
  1097. 'c': {'endpoints': (-np.inf, np.inf), 'inclusive': (False, False)},
  1098. 'mu': {'endpoints': (-np.inf, np.inf), 'inclusive': (False, False)},
  1099. 'sigma': {'endpoints': (0, np.inf), 'inclusive': (False, False)}
  1100. }
  1101. @property
  1102. def support(self):
  1103. def left(*, c, mu, sigma):
  1104. c, mu, sigma = np.broadcast_arrays(c, mu, sigma)
  1105. result = np.empty_like(c)
  1106. result[c >= 0] = -np.inf
  1107. result[c < 0] = mu[c < 0] + sigma[c < 0] / c[c < 0]
  1108. return result[()]
  1109. def right(*, c, mu, sigma):
  1110. c, mu, sigma = np.broadcast_arrays(c, mu, sigma)
  1111. result = np.empty_like(c)
  1112. result[c <= 0] = np.inf
  1113. result[c > 0] = mu[c > 0] + sigma[c > 0] / c[c > 0]
  1114. return result[()]
  1115. return {"endpoints": (left, right), "inclusive": (False, False)}
  1116. def pdf(self, x, *, c, mu, sigma):
  1117. x, c, mu, sigma = np.broadcast_arrays(x, c, mu, sigma)
  1118. t = np.empty_like(x)
  1119. mask = (c == 0)
  1120. t[mask] = np.exp(-(x[mask] - mu[mask])/sigma[mask])
  1121. t[~mask] = (
  1122. 1 - c[~mask]*(x[~mask] - mu[~mask])/sigma[~mask]
  1123. )**(1/c[~mask])
  1124. result = 1/sigma * t**(1 - c)*np.exp(-t)
  1125. return result[()]
  1126. def cdf(self, x, *, c, mu, sigma):
  1127. x, c, mu, sigma = np.broadcast_arrays(x, c, mu, sigma)
  1128. t = np.empty_like(x)
  1129. mask = (c == 0)
  1130. t[mask] = np.exp(-(x[mask] - mu[mask])/sigma[mask])
  1131. t[~mask] = (
  1132. 1 - c[~mask]*(x[~mask] - mu[~mask])/sigma[~mask]
  1133. )**(1/c[~mask])
  1134. return np.exp(-t)[()]
  1135. GenExtreme1 = stats.make_distribution(MyGenExtreme())
  1136. GenExtreme2 = stats.make_distribution(stats.genextreme)
  1137. X1 = GenExtreme1(c=c, mu=0, sigma=1)
  1138. X2 = GenExtreme2(c=c)
  1139. x = X1.sample(shape=10, rng=rng)
  1140. p = X1.cdf(x)
  1141. assert_allclose(X1.support(), X2.support())
  1142. assert_allclose(X1.entropy(), X2.entropy(), rtol=5e-6)
  1143. assert_allclose(X1.median(), X2.median())
  1144. assert_allclose(X1.logpdf(x), X2.logpdf(x))
  1145. assert_allclose(X1.pdf(x), X2.pdf(x))
  1146. assert_allclose(X1.logcdf(x), X2.logcdf(x))
  1147. assert_allclose(X1.cdf(x), X2.cdf(x))
  1148. assert_allclose(X1.logccdf(x), X2.logccdf(x))
  1149. assert_allclose(X1.ccdf(x), X2.ccdf(x))
  1150. assert_allclose(X1.icdf(p), X2.icdf(p))
  1151. assert_allclose(X1.iccdf(p), X2.iccdf(p))
  1152. @pytest.mark.slow
  1153. @pytest.mark.parametrize("a", [0.5, np.asarray([0.5, 1.0, 2.0, 4.0, 8.0])])
  1154. @pytest.mark.parametrize("b", [0.5, np.asarray([0.5, 1.0, 2.0, 4.0, 8.0])])
  1155. def test_custom_multiple_parameterizations(self, a, b):
  1156. rng = np.random.default_rng(7548723590230982)
  1157. class MyBeta:
  1158. @property
  1159. def __make_distribution_version__(self):
  1160. return "1.16.0"
  1161. @property
  1162. def parameters(self):
  1163. return (
  1164. {"a": (0, np.inf), "b": (0, np.inf)},
  1165. {"mu": (0, 1), "nu": (0, np.inf)},
  1166. )
  1167. def process_parameters(self, a=None, b=None, mu=None, nu=None):
  1168. if a is not None and b is not None and mu is None and nu is None:
  1169. nu = a + b
  1170. mu = a / nu
  1171. else:
  1172. a = mu * nu
  1173. b = nu - a
  1174. return {"a": a, "b": b, "mu": mu, "nu": nu}
  1175. @property
  1176. def support(self):
  1177. return {'endpoints': (0, 1)}
  1178. def pdf(self, x, a, b, mu, nu):
  1179. return special._ufuncs._beta_pdf(x, a, b)
  1180. def cdf(self, x, a, b, mu, nu):
  1181. return special.betainc(a, b, x)
  1182. Beta = stats.make_distribution(stats.beta)
  1183. MyBeta = stats.make_distribution(MyBeta())
  1184. mu = a / (a + b)
  1185. nu = a + b
  1186. X = MyBeta(a=a, b=b)
  1187. Y = MyBeta(mu=mu, nu=nu)
  1188. Z = Beta(a=a, b=b)
  1189. x = Z.sample(shape=10, rng=rng)
  1190. p = Z.cdf(x)
  1191. assert_allclose(X.support(), Z.support())
  1192. assert_allclose(X.median(), Z.median())
  1193. assert_allclose(X.pdf(x), Z.pdf(x))
  1194. assert_allclose(X.cdf(x), Z.cdf(x))
  1195. assert_allclose(X.ccdf(x), Z.ccdf(x))
  1196. assert_allclose(X.icdf(p), Z.icdf(p))
  1197. assert_allclose(X.iccdf(p), Z.iccdf(p))
  1198. assert_allclose(Y.support(), Z.support())
  1199. assert_allclose(Y.median(), Z.median())
  1200. assert_allclose(Y.pdf(x), Z.pdf(x))
  1201. assert_allclose(Y.cdf(x), Z.cdf(x))
  1202. assert_allclose(Y.ccdf(x), Z.ccdf(x))
  1203. assert_allclose(Y.icdf(p), Z.icdf(p))
  1204. assert_allclose(Y.iccdf(p), Z.iccdf(p))
  1205. def test_input_validation(self):
  1206. message = '`levy_stable` is not supported.'
  1207. with pytest.raises(NotImplementedError, match=message):
  1208. stats.make_distribution(stats.levy_stable)
  1209. message = '`vonmises` is not supported.'
  1210. with pytest.raises(NotImplementedError, match=message):
  1211. stats.make_distribution(stats.vonmises)
  1212. message = "The argument must be an instance of..."
  1213. with pytest.raises(ValueError, match=message):
  1214. stats.make_distribution(object())
  1215. def test_repr_str_docs(self):
  1216. from scipy.stats._distribution_infrastructure import _distribution_names
  1217. for dist in _distribution_names.keys():
  1218. assert hasattr(stats, dist)
  1219. dist = stats.make_distribution(stats.gamma)
  1220. assert str(dist(a=2)) == "Gamma(a=2.0)"
  1221. if np.__version__ >= "2":
  1222. assert repr(dist(a=2)) == "Gamma(a=np.float64(2.0))"
  1223. assert 'Gamma' in dist.__doc__
  1224. dist = stats.make_distribution(stats.halfgennorm)
  1225. assert str(dist(beta=2)) == "HalfGeneralizedNormal(beta=2.0)"
  1226. if np.__version__ >= "2":
  1227. assert repr(dist(beta=2)) == "HalfGeneralizedNormal(beta=np.float64(2.0))"
  1228. assert 'HalfGeneralizedNormal' in dist.__doc__
  1229. class TestTransforms:
  1230. def test_ContinuousDistribution_only(self):
  1231. X = stats.Binomial(n=10, p=0.5)
  1232. # This is applied at the top level TransformedDistribution,
  1233. # so testing one subclass is enough
  1234. message = "Transformations are currently only supported for continuous RVs."
  1235. with pytest.raises(NotImplementedError, match=message):
  1236. stats.exp(X)
  1237. def test_truncate(self):
  1238. rng = np.random.default_rng(81345982345826)
  1239. lb = rng.random((3, 1))
  1240. ub = rng.random((3, 1))
  1241. lb, ub = np.minimum(lb, ub), np.maximum(lb, ub)
  1242. Y = stats.truncate(Normal(), lb=lb, ub=ub)
  1243. Y0 = stats.truncnorm(lb, ub)
  1244. y = Y0.rvs((3, 10), random_state=rng)
  1245. p = Y0.cdf(y)
  1246. assert_allclose(Y.logentropy(), np.log(Y0.entropy() + 0j))
  1247. assert_allclose(Y.entropy(), Y0.entropy())
  1248. assert_allclose(Y.median(), Y0.ppf(0.5))
  1249. assert_allclose(Y.mean(), Y0.mean())
  1250. assert_allclose(Y.variance(), Y0.var())
  1251. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1252. assert_allclose(Y.skewness(), Y0.stats('s'))
  1253. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1254. assert_allclose(Y.support(), Y0.support())
  1255. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1256. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1257. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1258. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1259. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1260. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1261. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1262. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1263. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1264. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1265. sample = Y.sample(10)
  1266. assert np.all((sample > lb) & (sample < ub))
  1267. @pytest.mark.fail_slow(10)
  1268. @given(data=strategies.data(), seed=strategies.integers(min_value=0))
  1269. def test_loc_scale(self, data, seed):
  1270. # Need tests with negative scale
  1271. rng = np.random.default_rng(seed)
  1272. class TransformedNormal(ShiftedScaledDistribution):
  1273. def __init__(self, *args, **kwargs):
  1274. super().__init__(StandardNormal(), *args, **kwargs)
  1275. tmp = draw_distribution_from_family(
  1276. TransformedNormal, data, rng, proportions=(1, 0, 0, 0), min_side=1)
  1277. dist, x, y, p, logp, result_shape, x_result_shape, xy_result_shape = tmp
  1278. loc = dist.loc
  1279. scale = dist.scale
  1280. dist0 = StandardNormal()
  1281. dist_ref = stats.norm(loc=loc, scale=scale)
  1282. x0 = (x - loc) / scale
  1283. y0 = (y - loc) / scale
  1284. a, b = dist.support()
  1285. a0, b0 = dist0.support()
  1286. assert_allclose(a, a0 + loc)
  1287. assert_allclose(b, b0 + loc)
  1288. with np.errstate(invalid='ignore', divide='ignore'):
  1289. assert_allclose(np.exp(dist.logentropy()), dist.entropy())
  1290. assert_allclose(dist.entropy(), dist_ref.entropy())
  1291. assert_allclose(dist.median(), dist0.median() + loc)
  1292. assert_allclose(dist.mode(), dist0.mode() + loc)
  1293. assert_allclose(dist.mean(), dist0.mean() + loc)
  1294. assert_allclose(dist.variance(), dist0.variance() * scale**2)
  1295. assert_allclose(dist.standard_deviation(), dist.variance()**0.5)
  1296. assert_allclose(dist.skewness(), dist0.skewness() * np.sign(scale))
  1297. assert_allclose(dist.kurtosis(), dist0.kurtosis())
  1298. assert_allclose(dist.logpdf(x), dist0.logpdf(x0) - np.log(scale))
  1299. assert_allclose(dist.pdf(x), dist0.pdf(x0) / scale)
  1300. assert_allclose(dist.logcdf(x), dist0.logcdf(x0))
  1301. assert_allclose(dist.cdf(x), dist0.cdf(x0))
  1302. assert_allclose(dist.logccdf(x), dist0.logccdf(x0))
  1303. assert_allclose(dist.ccdf(x), dist0.ccdf(x0))
  1304. assert_allclose(dist.logcdf(x, y), dist0.logcdf(x0, y0))
  1305. assert_allclose(dist.cdf(x, y), dist0.cdf(x0, y0))
  1306. assert_allclose(dist.logccdf(x, y), dist0.logccdf(x0, y0))
  1307. assert_allclose(dist.ccdf(x, y), dist0.ccdf(x0, y0))
  1308. assert_allclose(dist.ilogcdf(logp), dist0.ilogcdf(logp)*scale + loc)
  1309. assert_allclose(dist.icdf(p), dist0.icdf(p)*scale + loc)
  1310. assert_allclose(dist.ilogccdf(logp), dist0.ilogccdf(logp)*scale + loc)
  1311. assert_allclose(dist.iccdf(p), dist0.iccdf(p)*scale + loc)
  1312. for i in range(1, 5):
  1313. assert_allclose(dist.moment(i, 'raw'), dist_ref.moment(i))
  1314. assert_allclose(dist.moment(i, 'central'),
  1315. dist0.moment(i, 'central') * scale**i)
  1316. assert_allclose(dist.moment(i, 'standardized'),
  1317. dist0.moment(i, 'standardized') * np.sign(scale)**i)
  1318. # Transform back to the original distribution using all arithmetic
  1319. # operations; check that it behaves as expected.
  1320. dist = (dist - 2*loc) + loc
  1321. dist = dist/scale**2 * scale
  1322. z = np.zeros(dist._shape) # compact broadcasting
  1323. a, b = dist.support()
  1324. a0, b0 = dist0.support()
  1325. assert_allclose(a, a0 + z)
  1326. assert_allclose(b, b0 + z)
  1327. with np.errstate(invalid='ignore', divide='ignore'):
  1328. assert_allclose(dist.logentropy(), dist0.logentropy() + z)
  1329. assert_allclose(dist.entropy(), dist0.entropy() + z)
  1330. assert_allclose(dist.median(), dist0.median() + z)
  1331. assert_allclose(dist.mode(), dist0.mode() + z)
  1332. assert_allclose(dist.mean(), dist0.mean() + z)
  1333. assert_allclose(dist.variance(), dist0.variance() + z)
  1334. assert_allclose(dist.standard_deviation(), dist0.standard_deviation() + z)
  1335. assert_allclose(dist.skewness(), dist0.skewness() + z)
  1336. assert_allclose(dist.kurtosis(), dist0.kurtosis() + z)
  1337. assert_allclose(dist.logpdf(x), dist0.logpdf(x)+z)
  1338. assert_allclose(dist.pdf(x), dist0.pdf(x) + z)
  1339. assert_allclose(dist.logcdf(x), dist0.logcdf(x) + z)
  1340. assert_allclose(dist.cdf(x), dist0.cdf(x) + z)
  1341. assert_allclose(dist.logccdf(x), dist0.logccdf(x) + z)
  1342. assert_allclose(dist.ccdf(x), dist0.ccdf(x) + z)
  1343. assert_allclose(dist.ilogcdf(logp), dist0.ilogcdf(logp) + z)
  1344. assert_allclose(dist.icdf(p), dist0.icdf(p) + z)
  1345. assert_allclose(dist.ilogccdf(logp), dist0.ilogccdf(logp) + z)
  1346. assert_allclose(dist.iccdf(p), dist0.iccdf(p) + z)
  1347. for i in range(1, 5):
  1348. assert_allclose(dist.moment(i, 'raw'), dist0.moment(i, 'raw'))
  1349. assert_allclose(dist.moment(i, 'central'), dist0.moment(i, 'central'))
  1350. assert_allclose(dist.moment(i, 'standardized'),
  1351. dist0.moment(i, 'standardized'))
  1352. # These are tough to compare because of the way the shape works
  1353. # rng = np.random.default_rng(seed)
  1354. # rng0 = np.random.default_rng(seed)
  1355. # assert_allclose(dist.sample(x_result_shape, rng=rng),
  1356. # dist0.sample(x_result_shape, rng=rng0) * scale + loc)
  1357. # Should also try to test fit, plot?
  1358. @pytest.mark.fail_slow(5)
  1359. @pytest.mark.parametrize('exp_pow', ['exp', 'pow'])
  1360. def test_exp_pow(self, exp_pow):
  1361. rng = np.random.default_rng(81345982345826)
  1362. mu = rng.random((3, 1))
  1363. sigma = rng.random((3, 1))
  1364. X = Normal()*sigma + mu
  1365. if exp_pow == 'exp':
  1366. Y = stats.exp(X)
  1367. else:
  1368. Y = np.e ** X
  1369. Y0 = stats.lognorm(sigma, scale=np.exp(mu))
  1370. y = Y0.rvs((3, 10), random_state=rng)
  1371. p = Y0.cdf(y)
  1372. assert_allclose(Y.logentropy(), np.log(Y0.entropy()))
  1373. assert_allclose(Y.entropy(), Y0.entropy())
  1374. assert_allclose(Y.median(), Y0.ppf(0.5))
  1375. assert_allclose(Y.mean(), Y0.mean())
  1376. assert_allclose(Y.variance(), Y0.var())
  1377. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1378. assert_allclose(Y.skewness(), Y0.stats('s'))
  1379. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1380. assert_allclose(Y.support(), Y0.support())
  1381. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1382. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1383. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1384. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1385. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1386. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1387. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1388. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1389. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1390. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1391. seed = 3984593485
  1392. assert_allclose(Y.sample(rng=seed), np.exp(X.sample(rng=seed)))
  1393. @pytest.mark.fail_slow(10)
  1394. @pytest.mark.parametrize('scale', [1, 2, -1])
  1395. @pytest.mark.xfail_on_32bit("`scale=-1` fails on 32-bit; needs investigation")
  1396. def test_reciprocal(self, scale):
  1397. rng = np.random.default_rng(81345982345826)
  1398. a = rng.random((3, 1))
  1399. # Separate sign from scale. It's easy to scale the resulting
  1400. # RV with negative scale; we want to test the ability to divide
  1401. # by a RV with negative support
  1402. sign, scale = np.sign(scale), abs(scale)
  1403. # Reference distribution
  1404. InvGamma = stats.make_distribution(stats.invgamma)
  1405. Y0 = sign * scale * InvGamma(a=a)
  1406. # Test distribution
  1407. X = _Gamma(a=a) if sign > 0 else -_Gamma(a=a)
  1408. Y = scale / X
  1409. y = Y0.sample(shape=(3, 10), rng=rng)
  1410. p = Y0.cdf(y)
  1411. logp = np.log(p)
  1412. assert_allclose(Y.logentropy(), np.log(Y0.entropy()))
  1413. assert_allclose(Y.entropy(), Y0.entropy())
  1414. assert_allclose(Y.median(), Y0.median())
  1415. # moments are not finite
  1416. assert_allclose(Y.support(), Y0.support())
  1417. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1418. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1419. assert_allclose(Y.ccdf(y), Y0.ccdf(y))
  1420. assert_allclose(Y.icdf(p), Y0.icdf(p))
  1421. assert_allclose(Y.iccdf(p), Y0.iccdf(p))
  1422. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1423. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1424. assert_allclose(Y.logccdf(y), Y0.logccdf(y))
  1425. with np.errstate(divide='ignore', invalid='ignore'):
  1426. assert_allclose(Y.ilogcdf(logp), Y0.ilogcdf(logp))
  1427. assert_allclose(Y.ilogccdf(logp), Y0.ilogccdf(logp))
  1428. seed = 3984593485
  1429. assert_allclose(Y.sample(rng=seed), scale/(X.sample(rng=seed)))
  1430. @pytest.mark.fail_slow(5)
  1431. def test_log(self):
  1432. rng = np.random.default_rng(81345982345826)
  1433. a = rng.random((3, 1))
  1434. X = _Gamma(a=a)
  1435. Y0 = stats.loggamma(a)
  1436. Y = stats.log(X)
  1437. y = Y0.rvs((3, 10), random_state=rng)
  1438. p = Y0.cdf(y)
  1439. assert_allclose(Y.logentropy(), np.log(Y0.entropy()))
  1440. assert_allclose(Y.entropy(), Y0.entropy())
  1441. assert_allclose(Y.median(), Y0.ppf(0.5))
  1442. assert_allclose(Y.mean(), Y0.mean())
  1443. assert_allclose(Y.variance(), Y0.var())
  1444. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1445. assert_allclose(Y.skewness(), Y0.stats('s'))
  1446. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1447. assert_allclose(Y.support(), Y0.support())
  1448. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1449. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1450. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1451. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1452. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1453. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1454. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1455. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1456. with np.errstate(invalid='ignore'):
  1457. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1458. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1459. seed = 3984593485
  1460. assert_allclose(Y.sample(rng=seed), np.log(X.sample(rng=seed)))
  1461. def test_monotonic_transforms(self):
  1462. # Some tests of monotonic transforms that are better to be grouped or
  1463. # don't fit well above
  1464. X = Uniform(a=1, b=2)
  1465. X_str = "Uniform(a=1.0, b=2.0)"
  1466. assert str(stats.log(X)) == f"log({X_str})"
  1467. assert str(1 / X) == f"1/({X_str})"
  1468. assert str(stats.exp(X)) == f"exp({X_str})"
  1469. X = Uniform(a=-1, b=2)
  1470. message = "Division by a random variable is only implemented when the..."
  1471. with pytest.raises(NotImplementedError, match=message):
  1472. 1 / X
  1473. message = "The logarithm of a random variable is only implemented when the..."
  1474. with pytest.raises(NotImplementedError, match=message):
  1475. stats.log(X)
  1476. message = "Raising an argument to the power of a random variable is only..."
  1477. with pytest.raises(NotImplementedError, match=message):
  1478. (-2) ** X
  1479. with pytest.raises(NotImplementedError, match=message):
  1480. 1 ** X
  1481. with pytest.raises(NotImplementedError, match=message):
  1482. [0.5, 1.5] ** X
  1483. message = "Raising a random variable to the power of an argument is only"
  1484. with pytest.raises(NotImplementedError, match=message):
  1485. X ** (-2)
  1486. with pytest.raises(NotImplementedError, match=message):
  1487. X ** 0
  1488. with pytest.raises(NotImplementedError, match=message):
  1489. X ** [0.5, 1.5]
  1490. def test_arithmetic_operators(self):
  1491. rng = np.random.default_rng(2348923495832349834)
  1492. a, b, loc, scale = 0.294, 1.34, 0.57, 1.16
  1493. x = rng.uniform(-3, 3, 100)
  1494. Y = _LogUniform(a=a, b=b)
  1495. X = scale*Y + loc
  1496. assert_allclose(X.cdf(x), Y.cdf((x - loc) / scale))
  1497. X = loc + Y*scale
  1498. assert_allclose(X.cdf(x), Y.cdf((x - loc) / scale))
  1499. X = Y/scale - loc
  1500. assert_allclose(X.cdf(x), Y.cdf((x + loc) * scale))
  1501. X = loc -_LogUniform(a=a, b=b)/scale
  1502. assert_allclose(X.cdf(x), Y.ccdf((-x + loc)*scale))
  1503. def test_abs(self):
  1504. rng = np.random.default_rng(81345982345826)
  1505. loc = rng.random((3, 1))
  1506. Y = stats.abs(Normal() + loc)
  1507. Y0 = stats.foldnorm(loc)
  1508. y = Y0.rvs((3, 10), random_state=rng)
  1509. p = Y0.cdf(y)
  1510. assert_allclose(Y.logentropy(), np.log(Y0.entropy() + 0j))
  1511. assert_allclose(Y.entropy(), Y0.entropy())
  1512. assert_allclose(Y.median(), Y0.ppf(0.5))
  1513. assert_allclose(Y.mean(), Y0.mean())
  1514. assert_allclose(Y.variance(), Y0.var())
  1515. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1516. assert_allclose(Y.skewness(), Y0.stats('s'))
  1517. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1518. assert_allclose(Y.support(), Y0.support())
  1519. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1520. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1521. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1522. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1523. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1524. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1525. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1526. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1527. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1528. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1529. sample = Y.sample(10)
  1530. assert np.all(sample > 0)
  1531. def test_abs_finite_support(self):
  1532. # The original implementation of `FoldedDistribution` might evaluate
  1533. # the private distribution methods outside the support. Check that this
  1534. # is resolved.
  1535. Weibull = stats.make_distribution(stats.weibull_min)
  1536. X = Weibull(c=2)
  1537. Y = abs(-X)
  1538. assert_equal(X.logpdf(1), Y.logpdf(1))
  1539. assert_equal(X.pdf(1), Y.pdf(1))
  1540. assert_equal(X.logcdf(1), Y.logcdf(1))
  1541. assert_equal(X.cdf(1), Y.cdf(1))
  1542. assert_equal(X.logccdf(1), Y.logccdf(1))
  1543. assert_equal(X.ccdf(1), Y.ccdf(1))
  1544. def test_pow(self):
  1545. rng = np.random.default_rng(81345982345826)
  1546. Y = Normal()**2
  1547. Y0 = stats.chi2(df=1)
  1548. y = Y0.rvs(10, random_state=rng)
  1549. p = Y0.cdf(y)
  1550. assert_allclose(Y.logentropy(), np.log(Y0.entropy() + 0j), rtol=1e-6)
  1551. assert_allclose(Y.entropy(), Y0.entropy(), rtol=1e-6)
  1552. assert_allclose(Y.median(), Y0.median())
  1553. assert_allclose(Y.mean(), Y0.mean())
  1554. assert_allclose(Y.variance(), Y0.var())
  1555. assert_allclose(Y.standard_deviation(), np.sqrt(Y0.var()))
  1556. assert_allclose(Y.skewness(), Y0.stats('s'))
  1557. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3)
  1558. assert_allclose(Y.support(), Y0.support())
  1559. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1560. assert_allclose(Y.cdf(y), Y0.cdf(y))
  1561. assert_allclose(Y.ccdf(y), Y0.sf(y))
  1562. assert_allclose(Y.icdf(p), Y0.ppf(p))
  1563. assert_allclose(Y.iccdf(p), Y0.isf(p))
  1564. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1565. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1566. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1567. assert_allclose(Y.ilogcdf(np.log(p)), Y0.ppf(p))
  1568. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1569. sample = Y.sample(10)
  1570. assert np.all(sample > 0)
  1571. class TestOrderStatistic:
  1572. @pytest.mark.fail_slow(20) # Moments require integration
  1573. def test_order_statistic(self):
  1574. rng = np.random.default_rng(7546349802439582)
  1575. X = Uniform(a=0, b=1)
  1576. n = 5
  1577. r = np.asarray([[1], [3], [5]])
  1578. Y = stats.order_statistic(X, n=n, r=r)
  1579. Y0 = stats.beta(r, n + 1 - r)
  1580. y = Y0.rvs((3, 10), random_state=rng)
  1581. p = Y0.cdf(y)
  1582. # log methods need some attention before merge
  1583. assert_allclose(np.exp(Y.logentropy()), Y0.entropy())
  1584. assert_allclose(Y.entropy(), Y0.entropy())
  1585. assert_allclose(Y.mean(), Y0.mean())
  1586. assert_allclose(Y.variance(), Y0.var())
  1587. assert_allclose(Y.skewness(), Y0.stats('s'), atol=1e-15)
  1588. assert_allclose(Y.kurtosis(), Y0.stats('k') + 3, atol=1e-15)
  1589. assert_allclose(Y.median(), Y0.ppf(0.5))
  1590. assert_allclose(Y.support(), Y0.support())
  1591. assert_allclose(Y.pdf(y), Y0.pdf(y))
  1592. assert_allclose(Y.cdf(y, method='formula'), Y.cdf(y, method='quadrature'))
  1593. assert_allclose(Y.ccdf(y, method='formula'), Y.ccdf(y, method='quadrature'))
  1594. assert_allclose(Y.icdf(p, method='formula'), Y.icdf(p, method='inversion'))
  1595. assert_allclose(Y.iccdf(p, method='formula'), Y.iccdf(p, method='inversion'))
  1596. assert_allclose(Y.logpdf(y), Y0.logpdf(y))
  1597. assert_allclose(Y.logcdf(y), Y0.logcdf(y))
  1598. assert_allclose(Y.logccdf(y), Y0.logsf(y))
  1599. with np.errstate(invalid='ignore', divide='ignore'):
  1600. assert_allclose(Y.ilogcdf(np.log(p),), Y0.ppf(p))
  1601. assert_allclose(Y.ilogccdf(np.log(p)), Y0.isf(p))
  1602. message = "`r` and `n` must contain only positive integers."
  1603. with pytest.raises(ValueError, match=message):
  1604. stats.order_statistic(X, n=n, r=-1)
  1605. with pytest.raises(ValueError, match=message):
  1606. stats.order_statistic(X, n=-1, r=r)
  1607. with pytest.raises(ValueError, match=message):
  1608. stats.order_statistic(X, n=n, r=1.5)
  1609. with pytest.raises(ValueError, match=message):
  1610. stats.order_statistic(X, n=1.5, r=r)
  1611. def test_support_gh22037(self):
  1612. # During review of gh-22037, it was noted that the `support` of
  1613. # an `OrderStatisticDistribution` returned incorrect results;
  1614. # this was resolved by overriding `_support`.
  1615. Uniform = stats.make_distribution(stats.uniform)
  1616. X = Uniform()
  1617. Y = X*5 + 2
  1618. Z = stats.order_statistic(Y, r=3, n=5)
  1619. assert_allclose(Z.support(), Y.support())
  1620. def test_composition_gh22037(self):
  1621. # During review of gh-22037, it was noted that an error was
  1622. # raised when creating an `OrderStatisticDistribution` from
  1623. # a `TruncatedDistribution`. This was resolved by overriding
  1624. # `_update_parameters`.
  1625. Normal = stats.make_distribution(stats.norm)
  1626. TruncatedNormal = stats.make_distribution(stats.truncnorm)
  1627. a, b = [-2, -1], 1
  1628. r, n = 3, [[4], [5]]
  1629. x = [[[-0.3]], [[0.1]]]
  1630. X1 = Normal()
  1631. Y1 = stats.truncate(X1, a, b)
  1632. Z1 = stats.order_statistic(Y1, r=r, n=n)
  1633. X2 = TruncatedNormal(a=a, b=b)
  1634. Z2 = stats.order_statistic(X2, r=r, n=n)
  1635. np.testing.assert_allclose(Z1.cdf(x), Z2.cdf(x))
  1636. class TestFullCoverage:
  1637. # Adds tests just to get to 100% test coverage; this way it's more obvious
  1638. # if new lines are untested.
  1639. def test_Domain(self):
  1640. with pytest.raises(NotImplementedError):
  1641. _Domain.contains(None, 1.)
  1642. with pytest.raises(NotImplementedError):
  1643. _Domain.get_numerical_endpoints(None, 1.)
  1644. with pytest.raises(NotImplementedError):
  1645. _Domain.__str__(None)
  1646. def test_Parameter(self):
  1647. with pytest.raises(NotImplementedError):
  1648. _Parameter.validate(None, 1.)
  1649. @pytest.mark.parametrize(("dtype_in", "dtype_out"),
  1650. [(np.float16, np.float16),
  1651. (np.int16, np.float64)])
  1652. def test_RealParameter_uncommon_dtypes(self, dtype_in, dtype_out):
  1653. domain = _RealInterval((-1, 1))
  1654. parameter = _RealParameter('x', domain=domain)
  1655. x = np.asarray([0.5, 2.5], dtype=dtype_in)
  1656. arr, dtype, valid = parameter.validate(x, parameter_values={})
  1657. assert_equal(arr, x)
  1658. assert dtype == dtype_out
  1659. assert_equal(valid, [True, False])
  1660. def test_ContinuousDistribution_set_invalid_nan(self):
  1661. # Exercise code paths when formula returns wrong shape and dtype
  1662. # We could consider making this raise an error to force authors
  1663. # to return the right shape and dytpe, but this would need to be
  1664. # configurable.
  1665. class TestDist(ContinuousDistribution):
  1666. _variable = _RealParameter('x', domain=_RealInterval(endpoints=(0., 1.)))
  1667. def _logpdf_formula(self, x, *args, **kwargs):
  1668. return 0
  1669. X = TestDist()
  1670. dtype = np.float32
  1671. X._dtype = dtype
  1672. x = np.asarray([0.5], dtype=dtype)
  1673. assert X.logpdf(x).dtype == dtype
  1674. def test_fiinfo(self):
  1675. assert _fiinfo(np.float64(1.)).max == np.finfo(np.float64).max
  1676. assert _fiinfo(np.int64(1)).max == np.iinfo(np.int64).max
  1677. def test_generate_domain_support(self):
  1678. msg = _generate_domain_support(StandardNormal)
  1679. assert "accepts no distribution parameters" in msg
  1680. msg = _generate_domain_support(Normal)
  1681. assert "accepts one parameterization" in msg
  1682. msg = _generate_domain_support(_LogUniform)
  1683. assert "accepts two parameterizations" in msg
  1684. def test_ContinuousDistribution__repr__(self):
  1685. X = Uniform(a=0, b=1)
  1686. if np.__version__ < "2":
  1687. assert repr(X) == "Uniform(a=0.0, b=1.0)"
  1688. else:
  1689. assert repr(X) == "Uniform(a=np.float64(0.0), b=np.float64(1.0))"
  1690. if np.__version__ < "2":
  1691. assert repr(X*3 + 2) == "3.0*Uniform(a=0.0, b=1.0) + 2.0"
  1692. else:
  1693. assert repr(X*3 + 2) == (
  1694. "np.float64(3.0)*Uniform(a=np.float64(0.0), b=np.float64(1.0))"
  1695. " + np.float64(2.0)"
  1696. )
  1697. X = Uniform(a=np.zeros(4), b=1)
  1698. assert repr(X) == "Uniform(a=array([0., 0., 0., 0.]), b=1)"
  1699. X = Uniform(a=np.zeros(4, dtype=np.float32), b=np.ones(4, dtype=np.float32))
  1700. assert repr(X) == (
  1701. "Uniform(a=array([0., 0., 0., 0.], dtype=float32),"
  1702. " b=array([1., 1., 1., 1.], dtype=float32))"
  1703. )
  1704. class TestReprs:
  1705. U = Uniform(a=0, b=1)
  1706. V = Uniform(a=np.float32(0.0), b=np.float32(1.0))
  1707. X = Normal(mu=-1, sigma=1)
  1708. Y = Normal(mu=1, sigma=1)
  1709. Z = Normal(mu=np.zeros(1000), sigma=1)
  1710. @pytest.mark.parametrize(
  1711. "dist",
  1712. [
  1713. U,
  1714. U - np.array([1.0, 2.0]),
  1715. pytest.param(
  1716. V,
  1717. marks=pytest.mark.skipif(
  1718. np.__version__ < "2",
  1719. reason="numpy 1.x didn't have dtype in repr",
  1720. )
  1721. ),
  1722. pytest.param(
  1723. np.ones(2, dtype=np.float32)*V + np.zeros(2, dtype=np.float64),
  1724. marks=pytest.mark.skipif(
  1725. np.__version__ < "2",
  1726. reason="numpy 1.x didn't have dtype in repr",
  1727. )
  1728. ),
  1729. 3*U + 2,
  1730. U**4,
  1731. (3*U + 2)**4,
  1732. (3*U + 2)**3,
  1733. 2**U,
  1734. 2**(3*U + 1),
  1735. 1 / (1 + U),
  1736. stats.order_statistic(U, r=3, n=5),
  1737. stats.truncate(U, 0.2, 0.8),
  1738. stats.Mixture([X, Y], weights=[0.3, 0.7]),
  1739. abs(U),
  1740. stats.exp(U),
  1741. stats.log(1 + U),
  1742. np.array([1.0, 2.0])*U + np.array([2.0, 3.0]),
  1743. ]
  1744. )
  1745. def test_executable(self, dist):
  1746. # Test that reprs actually evaluate to proper distribution
  1747. # provided relevant imports are made.
  1748. from numpy import array # noqa: F401
  1749. from numpy import float32 # noqa: F401
  1750. from scipy.stats import abs, exp, log, order_statistic, truncate # noqa: F401
  1751. from scipy.stats import Mixture, Normal # noqa: F401
  1752. from scipy.stats._new_distributions import Uniform # noqa: F401
  1753. new_dist = eval(repr(dist))
  1754. # A basic check that the distributions are the same
  1755. sample1 = dist.sample(shape=10, rng=1234)
  1756. sample2 = new_dist.sample(shape=10, rng=1234)
  1757. assert_equal(sample1, sample2)
  1758. assert sample1.dtype is sample2.dtype
  1759. @pytest.mark.parametrize(
  1760. "dist",
  1761. [
  1762. Z,
  1763. np.full(1000, 2.0) * X + 1.0,
  1764. 2.0 * X + np.full(1000, 1.0),
  1765. np.full(1000, 2.0) * X + 1.0,
  1766. stats.truncate(Z, -1, 1),
  1767. stats.truncate(Z, -np.ones(1000), np.ones(1000)),
  1768. stats.order_statistic(X, r=np.arange(1, 1000), n=1000),
  1769. Z**2,
  1770. 1.0 / (1 + stats.exp(Z)),
  1771. 2**Z,
  1772. ]
  1773. )
  1774. def test_not_too_long(self, dist):
  1775. # Tests that array summarization is working to ensure reprs aren't too long.
  1776. # None of the reprs above will be executable.
  1777. assert len(repr(dist)) < 250
  1778. class MixedDist(ContinuousDistribution):
  1779. _variable = _RealParameter('x', domain=_RealInterval(endpoints=(-np.inf, np.inf)))
  1780. def _pdf_formula(self, x, *args, **kwargs):
  1781. return (0.4 * 1/(1.1 * np.sqrt(2*np.pi)) * np.exp(-0.5*((x+0.25)/1.1)**2)
  1782. + 0.6 * 1/(0.9 * np.sqrt(2*np.pi)) * np.exp(-0.5*((x-0.5)/0.9)**2))
  1783. class TestMixture:
  1784. def test_input_validation(self):
  1785. message = "`components` must contain at least one random variable."
  1786. with pytest.raises(ValueError, match=message):
  1787. Mixture([])
  1788. message = "Each element of `components` must be an instance..."
  1789. with pytest.raises(ValueError, match=message):
  1790. Mixture((1, 2, 3))
  1791. message = "All elements of `components` must have scalar shapes."
  1792. with pytest.raises(ValueError, match=message):
  1793. Mixture([Normal(mu=[1, 2]), Normal()])
  1794. message = "`components` and `weights` must have the same length."
  1795. with pytest.raises(ValueError, match=message):
  1796. Mixture([Normal()], weights=[0.5, 0.5])
  1797. message = "`weights` must have floating point dtype."
  1798. with pytest.raises(ValueError, match=message):
  1799. Mixture([Normal()], weights=[1])
  1800. message = "`weights` must have floating point dtype."
  1801. with pytest.raises(ValueError, match=message):
  1802. Mixture([Normal()], weights=[1])
  1803. message = "`weights` must sum to 1.0."
  1804. with pytest.raises(ValueError, match=message):
  1805. Mixture([Normal(), Normal()], weights=[0.5, 1.0])
  1806. message = "All `weights` must be non-negative."
  1807. with pytest.raises(ValueError, match=message):
  1808. Mixture([Normal(), Normal()], weights=[1.5, -0.5])
  1809. @pytest.mark.parametrize('shape', [(), (10,)])
  1810. def test_basic(self, shape):
  1811. rng = np.random.default_rng(582348972387243524)
  1812. X = Mixture((Normal(mu=-0.25, sigma=1.1), Normal(mu=0.5, sigma=0.9)),
  1813. weights=(0.4, 0.6))
  1814. Y = MixedDist()
  1815. x = rng.random(shape)
  1816. def assert_allclose(res, ref, **kwargs):
  1817. if shape == ():
  1818. assert np.isscalar(res)
  1819. np.testing.assert_allclose(res, ref, **kwargs)
  1820. assert_allclose(X.logentropy(), Y.logentropy())
  1821. assert_allclose(X.entropy(), Y.entropy())
  1822. assert_allclose(X.mode(), Y.mode())
  1823. assert_allclose(X.median(), Y.median())
  1824. assert_allclose(X.mean(), Y.mean())
  1825. assert_allclose(X.variance(), Y.variance())
  1826. assert_allclose(X.standard_deviation(), Y.standard_deviation())
  1827. assert_allclose(X.skewness(), Y.skewness())
  1828. assert_allclose(X.kurtosis(), Y.kurtosis())
  1829. assert_allclose(X.logpdf(x), Y.logpdf(x))
  1830. assert_allclose(X.pdf(x), Y.pdf(x))
  1831. assert_allclose(X.logcdf(x), Y.logcdf(x))
  1832. assert_allclose(X.cdf(x), Y.cdf(x))
  1833. assert_allclose(X.logccdf(x), Y.logccdf(x))
  1834. assert_allclose(X.ccdf(x), Y.ccdf(x))
  1835. assert_allclose(X.ilogcdf(x), Y.ilogcdf(x))
  1836. assert_allclose(X.icdf(x), Y.icdf(x))
  1837. assert_allclose(X.ilogccdf(x), Y.ilogccdf(x))
  1838. assert_allclose(X.iccdf(x), Y.iccdf(x))
  1839. for kind in ['raw', 'central', 'standardized']:
  1840. for order in range(5):
  1841. assert_allclose(X.moment(order, kind=kind),
  1842. Y.moment(order, kind=kind),
  1843. atol=1e-15)
  1844. # weak test of `sample`
  1845. shape = (10, 20, 5)
  1846. y = X.sample(shape, rng=rng)
  1847. assert y.shape == shape
  1848. assert stats.ks_1samp(y.ravel(), X.cdf).pvalue > 0.05
  1849. def test_default_weights(self):
  1850. a = 1.1
  1851. Gamma = stats.make_distribution(stats.gamma)
  1852. X = Gamma(a=a)
  1853. Y = stats.Mixture((X, -X))
  1854. x = np.linspace(-4, 4, 300)
  1855. assert_allclose(Y.pdf(x), stats.dgamma(a=a).pdf(x))
  1856. def test_properties(self):
  1857. components = [Normal(mu=-0.25, sigma=1.1), Normal(mu=0.5, sigma=0.9)]
  1858. weights = (0.4, 0.6)
  1859. X = Mixture(components, weights=weights)
  1860. # Replacing properties doesn't work
  1861. # Different version of Python have different messages
  1862. with pytest.raises(AttributeError):
  1863. X.components = 10
  1864. with pytest.raises(AttributeError):
  1865. X.weights = 10
  1866. # Mutation doesn't work
  1867. X.components[0] = components[1]
  1868. assert X.components[0] == components[0]
  1869. X.weights[0] = weights[1]
  1870. assert X.weights[0] == weights[0]
  1871. def test_inverse(self):
  1872. # Originally, inverse relied on the mean to start the bracket search.
  1873. # This didn't work for distributions with non-finite mean. Check that
  1874. # this is resolved.
  1875. rng = np.random.default_rng(24358934657854237863456)
  1876. Cauchy = stats.make_distribution(stats.cauchy)
  1877. X0 = Cauchy()
  1878. X = stats.Mixture([X0, X0])
  1879. p = rng.random(size=10)
  1880. np.testing.assert_allclose(X.icdf(p), X0.icdf(p))
  1881. np.testing.assert_allclose(X.iccdf(p), X0.iccdf(p))
  1882. np.testing.assert_allclose(X.ilogcdf(p), X0.ilogcdf(p))
  1883. np.testing.assert_allclose(X.ilogccdf(p), X0.ilogccdf(p))
  1884. def test_zipfian_distribution_wrapper():
  1885. # Regression test for gh-23678: calling the cdf method at the end
  1886. # point of the Zipfian distribution would generate a warning.
  1887. Zipfian = stats.make_distribution(stats.zipfian)
  1888. zdist = Zipfian(a=0.75, n=15)
  1889. # This should not generate any warnings.
  1890. assert_equal(zdist.cdf(15), 1.0)