exponential.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.exp_family import ExponentialFamily
  7. from torch.distributions.utils import broadcast_all
  8. from torch.types import _Number, _size
  9. __all__ = ["Exponential"]
  10. class Exponential(ExponentialFamily):
  11. r"""
  12. Creates a Exponential distribution parameterized by :attr:`rate`.
  13. Example::
  14. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  15. >>> m = Exponential(torch.tensor([1.0]))
  16. >>> m.sample() # Exponential distributed with rate=1
  17. tensor([ 0.1046])
  18. Args:
  19. rate (float or Tensor): rate = 1 / scale of the distribution
  20. """
  21. # pyrefly: ignore [bad-override]
  22. arg_constraints = {"rate": constraints.positive}
  23. support = constraints.nonnegative
  24. has_rsample = True
  25. _mean_carrier_measure = 0
  26. @property
  27. def mean(self) -> Tensor:
  28. return self.rate.reciprocal()
  29. @property
  30. def mode(self) -> Tensor:
  31. return torch.zeros_like(self.rate)
  32. @property
  33. def stddev(self) -> Tensor:
  34. return self.rate.reciprocal()
  35. @property
  36. def variance(self) -> Tensor:
  37. return self.rate.pow(-2)
  38. def __init__(
  39. self,
  40. rate: Union[Tensor, float],
  41. validate_args: Optional[bool] = None,
  42. ) -> None:
  43. (self.rate,) = broadcast_all(rate)
  44. batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size()
  45. super().__init__(batch_shape, validate_args=validate_args)
  46. def expand(self, batch_shape, _instance=None):
  47. new = self._get_checked_instance(Exponential, _instance)
  48. batch_shape = torch.Size(batch_shape)
  49. new.rate = self.rate.expand(batch_shape)
  50. super(Exponential, new).__init__(batch_shape, validate_args=False)
  51. new._validate_args = self._validate_args
  52. return new
  53. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  54. shape = self._extended_shape(sample_shape)
  55. return self.rate.new(shape).exponential_() / self.rate
  56. def log_prob(self, value):
  57. if self._validate_args:
  58. self._validate_sample(value)
  59. return self.rate.log() - self.rate * value
  60. def cdf(self, value):
  61. if self._validate_args:
  62. self._validate_sample(value)
  63. return 1 - torch.exp(-self.rate * value)
  64. def icdf(self, value):
  65. return -torch.log1p(-value) / self.rate
  66. def entropy(self):
  67. return 1.0 - torch.log(self.rate)
  68. @property
  69. def _natural_params(self) -> tuple[Tensor]:
  70. return (-self.rate,)
  71. # pyrefly: ignore [bad-override]
  72. def _log_normalizer(self, x):
  73. return -torch.log(-x)