CUDASparseBlas.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #pragma once
  2. /*
  3. Provides a subset of cuSPARSE functions as templates:
  4. csrgeam2<scalar_t>(...)
  5. where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
  6. The functions are available in at::cuda::sparse namespace.
  7. */
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <ATen/cuda/CUDASparse.h>
  10. // NOLINTBEGIN(misc-misplaced-const)
  11. namespace at::cuda::sparse {
  12. #define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \
  13. cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
  14. const cusparseMatDescr_t descrA, int nnzA, \
  15. const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
  16. const int *csrSortedColIndA, const scalar_t *beta, \
  17. const cusparseMatDescr_t descrB, int nnzB, \
  18. const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
  19. const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
  20. const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \
  21. const int *csrSortedColIndC, size_t *pBufferSizeInBytes
  22. template <typename scalar_t>
  23. inline void csrgeam2_bufferSizeExt(
  24. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
  25. TORCH_INTERNAL_ASSERT(
  26. false,
  27. "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
  28. typeid(scalar_t).name());
  29. }
  30. template <>
  31. void csrgeam2_bufferSizeExt<float>(
  32. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
  33. template <>
  34. void csrgeam2_bufferSizeExt<double>(
  35. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
  36. template <>
  37. void csrgeam2_bufferSizeExt<c10::complex<float>>(
  38. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>));
  39. template <>
  40. void csrgeam2_bufferSizeExt<c10::complex<double>>(
  41. CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>));
  42. #define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \
  43. cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \
  44. int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \
  45. const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
  46. const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
  47. int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
  48. template <typename scalar_t>
  49. inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
  50. TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
  51. handle,
  52. m,
  53. n,
  54. descrA,
  55. nnzA,
  56. csrSortedRowPtrA,
  57. csrSortedColIndA,
  58. descrB,
  59. nnzB,
  60. csrSortedRowPtrB,
  61. csrSortedColIndB,
  62. descrC,
  63. csrSortedRowPtrC,
  64. nnzTotalDevHostPtr,
  65. workspace));
  66. }
  67. #define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \
  68. cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
  69. const cusparseMatDescr_t descrA, int nnzA, \
  70. const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
  71. const int *csrSortedColIndA, const scalar_t *beta, \
  72. const cusparseMatDescr_t descrB, int nnzB, \
  73. const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
  74. const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
  75. scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
  76. void *pBuffer
  77. template <typename scalar_t>
  78. inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
  79. TORCH_INTERNAL_ASSERT(
  80. false,
  81. "at::cuda::sparse::csrgeam2: not implemented for ",
  82. typeid(scalar_t).name());
  83. }
  84. template <>
  85. void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float));
  86. template <>
  87. void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double));
  88. template <>
  89. void csrgeam2<c10::complex<float>>(
  90. CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>));
  91. template <>
  92. void csrgeam2<c10::complex<double>>(
  93. CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>));
  94. #define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \
  95. cusparseHandle_t handle, cusparseDirection_t dirA, \
  96. cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
  97. int kb, int nnzb, const scalar_t *alpha, \
  98. const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  99. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  100. const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
  101. template <typename scalar_t>
  102. inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
  103. TORCH_INTERNAL_ASSERT(
  104. false,
  105. "at::cuda::sparse::bsrmm: not implemented for ",
  106. typeid(scalar_t).name());
  107. }
  108. template <>
  109. void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float));
  110. template <>
  111. void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double));
  112. template <>
  113. void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
  114. template <>
  115. void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
  116. #define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \
  117. cusparseHandle_t handle, cusparseDirection_t dirA, \
  118. cusparseOperation_t transA, int mb, int nb, int nnzb, \
  119. const scalar_t *alpha, const cusparseMatDescr_t descrA, \
  120. const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
  121. int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
  122. template <typename scalar_t>
  123. inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
  124. TORCH_INTERNAL_ASSERT(
  125. false,
  126. "at::cuda::sparse::bsrmv: not implemented for ",
  127. typeid(scalar_t).name());
  128. }
  129. template <>
  130. void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
  131. template <>
  132. void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
  133. template <>
  134. void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
  135. template <>
  136. void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
  137. #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
  138. #define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \
  139. cusparseHandle_t handle, cusparseDirection_t dirA, \
  140. cusparseOperation_t transA, int mb, int nnzb, \
  141. const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
  142. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  143. bsrsv2Info_t info, int *pBufferSizeInBytes
  144. template <typename scalar_t>
  145. inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
  146. TORCH_INTERNAL_ASSERT(
  147. false,
  148. "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
  149. typeid(scalar_t).name());
  150. }
  151. template <>
  152. void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
  153. template <>
  154. void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
  155. template <>
  156. void bsrsv2_bufferSize<c10::complex<float>>(
  157. CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>));
  158. template <>
  159. void bsrsv2_bufferSize<c10::complex<double>>(
  160. CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>));
  161. #define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \
  162. cusparseHandle_t handle, cusparseDirection_t dirA, \
  163. cusparseOperation_t transA, int mb, int nnzb, \
  164. const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  165. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  166. bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
  167. template <typename scalar_t>
  168. inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
  169. TORCH_INTERNAL_ASSERT(
  170. false,
  171. "at::cuda::sparse::bsrsv2_analysis: not implemented for ",
  172. typeid(scalar_t).name());
  173. }
  174. template <>
  175. void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
  176. template <>
  177. void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
  178. template <>
  179. void bsrsv2_analysis<c10::complex<float>>(
  180. CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>));
  181. template <>
  182. void bsrsv2_analysis<c10::complex<double>>(
  183. CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>));
  184. #define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \
  185. cusparseHandle_t handle, cusparseDirection_t dirA, \
  186. cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
  187. const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  188. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  189. bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \
  190. cusparseSolvePolicy_t policy, void *pBuffer
  191. template <typename scalar_t>
  192. inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
  193. TORCH_INTERNAL_ASSERT(
  194. false,
  195. "at::cuda::sparse::bsrsv2_solve: not implemented for ",
  196. typeid(scalar_t).name());
  197. }
  198. template <>
  199. void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
  200. template <>
  201. void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
  202. template <>
  203. void bsrsv2_solve<c10::complex<float>>(
  204. CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>));
  205. template <>
  206. void bsrsv2_solve<c10::complex<double>>(
  207. CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>));
  208. #define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \
  209. cusparseHandle_t handle, cusparseDirection_t dirA, \
  210. cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
  211. int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
  212. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  213. bsrsm2Info_t info, int *pBufferSizeInBytes
  214. template <typename scalar_t>
  215. inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
  216. TORCH_INTERNAL_ASSERT(
  217. false,
  218. "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
  219. typeid(scalar_t).name());
  220. }
  221. template <>
  222. void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
  223. template <>
  224. void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
  225. template <>
  226. void bsrsm2_bufferSize<c10::complex<float>>(
  227. CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>));
  228. template <>
  229. void bsrsm2_bufferSize<c10::complex<double>>(
  230. CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>));
  231. #define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \
  232. cusparseHandle_t handle, cusparseDirection_t dirA, \
  233. cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
  234. int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
  235. const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
  236. bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
  237. template <typename scalar_t>
  238. inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
  239. TORCH_INTERNAL_ASSERT(
  240. false,
  241. "at::cuda::sparse::bsrsm2_analysis: not implemented for ",
  242. typeid(scalar_t).name());
  243. }
  244. template <>
  245. void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
  246. template <>
  247. void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
  248. template <>
  249. void bsrsm2_analysis<c10::complex<float>>(
  250. CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>));
  251. template <>
  252. void bsrsm2_analysis<c10::complex<double>>(
  253. CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>));
  254. #define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \
  255. cusparseHandle_t handle, cusparseDirection_t dirA, \
  256. cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
  257. int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \
  258. const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
  259. int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \
  260. scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
  261. template <typename scalar_t>
  262. inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
  263. TORCH_INTERNAL_ASSERT(
  264. false,
  265. "at::cuda::sparse::bsrsm2_solve: not implemented for ",
  266. typeid(scalar_t).name());
  267. }
  268. template <>
  269. void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
  270. template <>
  271. void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
  272. template <>
  273. void bsrsm2_solve<c10::complex<float>>(
  274. CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>));
  275. template <>
  276. void bsrsm2_solve<c10::complex<double>>(
  277. CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>));
  278. #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
  279. } // namespace at::cuda::sparse
  280. // NOLINTEND(misc-misplaced-const)