DTensorState.h 1016 B

12345678910111213141516171819202122232425262728293031323334
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. namespace at {
  4. TORCH_API bool get_dtensor_allow_implicit_replication();
  5. TORCH_API void set_dtensor_allow_implicit_replication(bool enabled);
  6. struct DTensorAllowImplicitReplication {
  7. DTensorAllowImplicitReplication()
  8. : prev_dtensor_allow_implicit_replication_(
  9. get_dtensor_allow_implicit_replication()) {
  10. set_dtensor_allow_implicit_replication(true);
  11. }
  12. DTensorAllowImplicitReplication(const DTensorAllowImplicitReplication&) =
  13. delete;
  14. DTensorAllowImplicitReplication& operator=(
  15. const DTensorAllowImplicitReplication&) = delete;
  16. DTensorAllowImplicitReplication(DTensorAllowImplicitReplication&&) = delete;
  17. DTensorAllowImplicitReplication& operator=(
  18. DTensorAllowImplicitReplication&&) = delete;
  19. ~DTensorAllowImplicitReplication() {
  20. set_dtensor_allow_implicit_replication(
  21. prev_dtensor_allow_implicit_replication_);
  22. }
  23. private:
  24. bool prev_dtensor_allow_implicit_replication_;
  25. };
  26. } // namespace at