von_mises.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Optional
  4. import torch
  5. import torch.jit
  6. from torch import Tensor
  7. from torch.distributions import constraints
  8. from torch.distributions.distribution import Distribution
  9. from torch.distributions.utils import broadcast_all, lazy_property
  10. __all__ = ["VonMises"]
  11. def _eval_poly(y, coef):
  12. coef = list(coef)
  13. result = coef.pop()
  14. while coef:
  15. result = coef.pop() + y * result
  16. return result
  17. _I0_COEF_SMALL = [
  18. 1.0,
  19. 3.5156229,
  20. 3.0899424,
  21. 1.2067492,
  22. 0.2659732,
  23. 0.360768e-1,
  24. 0.45813e-2,
  25. ]
  26. _I0_COEF_LARGE = [
  27. 0.39894228,
  28. 0.1328592e-1,
  29. 0.225319e-2,
  30. -0.157565e-2,
  31. 0.916281e-2,
  32. -0.2057706e-1,
  33. 0.2635537e-1,
  34. -0.1647633e-1,
  35. 0.392377e-2,
  36. ]
  37. _I1_COEF_SMALL = [
  38. 0.5,
  39. 0.87890594,
  40. 0.51498869,
  41. 0.15084934,
  42. 0.2658733e-1,
  43. 0.301532e-2,
  44. 0.32411e-3,
  45. ]
  46. _I1_COEF_LARGE = [
  47. 0.39894228,
  48. -0.3988024e-1,
  49. -0.362018e-2,
  50. 0.163801e-2,
  51. -0.1031555e-1,
  52. 0.2282967e-1,
  53. -0.2895312e-1,
  54. 0.1787654e-1,
  55. -0.420059e-2,
  56. ]
  57. _COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
  58. _COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
  59. def _log_modified_bessel_fn(x, order=0):
  60. """
  61. Returns ``log(I_order(x))`` for ``x > 0``,
  62. where `order` is either 0 or 1.
  63. """
  64. assert order == 0 or order == 1
  65. # compute small solution
  66. y = x / 3.75
  67. y = y * y
  68. small = _eval_poly(y, _COEF_SMALL[order])
  69. if order == 1:
  70. small = x.abs() * small
  71. small = small.log()
  72. # compute large solution
  73. y = 3.75 / x
  74. large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
  75. result = torch.where(x < 3.75, small, large)
  76. return result
  77. @torch.jit.script_if_tracing
  78. def _rejection_sample(loc, concentration, proposal_r, x):
  79. done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
  80. while not done.all():
  81. u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
  82. u1, u2, u3 = u.unbind()
  83. z = torch.cos(math.pi * u1)
  84. f = (1 + proposal_r * z) / (proposal_r + z)
  85. c = concentration * (proposal_r - f)
  86. accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
  87. if accept.any():
  88. x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
  89. done = done | accept
  90. return (x + math.pi + loc) % (2 * math.pi) - math.pi
  91. class VonMises(Distribution):
  92. """
  93. A circular von Mises distribution.
  94. This implementation uses polar coordinates. The ``loc`` and ``value`` args
  95. can be any real number (to facilitate unconstrained optimization), but are
  96. interpreted as angles modulo 2 pi.
  97. Example::
  98. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  99. >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
  100. >>> m.sample() # von Mises distributed with loc=1 and concentration=1
  101. tensor([1.9777])
  102. :param torch.Tensor loc: an angle in radians.
  103. :param torch.Tensor concentration: concentration parameter
  104. """
  105. arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
  106. support = constraints.real
  107. has_rsample = False
  108. def __init__(
  109. self,
  110. loc: Tensor,
  111. concentration: Tensor,
  112. validate_args: Optional[bool] = None,
  113. ) -> None:
  114. self.loc, self.concentration = broadcast_all(loc, concentration)
  115. batch_shape = self.loc.shape
  116. event_shape = torch.Size()
  117. super().__init__(batch_shape, event_shape, validate_args)
  118. def log_prob(self, value):
  119. if self._validate_args:
  120. self._validate_sample(value)
  121. log_prob = self.concentration * torch.cos(value - self.loc)
  122. log_prob = (
  123. log_prob
  124. - math.log(2 * math.pi)
  125. - _log_modified_bessel_fn(self.concentration, order=0)
  126. )
  127. return log_prob
  128. @lazy_property
  129. def _loc(self) -> Tensor:
  130. return self.loc.to(torch.double)
  131. @lazy_property
  132. def _concentration(self) -> Tensor:
  133. return self.concentration.to(torch.double)
  134. @lazy_property
  135. def _proposal_r(self) -> Tensor:
  136. kappa = self._concentration
  137. tau = 1 + (1 + 4 * kappa**2).sqrt()
  138. rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
  139. _proposal_r = (1 + rho**2) / (2 * rho)
  140. # second order Taylor expansion around 0 for small kappa
  141. _proposal_r_taylor = 1 / kappa + kappa
  142. return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
  143. @torch.no_grad()
  144. def sample(self, sample_shape=torch.Size()):
  145. """
  146. The sampling algorithm for the von Mises distribution is based on the
  147. following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
  148. von Mises distribution." Applied Statistics (1979): 152-157.
  149. Sampling is always done in double precision internally to avoid a hang
  150. in _rejection_sample() for small values of the concentration, which
  151. starts to happen for single precision around 1e-4 (see issue #88443).
  152. """
  153. shape = self._extended_shape(sample_shape)
  154. x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
  155. return _rejection_sample(
  156. self._loc, self._concentration, self._proposal_r, x
  157. ).to(self.loc.dtype)
  158. def expand(self, batch_shape, _instance=None):
  159. try:
  160. return super().expand(batch_shape)
  161. except NotImplementedError:
  162. validate_args = self.__dict__.get("_validate_args")
  163. loc = self.loc.expand(batch_shape)
  164. concentration = self.concentration.expand(batch_shape)
  165. return type(self)(loc, concentration, validate_args=validate_args)
  166. @property
  167. def mean(self) -> Tensor:
  168. """
  169. The provided mean is the circular one.
  170. """
  171. return self.loc
  172. @property
  173. def mode(self) -> Tensor:
  174. return self.loc
  175. @lazy_property
  176. def variance(self) -> Tensor: # type: ignore[override]
  177. """
  178. The provided variance is the circular one.
  179. """
  180. return (
  181. 1
  182. - (
  183. _log_modified_bessel_fn(self.concentration, order=1)
  184. - _log_modified_bessel_fn(self.concentration, order=0)
  185. ).exp()
  186. )