ThreadLocalState.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #pragma once
  2. #include <c10/core/InferenceMode.h>
  3. #include <c10/core/impl/LocalDispatchKeySet.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/ThreadLocalDebugInfo.h>
  6. #include <ATen/FuncTorchTLS.h>
  7. #include <ATen/PythonTorchFunctionTLS.h>
  8. #include <ATen/SavedTensorHooks.h>
  9. #include <ATen/ThreadLocalPythonObjects.h>
  10. #include <ATen/record_function.h>
  11. #include <c10/core/impl/PythonDispatcherTLS.h>
  12. #include <c10/core/impl/TorchDispatchModeTLS.h>
  13. namespace at {
  14. // Thread local state contains values that are preserved across
  15. // thread boundaries (e.g. at::launch/JIT fork, autograd).
  16. // Note at::parallel_for doesn't preserve TLS across thread boundaries.
  17. class TORCH_API ThreadLocalState {
  18. public:
  19. // Saves the thread local variables' values and
  20. // returns them as a ThreadLocalState
  21. ThreadLocalState();
  22. // set_grad_mode - force the value of the grad mode TLS in
  23. // the current state object. This is used for example in the
  24. // autograd engine.
  25. void set_grad_mode(bool enabled);
  26. // set_multithreading_enabled - force the value of the multithreadinmaximum
  27. // threads TLS in
  28. // the current state object. This is used for example in the
  29. // autograd engine.
  30. void set_multithreading_enabled(bool enabled);
  31. // Sets thread local variables in the current thread,
  32. // according to the thread boundary specified
  33. static void setThreadLocalState(const ThreadLocalState& state);
  34. private:
  35. c10::impl::LocalDispatchKeySet dispatch_key_;
  36. // ThreadLocalDebugInfo does not change after being created
  37. // with DebugInfoGuard
  38. std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
  39. // RecordFunction TLS
  40. RecordFunctionTLS rf_tls_;
  41. // TLS for out-of-tree functorch
  42. // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
  43. // pointer (spoiler alert: it's due to the indirection)
  44. // This needs to be a shared_ptr instead of a unique_ptr because
  45. // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
  46. // consider adding an explicit copy constructor for ThreadLocalState in the
  47. // future but I didn't want to add one just for this.
  48. std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
  49. // TLS for AutogradModes
  50. AutogradState autograd_tls_;
  51. // TLS for enable_torch_dispatch_mode
  52. c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
  53. // TLS for enable_python_dispatcher
  54. c10::impl::PyInterpreter* python_dispatcher_state_;
  55. // TLS for __torch_function__ (mode and disable_torch_function)
  56. at::impl::PythonTorchFunctionTLS python_torch_function_state_;
  57. // TLS for saved tensors default hooks
  58. at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
  59. bool functionalization_reapply_views_state_;
  60. bool dtensor_allow_implicit_replication_;
  61. // TLS for arbitrary python objects that is registered via hooks
  62. at::impl::ThreadLocalPythonObjects saved_objects_;
  63. #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
  64. !defined(BUILD_LITE_INTERPRETER)
  65. // TLS for autocast dtypes
  66. std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
  67. autocast_dtypes_{};
  68. #endif
  69. friend class ThreadLocalStateGuard;
  70. };
  71. // Guard to set and reset the thread local state
  72. class TORCH_API ThreadLocalStateGuard {
  73. public:
  74. explicit ThreadLocalStateGuard(const ThreadLocalState& state)
  75. : prev_state_(ThreadLocalState()) {
  76. // set the given state across the thread boundary
  77. ThreadLocalState::setThreadLocalState(state);
  78. }
  79. ThreadLocalStateGuard(ThreadLocalStateGuard&& other) = delete;
  80. ThreadLocalStateGuard(const ThreadLocalStateGuard&) = delete;
  81. ThreadLocalStateGuard& operator=(const ThreadLocalStateGuard&) = delete;
  82. ThreadLocalStateGuard& operator=(ThreadLocalStateGuard&&) = delete;
  83. ~ThreadLocalStateGuard() {
  84. // restore previously set variables
  85. ThreadLocalState::setThreadLocalState(prev_state_);
  86. }
  87. private:
  88. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  89. const ThreadLocalState prev_state_;
  90. };
  91. template <typename T>
  92. auto wrapPropagateTLSState(T callback) {
  93. return [tls_state = ThreadLocalState(),
  94. callback = std::move(callback)](auto&&... args) {
  95. ThreadLocalStateGuard g(tls_state);
  96. // Propagate value returned by callback().
  97. return callback(std::forward<decltype(args)>(args)...);
  98. };
  99. }
  100. } // namespace at