continuous_bernoulli.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. import paddle
  16. from paddle.distribution import distribution
  17. class ContinuousBernoulli(distribution.Distribution):
  18. r"""The Continuous Bernoulli distribution with parameter: `probs` characterizing the shape of the density function.
  19. The Continuous Bernoulli distribution is defined on [0, 1], and it can be viewed as a continuous version of the Bernoulli distribution.
  20. `The continuous Bernoulli: fixing a pervasive error in variational autoencoders. <https://arxiv.org/abs/1907.06845>`_
  21. Mathematical details
  22. The probability density function (pdf) is
  23. .. math::
  24. p(x;\lambda) = C(\lambda)\lambda^x (1-\lambda)^{1-x}
  25. In the above equation:
  26. * :math:`x`: is continuous between 0 and 1
  27. * :math:`probs = \lambda`: is the probability.
  28. * :math:`C(\lambda)`: is the normalizing constant factor
  29. .. math::
  30. C(\lambda) =
  31. \left\{
  32. \begin{aligned}
  33. &2 & \text{ if $\lambda = \frac{1}{2}$} \\
  34. &\frac{2\tanh^{-1}(1-2\lambda)}{1 - 2\lambda} & \text{ otherwise}
  35. \end{aligned}
  36. \right.
  37. Args:
  38. probs(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1],
  39. which characterize the shape of the pdf. If the input data type is int or float, the data type of
  40. `probs` will be convert to a 1-D Tensor the paddle global default dtype.
  41. lims(tuple): Specify the unstable calculation region near 0.5, where the calculation is approximated
  42. by talyor expansion. The default value is (0.499, 0.501).
  43. Examples:
  44. .. code-block:: python
  45. >>> import paddle
  46. >>> from paddle.distribution import ContinuousBernoulli
  47. >>> paddle.set_device("cpu")
  48. >>> paddle.seed(100)
  49. >>> rv = ContinuousBernoulli(paddle.to_tensor([0.2, 0.5]))
  50. >>> print(rv.sample([2]))
  51. Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
  52. [[0.38694882, 0.20714243],
  53. [0.00631948, 0.51577556]])
  54. >>> print(rv.mean)
  55. Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
  56. [0.38801414, 0.50000000])
  57. >>> print(rv.variance)
  58. Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
  59. [0.07589778, 0.08333334])
  60. >>> print(rv.entropy())
  61. Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
  62. [-0.07641457, 0. ])
  63. >>> print(rv.cdf(paddle.to_tensor(0.1)))
  64. Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
  65. [0.17259926, 0.10000000])
  66. >>> print(rv.icdf(paddle.to_tensor(0.1)))
  67. Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
  68. [0.05623737, 0.10000000])
  69. >>> rv1 = ContinuousBernoulli(paddle.to_tensor([0.2, 0.8]))
  70. >>> rv2 = ContinuousBernoulli(paddle.to_tensor([0.7, 0.5]))
  71. >>> print(rv1.kl_divergence(rv2))
  72. Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
  73. [0.20103608, 0.07641447])
  74. """
  75. def __init__(self, probs, lims=(0.499, 0.501)):
  76. self.dtype = paddle.get_default_dtype()
  77. self.probs = self._to_tensor(probs)
  78. self.lims = paddle.to_tensor(lims, dtype=self.dtype)
  79. if not self._check_constraint(self.probs):
  80. raise ValueError(
  81. 'Every element of input parameter `probs` should be nonnegative.'
  82. )
  83. # eps_prob is used to clip the input `probs` in the range of [eps_prob, 1-eps_prob]
  84. eps_prob = paddle.finfo(self.probs.dtype).eps
  85. self.probs = paddle.clip(self.probs, min=eps_prob, max=1 - eps_prob)
  86. if self.probs.shape == []:
  87. batch_shape = (1,)
  88. else:
  89. batch_shape = self.probs.shape
  90. super().__init__(batch_shape)
  91. def _to_tensor(self, probs):
  92. """Convert the input parameters into tensors
  93. Returns:
  94. Tensor: converted probability.
  95. """
  96. # convert type
  97. if isinstance(probs, (float, int)):
  98. probs = paddle.to_tensor([probs], dtype=self.dtype)
  99. else:
  100. self.dtype = probs.dtype
  101. return probs
  102. def _check_constraint(self, value):
  103. """Check the constraint for input parameters
  104. Args:
  105. value (Tensor)
  106. Returns:
  107. bool: pass or not.
  108. """
  109. return (value >= 0).all() and (value <= 1).all()
  110. def _cut_support_region(self):
  111. """Generate stable support region indicator (prob < self.lims[0] && prob >= self.lims[1] )
  112. Returns:
  113. Tensor: the element of the returned indicator tensor corresponding to stable region is True, and False otherwise
  114. """
  115. return paddle.logical_or(
  116. paddle.less_equal(self.probs, self.lims[0]),
  117. paddle.greater_than(self.probs, self.lims[1]),
  118. )
  119. def _cut_probs(self):
  120. """Cut the probability parameter with stable support region
  121. Returns:
  122. Tensor: the element of the returned probability tensor corresponding to unstable region is set to be self.lims[0], and unchanged otherwise
  123. """
  124. return paddle.where(
  125. self._cut_support_region(),
  126. self.probs,
  127. self.lims[0] * paddle.ones_like(self.probs),
  128. )
  129. def _tanh_inverse(self, value):
  130. """Calculate the tanh inverse of value
  131. Args:
  132. value (Tensor)
  133. Returns:
  134. Tensor: tanh inverse of value
  135. """
  136. return 0.5 * (paddle.log1p(value) - paddle.log1p(-value))
  137. def _log_constant(self):
  138. """Calculate the logarithm of the constant factor :math:`C(lambda)` in the pdf of the Continuous Bernoulli distribution
  139. Returns:
  140. Tensor: logarithm of the constant factor
  141. """
  142. cut_probs = self._cut_probs()
  143. half = paddle.to_tensor(0.5, dtype=self.dtype)
  144. cut_probs_below_half = paddle.where(
  145. paddle.less_equal(cut_probs, half),
  146. cut_probs,
  147. paddle.zeros_like(cut_probs),
  148. )
  149. cut_probs_above_half = paddle.where(
  150. paddle.greater_equal(cut_probs, half),
  151. cut_probs,
  152. paddle.ones_like(cut_probs),
  153. )
  154. log_constant_propose = paddle.log(
  155. 2.0 * paddle.abs(self._tanh_inverse(1.0 - 2.0 * cut_probs))
  156. ) - paddle.where(
  157. paddle.less_equal(cut_probs, half),
  158. paddle.log1p(-2.0 * cut_probs_below_half),
  159. paddle.log(2.0 * cut_probs_above_half - 1.0),
  160. )
  161. x = paddle.square(self.probs - 0.5)
  162. taylor_expansion = (
  163. paddle.log(paddle.to_tensor(2.0, dtype=self.dtype))
  164. + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
  165. )
  166. return paddle.where(
  167. self._cut_support_region(), log_constant_propose, taylor_expansion
  168. )
  169. @property
  170. def mean(self):
  171. """Mean of Continuous Bernoulli distribution.
  172. Returns:
  173. Tensor: mean value.
  174. """
  175. cut_probs = self._cut_probs()
  176. tmp = paddle.divide(cut_probs, 2.0 * cut_probs - 1.0)
  177. propose = tmp + paddle.divide(
  178. paddle.to_tensor(1.0, dtype=self.dtype),
  179. 2.0 * self._tanh_inverse(1.0 - 2.0 * cut_probs),
  180. )
  181. x = self.probs - 0.5
  182. taylor_expansion = (
  183. 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * paddle.square(x)) * x
  184. )
  185. return paddle.where(
  186. self._cut_support_region(), propose, taylor_expansion
  187. )
  188. @property
  189. def variance(self):
  190. """Variance of Continuous Bernoulli distribution.
  191. Returns:
  192. Tensor: variance value.
  193. """
  194. cut_probs = self._cut_probs()
  195. tmp = paddle.divide(
  196. cut_probs * (cut_probs - 1.0),
  197. paddle.square(1.0 - 2.0 * cut_probs),
  198. )
  199. propose = tmp + paddle.divide(
  200. paddle.to_tensor(1.0, dtype=self.dtype),
  201. paddle.square(paddle.log1p(-cut_probs) - paddle.log(cut_probs)),
  202. )
  203. x = paddle.square(self.probs - 0.5)
  204. taylor_expansion = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x
  205. return paddle.where(
  206. self._cut_support_region(), propose, taylor_expansion
  207. )
  208. def sample(self, shape=()):
  209. """Generate Continuous Bernoulli samples of the specified shape. The final shape would be ``sample_shape + batch_shape``.
  210. Args:
  211. shape (Sequence[int], optional): Prepended shape of the generated samples.
  212. Returns:
  213. Tensor, Sampled data with shape `sample_shape` + `batch_shape`.
  214. """
  215. with paddle.no_grad():
  216. return self.rsample(shape)
  217. def rsample(self, shape=()):
  218. """Generate Continuous Bernoulli samples of the specified shape. The final shape would be ``sample_shape + batch_shape``.
  219. Args:
  220. shape (Sequence[int], optional): Prepended shape of the generated samples.
  221. Returns:
  222. Tensor, Sampled data with shape `sample_shape` + `batch_shape`.
  223. """
  224. if not isinstance(shape, Sequence):
  225. raise TypeError('sample shape must be Sequence object.')
  226. shape = tuple(shape)
  227. batch_shape = tuple(self.batch_shape)
  228. output_shape = tuple(shape + batch_shape)
  229. u = paddle.uniform(shape=output_shape, dtype=self.dtype, min=0, max=1)
  230. return self.icdf(u)
  231. def log_prob(self, value):
  232. """Log probability density function.
  233. Args:
  234. value (Tensor): The input tensor.
  235. Returns:
  236. Tensor: log probability. The data type is the same as `self.probs`.
  237. """
  238. value = paddle.cast(value, dtype=self.dtype)
  239. if not self._check_constraint(value):
  240. raise ValueError(
  241. 'Every element of input parameter `value` should be >= 0.0 and <= 1.0.'
  242. )
  243. eps = paddle.finfo(self.probs.dtype).eps
  244. cross_entropy = paddle.nan_to_num(
  245. value * paddle.log(self.probs)
  246. + (1.0 - value) * paddle.log(1 - self.probs),
  247. neginf=-eps,
  248. )
  249. return self._log_constant() + cross_entropy
  250. def prob(self, value):
  251. """Probability density function.
  252. Args:
  253. value (Tensor): The input tensor.
  254. Returns:
  255. Tensor: probability. The data type is the same as `self.probs`.
  256. """
  257. return paddle.exp(self.log_prob(value))
  258. def entropy(self):
  259. r"""Shannon entropy in nats.
  260. The entropy is
  261. .. math::
  262. \mathcal{H}(X) = -\log C + \left[ \log (1 - \lambda) -\log \lambda \right] \mathbb{E}(X) - \log(1 - \lambda)
  263. In the above equation:
  264. * :math:`\Omega`: is the support of the distribution.
  265. Returns:
  266. Tensor, Shannon entropy of Continuous Bernoulli distribution.
  267. """
  268. log_p = paddle.log(self.probs)
  269. log_1_minus_p = paddle.log1p(-self.probs)
  270. return paddle.where(
  271. paddle.equal(self.probs, paddle.to_tensor(0.5, dtype=self.dtype)),
  272. paddle.full_like(self.probs, 0.0),
  273. (
  274. -self._log_constant()
  275. + self.mean * (log_1_minus_p - log_p)
  276. - log_1_minus_p
  277. ),
  278. )
  279. def cdf(self, value):
  280. r"""Cumulative distribution function
  281. .. math::
  282. { P(X \le t; \lambda) =
  283. F(t;\lambda) =
  284. \left\{
  285. \begin{aligned}
  286. &t & \text{ if $\lambda = \frac{1}{2}$} \\
  287. &\frac{\lambda^t (1 - \lambda)^{1 - t} + \lambda - 1}{2\lambda - 1} & \text{ otherwise}
  288. \end{aligned}
  289. \right. }
  290. Args:
  291. value (Tensor): The input tensor.
  292. Returns:
  293. Tensor: quantile of :attr:`value`. The data type is the same as `self.probs`.
  294. """
  295. value = paddle.cast(value, dtype=self.dtype)
  296. if not self._check_constraint(value):
  297. raise ValueError(
  298. 'Every element of input parameter `value` should be >= 0.0 and <= 1.0.'
  299. )
  300. cut_probs = self._cut_probs()
  301. cdfs = (
  302. paddle.pow(cut_probs, value)
  303. * paddle.pow(1.0 - cut_probs, 1.0 - value)
  304. + cut_probs
  305. - 1.0
  306. ) / (2.0 * cut_probs - 1.0)
  307. unbounded_cdfs = paddle.where(self._cut_support_region(), cdfs, value)
  308. return paddle.where(
  309. paddle.less_equal(value, paddle.to_tensor(0.0, dtype=self.dtype)),
  310. paddle.zeros_like(value),
  311. paddle.where(
  312. paddle.greater_equal(
  313. value, paddle.to_tensor(1.0, dtype=self.dtype)
  314. ),
  315. paddle.ones_like(value),
  316. unbounded_cdfs,
  317. ),
  318. )
  319. def icdf(self, value):
  320. r"""Inverse cumulative distribution function
  321. .. math::
  322. { F^{-1}(x;\lambda) =
  323. \left\{
  324. \begin{aligned}
  325. &x & \text{ if $\lambda = \frac{1}{2}$} \\
  326. &\frac{\log(1+(\frac{2\lambda - 1}{1 - \lambda})x)}{\log(\frac{\lambda}{1-\lambda})} & \text{ otherwise}
  327. \end{aligned}
  328. \right. }
  329. Args:
  330. value (Tensor): The input tensor, meaning the quantile.
  331. Returns:
  332. Tensor: the value of the r.v. corresponding to the quantile. The data type is the same as `self.probs`.
  333. """
  334. value = paddle.cast(value, dtype=self.dtype)
  335. if not self._check_constraint(value):
  336. raise ValueError(
  337. 'Every element of input parameter `value` should be >= 0.0 and <= 1.0.'
  338. )
  339. cut_probs = self._cut_probs()
  340. return paddle.where(
  341. self._cut_support_region(),
  342. (
  343. paddle.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0))
  344. - paddle.log1p(-cut_probs)
  345. )
  346. / (paddle.log(cut_probs) - paddle.log1p(-cut_probs)),
  347. value,
  348. )
  349. def kl_divergence(self, other):
  350. r"""The KL-divergence between two Continuous Bernoulli distributions with the same `batch_shape`.
  351. The probability density function (pdf) is
  352. .. math::
  353. KL\_divergence(\lambda_1, \lambda_2) = - H - \{\log C_2 + [\log \lambda_2 - \log (1-\lambda_2)] \mathbb{E}_1(X) + \log (1-\lambda_2) \}
  354. Args:
  355. other (ContinuousBernoulli): instance of Continuous Bernoulli.
  356. Returns:
  357. Tensor, kl-divergence between two Continuous Bernoulli distributions.
  358. """
  359. if self.batch_shape != other.batch_shape:
  360. raise ValueError(
  361. "KL divergence of two Continuous Bernoulli distributions should share the same `batch_shape`."
  362. )
  363. part1 = -self.entropy()
  364. log_q = paddle.log(other.probs)
  365. log_1_minus_q = paddle.log1p(-other.probs)
  366. part2 = -(
  367. other._log_constant()
  368. + self.mean * (log_q - log_1_minus_q)
  369. + log_1_minus_q
  370. )
  371. return part1 + part2