CUDABlas.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. #pragma once
  2. /*
  3. Provides a subset of CUDA BLAS functions as templates:
  4. gemm<Dtype>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
  5. ldc)
  6. gemv<Dtype>(transa, m, n, alpha, a, lda, x, incx, beta, y, incy)
  7. dot<Dtype>(n, x, incx, y, incy, result)
  8. where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot).
  9. The functions are available in at::cuda::blas namespace.
  10. */
  11. #include <ATen/cuda/CUDAContext.h>
  12. #include <ATen/OpMathType.h>
  13. namespace at::cuda::blas {
  14. // RAII guard that sets the CuBLAS pointer mode and restores it to
  15. // its previous value when the guard is destroyed
  16. class PointerModeGuard {
  17. public:
  18. PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
  19. handle(handle) {
  20. TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
  21. TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
  22. }
  23. ~PointerModeGuard() {
  24. cublasSetPointerMode(handle, previous_mode);
  25. }
  26. private:
  27. cublasHandle_t handle;
  28. cublasPointerMode_t previous_mode{};
  29. };
  30. /* LEVEL 3 BLAS FUNCTIONS */
  31. #define CUDABLAS_GEMM_ARGTYPES(Dtype) CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
  32. #define CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype) \
  33. char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
  34. const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
  35. C_Dtype *c, int64_t ldc
  36. #define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc
  37. #define CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT \
  38. ((std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value) && std::is_same<C_Dtype, float>::value)
  39. template <typename Dtype, typename C_Dtype = Dtype, typename std::enable_if<!CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT, Dtype>::type* = nullptr>
  40. inline void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
  41. static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented");
  42. }
  43. template <typename Dtype, typename C_Dtype, typename std::enable_if<CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT, Dtype>::type* = nullptr>
  44. void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype));
  45. template <>
  46. void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
  47. template <>
  48. void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
  49. template <>
  50. void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
  51. template <>
  52. void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
  53. template <>
  54. void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
  55. template <>
  56. void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
  57. template<>
  58. void gemm<at::Half, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float));
  59. template<>
  60. void gemm<at::BFloat16, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float));
  61. template <typename Dtype, typename C_Dtype = Dtype>
  62. inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
  63. static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented");
  64. }
  65. template <>
  66. void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double));
  67. template <>
  68. void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float));
  69. template <>
  70. void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
  71. template <>
  72. void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
  73. template <>
  74. void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
  75. template <>
  76. void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
  77. template<>
  78. void gemm_internal<at::Half, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float));
  79. template<>
  80. void gemm_internal<at::BFloat16, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float));
  81. enum GEMMAndBiasActivationEpilogue {
  82. None,
  83. RELU,
  84. GELU,
  85. };
  86. // NOTE: GELU activation is not supported prior to CUDA 11.4 and will
  87. // do nothing if passed in that case.
  88. template <typename Dtype, typename C_Dtype = Dtype>
  89. bool gemm_and_bias(
  90. bool transpose_mat1,
  91. bool transpose_mat2,
  92. int64_t m,
  93. int64_t n,
  94. int64_t k,
  95. at::opmath_type<Dtype> alpha_val,
  96. const Dtype* mat1_ptr,
  97. int64_t mat1_ld,
  98. const Dtype* mat2_ptr,
  99. int64_t mat2_ld,
  100. const Dtype* bias,
  101. C_Dtype* result_ptr,
  102. int64_t result_ld,
  103. GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
  104. void int8_gemm(
  105. bool transpose_mat1,
  106. bool transpose_mat2,
  107. int64_t m,
  108. int64_t n,
  109. int64_t k,
  110. const int8_t* mat1_ptr,
  111. int64_t mat1_ld,
  112. const int8_t* mat2_ptr,
  113. int64_t mat2_ld,
  114. int32_t* result_ptr,
  115. int64_t result_ld);
  116. enum class ScalingType : std::uint8_t {
  117. TensorWise, // fp32 scales
  118. RowWise, // fp32 scales
  119. BlockWise1x16, // fp8_e4m3fn scales
  120. BlockWise1x32, // fp8_e8m0fnu scales
  121. BlockWise1x128, // fp32 scales
  122. BlockWise128x128, // fp32 scales
  123. };
  124. void scaled_gemm(
  125. char transa,
  126. char transb,
  127. int64_t m,
  128. int64_t n,
  129. int64_t k,
  130. const void* mat1_ptr,
  131. const void* mat1_scale_ptr,
  132. int64_t mat1_ld,
  133. ScalarType mat1_dtype,
  134. ScalarType mat1_scale_dtype,
  135. ScalingType mat1_scaling_type,
  136. const void* mat2_ptr,
  137. const void* mat2_scale_ptr,
  138. int64_t mat2_ld,
  139. ScalarType mat2_dtype,
  140. ScalarType mat2_scale_dtype,
  141. ScalingType mat2_scaling_type,
  142. const void* bias_ptr,
  143. ScalarType bias_dtype,
  144. void* result_ptr,
  145. const void* result_scale_ptr,
  146. int64_t result_ld,
  147. ScalarType result_dtype,
  148. bool use_fast_accum);
  149. #define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
  150. #define CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype) \
  151. char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
  152. const Dtype *a, int64_t lda, int64_t stridea, \
  153. const Dtype *b, int64_t ldb, int64_t strideb, \
  154. at::opmath_type<Dtype> beta, C_Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
  155. #define CUDABLAS_BGEMM_ARGS(Dtype) \
  156. transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches
  157. template <typename Dtype, typename C_Dtype = Dtype, typename std::enable_if<!CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT, Dtype>::type* = nullptr>
  158. inline void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
  159. static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented");
  160. }
  161. template <typename Dtype, typename C_Dtype, typename std::enable_if<CUDABLAS_GEMM_DTYPE_IS_FLOAT_TYPE_AND_C_DTYPE_IS_FLOAT, Dtype>::type* = nullptr>
  162. void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype));
  163. template <>
  164. void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
  165. template <>
  166. void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
  167. template <>
  168. void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
  169. template <>
  170. void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
  171. template <>
  172. void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
  173. template <>
  174. void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
  175. template<>
  176. void bgemm<at::Half, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float));
  177. template<>
  178. void bgemm<at::BFloat16, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float));
  179. template <typename Dtype, typename C_Dtype = Dtype>
  180. inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, C_Dtype)) {
  181. static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented");
  182. }
  183. template <>
  184. void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double));
  185. template <>
  186. void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float));
  187. template <>
  188. void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
  189. template <>
  190. void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
  191. template <>
  192. void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
  193. template <>
  194. void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
  195. template<>
  196. void bgemm_internal<at::Half, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float));
  197. template<>
  198. void bgemm_internal<at::BFloat16, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float));
  199. #define CUDABLAS_TRSM_ARGTYPES(Dtype) \
  200. cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
  201. cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
  202. const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
  203. template <typename Dtype>
  204. inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
  205. static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented");
  206. }
  207. template <>
  208. TORCH_CUDA_CU_API void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float));
  209. template <>
  210. TORCH_CUDA_CU_API void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double));
  211. template <>
  212. TORCH_CUDA_CU_API void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>));
  213. template <>
  214. TORCH_CUDA_CU_API void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>));
  215. #define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \
  216. cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
  217. cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
  218. const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \
  219. int batchCount
  220. template <typename Dtype>
  221. inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) {
  222. static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented");
  223. }
  224. template <>
  225. TORCH_CUDA_CU_API void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float));
  226. template <>
  227. TORCH_CUDA_CU_API void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double));
  228. template <>
  229. TORCH_CUDA_CU_API void trsmBatched<c10::complex<float>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>));
  230. template <>
  231. TORCH_CUDA_CU_API void trsmBatched<c10::complex<double>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>));
  232. /* LEVEL 2 BLAS FUNCTIONS */
  233. #define CUDABLAS_GEMV_ARGTYPES(Dtype) \
  234. char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \
  235. const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy
  236. template <typename Dtype>
  237. inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) {
  238. static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented");
  239. }
  240. template <>
  241. void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
  242. template <>
  243. void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
  244. template <>
  245. void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
  246. template <>
  247. void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
  248. template <>
  249. void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
  250. template <>
  251. void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
  252. /* LEVEL 1 BLAS FUNCTIONS */
  253. #define CUDABLAS_DOT_ARGTYPES(Dtype) \
  254. cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \
  255. int incy, Dtype *result
  256. template <typename Dtype>
  257. inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
  258. static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented");
  259. }
  260. template <>
  261. void dot<double>(CUDABLAS_DOT_ARGTYPES(double));
  262. template <>
  263. void dot<float>(CUDABLAS_DOT_ARGTYPES(float));
  264. template <>
  265. void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half));
  266. template <>
  267. void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16));
  268. template <>
  269. void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
  270. template <>
  271. void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
  272. template <typename Dtype>
  273. inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
  274. static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented");
  275. }
  276. template <>
  277. void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
  278. template <>
  279. void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
  280. #define CUDABLAS_GETRS_ARGTYPES(Dtype) \
  281. cublasHandle_t handle, cublasOperation_t trans, \
  282. int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
  283. Dtype** dB_array, int ldb, int* info_array, int batchsize
  284. #define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \
  285. cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
  286. Dtype **tau_array, int *info, int batchsize
  287. #define CUDABLAS_GETRF_ARGTYPES(Dtype) \
  288. int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
  289. #define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \
  290. cublasHandle_t handle, cublasOperation_t trans, \
  291. int m, int n, int nrhs, Dtype** dA_array, int ldda, \
  292. Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
  293. template<class Dtype>
  294. void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) {
  295. static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented");
  296. }
  297. template<>
  298. TORCH_CUDA_CU_API void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float));
  299. template<>
  300. TORCH_CUDA_CU_API void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double));
  301. template<>
  302. TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>));
  303. template<>
  304. TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>));
  305. template <class Dtype>
  306. void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
  307. static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented");
  308. }
  309. template <>
  310. TORCH_CUDA_CU_API void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float));
  311. template <>
  312. TORCH_CUDA_CU_API void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double));
  313. template <>
  314. TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
  315. CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>));
  316. template <>
  317. TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
  318. CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
  319. template<class Dtype>
  320. void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
  321. static_assert(false&&sizeof(Dtype), "at::cuda::blas::getrfBatched: not implemented");
  322. }
  323. template<>
  324. TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
  325. template<>
  326. TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
  327. template<>
  328. TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
  329. template<>
  330. TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
  331. template <class Dtype>
  332. void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
  333. static_assert(false&&sizeof(Dtype), "at::cuda::blas::gelsBatched: not implemented");
  334. }
  335. template<>
  336. TORCH_CUDA_CU_API void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double));
  337. template<>
  338. TORCH_CUDA_CU_API void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float));
  339. template<>
  340. TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>));
  341. template<>
  342. TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
  343. } // namespace at::cuda::blas