SavedTensorHooks.h 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #pragma once
  2. #include <c10/core/SafePyObject.h>
  3. #include <c10/macros/Export.h>
  4. #include <c10/util/python_stub.h>
  5. #include <optional>
  6. #include <stack>
  7. #include <string>
  8. #include <utility>
  9. namespace at {
  10. namespace impl {
  11. struct TORCH_API SavedTensorDefaultHooksTLS {
  12. // PyObject is defined in c10/util/python_stub.h
  13. std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
  14. // See NOTE: [Disabling SavedTensorDefaultHooks] for context
  15. // NOTE: [disabled_error_message invariant]
  16. // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
  17. // We did this for efficiency (so we didn't have to keep a separate bool
  18. // around)
  19. std::optional<std::string> disabled_error_message;
  20. // See NOTE: [Deferring tensor pack/unpack hooks until runtime]
  21. bool is_tracing = false;
  22. };
  23. } // namespace impl
  24. struct TORCH_API SavedTensorDefaultHooks {
  25. static void push_hooks(
  26. c10::SafePyObject pack_hook,
  27. c10::SafePyObject unpack_hook);
  28. static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
  29. static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
  30. get_hooks(bool ignore_is_tracing = false);
  31. static void lazy_initialize();
  32. static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
  33. static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
  34. // NOTE: [Disabling SavedTensorDefaultHooks]
  35. // A developer of a PyTorch feature may choose to disable SavedTensorDefault
  36. // hooks, especially if their feature does not work with it. If they are
  37. // disabled, then the following will raise an error:
  38. // - Attempting to push_hooks
  39. // - calling disable(message) with a non-zero stack (hooks) size
  40. static void disable(
  41. const std::string& error_message,
  42. const bool fail_if_non_empty = true);
  43. static void enable();
  44. static bool is_enabled();
  45. static const std::optional<std::string>& get_disabled_error_message();
  46. // NOTE: [Deferring tensor pack/unpack hooks until runtime]
  47. // To preserve eager semantics of pack/unpack hooks firing only once per saved
  48. // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using
  49. // disable() would loud error at trace time, and pushing a no-op hook would
  50. // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx.
  51. // To do so, we disable these hooks during tracing. See
  52. // https://github.com/pytorch/pytorch/issues/113263.
  53. static bool set_tracing(bool is_tracing);
  54. };
  55. } // namespace at