half_cauchy.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Optional, Union
  4. import torch
  5. from torch import inf, Tensor
  6. from torch.distributions import constraints
  7. from torch.distributions.cauchy import Cauchy
  8. from torch.distributions.transformed_distribution import TransformedDistribution
  9. from torch.distributions.transforms import AbsTransform
  10. __all__ = ["HalfCauchy"]
  11. class HalfCauchy(TransformedDistribution):
  12. r"""
  13. Creates a half-Cauchy distribution parameterized by `scale` where::
  14. X ~ Cauchy(0, scale)
  15. Y = |X| ~ HalfCauchy(scale)
  16. Example::
  17. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  18. >>> m = HalfCauchy(torch.tensor([1.0]))
  19. >>> m.sample() # half-cauchy distributed with scale=1
  20. tensor([ 2.3214])
  21. Args:
  22. scale (float or Tensor): scale of the full Cauchy distribution
  23. """
  24. arg_constraints = {"scale": constraints.positive}
  25. support = constraints.nonnegative
  26. has_rsample = True
  27. base_dist: Cauchy
  28. def __init__(
  29. self,
  30. scale: Union[Tensor, float],
  31. validate_args: Optional[bool] = None,
  32. ) -> None:
  33. base_dist = Cauchy(0, scale, validate_args=False)
  34. super().__init__(base_dist, AbsTransform(), validate_args=validate_args)
  35. def expand(self, batch_shape, _instance=None):
  36. new = self._get_checked_instance(HalfCauchy, _instance)
  37. return super().expand(batch_shape, _instance=new)
  38. @property
  39. def scale(self) -> Tensor:
  40. return self.base_dist.scale
  41. @property
  42. def mean(self) -> Tensor:
  43. return torch.full(
  44. self._extended_shape(),
  45. math.inf,
  46. dtype=self.scale.dtype,
  47. device=self.scale.device,
  48. )
  49. @property
  50. def mode(self) -> Tensor:
  51. return torch.zeros_like(self.scale)
  52. @property
  53. def variance(self) -> Tensor:
  54. return self.base_dist.variance
  55. def log_prob(self, value):
  56. if self._validate_args:
  57. self._validate_sample(value)
  58. value = torch.as_tensor(
  59. value, dtype=self.base_dist.scale.dtype, device=self.base_dist.scale.device
  60. )
  61. log_prob = self.base_dist.log_prob(value) + math.log(2)
  62. log_prob = torch.where(value >= 0, log_prob, -inf)
  63. return log_prob
  64. def cdf(self, value):
  65. if self._validate_args:
  66. self._validate_sample(value)
  67. return 2 * self.base_dist.cdf(value) - 1
  68. def icdf(self, prob):
  69. return self.base_dist.icdf((prob + 1) / 2)
  70. def entropy(self):
  71. return self.base_dist.entropy() - math.log(2)