_distribution_infrastructure.py 229 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766
  1. import functools
  2. from abc import ABC, abstractmethod
  3. from functools import cached_property
  4. from types import GenericAlias
  5. import inspect
  6. import math
  7. import numpy as np
  8. from numpy import inf
  9. from scipy._lib._array_api import xp_capabilities, xp_promote
  10. from scipy._lib._util import _rng_spawn, _RichResult
  11. from scipy._lib._docscrape import ClassDoc, NumpyDocString
  12. from scipy import special, stats
  13. from scipy.special._ufuncs import _log1mexp
  14. from scipy.integrate import tanhsinh as _tanhsinh, nsum
  15. from scipy.optimize._bracket import _bracket_root, _bracket_minimum
  16. from scipy.optimize._chandrupatla import _chandrupatla, _chandrupatla_minimize
  17. from scipy.stats._probability_distribution import _ProbabilityDistribution
  18. from scipy.stats import qmc
  19. # in case we need to distinguish between None and not specified
  20. # Typically this is used to determine whether the tolerance has been set by the
  21. # user and make a decision about which method to use to evaluate a distribution
  22. # function. Sometimes, the logic does not consider the value of the tolerance,
  23. # only whether this has been defined or not. This is not intended to be the
  24. # best possible logic; the intent is to establish the structure, which can
  25. # be refined in follow-up work.
  26. # See https://github.com/scipy/scipy/pull/21050#discussion_r1714195433.
  27. _null = object()
  28. def _isnull(x):
  29. return type(x) is object or x is None
  30. __all__ = ['make_distribution', 'Mixture', 'order_statistic',
  31. 'truncate', 'abs', 'exp', 'log']
  32. # Could add other policies for broadcasting and edge/out-of-bounds case handling
  33. # For instance, when edge case handling is known not to be needed, it's much
  34. # faster to turn it off, but it might still be nice to have array conversion
  35. # and shaping done so the user doesn't need to be so careful.
  36. _SKIP_ALL = "skip_all"
  37. # Other cache policies would be useful, too.
  38. _NO_CACHE = "no_cache"
  39. # TODO:
  40. # Test sample dtypes
  41. # Add dtype kwarg (especially for distributions with no parameters)
  42. # When drawing endpoint/out-of-bounds values of a parameter, draw them from
  43. # the endpoints/out-of-bounds region of the full `domain`, not `typical`.
  44. # Distributions without shape parameters probably need to accept a `dtype` parameter;
  45. # right now they default to float64. If we have them default to float16, they will
  46. # need to determine result_type when input is not float16 (overhead).
  47. # Test _solve_bounded bracket logic, and decide what to do about warnings
  48. # Get test coverage to 100%
  49. # Raise when distribution method returns wrong shape/dtype?
  50. # Consider ensuring everything is at least 1D for calculations? Would avoid needing
  51. # to sprinkle `np.asarray` throughout due to indescriminate conversion of 0D arrays
  52. # to scalars
  53. # Break up `test_basic`: test each method separately
  54. # Fix `sample` for QMCEngine (implementation does not match documentation)
  55. # When a parameter is invalid, set only the offending parameter to NaN (if possible)?
  56. # `_tanhsinh` special case when there are no abscissae between the limits
  57. # example: cdf of uniform betweeen 1.0 and np.nextafter(1.0, np.inf)
  58. # check behavior of moment methods when moments are undefined/infinite -
  59. # basically OK but needs tests
  60. # investigate use of median
  61. # implement symmetric distribution
  62. # implement composite distribution
  63. # implement wrapped distribution
  64. # profile/optimize
  65. # general cleanup (choose keyword-only parameters)
  66. # compare old/new distribution timing
  67. # make video
  68. # add array API support
  69. # why does dist.ilogcdf(-100) not converge to bound? Check solver response to inf
  70. # _chandrupatla_minimize should not report xm = fm = NaN when it fails
  71. # integrate `logmoment` into `moment`? (Not hard, but enough time and code
  72. # complexity to wait for reviewer feedback before adding.)
  73. # Eliminate bracket_root error "`min <= a < b <= max` must be True"
  74. # Test repr?
  75. # use `median` information to improve integration? In some cases this will
  76. # speed things up. If it's not needed, it may be about twice as slow. I think
  77. # it should depend on the accuracy setting.
  78. # in tests, check reference value against that produced using np.vectorize?
  79. # add `axis` to `ks_1samp`
  80. # User tips for faster execution:
  81. # - pass NumPy arrays
  82. # - pass inputs of floating point type (not integers)
  83. # - prefer NumPy scalars or 0d arrays over other size 1 arrays
  84. # - pass no invalid parameters and disable invalid parameter checks with iv_profile
  85. # - provide a Generator if you're going to do sampling
  86. # add options for drawing parameters: log-spacing
  87. # accuracy benchmark suite
  88. # Should caches be attributes so we can more easily ensure that they are not
  89. # modified when caching is turned off?
  90. # Make ShiftedScaledDistribution more efficient - only process underlying
  91. # distribution parameters as necessary.
  92. # Reconsider `all_inclusive`
  93. # Should process_parameters update kwargs rather than returning? Should we
  94. # update parameters rather than setting to what process_parameters returns?
  95. # Questions:
  96. # 1. I override `__getattr__` so that distribution parameters can be read as
  97. # attributes. We don't want uses to try to change them.
  98. # - To prevent replacements (dist.a = b), I could override `__setattr__`.
  99. # - To prevent in-place modifications, `__getattr__` could return a copy,
  100. # or it could set the WRITEABLE flag of the array to false.
  101. # Which should I do?
  102. # 2. `cache_policy` is supported in several methods where I imagine it being
  103. # useful, but it needs to be tested. Before doing that:
  104. # - What should the default value be?
  105. # - What should the other values be?
  106. # Or should we just eliminate this policy?
  107. # 3. `validation_policy` is supported in a few places, but it should be checked for
  108. # consistency. I have the same questions as for `cache_policy`.
  109. # 4. `tol` is currently notional. I think there needs to be way to set
  110. # separate `atol` and `rtol`. Some ways I imagine it being used:
  111. # - Values can be passed to iterative functions (quadrature, root-finder).
  112. # - To control which "method" of a distribution function is used. For
  113. # example, if `atol` is set to `1e-12`, it may be acceptable to compute
  114. # the complementary CDF as 1 - CDF even when CDF is nearly 1; otherwise,
  115. # a (potentially more time-consuming) method would need to be used.
  116. # I'm looking for unified suggestions for the interface, not ad hoc ideas
  117. # for using tolerances. Suppose the user wants to have more control over
  118. # the tolerances used for each method - how do they specify it? It would
  119. # probably be easiest for the user if they could pass tolerances into each
  120. # method, but it's easiest for us if they can only set it as a property of
  121. # the class. Perhaps a dictionary of tolerance settings?
  122. # 5. I also envision that accuracy estimates should be reported to the user
  123. # somehow. I think my preference would be to return a subclass of an array
  124. # with an `error` attribute - yes, really. But this is unlikely to be
  125. # popular, so what are other ideas? Again, we need a unified vision here,
  126. # not just pointing out difficulties (not all errors are known or easy
  127. # to estimate, what to do when errors could compound, etc.).
  128. # 6. The term "method" is used to refer to public instance functions,
  129. # private instance functions, the "method" string argument, and the means
  130. # of calculating the desired quantity (represented by the string argument).
  131. # For the sake of disambiguation, shall I rename the "method" string to
  132. # "strategy" and refer to the means of calculating the quantity as the
  133. # "strategy"?
  134. # Originally, I planned to filter out invalid distribution parameters;
  135. # distribution implementation functions would always work with "compressed",
  136. # 1D arrays containing only valid distribution parameters. There are two
  137. # problems with this:
  138. # - This essentially requires copying all arrays, even if there is only a
  139. # single invalid parameter combination. This is expensive. Then, to output
  140. # the original size data to the user, we need to "decompress" the arrays
  141. # and fill in the NaNs, so more copying. Unless we branch the code when
  142. # there are no invalid data, these copies happen even in the normal case,
  143. # where there are no invalid parameter combinations. We should not incur
  144. # all this overhead in the normal case.
  145. # - For methods that accept arguments other than distribution parameters, the
  146. # user will pass in arrays that are broadcastable with the original arrays,
  147. # not the compressed arrays. This means that this same sort of invalid
  148. # value detection needs to be repeated every time one of these methods is
  149. # called.
  150. # The much simpler solution is to keep the data uncompressed but to replace
  151. # the invalid parameters and arguments with NaNs (and only if some are
  152. # invalid). With this approach, the copying happens only if/when it is
  153. # needed. Most functions involved in stats distribution calculations don't
  154. # mind NaNs; they just return NaN. The behavior "If x_i is NaN, the result
  155. # is NaN" is explicit in the array API. So this should be fine.
  156. #
  157. # Currently, I am still leaving the parameters and function arguments
  158. # in their broadcasted shapes rather than, say, raveling. The intent
  159. # is to avoid back and forth reshaping. If authors of distributions have
  160. # trouble dealing with N-D arrays, we can reconsider this.
  161. #
  162. # Another important decision is that the *private* methods must accept
  163. # the distribution parameters as inputs rather than relying on these
  164. # cached properties directly (although the public methods typically pass
  165. # the cached values to the private methods). This is because the elementwise
  166. # algorithms for quadrature, differentiation, root-finding, and minimization
  167. # prefer that the input functions are strictly elementwise in the sense
  168. # that the value output for a given input element does not depend on the
  169. # shape of the input or that element's location within the input array.
  170. # When the computation has converged for an element, it is removed from
  171. # the computation entirely. As a result, the shape of the arrays passed to
  172. # the function will almost never be broadcastable with the shape of the
  173. # cached parameter arrays.
  174. #
  175. # I've sprinkled in some optimizations for scalars and same-shape/type arrays
  176. # throughout. The biggest time sinks before were:
  177. # - broadcast_arrays
  178. # - result_dtype
  179. # - is_subdtype
  180. # It is much faster to check whether these are necessary than to do them.
  181. class _Domain(ABC):
  182. r""" Representation of the applicable domain of a parameter or variable.
  183. A `_Domain` object is responsible for storing information about the
  184. domain of a parameter or variable, determining whether a value is within
  185. the domain (`contains`), and providing a text/mathematical representation
  186. of itself (`__str__`). Because the domain of a parameter/variable can have
  187. a complicated relationship with other parameters and variables of a
  188. distribution, `_Domain` itself does not try to represent all possibilities;
  189. in fact, it has no implementation and is meant for subclassing.
  190. Attributes
  191. ----------
  192. symbols : dict
  193. A map from special numerical values to symbols for use in `__str__`
  194. Methods
  195. -------
  196. contains(x)
  197. Determine whether the argument is contained within the domain (True)
  198. or not (False). Used for input validation.
  199. get_numerical_endpoints()
  200. Gets the numerical values of the domain endpoints, which may have been
  201. defined symbolically or through a callable.
  202. __str__()
  203. Returns a text representation of the domain (e.g. ``[0, b)``).
  204. Used for generating documentation.
  205. """
  206. symbols = {np.inf: r"\infty", -np.inf: r"-\infty", np.pi: r"\pi", -np.pi: r"-\pi"}
  207. # generic type compatibility with scipy-stubs
  208. __class_getitem__ = classmethod(GenericAlias)
  209. @abstractmethod
  210. def contains(self, x):
  211. raise NotImplementedError()
  212. @abstractmethod
  213. def draw(self, n):
  214. raise NotImplementedError()
  215. @abstractmethod
  216. def get_numerical_endpoints(self, x):
  217. raise NotImplementedError()
  218. @abstractmethod
  219. def __str__(self):
  220. raise NotImplementedError()
  221. class _Interval(_Domain):
  222. r""" Representation of an interval defined by two endpoints.
  223. Each endpoint may be a finite scalar, positive or negative infinity, or
  224. be given by a single parameter. The domain may include the endpoints or
  225. not.
  226. This class still does not provide an implementation of the __str__ method,
  227. so it is meant for subclassing (e.g. a subclass for domains on the real
  228. line).
  229. Attributes
  230. ----------
  231. symbols : dict
  232. Inherited. A map from special values to symbols for use in `__str__`.
  233. endpoints : 2-tuple of float(s) and/or str(s) and/or callable(s).
  234. A tuple with two values. Each may be either a float (the numerical
  235. value of the endpoints of the domain), a string (the name of the
  236. parameters that will define the endpoint), or a callable taking the
  237. parameters used to define the endpoints of the domain as keyword only
  238. arguments and returning a numerical value for the endpoint.
  239. inclusive : 2-tuple of bools
  240. A tuple with two boolean values; each indicates whether the
  241. corresponding endpoint is included within the domain or not.
  242. Methods
  243. -------
  244. define_parameters(*parameters)
  245. Records any parameters used to define the endpoints of the domain
  246. get_numerical_endpoints(parameter_values)
  247. Gets the numerical values of the domain endpoints, which may have been
  248. defined symbolically or through a callable.
  249. contains(item, parameter_values)
  250. Determines whether the argument is contained within the domain
  251. draw(size, rng, proportions, parameter_values)
  252. Draws random values based on the domain.
  253. """
  254. def __init__(self, endpoints=(-inf, inf), inclusive=(False, False)):
  255. self.symbols = super().symbols.copy()
  256. a, b = endpoints
  257. self.endpoints = np.asarray(a)[()], np.asarray(b)[()]
  258. self.inclusive = inclusive
  259. def define_parameters(self, *parameters):
  260. r""" Records any parameters used to define the endpoints of the domain.
  261. Adds the keyword name of each parameter and its text representation
  262. to the `symbols` attribute as key:value pairs.
  263. For instance, a parameter may be passed into to a distribution's
  264. initializer using the keyword `log_a`, and the corresponding
  265. string representation may be '\log(a)'. To form the text
  266. representation of the domain for use in documentation, the
  267. _Domain object needs to map from the keyword name used in the code
  268. to the string representation.
  269. Returns None, but updates the `symbols` attribute.
  270. Parameters
  271. ----------
  272. *parameters : _Parameter objects
  273. Parameters that may define the endpoints of the domain.
  274. """
  275. new_symbols = {param.name: param.symbol for param in parameters}
  276. self.symbols.update(new_symbols)
  277. def get_numerical_endpoints(self, parameter_values):
  278. r""" Get the numerical values of the domain endpoints.
  279. Domain endpoints may be defined symbolically or through a callable.
  280. This returns numerical values of the endpoints given numerical values for
  281. any variables.
  282. Parameters
  283. ----------
  284. parameter_values : dict
  285. A dictionary that maps between string variable names and numerical
  286. values of parameters, which may define the endpoints.
  287. Returns
  288. -------
  289. a, b : ndarray
  290. Numerical values of the endpoints
  291. """
  292. a, b = self.endpoints
  293. # If `a` (`b`) is a string - the name of the parameter that defines
  294. # the endpoint of the domain - then corresponding numerical values
  295. # will be found in the `parameter_values` dictionary.
  296. # If a callable, it will be executed with `parameter_values` passed as
  297. # keyword arguments, and it will return the numerical values.
  298. # Otherwise, it is itself the array of numerical values of the endpoint.
  299. try:
  300. if callable(a):
  301. a = a(**parameter_values)
  302. else:
  303. a = np.asarray(parameter_values.get(a, a))
  304. if callable(b):
  305. b = b(**parameter_values)
  306. else:
  307. b = np.asarray(parameter_values.get(b, b))
  308. except TypeError as e:
  309. message = ("The endpoints of the distribution are defined by "
  310. "parameters, but their values were not provided. When "
  311. f"using a private method of {self.__class__}, pass "
  312. "all required distribution parameters as keyword "
  313. "arguments.")
  314. raise TypeError(message) from e
  315. # Floating point types are used for even integer parameters.
  316. # Convert to float here to ensure consistency throughout framework.
  317. a, b = xp_promote(a, b, force_floating=True, xp=np)
  318. return a, b
  319. def contains(self, item, parameter_values=None):
  320. r"""Determine whether the argument is contained within the domain.
  321. Parameters
  322. ----------
  323. item : ndarray
  324. The argument
  325. parameter_values : dict
  326. A dictionary that maps between string variable names and numerical
  327. values of parameters, which may define the endpoints.
  328. Returns
  329. -------
  330. out : bool
  331. True if `item` is within the domain; False otherwise.
  332. """
  333. parameter_values = parameter_values or {}
  334. # if self.all_inclusive:
  335. # # Returning a 0d value here makes things much faster.
  336. # # I'm not sure if it's safe, though. If it causes a bug someday,
  337. # # I guess it wasn't.
  338. # # Even if there is no bug because of the shape, it is incorrect for
  339. # # `contains` to return True when there are invalid (e.g. NaN)
  340. # # parameters.
  341. # return np.asarray(True)
  342. a, b = self.get_numerical_endpoints(parameter_values)
  343. left_inclusive, right_inclusive = self.inclusive
  344. in_left = item >= a if left_inclusive else item > a
  345. in_right = item <= b if right_inclusive else item < b
  346. return in_left & in_right
  347. def draw(self, n, type_, min, max, squeezed_base_shape, rng=None):
  348. r""" Draw random values from the domain.
  349. Parameters
  350. ----------
  351. n : int
  352. The number of values to be drawn from the domain.
  353. type_ : str
  354. A string indicating whether the values are
  355. - strictly within the domain ('in'),
  356. - at one of the two endpoints ('on'),
  357. - strictly outside the domain ('out'), or
  358. - NaN ('nan').
  359. min, max : ndarray
  360. The endpoints of the domain.
  361. squeezed_based_shape : tuple of ints
  362. See _RealParameter.draw.
  363. rng : np.Generator
  364. The Generator used for drawing random values.
  365. """
  366. rng = np.random.default_rng(rng)
  367. def ints(*args, **kwargs): return rng.integers(*args, **kwargs, endpoint=True)
  368. uniform = rng.uniform if isinstance(self, _RealInterval) else ints
  369. # get copies of min and max with no nans so that uniform doesn't fail
  370. min_nn, max_nn = min.copy(), max.copy()
  371. i = np.isnan(min_nn) | np.isnan(max_nn)
  372. min_nn[i] = 0
  373. max_nn[i] = 1
  374. shape = (n,) + squeezed_base_shape
  375. if type_ == 'in':
  376. z = uniform(min_nn, max_nn, size=shape)
  377. elif type_ == 'on':
  378. z_on_shape = shape
  379. z = np.ones(z_on_shape)
  380. i = rng.random(size=n) < 0.5
  381. z[i] = min
  382. z[~i] = max
  383. elif type_ == 'out':
  384. z = min_nn - uniform(1, 5, size=shape) # 1, 5 is arbitary; we just want
  385. zr = max_nn + uniform(1, 5, size=shape) # some numbers outside domain
  386. i = rng.random(size=n) < 0.5
  387. z[i] = zr[i]
  388. elif type_ == 'nan':
  389. z = np.full(shape, np.nan)
  390. return z
  391. class _RealInterval(_Interval):
  392. r""" Represents a simply-connected subset of the real line; i.e., an interval
  393. Completes the implementation of the `_Interval` class for intervals
  394. on the real line.
  395. Methods
  396. -------
  397. define_parameters(*parameters)
  398. (Inherited) Records any parameters used to define the endpoints of the
  399. domain.
  400. get_numerical_endpoints(parameter_values)
  401. (Inherited) Gets the numerical values of the domain endpoints, which
  402. may have been defined symbolically.
  403. contains(item, parameter_values)
  404. (Inherited) Determines whether the argument is contained within the
  405. domain
  406. __str__()
  407. Returns a string representation of the domain, e.g. "[a, b)".
  408. """
  409. def __str__(self):
  410. a, b = self.endpoints
  411. a, b = self._get_endpoint_str(a, "f1"), self._get_endpoint_str(b, "f2")
  412. left_inclusive, right_inclusive = self.inclusive
  413. left = "[" if left_inclusive else "("
  414. right = "]" if right_inclusive else ")"
  415. return f"{left}{a}, {b}{right}"
  416. def _get_endpoint_str(self, endpoint, funcname):
  417. if callable(endpoint):
  418. if endpoint.__doc__ is not None:
  419. return endpoint.__doc__
  420. params = inspect.signature(endpoint).parameters.values()
  421. params = [
  422. p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
  423. ]
  424. return f"{funcname}({','.join(params)})"
  425. return self.symbols.get(endpoint, f"{endpoint}")
  426. class _IntegerInterval(_Interval):
  427. r""" Represents an interval of integers
  428. Completes the implementation of the `_Interval` class for simple
  429. domains on the integers.
  430. Methods
  431. -------
  432. define_parameters(*parameters)
  433. (Inherited) Records any parameters used to define the endpoints of the
  434. domain.
  435. get_numerical_endpoints(parameter_values)
  436. (Inherited) Gets the numerical values of the domain endpoints, which
  437. may have been defined symbolically.
  438. contains(item, parameter_values)
  439. (Overridden) Determines whether the argument is contained within the
  440. domain
  441. draw(n, type_, min, max, squeezed_base_shape, rng=None)
  442. (Inherited) Draws random values based on the domain.
  443. __str__()
  444. Returns a string representation of the domain, e.g. "{a, a+1, ..., b-1, b}".
  445. """
  446. def contains(self, item, parameter_values=None):
  447. super_contains = super().contains(item, parameter_values)
  448. integral = (item == np.round(item))
  449. return super_contains & integral
  450. def __str__(self):
  451. a, b = self.endpoints
  452. a = self.symbols.get(a, a)
  453. b = self.symbols.get(b, b)
  454. a_str, b_str = isinstance(a, str), isinstance(b, str)
  455. a_inf = a == r"-\infty" if a_str else np.isinf(a)
  456. b_inf = b == r"\infty" if b_str else np.isinf(b)
  457. # This doesn't work well for cases where ``a`` is floating point
  458. # number large enough that ``nextafter(a, inf) > a + 1``, and
  459. # similarly for ``b`` and nextafter(b, -inf). There may not be any
  460. # distributions fit for SciPy where we would actually need to handle these
  461. # cases though.
  462. ap1 = f"{a} + 1" if a_str else f"{a + 1}"
  463. bm1 = f"{b} - 1" if b_str else f"{b - 1}"
  464. if not a_str and not b_str:
  465. gap = b - a
  466. if gap == 3:
  467. return f"\\{{{a}, {ap1}, {bm1}, {b}\\}}"
  468. if gap == 2:
  469. return f"\\{{{a}, {ap1}, {b}\\}}"
  470. if gap == 1:
  471. return f"\\{{{a}, {b}\\}}"
  472. if gap == 0:
  473. return f"\\{{{a}\\}}"
  474. if not a_inf and b_inf:
  475. ap2 = f"{a} + 2" if a_str else f"{a + 2}"
  476. return f"\\{{{a}, {ap1}, {ap2}, ...\\}}"
  477. if a_inf and not b_inf:
  478. bm2 = f"{b} - 2" if b_str else f"{b - 2}"
  479. return f"\\{{{b}, {bm1}, {bm2}, ...\\}}"
  480. if a_inf and b_inf:
  481. return "\\{..., -2, -1, 0, 1, 2, ...\\}"
  482. return f"\\{{{a}, {ap1}, ..., {bm1}, {b}\\}}"
  483. class _Parameter(ABC):
  484. r""" Representation of a distribution parameter or variable.
  485. A `_Parameter` object is responsible for storing information about a
  486. parameter or variable, providing input validation/standardization of
  487. values passed for that parameter, providing a text/mathematical
  488. representation of the parameter for the documentation (`__str__`), and
  489. drawing random values of itself for testing and benchmarking. It does
  490. not provide a complete implementation of this functionality and is meant
  491. for subclassing.
  492. Attributes
  493. ----------
  494. name : str
  495. The keyword used to pass numerical values of the parameter into the
  496. initializer of the distribution
  497. symbol : str
  498. The text representation of the variable in the documentation. May
  499. include LaTeX.
  500. domain : _Domain
  501. The domain of the parameter for which the distribution is valid.
  502. typical : 2-tuple of floats or strings (consider making a _Domain)
  503. Defines the endpoints of a typical range of values of the parameter.
  504. Used for sampling.
  505. Methods
  506. -------
  507. __str__():
  508. Returns a string description of the variable for use in documentation,
  509. including the keyword used to represent it in code, the symbol used to
  510. represent it mathemtatically, and a description of the valid domain.
  511. draw(size, *, rng, domain, proportions)
  512. Draws random values of the parameter. Proportions of values within
  513. the valid domain, on the endpoints of the domain, outside the domain,
  514. and having value NaN are specified by `proportions`.
  515. validate(x):
  516. Validates and standardizes the argument for use as numerical values
  517. of the parameter.
  518. """
  519. # generic type compatibility with scipy-stubs
  520. __class_getitem__ = classmethod(GenericAlias)
  521. def __init__(self, name, *, domain, symbol=None, typical=None):
  522. self.name = name
  523. self.symbol = symbol or name
  524. self.domain = domain
  525. if typical is not None and not isinstance(typical, _Domain):
  526. typical = domain.__class__(typical)
  527. self.typical = typical or domain
  528. def __str__(self):
  529. r""" String representation of the parameter for use in documentation."""
  530. return f"`{self.name}` for :math:`{self.symbol} \\in {str(self.domain)}`"
  531. def draw(self, size=None, *, rng=None, region='domain', proportions=None,
  532. parameter_values=None):
  533. r""" Draw random values of the parameter for use in testing.
  534. Parameters
  535. ----------
  536. size : tuple of ints
  537. The shape of the array of valid values to be drawn.
  538. rng : np.Generator
  539. The Generator used for drawing random values.
  540. region : str
  541. The region of the `_Parameter` from which to draw. Default is
  542. "domain" (the *full* domain); alternative is "typical". An
  543. enhancement would give a way to interpolate between the two.
  544. proportions : tuple of numbers
  545. A tuple of four non-negative numbers that indicate the expected
  546. relative proportion of elements that:
  547. - are strictly within the domain,
  548. - are at one of the two endpoints,
  549. - are strictly outside the domain, and
  550. - are NaN,
  551. respectively. Default is (1, 0, 0, 0). The number of elements in
  552. each category is drawn from the multinomial distribution with
  553. `np.prod(size)` as the number of trials and `proportions` as the
  554. event probabilities. The values in `proportions` are automatically
  555. normalized to sum to 1.
  556. parameter_values : dict
  557. Map between the names of parameters (that define the endpoints of
  558. `typical`) and numerical values (arrays).
  559. """
  560. parameter_values = parameter_values or {}
  561. domain = self.domain
  562. proportions = (1, 0, 0, 0) if proportions is None else proportions
  563. pvals = proportions / np.sum(proportions)
  564. a, b = domain.get_numerical_endpoints(parameter_values)
  565. a, b = np.broadcast_arrays(a, b)
  566. base_shape = a.shape
  567. extended_shape = np.broadcast_shapes(size, base_shape)
  568. n_extended = np.prod(extended_shape)
  569. n_base = np.prod(base_shape)
  570. n = int(n_extended / n_base) if n_extended else 0
  571. rng = np.random.default_rng(rng)
  572. n_in, n_on, n_out, n_nan = rng.multinomial(n, pvals)
  573. # `min` and `max` can have singleton dimensions that correspond with
  574. # non-singleton dimensions in `size`. We need to be careful to avoid
  575. # shuffling results (e.g. a value that was generated for the domain
  576. # [min[i], max[i]] ends up at index j). To avoid this:
  577. # - Squeeze the singleton dimensions out of `min`/`max`. Squeezing is
  578. # often not the right thing to do, but here is equivalent to moving
  579. # all the dimensions that are singleton in `min`/`max` (which may be
  580. # non-singleton in the result) to the left. This is what we want.
  581. # - Now all the non-singleton dimensions of the result are on the left.
  582. # Ravel them to a single dimension of length `n`, which is now along
  583. # the 0th axis.
  584. # - Reshape the 0th axis back to the required dimensions, and move
  585. # these axes back to their original places.
  586. base_shape_padded = ((1,)*(len(extended_shape) - len(base_shape))
  587. + base_shape)
  588. base_singletons = np.where(np.asarray(base_shape_padded)==1)[0]
  589. new_base_singletons = tuple(range(len(base_singletons)))
  590. # Base singleton dimensions are going to get expanded to these lengths
  591. shape_expansion = np.asarray(extended_shape)[base_singletons]
  592. # assert(np.prod(shape_expansion) == n) # check understanding
  593. # min = np.reshape(min, base_shape_padded)
  594. # max = np.reshape(max, base_shape_padded)
  595. # min = np.moveaxis(min, base_singletons, new_base_singletons)
  596. # max = np.moveaxis(max, base_singletons, new_base_singletons)
  597. # squeezed_base_shape = max.shape[len(base_singletons):]
  598. # assert np.all(min.reshape(squeezed_base_shape) == min.squeeze())
  599. # assert np.all(max.reshape(squeezed_base_shape) == max.squeeze())
  600. # min = np.maximum(a, _fiinfo(a).min/10) if np.any(np.isinf(a)) else a
  601. # max = np.minimum(b, _fiinfo(b).max/10) if np.any(np.isinf(b)) else b
  602. min = np.asarray(a.squeeze())
  603. max = np.asarray(b.squeeze())
  604. squeezed_base_shape = max.shape
  605. if region == 'typical':
  606. typical = self.typical
  607. a, b = typical.get_numerical_endpoints(parameter_values)
  608. a, b = np.broadcast_arrays(a, b)
  609. min_here = np.asarray(a.squeeze())
  610. max_here = np.asarray(b.squeeze())
  611. z_in = typical.draw(n_in, 'in', min_here, max_here, squeezed_base_shape,
  612. rng=rng)
  613. else:
  614. z_in = domain.draw(n_in, 'in', min, max, squeezed_base_shape, rng=rng)
  615. z_on = domain.draw(n_on, 'on', min, max, squeezed_base_shape, rng=rng)
  616. z_out = domain.draw(n_out, 'out', min, max, squeezed_base_shape, rng=rng)
  617. z_nan= domain.draw(n_nan, 'nan', min, max, squeezed_base_shape, rng=rng)
  618. z = np.concatenate((z_in, z_on, z_out, z_nan), axis=0)
  619. z = rng.permuted(z, axis=0)
  620. z = np.reshape(z, tuple(shape_expansion) + squeezed_base_shape)
  621. z = np.moveaxis(z, new_base_singletons, base_singletons)
  622. return z
  623. @abstractmethod
  624. def validate(self, arr):
  625. raise NotImplementedError()
  626. class _RealParameter(_Parameter):
  627. r""" Represents a real-valued parameter.
  628. Implements the remaining methods of _Parameter for real parameters.
  629. All attributes are inherited.
  630. """
  631. def validate(self, arr, parameter_values):
  632. r""" Input validation/standardization of numerical values of a parameter.
  633. Checks whether elements of the argument `arr` are reals, ensuring that
  634. the dtype reflects this. Also produces a logical array that indicates
  635. which elements meet the requirements.
  636. Parameters
  637. ----------
  638. arr : ndarray
  639. The argument array to be validated and standardized.
  640. parameter_values : dict
  641. Map of parameter names to parameter value arrays.
  642. Returns
  643. -------
  644. arr : ndarray
  645. The argument array that has been validated and standardized
  646. (converted to an appropriate dtype, if necessary).
  647. dtype : NumPy dtype
  648. The appropriate floating point dtype of the parameter.
  649. valid : boolean ndarray
  650. Logical array indicating which elements are valid (True) and
  651. which are not (False). The arrays of all distribution parameters
  652. will be broadcasted, and elements for which any parameter value
  653. does not meet the requirements will be replaced with NaN.
  654. """
  655. arr = np.asarray(arr)
  656. valid_dtype = None
  657. # minor optimization - fast track the most common types to avoid
  658. # overhead of np.issubdtype. Checking for `in {...}` doesn't work : /
  659. if arr.dtype == np.float64 or arr.dtype == np.float32:
  660. pass
  661. elif arr.dtype == np.int32 or arr.dtype == np.int64:
  662. arr = np.asarray(arr, dtype=np.float64)
  663. elif np.issubdtype(arr.dtype, np.floating):
  664. pass
  665. elif np.issubdtype(arr.dtype, np.integer):
  666. arr = np.asarray(arr, dtype=np.float64)
  667. else:
  668. message = f"Parameter `{self.name}` must be of real dtype."
  669. raise TypeError(message)
  670. valid = self.domain.contains(arr, parameter_values)
  671. valid = valid & valid_dtype if valid_dtype is not None else valid
  672. return arr[()], arr.dtype, valid
  673. class _Parameterization:
  674. r""" Represents a parameterization of a distribution.
  675. Distributions can have multiple parameterizations. A `_Parameterization`
  676. object is responsible for recording the parameters used by the
  677. parameterization, checking whether keyword arguments passed to the
  678. distribution match the parameterization, and performing input validation
  679. of the numerical values of these parameters.
  680. Attributes
  681. ----------
  682. parameters : dict
  683. String names (of keyword arguments) and the corresponding _Parameters.
  684. Methods
  685. -------
  686. __len__()
  687. Returns the number of parameters in the parameterization.
  688. __str__()
  689. Returns a string representation of the parameterization.
  690. copy
  691. Returns a copy of the parameterization. This is needed for transformed
  692. distributions that add parameters to the parameterization.
  693. matches(parameters)
  694. Checks whether the keyword arguments match the parameterization.
  695. validation(parameter_values)
  696. Input validation / standardization of parameterization. Validates the
  697. numerical values of all parameters.
  698. draw(sizes, rng, proportions)
  699. Draw random values of all parameters of the parameterization for use
  700. in testing.
  701. """
  702. def __init__(self, *parameters):
  703. self.parameters = {param.name: param for param in parameters}
  704. def __len__(self):
  705. return len(self.parameters)
  706. def copy(self):
  707. return _Parameterization(*self.parameters.values())
  708. def matches(self, parameters):
  709. r""" Checks whether the keyword arguments match the parameterization.
  710. Parameters
  711. ----------
  712. parameters : set
  713. Set of names of parameters passed into the distribution as keyword
  714. arguments.
  715. Returns
  716. -------
  717. out : bool
  718. True if the keyword arguments names match the names of the
  719. parameters of this parameterization.
  720. """
  721. return parameters == set(self.parameters.keys())
  722. def validation(self, parameter_values):
  723. r""" Input validation / standardization of parameterization.
  724. Parameters
  725. ----------
  726. parameter_values : dict
  727. The keyword arguments passed as parameter values to the
  728. distribution.
  729. Returns
  730. -------
  731. all_valid : ndarray
  732. Logical array indicating the elements of the broadcasted arrays
  733. for which all parameter values are valid.
  734. dtype : dtype
  735. The common dtype of the parameter arrays. This will determine
  736. the dtype of the output of distribution methods.
  737. """
  738. all_valid = True
  739. dtypes = set() # avoid np.result_type if there's only one type
  740. for name, arr in parameter_values.items():
  741. parameter = self.parameters[name]
  742. arr, dtype, valid = parameter.validate(arr, parameter_values)
  743. dtypes.add(dtype)
  744. all_valid = all_valid & valid
  745. parameter_values[name] = arr
  746. dtype = arr.dtype if len(dtypes)==1 else np.result_type(*list(dtypes))
  747. return all_valid, dtype
  748. def __str__(self):
  749. r"""Returns a string representation of the parameterization."""
  750. messages = [str(param) for name, param in self.parameters.items()]
  751. return ", ".join(messages)
  752. def draw(self, sizes=None, rng=None, proportions=None, region='domain'):
  753. r"""Draw random values of all parameters for use in testing.
  754. Parameters
  755. ----------
  756. sizes : iterable of shape tuples
  757. The size of the array to be generated for each parameter in the
  758. parameterization. Note that the order of sizes is arbitary; the
  759. size of the array generated for a specific parameter is not
  760. controlled individually as written.
  761. rng : NumPy Generator
  762. The generator used to draw random values.
  763. proportions : tuple
  764. A tuple of four non-negative numbers that indicate the expected
  765. relative proportion of elements that are within the parameter's
  766. domain, are on the boundary of the parameter's domain, are outside
  767. the parameter's domain, and have value NaN. For more information,
  768. see the `draw` method of the _Parameter subclasses.
  769. domain : str
  770. The domain of the `_Parameter` from which to draw. Default is
  771. "domain" (the *full* domain); alternative is "typical".
  772. Returns
  773. -------
  774. parameter_values : dict (string: array)
  775. A dictionary of parameter name/value pairs.
  776. """
  777. # ENH: be smart about the order. The domains of some parameters
  778. # depend on others. If the relationshp is simple (e.g. a < b < c),
  779. # we can draw values in order a, b, c.
  780. parameter_values = {}
  781. if sizes is None or not len(sizes) or not np.iterable(sizes[0]):
  782. sizes = [sizes]*len(self.parameters)
  783. for size, param in zip(sizes, self.parameters.values()):
  784. parameter_values[param.name] = param.draw(
  785. size, rng=rng, proportions=proportions,
  786. parameter_values=parameter_values,
  787. region=region
  788. )
  789. return parameter_values
  790. def _set_invalid_nan(f):
  791. # Wrapper for input / output validation and standardization of distribution
  792. # functions that accept either the quantile or percentile as an argument:
  793. # logpdf, pdf
  794. # logpmf, pmf
  795. # logcdf, cdf
  796. # logccdf, ccdf
  797. # ilogcdf, icdf
  798. # ilogccdf, iccdf
  799. # Arguments that are outside the required range are replaced by NaN before
  800. # passing them into the underlying function. The corresponding outputs
  801. # are replaced by the appropriate value before being returned to the user.
  802. # For example, when the argument of `cdf` exceeds the right end of the
  803. # distribution's support, the wrapper replaces the argument with NaN,
  804. # ignores the output of the underlying function, and returns 1.0. It also
  805. # ensures that output is of the appropriate shape and dtype.
  806. endpoints = {'icdf': (0, 1), 'iccdf': (0, 1),
  807. 'ilogcdf': (-np.inf, 0), 'ilogccdf': (-np.inf, 0)}
  808. replacements = {'logpdf': (-inf, -inf), 'pdf': (0, 0),
  809. 'logpmf': (-inf, -inf), 'pmf': (0, 0),
  810. '_logcdf1': (-inf, 0), '_logccdf1': (0, -inf),
  811. '_cdf1': (0, 1), '_ccdf1': (1, 0)}
  812. replace_strict = {'pdf', 'logpdf', 'pmf', 'logpmf'}
  813. replace_exact = {'icdf', 'iccdf', 'ilogcdf', 'ilogccdf'}
  814. clip = {'_cdf1', '_ccdf1'}
  815. clip_log = {'_logcdf1', '_logccdf1'}
  816. # relevant to discrete distributions only
  817. replace_non_integral = {'pmf', 'logpmf', 'pdf', 'logpdf'}
  818. @functools.wraps(f)
  819. def filtered(self, x, *args, **kwargs):
  820. if self.validation_policy == _SKIP_ALL:
  821. return f(self, x, *args, **kwargs)
  822. method_name = f.__name__
  823. x = np.asarray(x)
  824. dtype = self._dtype
  825. shape = self._shape
  826. discrete = isinstance(self, DiscreteDistribution)
  827. keep_low_endpoint = discrete and method_name in {'_cdf1', '_logcdf1',
  828. '_ccdf1', '_logccdf1'}
  829. # Ensure that argument is at least as precise as distribution
  830. # parameters, which are already at least floats. This will avoid issues
  831. # with raising integers to negative integer powers and failure to replace
  832. # invalid integers with NaNs.
  833. if x.dtype != dtype:
  834. dtype = np.result_type(x.dtype, dtype)
  835. x = np.asarray(x, dtype=dtype)
  836. # Broadcasting is slow. Do it only if necessary.
  837. if not x.shape == shape:
  838. try:
  839. shape = np.broadcast_shapes(x.shape, shape)
  840. x = np.broadcast_to(x, shape)
  841. # Should we broadcast the distribution parameters to this shape, too?
  842. except ValueError as e:
  843. message = (
  844. f"The argument provided to `{self.__class__.__name__}"
  845. f".{method_name}` cannot be be broadcast to the same "
  846. "shape as the distribution parameters.")
  847. raise ValueError(message) from e
  848. low, high = endpoints.get(method_name, self.support())
  849. # Check for arguments outside of domain. They'll be replaced with NaNs,
  850. # and the result will be set to the appropriate value.
  851. left_inc, right_inc = self._variable.domain.inclusive
  852. mask_low = (x < low if (method_name in replace_strict and left_inc)
  853. or keep_low_endpoint else x <= low)
  854. mask_high = (x > high if (method_name in replace_strict and right_inc)
  855. else x >= high)
  856. mask_invalid = (mask_low | mask_high)
  857. any_invalid = (mask_invalid if mask_invalid.shape == ()
  858. else np.any(mask_invalid))
  859. # Check for arguments at domain endpoints, whether they
  860. # are part of the domain or not.
  861. any_endpoint = False
  862. if method_name in replace_exact:
  863. mask_low_endpoint = (x == low)
  864. mask_high_endpoint = (x == high)
  865. mask_endpoint = (mask_low_endpoint | mask_high_endpoint)
  866. any_endpoint = (mask_endpoint if mask_endpoint.shape == ()
  867. else np.any(mask_endpoint))
  868. # Check for non-integral arguments to PMF method
  869. # or PDF of a discrete distribution.
  870. any_non_integral = False
  871. if discrete and method_name in replace_non_integral:
  872. mask_non_integral = (x != np.floor(x))
  873. any_non_integral = (mask_non_integral if mask_non_integral.shape == ()
  874. else np.any(mask_non_integral))
  875. # Set out-of-domain arguments to NaN. The result will be set to the
  876. # appropriate value later.
  877. if any_invalid:
  878. x = np.array(x, dtype=dtype, copy=True)
  879. x[mask_invalid] = np.nan
  880. res = np.asarray(f(self, x, *args, **kwargs))
  881. # Ensure that the result is the correct dtype and shape,
  882. # copying (only once) if necessary.
  883. res_needs_copy = False
  884. if res.dtype != dtype:
  885. dtype = np.result_type(dtype, self._dtype)
  886. res_needs_copy = True
  887. if res.shape != shape: # faster to check first
  888. res = np.broadcast_to(res, self._shape)
  889. res_needs_copy = (res_needs_copy or any_invalid
  890. or any_endpoint or any_non_integral)
  891. if res_needs_copy:
  892. res = np.array(res, dtype=dtype, copy=True)
  893. # For non-integral arguments to PMF (and PDF of discrete distribution)
  894. # replace with zero.
  895. if any_non_integral:
  896. zero = -np.inf if method_name in {'logpmf', 'logpdf'} else 0
  897. res[mask_non_integral & ~np.isnan(res)] = zero
  898. # For arguments outside the function domain, replace results
  899. if any_invalid:
  900. replace_low, replace_high = (
  901. replacements.get(method_name, (np.nan, np.nan)))
  902. res[mask_low] = replace_low
  903. res[mask_high] = replace_high
  904. # For arguments at the endpoints of the domain, replace results
  905. if any_endpoint:
  906. a, b = self.support()
  907. if a.shape != shape:
  908. a = np.array(np.broadcast_to(a, shape), copy=True)
  909. b = np.array(np.broadcast_to(b, shape), copy=True)
  910. replace_low_endpoint = (
  911. b[mask_low_endpoint] if method_name.endswith('ccdf')
  912. else a[mask_low_endpoint])
  913. replace_high_endpoint = (
  914. a[mask_high_endpoint] if method_name.endswith('ccdf')
  915. else b[mask_high_endpoint])
  916. if not keep_low_endpoint:
  917. res[mask_low_endpoint] = replace_low_endpoint
  918. res[mask_high_endpoint] = replace_high_endpoint
  919. # Clip probabilities to [0, 1]
  920. if method_name in clip:
  921. res = np.clip(res, 0., 1.)
  922. elif method_name in clip_log:
  923. res = res.real # exp(res) > 0
  924. res = np.clip(res, None, 0.) # exp(res) < 1
  925. return res[()]
  926. return filtered
  927. def _set_invalid_nan_property(f):
  928. # Wrapper for input / output validation and standardization of distribution
  929. # functions that represent properties of the distribution itself:
  930. # logentropy, entropy
  931. # median, mode
  932. # moment
  933. # It ensures that the output is of the correct shape and dtype and that
  934. # there are NaNs wherever the distribution parameters were invalid.
  935. @functools.wraps(f)
  936. def filtered(self, *args, **kwargs):
  937. if self.validation_policy == _SKIP_ALL:
  938. return f(self, *args, **kwargs)
  939. res = f(self, *args, **kwargs)
  940. if res is None:
  941. # message could be more appropriate
  942. raise NotImplementedError(self._not_implemented)
  943. res = np.asarray(res)
  944. needs_copy = False
  945. dtype = res.dtype
  946. if dtype != self._dtype: # this won't work for logmoments (complex)
  947. dtype = np.result_type(dtype, self._dtype)
  948. needs_copy = True
  949. if res.shape != self._shape: # faster to check first
  950. res = np.broadcast_to(res, self._shape)
  951. needs_copy = needs_copy or self._any_invalid
  952. if needs_copy:
  953. res = res.astype(dtype=dtype, copy=True)
  954. if self._any_invalid:
  955. # may be redundant when quadrature is used, but not necessarily
  956. # when formulas are used.
  957. res[self._invalid] = np.nan
  958. return res[()]
  959. return filtered
  960. def _dispatch(f):
  961. # For each public method (instance function) of a distribution (e.g. ccdf),
  962. # there may be several ways ("method"s) that it can be computed (e.g. a
  963. # formula, as the complement of the CDF, or via numerical integration).
  964. # Each "method" is implemented by a different private method (instance
  965. # function).
  966. # This wrapper calls the appropriate private method based on the public
  967. # method and any specified `method` keyword option.
  968. # - If `method` is specified as a string (by the user), the appropriate
  969. # private method is called.
  970. # - If `method` is None:
  971. # - The appropriate private method for the public method is looked up
  972. # in a cache.
  973. # - If the cache does not have an entry for the public method, the
  974. # appropriate "dispatch " function is called to determine which method
  975. # is most appropriate given the available private methods and
  976. # settings (e.g. tolerance).
  977. @functools.wraps(f)
  978. def wrapped(self, *args, method=None, **kwargs):
  979. func_name = f.__name__
  980. method = method or self._method_cache.get(func_name, None)
  981. if callable(method):
  982. pass
  983. elif method is not None:
  984. method = 'logexp' if method == 'log/exp' else method
  985. method_name = func_name.replace('dispatch', method)
  986. method = getattr(self, method_name)
  987. else:
  988. method = f(self, *args, method=method, **kwargs)
  989. if func_name != '_sample_dispatch' and self.cache_policy != _NO_CACHE:
  990. self._method_cache[func_name] = method
  991. try:
  992. return method(*args, **kwargs)
  993. except KeyError as e:
  994. raise NotImplementedError(self._not_implemented) from e
  995. return wrapped
  996. def _cdf2_input_validation(f):
  997. # Wrapper that does the job of `_set_invalid_nan` when `cdf` or `logcdf`
  998. # is called with two quantile arguments.
  999. # Let's keep it simple; no special cases for speed right now.
  1000. # The strategy is a bit different than for 1-arg `cdf` (and other methods
  1001. # covered by `_set_invalid_nan`). For 1-arg `cdf`, elements of `x` that
  1002. # are outside (or at the edge of) the support get replaced by `nan`,
  1003. # and then the results get replaced by the appropriate value (0 or 1).
  1004. # We *could* do something similar, dispatching to `_cdf1` in these
  1005. # cases. That would be a bit more robust, but it would also be quite
  1006. # a bit more complex, since we'd have to do different things when
  1007. # `x` and `y` are both out of bounds, when just `x` is out of bounds,
  1008. # when just `y` is out of bounds, and when both are out of bounds.
  1009. # I'm not going to do that right now. Instead, simply replace values
  1010. # outside the support by those at the edge of the support. Here, we also
  1011. # omit some of the optimizations that make `_set_invalid_nan` faster for
  1012. # simple arguments (e.g. float64 scalars).
  1013. @functools.wraps(f)
  1014. def wrapped(self, x, y, *args, **kwargs):
  1015. func_name = f.__name__
  1016. low, high = self.support()
  1017. x, y, low, high = np.broadcast_arrays(x, y, low, high)
  1018. dtype = np.result_type(x.dtype, y.dtype, self._dtype)
  1019. # yes, copy to avoid modifying input arrays
  1020. x, y = x.astype(dtype, copy=True), y.astype(dtype, copy=True)
  1021. # Swap arguments to ensure that x < y, and replace
  1022. # out-of domain arguments with domain endpoints. We'll
  1023. # transform the result later.
  1024. i_swap = y < x
  1025. x[i_swap], y[i_swap] = y[i_swap], x[i_swap]
  1026. i = x < low
  1027. x[i] = low[i]
  1028. i = y < low
  1029. y[i] = low[i]
  1030. i = x > high
  1031. x[i] = high[i]
  1032. i = y > high
  1033. y[i] = high[i]
  1034. res = f(self, x, y, *args, **kwargs)
  1035. # Clipping probability to [0, 1]
  1036. if func_name in {'_cdf2', '_ccdf2'}:
  1037. res = np.clip(res, 0., 1.)
  1038. else:
  1039. res = np.clip(res, None, 0.) # exp(res) < 1
  1040. # Transform the result to account for swapped argument order
  1041. res = np.asarray(res)
  1042. if func_name == '_cdf2':
  1043. res[i_swap] *= -1.
  1044. elif func_name == '_ccdf2':
  1045. res[i_swap] *= -1
  1046. res[i_swap] += 2.
  1047. elif func_name == '_logcdf2':
  1048. res = np.asarray(res + 0j) if np.any(i_swap) else res
  1049. res[i_swap] = res[i_swap] + np.pi*1j
  1050. else:
  1051. # res[i_swap] is always positive and less than 1, so it's
  1052. # safe to ensure that the result is real
  1053. res[i_swap] = _logexpxmexpy(np.log(2), res[i_swap]).real
  1054. return res[()]
  1055. return wrapped
  1056. def _fiinfo(x):
  1057. if np.issubdtype(x.dtype, np.inexact):
  1058. return np.finfo(x.dtype)
  1059. else:
  1060. return np.iinfo(x)
  1061. def _kwargs2args(f, args=None, kwargs=None):
  1062. # Wraps a function that accepts a primary argument `x`, secondary
  1063. # arguments `args`, and secondary keyward arguments `kwargs` such that the
  1064. # wrapper accepts only `x` and `args`. The keyword arguments are extracted
  1065. # from `args` passed into the wrapper, and these are passed to the
  1066. # underlying function as `kwargs`.
  1067. # This is a temporary workaround until the scalar algorithms `_tanhsinh`,
  1068. # `_chandrupatla`, etc., support `kwargs` or can operate with compressing
  1069. # arguments to the callable.
  1070. args = args or []
  1071. kwargs = kwargs or {}
  1072. names = list(kwargs.keys())
  1073. n_args = len(args)
  1074. def wrapped(x, *args):
  1075. return f(x, *args[:n_args], **dict(zip(names, args[n_args:])))
  1076. args = tuple(args) + tuple(kwargs.values())
  1077. return wrapped, args
  1078. def _logexpxmexpy(x, y):
  1079. """ Compute the log of the difference of the exponentials of two arguments.
  1080. Avoids over/underflow, but does not prevent loss of precision otherwise.
  1081. """
  1082. # TODO: properly avoid NaN when y is negative infinity
  1083. # TODO: silence warning with taking log of complex nan
  1084. # TODO: deal with x == y better
  1085. i = np.isneginf(np.real(y))
  1086. if np.any(i):
  1087. y = np.asarray(y.copy())
  1088. y[i] = np.finfo(y.dtype).min
  1089. x, y = np.broadcast_arrays(x, y)
  1090. res = np.asarray(special.logsumexp([x, y+np.pi*1j], axis=0))
  1091. i = (x == y)
  1092. res[i] = -np.inf
  1093. return res
  1094. def _guess_bracket(xmin, xmax):
  1095. a = np.full_like(xmin, -1.0)
  1096. b = np.ones_like(xmax)
  1097. i = np.isfinite(xmin) & np.isfinite(xmax)
  1098. a[i] = xmin[i]
  1099. b[i] = xmax[i]
  1100. i = np.isfinite(xmin) & ~np.isfinite(xmax)
  1101. a[i] = xmin[i]
  1102. b[i] = xmin[i] + 1
  1103. i = np.isfinite(xmax) & ~np.isfinite(xmin)
  1104. a[i] = xmax[i] - 1
  1105. b[i] = xmax[i]
  1106. return a, b
  1107. def _log_real_standardize(x):
  1108. """Standardizes the (complex) logarithm of a real number.
  1109. The logarithm of a real number may be represented by a complex number with
  1110. imaginary part that is a multiple of pi*1j. Even multiples correspond with
  1111. a positive real and odd multiples correspond with a negative real.
  1112. Given a logarithm of a real number `x`, this function returns an equivalent
  1113. representation in a standard form: the log of a positive real has imaginary
  1114. part `0` and the log of a negative real has imaginary part `pi`.
  1115. """
  1116. shape = x.shape
  1117. x = np.atleast_1d(x)
  1118. real = np.real(x).astype(x.dtype)
  1119. complex = np.imag(x)
  1120. y = real
  1121. negative = np.exp(complex*1j) < 0.5
  1122. y[negative] = y[negative] + np.pi * 1j
  1123. return y.reshape(shape)[()]
  1124. def _combine_docs(dist_family, *, include_examples=True):
  1125. fields = set(NumpyDocString.sections)
  1126. fields.remove('index')
  1127. if not include_examples:
  1128. fields.remove('Examples')
  1129. doc = ClassDoc(dist_family)
  1130. superdoc = ClassDoc(UnivariateDistribution)
  1131. for field in fields:
  1132. if field in {"Methods", "Attributes"}:
  1133. doc[field] = superdoc[field]
  1134. elif field in {"Summary"}:
  1135. pass
  1136. elif field == "Extended Summary":
  1137. doc[field].append(_generate_domain_support(dist_family))
  1138. elif field == 'Examples':
  1139. doc[field] = [_generate_example(dist_family)]
  1140. else:
  1141. doc[field] += superdoc[field]
  1142. return str(doc)
  1143. def _generate_domain_support(dist_family):
  1144. n_parameterizations = len(dist_family._parameterizations)
  1145. domain = f"\nfor :math:`x \\in {dist_family._variable.domain}`.\n"
  1146. if n_parameterizations == 0:
  1147. support = """
  1148. This class accepts no distribution parameters.
  1149. """
  1150. elif n_parameterizations == 1:
  1151. support = f"""
  1152. This class accepts one parameterization:
  1153. {str(dist_family._parameterizations[0])}.
  1154. """
  1155. else:
  1156. number = {2: 'two', 3: 'three', 4: 'four', 5: 'five'}[
  1157. n_parameterizations]
  1158. parameterizations = [f"- {str(p)}" for p in
  1159. dist_family._parameterizations]
  1160. parameterizations = "\n".join(parameterizations)
  1161. support = f"""
  1162. This class accepts {number} parameterizations:
  1163. {parameterizations}
  1164. """
  1165. support = "\n".join([line.lstrip() for line in support.split("\n")][1:])
  1166. return domain + support
  1167. def _generate_example(dist_family):
  1168. n_parameters = dist_family._num_parameters(0)
  1169. shapes = [()] * n_parameters
  1170. rng = np.random.default_rng(615681484984984)
  1171. i = 0
  1172. dist = dist_family._draw(shapes, rng=rng, i_parameterization=i)
  1173. rng = np.random.default_rng(2354873452)
  1174. name = dist_family.__name__
  1175. if n_parameters:
  1176. parameter_names = list(dist._parameterizations[i].parameters)
  1177. parameter_values = [round(getattr(dist, name), 2) for name in
  1178. parameter_names]
  1179. name_values = [f"{name}={value}" for name, value in
  1180. zip(parameter_names, parameter_values)]
  1181. instantiation = f"{name}({', '.join(name_values)})"
  1182. attributes = ", ".join([f"X.{param}" for param in dist._parameters])
  1183. X = dist_family(**dict(zip(parameter_names, parameter_values)))
  1184. else:
  1185. instantiation = f"{name}()"
  1186. X = dist
  1187. p = 0.32
  1188. x = round(X.icdf(p), 2)
  1189. y = round(X.icdf(2 * p), 2) # noqa: F841
  1190. example = f"""
  1191. To use the distribution class, it must be instantiated using keyword
  1192. parameters corresponding with one of the accepted parameterizations.
  1193. >>> import numpy as np
  1194. >>> import matplotlib.pyplot as plt
  1195. >>> from scipy import stats
  1196. >>> from scipy.stats import {name}
  1197. >>> X = {instantiation}
  1198. For convenience, the ``plot`` method can be used to visualize the density
  1199. and other functions of the distribution.
  1200. >>> X.plot()
  1201. >>> plt.show()
  1202. The support of the underlying distribution is available using the ``support``
  1203. method.
  1204. >>> X.support()
  1205. {X.support()}
  1206. """
  1207. if n_parameters:
  1208. example += f"""
  1209. The numerical values of parameters associated with all parameterizations
  1210. are available as attributes.
  1211. >>> {attributes}
  1212. {tuple(X._parameters.values())}
  1213. """
  1214. example += f"""
  1215. To evaluate the probability density/mass function of the underlying distribution
  1216. at argument ``x={x}``:
  1217. >>> x = {x}
  1218. >>> X.pdf(x), X.pmf(x)
  1219. {X.pdf(x), X.pmf(x)}
  1220. The cumulative distribution function, its complement, and the logarithm
  1221. of these functions are evaluated similarly.
  1222. >>> np.allclose(np.exp(X.logccdf(x)), 1 - X.cdf(x))
  1223. True
  1224. """
  1225. # When two-arg CDF is implemented for DiscreteDistribution, consider removing
  1226. # the special-casing here.
  1227. if issubclass(dist_family, ContinuousDistribution):
  1228. example_continuous = f"""
  1229. The inverse of these functions with respect to the argument ``x`` is also
  1230. available.
  1231. >>> logp = np.log(1 - X.ccdf(x))
  1232. >>> np.allclose(X.ilogcdf(logp), x)
  1233. True
  1234. Note that distribution functions and their logarithms also have two-argument
  1235. versions for working with the probability mass between two arguments. The
  1236. result tends to be more accurate than the naive implementation because it avoids
  1237. subtractive cancellation.
  1238. >>> y = {y}
  1239. >>> np.allclose(X.ccdf(x, y), 1 - (X.cdf(y) - X.cdf(x)))
  1240. True
  1241. """
  1242. example += example_continuous
  1243. example += f"""
  1244. There are methods for computing measures of central tendency,
  1245. dispersion, higher moments, and entropy.
  1246. >>> X.mean(), X.median(), X.mode()
  1247. {X.mean(), X.median(), X.mode()}
  1248. >>> X.variance(), X.standard_deviation()
  1249. {X.variance(), X.standard_deviation()}
  1250. >>> X.skewness(), X.kurtosis()
  1251. {X.skewness(), X.kurtosis()}
  1252. >>> np.allclose(X.moment(order=6, kind='standardized'),
  1253. ... X.moment(order=6, kind='central') / X.variance()**3)
  1254. True
  1255. """
  1256. # When logentropy is implemented for DiscreteDistribution, remove special-casing
  1257. if issubclass(dist_family, ContinuousDistribution):
  1258. example += """
  1259. >>> np.allclose(np.exp(X.logentropy()), X.entropy())
  1260. True
  1261. """
  1262. else:
  1263. example += f"""
  1264. >>> X.entropy()
  1265. {X.entropy()}
  1266. """
  1267. example += f"""
  1268. Pseudo-random samples can be drawn from
  1269. the underlying distribution using ``sample``.
  1270. >>> X.sample(shape=(4,))
  1271. {repr(X.sample(shape=(4,)))} # may vary
  1272. """
  1273. # remove the indentation due to use of block quote within function;
  1274. # eliminate blank first line
  1275. example = "\n".join([line.lstrip() for line in example.split("\n")][1:])
  1276. return example
  1277. class UnivariateDistribution(_ProbabilityDistribution):
  1278. r""" Class that represents a continuous statistical distribution.
  1279. Parameters
  1280. ----------
  1281. tol : positive float, optional
  1282. The desired relative tolerance of calculations. Left unspecified,
  1283. calculations may be faster; when provided, calculations may be
  1284. more likely to meet the desired accuracy.
  1285. validation_policy : {None, "skip_all"}
  1286. Specifies the level of input validation to perform. Left unspecified,
  1287. input validation is performed to ensure appropriate behavior in edge
  1288. case (e.g. parameters out of domain, argument outside of distribution
  1289. support, etc.) and improve consistency of output dtype, shape, etc.
  1290. Pass ``'skip_all'`` to avoid the computational overhead of these
  1291. checks when rough edges are acceptable.
  1292. cache_policy : {None, "no_cache"}
  1293. Specifies the extent to which intermediate results are cached. Left
  1294. unspecified, intermediate results of some calculations (e.g. distribution
  1295. support, moments, etc.) are cached to improve performance of future
  1296. calculations. Pass ``'no_cache'`` to reduce memory reserved by the class
  1297. instance.
  1298. Attributes
  1299. ----------
  1300. All parameters are available as attributes.
  1301. Methods
  1302. -------
  1303. support
  1304. plot
  1305. sample
  1306. moment
  1307. mean
  1308. median
  1309. mode
  1310. variance
  1311. standard_deviation
  1312. skewness
  1313. kurtosis
  1314. pdf
  1315. logpdf
  1316. cdf
  1317. icdf
  1318. ccdf
  1319. iccdf
  1320. logcdf
  1321. ilogcdf
  1322. logccdf
  1323. ilogccdf
  1324. entropy
  1325. logentropy
  1326. See Also
  1327. --------
  1328. :ref:`rv_infrastructure` : Tutorial
  1329. Notes
  1330. -----
  1331. The following abbreviations are used throughout the documentation.
  1332. - PDF: probability density function
  1333. - CDF: cumulative distribution function
  1334. - CCDF: complementary CDF
  1335. - entropy: differential entropy
  1336. - log-*F*: logarithm of *F* (e.g. log-CDF)
  1337. - inverse *F*: inverse function of *F* (e.g. inverse CDF)
  1338. The API documentation is written to describe the API, not to serve as
  1339. a statistical reference. Effort is made to be correct at the level
  1340. required to use the functionality, not to be mathematically rigorous.
  1341. For example, continuity and differentiability may be implicitly assumed.
  1342. For precise mathematical definitions, consult your preferred mathematical
  1343. text.
  1344. """
  1345. __array_priority__ = 1
  1346. _parameterizations = [] # type: ignore[var-annotated]
  1347. ### Initialization
  1348. def __init__(self, *, tol=_null, validation_policy=None, cache_policy=None,
  1349. **parameters):
  1350. self.tol = tol
  1351. self.validation_policy = validation_policy
  1352. self.cache_policy = cache_policy
  1353. self._not_implemented = (
  1354. f"`{self.__class__.__name__}` does not provide an accurate "
  1355. "implementation of the required method. Consider leaving "
  1356. "`method` and `tol` unspecified to use another implementation."
  1357. )
  1358. self._original_parameters = {}
  1359. # We may want to override the `__init__` method with parameters so
  1360. # IDEs can suggest parameter names. If there are multiple parameterizations,
  1361. # we'll need the default values of parameters to be None; this will
  1362. # filter out the parameters that were not actually specified by the user.
  1363. parameters = {key: val for key, val in
  1364. sorted(parameters.items()) if val is not None}
  1365. self._update_parameters(**parameters)
  1366. def _update_parameters(self, *, validation_policy=None, **params):
  1367. r""" Update the numerical values of distribution parameters.
  1368. Parameters
  1369. ----------
  1370. **params : array_like
  1371. Desired numerical values of the distribution parameters. Any or all
  1372. of the parameters initially used to instantiate the distribution
  1373. may be modified. Parameters used in alternative parameterizations
  1374. are not accepted.
  1375. validation_policy : str
  1376. To be documented. See Question 3 at the top.
  1377. """
  1378. parameters = original_parameters = self._original_parameters.copy()
  1379. parameters.update(**params)
  1380. parameterization = None
  1381. self._invalid = np.asarray(False)
  1382. self._any_invalid = False
  1383. self._shape = tuple()
  1384. self._ndim = 0
  1385. self._size = 1
  1386. self._dtype = np.float64
  1387. if (validation_policy or self.validation_policy) == _SKIP_ALL:
  1388. parameters = self._process_parameters(**parameters)
  1389. elif not len(self._parameterizations):
  1390. if parameters:
  1391. message = (f"The `{self.__class__.__name__}` distribution "
  1392. "family does not accept parameters, but parameters "
  1393. f"`{set(parameters)}` were provided.")
  1394. raise ValueError(message)
  1395. else:
  1396. # This is default behavior, which re-runs all parameter validations
  1397. # even when only a single parameter is modified. For many
  1398. # distributions, the domain of a parameter doesn't depend on other
  1399. # parameters, so parameters could safely be modified without
  1400. # re-validating all other parameters. To handle these cases more
  1401. # efficiently, we could allow the developer to override this
  1402. # behavior.
  1403. # Currently the user can only update the original parameterization.
  1404. # Even though that parameterization is already known,
  1405. # `_identify_parameterization` is called to produce a nice error
  1406. # message if the user passes other values. To be a little more
  1407. # efficient, we could detect whether the values passed are
  1408. # consistent with the original parameterization rather than finding
  1409. # it from scratch. However, we might want other parameterizations
  1410. # to be accepted, which would require other changes, so I didn't
  1411. # optimize this.
  1412. parameterization = self._identify_parameterization(parameters)
  1413. parameters, shape, size, ndim = self._broadcast(parameters)
  1414. parameters, invalid, any_invalid, dtype = (
  1415. self._validate(parameterization, parameters))
  1416. parameters = self._process_parameters(**parameters)
  1417. self._invalid = invalid
  1418. self._any_invalid = any_invalid
  1419. self._shape = shape
  1420. self._size = size
  1421. self._ndim = ndim
  1422. self._dtype = dtype
  1423. self.reset_cache()
  1424. self._parameters = parameters
  1425. self._parameterization = parameterization
  1426. self._original_parameters = original_parameters
  1427. for name in self._parameters.keys():
  1428. # Make parameters properties of the class; return values from the instance
  1429. if hasattr(self.__class__, name):
  1430. continue
  1431. setattr(self.__class__, name, property(lambda self_, name_=name:
  1432. self_._parameters[name_].copy()[()]))
  1433. def reset_cache(self):
  1434. r""" Clear all cached values.
  1435. To improve the speed of some calculations, the distribution's support
  1436. and moments are cached.
  1437. This function is called automatically whenever the distribution
  1438. parameters are updated.
  1439. """
  1440. # We could offer finer control over what is cleared.
  1441. # For simplicity, these will still exist even if cache_policy is
  1442. # NO_CACHE; they just won't be populated. This allows caching to be
  1443. # turned on and off easily.
  1444. self._moment_raw_cache = {}
  1445. self._moment_central_cache = {}
  1446. self._moment_standardized_cache = {}
  1447. self._support_cache = None
  1448. self._method_cache = {}
  1449. self._constant_cache = None
  1450. def _identify_parameterization(self, parameters):
  1451. # Determine whether a `parameters` dictionary matches is consistent
  1452. # with one of the parameterizations of the distribution. If so,
  1453. # return that parameterization object; if not, raise an error.
  1454. #
  1455. # I've come back to this a few times wanting to avoid this explicit
  1456. # loop. I've considered several possibilities, but they've all been a
  1457. # little unusual. For example, we could override `_eq_` so we can
  1458. # use _parameterizations.index() to retrieve the parameterization,
  1459. # or the user could put the parameterizations in a dictionary so we
  1460. # could look them up with a key (e.g. frozenset of parameter names).
  1461. # I haven't been sure enough of these approaches to implement them.
  1462. parameter_names_set = set(parameters)
  1463. for parameterization in self._parameterizations:
  1464. if parameterization.matches(parameter_names_set):
  1465. break
  1466. else:
  1467. if not parameter_names_set:
  1468. message = (f"The `{self.__class__.__name__}` distribution "
  1469. "family requires parameters, but none were "
  1470. "provided.")
  1471. else:
  1472. parameter_names = self._get_parameter_str(parameters)
  1473. message = (f"The provided parameters `{parameter_names}` "
  1474. "do not match a supported parameterization of the "
  1475. f"`{self.__class__.__name__}` distribution family.")
  1476. raise ValueError(message)
  1477. return parameterization
  1478. def _broadcast(self, parameters):
  1479. # Broadcast the distribution parameters to the same shape. If the
  1480. # arrays are not broadcastable, raise a meaningful error.
  1481. #
  1482. # We always make sure that the parameters *are* the same shape
  1483. # and not just broadcastable. Users can access parameters as
  1484. # attributes, and I think they should see the arrays as the same shape.
  1485. # More importantly, arrays should be the same shape before logical
  1486. # indexing operations, which are needed in infrastructure code when
  1487. # there are invalid parameters, and may be needed in
  1488. # distribution-specific code. We don't want developers to need to
  1489. # broadcast in implementation functions.
  1490. # It's much faster to check whether broadcasting is necessary than to
  1491. # broadcast when it's not necessary.
  1492. parameter_vals = [np.asarray(parameter)
  1493. for parameter in parameters.values()]
  1494. parameter_shapes = set(parameter.shape for parameter in parameter_vals)
  1495. if len(parameter_shapes) == 1:
  1496. return (parameters, parameter_vals[0].shape,
  1497. parameter_vals[0].size, parameter_vals[0].ndim)
  1498. try:
  1499. parameter_vals = np.broadcast_arrays(*parameter_vals)
  1500. except ValueError as e:
  1501. parameter_names = self._get_parameter_str(parameters)
  1502. message = (f"The parameters `{parameter_names}` provided to the "
  1503. f"`{self.__class__.__name__}` distribution family "
  1504. "cannot be broadcast to the same shape.")
  1505. raise ValueError(message) from e
  1506. return (dict(zip(parameters.keys(), parameter_vals)),
  1507. parameter_vals[0].shape,
  1508. parameter_vals[0].size,
  1509. parameter_vals[0].ndim)
  1510. def _validate(self, parameterization, parameters):
  1511. # Broadcasts distribution parameter arrays and converts them to a
  1512. # consistent dtype. Replaces invalid parameters with `np.nan`.
  1513. # Returns the validated parameters, a boolean mask indicated *which*
  1514. # elements are invalid, a boolean scalar indicating whether *any*
  1515. # are invalid (to skip special treatments if none are invalid), and
  1516. # the common dtype.
  1517. valid, dtype = parameterization.validation(parameters)
  1518. invalid = ~valid
  1519. any_invalid = invalid if invalid.shape == () else np.any(invalid)
  1520. # If necessary, make the arrays contiguous and replace invalid with NaN
  1521. if any_invalid:
  1522. for parameter_name in parameters:
  1523. parameters[parameter_name] = np.copy(
  1524. parameters[parameter_name])
  1525. parameters[parameter_name][invalid] = np.nan
  1526. return parameters, invalid, any_invalid, dtype
  1527. def _process_parameters(self, **params):
  1528. r""" Process and cache distribution parameters for reuse.
  1529. This is intended to be overridden by subclasses. It allows distribution
  1530. authors to pre-process parameters for re-use. For instance, when a user
  1531. parameterizes a LogUniform distribution with `a` and `b`, it makes
  1532. sense to calculate `log(a)` and `log(b)` because these values will be
  1533. used in almost all distribution methods. The dictionary returned by
  1534. this method is passed to all private methods that calculate functions
  1535. of the distribution.
  1536. """
  1537. return params
  1538. def _get_parameter_str(self, parameters):
  1539. # Get a string representation of the parameters like "{a, b, c}".
  1540. return f"{{{', '.join(parameters.keys())}}}"
  1541. def _copy_parameterization(self):
  1542. self._parameterizations = self._parameterizations.copy()
  1543. for i in range(len(self._parameterizations)):
  1544. self._parameterizations[i] = self._parameterizations[i].copy()
  1545. ### Attributes
  1546. # `tol` attribute is just notional right now. See Question 4 above.
  1547. @property
  1548. def tol(self):
  1549. r"""positive float:
  1550. The desired relative tolerance of calculations. Left unspecified,
  1551. calculations may be faster; when provided, calculations may be
  1552. more likely to meet the desired accuracy.
  1553. """
  1554. return self._tol
  1555. @tol.setter
  1556. def tol(self, tol):
  1557. if _isnull(tol):
  1558. self._tol = tol
  1559. return
  1560. tol = np.asarray(tol)
  1561. if (tol.shape != () or not tol > 0 or # catches NaNs
  1562. not np.issubdtype(tol.dtype, np.floating)):
  1563. message = (f"Attribute `tol` of `{self.__class__.__name__}` must "
  1564. "be a positive float, if specified.")
  1565. raise ValueError(message)
  1566. self._tol = tol[()]
  1567. @property
  1568. def cache_policy(self):
  1569. r"""{None, "no_cache"}:
  1570. Specifies the extent to which intermediate results are cached. Left
  1571. unspecified, intermediate results of some calculations (e.g. distribution
  1572. support, moments, etc.) are cached to improve performance of future
  1573. calculations. Pass ``'no_cache'`` to reduce memory reserved by the class
  1574. instance.
  1575. """
  1576. return self._cache_policy
  1577. @cache_policy.setter
  1578. def cache_policy(self, cache_policy):
  1579. cache_policy = str(cache_policy).lower() if cache_policy is not None else None
  1580. cache_policies = {None, 'no_cache'}
  1581. if cache_policy not in cache_policies:
  1582. message = (f"Attribute `cache_policy` of `{self.__class__.__name__}` "
  1583. f"must be one of {cache_policies}, if specified.")
  1584. raise ValueError(message)
  1585. self._cache_policy = cache_policy
  1586. @property
  1587. def validation_policy(self):
  1588. r"""{None, "skip_all"}:
  1589. Specifies the level of input validation to perform. Left unspecified,
  1590. input validation is performed to ensure appropriate behavior in edge
  1591. case (e.g. parameters out of domain, argument outside of distribution
  1592. support, etc.) and improve consistency of output dtype, shape, etc.
  1593. Use ``'skip_all'`` to avoid the computational overhead of these
  1594. checks when rough edges are acceptable.
  1595. """
  1596. return self._validation_policy
  1597. @validation_policy.setter
  1598. def validation_policy(self, validation_policy):
  1599. validation_policy = (str(validation_policy).lower()
  1600. if validation_policy is not None else None)
  1601. iv_policies = {None, 'skip_all'}
  1602. if validation_policy not in iv_policies:
  1603. message = (f"Attribute `validation_policy` of `{self.__class__.__name__}` "
  1604. f"must be one of {iv_policies}, if specified.")
  1605. raise ValueError(message)
  1606. self._validation_policy = validation_policy
  1607. ### Other magic methods
  1608. def __repr__(self):
  1609. r""" Returns a string representation of the distribution.
  1610. Includes the name of the distribution family, the names of the
  1611. parameters and the `repr` of each of their values.
  1612. """
  1613. class_name = self.__class__.__name__
  1614. parameters = list(self._original_parameters.items())
  1615. info = []
  1616. with np.printoptions(threshold=10):
  1617. str_parameters = [f"{symbol}={repr(value)}" for symbol, value in parameters]
  1618. str_parameters = f"{', '.join(str_parameters)}"
  1619. info.append(str_parameters)
  1620. return f"{class_name}({', '.join(info)})"
  1621. def __str__(self):
  1622. class_name = self.__class__.__name__
  1623. parameters = list(self._original_parameters.items())
  1624. info = []
  1625. with np.printoptions(threshold=10):
  1626. str_parameters = [f"{symbol}={str(value)}" for symbol, value in parameters]
  1627. str_parameters = f"{', '.join(str_parameters)}"
  1628. info.append(str_parameters)
  1629. return f"{class_name}({', '.join(info)})"
  1630. def __add__(self, loc):
  1631. return ShiftedScaledDistribution(self, loc=loc)
  1632. def __sub__(self, loc):
  1633. return ShiftedScaledDistribution(self, loc=-loc)
  1634. def __mul__(self, scale):
  1635. return ShiftedScaledDistribution(self, scale=scale)
  1636. def __truediv__(self, scale):
  1637. return ShiftedScaledDistribution(self, scale=1/scale)
  1638. def __pow__(self, other):
  1639. if not np.isscalar(other) or other <= 0 or other != int(other):
  1640. message = ("Raising a random variable to the power of an argument is only "
  1641. "implemented when the argument is a positive integer.")
  1642. raise NotImplementedError(message)
  1643. # Fill in repr_pattern with the repr of self before taking abs.
  1644. # Avoids having unnecessary abs in the repr.
  1645. with np.printoptions(threshold=10):
  1646. repr_pattern = f"({repr(self)})**{repr(other)}"
  1647. str_pattern = f"({str(self)})**{str(other)}"
  1648. X = abs(self) if other % 2 == 0 else self
  1649. funcs = dict(g=lambda u: u**other, repr_pattern=repr_pattern,
  1650. str_pattern=str_pattern,
  1651. h=lambda u: np.sign(u) * np.abs(u)**(1 / other),
  1652. dh=lambda u: 1/other * np.abs(u)**(1/other - 1))
  1653. return MonotonicTransformedDistribution(X, **funcs, increasing=True)
  1654. def __radd__(self, other):
  1655. return self.__add__(other)
  1656. def __rsub__(self, other):
  1657. return self.__neg__().__add__(other)
  1658. def __rmul__(self, other):
  1659. return self.__mul__(other)
  1660. def __rtruediv__(self, other):
  1661. a, b = self.support()
  1662. with np.printoptions(threshold=10):
  1663. funcs = dict(g=lambda u: 1 / u,
  1664. repr_pattern=f"{repr(other)}/({repr(self)})",
  1665. str_pattern=f"{str(other)}/({str(self)})",
  1666. h=lambda u: 1 / u, dh=lambda u: 1 / u ** 2)
  1667. if np.all(a >= 0) or np.all(b <= 0):
  1668. out = MonotonicTransformedDistribution(self, **funcs, increasing=False)
  1669. else:
  1670. message = ("Division by a random variable is only implemented "
  1671. "when the support is either non-negative or non-positive.")
  1672. raise NotImplementedError(message)
  1673. if np.all(other == 1):
  1674. return out
  1675. else:
  1676. return out * other
  1677. def __rpow__(self, other):
  1678. with np.printoptions(threshold=10):
  1679. funcs = dict(g=lambda u: other**u,
  1680. h=lambda u: np.log(u) / np.log(other),
  1681. dh=lambda u: 1 / np.abs(u * np.log(other)),
  1682. repr_pattern=f"{repr(other)}**({repr(self)})",
  1683. str_pattern=f"{str(other)}**({str(self)})",)
  1684. if not np.isscalar(other) or other <= 0 or other == 1:
  1685. message = ("Raising an argument to the power of a random variable is only "
  1686. "implemented when the argument is a positive scalar other than "
  1687. "1.")
  1688. raise NotImplementedError(message)
  1689. if other > 1:
  1690. return MonotonicTransformedDistribution(self, **funcs, increasing=True)
  1691. else:
  1692. return MonotonicTransformedDistribution(self, **funcs, increasing=False)
  1693. def __neg__(self):
  1694. return self * -1
  1695. def __abs__(self):
  1696. return FoldedDistribution(self)
  1697. ### Utilities
  1698. ## Input validation
  1699. def _validate_order_kind(self, order, kind, kinds):
  1700. # Yet another integer validating function. Unlike others in SciPy, it
  1701. # Is quite flexible about what is allowed as an integer, and it
  1702. # raises a distribution-specific error message to facilitate
  1703. # identification of the source of the error.
  1704. if self.validation_policy == _SKIP_ALL:
  1705. return order
  1706. order = np.asarray(order, dtype=self._dtype)[()]
  1707. message = (f"Argument `order` of `{self.__class__.__name__}.moment` "
  1708. "must be a finite, positive integer.")
  1709. try:
  1710. order_int = round(order.item())
  1711. # If this fails for any reason (e.g. it's an array, it's infinite)
  1712. # it's not a valid `order`.
  1713. except Exception as e:
  1714. raise ValueError(message) from e
  1715. if order_int <0 or order_int != order:
  1716. raise ValueError(message)
  1717. message = (f"Argument `kind` of `{self.__class__.__name__}.moment` "
  1718. f"must be one of {set(kinds)}.")
  1719. if kind.lower() not in kinds:
  1720. raise ValueError(message)
  1721. return order
  1722. def _preserve_type(self, x):
  1723. x = np.asarray(x)
  1724. if x.dtype != self._dtype:
  1725. x = x.astype(self._dtype)
  1726. return x[()]
  1727. ## Testing
  1728. @classmethod
  1729. def _draw(cls, sizes=None, rng=None, i_parameterization=None,
  1730. proportions=None):
  1731. r""" Draw a specific (fully-defined) distribution from the family.
  1732. See _Parameterization.draw for documentation details.
  1733. """
  1734. rng = np.random.default_rng(rng)
  1735. if len(cls._parameterizations) == 0:
  1736. return cls()
  1737. if i_parameterization is None:
  1738. n = cls._num_parameterizations()
  1739. i_parameterization = rng.integers(0, max(0, n - 1), endpoint=True)
  1740. parameterization = cls._parameterizations[i_parameterization]
  1741. parameters = parameterization.draw(sizes, rng, proportions=proportions,
  1742. region='typical')
  1743. return cls(**parameters)
  1744. @classmethod
  1745. def _num_parameterizations(cls):
  1746. # Returns the number of parameterizations accepted by the family.
  1747. return len(cls._parameterizations)
  1748. @classmethod
  1749. def _num_parameters(cls, i_parameterization=0):
  1750. # Returns the number of parameters used in the specified
  1751. # parameterization.
  1752. return (0 if not cls._num_parameterizations()
  1753. else len(cls._parameterizations[i_parameterization]))
  1754. ## Algorithms
  1755. def _quadrature(self, integrand, limits=None, args=None,
  1756. params=None, log=False):
  1757. # Performs numerical integration of an integrand between limits.
  1758. # Much of this should be added to `_tanhsinh`.
  1759. a, b = self._support(**params) if limits is None else limits
  1760. a, b = np.broadcast_arrays(a, b)
  1761. if not a.size:
  1762. # maybe need to figure out result type from a, b
  1763. return np.empty(a.shape, dtype=self._dtype)
  1764. args = [] if args is None else args
  1765. params = {} if params is None else params
  1766. f, args = _kwargs2args(integrand, args=args, kwargs=params)
  1767. args = np.broadcast_arrays(*args)
  1768. # If we know the median or mean, consider breaking up the interval
  1769. rtol = None if _isnull(self.tol) else self.tol
  1770. # For now, we ignore the status, but I want to return the error
  1771. # estimate - see question 5 at the top.
  1772. if isinstance(self, ContinuousDistribution):
  1773. res = _tanhsinh(f, a, b, args=args, log=log, rtol=rtol)
  1774. return res.integral
  1775. else:
  1776. res = nsum(f, a, b, args=args, log=log, tolerances=dict(rtol=rtol)).sum
  1777. res = np.asarray(res)
  1778. # The result should be nan when parameters are nan, so need to special
  1779. # case this.
  1780. cond = np.isnan(params.popitem()[1]) if params else np.True_
  1781. cond = np.broadcast_to(cond, a.shape)
  1782. res[(a > b)] = -np.inf if log else 0 # fix in nsum?
  1783. res[cond] = np.nan
  1784. return res[()]
  1785. def _solve_bounded(self, f, p, *, bounds=None, params=None, xatol=None):
  1786. # Finds the argument of a function that produces the desired output.
  1787. # Much of this should be added to _bracket_root / _chandrupatla.
  1788. xmin, xmax = self._support(**params) if bounds is None else bounds
  1789. params = {} if params is None else params
  1790. p, xmin, xmax = np.broadcast_arrays(p, xmin, xmax)
  1791. if not p.size:
  1792. # might need to figure out result type based on p
  1793. res = _RichResult()
  1794. empty = np.empty(p.shape, dtype=self._dtype)
  1795. res.xl, res.x, res.xr = empty, empty, empty
  1796. res.fl, res.fr = empty, empty
  1797. def f2(x, _p, **kwargs): # named `_p` to avoid conflict with shape `p`
  1798. return f(x, **kwargs) - _p
  1799. f3, args = _kwargs2args(f2, args=[p], kwargs=params)
  1800. # If we know the median or mean, should use it
  1801. # Any operations between 0d array and a scalar produces a scalar, so...
  1802. shape = xmin.shape
  1803. xmin, xmax = np.atleast_1d(xmin, xmax)
  1804. xl0, xr0 = _guess_bracket(xmin, xmax)
  1805. xmin = xmin.reshape(shape)
  1806. xmax = xmax.reshape(shape)
  1807. xl0 = xl0.reshape(shape)
  1808. xr0 = xr0.reshape(shape)
  1809. res = _bracket_root(f3, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, args=args)
  1810. # For now, we ignore the status, but I want to use the bracket width
  1811. # as an error estimate - see question 5 at the top.
  1812. xrtol = None if _isnull(self.tol) else self.tol
  1813. xatol = None if xatol is None else xatol
  1814. tolerances = dict(xrtol=xrtol, xatol=xatol, fatol=0, frtol=0)
  1815. return _chandrupatla(f3, a=res.xl, b=res.xr, args=args, **tolerances)
  1816. ## Other
  1817. def _overrides(self, method_name):
  1818. # Determines whether a class overrides a specified method.
  1819. # Returns True if the method implementation exists and is the same as
  1820. # that of the `ContinuousDistribution` class; otherwise returns False.
  1821. # Sometimes we use `_overrides` to check whether a certain method is overridden
  1822. # and if so, call it. This begs the questions of why we don't do the more
  1823. # obvious thing: restructure so that if the private method is overridden,
  1824. # Python will call it instead of the inherited version automatically. The short
  1825. # answer is that there are multiple ways a use might wish to evaluate a method,
  1826. # and simply overriding the method with a formula is not always the best option.
  1827. # For more complete discussion of the considerations, see:
  1828. # https://github.com/scipy/scipy/pull/21050#discussion_r1707798901
  1829. method = getattr(self.__class__, method_name, None)
  1830. super_method = getattr(UnivariateDistribution, method_name, None)
  1831. return method is not super_method
  1832. ### Distribution properties
  1833. # The following "distribution properties" are exposed via a public method
  1834. # that accepts only options (not distribution parameters or quantile/
  1835. # percentile argument).
  1836. # support
  1837. # logentropy, entropy,
  1838. # median, mode, mean,
  1839. # variance, standard_deviation
  1840. # skewness, kurtosis
  1841. # Common options are:
  1842. # method - a string that indicates which method should be used to compute
  1843. # the quantity (e.g. a formula or numerical integration).
  1844. # Input/output validation is provided by the `_set_invalid_nan_property`
  1845. # decorator. These are the methods meant to be called by users.
  1846. #
  1847. # Each public method calls a private "dispatch" method that
  1848. # determines which "method" (strategy for calculating the desired quantity)
  1849. # to use by default and, via the `@_dispatch` decorator, calls the
  1850. # method and computes the result.
  1851. # Dispatch methods always accept:
  1852. # method - as passed from the public method
  1853. # params - a dictionary of distribution shape parameters passed by
  1854. # the public method.
  1855. # Dispatch methods accept `params` rather than relying on the state of the
  1856. # object because iterative algorithms like `_tanhsinh` and `_chandrupatla`
  1857. # need their callable to follow a strict elementwise protocol: each element
  1858. # of the output is determined solely by the values of the inputs at the
  1859. # corresponding location. The public methods do not satisfy this protocol
  1860. # because they do not accept the parameters as arguments, producing an
  1861. # output that generally has a different shape than that of the input. Also,
  1862. # by calling "dispatch" methods rather than the public methods, the
  1863. # iterative algorithms avoid the overhead of input validation.
  1864. #
  1865. # Each dispatch method can designate the responsibility of computing
  1866. # the required value to any of several "implementation" methods. These
  1867. # methods accept only `**params`, the parameter dictionary passed from
  1868. # the public method via the dispatch method. We separate the implementation
  1869. # methods from the dispatch methods for the sake of simplicity (via
  1870. # compartmentalization) and to allow subclasses to override certain
  1871. # implementation methods (typically only the "formula" methods). The names
  1872. # of implementation methods are combinations of the public method name and
  1873. # the name of the "method" (strategy for calculating the desired quantity)
  1874. # string. (In fact, the name of the implementation method is calculated
  1875. # from these two strings in the `_dispatch` decorator.) Common method
  1876. # strings are:
  1877. # formula - distribution-specific analytical expressions to be implemented
  1878. # by subclasses.
  1879. # log/exp - Compute the log of a number and then exponentiate it or vice
  1880. # versa.
  1881. # quadrature - Compute the value via numerical integration.
  1882. #
  1883. # The default method (strategy) is determined based on what implementation
  1884. # methods are available and the error tolerance of the user. Typically,
  1885. # a formula is always used if available. We fall back to "log/exp" if a
  1886. # formula for the logarithm or exponential of the quantity is available,
  1887. # and we use quadrature otherwise.
  1888. def support(self):
  1889. # If this were a `cached_property`, we couldn't update the value
  1890. # when the distribution parameters change.
  1891. # Caching is important, though, because calls to _support take a few
  1892. # microseconds even when `a` and `b` are already the same shape.
  1893. if self._support_cache is not None:
  1894. return self._support_cache
  1895. a, b = self._support(**self._parameters)
  1896. if a.shape != self._shape:
  1897. a = np.broadcast_to(a, self._shape)
  1898. if b.shape != self._shape:
  1899. b = np.broadcast_to(b, self._shape)
  1900. if self._any_invalid:
  1901. a, b = np.asarray(a).copy(), np.asarray(b).copy()
  1902. a[self._invalid], b[self._invalid] = np.nan, np.nan
  1903. a, b = a[()], b[()]
  1904. support = (a, b)
  1905. if self.cache_policy != _NO_CACHE:
  1906. self._support_cache = support
  1907. return support
  1908. def _support(self, **params):
  1909. # Computes the support given distribution parameters
  1910. a, b = self._variable.domain.get_numerical_endpoints(params)
  1911. if len(params):
  1912. # the parameters should all be of the same dtype and shape at this point
  1913. vals = list(params.values())
  1914. shape = vals[0].shape
  1915. a = np.broadcast_to(a, shape) if a.shape != shape else a
  1916. b = np.broadcast_to(b, shape) if b.shape != shape else b
  1917. return self._preserve_type(a), self._preserve_type(b)
  1918. @_set_invalid_nan_property
  1919. def logentropy(self, *, method=None):
  1920. return self._logentropy_dispatch(method=method, **self._parameters) + 0j
  1921. @_dispatch
  1922. def _logentropy_dispatch(self, method=None, **params):
  1923. if self._overrides('_logentropy_formula'):
  1924. method = self._logentropy_formula
  1925. elif self._overrides('_entropy_formula'):
  1926. method = self._logentropy_logexp_safe
  1927. else:
  1928. method = self._logentropy_quadrature
  1929. return method
  1930. def _logentropy_formula(self, **params):
  1931. raise NotImplementedError(self._not_implemented)
  1932. def _logentropy_logexp(self, **params):
  1933. res = np.log(self._entropy_dispatch(**params)+0j)
  1934. return _log_real_standardize(res)
  1935. def _logentropy_logexp_safe(self, **params):
  1936. out = self._logentropy_logexp(**params)
  1937. mask = np.isinf(out.real)
  1938. if np.any(mask):
  1939. params_mask = {key:val[mask] for key, val in params.items()}
  1940. out = np.asarray(out)
  1941. out[mask] = self._logentropy_quadrature(**params_mask)
  1942. return out[()]
  1943. def _logentropy_quadrature(self, **params):
  1944. def logintegrand(x, **params):
  1945. logpxf = self._logpxf_dispatch(x, **params)
  1946. return logpxf + np.log(0j+logpxf)
  1947. res = self._quadrature(logintegrand, params=params, log=True)
  1948. return _log_real_standardize(res + np.pi*1j)
  1949. @_set_invalid_nan_property
  1950. def entropy(self, *, method=None):
  1951. return self._entropy_dispatch(method=method, **self._parameters)
  1952. @_dispatch
  1953. def _entropy_dispatch(self, method=None, **params):
  1954. if self._overrides('_entropy_formula'):
  1955. method = self._entropy_formula
  1956. elif self._overrides('_logentropy_formula'):
  1957. method = self._entropy_logexp
  1958. else:
  1959. method = self._entropy_quadrature
  1960. return method
  1961. def _entropy_formula(self, **params):
  1962. raise NotImplementedError(self._not_implemented)
  1963. def _entropy_logexp(self, **params):
  1964. return np.real(np.exp(self._logentropy_dispatch(**params)))
  1965. def _entropy_quadrature(self, **params):
  1966. def integrand(x, **params):
  1967. pxf = self._pxf_dispatch(x, **params)
  1968. logpxf = self._logpxf_dispatch(x, **params)
  1969. temp = np.asarray(pxf)
  1970. i = (pxf != 0) # 0 * inf -> nan; should be 0
  1971. temp[i] = pxf[i]*logpxf[i]
  1972. return temp
  1973. return -self._quadrature(integrand, params=params)
  1974. @_set_invalid_nan_property
  1975. def median(self, *, method=None):
  1976. return self._median_dispatch(method=method, **self._parameters)
  1977. @_dispatch
  1978. def _median_dispatch(self, method=None, **params):
  1979. if self._overrides('_median_formula'):
  1980. method = self._median_formula
  1981. else:
  1982. method = self._median_icdf
  1983. return method
  1984. def _median_formula(self, **params):
  1985. raise NotImplementedError(self._not_implemented)
  1986. def _median_icdf(self, **params):
  1987. return self._icdf_dispatch(np.asarray(0.5, dtype=self._dtype), **params)
  1988. @_set_invalid_nan_property
  1989. def mode(self, *, method=None):
  1990. return self._mode_dispatch(method=method, **self._parameters)
  1991. @_dispatch
  1992. def _mode_dispatch(self, method=None, **params):
  1993. # We could add a method that looks for a critical point with
  1994. # differentiation and the root finder
  1995. if self._overrides('_mode_formula'):
  1996. method = self._mode_formula
  1997. else:
  1998. method = self._mode_optimization
  1999. return method
  2000. def _mode_formula(self, **params):
  2001. raise NotImplementedError(self._not_implemented)
  2002. def _mode_optimization(self, xatol=None, **params):
  2003. if not self._size:
  2004. return np.empty(self._shape, dtype=self._dtype)
  2005. a, b = self._support(**params)
  2006. m = self._median_dispatch(**params)
  2007. f, args = _kwargs2args(lambda x, **params: -self._pxf_dispatch(x, **params),
  2008. args=(), kwargs=params)
  2009. res_b = _bracket_minimum(f, m, xmin=a, xmax=b, args=args)
  2010. res = _chandrupatla_minimize(f, res_b.xl, res_b.xm, res_b.xr,
  2011. args=args, xatol=xatol)
  2012. mode = np.asarray(res.x)
  2013. mode_at_boundary = res_b.status == -1
  2014. mode_at_left = mode_at_boundary & (res_b.fl <= res_b.fm)
  2015. mode_at_right = mode_at_boundary & (res_b.fr < res_b.fm)
  2016. mode[mode_at_left] = a[mode_at_left]
  2017. mode[mode_at_right] = b[mode_at_right]
  2018. return mode[()]
  2019. def mean(self, *, method=None):
  2020. return self.moment(1, kind='raw', method=method)
  2021. def variance(self, *, method=None):
  2022. return self.moment(2, kind='central', method=method)
  2023. def standard_deviation(self, *, method=None):
  2024. return np.sqrt(self.variance(method=method))
  2025. def skewness(self, *, method=None):
  2026. return self.moment(3, kind='standardized', method=method)
  2027. def kurtosis(self, *, method=None, convention='non-excess'):
  2028. conventions = {'non-excess', 'excess'}
  2029. message = (f'Parameter `convention` of `{self.__class__.__name__}.kurtosis` '
  2030. f"must be one of {conventions}.")
  2031. convention = convention.lower()
  2032. if convention not in conventions:
  2033. raise ValueError(message)
  2034. k = self.moment(4, kind='standardized', method=method)
  2035. return k - 3 if convention == 'excess' else k
  2036. ### Distribution functions
  2037. # The following functions related to the distribution PDF and CDF are
  2038. # exposed via a public method that accepts one positional argument - the
  2039. # quantile - and keyword options (but not distribution parameters).
  2040. # logpdf, pdf
  2041. # logcdf, cdf
  2042. # logccdf, ccdf
  2043. # The `logcdf` and `cdf` functions can also be called with two positional
  2044. # arguments - lower and upper quantiles - and they return the probability
  2045. # mass (integral of the PDF) between them. The 2-arg versions of `logccdf`
  2046. # and `ccdf` return the complement of this quantity.
  2047. # All the (1-arg) cumulative distribution functions have inverse
  2048. # functions, which accept one positional argument - the percentile.
  2049. # ilogcdf, icdf
  2050. # ilogccdf, iccdf
  2051. # Common keyword options include:
  2052. # method - a string that indicates which method should be used to compute
  2053. # the quantity (e.g. a formula or numerical integration).
  2054. # Tolerance options should be added.
  2055. # Input/output validation is provided by the `_set_invalid_nan`
  2056. # decorator. These are the methods meant to be called by users.
  2057. #
  2058. # Each public method calls a private "dispatch" method that
  2059. # determines which "method" (strategy for calculating the desired quantity)
  2060. # to use by default and, via the `@_dispatch` decorator, calls the
  2061. # method and computes the result.
  2062. # Each dispatch method can designate the responsibility of computing
  2063. # the required value to any of several "implementation" methods. These
  2064. # methods accept only `**params`, the parameter dictionary passed from
  2065. # the public method via the dispatch method.
  2066. # See the note corresponding with the "Distribution Parameters" for more
  2067. # information.
  2068. ## Probability Density/Mass Functions
  2069. @_set_invalid_nan
  2070. def logpdf(self, x, /, *, method=None):
  2071. return self._logpdf_dispatch(x, method=method, **self._parameters)
  2072. @_dispatch
  2073. def _logpdf_dispatch(self, x, *, method=None, **params):
  2074. if self._overrides('_logpdf_formula'):
  2075. method = self._logpdf_formula
  2076. elif _isnull(self.tol): # ensure that developers override _logpdf
  2077. method = self._logpdf_logexp
  2078. return method
  2079. def _logpdf_formula(self, x, **params):
  2080. raise NotImplementedError(self._not_implemented)
  2081. def _logpdf_logexp(self, x, **params):
  2082. return np.log(self._pdf_dispatch(x, **params))
  2083. @_set_invalid_nan
  2084. def pdf(self, x, /, *, method=None):
  2085. return self._pdf_dispatch(x, method=method, **self._parameters)
  2086. @_dispatch
  2087. def _pdf_dispatch(self, x, *, method=None, **params):
  2088. if self._overrides('_pdf_formula'):
  2089. method = self._pdf_formula
  2090. else:
  2091. method = self._pdf_logexp
  2092. return method
  2093. def _pdf_formula(self, x, **params):
  2094. raise NotImplementedError(self._not_implemented)
  2095. def _pdf_logexp(self, x, **params):
  2096. return np.exp(self._logpdf_dispatch(x, **params))
  2097. @_set_invalid_nan
  2098. def logpmf(self, x, /, *, method=None):
  2099. return self._logpmf_dispatch(x, method=method, **self._parameters)
  2100. @_dispatch
  2101. def _logpmf_dispatch(self, x, *, method=None, **params):
  2102. if self._overrides('_logpmf_formula'):
  2103. method = self._logpmf_formula
  2104. elif _isnull(self.tol): # ensure that developers override _logpmf
  2105. method = self._logpmf_logexp
  2106. return method
  2107. def _logpmf_formula(self, x, **params):
  2108. raise NotImplementedError(self._not_implemented)
  2109. def _logpmf_logexp(self, x, **params):
  2110. with np.errstate(divide='ignore'):
  2111. return np.log(self._pmf_dispatch(x, **params))
  2112. @_set_invalid_nan
  2113. def pmf(self, x, /, *, method=None):
  2114. return self._pmf_dispatch(x, method=method, **self._parameters)
  2115. @_dispatch
  2116. def _pmf_dispatch(self, x, *, method=None, **params):
  2117. if self._overrides('_pmf_formula'):
  2118. method = self._pmf_formula
  2119. else:
  2120. method = self._pmf_logexp
  2121. return method
  2122. def _pmf_formula(self, x, **params):
  2123. raise NotImplementedError(self._not_implemented)
  2124. def _pmf_logexp(self, x, **params):
  2125. return np.exp(self._logpmf_dispatch(x, **params))
  2126. ## Cumulative Distribution Functions
  2127. def logcdf(self, x, y=None, /, *, method=None):
  2128. if y is None:
  2129. return self._logcdf1(x, method=method)
  2130. else:
  2131. return self._logcdf2(x, y, method=method)
  2132. @_cdf2_input_validation
  2133. def _logcdf2(self, x, y, *, method):
  2134. out = self._logcdf2_dispatch(x, y, method=method, **self._parameters)
  2135. return (out + 0j) if not np.issubdtype(out.dtype, np.complexfloating) else out
  2136. @_dispatch
  2137. def _logcdf2_dispatch(self, x, y, *, method=None, **params):
  2138. # dtype is complex if any x > y, else real
  2139. # Should revisit this logic.
  2140. if self._overrides('_logcdf2_formula'):
  2141. method = self._logcdf2_formula
  2142. elif (self._overrides('_logcdf_formula')
  2143. or self._overrides('_logccdf_formula')):
  2144. method = self._logcdf2_subtraction
  2145. elif (self._overrides('_cdf_formula')
  2146. or self._overrides('_ccdf_formula')):
  2147. method = self._logcdf2_logexp_safe
  2148. else:
  2149. method = self._logcdf2_quadrature
  2150. return method
  2151. def _logcdf2_formula(self, x, y, **params):
  2152. raise NotImplementedError(self._not_implemented)
  2153. def _logcdf2_subtraction(self, x, y, **params):
  2154. flip_sign = x > y # some results will be negative
  2155. x, y = np.minimum(x, y), np.maximum(x, y)
  2156. logcdf_x = self._logcdf_dispatch(x, **params)
  2157. logcdf_y = self._logcdf_dispatch(y, **params)
  2158. logccdf_x = self._logccdf_dispatch(x, **params)
  2159. logccdf_y = self._logccdf_dispatch(y, **params)
  2160. case_left = (logcdf_x < -1) & (logcdf_y < -1)
  2161. case_right = (logccdf_x < -1) & (logccdf_y < -1)
  2162. case_central = ~(case_left | case_right)
  2163. log_mass = _logexpxmexpy(logcdf_y, logcdf_x)
  2164. log_mass[case_right] = _logexpxmexpy(logccdf_x, logccdf_y)[case_right]
  2165. log_tail = np.logaddexp(logcdf_x, logccdf_y)[case_central]
  2166. log_mass[case_central] = _log1mexp(log_tail)
  2167. log_mass[flip_sign] += np.pi * 1j
  2168. return log_mass[()] if np.any(flip_sign) else log_mass.real[()]
  2169. def _logcdf2_logexp(self, x, y, **params):
  2170. expres = self._cdf2_dispatch(x, y, **params)
  2171. expres = expres + 0j if np.any(x > y) else expres
  2172. return np.log(expres)
  2173. def _logcdf2_logexp_safe(self, x, y, **params):
  2174. out = self._logcdf2_logexp(x, y, **params)
  2175. mask = np.isinf(out.real)
  2176. if np.any(mask):
  2177. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2178. for key, val in params.items()}
  2179. out = np.asarray(out)
  2180. out[mask] = self._logcdf2_quadrature(x[mask], y[mask], **params_mask)
  2181. return out[()]
  2182. def _logcdf2_quadrature(self, x, y, **params):
  2183. logres = self._quadrature(self._logpxf_dispatch, limits=(x, y),
  2184. log=True, params=params)
  2185. return logres
  2186. @_set_invalid_nan
  2187. def _logcdf1(self, x, *, method=None):
  2188. return self._logcdf_dispatch(x, method=method, **self._parameters)
  2189. @_dispatch
  2190. def _logcdf_dispatch(self, x, *, method=None, **params):
  2191. if self._overrides('_logcdf_formula'):
  2192. method = self._logcdf_formula
  2193. elif self._overrides('_logccdf_formula'):
  2194. method = self._logcdf_complement
  2195. elif self._overrides('_cdf_formula'):
  2196. method = self._logcdf_logexp_safe
  2197. else:
  2198. method = self._logcdf_quadrature
  2199. return method
  2200. def _logcdf_formula(self, x, **params):
  2201. raise NotImplementedError(self._not_implemented)
  2202. def _logcdf_complement(self, x, **params):
  2203. return _log1mexp(self._logccdf_dispatch(x, **params))
  2204. def _logcdf_logexp(self, x, **params):
  2205. return np.log(self._cdf_dispatch(x, **params))
  2206. def _logcdf_logexp_safe(self, x, **params):
  2207. out = self._logcdf_logexp(x, **params)
  2208. mask = np.isinf(out)
  2209. if np.any(mask):
  2210. params_mask = {key:np.broadcast_to(val, mask.shape)[mask]
  2211. for key, val in params.items()}
  2212. out = np.asarray(out)
  2213. out[mask] = self._logcdf_quadrature(x[mask], **params_mask)
  2214. return out[()]
  2215. def _logcdf_quadrature(self, x, **params):
  2216. a, _ = self._support(**params)
  2217. return self._quadrature(self._logpxf_dispatch, limits=(a, x),
  2218. params=params, log=True)
  2219. def cdf(self, x, y=None, /, *, method=None):
  2220. if y is None:
  2221. return self._cdf1(x, method=method)
  2222. else:
  2223. return self._cdf2(x, y, method=method)
  2224. @_cdf2_input_validation
  2225. def _cdf2(self, x, y, *, method):
  2226. return self._cdf2_dispatch(x, y, method=method, **self._parameters)
  2227. @_dispatch
  2228. def _cdf2_dispatch(self, x, y, *, method=None, **params):
  2229. # Should revisit this logic.
  2230. if self._overrides('_cdf2_formula'):
  2231. method = self._cdf2_formula
  2232. elif (self._overrides('_logcdf_formula')
  2233. or self._overrides('_logccdf_formula')):
  2234. method = self._cdf2_logexp
  2235. elif self._overrides('_cdf_formula') or self._overrides('_ccdf_formula'):
  2236. method = self._cdf2_subtraction_safe
  2237. else:
  2238. method = self._cdf2_quadrature
  2239. return method
  2240. def _cdf2_formula(self, x, y, **params):
  2241. raise NotImplementedError(self._not_implemented)
  2242. def _cdf2_logexp(self, x, y, **params):
  2243. return np.real(np.exp(self._logcdf2_dispatch(x, y, **params)))
  2244. def _cdf2_subtraction(self, x, y, **params):
  2245. # Improvements:
  2246. # Lazy evaluation of cdf/ccdf only where needed
  2247. # Stack x and y to reduce function calls?
  2248. cdf_x = self._cdf_dispatch(x, **params)
  2249. cdf_y = self._cdf_dispatch(y, **params)
  2250. ccdf_x = self._ccdf_dispatch(x, **params)
  2251. ccdf_y = self._ccdf_dispatch(y, **params)
  2252. i = (ccdf_x < 0.5) & (ccdf_y < 0.5)
  2253. return np.where(i, ccdf_x-ccdf_y, cdf_y-cdf_x)
  2254. def _cdf2_subtraction_safe(self, x, y, **params):
  2255. cdf_x = self._cdf_dispatch(x, **params)
  2256. cdf_y = self._cdf_dispatch(y, **params)
  2257. ccdf_x = self._ccdf_dispatch(x, **params)
  2258. ccdf_y = self._ccdf_dispatch(y, **params)
  2259. i = (ccdf_x < 0.5) & (ccdf_y < 0.5)
  2260. out = np.where(i, ccdf_x-ccdf_y, cdf_y-cdf_x)
  2261. eps = np.finfo(self._dtype).eps
  2262. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2263. cdf_max = np.maximum(cdf_x, cdf_y)
  2264. ccdf_max = np.maximum(ccdf_x, ccdf_y)
  2265. spacing = np.spacing(np.where(i, ccdf_max, cdf_max))
  2266. mask = np.abs(tol * out) < spacing
  2267. if np.any(mask):
  2268. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2269. for key, val in params.items()}
  2270. out = np.asarray(out)
  2271. out[mask] = self._cdf2_quadrature(x[mask], y[mask], **params_mask)
  2272. return out[()]
  2273. def _cdf2_quadrature(self, x, y, **params):
  2274. return self._quadrature(self._pxf_dispatch, limits=(x, y), params=params)
  2275. @_set_invalid_nan
  2276. def _cdf1(self, x, *, method):
  2277. return self._cdf_dispatch(x, method=method, **self._parameters)
  2278. @_dispatch
  2279. def _cdf_dispatch(self, x, *, method=None, **params):
  2280. if self._overrides('_cdf_formula'):
  2281. method = self._cdf_formula
  2282. elif self._overrides('_logcdf_formula'):
  2283. method = self._cdf_logexp
  2284. elif self._overrides('_ccdf_formula'):
  2285. method = self._cdf_complement_safe
  2286. else:
  2287. method = self._cdf_quadrature
  2288. return method
  2289. def _cdf_formula(self, x, **params):
  2290. raise NotImplementedError(self._not_implemented)
  2291. def _cdf_logexp(self, x, **params):
  2292. return np.exp(self._logcdf_dispatch(x, **params))
  2293. def _cdf_complement(self, x, **params):
  2294. return 1 - self._ccdf_dispatch(x, **params)
  2295. def _cdf_complement_safe(self, x, **params):
  2296. ccdf = self._ccdf_dispatch(x, **params)
  2297. out = 1 - ccdf
  2298. eps = np.finfo(self._dtype).eps
  2299. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2300. mask = tol * out < np.spacing(ccdf)
  2301. if np.any(mask):
  2302. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2303. for key, val in params.items()}
  2304. out = np.asarray(out)
  2305. out[mask] = self._cdf_quadrature(x[mask], *params_mask)
  2306. return out[()]
  2307. def _cdf_quadrature(self, x, **params):
  2308. a, _ = self._support(**params)
  2309. return self._quadrature(self._pxf_dispatch, limits=(a, x),
  2310. params=params)
  2311. def logccdf(self, x, y=None, /, *, method=None):
  2312. if y is None:
  2313. return self._logccdf1(x, method=method)
  2314. else:
  2315. return self._logccdf2(x, y, method=method)
  2316. @_cdf2_input_validation
  2317. def _logccdf2(self, x, y, *, method):
  2318. return self._logccdf2_dispatch(x, y, method=method, **self._parameters)
  2319. @_dispatch
  2320. def _logccdf2_dispatch(self, x, y, *, method=None, **params):
  2321. # if _logccdf2_formula exists, we could use the complement
  2322. # if _ccdf2_formula exists, we could use log/exp
  2323. if self._overrides('_logccdf2_formula'):
  2324. method = self._logccdf2_formula
  2325. else:
  2326. method = self._logccdf2_addition
  2327. return method
  2328. def _logccdf2_formula(self, x, y, **params):
  2329. raise NotImplementedError(self._not_implemented)
  2330. def _logccdf2_addition(self, x, y, **params):
  2331. logcdf_x = self._logcdf_dispatch(x, **params)
  2332. logccdf_y = self._logccdf_dispatch(y, **params)
  2333. return special.logsumexp([logcdf_x, logccdf_y], axis=0)
  2334. @_set_invalid_nan
  2335. def _logccdf1(self, x, *, method=None):
  2336. return self._logccdf_dispatch(x, method=method, **self._parameters)
  2337. @_dispatch
  2338. def _logccdf_dispatch(self, x, method=None, **params):
  2339. if self._overrides('_logccdf_formula'):
  2340. method = self._logccdf_formula
  2341. elif self._overrides('_logcdf_formula'):
  2342. method = self._logccdf_complement
  2343. elif self._overrides('_ccdf_formula'):
  2344. method = self._logccdf_logexp_safe
  2345. else:
  2346. method = self._logccdf_quadrature
  2347. return method
  2348. def _logccdf_formula(self, x, **params):
  2349. raise NotImplementedError(self._not_implemented)
  2350. def _logccdf_complement(self, x, **params):
  2351. return _log1mexp(self._logcdf_dispatch(x, **params))
  2352. def _logccdf_logexp(self, x, **params):
  2353. return np.log(self._ccdf_dispatch(x, **params))
  2354. def _logccdf_logexp_safe(self, x, **params):
  2355. out = self._logccdf_logexp(x, **params)
  2356. mask = np.isinf(out)
  2357. if np.any(mask):
  2358. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2359. for key, val in params.items()}
  2360. out = np.asarray(out)
  2361. out[mask] = self._logccdf_quadrature(x[mask], **params_mask)
  2362. return out[()]
  2363. def _logccdf_quadrature(self, x, **params):
  2364. _, b = self._support(**params)
  2365. return self._quadrature(self._logpxf_dispatch, limits=(x, b),
  2366. params=params, log=True)
  2367. def ccdf(self, x, y=None, /, *, method=None):
  2368. if y is None:
  2369. return self._ccdf1(x, method=method)
  2370. else:
  2371. return self._ccdf2(x, y, method=method)
  2372. @_cdf2_input_validation
  2373. def _ccdf2(self, x, y, *, method):
  2374. return self._ccdf2_dispatch(x, y, method=method, **self._parameters)
  2375. @_dispatch
  2376. def _ccdf2_dispatch(self, x, y, *, method=None, **params):
  2377. if self._overrides('_ccdf2_formula'):
  2378. method = self._ccdf2_formula
  2379. else:
  2380. method = self._ccdf2_addition
  2381. return method
  2382. def _ccdf2_formula(self, x, y, **params):
  2383. raise NotImplementedError(self._not_implemented)
  2384. def _ccdf2_addition(self, x, y, **params):
  2385. cdf_x = self._cdf_dispatch(x, **params)
  2386. ccdf_y = self._ccdf_dispatch(y, **params)
  2387. # even if x > y, cdf(x, y) + ccdf(x,y) sums to 1
  2388. return cdf_x + ccdf_y
  2389. @_set_invalid_nan
  2390. def _ccdf1(self, x, *, method):
  2391. return self._ccdf_dispatch(x, method=method, **self._parameters)
  2392. @_dispatch
  2393. def _ccdf_dispatch(self, x, method=None, **params):
  2394. if self._overrides('_ccdf_formula'):
  2395. method = self._ccdf_formula
  2396. elif self._overrides('_logccdf_formula'):
  2397. method = self._ccdf_logexp
  2398. elif self._overrides('_cdf_formula'):
  2399. method = self._ccdf_complement_safe
  2400. else:
  2401. method = self._ccdf_quadrature
  2402. return method
  2403. def _ccdf_formula(self, x, **params):
  2404. raise NotImplementedError(self._not_implemented)
  2405. def _ccdf_logexp(self, x, **params):
  2406. return np.exp(self._logccdf_dispatch(x, **params))
  2407. def _ccdf_complement(self, x, **params):
  2408. return 1 - self._cdf_dispatch(x, **params)
  2409. def _ccdf_complement_safe(self, x, **params):
  2410. cdf = self._cdf_dispatch(x, **params)
  2411. out = 1 - cdf
  2412. eps = np.finfo(self._dtype).eps
  2413. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2414. mask = tol * out < np.spacing(cdf)
  2415. if np.any(mask):
  2416. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2417. for key, val in params.items()}
  2418. out = np.asarray(out)
  2419. out[mask] = self._ccdf_quadrature(x[mask], **params_mask)
  2420. return out[()]
  2421. def _ccdf_quadrature(self, x, **params):
  2422. _, b = self._support(**params)
  2423. return self._quadrature(self._pxf_dispatch, limits=(x, b),
  2424. params=params)
  2425. ## Inverse cumulative distribution functions
  2426. @_set_invalid_nan
  2427. def ilogcdf(self, logp, /, *, method=None):
  2428. return self._ilogcdf_dispatch(logp, method=method, **self._parameters)
  2429. @_dispatch
  2430. def _ilogcdf_dispatch(self, x, method=None, **params):
  2431. if self._overrides('_ilogcdf_formula'):
  2432. method = self._ilogcdf_formula
  2433. elif self._overrides('_ilogccdf_formula'):
  2434. method = self._ilogcdf_complement
  2435. else:
  2436. method = self._ilogcdf_inversion
  2437. return method
  2438. def _ilogcdf_formula(self, x, **params):
  2439. raise NotImplementedError(self._not_implemented)
  2440. def _ilogcdf_complement(self, x, **params):
  2441. return self._ilogccdf_dispatch(_log1mexp(x), **params)
  2442. def _ilogcdf_inversion(self, x, **params):
  2443. return self._solve_bounded_continuous(self._logcdf_dispatch, x, params=params)
  2444. @_set_invalid_nan
  2445. def icdf(self, p, /, *, method=None):
  2446. return self._icdf_dispatch(p, method=method, **self._parameters)
  2447. @_dispatch
  2448. def _icdf_dispatch(self, x, method=None, **params):
  2449. if self._overrides('_icdf_formula'):
  2450. method = self._icdf_formula
  2451. elif self._overrides('_iccdf_formula'):
  2452. method = self._icdf_complement_safe
  2453. else:
  2454. method = self._icdf_inversion
  2455. return method
  2456. def _icdf_formula(self, x, **params):
  2457. raise NotImplementedError(self._not_implemented)
  2458. def _icdf_complement(self, x, **params):
  2459. return self._iccdf_dispatch(1 - x, **params)
  2460. def _icdf_complement_safe(self, x, **params):
  2461. out = self._icdf_complement(x, **params)
  2462. eps = np.finfo(self._dtype).eps
  2463. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2464. mask = tol * x < np.spacing(1 - x)
  2465. if np.any(mask):
  2466. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2467. for key, val in params.items()}
  2468. out = np.asarray(out)
  2469. out[mask] = self._icdf_inversion(x[mask], *params_mask)
  2470. return out[()]
  2471. def _icdf_inversion(self, x, **params):
  2472. return self._solve_bounded_continuous(self._cdf_dispatch, x, params=params)
  2473. @_set_invalid_nan
  2474. def ilogccdf(self, logp, /, *, method=None):
  2475. return self._ilogccdf_dispatch(logp, method=method, **self._parameters)
  2476. @_dispatch
  2477. def _ilogccdf_dispatch(self, x, method=None, **params):
  2478. if self._overrides('_ilogccdf_formula'):
  2479. method = self._ilogccdf_formula
  2480. elif self._overrides('_ilogcdf_formula'):
  2481. method = self._ilogccdf_complement
  2482. else:
  2483. method = self._ilogccdf_inversion
  2484. return method
  2485. def _ilogccdf_formula(self, x, **params):
  2486. raise NotImplementedError(self._not_implemented)
  2487. def _ilogccdf_complement(self, x, **params):
  2488. return self._ilogcdf_dispatch(_log1mexp(x), **params)
  2489. def _ilogccdf_inversion(self, x, **params):
  2490. return self._solve_bounded_continuous(self._logccdf_dispatch, x, params=params)
  2491. @_set_invalid_nan
  2492. def iccdf(self, p, /, *, method=None):
  2493. return self._iccdf_dispatch(p, method=method, **self._parameters)
  2494. @_dispatch
  2495. def _iccdf_dispatch(self, x, method=None, **params):
  2496. if self._overrides('_iccdf_formula'):
  2497. method = self._iccdf_formula
  2498. elif self._overrides('_icdf_formula'):
  2499. method = self._iccdf_complement_safe
  2500. else:
  2501. method = self._iccdf_inversion
  2502. return method
  2503. def _iccdf_formula(self, x, **params):
  2504. raise NotImplementedError(self._not_implemented)
  2505. def _iccdf_complement(self, x, **params):
  2506. return self._icdf_dispatch(1 - x, **params)
  2507. def _iccdf_complement_safe(self, x, **params):
  2508. out = self._iccdf_complement(x, **params)
  2509. eps = np.finfo(self._dtype).eps
  2510. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2511. mask = tol * x < np.spacing(1 - x)
  2512. if np.any(mask):
  2513. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2514. for key, val in params.items()}
  2515. out = np.asarray(out)
  2516. out[mask] = self._iccdf_inversion(x[mask], *params_mask)
  2517. return out[()]
  2518. def _iccdf_inversion(self, x, **params):
  2519. return self._solve_bounded_continuous(self._ccdf_dispatch, x, params=params)
  2520. ### Sampling Functions
  2521. # The following functions for drawing samples from the distribution are
  2522. # exposed via a public method that accepts one positional argument - the
  2523. # shape of the sample - and keyword options (but not distribution
  2524. # parameters).
  2525. # sample
  2526. # ~~qmc_sample~~ built into sample now
  2527. #
  2528. # Common keyword options include:
  2529. # method - a string that indicates which method should be used to compute
  2530. # the quantity (e.g. a formula or numerical integration).
  2531. # rng - the NumPy Generator/SciPy QMCEngine object to used for drawing numbers.
  2532. #
  2533. # Input/output validation is included in each function, since there is
  2534. # little code to be shared.
  2535. # These are the methods meant to be called by users.
  2536. #
  2537. # Each public method calls a private "dispatch" method that
  2538. # determines which "method" (strategy for calculating the desired quantity)
  2539. # to use by default and, via the `@_dispatch` decorator, calls the
  2540. # method and computes the result.
  2541. # Each dispatch method can designate the responsibility of sampling to any
  2542. # of several "implementation" methods. These methods accept only
  2543. # `**params`, the parameter dictionary passed from the public method via
  2544. # the "dispatch" method.
  2545. # See the note corresponding with the "Distribution Parameters" for more
  2546. # information.
  2547. # TODO:
  2548. # - should we accept a QRNG with `d != 1`?
  2549. def sample(self, shape=(), *, method=None, rng=None):
  2550. # needs output validation to ensure that developer returns correct
  2551. # dtype and shape
  2552. sample_shape = (shape,) if not np.iterable(shape) else tuple(shape)
  2553. full_shape = sample_shape + self._shape
  2554. rng = np.random.default_rng(rng) if not isinstance(rng, qmc.QMCEngine) else rng
  2555. res = self._sample_dispatch(full_shape, method=method, rng=rng,
  2556. **self._parameters)
  2557. return res.astype(self._dtype, copy=False)
  2558. @_dispatch
  2559. def _sample_dispatch(self, full_shape, *, method, rng, **params):
  2560. # make sure that tests catch if sample is 0d array
  2561. if self._overrides('_sample_formula') and not isinstance(rng, qmc.QMCEngine):
  2562. method = self._sample_formula
  2563. else:
  2564. method = self._sample_inverse_transform
  2565. return method
  2566. def _sample_formula(self, full_shape, *, rng, **params):
  2567. raise NotImplementedError(self._not_implemented)
  2568. def _sample_inverse_transform(self, full_shape, *, rng, **params):
  2569. if isinstance(rng, qmc.QMCEngine):
  2570. uniform = self._qmc_uniform(full_shape, qrng=rng, **params)
  2571. else:
  2572. uniform = rng.random(size=full_shape, dtype=self._dtype)
  2573. return self._icdf_dispatch(uniform, **params)
  2574. def _qmc_uniform(self, full_shape, *, qrng, **params):
  2575. # Generate QMC uniform sample(s) on unit interval with specified shape;
  2576. # if `sample_shape != ()`, then each slice along axis 0 is independent.
  2577. sample_shape = full_shape[:len(full_shape)-len(self._shape)]
  2578. # Determine the number of independent sequences and the length of each.
  2579. n_low_discrepancy = sample_shape[0] if sample_shape else 1
  2580. n_independent = math.prod(full_shape[1:] if sample_shape else full_shape)
  2581. # For each independent sequence, we'll need a new QRNG of the appropriate class
  2582. # with its own RNG. (If scramble=False, we don't really need all the separate
  2583. # rngs, but I'm not going to add a special code path right now.)
  2584. rngs = _rng_spawn(qrng.rng, n_independent)
  2585. qrng_class = qrng.__class__
  2586. kwargs = dict(d=1, scramble=qrng.scramble, optimization=qrng._optimization)
  2587. if isinstance(qrng, qmc.Sobol):
  2588. kwargs['bits'] = qrng.bits
  2589. # Draw uniform low-discrepancy sequences scrambled with each RNG
  2590. uniforms = []
  2591. for rng in rngs:
  2592. qrng = qrng_class(seed=rng, **kwargs)
  2593. uniform = qrng.random(n_low_discrepancy)
  2594. uniform = uniform.reshape(n_low_discrepancy if sample_shape else ())[()]
  2595. uniforms.append(uniform)
  2596. # Reorder the axes and ensure that the shape is correct
  2597. uniform = np.moveaxis(np.stack(uniforms), -1, 0) if uniforms else np.asarray([])
  2598. return uniform.reshape(full_shape)
  2599. ### Moments
  2600. # The `moment` method accepts two positional arguments - the order and kind
  2601. # (raw, central, or standard) of the moment - and a keyword option:
  2602. # method - a string that indicates which method should be used to compute
  2603. # the quantity (e.g. a formula or numerical integration).
  2604. # Like the distribution properties, input/output validation is provided by
  2605. # the `_set_invalid_nan_property` decorator.
  2606. #
  2607. # Unlike most public methods above, `moment` dispatches to one of three
  2608. # private methods - one for each 'kind'. Like most *public* methods above,
  2609. # each of these private methods calls a private "dispatch" method that
  2610. # determines which "method" (strategy for calculating the desired quantity)
  2611. # to use. Also, each dispatch method can designate the responsibility
  2612. # computing the moment to one of several "implementation" methods.
  2613. # Unlike the dispatch methods above, however, the `@_dispatch` decorator
  2614. # is not used, and both logic and method calls are included in the function
  2615. # itself.
  2616. # Instead of determining which method will be used based solely on the
  2617. # implementation methods available and calling only the corresponding
  2618. # implementation method, *all* the implementation methods are called
  2619. # in sequence until one returns the desired information. When an
  2620. # implementation methods cannot provide the requested information, it
  2621. # returns the object None (which is distinct from arrays with NaNs or infs,
  2622. # which are valid values of moments).
  2623. # The reason for this approach is that although formulae for the first
  2624. # few moments of a distribution may be found, general formulae that work
  2625. # for all orders are not always easy to find. This approach allows the
  2626. # developer to write "formula" implementation functions that return the
  2627. # desired moment when it is available and None otherwise.
  2628. #
  2629. # Note that the first implementation method called is a cache. This is
  2630. # important because lower-order moments are often needed to compute
  2631. # higher moments from formulae, so we eliminate redundant calculations
  2632. # when moments of several orders are needed.
  2633. @cached_property
  2634. def _moment_methods(self):
  2635. return {'cache', 'formula', 'transform',
  2636. 'normalize', 'general', 'quadrature'}
  2637. @property
  2638. def _zero(self):
  2639. return self._constants()[0]
  2640. @property
  2641. def _one(self):
  2642. return self._constants()[1]
  2643. def _constants(self):
  2644. if self._constant_cache is not None:
  2645. return self._constant_cache
  2646. constants = self._preserve_type([0, 1])
  2647. if self.cache_policy != _NO_CACHE:
  2648. self._constant_cache = constants
  2649. return constants
  2650. @_set_invalid_nan_property
  2651. def moment(self, order=1, kind='raw', *, method=None):
  2652. kinds = {'raw': self._moment_raw,
  2653. 'central': self._moment_central,
  2654. 'standardized': self._moment_standardized}
  2655. order = self._validate_order_kind(order, kind, kinds)
  2656. moment_kind = kinds[kind]
  2657. return moment_kind(order, method=method)
  2658. def _moment_raw(self, order=1, *, method=None):
  2659. """Raw distribution moment about the origin."""
  2660. # Consider exposing the point about which moments are taken as an
  2661. # option. This is easy to support, since `_moment_transform_center`
  2662. # does all the work.
  2663. methods = self._moment_methods if method is None else {method}
  2664. return self._moment_raw_dispatch(order, methods=methods, **self._parameters)
  2665. def _moment_raw_dispatch(self, order, *, methods, **params):
  2666. moment = None
  2667. if 'cache' in methods:
  2668. moment = self._moment_raw_cache.get(order, None)
  2669. if moment is None and 'formula' in methods:
  2670. moment = self._moment_raw_formula(order, **params)
  2671. if moment is None and 'transform' in methods and order > 1:
  2672. moment = self._moment_raw_transform(order, **params)
  2673. if moment is None and 'general' in methods:
  2674. moment = self._moment_raw_general(order, **params)
  2675. if moment is None and 'quadrature' in methods:
  2676. moment = self._moment_from_pxf(order, center=self._zero, **params)
  2677. if moment is None and 'quadrature_icdf' in methods:
  2678. moment = self._moment_integrate_icdf(order, center=self._zero, **params)
  2679. if moment is not None and self.cache_policy != _NO_CACHE:
  2680. self._moment_raw_cache[order] = moment
  2681. return moment
  2682. def _moment_raw_formula(self, order, **params):
  2683. return None
  2684. def _moment_raw_transform(self, order, **params):
  2685. central_moments = []
  2686. for i in range(int(order) + 1):
  2687. methods = {'cache', 'formula', 'normalize', 'general'}
  2688. moment_i = self._moment_central_dispatch(order=i, methods=methods, **params)
  2689. if moment_i is None:
  2690. return None
  2691. central_moments.append(moment_i)
  2692. # Doesn't make sense to get the mean by "transform", since that's
  2693. # how we got here. Questionable whether 'quadrature' should be here.
  2694. mean_methods = {'cache', 'formula', 'quadrature'}
  2695. mean = self._moment_raw_dispatch(self._one, methods=mean_methods, **params)
  2696. if mean is None:
  2697. return None
  2698. moment = self._moment_transform_center(order, central_moments, mean, self._zero)
  2699. return moment
  2700. def _moment_raw_general(self, order, **params):
  2701. # This is the only general formula for a raw moment of a probability
  2702. # distribution
  2703. return self._one if order == 0 else None
  2704. def _moment_central(self, order=1, *, method=None):
  2705. """Distribution moment about the mean."""
  2706. methods = self._moment_methods if method is None else {method}
  2707. return self._moment_central_dispatch(order, methods=methods, **self._parameters)
  2708. def _moment_central_dispatch(self, order, *, methods, **params):
  2709. moment = None
  2710. if 'cache' in methods:
  2711. moment = self._moment_central_cache.get(order, None)
  2712. if moment is None and 'formula' in methods:
  2713. moment = self._moment_central_formula(order, **params)
  2714. if moment is None and 'transform' in methods:
  2715. moment = self._moment_central_transform(order, **params)
  2716. if moment is None and 'normalize' in methods and order > 2:
  2717. moment = self._moment_central_normalize(order, **params)
  2718. if moment is None and 'general' in methods:
  2719. moment = self._moment_central_general(order, **params)
  2720. if moment is None and 'quadrature' in methods:
  2721. mean = self._moment_raw_dispatch(self._one, **params,
  2722. methods=self._moment_methods)
  2723. moment = self._moment_from_pxf(order, center=mean, **params)
  2724. if moment is None and 'quadrature_icdf' in methods:
  2725. mean = self._moment_raw_dispatch(self._one, **params,
  2726. methods=self._moment_methods)
  2727. moment = self._moment_integrate_icdf(order, center=mean, **params)
  2728. if moment is not None and self.cache_policy != _NO_CACHE:
  2729. self._moment_central_cache[order] = moment
  2730. return moment
  2731. def _moment_central_formula(self, order, **params):
  2732. return None
  2733. def _moment_central_transform(self, order, **params):
  2734. raw_moments = []
  2735. for i in range(int(order) + 1):
  2736. methods = {'cache', 'formula', 'general'}
  2737. moment_i = self._moment_raw_dispatch(order=i, methods=methods, **params)
  2738. if moment_i is None:
  2739. return None
  2740. raw_moments.append(moment_i)
  2741. mean_methods = self._moment_methods
  2742. mean = self._moment_raw_dispatch(self._one, methods=mean_methods, **params)
  2743. moment = self._moment_transform_center(order, raw_moments, self._zero, mean)
  2744. return moment
  2745. def _moment_central_normalize(self, order, **params):
  2746. methods = {'cache', 'formula', 'general'}
  2747. standard_moment = self._moment_standardized_dispatch(order, **params,
  2748. methods=methods)
  2749. if standard_moment is None:
  2750. return None
  2751. var = self._moment_central_dispatch(2, methods=self._moment_methods, **params)
  2752. return standard_moment*var**(order/2)
  2753. def _moment_central_general(self, order, **params):
  2754. general_central_moments = {0: self._one, 1: self._zero}
  2755. return general_central_moments.get(order, None)
  2756. def _moment_standardized(self, order=1, *, method=None):
  2757. """Standardized distribution moment."""
  2758. methods = self._moment_methods if method is None else {method}
  2759. return self._moment_standardized_dispatch(order, methods=methods,
  2760. **self._parameters)
  2761. def _moment_standardized_dispatch(self, order, *, methods, **params):
  2762. moment = None
  2763. if 'cache' in methods:
  2764. moment = self._moment_standardized_cache.get(order, None)
  2765. if moment is None and 'formula' in methods:
  2766. moment = self._moment_standardized_formula(order, **params)
  2767. if moment is None and 'normalize' in methods:
  2768. moment = self._moment_standardized_normalize(order, False, **params)
  2769. if moment is None and 'general' in methods:
  2770. moment = self._moment_standardized_general(order, **params)
  2771. if moment is None and 'normalize' in methods:
  2772. moment = self._moment_standardized_normalize(order, True, **params)
  2773. if moment is not None and self.cache_policy != _NO_CACHE:
  2774. self._moment_standardized_cache[order] = moment
  2775. return moment
  2776. def _moment_standardized_formula(self, order, **params):
  2777. return None
  2778. def _moment_standardized_normalize(self, order, use_quadrature, **params):
  2779. methods = ({'quadrature'} if use_quadrature
  2780. else {'cache', 'formula', 'transform'})
  2781. central_moment = self._moment_central_dispatch(order, **params,
  2782. methods=methods)
  2783. if central_moment is None:
  2784. return None
  2785. var = self._moment_central_dispatch(2, methods=self._moment_methods,
  2786. **params)
  2787. return central_moment/var**(order/2)
  2788. def _moment_standardized_general(self, order, **params):
  2789. general_standard_moments = {0: self._one, 1: self._zero, 2: self._one}
  2790. return general_standard_moments.get(order, None)
  2791. def _moment_from_pxf(self, order, center, **params):
  2792. def integrand(x, order, center, **params):
  2793. pxf = self._pxf_dispatch(x, **params)
  2794. return pxf*(x-center)**order
  2795. return self._quadrature(integrand, args=(order, center), params=params)
  2796. def _moment_integrate_icdf(self, order, center, **params):
  2797. def integrand(x, order, center, **params):
  2798. x = self._icdf_dispatch(x, **params)
  2799. return (x-center)**order
  2800. return self._quadrature(integrand, limits=(0., 1.),
  2801. args=(order, center), params=params)
  2802. def _moment_transform_center(self, order, moment_as, a, b):
  2803. a, b, *moment_as = np.broadcast_arrays(a, b, *moment_as)
  2804. n = order
  2805. i = np.arange(n+1).reshape([-1]+[1]*a.ndim) # orthogonal to other axes
  2806. i = self._preserve_type(i)
  2807. n_choose_i = special.binom(n, i)
  2808. with np.errstate(invalid='ignore'): # can happen with infinite moment
  2809. moment_b = np.sum(n_choose_i*moment_as*(a-b)**(n-i), axis=0)
  2810. return moment_b
  2811. def _logmoment(self, order=1, *, logcenter=None, standardized=False):
  2812. # make this private until it is worked into moment
  2813. if logcenter is None or standardized is True:
  2814. logmean = self._logmoment_quad(self._one, -np.inf, **self._parameters)
  2815. else:
  2816. logmean = None
  2817. logcenter = logmean if logcenter is None else logcenter
  2818. res = self._logmoment_quad(order, logcenter, **self._parameters)
  2819. if standardized:
  2820. logvar = self._logmoment_quad(2, logmean, **self._parameters)
  2821. res = res - logvar * (order/2)
  2822. return res
  2823. def _logmoment_quad(self, order, logcenter, **params):
  2824. def logintegrand(x, order, logcenter, **params):
  2825. logpdf = self._logpxf_dispatch(x, **params)
  2826. return logpdf + order * _logexpxmexpy(np.log(x + 0j), logcenter)
  2827. ## if logx == logcenter, `_logexpxmexpy` returns (-inf + 0j)
  2828. ## multiplying by order produces (-inf + nan j) - bad
  2829. ## We're skipping logmoment tests, so we might don't need to fix
  2830. ## now, but if we ever do use run them, this might help:
  2831. # logx = np.log(x+0j)
  2832. # out = np.asarray(logpdf + order*_logexpxmexpy(logx, logcenter))
  2833. # i = (logx == logcenter)
  2834. # out[i] = logpdf[i]
  2835. # return out
  2836. return self._quadrature(logintegrand, args=(order, logcenter),
  2837. params=params, log=True)
  2838. ### Convenience
  2839. def plot(self, x='x', y=None, *, t=None, ax=None):
  2840. r"""Plot a function of the distribution.
  2841. Convenience function for quick visualization of the distribution
  2842. underlying the random variable.
  2843. Parameters
  2844. ----------
  2845. x, y : str, optional
  2846. String indicating the quantities to be used as the abscissa and
  2847. ordinate (horizontal and vertical coordinates), respectively.
  2848. Defaults are ``'x'`` (the domain of the random variable) and either
  2849. ``'pdf'`` (the probability density function) (continuous) or
  2850. ``'pdf'`` (the probability density function) (discrete).
  2851. Valid values are:
  2852. 'x', 'pdf', 'pmf', 'cdf', 'ccdf', 'icdf', 'iccdf', 'logpdf', 'logpmf',
  2853. 'logcdf', 'logccdf', 'ilogcdf', 'ilogccdf'.
  2854. t : 3-tuple of (str, float, float), optional
  2855. Tuple indicating the limits within which the quantities are plotted.
  2856. The default is ``('cdf', 0.0005, 0.9995)`` if the domain is infinite,
  2857. indicating that the central 99.9% of the distribution is to be shown;
  2858. otherwise, endpoints of the support are used where they are finite.
  2859. Valid values are:
  2860. 'x', 'cdf', 'ccdf', 'icdf', 'iccdf', 'logcdf', 'logccdf',
  2861. 'ilogcdf', 'ilogccdf'.
  2862. ax : `matplotlib.axes`, optional
  2863. Axes on which to generate the plot. If not provided, use the
  2864. current axes.
  2865. Returns
  2866. -------
  2867. ax : `matplotlib.axes`
  2868. Axes on which the plot was generated.
  2869. The plot can be customized by manipulating this object.
  2870. Examples
  2871. --------
  2872. Instantiate a distribution with the desired parameters:
  2873. >>> import numpy as np
  2874. >>> import matplotlib.pyplot as plt
  2875. >>> from scipy import stats
  2876. >>> X = stats.Normal(mu=1., sigma=2.)
  2877. Plot the PDF over the central 99.9% of the distribution.
  2878. Compare against a histogram of a random sample.
  2879. >>> ax = X.plot()
  2880. >>> sample = X.sample(10000)
  2881. >>> ax.hist(sample, density=True, bins=50, alpha=0.5)
  2882. >>> plt.show()
  2883. Plot ``logpdf(x)`` as a function of ``x`` in the left tail,
  2884. where the log of the CDF is between -10 and ``np.log(0.5)``.
  2885. >>> X.plot('x', 'logpdf', t=('logcdf', -10, np.log(0.5)))
  2886. >>> plt.show()
  2887. Plot the PDF of the normal distribution as a function of the
  2888. CDF for various values of the scale parameter.
  2889. >>> X = stats.Normal(mu=0., sigma=[0.5, 1., 2])
  2890. >>> X.plot('cdf', 'pdf')
  2891. >>> plt.show()
  2892. """
  2893. # Strategy: given t limits, get quantile limits. Form grid of
  2894. # quantiles, compute requested x and y at quantiles, and plot.
  2895. # Currently, the grid of quantiles is always linearly spaced.
  2896. # Instead of always computing linearly-spaced quantiles, it
  2897. # would be better to choose:
  2898. # a) quantiles or probabilities
  2899. # b) linearly or logarithmically spaced
  2900. # based on the specified `t`.
  2901. # TODO:
  2902. # - smart spacing of points
  2903. # - when the parameters of the distribution are an array,
  2904. # use the full range of abscissae for all curves
  2905. discrete = isinstance(self, DiscreteDistribution)
  2906. t_is_quantile = {'x', 'icdf', 'iccdf', 'ilogcdf', 'ilogccdf'}
  2907. t_is_probability = {'cdf', 'ccdf', 'logcdf', 'logccdf'}
  2908. valid_t = t_is_quantile.union(t_is_probability)
  2909. valid_xy = valid_t.union({'pdf', 'logpdf', 'pmf', 'logpmf'})
  2910. y_default = 'pmf' if discrete else 'pdf'
  2911. y = y_default if y is None else y
  2912. ndim = self._ndim
  2913. x_name, y_name = x, y
  2914. t_name = 'cdf' if t is None else t[0]
  2915. a, b = self.support()
  2916. tliml_default = 0 if np.all(np.isfinite(a)) else 0.0005
  2917. tliml = tliml_default if t is None else t[1]
  2918. tlimr_default = 1 if np.all(np.isfinite(b)) else 0.9995
  2919. tlimr = tlimr_default if t is None else t[2]
  2920. tlim = np.asarray([tliml, tlimr])
  2921. tlim = tlim[:, np.newaxis] if ndim else tlim
  2922. # pdf/logpdf are not valid for `t` because we can't easily invert them
  2923. message = (f'Argument `t` of `{self.__class__.__name__}.plot` "'
  2924. f'must be one of {valid_t}')
  2925. if t_name not in valid_t:
  2926. raise ValueError(message)
  2927. message = (f'Argument `x` of `{self.__class__.__name__}.plot` "'
  2928. f'must be one of {valid_xy}')
  2929. if x_name not in valid_xy:
  2930. raise ValueError(message)
  2931. message = (f'Argument `y` of `{self.__class__.__name__}.plot` "'
  2932. f'must be one of {valid_xy}')
  2933. if y_name not in valid_xy:
  2934. raise ValueError(message)
  2935. # This could just be a warning
  2936. message = (f'`{self.__class__.__name__}.plot` was called on a random '
  2937. 'variable with at least one invalid shape parameters. When '
  2938. 'a parameter is invalid, no plot can be shown.')
  2939. if self._any_invalid:
  2940. raise ValueError(message)
  2941. # We could automatically ravel, but do we want to? For now, raise.
  2942. message = ("To use `plot`, distribution parameters must be "
  2943. "scalars or arrays with one or fewer dimensions.")
  2944. if ndim > 1:
  2945. raise ValueError(message)
  2946. try:
  2947. import matplotlib.pyplot as plt # noqa: F401, E402
  2948. except ModuleNotFoundError as exc:
  2949. message = ("`matplotlib` must be installed to use "
  2950. f"`{self.__class__.__name__}.plot`.")
  2951. raise ModuleNotFoundError(message) from exc
  2952. ax = plt.gca() if ax is None else ax
  2953. # get quantile limits given t limits
  2954. qlim = tlim if t_name in t_is_quantile else getattr(self, 'i'+t_name)(tlim)
  2955. message = (f"`{self.__class__.__name__}.plot` received invalid input for `t`: "
  2956. f"calling {'i'+t_name}({tlim}) produced {qlim}.")
  2957. if not np.all(np.isfinite(qlim)):
  2958. raise ValueError(message)
  2959. # form quantile grid
  2960. if discrete and x_name in t_is_quantile:
  2961. # should probably aggregate for large ranges
  2962. q = np.arange(np.min(qlim[0]), np.max(qlim[1]) + 1)
  2963. q = q[:, np.newaxis] if ndim else q
  2964. else:
  2965. grid = np.linspace(0, 1, 300)
  2966. grid = grid[:, np.newaxis] if ndim else grid
  2967. q = qlim[0] + (qlim[1] - qlim[0]) * grid
  2968. q = np.round(q) if discrete else q
  2969. # compute requested x and y at quantile grid
  2970. x = q if x_name in t_is_quantile else getattr(self, x_name)(q)
  2971. y = q if y_name in t_is_quantile else getattr(self, y_name)(q)
  2972. # make plot
  2973. x, y = np.broadcast_arrays(x.T, np.atleast_2d(y.T))
  2974. for xi, yi in zip(x, y): # plot is vectorized, but bar/step don't seem to be
  2975. if discrete and x_name in t_is_quantile and y_name == 'pmf':
  2976. # should this just be a step plot, too?
  2977. ax.bar(xi, yi, alpha=np.sqrt(1/y.shape[0])) # alpha heuristic
  2978. elif discrete and x_name in t_is_quantile:
  2979. values = yi
  2980. edges = np.concatenate((xi, [xi[-1]+1]))
  2981. ax.stairs(values, edges, baseline=None)
  2982. else:
  2983. ax.plot(xi, yi)
  2984. ax.set_xlabel(f"${x_name}$")
  2985. ax.set_ylabel(f"${y_name}$")
  2986. ax.set_title(str(self))
  2987. # only need a legend if distribution has parameters
  2988. if len(self._parameters):
  2989. label = []
  2990. parameters = self._parameterization.parameters
  2991. param_names = list(parameters)
  2992. param_arrays = [np.atleast_1d(self._parameters[pname])
  2993. for pname in param_names]
  2994. for param_vals in zip(*param_arrays):
  2995. assignments = [f"${parameters[name].symbol}$ = {val:.4g}"
  2996. for name, val in zip(param_names, param_vals)]
  2997. label.append(", ".join(assignments))
  2998. ax.legend(label)
  2999. return ax
  3000. ### Fitting
  3001. # All methods above treat the distribution parameters as fixed, and the
  3002. # variable argument may be a quantile or probability. The fitting functions
  3003. # are fundamentally different because the quantiles (often observations)
  3004. # are considered to be fixed, and the distribution parameters are the
  3005. # variables. In a sense, they are like an inverse of the sampling
  3006. # functions.
  3007. #
  3008. # At first glance, it would seem ideal for `fit` to be a classmethod,
  3009. # called like `LogUniform.fit(sample=sample)`.
  3010. # I tried this. I insisted on it for a while. But if `fit` is a
  3011. # classmethod, it cannot call instance methods. If we want to support MLE,
  3012. # MPS, MoM, MoLM, then we end up with most of the distribution functions
  3013. # above needing to be classmethods, too. All state information, such as
  3014. # tolerances and the underlying distribution of `ShiftedScaledDistribution`
  3015. # and `OrderStatisticDistribution`, would need to be passed into all
  3016. # methods. And I'm not really sure how we would call `fit` as a
  3017. # classmethod of a transformed distribution - maybe
  3018. # ShiftedScaledDistribution.fit would accept the class of the
  3019. # shifted/scaled distribution as an argument?
  3020. #
  3021. # In any case, it was a conscious decision for the infrastructure to
  3022. # treat the parameters as "fixed" and the quantile/percentile arguments
  3023. # as "variable". There are a lot of advantages to this structure, and I
  3024. # don't think the fact that a few methods reverse the fixed and variable
  3025. # quantities should make us question that choice. It can still accomodate
  3026. # these methods reasonably efficiently.
  3027. class ContinuousDistribution(UnivariateDistribution):
  3028. def _overrides(self, method_name):
  3029. if method_name in {'_logpmf_formula', '_pmf_formula'}:
  3030. return True
  3031. return super()._overrides(method_name)
  3032. def _pmf_formula(self, x, **params):
  3033. return np.zeros_like(x)
  3034. def _logpmf_formula(self, x, **params):
  3035. return np.full_like(x, -np.inf)
  3036. def _pxf_dispatch(self, x, *, method=None, **params):
  3037. return self._pdf_dispatch(x, method=method, **params)
  3038. def _logpxf_dispatch(self, x, *, method=None, **params):
  3039. return self._logpdf_dispatch(x, method=method, **params)
  3040. def _solve_bounded_continuous(self, func, p, params, xatol=None):
  3041. return self._solve_bounded(func, p, params=params, xatol=xatol).x
  3042. class DiscreteDistribution(UnivariateDistribution):
  3043. def _overrides(self, method_name):
  3044. if method_name in {'_logpdf_formula', '_pdf_formula'}:
  3045. return True
  3046. return super()._overrides(method_name)
  3047. def _logpdf_formula(self, x, **params):
  3048. if params:
  3049. p = next(iter(params.values()))
  3050. nan_result = np.isnan(x) | np.isnan(p)
  3051. else:
  3052. nan_result = np.isnan(x)
  3053. return np.where(nan_result, np.nan, np.inf)
  3054. def _pdf_formula(self, x, **params):
  3055. if params:
  3056. p = next(iter(params.values()))
  3057. nan_result = np.isnan(x) | np.isnan(p)
  3058. else:
  3059. nan_result = np.isnan(x)
  3060. return np.where(nan_result, np.nan, np.inf)
  3061. def _pxf_dispatch(self, x, *, method=None, **params):
  3062. return self._pmf_dispatch(x, method=method, **params)
  3063. def _logpxf_dispatch(self, x, *, method=None, **params):
  3064. return self._logpmf_dispatch(x, method=method, **params)
  3065. def _cdf_quadrature(self, x, **params):
  3066. return super()._cdf_quadrature(np.floor(x), **params)
  3067. def _logcdf_quadrature(self, x, **params):
  3068. return super()._logcdf_quadrature(np.floor(x), **params)
  3069. def _ccdf_quadrature(self, x, **params):
  3070. return super()._ccdf_quadrature(np.floor(x + 1), **params)
  3071. def _logccdf_quadrature(self, x, **params):
  3072. return super()._logccdf_quadrature(np.floor(x + 1), **params)
  3073. def _cdf2(self, x, y, *, method):
  3074. raise NotImplementedError(
  3075. "Two argument cdf functions are currently only supported for "
  3076. "continuous distributions.")
  3077. def _ccdf2(self, x, y, *, method):
  3078. raise NotImplementedError(
  3079. "Two argument cdf functions are currently only supported for "
  3080. "continuous distributions.")
  3081. def _logcdf2(self, x, y, *, method):
  3082. raise NotImplementedError(
  3083. "Two argument cdf functions are currently only supported for "
  3084. "continuous distributions.")
  3085. def _logccdf2(self, x, y, *, method):
  3086. raise NotImplementedError(
  3087. "Two argument cdf functions are currently only supported for "
  3088. "continuous distributions.")
  3089. def _solve_bounded_discrete(self, func, p, params, comp):
  3090. res = self._solve_bounded(func, p, params=params, xatol=0.9)
  3091. x = np.asarray(np.floor(res.xr))
  3092. # if _chandrupatla finds exact inverse, the bracket may not have been reduced
  3093. # enough for `np.floor(res.x)` to be the appropriate value of `x`.
  3094. mask = res.fun == 0
  3095. x[mask] = np.floor(res.x[mask])
  3096. xmin, xmax = self._support(**params)
  3097. p, xmin, xmax = np.broadcast_arrays(p, xmin, xmax)
  3098. mask = comp(func(xmin, **params), p)
  3099. x[mask] = xmin[mask]
  3100. return x
  3101. def _base_discrete_inversion(self, p, func, comp, /, **params):
  3102. # For discrete distributions, icdf(p) is defined as the minimum n
  3103. # such that cdf(n) >= p. iccdf(p) is defined as the minimum n such
  3104. # that ccdf(n) <= p, or equivalently as iccdf(p) = icdf(1 - p).
  3105. # First try to find where cdf(x) == p for the continuous extension of the
  3106. # cdf. res.xl and res.xr will be a bracket for this root. The parameter
  3107. # xatol in solve_bounded controls the bracket width. We thus know that
  3108. # know cdf(res.xr) >= p, cdf(res.xl) <= p, and |res.xr - res.xl| <= 0.9.
  3109. # This means the minimum integer n such that cdf(n) >= p is either floor(x)
  3110. # or floor(x) + 1.
  3111. x = self._solve_bounded_discrete(func, p, params=params, comp=comp)
  3112. # comp should be <= for ccdf, >= for cdf.
  3113. f = func(x, **params)
  3114. res = np.where(comp(f, p), x, x + 1.0)
  3115. # xr is a bracket endpoint, and will usually be a finite value even when
  3116. # the computed result should be nan. We need to explicitly handle this
  3117. # case.
  3118. res[np.isnan(f) | np.isnan(p)] = np.nan
  3119. return res[()]
  3120. def _icdf_inversion(self, x, **params):
  3121. return self._base_discrete_inversion(x, self._cdf_dispatch,
  3122. np.greater_equal, **params)
  3123. def _ilogcdf_inversion(self, x, **params):
  3124. return self._base_discrete_inversion(x, self._logcdf_dispatch,
  3125. np.greater_equal, **params)
  3126. def _iccdf_inversion(self, x, **params):
  3127. return self._base_discrete_inversion(x, self._ccdf_dispatch,
  3128. np.less_equal, **params)
  3129. def _ilogccdf_inversion(self, x, **params):
  3130. return self._base_discrete_inversion(x, self._logccdf_dispatch,
  3131. np.less_equal, **params)
  3132. def _mode_optimization(self, **params):
  3133. # If `x` is the true mode of a unimodal continuous function, we can find
  3134. # the mode among integers by rounding in each direction and checking
  3135. # which is better. If the difference between `x` and the nearest integer
  3136. # is less than `xatol`, the computed value of `x` may end up on the wrong
  3137. # side of the nearest integer. Setting `xatol=0.5` guarantees that at most
  3138. # three integers need to be checked, the two nearest integers, ``floor(x)``
  3139. # and ``round(x)`` and the nearest integer other than these.
  3140. x = super()._mode_optimization(xatol=0.5, **params)
  3141. low, high = self.support()
  3142. xl, xr = np.floor(x), np.ceil(x)
  3143. nearest = np.round(x)
  3144. # Clip to stay within support. There will be redundant calculation
  3145. # when clipping since `xo` will be one of `xl` or `xr`, but let's
  3146. # keep the implementation simple for now.
  3147. xo = np.clip(nearest + np.copysign(1, nearest - x), low, high)
  3148. x = np.stack([xl, xo, xr])
  3149. idx = np.argmax(self._pmf_dispatch(x, **params), axis=0)
  3150. return np.choose(idx, [xl, xo, xr])
  3151. def _logentropy_quadrature(self, **params):
  3152. def logintegrand(x, **params):
  3153. logpmf = self._logpmf_dispatch(x, **params)
  3154. # Entropy summand is -pmf*log(pmf), so log-entropy summand is
  3155. # logpmf + log(logpmf) + pi*j. But pmf is always between 0 and 1,
  3156. # so logpmf is always negative, and so log(logpmf) = log(-logpmf) + pi*j.
  3157. # The two imaginary components "cancel" each other out (which we would
  3158. # expect because each term of the entropy summand is positive).
  3159. return np.where(np.isfinite(logpmf), logpmf + np.log(-logpmf), -np.inf)
  3160. return self._quadrature(logintegrand, params=params, log=True)
  3161. # Special case the names of some new-style distributions in `make_distribution`
  3162. _distribution_names = {
  3163. # Continuous
  3164. 'argus': 'ARGUS',
  3165. 'betaprime': 'BetaPrime',
  3166. 'chi2': 'ChiSquared',
  3167. 'crystalball': 'CrystalBall',
  3168. 'dgamma': 'DoubleGamma',
  3169. 'dweibull': 'DoubleWeibull',
  3170. 'expon': 'Exponential',
  3171. 'exponnorm': 'ExponentiallyModifiedNormal',
  3172. 'exponweib': 'ExponentialWeibull',
  3173. 'exponpow': 'ExponentialPower',
  3174. 'fatiguelife': 'FatigueLife',
  3175. 'foldcauchy': 'FoldedCauchy',
  3176. 'foldnorm': 'FoldedNormal',
  3177. 'genlogistic': 'GeneralizedLogistic',
  3178. 'gennorm': 'GeneralizedNormal',
  3179. 'genpareto': 'GeneralizedPareto',
  3180. 'genexpon': 'GeneralizedExponential',
  3181. 'genextreme': 'GeneralizedExtremeValue',
  3182. 'gausshyper': 'GaussHypergeometric',
  3183. 'gengamma': 'GeneralizedGamma',
  3184. 'genhalflogistic': 'GeneralizedHalfLogistic',
  3185. 'geninvgauss': 'GeneralizedInverseGaussian',
  3186. 'gumbel_r': 'Gumbel',
  3187. 'gumbel_l': 'ReflectedGumbel',
  3188. 'halfcauchy': 'HalfCauchy',
  3189. 'halflogistic': 'HalfLogistic',
  3190. 'halfnorm': 'HalfNormal',
  3191. 'halfgennorm': 'HalfGeneralizedNormal',
  3192. 'hypsecant': 'HyperbolicSecant',
  3193. 'invgamma': 'InverseGammma',
  3194. 'invgauss': 'InverseGaussian',
  3195. 'invweibull': 'InverseWeibull',
  3196. 'irwinhall': 'IrwinHall',
  3197. 'jf_skew_t': 'JonesFaddySkewT',
  3198. 'johnsonsb': 'JohnsonSB',
  3199. 'johnsonsu': 'JohnsonSU',
  3200. 'ksone': 'KSOneSided',
  3201. 'kstwo': 'KSTwoSided',
  3202. 'kstwobign': 'KSTwoSidedAsymptotic',
  3203. 'laplace_asymmetric': 'LaplaceAsymmetric',
  3204. 'levy_l': 'LevyLeft',
  3205. 'levy_stable': 'LevyStable',
  3206. 'loggamma': 'ExpGamma', # really the Exponential Gamma Distribution
  3207. 'loglaplace': 'LogLaplace',
  3208. 'lognorm': 'LogNormal',
  3209. 'loguniform': 'LogUniform',
  3210. 'ncx2': 'NoncentralChiSquared',
  3211. 'nct': 'NoncentralT',
  3212. 'norm': 'Normal',
  3213. 'norminvgauss': 'NormalInverseGaussian',
  3214. 'powerlaw': 'PowerLaw',
  3215. 'powernorm': 'PowerNormal',
  3216. 'rdist': 'R',
  3217. 'rel_breitwigner': 'RelativisticBreitWigner',
  3218. 'recipinvgauss': 'ReciprocalInverseGaussian',
  3219. 'reciprocal': 'LogUniform',
  3220. 'semicircular': 'SemiCircular',
  3221. 'skewcauchy': 'SkewCauchy',
  3222. 'skewnorm': 'SkewNormal',
  3223. 'studentized_range': 'StudentizedRange',
  3224. 't': 'StudentT',
  3225. 'trapezoid': 'Trapezoidal',
  3226. 'triang': 'Triangular',
  3227. 'truncexpon': 'TruncatedExponential',
  3228. 'truncnorm': 'TruncatedNormal',
  3229. 'truncpareto': 'TruncatedPareto',
  3230. 'truncweibull_min': 'TruncatedWeibull',
  3231. 'tukeylambda': 'TukeyLambda',
  3232. 'vonmises_line': 'VonMisesLine',
  3233. 'weibull_min': 'Weibull',
  3234. 'weibull_max': 'ReflectedWeibull',
  3235. 'wrapcauchy': 'WrappedCauchyLine',
  3236. # Discrete
  3237. 'betabinom': 'BetaBinomial',
  3238. 'betanbinom': 'BetaNegativeBinomial',
  3239. 'dlaplace': 'LaplaceDiscrete',
  3240. 'geom': 'Geometric',
  3241. 'hypergeom': 'Hypergeometric',
  3242. 'logser': 'LogarithmicSeries',
  3243. 'nbinom': 'NegativeBinomial',
  3244. 'nchypergeom_fisher': 'NoncentralHypergeometricFisher',
  3245. 'nchypergeom_wallenius': 'NoncentralHypergeometricWallenius',
  3246. 'nhypergeom': 'NegativeHypergeometric',
  3247. 'poisson_binom': 'PoissonBinomial',
  3248. 'randint': 'UniformDiscrete',
  3249. 'yulesimon': 'YuleSimon',
  3250. 'zipf': 'Zeta',
  3251. }
  3252. # beta, genextreme, gengamma, t, tukeylambda need work for 1D arrays
  3253. @xp_capabilities(np_only=True)
  3254. def make_distribution(dist):
  3255. """Generate a `UnivariateDistribution` class from a compatible object
  3256. The argument may be an instance of `rv_continuous` or an instance of
  3257. another class that satisfies the interface described below.
  3258. The returned value is a `ContinuousDistribution` subclass if the input is an
  3259. instance of `rv_continuous` or a `DiscreteDistribution` subclass if the input
  3260. is an instance of `rv_discrete`. Like any subclass of `UnivariateDistribution`,
  3261. it must be instantiated (i.e. by passing all shape parameters as keyword
  3262. arguments) before use. Once instantiated, the resulting object will have the
  3263. same interface as any other instance of `UnivariateDistribution`; e.g.,
  3264. `scipy.stats.Normal`, `scipy.stats.Binomial`.
  3265. .. note::
  3266. `make_distribution` does not work perfectly with all instances of
  3267. `rv_continuous`. Known failures include `levy_stable`, `vonmises`,
  3268. `hypergeom`, 'nchypergeom_fisher', 'nchypergeom_wallenius', and
  3269. `poisson_binom`. Some methods of some distributions will not support
  3270. array shape parameters.
  3271. Parameters
  3272. ----------
  3273. dist : `rv_continuous`
  3274. Instance of `rv_continuous`, `rv_discrete`, or an instance of any class with
  3275. the following attributes:
  3276. __make_distribution_version__ : str
  3277. A string containing the version number of SciPy in which this interface
  3278. is defined. The preferred interface may change in future SciPy versions,
  3279. in which case support for an old interface version may be deprecated
  3280. and eventually removed.
  3281. parameters : dict or tuple
  3282. If a dictionary, each key is the name of a parameter,
  3283. and the corresponding value is either a dictionary or tuple.
  3284. If the value is a dictionary, it may have the following items, with default
  3285. values used for entries which aren't present.
  3286. endpoints : tuple, default: (-inf, inf)
  3287. A tuple defining the lower and upper endpoints of the domain of the
  3288. parameter; allowable values are floats, the name (string) of another
  3289. parameter, or a callable taking parameters as keyword only
  3290. arguments and returning the numerical value of an endpoint for
  3291. given parameter values.
  3292. inclusive : tuple of bool, default: (False, False)
  3293. A tuple specifying whether the endpoints are included within the domain
  3294. of the parameter.
  3295. typical : tuple, default: ``endpoints``
  3296. Defining endpoints of a typical range of values of a parameter. Can be
  3297. used for sampling parameter values for testing. Behaves like the
  3298. ``endpoints`` tuple above, and should define a subinterval of the
  3299. domain given by ``endpoints``.
  3300. A tuple value ``(a, b)`` associated to a key in the ``parameters``
  3301. dictionary is equivalent to ``{endpoints: (a, b)}``.
  3302. Custom distributions with multiple parameterizations can be defined by
  3303. having the ``parameters`` attribute be a tuple of dictionaries with
  3304. the structure described above. In this case, ``dist``\'s class must also
  3305. define a method ``process_parameters`` to map between the different
  3306. parameterizations. It must take all parameters from all parameterizations
  3307. as optional keyword arguments and return a dictionary mapping parameters to
  3308. values, filling in values from other parameterizations using values from
  3309. the supplied parameterization. See example.
  3310. support : dict or tuple
  3311. A dictionary describing the support of the distribution or a tuple
  3312. describing the endpoints of the support. This behaves identically to
  3313. the values of the parameters dict described above, except that the key
  3314. ``typical`` is ignored.
  3315. The class **must** also define a ``pdf`` method and **may** define methods
  3316. ``logentropy``, ``entropy``, ``median``, ``mode``, ``logpdf``,
  3317. ``logcdf``, ``cdf``, ``logccdf``, ``ccdf``,
  3318. ``ilogcdf``, ``icdf``, ``ilogccdf``, ``iccdf``,
  3319. ``moment``, and ``sample``.
  3320. If defined, these methods must accept the parameters of the distribution as
  3321. keyword arguments and also accept any positional-only arguments accepted by
  3322. the corresponding method of `ContinuousDistribution`.
  3323. When multiple parameterizations are defined, these methods must accept
  3324. all parameters from all parameterizations. The ``moment`` method
  3325. must accept the ``order`` and ``kind`` arguments by position or keyword, but
  3326. may return ``None`` if a formula is not available for the arguments; in this
  3327. case, the infrastructure will fall back to a default implementation. The
  3328. ``sample`` method must accept ``shape`` by position or keyword, but contrary
  3329. to the public method of the same name, the argument it receives will be the
  3330. *full* shape of the output array - that is, the shape passed to the public
  3331. method prepended to the broadcasted shape of random variable parameters.
  3332. Returns
  3333. -------
  3334. CustomDistribution : `UnivariateDistribution`
  3335. A subclass of `UnivariateDistribution` corresponding with `dist`. The
  3336. initializer requires all shape parameters to be passed as keyword arguments
  3337. (using the same names as the instance of `rv_continuous`/`rv_discrete`).
  3338. Notes
  3339. -----
  3340. The documentation of `UnivariateDistribution` is not rendered. See below for
  3341. an example of how to instantiate the class (i.e. pass all shape parameters of
  3342. `dist` to the initializer as keyword arguments). Documentation of all methods
  3343. is identical to that of `scipy.stats.Normal`. Use ``help`` on the returned
  3344. class or its methods for more information.
  3345. Examples
  3346. --------
  3347. >>> import numpy as np
  3348. >>> import matplotlib.pyplot as plt
  3349. >>> from scipy import stats
  3350. >>> from scipy import special
  3351. Create a `ContinuousDistribution` from `scipy.stats.loguniform`.
  3352. >>> LogUniform = stats.make_distribution(stats.loguniform)
  3353. >>> X = LogUniform(a=1.0, b=3.0)
  3354. >>> np.isclose((X + 0.25).median(), stats.loguniform.ppf(0.5, 1, 3, loc=0.25))
  3355. np.True_
  3356. >>> X.plot()
  3357. >>> sample = X.sample(10000, rng=np.random.default_rng())
  3358. >>> plt.hist(sample, density=True, bins=30)
  3359. >>> plt.legend(('pdf', 'histogram'))
  3360. >>> plt.show()
  3361. Create a custom distribution.
  3362. >>> class MyLogUniform:
  3363. ... @property
  3364. ... def __make_distribution_version__(self):
  3365. ... return "1.16.0"
  3366. ...
  3367. ... @property
  3368. ... def parameters(self):
  3369. ... return {'a': {'endpoints': (0, np.inf),
  3370. ... 'inclusive': (False, False)},
  3371. ... 'b': {'endpoints': ('a', np.inf),
  3372. ... 'inclusive': (False, False)}}
  3373. ...
  3374. ... @property
  3375. ... def support(self):
  3376. ... return {'endpoints': ('a', 'b'), 'inclusive': (True, True)}
  3377. ...
  3378. ... def pdf(self, x, a, b):
  3379. ... return 1 / (x * (np.log(b)- np.log(a)))
  3380. >>>
  3381. >>> MyLogUniform = stats.make_distribution(MyLogUniform())
  3382. >>> Y = MyLogUniform(a=1.0, b=3.0)
  3383. >>> np.isclose(Y.cdf(2.), X.cdf(2.))
  3384. np.True_
  3385. Create a custom distribution with variable support.
  3386. >>> class MyUniformCube:
  3387. ... @property
  3388. ... def __make_distribution_version__(self):
  3389. ... return "1.16.0"
  3390. ...
  3391. ... @property
  3392. ... def parameters(self):
  3393. ... return {"a": (-np.inf, np.inf),
  3394. ... "b": {'endpoints':('a', np.inf), 'inclusive':(True, False)}}
  3395. ...
  3396. ... @property
  3397. ... def support(self):
  3398. ... def left(*, a, b):
  3399. ... return a**3
  3400. ...
  3401. ... def right(*, a, b):
  3402. ... return b**3
  3403. ... return (left, right)
  3404. ...
  3405. ... def pdf(self, x, *, a, b):
  3406. ... return 1 / (3*(b - a)*np.cbrt(x)**2)
  3407. ...
  3408. ... def cdf(self, x, *, a, b):
  3409. ... return (np.cbrt(x) - a) / (b - a)
  3410. >>>
  3411. >>> MyUniformCube = stats.make_distribution(MyUniformCube())
  3412. >>> X = MyUniformCube(a=-2, b=2)
  3413. >>> Y = stats.Uniform(a=-2, b=2)**3
  3414. >>> X.support()
  3415. (-8.0, 8.0)
  3416. >>> np.isclose(X.cdf(2.1), Y.cdf(2.1))
  3417. np.True_
  3418. Create a custom distribution with multiple parameterizations. Here we create a
  3419. custom version of the beta distribution that has an alternative parameterization
  3420. in terms of the mean ``mu`` and a dispersion parameter ``nu``.
  3421. >>> class MyBeta:
  3422. ... @property
  3423. ... def __make_distribution_version__(self):
  3424. ... return "1.16.0"
  3425. ...
  3426. ... @property
  3427. ... def parameters(self):
  3428. ... return ({"a": (0, np.inf), "b": (0, np.inf)},
  3429. ... {"mu": (0, 1), "nu": (0, np.inf)})
  3430. ...
  3431. ... def process_parameters(self, a=None, b=None, mu=None, nu=None):
  3432. ... if a is not None and b is not None:
  3433. ... nu = a + b
  3434. ... mu = a / nu
  3435. ... else:
  3436. ... a = mu * nu
  3437. ... b = nu - a
  3438. ... return dict(a=a, b=b, mu=mu, nu=nu)
  3439. ...
  3440. ... @property
  3441. ... def support(self):
  3442. ... return {'endpoints': (0, 1)}
  3443. ...
  3444. ... def pdf(self, x, a, b, mu, nu):
  3445. ... return special._ufuncs._beta_pdf(x, a, b)
  3446. ...
  3447. ... def cdf(self, x, a, b, mu, nu):
  3448. ... return special.betainc(a, b, x)
  3449. >>>
  3450. >>> MyBeta = stats.make_distribution(MyBeta())
  3451. >>> X = MyBeta(a=2.0, b=2.0)
  3452. >>> Y = MyBeta(mu=0.5, nu=4.0)
  3453. >>> np.isclose(X.pdf(0.3), Y.pdf(0.3))
  3454. np.True_
  3455. """
  3456. if dist in {stats.levy_stable, stats.vonmises, stats.hypergeom,
  3457. stats.nchypergeom_fisher, stats.nchypergeom_wallenius,
  3458. stats.poisson_binom}:
  3459. raise NotImplementedError(f"`{dist.name}` is not supported.")
  3460. if isinstance(dist, stats.rv_continuous | stats.rv_discrete):
  3461. return _make_distribution_rv_generic(dist)
  3462. elif getattr(dist, "__make_distribution_version__", "0.0.0") >= "1.16.0":
  3463. return _make_distribution_custom(dist)
  3464. else:
  3465. message = ("The argument must be an instance of `rv_continuous`, "
  3466. "`rv_discrete`, or an instance of a class with attribute "
  3467. "`__make_distribution_version__ >= 1.16`.")
  3468. raise ValueError(message)
  3469. def _make_distribution_rv_generic(dist):
  3470. parameters = []
  3471. names = []
  3472. support = getattr(dist, '_support', (dist.a, dist.b))
  3473. for shape_info in dist._shape_info():
  3474. domain = _RealInterval(endpoints=shape_info.endpoints,
  3475. inclusive=shape_info.inclusive)
  3476. param = _RealParameter(shape_info.name, domain=domain)
  3477. parameters.append(param)
  3478. names.append(shape_info.name)
  3479. repr_str = _distribution_names.get(dist.name, dist.name.capitalize())
  3480. if isinstance(dist, stats.rv_continuous):
  3481. old_class, new_class = stats.rv_continuous, ContinuousDistribution
  3482. else:
  3483. old_class, new_class = stats.rv_discrete, DiscreteDistribution
  3484. def _overrides(method_name):
  3485. return (getattr(dist.__class__, method_name, None)
  3486. is not getattr(old_class, method_name, None))
  3487. if _overrides("_get_support"):
  3488. def left(**parameter_values):
  3489. a, _ = dist._get_support(**parameter_values)
  3490. return np.asarray(a)[()]
  3491. def right(**parameter_values):
  3492. _, b = dist._get_support(**parameter_values)
  3493. return np.asarray(b)[()]
  3494. endpoints = (left, right)
  3495. else:
  3496. endpoints = support
  3497. _x_support = _RealInterval(endpoints=endpoints, inclusive=(True, True))
  3498. _x_param = _RealParameter('x', domain=_x_support, typical=(-1, 1))
  3499. class CustomDistribution(new_class):
  3500. _parameterizations = ([_Parameterization(*parameters)] if parameters
  3501. else [])
  3502. _variable = _x_param
  3503. __class_getitem__ = None
  3504. def __repr__(self):
  3505. s = super().__repr__()
  3506. return s.replace('CustomDistribution', repr_str)
  3507. def __str__(self):
  3508. s = super().__str__()
  3509. return s.replace('CustomDistribution', repr_str)
  3510. def _sample_formula(self, full_shape=(), *, rng=None, **kwargs):
  3511. return dist._rvs(size=full_shape, random_state=rng, **kwargs)
  3512. def _moment_raw_formula(self, order, **kwargs):
  3513. return dist._munp(int(order), **kwargs)
  3514. def _moment_raw_formula_1(self, order, **kwargs):
  3515. if order != 1:
  3516. return None
  3517. return dist._stats(**kwargs)[0]
  3518. def _moment_central_formula(self, order, **kwargs):
  3519. if order != 2:
  3520. return None
  3521. return dist._stats(**kwargs)[1]
  3522. def _moment_standard_formula(self, order, **kwargs):
  3523. if order == 3:
  3524. if dist._stats_has_moments:
  3525. kwargs['moments'] = 's'
  3526. return dist._stats(**kwargs)[int(order - 1)]
  3527. elif order == 4:
  3528. if dist._stats_has_moments:
  3529. kwargs['moments'] = 'k'
  3530. k = dist._stats(**kwargs)[int(order - 1)]
  3531. return k if k is None else k + 3
  3532. else:
  3533. return None
  3534. methods = {'_logpdf': '_logpdf_formula',
  3535. '_pdf': '_pdf_formula',
  3536. '_logpmf': '_logpmf_formula',
  3537. '_pmf': '_pmf_formula',
  3538. '_logcdf': '_logcdf_formula',
  3539. '_cdf': '_cdf_formula',
  3540. '_logsf': '_logccdf_formula',
  3541. '_sf': '_ccdf_formula',
  3542. '_ppf': '_icdf_formula',
  3543. '_isf': '_iccdf_formula',
  3544. '_entropy': '_entropy_formula',
  3545. '_median': '_median_formula'}
  3546. # These are not desirable overrides for the new infrastructure
  3547. skip_override = {'norminvgauss': {'_sf', '_isf'}}
  3548. for old_method, new_method in methods.items():
  3549. if dist.name in skip_override and old_method in skip_override[dist.name]:
  3550. continue
  3551. # If method of old distribution overrides generic implementation...
  3552. method = getattr(dist.__class__, old_method, None)
  3553. super_method = getattr(old_class, old_method, None)
  3554. if method is not super_method:
  3555. # Make it an attribute of the new object with the new name
  3556. setattr(CustomDistribution, new_method, getattr(dist, old_method))
  3557. if _overrides('_munp'):
  3558. CustomDistribution._moment_raw_formula = _moment_raw_formula
  3559. if _overrides('_rvs'):
  3560. CustomDistribution._sample_formula = _sample_formula
  3561. if _overrides('_stats'):
  3562. CustomDistribution._moment_standardized_formula = _moment_standard_formula
  3563. if not _overrides('_munp'):
  3564. CustomDistribution._moment_raw_formula = _moment_raw_formula_1
  3565. CustomDistribution._moment_central_formula = _moment_central_formula
  3566. support_etc = _combine_docs(CustomDistribution, include_examples=False).lstrip()
  3567. docs = [
  3568. f"This class represents `scipy.stats.{dist.name}` as a subclass of "
  3569. f"`{new_class}`.",
  3570. f"The `repr`/`str` of class instances is `{repr_str}`.",
  3571. f"The PDF of the distribution is defined {support_etc}"
  3572. ]
  3573. CustomDistribution.__doc__ = ("\n".join(docs))
  3574. return CustomDistribution
  3575. def _get_domain_info(info):
  3576. domain_info = {"endpoints": info} if isinstance(info, tuple) else info
  3577. typical = domain_info.pop("typical", None)
  3578. return domain_info, typical
  3579. def _make_distribution_custom(dist):
  3580. dist_parameters = (
  3581. dist.parameters if isinstance(dist.parameters, tuple) else (dist.parameters, )
  3582. )
  3583. parameterizations = []
  3584. for parameterization in dist_parameters:
  3585. # The attribute name ``parameters`` appears reasonable from a user facing
  3586. # perspective, but there is a little tension here with the internal. It's
  3587. # important to keep in mind that the ``parameters`` attribute in a
  3588. # user-created custom distribution specifies ``_parameterizations`` within
  3589. # the infrastructure.
  3590. parameters = []
  3591. for name, info in parameterization.items():
  3592. domain_info, typical = _get_domain_info(info)
  3593. domain = _RealInterval(**domain_info)
  3594. param = _RealParameter(name, domain=domain, typical=typical)
  3595. parameters.append(param)
  3596. parameterizations.append(_Parameterization(*parameters) if parameters else [])
  3597. domain_info, _ = _get_domain_info(dist.support)
  3598. _x_support = _RealInterval(**domain_info)
  3599. _x_param = _RealParameter('x', domain=_x_support)
  3600. repr_str = dist.__class__.__name__
  3601. class CustomDistribution(ContinuousDistribution):
  3602. _parameterizations = parameterizations
  3603. _variable = _x_param
  3604. def __repr__(self):
  3605. s = super().__repr__()
  3606. return s.replace('CustomDistribution', repr_str)
  3607. def __str__(self):
  3608. s = super().__str__()
  3609. return s.replace('CustomDistribution', repr_str)
  3610. methods = {'sample', 'logentropy', 'entropy',
  3611. 'median', 'mode', 'logpdf', 'pdf',
  3612. 'logcdf2', 'logcdf', 'cdf2', 'cdf',
  3613. 'logccdf2', 'logccdf', 'ccdf2', 'ccdf',
  3614. 'ilogcdf', 'icdf', 'ilogccdf', 'iccdf'}
  3615. for method in methods:
  3616. if hasattr(dist, method):
  3617. # Make it an attribute of the new object with the new name
  3618. new_method = f"_{method}_formula"
  3619. setattr(CustomDistribution, new_method, getattr(dist, method))
  3620. if hasattr(dist, 'moment'):
  3621. def _moment_raw_formula(self, order, **kwargs):
  3622. return dist.moment(order, kind='raw', **kwargs)
  3623. def _moment_central_formula(self, order, **kwargs):
  3624. return dist.moment(order, kind='central', **kwargs)
  3625. def _moment_standardized_formula(self, order, **kwargs):
  3626. return dist.moment(order, kind='standardized', **kwargs)
  3627. CustomDistribution._moment_raw_formula = _moment_raw_formula
  3628. CustomDistribution._moment_central_formula = _moment_central_formula
  3629. CustomDistribution._moment_standardized_formula = _moment_standardized_formula
  3630. if hasattr(dist, 'process_parameters'):
  3631. setattr(
  3632. CustomDistribution,
  3633. "_process_parameters",
  3634. getattr(dist, "process_parameters")
  3635. )
  3636. support_etc = _combine_docs(CustomDistribution, include_examples=False).lstrip()
  3637. docs = [
  3638. f"This class represents `{repr_str}` as a subclass of "
  3639. "`ContinuousDistribution`.",
  3640. f"The PDF of the distribution is defined {support_etc}"
  3641. ]
  3642. CustomDistribution.__doc__ = ("\n".join(docs))
  3643. return CustomDistribution
  3644. # Rough sketch of how we might shift/scale distributions. The purpose of
  3645. # making it a separate class is for
  3646. # a) simplicity of the ContinuousDistribution class and
  3647. # b) avoiding the requirement that every distribution accept loc/scale.
  3648. # The simplicity of ContinuousDistribution is important, because there are
  3649. # several other distribution transformations to be supported; e.g., truncation,
  3650. # wrapping, folding, and doubling. We wouldn't want to cram all of this
  3651. # into the `ContinuousDistribution` class. Also, the order of the composition
  3652. # matters (e.g. truncate then shift/scale or vice versa). It's easier to
  3653. # accommodate different orders if the transformation is built up from
  3654. # components rather than all built into `ContinuousDistribution`.
  3655. def _shift_scale_distribution_function_2arg(func):
  3656. def wrapped(self, x, y, *args, loc, scale, sign, **kwargs):
  3657. item = func.__name__
  3658. f = getattr(self._dist, item)
  3659. # Obviously it's possible to get away with half of the work here.
  3660. # Let's focus on correct results first and optimize later.
  3661. xt = self._transform(x, loc, scale)
  3662. yt = self._transform(y, loc, scale)
  3663. fxy = f(xt, yt, *args, **kwargs)
  3664. fyx = f(yt, xt, *args, **kwargs)
  3665. return np.real_if_close(np.where(sign, fxy, fyx))[()]
  3666. return wrapped
  3667. def _shift_scale_distribution_function(func):
  3668. # c is for complementary
  3669. citem = {'_logcdf_dispatch': '_logccdf_dispatch',
  3670. '_cdf_dispatch': '_ccdf_dispatch',
  3671. '_logccdf_dispatch': '_logcdf_dispatch',
  3672. '_ccdf_dispatch': '_cdf_dispatch'}
  3673. def wrapped(self, x, *args, loc, scale, sign, **kwargs):
  3674. item = func.__name__
  3675. f = getattr(self._dist, item)
  3676. cf = getattr(self._dist, citem[item])
  3677. # Obviously it's possible to get away with half of the work here.
  3678. # Let's focus on correct results first and optimize later.
  3679. xt = self._transform(x, loc, scale)
  3680. fx = f(xt, *args, **kwargs)
  3681. cfx = cf(xt, *args, **kwargs)
  3682. return np.where(sign, fx, cfx)[()]
  3683. return wrapped
  3684. def _shift_scale_inverse_function(func):
  3685. citem = {'_ilogcdf_dispatch': '_ilogccdf_dispatch',
  3686. '_icdf_dispatch': '_iccdf_dispatch',
  3687. '_ilogccdf_dispatch': '_ilogcdf_dispatch',
  3688. '_iccdf_dispatch': '_icdf_dispatch'}
  3689. def wrapped(self, p, *args, loc, scale, sign, **kwargs):
  3690. item = func.__name__
  3691. f = getattr(self._dist, item)
  3692. cf = getattr(self._dist, citem[item])
  3693. # Obviously it's possible to get away with half of the work here.
  3694. # Let's focus on correct results first and optimize later.
  3695. fx = self._itransform(f(p, *args, **kwargs), loc, scale)
  3696. cfx = self._itransform(cf(p, *args, **kwargs), loc, scale)
  3697. return np.where(sign, fx, cfx)[()]
  3698. return wrapped
  3699. class TransformedDistribution(ContinuousDistribution):
  3700. def __init__(self, X, /, *args, **kwargs):
  3701. if not isinstance(X, ContinuousDistribution):
  3702. message = "Transformations are currently only supported for continuous RVs."
  3703. raise NotImplementedError(message)
  3704. self._copy_parameterization()
  3705. self._variable = X._variable
  3706. self._dist = X
  3707. if X._parameterization:
  3708. # Add standard distribution parameters to our parameterization
  3709. dist_parameters = X._parameterization.parameters
  3710. set_params = set(dist_parameters)
  3711. if not self._parameterizations:
  3712. self._parameterizations.append(_Parameterization())
  3713. for parameterization in self._parameterizations:
  3714. if set_params.intersection(parameterization.parameters):
  3715. message = (f"One or more of the parameters of {X} has "
  3716. "the same name as a parameter of "
  3717. f"{self.__class__.__name__}. Name collisions "
  3718. "create ambiguities and are not supported.")
  3719. raise ValueError(message)
  3720. parameterization.parameters.update(dist_parameters)
  3721. super().__init__(*args, **kwargs)
  3722. def _overrides(self, method_name):
  3723. return (self._dist._overrides(method_name)
  3724. or super()._overrides(method_name))
  3725. def reset_cache(self):
  3726. self._dist.reset_cache()
  3727. super().reset_cache()
  3728. def _update_parameters(self, *, validation_policy=None, **params):
  3729. # maybe broadcast everything before processing?
  3730. parameters = {}
  3731. # There may be some issues with _original_parameters
  3732. # We only want to update with _dist._original_parameters during
  3733. # initialization. Afterward that, we want to start with
  3734. # self._original_parameters.
  3735. parameters.update(self._dist._original_parameters)
  3736. parameters.update(params)
  3737. super()._update_parameters(validation_policy=validation_policy, **parameters)
  3738. def _process_parameters(self, **params):
  3739. return self._dist._process_parameters(**params)
  3740. def __repr__(self):
  3741. raise NotImplementedError()
  3742. def __str__(self):
  3743. raise NotImplementedError()
  3744. class TruncatedDistribution(TransformedDistribution):
  3745. """Truncated distribution."""
  3746. # TODO:
  3747. # - consider avoiding catastropic cancellation by using appropriate tail
  3748. # - if the mode of `_dist` is within the support, it's still the mode
  3749. # - rejection sampling might be more efficient than inverse transform
  3750. _lb_domain = _RealInterval(endpoints=(-inf, 'ub'), inclusive=(True, False))
  3751. _lb_param = _RealParameter('lb', symbol=r'b_l',
  3752. domain=_lb_domain, typical=(0.1, 0.2))
  3753. _ub_domain = _RealInterval(endpoints=('lb', inf), inclusive=(False, True))
  3754. _ub_param = _RealParameter('ub', symbol=r'b_u',
  3755. domain=_ub_domain, typical=(0.8, 0.9))
  3756. _parameterizations = [_Parameterization(_lb_param, _ub_param),
  3757. _Parameterization(_lb_param),
  3758. _Parameterization(_ub_param)]
  3759. def __init__(self, X, /, *args, lb=-np.inf, ub=np.inf, **kwargs):
  3760. return super().__init__(X, *args, lb=lb, ub=ub, **kwargs)
  3761. def _process_parameters(self, lb=None, ub=None, **params):
  3762. lb = lb if lb is not None else np.full_like(lb, -np.inf)[()]
  3763. ub = ub if ub is not None else np.full_like(ub, np.inf)[()]
  3764. parameters = self._dist._process_parameters(**params)
  3765. a, b = self._support(lb=lb, ub=ub, **parameters)
  3766. logmass = self._dist._logcdf2_dispatch(a, b, **parameters)
  3767. parameters.update(dict(lb=lb, ub=ub, _a=a, _b=b, logmass=logmass))
  3768. return parameters
  3769. def _support(self, lb, ub, **params):
  3770. a, b = self._dist._support(**params)
  3771. return np.maximum(a, lb), np.minimum(b, ub)
  3772. def _overrides(self, method_name):
  3773. return False
  3774. def _logpdf_dispatch(self, x, *args, lb, ub, _a, _b, logmass, **params):
  3775. logpdf = self._dist._logpdf_dispatch(x, *args, **params)
  3776. return logpdf - logmass
  3777. def _logcdf_dispatch(self, x, *args, lb, ub, _a, _b, logmass, **params):
  3778. logcdf = self._dist._logcdf2_dispatch(_a, x, *args, **params)
  3779. # of course, if this result is small we could compute with the other tail
  3780. return logcdf - logmass
  3781. def _logccdf_dispatch(self, x, *args, lb, ub, _a, _b, logmass, **params):
  3782. logccdf = self._dist._logcdf2_dispatch(x, _b, *args, **params)
  3783. return logccdf - logmass
  3784. def _logcdf2_dispatch(self, x, y, *args, lb, ub, _a, _b, logmass, **params):
  3785. logcdf2 = self._dist._logcdf2_dispatch(x, y, *args, **params)
  3786. return logcdf2 - logmass
  3787. def _ilogcdf_dispatch(self, logp, *args, lb, ub, _a, _b, logmass, **params):
  3788. log_Fa = self._dist._logcdf_dispatch(_a, *args, **params)
  3789. logp_adjusted = np.logaddexp(log_Fa, logp + logmass)
  3790. return self._dist._ilogcdf_dispatch(logp_adjusted, *args, **params)
  3791. def _ilogccdf_dispatch(self, logp, *args, lb, ub, _a, _b, logmass, **params):
  3792. log_cFb = self._dist._logccdf_dispatch(_b, *args, **params)
  3793. logp_adjusted = np.logaddexp(log_cFb, logp + logmass)
  3794. return self._dist._ilogccdf_dispatch(logp_adjusted, *args, **params)
  3795. def _icdf_dispatch(self, p, *args, lb, ub, _a, _b, logmass, **params):
  3796. Fa = self._dist._cdf_dispatch(_a, *args, **params)
  3797. p_adjusted = Fa + p*np.exp(logmass)
  3798. return self._dist._icdf_dispatch(p_adjusted, *args, **params)
  3799. def _iccdf_dispatch(self, p, *args, lb, ub, _a, _b, logmass, **params):
  3800. cFb = self._dist._ccdf_dispatch(_b, *args, **params)
  3801. p_adjusted = cFb + p*np.exp(logmass)
  3802. return self._dist._iccdf_dispatch(p_adjusted, *args, **params)
  3803. def __repr__(self):
  3804. with np.printoptions(threshold=10):
  3805. return (f"truncate({repr(self._dist)}, "
  3806. f"lb={repr(self.lb)}, ub={repr(self.ub)})")
  3807. def __str__(self):
  3808. with np.printoptions(threshold=10):
  3809. return (f"truncate({str(self._dist)}, "
  3810. f"lb={str(self.lb)}, ub={str(self.ub)})")
  3811. @xp_capabilities(np_only=True)
  3812. def truncate(X, lb=-np.inf, ub=np.inf):
  3813. """Truncate the support of a random variable.
  3814. Given a random variable `X`, `truncate` returns a random variable with
  3815. support truncated to the interval between `lb` and `ub`. The underlying
  3816. probability density function is normalized accordingly.
  3817. Parameters
  3818. ----------
  3819. X : `ContinuousDistribution`
  3820. The random variable to be truncated.
  3821. lb, ub : float array-like
  3822. The lower and upper truncation points, respectively. Must be
  3823. broadcastable with one another and the shape of `X`.
  3824. Returns
  3825. -------
  3826. X : `ContinuousDistribution`
  3827. The truncated random variable.
  3828. References
  3829. ----------
  3830. .. [1] "Truncated Distribution". *Wikipedia*.
  3831. https://en.wikipedia.org/wiki/Truncated_distribution
  3832. Examples
  3833. --------
  3834. Compare against `scipy.stats.truncnorm`, which truncates a standard normal,
  3835. *then* shifts and scales it.
  3836. >>> import numpy as np
  3837. >>> import matplotlib.pyplot as plt
  3838. >>> from scipy import stats
  3839. >>> loc, scale, lb, ub = 1, 2, -2, 2
  3840. >>> X = stats.truncnorm(lb, ub, loc, scale)
  3841. >>> Y = scale * stats.truncate(stats.Normal(), lb, ub) + loc
  3842. >>> x = np.linspace(-3, 5, 300)
  3843. >>> plt.plot(x, X.pdf(x), '-', label='X')
  3844. >>> plt.plot(x, Y.pdf(x), '--', label='Y')
  3845. >>> plt.xlabel('x')
  3846. >>> plt.ylabel('PDF')
  3847. >>> plt.title('Truncated, then Shifted/Scaled Normal')
  3848. >>> plt.legend()
  3849. >>> plt.show()
  3850. However, suppose we wish to shift and scale a normal random variable,
  3851. then truncate its support to given values. This is straightforward with
  3852. `truncate`.
  3853. >>> Z = stats.truncate(scale * stats.Normal() + loc, lb, ub)
  3854. >>> Z.plot()
  3855. >>> plt.show()
  3856. Furthermore, `truncate` can be applied to any random variable:
  3857. >>> Rayleigh = stats.make_distribution(stats.rayleigh)
  3858. >>> W = stats.truncate(Rayleigh(), lb=0.5, ub=3)
  3859. >>> W.plot()
  3860. >>> plt.show()
  3861. """
  3862. return TruncatedDistribution(X, lb=lb, ub=ub)
  3863. class ShiftedScaledDistribution(TransformedDistribution):
  3864. """Distribution with a standard shift/scale transformation."""
  3865. # Unclear whether infinite loc/scale will work reasonably in all cases
  3866. _loc_domain = _RealInterval(endpoints=(-inf, inf), inclusive=(True, True))
  3867. _loc_param = _RealParameter('loc', symbol=r'\mu',
  3868. domain=_loc_domain, typical=(1, 2))
  3869. _scale_domain = _RealInterval(endpoints=(-inf, inf), inclusive=(True, True))
  3870. _scale_param = _RealParameter('scale', symbol=r'\sigma',
  3871. domain=_scale_domain, typical=(0.1, 10))
  3872. _parameterizations = [_Parameterization(_loc_param, _scale_param),
  3873. _Parameterization(_loc_param),
  3874. _Parameterization(_scale_param)]
  3875. def _process_parameters(self, loc=None, scale=None, **params):
  3876. loc = loc if loc is not None else np.zeros_like(scale)[()]
  3877. scale = scale if scale is not None else np.ones_like(loc)[()]
  3878. sign = scale > 0
  3879. parameters = self._dist._process_parameters(**params)
  3880. parameters.update(dict(loc=loc, scale=scale, sign=sign))
  3881. return parameters
  3882. def _transform(self, x, loc, scale, **kwargs):
  3883. return (x - loc)/scale
  3884. def _itransform(self, x, loc, scale, **kwargs):
  3885. return x * scale + loc
  3886. def _support(self, loc, scale, sign, **params):
  3887. # Add shortcut for infinite support?
  3888. a, b = self._dist._support(**params)
  3889. a, b = self._itransform(a, loc, scale), self._itransform(b, loc, scale)
  3890. return np.where(sign, a, b)[()], np.where(sign, b, a)[()]
  3891. def __repr__(self):
  3892. with np.printoptions(threshold=10):
  3893. result = f"{repr(self.scale)}*{repr(self._dist)}"
  3894. if not self.loc.ndim and self.loc < 0:
  3895. result += f" - {repr(-self.loc)}"
  3896. elif (np.any(self.loc != 0)
  3897. or not np.can_cast(self.loc.dtype, self.scale.dtype)):
  3898. # We don't want to hide a zero array loc if it can cause
  3899. # a type promotion.
  3900. result += f" + {repr(self.loc)}"
  3901. return result
  3902. def __str__(self):
  3903. with np.printoptions(threshold=10):
  3904. result = f"{str(self.scale)}*{str(self._dist)}"
  3905. if not self.loc.ndim and self.loc < 0:
  3906. result += f" - {str(-self.loc)}"
  3907. elif (np.any(self.loc != 0)
  3908. or not np.can_cast(self.loc.dtype, self.scale.dtype)):
  3909. # We don't want to hide a zero array loc if it can cause
  3910. # a type promotion.
  3911. result += f" + {str(self.loc)}"
  3912. return result
  3913. # Here, we override all the `_dispatch` methods rather than the public
  3914. # methods or _function methods. Why not the public methods?
  3915. # If we were to override the public methods, then other
  3916. # TransformedDistribution classes (which could transform a
  3917. # ShiftedScaledDistribution) would need to call the public methods of
  3918. # ShiftedScaledDistribution, which would run the input validation again.
  3919. # Why not the _function methods? For distributions that rely on the
  3920. # default implementation of methods (e.g. `quadrature`, `inversion`),
  3921. # the implementation would "see" the location and scale like other
  3922. # distribution parameters, so they could affect the accuracy of the
  3923. # calculations. I think it is cleaner if `loc` and `scale` do not affect
  3924. # the underlying calculations at all.
  3925. def _entropy_dispatch(self, *args, loc, scale, sign, **params):
  3926. return (self._dist._entropy_dispatch(*args, **params)
  3927. + np.log(np.abs(scale)))
  3928. def _logentropy_dispatch(self, *args, loc, scale, sign, **params):
  3929. lH0 = self._dist._logentropy_dispatch(*args, **params)
  3930. lls = np.log(np.log(np.abs(scale))+0j)
  3931. return special.logsumexp(np.broadcast_arrays(lH0, lls), axis=0)
  3932. def _median_dispatch(self, *, method, loc, scale, sign, **params):
  3933. raw = self._dist._median_dispatch(method=method, **params)
  3934. return self._itransform(raw, loc, scale)
  3935. def _mode_dispatch(self, *, method, loc, scale, sign, **params):
  3936. raw = self._dist._mode_dispatch(method=method, **params)
  3937. return self._itransform(raw, loc, scale)
  3938. def _logpdf_dispatch(self, x, *args, loc, scale, sign, **params):
  3939. x = self._transform(x, loc, scale)
  3940. logpdf = self._dist._logpdf_dispatch(x, *args, **params)
  3941. return logpdf - np.log(np.abs(scale))
  3942. def _pdf_dispatch(self, x, *args, loc, scale, sign, **params):
  3943. x = self._transform(x, loc, scale)
  3944. pdf = self._dist._pdf_dispatch(x, *args, **params)
  3945. return pdf / np.abs(scale)
  3946. def _logpmf_dispatch(self, x, *args, loc, scale, sign, **params):
  3947. x = self._transform(x, loc, scale)
  3948. logpmf = self._dist._logpmf_dispatch(x, *args, **params)
  3949. return logpmf - np.log(np.abs(scale))
  3950. def _pmf_dispatch(self, x, *args, loc, scale, sign, **params):
  3951. x = self._transform(x, loc, scale)
  3952. pmf = self._dist._pmf_dispatch(x, *args, **params)
  3953. return pmf / np.abs(scale)
  3954. def _logpxf_dispatch(self, x, *args, loc, scale, sign, **params):
  3955. x = self._transform(x, loc, scale)
  3956. logpxf = self._dist._logpxf_dispatch(x, *args, **params)
  3957. return logpxf - np.log(np.abs(scale))
  3958. def _pxf_dispatch(self, x, *args, loc, scale, sign, **params):
  3959. x = self._transform(x, loc, scale)
  3960. pxf = self._dist._pxf_dispatch(x, *args, **params)
  3961. return pxf / np.abs(scale)
  3962. # Sorry about the magic. This is just a draft to show the behavior.
  3963. @_shift_scale_distribution_function
  3964. def _logcdf_dispatch(self, x, *, method=None, **params):
  3965. pass
  3966. @_shift_scale_distribution_function
  3967. def _cdf_dispatch(self, x, *, method=None, **params):
  3968. pass
  3969. @_shift_scale_distribution_function
  3970. def _logccdf_dispatch(self, x, *, method=None, **params):
  3971. pass
  3972. @_shift_scale_distribution_function
  3973. def _ccdf_dispatch(self, x, *, method=None, **params):
  3974. pass
  3975. @_shift_scale_distribution_function_2arg
  3976. def _logcdf2_dispatch(self, x, y, *, method=None, **params):
  3977. pass
  3978. @_shift_scale_distribution_function_2arg
  3979. def _cdf2_dispatch(self, x, y, *, method=None, **params):
  3980. pass
  3981. @_shift_scale_distribution_function_2arg
  3982. def _logccdf2_dispatch(self, x, y, *, method=None, **params):
  3983. pass
  3984. @_shift_scale_distribution_function_2arg
  3985. def _ccdf2_dispatch(self, x, y, *, method=None, **params):
  3986. pass
  3987. @_shift_scale_inverse_function
  3988. def _ilogcdf_dispatch(self, x, *, method=None, **params):
  3989. pass
  3990. @_shift_scale_inverse_function
  3991. def _icdf_dispatch(self, x, *, method=None, **params):
  3992. pass
  3993. @_shift_scale_inverse_function
  3994. def _ilogccdf_dispatch(self, x, *, method=None, **params):
  3995. pass
  3996. @_shift_scale_inverse_function
  3997. def _iccdf_dispatch(self, x, *, method=None, **params):
  3998. pass
  3999. def _moment_standardized_dispatch(self, order, *, loc, scale, sign, methods,
  4000. **params):
  4001. res = (self._dist._moment_standardized_dispatch(
  4002. order, methods=methods, **params))
  4003. return None if res is None else res * np.sign(scale)**order
  4004. def _moment_central_dispatch(self, order, *, loc, scale, sign, methods,
  4005. **params):
  4006. res = (self._dist._moment_central_dispatch(
  4007. order, methods=methods, **params))
  4008. return None if res is None else res * scale**order
  4009. def _moment_raw_dispatch(self, order, *, loc, scale, sign, methods,
  4010. **params):
  4011. raw_moments = []
  4012. methods_highest_order = methods
  4013. for i in range(int(order) + 1):
  4014. methods = (self._moment_methods if i < order
  4015. else methods_highest_order)
  4016. raw = self._dist._moment_raw_dispatch(i, methods=methods, **params)
  4017. if raw is None:
  4018. return None
  4019. moment_i = raw * scale**i
  4020. raw_moments.append(moment_i)
  4021. return self._moment_transform_center(
  4022. order, raw_moments, loc, self._zero)
  4023. def _sample_dispatch(self, full_shape, *,
  4024. rng, loc, scale, sign, method, **params):
  4025. rvs = self._dist._sample_dispatch(full_shape, method=method, rng=rng, **params)
  4026. return self._itransform(rvs, loc=loc, scale=scale, sign=sign, **params)
  4027. def __add__(self, loc):
  4028. return ShiftedScaledDistribution(self._dist, loc=self.loc + loc,
  4029. scale=self.scale)
  4030. def __sub__(self, loc):
  4031. return ShiftedScaledDistribution(self._dist, loc=self.loc - loc,
  4032. scale=self.scale)
  4033. def __mul__(self, scale):
  4034. return ShiftedScaledDistribution(self._dist,
  4035. loc=self.loc * scale,
  4036. scale=self.scale * scale)
  4037. def __truediv__(self, scale):
  4038. return ShiftedScaledDistribution(self._dist,
  4039. loc=self.loc / scale,
  4040. scale=self.scale / scale)
  4041. class OrderStatisticDistribution(TransformedDistribution):
  4042. r"""Probability distribution of an order statistic
  4043. An instance of this class represents a random variable that follows the
  4044. distribution underlying the :math:`r^{\text{th}}` order statistic of a
  4045. sample of :math:`n` observations of a random variable :math:`X`.
  4046. Parameters
  4047. ----------
  4048. dist : `ContinuousDistribution`
  4049. The random variable :math:`X`
  4050. n : array_like
  4051. The (integer) sample size :math:`n`
  4052. r : array_like
  4053. The (integer) rank of the order statistic :math:`r`
  4054. Notes
  4055. -----
  4056. If we make :math:`n` observations of a continuous random variable
  4057. :math:`X` and sort them in increasing order
  4058. :math:`X_{(1)}, \dots, X_{(r)}, \dots, X_{(n)}`,
  4059. :math:`X_{(r)}` is known as the :math:`r^{\text{th}}` order statistic.
  4060. If the PDF, CDF, and CCDF underlying math:`X` are denoted :math:`f`,
  4061. :math:`F`, and :math:`F'`, respectively, then the PDF underlying
  4062. math:`X_{(r)}` is given by:
  4063. .. math::
  4064. f_r(x) = \frac{n!}{(r-1)! (n-r)!} f(x) F(x)^{r-1} F'(x)^{n - r}
  4065. The CDF and other methods of the distribution underlying :math:`X_{(r)}`
  4066. are calculated using the fact that :math:`X = F^{-1}(U)`, where :math:`U` is
  4067. a standard uniform random variable, and that the order statistics of
  4068. observations of `U` follow a beta distribution, :math:`B(r, n - r + 1)`.
  4069. References
  4070. ----------
  4071. .. [1] Order statistic. *Wikipedia*. https://en.wikipedia.org/wiki/Order_statistic
  4072. Examples
  4073. --------
  4074. Suppose we are interested in order statistics of samples of size five drawn
  4075. from the standard normal distribution. Plot the PDF underlying the fourth
  4076. order statistic and compare with a normalized histogram from simulation.
  4077. >>> import numpy as np
  4078. >>> import matplotlib.pyplot as plt
  4079. >>> from scipy import stats
  4080. >>> from scipy.stats._distribution_infrastructure import OrderStatisticDistribution
  4081. >>>
  4082. >>> X = stats.Normal()
  4083. >>> data = X.sample(shape=(10000, 5))
  4084. >>> ranks = np.sort(data, axis=1)
  4085. >>> Y = OrderStatisticDistribution(X, r=4, n=5)
  4086. >>>
  4087. >>> ax = plt.gca()
  4088. >>> Y.plot(ax=ax)
  4089. >>> ax.hist(ranks[:, 3], density=True, bins=30)
  4090. >>> plt.show()
  4091. """
  4092. # These can be restricted to _IntegerInterval/_IntegerParameter in a separate
  4093. # PR if desired.
  4094. _r_domain = _RealInterval(endpoints=(1, 'n'), inclusive=(True, True))
  4095. _r_param = _RealParameter('r', domain=_r_domain, typical=(1, 2))
  4096. _n_domain = _RealInterval(endpoints=(1, np.inf), inclusive=(True, True))
  4097. _n_param = _RealParameter('n', domain=_n_domain, typical=(1, 4))
  4098. _r_domain.define_parameters(_n_param)
  4099. _parameterizations = [_Parameterization(_r_param, _n_param)]
  4100. def __init__(self, dist, /, *args, r, n, **kwargs):
  4101. super().__init__(dist, *args, r=r, n=n, **kwargs)
  4102. def _support(self, *args, r, n, **kwargs):
  4103. return self._dist._support(*args, **kwargs)
  4104. def _process_parameters(self, r=None, n=None, **params):
  4105. parameters = self._dist._process_parameters(**params)
  4106. parameters.update(dict(r=r, n=n))
  4107. return parameters
  4108. def _overrides(self, method_name):
  4109. return method_name in {'_logpdf_formula', '_pdf_formula',
  4110. '_cdf_formula', '_ccdf_formula',
  4111. '_icdf_formula', '_iccdf_formula'}
  4112. def _logpdf_formula(self, x, r, n, **kwargs):
  4113. log_factor = special.betaln(r, n - r + 1)
  4114. log_fX = self._dist._logpdf_dispatch(x, **kwargs)
  4115. # log-methods sometimes use complex dtype with 0 imaginary component,
  4116. # but `_tanhsinh` doesn't accept complex limits of integration; take `real`.
  4117. log_FX = self._dist._logcdf_dispatch(x.real, **kwargs)
  4118. log_cFX = self._dist._logccdf_dispatch(x.real, **kwargs)
  4119. # This can be problematic when (r - 1)|(n-r) = 0 and `log_FX`|log_cFX = -inf
  4120. # The PDF in these cases is 0^0, so these should be replaced with log(1)=0
  4121. # return log_fX + (r-1)*log_FX + (n-r)*log_cFX - log_factor
  4122. rm1_log_FX = np.where((r - 1 == 0) & np.isneginf(log_FX), 0, (r-1)*log_FX)
  4123. nmr_log_cFX = np.where((n - r == 0) & np.isneginf(log_cFX), 0, (n-r)*log_cFX)
  4124. return log_fX + rm1_log_FX + nmr_log_cFX - log_factor
  4125. def _pdf_formula(self, x, r, n, **kwargs):
  4126. # 1 / factor = factorial(n) / (factorial(r-1) * factorial(n-r))
  4127. factor = special.beta(r, n - r + 1)
  4128. fX = self._dist._pdf_dispatch(x, **kwargs)
  4129. FX = self._dist._cdf_dispatch(x, **kwargs)
  4130. cFX = self._dist._ccdf_dispatch(x, **kwargs)
  4131. return fX * FX**(r-1) * cFX**(n-r) / factor
  4132. def _cdf_formula(self, x, r, n, **kwargs):
  4133. x_ = self._dist._cdf_dispatch(x, **kwargs)
  4134. return special.betainc(r, n-r+1, x_)
  4135. def _ccdf_formula(self, x, r, n, **kwargs):
  4136. x_ = self._dist._cdf_dispatch(x, **kwargs)
  4137. return special.betaincc(r, n-r+1, x_)
  4138. def _icdf_formula(self, p, r, n, **kwargs):
  4139. p_ = special.betaincinv(r, n-r+1, p)
  4140. return self._dist._icdf_dispatch(p_, **kwargs)
  4141. def _iccdf_formula(self, p, r, n, **kwargs):
  4142. p_ = special.betainccinv(r, n-r+1, p)
  4143. return self._dist._icdf_dispatch(p_, **kwargs)
  4144. def __repr__(self):
  4145. with np.printoptions(threshold=10):
  4146. return (f"order_statistic({repr(self._dist)}, r={repr(self.r)}, "
  4147. f"n={repr(self.n)})")
  4148. def __str__(self):
  4149. with np.printoptions(threshold=10):
  4150. return (f"order_statistic({str(self._dist)}, r={str(self.r)}, "
  4151. f"n={str(self.n)})")
  4152. @xp_capabilities(np_only=True)
  4153. def order_statistic(X, /, *, r, n):
  4154. r"""Probability distribution of an order statistic
  4155. Returns a random variable that follows the distribution underlying the
  4156. :math:`r^{\text{th}}` order statistic of a sample of :math:`n`
  4157. observations of a random variable :math:`X`.
  4158. Parameters
  4159. ----------
  4160. X : `ContinuousDistribution`
  4161. The random variable :math:`X`
  4162. r : array_like
  4163. The (positive integer) rank of the order statistic :math:`r`
  4164. n : array_like
  4165. The (positive integer) sample size :math:`n`
  4166. Returns
  4167. -------
  4168. Y : `ContinuousDistribution`
  4169. A random variable that follows the distribution of the prescribed
  4170. order statistic.
  4171. Notes
  4172. -----
  4173. If we make :math:`n` observations of a continuous random variable
  4174. :math:`X` and sort them in increasing order
  4175. :math:`X_{(1)}, \dots, X_{(r)}, \dots, X_{(n)}`,
  4176. :math:`X_{(r)}` is known as the :math:`r^{\text{th}}` order statistic.
  4177. If the PDF, CDF, and CCDF underlying math:`X` are denoted :math:`f`,
  4178. :math:`F`, and :math:`F'`, respectively, then the PDF underlying
  4179. math:`X_{(r)}` is given by:
  4180. .. math::
  4181. f_r(x) = \frac{n!}{(r-1)! (n-r)!} f(x) F(x)^{r-1} F'(x)^{n - r}
  4182. The CDF and other methods of the distribution underlying :math:`X_{(r)}`
  4183. are calculated using the fact that :math:`X = F^{-1}(U)`, where :math:`U` is
  4184. a standard uniform random variable, and that the order statistics of
  4185. observations of `U` follow a beta distribution, :math:`B(r, n - r + 1)`.
  4186. References
  4187. ----------
  4188. .. [1] Order statistic. *Wikipedia*. https://en.wikipedia.org/wiki/Order_statistic
  4189. Examples
  4190. --------
  4191. Suppose we are interested in order statistics of samples of size five drawn
  4192. from the standard normal distribution. Plot the PDF underlying each
  4193. order statistic and compare with a normalized histogram from simulation.
  4194. >>> import numpy as np
  4195. >>> import matplotlib.pyplot as plt
  4196. >>> from scipy import stats
  4197. >>>
  4198. >>> X = stats.Normal()
  4199. >>> data = X.sample(shape=(10000, 5))
  4200. >>> sorted = np.sort(data, axis=1)
  4201. >>> Y = stats.order_statistic(X, r=[1, 2, 3, 4, 5], n=5)
  4202. >>>
  4203. >>> ax = plt.gca()
  4204. >>> colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
  4205. >>> for i in range(5):
  4206. ... y = sorted[:, i]
  4207. ... ax.hist(y, density=True, bins=30, alpha=0.1, color=colors[i])
  4208. >>> Y.plot(ax=ax)
  4209. >>> plt.show()
  4210. """
  4211. r, n = np.asarray(r), np.asarray(n)
  4212. if np.any((r != np.floor(r)) | (r < 0)) or np.any((n != np.floor(n)) | (n < 0)):
  4213. message = "`r` and `n` must contain only positive integers."
  4214. raise ValueError(message)
  4215. return OrderStatisticDistribution(X, r=r, n=n)
  4216. class Mixture(_ProbabilityDistribution):
  4217. r"""Representation of a mixture distribution.
  4218. A mixture distribution is the distribution of a random variable
  4219. defined in the following way: first, a random variable is selected
  4220. from `components` according to the probabilities given by `weights`, then
  4221. the selected random variable is realized.
  4222. Parameters
  4223. ----------
  4224. components : sequence of `ContinuousDistribution`
  4225. The underlying instances of `ContinuousDistribution`.
  4226. All must have scalar shape parameters (if any); e.g., the `pdf` evaluated
  4227. at a scalar argument must return a scalar.
  4228. weights : sequence of floats, optional
  4229. The corresponding probabilities of selecting each random variable.
  4230. Must be non-negative and sum to one. The default behavior is to weight
  4231. all components equally.
  4232. Attributes
  4233. ----------
  4234. components : sequence of `ContinuousDistribution`
  4235. The underlying instances of `ContinuousDistribution`.
  4236. weights : ndarray
  4237. The corresponding probabilities of selecting each random variable.
  4238. Methods
  4239. -------
  4240. support
  4241. sample
  4242. moment
  4243. mean
  4244. median
  4245. mode
  4246. variance
  4247. standard_deviation
  4248. skewness
  4249. kurtosis
  4250. pdf
  4251. logpdf
  4252. cdf
  4253. icdf
  4254. ccdf
  4255. iccdf
  4256. logcdf
  4257. ilogcdf
  4258. logccdf
  4259. ilogccdf
  4260. entropy
  4261. Notes
  4262. -----
  4263. The following abbreviations are used throughout the documentation.
  4264. - PDF: probability density function
  4265. - CDF: cumulative distribution function
  4266. - CCDF: complementary CDF
  4267. - entropy: differential entropy
  4268. - log-*F*: logarithm of *F* (e.g. log-CDF)
  4269. - inverse *F*: inverse function of *F* (e.g. inverse CDF)
  4270. References
  4271. ----------
  4272. .. [1] Mixture distribution, *Wikipedia*,
  4273. https://en.wikipedia.org/wiki/Mixture_distribution
  4274. Examples
  4275. --------
  4276. A mixture of normal distributions:
  4277. >>> import numpy as np
  4278. >>> from scipy import stats
  4279. >>> import matplotlib.pyplot as plt
  4280. >>> X1 = stats.Normal(mu=-2, sigma=1)
  4281. >>> X2 = stats.Normal(mu=2, sigma=1)
  4282. >>> mixture = stats.Mixture([X1, X2], weights=[0.4, 0.6])
  4283. >>> print(f'mean: {mixture.mean():.2f}, '
  4284. ... f'median: {mixture.median():.2f}, '
  4285. ... f'mode: {mixture.mode():.2f}')
  4286. mean: 0.40, median: 1.04, mode: 2.00
  4287. >>> x = np.linspace(-10, 10, 300)
  4288. >>> plt.plot(x, mixture.pdf(x))
  4289. >>> plt.title('PDF of normal distribution mixture')
  4290. >>> plt.show()
  4291. """
  4292. # Todo:
  4293. # Add support for array shapes, weights
  4294. def _input_validation(self, components, weights):
  4295. if len(components) == 0:
  4296. message = ("`components` must contain at least one random variable.")
  4297. raise ValueError(message)
  4298. for var in components:
  4299. # will generalize to other kinds of distributions when there
  4300. # *are* other kinds of distributions
  4301. if not isinstance(var, ContinuousDistribution):
  4302. message = ("Each element of `components` must be an instance of "
  4303. "`ContinuousDistribution`.")
  4304. raise ValueError(message)
  4305. if not var._shape == ():
  4306. message = "All elements of `components` must have scalar shapes."
  4307. raise ValueError(message)
  4308. if weights is None:
  4309. return components, weights
  4310. weights = np.asarray(weights)
  4311. if weights.shape != (len(components),):
  4312. message = "`components` and `weights` must have the same length."
  4313. raise ValueError(message)
  4314. if not np.issubdtype(weights.dtype, np.inexact):
  4315. message = "`weights` must have floating point dtype."
  4316. raise ValueError(message)
  4317. if not np.isclose(np.sum(weights), 1.0):
  4318. message = "`weights` must sum to 1.0."
  4319. raise ValueError(message)
  4320. if not np.all(weights >= 0):
  4321. message = "All `weights` must be non-negative."
  4322. raise ValueError(message)
  4323. return components, weights
  4324. def __init__(self, components, *, weights=None):
  4325. components, weights = self._input_validation(components, weights)
  4326. n = len(components)
  4327. dtype = np.result_type(*(var._dtype for var in components))
  4328. self._shape = np.broadcast_shapes(*(var._shape for var in components))
  4329. self._dtype, self._components = dtype, components
  4330. self._weights = np.full(n, 1/n, dtype=dtype) if weights is None else weights
  4331. self.validation_policy = None
  4332. @property
  4333. def components(self):
  4334. return list(self._components)
  4335. @property
  4336. def weights(self):
  4337. return self._weights.copy()
  4338. def _full(self, val, *args):
  4339. args = [np.asarray(arg) for arg in args]
  4340. dtype = np.result_type(self._dtype, *(arg.dtype for arg in args))
  4341. shape = np.broadcast_shapes(self._shape, *(arg.shape for arg in args))
  4342. return np.full(shape, val, dtype=dtype)
  4343. def _sum(self, fun, *args):
  4344. out = self._full(0, *args)
  4345. for var, weight in zip(self._components, self._weights):
  4346. out += getattr(var, fun)(*args) * weight
  4347. return out[()]
  4348. def _logsum(self, fun, *args):
  4349. out = self._full(-np.inf, *args)
  4350. for var, log_weight in zip(self._components, np.log(self._weights)):
  4351. np.logaddexp(out, getattr(var, fun)(*args) + log_weight, out=out)
  4352. return out[()]
  4353. def support(self):
  4354. a = self._full(np.inf)
  4355. b = self._full(-np.inf)
  4356. for var in self._components:
  4357. a = np.minimum(a, var.support()[0])
  4358. b = np.maximum(b, var.support()[1])
  4359. return a, b
  4360. def _raise_if_method(self, method):
  4361. if method is not None:
  4362. raise NotImplementedError("`method` not implemented for this distribution.")
  4363. def logentropy(self, *, method=None):
  4364. self._raise_if_method(method)
  4365. def log_integrand(x):
  4366. # `x` passed by `_tanhsinh` will be of complex dtype because
  4367. # `log_integrand` returns complex values, but the imaginary
  4368. # component is always zero. Extract the real part because
  4369. # `logpdf` uses `logaddexp`, which fails for complex input.
  4370. return self.logpdf(x.real) + np.log(self.logpdf(x.real) + 0j)
  4371. res = _tanhsinh(log_integrand, *self.support(), log=True).integral
  4372. return _log_real_standardize(res + np.pi*1j)
  4373. def entropy(self, *, method=None):
  4374. self._raise_if_method(method)
  4375. return _tanhsinh(lambda x: -self.pdf(x) * self.logpdf(x),
  4376. *self.support()).integral
  4377. def mode(self, *, method=None):
  4378. self._raise_if_method(method)
  4379. a, b = self.support()
  4380. def f(x): return -self.pdf(x)
  4381. res = _bracket_minimum(f, 1., xmin=a, xmax=b)
  4382. res = _chandrupatla_minimize(f, res.xl, res.xm, res.xr)
  4383. return res.x
  4384. def median(self, *, method=None):
  4385. self._raise_if_method(method)
  4386. return self.icdf(0.5)
  4387. def mean(self, *, method=None):
  4388. self._raise_if_method(method)
  4389. return self._sum('mean')
  4390. def variance(self, *, method=None):
  4391. self._raise_if_method(method)
  4392. return self._moment_central(2)
  4393. def standard_deviation(self, *, method=None):
  4394. self._raise_if_method(method)
  4395. return self.variance()**0.5
  4396. def skewness(self, *, method=None):
  4397. self._raise_if_method(method)
  4398. return self._moment_standardized(3)
  4399. def kurtosis(self, *, method=None):
  4400. self._raise_if_method(method)
  4401. return self._moment_standardized(4)
  4402. def moment(self, order=1, kind='raw', *, method=None):
  4403. self._raise_if_method(method)
  4404. kinds = {'raw': self._moment_raw,
  4405. 'central': self._moment_central,
  4406. 'standardized': self._moment_standardized}
  4407. order = ContinuousDistribution._validate_order_kind(self, order, kind, kinds)
  4408. moment_kind = kinds[kind]
  4409. return moment_kind(order)
  4410. def _moment_raw(self, order):
  4411. out = self._full(0)
  4412. for var, weight in zip(self._components, self._weights):
  4413. out += var.moment(order, kind='raw') * weight
  4414. return out[()]
  4415. def _moment_central(self, order):
  4416. order = int(order)
  4417. out = self._full(0)
  4418. for var, weight in zip(self._components, self._weights):
  4419. moment_as = [var.moment(order, kind='central')
  4420. for order in range(order + 1)]
  4421. a, b = var.mean(), self.mean()
  4422. moment = var._moment_transform_center(order, moment_as, a, b)
  4423. out += moment * weight
  4424. return out[()]
  4425. def _moment_standardized(self, order):
  4426. return self._moment_central(order) / self.standard_deviation()**order
  4427. def pdf(self, x, /, *, method=None):
  4428. self._raise_if_method(method)
  4429. return self._sum('pdf', x)
  4430. def logpdf(self, x, /, *, method=None):
  4431. self._raise_if_method(method)
  4432. return self._logsum('logpdf', x)
  4433. def pmf(self, x, /, *, method=None):
  4434. self._raise_if_method(method)
  4435. return self._sum('pmf', x)
  4436. def logpmf(self, x, /, *, method=None):
  4437. self._raise_if_method(method)
  4438. return self._logsum('logpmf', x)
  4439. def cdf(self, x, y=None, /, *, method=None):
  4440. self._raise_if_method(method)
  4441. args = (x,) if y is None else (x, y)
  4442. return self._sum('cdf', *args)
  4443. def logcdf(self, x, y=None, /, *, method=None):
  4444. self._raise_if_method(method)
  4445. args = (x,) if y is None else (x, y)
  4446. return self._logsum('logcdf', *args)
  4447. def ccdf(self, x, y=None, /, *, method=None):
  4448. self._raise_if_method(method)
  4449. args = (x,) if y is None else (x, y)
  4450. return self._sum('ccdf', *args)
  4451. def logccdf(self, x, y=None, /, *, method=None):
  4452. self._raise_if_method(method)
  4453. args = (x,) if y is None else (x, y)
  4454. return self._logsum('logccdf', *args)
  4455. def _invert(self, fun, p):
  4456. xmin, xmax = self.support()
  4457. fun = getattr(self, fun)
  4458. f = lambda x, p: fun(x) - p # noqa: E731 is silly
  4459. xl0, xr0 = _guess_bracket(xmin, xmax)
  4460. res = _bracket_root(f, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, args=(p,))
  4461. return _chandrupatla(f, a=res.xl, b=res.xr, args=(p,)).x
  4462. def icdf(self, p, /, *, method=None):
  4463. self._raise_if_method(method)
  4464. return self._invert('cdf', p)
  4465. def iccdf(self, p, /, *, method=None):
  4466. self._raise_if_method(method)
  4467. return self._invert('ccdf', p)
  4468. def ilogcdf(self, p, /, *, method=None):
  4469. self._raise_if_method(method)
  4470. return self._invert('logcdf', p)
  4471. def ilogccdf(self, p, /, *, method=None):
  4472. self._raise_if_method(method)
  4473. return self._invert('logccdf', p)
  4474. def sample(self, shape=(), *, rng=None, method=None):
  4475. self._raise_if_method(method)
  4476. rng = np.random.default_rng(rng)
  4477. size = np.prod(np.atleast_1d(shape))
  4478. ns = rng.multinomial(size, self._weights)
  4479. x = [var.sample(shape=n, rng=rng) for n, var in zip(ns, self._components)]
  4480. x = np.reshape(rng.permuted(np.concatenate(x)), shape)
  4481. return x[()]
  4482. def __repr__(self):
  4483. result = "Mixture(\n"
  4484. result += " [\n"
  4485. with np.printoptions(threshold=10):
  4486. for component in self.components:
  4487. result += f" {repr(component)},\n"
  4488. result += " ],\n"
  4489. result += f" weights={repr(self.weights)},\n"
  4490. result += ")"
  4491. return result
  4492. def __str__(self):
  4493. result = "Mixture(\n"
  4494. result += " [\n"
  4495. with np.printoptions(threshold=10):
  4496. for component in self.components:
  4497. result += f" {str(component)},\n"
  4498. result += " ],\n"
  4499. result += f" weights={str(self.weights)},\n"
  4500. result += ")"
  4501. return result
  4502. class MonotonicTransformedDistribution(TransformedDistribution):
  4503. r"""Distribution underlying a strictly monotonic function of a random variable
  4504. Given a random variable :math:`X`; a strictly monotonic function
  4505. :math:`g(u)`, its inverse :math:`h(u) = g^{-1}(u)`, and the derivative magnitude
  4506. :math: `|h'(u)| = \left| \frac{dh(u)}{du} \right|`, define the distribution
  4507. underlying the random variable :math:`Y = g(X)`.
  4508. Parameters
  4509. ----------
  4510. X : `ContinuousDistribution`
  4511. The random variable :math:`X`.
  4512. g, h, dh : callable
  4513. Elementwise functions representing the mathematical functions
  4514. :math:`g(u)`, :math:`h(u)`, and :math:`|h'(u)|`
  4515. logdh : callable, optional
  4516. Elementwise function representing :math:`\log(h'(u))`.
  4517. The default is ``lambda u: np.log(dh(u))``, but providing
  4518. a custom implementation may avoid over/underflow.
  4519. increasing : bool, optional
  4520. Whether the function is strictly increasing (True, default)
  4521. or strictly decreasing (False).
  4522. repr_pattern : str, optional
  4523. A string pattern for determining the __repr__. The __repr__
  4524. for X will be substituted into the position where `***` appears.
  4525. For example:
  4526. ``"exp(***)"`` for the repr of an exponentially transformed
  4527. distribution
  4528. The default is ``f"{g.__name__}(***)"``.
  4529. str_pattern : str, optional
  4530. A string pattern for determining `__str__`. The `__str__`
  4531. for X will be substituted into the position where `***` appears.
  4532. For example:
  4533. ``"exp(***)"`` for the repr of an exponentially transformed
  4534. distribution
  4535. The default is the value `repr_pattern` takes.
  4536. """
  4537. def __init__(self, X, /, *args, g, h, dh, logdh=None,
  4538. increasing=True, repr_pattern=None,
  4539. str_pattern=None, **kwargs):
  4540. super().__init__(X, *args, **kwargs)
  4541. self._g = g
  4542. self._h = h
  4543. self._dh = dh
  4544. self._logdh = (logdh if logdh is not None
  4545. else lambda u: np.log(dh(u)))
  4546. if increasing:
  4547. self._xdf = self._dist._cdf_dispatch
  4548. self._cxdf = self._dist._ccdf_dispatch
  4549. self._ixdf = self._dist._icdf_dispatch
  4550. self._icxdf = self._dist._iccdf_dispatch
  4551. self._logxdf = self._dist._logcdf_dispatch
  4552. self._logcxdf = self._dist._logccdf_dispatch
  4553. self._ilogxdf = self._dist._ilogcdf_dispatch
  4554. self._ilogcxdf = self._dist._ilogccdf_dispatch
  4555. else:
  4556. self._xdf = self._dist._ccdf_dispatch
  4557. self._cxdf = self._dist._cdf_dispatch
  4558. self._ixdf = self._dist._iccdf_dispatch
  4559. self._icxdf = self._dist._icdf_dispatch
  4560. self._logxdf = self._dist._logccdf_dispatch
  4561. self._logcxdf = self._dist._logcdf_dispatch
  4562. self._ilogxdf = self._dist._ilogccdf_dispatch
  4563. self._ilogcxdf = self._dist._ilogcdf_dispatch
  4564. self._increasing = increasing
  4565. self._repr_pattern = repr_pattern or f"{g.__name__}(***)"
  4566. self._str_pattern = str_pattern or self._repr_pattern
  4567. def __repr__(self):
  4568. with np.printoptions(threshold=10):
  4569. return self._repr_pattern.replace("***", repr(self._dist))
  4570. def __str__(self):
  4571. with np.printoptions(threshold=10):
  4572. return self._str_pattern.replace("***", str(self._dist))
  4573. def _overrides(self, method_name):
  4574. # Do not use the generic overrides of TransformedDistribution
  4575. return False
  4576. def _support(self, **params):
  4577. a, b = self._dist._support(**params)
  4578. # For reciprocal transformation, we want this zero to become -inf
  4579. b = np.where(b==0, np.asarray("-0", dtype=b.dtype), b)
  4580. with np.errstate(divide='ignore'):
  4581. if self._increasing:
  4582. return self._g(a), self._g(b)
  4583. else:
  4584. return self._g(b), self._g(a)
  4585. def _logpdf_dispatch(self, x, *args, **params):
  4586. return self._dist._logpdf_dispatch(self._h(x), *args, **params) + self._logdh(x)
  4587. def _pdf_dispatch(self, x, *args, **params):
  4588. return self._dist._pdf_dispatch(self._h(x), *args, **params) * self._dh(x)
  4589. def _logcdf_dispatch(self, x, *args, **params):
  4590. return self._logxdf(self._h(x), *args, **params)
  4591. def _cdf_dispatch(self, x, *args, **params):
  4592. return self._xdf(self._h(x), *args, **params)
  4593. def _logccdf_dispatch(self, x, *args, **params):
  4594. return self._logcxdf(self._h(x), *args, **params)
  4595. def _ccdf_dispatch(self, x, *args, **params):
  4596. return self._cxdf(self._h(x), *args, **params)
  4597. def _ilogcdf_dispatch(self, p, *args, **params):
  4598. return self._g(self._ilogxdf(p, *args, **params))
  4599. def _icdf_dispatch(self, p, *args, **params):
  4600. return self._g(self._ixdf(p, *args, **params))
  4601. def _ilogccdf_dispatch(self, p, *args, **params):
  4602. return self._g(self._ilogcxdf(p, *args, **params))
  4603. def _iccdf_dispatch(self, p, *args, **params):
  4604. return self._g(self._icxdf(p, *args, **params))
  4605. def _sample_dispatch(self, full_shape, *, method, rng, **params):
  4606. rvs = self._dist._sample_dispatch(full_shape, method=method, rng=rng, **params)
  4607. return self._g(rvs)
  4608. class FoldedDistribution(TransformedDistribution):
  4609. r"""Distribution underlying the absolute value of a random variable
  4610. Given a random variable :math:`X`; define the distribution
  4611. underlying the random variable :math:`Y = |X|`.
  4612. Parameters
  4613. ----------
  4614. X : `ContinuousDistribution`
  4615. The random variable :math:`X`.
  4616. Returns
  4617. -------
  4618. Y : `ContinuousDistribution`
  4619. The random variable :math:`Y = |X|`
  4620. """
  4621. # Many enhancements are possible if distribution is symmetric. Start
  4622. # with the general case; enhance later.
  4623. def __init__(self, X, /, *args, **kwargs):
  4624. super().__init__(X, *args, **kwargs)
  4625. # I think we need to allow `_support` to define whether the endpoints
  4626. # are inclusive or not. In the meantime, it's best to ensure that the lower
  4627. # endpoint (typically 0 for folded distribution) is inclusive so PDF evaluates
  4628. # correctly at that point.
  4629. self._variable.domain.inclusive = (True, self._variable.domain.inclusive[1])
  4630. def _overrides(self, method_name):
  4631. # Do not use the generic overrides of TransformedDistribution
  4632. return False
  4633. def _support(self, **params):
  4634. a, b = self._dist._support(**params)
  4635. a_, b_ = np.abs(a), np.abs(b)
  4636. a_, b_ = np.minimum(a_, b_), np.maximum(a_, b_)
  4637. i = (a < 0) & (b > 0)
  4638. a_ = np.asarray(a_)
  4639. a_[i] = 0
  4640. return a_[()], b_[()]
  4641. def _logpdf_dispatch(self, x, *args, method=None, **params):
  4642. x = np.abs(x)
  4643. right = self._dist._logpdf_dispatch(x, *args, method=method, **params)
  4644. left = self._dist._logpdf_dispatch(-x, *args, method=method, **params)
  4645. left = np.asarray(left)
  4646. right = np.asarray(right)
  4647. a, b = self._dist._support(**params)
  4648. left[-x < a] = -np.inf
  4649. right[x > b] = -np.inf
  4650. logpdfs = np.stack([left, right])
  4651. return special.logsumexp(logpdfs, axis=0)
  4652. def _pdf_dispatch(self, x, *args, method=None, **params):
  4653. x = np.abs(x)
  4654. right = self._dist._pdf_dispatch(x, *args, method=method, **params)
  4655. left = self._dist._pdf_dispatch(-x, *args, method=method, **params)
  4656. left = np.asarray(left)
  4657. right = np.asarray(right)
  4658. a, b = self._dist._support(**params)
  4659. left[-x < a] = 0
  4660. right[x > b] = 0
  4661. return left + right
  4662. def _logcdf_dispatch(self, x, *args, method=None, **params):
  4663. x = np.abs(x)
  4664. a, b = self._dist._support(**params)
  4665. xl = np.maximum(-x, a)
  4666. xr = np.minimum(x, b)
  4667. return self._dist._logcdf2_dispatch(xl, xr, *args, method=method, **params).real
  4668. def _cdf_dispatch(self, x, *args, method=None, **params):
  4669. x = np.abs(x)
  4670. a, b = self._dist._support(**params)
  4671. xl = np.maximum(-x, a)
  4672. xr = np.minimum(x, b)
  4673. return self._dist._cdf2_dispatch(xl, xr, *args, **params)
  4674. def _logccdf_dispatch(self, x, *args, method=None, **params):
  4675. x = np.abs(x)
  4676. a, b = self._dist._support(**params)
  4677. xl = np.maximum(-x, a)
  4678. xr = np.minimum(x, b)
  4679. return self._dist._logccdf2_dispatch(xl, xr, *args, method=method,
  4680. **params).real
  4681. def _ccdf_dispatch(self, x, *args, method=None, **params):
  4682. x = np.abs(x)
  4683. a, b = self._dist._support(**params)
  4684. xl = np.maximum(-x, a)
  4685. xr = np.minimum(x, b)
  4686. return self._dist._ccdf2_dispatch(xl, xr, *args, method=method, **params)
  4687. def _sample_dispatch(self, full_shape, *, method, rng, **params):
  4688. rvs = self._dist._sample_dispatch(full_shape, method=method, rng=rng, **params)
  4689. return np.abs(rvs)
  4690. def __repr__(self):
  4691. with np.printoptions(threshold=10):
  4692. return f"abs({repr(self._dist)})"
  4693. def __str__(self):
  4694. with np.printoptions(threshold=10):
  4695. return f"abs({str(self._dist)})"
  4696. @xp_capabilities(np_only=True)
  4697. def abs(X, /):
  4698. r"""Absolute value of a random variable
  4699. Parameters
  4700. ----------
  4701. X : `ContinuousDistribution`
  4702. The random variable :math:`X`.
  4703. Returns
  4704. -------
  4705. Y : `ContinuousDistribution`
  4706. A random variable :math:`Y = |X|`.
  4707. Examples
  4708. --------
  4709. Suppose we have a normally distributed random variable :math:`X`:
  4710. >>> import numpy as np
  4711. >>> from scipy import stats
  4712. >>> X = stats.Normal()
  4713. We wish to have a random variable :math:`Y` distributed according to
  4714. the folded normal distribution; that is, a random variable :math:`|X|`.
  4715. >>> Y = stats.abs(X)
  4716. The PDF of the distribution in the left half plane is "folded" over to
  4717. the right half plane. Because the normal PDF is symmetric, the resulting
  4718. PDF is zero for negative arguments and doubled for positive arguments.
  4719. >>> import matplotlib.pyplot as plt
  4720. >>> x = np.linspace(0, 5, 300)
  4721. >>> ax = plt.gca()
  4722. >>> Y.plot(x='x', y='pdf', t=('x', -1, 5), ax=ax)
  4723. >>> plt.plot(x, 2 * X.pdf(x), '--')
  4724. >>> plt.legend(('PDF of `Y`', 'Doubled PDF of `X`'))
  4725. >>> plt.show()
  4726. """
  4727. return FoldedDistribution(X)
  4728. @xp_capabilities(np_only=True)
  4729. def exp(X, /):
  4730. r"""Natural exponential of a random variable
  4731. Parameters
  4732. ----------
  4733. X : `ContinuousDistribution`
  4734. The random variable :math:`X`.
  4735. Returns
  4736. -------
  4737. Y : `ContinuousDistribution`
  4738. A random variable :math:`Y = \exp(X)`.
  4739. Examples
  4740. --------
  4741. Suppose we have a normally distributed random variable :math:`X`:
  4742. >>> import numpy as np
  4743. >>> from scipy import stats
  4744. >>> X = stats.Normal()
  4745. We wish to have a lognormally distributed random variable :math:`Y`,
  4746. a random variable whose natural logarithm is :math:`X`.
  4747. If :math:`X` is to be the natural logarithm of :math:`Y`, then we
  4748. must take :math:`Y` to be the natural exponential of :math:`X`.
  4749. >>> Y = stats.exp(X)
  4750. To demonstrate that ``X`` represents the logarithm of ``Y``,
  4751. we plot a normalized histogram of the logarithm of observations of
  4752. ``Y`` against the PDF underlying ``X``.
  4753. >>> import matplotlib.pyplot as plt
  4754. >>> rng = np.random.default_rng(435383595582522)
  4755. >>> y = Y.sample(shape=10000, rng=rng)
  4756. >>> ax = plt.gca()
  4757. >>> ax.hist(np.log(y), bins=50, density=True)
  4758. >>> X.plot(ax=ax)
  4759. >>> plt.legend(('PDF of `X`', 'histogram of `log(y)`'))
  4760. >>> plt.show()
  4761. """
  4762. return MonotonicTransformedDistribution(X, g=np.exp, h=np.log, dh=lambda u: 1 / u,
  4763. logdh=lambda u: -np.log(u))
  4764. @xp_capabilities(np_only=True)
  4765. def log(X, /):
  4766. r"""Natural logarithm of a non-negative random variable
  4767. Parameters
  4768. ----------
  4769. X : `ContinuousDistribution`
  4770. The random variable :math:`X` with positive support.
  4771. Returns
  4772. -------
  4773. Y : `ContinuousDistribution`
  4774. A random variable :math:`Y = \log(X)`.
  4775. Examples
  4776. --------
  4777. Suppose we have a gamma distributed random variable :math:`X`:
  4778. >>> import numpy as np
  4779. >>> from scipy import stats
  4780. >>> Gamma = stats.make_distribution(stats.gamma)
  4781. >>> X = Gamma(a=1.0)
  4782. We wish to have an exp-gamma distributed random variable :math:`Y`,
  4783. a random variable whose natural exponential is :math:`X`.
  4784. If :math:`X` is to be the natural exponential of :math:`Y`, then we
  4785. must take :math:`Y` to be the natural logarithm of :math:`X`.
  4786. >>> Y = stats.log(X)
  4787. To demonstrate that ``X`` represents the exponential of ``Y``,
  4788. we plot a normalized histogram of the exponential of observations of
  4789. ``Y`` against the PDF underlying ``X``.
  4790. >>> import matplotlib.pyplot as plt
  4791. >>> rng = np.random.default_rng(435383595582522)
  4792. >>> y = Y.sample(shape=10000, rng=rng)
  4793. >>> ax = plt.gca()
  4794. >>> ax.hist(np.exp(y), bins=50, density=True)
  4795. >>> X.plot(ax=ax)
  4796. >>> plt.legend(('PDF of `X`', 'histogram of `exp(y)`'))
  4797. >>> plt.show()
  4798. """
  4799. if np.any(X.support()[0] < 0):
  4800. message = ("The logarithm of a random variable is only implemented when the "
  4801. "support is non-negative.")
  4802. raise NotImplementedError(message)
  4803. return MonotonicTransformedDistribution(X, g=np.log, h=np.exp, dh=np.exp,
  4804. logdh=lambda u: u)