Exceptions.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. #pragma once
  2. #include <cublas_v2.h>
  3. #include <cusparse.h>
  4. #include <c10/macros/Export.h>
  5. #if !defined(USE_ROCM)
  6. #include <cusolver_common.h>
  7. #else
  8. #include <hipsolver/hipsolver.h>
  9. #endif
  10. #if defined(USE_CUDSS)
  11. #include <cudss.h>
  12. #endif
  13. #include <ATen/Context.h>
  14. #include <c10/util/Exception.h>
  15. #include <c10/cuda/CUDAException.h>
  16. namespace c10 {
  17. class CuDNNError : public c10::Error {
  18. using Error::Error;
  19. };
  20. } // namespace c10
  21. #define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \
  22. do { \
  23. auto error_object = EXPR; \
  24. if (!error_object.is_good()) { \
  25. TORCH_CHECK_WITH(CuDNNError, false, \
  26. "cuDNN Frontend error: ", error_object.get_message()); \
  27. } \
  28. } while (0) \
  29. #define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
  30. // See Note [CHECK macro]
  31. #define AT_CUDNN_CHECK(EXPR, ...) \
  32. do { \
  33. cudnnStatus_t status = EXPR; \
  34. if (status != CUDNN_STATUS_SUCCESS) { \
  35. if (status == CUDNN_STATUS_NOT_SUPPORTED) { \
  36. TORCH_CHECK_WITH(CuDNNError, false, \
  37. "cuDNN error: ", \
  38. cudnnGetErrorString(status), \
  39. ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
  40. } else { \
  41. TORCH_CHECK_WITH(CuDNNError, false, \
  42. "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \
  43. } \
  44. } \
  45. } while (0)
  46. namespace at::cuda::blas {
  47. C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
  48. } // namespace at::cuda::blas
  49. #define TORCH_CUDABLAS_CHECK(EXPR) \
  50. do { \
  51. cublasStatus_t __err = EXPR; \
  52. TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \
  53. "CUDA error: ", \
  54. at::cuda::blas::_cublasGetErrorEnum(__err), \
  55. " when calling `" #EXPR "`"); \
  56. } while (0)
  57. const char *cusparseGetErrorString(cusparseStatus_t status);
  58. #define TORCH_CUDASPARSE_CHECK(EXPR) \
  59. do { \
  60. cusparseStatus_t __err = EXPR; \
  61. TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \
  62. "CUDA error: ", \
  63. cusparseGetErrorString(__err), \
  64. " when calling `" #EXPR "`"); \
  65. } while (0)
  66. #if defined(USE_CUDSS)
  67. namespace at::cuda::cudss {
  68. C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error);
  69. } // namespace at::cuda::solver
  70. #define TORCH_CUDSS_CHECK(EXPR) \
  71. do { \
  72. cudssStatus_t __err = EXPR; \
  73. if (__err == CUDSS_STATUS_EXECUTION_FAILED) { \
  74. TORCH_CHECK_LINALG( \
  75. false, \
  76. "cudss error: ", \
  77. at::cuda::cudss::cudssGetErrorMessage(__err), \
  78. ", when calling `" #EXPR "`", \
  79. ". This error may appear if the input matrix contains NaN. ");\
  80. } else { \
  81. TORCH_CHECK( \
  82. __err == CUDSS_STATUS_SUCCESS, \
  83. "cudss error: ", \
  84. at::cuda::cudss::cudssGetErrorMessage(__err), \
  85. ", when calling `" #EXPR "`. "); \
  86. } \
  87. } while (0)
  88. #else
  89. #define TORCH_CUDSS_CHECK(EXPR) EXPR
  90. #endif
  91. namespace at::cuda::solver {
  92. #if !defined(USE_ROCM)
  93. C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
  94. constexpr const char* _cusolver_backend_suggestion = \
  95. "If you keep seeing this error, you may use " \
  96. "`torch.backends.cuda.preferred_linalg_library()` to try " \
  97. "linear algebra operators with other supported backends. " \
  98. "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
  99. // When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
  100. #define TORCH_CUSOLVER_CHECK(EXPR) \
  101. do { \
  102. cusolverStatus_t __err = EXPR; \
  103. if (__err == CUSOLVER_STATUS_INVALID_VALUE) { \
  104. TORCH_CHECK_LINALG( \
  105. false, \
  106. "cusolver error: ", \
  107. at::cuda::solver::cusolverGetErrorMessage(__err), \
  108. ", when calling `" #EXPR "`", \
  109. ". This error may appear if the input matrix contains NaN. ", \
  110. at::cuda::solver::_cusolver_backend_suggestion); \
  111. } else { \
  112. TORCH_CHECK( \
  113. __err == CUSOLVER_STATUS_SUCCESS, \
  114. "cusolver error: ", \
  115. at::cuda::solver::cusolverGetErrorMessage(__err), \
  116. ", when calling `" #EXPR "`. ", \
  117. at::cuda::solver::_cusolver_backend_suggestion); \
  118. } \
  119. } while (0)
  120. #else // defined(USE_ROCM)
  121. C10_EXPORT const char* hipsolverGetErrorMessage(hipsolverStatus_t status);
  122. constexpr const char* _hipsolver_backend_suggestion = \
  123. "If you keep seeing this error, you may use " \
  124. "`torch.backends.cuda.preferred_linalg_library()` to try " \
  125. "linear algebra operators with other supported backends. " \
  126. "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
  127. #define TORCH_CUSOLVER_CHECK(EXPR) \
  128. do { \
  129. hipsolverStatus_t __err = EXPR; \
  130. if (__err == HIPSOLVER_STATUS_INVALID_VALUE) { \
  131. TORCH_CHECK_LINALG( \
  132. false, \
  133. "hipsolver error: ", \
  134. at::cuda::solver::hipsolverGetErrorMessage(__err), \
  135. ", when calling `" #EXPR "`", \
  136. ". This error may appear if the input matrix contains NaN. ", \
  137. at::cuda::solver::_hipsolver_backend_suggestion); \
  138. } else { \
  139. TORCH_CHECK( \
  140. __err == HIPSOLVER_STATUS_SUCCESS, \
  141. "hipsolver error: ", \
  142. at::cuda::solver::hipsolverGetErrorMessage(__err), \
  143. ", when calling `" #EXPR "`. ", \
  144. at::cuda::solver::_hipsolver_backend_suggestion); \
  145. } \
  146. } while (0)
  147. #endif
  148. } // namespace at::cuda::solver
  149. #define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
  150. // For CUDA Driver API
  151. //
  152. // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
  153. // in ATen, and we need to use its nvrtcGetErrorString.
  154. // See NOTE [ USE OF NVRTC AND DRIVER API ].
  155. #if !defined(USE_ROCM)
  156. #define AT_CUDA_DRIVER_CHECK(EXPR) \
  157. do { \
  158. CUresult __err = EXPR; \
  159. if (__err != CUDA_SUCCESS) { \
  160. const char* err_str; \
  161. [[maybe_unused]] CUresult get_error_str_err = \
  162. at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
  163. if (get_error_str_err != CUDA_SUCCESS) { \
  164. TORCH_CHECK(false, "CUDA driver error: unknown error"); \
  165. } else { \
  166. TORCH_CHECK(false, "CUDA driver error: ", err_str); \
  167. } \
  168. } \
  169. } while (0)
  170. #else
  171. #define AT_CUDA_DRIVER_CHECK(EXPR) \
  172. do { \
  173. CUresult __err = EXPR; \
  174. if (__err != CUDA_SUCCESS) { \
  175. TORCH_CHECK(false, "CUDA driver error: ", static_cast<int>(__err)); \
  176. } \
  177. } while (0)
  178. #endif
  179. // For CUDA NVRTC
  180. //
  181. // Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
  182. // incorrectly produces the error string "NVRTC unknown error."
  183. // The following maps it correctly.
  184. //
  185. // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
  186. // in ATen, and we need to use its nvrtcGetErrorString.
  187. // See NOTE [ USE OF NVRTC AND DRIVER API ].
  188. #define AT_CUDA_NVRTC_CHECK(EXPR) \
  189. do { \
  190. nvrtcResult __err = EXPR; \
  191. if (__err != NVRTC_SUCCESS) { \
  192. if (static_cast<int>(__err) != 7) { \
  193. TORCH_CHECK(false, "CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \
  194. } else { \
  195. TORCH_CHECK(false, "CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \
  196. } \
  197. } \
  198. } while (0)