| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import functools
- import warnings
- import paddle
- from paddle.distribution.bernoulli import Bernoulli
- from paddle.distribution.beta import Beta
- from paddle.distribution.binomial import Binomial
- from paddle.distribution.categorical import Categorical
- from paddle.distribution.cauchy import Cauchy
- from paddle.distribution.continuous_bernoulli import ContinuousBernoulli
- from paddle.distribution.dirichlet import Dirichlet
- from paddle.distribution.distribution import Distribution
- from paddle.distribution.exponential import Exponential
- from paddle.distribution.exponential_family import ExponentialFamily
- from paddle.distribution.gamma import Gamma
- from paddle.distribution.geometric import Geometric
- from paddle.distribution.laplace import Laplace
- from paddle.distribution.lognormal import LogNormal
- from paddle.distribution.multivariate_normal import MultivariateNormal
- from paddle.distribution.normal import Normal
- from paddle.distribution.poisson import Poisson
- from paddle.distribution.uniform import Uniform
- from paddle.framework import in_dynamic_mode
- __all__ = ["register_kl", "kl_divergence"]
- _REGISTER_TABLE = {}
- def kl_divergence(p, q):
- r"""
- Kullback-Leibler divergence between distribution p and q.
- .. math::
- KL(p||q) = \int p(x)log\frac{p(x)}{q(x)} \mathrm{d}x
- Args:
- p (Distribution): ``Distribution`` object. Inherits from the Distribution Base class.
- q (Distribution): ``Distribution`` object. Inherits from the Distribution Base class.
- Returns:
- Tensor, Batchwise KL-divergence between distribution p and q.
- Examples:
- .. code-block:: python
- >>> import paddle
- >>> p = paddle.distribution.Beta(alpha=0.5, beta=0.5)
- >>> q = paddle.distribution.Beta(alpha=0.3, beta=0.7)
- >>> print(paddle.distribution.kl_divergence(p, q))
- Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
- 0.21193528)
- """
- return _dispatch(type(p), type(q))(p, q)
- def register_kl(cls_p, cls_q):
- """Decorator for register a KL divergence implementation function.
- The ``kl_divergence(p, q)`` function will search concrete implementation
- functions registered by ``register_kl``, according to multi-dispatch pattern.
- If an implementation function is found, it will return the result, otherwise,
- it will raise ``NotImplementError`` exception. Users can register
- implementation function by the decorator.
- Args:
- cls_p (Distribution): The Distribution type of Instance p. Subclass derived from ``Distribution``.
- cls_q (Distribution): The Distribution type of Instance q. Subclass derived from ``Distribution``.
- Examples:
- .. code-block:: python
- >>> import paddle
- >>> @paddle.distribution.register_kl(paddle.distribution.Beta, paddle.distribution.Beta)
- >>> def kl_beta_beta():
- ... pass # insert implementation here
- """
- if not issubclass(cls_p, Distribution) or not issubclass(
- cls_q, Distribution
- ):
- raise TypeError('cls_p and cls_q must be subclass of Distribution')
- def decorator(f):
- _REGISTER_TABLE[cls_p, cls_q] = f
- return f
- return decorator
- def _dispatch(cls_p, cls_q):
- """Multiple dispatch into concrete implement function."""
- # find all matched super class pair of p and q
- matches = [
- (super_p, super_q)
- for super_p, super_q in _REGISTER_TABLE
- if issubclass(cls_p, super_p) and issubclass(cls_q, super_q)
- ]
- if not matches:
- raise NotImplementedError
- left_p, left_q = min(_Compare(*m) for m in matches).classes
- right_p, right_q = min(_Compare(*reversed(m)) for m in matches).classes
- if _REGISTER_TABLE[left_p, left_q] is not _REGISTER_TABLE[right_p, right_q]:
- warnings.warn(
- f'Ambiguous kl_divergence({cls_p.__name__}, {cls_q.__name__}). Please register_kl({left_p.__name__}, {right_q.__name__})',
- RuntimeWarning,
- )
- return _REGISTER_TABLE[left_p, left_q]
- @functools.total_ordering
- class _Compare:
- def __init__(self, *classes):
- self.classes = classes
- def __eq__(self, other):
- return self.classes == other.classes
- def __le__(self, other):
- for cls_x, cls_y in zip(self.classes, other.classes):
- if not issubclass(cls_x, cls_y):
- return False
- if cls_x is not cls_y:
- break
- return True
- @register_kl(Bernoulli, Bernoulli)
- def _kl_bernoulli_bernoulli(p, q):
- return p.kl_divergence(q)
- @register_kl(Beta, Beta)
- def _kl_beta_beta(p, q):
- return (
- (q.alpha.lgamma() + q.beta.lgamma() + (p.alpha + p.beta).lgamma())
- - (p.alpha.lgamma() + p.beta.lgamma() + (q.alpha + q.beta).lgamma())
- + ((p.alpha - q.alpha) * p.alpha.digamma())
- + ((p.beta - q.beta) * p.beta.digamma())
- + (
- ((q.alpha + q.beta) - (p.alpha + p.beta))
- * (p.alpha + p.beta).digamma()
- )
- )
- @register_kl(Binomial, Binomial)
- def _kl_binomial_binomial(p, q):
- return p.kl_divergence(q)
- @register_kl(Dirichlet, Dirichlet)
- def _kl_dirichlet_dirichlet(p, q):
- return (
- (p.concentration.sum(-1).lgamma() - q.concentration.sum(-1).lgamma())
- - ((p.concentration.lgamma() - q.concentration.lgamma()).sum(-1))
- + (
- (
- (p.concentration - q.concentration)
- * (
- p.concentration.digamma()
- - p.concentration.sum(-1).digamma().unsqueeze(-1)
- )
- ).sum(-1)
- )
- )
- @register_kl(Categorical, Categorical)
- def _kl_categorical_categorical(p, q):
- return p.kl_divergence(q)
- @register_kl(Cauchy, Cauchy)
- def _kl_cauchy_cauchy(p, q):
- return p.kl_divergence(q)
- @register_kl(ContinuousBernoulli, ContinuousBernoulli)
- def _kl_continuousbernoulli_continuousbernoulli(p, q):
- return p.kl_divergence(q)
- @register_kl(Normal, Normal)
- def _kl_normal_normal(p, q):
- return p.kl_divergence(q)
- @register_kl(MultivariateNormal, MultivariateNormal)
- def _kl_mvn_mvn(p, q):
- return p.kl_divergence(q)
- @register_kl(Uniform, Uniform)
- def _kl_uniform_uniform(p, q):
- return p.kl_divergence(q)
- @register_kl(Laplace, Laplace)
- def _kl_laplace_laplace(p, q):
- return p.kl_divergence(q)
- @register_kl(Geometric, Geometric)
- def _kl_geometric_geometric(p, q):
- return p.kl_divergence(q)
- @register_kl(ExponentialFamily, ExponentialFamily)
- def _kl_expfamily_expfamily(p, q):
- """Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_"""
- if not type(p) == type(q):
- raise NotImplementedError
- p_natural_params = []
- for param in p._natural_parameters:
- param = param.detach()
- param.stop_gradient = False
- p_natural_params.append(param)
- q_natural_params = q._natural_parameters
- p_log_norm = p._log_normalizer(*p_natural_params)
- try:
- if in_dynamic_mode():
- p_grads = paddle.grad(
- p_log_norm, p_natural_params, create_graph=True
- )
- else:
- p_grads = paddle.static.gradients(p_log_norm, p_natural_params)
- except RuntimeError as e:
- raise TypeError(
- "Cann't compute kl_divergence({cls_p}, {cls_q}) use bregman divergence. Please register_kl({cls_p}, {cls_q}).".format(
- cls_p=type(p).__name__, cls_q=type(q).__name__
- )
- ) from e
- kl = q._log_normalizer(*q_natural_params) - p_log_norm
- for p_param, q_param, p_grad in zip(
- p_natural_params, q_natural_params, p_grads
- ):
- term = (q_param - p_param) * p_grad
- kl -= _sum_rightmost(term, len(q.event_shape))
- return kl
- @register_kl(Exponential, Exponential)
- def _kl_exponential_exponential(p, q):
- return p.kl_divergence(q)
- @register_kl(Gamma, Gamma)
- def _kl_gamma_gamma(p, q):
- return p.kl_divergence(q)
- @register_kl(LogNormal, LogNormal)
- def _kl_lognormal_lognormal(p, q):
- return p._base.kl_divergence(q._base)
- @register_kl(Poisson, Poisson)
- def _kl_poisson_poisson(p, q):
- return p.kl_divergence(q)
- def _sum_rightmost(value, n):
- return value.sum(list(range(-n, 0))) if n > 0 else value
|