studentT.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Optional, Union
  4. import torch
  5. from torch import inf, nan, Tensor
  6. from torch.distributions import Chi2, constraints
  7. from torch.distributions.distribution import Distribution
  8. from torch.distributions.utils import _standard_normal, broadcast_all
  9. from torch.types import _size
  10. __all__ = ["StudentT"]
  11. class StudentT(Distribution):
  12. r"""
  13. Creates a Student's t-distribution parameterized by degree of
  14. freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`.
  15. Example::
  16. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  17. >>> m = StudentT(torch.tensor([2.0]))
  18. >>> m.sample() # Student's t-distributed with degrees of freedom=2
  19. tensor([ 0.1046])
  20. Args:
  21. df (float or Tensor): degrees of freedom
  22. loc (float or Tensor): mean of the distribution
  23. scale (float or Tensor): scale of the distribution
  24. """
  25. # pyrefly: ignore [bad-override]
  26. arg_constraints = {
  27. "df": constraints.positive,
  28. "loc": constraints.real,
  29. "scale": constraints.positive,
  30. }
  31. support = constraints.real
  32. has_rsample = True
  33. @property
  34. def mean(self) -> Tensor:
  35. m = self.loc.clone(memory_format=torch.contiguous_format)
  36. m[self.df <= 1] = nan
  37. return m
  38. @property
  39. def mode(self) -> Tensor:
  40. return self.loc
  41. @property
  42. def variance(self) -> Tensor:
  43. m = self.df.clone(memory_format=torch.contiguous_format)
  44. m[self.df > 2] = (
  45. self.scale[self.df > 2].pow(2)
  46. * self.df[self.df > 2]
  47. / (self.df[self.df > 2] - 2)
  48. )
  49. m[(self.df <= 2) & (self.df > 1)] = inf
  50. m[self.df <= 1] = nan
  51. return m
  52. def __init__(
  53. self,
  54. df: Union[Tensor, float],
  55. loc: Union[Tensor, float] = 0.0,
  56. scale: Union[Tensor, float] = 1.0,
  57. validate_args: Optional[bool] = None,
  58. ) -> None:
  59. self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
  60. self._chi2 = Chi2(self.df)
  61. batch_shape = self.df.size()
  62. super().__init__(batch_shape, validate_args=validate_args)
  63. def expand(self, batch_shape, _instance=None):
  64. new = self._get_checked_instance(StudentT, _instance)
  65. batch_shape = torch.Size(batch_shape)
  66. new.df = self.df.expand(batch_shape)
  67. new.loc = self.loc.expand(batch_shape)
  68. new.scale = self.scale.expand(batch_shape)
  69. new._chi2 = self._chi2.expand(batch_shape)
  70. super(StudentT, new).__init__(batch_shape, validate_args=False)
  71. new._validate_args = self._validate_args
  72. return new
  73. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  74. # NOTE: This does not agree with scipy implementation as much as other distributions.
  75. # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
  76. # parameters seems to help.
  77. # X ~ Normal(0, 1)
  78. # Z ~ Chi2(df)
  79. # Y = X / sqrt(Z / df) ~ StudentT(df)
  80. shape = self._extended_shape(sample_shape)
  81. X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
  82. Z = self._chi2.rsample(sample_shape)
  83. Y = X * torch.rsqrt(Z / self.df)
  84. return self.loc + self.scale * Y
  85. def log_prob(self, value):
  86. if self._validate_args:
  87. self._validate_sample(value)
  88. y = (value - self.loc) / self.scale
  89. Z = (
  90. self.scale.log()
  91. + 0.5 * self.df.log()
  92. + 0.5 * math.log(math.pi)
  93. + torch.lgamma(0.5 * self.df)
  94. - torch.lgamma(0.5 * (self.df + 1.0))
  95. )
  96. return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
  97. def entropy(self):
  98. lbeta = (
  99. torch.lgamma(0.5 * self.df)
  100. + math.lgamma(0.5)
  101. - torch.lgamma(0.5 * (self.df + 1))
  102. )
  103. return (
  104. self.scale.log()
  105. + 0.5
  106. * (self.df + 1)
  107. * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
  108. + 0.5 * self.df.log()
  109. + lbeta
  110. )