random.py 53 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596
  1. """Classes and functions related to pseudo-random number generation.
  2. This module deals with the generation of pseudo-random numbers.
  3. It provides the :class:`~imgaug.random.RNG` class, which is the primary
  4. random number generator in ``imgaug``. It also provides various utility
  5. functions related random number generation, such as copying random number
  6. generators or setting their state.
  7. The main benefit of this module is to hide the actually used random number
  8. generation classes and methods behin imgaug-specific classes and methods.
  9. This allows to deal with numpy using two different interfaces (one old
  10. interface in numpy <=1.16 and a new one in numpy 1.17+). It also allows
  11. to potentially switch to a different framework/library in the future.
  12. Definitions
  13. -----------
  14. - *numpy generator* or *numpy random number generator*: Usually an instance
  15. of :class:`numpy.random.Generator`. Can often also denote an instance
  16. of :class:`numpy.random.RandomState` as both have almost the same interface.
  17. - *RandomState*: An instance of `numpy.random.RandomState`.
  18. Note that outside of this module, the term "random state" often roughly
  19. translates to "any random number generator with numpy-like interface
  20. in a given state", i.e. it can then include instances of
  21. :class:`numpy.random.Generator` or :class:`~imgaug.random.RNG`.
  22. - *RNG*: An instance of :class:`~imgaug.random.RNG`.
  23. Examples
  24. --------
  25. >>> import imgaug.random as iarandom
  26. >>> rng = iarandom.RNG(1234)
  27. >>> rng.integers(0, 1000)
  28. Initialize a random number generator with seed ``1234``, then sample
  29. a single integer from the discrete interval ``[0, 1000)``.
  30. This will use a :class:`numpy.random.Generator` in numpy 1.17+ and
  31. automatically fall back to :class:`numpy.random.RandomState` in numpy <=1.16.
  32. """
  33. from __future__ import print_function, division, absolute_import
  34. import copy as copylib
  35. import numpy as np
  36. import six.moves as sm
  37. # Check if numpy is version 1.17 or later. In that version, the new random
  38. # number interface was added.
  39. # Note that a valid version number can also be "1.18.0.dev0+285ab1d",
  40. # in which the last component cannot easily be converted to an int. Hence we
  41. # only pick the first two components.
  42. SUPPORTS_NEW_NP_RNG_STYLE = False
  43. BIT_GENERATOR = None
  44. _NP_VERSION = list(map(int, np.__version__.split(".")[0:2]))
  45. if _NP_VERSION[0] > 1 or _NP_VERSION[1] >= 17:
  46. SUPPORTS_NEW_NP_RNG_STYLE = True
  47. BIT_GENERATOR = np.random.SFC64 # pylint: disable=invalid-name
  48. # interface of BitGenerator
  49. # in 1.17 this was at numpy.random.bit_generator.BitGenerator
  50. # in 1.18 this was moved to numpy.random.BitGenerator
  51. # pylint: disable=invalid-name, no-member
  52. if _NP_VERSION[1] == 17:
  53. # Added in 0.4.0.
  54. _BIT_GENERATOR_INTERFACE = np.random.bit_generator.BitGenerator
  55. else:
  56. # Added in 0.4.0.
  57. _BIT_GENERATOR_INTERFACE = np.random.BitGenerator
  58. # pylint: enable=invalid-name, no-member
  59. # We instantiate a current/global random state here once.
  60. GLOBAL_RNG = None
  61. # use 2**31 instead of 2**32 as the maximum here, because 2**31 errored on
  62. # some systems
  63. SEED_MIN_VALUE = 0
  64. SEED_MAX_VALUE = 2**31-1
  65. # TODO decrease pool_size in SeedSequence to 2 or 1?
  66. # TODO add 'with resetted_rng(...)'
  67. # TODO change random_state to rng or seed
  68. class RNG(object):
  69. """
  70. Random number generator for imgaug.
  71. This class is a wrapper around ``numpy.random.Generator`` and
  72. automatically falls back to ``numpy.random.RandomState`` in case of
  73. numpy version 1.16 or lower. It allows to use numpy 1.17's sampling
  74. functions in 1.16 too and supports a variety of useful functions on
  75. the wrapped sampler, e.g. gettings its state or copying it.
  76. Not supported sampling functions of numpy <=1.16:
  77. * :func:`numpy.random.RandomState.rand`
  78. * :func:`numpy.random.RandomState.randn`
  79. * :func:`numpy.random.RandomState.randint`
  80. * :func:`numpy.random.RandomState.random_integers`
  81. * :func:`numpy.random.RandomState.random_sample`
  82. * :func:`numpy.random.RandomState.ranf`
  83. * :func:`numpy.random.RandomState.sample`
  84. * :func:`numpy.random.RandomState.seed`
  85. * :func:`numpy.random.RandomState.get_state`
  86. * :func:`numpy.random.RandomState.set_state`
  87. In :func:`~imgaug.random.RNG.choice`, the `axis` argument is not yet
  88. supported.
  89. Parameters
  90. ----------
  91. generator : None or int or RNG or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState
  92. The numpy random number generator to use. In case of numpy
  93. version 1.17 or later, this shouldn't be a ``RandomState`` as that
  94. class is outdated.
  95. Behaviour for different datatypes:
  96. * If ``None``: The global RNG is wrapped by this RNG (they are then
  97. effectively identical, any sampling on this RNG will affect the
  98. global RNG).
  99. * If ``int``: In numpy 1.17+, the value is used as a seed for a
  100. ``Generator`` wrapped by this RNG. I.e. it will be provided as the
  101. entropy to a ``SeedSequence``, which will then be used for an
  102. ``SFC64`` bit generator and wrapped by a ``Generator``.
  103. In numpy <=1.16, the value is used as a seed for a ``RandomState``,
  104. which is then wrapped by this RNG.
  105. * If :class:`RNG`: That RNG's ``generator`` attribute will be used
  106. as the generator for this RNG, i.e. the same as
  107. ``RNG(other_rng.generator)``.
  108. * If :class:`numpy.random.Generator`: That generator will be wrapped.
  109. * If :class:`numpy.random.BitGenerator`: A numpy
  110. generator will be created (and wrapped by this RNG) that contains
  111. the bit generator.
  112. * If :class:`numpy.random.SeedSequence`: A numpy
  113. generator will be created (and wrapped by this RNG) that contains
  114. an ``SFC64`` bit generator initialized with the given
  115. ``SeedSequence``.
  116. * If :class:`numpy.random.RandomState`: In numpy <=1.16, this
  117. ``RandomState`` will be wrapped and used to sample random values.
  118. In numpy 1.17+, a seed will be derived from this ``RandomState``
  119. and a new ``numpy.generator.Generator`` based on an ``SFC64``
  120. bit generator will be created and wrapped by this RNG.
  121. """
  122. # TODO add maybe a __new__ here that feeds-through an RNG input without
  123. # wrapping it in RNG(rng_input)?
  124. def __init__(self, generator):
  125. if isinstance(generator, RNG):
  126. self.generator = generator.generator
  127. else:
  128. self.generator = normalize_generator_(generator)
  129. self._is_new_rng_style = (
  130. not isinstance(self.generator, np.random.RandomState))
  131. @property
  132. def state(self):
  133. """Get the state of this RNG.
  134. Returns
  135. -------
  136. tuple or dict
  137. The state of the RNG.
  138. In numpy 1.17+, the bit generator's state will be returned.
  139. In numpy <=1.16, the ``RandomState`` 's state is returned.
  140. In both cases the state is a copy. In-place changes will not affect
  141. the RNG.
  142. """
  143. return get_generator_state(self.generator)
  144. @state.setter
  145. def state(self, value):
  146. """Set the state if the RNG in-place.
  147. Parameters
  148. ----------
  149. value : tuple or dict
  150. The new state of the RNG.
  151. Should correspond to the output of the ``state`` property.
  152. """
  153. self.set_state_(value)
  154. def set_state_(self, value):
  155. """Set the state if the RNG in-place.
  156. Parameters
  157. ----------
  158. value : tuple or dict
  159. The new state of the RNG.
  160. Should correspond to the output of the ``state`` property.
  161. Returns
  162. -------
  163. RNG
  164. The RNG itself.
  165. """
  166. set_generator_state_(self.generator, value)
  167. return self
  168. def use_state_of_(self, other):
  169. """Copy and use (in-place) the state of another RNG.
  170. .. note::
  171. It is often sensible to first verify that neither this RNG nor
  172. `other` are identical to the global RNG.
  173. Parameters
  174. ----------
  175. other : RNG
  176. The other RNG, which's state will be copied.
  177. Returns
  178. -------
  179. RNG
  180. The RNG itself.
  181. """
  182. return self.set_state_(other.state)
  183. def is_global_rng(self):
  184. """Estimate whether this RNG is identical to the global RNG.
  185. Returns
  186. -------
  187. bool
  188. ``True`` is this RNG's underlying generator is identical to the
  189. global RNG's underlying generator. The RNGs themselves may
  190. be different, only the wrapped generator matters.
  191. ``False`` otherwise.
  192. """
  193. # We use .generator here, because otherwise RNG(global_rng) would be
  194. # viewed as not-identical to the global RNG, even though its generator
  195. # and bit generator are identical.
  196. return get_global_rng().generator is self.generator
  197. def equals_global_rng(self):
  198. """Estimate whether this RNG has the same state as the global RNG.
  199. Returns
  200. -------
  201. bool
  202. ``True`` is this RNG has the same state as the global RNG, i.e.
  203. it will lead to the same sampled values given the same sampling
  204. method calls. The RNGs *don't* have to be identical object
  205. instances, which protects against e.g. copy effects.
  206. ``False`` otherwise.
  207. """
  208. return get_global_rng().equals(self)
  209. def generate_seed_(self):
  210. """Sample a random seed.
  211. This advances the underlying generator's state.
  212. See ``SEED_MIN_VALUE`` and ``SEED_MAX_VALUE`` for the seed's value
  213. range.
  214. Returns
  215. -------
  216. int
  217. The sampled seed.
  218. """
  219. return generate_seed_(self.generator)
  220. def generate_seeds_(self, n):
  221. """Generate `n` random seed values.
  222. This advances the underlying generator's state.
  223. See ``SEED_MIN_VALUE`` and ``SEED_MAX_VALUE`` for the seed's value
  224. range.
  225. Parameters
  226. ----------
  227. n : int
  228. Number of seeds to sample.
  229. Returns
  230. -------
  231. ndarray
  232. 1D-array of ``int32`` seeds.
  233. """
  234. return generate_seeds_(self.generator, n)
  235. def reset_cache_(self):
  236. """Reset all cache of this RNG.
  237. Returns
  238. -------
  239. RNG
  240. The RNG itself.
  241. """
  242. reset_generator_cache_(self.generator)
  243. return self
  244. def derive_rng_(self):
  245. """Create a child RNG.
  246. This advances the underlying generator's state.
  247. Returns
  248. -------
  249. RNG
  250. A child RNG.
  251. """
  252. return self.derive_rngs_(1)[0]
  253. def derive_rngs_(self, n):
  254. """Create `n` child RNGs.
  255. This advances the underlying generator's state.
  256. Parameters
  257. ----------
  258. n : int
  259. Number of child RNGs to derive.
  260. Returns
  261. -------
  262. list of RNG
  263. Child RNGs.
  264. """
  265. return [RNG(gen) for gen in derive_generators_(self.generator, n)]
  266. def equals(self, other):
  267. """Estimate whether this RNG and `other` have the same state.
  268. Returns
  269. -------
  270. bool
  271. ``True`` if this RNG's generator and the generator of `other`
  272. have equal internal states. ``False`` otherwise.
  273. """
  274. assert isinstance(other, RNG), (
  275. "Expected 'other' to be an RNG, got type %s. "
  276. "Use imgaug.random.is_generator_equal_to() to compare "
  277. "numpy generators or RandomStates." % (type(other),))
  278. return is_generator_equal_to(self.generator, other.generator)
  279. def advance_(self):
  280. """Advance the RNG's internal state in-place by one step.
  281. This advances the underlying generator's state.
  282. .. note::
  283. This simply samples one or more random values. This means that
  284. a call of this method will not completely change the outputs of
  285. the next called sampling method. To achieve more drastic output
  286. changes, call :func:`~imgaug.random.RNG.derive_rng_`.
  287. Returns
  288. -------
  289. RNG
  290. The RNG itself.
  291. """
  292. advance_generator_(self.generator)
  293. return self
  294. def copy(self):
  295. """Create a copy of this RNG.
  296. Returns
  297. -------
  298. RNG
  299. Copy of this RNG. The copy will produce the same random samples.
  300. """
  301. return RNG(copy_generator(self.generator))
  302. def copy_unless_global_rng(self):
  303. """Create a copy of this RNG unless it is the global RNG.
  304. Returns
  305. -------
  306. RNG
  307. Copy of this RNG unless it is the global RNG. In the latter case
  308. the RNG instance itself will be returned without any changes.
  309. """
  310. if self.is_global_rng():
  311. return self
  312. return self.copy()
  313. def duplicate(self, n):
  314. """Create a list containing `n` times this RNG.
  315. This method was mainly introduced as a replacement for previous
  316. calls of :func:`~imgaug.random.RNG.derive_rngs_`. These calls
  317. turned out to be very slow in numpy 1.17+ and were hence replaced
  318. by simple duplication (except for the cases where child RNGs
  319. absolutely *had* to be created).
  320. This RNG duplication method doesn't help very much against code
  321. repetition, but it does *mark* the points where it would be desirable
  322. to create child RNGs for various reasons. Once deriving child RNGs
  323. is somehow sped up in the future, these calls can again be
  324. easily found and replaced.
  325. Parameters
  326. ----------
  327. n : int
  328. Length of the output list.
  329. Returns
  330. -------
  331. list of RNG
  332. List containing `n` times this RNG (same instances, no copies).
  333. """
  334. return [self for _ in sm.xrange(n)]
  335. @classmethod
  336. def create_fully_random(cls):
  337. """Create a new RNG, based on entropy provided from the OS.
  338. Returns
  339. -------
  340. RNG
  341. A new RNG. It is not derived from any other previously created
  342. RNG, nor does it depend on the seeding of imgaug or numpy.
  343. """
  344. return RNG(create_fully_random_generator())
  345. @classmethod
  346. def create_pseudo_random_(cls):
  347. """Create a new RNG in pseudo-random fashion.
  348. A seed will be sampled from the current global RNG and used to
  349. initialize the new RNG.
  350. This advandes the global RNG's state.
  351. Returns
  352. -------
  353. RNG
  354. A new RNG, derived from the current global RNG.
  355. """
  356. return get_global_rng().derive_rng_()
  357. ###########################################################################
  358. # Below:
  359. # Aliases for methods of numpy.random.Generator functions
  360. #
  361. # The methods below could also be handled with less code using some magic
  362. # methods. Explicitly writing things down here has the advantage that
  363. # the methods actually appear in the autogenerated API.
  364. ###########################################################################
  365. def integers(self, low, high=None, size=None, dtype="int32",
  366. endpoint=False):
  367. """Call numpy's ``integers()`` or ``randint()``.
  368. .. note::
  369. Changed `dtype` argument default value from numpy's ``int64`` to
  370. ``int32``.
  371. """
  372. return polyfill_integers(
  373. self.generator, low=low, high=high, size=size, dtype=dtype,
  374. endpoint=endpoint)
  375. def random(self, size, dtype="float32", out=None):
  376. """Call numpy's ``random()`` or ``random_sample()``.
  377. .. note::
  378. Changed `dtype` argument default value from numpy's ``d`` to
  379. ``float32``.
  380. """
  381. return polyfill_random(
  382. self.generator, size=size, dtype=dtype, out=out)
  383. # TODO add support for Generator's 'axis' argument
  384. def choice(self, a, size=None, replace=True, p=None):
  385. """Call :func:`numpy.random.Generator.choice`."""
  386. # pylint: disable=invalid-name
  387. return self.generator.choice(a=a, size=size, replace=replace, p=p)
  388. def bytes(self, length):
  389. """Call :func:`numpy.random.Generator.bytes`."""
  390. return self.generator.bytes(length=length)
  391. # TODO mark in-place
  392. def shuffle(self, x):
  393. """Call :func:`numpy.random.Generator.shuffle`."""
  394. # note that shuffle() does not allow keyword arguments
  395. # note that shuffle() works in-place
  396. self.generator.shuffle(x)
  397. def permutation(self, x):
  398. """Call :func:`numpy.random.Generator.permutation`."""
  399. # note that permutation() does not allow keyword arguments
  400. return self.generator.permutation(x)
  401. def beta(self, a, b, size=None):
  402. """Call :func:`numpy.random.Generator.beta`."""
  403. # pylint: disable=invalid-name
  404. return self.generator.beta(a=a, b=b, size=size)
  405. def binomial(self, n, p, size=None):
  406. """Call :func:`numpy.random.Generator.binomial`."""
  407. return self.generator.binomial(n=n, p=p, size=size)
  408. def chisquare(self, df, size=None):
  409. """Call :func:`numpy.random.Generator.chisquare`."""
  410. # pylint: disable=invalid-name
  411. return self.generator.chisquare(df=df, size=size)
  412. def dirichlet(self, alpha, size=None):
  413. """Call :func:`numpy.random.Generator.dirichlet`."""
  414. return self.generator.dirichlet(alpha=alpha, size=size)
  415. def exponential(self, scale=1.0, size=None):
  416. """Call :func:`numpy.random.Generator.exponential`."""
  417. return self.generator.exponential(scale=scale, size=size)
  418. def f(self, dfnum, dfden, size=None):
  419. """Call :func:`numpy.random.Generator.f`."""
  420. return self.generator.f(dfnum=dfnum, dfden=dfden, size=size)
  421. def gamma(self, shape, scale=1.0, size=None):
  422. """Call :func:`numpy.random.Generator.gamma`."""
  423. return self.generator.gamma(shape=shape, scale=scale, size=size)
  424. def geometric(self, p, size=None):
  425. """Call :func:`numpy.random.Generator.geometric`."""
  426. return self.generator.geometric(p=p, size=size)
  427. def gumbel(self, loc=0.0, scale=1.0, size=None):
  428. """Call :func:`numpy.random.Generator.gumbel`."""
  429. return self.generator.gumbel(loc=loc, scale=scale, size=size)
  430. def hypergeometric(self, ngood, nbad, nsample, size=None):
  431. """Call :func:`numpy.random.Generator.hypergeometric`."""
  432. return self.generator.hypergeometric(
  433. ngood=ngood, nbad=nbad, nsample=nsample, size=size)
  434. def laplace(self, loc=0.0, scale=1.0, size=None):
  435. """Call :func:`numpy.random.Generator.laplace`."""
  436. return self.generator.laplace(loc=loc, scale=scale, size=size)
  437. def logistic(self, loc=0.0, scale=1.0, size=None):
  438. """Call :func:`numpy.random.Generator.logistic`."""
  439. return self.generator.logistic(loc=loc, scale=scale, size=size)
  440. def lognormal(self, mean=0.0, sigma=1.0, size=None):
  441. """Call :func:`numpy.random.Generator.lognormal`."""
  442. return self.generator.lognormal(mean=mean, sigma=sigma, size=size)
  443. def logseries(self, p, size=None):
  444. """Call :func:`numpy.random.Generator.logseries`."""
  445. return self.generator.logseries(p=p, size=size)
  446. def multinomial(self, n, pvals, size=None):
  447. """Call :func:`numpy.random.Generator.multinomial`."""
  448. return self.generator.multinomial(n=n, pvals=pvals, size=size)
  449. def multivariate_normal(self, mean, cov, size=None, check_valid="warn",
  450. tol=1e-8):
  451. """Call :func:`numpy.random.Generator.multivariate_normal`."""
  452. return self.generator.multivariate_normal(
  453. mean=mean, cov=cov, size=size, check_valid=check_valid, tol=tol)
  454. def negative_binomial(self, n, p, size=None):
  455. """Call :func:`numpy.random.Generator.negative_binomial`."""
  456. return self.generator.negative_binomial(n=n, p=p, size=size)
  457. def noncentral_chisquare(self, df, nonc, size=None):
  458. """Call :func:`numpy.random.Generator.noncentral_chisquare`."""
  459. # pylint: disable=invalid-name
  460. return self.generator.noncentral_chisquare(df=df, nonc=nonc, size=size)
  461. def noncentral_f(self, dfnum, dfden, nonc, size=None):
  462. """Call :func:`numpy.random.Generator.noncentral_f`."""
  463. return self.generator.noncentral_f(
  464. dfnum=dfnum, dfden=dfden, nonc=nonc, size=size)
  465. def normal(self, loc=0.0, scale=1.0, size=None):
  466. """Call :func:`numpy.random.Generator.normal`."""
  467. return self.generator.normal(loc=loc, scale=scale, size=size)
  468. def pareto(self, a, size=None):
  469. """Call :func:`numpy.random.Generator.pareto`."""
  470. # pylint: disable=invalid-name
  471. return self.generator.pareto(a=a, size=size)
  472. def poisson(self, lam=1.0, size=None):
  473. """Call :func:`numpy.random.Generator.poisson`."""
  474. return self.generator.poisson(lam=lam, size=size)
  475. def power(self, a, size=None):
  476. """Call :func:`numpy.random.Generator.power`."""
  477. # pylint: disable=invalid-name
  478. return self.generator.power(a=a, size=size)
  479. def rayleigh(self, scale=1.0, size=None):
  480. """Call :func:`numpy.random.Generator.rayleigh`."""
  481. return self.generator.rayleigh(scale=scale, size=size)
  482. def standard_cauchy(self, size=None):
  483. """Call :func:`numpy.random.Generator.standard_cauchy`."""
  484. return self.generator.standard_cauchy(size=size)
  485. def standard_exponential(self, size=None, dtype="float32", method="zig",
  486. out=None):
  487. """Call :func:`numpy.random.Generator.standard_exponential`.
  488. .. note::
  489. Changed `dtype` argument default value from numpy's ``d`` to
  490. ``float32``.
  491. """
  492. if self._is_new_rng_style:
  493. return self.generator.standard_exponential(
  494. size=size, dtype=dtype, method=method, out=out)
  495. result = self.generator.standard_exponential(size=size).astype(dtype)
  496. if out is not None:
  497. assert out.dtype.name == result.dtype.name, (
  498. "Expected out array to have the same dtype as "
  499. "standard_exponential()'s result array. Got %s (out) and "
  500. "%s (result) instead." % (out.dtype.name, result.dtype.name))
  501. out[...] = result
  502. return result
  503. def standard_gamma(self, shape, size=None, dtype="float32", out=None):
  504. """Call :func:`numpy.random.Generator.standard_gamma`.
  505. .. note::
  506. Changed `dtype` argument default value from numpy's ``d`` to
  507. ``float32``.
  508. """
  509. if self._is_new_rng_style:
  510. return self.generator.standard_gamma(
  511. shape=shape, size=size, dtype=dtype, out=out)
  512. result = self.generator.standard_gamma(
  513. shape=shape, size=size).astype(dtype)
  514. if out is not None:
  515. assert out.dtype.name == result.dtype.name, (
  516. "Expected out array to have the same dtype as "
  517. "standard_gamma()'s result array. Got %s (out) and "
  518. "%s (result) instead." % (out.dtype.name, result.dtype.name))
  519. out[...] = result
  520. return result
  521. def standard_normal(self, size=None, dtype="float32", out=None):
  522. """Call :func:`numpy.random.Generator.standard_normal`.
  523. .. note::
  524. Changed `dtype` argument default value from numpy's ``d`` to
  525. ``float32``.
  526. """
  527. if self._is_new_rng_style:
  528. return self.generator.standard_normal(
  529. size=size, dtype=dtype, out=out)
  530. result = self.generator.standard_normal(size=size).astype(dtype)
  531. if out is not None:
  532. assert out.dtype.name == result.dtype.name, (
  533. "Expected out array to have the same dtype as "
  534. "standard_normal()'s result array. Got %s (out) and "
  535. "%s (result) instead." % (out.dtype.name, result.dtype.name))
  536. out[...] = result
  537. return result
  538. def standard_t(self, df, size=None):
  539. """Call :func:`numpy.random.Generator.standard_t`."""
  540. # pylint: disable=invalid-name
  541. return self.generator.standard_t(df=df, size=size)
  542. def triangular(self, left, mode, right, size=None):
  543. """Call :func:`numpy.random.Generator.triangular`."""
  544. return self.generator.triangular(
  545. left=left, mode=mode, right=right, size=size)
  546. def uniform(self, low=0.0, high=1.0, size=None):
  547. """Call :func:`numpy.random.Generator.uniform`."""
  548. return self.generator.uniform(low=low, high=high, size=size)
  549. def vonmises(self, mu, kappa, size=None):
  550. """Call :func:`numpy.random.Generator.vonmises`."""
  551. # pylint: disable=invalid-name
  552. return self.generator.vonmises(mu=mu, kappa=kappa, size=size)
  553. def wald(self, mean, scale, size=None):
  554. """Call :func:`numpy.random.Generator.wald`."""
  555. return self.generator.wald(mean=mean, scale=scale, size=size)
  556. def weibull(self, a, size=None):
  557. """Call :func:`numpy.random.Generator.weibull`."""
  558. # pylint: disable=invalid-name
  559. return self.generator.weibull(a=a, size=size)
  560. def zipf(self, a, size=None):
  561. """Call :func:`numpy.random.Generator.zipf`."""
  562. # pylint: disable=invalid-name
  563. return self.generator.zipf(a=a, size=size)
  564. ##################################################################
  565. # Outdated methods from RandomState
  566. # These are added here for backwards compatibility in case of old
  567. # custom augmenters and Lambda calls that rely on the RandomState
  568. # API.
  569. ##################################################################
  570. def rand(self, *args):
  571. """Call :func:`numpy.random.RandomState.rand`.
  572. .. warning::
  573. This method is outdated in numpy. Use :func:`RNG.random` instead.
  574. Added in 0.4.0.
  575. """
  576. return self.random(size=args)
  577. def randint(self, low, high=None, size=None, dtype="int32"):
  578. """Call :func:`numpy.random.RandomState.randint`.
  579. .. note::
  580. Changed `dtype` argument default value from numpy's ``I`` to
  581. ``int32``.
  582. .. warning::
  583. This method is outdated in numpy. Use :func:`RNG.integers`
  584. instead.
  585. Added in 0.4.0.
  586. """
  587. return self.integers(low=low, high=high, size=size, dtype=dtype,
  588. endpoint=False)
  589. def randn(self, *args):
  590. """Call :func:`numpy.random.RandomState.randn`.
  591. .. warning::
  592. This method is outdated in numpy. Use :func:`RNG.standard_normal`
  593. instead.
  594. Added in 0.4.0.
  595. """
  596. return self.standard_normal(size=args)
  597. def random_integers(self, low, high=None, size=None):
  598. """Call :func:`numpy.random.RandomState.random_integers`.
  599. .. warning::
  600. This method is outdated in numpy. Use :func:`RNG.integers`
  601. instead.
  602. Added in 0.4.0.
  603. """
  604. if high is None:
  605. return self.integers(low=1, high=low, size=size, endpoint=True)
  606. return self.integers(low=low, high=high, size=size, endpoint=True)
  607. def random_sample(self, size):
  608. """Call :func:`numpy.random.RandomState.random_sample`.
  609. .. warning::
  610. This method is outdated in numpy. Use :func:`RNG.uniform`
  611. instead.
  612. Added in 0.4.0.
  613. """
  614. return self.uniform(0.0, 1.0, size=size)
  615. def tomaxint(self, size=None):
  616. """Call :func:`numpy.random.RandomState.tomaxint`.
  617. .. warning::
  618. This method is outdated in numpy. Use :func:`RNG.integers`
  619. instead.
  620. Added in 0.4.0.
  621. """
  622. import sys
  623. maxint = sys.maxsize
  624. int32max = np.iinfo(np.int32).max
  625. return self.integers(0, min(maxint, int32max), size=size,
  626. endpoint=True)
  627. def supports_new_numpy_rng_style():
  628. """
  629. Determine whether numpy supports the new ``random`` interface (v1.17+).
  630. Returns
  631. -------
  632. bool
  633. ``True`` if the new ``random`` interface is supported by numpy, i.e.
  634. if numpy has version 1.17 or later. Otherwise ``False``, i.e.
  635. numpy has version 1.16 or older and ``numpy.random.RandomState``
  636. should be used instead.
  637. """
  638. return SUPPORTS_NEW_NP_RNG_STYLE
  639. def get_global_rng():
  640. """
  641. Get or create the current global RNG of imgaug.
  642. Note that the first call to this function will create a global RNG.
  643. Returns
  644. -------
  645. RNG
  646. The global RNG to use.
  647. """
  648. # TODO change global_rng to singleton
  649. # pylint: disable=global-statement, redefined-outer-name
  650. global GLOBAL_RNG
  651. if GLOBAL_RNG is None:
  652. # This uses numpy's random state to sample a seed.
  653. # Alternatively, `secrets.randbits(n_bits)` (3.6+) and
  654. # `os.urandom(n_bytes)` could be used.
  655. # See https://stackoverflow.com/a/27286733/3760780
  656. # for an explanation how random.seed() picks a random seed value.
  657. seed = generate_seed_(np.random)
  658. GLOBAL_RNG = RNG(convert_seed_to_generator(seed))
  659. return GLOBAL_RNG
  660. # This is an in-place operation, but does not use a trailing slash to indicate
  661. # that in order to match the interface of `random` and `numpy.random`.
  662. def seed(entropy):
  663. """Set the seed of imgaug's global RNG (in-place).
  664. The global RNG controls most of the "randomness" in imgaug.
  665. The global RNG is the default one used by all augmenters. Under special
  666. circumstances (e.g. when an augmenter is switched to deterministic mode),
  667. the global RNG is replaced with a local one. The state of that replacement
  668. may be dependent on the global RNG's state at the time of creating the
  669. child RNG.
  670. Parameters
  671. ----------
  672. entropy : int
  673. The seed value to use.
  674. """
  675. if SUPPORTS_NEW_NP_RNG_STYLE:
  676. _seed_np117_(entropy)
  677. else:
  678. _seed_np116_(entropy)
  679. def _seed_np117_(entropy):
  680. # We can't easily seed a BitGenerator in-place, nor can we easily modify
  681. # a Generator's bit_generator in-place. So instead we create a new
  682. # bit generator and set the current global RNG's internal bit generator
  683. # state to a copy of the new bit generator's state.
  684. get_global_rng().state = BIT_GENERATOR(entropy).state
  685. def _seed_np116_(entropy):
  686. get_global_rng().generator.seed(entropy)
  687. def normalize_generator(generator):
  688. """Normalize various inputs to a numpy (random number) generator.
  689. This function will first copy the provided argument, i.e. it never returns
  690. a provided instance itself.
  691. Parameters
  692. ----------
  693. generator : None or int or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState
  694. The numpy random number generator to normalize. In case of numpy
  695. version 1.17 or later, this shouldn't be a ``RandomState`` as that
  696. class is outdated.
  697. Behaviour for different datatypes:
  698. * If ``None``: The global RNG's generator is returned.
  699. * If ``int``: In numpy 1.17+, the value is used as a seed for a
  700. ``Generator``, i.e. it will be provided as the entropy to a
  701. ``SeedSequence``, which will then be used for an ``SFC64`` bit
  702. generator and wrapped by a ``Generator``, which is then returned.
  703. In numpy <=1.16, the value is used as a seed for a ``RandomState``,
  704. which will then be returned.
  705. * If :class:`numpy.random.Generator`: That generator will be
  706. returned.
  707. * If :class:`numpy.random.BitGenerator`: A numpy
  708. generator will be created and returned that contains the bit
  709. generator.
  710. * If :class:`numpy.random.SeedSequence`: A numpy
  711. generator will be created and returned that contains an ``SFC64``
  712. bit generator initialized with the given ``SeedSequence``.
  713. * If :class:`numpy.random.RandomState`: In numpy <=1.16, this
  714. ``RandomState`` will be returned. In numpy 1.17+, a seed will be
  715. derived from this ``RandomState`` and a new
  716. ``numpy.generator.Generator`` based on an ``SFC64`` bit generator
  717. will be created and returned.
  718. Returns
  719. -------
  720. numpy.random.Generator or numpy.random.RandomState
  721. In numpy <=1.16 a ``RandomState``, in 1.17+ a ``Generator`` (even if
  722. the input was a ``RandomState``).
  723. """
  724. return normalize_generator_(copylib.deepcopy(generator))
  725. def normalize_generator_(generator):
  726. """Normalize in-place various inputs to a numpy (random number) generator.
  727. This function will try to return the provided instance itself.
  728. Parameters
  729. ----------
  730. generator : None or int or numpy.random.Generator or numpy.random.BitGenerator or numpy.random.SeedSequence or numpy.random.RandomState
  731. See :func:`~imgaug.random.normalize_generator`.
  732. Returns
  733. -------
  734. numpy.random.Generator or numpy.random.RandomState
  735. In numpy <=1.16 a ``RandomState``, in 1.17+ a ``Generator`` (even if
  736. the input was a ``RandomState``).
  737. """
  738. if not SUPPORTS_NEW_NP_RNG_STYLE:
  739. return _normalize_generator_np116_(generator)
  740. return _normalize_generator_np117_(generator)
  741. def _normalize_generator_np117_(generator):
  742. if generator is None:
  743. return get_global_rng().generator
  744. if isinstance(generator, np.random.SeedSequence):
  745. return np.random.Generator(
  746. BIT_GENERATOR(generator)
  747. )
  748. if isinstance(generator, _BIT_GENERATOR_INTERFACE):
  749. generator = np.random.Generator(generator)
  750. # TODO is it necessary/sensible here to reset the cache?
  751. reset_generator_cache_(generator)
  752. return generator
  753. if isinstance(generator, np.random.Generator):
  754. # TODO is it necessary/sensible here to reset the cache?
  755. reset_generator_cache_(generator)
  756. return generator
  757. if isinstance(generator, np.random.RandomState):
  758. # TODO warn
  759. # TODO reset the cache here too?
  760. return convert_seed_to_generator(generate_seed_(generator))
  761. # seed given
  762. seed_ = generator
  763. return convert_seed_to_generator(seed_)
  764. def _normalize_generator_np116_(random_state):
  765. if random_state is None:
  766. return get_global_rng().generator
  767. if isinstance(random_state, np.random.RandomState):
  768. # TODO reset the cache here, like in np117?
  769. return random_state
  770. # seed given
  771. seed_ = random_state
  772. return convert_seed_to_generator(seed_)
  773. def convert_seed_to_generator(entropy):
  774. """Convert a seed value to a numpy (random number) generator.
  775. Parameters
  776. ----------
  777. entropy : int
  778. The seed value to use.
  779. Returns
  780. -------
  781. numpy.random.Generator or numpy.random.RandomState
  782. In numpy <=1.16 a ``RandomState``, in 1.17+ a ``Generator``.
  783. Both are initialized with the provided seed.
  784. """
  785. if not SUPPORTS_NEW_NP_RNG_STYLE:
  786. return _convert_seed_to_generator_np116(entropy)
  787. return _convert_seed_to_generator_np117(entropy)
  788. def _convert_seed_to_generator_np117(entropy):
  789. seed_sequence = np.random.SeedSequence(entropy)
  790. return convert_seed_sequence_to_generator(seed_sequence)
  791. def _convert_seed_to_generator_np116(entropy):
  792. return np.random.RandomState(entropy)
  793. def convert_seed_sequence_to_generator(seed_sequence):
  794. """Convert a seed sequence to a numpy (random number) generator.
  795. Parameters
  796. ----------
  797. seed_sequence : numpy.random.SeedSequence
  798. The seed value to use.
  799. Returns
  800. -------
  801. numpy.random.Generator
  802. Generator initialized with the provided seed sequence.
  803. """
  804. return np.random.Generator(BIT_GENERATOR(seed_sequence))
  805. def create_pseudo_random_generator_():
  806. """Create a new numpy (random) generator, derived from the global RNG.
  807. This function advances the global RNG's state.
  808. Returns
  809. -------
  810. numpy.random.Generator or numpy.random.RandomState
  811. In numpy <=1.16 a ``RandomState``, in 1.17+ a ``Generator``.
  812. Both are initialized with a seed sampled from the global RNG.
  813. """
  814. # could also use derive_rng(get_global_rng()) here
  815. random_seed = generate_seed_(get_global_rng().generator)
  816. return convert_seed_to_generator(random_seed)
  817. def create_fully_random_generator():
  818. """Create a new numpy (random) generator, derived from OS's entropy.
  819. Returns
  820. -------
  821. numpy.random.Generator or numpy.random.RandomState
  822. In numpy <=1.16 a ``RandomState``, in 1.17+ a ``Generator``.
  823. Both are initialized with entropy requested from the OS. They are
  824. hence independent of entered seeds or the library's global RNG.
  825. """
  826. if not SUPPORTS_NEW_NP_RNG_STYLE:
  827. return _create_fully_random_generator_np116()
  828. return _create_fully_random_generator_np117()
  829. def _create_fully_random_generator_np117():
  830. # TODO need entropy here?
  831. return np.random.Generator(np.random.SFC64())
  832. def _create_fully_random_generator_np116():
  833. return np.random.RandomState()
  834. def generate_seed_(generator):
  835. """Sample a seed from the provided generator.
  836. This function advances the generator's state.
  837. See ``SEED_MIN_VALUE`` and ``SEED_MAX_VALUE`` for the seed's value
  838. range.
  839. Parameters
  840. ----------
  841. generator : numpy.random.Generator or numpy.random.RandomState
  842. The generator from which to sample the seed.
  843. Returns
  844. -------
  845. int
  846. The sampled seed.
  847. """
  848. return generate_seeds_(generator, 1)[0]
  849. def generate_seeds_(generator, n):
  850. """Sample `n` seeds from the provided generator.
  851. This function advances the generator's state.
  852. Parameters
  853. ----------
  854. generator : numpy.random.Generator or numpy.random.RandomState
  855. The generator from which to sample the seed.
  856. n : int
  857. Number of seeds to sample.
  858. Returns
  859. -------
  860. ndarray
  861. 1D-array of ``int32`` seeds.
  862. """
  863. return polyfill_integers(generator, SEED_MIN_VALUE, SEED_MAX_VALUE,
  864. size=(n,))
  865. def copy_generator(generator):
  866. """Copy an existing numpy (random number) generator.
  867. Parameters
  868. ----------
  869. generator : numpy.random.Generator or numpy.random.RandomState
  870. The generator to copy.
  871. Returns
  872. -------
  873. numpy.random.Generator or numpy.random.RandomState
  874. In numpy <=1.16 a ``RandomState``, in 1.17+ a ``Generator``.
  875. Both are copies of the input argument.
  876. """
  877. if isinstance(generator, np.random.RandomState):
  878. return _copy_generator_np116(generator)
  879. return _copy_generator_np117(generator)
  880. def _copy_generator_np117(generator):
  881. # TODO not sure if it is enough to only copy the state
  882. # TODO initializing a bit gen and then copying the state might be slower
  883. # then just deepcopying the whole thing
  884. old_bit_gen = generator.bit_generator
  885. new_bit_gen = old_bit_gen.__class__(1)
  886. new_bit_gen.state = copylib.deepcopy(old_bit_gen.state)
  887. return np.random.Generator(new_bit_gen)
  888. def _copy_generator_np116(random_state):
  889. rs_copy = np.random.RandomState(1)
  890. state = random_state.get_state()
  891. rs_copy.set_state(state)
  892. return rs_copy
  893. def copy_generator_unless_global_generator(generator):
  894. """Copy a numpy generator unless it is the current global generator.
  895. "global generator" here denotes the generator contained in the
  896. global RNG's ``.generator`` attribute.
  897. Parameters
  898. ----------
  899. generator : numpy.random.Generator or numpy.random.RandomState
  900. The generator to copy.
  901. Returns
  902. -------
  903. numpy.random.Generator or numpy.random.RandomState
  904. In numpy <=1.16 a ``RandomState``, in 1.17+ a ``Generator``.
  905. Both are copies of the input argument, unless that input is
  906. identical to the global generator. If it is identical, the
  907. instance itself will be returned without copying it.
  908. """
  909. if generator is get_global_rng().generator:
  910. return generator
  911. return copy_generator(generator)
  912. def reset_generator_cache_(generator):
  913. """Reset a numpy (random number) generator's internal cache.
  914. This function modifies the generator's state in-place.
  915. Parameters
  916. ----------
  917. generator : numpy.random.Generator or numpy.random.RandomState
  918. The generator of which to reset the cache.
  919. Returns
  920. -------
  921. numpy.random.Generator or numpy.random.RandomState
  922. In numpy <=1.16 a ``RandomState``, in 1.17+ a ``Generator``.
  923. In both cases the input argument itself.
  924. """
  925. if isinstance(generator, np.random.RandomState):
  926. return _reset_generator_cache_np116_(generator)
  927. return _reset_generator_cache_np117_(generator)
  928. def _reset_generator_cache_np117_(generator):
  929. # This deactivates usage of the cache. We could also remove the cached
  930. # value itself in "uinteger", but setting the RNG to ignore the cached
  931. # value should be enough.
  932. state = _get_generator_state_np117(generator)
  933. state["has_uint32"] = 0
  934. _set_generator_state_np117_(generator, state)
  935. return generator
  936. def _reset_generator_cache_np116_(random_state):
  937. # State tuple content:
  938. # 'MT19937', array of ints, unknown int, cache flag, cached value
  939. # The cache flag only affects the standard_normal() method.
  940. state = list(random_state.get_state())
  941. state[-2] = 0
  942. random_state.set_state(tuple(state))
  943. return random_state
  944. def derive_generator_(generator):
  945. """Create a child numpy (random number) generator from an existing one.
  946. This advances the generator's state.
  947. Parameters
  948. ----------
  949. generator : numpy.random.Generator or numpy.random.RandomState
  950. The generator from which to derive a new child generator.
  951. Returns
  952. -------
  953. numpy.random.Generator or numpy.random.RandomState
  954. In numpy <=1.16 a ``RandomState``, in 1.17+ a ``Generator``.
  955. In both cases a derived child generator.
  956. """
  957. return derive_generators_(generator, n=1)[0]
  958. # TODO does this advance the RNG in 1.17? It should advance it for security
  959. # reasons
  960. def derive_generators_(generator, n):
  961. """Create child numpy (random number) generators from an existing one.
  962. Parameters
  963. ----------
  964. generator : numpy.random.Generator or numpy.random.RandomState
  965. The generator from which to derive new child generators.
  966. n : int
  967. Number of child generators to derive.
  968. Returns
  969. -------
  970. list of numpy.random.Generator or list of numpy.random.RandomState
  971. In numpy <=1.16 a list of ``RandomState`` s,
  972. in 1.17+ a list of ``Generator`` s.
  973. In both cases lists of derived child generators.
  974. """
  975. if isinstance(generator, np.random.RandomState):
  976. return _derive_generators_np116_(generator, n=n)
  977. return _derive_generators_np117_(generator, n=n)
  978. def _derive_generators_np117_(generator, n):
  979. # TODO possible to get the SeedSequence from 'rng'?
  980. """
  981. advance_rng_(rng)
  982. rng = copylib.deepcopy(rng)
  983. reset_rng_cache_(rng)
  984. state = rng.bit_generator.state
  985. rngs = []
  986. for i in sm.xrange(n):
  987. state["state"]["state"] += (i * 100003 + 17)
  988. rng.bit_generator.state = state
  989. rngs.append(rng)
  990. rng = copylib.deepcopy(rng)
  991. return rngs
  992. """
  993. # We generate here two integers instead of one, because the internal state
  994. # of the RNG might have one 32bit integer still cached up, which would
  995. # then be returned first when calling integers(). This should usually be
  996. # fine, but there is some risk involved that this will lead to sampling
  997. # many times the same seed in loop constructions (if the internal state
  998. # is not properly advanced and the cache is then also not reset). Adding
  999. # 'size=(2,)' decreases that risk. (It is then enough to e.g. call once
  1000. # random() to advance the internal state. No resetting of caches is
  1001. # needed.)
  1002. seed_ = generator.integers(SEED_MIN_VALUE, SEED_MAX_VALUE, dtype="int32",
  1003. size=(2,))[-1]
  1004. seed_seq = np.random.SeedSequence(seed_)
  1005. seed_seqs = seed_seq.spawn(n)
  1006. return [convert_seed_sequence_to_generator(seed_seq)
  1007. for seed_seq in seed_seqs]
  1008. def _derive_generators_np116_(random_state, n):
  1009. seed_ = random_state.randint(SEED_MIN_VALUE, SEED_MAX_VALUE)
  1010. return [_convert_seed_to_generator_np116(seed_ + i) for i in sm.xrange(n)]
  1011. def get_generator_state(generator):
  1012. """Get the state of this provided generator.
  1013. Parameters
  1014. ----------
  1015. generator : numpy.random.Generator or numpy.random.RandomState
  1016. The generator, which's state is supposed to be extracted.
  1017. Returns
  1018. -------
  1019. tuple or dict
  1020. The state of the generator.
  1021. In numpy 1.17+, the bit generator's state will be returned.
  1022. In numpy <=1.16, the ``RandomState`` 's state is returned.
  1023. In both cases the state is a copy. In-place changes will not affect
  1024. the RNG.
  1025. """
  1026. if isinstance(generator, np.random.RandomState):
  1027. return _get_generator_state_np116(generator)
  1028. return _get_generator_state_np117(generator)
  1029. def _get_generator_state_np117(generator):
  1030. return generator.bit_generator.state
  1031. def _get_generator_state_np116(random_state):
  1032. return random_state.get_state()
  1033. def set_generator_state_(generator, state):
  1034. """Set the state of a numpy (random number) generator in-place.
  1035. Parameters
  1036. ----------
  1037. generator : numpy.random.Generator or numpy.random.RandomState
  1038. The generator, which's state is supposed to be modified.
  1039. state : tuple or dict
  1040. The new state of the generator.
  1041. Should correspond to the output of
  1042. :func:`~imgaug.random.get_generator_state`.
  1043. """
  1044. if isinstance(generator, np.random.RandomState):
  1045. _set_generator_state_np116_(generator, state)
  1046. else:
  1047. _set_generator_state_np117_(generator, state)
  1048. def _set_generator_state_np117_(generator, state):
  1049. generator.bit_generator.state = state
  1050. def _set_generator_state_np116_(random_state, state):
  1051. random_state.set_state(state)
  1052. def is_generator_equal_to(generator, other_generator):
  1053. """Estimate whether two generator have the same class and state.
  1054. Parameters
  1055. ----------
  1056. generator : numpy.random.Generator or numpy.random.RandomState
  1057. First generator used in the comparison.
  1058. other_generator : numpy.random.Generator or numpy.random.RandomState
  1059. Second generator used in the comparison.
  1060. Returns
  1061. -------
  1062. bool
  1063. ``True`` if `generator` 's class and state are the same as the
  1064. class and state of `other_generator`. ``False`` otherwise.
  1065. """
  1066. if isinstance(generator, np.random.RandomState):
  1067. return _is_generator_equal_to_np116(generator, other_generator)
  1068. return _is_generator_equal_to_np117(generator, other_generator)
  1069. def _is_generator_equal_to_np117(generator, other_generator):
  1070. assert generator.__class__ is other_generator.__class__, (
  1071. "Expected both rngs to have the same class, "
  1072. "got types '%s' and '%s'." % (type(generator), type(other_generator)))
  1073. state1 = get_generator_state(generator)
  1074. state2 = get_generator_state(other_generator)
  1075. assert state1["bit_generator"] == "SFC64", (
  1076. "Can currently only compare the states of numpy.random.SFC64 bit "
  1077. "generators, got %s." % (state1["bit_generator"],))
  1078. assert state2["bit_generator"] == "SFC64", (
  1079. "Can currently only compare the states of numpy.random.SFC64 bit "
  1080. "generators, got %s." % (state2["bit_generator"],))
  1081. if state1["has_uint32"] != state2["has_uint32"]:
  1082. return False
  1083. if state1["has_uint32"] == state2["has_uint32"] == 1:
  1084. if state1["uinteger"] != state2["uinteger"]:
  1085. return False
  1086. return np.array_equal(state1["state"]["state"], state2["state"]["state"])
  1087. def _is_generator_equal_to_np116(random_state, other_random_state):
  1088. state1 = _get_generator_state_np116(random_state)
  1089. state2 = _get_generator_state_np116(other_random_state)
  1090. # Note that state1 and state2 are tuples with the value at index 1 being
  1091. # a numpy array and the values at 2-4 being ints/floats, so we can't just
  1092. # apply array_equal to state1[1:4+1] and state2[1:4+1]. We need a loop
  1093. # here.
  1094. for i in sm.xrange(1, 4+1):
  1095. if not np.array_equal(state1[i], state2[i]):
  1096. return False
  1097. return True
  1098. def advance_generator_(generator):
  1099. """Advance a numpy random generator's internal state in-place by one step.
  1100. This advances the generator's state.
  1101. .. note::
  1102. This simply samples one or more random values. This means that
  1103. a call of this method will not completely change the outputs of
  1104. the next called sampling method. To achieve more drastic output
  1105. changes, call :func:`~imgaug.random.derive_generator_`.
  1106. Parameters
  1107. ----------
  1108. generator : numpy.random.Generator or numpy.random.RandomState
  1109. Generator of which to advance the internal state.
  1110. """
  1111. if isinstance(generator, np.random.RandomState):
  1112. _advance_generator_np116_(generator)
  1113. else:
  1114. _advance_generator_np117_(generator)
  1115. def _advance_generator_np117_(generator):
  1116. _reset_generator_cache_np117_(generator)
  1117. generator.random()
  1118. def _advance_generator_np116_(generator):
  1119. _reset_generator_cache_np116_(generator)
  1120. generator.uniform()
  1121. def polyfill_integers(generator, low, high=None, size=None, dtype="int32",
  1122. endpoint=False):
  1123. """Sample integers from a generator in different numpy versions.
  1124. Parameters
  1125. ----------
  1126. generator : numpy.random.Generator or numpy.random.RandomState
  1127. The generator to sample from. If it is a ``RandomState``,
  1128. :func:`numpy.random.RandomState.randint` will be called,
  1129. otherwise :func:`numpy.random.Generator.integers`.
  1130. low : int or array-like of ints
  1131. See :func:`numpy.random.Generator.integers`.
  1132. high : int or array-like of ints, optional
  1133. See :func:`numpy.random.Generator.integers`.
  1134. size : int or tuple of ints, optional
  1135. See :func:`numpy.random.Generator.integers`.
  1136. dtype : {str, dtype}, optional
  1137. See :func:`numpy.random.Generator.integers`.
  1138. endpoint : bool, optional
  1139. See :func:`numpy.random.Generator.integers`.
  1140. Returns
  1141. -------
  1142. int or ndarray of ints
  1143. See :func:`numpy.random.Generator.integers`.
  1144. """
  1145. if hasattr(generator, "randint"):
  1146. if endpoint:
  1147. if high is None:
  1148. high = low + 1
  1149. low = 0
  1150. else:
  1151. high = high + 1
  1152. return generator.randint(low=low, high=high, size=size, dtype=dtype)
  1153. return generator.integers(low=low, high=high, size=size, dtype=dtype,
  1154. endpoint=endpoint)
  1155. def polyfill_random(generator, size, dtype="float32", out=None):
  1156. """Sample random floats from a generator in different numpy versions.
  1157. Parameters
  1158. ----------
  1159. generator : numpy.random.Generator or numpy.random.RandomState
  1160. The generator to sample from. Both ``RandomState`` and ``Generator``
  1161. support ``random()``, but with different interfaces.
  1162. size : int or tuple of ints, optional
  1163. See :func:`numpy.random.Generator.random`.
  1164. dtype : {str, dtype}, optional
  1165. See :func:`numpy.random.Generator.random`.
  1166. out : ndarray, optional
  1167. See :func:`numpy.random.Generator.random`.
  1168. Returns
  1169. -------
  1170. float or ndarray of floats
  1171. See :func:`numpy.random.Generator.random`.
  1172. """
  1173. if hasattr(generator, "random_sample"):
  1174. # note that numpy.random in <=1.16 supports random(), but
  1175. # numpy.random.RandomState does not
  1176. result = generator.random_sample(size=size).astype(dtype)
  1177. if out is not None:
  1178. assert out.dtype.name == result.dtype.name, (
  1179. "Expected out array to have the same dtype as "
  1180. "random_sample()'s result array. Got %s (out) and %s (result) "
  1181. "instead." % (out.dtype.name, result.dtype.name))
  1182. out[...] = result
  1183. return result
  1184. return generator.random(size=size, dtype=dtype, out=out)
  1185. # TODO add tests
  1186. class temporary_numpy_seed(object):
  1187. """Context to temporarily alter the random state of ``numpy.random``.
  1188. The random state's internal state will be set back to the original one
  1189. once the context finishes.
  1190. Added in 0.4.0.
  1191. Parameters
  1192. ----------
  1193. entropy : None or int
  1194. The seed value to use.
  1195. If `None` then the seed will not be altered and the internal state
  1196. of ``numpy.random`` will not be reset back upon context exit (i.e.
  1197. this context will do nothing).
  1198. """
  1199. # pylint complains about class name
  1200. # pylint: disable=invalid-name
  1201. def __init__(self, entropy=None):
  1202. self.old_state = None
  1203. self.entropy = entropy
  1204. def __enter__(self):
  1205. if self.entropy is not None:
  1206. self.old_state = np.random.get_state()
  1207. np.random.seed(self.entropy)
  1208. def __exit__(self, exc_type, exc_val, exc_tb):
  1209. if self.entropy is not None:
  1210. np.random.set_state(self.old_state)