studentT.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Optional, Union
  4. import torch
  5. from torch import inf, nan, Tensor
  6. from torch.distributions import Chi2, constraints
  7. from torch.distributions.distribution import Distribution
  8. from torch.distributions.utils import _standard_normal, broadcast_all
  9. from torch.types import _size
  10. __all__ = ["StudentT"]
  11. class StudentT(Distribution):
  12. r"""
  13. Creates a Student's t-distribution parameterized by degree of
  14. freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`.
  15. Example::
  16. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  17. >>> m = StudentT(torch.tensor([2.0]))
  18. >>> m.sample() # Student's t-distributed with degrees of freedom=2
  19. tensor([ 0.1046])
  20. Args:
  21. df (float or Tensor): degrees of freedom
  22. loc (float or Tensor): mean of the distribution
  23. scale (float or Tensor): scale of the distribution
  24. """
  25. arg_constraints = {
  26. "df": constraints.positive,
  27. "loc": constraints.real,
  28. "scale": constraints.positive,
  29. }
  30. support = constraints.real
  31. has_rsample = True
  32. @property
  33. def mean(self) -> Tensor:
  34. m = self.loc.clone(memory_format=torch.contiguous_format)
  35. m[self.df <= 1] = nan
  36. return m
  37. @property
  38. def mode(self) -> Tensor:
  39. return self.loc
  40. @property
  41. def variance(self) -> Tensor:
  42. m = self.df.clone(memory_format=torch.contiguous_format)
  43. m[self.df > 2] = (
  44. self.scale[self.df > 2].pow(2)
  45. * self.df[self.df > 2]
  46. / (self.df[self.df > 2] - 2)
  47. )
  48. m[(self.df <= 2) & (self.df > 1)] = inf
  49. m[self.df <= 1] = nan
  50. return m
  51. def __init__(
  52. self,
  53. df: Union[Tensor, float],
  54. loc: Union[Tensor, float] = 0.0,
  55. scale: Union[Tensor, float] = 1.0,
  56. validate_args: Optional[bool] = None,
  57. ) -> None:
  58. self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
  59. self._chi2 = Chi2(self.df)
  60. batch_shape = self.df.size()
  61. super().__init__(batch_shape, validate_args=validate_args)
  62. def expand(self, batch_shape, _instance=None):
  63. new = self._get_checked_instance(StudentT, _instance)
  64. batch_shape = torch.Size(batch_shape)
  65. new.df = self.df.expand(batch_shape)
  66. new.loc = self.loc.expand(batch_shape)
  67. new.scale = self.scale.expand(batch_shape)
  68. new._chi2 = self._chi2.expand(batch_shape)
  69. super(StudentT, new).__init__(batch_shape, validate_args=False)
  70. new._validate_args = self._validate_args
  71. return new
  72. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  73. # NOTE: This does not agree with scipy implementation as much as other distributions.
  74. # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
  75. # parameters seems to help.
  76. # X ~ Normal(0, 1)
  77. # Z ~ Chi2(df)
  78. # Y = X / sqrt(Z / df) ~ StudentT(df)
  79. shape = self._extended_shape(sample_shape)
  80. X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
  81. Z = self._chi2.rsample(sample_shape)
  82. Y = X * torch.rsqrt(Z / self.df)
  83. return self.loc + self.scale * Y
  84. def log_prob(self, value):
  85. if self._validate_args:
  86. self._validate_sample(value)
  87. y = (value - self.loc) / self.scale
  88. Z = (
  89. self.scale.log()
  90. + 0.5 * self.df.log()
  91. + 0.5 * math.log(math.pi)
  92. + torch.lgamma(0.5 * self.df)
  93. - torch.lgamma(0.5 * (self.df + 1.0))
  94. )
  95. return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
  96. def entropy(self):
  97. lbeta = (
  98. torch.lgamma(0.5 * self.df)
  99. + math.lgamma(0.5)
  100. - torch.lgamma(0.5 * (self.df + 1))
  101. )
  102. return (
  103. self.scale.log()
  104. + 0.5
  105. * (self.df + 1)
  106. * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
  107. + 0.5 * self.df.log()
  108. + lbeta
  109. )