CallOnce.h 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <c10/util/C++17.h>
  4. #include <atomic>
  5. #include <functional>
  6. #include <mutex>
  7. #include <utility>
  8. namespace c10 {
  9. // custom c10 call_once implementation to avoid the deadlock in std::call_once.
  10. // The implementation here is a simplified version from folly and likely much
  11. // much higher memory footprint.
  12. template <typename Flag, typename F, typename... Args>
  13. inline void call_once(Flag& flag, F&& f, Args&&... args) {
  14. if (C10_LIKELY(flag.test_once())) {
  15. return;
  16. }
  17. flag.call_once_slow(std::forward<F>(f), std::forward<Args>(args)...);
  18. }
  19. class once_flag {
  20. public:
  21. #ifndef _WIN32
  22. // running into build error on MSVC. Can't seem to get a repro locally so I'm
  23. // just avoiding constexpr
  24. //
  25. // C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error:
  26. // defaulted default constructor cannot be constexpr because the
  27. // corresponding implicitly declared default constructor would not be
  28. // constexpr 1 error detected in the compilation of
  29. // "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu".
  30. constexpr
  31. #endif
  32. once_flag() noexcept = default;
  33. once_flag(const once_flag&) = delete;
  34. once_flag& operator=(const once_flag&) = delete;
  35. once_flag(once_flag&&) = delete;
  36. once_flag& operator=(once_flag&&) = delete;
  37. ~once_flag() = default;
  38. bool test_once() {
  39. return init_.load(std::memory_order_acquire);
  40. }
  41. private:
  42. template <typename Flag, typename F, typename... Args>
  43. friend void call_once(Flag& flag, F&& f, Args&&... args);
  44. template <typename F, typename... Args>
  45. void call_once_slow(F&& f, Args&&... args) {
  46. std::lock_guard<std::mutex> guard(mutex_);
  47. if (init_.load(std::memory_order_relaxed)) {
  48. return;
  49. }
  50. std::invoke(std::forward<F>(f), std::forward<Args>(args)...);
  51. init_.store(true, std::memory_order_release);
  52. }
  53. void reset_once() {
  54. init_.store(false, std::memory_order_release);
  55. }
  56. private:
  57. std::mutex mutex_;
  58. std::atomic<bool> init_{false};
  59. };
  60. } // namespace c10