CUDAGraphsC10Utils.h 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #pragma once
  2. #include <c10/cuda/CUDAStream.h>
  3. #include <iostream>
  4. #include <utility>
  5. // CUDA Graphs utils used by c10 and aten.
  6. // aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only.
  7. namespace c10::cuda {
  8. // RAII guard for "cudaStreamCaptureMode", a thread-local value
  9. // that controls the error-checking strictness of a capture.
  10. struct C10_CUDA_API CUDAStreamCaptureModeGuard {
  11. CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired)
  12. : strictness_(desired) {
  13. C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_));
  14. }
  15. CUDAStreamCaptureModeGuard(const CUDAStreamCaptureModeGuard&) = delete;
  16. CUDAStreamCaptureModeGuard(CUDAStreamCaptureModeGuard&&) = delete;
  17. CUDAStreamCaptureModeGuard& operator=(const CUDAStreamCaptureModeGuard&) =
  18. delete;
  19. CUDAStreamCaptureModeGuard& operator=(CUDAStreamCaptureModeGuard&&) = delete;
  20. ~CUDAStreamCaptureModeGuard() {
  21. C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
  22. }
  23. private:
  24. cudaStreamCaptureMode strictness_;
  25. };
  26. // Protects against enum cudaStreamCaptureStatus implementation changes.
  27. // Some compilers seem not to like static_assert without the messages.
  28. static_assert(
  29. int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
  30. "unexpected int(cudaStreamCaptureStatusNone) value");
  31. static_assert(
  32. int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
  33. "unexpected int(cudaStreamCaptureStatusActive) value");
  34. static_assert(
  35. int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
  36. "unexpected int(cudaStreamCaptureStatusInvalidated) value");
  37. enum class CaptureStatus : int {
  38. None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
  39. Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
  40. Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
  41. };
  42. inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
  43. switch (status) {
  44. case CaptureStatus::None:
  45. os << "cudaStreamCaptureStatusNone";
  46. break;
  47. case CaptureStatus::Active:
  48. os << "cudaStreamCaptureStatusActive";
  49. break;
  50. case CaptureStatus::Invalidated:
  51. os << "cudaStreamCaptureStatusInvalidated";
  52. break;
  53. default:
  54. TORCH_INTERNAL_ASSERT(
  55. false, "Unknown CUDA graph CaptureStatus", int(status));
  56. }
  57. return os;
  58. }
  59. // Use this version where you're sure a CUDA context exists already.
  60. inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
  61. cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone};
  62. C10_CUDA_CHECK(
  63. cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
  64. return CaptureStatus(is_capturing);
  65. }
  66. } // namespace c10::cuda