independent.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # mypy: allow-untyped-defs
  2. from typing import Generic, Optional, TypeVar
  3. import torch
  4. from torch import Size, Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.utils import _sum_rightmost
  8. from torch.types import _size
  9. __all__ = ["Independent"]
  10. D = TypeVar("D", bound=Distribution)
  11. class Independent(Distribution, Generic[D]):
  12. r"""
  13. Reinterprets some of the batch dims of a distribution as event dims.
  14. This is mainly useful for changing the shape of the result of
  15. :meth:`log_prob`. For example to create a diagonal Normal distribution with
  16. the same shape as a Multivariate Normal distribution (so they are
  17. interchangeable), you can::
  18. >>> from torch.distributions.multivariate_normal import MultivariateNormal
  19. >>> from torch.distributions.normal import Normal
  20. >>> loc = torch.zeros(3)
  21. >>> scale = torch.ones(3)
  22. >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
  23. >>> [mvn.batch_shape, mvn.event_shape]
  24. [torch.Size([]), torch.Size([3])]
  25. >>> normal = Normal(loc, scale)
  26. >>> [normal.batch_shape, normal.event_shape]
  27. [torch.Size([3]), torch.Size([])]
  28. >>> diagn = Independent(normal, 1)
  29. >>> [diagn.batch_shape, diagn.event_shape]
  30. [torch.Size([]), torch.Size([3])]
  31. Args:
  32. base_distribution (torch.distributions.distribution.Distribution): a
  33. base distribution
  34. reinterpreted_batch_ndims (int): the number of batch dims to
  35. reinterpret as event dims
  36. """
  37. arg_constraints: dict[str, constraints.Constraint] = {}
  38. base_dist: D
  39. def __init__(
  40. self,
  41. base_distribution: D,
  42. reinterpreted_batch_ndims: int,
  43. validate_args: Optional[bool] = None,
  44. ) -> None:
  45. if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
  46. raise ValueError(
  47. "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
  48. f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"
  49. )
  50. shape: Size = base_distribution.batch_shape + base_distribution.event_shape
  51. event_dim: int = reinterpreted_batch_ndims + len(base_distribution.event_shape)
  52. batch_shape = shape[: len(shape) - event_dim]
  53. event_shape = shape[len(shape) - event_dim :]
  54. self.base_dist = base_distribution
  55. self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
  56. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  57. def expand(self, batch_shape, _instance=None):
  58. new = self._get_checked_instance(Independent, _instance)
  59. batch_shape = torch.Size(batch_shape)
  60. new.base_dist = self.base_dist.expand(
  61. batch_shape + self.event_shape[: self.reinterpreted_batch_ndims]
  62. )
  63. new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
  64. super(Independent, new).__init__(
  65. batch_shape, self.event_shape, validate_args=False
  66. )
  67. new._validate_args = self._validate_args
  68. return new
  69. @property
  70. def has_rsample(self) -> bool: # type: ignore[override]
  71. return self.base_dist.has_rsample
  72. @property
  73. def has_enumerate_support(self) -> bool: # type: ignore[override]
  74. if self.reinterpreted_batch_ndims > 0:
  75. return False
  76. return self.base_dist.has_enumerate_support
  77. @constraints.dependent_property
  78. def support(self):
  79. result = self.base_dist.support
  80. if self.reinterpreted_batch_ndims:
  81. result = constraints.independent(result, self.reinterpreted_batch_ndims)
  82. return result
  83. @property
  84. def mean(self) -> Tensor:
  85. return self.base_dist.mean
  86. @property
  87. def mode(self) -> Tensor:
  88. return self.base_dist.mode
  89. @property
  90. def variance(self) -> Tensor:
  91. return self.base_dist.variance
  92. def sample(self, sample_shape=torch.Size()) -> Tensor:
  93. return self.base_dist.sample(sample_shape)
  94. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  95. return self.base_dist.rsample(sample_shape)
  96. def log_prob(self, value):
  97. log_prob = self.base_dist.log_prob(value)
  98. return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
  99. def entropy(self):
  100. entropy = self.base_dist.entropy()
  101. return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
  102. def enumerate_support(self, expand=True):
  103. if self.reinterpreted_batch_ndims > 0:
  104. raise NotImplementedError(
  105. "Enumeration over cartesian product is not implemented"
  106. )
  107. return self.base_dist.enumerate_support(expand=expand)
  108. def __repr__(self):
  109. return (
  110. self.__class__.__name__
  111. + f"({self.base_dist}, {self.reinterpreted_batch_ndims})"
  112. )