_lowrank.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """Implement various linear algebra algorithms for low rank matrices."""
  2. __all__ = ["svd_lowrank", "pca_lowrank"]
  3. from typing import Optional
  4. import torch
  5. from torch import _linalg_utils as _utils, Tensor
  6. from torch.overrides import handle_torch_function, has_torch_function
  7. def get_approximate_basis(
  8. A: Tensor,
  9. q: int,
  10. niter: Optional[int] = 2,
  11. M: Optional[Tensor] = None,
  12. ) -> Tensor:
  13. """Return tensor :math:`Q` with :math:`q` orthonormal columns such
  14. that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
  15. specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
  16. approximates :math:`A - M`. without instantiating any tensors
  17. of the size of :math:`A` or :math:`M`.
  18. .. note:: The implementation is based on the Algorithm 4.4 from
  19. Halko et al., 2009.
  20. .. note:: For an adequate approximation of a k-rank matrix
  21. :math:`A`, where k is not known in advance but could be
  22. estimated, the number of :math:`Q` columns, q, can be
  23. chosen according to the following criteria: in general,
  24. :math:`k <= q <= min(2*k, m, n)`. For large low-rank
  25. matrices, take :math:`q = k + 5..10`. If k is
  26. relatively small compared to :math:`min(m, n)`, choosing
  27. :math:`q = k + 0..2` may be sufficient.
  28. .. note:: To obtain repeatable results, reset the seed for the
  29. pseudorandom number generator
  30. Args::
  31. A (Tensor): the input tensor of size :math:`(*, m, n)`
  32. q (int): the dimension of subspace spanned by :math:`Q`
  33. columns.
  34. niter (int, optional): the number of subspace iterations to
  35. conduct; ``niter`` must be a
  36. nonnegative integer. In most cases, the
  37. default value 2 is more than enough.
  38. M (Tensor, optional): the input tensor's mean of size
  39. :math:`(*, m, n)`.
  40. References::
  41. - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
  42. structure with randomness: probabilistic algorithms for
  43. constructing approximate matrix decompositions,
  44. arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
  45. `arXiv <http://arxiv.org/abs/0909.4061>`_).
  46. """
  47. niter = 2 if niter is None else niter
  48. dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype
  49. matmul = _utils.matmul
  50. R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device)
  51. # The following code could be made faster using torch.geqrf + torch.ormqr
  52. # but geqrf is not differentiable
  53. X = matmul(A, R)
  54. if M is not None:
  55. X = X - matmul(M, R)
  56. Q = torch.linalg.qr(X).Q
  57. for _ in range(niter):
  58. X = matmul(A.mH, Q)
  59. if M is not None:
  60. X = X - matmul(M.mH, Q)
  61. Q = torch.linalg.qr(X).Q
  62. X = matmul(A, Q)
  63. if M is not None:
  64. X = X - matmul(M, Q)
  65. Q = torch.linalg.qr(X).Q
  66. return Q
  67. def svd_lowrank(
  68. A: Tensor,
  69. q: Optional[int] = 6,
  70. niter: Optional[int] = 2,
  71. M: Optional[Tensor] = None,
  72. ) -> tuple[Tensor, Tensor, Tensor]:
  73. r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
  74. batches of matrices, or a sparse matrix :math:`A` such that
  75. :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
  76. SVD is computed for the matrix :math:`A - M`.
  77. .. note:: The implementation is based on the Algorithm 5.1 from
  78. Halko et al., 2009.
  79. .. note:: For an adequate approximation of a k-rank matrix
  80. :math:`A`, where k is not known in advance but could be
  81. estimated, the number of :math:`Q` columns, q, can be
  82. chosen according to the following criteria: in general,
  83. :math:`k <= q <= min(2*k, m, n)`. For large low-rank
  84. matrices, take :math:`q = k + 5..10`. If k is
  85. relatively small compared to :math:`min(m, n)`, choosing
  86. :math:`q = k + 0..2` may be sufficient.
  87. .. note:: This is a randomized method. To obtain repeatable results,
  88. set the seed for the pseudorandom number generator
  89. .. note:: In general, use the full-rank SVD implementation
  90. :func:`torch.linalg.svd` for dense matrices due to its 10x
  91. higher performance characteristics. The low-rank SVD
  92. will be useful for huge sparse matrices that
  93. :func:`torch.linalg.svd` cannot handle.
  94. Args::
  95. A (Tensor): the input tensor of size :math:`(*, m, n)`
  96. q (int, optional): a slightly overestimated rank of A.
  97. niter (int, optional): the number of subspace iterations to
  98. conduct; niter must be a nonnegative
  99. integer, and defaults to 2
  100. M (Tensor, optional): the input tensor's mean of size
  101. :math:`(*, m, n)`, which will be broadcasted
  102. to the size of A in this function.
  103. References::
  104. - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
  105. structure with randomness: probabilistic algorithms for
  106. constructing approximate matrix decompositions,
  107. arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
  108. `arXiv <https://arxiv.org/abs/0909.4061>`_).
  109. """
  110. if not torch.jit.is_scripting():
  111. tensor_ops = (A, M)
  112. if not set(map(type, tensor_ops)).issubset(
  113. (torch.Tensor, type(None))
  114. ) and has_torch_function(tensor_ops):
  115. return handle_torch_function(
  116. svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M
  117. )
  118. return _svd_lowrank(A, q=q, niter=niter, M=M)
  119. def _svd_lowrank(
  120. A: Tensor,
  121. q: Optional[int] = 6,
  122. niter: Optional[int] = 2,
  123. M: Optional[Tensor] = None,
  124. ) -> tuple[Tensor, Tensor, Tensor]:
  125. # Algorithm 5.1 in Halko et al., 2009
  126. q = 6 if q is None else q
  127. m, n = A.shape[-2:]
  128. matmul = _utils.matmul
  129. if M is not None:
  130. M = M.broadcast_to(A.size())
  131. # Assume that A is tall
  132. if m < n:
  133. A = A.mH
  134. if M is not None:
  135. M = M.mH
  136. Q = get_approximate_basis(A, q, niter=niter, M=M)
  137. B = matmul(Q.mH, A)
  138. if M is not None:
  139. B = B - matmul(Q.mH, M)
  140. U, S, Vh = torch.linalg.svd(B, full_matrices=False)
  141. V = Vh.mH
  142. U = Q.matmul(U)
  143. if m < n:
  144. U, V = V, U
  145. return U, S, V
  146. def pca_lowrank(
  147. A: Tensor,
  148. q: Optional[int] = None,
  149. center: bool = True,
  150. niter: int = 2,
  151. ) -> tuple[Tensor, Tensor, Tensor]:
  152. r"""Performs linear Principal Component Analysis (PCA) on a low-rank
  153. matrix, batches of such matrices, or sparse matrix.
  154. This function returns a namedtuple ``(U, S, V)`` which is the
  155. nearly optimal approximation of a singular value decomposition of
  156. a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`
  157. .. note:: The relation of ``(U, S, V)`` to PCA is as follows:
  158. - :math:`A` is a data matrix with ``m`` samples and
  159. ``n`` features
  160. - the :math:`V` columns represent the principal directions
  161. - :math:`S ** 2 / (m - 1)` contains the eigenvalues of
  162. :math:`A^T A / (m - 1)` which is the covariance of
  163. ``A`` when ``center=True`` is provided.
  164. - ``matmul(A, V[:, :k])`` projects data to the first k
  165. principal components
  166. .. note:: Different from the standard SVD, the size of returned
  167. matrices depend on the specified rank and q
  168. values as follows:
  169. - :math:`U` is m x q matrix
  170. - :math:`S` is q-vector
  171. - :math:`V` is n x q matrix
  172. .. note:: To obtain repeatable results, reset the seed for the
  173. pseudorandom number generator
  174. Args:
  175. A (Tensor): the input tensor of size :math:`(*, m, n)`
  176. q (int, optional): a slightly overestimated rank of
  177. :math:`A`. By default, ``q = min(6, m,
  178. n)``.
  179. center (bool, optional): if True, center the input tensor,
  180. otherwise, assume that the input is
  181. centered.
  182. niter (int, optional): the number of subspace iterations to
  183. conduct; niter must be a nonnegative
  184. integer, and defaults to 2.
  185. References::
  186. - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
  187. structure with randomness: probabilistic algorithms for
  188. constructing approximate matrix decompositions,
  189. arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
  190. `arXiv <http://arxiv.org/abs/0909.4061>`_).
  191. """
  192. if not torch.jit.is_scripting():
  193. if type(A) is not torch.Tensor and has_torch_function((A,)):
  194. return handle_torch_function(
  195. pca_lowrank, (A,), A, q=q, center=center, niter=niter
  196. )
  197. (m, n) = A.shape[-2:]
  198. if q is None:
  199. q = min(6, m, n)
  200. elif not (q >= 0 and q <= min(m, n)):
  201. raise ValueError(
  202. f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}"
  203. )
  204. if not (niter >= 0):
  205. raise ValueError(f"niter(={niter}) must be non-negative integer")
  206. dtype = _utils.get_floating_dtype(A)
  207. if not center:
  208. return _svd_lowrank(A, q, niter=niter, M=None)
  209. if _utils.is_sparse(A):
  210. if len(A.shape) != 2:
  211. raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor")
  212. c = torch.sparse.sum(A, dim=(-2,)) / m
  213. # reshape c
  214. column_indices = c.indices()[0]
  215. indices = torch.zeros(
  216. 2,
  217. len(column_indices),
  218. dtype=column_indices.dtype,
  219. device=column_indices.device,
  220. )
  221. indices[0] = column_indices
  222. C_t = torch.sparse_coo_tensor(
  223. indices, c.values(), (n, 1), dtype=dtype, device=A.device
  224. )
  225. ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device)
  226. M = torch.sparse.mm(C_t, ones_m1_t).mT
  227. return _svd_lowrank(A, q, niter=niter, M=M)
  228. else:
  229. C = A.mean(dim=(-2,), keepdim=True)
  230. return _svd_lowrank(A - C, q, niter=niter, M=None)