quasirandom.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # mypy: allow-untyped-defs
  2. from typing import Optional
  3. import torch
  4. class SobolEngine:
  5. r"""
  6. The :class:`torch.quasirandom.SobolEngine` is an engine for generating
  7. (scrambled) Sobol sequences. Sobol sequences are an example of low
  8. discrepancy quasi-random sequences.
  9. This implementation of an engine for Sobol sequences is capable of
  10. sampling sequences up to a maximum dimension of 21201. It uses direction
  11. numbers from https://web.maths.unsw.edu.au/~fkuo/sobol/ obtained using the
  12. search criterion D(6) up to the dimension 21201. This is the recommended
  13. choice by the authors.
  14. References:
  15. - Art B. Owen. Scrambling Sobol and Niederreiter-Xing points.
  16. Journal of Complexity, 14(4):466-489, December 1998.
  17. - I. M. Sobol. The distribution of points in a cube and the accurate
  18. evaluation of integrals.
  19. Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967.
  20. Args:
  21. dimension (Int): The dimensionality of the sequence to be drawn
  22. scramble (bool, optional): Setting this to ``True`` will produce
  23. scrambled Sobol sequences. Scrambling is
  24. capable of producing better Sobol
  25. sequences. Default: ``False``.
  26. seed (Int, optional): This is the seed for the scrambling. The seed
  27. of the random number generator is set to this,
  28. if specified. Otherwise, it uses a random seed.
  29. Default: ``None``
  30. Examples::
  31. >>> # xdoctest: +SKIP("unseeded random state")
  32. >>> soboleng = torch.quasirandom.SobolEngine(dimension=5)
  33. >>> soboleng.draw(3)
  34. tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
  35. [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
  36. [0.7500, 0.2500, 0.2500, 0.2500, 0.7500]])
  37. """
  38. MAXBIT = 30
  39. MAXDIM = 21201
  40. def __init__(self, dimension, scramble=False, seed=None):
  41. if dimension > self.MAXDIM or dimension < 1:
  42. raise ValueError(
  43. "Supported range of dimensionality "
  44. f"for SobolEngine is [1, {self.MAXDIM}]"
  45. )
  46. self.seed = seed
  47. self.scramble = scramble
  48. self.dimension = dimension
  49. cpu = torch.device("cpu")
  50. self.sobolstate = torch.zeros(
  51. dimension, self.MAXBIT, device=cpu, dtype=torch.long
  52. )
  53. torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
  54. if not self.scramble:
  55. self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long)
  56. else:
  57. self._scramble()
  58. self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
  59. self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
  60. self.num_generated = 0
  61. def draw(
  62. self,
  63. n: int = 1,
  64. out: Optional[torch.Tensor] = None,
  65. dtype: Optional[torch.dtype] = None,
  66. ) -> torch.Tensor:
  67. r"""
  68. Function to draw a sequence of :attr:`n` points from a Sobol sequence.
  69. Note that the samples are dependent on the previous samples. The size
  70. of the result is :math:`(n, dimension)`.
  71. Args:
  72. n (Int, optional): The length of sequence of points to draw.
  73. Default: 1
  74. out (Tensor, optional): The output tensor
  75. dtype (:class:`torch.dtype`, optional): the desired data type of the
  76. returned tensor.
  77. Default: ``None``
  78. """
  79. if dtype is None:
  80. dtype = torch.get_default_dtype()
  81. if self.num_generated == 0:
  82. if n == 1:
  83. result = self._first_point.to(dtype)
  84. else:
  85. result, self.quasi = torch._sobol_engine_draw(
  86. self.quasi,
  87. n - 1,
  88. self.sobolstate,
  89. self.dimension,
  90. self.num_generated,
  91. dtype=dtype,
  92. )
  93. result = torch.cat((self._first_point.to(dtype), result), dim=-2)
  94. else:
  95. result, self.quasi = torch._sobol_engine_draw(
  96. self.quasi,
  97. n,
  98. self.sobolstate,
  99. self.dimension,
  100. self.num_generated - 1,
  101. dtype=dtype,
  102. )
  103. self.num_generated += n
  104. if out is not None:
  105. out.resize_as_(result).copy_(result)
  106. return out
  107. return result
  108. def draw_base2(
  109. self,
  110. m: int,
  111. out: Optional[torch.Tensor] = None,
  112. dtype: Optional[torch.dtype] = None,
  113. ) -> torch.Tensor:
  114. r"""
  115. Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
  116. Note that the samples are dependent on the previous samples. The size
  117. of the result is :math:`(2**m, dimension)`.
  118. Args:
  119. m (Int): The (base2) exponent of the number of points to draw.
  120. out (Tensor, optional): The output tensor
  121. dtype (:class:`torch.dtype`, optional): the desired data type of the
  122. returned tensor.
  123. Default: ``None``
  124. """
  125. n = 2**m
  126. total_n = self.num_generated + n
  127. if not (total_n & (total_n - 1) == 0):
  128. raise ValueError(
  129. "The balance properties of Sobol' points require "
  130. f"n to be a power of 2. {self.num_generated} points have been "
  131. f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
  132. "If you still want to do this, please use "
  133. "'SobolEngine.draw()' instead."
  134. )
  135. return self.draw(n=n, out=out, dtype=dtype)
  136. def reset(self):
  137. r"""
  138. Function to reset the ``SobolEngine`` to base state.
  139. """
  140. self.quasi.copy_(self.shift)
  141. self.num_generated = 0
  142. return self
  143. def fast_forward(self, n):
  144. r"""
  145. Function to fast-forward the state of the ``SobolEngine`` by
  146. :attr:`n` steps. This is equivalent to drawing :attr:`n` samples
  147. without using the samples.
  148. Args:
  149. n (Int): The number of steps to fast-forward by.
  150. """
  151. if self.num_generated == 0:
  152. torch._sobol_engine_ff_(
  153. self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
  154. )
  155. else:
  156. torch._sobol_engine_ff_(
  157. self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
  158. )
  159. self.num_generated += n
  160. return self
  161. def _scramble(self):
  162. g: Optional[torch.Generator] = None
  163. if self.seed is not None:
  164. g = torch.Generator()
  165. g.manual_seed(self.seed)
  166. cpu = torch.device("cpu")
  167. # Generate shift vector
  168. shift_ints = torch.randint(
  169. 2, (self.dimension, self.MAXBIT), device=cpu, generator=g
  170. )
  171. self.shift = torch.mv(
  172. shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
  173. )
  174. # Generate lower triangular matrices (stacked across dimensions)
  175. ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
  176. ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril()
  177. torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
  178. def __repr__(self):
  179. fmt_string = [f"dimension={self.dimension}"]
  180. if self.scramble:
  181. fmt_string += ["scramble=True"]
  182. if self.seed is not None:
  183. fmt_string += [f"seed={self.seed}"]
  184. return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"