CUDASparse.h 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. #if defined(USE_ROCM)
  4. #include <hipsparse/hipsparse-version.h>
  5. #define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
  6. #endif
  7. // cuSparse Generic API added in CUDA 10.1
  8. // Windows support added in CUDA 11.0
  9. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32)))
  10. #define AT_USE_CUSPARSE_GENERIC_API() 1
  11. #else
  12. #define AT_USE_CUSPARSE_GENERIC_API() 0
  13. #endif
  14. // cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0
  15. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
  16. (CUSPARSE_VERSION < 12000)
  17. #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1
  18. #else
  19. #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0
  20. #endif
  21. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
  22. (CUSPARSE_VERSION >= 12000)
  23. #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1
  24. #else
  25. #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0
  26. #endif
  27. #if defined(USE_ROCM)
  28. // hipSparse const API added in v2.4.0
  29. #if HIPSPARSE_VERSION >= 200400
  30. #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1
  31. #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
  32. #define AT_USE_HIPSPARSE_GENERIC_API() 1
  33. #else
  34. #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
  35. #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1
  36. #define AT_USE_HIPSPARSE_GENERIC_API() 1
  37. #endif
  38. #else // USE_ROCM
  39. #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
  40. #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
  41. #define AT_USE_HIPSPARSE_GENERIC_API() 0
  42. #endif // USE_ROCM
  43. // cuSparse Generic API spsv function was added in CUDA 11.3.0
  44. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
  45. #define AT_USE_CUSPARSE_GENERIC_SPSV() 1
  46. #else
  47. #define AT_USE_CUSPARSE_GENERIC_SPSV() 0
  48. #endif
  49. // cuSparse Generic API spsm function was added in CUDA 11.3.1
  50. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600)
  51. #define AT_USE_CUSPARSE_GENERIC_SPSM() 1
  52. #else
  53. #define AT_USE_CUSPARSE_GENERIC_SPSM() 0
  54. #endif
  55. // cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400)
  56. #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400)
  57. #define AT_USE_CUSPARSE_GENERIC_SDDMM() 1
  58. #else
  59. #define AT_USE_CUSPARSE_GENERIC_SDDMM() 0
  60. #endif
  61. // BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0)
  62. #if defined(CUDART_VERSION) || defined(USE_ROCM)
  63. #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1
  64. #else
  65. #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0
  66. #endif