mixture_same_family.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # mypy: allow-untyped-defs
  2. from typing import Optional
  3. import torch
  4. from torch import Tensor
  5. from torch.distributions import Categorical, constraints
  6. from torch.distributions.constraints import MixtureSameFamilyConstraint
  7. from torch.distributions.distribution import Distribution
  8. __all__ = ["MixtureSameFamily"]
  9. class MixtureSameFamily(Distribution):
  10. r"""
  11. The `MixtureSameFamily` distribution implements a (batch of) mixture
  12. distribution where all component are from different parameterizations of
  13. the same distribution type. It is parameterized by a `Categorical`
  14. "selecting distribution" (over `k` component) and a component
  15. distribution, i.e., a `Distribution` with a rightmost batch shape
  16. (equal to `[k]`) which indexes each (batch of) component.
  17. Examples::
  18. >>> # xdoctest: +SKIP("undefined vars")
  19. >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
  20. >>> # weighted normal distributions
  21. >>> mix = D.Categorical(torch.ones(5,))
  22. >>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
  23. >>> gmm = MixtureSameFamily(mix, comp)
  24. >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
  25. >>> # weighted bivariate normal distributions
  26. >>> mix = D.Categorical(torch.ones(5,))
  27. >>> comp = D.Independent(D.Normal(
  28. ... torch.randn(5,2), torch.rand(5,2)), 1)
  29. >>> gmm = MixtureSameFamily(mix, comp)
  30. >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
  31. >>> # consisting of 5 random weighted bivariate normal distributions
  32. >>> mix = D.Categorical(torch.rand(3,5))
  33. >>> comp = D.Independent(D.Normal(
  34. ... torch.randn(3,5,2), torch.rand(3,5,2)), 1)
  35. >>> gmm = MixtureSameFamily(mix, comp)
  36. Args:
  37. mixture_distribution: `torch.distributions.Categorical`-like
  38. instance. Manages the probability of selecting component.
  39. The number of categories must match the rightmost batch
  40. dimension of the `component_distribution`. Must have either
  41. scalar `batch_shape` or `batch_shape` matching
  42. `component_distribution.batch_shape[:-1]`
  43. component_distribution: `torch.distributions.Distribution`-like
  44. instance. Right-most batch dimension indexes component.
  45. """
  46. arg_constraints: dict[str, constraints.Constraint] = {}
  47. has_rsample = False
  48. def __init__(
  49. self,
  50. mixture_distribution: Categorical,
  51. component_distribution: Distribution,
  52. validate_args: Optional[bool] = None,
  53. ) -> None:
  54. self._mixture_distribution = mixture_distribution
  55. self._component_distribution = component_distribution
  56. if not isinstance(self._mixture_distribution, Categorical):
  57. raise ValueError(
  58. " The Mixture distribution needs to be an "
  59. " instance of torch.distributions.Categorical"
  60. )
  61. if not isinstance(self._component_distribution, Distribution):
  62. raise ValueError(
  63. "The Component distribution need to be an "
  64. "instance of torch.distributions.Distribution"
  65. )
  66. # Check that batch size matches
  67. mdbs = self._mixture_distribution.batch_shape
  68. cdbs = self._component_distribution.batch_shape[:-1]
  69. for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
  70. if size1 != 1 and size2 != 1 and size1 != size2:
  71. raise ValueError(
  72. f"`mixture_distribution.batch_shape` ({mdbs}) is not "
  73. "compatible with `component_distribution."
  74. f"batch_shape`({cdbs})"
  75. )
  76. # Check that the number of mixture component matches
  77. km = self._mixture_distribution.logits.shape[-1]
  78. kc = self._component_distribution.batch_shape[-1]
  79. if km is not None and kc is not None and km != kc:
  80. raise ValueError(
  81. f"`mixture_distribution component` ({km}) does not"
  82. " equal `component_distribution.batch_shape[-1]`"
  83. f" ({kc})"
  84. )
  85. self._num_component = km
  86. event_shape = self._component_distribution.event_shape
  87. self._event_ndims = len(event_shape)
  88. super().__init__(
  89. batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args
  90. )
  91. def expand(self, batch_shape, _instance=None):
  92. batch_shape = torch.Size(batch_shape)
  93. batch_shape_comp = batch_shape + (self._num_component,)
  94. new = self._get_checked_instance(MixtureSameFamily, _instance)
  95. new._component_distribution = self._component_distribution.expand(
  96. batch_shape_comp
  97. )
  98. new._mixture_distribution = self._mixture_distribution.expand(batch_shape)
  99. new._num_component = self._num_component
  100. new._event_ndims = self._event_ndims
  101. event_shape = new._component_distribution.event_shape
  102. super(MixtureSameFamily, new).__init__(
  103. batch_shape=batch_shape, event_shape=event_shape, validate_args=False
  104. )
  105. new._validate_args = self._validate_args
  106. return new
  107. @constraints.dependent_property
  108. def support(self):
  109. return MixtureSameFamilyConstraint(self._component_distribution.support)
  110. @property
  111. def mixture_distribution(self) -> Categorical:
  112. return self._mixture_distribution
  113. @property
  114. def component_distribution(self) -> Distribution:
  115. return self._component_distribution
  116. @property
  117. def mean(self) -> Tensor:
  118. probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
  119. return torch.sum(
  120. probs * self.component_distribution.mean, dim=-1 - self._event_ndims
  121. ) # [B, E]
  122. @property
  123. def variance(self) -> Tensor:
  124. # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
  125. probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
  126. mean_cond_var = torch.sum(
  127. probs * self.component_distribution.variance, dim=-1 - self._event_ndims
  128. )
  129. var_cond_mean = torch.sum(
  130. probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0),
  131. dim=-1 - self._event_ndims,
  132. )
  133. return mean_cond_var + var_cond_mean
  134. def cdf(self, x):
  135. x = self._pad(x)
  136. cdf_x = self.component_distribution.cdf(x)
  137. mix_prob = self.mixture_distribution.probs
  138. return torch.sum(cdf_x * mix_prob, dim=-1)
  139. def log_prob(self, x):
  140. if self._validate_args:
  141. self._validate_sample(x)
  142. x = self._pad(x)
  143. log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]
  144. log_mix_prob = torch.log_softmax(
  145. self.mixture_distribution.logits, dim=-1
  146. ) # [B, k]
  147. return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B]
  148. def sample(self, sample_shape=torch.Size()):
  149. with torch.no_grad():
  150. sample_len = len(sample_shape)
  151. batch_len = len(self.batch_shape)
  152. gather_dim = sample_len + batch_len
  153. es = self.event_shape
  154. # mixture samples [n, B]
  155. mix_sample = self.mixture_distribution.sample(sample_shape)
  156. mix_shape = mix_sample.shape
  157. # component samples [n, B, k, E]
  158. comp_samples = self.component_distribution.sample(sample_shape)
  159. # Gather along the k dimension
  160. mix_sample_r = mix_sample.reshape(
  161. mix_shape + torch.Size([1] * (len(es) + 1))
  162. )
  163. mix_sample_r = mix_sample_r.repeat(
  164. torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es
  165. )
  166. samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
  167. return samples.squeeze(gather_dim)
  168. def _pad(self, x):
  169. return x.unsqueeze(-1 - self._event_ndims)
  170. def _pad_mixture_dimensions(self, x):
  171. dist_batch_ndims = len(self.batch_shape)
  172. cat_batch_ndims = len(self.mixture_distribution.batch_shape)
  173. pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims
  174. xs = x.shape
  175. x = x.reshape(
  176. xs[:-1]
  177. + torch.Size(pad_ndims * [1])
  178. + xs[-1:]
  179. + torch.Size(self._event_ndims * [1])
  180. )
  181. return x
  182. def __repr__(self):
  183. args_string = (
  184. f"\n {self.mixture_distribution},\n {self.component_distribution}"
  185. )
  186. return "MixtureSameFamily" + "(" + args_string + ")"