multivariate_normal.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Optional
  4. import torch
  5. from torch import Tensor
  6. from torch.distributions import constraints
  7. from torch.distributions.distribution import Distribution
  8. from torch.distributions.utils import _standard_normal, lazy_property
  9. from torch.types import _size
  10. __all__ = ["MultivariateNormal"]
  11. def _batch_mv(bmat, bvec):
  12. r"""
  13. Performs a batched matrix-vector product, with compatible but different batch shapes.
  14. This function takes as input `bmat`, containing :math:`n \times n` matrices, and
  15. `bvec`, containing length :math:`n` vectors.
  16. Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
  17. to a batch shape. They are not necessarily assumed to have the same batch shape,
  18. just ones which can be broadcasted.
  19. """
  20. return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
  21. def _batch_mahalanobis(bL, bx):
  22. r"""
  23. Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
  24. for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
  25. Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
  26. shape, but `bL` one should be able to broadcasted to `bx` one.
  27. """
  28. n = bx.size(-1)
  29. bx_batch_shape = bx.shape[:-1]
  30. # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
  31. # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve
  32. bx_batch_dims = len(bx_batch_shape)
  33. bL_batch_dims = bL.dim() - 2
  34. outer_batch_dims = bx_batch_dims - bL_batch_dims
  35. old_batch_dims = outer_batch_dims + bL_batch_dims
  36. new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
  37. # Reshape bx with the shape (..., 1, i, j, 1, n)
  38. bx_new_shape = bx.shape[:outer_batch_dims]
  39. for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
  40. bx_new_shape += (sx // sL, sL)
  41. bx_new_shape += (n,)
  42. bx = bx.reshape(bx_new_shape)
  43. # Permute bx to make it have shape (..., 1, j, i, 1, n)
  44. permute_dims = (
  45. list(range(outer_batch_dims))
  46. + list(range(outer_batch_dims, new_batch_dims, 2))
  47. + list(range(outer_batch_dims + 1, new_batch_dims, 2))
  48. + [new_batch_dims]
  49. )
  50. bx = bx.permute(permute_dims)
  51. flat_L = bL.reshape(-1, n, n) # shape = b x n x n
  52. flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
  53. flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
  54. M_swap = (
  55. torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
  56. ) # shape = b x c
  57. M = M_swap.t() # shape = c x b
  58. # Now we revert the above reshape and permute operators.
  59. permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
  60. permute_inv_dims = list(range(outer_batch_dims))
  61. for i in range(bL_batch_dims):
  62. permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
  63. reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
  64. return reshaped_M.reshape(bx_batch_shape)
  65. def _precision_to_scale_tril(P):
  66. # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
  67. Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
  68. L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
  69. Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device)
  70. L = torch.linalg.solve_triangular(L_inv, Id, upper=False)
  71. return L
  72. class MultivariateNormal(Distribution):
  73. r"""
  74. Creates a multivariate normal (also called Gaussian) distribution
  75. parameterized by a mean vector and a covariance matrix.
  76. The multivariate normal distribution can be parameterized either
  77. in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
  78. or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
  79. or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
  80. diagonal entries, such that
  81. :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
  82. can be obtained via e.g. Cholesky decomposition of the covariance.
  83. Example:
  84. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  85. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  86. >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
  87. >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
  88. tensor([-0.2102, -0.5429])
  89. Args:
  90. loc (Tensor): mean of the distribution
  91. covariance_matrix (Tensor): positive-definite covariance matrix
  92. precision_matrix (Tensor): positive-definite precision matrix
  93. scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
  94. Note:
  95. Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
  96. :attr:`scale_tril` can be specified.
  97. Using :attr:`scale_tril` will be more efficient: all computations internally
  98. are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
  99. :attr:`precision_matrix` is passed instead, it is only used to compute
  100. the corresponding lower triangular matrices using a Cholesky decomposition.
  101. """
  102. arg_constraints = {
  103. "loc": constraints.real_vector,
  104. "covariance_matrix": constraints.positive_definite,
  105. "precision_matrix": constraints.positive_definite,
  106. "scale_tril": constraints.lower_cholesky,
  107. }
  108. support = constraints.real_vector
  109. has_rsample = True
  110. def __init__(
  111. self,
  112. loc: Tensor,
  113. covariance_matrix: Optional[Tensor] = None,
  114. precision_matrix: Optional[Tensor] = None,
  115. scale_tril: Optional[Tensor] = None,
  116. validate_args: Optional[bool] = None,
  117. ) -> None:
  118. if loc.dim() < 1:
  119. raise ValueError("loc must be at least one-dimensional.")
  120. if (covariance_matrix is not None) + (scale_tril is not None) + (
  121. precision_matrix is not None
  122. ) != 1:
  123. raise ValueError(
  124. "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
  125. )
  126. if scale_tril is not None:
  127. if scale_tril.dim() < 2:
  128. raise ValueError(
  129. "scale_tril matrix must be at least two-dimensional, "
  130. "with optional leading batch dimensions"
  131. )
  132. batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
  133. self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
  134. elif covariance_matrix is not None:
  135. if covariance_matrix.dim() < 2:
  136. raise ValueError(
  137. "covariance_matrix must be at least two-dimensional, "
  138. "with optional leading batch dimensions"
  139. )
  140. batch_shape = torch.broadcast_shapes(
  141. covariance_matrix.shape[:-2], loc.shape[:-1]
  142. )
  143. self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
  144. else:
  145. assert precision_matrix is not None # helps mypy
  146. if precision_matrix.dim() < 2:
  147. raise ValueError(
  148. "precision_matrix must be at least two-dimensional, "
  149. "with optional leading batch dimensions"
  150. )
  151. batch_shape = torch.broadcast_shapes(
  152. precision_matrix.shape[:-2], loc.shape[:-1]
  153. )
  154. self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
  155. self.loc = loc.expand(batch_shape + (-1,))
  156. event_shape = self.loc.shape[-1:]
  157. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  158. if scale_tril is not None:
  159. self._unbroadcasted_scale_tril = scale_tril
  160. elif covariance_matrix is not None:
  161. self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
  162. else: # precision_matrix is not None
  163. self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
  164. def expand(self, batch_shape, _instance=None):
  165. new = self._get_checked_instance(MultivariateNormal, _instance)
  166. batch_shape = torch.Size(batch_shape)
  167. loc_shape = batch_shape + self.event_shape
  168. cov_shape = batch_shape + self.event_shape + self.event_shape
  169. new.loc = self.loc.expand(loc_shape)
  170. new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
  171. if "covariance_matrix" in self.__dict__:
  172. new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
  173. if "scale_tril" in self.__dict__:
  174. new.scale_tril = self.scale_tril.expand(cov_shape)
  175. if "precision_matrix" in self.__dict__:
  176. new.precision_matrix = self.precision_matrix.expand(cov_shape)
  177. super(MultivariateNormal, new).__init__(
  178. batch_shape, self.event_shape, validate_args=False
  179. )
  180. new._validate_args = self._validate_args
  181. return new
  182. @lazy_property
  183. def scale_tril(self) -> Tensor:
  184. return self._unbroadcasted_scale_tril.expand(
  185. self._batch_shape + self._event_shape + self._event_shape
  186. )
  187. @lazy_property
  188. def covariance_matrix(self) -> Tensor:
  189. return torch.matmul(
  190. self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT
  191. ).expand(self._batch_shape + self._event_shape + self._event_shape)
  192. @lazy_property
  193. def precision_matrix(self) -> Tensor:
  194. return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
  195. self._batch_shape + self._event_shape + self._event_shape
  196. )
  197. @property
  198. def mean(self) -> Tensor:
  199. return self.loc
  200. @property
  201. def mode(self) -> Tensor:
  202. return self.loc
  203. @property
  204. def variance(self) -> Tensor:
  205. return (
  206. self._unbroadcasted_scale_tril.pow(2)
  207. .sum(-1)
  208. .expand(self._batch_shape + self._event_shape)
  209. )
  210. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  211. shape = self._extended_shape(sample_shape)
  212. eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
  213. return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
  214. def log_prob(self, value):
  215. if self._validate_args:
  216. self._validate_sample(value)
  217. diff = value - self.loc
  218. M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
  219. half_log_det = (
  220. self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
  221. )
  222. return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
  223. def entropy(self):
  224. half_log_det = (
  225. self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
  226. )
  227. H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
  228. if len(self._batch_shape) == 0:
  229. return H
  230. else:
  231. return H.expand(self._batch_shape)