| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- #pragma once
- #include <c10/cuda/CUDAStream.h>
- #include <iostream>
- #include <utility>
- // CUDA Graphs utils used by c10 and aten.
- // aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only.
- namespace c10::cuda {
- // RAII guard for "cudaStreamCaptureMode", a thread-local value
- // that controls the error-checking strictness of a capture.
- struct C10_CUDA_API CUDAStreamCaptureModeGuard {
- CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired)
- : strictness_(desired) {
- C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_));
- }
- CUDAStreamCaptureModeGuard(const CUDAStreamCaptureModeGuard&) = delete;
- CUDAStreamCaptureModeGuard(CUDAStreamCaptureModeGuard&&) = delete;
- CUDAStreamCaptureModeGuard& operator=(const CUDAStreamCaptureModeGuard&) =
- delete;
- CUDAStreamCaptureModeGuard& operator=(CUDAStreamCaptureModeGuard&&) = delete;
- ~CUDAStreamCaptureModeGuard() {
- C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
- }
- private:
- cudaStreamCaptureMode strictness_;
- };
- // Protects against enum cudaStreamCaptureStatus implementation changes.
- // Some compilers seem not to like static_assert without the messages.
- static_assert(
- int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
- "unexpected int(cudaStreamCaptureStatusNone) value");
- static_assert(
- int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
- "unexpected int(cudaStreamCaptureStatusActive) value");
- static_assert(
- int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
- "unexpected int(cudaStreamCaptureStatusInvalidated) value");
- enum class CaptureStatus : int {
- None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
- Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
- Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
- };
- inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
- switch (status) {
- case CaptureStatus::None:
- os << "cudaStreamCaptureStatusNone";
- break;
- case CaptureStatus::Active:
- os << "cudaStreamCaptureStatusActive";
- break;
- case CaptureStatus::Invalidated:
- os << "cudaStreamCaptureStatusInvalidated";
- break;
- default:
- TORCH_INTERNAL_ASSERT(
- false, "Unknown CUDA graph CaptureStatus", int(status));
- }
- return os;
- }
- // Use this version where you're sure a CUDA context exists already.
- inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
- cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone};
- C10_CUDA_CHECK(
- cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
- return CaptureStatus(is_capturing);
- }
- } // namespace c10::cuda
|