kl.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. import warnings
  16. import paddle
  17. from paddle.distribution.bernoulli import Bernoulli
  18. from paddle.distribution.beta import Beta
  19. from paddle.distribution.binomial import Binomial
  20. from paddle.distribution.categorical import Categorical
  21. from paddle.distribution.cauchy import Cauchy
  22. from paddle.distribution.continuous_bernoulli import ContinuousBernoulli
  23. from paddle.distribution.dirichlet import Dirichlet
  24. from paddle.distribution.distribution import Distribution
  25. from paddle.distribution.exponential import Exponential
  26. from paddle.distribution.exponential_family import ExponentialFamily
  27. from paddle.distribution.gamma import Gamma
  28. from paddle.distribution.geometric import Geometric
  29. from paddle.distribution.laplace import Laplace
  30. from paddle.distribution.lognormal import LogNormal
  31. from paddle.distribution.multivariate_normal import MultivariateNormal
  32. from paddle.distribution.normal import Normal
  33. from paddle.distribution.poisson import Poisson
  34. from paddle.distribution.uniform import Uniform
  35. from paddle.framework import in_dynamic_mode
  36. __all__ = ["register_kl", "kl_divergence"]
  37. _REGISTER_TABLE = {}
  38. def kl_divergence(p, q):
  39. r"""
  40. Kullback-Leibler divergence between distribution p and q.
  41. .. math::
  42. KL(p||q) = \int p(x)log\frac{p(x)}{q(x)} \mathrm{d}x
  43. Args:
  44. p (Distribution): ``Distribution`` object. Inherits from the Distribution Base class.
  45. q (Distribution): ``Distribution`` object. Inherits from the Distribution Base class.
  46. Returns:
  47. Tensor, Batchwise KL-divergence between distribution p and q.
  48. Examples:
  49. .. code-block:: python
  50. >>> import paddle
  51. >>> p = paddle.distribution.Beta(alpha=0.5, beta=0.5)
  52. >>> q = paddle.distribution.Beta(alpha=0.3, beta=0.7)
  53. >>> print(paddle.distribution.kl_divergence(p, q))
  54. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  55. 0.21193528)
  56. """
  57. return _dispatch(type(p), type(q))(p, q)
  58. def register_kl(cls_p, cls_q):
  59. """Decorator for register a KL divergence implementation function.
  60. The ``kl_divergence(p, q)`` function will search concrete implementation
  61. functions registered by ``register_kl``, according to multi-dispatch pattern.
  62. If an implementation function is found, it will return the result, otherwise,
  63. it will raise ``NotImplementError`` exception. Users can register
  64. implementation function by the decorator.
  65. Args:
  66. cls_p (Distribution): The Distribution type of Instance p. Subclass derived from ``Distribution``.
  67. cls_q (Distribution): The Distribution type of Instance q. Subclass derived from ``Distribution``.
  68. Examples:
  69. .. code-block:: python
  70. >>> import paddle
  71. >>> @paddle.distribution.register_kl(paddle.distribution.Beta, paddle.distribution.Beta)
  72. >>> def kl_beta_beta():
  73. ... pass # insert implementation here
  74. """
  75. if not issubclass(cls_p, Distribution) or not issubclass(
  76. cls_q, Distribution
  77. ):
  78. raise TypeError('cls_p and cls_q must be subclass of Distribution')
  79. def decorator(f):
  80. _REGISTER_TABLE[cls_p, cls_q] = f
  81. return f
  82. return decorator
  83. def _dispatch(cls_p, cls_q):
  84. """Multiple dispatch into concrete implement function."""
  85. # find all matched super class pair of p and q
  86. matches = [
  87. (super_p, super_q)
  88. for super_p, super_q in _REGISTER_TABLE
  89. if issubclass(cls_p, super_p) and issubclass(cls_q, super_q)
  90. ]
  91. if not matches:
  92. raise NotImplementedError
  93. left_p, left_q = min(_Compare(*m) for m in matches).classes
  94. right_p, right_q = min(_Compare(*reversed(m)) for m in matches).classes
  95. if _REGISTER_TABLE[left_p, left_q] is not _REGISTER_TABLE[right_p, right_q]:
  96. warnings.warn(
  97. f'Ambiguous kl_divergence({cls_p.__name__}, {cls_q.__name__}). Please register_kl({left_p.__name__}, {right_q.__name__})',
  98. RuntimeWarning,
  99. )
  100. return _REGISTER_TABLE[left_p, left_q]
  101. @functools.total_ordering
  102. class _Compare:
  103. def __init__(self, *classes):
  104. self.classes = classes
  105. def __eq__(self, other):
  106. return self.classes == other.classes
  107. def __le__(self, other):
  108. for cls_x, cls_y in zip(self.classes, other.classes):
  109. if not issubclass(cls_x, cls_y):
  110. return False
  111. if cls_x is not cls_y:
  112. break
  113. return True
  114. @register_kl(Bernoulli, Bernoulli)
  115. def _kl_bernoulli_bernoulli(p, q):
  116. return p.kl_divergence(q)
  117. @register_kl(Beta, Beta)
  118. def _kl_beta_beta(p, q):
  119. return (
  120. (q.alpha.lgamma() + q.beta.lgamma() + (p.alpha + p.beta).lgamma())
  121. - (p.alpha.lgamma() + p.beta.lgamma() + (q.alpha + q.beta).lgamma())
  122. + ((p.alpha - q.alpha) * p.alpha.digamma())
  123. + ((p.beta - q.beta) * p.beta.digamma())
  124. + (
  125. ((q.alpha + q.beta) - (p.alpha + p.beta))
  126. * (p.alpha + p.beta).digamma()
  127. )
  128. )
  129. @register_kl(Binomial, Binomial)
  130. def _kl_binomial_binomial(p, q):
  131. return p.kl_divergence(q)
  132. @register_kl(Dirichlet, Dirichlet)
  133. def _kl_dirichlet_dirichlet(p, q):
  134. return (
  135. (p.concentration.sum(-1).lgamma() - q.concentration.sum(-1).lgamma())
  136. - ((p.concentration.lgamma() - q.concentration.lgamma()).sum(-1))
  137. + (
  138. (
  139. (p.concentration - q.concentration)
  140. * (
  141. p.concentration.digamma()
  142. - p.concentration.sum(-1).digamma().unsqueeze(-1)
  143. )
  144. ).sum(-1)
  145. )
  146. )
  147. @register_kl(Categorical, Categorical)
  148. def _kl_categorical_categorical(p, q):
  149. return p.kl_divergence(q)
  150. @register_kl(Cauchy, Cauchy)
  151. def _kl_cauchy_cauchy(p, q):
  152. return p.kl_divergence(q)
  153. @register_kl(ContinuousBernoulli, ContinuousBernoulli)
  154. def _kl_continuousbernoulli_continuousbernoulli(p, q):
  155. return p.kl_divergence(q)
  156. @register_kl(Normal, Normal)
  157. def _kl_normal_normal(p, q):
  158. return p.kl_divergence(q)
  159. @register_kl(MultivariateNormal, MultivariateNormal)
  160. def _kl_mvn_mvn(p, q):
  161. return p.kl_divergence(q)
  162. @register_kl(Uniform, Uniform)
  163. def _kl_uniform_uniform(p, q):
  164. return p.kl_divergence(q)
  165. @register_kl(Laplace, Laplace)
  166. def _kl_laplace_laplace(p, q):
  167. return p.kl_divergence(q)
  168. @register_kl(Geometric, Geometric)
  169. def _kl_geometric_geometric(p, q):
  170. return p.kl_divergence(q)
  171. @register_kl(ExponentialFamily, ExponentialFamily)
  172. def _kl_expfamily_expfamily(p, q):
  173. """Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_"""
  174. if not type(p) == type(q):
  175. raise NotImplementedError
  176. p_natural_params = []
  177. for param in p._natural_parameters:
  178. param = param.detach()
  179. param.stop_gradient = False
  180. p_natural_params.append(param)
  181. q_natural_params = q._natural_parameters
  182. p_log_norm = p._log_normalizer(*p_natural_params)
  183. try:
  184. if in_dynamic_mode():
  185. p_grads = paddle.grad(
  186. p_log_norm, p_natural_params, create_graph=True
  187. )
  188. else:
  189. p_grads = paddle.static.gradients(p_log_norm, p_natural_params)
  190. except RuntimeError as e:
  191. raise TypeError(
  192. "Cann't compute kl_divergence({cls_p}, {cls_q}) use bregman divergence. Please register_kl({cls_p}, {cls_q}).".format(
  193. cls_p=type(p).__name__, cls_q=type(q).__name__
  194. )
  195. ) from e
  196. kl = q._log_normalizer(*q_natural_params) - p_log_norm
  197. for p_param, q_param, p_grad in zip(
  198. p_natural_params, q_natural_params, p_grads
  199. ):
  200. term = (q_param - p_param) * p_grad
  201. kl -= _sum_rightmost(term, len(q.event_shape))
  202. return kl
  203. @register_kl(Exponential, Exponential)
  204. def _kl_exponential_exponential(p, q):
  205. return p.kl_divergence(q)
  206. @register_kl(Gamma, Gamma)
  207. def _kl_gamma_gamma(p, q):
  208. return p.kl_divergence(q)
  209. @register_kl(LogNormal, LogNormal)
  210. def _kl_lognormal_lognormal(p, q):
  211. return p._base.kl_divergence(q._base)
  212. @register_kl(Poisson, Poisson)
  213. def _kl_poisson_poisson(p, q):
  214. return p.kl_divergence(q)
  215. def _sum_rightmost(value, n):
  216. return value.sum(list(range(-n, 0))) if n > 0 else value