| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- # mypy: allow-untyped-defs
- from typing import Optional, Union
- import torch
- from torch import Tensor
- from torch.distributions import constraints
- from torch.distributions.gamma import Gamma
- from torch.distributions.transformed_distribution import TransformedDistribution
- from torch.distributions.transforms import PowerTransform
- __all__ = ["InverseGamma"]
- class InverseGamma(TransformedDistribution):
- r"""
- Creates an inverse gamma distribution parameterized by :attr:`concentration` and :attr:`rate`
- where::
- X ~ Gamma(concentration, rate)
- Y = 1 / X ~ InverseGamma(concentration, rate)
- Example::
- >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
- >>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0]))
- >>> m.sample()
- tensor([ 1.2953])
- Args:
- concentration (float or Tensor): shape parameter of the distribution
- (often referred to as alpha)
- rate (float or Tensor): rate = 1 / scale of the distribution
- (often referred to as beta)
- """
- arg_constraints = {
- "concentration": constraints.positive,
- "rate": constraints.positive,
- }
- support = constraints.positive
- has_rsample = True
- base_dist: Gamma
- def __init__(
- self,
- concentration: Union[Tensor, float],
- rate: Union[Tensor, float],
- validate_args: Optional[bool] = None,
- ) -> None:
- base_dist = Gamma(concentration, rate, validate_args=validate_args)
- neg_one = -base_dist.rate.new_ones(())
- super().__init__(
- base_dist, PowerTransform(neg_one), validate_args=validate_args
- )
- def expand(self, batch_shape, _instance=None):
- new = self._get_checked_instance(InverseGamma, _instance)
- return super().expand(batch_shape, _instance=new)
- @property
- def concentration(self) -> Tensor:
- return self.base_dist.concentration
- @property
- def rate(self) -> Tensor:
- return self.base_dist.rate
- @property
- def mean(self) -> Tensor:
- result = self.rate / (self.concentration - 1)
- return torch.where(self.concentration > 1, result, torch.inf)
- @property
- def mode(self) -> Tensor:
- return self.rate / (self.concentration + 1)
- @property
- def variance(self) -> Tensor:
- result = self.rate.square() / (
- (self.concentration - 1).square() * (self.concentration - 2)
- )
- return torch.where(self.concentration > 2, result, torch.inf)
- def entropy(self):
- return (
- self.concentration
- + self.rate.log()
- + self.concentration.lgamma()
- - (1 + self.concentration) * self.concentration.digamma()
- )
|