_lobpcg.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157
  1. # mypy: allow-untyped-defs
  2. """Locally Optimal Block Preconditioned Conjugate Gradient methods."""
  3. # Author: Pearu Peterson
  4. # Created: February 2020
  5. from typing import Optional
  6. import torch
  7. from torch import _linalg_utils as _utils, Tensor
  8. from torch.overrides import handle_torch_function, has_torch_function
  9. __all__ = ["lobpcg"]
  10. def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
  11. # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
  12. F = D.unsqueeze(-2) - D.unsqueeze(-1)
  13. F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
  14. F.pow_(-1)
  15. # A.grad = U (D.grad + (U^T U.grad * F)) U^T
  16. Ut = U.mT.contiguous()
  17. res = torch.matmul(
  18. U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)
  19. )
  20. return res
  21. def _polynomial_coefficients_given_roots(roots):
  22. """
  23. Given the `roots` of a polynomial, find the polynomial's coefficients.
  24. If roots = (r_1, ..., r_n), then the method returns
  25. coefficients (a_0, a_1, ..., a_n (== 1)) so that
  26. p(x) = (x - r_1) * ... * (x - r_n)
  27. = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
  28. Note: for better performance requires writing a low-level kernel
  29. """
  30. poly_order = roots.shape[-1]
  31. poly_coeffs_shape = list(roots.shape)
  32. # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
  33. # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
  34. # but we insert one extra coefficient to enable better vectorization below
  35. poly_coeffs_shape[-1] += 2
  36. poly_coeffs = roots.new_zeros(poly_coeffs_shape)
  37. poly_coeffs[..., 0] = 1
  38. poly_coeffs[..., -1] = 1
  39. # perform the Horner's rule
  40. for i in range(1, poly_order + 1):
  41. # note that it is computationally hard to compute backward for this method,
  42. # because then given the coefficients it would require finding the roots and/or
  43. # calculating the sensitivity based on the Vieta's theorem.
  44. # So the code below tries to circumvent the explicit root finding by series
  45. # of operations on memory copies imitating the Horner's method.
  46. # The memory copies are required to construct nodes in the computational graph
  47. # by exploiting the explicit (not in-place, separate node for each step)
  48. # recursion of the Horner's method.
  49. # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
  50. poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
  51. out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
  52. out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
  53. -1, poly_order - i + 1, i + 1
  54. )
  55. poly_coeffs = poly_coeffs_new
  56. return poly_coeffs.narrow(-1, 1, poly_order + 1)
  57. def _polynomial_value(poly, x, zero_power, transition):
  58. """
  59. A generic method for computing poly(x) using the Horner's rule.
  60. Args:
  61. poly (Tensor): the (possibly batched) 1D Tensor representing
  62. polynomial coefficients such that
  63. poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
  64. poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
  65. x (Tensor): the value (possible batched) to evaluate the polynomial `poly` at.
  66. zero_power (Tensor): the representation of `x^0`. It is application-specific.
  67. transition (Callable): the function that accepts some intermediate result `int_val`,
  68. the `x` and a specific polynomial coefficient
  69. `poly[..., k]` for some iteration `k`.
  70. It basically performs one iteration of the Horner's rule
  71. defined as `x * int_val + poly[..., k] * zero_power`.
  72. Note that `zero_power` is not a parameter,
  73. because the step `+ poly[..., k] * zero_power` depends on `x`,
  74. whether it is a vector, a matrix, or something else, so this
  75. functionality is delegated to the user.
  76. """
  77. res = zero_power.clone()
  78. for k in range(poly.size(-1) - 2, -1, -1):
  79. res = transition(res, x, poly[..., k])
  80. return res
  81. def _matrix_polynomial_value(poly, x, zero_power=None):
  82. """
  83. Evaluates `poly(x)` for the (batched) matrix input `x`.
  84. Check out `_polynomial_value` function for more details.
  85. """
  86. # matrix-aware Horner's rule iteration
  87. def transition(curr_poly_val, x, poly_coeff):
  88. res = x.matmul(curr_poly_val)
  89. res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
  90. return res
  91. if zero_power is None:
  92. zero_power = torch.eye(
  93. x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
  94. ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
  95. return _polynomial_value(poly, x, zero_power, transition)
  96. def _vector_polynomial_value(poly, x, zero_power=None):
  97. """
  98. Evaluates `poly(x)` for the (batched) vector input `x`.
  99. Check out `_polynomial_value` function for more details.
  100. """
  101. # vector-aware Horner's rule iteration
  102. def transition(curr_poly_val, x, poly_coeff):
  103. res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
  104. return res
  105. if zero_power is None:
  106. zero_power = x.new_ones(1).expand(x.shape)
  107. return _polynomial_value(poly, x, zero_power, transition)
  108. def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
  109. # compute a projection operator onto an orthogonal subspace spanned by the
  110. # columns of U defined as (I - UU^T)
  111. Ut = U.mT.contiguous()
  112. proj_U_ortho = -U.matmul(Ut)
  113. proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
  114. # compute U_ortho, a basis for the orthogonal complement to the span(U),
  115. # by projecting a random [..., m, m - k] matrix onto the subspace spanned
  116. # by the columns of U.
  117. #
  118. # fix generator for determinism
  119. gen = torch.Generator(A.device)
  120. # orthogonal complement to the span(U)
  121. U_ortho = proj_U_ortho.matmul(
  122. torch.randn(
  123. (*A.shape[:-1], A.size(-1) - D.size(-1)),
  124. dtype=A.dtype,
  125. device=A.device,
  126. generator=gen,
  127. )
  128. )
  129. U_ortho_t = U_ortho.mT.contiguous()
  130. # compute the coefficients of the characteristic polynomial of the tensor D.
  131. # Note that D is diagonal, so the diagonal elements are exactly the roots
  132. # of the characteristic polynomial.
  133. chr_poly_D = _polynomial_coefficients_given_roots(D)
  134. # the code below finds the explicit solution to the Sylvester equation
  135. # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
  136. # and incorporates it into the whole gradient stored in the `res` variable.
  137. #
  138. # Equivalent to the following naive implementation:
  139. # res = A.new_zeros(A.shape)
  140. # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
  141. # for k in range(1, chr_poly_D.size(-1)):
  142. # p_res.zero_()
  143. # for i in range(0, k):
  144. # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
  145. # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t())
  146. #
  147. # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
  148. # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
  149. # and we need to compute g(U_grad, A, U, D)
  150. #
  151. # The naive implementation is based on the paper
  152. # Hu, Qingxi, and Daizhan Cheng.
  153. # "The polynomial solution to the Sylvester matrix equation."
  154. # Applied mathematics letters 19.9 (2006): 859-864.
  155. #
  156. # We can modify the computation of `p_res` from above in a more efficient way
  157. # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
  158. # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
  159. # + ...
  160. # + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
  161. # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
  162. U_grad_projected = U_grad
  163. series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
  164. for k in range(1, chr_poly_D.size(-1)):
  165. poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
  166. series_acc += U_grad_projected * poly_D.unsqueeze(-2)
  167. U_grad_projected = A.matmul(U_grad_projected)
  168. # compute chr_poly_D(A) which essentially is:
  169. #
  170. # chr_poly_D_at_A = A.new_zeros(A.shape)
  171. # for k in range(chr_poly_D.size(-1)):
  172. # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
  173. #
  174. # Note, however, for better performance we use the Horner's rule
  175. chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
  176. # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
  177. chr_poly_D_at_A_to_U_ortho = torch.matmul(
  178. U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
  179. )
  180. # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
  181. # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
  182. # Cholesky decomposition requires the input to be positive-definite.
  183. # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
  184. # 1. `largest` == False, or
  185. # 2. `largest` == True and `k` is even
  186. # under the assumption that `A` has distinct eigenvalues.
  187. #
  188. # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
  189. chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
  190. chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
  191. chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
  192. )
  193. # compute the gradient part in span(U)
  194. res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
  195. # incorporate the Sylvester equation solution into the full gradient
  196. # it resides in span(U_ortho)
  197. res -= U_ortho.matmul(
  198. chr_poly_D_at_A_to_U_ortho_sign
  199. * torch.cholesky_solve(
  200. U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L
  201. )
  202. ).matmul(Ut)
  203. return res
  204. def _symeig_backward(D_grad, U_grad, A, D, U, largest):
  205. # if `U` is square, then the columns of `U` is a complete eigenspace
  206. if U.size(-1) == U.size(-2):
  207. return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
  208. else:
  209. return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)
  210. class LOBPCGAutogradFunction(torch.autograd.Function):
  211. @staticmethod
  212. def forward( # type: ignore[override]
  213. ctx,
  214. A: Tensor,
  215. k: Optional[int] = None,
  216. B: Optional[Tensor] = None,
  217. X: Optional[Tensor] = None,
  218. n: Optional[int] = None,
  219. iK: Optional[Tensor] = None,
  220. niter: Optional[int] = None,
  221. tol: Optional[float] = None,
  222. largest: Optional[bool] = None,
  223. method: Optional[str] = None,
  224. tracker: None = None,
  225. ortho_iparams: Optional[dict[str, int]] = None,
  226. ortho_fparams: Optional[dict[str, float]] = None,
  227. ortho_bparams: Optional[dict[str, bool]] = None,
  228. ) -> tuple[Tensor, Tensor]:
  229. # makes sure that input is contiguous for efficiency.
  230. # Note: autograd does not support dense gradients for sparse input yet.
  231. A = A.contiguous() if (not A.is_sparse) else A
  232. if B is not None:
  233. B = B.contiguous() if (not B.is_sparse) else B
  234. D, U = _lobpcg(
  235. A,
  236. k,
  237. B,
  238. X,
  239. n,
  240. iK,
  241. niter,
  242. tol,
  243. largest,
  244. method,
  245. tracker,
  246. ortho_iparams,
  247. ortho_fparams,
  248. ortho_bparams,
  249. )
  250. ctx.save_for_backward(A, B, D, U)
  251. ctx.largest = largest
  252. return D, U
  253. @staticmethod
  254. def backward(ctx, D_grad, U_grad):
  255. A_grad = B_grad = None
  256. grads = [None] * 14
  257. A, B, D, U = ctx.saved_tensors
  258. largest = ctx.largest
  259. # lobpcg.backward has some limitations. Checks for unsupported input
  260. if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
  261. raise ValueError(
  262. "lobpcg.backward does not support sparse input yet."
  263. "Note that lobpcg.forward does though."
  264. )
  265. if (
  266. A.dtype in (torch.complex64, torch.complex128)
  267. or B is not None
  268. and B.dtype in (torch.complex64, torch.complex128)
  269. ):
  270. raise ValueError(
  271. "lobpcg.backward does not support complex input yet."
  272. "Note that lobpcg.forward does though."
  273. )
  274. if B is not None:
  275. raise ValueError(
  276. "lobpcg.backward does not support backward with B != I yet."
  277. )
  278. if largest is None:
  279. largest = True
  280. # symeig backward
  281. if B is None:
  282. A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)
  283. # A has index 0
  284. grads[0] = A_grad
  285. # B has index 2
  286. grads[2] = B_grad
  287. return tuple(grads)
  288. def lobpcg(
  289. A: Tensor,
  290. k: Optional[int] = None,
  291. B: Optional[Tensor] = None,
  292. X: Optional[Tensor] = None,
  293. n: Optional[int] = None,
  294. iK: Optional[Tensor] = None,
  295. niter: Optional[int] = None,
  296. tol: Optional[float] = None,
  297. largest: Optional[bool] = None,
  298. method: Optional[str] = None,
  299. tracker: None = None,
  300. ortho_iparams: Optional[dict[str, int]] = None,
  301. ortho_fparams: Optional[dict[str, float]] = None,
  302. ortho_bparams: Optional[dict[str, bool]] = None,
  303. ) -> tuple[Tensor, Tensor]:
  304. """Find the k largest (or smallest) eigenvalues and the corresponding
  305. eigenvectors of a symmetric positive definite generalized
  306. eigenvalue problem using matrix-free LOBPCG methods.
  307. This function is a front-end to the following LOBPCG algorithms
  308. selectable via `method` argument:
  309. `method="basic"` - the LOBPCG method introduced by Andrew
  310. Knyazev, see [Knyazev2001]. A less robust method, may fail when
  311. Cholesky is applied to singular input.
  312. `method="ortho"` - the LOBPCG method with orthogonal basis
  313. selection [StathopoulosEtal2002]. A robust method.
  314. Supported inputs are dense, sparse, and batches of dense matrices.
  315. .. note:: In general, the basic method spends least time per
  316. iteration. However, the robust methods converge much faster and
  317. are more stable. So, the usage of the basic method is generally
  318. not recommended but there exist cases where the usage of the
  319. basic method may be preferred.
  320. .. warning:: The backward method does not support sparse and complex inputs.
  321. It works only when `B` is not provided (i.e. `B == None`).
  322. We are actively working on extensions, and the details of
  323. the algorithms are going to be published promptly.
  324. .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
  325. To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
  326. in first-order optimization routines, prior to running `lobpcg`
  327. we do the following symmetrization map: `A -> (A + A.t()) / 2`.
  328. The map is performed only when the `A` requires gradients.
  329. .. warning:: LOBPCG algorithm is not applicable when the number of `A`'s rows
  330. is smaller than 3x the number of requested eigenpairs `n`.
  331. Args:
  332. A (Tensor): the input tensor of size :math:`(*, m, m)`
  333. k (integer, optional): the number of requested
  334. eigenpairs. Default is the number of :math:`X`
  335. columns (when specified) or `1`.
  336. B (Tensor, optional): the input tensor of size :math:`(*, m,
  337. m)`. When not specified, `B` is interpreted as
  338. identity matrix.
  339. X (tensor, optional): the input tensor of size :math:`(*, m, n)`
  340. where `k <= n <= m`. When specified, it is used as
  341. initial approximation of eigenvectors. X must be a
  342. dense tensor.
  343. n (integer, optional): if :math:`X` is not specified then `n`
  344. specifies the size of the generated random
  345. approximation of eigenvectors. Default value for `n`
  346. is `k`. If :math:`X` is specified, any provided value of `n` is
  347. ignored and `n` is automatically set to the number of
  348. columns in :math:`X`.
  349. iK (tensor, optional): the input tensor of size :math:`(*, m,
  350. m)`. When specified, it will be used as preconditioner.
  351. niter (int, optional): maximum number of iterations. When
  352. reached, the iteration process is hard-stopped and
  353. the current approximation of eigenpairs is returned.
  354. For infinite iteration but until convergence criteria
  355. is met, use `-1`.
  356. tol (float, optional): residual tolerance for stopping
  357. criterion. Default is `feps ** 0.5` where `feps` is
  358. smallest non-zero floating-point number of the given
  359. input tensor `A` data type.
  360. largest (bool, optional): when True, solve the eigenproblem for
  361. the largest eigenvalues. Otherwise, solve the
  362. eigenproblem for smallest eigenvalues. Default is
  363. `True`.
  364. method (str, optional): select LOBPCG method. See the
  365. description of the function above. Default is
  366. "ortho".
  367. tracker (callable, optional) : a function for tracing the
  368. iteration process. When specified, it is called at
  369. each iteration step with LOBPCG instance as an
  370. argument. The LOBPCG instance holds the full state of
  371. the iteration process in the following attributes:
  372. `iparams`, `fparams`, `bparams` - dictionaries of
  373. integer, float, and boolean valued input
  374. parameters, respectively
  375. `ivars`, `fvars`, `bvars`, `tvars` - dictionaries
  376. of integer, float, boolean, and Tensor valued
  377. iteration variables, respectively.
  378. `A`, `B`, `iK` - input Tensor arguments.
  379. `E`, `X`, `S`, `R` - iteration Tensor variables.
  380. For instance:
  381. `ivars["istep"]` - the current iteration step
  382. `X` - the current approximation of eigenvectors
  383. `E` - the current approximation of eigenvalues
  384. `R` - the current residual
  385. `ivars["converged_count"]` - the current number of converged eigenpairs
  386. `tvars["rerr"]` - the current state of convergence criteria
  387. Note that when `tracker` stores Tensor objects from
  388. the LOBPCG instance, it must make copies of these.
  389. If `tracker` sets `bvars["force_stop"] = True`, the
  390. iteration process will be hard-stopped.
  391. ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
  392. various parameters to LOBPCG algorithm when using
  393. `method="ortho"`.
  394. Returns:
  395. E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
  396. X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
  397. References:
  398. [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
  399. Preconditioned Eigensolver: Locally Optimal Block Preconditioned
  400. Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
  401. 517-541. (25 pages)
  402. https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
  403. [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
  404. Wu. (2002) A Block Orthogonalization Procedure with Constant
  405. Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
  406. 2165-2182. (18 pages)
  407. https://epubs.siam.org/doi/10.1137/S1064827500370883
  408. [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
  409. Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
  410. SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
  411. https://arxiv.org/abs/1704.07458
  412. """
  413. if not torch.jit.is_scripting():
  414. tensor_ops = (A, B, X, iK)
  415. if not set(map(type, tensor_ops)).issubset(
  416. (torch.Tensor, type(None))
  417. ) and has_torch_function(tensor_ops):
  418. return handle_torch_function(
  419. lobpcg,
  420. tensor_ops,
  421. A,
  422. k=k,
  423. B=B,
  424. X=X,
  425. n=n,
  426. iK=iK,
  427. niter=niter,
  428. tol=tol,
  429. largest=largest,
  430. method=method,
  431. tracker=tracker,
  432. ortho_iparams=ortho_iparams,
  433. ortho_fparams=ortho_fparams,
  434. ortho_bparams=ortho_bparams,
  435. )
  436. if not torch._jit_internal.is_scripting():
  437. if A.requires_grad or (B is not None and B.requires_grad):
  438. # While it is expected that `A` is symmetric,
  439. # the `A_grad` might be not. Therefore we perform the trick below,
  440. # so that `A_grad` becomes symmetric.
  441. # The symmetrization is important for first-order optimization methods,
  442. # so that (A - alpha * A_grad) is still a symmetric matrix.
  443. # Same holds for `B`.
  444. A_sym = (A + A.mT) / 2
  445. B_sym = (B + B.mT) / 2 if (B is not None) else None
  446. return LOBPCGAutogradFunction.apply(
  447. A_sym,
  448. k,
  449. B_sym,
  450. X,
  451. n,
  452. iK,
  453. niter,
  454. tol,
  455. largest,
  456. method,
  457. tracker,
  458. ortho_iparams,
  459. ortho_fparams,
  460. ortho_bparams,
  461. )
  462. else:
  463. if A.requires_grad or (B is not None and B.requires_grad):
  464. raise RuntimeError(
  465. "Script and require grads is not supported atm."
  466. "If you just want to do the forward, use .detach()"
  467. "on A and B before calling into lobpcg"
  468. )
  469. return _lobpcg(
  470. A,
  471. k,
  472. B,
  473. X,
  474. n,
  475. iK,
  476. niter,
  477. tol,
  478. largest,
  479. method,
  480. tracker,
  481. ortho_iparams,
  482. ortho_fparams,
  483. ortho_bparams,
  484. )
  485. def _lobpcg(
  486. A: Tensor,
  487. k: Optional[int] = None,
  488. B: Optional[Tensor] = None,
  489. X: Optional[Tensor] = None,
  490. n: Optional[int] = None,
  491. iK: Optional[Tensor] = None,
  492. niter: Optional[int] = None,
  493. tol: Optional[float] = None,
  494. largest: Optional[bool] = None,
  495. method: Optional[str] = None,
  496. tracker: None = None,
  497. ortho_iparams: Optional[dict[str, int]] = None,
  498. ortho_fparams: Optional[dict[str, float]] = None,
  499. ortho_bparams: Optional[dict[str, bool]] = None,
  500. ) -> tuple[Tensor, Tensor]:
  501. # A must be square:
  502. assert A.shape[-2] == A.shape[-1], A.shape
  503. if B is not None:
  504. # A and B must have the same shapes:
  505. assert A.shape == B.shape, (A.shape, B.shape)
  506. dtype = _utils.get_floating_dtype(A)
  507. device = A.device
  508. if tol is None:
  509. feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype]
  510. tol = feps**0.5
  511. m = A.shape[-1]
  512. k = (1 if X is None else X.shape[-1]) if k is None else k
  513. n = (k if n is None else n) if X is None else X.shape[-1]
  514. if m < 3 * n:
  515. raise ValueError(
  516. f"LPBPCG algorithm is not applicable when the number of A rows (={m})"
  517. f" is smaller than 3 x the number of requested eigenpairs (={n})"
  518. )
  519. method = "ortho" if method is None else method
  520. iparams = {
  521. "m": m,
  522. "n": n,
  523. "k": k,
  524. "niter": 1000 if niter is None else niter,
  525. }
  526. fparams = {
  527. "tol": tol,
  528. }
  529. bparams = {"largest": True if largest is None else largest}
  530. if method == "ortho":
  531. if ortho_iparams is not None:
  532. iparams.update(ortho_iparams)
  533. if ortho_fparams is not None:
  534. fparams.update(ortho_fparams)
  535. if ortho_bparams is not None:
  536. bparams.update(ortho_bparams)
  537. iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3)
  538. iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3)
  539. fparams["ortho_tol"] = fparams.get("ortho_tol", tol)
  540. fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol)
  541. fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol)
  542. bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False)
  543. if not torch.jit.is_scripting():
  544. LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[method-assign]
  545. if len(A.shape) > 2:
  546. N = int(torch.prod(torch.tensor(A.shape[:-2])))
  547. bA = A.reshape((N,) + A.shape[-2:])
  548. bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
  549. bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
  550. bE = torch.empty((N, k), dtype=dtype, device=device)
  551. bXret = torch.empty((N, m, k), dtype=dtype, device=device)
  552. for i in range(N):
  553. A_ = bA[i]
  554. B_ = bB[i] if bB is not None else None
  555. X_ = (
  556. torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
  557. )
  558. assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
  559. iparams["batch_index"] = i
  560. worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
  561. worker.run()
  562. bE[i] = worker.E[:k]
  563. bXret[i] = worker.X[:, :k]
  564. if not torch.jit.is_scripting():
  565. LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign]
  566. return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
  567. X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
  568. assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
  569. worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
  570. worker.run()
  571. if not torch.jit.is_scripting():
  572. LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign]
  573. return worker.E[:k], worker.X[:, :k]
  574. class LOBPCG:
  575. """Worker class of LOBPCG methods."""
  576. def __init__(
  577. self,
  578. A: Optional[Tensor],
  579. B: Optional[Tensor],
  580. X: Tensor,
  581. iK: Optional[Tensor],
  582. iparams: dict[str, int],
  583. fparams: dict[str, float],
  584. bparams: dict[str, bool],
  585. method: str,
  586. tracker: None,
  587. ) -> None:
  588. # constant parameters
  589. self.A = A
  590. self.B = B
  591. self.iK = iK
  592. self.iparams = iparams
  593. self.fparams = fparams
  594. self.bparams = bparams
  595. self.method = method
  596. self.tracker = tracker
  597. m = iparams["m"]
  598. n = iparams["n"]
  599. # variable parameters
  600. self.X = X
  601. self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
  602. self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
  603. self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
  604. self.tvars: dict[str, Tensor] = {}
  605. self.ivars: dict[str, int] = {"istep": 0}
  606. self.fvars: dict[str, float] = {"_": 0.0}
  607. self.bvars: dict[str, bool] = {"_": False}
  608. def __str__(self):
  609. lines = ["LOPBCG:"]
  610. lines += [f" iparams={self.iparams}"]
  611. lines += [f" fparams={self.fparams}"]
  612. lines += [f" bparams={self.bparams}"]
  613. lines += [f" ivars={self.ivars}"]
  614. lines += [f" fvars={self.fvars}"]
  615. lines += [f" bvars={self.bvars}"]
  616. lines += [f" tvars={self.tvars}"]
  617. lines += [f" A={self.A}"]
  618. lines += [f" B={self.B}"]
  619. lines += [f" iK={self.iK}"]
  620. lines += [f" X={self.X}"]
  621. lines += [f" E={self.E}"]
  622. r = ""
  623. for line in lines:
  624. r += line + "\n"
  625. return r
  626. def update(self):
  627. """Set and update iteration variables."""
  628. if self.ivars["istep"] == 0:
  629. X_norm = float(torch.norm(self.X))
  630. iX_norm = X_norm**-1
  631. A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
  632. B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
  633. self.fvars["X_norm"] = X_norm
  634. self.fvars["A_norm"] = A_norm
  635. self.fvars["B_norm"] = B_norm
  636. self.ivars["iterations_left"] = self.iparams["niter"]
  637. self.ivars["converged_count"] = 0
  638. self.ivars["converged_end"] = 0
  639. if self.method == "ortho":
  640. self._update_ortho()
  641. else:
  642. self._update_basic()
  643. self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1
  644. self.ivars["istep"] = self.ivars["istep"] + 1
  645. def update_residual(self):
  646. """Update residual R from A, B, X, E."""
  647. mm = _utils.matmul
  648. self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
  649. def update_converged_count(self):
  650. """Determine the number of converged eigenpairs using backward stable
  651. convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
  652. Users may redefine this method for custom convergence criteria.
  653. """
  654. # (...) -> int
  655. prev_count = self.ivars["converged_count"]
  656. tol = self.fparams["tol"]
  657. A_norm = self.fvars["A_norm"]
  658. B_norm = self.fvars["B_norm"]
  659. E, X, R = self.E, self.X, self.R
  660. rerr = torch.norm(R, 2, (0,)) / (
  661. torch.norm(X, 2, (0,)) * (A_norm + torch.abs(E[: X.shape[-1]]) * B_norm)
  662. )
  663. converged = rerr < tol
  664. count = 0
  665. for b in converged:
  666. if not b:
  667. # ignore convergence of following pairs to ensure
  668. # strict ordering of eigenpairs
  669. break
  670. count += 1
  671. assert count >= prev_count, (
  672. f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease"
  673. )
  674. self.ivars["converged_count"] = count
  675. self.tvars["rerr"] = rerr
  676. return count
  677. def stop_iteration(self):
  678. """Return True to stop iterations.
  679. Note that tracker (if defined) can force-stop iterations by
  680. setting ``worker.bvars['force_stop'] = True``.
  681. """
  682. return (
  683. self.bvars.get("force_stop", False)
  684. or self.ivars["iterations_left"] == 0
  685. or self.ivars["converged_count"] >= self.iparams["k"]
  686. )
  687. def run(self):
  688. """Run LOBPCG iterations.
  689. Use this method as a template for implementing LOBPCG
  690. iteration scheme with custom tracker that is compatible with
  691. TorchScript.
  692. """
  693. self.update()
  694. if not torch.jit.is_scripting() and self.tracker is not None:
  695. self.call_tracker()
  696. while not self.stop_iteration():
  697. self.update()
  698. if not torch.jit.is_scripting() and self.tracker is not None:
  699. self.call_tracker()
  700. @torch.jit.unused
  701. def call_tracker(self):
  702. """Interface for tracking iteration process in Python mode.
  703. Tracking the iteration process is disabled in TorchScript
  704. mode. In fact, one should specify tracker=None when JIT
  705. compiling functions using lobpcg.
  706. """
  707. # do nothing when in TorchScript mode
  708. # Internal methods
  709. def _update_basic(self):
  710. """
  711. Update or initialize iteration variables when `method == "basic"`.
  712. """
  713. mm = torch.matmul
  714. ns = self.ivars["converged_end"]
  715. nc = self.ivars["converged_count"]
  716. n = self.iparams["n"]
  717. largest = self.bparams["largest"]
  718. if self.ivars["istep"] == 0:
  719. Ri = self._get_rayleigh_ritz_transform(self.X)
  720. M = _utils.qform(_utils.qform(self.A, self.X), Ri)
  721. E, Z = _utils.symeig(M, largest)
  722. self.X[:] = mm(self.X, mm(Ri, Z))
  723. self.E[:] = E
  724. np = 0
  725. self.update_residual()
  726. nc = self.update_converged_count()
  727. self.S[..., :n] = self.X
  728. W = _utils.matmul(self.iK, self.R)
  729. self.ivars["converged_end"] = ns = n + np + W.shape[-1]
  730. self.S[:, n + np : ns] = W
  731. else:
  732. S_ = self.S[:, nc:ns]
  733. Ri = self._get_rayleigh_ritz_transform(S_)
  734. M = _utils.qform(_utils.qform(self.A, S_), Ri)
  735. E_, Z = _utils.symeig(M, largest)
  736. self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc]))
  737. self.E[nc:] = E_[: n - nc]
  738. P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc]))
  739. np = P.shape[-1]
  740. self.update_residual()
  741. nc = self.update_converged_count()
  742. self.S[..., :n] = self.X
  743. self.S[:, n : n + np] = P
  744. W = _utils.matmul(self.iK, self.R[:, nc:])
  745. self.ivars["converged_end"] = ns = n + np + W.shape[-1]
  746. self.S[:, n + np : ns] = W
  747. def _update_ortho(self):
  748. """
  749. Update or initialize iteration variables when `method == "ortho"`.
  750. """
  751. mm = torch.matmul
  752. ns = self.ivars["converged_end"]
  753. nc = self.ivars["converged_count"]
  754. n = self.iparams["n"]
  755. largest = self.bparams["largest"]
  756. if self.ivars["istep"] == 0:
  757. Ri = self._get_rayleigh_ritz_transform(self.X)
  758. M = _utils.qform(_utils.qform(self.A, self.X), Ri)
  759. _E, Z = _utils.symeig(M, largest)
  760. self.X = mm(self.X, mm(Ri, Z))
  761. self.update_residual()
  762. np = 0
  763. nc = self.update_converged_count()
  764. self.S[:, :n] = self.X
  765. W = self._get_ortho(self.R, self.X)
  766. ns = self.ivars["converged_end"] = n + np + W.shape[-1]
  767. self.S[:, n + np : ns] = W
  768. else:
  769. S_ = self.S[:, nc:ns]
  770. # Rayleigh-Ritz procedure
  771. E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
  772. # Update E, X, P
  773. self.X[:, nc:] = mm(S_, Z[:, : n - nc])
  774. self.E[nc:] = E_[: n - nc]
  775. P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT)))
  776. np = P.shape[-1]
  777. # check convergence
  778. self.update_residual()
  779. nc = self.update_converged_count()
  780. # update S
  781. self.S[:, :n] = self.X
  782. self.S[:, n : n + np] = P
  783. W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np])
  784. ns = self.ivars["converged_end"] = n + np + W.shape[-1]
  785. self.S[:, n + np : ns] = W
  786. def _get_rayleigh_ritz_transform(self, S):
  787. """Return a transformation matrix that is used in Rayleigh-Ritz
  788. procedure for reducing a general eigenvalue problem :math:`(S^TAS)
  789. C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
  790. S^TAS Ri) Z = Z E` where `C = Ri Z`.
  791. .. note:: In the original Rayleight-Ritz procedure in
  792. [DuerschEtal2018], the problem is formulated as follows::
  793. SAS = S^T A S
  794. SBS = S^T B S
  795. D = (<diagonal matrix of SBS>) ** -1/2
  796. R^T R = Cholesky(D SBS D)
  797. Ri = D R^-1
  798. solve symeig problem Ri^T SAS Ri Z = Theta Z
  799. C = Ri Z
  800. To reduce the number of matrix products (denoted by empty
  801. space between matrices), here we introduce element-wise
  802. products (denoted by symbol `*`) so that the Rayleight-Ritz
  803. procedure becomes::
  804. SAS = S^T A S
  805. SBS = S^T B S
  806. d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector
  807. dd = d d^T # this is 2-d matrix
  808. R^T R = Cholesky(dd * SBS)
  809. Ri = R^-1 * d # broadcasting
  810. solve symeig problem Ri^T SAS Ri Z = Theta Z
  811. C = Ri Z
  812. where `dd` is 2-d matrix that replaces matrix products `D M
  813. D` with one element-wise product `M * dd`; and `d` replaces
  814. matrix product `D M` with element-wise product `M *
  815. d`. Also, creating the diagonal matrix `D` is avoided.
  816. Args:
  817. S (Tensor): the matrix basis for the search subspace, size is
  818. :math:`(m, n)`.
  819. Returns:
  820. Ri (tensor): upper-triangular transformation matrix of size
  821. :math:`(n, n)`.
  822. """
  823. B = self.B
  824. SBS = _utils.qform(B, S)
  825. d_row = SBS.diagonal(0, -2, -1) ** -0.5
  826. d_col = d_row.reshape(d_row.shape[0], 1)
  827. # TODO use torch.linalg.cholesky_solve once it is implemented
  828. R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
  829. return torch.linalg.solve_triangular(
  830. R, d_row.diag_embed(), upper=True, left=False
  831. )
  832. def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor:
  833. """Return B-orthonormal U.
  834. .. note:: When `drop` is `False` then `svqb` is based on the
  835. Algorithm 4 from [DuerschPhD2015] that is a slight
  836. modification of the corresponding algorithm
  837. introduced in [StathopolousWu2002].
  838. Args:
  839. U (Tensor) : initial approximation, size is (m, n)
  840. drop (bool) : when True, drop columns that
  841. contribution to the `span([U])` is small.
  842. tau (float) : positive tolerance
  843. Returns:
  844. U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
  845. is (m, n1), where `n1 = n` if `drop` is `False,
  846. otherwise `n1 <= n`.
  847. """
  848. if torch.numel(U) == 0:
  849. return U
  850. UBU = _utils.qform(self.B, U)
  851. d = UBU.diagonal(0, -2, -1)
  852. # Detect and drop exact zero columns from U. While the test
  853. # `abs(d) == 0` is unlikely to be True for random data, it is
  854. # possible to construct input data to lobpcg where it will be
  855. # True leading to a failure (notice the `d ** -0.5` operation
  856. # in the original algorithm). To prevent the failure, we drop
  857. # the exact zero columns here and then continue with the
  858. # original algorithm below.
  859. nz = torch.where(abs(d) != 0.0)
  860. assert len(nz) == 1, nz
  861. if len(nz[0]) < len(d):
  862. U = U[:, nz[0]]
  863. if torch.numel(U) == 0:
  864. return U
  865. UBU = _utils.qform(self.B, U)
  866. d = UBU.diagonal(0, -2, -1)
  867. nz = torch.where(abs(d) != 0.0)
  868. assert len(nz[0]) == len(d)
  869. # The original algorithm 4 from [DuerschPhD2015].
  870. d_col = (d**-0.5).reshape(d.shape[0], 1)
  871. DUBUD = (UBU * d_col) * d_col.mT
  872. E, Z = _utils.symeig(DUBUD)
  873. t = tau * abs(E).max()
  874. if drop:
  875. keep = torch.where(E > t)
  876. assert len(keep) == 1, keep
  877. E = E[keep[0]]
  878. Z = Z[:, keep[0]]
  879. d_col = d_col[keep[0]]
  880. else:
  881. E[(torch.where(E < t))[0]] = t
  882. return torch.matmul(U * d_col.mT, Z * E**-0.5)
  883. def _get_ortho(self, U, V):
  884. """Return B-orthonormal U with columns are B-orthogonal to V.
  885. .. note:: When `bparams["ortho_use_drop"] == False` then
  886. `_get_ortho` is based on the Algorithm 3 from
  887. [DuerschPhD2015] that is a slight modification of
  888. the corresponding algorithm introduced in
  889. [StathopolousWu2002]. Otherwise, the method
  890. implements Algorithm 6 from [DuerschPhD2015]
  891. .. note:: If all U columns are B-collinear to V then the
  892. returned tensor U will be empty.
  893. Args:
  894. U (Tensor) : initial approximation, size is (m, n)
  895. V (Tensor) : B-orthogonal external basis, size is (m, k)
  896. Returns:
  897. U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
  898. such that :math:`V^T B U=0`, size is (m, n1),
  899. where `n1 = n` if `drop` is `False, otherwise
  900. `n1 <= n`.
  901. """
  902. mm = torch.matmul
  903. mm_B = _utils.matmul
  904. m = self.iparams["m"]
  905. tau_ortho = self.fparams["ortho_tol"]
  906. tau_drop = self.fparams["ortho_tol_drop"]
  907. tau_replace = self.fparams["ortho_tol_replace"]
  908. i_max = self.iparams["ortho_i_max"]
  909. j_max = self.iparams["ortho_j_max"]
  910. # when use_drop==True, enable dropping U columns that have
  911. # small contribution to the `span([U, V])`.
  912. use_drop = self.bparams["ortho_use_drop"]
  913. # clean up variables from the previous call
  914. for vkey in list(self.fvars.keys()):
  915. if vkey.startswith("ortho_") and vkey.endswith("_rerr"):
  916. self.fvars.pop(vkey)
  917. self.ivars.pop("ortho_i", 0)
  918. self.ivars.pop("ortho_j", 0)
  919. BV_norm = torch.norm(mm_B(self.B, V))
  920. BU = mm_B(self.B, U)
  921. VBU = mm(V.mT, BU)
  922. i = j = 0
  923. for i in range(i_max):
  924. U = U - mm(V, VBU)
  925. drop = False
  926. tau_svqb = tau_drop
  927. for j in range(j_max):
  928. if use_drop:
  929. U = self._get_svqb(U, drop, tau_svqb)
  930. drop = True
  931. tau_svqb = tau_replace
  932. else:
  933. U = self._get_svqb(U, False, tau_replace)
  934. if torch.numel(U) == 0:
  935. # all initial U columns are B-collinear to V
  936. self.ivars["ortho_i"] = i
  937. self.ivars["ortho_j"] = j
  938. return U
  939. BU = mm_B(self.B, U)
  940. UBU = mm(U.mT, BU)
  941. U_norm = torch.norm(U)
  942. BU_norm = torch.norm(BU)
  943. R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype)
  944. R_norm = torch.norm(R)
  945. # https://github.com/pytorch/pytorch/issues/33810 workaround:
  946. rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
  947. vkey = f"ortho_UBUmI_rerr[{i}, {j}]"
  948. self.fvars[vkey] = rerr
  949. if rerr < tau_ortho:
  950. break
  951. VBU = mm(V.mT, BU)
  952. VBU_norm = torch.norm(VBU)
  953. U_norm = torch.norm(U)
  954. rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
  955. vkey = f"ortho_VBU_rerr[{i}]"
  956. self.fvars[vkey] = rerr
  957. if rerr < tau_ortho:
  958. break
  959. if m < U.shape[-1] + V.shape[-1]:
  960. # TorchScript needs the class var to be assigned to a local to
  961. # do optional type refinement
  962. B = self.B
  963. assert B is not None
  964. raise ValueError(
  965. "Overdetermined shape of U:"
  966. f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold"
  967. )
  968. self.ivars["ortho_i"] = i
  969. self.ivars["ortho_j"] = j
  970. return U
  971. # Calling tracker is separated from LOBPCG definitions because
  972. # TorchScript does not support user-defined callback arguments:
  973. LOBPCG_call_tracker_orig = LOBPCG.call_tracker
  974. def LOBPCG_call_tracker(self):
  975. self.tracker(self)