cauchy.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Optional, Union
  4. import torch
  5. from torch import inf, nan, Tensor
  6. from torch.distributions import constraints
  7. from torch.distributions.distribution import Distribution
  8. from torch.distributions.utils import broadcast_all
  9. from torch.types import _Number, _size
  10. __all__ = ["Cauchy"]
  11. class Cauchy(Distribution):
  12. r"""
  13. Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of
  14. independent normally distributed random variables with means `0` follows a
  15. Cauchy distribution.
  16. Example::
  17. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  18. >>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
  19. >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1
  20. tensor([ 2.3214])
  21. Args:
  22. loc (float or Tensor): mode or median of the distribution.
  23. scale (float or Tensor): half width at half maximum.
  24. """
  25. arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
  26. support = constraints.real
  27. has_rsample = True
  28. def __init__(
  29. self,
  30. loc: Union[Tensor, float],
  31. scale: Union[Tensor, float],
  32. validate_args: Optional[bool] = None,
  33. ) -> None:
  34. self.loc, self.scale = broadcast_all(loc, scale)
  35. if isinstance(loc, _Number) and isinstance(scale, _Number):
  36. batch_shape = torch.Size()
  37. else:
  38. batch_shape = self.loc.size()
  39. super().__init__(batch_shape, validate_args=validate_args)
  40. def expand(self, batch_shape, _instance=None):
  41. new = self._get_checked_instance(Cauchy, _instance)
  42. batch_shape = torch.Size(batch_shape)
  43. new.loc = self.loc.expand(batch_shape)
  44. new.scale = self.scale.expand(batch_shape)
  45. super(Cauchy, new).__init__(batch_shape, validate_args=False)
  46. new._validate_args = self._validate_args
  47. return new
  48. @property
  49. def mean(self) -> Tensor:
  50. return torch.full(
  51. self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device
  52. )
  53. @property
  54. def mode(self) -> Tensor:
  55. return self.loc
  56. @property
  57. def variance(self) -> Tensor:
  58. return torch.full(
  59. self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device
  60. )
  61. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  62. shape = self._extended_shape(sample_shape)
  63. eps = self.loc.new(shape).cauchy_()
  64. return self.loc + eps * self.scale
  65. def log_prob(self, value):
  66. if self._validate_args:
  67. self._validate_sample(value)
  68. return (
  69. -math.log(math.pi)
  70. - self.scale.log()
  71. - (((value - self.loc) / self.scale) ** 2).log1p()
  72. )
  73. def cdf(self, value):
  74. if self._validate_args:
  75. self._validate_sample(value)
  76. return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5
  77. def icdf(self, value):
  78. return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc
  79. def entropy(self):
  80. return math.log(4 * math.pi) + self.scale.log()