semi_structured.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from collections import namedtuple
  4. from typing import Any, Callable, Optional
  5. import torch
  6. from torch.sparse._semi_structured_conversions import (
  7. sparse_semi_structured_from_dense_cutlass,
  8. sparse_semi_structured_to_dense_cutlass,
  9. )
  10. from torch.sparse._semi_structured_ops import (
  11. fallback_dispatcher,
  12. semi_sparse_addmm,
  13. semi_sparse_detach,
  14. semi_sparse_indices,
  15. semi_sparse_linear,
  16. semi_sparse_mm,
  17. semi_sparse_scaled_mm,
  18. semi_sparse_t,
  19. semi_sparse_values,
  20. semi_sparse_view,
  21. )
  22. __all__ = [
  23. "SparseSemiStructuredTensor",
  24. "SparseSemiStructuredTensorCUTLASS",
  25. "SparseSemiStructuredTensorCUSPARSELT",
  26. "to_sparse_semi_structured",
  27. ]
  28. _SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
  29. "_SEMI_STRUCTURED_SPARSE_CONFIG",
  30. "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
  31. )
  32. class SparseSemiStructuredTensor(torch.Tensor):
  33. """
  34. This class implements semi-structured sparsity as a Tensor subclass.
  35. Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
  36. depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
  37. structured sparsity.
  38. There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
  39. This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
  40. and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
  41. Note that as such, this class cannot be instantiated directly.
  42. -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
  43. - `def from_dense()` - backend specific compression routines
  44. - `def _mm()` - backend specific mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
  45. """
  46. _DEFAULT_ALG_ID: int = 0
  47. _DTYPE_SHAPE_CONSTRAINTS: dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
  48. _FORCE_CUTLASS: bool = False
  49. _FUSE_TRANSPOSE: bool = False
  50. _PROTOTYPE_WARNING_SHOWN: bool = False
  51. BACKEND: str
  52. SPARSE_DISPATCH: dict[Callable, Callable]
  53. packed: Optional[torch.Tensor]
  54. meta: Optional[torch.Tensor]
  55. packed_t: Optional[torch.Tensor]
  56. meta_t: Optional[torch.Tensor]
  57. compressed_swizzled_bitmask: Optional[torch.Tensor]
  58. fuse_transpose_cusparselt: bool
  59. alg_id_cusparselt: int
  60. __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
  61. @staticmethod
  62. def __new__( # noqa: PYI034
  63. cls,
  64. shape: torch.Size,
  65. packed: Optional[torch.Tensor],
  66. meta: Optional[torch.Tensor],
  67. packed_t: Optional[torch.Tensor],
  68. meta_t: Optional[torch.Tensor],
  69. compressed_swizzled_bitmask: Optional[torch.Tensor],
  70. fuse_transpose_cusparselt: bool = False,
  71. alg_id_cusparselt: int = 0,
  72. requires_grad: bool = False,
  73. ):
  74. """
  75. Create a new instance of the tensor subclass from the compressed sparse representation.
  76. We have the option to create the subclass with the compressed representations of both X and X', for training.
  77. For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
  78. Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
  79. Args:
  80. shape: The shape of the original dense tensor
  81. packed: The compressed representation of the original dense tensor
  82. meta: The metadata of the original dense tensor, if it is stored separately
  83. packed_t: The compressed representation of the transposed original dense tensor
  84. meta_t: The metadata of the transposed original dense tensor, if it is stored separately
  85. compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
  86. participate in the computation. Used for pointwise ops.
  87. fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
  88. with a matmul, which is useful in the case of 2:4 sparse training.
  89. alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
  90. Returns:
  91. torch.Tensor: A torch.Tensor wrapper subclass.
  92. Raises:
  93. ValueError: If all of the tensor arguments are None.
  94. """
  95. if not cls._PROTOTYPE_WARNING_SHOWN:
  96. warnings.warn(
  97. (
  98. "The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
  99. "and will change in the near future. Please open a Github issue "
  100. "for features requests and see our documentation on the torch.sparse "
  101. "module for further information about the project."
  102. ),
  103. UserWarning,
  104. )
  105. cls._PROTOTYPE_WARNING_SHOWN = True
  106. # Because this only runs once, we also load the dispatch table here as well.
  107. # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
  108. # But this is useful since it allows users to overload the dispatch table for debugging / testing.
  109. cls._load_dispatch_table()
  110. # we can also register the classes with dynamo when the warning is shown.
  111. torch._dynamo.allow_in_graph(cls)
  112. if packed is not None:
  113. previous_tensor = packed
  114. elif packed_t is not None:
  115. previous_tensor = packed_t
  116. else:
  117. raise ValueError("At least one of packed or packed_t must be provided")
  118. tensor = torch.Tensor._make_wrapper_subclass(
  119. cls,
  120. shape,
  121. device=previous_tensor.device,
  122. dtype=previous_tensor.dtype,
  123. layout=previous_tensor.layout,
  124. requires_grad=requires_grad,
  125. )
  126. tensor.packed = packed
  127. tensor.meta = meta
  128. tensor.packed_t = packed_t
  129. tensor.meta_t = meta_t
  130. tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
  131. tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
  132. tensor.alg_id_cusparselt = alg_id_cusparselt
  133. return tensor
  134. def __repr__(self) -> str: # type: ignore[override]
  135. assert hasattr(self, "shape")
  136. return f"{self.__class__.__name__}(shape={self.shape})"
  137. def __tensor_flatten__(
  138. self,
  139. ) -> tuple[list[str], tuple[torch.Size, bool, int, bool]]:
  140. inner_tensors = list(
  141. filter(lambda x: getattr(self, x) is not None, self.__slots__)
  142. )
  143. tensor_meta = (
  144. self.shape,
  145. self.fuse_transpose_cusparselt,
  146. self.alg_id_cusparselt,
  147. self.requires_grad,
  148. )
  149. return inner_tensors, tensor_meta
  150. @classmethod
  151. def __tensor_unflatten__(
  152. cls,
  153. inner_tensors,
  154. tensor_meta: tuple[torch.Size, bool, int, bool],
  155. outer_size,
  156. outer_stride,
  157. ) -> torch.Tensor:
  158. shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
  159. return cls(
  160. shape=shape,
  161. packed=inner_tensors.get("packed", None),
  162. meta=inner_tensors.get("meta", None),
  163. packed_t=inner_tensors.get("packed_t", None),
  164. meta_t=inner_tensors.get("meta_t", None),
  165. compressed_swizzled_bitmask=inner_tensors.get(
  166. "compressed_swizzled_bitmask", None
  167. ),
  168. fuse_transpose_cusparselt=fuse_transpose_cusparselt,
  169. alg_id_cusparselt=alg_id_cusparselt,
  170. requires_grad=requires_grad,
  171. )
  172. __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore[assignment]
  173. @classmethod
  174. def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: # type: ignore[override]
  175. if func._overloadpacket not in cls.SPARSE_DISPATCH:
  176. raise NotImplementedError(
  177. f"{cls.__name__} only supports a specific set of operations, "
  178. f"can't perform requested op ({func.__name__})"
  179. )
  180. return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
  181. @classmethod
  182. def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
  183. """
  184. Loads the op overload sparse dispatch table for the current class.
  185. """
  186. if getattr(cls, "SPARSE_DISPATCH", None) is None:
  187. cls.SPARSE_DISPATCH = {
  188. torch.ops.aten.values: semi_sparse_values,
  189. torch.ops.aten.indices: semi_sparse_indices,
  190. torch.ops.aten.is_same_size: fallback_dispatcher,
  191. torch.ops.aten.detach_: fallback_dispatcher,
  192. torch.ops.aten.detach: semi_sparse_detach,
  193. torch.ops.aten.t: semi_sparse_t,
  194. torch.ops.aten.view: semi_sparse_view,
  195. torch.ops.aten.mm: semi_sparse_mm,
  196. torch.ops.aten.matmul: semi_sparse_mm,
  197. torch.ops.aten.addmm: semi_sparse_addmm,
  198. torch.ops.aten.linear: semi_sparse_linear,
  199. torch.ops.aten._to_copy: fallback_dispatcher,
  200. torch.ops.aten._scaled_mm: semi_sparse_scaled_mm,
  201. }
  202. if custom_dispatch_table is not None:
  203. cls.SPARSE_DISPATCH.update(custom_dispatch_table)
  204. @classmethod
  205. def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None:
  206. """
  207. Assert that the given tensor is valid for semi-structured sparse compression.
  208. """
  209. # check device
  210. if not original_tensor.is_cuda:
  211. raise RuntimeError(
  212. f"Error original_tensor.device= {original_tensor.device} is not supported! "
  213. "Only CUDA tensors are currently supported."
  214. )
  215. # check dim
  216. if original_tensor.dim() != 2:
  217. raise RuntimeError(
  218. f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
  219. "Only 2d tensors are currently supported."
  220. )
  221. # check contiguous
  222. if not original_tensor.is_contiguous():
  223. raise RuntimeError(
  224. "Error original_tensor is not contiguous!"
  225. "Only contiguous tensors are currently supported."
  226. )
  227. # check dtype
  228. if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
  229. raise RuntimeError(
  230. f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype for {cls}!"
  231. )
  232. # check shape
  233. m, n = original_tensor.shape
  234. min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
  235. min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
  236. if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
  237. # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
  238. raise RuntimeError(
  239. f"Error original_tensor.shape {original_tensor.shape} is not supported! "
  240. f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
  241. )
  242. @classmethod
  243. def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
  244. """
  245. Calculates padding for dense tensor and pads tensor if necessary.
  246. If padding is not required, this function returns the original tensor.
  247. """
  248. # only 2d matmul
  249. assert dense_input.dim() == 2
  250. # check shape
  251. m, n = dense_input.shape
  252. min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
  253. min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
  254. # calculate padding
  255. to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
  256. to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
  257. if to_pad_m or to_pad_n:
  258. return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
  259. else:
  260. return dense_input
  261. def to_dense(self): # type:ignore[override]
  262. col = self.shape[-1]
  263. return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
  264. @classmethod
  265. def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor":
  266. raise NotImplementedError
  267. def _mm(
  268. self,
  269. B: torch.Tensor,
  270. *,
  271. bias: Optional[torch.Tensor] = None,
  272. **kwargs,
  273. ) -> torch.Tensor:
  274. raise NotImplementedError
  275. def to_sparse_semi_structured(
  276. original_tensor: torch.Tensor,
  277. transposed: bool = False,
  278. ) -> SparseSemiStructuredTensor:
  279. """
  280. This function converts a dense tensor into a sparse semi-structured tensor.
  281. It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
  282. This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
  283. We currently only support semi-structured sparse tensors for 2d CUDA tensors.
  284. Additionally, your tensor must be a positive multiple of the minimum sparse block size, given in
  285. `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
  286. Args:
  287. original_tensor (Tensor): the dense tensor to convert
  288. transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
  289. Returns:
  290. SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
  291. Raises:
  292. None
  293. Example:
  294. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  295. >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
  296. tensor([[0., 0., 1., ..., 0., 1., 1.],
  297. [0., 0., 1., ..., 0., 1., 1.],
  298. [0., 0., 1., ..., 0., 1., 1.],
  299. ...,
  300. [0., 0., 1., ..., 0., 1., 1.],
  301. [0., 0., 1., ..., 0., 1., 1.],
  302. [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
  303. >>> A_sparse = to_sparse_semi_structured(A)
  304. SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
  305. >>> A_sparse.values()
  306. tensor([[1., 1., 1., ..., 1., 1., 1.],
  307. [1., 1., 1., ..., 1., 1., 1.],
  308. [1., 1., 1., ..., 1., 1., 1.],
  309. ...,
  310. [1., 1., 1., ..., 1., 1., 1.],
  311. [1., 1., 1., ..., 1., 1., 1.],
  312. [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
  313. >>> A_sparse.indices()
  314. tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
  315. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  316. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  317. ...,
  318. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  319. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  320. [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
  321. """
  322. if transposed:
  323. warnings.warn(
  324. "Setting transpose from `to_sparse_semi_structured` is deprecated "
  325. "and will be removed in a future release. "
  326. "`SparseSemiStructuredTensor` only support contiguous input tensors.",
  327. FutureWarning,
  328. stacklevel=2,
  329. )
  330. # set from _FORCE_CUTLASS flag
  331. SPARSE_SUBCLASS = (
  332. torch.sparse.SparseSemiStructuredTensorCUTLASS
  333. if SparseSemiStructuredTensor._FORCE_CUTLASS
  334. else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
  335. )
  336. return SPARSE_SUBCLASS.from_dense(original_tensor)
  337. class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
  338. """
  339. This class implements semi-structured sparsity for the CUTLASS backend.
  340. In this implementation, the specified elements and metadata are stored separately,
  341. in packed and meta respectively.
  342. When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
  343. sparse_semi_structured_from_dense for conversion to the compressed format.
  344. """
  345. BACKEND = "cutlass"
  346. _DTYPE_SHAPE_CONSTRAINTS = {
  347. torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
  348. torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
  349. torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
  350. torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
  351. }
  352. @classmethod
  353. def from_dense(
  354. cls, original_tensor: torch.Tensor
  355. ) -> "SparseSemiStructuredTensorCUTLASS":
  356. cls._validate_device_dim_dtype_shape(original_tensor)
  357. (
  358. sparse_tensor_cutlass,
  359. meta_tensor_cutlass,
  360. ) = sparse_semi_structured_from_dense_cutlass(original_tensor)
  361. return cls(
  362. original_tensor.shape,
  363. packed=sparse_tensor_cutlass,
  364. meta=meta_tensor_cutlass,
  365. packed_t=None,
  366. meta_t=None,
  367. compressed_swizzled_bitmask=None,
  368. requires_grad=original_tensor.requires_grad,
  369. )
  370. def to_dense(self): # type: ignore[override]
  371. assert self.meta is not None and self.packed is not None
  372. return (
  373. sparse_semi_structured_to_dense_cutlass(
  374. self.packed,
  375. self.meta,
  376. )
  377. if self.meta.ndim == 2
  378. else super().to_dense()
  379. )
  380. @classmethod
  381. def prune_dense_static_sort(
  382. cls, original_tensor: torch.Tensor, algorithm=""
  383. ) -> "SparseSemiStructuredTensor":
  384. """
  385. This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
  386. It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
  387. The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
  388. Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
  389. It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
  390. pruned dense tensor.
  391. Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
  392. Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
  393. This can be used in the backward pass to mask the gradients.
  394. [9 1 7 4] [9 0 7 0]
  395. [1 2 3 0] [0 2 0 0]
  396. [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
  397. [1 2 6 2] [0 0 6 2] -> metadata
  398. -> pack to transposed CUTLASS -> packed_t
  399. semi-structured representation -> metadata_t
  400. -> compute swizzled bitmask -> compressed_swizzled_bitmask
  401. The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
  402. ```
  403. from torch.sparse import SparseSemiStructuredTensorCUTLASS
  404. from torch.sparse._semi_structured_conversions import (
  405. _sparse_semi_structured_tile,
  406. _compute_compressed_swizzled_bitmask,
  407. )
  408. pruned = _sparse_semi_structured_tile(dense)
  409. packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
  410. packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(
  411. pruned.t().contiguous()
  412. )
  413. bitmask = _compute_compressed_swizzled_bitmask(pruned)
  414. SparseSemiStructuredTensorCUTLASS(
  415. dense.shape,
  416. packed_cutlass,
  417. meta_cutlass,
  418. packed_t_cutlass,
  419. meta_t_cutlass,
  420. bitmask,
  421. )
  422. ```
  423. """
  424. # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
  425. (
  426. packed,
  427. meta,
  428. packed_t,
  429. meta_t,
  430. compressed_swizzled_bitmask,
  431. ) = torch._sparse_semi_structured_tile(
  432. original_tensor, algorithm=algorithm, use_cutlass=True
  433. )
  434. return cls(
  435. original_tensor.shape,
  436. packed=packed,
  437. meta=meta,
  438. packed_t=packed_t,
  439. meta_t=meta_t,
  440. compressed_swizzled_bitmask=compressed_swizzled_bitmask,
  441. requires_grad=False,
  442. )
  443. def _mm(
  444. self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
  445. ) -> torch.Tensor:
  446. if isinstance(B, SparseSemiStructuredTensor):
  447. raise ValueError(
  448. "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
  449. )
  450. cls_name = self.__class__.__name__
  451. if self.ndim != 2 or B.ndim != 2:
  452. raise NotImplementedError(
  453. f"`{cls_name}` matmul: Broadcasting is not implemented"
  454. )
  455. if self.packed is None or self.meta is None:
  456. raise NotImplementedError(
  457. f"`{cls_name}` matmul: operation is not supported"
  458. )
  459. else:
  460. if bias is None:
  461. res = torch._sparse_semi_structured_mm(self.packed, self.meta, B)
  462. else:
  463. res = torch._sparse_semi_structured_addmm(
  464. bias, self.packed, self.meta, B
  465. )
  466. return res[: self.shape[0]]
  467. class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
  468. """
  469. The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
  470. packed = [ specified elements of original tensor | metadata ]
  471. For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
  472. The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
  473. attributes respectively.
  474. cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
  475. as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
  476. """
  477. BACKEND = "cusparselt"
  478. _DTYPE_SHAPE_CONSTRAINTS = {
  479. torch.float8_e4m3fn: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
  480. torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
  481. torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
  482. torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
  483. }
  484. @classmethod
  485. def from_dense(
  486. cls, original_tensor: torch.Tensor
  487. ) -> "SparseSemiStructuredTensorCUSPARSELT":
  488. cls._validate_device_dim_dtype_shape(original_tensor)
  489. return cls(
  490. shape=original_tensor.shape,
  491. packed=torch._cslt_compress(original_tensor),
  492. meta=None,
  493. packed_t=None,
  494. meta_t=None,
  495. compressed_swizzled_bitmask=None,
  496. fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
  497. alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
  498. requires_grad=original_tensor.requires_grad,
  499. )
  500. @classmethod
  501. def prune_dense_static_sort(
  502. cls, original_tensor: torch.Tensor, algorithm=""
  503. ) -> "SparseSemiStructuredTensor":
  504. """
  505. This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
  506. layout and sparse matmul.
  507. The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
  508. [9 1 7 4] [9 0 7 0]
  509. [1 2 3 0] [0 2 0 0]
  510. [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
  511. [1 2 6 2] [0 0 6 2]
  512. -> pack to transposed cuSPARSELt -> packed_t
  513. semi-structured representation
  514. -> compute swizzled bitmask -> compressed_swizzled_bitmask
  515. The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
  516. ```
  517. from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
  518. from torch.sparse._semi_structured_conversions import (
  519. _sparse_semi_structured_tile,
  520. _compute_compressed_swizzled_bitmask,
  521. )
  522. pruned = _sparse_semi_structured_tile(dense)
  523. packed_cusparselt = torch._cslt_compress(pruned)
  524. packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
  525. bitmask = _compute_compressed_swizzled_bitmask(pruned)
  526. SparseSemiStructuredTensorCUSPARSELT(
  527. dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask
  528. )
  529. ```
  530. """
  531. (
  532. packed,
  533. meta,
  534. packed_t,
  535. meta_t,
  536. compressed_swizzled_bitmask,
  537. ) = torch._sparse_semi_structured_tile(
  538. original_tensor, algorithm=algorithm, use_cutlass=False
  539. )
  540. return cls(
  541. original_tensor.shape,
  542. packed=packed,
  543. meta=meta,
  544. packed_t=packed_t,
  545. meta_t=meta_t,
  546. compressed_swizzled_bitmask=compressed_swizzled_bitmask,
  547. requires_grad=False,
  548. )
  549. def _mm(
  550. self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
  551. ) -> torch.Tensor:
  552. if isinstance(B, SparseSemiStructuredTensor):
  553. raise ValueError(
  554. "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
  555. )
  556. if self.ndim != 2 or B.ndim != 2:
  557. raise NotImplementedError(
  558. f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
  559. )
  560. if B.dtype != self.dtype:
  561. raise NotImplementedError(
  562. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
  563. f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
  564. "This operation is only supported when A and B have the same data type."
  565. )
  566. if bias is not None and bias.dtype != self.dtype:
  567. raise NotImplementedError(
  568. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
  569. f"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
  570. "This operation is only supported when A, B and C have the same data type."
  571. )
  572. # Force fp8 mm to error to be consistent with torch
  573. if self.dtype == torch.float8_e4m3fn:
  574. raise NotImplementedError(
  575. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
  576. f"with A.dtype=B.dtype={self.dtype}. "
  577. "mm is not supported for float8_e4m3fn, please use `torch._scaled_mm` instead."
  578. )
  579. if self.packed is None:
  580. raise NotImplementedError(
  581. f"`{self.__class__.__name__}` matmul: operation is not supported"
  582. )
  583. else:
  584. res = torch._cslt_sparse_mm(
  585. self.packed,
  586. B,
  587. bias=bias,
  588. transpose_result=self.fuse_transpose_cusparselt,
  589. alg_id=self.alg_id_cusparselt,
  590. )
  591. return res.t() if self.fuse_transpose_cusparselt else res