CUDAFunctions.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #pragma once
  2. // This header provides C++ wrappers around commonly used CUDA API functions.
  3. // The benefit of using C++ here is that we can raise an exception in the
  4. // event of an error, rather than explicitly pass around error codes. This
  5. // leads to more natural APIs.
  6. //
  7. // The naming convention used here matches the naming convention of torch.cuda
  8. #include <c10/core/Device.h>
  9. #include <c10/core/impl/GPUTrace.h>
  10. #include <c10/cuda/CUDAException.h>
  11. #include <c10/cuda/CUDAMacros.h>
  12. #include <cuda_runtime_api.h>
  13. namespace c10::cuda {
  14. // NB: In the past, we were inconsistent about whether or not this reported
  15. // an error if there were driver problems are not. Based on experience
  16. // interacting with users, it seems that people basically ~never want this
  17. // function to fail; it should just return zero if things are not working.
  18. // Oblige them.
  19. // It still might log a warning for user first time it's invoked
  20. C10_CUDA_API DeviceIndex device_count() noexcept;
  21. // Version of device_count that throws is no devices are detected
  22. C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
  23. C10_CUDA_API DeviceIndex current_device();
  24. C10_CUDA_API void set_device(DeviceIndex device, const bool force = false);
  25. C10_CUDA_API void device_synchronize();
  26. C10_CUDA_API void warn_or_error_on_sync();
  27. // Raw CUDA device management functions
  28. C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
  29. C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
  30. C10_CUDA_API cudaError_t
  31. SetDevice(DeviceIndex device, const bool force = false);
  32. C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);
  33. C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device);
  34. C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device);
  35. C10_CUDA_API void SetTargetDevice();
  36. enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
  37. // this is a holder for c10 global state (similar to at GlobalContext)
  38. // currently it's used to store cuda synchronization warning state,
  39. // but can be expanded to hold other related global state, e.g. to
  40. // record stream usage
  41. class WarningState {
  42. public:
  43. void set_sync_debug_mode(SyncDebugMode l) {
  44. sync_debug_mode = l;
  45. }
  46. SyncDebugMode get_sync_debug_mode() {
  47. return sync_debug_mode;
  48. }
  49. private:
  50. SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
  51. };
  52. C10_CUDA_API __inline__ WarningState& warning_state() {
  53. static WarningState warning_state_;
  54. return warning_state_;
  55. }
  56. // the subsequent functions are defined in the header because for performance
  57. // reasons we want them to be inline
  58. C10_CUDA_API void __inline__ memcpy_and_sync(
  59. void* dst,
  60. const void* src,
  61. int64_t nbytes,
  62. cudaMemcpyKind kind,
  63. cudaStream_t stream) {
  64. if (C10_UNLIKELY(
  65. warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
  66. warn_or_error_on_sync();
  67. }
  68. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  69. if (C10_UNLIKELY(interp)) {
  70. (*interp)->trace_gpu_stream_synchronization(
  71. c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
  72. }
  73. #if defined(USE_ROCM) && USE_ROCM
  74. // As of ROCm 6.4.1, HIP runtime does not raise an error during capture of
  75. // hipMemcpyWithStream which is a synchronous call. Thus, we add a check
  76. // here explicitly.
  77. hipStreamCaptureStatus captureStatus;
  78. C10_CUDA_CHECK(hipStreamGetCaptureInfo(stream, &captureStatus, nullptr));
  79. if (C10_LIKELY(captureStatus == hipStreamCaptureStatusNone)) {
  80. C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
  81. } else {
  82. C10_CUDA_CHECK(hipErrorStreamCaptureUnsupported);
  83. }
  84. #else
  85. C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
  86. C10_CUDA_CHECK(cudaStreamSynchronize(stream));
  87. #endif
  88. }
  89. C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
  90. if (C10_UNLIKELY(
  91. warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
  92. warn_or_error_on_sync();
  93. }
  94. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  95. if (C10_UNLIKELY(interp)) {
  96. (*interp)->trace_gpu_stream_synchronization(
  97. c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
  98. }
  99. C10_CUDA_CHECK(cudaStreamSynchronize(stream));
  100. }
  101. C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
  102. C10_CUDA_API std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();
  103. } // namespace c10::cuda