semi_structured.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from collections import namedtuple
  4. from collections.abc import Callable
  5. from typing import Any
  6. import torch
  7. from torch.sparse._semi_structured_conversions import (
  8. sparse_semi_structured_from_dense_cutlass,
  9. sparse_semi_structured_to_dense_cutlass,
  10. )
  11. from torch.sparse._semi_structured_ops import (
  12. fallback_dispatcher,
  13. semi_sparse_addmm,
  14. semi_sparse_detach,
  15. semi_sparse_indices,
  16. semi_sparse_linear,
  17. semi_sparse_mm,
  18. semi_sparse_scaled_mm,
  19. semi_sparse_t,
  20. semi_sparse_values,
  21. semi_sparse_view,
  22. )
  23. __all__ = [
  24. "SparseSemiStructuredTensor",
  25. "SparseSemiStructuredTensorCUTLASS",
  26. "SparseSemiStructuredTensorCUSPARSELT",
  27. "to_sparse_semi_structured",
  28. ]
  29. _SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
  30. "_SEMI_STRUCTURED_SPARSE_CONFIG",
  31. "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
  32. )
  33. class SparseSemiStructuredTensor(torch.Tensor):
  34. """
  35. This class implements semi-structured sparsity as a Tensor subclass.
  36. Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
  37. depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
  38. structured sparsity.
  39. There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
  40. This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
  41. and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
  42. Note that as such, this class cannot be instantiated directly.
  43. -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
  44. - `def from_dense()` - backend specific compression routines
  45. - `def _mm()` - backend specific mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
  46. """
  47. _DEFAULT_ALG_ID: int = 0
  48. _DTYPE_SHAPE_CONSTRAINTS: dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
  49. _FORCE_CUTLASS: bool = False
  50. _FUSE_TRANSPOSE: bool = False
  51. _PROTOTYPE_WARNING_SHOWN: bool = False
  52. BACKEND: str
  53. SPARSE_DISPATCH: dict[Callable, Callable]
  54. packed: torch.Tensor | None
  55. meta: torch.Tensor | None
  56. packed_t: torch.Tensor | None
  57. meta_t: torch.Tensor | None
  58. compressed_swizzled_bitmask: torch.Tensor | None
  59. fuse_transpose_cusparselt: bool
  60. alg_id_cusparselt: int
  61. __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
  62. @staticmethod
  63. def __new__( # noqa: PYI034
  64. cls,
  65. shape: torch.Size,
  66. packed: torch.Tensor | None,
  67. meta: torch.Tensor | None,
  68. packed_t: torch.Tensor | None,
  69. meta_t: torch.Tensor | None,
  70. compressed_swizzled_bitmask: torch.Tensor | None,
  71. fuse_transpose_cusparselt: bool = False,
  72. alg_id_cusparselt: int = 0,
  73. requires_grad: bool = False,
  74. ):
  75. """
  76. Create a new instance of the tensor subclass from the compressed sparse representation.
  77. We have the option to create the subclass with the compressed representations of both X and X', for training.
  78. For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
  79. Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
  80. Args:
  81. shape: The shape of the original dense tensor
  82. packed: The compressed representation of the original dense tensor
  83. meta: The metadata of the original dense tensor, if it is stored separately
  84. packed_t: The compressed representation of the transposed original dense tensor
  85. meta_t: The metadata of the transposed original dense tensor, if it is stored separately
  86. compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
  87. participate in the computation. Used for pointwise ops.
  88. fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
  89. with a matmul, which is useful in the case of 2:4 sparse training.
  90. alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
  91. Returns:
  92. torch.Tensor: A torch.Tensor wrapper subclass.
  93. Raises:
  94. ValueError: If all of the tensor arguments are None.
  95. """
  96. if not cls._PROTOTYPE_WARNING_SHOWN:
  97. warnings.warn(
  98. (
  99. "The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
  100. "and will change in the near future. Please open a Github issue "
  101. "for features requests and see our documentation on the torch.sparse "
  102. "module for further information about the project."
  103. ),
  104. UserWarning,
  105. stacklevel=2,
  106. )
  107. cls._PROTOTYPE_WARNING_SHOWN = True
  108. # Because this only runs once, we also load the dispatch table here as well.
  109. # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
  110. # But this is useful since it allows users to overload the dispatch table for debugging / testing.
  111. cls._load_dispatch_table()
  112. # we can also register the classes with dynamo when the warning is shown.
  113. torch._dynamo.allow_in_graph(cls)
  114. if packed is not None:
  115. previous_tensor = packed
  116. elif packed_t is not None:
  117. previous_tensor = packed_t
  118. else:
  119. raise ValueError("At least one of packed or packed_t must be provided")
  120. tensor = torch.Tensor._make_wrapper_subclass(
  121. cls,
  122. shape,
  123. device=previous_tensor.device,
  124. dtype=previous_tensor.dtype,
  125. layout=previous_tensor.layout,
  126. requires_grad=requires_grad,
  127. )
  128. tensor.packed = packed
  129. tensor.meta = meta
  130. tensor.packed_t = packed_t
  131. tensor.meta_t = meta_t
  132. tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
  133. tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
  134. tensor.alg_id_cusparselt = alg_id_cusparselt
  135. return tensor
  136. def __repr__(self) -> str: # type: ignore[override]
  137. assert hasattr(self, "shape")
  138. return f"{self.__class__.__name__}(shape={self.shape})"
  139. def __tensor_flatten__(
  140. self,
  141. ) -> tuple[list[str], tuple[torch.Size, bool, int, bool]]:
  142. inner_tensors = list(
  143. filter(lambda x: getattr(self, x) is not None, self.__slots__)
  144. )
  145. tensor_meta = (
  146. self.shape,
  147. self.fuse_transpose_cusparselt,
  148. self.alg_id_cusparselt,
  149. self.requires_grad,
  150. )
  151. return inner_tensors, tensor_meta
  152. @classmethod
  153. def __tensor_unflatten__(
  154. cls,
  155. inner_tensors,
  156. tensor_meta: tuple[torch.Size, bool, int, bool],
  157. outer_size,
  158. outer_stride,
  159. ) -> torch.Tensor:
  160. shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
  161. # pyrefly: ignore [no-matching-overload]
  162. return cls(
  163. shape=shape,
  164. packed=inner_tensors.get("packed", None),
  165. meta=inner_tensors.get("meta", None),
  166. packed_t=inner_tensors.get("packed_t", None),
  167. meta_t=inner_tensors.get("meta_t", None),
  168. compressed_swizzled_bitmask=inner_tensors.get(
  169. "compressed_swizzled_bitmask", None
  170. ),
  171. fuse_transpose_cusparselt=fuse_transpose_cusparselt,
  172. alg_id_cusparselt=alg_id_cusparselt,
  173. requires_grad=requires_grad,
  174. )
  175. __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore[assignment]
  176. @classmethod
  177. def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: # type: ignore[override]
  178. if func._overloadpacket not in cls.SPARSE_DISPATCH:
  179. raise NotImplementedError(
  180. f"{cls.__name__} only supports a specific set of operations, "
  181. f"can't perform requested op ({func.__name__})"
  182. )
  183. return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
  184. @classmethod
  185. def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
  186. """
  187. Loads the op overload sparse dispatch table for the current class.
  188. """
  189. if getattr(cls, "SPARSE_DISPATCH", None) is None:
  190. cls.SPARSE_DISPATCH = {
  191. torch.ops.aten.values: semi_sparse_values,
  192. torch.ops.aten.indices: semi_sparse_indices,
  193. torch.ops.aten.is_same_size: fallback_dispatcher,
  194. torch.ops.aten.detach_: fallback_dispatcher,
  195. torch.ops.aten.detach: semi_sparse_detach,
  196. torch.ops.aten.t: semi_sparse_t,
  197. torch.ops.aten.view: semi_sparse_view,
  198. torch.ops.aten.mm: semi_sparse_mm,
  199. torch.ops.aten.matmul: semi_sparse_mm,
  200. torch.ops.aten.addmm: semi_sparse_addmm,
  201. torch.ops.aten.linear: semi_sparse_linear,
  202. torch.ops.aten._to_copy: fallback_dispatcher,
  203. torch.ops.aten._scaled_mm: semi_sparse_scaled_mm,
  204. }
  205. if custom_dispatch_table is not None:
  206. cls.SPARSE_DISPATCH.update(custom_dispatch_table)
  207. @classmethod
  208. def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None:
  209. """
  210. Assert that the given tensor is valid for semi-structured sparse compression.
  211. """
  212. # check device
  213. if not original_tensor.is_cuda:
  214. raise RuntimeError(
  215. f"Error original_tensor.device= {original_tensor.device} is not supported! "
  216. "Only CUDA tensors are currently supported."
  217. )
  218. # check dim
  219. if original_tensor.dim() != 2:
  220. raise RuntimeError(
  221. f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
  222. "Only 2d tensors are currently supported."
  223. )
  224. # check contiguous
  225. if not original_tensor.is_contiguous():
  226. raise RuntimeError(
  227. "Error original_tensor is not contiguous!"
  228. "Only contiguous tensors are currently supported."
  229. )
  230. # check dtype
  231. if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
  232. raise RuntimeError(
  233. f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype for {cls}!"
  234. )
  235. # check shape
  236. m, n = original_tensor.shape
  237. min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
  238. min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
  239. if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
  240. # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
  241. raise RuntimeError(
  242. f"Error original_tensor.shape {original_tensor.shape} is not supported! "
  243. f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
  244. )
  245. @classmethod
  246. def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
  247. """
  248. Calculates padding for dense tensor and pads tensor if necessary.
  249. If padding is not required, this function returns the original tensor.
  250. """
  251. # only 2d matmul
  252. assert dense_input.dim() == 2
  253. # check shape
  254. m, n = dense_input.shape
  255. min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
  256. min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
  257. # calculate padding
  258. to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
  259. to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
  260. if to_pad_m or to_pad_n:
  261. return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
  262. else:
  263. return dense_input
  264. def to_dense(self): # type:ignore[override]
  265. col = self.shape[-1]
  266. return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
  267. @classmethod
  268. def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor":
  269. raise NotImplementedError
  270. def _mm(
  271. self,
  272. B: torch.Tensor,
  273. *,
  274. bias: torch.Tensor | None = None,
  275. **kwargs,
  276. ) -> torch.Tensor:
  277. raise NotImplementedError
  278. def to_sparse_semi_structured(
  279. original_tensor: torch.Tensor,
  280. transposed: bool = False,
  281. ) -> SparseSemiStructuredTensor:
  282. """
  283. This function converts a dense tensor into a sparse semi-structured tensor.
  284. It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
  285. This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
  286. We currently only support semi-structured sparse tensors for 2d CUDA tensors.
  287. Additionally, your tensor must be a positive multiple of the minimum sparse block size, given in
  288. `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
  289. Args:
  290. original_tensor (Tensor): the dense tensor to convert
  291. transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
  292. Returns:
  293. SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
  294. Raises:
  295. None
  296. Example:
  297. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  298. >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
  299. tensor([[0., 0., 1., ..., 0., 1., 1.],
  300. [0., 0., 1., ..., 0., 1., 1.],
  301. [0., 0., 1., ..., 0., 1., 1.],
  302. ...,
  303. [0., 0., 1., ..., 0., 1., 1.],
  304. [0., 0., 1., ..., 0., 1., 1.],
  305. [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
  306. >>> A_sparse = to_sparse_semi_structured(A)
  307. SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
  308. >>> A_sparse.values()
  309. tensor([[1., 1., 1., ..., 1., 1., 1.],
  310. [1., 1., 1., ..., 1., 1., 1.],
  311. [1., 1., 1., ..., 1., 1., 1.],
  312. ...,
  313. [1., 1., 1., ..., 1., 1., 1.],
  314. [1., 1., 1., ..., 1., 1., 1.],
  315. [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
  316. >>> A_sparse.indices()
  317. tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
  318. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  319. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  320. ...,
  321. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  322. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  323. [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
  324. """
  325. if transposed:
  326. warnings.warn(
  327. "Setting transpose from `to_sparse_semi_structured` is deprecated "
  328. "and will be removed in a future release. "
  329. "`SparseSemiStructuredTensor` only support contiguous input tensors.",
  330. FutureWarning,
  331. stacklevel=2,
  332. )
  333. # set from _FORCE_CUTLASS flag
  334. SPARSE_SUBCLASS = (
  335. torch.sparse.SparseSemiStructuredTensorCUTLASS
  336. if SparseSemiStructuredTensor._FORCE_CUTLASS
  337. else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
  338. )
  339. return SPARSE_SUBCLASS.from_dense(original_tensor)
  340. class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
  341. """
  342. This class implements semi-structured sparsity for the CUTLASS backend.
  343. In this implementation, the specified elements and metadata are stored separately,
  344. in packed and meta respectively.
  345. When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
  346. sparse_semi_structured_from_dense for conversion to the compressed format.
  347. """
  348. BACKEND = "cutlass"
  349. _DTYPE_SHAPE_CONSTRAINTS = {
  350. torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
  351. torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
  352. torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
  353. torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
  354. }
  355. @classmethod
  356. def from_dense(
  357. cls, original_tensor: torch.Tensor
  358. ) -> "SparseSemiStructuredTensorCUTLASS":
  359. cls._validate_device_dim_dtype_shape(original_tensor)
  360. (
  361. sparse_tensor_cutlass,
  362. meta_tensor_cutlass,
  363. ) = sparse_semi_structured_from_dense_cutlass(original_tensor)
  364. # pyrefly: ignore [no-matching-overload]
  365. return cls(
  366. original_tensor.shape,
  367. packed=sparse_tensor_cutlass,
  368. meta=meta_tensor_cutlass,
  369. packed_t=None,
  370. meta_t=None,
  371. compressed_swizzled_bitmask=None,
  372. requires_grad=original_tensor.requires_grad,
  373. )
  374. def to_dense(self): # type: ignore[override]
  375. assert self.meta is not None and self.packed is not None
  376. return (
  377. sparse_semi_structured_to_dense_cutlass(
  378. self.packed,
  379. self.meta,
  380. )
  381. if self.meta.ndim == 2
  382. else super().to_dense()
  383. )
  384. @classmethod
  385. def prune_dense_static_sort(
  386. cls, original_tensor: torch.Tensor, algorithm=""
  387. ) -> "SparseSemiStructuredTensor":
  388. """
  389. This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
  390. It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
  391. The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
  392. Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
  393. It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
  394. pruned dense tensor.
  395. Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
  396. Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
  397. This can be used in the backward pass to mask the gradients.
  398. [9 1 7 4] [9 0 7 0]
  399. [1 2 3 0] [0 2 0 0]
  400. [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
  401. [1 2 6 2] [0 0 6 2] -> metadata
  402. -> pack to transposed CUTLASS -> packed_t
  403. semi-structured representation -> metadata_t
  404. -> compute swizzled bitmask -> compressed_swizzled_bitmask
  405. The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
  406. ```
  407. from torch.sparse import SparseSemiStructuredTensorCUTLASS
  408. from torch.sparse._semi_structured_conversions import (
  409. _sparse_semi_structured_tile,
  410. _compute_compressed_swizzled_bitmask,
  411. )
  412. pruned = _sparse_semi_structured_tile(dense)
  413. packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
  414. packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(
  415. pruned.t().contiguous()
  416. )
  417. bitmask = _compute_compressed_swizzled_bitmask(pruned)
  418. SparseSemiStructuredTensorCUTLASS(
  419. dense.shape,
  420. packed_cutlass,
  421. meta_cutlass,
  422. packed_t_cutlass,
  423. meta_t_cutlass,
  424. bitmask,
  425. )
  426. ```
  427. """
  428. # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
  429. (
  430. packed,
  431. meta,
  432. packed_t,
  433. meta_t,
  434. compressed_swizzled_bitmask,
  435. ) = torch._sparse_semi_structured_tile(
  436. original_tensor, algorithm=algorithm, use_cutlass=True
  437. )
  438. # pyrefly: ignore [no-matching-overload]
  439. return cls(
  440. original_tensor.shape,
  441. packed=packed,
  442. meta=meta,
  443. packed_t=packed_t,
  444. meta_t=meta_t,
  445. compressed_swizzled_bitmask=compressed_swizzled_bitmask,
  446. requires_grad=False,
  447. )
  448. def _mm(
  449. self, B: torch.Tensor, *, bias: torch.Tensor | None = None, **kwargs
  450. ) -> torch.Tensor:
  451. if isinstance(B, SparseSemiStructuredTensor):
  452. raise ValueError(
  453. "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
  454. )
  455. cls_name = self.__class__.__name__
  456. if self.ndim != 2 or B.ndim != 2:
  457. raise NotImplementedError(
  458. f"`{cls_name}` matmul: Broadcasting is not implemented"
  459. )
  460. if self.packed is None or self.meta is None:
  461. raise NotImplementedError(
  462. f"`{cls_name}` matmul: operation is not supported"
  463. )
  464. else:
  465. if bias is None:
  466. res = torch._sparse_semi_structured_mm(self.packed, self.meta, B)
  467. else:
  468. res = torch._sparse_semi_structured_addmm(
  469. bias, self.packed, self.meta, B
  470. )
  471. return res[: self.shape[0]]
  472. class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
  473. """
  474. The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
  475. packed = [ specified elements of original tensor | metadata ]
  476. For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
  477. The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
  478. attributes respectively.
  479. cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
  480. as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
  481. """
  482. BACKEND = "cusparselt"
  483. _DTYPE_SHAPE_CONSTRAINTS = {
  484. torch.float8_e4m3fn: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
  485. torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
  486. torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
  487. torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
  488. }
  489. @classmethod
  490. def from_dense(
  491. cls, original_tensor: torch.Tensor
  492. ) -> "SparseSemiStructuredTensorCUSPARSELT":
  493. cls._validate_device_dim_dtype_shape(original_tensor)
  494. # pyrefly: ignore [no-matching-overload]
  495. return cls(
  496. shape=original_tensor.shape,
  497. packed=torch._cslt_compress(original_tensor),
  498. meta=None,
  499. packed_t=None,
  500. meta_t=None,
  501. compressed_swizzled_bitmask=None,
  502. fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
  503. alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
  504. requires_grad=original_tensor.requires_grad,
  505. )
  506. @classmethod
  507. def prune_dense_static_sort(
  508. cls, original_tensor: torch.Tensor, algorithm=""
  509. ) -> "SparseSemiStructuredTensor":
  510. """
  511. This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPARSELt metadata
  512. layout and sparse matmul.
  513. The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
  514. [9 1 7 4] [9 0 7 0]
  515. [1 2 3 0] [0 2 0 0]
  516. [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
  517. [1 2 6 2] [0 0 6 2]
  518. -> pack to transposed cuSPARSELt -> packed_t
  519. semi-structured representation
  520. -> compute swizzled bitmask -> compressed_swizzled_bitmask
  521. The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
  522. ```
  523. from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
  524. from torch.sparse._semi_structured_conversions import (
  525. _sparse_semi_structured_tile,
  526. _compute_compressed_swizzled_bitmask,
  527. )
  528. pruned = _sparse_semi_structured_tile(dense)
  529. packed_cusparselt = torch._cslt_compress(pruned)
  530. packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
  531. bitmask = _compute_compressed_swizzled_bitmask(pruned)
  532. SparseSemiStructuredTensorCUSPARSELT(
  533. dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask
  534. )
  535. ```
  536. """
  537. (
  538. packed,
  539. meta,
  540. packed_t,
  541. meta_t,
  542. compressed_swizzled_bitmask,
  543. ) = torch._sparse_semi_structured_tile(
  544. original_tensor, algorithm=algorithm, use_cutlass=False
  545. )
  546. # Map this two 2-dim view of packed data.
  547. # TODO: is this proper cuSPARSELt metadata?
  548. packed = packed.view(original_tensor.shape[0], -1)
  549. packed_t = packed_t.view(original_tensor.shape[1], -1)
  550. # pyrefly: ignore [no-matching-overload]
  551. return cls(
  552. original_tensor.shape,
  553. packed=packed,
  554. meta=meta,
  555. packed_t=packed_t,
  556. meta_t=meta_t,
  557. compressed_swizzled_bitmask=compressed_swizzled_bitmask,
  558. requires_grad=False,
  559. )
  560. def _mm(
  561. self, B: torch.Tensor, *, bias: torch.Tensor | None = None, **kwargs
  562. ) -> torch.Tensor:
  563. if isinstance(B, SparseSemiStructuredTensor):
  564. raise ValueError(
  565. "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
  566. )
  567. if self.ndim != 2 or B.ndim != 2:
  568. raise NotImplementedError(
  569. f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
  570. )
  571. if B.dtype != self.dtype:
  572. raise NotImplementedError(
  573. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
  574. f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
  575. "This operation is only supported when A and B have the same data type."
  576. )
  577. if bias is not None and bias.dtype != self.dtype:
  578. raise NotImplementedError(
  579. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
  580. f"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
  581. "This operation is only supported when A, B and C have the same data type."
  582. )
  583. # Force fp8 mm to error to be consistent with torch
  584. if self.dtype == torch.float8_e4m3fn:
  585. raise NotImplementedError(
  586. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
  587. f"with A.dtype=B.dtype={self.dtype}. "
  588. "mm is not supported for float8_e4m3fn, please use `torch._scaled_mm` instead."
  589. )
  590. if self.packed is None:
  591. raise NotImplementedError(
  592. f"`{self.__class__.__name__}` matmul: operation is not supported"
  593. )
  594. else:
  595. res = torch._cslt_sparse_mm(
  596. self.packed,
  597. B,
  598. bias=bias,
  599. transpose_result=self.fuse_transpose_cusparselt,
  600. alg_id=self.alg_id_cusparselt,
  601. )
  602. return res.t() if self.fuse_transpose_cusparselt else res