multinomial.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # mypy: allow-untyped-defs
  2. from typing import Optional
  3. import torch
  4. from torch import inf, Tensor
  5. from torch.distributions import Categorical, constraints
  6. from torch.distributions.binomial import Binomial
  7. from torch.distributions.distribution import Distribution
  8. from torch.distributions.utils import broadcast_all
  9. __all__ = ["Multinomial"]
  10. class Multinomial(Distribution):
  11. r"""
  12. Creates a Multinomial distribution parameterized by :attr:`total_count` and
  13. either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
  14. :attr:`probs` indexes over categories. All other dimensions index over batches.
  15. Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
  16. called (see example below)
  17. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
  18. and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
  19. will return this normalized value.
  20. The `logits` argument will be interpreted as unnormalized log probabilities
  21. and can therefore be any real number. It will likewise be normalized so that
  22. the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
  23. will return this normalized value.
  24. - :meth:`sample` requires a single shared `total_count` for all
  25. parameters and samples.
  26. - :meth:`log_prob` allows different `total_count` for each parameter and
  27. sample.
  28. Example::
  29. >>> # xdoctest: +SKIP("FIXME: found invalid values")
  30. >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
  31. >>> x = m.sample() # equal probability of 0, 1, 2, 3
  32. tensor([ 21., 24., 30., 25.])
  33. >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
  34. tensor([-4.1338])
  35. Args:
  36. total_count (int): number of trials
  37. probs (Tensor): event probabilities
  38. logits (Tensor): event log probabilities (unnormalized)
  39. """
  40. arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
  41. total_count: int
  42. @property
  43. def mean(self) -> Tensor:
  44. return self.probs * self.total_count
  45. @property
  46. def variance(self) -> Tensor:
  47. return self.total_count * self.probs * (1 - self.probs)
  48. def __init__(
  49. self,
  50. total_count: int = 1,
  51. probs: Optional[Tensor] = None,
  52. logits: Optional[Tensor] = None,
  53. validate_args: Optional[bool] = None,
  54. ) -> None:
  55. if not isinstance(total_count, int):
  56. raise NotImplementedError("inhomogeneous total_count is not supported")
  57. self.total_count = total_count
  58. self._categorical = Categorical(probs=probs, logits=logits)
  59. self._binomial = Binomial(total_count=total_count, probs=self.probs)
  60. batch_shape = self._categorical.batch_shape
  61. event_shape = self._categorical.param_shape[-1:]
  62. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  63. def expand(self, batch_shape, _instance=None):
  64. new = self._get_checked_instance(Multinomial, _instance)
  65. batch_shape = torch.Size(batch_shape)
  66. new.total_count = self.total_count
  67. new._categorical = self._categorical.expand(batch_shape)
  68. super(Multinomial, new).__init__(
  69. batch_shape, self.event_shape, validate_args=False
  70. )
  71. new._validate_args = self._validate_args
  72. return new
  73. def _new(self, *args, **kwargs):
  74. return self._categorical._new(*args, **kwargs)
  75. @constraints.dependent_property(is_discrete=True, event_dim=1)
  76. def support(self):
  77. return constraints.multinomial(self.total_count)
  78. @property
  79. def logits(self) -> Tensor:
  80. return self._categorical.logits
  81. @property
  82. def probs(self) -> Tensor:
  83. return self._categorical.probs
  84. @property
  85. def param_shape(self) -> torch.Size:
  86. return self._categorical.param_shape
  87. def sample(self, sample_shape=torch.Size()):
  88. sample_shape = torch.Size(sample_shape)
  89. samples = self._categorical.sample(
  90. torch.Size((self.total_count,)) + sample_shape
  91. )
  92. # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
  93. # (sample_shape, batch_shape, total_count)
  94. shifted_idx = list(range(samples.dim()))
  95. shifted_idx.append(shifted_idx.pop(0))
  96. samples = samples.permute(*shifted_idx)
  97. counts = samples.new(self._extended_shape(sample_shape)).zero_()
  98. counts.scatter_add_(-1, samples, torch.ones_like(samples))
  99. return counts.type_as(self.probs)
  100. def entropy(self):
  101. n = torch.tensor(self.total_count)
  102. cat_entropy = self._categorical.entropy()
  103. term1 = n * cat_entropy - torch.lgamma(n + 1)
  104. support = self._binomial.enumerate_support(expand=False)[1:]
  105. binomial_probs = torch.exp(self._binomial.log_prob(support))
  106. weights = torch.lgamma(support + 1)
  107. term2 = (binomial_probs * weights).sum([0, -1])
  108. return term1 + term2
  109. def log_prob(self, value):
  110. if self._validate_args:
  111. self._validate_sample(value)
  112. logits, value = broadcast_all(self.logits, value)
  113. logits = logits.clone(memory_format=torch.contiguous_format)
  114. log_factorial_n = torch.lgamma(value.sum(-1) + 1)
  115. log_factorial_xs = torch.lgamma(value + 1).sum(-1)
  116. logits[(value == 0) & (logits == -inf)] = 0
  117. log_powers = (logits * value).sum(-1)
  118. return log_factorial_n - log_factorial_xs + log_powers