half_normal.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Optional, Union
  4. import torch
  5. from torch import inf, Tensor
  6. from torch.distributions import constraints
  7. from torch.distributions.normal import Normal
  8. from torch.distributions.transformed_distribution import TransformedDistribution
  9. from torch.distributions.transforms import AbsTransform
  10. __all__ = ["HalfNormal"]
  11. class HalfNormal(TransformedDistribution):
  12. r"""
  13. Creates a half-normal distribution parameterized by `scale` where::
  14. X ~ Normal(0, scale)
  15. Y = |X| ~ HalfNormal(scale)
  16. Example::
  17. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  18. >>> m = HalfNormal(torch.tensor([1.0]))
  19. >>> m.sample() # half-normal distributed with scale=1
  20. tensor([ 0.1046])
  21. Args:
  22. scale (float or Tensor): scale of the full Normal distribution
  23. """
  24. arg_constraints = {"scale": constraints.positive}
  25. support = constraints.nonnegative
  26. has_rsample = True
  27. base_dist: Normal
  28. def __init__(
  29. self,
  30. scale: Union[Tensor, float],
  31. validate_args: Optional[bool] = None,
  32. ) -> None:
  33. base_dist = Normal(0, scale, validate_args=False)
  34. super().__init__(base_dist, AbsTransform(), validate_args=validate_args)
  35. def expand(self, batch_shape, _instance=None):
  36. new = self._get_checked_instance(HalfNormal, _instance)
  37. return super().expand(batch_shape, _instance=new)
  38. @property
  39. def scale(self) -> Tensor:
  40. return self.base_dist.scale
  41. @property
  42. def mean(self) -> Tensor:
  43. return self.scale * math.sqrt(2 / math.pi)
  44. @property
  45. def mode(self) -> Tensor:
  46. return torch.zeros_like(self.scale)
  47. @property
  48. def variance(self) -> Tensor:
  49. return self.scale.pow(2) * (1 - 2 / math.pi)
  50. def log_prob(self, value):
  51. if self._validate_args:
  52. self._validate_sample(value)
  53. log_prob = self.base_dist.log_prob(value) + math.log(2)
  54. log_prob = torch.where(value >= 0, log_prob, -inf)
  55. return log_prob
  56. def cdf(self, value):
  57. if self._validate_args:
  58. self._validate_sample(value)
  59. return 2 * self.base_dist.cdf(value) - 1
  60. def icdf(self, prob):
  61. return self.base_dist.icdf((prob + 1) / 2)
  62. def entropy(self):
  63. return self.base_dist.entropy() - math.log(2)