distribution.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # TODO: define the distribution functions
  15. # __all__ = ['Categorical',
  16. # 'MultivariateNormalDiag',
  17. # 'Normal',
  18. # 'Uniform']
  19. import warnings
  20. import numpy as np
  21. import paddle
  22. from paddle import _C_ops
  23. from paddle.base.data_feeder import check_variable_and_dtype, convert_dtype
  24. from paddle.base.framework import Variable
  25. from paddle.framework import (
  26. in_dynamic_or_pir_mode,
  27. in_pir_mode,
  28. )
  29. class Distribution:
  30. """
  31. The abstract base class for probability distributions. Functions are
  32. implemented in specific distributions.
  33. Args:
  34. batch_shape(Sequence[int], optional): independent, not identically
  35. distributed draws, aka a "collection" or "bunch" of distributions.
  36. event_shape(Sequence[int], optional): the shape of a single
  37. draw from the distribution; it may be dependent across dimensions.
  38. For scalar distributions, the event shape is []. For n-dimension
  39. multivariate distribution, the event shape is [n].
  40. """
  41. def __init__(self, batch_shape=(), event_shape=()):
  42. self._batch_shape = (
  43. batch_shape
  44. if isinstance(batch_shape, tuple)
  45. else tuple(batch_shape)
  46. )
  47. self._event_shape = (
  48. event_shape
  49. if isinstance(event_shape, tuple)
  50. else tuple(event_shape)
  51. )
  52. super().__init__()
  53. @property
  54. def batch_shape(self):
  55. """Returns batch shape of distribution
  56. Returns:
  57. Sequence[int]: batch shape
  58. """
  59. return self._batch_shape
  60. @property
  61. def event_shape(self):
  62. """Returns event shape of distribution
  63. Returns:
  64. Sequence[int]: event shape
  65. """
  66. return self._event_shape
  67. @property
  68. def mean(self):
  69. """Mean of distribution"""
  70. raise NotImplementedError
  71. @property
  72. def variance(self):
  73. """Variance of distribution"""
  74. raise NotImplementedError
  75. def sample(self, shape=()):
  76. """Sampling from the distribution."""
  77. raise NotImplementedError
  78. def rsample(self, shape=()):
  79. """reparameterized sample"""
  80. raise NotImplementedError
  81. def entropy(self):
  82. """The entropy of the distribution."""
  83. raise NotImplementedError
  84. def kl_divergence(self, other):
  85. """The KL-divergence between self distributions and other."""
  86. raise NotImplementedError
  87. def prob(self, value):
  88. """Probability density/mass function evaluated at value.
  89. Args:
  90. value (Tensor): value which will be evaluated
  91. """
  92. return self.log_prob(value).exp()
  93. def log_prob(self, value):
  94. """Log probability density/mass function."""
  95. raise NotImplementedError
  96. def probs(self, value):
  97. """Probability density/mass function.
  98. Note:
  99. This method will be deprecated in the future, please use `prob`
  100. instead.
  101. """
  102. raise NotImplementedError
  103. def _extend_shape(self, sample_shape):
  104. """compute shape of the sample
  105. Args:
  106. sample_shape (Tensor): sample shape
  107. Returns:
  108. Tensor: generated sample data shape
  109. """
  110. return (
  111. tuple(sample_shape)
  112. + tuple(self._batch_shape)
  113. + tuple(self._event_shape)
  114. )
  115. def _validate_args(self, *args):
  116. """
  117. Argument validation for distribution args
  118. Args:
  119. value (float, list, numpy.ndarray, Tensor)
  120. Raises
  121. ValueError: if one argument is Tensor, all arguments should be Tensor
  122. """
  123. is_variable = False
  124. is_number = False
  125. for arg in args:
  126. if isinstance(arg, (Variable, paddle.pir.Value)):
  127. is_variable = True
  128. else:
  129. is_number = True
  130. if is_variable and is_number:
  131. raise ValueError(
  132. 'if one argument is Tensor, all arguments should be Tensor'
  133. )
  134. return is_variable
  135. def _to_tensor(self, *args):
  136. """
  137. Argument convert args to Tensor
  138. Args:
  139. value (float, list, numpy.ndarray, Tensor)
  140. Returns:
  141. Tensor of args.
  142. """
  143. numpy_args = []
  144. variable_args = []
  145. tmp = 0.0
  146. for arg in args:
  147. if not isinstance(
  148. arg,
  149. (float, list, tuple, np.ndarray, Variable, paddle.pir.Value),
  150. ):
  151. raise TypeError(
  152. f"Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {type(arg)}"
  153. )
  154. if isinstance(arg, paddle.pir.Value):
  155. # pir.Value does not need to be converted to numpy.ndarray, so we skip here
  156. numpy_args.append(arg)
  157. continue
  158. arg_np = np.array(arg)
  159. arg_dtype = arg_np.dtype
  160. if str(arg_dtype) != 'float32':
  161. if str(arg_dtype) != 'float64':
  162. # "assign" op doesn't support float64. if dtype is float64, float32 variable will be generated
  163. # and converted to float64 later using "cast".
  164. warnings.warn(
  165. "data type of argument only support float32 and float64, your argument will be convert to float32."
  166. )
  167. arg_np = arg_np.astype('float32')
  168. # tmp is used to support broadcast, it summarizes shapes of all the args and get the mixed shape.
  169. tmp = tmp + arg_np
  170. numpy_args.append(arg_np)
  171. dtype = tmp.dtype
  172. for arg in numpy_args:
  173. if isinstance(arg, paddle.pir.Value):
  174. # pir.Value does not need to be converted to numpy.ndarray, so we skip here
  175. variable_args.append(arg)
  176. continue
  177. arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
  178. if in_pir_mode():
  179. arg_variable = paddle.zeros(arg_broadcasted.shape)
  180. else:
  181. arg_variable = paddle.tensor.create_tensor(dtype=dtype)
  182. paddle.assign(arg_broadcasted, arg_variable)
  183. variable_args.append(arg_variable)
  184. return tuple(variable_args)
  185. def _check_values_dtype_in_probs(self, param, value):
  186. """
  187. Log_prob and probs methods have input ``value``, if value's dtype is different from param,
  188. convert value's dtype to be consistent with param's dtype.
  189. Args:
  190. param (Tensor): low and high in Uniform class, loc and scale in Normal class.
  191. value (Tensor): The input tensor.
  192. Returns:
  193. value (Tensor): Change value's dtype if value's dtype is different from param.
  194. """
  195. if in_dynamic_or_pir_mode():
  196. if value.dtype != param.dtype and convert_dtype(value.dtype) in [
  197. 'float32',
  198. 'float64',
  199. ]:
  200. warnings.warn(
  201. "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
  202. )
  203. return _C_ops.cast(value, param.dtype)
  204. return value
  205. check_variable_and_dtype(
  206. value, 'value', ['float32', 'float64'], 'log_prob'
  207. )
  208. if value.dtype != param.dtype:
  209. warnings.warn(
  210. "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
  211. )
  212. return paddle.cast(value, dtype=param.dtype)
  213. return value
  214. def _probs_to_logits(self, probs, is_binary=False):
  215. r"""
  216. Converts probabilities into logits. For the binary, probs denotes the
  217. probability of occurrence of the event indexed by `1`. For the
  218. multi-dimensional, values of last axis denote the probabilities of
  219. occurrence of each of the events.
  220. """
  221. return (
  222. (paddle.log(probs) - paddle.log1p(-probs))
  223. if is_binary
  224. else paddle.log(probs)
  225. )
  226. def _logits_to_probs(self, logits, is_binary=False):
  227. r"""
  228. Converts logits into probabilities. For the binary, each value denotes
  229. log odds, whereas for the multi-dimensional case, the values along the
  230. last dimension denote the log probabilities of the events.
  231. """
  232. return (
  233. paddle.nn.functional.sigmoid(logits)
  234. if is_binary
  235. else paddle.nn.functional.softmax(logits, axis=-1)
  236. )