CPUBlas.h 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. #pragma once
  2. #include <ATen/OpMathType.h>
  3. #include <ATen/native/DispatchStub.h>
  4. #include <ATen/native/TransposeType.h>
  5. #include <c10/util/complex.h>
  6. #include <c10/core/ScalarType.h>
  7. #include <c10/core/Scalar.h>
  8. namespace at::native::cpublas {
  9. namespace internal {
  10. void normalize_last_dims(
  11. TransposeType transa, TransposeType transb,
  12. int64_t m, int64_t n, int64_t k,
  13. int64_t *lda, int64_t *ldb, int64_t *ldc);
  14. } // namespace internal
  15. using gemm_fn = void(*)(
  16. at::ScalarType type,
  17. TransposeType transa, TransposeType transb,
  18. int64_t m, int64_t n, int64_t k,
  19. const Scalar& alpha,
  20. const void *a, int64_t lda,
  21. const void *b, int64_t ldb,
  22. const Scalar& beta,
  23. void *c, int64_t ldc);
  24. DECLARE_DISPATCH(gemm_fn, gemm_stub)
  25. using gemm_no_downcast_fn = void(*)(
  26. at::ScalarType type,
  27. TransposeType transa, TransposeType transb,
  28. int64_t m, int64_t n, int64_t k,
  29. const Scalar& alpha,
  30. const void *a, int64_t lda,
  31. const void *b, int64_t ldb,
  32. const Scalar& beta,
  33. void *c, int64_t ldc);
  34. DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub)
  35. template <typename scalar_t>
  36. void gemm(
  37. TransposeType transa, TransposeType transb,
  38. int64_t m, int64_t n, int64_t k,
  39. at::opmath_type<scalar_t> alpha,
  40. const scalar_t *a, int64_t lda,
  41. const scalar_t *b, int64_t ldb,
  42. at::opmath_type<scalar_t> beta,
  43. scalar_t *c, int64_t ldc) {
  44. internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
  45. gemm_stub(
  46. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  47. transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  48. }
  49. void gemm(
  50. TransposeType transa, TransposeType transb,
  51. int64_t m, int64_t n, int64_t k,
  52. double alpha,
  53. const double *a, int64_t lda,
  54. const double *b, int64_t ldb,
  55. double beta,
  56. double *c, int64_t ldc);
  57. void gemm(
  58. TransposeType transa, TransposeType transb,
  59. int64_t m, int64_t n, int64_t k,
  60. float alpha,
  61. const float *a, int64_t lda,
  62. const float *b, int64_t ldb,
  63. float beta,
  64. float *c, int64_t ldc);
  65. void gemm(
  66. TransposeType transa, TransposeType transb,
  67. int64_t m, int64_t n, int64_t k,
  68. float alpha,
  69. const at::BFloat16 *a, int64_t lda,
  70. const at::BFloat16 *b, int64_t ldb,
  71. float beta,
  72. at::BFloat16 *c, int64_t ldc);
  73. void gemm(
  74. TransposeType transa, TransposeType transb,
  75. int64_t m, int64_t n, int64_t k,
  76. const float alpha,
  77. const at::BFloat16 *a, int64_t lda,
  78. const at::BFloat16 *b, int64_t ldb,
  79. const float beta,
  80. float *c, int64_t ldc);
  81. void gemm(
  82. TransposeType transa, TransposeType transb,
  83. int64_t m, int64_t n, int64_t k,
  84. float alpha,
  85. const at::Half *a, int64_t lda,
  86. const at::Half *b, int64_t ldb,
  87. float beta,
  88. at::Half *c, int64_t ldc);
  89. void gemm(
  90. TransposeType transa, TransposeType transb,
  91. int64_t m, int64_t n, int64_t k,
  92. const float alpha,
  93. const at::Half *a, int64_t lda,
  94. const at::Half *b, int64_t ldb,
  95. const float beta,
  96. float *c, int64_t ldc);
  97. void gemm(
  98. TransposeType transa, TransposeType transb,
  99. int64_t m, int64_t n, int64_t k,
  100. c10::complex<double> alpha,
  101. const c10::complex<double> *a, int64_t lda,
  102. const c10::complex<double> *b, int64_t ldb,
  103. c10::complex<double> beta,
  104. c10::complex<double> *c, int64_t ldc);
  105. void gemm(
  106. TransposeType transa, TransposeType transb,
  107. int64_t m, int64_t n, int64_t k,
  108. c10::complex<float> alpha,
  109. const c10::complex<float> *a, int64_t lda,
  110. const c10::complex<float> *b, int64_t ldb,
  111. c10::complex<float> beta,
  112. c10::complex<float> *c, int64_t ldc);
  113. void gemm(
  114. TransposeType transa, TransposeType transb,
  115. int64_t m, int64_t n, int64_t k,
  116. int64_t alpha,
  117. const int64_t *a, int64_t lda,
  118. const int64_t *b, int64_t ldb,
  119. int64_t beta,
  120. int64_t *c, int64_t ldc);
  121. template <typename scalar_t>
  122. void gemm_batched(
  123. TransposeType transa, TransposeType transb,
  124. int64_t batch_size, int64_t m, int64_t n, int64_t k,
  125. scalar_t alpha,
  126. const scalar_t * const *a, int64_t lda,
  127. const scalar_t * const *b, int64_t ldb,
  128. const scalar_t beta,
  129. scalar_t * const *c, int64_t ldc);
  130. template <typename scalar_t>
  131. void gemm_batched_with_stride(
  132. TransposeType transa, TransposeType transb,
  133. int64_t batch_size, int64_t m, int64_t n, int64_t k,
  134. scalar_t alpha,
  135. const scalar_t *a, int64_t lda, int64_t batch_stride_a,
  136. const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
  137. scalar_t beta,
  138. scalar_t *c, int64_t ldc, int64_t batch_stride_c);
  139. using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
  140. DECLARE_DISPATCH(axpy_fn, axpy_stub)
  141. template<typename scalar_t>
  142. void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
  143. if(n == 1)
  144. {
  145. incx = 1;
  146. incy = 1;
  147. }
  148. axpy_stub(
  149. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  150. n, a, x, incx, y, incy);
  151. }
  152. void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
  153. void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
  154. void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
  155. void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
  156. using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
  157. DECLARE_DISPATCH(copy_fn, copy_stub)
  158. template<typename scalar_t>
  159. void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
  160. if(n == 1)
  161. {
  162. incx = 1;
  163. incy = 1;
  164. }
  165. copy_stub(
  166. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  167. n, x, incx, y, incy);
  168. }
  169. void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
  170. void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
  171. void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
  172. void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
  173. // Batch-reduce GEMM
  174. // Operates by the following formula:
  175. // C = SUM(A[i] x B[i]) + C if add_C is true, i = 0 to batch size
  176. // A Base pointer to a tensor A.
  177. // B Base pointer to a tensor B.
  178. // C Pointer to a tensor C (accumulation buffer).
  179. // Note only batch size 1 is used currently
  180. // Define macros for available brgemm APIs
  181. // so that callers can determine which APIs are available
  182. #define CPUBLAS_BRGEMM_F16F16F32 // half * half -> float
  183. #define CPUBLAS_BRGEMM_BF16BF16F32 // bfloat16 * bfloat16 -> float
  184. #define CPUBLAS_BRGEMM_F32F32F32 // float * float -> float
  185. #define CPUBLAS_BRGEMM_U8U8I32 // unsigned char * unsigned char -> int32
  186. #define CPUBLAS_BRGEMM_U8I8I32 // unsigned char * signed char -> int32
  187. #define CPUBLAS_BRGEMM_I8I8I32 // signed char * signed char -> int32
  188. TORCH_API void brgemm(
  189. int64_t M,
  190. int64_t N,
  191. int64_t K,
  192. int64_t ld_a,
  193. int64_t ld_b,
  194. int64_t ld_c,
  195. const bool add_C,
  196. const at::Half* A,
  197. const at::Half* B,
  198. float* C,
  199. bool is_vnni = true);
  200. TORCH_API void brgemm(
  201. int64_t M,
  202. int64_t N,
  203. int64_t K,
  204. int64_t ld_a,
  205. int64_t ld_b,
  206. int64_t ld_c,
  207. const bool add_C,
  208. const at::BFloat16* A,
  209. const at::BFloat16* B,
  210. float* C,
  211. bool is_vnni = true);
  212. TORCH_API void brgemm(
  213. int64_t M,
  214. int64_t N,
  215. int64_t K,
  216. int64_t ld_a,
  217. int64_t ld_b,
  218. int64_t ld_c,
  219. const bool add_C,
  220. const float* A,
  221. const float* B,
  222. float* C,
  223. bool is_vnni = false);
  224. TORCH_API void brgemm(
  225. int64_t M,
  226. int64_t N,
  227. int64_t K,
  228. int64_t ld_a,
  229. int64_t ld_b,
  230. int64_t ld_c,
  231. const bool add_C,
  232. const unsigned char* A,
  233. const unsigned char* B,
  234. int32_t* C,
  235. bool is_vnni = true);
  236. TORCH_API void brgemm(
  237. int64_t M,
  238. int64_t N,
  239. int64_t K,
  240. int64_t ld_a,
  241. int64_t ld_b,
  242. int64_t ld_c,
  243. const bool add_C,
  244. const unsigned char* A,
  245. const signed char* B,
  246. int32_t* C,
  247. bool is_vnni = true);
  248. TORCH_API void brgemm(
  249. int64_t M,
  250. int64_t N,
  251. int64_t K,
  252. int64_t ld_a,
  253. int64_t ld_b,
  254. int64_t ld_c,
  255. const bool add_C,
  256. const signed char* A,
  257. const signed char* B,
  258. int32_t* C,
  259. bool is_vnni = true);
  260. // Release brgemm hardware context
  261. TORCH_API void brgemm_release(bool is_vnni = true);
  262. // Pack B matrix to get better performance if needed
  263. TORCH_API void pack(
  264. int64_t K,
  265. int64_t N,
  266. int64_t ld_in,
  267. int64_t ld_out,
  268. ScalarType dt_in,
  269. ScalarType dt_out,
  270. const void* in,
  271. void* out);
  272. // Whether pack is supported in the platform.
  273. TORCH_API bool could_pack(ScalarType dt_in);
  274. } // namespace at::native::cpublas