variable.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from paddle.distribution import constraint
  16. class Variable:
  17. """Random variable of probability distribution.
  18. Args:
  19. is_discrete (bool): Is the variable discrete or continuous.
  20. event_rank (int): The rank of event dimensions.
  21. """
  22. def __init__(self, is_discrete=False, event_rank=0, constraint=None):
  23. self._is_discrete = is_discrete
  24. self._event_rank = event_rank
  25. self._constraint = constraint
  26. @property
  27. def is_discrete(self):
  28. return self._is_discrete
  29. @property
  30. def event_rank(self):
  31. return self._event_rank
  32. def constraint(self, value):
  33. """Check whether the 'value' meet the constraint conditions of this
  34. random variable."""
  35. return self._constraint(value)
  36. class Real(Variable):
  37. def __init__(self, event_rank=0):
  38. super().__init__(False, event_rank, constraint.real)
  39. class Positive(Variable):
  40. def __init__(self, event_rank=0):
  41. super().__init__(False, event_rank, constraint.positive)
  42. class Independent(Variable):
  43. """Reinterprets some of the batch axes of variable as event axes.
  44. Args:
  45. base (Variable): Base variable.
  46. reinterpreted_batch_rank (int): The rightmost batch rank to be
  47. reinterpreted.
  48. """
  49. def __init__(self, base, reinterpreted_batch_rank):
  50. self._base = base
  51. self._reinterpreted_batch_rank = reinterpreted_batch_rank
  52. super().__init__(
  53. base.is_discrete, base.event_rank + reinterpreted_batch_rank
  54. )
  55. def constraint(self, value):
  56. ret = self._base.constraint(value)
  57. if ret.dim() < self._reinterpreted_batch_rank:
  58. raise ValueError(
  59. f"Input dimensions must be equal or grater than {self._reinterpreted_batch_rank}"
  60. )
  61. return ret.reshape(
  62. ret.shape[: ret.dim() - self.reinterpreted_batch_rank] + (-1,)
  63. ).all(-1)
  64. class Stack(Variable):
  65. def __init__(self, vars, axis=0):
  66. self._vars = vars
  67. self._axis = axis
  68. @property
  69. def is_discrete(self):
  70. return any(var.is_discrete for var in self._vars)
  71. @property
  72. def event_rank(self):
  73. rank = max(var.event_rank for var in self._vars)
  74. if self._axis + rank < 0:
  75. rank += 1
  76. return rank
  77. def constraint(self, value):
  78. if not (-value.dim() <= self._axis < value.dim()):
  79. raise ValueError(
  80. f'Input dimensions {value.dim()} should be grater than stack '
  81. f'constraint axis {self._axis}.'
  82. )
  83. return paddle.stack(
  84. [
  85. var.check(value)
  86. for var, value in zip(
  87. self._vars, paddle.unstack(value, self._axis)
  88. )
  89. ],
  90. self._axis,
  91. )
  92. real = Real()
  93. positive = Positive()