continuous_bernoulli.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Optional, Union
  4. import torch
  5. from torch import Tensor
  6. from torch.distributions import constraints
  7. from torch.distributions.exp_family import ExponentialFamily
  8. from torch.distributions.utils import (
  9. broadcast_all,
  10. clamp_probs,
  11. lazy_property,
  12. logits_to_probs,
  13. probs_to_logits,
  14. )
  15. from torch.nn.functional import binary_cross_entropy_with_logits
  16. from torch.types import _Number, _size, Number
  17. __all__ = ["ContinuousBernoulli"]
  18. class ContinuousBernoulli(ExponentialFamily):
  19. r"""
  20. Creates a continuous Bernoulli distribution parameterized by :attr:`probs`
  21. or :attr:`logits` (but not both).
  22. The distribution is supported in [0, 1] and parameterized by 'probs' (in
  23. (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs'
  24. does not correspond to a probability and 'logits' does not correspond to
  25. log-odds, but the same names are used due to the similarity with the
  26. Bernoulli. See [1] for more details.
  27. Example::
  28. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  29. >>> m = ContinuousBernoulli(torch.tensor([0.3]))
  30. >>> m.sample()
  31. tensor([ 0.2538])
  32. Args:
  33. probs (Number, Tensor): (0,1) valued parameters
  34. logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs'
  35. [1] The continuous Bernoulli: fixing a pervasive error in variational
  36. autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.
  37. https://arxiv.org/abs/1907.06845
  38. """
  39. arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
  40. support = constraints.unit_interval
  41. _mean_carrier_measure = 0
  42. has_rsample = True
  43. def __init__(
  44. self,
  45. probs: Optional[Union[Tensor, Number]] = None,
  46. logits: Optional[Union[Tensor, Number]] = None,
  47. lims: tuple[float, float] = (0.499, 0.501),
  48. validate_args: Optional[bool] = None,
  49. ) -> None:
  50. if (probs is None) == (logits is None):
  51. raise ValueError(
  52. "Either `probs` or `logits` must be specified, but not both."
  53. )
  54. if probs is not None:
  55. is_scalar = isinstance(probs, _Number)
  56. (self.probs,) = broadcast_all(probs)
  57. # validate 'probs' here if necessary as it is later clamped for numerical stability
  58. # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
  59. if validate_args is not None:
  60. if not self.arg_constraints["probs"].check(self.probs).all():
  61. raise ValueError("The parameter probs has invalid values")
  62. self.probs = clamp_probs(self.probs)
  63. else:
  64. assert logits is not None # helps mypy
  65. is_scalar = isinstance(logits, _Number)
  66. (self.logits,) = broadcast_all(logits)
  67. self._param = self.probs if probs is not None else self.logits
  68. if is_scalar:
  69. batch_shape = torch.Size()
  70. else:
  71. batch_shape = self._param.size()
  72. self._lims = lims
  73. super().__init__(batch_shape, validate_args=validate_args)
  74. def expand(self, batch_shape, _instance=None):
  75. new = self._get_checked_instance(ContinuousBernoulli, _instance)
  76. new._lims = self._lims
  77. batch_shape = torch.Size(batch_shape)
  78. if "probs" in self.__dict__:
  79. new.probs = self.probs.expand(batch_shape)
  80. new._param = new.probs
  81. if "logits" in self.__dict__:
  82. new.logits = self.logits.expand(batch_shape)
  83. new._param = new.logits
  84. super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
  85. new._validate_args = self._validate_args
  86. return new
  87. def _new(self, *args, **kwargs):
  88. return self._param.new(*args, **kwargs)
  89. def _outside_unstable_region(self):
  90. return torch.max(
  91. torch.le(self.probs, self._lims[0]), torch.gt(self.probs, self._lims[1])
  92. )
  93. def _cut_probs(self):
  94. return torch.where(
  95. self._outside_unstable_region(),
  96. self.probs,
  97. self._lims[0] * torch.ones_like(self.probs),
  98. )
  99. def _cont_bern_log_norm(self):
  100. """computes the log normalizing constant as a function of the 'probs' parameter"""
  101. cut_probs = self._cut_probs()
  102. cut_probs_below_half = torch.where(
  103. torch.le(cut_probs, 0.5), cut_probs, torch.zeros_like(cut_probs)
  104. )
  105. cut_probs_above_half = torch.where(
  106. torch.ge(cut_probs, 0.5), cut_probs, torch.ones_like(cut_probs)
  107. )
  108. log_norm = torch.log(
  109. torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))
  110. ) - torch.where(
  111. torch.le(cut_probs, 0.5),
  112. torch.log1p(-2.0 * cut_probs_below_half),
  113. torch.log(2.0 * cut_probs_above_half - 1.0),
  114. )
  115. x = torch.pow(self.probs - 0.5, 2)
  116. taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
  117. return torch.where(self._outside_unstable_region(), log_norm, taylor)
  118. @property
  119. def mean(self) -> Tensor:
  120. cut_probs = self._cut_probs()
  121. mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (
  122. torch.log1p(-cut_probs) - torch.log(cut_probs)
  123. )
  124. x = self.probs - 0.5
  125. taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x
  126. return torch.where(self._outside_unstable_region(), mus, taylor)
  127. @property
  128. def stddev(self) -> Tensor:
  129. return torch.sqrt(self.variance)
  130. @property
  131. def variance(self) -> Tensor:
  132. cut_probs = self._cut_probs()
  133. vars = cut_probs * (cut_probs - 1.0) / torch.pow(
  134. 1.0 - 2.0 * cut_probs, 2
  135. ) + 1.0 / torch.pow(torch.log1p(-cut_probs) - torch.log(cut_probs), 2)
  136. x = torch.pow(self.probs - 0.5, 2)
  137. taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x
  138. return torch.where(self._outside_unstable_region(), vars, taylor)
  139. @lazy_property
  140. def logits(self) -> Tensor:
  141. return probs_to_logits(self.probs, is_binary=True)
  142. @lazy_property
  143. def probs(self) -> Tensor:
  144. return clamp_probs(logits_to_probs(self.logits, is_binary=True))
  145. @property
  146. def param_shape(self) -> torch.Size:
  147. return self._param.size()
  148. def sample(self, sample_shape=torch.Size()):
  149. shape = self._extended_shape(sample_shape)
  150. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  151. with torch.no_grad():
  152. return self.icdf(u)
  153. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  154. shape = self._extended_shape(sample_shape)
  155. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  156. return self.icdf(u)
  157. def log_prob(self, value):
  158. if self._validate_args:
  159. self._validate_sample(value)
  160. logits, value = broadcast_all(self.logits, value)
  161. return (
  162. -binary_cross_entropy_with_logits(logits, value, reduction="none")
  163. + self._cont_bern_log_norm()
  164. )
  165. def cdf(self, value):
  166. if self._validate_args:
  167. self._validate_sample(value)
  168. cut_probs = self._cut_probs()
  169. cdfs = (
  170. torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value)
  171. + cut_probs
  172. - 1.0
  173. ) / (2.0 * cut_probs - 1.0)
  174. unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value)
  175. return torch.where(
  176. torch.le(value, 0.0),
  177. torch.zeros_like(value),
  178. torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs),
  179. )
  180. def icdf(self, value):
  181. cut_probs = self._cut_probs()
  182. return torch.where(
  183. self._outside_unstable_region(),
  184. (
  185. torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0))
  186. - torch.log1p(-cut_probs)
  187. )
  188. / (torch.log(cut_probs) - torch.log1p(-cut_probs)),
  189. value,
  190. )
  191. def entropy(self):
  192. log_probs0 = torch.log1p(-self.probs)
  193. log_probs1 = torch.log(self.probs)
  194. return (
  195. self.mean * (log_probs0 - log_probs1)
  196. - self._cont_bern_log_norm()
  197. - log_probs0
  198. )
  199. @property
  200. def _natural_params(self) -> tuple[Tensor]:
  201. return (self.logits,)
  202. def _log_normalizer(self, x):
  203. """computes the log normalizing constant as a function of the natural parameter"""
  204. out_unst_reg = torch.max(
  205. torch.le(x, self._lims[0] - 0.5), torch.gt(x, self._lims[1] - 0.5)
  206. )
  207. cut_nat_params = torch.where(
  208. out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x)
  209. )
  210. log_norm = torch.log(
  211. torch.abs(torch.special.expm1(cut_nat_params))
  212. ) - torch.log(torch.abs(cut_nat_params))
  213. taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
  214. return torch.where(out_unst_reg, log_norm, taylor)