CUDAEvent.h 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. #pragma once
  2. #include <ATen/cuda/ATenCUDAGeneral.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <c10/core/impl/GPUTrace.h>
  5. #include <c10/cuda/CUDAStream.h>
  6. #include <c10/cuda/CUDAGuard.h>
  7. #include <ATen/cuda/Exceptions.h>
  8. #include <c10/util/Exception.h>
  9. #include <cuda_runtime_api.h>
  10. #include <cstdint>
  11. #include <utility>
  12. /*
  13. * `cudaEventExternal` is a torch-specific flag that is used to
  14. * indicate that the CUDAEvent will be used only for synchronization
  15. * with work outside of the cuda graph, rather than creation of
  16. * cross-stream dependencies within a cuda graph. Resources:
  17. * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events
  18. * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47
  19. * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e
  20. */
  21. #define cudaEventExternal 0x08
  22. namespace at::cuda {
  23. /*
  24. * CUDAEvents are movable not copyable wrappers around CUDA's events.
  25. *
  26. * CUDAEvents are constructed lazily when first recorded unless it is
  27. * reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
  28. * device is acquired from the first recording stream. However, if reconstructed
  29. * from a handle, the device should be explicitly specified; or if ipc_handle() is
  30. * called before the event is ever recorded, it will use the current device.
  31. * Later streams that record the event must match this device.
  32. */
  33. struct TORCH_CUDA_CPP_API CUDAEvent {
  34. // Constructors
  35. // Default value for `flags` is specified below - it's cudaEventDisableTiming
  36. CUDAEvent() noexcept = default;
  37. CUDAEvent(unsigned int flags) noexcept : flags_{flags} {}
  38. CUDAEvent(
  39. DeviceIndex device_index, const cudaIpcEventHandle_t* handle) : device_index_(device_index) {
  40. CUDAGuard guard(device_index_);
  41. AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
  42. is_created_ = true;
  43. }
  44. // Note: event destruction done on creating device to avoid creating a
  45. // CUDA context on other devices.
  46. ~CUDAEvent() {
  47. try {
  48. if (is_created_) {
  49. CUDAGuard guard(device_index_);
  50. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  51. if (C10_UNLIKELY(interp)) {
  52. (*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
  53. }
  54. AT_CUDA_CHECK(cudaEventDestroy(event_));
  55. }
  56. } catch (...) { /* No throw */ }
  57. }
  58. CUDAEvent(const CUDAEvent&) = delete;
  59. CUDAEvent& operator=(const CUDAEvent&) = delete;
  60. CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); }
  61. CUDAEvent& operator=(CUDAEvent&& other) noexcept {
  62. if (this != &other) {
  63. moveHelper(std::move(other));
  64. }
  65. return *this;
  66. }
  67. operator cudaEvent_t() const { return event(); }
  68. // Less than operator (to allow use in sets)
  69. friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
  70. return left.event_ < right.event_;
  71. }
  72. std::optional<at::Device> device() const {
  73. if (is_created_) {
  74. return at::Device(at::kCUDA, device_index_);
  75. } else {
  76. return {};
  77. }
  78. }
  79. bool isCreated() const { return is_created_; }
  80. DeviceIndex device_index() const {return device_index_;}
  81. cudaEvent_t event() const { return event_; }
  82. // Note: cudaEventQuery can be safely called from any device
  83. bool query() const {
  84. if (!is_created_) {
  85. return true;
  86. }
  87. cudaError_t err = cudaEventQuery(event_);
  88. if (err == cudaSuccess) {
  89. return true;
  90. } else if (err != cudaErrorNotReady) {
  91. C10_CUDA_CHECK(err);
  92. } else {
  93. // ignore and clear the error if not ready
  94. (void)cudaGetLastError();
  95. }
  96. return false;
  97. }
  98. void record() { record(getCurrentCUDAStream()); }
  99. void recordOnce(const CUDAStream& stream) {
  100. if (!was_recorded_) record(stream);
  101. }
  102. // Note: cudaEventRecord must be called on the same device as the event.
  103. void record(const CUDAStream& stream) {
  104. if (!is_created_) {
  105. createEvent(stream.device_index());
  106. }
  107. TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_,
  108. " does not match recording stream's device ", stream.device_index(), ".");
  109. CUDAGuard guard(device_index_);
  110. #ifndef USE_ROCM
  111. // it is an error to use cudaEventRecordExternal when not doing stream capture
  112. unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventRecordExternal : cudaEventRecordDefault;
  113. AT_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags));
  114. #else
  115. AT_CUDA_CHECK(cudaEventRecord(event_, stream));
  116. #endif
  117. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  118. if (C10_UNLIKELY(interp)) {
  119. (*interp)->trace_gpu_event_record(at::kCUDA,
  120. reinterpret_cast<uintptr_t>(event_),
  121. reinterpret_cast<uintptr_t>(stream.stream())
  122. );
  123. }
  124. was_recorded_ = true;
  125. }
  126. // Note: cudaStreamWaitEvent must be called on the same device as the stream.
  127. // The event has no actual GPU resources associated with it.
  128. void block(const CUDAStream& stream) {
  129. if (is_created_) {
  130. CUDAGuard guard(stream.device_index());
  131. #ifndef USE_ROCM
  132. // it is an error to use cudaEventWaitExternal when not doing stream capture
  133. unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventWaitExternal : cudaEventWaitDefault;
  134. AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags));
  135. #else
  136. AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_));
  137. #endif
  138. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  139. if (C10_UNLIKELY(interp)) {
  140. (*interp)->trace_gpu_event_wait(at::kCUDA,
  141. reinterpret_cast<uintptr_t>(event_),
  142. reinterpret_cast<uintptr_t>(stream.stream())
  143. );
  144. }
  145. }
  146. }
  147. // Note: cudaEventElapsedTime can be safely called from any device
  148. float elapsed_time(const CUDAEvent& other) const {
  149. TORCH_CHECK_VALUE(
  150. !(flags_ & cudaEventDisableTiming) && !(other.flags_ & cudaEventDisableTiming),
  151. "Both events must be created with argument 'enable_timing=True'.");
  152. TORCH_CHECK_VALUE(
  153. is_created_ && other.isCreated(),
  154. "Both events must be recorded before calculating elapsed time.");
  155. TORCH_CHECK(
  156. query() && other.query(),
  157. "Both events must be completed before calculating elapsed time.");
  158. float time_ms = 0;
  159. // We do not strictly have to set the device index to the same as our event,
  160. // but if we don't and the current device is not initialized, it will
  161. // create a new cuda context, which will consume a lot of memory.
  162. CUDAGuard guard(device_index_);
  163. // raise cudaErrorNotReady if either event is recorded but not yet completed
  164. AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
  165. return time_ms;
  166. }
  167. // Note: cudaEventSynchronize can be safely called from any device
  168. void synchronize() const {
  169. if (is_created_) {
  170. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  171. if (C10_UNLIKELY(interp)) {
  172. (*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
  173. }
  174. AT_CUDA_CHECK(cudaEventSynchronize(event_));
  175. }
  176. }
  177. // Note: cudaIpcGetEventHandle must be called on the same device as the event
  178. void ipc_handle(cudaIpcEventHandle_t * handle) {
  179. if (!is_created_) {
  180. // this CUDAEvent object was initially constructed from flags but event_
  181. // is not created yet.
  182. createEvent(getCurrentCUDAStream().device_index());
  183. }
  184. CUDAGuard guard(device_index_);
  185. AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
  186. }
  187. private:
  188. unsigned int flags_ = cudaEventDisableTiming;
  189. bool is_created_ = false;
  190. bool was_recorded_ = false;
  191. bool external_ = false;
  192. DeviceIndex device_index_ = -1;
  193. cudaEvent_t event_{};
  194. void createEvent(DeviceIndex device_index) {
  195. external_ = (flags_ & cudaEventExternal) != 0;
  196. #ifdef USE_ROCM
  197. TORCH_CHECK(!external_, "External events are disallowed in rocm");
  198. #endif
  199. flags_ &= ~cudaEventExternal;
  200. device_index_ = device_index;
  201. CUDAGuard guard(device_index_);
  202. AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
  203. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  204. if (C10_UNLIKELY(interp)) {
  205. (*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
  206. }
  207. is_created_ = true;
  208. }
  209. void moveHelper(CUDAEvent&& other) {
  210. std::swap(flags_, other.flags_);
  211. std::swap(is_created_, other.is_created_);
  212. std::swap(was_recorded_, other.was_recorded_);
  213. std::swap(device_index_, other.device_index_);
  214. std::swap(event_, other.event_);
  215. }
  216. };
  217. } // namespace at::cuda