cauchy.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Optional, Union
  4. import torch
  5. from torch import inf, nan, Tensor
  6. from torch.distributions import constraints
  7. from torch.distributions.distribution import Distribution
  8. from torch.distributions.utils import broadcast_all
  9. from torch.types import _Number, _size
  10. __all__ = ["Cauchy"]
  11. class Cauchy(Distribution):
  12. r"""
  13. Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of
  14. independent normally distributed random variables with means `0` follows a
  15. Cauchy distribution.
  16. Example::
  17. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  18. >>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
  19. >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1
  20. tensor([ 2.3214])
  21. Args:
  22. loc (float or Tensor): mode or median of the distribution.
  23. scale (float or Tensor): half width at half maximum.
  24. """
  25. # pyrefly: ignore [bad-override]
  26. arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
  27. support = constraints.real
  28. has_rsample = True
  29. def __init__(
  30. self,
  31. loc: Union[Tensor, float],
  32. scale: Union[Tensor, float],
  33. validate_args: Optional[bool] = None,
  34. ) -> None:
  35. self.loc, self.scale = broadcast_all(loc, scale)
  36. if isinstance(loc, _Number) and isinstance(scale, _Number):
  37. batch_shape = torch.Size()
  38. else:
  39. batch_shape = self.loc.size()
  40. super().__init__(batch_shape, validate_args=validate_args)
  41. def expand(self, batch_shape, _instance=None):
  42. new = self._get_checked_instance(Cauchy, _instance)
  43. batch_shape = torch.Size(batch_shape)
  44. new.loc = self.loc.expand(batch_shape)
  45. new.scale = self.scale.expand(batch_shape)
  46. super(Cauchy, new).__init__(batch_shape, validate_args=False)
  47. new._validate_args = self._validate_args
  48. return new
  49. @property
  50. def mean(self) -> Tensor:
  51. return torch.full(
  52. self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device
  53. )
  54. @property
  55. def mode(self) -> Tensor:
  56. return self.loc
  57. @property
  58. def variance(self) -> Tensor:
  59. return torch.full(
  60. self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device
  61. )
  62. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  63. shape = self._extended_shape(sample_shape)
  64. eps = self.loc.new(shape).cauchy_()
  65. return self.loc + eps * self.scale
  66. def log_prob(self, value):
  67. if self._validate_args:
  68. self._validate_sample(value)
  69. return (
  70. -math.log(math.pi)
  71. - self.scale.log()
  72. - (((value - self.loc) / self.scale) ** 2).log1p()
  73. )
  74. def cdf(self, value):
  75. if self._validate_args:
  76. self._validate_sample(value)
  77. return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5
  78. def icdf(self, value):
  79. return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc
  80. def entropy(self):
  81. return math.log(4 * math.pi) + self.scale.log()