dirichlet.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from paddle.base.data_feeder import check_variable_and_dtype
  16. from paddle.base.layer_helper import LayerHelper
  17. from paddle.distribution import exponential_family
  18. from paddle.framework import in_dynamic_mode
  19. class Dirichlet(exponential_family.ExponentialFamily):
  20. r"""
  21. Dirichlet distribution with parameter "concentration".
  22. The Dirichlet distribution is defined over the `(k-1)-simplex` using a
  23. positive, length-k vector concentration(`k > 1`).
  24. The Dirichlet is identically the Beta distribution when `k = 2`.
  25. For independent and identically distributed continuous random variable
  26. :math:`\boldsymbol X \in R_k` , and support
  27. :math:`\boldsymbol X \in (0,1), ||\boldsymbol X|| = 1` ,
  28. The probability density function (pdf) is
  29. .. math::
  30. f(\boldsymbol X; \boldsymbol \alpha) = \frac{1}{B(\boldsymbol \alpha)} \prod_{i=1}^{k}x_i^{\alpha_i-1}
  31. where :math:`\boldsymbol \alpha = {\alpha_1,...,\alpha_k}, k \ge 2` is
  32. parameter, the normalizing constant is the multivariate beta function.
  33. .. math::
  34. B(\boldsymbol \alpha) = \frac{\prod_{i=1}^{k} \Gamma(\alpha_i)}{\Gamma(\alpha_0)}
  35. :math:`\alpha_0=\sum_{i=1}^{k} \alpha_i` is the sum of parameters,
  36. :math:`\Gamma(\alpha)` is gamma function.
  37. Args:
  38. concentration (Tensor): "Concentration" parameter of dirichlet
  39. distribution, also called :math:`\alpha`. When it's over one
  40. dimension, the last axis denotes the parameter of distribution,
  41. ``event_shape=concentration.shape[-1:]`` , axes other than last are
  42. consider batch dimensions with ``batch_shape=concentration.shape[:-1]`` .
  43. Examples:
  44. .. code-block:: python
  45. >>> import paddle
  46. >>> dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.]))
  47. >>> print(dirichlet.entropy())
  48. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  49. -1.24434423)
  50. >>> print(dirichlet.prob(paddle.to_tensor([.3, .5, .6])))
  51. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  52. 10.80000019)
  53. """
  54. def __init__(self, concentration):
  55. if concentration.dim() < 1:
  56. raise ValueError(
  57. "`concentration` parameter must be at least one dimensional"
  58. )
  59. self.concentration = concentration
  60. super().__init__(concentration.shape[:-1], concentration.shape[-1:])
  61. @property
  62. def mean(self):
  63. """Mean of Dirichlet distribution.
  64. Returns:
  65. Mean value of distribution.
  66. """
  67. return self.concentration / self.concentration.sum(-1, keepdim=True)
  68. @property
  69. def variance(self):
  70. """Variance of Dirichlet distribution.
  71. Returns:
  72. Variance value of distribution.
  73. """
  74. concentration0 = self.concentration.sum(-1, keepdim=True)
  75. return (self.concentration * (concentration0 - self.concentration)) / (
  76. concentration0.pow(2) * (concentration0 + 1)
  77. )
  78. def sample(self, shape=()):
  79. """Sample from dirichlet distribution.
  80. Args:
  81. shape (Sequence[int], optional): Sample shape. Defaults to empty tuple.
  82. """
  83. shape = shape if isinstance(shape, tuple) else tuple(shape)
  84. return _dirichlet(self.concentration.expand(self._extend_shape(shape)))
  85. def prob(self, value):
  86. """Probability density function(PDF) evaluated at value.
  87. Args:
  88. value (Tensor): Value to be evaluated.
  89. Returns:
  90. PDF evaluated at value.
  91. """
  92. return paddle.exp(self.log_prob(value))
  93. def log_prob(self, value):
  94. """Log of probability density function.
  95. Args:
  96. value (Tensor): Value to be evaluated.
  97. """
  98. return (
  99. (paddle.log(value) * (self.concentration - 1.0)).sum(-1)
  100. + paddle.lgamma(self.concentration.sum(-1))
  101. - paddle.lgamma(self.concentration).sum(-1)
  102. )
  103. def entropy(self):
  104. """Entropy of Dirichlet distribution.
  105. Returns:
  106. Entropy of distribution.
  107. """
  108. concentration0 = self.concentration.sum(-1)
  109. k = self.concentration.shape[-1]
  110. return (
  111. paddle.lgamma(self.concentration).sum(-1)
  112. - paddle.lgamma(concentration0)
  113. - (k - concentration0) * paddle.digamma(concentration0)
  114. - (
  115. (self.concentration - 1.0) * paddle.digamma(self.concentration)
  116. ).sum(-1)
  117. )
  118. @property
  119. def _natural_parameters(self):
  120. return (self.concentration,)
  121. def _log_normalizer(self, x):
  122. return x.lgamma().sum(-1) - paddle.lgamma(x.sum(-1))
  123. def _dirichlet(concentration, name=None):
  124. if in_dynamic_mode():
  125. return paddle._C_ops.dirichlet(concentration)
  126. else:
  127. op_type = 'dirichlet'
  128. check_variable_and_dtype(
  129. concentration,
  130. 'concentration',
  131. ['float16', 'float32', 'float64', 'uint16'],
  132. op_type,
  133. )
  134. helper = LayerHelper(op_type, **locals())
  135. out = helper.create_variable_for_type_inference(
  136. dtype=concentration.dtype
  137. )
  138. helper.append_op(
  139. type=op_type,
  140. inputs={"Alpha": concentration},
  141. outputs={'Out': out},
  142. attrs={},
  143. )
  144. return out