poisson.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.exp_family import ExponentialFamily
  7. from torch.distributions.utils import broadcast_all
  8. from torch.types import _Number, Number
  9. __all__ = ["Poisson"]
  10. class Poisson(ExponentialFamily):
  11. r"""
  12. Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
  13. Samples are nonnegative integers, with a pmf given by
  14. .. math::
  15. \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
  16. Example::
  17. >>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'")
  18. >>> m = Poisson(torch.tensor([4]))
  19. >>> m.sample()
  20. tensor([ 3.])
  21. Args:
  22. rate (Number, Tensor): the rate parameter
  23. """
  24. # pyrefly: ignore [bad-override]
  25. arg_constraints = {"rate": constraints.nonnegative}
  26. support = constraints.nonnegative_integer
  27. @property
  28. def mean(self) -> Tensor:
  29. return self.rate
  30. @property
  31. def mode(self) -> Tensor:
  32. return self.rate.floor()
  33. @property
  34. def variance(self) -> Tensor:
  35. return self.rate
  36. def __init__(
  37. self,
  38. rate: Union[Tensor, Number],
  39. validate_args: Optional[bool] = None,
  40. ) -> None:
  41. (self.rate,) = broadcast_all(rate)
  42. if isinstance(rate, _Number):
  43. batch_shape = torch.Size()
  44. else:
  45. batch_shape = self.rate.size()
  46. super().__init__(batch_shape, validate_args=validate_args)
  47. def expand(self, batch_shape, _instance=None):
  48. new = self._get_checked_instance(Poisson, _instance)
  49. batch_shape = torch.Size(batch_shape)
  50. new.rate = self.rate.expand(batch_shape)
  51. super(Poisson, new).__init__(batch_shape, validate_args=False)
  52. new._validate_args = self._validate_args
  53. return new
  54. def sample(self, sample_shape=torch.Size()):
  55. shape = self._extended_shape(sample_shape)
  56. with torch.no_grad():
  57. return torch.poisson(self.rate.expand(shape))
  58. def log_prob(self, value):
  59. if self._validate_args:
  60. self._validate_sample(value)
  61. rate, value = broadcast_all(self.rate, value)
  62. return value.xlogy(rate) - rate - (value + 1).lgamma()
  63. @property
  64. def _natural_params(self) -> tuple[Tensor]:
  65. return (torch.log(self.rate),)
  66. # pyrefly: ignore [bad-override]
  67. def _log_normalizer(self, x):
  68. return torch.exp(x)