CUDAStream.h 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. #pragma once
  2. #include <cuda_runtime_api.h>
  3. #include <c10/core/DeviceGuard.h>
  4. #include <c10/core/Stream.h>
  5. #include <c10/cuda/CUDAFunctions.h>
  6. #include <c10/util/Exception.h>
  7. /*
  8. * Stream pool note.
  9. *
  10. * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
  11. * are backed by cuStreams, but they use several pools to minimize the costs
  12. * associated with creating, retaining, and destroying cuStreams.
  13. *
  14. * There are three pools per device, and a device's pools are lazily created.
  15. *
  16. * The first pool contains only the default stream. When the default stream
  17. * is requested it's returned.
  18. *
  19. * The second pool is the "low priority" or "default priority" streams. In
  20. * HIP builds there is no distinction between streams in this pool and streams
  21. * in the third pool (below). There are 32 of these streams per device, and
  22. * when a stream is requested one of these streams is returned round-robin.
  23. * That is, the first stream requested is at index 0, the second at index 1...
  24. * to index 31, then index 0 again.
  25. *
  26. * This means that if 33 low priority streams are requested, the first and
  27. * last streams requested are actually the same stream (under the covers)
  28. * and kernels enqueued on them cannot run concurrently.
  29. *
  30. * The third pool is the "high priority" streams. The third pool acts like
  31. * the second pool except the streams are created with a higher priority.
  32. *
  33. * These pools suggest that stream users should prefer many short-lived streams,
  34. * as the cost of acquiring and releasing streams is effectively zero. If
  35. * many longer-lived streams are required in performance critical scenarios
  36. * then the functionality here may need to be extended to allow, for example,
  37. * "reserving" a subset of the pool so that other streams do not accidentally
  38. * overlap the performance critical streams.
  39. *
  40. * Note: although the notion of "current stream for device" is thread local
  41. * (every OS thread has a separate current stream, as one might expect),
  42. * the stream pool is global across all threads; stream 0 is always stream 0
  43. * no matter which thread you use it on. Multiple threads can synchronize
  44. * on the same stream. Although the CUDA documentation is not very clear
  45. * on the matter, streams are thread safe; e.g., it is safe to enqueue
  46. * a kernel on the same stream from two different threads.
  47. */
  48. namespace c10::cuda {
  49. static constexpr int max_compile_time_stream_priorities = 4;
  50. // Value object representing a CUDA stream. This is just a wrapper
  51. // around c10::Stream, but it comes with a little extra CUDA-specific
  52. // functionality (conversion to cudaStream_t), and a guarantee that
  53. // the wrapped c10::Stream really is a CUDA stream.
  54. class C10_CUDA_API CUDAStream {
  55. public:
  56. enum Unchecked { UNCHECKED };
  57. /// Construct a CUDAStream from a Stream. This construction is checked,
  58. /// and will raise an error if the Stream is not, in fact, a CUDA stream.
  59. explicit CUDAStream(Stream stream) : stream_(stream) {
  60. TORCH_CHECK(stream_.device_type() == DeviceType::CUDA);
  61. }
  62. /// Construct a CUDAStream from a Stream with no error checking.
  63. /// This constructor uses the "named" constructor idiom, and can
  64. /// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream)
  65. explicit CUDAStream(Unchecked, Stream stream) : stream_(stream) {}
  66. bool operator==(const CUDAStream& other) const noexcept {
  67. return unwrap() == other.unwrap();
  68. }
  69. bool operator!=(const CUDAStream& other) const noexcept {
  70. return unwrap() != other.unwrap();
  71. }
  72. /// Implicit conversion to cudaStream_t.
  73. operator cudaStream_t() const {
  74. return stream();
  75. }
  76. /// Implicit conversion to Stream (a.k.a., forget that the stream is a
  77. /// CUDA stream).
  78. operator Stream() const {
  79. return unwrap();
  80. }
  81. /// Used to avoid baking in device type explicitly to Python-side API.
  82. DeviceType device_type() const {
  83. return DeviceType::CUDA;
  84. }
  85. /// Get the CUDA device index that this stream is associated with.
  86. DeviceIndex device_index() const {
  87. return stream_.device_index();
  88. }
  89. /// Get the full Device that this stream is associated with. The Device
  90. /// is guaranteed to be a CUDA device.
  91. Device device() const {
  92. return Device(DeviceType::CUDA, device_index());
  93. }
  94. /// Return the stream ID corresponding to this particular stream.
  95. StreamId id() const {
  96. return stream_.id();
  97. }
  98. bool query() const {
  99. DeviceGuard guard{stream_.device()};
  100. cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream()));
  101. if (err == cudaSuccess) {
  102. return true;
  103. } else if (err != cudaErrorNotReady) {
  104. C10_CUDA_CHECK(err);
  105. } else {
  106. // ignore and clear the error if not ready
  107. (void)cudaGetLastError();
  108. }
  109. return false;
  110. }
  111. void synchronize() const {
  112. DeviceGuard guard{stream_.device()};
  113. c10::cuda::stream_synchronize(stream());
  114. }
  115. int priority() const {
  116. DeviceGuard guard{stream_.device()};
  117. int priority = 0;
  118. C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
  119. return priority;
  120. }
  121. /// Explicit conversion to cudaStream_t.
  122. cudaStream_t stream() const;
  123. /// Explicit conversion to Stream.
  124. Stream unwrap() const {
  125. return stream_;
  126. }
  127. /// Reversibly pack a CUDAStream into a struct representation.
  128. /// Previously the stream's data was packed into a single int64_t,
  129. /// as it was assumed the fields would not require more than
  130. /// 64 bits of storage in total.
  131. /// See https://github.com/pytorch/pytorch/issues/75854
  132. /// for more information regarding newer platforms that may violate
  133. /// this assumption.
  134. ///
  135. /// The CUDAStream can be unpacked using unpack().
  136. struct c10::StreamData3 pack3() const {
  137. return stream_.pack3();
  138. }
  139. // Unpack a CUDAStream from the 3 fields generated by pack().
  140. static CUDAStream unpack3(
  141. StreamId stream_id,
  142. DeviceIndex device_index,
  143. DeviceType device_type) {
  144. return CUDAStream(Stream::unpack3(stream_id, device_index, device_type));
  145. }
  146. static std::tuple<int, int> priority_range() {
  147. // Note: this returns the range of priority **supported by PyTorch**, not
  148. // the range of priority **supported by CUDA**. The former is a subset of
  149. // the latter.
  150. int least_priority = 0, greatest_priority = 0;
  151. C10_CUDA_CHECK(
  152. cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
  153. #ifdef USE_ROCM
  154. // See Note [HIP stream priorities]
  155. TORCH_INTERNAL_ASSERT(
  156. least_priority == 1, "Unexpected HIP stream priority range");
  157. least_priority = 0;
  158. #else
  159. TORCH_INTERNAL_ASSERT(
  160. least_priority == 0, "Unexpected CUDA stream priority range");
  161. #endif
  162. TORCH_INTERNAL_ASSERT(
  163. greatest_priority <= -1, "Unexpected CUDA stream priority range");
  164. greatest_priority = std::max(
  165. -c10::cuda::max_compile_time_stream_priorities + 1, greatest_priority);
  166. return std::make_tuple(least_priority, greatest_priority);
  167. }
  168. // Deleted for now; use CUDAEvent::block instead
  169. // void synchronize_with(const CUDAEvent& event) const;
  170. private:
  171. Stream stream_;
  172. };
  173. /**
  174. * Get a new stream from the CUDA stream pool. You can think of this
  175. * as "creating" a new stream, but no such creation actually happens;
  176. * instead, streams are preallocated from the pool and returned in a
  177. * round-robin fashion.
  178. *
  179. * You can request a stream from the high priority pool by setting
  180. * isHighPriority to true, or a stream for a specific device by setting device
  181. * (defaulting to the current CUDA stream.)
  182. */
  183. C10_API CUDAStream
  184. getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
  185. // no default priority to disambiguate overloads
  186. C10_API CUDAStream
  187. getStreamFromPool(const int priority, DeviceIndex device = -1);
  188. /**
  189. * Get a CUDAStream from a externally allocated one.
  190. *
  191. * This is mainly for interoperability with different libraries where we
  192. * want to operate on a non-torch allocated stream for data exchange or similar
  193. * purposes
  194. */
  195. C10_API CUDAStream
  196. getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index);
  197. /**
  198. * Get the default CUDA stream, for the passed CUDA device, or for the
  199. * current device if no device index is passed. The default stream is
  200. * where most computation occurs when you aren't explicitly using
  201. * streams.
  202. */
  203. C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1);
  204. /**
  205. * Get the current CUDA stream, for the passed CUDA device, or for the
  206. * current device if no device index is passed. The current CUDA stream
  207. * will usually be the default CUDA stream for the device, but it may
  208. * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
  209. * or 'CUDAStreamGuard'.
  210. */
  211. C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1);
  212. /**
  213. * Set the current stream on the device of the passed in stream to be
  214. * the passed in stream. Yes, you read that right: this function
  215. * has *nothing* to do with the current device: it toggles the current
  216. * stream of the device of the passed stream.
  217. *
  218. * Confused? Avoid using this function; prefer using 'CUDAStreamGuard' instead
  219. * (which will switch both your current device and current stream in the way you
  220. * expect, and reset it back to its original state afterwards).
  221. */
  222. C10_API void setCurrentCUDAStream(CUDAStream stream);
  223. C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s);
  224. } // namespace c10::cuda
  225. namespace std {
  226. template <>
  227. struct hash<c10::cuda::CUDAStream> {
  228. size_t operator()(c10::cuda::CUDAStream s) const noexcept {
  229. return std::hash<c10::Stream>{}(s.unwrap());
  230. }
  231. };
  232. } // namespace std