uniform.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. import torch
  4. from torch import nan, Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.utils import broadcast_all
  8. from torch.types import _Number, _size
  9. __all__ = ["Uniform"]
  10. class Uniform(Distribution):
  11. r"""
  12. Generates uniformly distributed random samples from the half-open interval
  13. ``[low, high)``.
  14. Example::
  15. >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
  16. >>> m.sample() # uniformly distributed in the range [0.0, 5.0)
  17. >>> # xdoctest: +SKIP
  18. tensor([ 2.3418])
  19. Args:
  20. low (float or Tensor): lower range (inclusive).
  21. high (float or Tensor): upper range (exclusive).
  22. """
  23. has_rsample = True
  24. @property
  25. def arg_constraints(self):
  26. # TODO allow (loc,scale) parameterization to allow independent constraints.
  27. return {
  28. "low": constraints.less_than(self.high),
  29. "high": constraints.greater_than(self.low),
  30. }
  31. @property
  32. def mean(self) -> Tensor:
  33. return (self.high + self.low) / 2
  34. @property
  35. def mode(self) -> Tensor:
  36. return nan * self.high
  37. @property
  38. def stddev(self) -> Tensor:
  39. return (self.high - self.low) / 12**0.5
  40. @property
  41. def variance(self) -> Tensor:
  42. return (self.high - self.low).pow(2) / 12
  43. def __init__(
  44. self,
  45. low: Union[Tensor, float],
  46. high: Union[Tensor, float],
  47. validate_args: Optional[bool] = None,
  48. ) -> None:
  49. self.low, self.high = broadcast_all(low, high)
  50. if isinstance(low, _Number) and isinstance(high, _Number):
  51. batch_shape = torch.Size()
  52. else:
  53. batch_shape = self.low.size()
  54. super().__init__(batch_shape, validate_args=validate_args)
  55. def expand(self, batch_shape, _instance=None):
  56. new = self._get_checked_instance(Uniform, _instance)
  57. batch_shape = torch.Size(batch_shape)
  58. new.low = self.low.expand(batch_shape)
  59. new.high = self.high.expand(batch_shape)
  60. super(Uniform, new).__init__(batch_shape, validate_args=False)
  61. new._validate_args = self._validate_args
  62. return new
  63. @constraints.dependent_property(is_discrete=False, event_dim=0)
  64. def support(self):
  65. return constraints.interval(self.low, self.high)
  66. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  67. shape = self._extended_shape(sample_shape)
  68. rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device)
  69. return self.low + rand * (self.high - self.low)
  70. def log_prob(self, value):
  71. if self._validate_args:
  72. self._validate_sample(value)
  73. lb = self.low.le(value).type_as(self.low)
  74. ub = self.high.gt(value).type_as(self.low)
  75. return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)
  76. def cdf(self, value):
  77. if self._validate_args:
  78. self._validate_sample(value)
  79. result = (value - self.low) / (self.high - self.low)
  80. return result.clamp(min=0, max=1)
  81. def icdf(self, value):
  82. result = value * (self.high - self.low) + self.low
  83. return result
  84. def entropy(self):
  85. return torch.log(self.high - self.low)