XPUEvent.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. #pragma once
  2. #include <ATen/xpu/XPUContext.h>
  3. #include <optional>
  4. namespace at::xpu {
  5. /*
  6. * XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
  7. * constructed lazily when first recorded. It has a device, and this device is
  8. * acquired from the first recording stream. Later streams that record the event
  9. * must match the same device.
  10. *
  11. * Currently, XPUEvent does NOT support to export an inter-process event from
  12. * another process via inter-process communication(IPC). So it means that
  13. * inter-process communication for event handles between different processes is
  14. * not available. This could impact some applications that rely on cross-process
  15. * synchronization and communication.
  16. */
  17. struct TORCH_XPU_API XPUEvent {
  18. // Constructors
  19. XPUEvent(bool enable_timing = false) noexcept
  20. : enable_timing_{enable_timing} {}
  21. ~XPUEvent() {
  22. if (isCreated()) {
  23. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  24. if (C10_UNLIKELY(interp)) {
  25. (*interp)->trace_gpu_event_deletion(
  26. at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
  27. }
  28. }
  29. }
  30. XPUEvent(const XPUEvent&) = delete;
  31. XPUEvent& operator=(const XPUEvent&) = delete;
  32. XPUEvent(XPUEvent&& other) = default;
  33. XPUEvent& operator=(XPUEvent&& other) = default;
  34. operator sycl::event&() const {
  35. return event();
  36. }
  37. std::optional<at::Device> device() const {
  38. if (isCreated()) {
  39. return at::Device(at::kXPU, device_index_);
  40. } else {
  41. return std::nullopt;
  42. }
  43. }
  44. inline bool isCreated() const {
  45. return (event_.get() != nullptr);
  46. }
  47. DeviceIndex device_index() const {
  48. return device_index_;
  49. }
  50. sycl::event& event() const {
  51. return *event_;
  52. }
  53. bool query() const {
  54. using namespace sycl::info;
  55. if (!isCreated()) {
  56. return true;
  57. }
  58. return event().get_info<event::command_execution_status>() ==
  59. event_command_status::complete;
  60. }
  61. void record() {
  62. record(getCurrentXPUStream());
  63. }
  64. void recordOnce(const XPUStream& stream) {
  65. if (!isCreated()) {
  66. record(stream);
  67. }
  68. }
  69. void record(const XPUStream& stream) {
  70. if (!isCreated()) {
  71. device_index_ = stream.device_index();
  72. assignEvent(stream.queue());
  73. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  74. if (C10_UNLIKELY(interp)) {
  75. (*interp)->trace_gpu_event_creation(
  76. at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
  77. }
  78. } else {
  79. TORCH_CHECK(
  80. device_index_ == stream.device_index(),
  81. "Event device ",
  82. device_index_,
  83. " does not match recording stream's device ",
  84. stream.device_index(),
  85. ".");
  86. reassignEvent(stream.queue());
  87. }
  88. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  89. if (C10_UNLIKELY(interp)) {
  90. (*interp)->trace_gpu_event_record(
  91. at::kXPU,
  92. reinterpret_cast<uintptr_t>(event_.get()),
  93. reinterpret_cast<uintptr_t>(&stream.queue()));
  94. }
  95. }
  96. void block(const XPUStream& stream) {
  97. if (isCreated()) {
  98. std::vector<sycl::event> event_list{event()};
  99. // Make this stream wait until event_ is completed.
  100. stream.queue().ext_oneapi_submit_barrier(event_list);
  101. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  102. if (C10_UNLIKELY(interp)) {
  103. (*interp)->trace_gpu_event_wait(
  104. at::kXPU,
  105. reinterpret_cast<uintptr_t>(event_.get()),
  106. reinterpret_cast<uintptr_t>(&stream.queue()));
  107. }
  108. }
  109. }
  110. double elapsed_time(const XPUEvent& other) const {
  111. TORCH_CHECK(
  112. isCreated() && other.isCreated(),
  113. "Both events must be recorded before calculating elapsed time.");
  114. TORCH_CHECK(
  115. query() && other.query(),
  116. "Both events must be completed before calculating elapsed time.");
  117. TORCH_CHECK(
  118. enable_timing_ && other.enable_timing_,
  119. "Both events must be created with argument 'enable_timing=True'.");
  120. #if SYCL_COMPILER_VERSION < 20250000
  121. TORCH_CHECK_NOT_IMPLEMENTED(
  122. false,
  123. "elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
  124. #endif
  125. using namespace sycl::info::event_profiling;
  126. // Block until both of the recorded events are completed.
  127. uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
  128. uint64_t start_time_ns = event().get_profiling_info<command_end>();
  129. // Return the eplased time in milliseconds.
  130. return 1e-6 *
  131. (static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
  132. }
  133. void synchronize() const {
  134. if (isCreated()) {
  135. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  136. if (C10_UNLIKELY(interp)) {
  137. (*interp)->trace_gpu_event_synchronization(
  138. at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
  139. }
  140. event().wait_and_throw();
  141. }
  142. }
  143. private:
  144. void assignEvent(sycl::queue& queue) {
  145. #if SYCL_COMPILER_VERSION >= 20250000
  146. if (enable_timing_) {
  147. event_ = std::make_unique<sycl::event>(
  148. sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
  149. } else {
  150. event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
  151. }
  152. #else
  153. event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
  154. #endif
  155. }
  156. void reassignEvent(sycl::queue& queue) {
  157. event_.reset();
  158. assignEvent(queue);
  159. }
  160. bool enable_timing_ = false;
  161. DeviceIndex device_index_ = -1;
  162. // Only need to track the last event, as events in an in-order queue are
  163. // executed sequentially.
  164. std::unique_ptr<sycl::event> event_;
  165. };
  166. } // namespace at::xpu