multinomial.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Iterable
  15. import paddle
  16. from paddle.distribution import categorical, distribution
  17. class Multinomial(distribution.Distribution):
  18. r"""
  19. Multinomial distribution parameterized by :attr:`total_count` and
  20. :attr:`probs`.
  21. In probability theory, the multinomial distribution is a generalization of
  22. the binomial distribution, it models the probability of counts for each side
  23. of a k-sided die rolled n times. When k is 2 and n is 1, the multinomial is
  24. the bernoulli distribution, when k is 2 and n is grater than 1, it is the
  25. binomial distribution, when k is grater than 2 and n is 1, it is the
  26. categorical distribution.
  27. The probability mass function (PMF) for multinomial is
  28. .. math::
  29. f(x_1, ..., x_k; n, p_1,...,p_k) = \frac{n!}{x_1!...x_k!}p_1^{x_1}...p_k^{x_k}
  30. where, :math:`n` is number of trials, k is the number of categories,
  31. :math:`p_i` denote probability of a trial falling into each category,
  32. :math:`{\textstyle \sum_{i=1}^{k}p_i=1}, p_i \ge 0`, and :math:`x_i` denote
  33. count of each category.
  34. Args:
  35. total_count (int): Number of trials.
  36. probs (Tensor): Probability of a trial falling into each category. Last
  37. axis of probs indexes over categories, other axes index over batches.
  38. Probs value should between [0, 1], and sum to 1 along last axis. If
  39. the value over 1, it will be normalized to sum to 1 along the last
  40. axis.
  41. Examples:
  42. .. code-block:: python
  43. >>> import paddle
  44. >>> paddle.seed(2023)
  45. >>> multinomial = paddle.distribution.Multinomial(10, paddle.to_tensor([0.2, 0.3, 0.5]))
  46. >>> print(multinomial.sample((2, 3)))
  47. Tensor(shape=[2, 3, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
  48. [[[1., 5., 4.],
  49. [0., 4., 6.],
  50. [1., 3., 6.]],
  51. [[2., 2., 6.],
  52. [0., 6., 4.],
  53. [3., 3., 4.]]])
  54. """
  55. def __init__(self, total_count, probs):
  56. if not isinstance(total_count, int) or total_count < 1:
  57. raise ValueError(
  58. 'input parameter total_count must be int type and grater than zero.'
  59. )
  60. if probs.dim() < 1:
  61. raise ValueError(
  62. 'probs parameter should not be none and over one dimension'
  63. )
  64. self.probs = probs / probs.sum(-1, keepdim=True)
  65. self.total_count = total_count
  66. self._categorical = categorical.Categorical(
  67. logits=self._probs_to_logits(probs)
  68. )
  69. super().__init__(probs.shape[:-1], probs.shape[-1:])
  70. @property
  71. def mean(self):
  72. """mean of multinomial distribution.
  73. Returns:
  74. Tensor: mean value.
  75. """
  76. return self.probs * self.total_count
  77. @property
  78. def variance(self):
  79. """variance of multinomial distribution.
  80. Returns:
  81. Tensor: variance value.
  82. """
  83. return self.total_count * self.probs * (1 - self.probs)
  84. def prob(self, value):
  85. """probability mass function evaluated at value.
  86. Args:
  87. value (Tensor): value to be evaluated.
  88. Returns:
  89. Tensor: probability of value.
  90. """
  91. return paddle.exp(self.log_prob(value))
  92. def log_prob(self, value):
  93. """probability mass function evaluated at value.
  94. Args:
  95. value (Tensor): value to be evaluated.
  96. Returns:
  97. Tensor: probability of value.
  98. """
  99. if paddle.is_integer(value):
  100. value = paddle.cast(value, self.probs.dtype)
  101. logits, value = paddle.broadcast_tensors(
  102. [paddle.log(self.probs), value]
  103. )
  104. if paddle.in_dynamic_mode():
  105. logits[(value == 0) & (paddle.isinf(logits))] = 0
  106. else:
  107. logits = paddle.static.setitem(
  108. logits, (value == 0) & (paddle.isinf(logits)), 0
  109. )
  110. return (
  111. paddle.lgamma(value.sum(-1) + 1)
  112. - paddle.lgamma(value + 1).sum(-1)
  113. + (value * logits).sum(-1)
  114. )
  115. def sample(self, shape=()):
  116. """draw sample data from multinomial distribution
  117. Args:
  118. sample_shape (tuple, optional): [description]. Defaults to ().
  119. """
  120. if not isinstance(shape, Iterable):
  121. raise TypeError('sample shape must be Iterable object.')
  122. samples = self._categorical.sample(
  123. [
  124. self.total_count,
  125. ]
  126. + list(shape)
  127. )
  128. return (
  129. paddle.nn.functional.one_hot(samples, self.probs.shape[-1])
  130. .cast(self.probs.dtype)
  131. .sum(0)
  132. )
  133. def entropy(self):
  134. """entropy of multinomial distribution
  135. Returns:
  136. Tensor: entropy value
  137. """
  138. n = paddle.full(
  139. shape=[], fill_value=self.total_count, dtype=self.probs.dtype
  140. )
  141. support = paddle.arange(
  142. self.total_count + 1, dtype=self.probs.dtype
  143. ).reshape((-1,) + (1,) * len(self.probs.shape))[1:]
  144. binomial_pmf = paddle.exp(self._binomial_logpmf(n, support))
  145. return (n * self._categorical.entropy() - paddle.lgamma(n + 1)) + (
  146. (binomial_pmf * paddle.lgamma(support + 1)).sum([0, -1])
  147. )
  148. def _binomial_logpmf(self, count, value):
  149. logits = self._probs_to_logits(self.probs, is_binary=True)
  150. factor_n = paddle.lgamma(count + 1)
  151. factor_k = paddle.lgamma(value + 1)
  152. factor_nmk = paddle.lgamma(count - value + 1)
  153. norm = (
  154. count * _clip_by_zero(logits)
  155. + count * paddle.log1p(paddle.exp(-paddle.abs(logits)))
  156. - factor_n
  157. )
  158. return value * logits - factor_k - factor_nmk - norm
  159. def _binomial_support(count, dtype):
  160. return paddle.arange(count + 1, dtype=dtype)
  161. def _clip_by_zero(x):
  162. # like clip(x, min=0) but grad at 0 is 0.5
  163. return (x.clip(min=0) + x - x.clip(max=0)) / 2