__init__.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710
  1. # mypy: allow-untyped-defs
  2. # The Tensor classes are added to this module by python_tensor.cpp
  3. # A workaround to support both TorchScript and MyPy:
  4. from typing import Any, Optional, TYPE_CHECKING, Union
  5. import torch
  6. from torch import Tensor
  7. from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
  8. # Semi structured sparsity support
  9. from .semi_structured import (
  10. SparseSemiStructuredTensor,
  11. SparseSemiStructuredTensorCUSPARSELT,
  12. SparseSemiStructuredTensorCUTLASS,
  13. to_sparse_semi_structured,
  14. )
  15. if TYPE_CHECKING:
  16. from torch.types import _dtype as DType
  17. DimOrDims = Optional[Union[int, tuple[int, ...], list[int]]]
  18. else:
  19. # The JIT doesn't understand Union, nor torch.dtype here
  20. DType = int
  21. DimOrDims = Optional[tuple[int]]
  22. __all__ = [
  23. "addmm",
  24. "check_sparse_tensor_invariants",
  25. "mm",
  26. "sum",
  27. "softmax",
  28. "solve",
  29. "log_softmax",
  30. "SparseSemiStructuredTensor",
  31. "SparseSemiStructuredTensorCUTLASS",
  32. "SparseSemiStructuredTensorCUSPARSELT",
  33. "to_sparse_semi_structured",
  34. "as_sparse_gradcheck",
  35. ]
  36. addmm = _add_docstr(
  37. _sparse._sparse_addmm,
  38. r"""
  39. sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
  40. This function does exact same thing as :func:`torch.addmm` in the forward,
  41. except that it supports backward for sparse COO matrix :attr:`mat1`.
  42. When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
  43. When inputs are COO tensors, this function also supports backward for both inputs.
  44. Supports both CSR and COO storage formats.
  45. .. note::
  46. This function doesn't support computing derivatives with respect to CSR matrices.
  47. Args:
  48. mat (Tensor): a dense matrix to be added
  49. mat1 (Tensor): a sparse matrix to be multiplied
  50. mat2 (Tensor): a dense matrix to be multiplied
  51. beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
  52. alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
  53. """,
  54. )
  55. mm = _add_docstr(
  56. _sparse._sparse_mm,
  57. r"""
  58. Performs a matrix multiplication of the sparse matrix :attr:`mat1`
  59. and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a
  60. :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
  61. :math:`(n \times p)` tensor.
  62. When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
  63. When inputs are COO tensors, this function also supports backward for both inputs.
  64. Supports both CSR and COO storage formats.
  65. .. note::
  66. This function doesn't support computing derivatives with respect to CSR matrices.
  67. This function also additionally accepts an optional :attr:`reduce` argument that allows
  68. specification of an optional reduction operation, mathematically performs the following operation:
  69. .. math::
  70. z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}
  71. where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for
  72. CSR storage format on CPU device.
  73. Args:
  74. mat1 (Tensor): the first sparse matrix to be multiplied
  75. mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
  76. reduce (str, optional): the reduction operation to apply for non-unique indices
  77. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`.
  78. Shape:
  79. The format of the output tensor of this function follows:
  80. - sparse x sparse -> sparse
  81. - sparse x dense -> dense
  82. Example::
  83. >>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
  84. >>> a
  85. tensor(indices=tensor([[0, 0, 1],
  86. [0, 2, 1]]),
  87. values=tensor([1., 2., 3.]),
  88. size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
  89. >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
  90. >>> b
  91. tensor([[0., 1.],
  92. [2., 0.],
  93. [0., 0.]], requires_grad=True)
  94. >>> y = torch.sparse.mm(a, b)
  95. >>> y
  96. tensor([[0., 1.],
  97. [6., 0.]], grad_fn=<SparseAddmmBackward0>)
  98. >>> y.sum().backward()
  99. >>> a.grad
  100. tensor(indices=tensor([[0, 0, 1],
  101. [0, 2, 1]]),
  102. values=tensor([1., 0., 2.]),
  103. size=(2, 3), nnz=3, layout=torch.sparse_coo)
  104. >>> c = a.detach().to_sparse_csr()
  105. >>> c
  106. tensor(crow_indices=tensor([0, 2, 3]),
  107. col_indices=tensor([0, 2, 1]),
  108. values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
  109. layout=torch.sparse_csr)
  110. >>> y1 = torch.sparse.mm(c, b, 'sum')
  111. >>> y1
  112. tensor([[0., 1.],
  113. [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
  114. >>> y2 = torch.sparse.mm(c, b, 'max')
  115. >>> y2
  116. tensor([[0., 1.],
  117. [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
  118. """,
  119. )
  120. sampled_addmm = _add_docstr(
  121. _sparse.sparse_sampled_addmm,
  122. r"""
  123. sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor
  124. Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations
  125. specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result.
  126. Mathematically this performs the following operation:
  127. .. math::
  128. \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}
  129. where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha`
  130. and :attr:`beta` are the scaling factors.
  131. :math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere.
  132. .. note::
  133. :attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors.
  134. Args:
  135. input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute
  136. the sampled matrix multiplication
  137. mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied
  138. mat2 (Tensor): a dense matrix of shape `(k, n)` to be multiplied
  139. Keyword args:
  140. beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`)
  141. alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
  142. out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
  143. Examples::
  144. >>> input = torch.eye(3, device='cuda').to_sparse_csr()
  145. >>> mat1 = torch.randn(3, 5, device='cuda')
  146. >>> mat2 = torch.randn(5, 3, device='cuda')
  147. >>> torch.sparse.sampled_addmm(input, mat1, mat2)
  148. tensor(crow_indices=tensor([0, 1, 2, 3]),
  149. col_indices=tensor([0, 1, 2]),
  150. values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0',
  151. size=(3, 3), nnz=3, layout=torch.sparse_csr)
  152. >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense()
  153. tensor([[ 0.2847, 0.0000, 0.0000],
  154. [ 0.0000, -0.7805, 0.0000],
  155. [ 0.0000, 0.0000, -0.1900]], device='cuda:0')
  156. >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5)
  157. tensor(crow_indices=tensor([0, 1, 2, 3]),
  158. col_indices=tensor([0, 1, 2]),
  159. values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0',
  160. size=(3, 3), nnz=3, layout=torch.sparse_csr)
  161. """,
  162. )
  163. def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor:
  164. r"""Return the sum of each row of the given sparse tensor.
  165. Returns the sum of each row of the sparse tensor :attr:`input` in the given
  166. dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,
  167. reduce over all of them. When sum over all ``sparse_dim``, this method
  168. returns a dense tensor instead of a sparse tensor.
  169. All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output
  170. tensor having :attr:`dim` fewer dimensions than :attr:`input`.
  171. During backward, only gradients at ``nnz`` locations of :attr:`input`
  172. will propagate back. Note that the gradients of :attr:`input` is coalesced.
  173. Args:
  174. input (Tensor): the input sparse tensor
  175. dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce
  176. over all dims.
  177. dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
  178. Default: dtype of :attr:`input`.
  179. Example::
  180. >>> nnz = 3
  181. >>> dims = [5, 5, 2, 3]
  182. >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
  183. torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
  184. >>> V = torch.randn(nnz, dims[2], dims[3])
  185. >>> size = torch.Size(dims)
  186. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  187. >>> S = torch.sparse_coo_tensor(I, V, size)
  188. >>> S
  189. tensor(indices=tensor([[2, 0, 3],
  190. [2, 4, 1]]),
  191. values=tensor([[[-0.6438, -1.6467, 1.4004],
  192. [ 0.3411, 0.0918, -0.2312]],
  193. [[ 0.5348, 0.0634, -2.0494],
  194. [-0.7125, -1.0646, 2.1844]],
  195. [[ 0.1276, 0.1874, -0.6334],
  196. [-1.9682, -0.5340, 0.7483]]]),
  197. size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)
  198. # when sum over only part of sparse_dims, return a sparse tensor
  199. >>> torch.sparse.sum(S, [1, 3])
  200. tensor(indices=tensor([[0, 2, 3]]),
  201. values=tensor([[-1.4512, 0.4073],
  202. [-0.8901, 0.2017],
  203. [-0.3183, -1.7539]]),
  204. size=(5, 2), nnz=3, layout=torch.sparse_coo)
  205. # when sum over all sparse dim, return a dense tensor
  206. # with summed dims squeezed
  207. >>> torch.sparse.sum(S, [0, 1, 3])
  208. tensor([-2.6596, -1.1450])
  209. """
  210. if dtype is None:
  211. if dim is not None:
  212. return torch._sparse_sum(input, dim)
  213. else:
  214. return torch._sparse_sum(input)
  215. else:
  216. if dim is not None:
  217. return torch._sparse_sum(input, dim, dtype=dtype)
  218. else:
  219. return torch._sparse_sum(input, dtype=dtype)
  220. softmax = _add_docstr(
  221. _sparse._sparse_softmax,
  222. r"""
  223. sparse.softmax(input, dim, *, dtype=None) -> Tensor
  224. Applies a softmax function.
  225. Softmax is defined as:
  226. :math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`
  227. where :math:`i, j` run over sparse tensor indices and unspecified
  228. entries are ignores. This is equivalent to defining unspecified
  229. entries as negative infinity so that :math:`exp(x_k) = 0` when the
  230. entry with index :math:`k` has not specified.
  231. It is applied to all slices along `dim`, and will re-scale them so
  232. that the elements lie in the range `[0, 1]` and sum to 1.
  233. Args:
  234. input (Tensor): input
  235. dim (int): A dimension along which softmax will be computed.
  236. dtype (:class:`torch.dtype`, optional): the desired data type
  237. of returned tensor. If specified, the input tensor is
  238. casted to :attr:`dtype` before the operation is
  239. performed. This is useful for preventing data type
  240. overflows. Default: None
  241. """,
  242. )
  243. spsolve = _add_docstr(
  244. _sparse._spsolve,
  245. r"""
  246. sparse.spsolve(input, other, *, left=True) -> Tensor
  247. Computes the solution of a square system of linear equations with
  248. a unique solution. Its purpose is similar to :func:`torch.linalg.solve`,
  249. except that the system is defined by a sparse CSR matrix with layout
  250. `sparse_csr`.
  251. Args:
  252. input (Tensor): a sparse CSR matrix of shape `(n, n)` representing the
  253. coefficients of the linear system.
  254. other (Tensor): a dense matrix of shape `(n, )` representing the right-hand
  255. side of the linear system.
  256. left (bool, optional): whether to solve the system for `input @ out = other`
  257. (default) or `out @ input = other`. Only `left=True` is supported.
  258. """,
  259. )
  260. log_softmax = _add_docstr(
  261. _sparse._sparse_log_softmax,
  262. r"""
  263. sparse.log_softmax(input, dim, *, dtype=None) -> Tensor
  264. Applies a softmax function followed by logarithm.
  265. See :class:`~torch.sparse.softmax` for more details.
  266. Args:
  267. input (Tensor): input
  268. dim (int): A dimension along which softmax will be computed.
  269. dtype (:class:`torch.dtype`, optional): the desired data type
  270. of returned tensor. If specified, the input tensor is
  271. casted to :attr:`dtype` before the operation is
  272. performed. This is useful for preventing data type
  273. overflows. Default: None
  274. """,
  275. )
  276. spdiags = _add_docstr(
  277. _sparse._spdiags,
  278. r"""
  279. sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
  280. Creates a sparse 2D tensor by placing the values from rows of
  281. :attr:`diagonals` along specified diagonals of the output
  282. The :attr:`offsets` tensor controls which diagonals are set.
  283. - If :attr:`offsets[i]` = 0, it is the main diagonal
  284. - If :attr:`offsets[i]` < 0, it is below the main diagonal
  285. - If :attr:`offsets[i]` > 0, it is above the main diagonal
  286. The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
  287. and an offset may not be repeated.
  288. Args:
  289. diagonals (Tensor): Matrix storing diagonals row-wise
  290. offsets (Tensor): The diagonals to be set, stored as a vector
  291. shape (2-tuple of ints): The desired shape of the result
  292. Keyword args:
  293. layout (:class:`torch.layout`, optional): The desired layout of the
  294. returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
  295. are supported. Default: ``torch.sparse_coo``
  296. Examples:
  297. Set the main and first two lower diagonals of a matrix::
  298. >>> diags = torch.arange(9).reshape(3, 3)
  299. >>> diags
  300. tensor([[0, 1, 2],
  301. [3, 4, 5],
  302. [6, 7, 8]])
  303. >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
  304. >>> s
  305. tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
  306. [0, 1, 2, 0, 1, 0]]),
  307. values=tensor([0, 1, 2, 3, 4, 6]),
  308. size=(3, 3), nnz=6, layout=torch.sparse_coo)
  309. >>> s.to_dense()
  310. tensor([[0, 0, 0],
  311. [3, 1, 0],
  312. [6, 4, 2]])
  313. Change the output layout::
  314. >>> diags = torch.arange(9).reshape(3, 3)
  315. >>> diags
  316. tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
  317. >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
  318. >>> s
  319. tensor(crow_indices=tensor([0, 1, 3, 6]),
  320. col_indices=tensor([0, 0, 1, 0, 1, 2]),
  321. values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
  322. layout=torch.sparse_csr)
  323. >>> s.to_dense()
  324. tensor([[0, 0, 0],
  325. [3, 1, 0],
  326. [6, 4, 2]])
  327. Set partial diagonals of a large output::
  328. >>> diags = torch.tensor([[1, 2], [3, 4]])
  329. >>> offsets = torch.tensor([0, -1])
  330. >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
  331. tensor([[1, 0, 0, 0, 0],
  332. [3, 2, 0, 0, 0],
  333. [0, 4, 0, 0, 0],
  334. [0, 0, 0, 0, 0],
  335. [0, 0, 0, 0, 0]])
  336. .. note::
  337. When setting the values along a given diagonal the index into the diagonal
  338. and the index into the row of :attr:`diagonals` is taken as the
  339. column index in the output. This has the effect that when setting a diagonal
  340. with a positive offset `k` the first value along that diagonal will be
  341. the value in position `k` of the row of :attr:`diagonals`
  342. Specifying a positive offset::
  343. >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
  344. >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
  345. tensor([[1, 2, 3, 0, 0],
  346. [0, 2, 3, 0, 0],
  347. [0, 0, 3, 0, 0],
  348. [0, 0, 0, 0, 0],
  349. [0, 0, 0, 0, 0]])
  350. """,
  351. )
  352. class check_sparse_tensor_invariants:
  353. """A tool to control checking sparse tensor invariants.
  354. The following options exists to manage sparsr tensor invariants
  355. checking in sparse tensor construction:
  356. 1. Using a context manager:
  357. .. code:: python
  358. with torch.sparse.check_sparse_tensor_invariants():
  359. run_my_model()
  360. 2. Using a procedural approach:
  361. .. code:: python
  362. prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
  363. torch.sparse.check_sparse_tensor_invariants.enable()
  364. run_my_model()
  365. if not prev_checks_enabled:
  366. torch.sparse.check_sparse_tensor_invariants.disable()
  367. 3. Using function decoration:
  368. .. code:: python
  369. @torch.sparse.check_sparse_tensor_invariants()
  370. def run_my_model():
  371. ...
  372. run_my_model()
  373. 4. Using ``check_invariants`` keyword argument in sparse tensor constructor call.
  374. For example:
  375. >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True)
  376. Traceback (most recent call last):
  377. File "<stdin>", line 1, in <module>
  378. RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied.
  379. """
  380. @staticmethod
  381. def is_enabled():
  382. r"""Return True if the sparse tensor invariants checking is enabled.
  383. .. note::
  384. Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or
  385. :func:`torch.sparse.check_sparse_tensor_invariants.disable` to
  386. manage the state of the sparse tensor invariants checks.
  387. """
  388. return torch._C._check_sparse_tensor_invariants()
  389. @staticmethod
  390. def enable():
  391. r"""Enable sparse tensor invariants checking in sparse tensor constructors.
  392. .. note::
  393. By default, the sparse tensor invariants checks are disabled. Use
  394. :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to
  395. retrieve the current state of sparse tensor invariants checking.
  396. .. note::
  397. The sparse tensor invariants check flag is effective to all sparse
  398. tensor constructors, both in Python and ATen.
  399. The flag can be locally overridden by the ``check_invariants``
  400. optional argument of the sparse tensor constructor functions.
  401. """
  402. torch._C._set_check_sparse_tensor_invariants(True)
  403. @staticmethod
  404. def disable():
  405. r"""Disable sparse tensor invariants checking in sparse tensor constructors.
  406. See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information.
  407. """
  408. torch._C._set_check_sparse_tensor_invariants(False)
  409. # context manager support
  410. def __init__(self, enable=True):
  411. self.state = enable
  412. self.saved_state: Optional[bool] = None
  413. def __enter__(self):
  414. if self.saved_state is not None:
  415. raise RuntimeError(
  416. "This context manager instance is already activated."
  417. " Use a different context manager instance for context nesting."
  418. )
  419. self.saved_state = self.is_enabled()
  420. torch._C._set_check_sparse_tensor_invariants(self.state)
  421. def __exit__(self, type, value, traceback):
  422. assert self.saved_state is not None
  423. torch._C._set_check_sparse_tensor_invariants(self.saved_state)
  424. self.saved_state = None
  425. # decorator support
  426. def __call__(self, mth):
  427. def test_mth(*args, **kwargs):
  428. with type(self)(self.state):
  429. return mth(*args, **kwargs)
  430. return test_mth
  431. def as_sparse_gradcheck(gradcheck):
  432. """Decorate function, to extend gradcheck for sparse tensors.
  433. Decorator for torch.autograd.gradcheck or its functools.partial
  434. variants that extends the gradcheck function with support to input
  435. functions that operate on or/and return sparse tensors.
  436. The specified gradcheck function itself is guaranteed to operate
  437. on strided tensors only.
  438. For example:
  439. >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
  440. >>> x = (
  441. ... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64)
  442. ... .to_sparse_coo()
  443. ... .requires_grad_(True)
  444. ... )
  445. >>> gradcheck(lambda x: x.to_sparse_csr(), x)
  446. True
  447. """
  448. def gradcheck_with_sparse_support(func, inputs, **kwargs):
  449. """
  450. Create gradcheck with support for sparse tensors.
  451. Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support.
  452. """
  453. masked = kwargs.pop("masked", False)
  454. sparse_layouts = {
  455. torch.sparse_coo,
  456. torch.sparse_csr,
  457. torch.sparse_csc,
  458. torch.sparse_bsr,
  459. torch.sparse_bsc,
  460. }
  461. sparse_compressed_layouts = {
  462. torch.sparse_csr,
  463. torch.sparse_csc,
  464. torch.sparse_bsr,
  465. torch.sparse_bsc,
  466. }
  467. sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc}
  468. STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__"
  469. def convert_to_strided_representation(args):
  470. """Convert differentiable non-strided tensors to a representation containing differentiable strided tensors."""
  471. if not isinstance(args, (list, tuple)):
  472. args = (args,)
  473. new_args: list[Any] = []
  474. for obj in args:
  475. if (
  476. isinstance(obj, torch.Tensor)
  477. and obj.requires_grad
  478. and obj.layout in sparse_layouts
  479. ):
  480. d = {
  481. "layout": obj.layout,
  482. "shape": obj.shape,
  483. }
  484. if not masked:
  485. # Materialize unspecified elements with zero values
  486. batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
  487. blocksize = (
  488. obj.values().shape[batch_dim + 1 : batch_dim + 3]
  489. if obj.layout in sparse_block_layouts
  490. else None
  491. )
  492. full_mask = torch.ones(
  493. obj.shape, device=obj.device, dtype=torch.bool
  494. ).to_sparse(
  495. layout=obj.layout,
  496. blocksize=blocksize,
  497. dense_dim=obj.dense_dim(),
  498. )
  499. obj = obj.to_dense().sparse_mask(full_mask)
  500. if obj.layout is torch.sparse_coo:
  501. d.update(
  502. indices=obj._indices(), is_coalesced=obj.is_coalesced()
  503. )
  504. values = obj._values()
  505. elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
  506. d.update(
  507. compressed_indices=obj.crow_indices(),
  508. plain_indices=obj.col_indices(),
  509. )
  510. values = obj.values()
  511. else:
  512. d.update(
  513. compressed_indices=obj.ccol_indices(),
  514. plain_indices=obj.row_indices(),
  515. )
  516. values = obj.values()
  517. new_args.extend(
  518. (STRIDED_REPRESENTATION, d, values.requires_grad_(True))
  519. )
  520. else:
  521. new_args.append(obj)
  522. return tuple(new_args)
  523. def restore_from_strided_representation(args):
  524. """Restore non-strided differentiable tensosr from their strided representations."""
  525. new_args = []
  526. args = list(args)
  527. while args:
  528. a = args.pop(0)
  529. if a == STRIDED_REPRESENTATION:
  530. d, values = args.pop(0), args.pop(0)
  531. if d["layout"] is torch.sparse_coo:
  532. a = torch.sparse_coo_tensor(
  533. d["indices"],
  534. values,
  535. size=d["shape"],
  536. is_coalesced=d["is_coalesced"],
  537. )
  538. elif d["layout"] in sparse_compressed_layouts:
  539. a = torch.sparse_compressed_tensor(
  540. d["compressed_indices"],
  541. d["plain_indices"],
  542. values,
  543. size=d["shape"],
  544. layout=d["layout"],
  545. )
  546. else:
  547. raise NotImplementedError(
  548. f"conversion of {d['layout']} strided representation to tensor"
  549. )
  550. new_args.append(a)
  551. return tuple(new_args)
  552. def func_wrapper(*args, **kwargs):
  553. restored_args = restore_from_strided_representation(args)
  554. # convert differentiable output sparse tensors to strided
  555. # tensors:
  556. outputs = func(*restored_args, **kwargs)
  557. strided_outputs = (
  558. tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
  559. )
  560. strided_outputs = tuple(
  561. (
  562. o.to_dense(masked_grad=masked)
  563. if isinstance(o, torch.Tensor)
  564. and o.requires_grad
  565. and o.layout in sparse_layouts
  566. else o
  567. )
  568. for o in strided_outputs
  569. )
  570. return (
  571. strided_outputs
  572. if isinstance(outputs, (list, tuple))
  573. else strided_outputs[0]
  574. )
  575. args = (func_wrapper, convert_to_strided_representation(inputs))
  576. return gradcheck(*args, **kwargs)
  577. return gradcheck_with_sparse_support