transforms.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import math
  4. import operator
  5. import weakref
  6. from collections.abc import Sequence
  7. from typing import Optional, Union
  8. import torch
  9. import torch.nn.functional as F
  10. from torch import Tensor
  11. from torch.distributions import constraints
  12. from torch.distributions.distribution import Distribution
  13. from torch.distributions.utils import (
  14. _sum_rightmost,
  15. broadcast_all,
  16. lazy_property,
  17. tril_matrix_to_vec,
  18. vec_to_tril_matrix,
  19. )
  20. from torch.nn.functional import pad, softplus
  21. from torch.types import _Number
  22. __all__ = [
  23. "AbsTransform",
  24. "AffineTransform",
  25. "CatTransform",
  26. "ComposeTransform",
  27. "CorrCholeskyTransform",
  28. "CumulativeDistributionTransform",
  29. "ExpTransform",
  30. "IndependentTransform",
  31. "LowerCholeskyTransform",
  32. "PositiveDefiniteTransform",
  33. "PowerTransform",
  34. "ReshapeTransform",
  35. "SigmoidTransform",
  36. "SoftplusTransform",
  37. "TanhTransform",
  38. "SoftmaxTransform",
  39. "StackTransform",
  40. "StickBreakingTransform",
  41. "Transform",
  42. "identity_transform",
  43. ]
  44. class Transform:
  45. """
  46. Abstract class for invertable transformations with computable log
  47. det jacobians. They are primarily used in
  48. :class:`torch.distributions.TransformedDistribution`.
  49. Caching is useful for transforms whose inverses are either expensive or
  50. numerically unstable. Note that care must be taken with memoized values
  51. since the autograd graph may be reversed. For example while the following
  52. works with or without caching::
  53. y = t(x)
  54. t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
  55. However the following will error when caching due to dependency reversal::
  56. y = t(x)
  57. z = t.inv(y)
  58. grad(z.sum(), [y]) # error because z is x
  59. Derived classes should implement one or both of :meth:`_call` or
  60. :meth:`_inverse`. Derived classes that set `bijective=True` should also
  61. implement :meth:`log_abs_det_jacobian`.
  62. Args:
  63. cache_size (int): Size of cache. If zero, no caching is done. If one,
  64. the latest single value is cached. Only 0 and 1 are supported.
  65. Attributes:
  66. domain (:class:`~torch.distributions.constraints.Constraint`):
  67. The constraint representing valid inputs to this transform.
  68. codomain (:class:`~torch.distributions.constraints.Constraint`):
  69. The constraint representing valid outputs to this transform
  70. which are inputs to the inverse transform.
  71. bijective (bool): Whether this transform is bijective. A transform
  72. ``t`` is bijective iff ``t.inv(t(x)) == x`` and
  73. ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
  74. the codomain. Transforms that are not bijective should at least
  75. maintain the weaker pseudoinverse properties
  76. ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
  77. sign (int or Tensor): For bijective univariate transforms, this
  78. should be +1 or -1 depending on whether transform is monotone
  79. increasing or decreasing.
  80. """
  81. bijective = False
  82. domain: constraints.Constraint
  83. codomain: constraints.Constraint
  84. def __init__(self, cache_size: int = 0) -> None:
  85. self._cache_size = cache_size
  86. self._inv: Optional[weakref.ReferenceType[Transform]] = None
  87. if cache_size == 0:
  88. pass # default behavior
  89. elif cache_size == 1:
  90. self._cached_x_y = None, None
  91. else:
  92. raise ValueError("cache_size must be 0 or 1")
  93. super().__init__()
  94. def __getstate__(self):
  95. state = self.__dict__.copy()
  96. state["_inv"] = None
  97. return state
  98. @property
  99. def event_dim(self) -> int:
  100. if self.domain.event_dim == self.codomain.event_dim:
  101. return self.domain.event_dim
  102. raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
  103. @property
  104. def inv(self) -> "Transform":
  105. """
  106. Returns the inverse :class:`Transform` of this transform.
  107. This should satisfy ``t.inv.inv is t``.
  108. """
  109. inv = None
  110. if self._inv is not None:
  111. inv = self._inv()
  112. if inv is None:
  113. inv = _InverseTransform(self)
  114. self._inv = weakref.ref(inv)
  115. return inv
  116. @property
  117. def sign(self) -> int:
  118. """
  119. Returns the sign of the determinant of the Jacobian, if applicable.
  120. In general this only makes sense for bijective transforms.
  121. """
  122. raise NotImplementedError
  123. def with_cache(self, cache_size=1):
  124. if self._cache_size == cache_size:
  125. return self
  126. if type(self).__init__ is Transform.__init__:
  127. return type(self)(cache_size=cache_size)
  128. raise NotImplementedError(f"{type(self)}.with_cache is not implemented")
  129. def __eq__(self, other):
  130. return self is other
  131. def __ne__(self, other):
  132. # Necessary for Python2
  133. return not self.__eq__(other)
  134. def __call__(self, x):
  135. """
  136. Computes the transform `x => y`.
  137. """
  138. if self._cache_size == 0:
  139. return self._call(x)
  140. x_old, y_old = self._cached_x_y
  141. if x is x_old:
  142. return y_old
  143. y = self._call(x)
  144. self._cached_x_y = x, y
  145. return y
  146. def _inv_call(self, y):
  147. """
  148. Inverts the transform `y => x`.
  149. """
  150. if self._cache_size == 0:
  151. return self._inverse(y)
  152. x_old, y_old = self._cached_x_y
  153. if y is y_old:
  154. return x_old
  155. x = self._inverse(y)
  156. self._cached_x_y = x, y
  157. return x
  158. def _call(self, x):
  159. """
  160. Abstract method to compute forward transformation.
  161. """
  162. raise NotImplementedError
  163. def _inverse(self, y):
  164. """
  165. Abstract method to compute inverse transformation.
  166. """
  167. raise NotImplementedError
  168. def log_abs_det_jacobian(self, x, y):
  169. """
  170. Computes the log det jacobian `log |dy/dx|` given input and output.
  171. """
  172. raise NotImplementedError
  173. def __repr__(self):
  174. return self.__class__.__name__ + "()"
  175. def forward_shape(self, shape):
  176. """
  177. Infers the shape of the forward computation, given the input shape.
  178. Defaults to preserving shape.
  179. """
  180. return shape
  181. def inverse_shape(self, shape):
  182. """
  183. Infers the shapes of the inverse computation, given the output shape.
  184. Defaults to preserving shape.
  185. """
  186. return shape
  187. class _InverseTransform(Transform):
  188. """
  189. Inverts a single :class:`Transform`.
  190. This class is private; please instead use the ``Transform.inv`` property.
  191. """
  192. def __init__(self, transform: Transform) -> None:
  193. super().__init__(cache_size=transform._cache_size)
  194. self._inv: Transform = transform # type: ignore[assignment]
  195. @constraints.dependent_property(is_discrete=False)
  196. def domain(self):
  197. assert self._inv is not None
  198. return self._inv.codomain
  199. @constraints.dependent_property(is_discrete=False)
  200. def codomain(self):
  201. assert self._inv is not None
  202. return self._inv.domain
  203. @property
  204. def bijective(self) -> bool: # type: ignore[override]
  205. assert self._inv is not None
  206. return self._inv.bijective
  207. @property
  208. def sign(self) -> int:
  209. assert self._inv is not None
  210. return self._inv.sign
  211. @property
  212. def inv(self) -> Transform:
  213. return self._inv
  214. def with_cache(self, cache_size=1):
  215. assert self._inv is not None
  216. return self.inv.with_cache(cache_size).inv
  217. def __eq__(self, other):
  218. if not isinstance(other, _InverseTransform):
  219. return False
  220. assert self._inv is not None
  221. return self._inv == other._inv
  222. def __repr__(self):
  223. return f"{self.__class__.__name__}({repr(self._inv)})"
  224. def __call__(self, x):
  225. assert self._inv is not None
  226. return self._inv._inv_call(x)
  227. def log_abs_det_jacobian(self, x, y):
  228. assert self._inv is not None
  229. return -self._inv.log_abs_det_jacobian(y, x)
  230. def forward_shape(self, shape):
  231. return self._inv.inverse_shape(shape)
  232. def inverse_shape(self, shape):
  233. return self._inv.forward_shape(shape)
  234. class ComposeTransform(Transform):
  235. """
  236. Composes multiple transforms in a chain.
  237. The transforms being composed are responsible for caching.
  238. Args:
  239. parts (list of :class:`Transform`): A list of transforms to compose.
  240. cache_size (int): Size of cache. If zero, no caching is done. If one,
  241. the latest single value is cached. Only 0 and 1 are supported.
  242. """
  243. def __init__(self, parts: list[Transform], cache_size: int = 0) -> None:
  244. if cache_size:
  245. parts = [part.with_cache(cache_size) for part in parts]
  246. super().__init__(cache_size=cache_size)
  247. self.parts = parts
  248. def __eq__(self, other):
  249. if not isinstance(other, ComposeTransform):
  250. return False
  251. return self.parts == other.parts
  252. @constraints.dependent_property(is_discrete=False)
  253. def domain(self):
  254. if not self.parts:
  255. return constraints.real
  256. domain = self.parts[0].domain
  257. # Adjust event_dim to be maximum among all parts.
  258. event_dim = self.parts[-1].codomain.event_dim
  259. for part in reversed(self.parts):
  260. event_dim += part.domain.event_dim - part.codomain.event_dim
  261. event_dim = max(event_dim, part.domain.event_dim)
  262. assert event_dim >= domain.event_dim
  263. if event_dim > domain.event_dim:
  264. domain = constraints.independent(domain, event_dim - domain.event_dim)
  265. return domain
  266. @constraints.dependent_property(is_discrete=False)
  267. def codomain(self):
  268. if not self.parts:
  269. return constraints.real
  270. codomain = self.parts[-1].codomain
  271. # Adjust event_dim to be maximum among all parts.
  272. event_dim = self.parts[0].domain.event_dim
  273. for part in self.parts:
  274. event_dim += part.codomain.event_dim - part.domain.event_dim
  275. event_dim = max(event_dim, part.codomain.event_dim)
  276. assert event_dim >= codomain.event_dim
  277. if event_dim > codomain.event_dim:
  278. codomain = constraints.independent(codomain, event_dim - codomain.event_dim)
  279. return codomain
  280. @lazy_property
  281. def bijective(self) -> bool: # type: ignore[override]
  282. return all(p.bijective for p in self.parts)
  283. @lazy_property
  284. def sign(self) -> int: # type: ignore[override]
  285. sign = 1
  286. for p in self.parts:
  287. sign = sign * p.sign
  288. return sign
  289. @property
  290. def inv(self) -> Transform:
  291. inv = None
  292. if self._inv is not None:
  293. inv = self._inv()
  294. if inv is None:
  295. inv = ComposeTransform([p.inv for p in reversed(self.parts)])
  296. self._inv = weakref.ref(inv)
  297. inv._inv = weakref.ref(self)
  298. return inv
  299. def with_cache(self, cache_size=1):
  300. if self._cache_size == cache_size:
  301. return self
  302. return ComposeTransform(self.parts, cache_size=cache_size)
  303. def __call__(self, x):
  304. for part in self.parts:
  305. x = part(x)
  306. return x
  307. def log_abs_det_jacobian(self, x, y):
  308. if not self.parts:
  309. return torch.zeros_like(x)
  310. # Compute intermediates. This will be free if parts[:-1] are all cached.
  311. xs = [x]
  312. for part in self.parts[:-1]:
  313. xs.append(part(xs[-1]))
  314. xs.append(y)
  315. terms = []
  316. event_dim = self.domain.event_dim
  317. for part, x, y in zip(self.parts, xs[:-1], xs[1:]):
  318. terms.append(
  319. _sum_rightmost(
  320. part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim
  321. )
  322. )
  323. event_dim += part.codomain.event_dim - part.domain.event_dim
  324. return functools.reduce(operator.add, terms)
  325. def forward_shape(self, shape):
  326. for part in self.parts:
  327. shape = part.forward_shape(shape)
  328. return shape
  329. def inverse_shape(self, shape):
  330. for part in reversed(self.parts):
  331. shape = part.inverse_shape(shape)
  332. return shape
  333. def __repr__(self):
  334. fmt_string = self.__class__.__name__ + "(\n "
  335. fmt_string += ",\n ".join([p.__repr__() for p in self.parts])
  336. fmt_string += "\n)"
  337. return fmt_string
  338. identity_transform = ComposeTransform([])
  339. class IndependentTransform(Transform):
  340. """
  341. Wrapper around another transform to treat
  342. ``reinterpreted_batch_ndims``-many extra of the right most dimensions as
  343. dependent. This has no effect on the forward or backward transforms, but
  344. does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions
  345. in :meth:`log_abs_det_jacobian`.
  346. Args:
  347. base_transform (:class:`Transform`): A base transform.
  348. reinterpreted_batch_ndims (int): The number of extra rightmost
  349. dimensions to treat as dependent.
  350. """
  351. def __init__(
  352. self,
  353. base_transform: Transform,
  354. reinterpreted_batch_ndims: int,
  355. cache_size: int = 0,
  356. ) -> None:
  357. super().__init__(cache_size=cache_size)
  358. self.base_transform = base_transform.with_cache(cache_size)
  359. self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
  360. def with_cache(self, cache_size=1):
  361. if self._cache_size == cache_size:
  362. return self
  363. return IndependentTransform(
  364. self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size
  365. )
  366. @constraints.dependent_property(is_discrete=False)
  367. def domain(self):
  368. return constraints.independent(
  369. self.base_transform.domain, self.reinterpreted_batch_ndims
  370. )
  371. @constraints.dependent_property(is_discrete=False)
  372. def codomain(self):
  373. return constraints.independent(
  374. self.base_transform.codomain, self.reinterpreted_batch_ndims
  375. )
  376. @property
  377. def bijective(self) -> bool: # type: ignore[override]
  378. return self.base_transform.bijective
  379. @property
  380. def sign(self) -> int:
  381. return self.base_transform.sign
  382. def _call(self, x):
  383. if x.dim() < self.domain.event_dim:
  384. raise ValueError("Too few dimensions on input")
  385. return self.base_transform(x)
  386. def _inverse(self, y):
  387. if y.dim() < self.codomain.event_dim:
  388. raise ValueError("Too few dimensions on input")
  389. return self.base_transform.inv(y)
  390. def log_abs_det_jacobian(self, x, y):
  391. result = self.base_transform.log_abs_det_jacobian(x, y)
  392. result = _sum_rightmost(result, self.reinterpreted_batch_ndims)
  393. return result
  394. def __repr__(self):
  395. return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})"
  396. def forward_shape(self, shape):
  397. return self.base_transform.forward_shape(shape)
  398. def inverse_shape(self, shape):
  399. return self.base_transform.inverse_shape(shape)
  400. class ReshapeTransform(Transform):
  401. """
  402. Unit Jacobian transform to reshape the rightmost part of a tensor.
  403. Note that ``in_shape`` and ``out_shape`` must have the same number of
  404. elements, just as for :meth:`torch.Tensor.reshape`.
  405. Arguments:
  406. in_shape (torch.Size): The input event shape.
  407. out_shape (torch.Size): The output event shape.
  408. cache_size (int): Size of cache. If zero, no caching is done. If one,
  409. the latest single value is cached. Only 0 and 1 are supported. (Default 0.)
  410. """
  411. bijective = True
  412. def __init__(
  413. self,
  414. in_shape: torch.Size,
  415. out_shape: torch.Size,
  416. cache_size: int = 0,
  417. ) -> None:
  418. self.in_shape = torch.Size(in_shape)
  419. self.out_shape = torch.Size(out_shape)
  420. if self.in_shape.numel() != self.out_shape.numel():
  421. raise ValueError("in_shape, out_shape have different numbers of elements")
  422. super().__init__(cache_size=cache_size)
  423. @constraints.dependent_property
  424. def domain(self):
  425. return constraints.independent(constraints.real, len(self.in_shape))
  426. @constraints.dependent_property
  427. def codomain(self):
  428. return constraints.independent(constraints.real, len(self.out_shape))
  429. def with_cache(self, cache_size=1):
  430. if self._cache_size == cache_size:
  431. return self
  432. return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size)
  433. def _call(self, x):
  434. batch_shape = x.shape[: x.dim() - len(self.in_shape)]
  435. return x.reshape(batch_shape + self.out_shape)
  436. def _inverse(self, y):
  437. batch_shape = y.shape[: y.dim() - len(self.out_shape)]
  438. return y.reshape(batch_shape + self.in_shape)
  439. def log_abs_det_jacobian(self, x, y):
  440. batch_shape = x.shape[: x.dim() - len(self.in_shape)]
  441. return x.new_zeros(batch_shape)
  442. def forward_shape(self, shape):
  443. if len(shape) < len(self.in_shape):
  444. raise ValueError("Too few dimensions on input")
  445. cut = len(shape) - len(self.in_shape)
  446. if shape[cut:] != self.in_shape:
  447. raise ValueError(
  448. f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}"
  449. )
  450. return shape[:cut] + self.out_shape
  451. def inverse_shape(self, shape):
  452. if len(shape) < len(self.out_shape):
  453. raise ValueError("Too few dimensions on input")
  454. cut = len(shape) - len(self.out_shape)
  455. if shape[cut:] != self.out_shape:
  456. raise ValueError(
  457. f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}"
  458. )
  459. return shape[:cut] + self.in_shape
  460. class ExpTransform(Transform):
  461. r"""
  462. Transform via the mapping :math:`y = \exp(x)`.
  463. """
  464. domain = constraints.real
  465. codomain = constraints.positive
  466. bijective = True
  467. sign = +1
  468. def __eq__(self, other):
  469. return isinstance(other, ExpTransform)
  470. def _call(self, x):
  471. return x.exp()
  472. def _inverse(self, y):
  473. return y.log()
  474. def log_abs_det_jacobian(self, x, y):
  475. return x
  476. class PowerTransform(Transform):
  477. r"""
  478. Transform via the mapping :math:`y = x^{\text{exponent}}`.
  479. """
  480. domain = constraints.positive
  481. codomain = constraints.positive
  482. bijective = True
  483. def __init__(self, exponent: Tensor, cache_size: int = 0) -> None:
  484. super().__init__(cache_size=cache_size)
  485. (self.exponent,) = broadcast_all(exponent)
  486. def with_cache(self, cache_size=1):
  487. if self._cache_size == cache_size:
  488. return self
  489. return PowerTransform(self.exponent, cache_size=cache_size)
  490. @lazy_property
  491. def sign(self) -> int: # type: ignore[override]
  492. return self.exponent.sign() # type: ignore[return-value]
  493. def __eq__(self, other):
  494. if not isinstance(other, PowerTransform):
  495. return False
  496. return self.exponent.eq(other.exponent).all().item()
  497. def _call(self, x):
  498. return x.pow(self.exponent)
  499. def _inverse(self, y):
  500. return y.pow(1 / self.exponent)
  501. def log_abs_det_jacobian(self, x, y):
  502. return (self.exponent * y / x).abs().log()
  503. def forward_shape(self, shape):
  504. return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
  505. def inverse_shape(self, shape):
  506. return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
  507. def _clipped_sigmoid(x):
  508. finfo = torch.finfo(x.dtype)
  509. return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps)
  510. class SigmoidTransform(Transform):
  511. r"""
  512. Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
  513. """
  514. domain = constraints.real
  515. codomain = constraints.unit_interval
  516. bijective = True
  517. sign = +1
  518. def __eq__(self, other):
  519. return isinstance(other, SigmoidTransform)
  520. def _call(self, x):
  521. return _clipped_sigmoid(x)
  522. def _inverse(self, y):
  523. finfo = torch.finfo(y.dtype)
  524. y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps)
  525. return y.log() - (-y).log1p()
  526. def log_abs_det_jacobian(self, x, y):
  527. return -F.softplus(-x) - F.softplus(x)
  528. class SoftplusTransform(Transform):
  529. r"""
  530. Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
  531. The implementation reverts to the linear function when :math:`x > 20`.
  532. """
  533. domain = constraints.real
  534. codomain = constraints.positive
  535. bijective = True
  536. sign = +1
  537. def __eq__(self, other):
  538. return isinstance(other, SoftplusTransform)
  539. def _call(self, x):
  540. return softplus(x)
  541. def _inverse(self, y):
  542. return (-y).expm1().neg().log() + y
  543. def log_abs_det_jacobian(self, x, y):
  544. return -softplus(-x)
  545. class TanhTransform(Transform):
  546. r"""
  547. Transform via the mapping :math:`y = \tanh(x)`.
  548. It is equivalent to
  549. .. code-block:: python
  550. ComposeTransform(
  551. [
  552. AffineTransform(0.0, 2.0),
  553. SigmoidTransform(),
  554. AffineTransform(-1.0, 2.0),
  555. ]
  556. )
  557. However this might not be numerically stable, thus it is recommended to use `TanhTransform`
  558. instead.
  559. Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
  560. """
  561. domain = constraints.real
  562. codomain = constraints.interval(-1.0, 1.0)
  563. bijective = True
  564. sign = +1
  565. def __eq__(self, other):
  566. return isinstance(other, TanhTransform)
  567. def _call(self, x):
  568. return x.tanh()
  569. def _inverse(self, y):
  570. # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
  571. # one should use `cache_size=1` instead
  572. return torch.atanh(y)
  573. def log_abs_det_jacobian(self, x, y):
  574. # We use a formula that is more numerically stable, see details in the following link
  575. # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
  576. return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x))
  577. class AbsTransform(Transform):
  578. r"""Transform via the mapping :math:`y = |x|`."""
  579. domain = constraints.real
  580. codomain = constraints.positive
  581. def __eq__(self, other):
  582. return isinstance(other, AbsTransform)
  583. def _call(self, x):
  584. return x.abs()
  585. def _inverse(self, y):
  586. return y
  587. class AffineTransform(Transform):
  588. r"""
  589. Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`.
  590. Args:
  591. loc (Tensor or float): Location parameter.
  592. scale (Tensor or float): Scale parameter.
  593. event_dim (int): Optional size of `event_shape`. This should be zero
  594. for univariate random variables, 1 for distributions over vectors,
  595. 2 for distributions over matrices, etc.
  596. """
  597. bijective = True
  598. def __init__(
  599. self,
  600. loc: Union[Tensor, float],
  601. scale: Union[Tensor, float],
  602. event_dim: int = 0,
  603. cache_size: int = 0,
  604. ) -> None:
  605. super().__init__(cache_size=cache_size)
  606. self.loc = loc
  607. self.scale = scale
  608. self._event_dim = event_dim
  609. @property
  610. def event_dim(self) -> int:
  611. return self._event_dim
  612. @constraints.dependent_property(is_discrete=False)
  613. def domain(self):
  614. if self.event_dim == 0:
  615. return constraints.real
  616. return constraints.independent(constraints.real, self.event_dim)
  617. @constraints.dependent_property(is_discrete=False)
  618. def codomain(self):
  619. if self.event_dim == 0:
  620. return constraints.real
  621. return constraints.independent(constraints.real, self.event_dim)
  622. def with_cache(self, cache_size=1):
  623. if self._cache_size == cache_size:
  624. return self
  625. return AffineTransform(
  626. self.loc, self.scale, self.event_dim, cache_size=cache_size
  627. )
  628. def __eq__(self, other):
  629. if not isinstance(other, AffineTransform):
  630. return False
  631. if isinstance(self.loc, _Number) and isinstance(other.loc, _Number):
  632. if self.loc != other.loc:
  633. return False
  634. else:
  635. if not (self.loc == other.loc).all().item(): # type: ignore[union-attr]
  636. return False
  637. if isinstance(self.scale, _Number) and isinstance(other.scale, _Number):
  638. if self.scale != other.scale:
  639. return False
  640. else:
  641. if not (self.scale == other.scale).all().item(): # type: ignore[union-attr]
  642. return False
  643. return True
  644. @property
  645. def sign(self) -> Union[Tensor, int]: # type: ignore[override]
  646. if isinstance(self.scale, _Number):
  647. return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
  648. return self.scale.sign()
  649. def _call(self, x):
  650. return self.loc + self.scale * x
  651. def _inverse(self, y):
  652. return (y - self.loc) / self.scale
  653. def log_abs_det_jacobian(self, x, y):
  654. shape = x.shape
  655. scale = self.scale
  656. if isinstance(scale, _Number):
  657. result = torch.full_like(x, math.log(abs(scale)))
  658. else:
  659. result = torch.abs(scale).log()
  660. if self.event_dim:
  661. result_size = result.size()[: -self.event_dim] + (-1,)
  662. result = result.view(result_size).sum(-1)
  663. shape = shape[: -self.event_dim]
  664. return result.expand(shape)
  665. def forward_shape(self, shape):
  666. return torch.broadcast_shapes(
  667. shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
  668. )
  669. def inverse_shape(self, shape):
  670. return torch.broadcast_shapes(
  671. shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
  672. )
  673. class CorrCholeskyTransform(Transform):
  674. r"""
  675. Transforms an unconstrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
  676. Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
  677. triangular matrix with positive diagonals and unit Euclidean norm for each row.
  678. The transform is processed as follows:
  679. 1. First we convert x into a lower triangular matrix in row order.
  680. 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of
  681. class :class:`StickBreakingTransform` to transform :math:`X_i` into a
  682. unit Euclidean length vector using the following steps:
  683. - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`.
  684. - Transforms into an unsigned domain: :math:`z_i = r_i^2`.
  685. - Applies :math:`s_i = StickBreakingTransform(z_i)`.
  686. - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`.
  687. """
  688. domain = constraints.real_vector
  689. codomain = constraints.corr_cholesky
  690. bijective = True
  691. def _call(self, x):
  692. x = torch.tanh(x)
  693. eps = torch.finfo(x.dtype).eps
  694. x = x.clamp(min=-1 + eps, max=1 - eps)
  695. r = vec_to_tril_matrix(x, diag=-1)
  696. # apply stick-breaking on the squared values
  697. # Note that y = sign(r) * sqrt(z * z1m_cumprod)
  698. # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
  699. z = r**2
  700. z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
  701. # Diagonal elements must be 1.
  702. r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device)
  703. y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)
  704. return y
  705. def _inverse(self, y):
  706. # inverse stick-breaking
  707. # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
  708. y_cumsum = 1 - torch.cumsum(y * y, dim=-1)
  709. y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1)
  710. y_vec = tril_matrix_to_vec(y, diag=-1)
  711. y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1)
  712. t = y_vec / (y_cumsum_vec).sqrt()
  713. # inverse of tanh
  714. x = (t.log1p() - t.neg().log1p()) / 2
  715. return x
  716. def log_abs_det_jacobian(self, x, y, intermediates=None):
  717. # Because domain and codomain are two spaces with different dimensions, determinant of
  718. # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the
  719. # flattened lower triangular part of `y`.
  720. # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
  721. y1m_cumsum = 1 - (y * y).cumsum(dim=-1)
  722. # by taking diagonal=-2, we don't need to shift z_cumprod to the right
  723. # also works for 2 x 2 matrix
  724. y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2)
  725. stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1)
  726. tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1)
  727. return stick_breaking_logdet + tanh_logdet
  728. def forward_shape(self, shape):
  729. # Reshape from (..., N) to (..., D, D).
  730. if len(shape) < 1:
  731. raise ValueError("Too few dimensions on input")
  732. N = shape[-1]
  733. D = round((0.25 + 2 * N) ** 0.5 + 0.5)
  734. if D * (D - 1) // 2 != N:
  735. raise ValueError("Input is not a flattened lower-diagonal number")
  736. return shape[:-1] + (D, D)
  737. def inverse_shape(self, shape):
  738. # Reshape from (..., D, D) to (..., N).
  739. if len(shape) < 2:
  740. raise ValueError("Too few dimensions on input")
  741. if shape[-2] != shape[-1]:
  742. raise ValueError("Input is not square")
  743. D = shape[-1]
  744. N = D * (D - 1) // 2
  745. return shape[:-2] + (N,)
  746. class SoftmaxTransform(Transform):
  747. r"""
  748. Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then
  749. normalizing.
  750. This is not bijective and cannot be used for HMC. However this acts mostly
  751. coordinate-wise (except for the final normalization), and thus is
  752. appropriate for coordinate-wise optimization algorithms.
  753. """
  754. domain = constraints.real_vector
  755. codomain = constraints.simplex
  756. def __eq__(self, other):
  757. return isinstance(other, SoftmaxTransform)
  758. def _call(self, x):
  759. logprobs = x
  760. probs = (logprobs - logprobs.max(-1, True)[0]).exp()
  761. return probs / probs.sum(-1, True)
  762. def _inverse(self, y):
  763. probs = y
  764. return probs.log()
  765. def forward_shape(self, shape):
  766. if len(shape) < 1:
  767. raise ValueError("Too few dimensions on input")
  768. return shape
  769. def inverse_shape(self, shape):
  770. if len(shape) < 1:
  771. raise ValueError("Too few dimensions on input")
  772. return shape
  773. class StickBreakingTransform(Transform):
  774. """
  775. Transform from unconstrained space to the simplex of one additional
  776. dimension via a stick-breaking process.
  777. This transform arises as an iterated sigmoid transform in a stick-breaking
  778. construction of the `Dirichlet` distribution: the first logit is
  779. transformed via sigmoid to the first probability and the probability of
  780. everything else, and then the process recurses.
  781. This is bijective and appropriate for use in HMC; however it mixes
  782. coordinates together and is less appropriate for optimization.
  783. """
  784. domain = constraints.real_vector
  785. codomain = constraints.simplex
  786. bijective = True
  787. def __eq__(self, other):
  788. return isinstance(other, StickBreakingTransform)
  789. def _call(self, x):
  790. offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
  791. z = _clipped_sigmoid(x - offset.log())
  792. z_cumprod = (1 - z).cumprod(-1)
  793. y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
  794. return y
  795. def _inverse(self, y):
  796. y_crop = y[..., :-1]
  797. offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
  798. sf = 1 - y_crop.cumsum(-1)
  799. # we clamp to make sure that sf is positive which sometimes does not
  800. # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
  801. sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
  802. x = y_crop.log() - sf.log() + offset.log()
  803. return x
  804. def log_abs_det_jacobian(self, x, y):
  805. offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
  806. x = x - offset.log()
  807. # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
  808. detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
  809. return detJ
  810. def forward_shape(self, shape):
  811. if len(shape) < 1:
  812. raise ValueError("Too few dimensions on input")
  813. return shape[:-1] + (shape[-1] + 1,)
  814. def inverse_shape(self, shape):
  815. if len(shape) < 1:
  816. raise ValueError("Too few dimensions on input")
  817. return shape[:-1] + (shape[-1] - 1,)
  818. class LowerCholeskyTransform(Transform):
  819. """
  820. Transform from unconstrained matrices to lower-triangular matrices with
  821. nonnegative diagonal entries.
  822. This is useful for parameterizing positive definite matrices in terms of
  823. their Cholesky factorization.
  824. """
  825. domain = constraints.independent(constraints.real, 2)
  826. codomain = constraints.lower_cholesky
  827. def __eq__(self, other):
  828. return isinstance(other, LowerCholeskyTransform)
  829. def _call(self, x):
  830. return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()
  831. def _inverse(self, y):
  832. return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()
  833. class PositiveDefiniteTransform(Transform):
  834. """
  835. Transform from unconstrained matrices to positive-definite matrices.
  836. """
  837. domain = constraints.independent(constraints.real, 2)
  838. codomain = constraints.positive_definite
  839. def __eq__(self, other):
  840. return isinstance(other, PositiveDefiniteTransform)
  841. def _call(self, x):
  842. x = LowerCholeskyTransform()(x)
  843. return x @ x.mT
  844. def _inverse(self, y):
  845. y = torch.linalg.cholesky(y)
  846. return LowerCholeskyTransform().inv(y)
  847. class CatTransform(Transform):
  848. """
  849. Transform functor that applies a sequence of transforms `tseq`
  850. component-wise to each submatrix at `dim`, of length `lengths[dim]`,
  851. in a way compatible with :func:`torch.cat`.
  852. Example::
  853. x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0)
  854. x = torch.cat([x0, x0], dim=0)
  855. t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10])
  856. t = CatTransform([t0, t0], dim=0, lengths=[20, 20])
  857. y = t(x)
  858. """
  859. transforms: list[Transform]
  860. def __init__(
  861. self,
  862. tseq: Sequence[Transform],
  863. dim: int = 0,
  864. lengths: Optional[Sequence[int]] = None,
  865. cache_size: int = 0,
  866. ) -> None:
  867. assert all(isinstance(t, Transform) for t in tseq)
  868. if cache_size:
  869. tseq = [t.with_cache(cache_size) for t in tseq]
  870. super().__init__(cache_size=cache_size)
  871. self.transforms = list(tseq)
  872. if lengths is None:
  873. lengths = [1] * len(self.transforms)
  874. self.lengths = list(lengths)
  875. assert len(self.lengths) == len(self.transforms)
  876. self.dim = dim
  877. @lazy_property
  878. def event_dim(self) -> int: # type: ignore[override]
  879. return max(t.event_dim for t in self.transforms)
  880. @lazy_property
  881. def length(self) -> int:
  882. return sum(self.lengths)
  883. def with_cache(self, cache_size=1):
  884. if self._cache_size == cache_size:
  885. return self
  886. return CatTransform(self.transforms, self.dim, self.lengths, cache_size)
  887. def _call(self, x):
  888. assert -x.dim() <= self.dim < x.dim()
  889. assert x.size(self.dim) == self.length
  890. yslices = []
  891. start = 0
  892. for trans, length in zip(self.transforms, self.lengths):
  893. xslice = x.narrow(self.dim, start, length)
  894. yslices.append(trans(xslice))
  895. start = start + length # avoid += for jit compat
  896. return torch.cat(yslices, dim=self.dim)
  897. def _inverse(self, y):
  898. assert -y.dim() <= self.dim < y.dim()
  899. assert y.size(self.dim) == self.length
  900. xslices = []
  901. start = 0
  902. for trans, length in zip(self.transforms, self.lengths):
  903. yslice = y.narrow(self.dim, start, length)
  904. xslices.append(trans.inv(yslice))
  905. start = start + length # avoid += for jit compat
  906. return torch.cat(xslices, dim=self.dim)
  907. def log_abs_det_jacobian(self, x, y):
  908. assert -x.dim() <= self.dim < x.dim()
  909. assert x.size(self.dim) == self.length
  910. assert -y.dim() <= self.dim < y.dim()
  911. assert y.size(self.dim) == self.length
  912. logdetjacs = []
  913. start = 0
  914. for trans, length in zip(self.transforms, self.lengths):
  915. xslice = x.narrow(self.dim, start, length)
  916. yslice = y.narrow(self.dim, start, length)
  917. logdetjac = trans.log_abs_det_jacobian(xslice, yslice)
  918. if trans.event_dim < self.event_dim:
  919. logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim)
  920. logdetjacs.append(logdetjac)
  921. start = start + length # avoid += for jit compat
  922. # Decide whether to concatenate or sum.
  923. dim = self.dim
  924. if dim >= 0:
  925. dim = dim - x.dim()
  926. dim = dim + self.event_dim
  927. if dim < 0:
  928. return torch.cat(logdetjacs, dim=dim)
  929. else:
  930. return sum(logdetjacs)
  931. @property
  932. def bijective(self) -> bool: # type: ignore[override]
  933. return all(t.bijective for t in self.transforms)
  934. @constraints.dependent_property
  935. def domain(self):
  936. return constraints.cat(
  937. [t.domain for t in self.transforms], self.dim, self.lengths
  938. )
  939. @constraints.dependent_property
  940. def codomain(self):
  941. return constraints.cat(
  942. [t.codomain for t in self.transforms], self.dim, self.lengths
  943. )
  944. class StackTransform(Transform):
  945. """
  946. Transform functor that applies a sequence of transforms `tseq`
  947. component-wise to each submatrix at `dim`
  948. in a way compatible with :func:`torch.stack`.
  949. Example::
  950. x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1)
  951. t = StackTransform([ExpTransform(), identity_transform], dim=1)
  952. y = t(x)
  953. """
  954. transforms: list[Transform]
  955. def __init__(
  956. self, tseq: Sequence[Transform], dim: int = 0, cache_size: int = 0
  957. ) -> None:
  958. assert all(isinstance(t, Transform) for t in tseq)
  959. if cache_size:
  960. tseq = [t.with_cache(cache_size) for t in tseq]
  961. super().__init__(cache_size=cache_size)
  962. self.transforms = list(tseq)
  963. self.dim = dim
  964. def with_cache(self, cache_size=1):
  965. if self._cache_size == cache_size:
  966. return self
  967. return StackTransform(self.transforms, self.dim, cache_size)
  968. def _slice(self, z):
  969. return [z.select(self.dim, i) for i in range(z.size(self.dim))]
  970. def _call(self, x):
  971. assert -x.dim() <= self.dim < x.dim()
  972. assert x.size(self.dim) == len(self.transforms)
  973. yslices = []
  974. for xslice, trans in zip(self._slice(x), self.transforms):
  975. yslices.append(trans(xslice))
  976. return torch.stack(yslices, dim=self.dim)
  977. def _inverse(self, y):
  978. assert -y.dim() <= self.dim < y.dim()
  979. assert y.size(self.dim) == len(self.transforms)
  980. xslices = []
  981. for yslice, trans in zip(self._slice(y), self.transforms):
  982. xslices.append(trans.inv(yslice))
  983. return torch.stack(xslices, dim=self.dim)
  984. def log_abs_det_jacobian(self, x, y):
  985. assert -x.dim() <= self.dim < x.dim()
  986. assert x.size(self.dim) == len(self.transforms)
  987. assert -y.dim() <= self.dim < y.dim()
  988. assert y.size(self.dim) == len(self.transforms)
  989. logdetjacs = []
  990. yslices = self._slice(y)
  991. xslices = self._slice(x)
  992. for xslice, yslice, trans in zip(xslices, yslices, self.transforms):
  993. logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice))
  994. return torch.stack(logdetjacs, dim=self.dim)
  995. @property
  996. def bijective(self) -> bool: # type: ignore[override]
  997. return all(t.bijective for t in self.transforms)
  998. @constraints.dependent_property
  999. def domain(self):
  1000. return constraints.stack([t.domain for t in self.transforms], self.dim)
  1001. @constraints.dependent_property
  1002. def codomain(self):
  1003. return constraints.stack([t.codomain for t in self.transforms], self.dim)
  1004. class CumulativeDistributionTransform(Transform):
  1005. """
  1006. Transform via the cumulative distribution function of a probability distribution.
  1007. Args:
  1008. distribution (Distribution): Distribution whose cumulative distribution function to use for
  1009. the transformation.
  1010. Example::
  1011. # Construct a Gaussian copula from a multivariate normal.
  1012. base_dist = MultivariateNormal(
  1013. loc=torch.zeros(2),
  1014. scale_tril=LKJCholesky(2).sample(),
  1015. )
  1016. transform = CumulativeDistributionTransform(Normal(0, 1))
  1017. copula = TransformedDistribution(base_dist, [transform])
  1018. """
  1019. bijective = True
  1020. codomain = constraints.unit_interval
  1021. sign = +1
  1022. def __init__(self, distribution: Distribution, cache_size: int = 0) -> None:
  1023. super().__init__(cache_size=cache_size)
  1024. self.distribution = distribution
  1025. @property
  1026. def domain(self) -> Optional[constraints.Constraint]: # type: ignore[override]
  1027. return self.distribution.support
  1028. def _call(self, x):
  1029. return self.distribution.cdf(x)
  1030. def _inverse(self, y):
  1031. return self.distribution.icdf(y)
  1032. def log_abs_det_jacobian(self, x, y):
  1033. return self.distribution.log_prob(x)
  1034. def with_cache(self, cache_size=1):
  1035. if self._cache_size == cache_size:
  1036. return self
  1037. return CumulativeDistributionTransform(self.distribution, cache_size=cache_size)