bernoulli.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from paddle.base.data_feeder import check_type, convert_dtype
  17. from paddle.base.framework import Variable
  18. from paddle.distribution import exponential_family
  19. from paddle.framework import in_dynamic_mode
  20. from paddle.nn.functional import (
  21. binary_cross_entropy_with_logits,
  22. sigmoid,
  23. softplus,
  24. )
  25. # Smallest representable number
  26. EPS = {
  27. 'float32': paddle.finfo(paddle.float32).eps,
  28. 'float64': paddle.finfo(paddle.float64).eps,
  29. }
  30. def _clip_probs(probs, dtype):
  31. """Clip probs from [0, 1] to (0, 1) with ``eps``.
  32. Args:
  33. probs (Tensor): probs of Bernoulli.
  34. dtype (str): data type.
  35. Returns:
  36. Tensor: Clipped probs.
  37. """
  38. eps = EPS.get(dtype)
  39. return paddle.clip(probs, min=eps, max=1 - eps).astype(dtype)
  40. class Bernoulli(exponential_family.ExponentialFamily):
  41. r"""Bernoulli distribution parameterized by ``probs``, which is the probability of value 1.
  42. In probability theory and statistics, the Bernoulli distribution, named after Swiss
  43. mathematician Jacob Bernoulli, is the discrete probability distribution of a random
  44. variable which takes the value 1 with probability ``p`` and the value 0 with
  45. probability ``q=1-p``.
  46. The probability mass function of this distribution, over possible outcomes ``k``, is
  47. .. math::
  48. {\begin{cases}
  49. q=1-p & \text{if }value=0 \\
  50. p & \text{if }value=1
  51. \end{cases}}
  52. Args:
  53. probs (float|Tensor): The ``probs`` input of Bernoulli distribution. The data type is float32 or float64. The range must be in [0, 1].
  54. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
  55. Examples:
  56. .. code-block:: python
  57. >>> import paddle
  58. >>> from paddle.distribution import Bernoulli
  59. >>> # init `probs` with a float
  60. >>> rv = Bernoulli(probs=0.3)
  61. >>> print(rv.mean)
  62. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  63. 0.30000001)
  64. >>> print(rv.variance)
  65. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  66. 0.21000001)
  67. >>> print(rv.entropy())
  68. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  69. 0.61086434)
  70. """
  71. def __init__(self, probs, name=None):
  72. self.name = name or 'Bernoulli'
  73. if not in_dynamic_mode():
  74. check_type(
  75. probs,
  76. 'probs',
  77. (float, Variable),
  78. self.name,
  79. )
  80. # Get/convert probs to tensor.
  81. if self._validate_args(probs):
  82. self.probs = probs
  83. self.dtype = convert_dtype(probs.dtype)
  84. else:
  85. [self.probs] = self._to_tensor(probs)
  86. self.dtype = paddle.get_default_dtype()
  87. # Check probs range [0, 1].
  88. if in_dynamic_mode():
  89. """Not use `paddle.any` in static mode, which always be `True`."""
  90. if (
  91. paddle.any(self.probs < 0)
  92. or paddle.any(self.probs > 1)
  93. or paddle.any(paddle.isnan(self.probs))
  94. ):
  95. raise ValueError("The arg of `probs` must be in range [0, 1].")
  96. # Clip probs from [0, 1] to (0, 1) with smallest representable number `eps`.
  97. self.probs = _clip_probs(self.probs, self.dtype)
  98. self.logits = self._probs_to_logits(self.probs, is_binary=True)
  99. super().__init__(batch_shape=self.probs.shape, event_shape=())
  100. @property
  101. def mean(self):
  102. """Mean of Bernoulli distribution.
  103. Returns:
  104. Tensor: Mean value of distribution.
  105. """
  106. return self.probs
  107. @property
  108. def variance(self):
  109. """Variance of Bernoulli distribution.
  110. Returns:
  111. Tensor: Variance value of distribution.
  112. """
  113. return paddle.multiply(self.probs, (1 - self.probs))
  114. def sample(self, shape):
  115. """Sample from Bernoulli distribution.
  116. Args:
  117. shape (Sequence[int]): Sample shape.
  118. Returns:
  119. Tensor: Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
  120. Examples:
  121. .. code-block:: python
  122. >>> import paddle
  123. >>> from paddle.distribution import Bernoulli
  124. >>> rv = Bernoulli(paddle.full((1), 0.3))
  125. >>> print(rv.sample([100]).shape)
  126. [100, 1]
  127. >>> rv = Bernoulli(paddle.to_tensor(0.3))
  128. >>> print(rv.sample([100]).shape)
  129. [100]
  130. >>> rv = Bernoulli(paddle.to_tensor([0.3, 0.5]))
  131. >>> print(rv.sample([100]).shape)
  132. [100, 2]
  133. >>> rv = Bernoulli(paddle.to_tensor([0.3, 0.5]))
  134. >>> print(rv.sample([100, 2]).shape)
  135. [100, 2, 2]
  136. """
  137. name = self.name + '_sample'
  138. if not in_dynamic_mode():
  139. check_type(
  140. shape,
  141. 'shape',
  142. (np.ndarray, Variable, list, tuple),
  143. name,
  144. )
  145. shape = shape if isinstance(shape, tuple) else tuple(shape)
  146. shape = self._extend_shape(shape)
  147. with paddle.no_grad():
  148. return paddle.bernoulli(self.probs.expand(shape), name=name)
  149. def rsample(self, shape, temperature=1.0):
  150. """Sample from Bernoulli distribution (reparameterized).
  151. The `rsample` is a continuously approximate of Bernoulli distribution reparameterized sample method.
  152. [1] Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables. 2016.
  153. [2] Eric Jang, Shixiang Gu, and Ben Poole. Categorical Reparameterization with Gumbel-Softmax. 2016.
  154. Note:
  155. `rsample` need to be followed by a `sigmoid`, which converts samples' value to unit interval (0, 1).
  156. Args:
  157. shape (Sequence[int]): Sample shape.
  158. temperature (float): temperature for rsample, must be positive.
  159. Returns:
  160. Tensor: Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
  161. Examples:
  162. .. code-block:: python
  163. >>> import paddle
  164. >>> paddle.seed(1)
  165. >>> from paddle.distribution import Bernoulli
  166. >>> rv = Bernoulli(paddle.full((1), 0.3))
  167. >>> print(rv.sample([100]).shape)
  168. [100, 1]
  169. >>> rv = Bernoulli(0.3)
  170. >>> print(rv.rsample([100]).shape)
  171. [100]
  172. >>> rv = Bernoulli(paddle.to_tensor([0.3, 0.5]))
  173. >>> print(rv.rsample([100]).shape)
  174. [100, 2]
  175. >>> rv = Bernoulli(paddle.to_tensor([0.3, 0.5]))
  176. >>> print(rv.rsample([100, 2]).shape)
  177. [100, 2, 2]
  178. >>> # `rsample` has to be followed by a `sigmoid`
  179. >>> rv = Bernoulli(0.3)
  180. >>> rsample = rv.rsample([3, ])
  181. >>> rsample_sigmoid = paddle.nn.functional.sigmoid(rsample)
  182. >>> print(rsample)
  183. Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
  184. [-1.46112013, -0.01239836, -1.32765460])
  185. >>> print(rsample_sigmoid)
  186. Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
  187. [0.18829606, 0.49690047, 0.20954758])
  188. >>> # The smaller the `temperature`, the distribution of `rsample` closer to `sample`, with `probs` of 0.3.
  189. >>> print(paddle.nn.functional.sigmoid(rv.rsample([1000, ], temperature=1.0)).sum())
  190. >>> # doctest: +SKIP('output will be different')
  191. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  192. 365.63122559)
  193. >>> # doctest: -SKIP
  194. >>> print(paddle.nn.functional.sigmoid(rv.rsample([1000, ], temperature=0.1)).sum())
  195. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  196. 320.15057373)
  197. """
  198. name = self.name + '_rsample'
  199. if not in_dynamic_mode():
  200. check_type(
  201. shape,
  202. 'shape',
  203. (np.ndarray, Variable, list, tuple),
  204. name,
  205. )
  206. check_type(
  207. temperature,
  208. 'temperature',
  209. (float,),
  210. name,
  211. )
  212. shape = shape if isinstance(shape, tuple) else tuple(shape)
  213. shape = self._extend_shape(shape)
  214. temperature = paddle.full(
  215. shape=(), fill_value=temperature, dtype=self.dtype
  216. )
  217. probs = self.probs.expand(shape)
  218. uniforms = paddle.rand(shape, dtype=self.dtype)
  219. return paddle.divide(
  220. paddle.add(
  221. paddle.subtract(uniforms.log(), (-uniforms).log1p()),
  222. paddle.subtract(probs.log(), (-probs).log1p()),
  223. ),
  224. temperature,
  225. )
  226. def cdf(self, value):
  227. r"""Cumulative distribution function(CDF) evaluated at value.
  228. .. math::
  229. { \begin{cases}
  230. 0 & \text{if } value \lt 0 \\
  231. 1 - p & \text{if } 0 \leq value \lt 1 \\
  232. 1 & \text{if } value \geq 1
  233. \end{cases}
  234. }
  235. Args:
  236. value (Tensor): Value to be evaluated.
  237. Returns:
  238. Tensor: CDF evaluated at value.
  239. Examples:
  240. .. code-block:: python
  241. >>> import paddle
  242. >>> from paddle.distribution import Bernoulli
  243. >>> rv = Bernoulli(0.3)
  244. >>> print(rv.cdf(paddle.to_tensor([1.0])))
  245. Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
  246. [1.])
  247. """
  248. name = self.name + '_cdf'
  249. if not in_dynamic_mode():
  250. check_type(value, 'value', Variable, name)
  251. value = self._check_values_dtype_in_probs(self.probs, value)
  252. probs, value = paddle.broadcast_tensors([self.probs, value])
  253. zeros = paddle.zeros_like(probs)
  254. ones = paddle.ones_like(probs)
  255. return paddle.where(
  256. value < 0,
  257. zeros,
  258. paddle.where(value < 1, paddle.subtract(ones, probs), ones),
  259. name=name,
  260. )
  261. def log_prob(self, value):
  262. """Log of probability density function.
  263. Args:
  264. value (Tensor): Value to be evaluated.
  265. Returns:
  266. Tensor: Log of probability density evaluated at value.
  267. Examples:
  268. .. code-block:: python
  269. >>> import paddle
  270. >>> from paddle.distribution import Bernoulli
  271. >>> rv = Bernoulli(0.3)
  272. >>> print(rv.log_prob(paddle.to_tensor([1.0])))
  273. Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
  274. [-1.20397282])
  275. """
  276. name = self.name + '_log_prob'
  277. if not in_dynamic_mode():
  278. check_type(value, 'value', Variable, name)
  279. value = self._check_values_dtype_in_probs(self.probs, value)
  280. logits, value = paddle.broadcast_tensors([self.logits, value])
  281. return -binary_cross_entropy_with_logits(
  282. logits, value, reduction='none', name=name
  283. )
  284. def prob(self, value):
  285. r"""Probability density function(PDF) evaluated at value.
  286. .. math::
  287. { \begin{cases}
  288. q=1-p & \text{if }value=0 \\
  289. p & \text{if }value=1
  290. \end{cases}
  291. }
  292. Args:
  293. value (Tensor): Value to be evaluated.
  294. Returns:
  295. Tensor: PDF evaluated at value.
  296. Examples:
  297. .. code-block:: python
  298. >>> import paddle
  299. >>> from paddle.distribution import Bernoulli
  300. >>> rv = Bernoulli(0.3)
  301. >>> print(rv.prob(paddle.to_tensor([1.0])))
  302. Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
  303. [0.29999998])
  304. """
  305. name = self.name + '_prob'
  306. if not in_dynamic_mode():
  307. check_type(value, 'value', Variable, name)
  308. return self.log_prob(value).exp(name=name)
  309. def entropy(self):
  310. r"""Entropy of Bernoulli distribution.
  311. .. math::
  312. {
  313. entropy = -(q \log q + p \log p)
  314. }
  315. Returns:
  316. Tensor: Entropy of distribution.
  317. Examples:
  318. .. code-block:: python
  319. >>> import paddle
  320. >>> from paddle.distribution import Bernoulli
  321. >>> rv = Bernoulli(0.3)
  322. >>> print(rv.entropy())
  323. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  324. 0.61086434)
  325. """
  326. name = self.name + '_entropy'
  327. return binary_cross_entropy_with_logits(
  328. self.logits, self.probs, reduction='none', name=name
  329. )
  330. def kl_divergence(self, other):
  331. r"""The KL-divergence between two Bernoulli distributions.
  332. .. math::
  333. {
  334. KL(a || b) = p_a \log(p_a / p_b) + (1 - p_a) \log((1 - p_a) / (1 - p_b))
  335. }
  336. Args:
  337. other (Bernoulli): instance of Bernoulli.
  338. Returns:
  339. Tensor: kl-divergence between two Bernoulli distributions.
  340. Examples:
  341. .. code-block:: python
  342. >>> import paddle
  343. >>> from paddle.distribution import Bernoulli
  344. >>> rv = Bernoulli(0.3)
  345. >>> rv_other = Bernoulli(0.7)
  346. >>> print(rv.kl_divergence(rv_other))
  347. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  348. 0.33891910)
  349. """
  350. name = self.name + '_kl_divergence'
  351. if not in_dynamic_mode():
  352. check_type(other, 'other', Bernoulli, name)
  353. a_logits = self.logits
  354. b_logits = other.logits
  355. log_pa = -softplus(-a_logits)
  356. log_pb = -softplus(-b_logits)
  357. pa = sigmoid(a_logits)
  358. one_minus_pa = sigmoid(-a_logits)
  359. log_one_minus_pa = -softplus(a_logits)
  360. log_one_minus_pb = -softplus(b_logits)
  361. return paddle.add(
  362. paddle.subtract(
  363. paddle.multiply(log_pa, pa), paddle.multiply(log_pb, pa)
  364. ),
  365. paddle.subtract(
  366. paddle.multiply(log_one_minus_pa, one_minus_pa),
  367. paddle.multiply(log_one_minus_pb, one_minus_pa),
  368. ),
  369. )