geometric.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.utils import (
  8. broadcast_all,
  9. lazy_property,
  10. logits_to_probs,
  11. probs_to_logits,
  12. )
  13. from torch.nn.functional import binary_cross_entropy_with_logits
  14. from torch.types import _Number, Number
  15. __all__ = ["Geometric"]
  16. class Geometric(Distribution):
  17. r"""
  18. Creates a Geometric distribution parameterized by :attr:`probs`,
  19. where :attr:`probs` is the probability of success of Bernoulli trials.
  20. .. math::
  21. P(X=k) = (1-p)^{k} p, k = 0, 1, ...
  22. .. note::
  23. :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success
  24. hence draws samples in :math:`\{0, 1, \ldots\}`, whereas
  25. :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`.
  26. Example::
  27. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  28. >>> m = Geometric(torch.tensor([0.3]))
  29. >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0
  30. tensor([ 2.])
  31. Args:
  32. probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1]
  33. logits (Number, Tensor): the log-odds of sampling `1`.
  34. """
  35. arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
  36. support = constraints.nonnegative_integer
  37. def __init__(
  38. self,
  39. probs: Optional[Union[Tensor, Number]] = None,
  40. logits: Optional[Union[Tensor, Number]] = None,
  41. validate_args: Optional[bool] = None,
  42. ) -> None:
  43. if (probs is None) == (logits is None):
  44. raise ValueError(
  45. "Either `probs` or `logits` must be specified, but not both."
  46. )
  47. if probs is not None:
  48. (self.probs,) = broadcast_all(probs)
  49. else:
  50. assert logits is not None # helps mypy
  51. (self.logits,) = broadcast_all(logits)
  52. probs_or_logits = probs if probs is not None else logits
  53. if isinstance(probs_or_logits, _Number):
  54. batch_shape = torch.Size()
  55. else:
  56. assert probs_or_logits is not None # helps mypy
  57. batch_shape = probs_or_logits.size()
  58. super().__init__(batch_shape, validate_args=validate_args)
  59. if self._validate_args and probs is not None:
  60. # Add an extra check beyond unit_interval
  61. value = self.probs
  62. valid = value > 0
  63. if not valid.all():
  64. invalid_value = value.data[~valid]
  65. raise ValueError(
  66. "Expected parameter probs "
  67. f"({type(value).__name__} of shape {tuple(value.shape)}) "
  68. f"of distribution {repr(self)} "
  69. f"to be positive but found invalid values:\n{invalid_value}"
  70. )
  71. def expand(self, batch_shape, _instance=None):
  72. new = self._get_checked_instance(Geometric, _instance)
  73. batch_shape = torch.Size(batch_shape)
  74. if "probs" in self.__dict__:
  75. new.probs = self.probs.expand(batch_shape)
  76. if "logits" in self.__dict__:
  77. new.logits = self.logits.expand(batch_shape)
  78. super(Geometric, new).__init__(batch_shape, validate_args=False)
  79. new._validate_args = self._validate_args
  80. return new
  81. @property
  82. def mean(self) -> Tensor:
  83. return 1.0 / self.probs - 1.0
  84. @property
  85. def mode(self) -> Tensor:
  86. return torch.zeros_like(self.probs)
  87. @property
  88. def variance(self) -> Tensor:
  89. return (1.0 / self.probs - 1.0) / self.probs
  90. @lazy_property
  91. def logits(self) -> Tensor:
  92. return probs_to_logits(self.probs, is_binary=True)
  93. @lazy_property
  94. def probs(self) -> Tensor:
  95. return logits_to_probs(self.logits, is_binary=True)
  96. def sample(self, sample_shape=torch.Size()):
  97. shape = self._extended_shape(sample_shape)
  98. tiny = torch.finfo(self.probs.dtype).tiny
  99. with torch.no_grad():
  100. if torch._C._get_tracing_state():
  101. # [JIT WORKAROUND] lack of support for .uniform_()
  102. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  103. u = u.clamp(min=tiny)
  104. else:
  105. u = self.probs.new(shape).uniform_(tiny, 1)
  106. return (u.log() / (-self.probs).log1p()).floor()
  107. def log_prob(self, value):
  108. if self._validate_args:
  109. self._validate_sample(value)
  110. value, probs = broadcast_all(value, self.probs)
  111. probs = probs.clone(memory_format=torch.contiguous_format)
  112. probs[(probs == 1) & (value == 0)] = 0
  113. return value * (-probs).log1p() + self.probs.log()
  114. def entropy(self):
  115. return (
  116. binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none")
  117. / self.probs
  118. )