transformed_distribution.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.independent import Independent
  8. from torch.distributions.transforms import ComposeTransform, Transform
  9. from torch.distributions.utils import _sum_rightmost
  10. from torch.types import _size
  11. __all__ = ["TransformedDistribution"]
  12. class TransformedDistribution(Distribution):
  13. r"""
  14. Extension of the Distribution class, which applies a sequence of Transforms
  15. to a base distribution. Let f be the composition of transforms applied::
  16. X ~ BaseDistribution
  17. Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
  18. log p(Y) = log p(X) + log |det (dX/dY)|
  19. Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
  20. maximum shape of its base distribution and its transforms, since transforms
  21. can introduce correlations among events.
  22. An example for the usage of :class:`TransformedDistribution` would be::
  23. # Building a Logistic Distribution
  24. # X ~ Uniform(0, 1)
  25. # f = a + b * logit(X)
  26. # Y ~ f(X) ~ Logistic(a, b)
  27. base_distribution = Uniform(0, 1)
  28. transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
  29. logistic = TransformedDistribution(base_distribution, transforms)
  30. For more examples, please look at the implementations of
  31. :class:`~torch.distributions.gumbel.Gumbel`,
  32. :class:`~torch.distributions.half_cauchy.HalfCauchy`,
  33. :class:`~torch.distributions.half_normal.HalfNormal`,
  34. :class:`~torch.distributions.log_normal.LogNormal`,
  35. :class:`~torch.distributions.pareto.Pareto`,
  36. :class:`~torch.distributions.weibull.Weibull`,
  37. :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
  38. :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
  39. """
  40. arg_constraints: dict[str, constraints.Constraint] = {}
  41. def __init__(
  42. self,
  43. base_distribution: Distribution,
  44. transforms: Union[Transform, list[Transform]],
  45. validate_args: Optional[bool] = None,
  46. ) -> None:
  47. if isinstance(transforms, Transform):
  48. self.transforms = [
  49. transforms,
  50. ]
  51. elif isinstance(transforms, list):
  52. if not all(isinstance(t, Transform) for t in transforms):
  53. raise ValueError(
  54. "transforms must be a Transform or a list of Transforms"
  55. )
  56. self.transforms = transforms
  57. else:
  58. raise ValueError(
  59. f"transforms must be a Transform or list, but was {transforms}"
  60. )
  61. # Reshape base_distribution according to transforms.
  62. base_shape = base_distribution.batch_shape + base_distribution.event_shape
  63. base_event_dim = len(base_distribution.event_shape)
  64. transform = ComposeTransform(self.transforms)
  65. if len(base_shape) < transform.domain.event_dim:
  66. raise ValueError(
  67. f"base_distribution needs to have shape with size at least {transform.domain.event_dim}, but got {base_shape}."
  68. )
  69. forward_shape = transform.forward_shape(base_shape)
  70. expanded_base_shape = transform.inverse_shape(forward_shape)
  71. if base_shape != expanded_base_shape:
  72. base_batch_shape = expanded_base_shape[
  73. : len(expanded_base_shape) - base_event_dim
  74. ]
  75. base_distribution = base_distribution.expand(base_batch_shape)
  76. reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
  77. if reinterpreted_batch_ndims > 0:
  78. base_distribution = Independent(
  79. base_distribution, reinterpreted_batch_ndims
  80. )
  81. self.base_dist = base_distribution
  82. # Compute shapes.
  83. transform_change_in_event_dim = (
  84. transform.codomain.event_dim - transform.domain.event_dim
  85. )
  86. event_dim = max(
  87. transform.codomain.event_dim, # the transform is coupled
  88. base_event_dim + transform_change_in_event_dim, # the base dist is coupled
  89. )
  90. assert len(forward_shape) >= event_dim
  91. cut = len(forward_shape) - event_dim
  92. batch_shape = forward_shape[:cut]
  93. event_shape = forward_shape[cut:]
  94. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  95. def expand(self, batch_shape, _instance=None):
  96. new = self._get_checked_instance(TransformedDistribution, _instance)
  97. batch_shape = torch.Size(batch_shape)
  98. shape = batch_shape + self.event_shape
  99. for t in reversed(self.transforms):
  100. shape = t.inverse_shape(shape)
  101. base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
  102. new.base_dist = self.base_dist.expand(base_batch_shape)
  103. new.transforms = self.transforms
  104. super(TransformedDistribution, new).__init__(
  105. batch_shape, self.event_shape, validate_args=False
  106. )
  107. new._validate_args = self._validate_args
  108. return new
  109. @constraints.dependent_property(is_discrete=False)
  110. def support(self):
  111. if not self.transforms:
  112. return self.base_dist.support
  113. support = self.transforms[-1].codomain
  114. if len(self.event_shape) > support.event_dim:
  115. support = constraints.independent(
  116. support, len(self.event_shape) - support.event_dim
  117. )
  118. return support
  119. @property
  120. def has_rsample(self) -> bool: # type: ignore[override]
  121. return self.base_dist.has_rsample
  122. def sample(self, sample_shape=torch.Size()):
  123. """
  124. Generates a sample_shape shaped sample or sample_shape shaped batch of
  125. samples if the distribution parameters are batched. Samples first from
  126. base distribution and applies `transform()` for every transform in the
  127. list.
  128. """
  129. with torch.no_grad():
  130. x = self.base_dist.sample(sample_shape)
  131. for transform in self.transforms:
  132. x = transform(x)
  133. return x
  134. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  135. """
  136. Generates a sample_shape shaped reparameterized sample or sample_shape
  137. shaped batch of reparameterized samples if the distribution parameters
  138. are batched. Samples first from base distribution and applies
  139. `transform()` for every transform in the list.
  140. """
  141. x = self.base_dist.rsample(sample_shape)
  142. for transform in self.transforms:
  143. x = transform(x)
  144. return x
  145. def log_prob(self, value):
  146. """
  147. Scores the sample by inverting the transform(s) and computing the score
  148. using the score of the base distribution and the log abs det jacobian.
  149. """
  150. if self._validate_args:
  151. self._validate_sample(value)
  152. event_dim = len(self.event_shape)
  153. log_prob: Union[Tensor, float] = 0.0
  154. y = value
  155. for transform in reversed(self.transforms):
  156. x = transform.inv(y)
  157. event_dim += transform.domain.event_dim - transform.codomain.event_dim
  158. log_prob = log_prob - _sum_rightmost(
  159. transform.log_abs_det_jacobian(x, y),
  160. event_dim - transform.domain.event_dim,
  161. )
  162. y = x
  163. log_prob = log_prob + _sum_rightmost(
  164. self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
  165. )
  166. return log_prob
  167. def _monotonize_cdf(self, value):
  168. """
  169. This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
  170. monotone increasing.
  171. """
  172. sign = 1
  173. for transform in self.transforms:
  174. sign = sign * transform.sign
  175. if isinstance(sign, int) and sign == 1:
  176. return value
  177. return sign * (value - 0.5) + 0.5
  178. def cdf(self, value):
  179. """
  180. Computes the cumulative distribution function by inverting the
  181. transform(s) and computing the score of the base distribution.
  182. """
  183. for transform in self.transforms[::-1]:
  184. value = transform.inv(value)
  185. if self._validate_args:
  186. self.base_dist._validate_sample(value)
  187. value = self.base_dist.cdf(value)
  188. value = self._monotonize_cdf(value)
  189. return value
  190. def icdf(self, value):
  191. """
  192. Computes the inverse cumulative distribution function using
  193. transform(s) and computing the score of the base distribution.
  194. """
  195. value = self._monotonize_cdf(value)
  196. value = self.base_dist.icdf(value)
  197. for transform in self.transforms:
  198. value = transform(value)
  199. return value