MTIAHooksInterface.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #pragma once
  2. #include <c10/core/Device.h>
  3. #include <c10/util/Exception.h>
  4. #include <c10/core/Stream.h>
  5. #include <c10/util/Registry.h>
  6. #include <c10/core/Allocator.h>
  7. #include <c10/util/python_stub.h>
  8. #include <ATen/detail/AcceleratorHooksInterface.h>
  9. #include <string>
  10. namespace at {
  11. class Context;
  12. }
  13. namespace at {
  14. constexpr const char* MTIA_HELP =
  15. "The MTIA backend requires MTIA extension for PyTorch;"
  16. "this error has occurred because you are trying "
  17. "to use some MTIA's functionality without MTIA extension included.";
  18. struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
  19. // this fails the implementation if MTIAHooks functions are called, but
  20. // MTIA backend is not present.
  21. #define FAIL_MTIAHOOKS_FUNC(func) \
  22. TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
  23. ~MTIAHooksInterface() override = default;
  24. void init() const override {
  25. // Avoid logging here, since MTIA needs init devices first then it will know
  26. // how many devices are available. Make it as no-op if mtia extension is not
  27. // dynamically loaded.
  28. return;
  29. }
  30. virtual bool hasMTIA() const {
  31. return false;
  32. }
  33. DeviceIndex deviceCount() const override {
  34. return 0;
  35. }
  36. virtual void deviceSynchronize(c10::DeviceIndex /*device_index*/) const {
  37. FAIL_MTIAHOOKS_FUNC(__func__);
  38. }
  39. virtual std::string showConfig() const {
  40. FAIL_MTIAHOOKS_FUNC(__func__);
  41. }
  42. bool hasPrimaryContext(DeviceIndex /*device_index*/) const override {
  43. return false;
  44. }
  45. void setCurrentDevice(DeviceIndex /*device*/) const override {
  46. FAIL_MTIAHOOKS_FUNC(__func__);
  47. }
  48. DeviceIndex getCurrentDevice() const override {
  49. FAIL_MTIAHOOKS_FUNC(__func__);
  50. return -1;
  51. }
  52. DeviceIndex exchangeDevice(DeviceIndex /*device*/) const override {
  53. FAIL_MTIAHOOKS_FUNC(__func__);
  54. return -1;
  55. }
  56. DeviceIndex maybeExchangeDevice(DeviceIndex /*device*/) const override {
  57. FAIL_MTIAHOOKS_FUNC(__func__);
  58. return -1;
  59. }
  60. virtual c10::Stream getCurrentStream(DeviceIndex /*device*/) const {
  61. FAIL_MTIAHOOKS_FUNC(__func__);
  62. return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
  63. }
  64. virtual int64_t getCurrentRawStream(DeviceIndex /*device*/) const {
  65. FAIL_MTIAHOOKS_FUNC(__func__);
  66. return -1;
  67. }
  68. virtual c10::Stream getDefaultStream(DeviceIndex /*device*/) const {
  69. FAIL_MTIAHOOKS_FUNC(__func__);
  70. return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
  71. }
  72. virtual void setCurrentStream(const c10::Stream& /*stream*/ ) const {
  73. FAIL_MTIAHOOKS_FUNC(__func__);
  74. }
  75. bool isPinnedPtr(const void* /*data*/) const override {
  76. return false;
  77. }
  78. Allocator* getPinnedMemoryAllocator() const override {
  79. FAIL_MTIAHOOKS_FUNC(__func__);
  80. return nullptr;
  81. }
  82. virtual PyObject* memoryStats(DeviceIndex /*device*/) const {
  83. FAIL_MTIAHOOKS_FUNC(__func__);
  84. return nullptr;
  85. }
  86. virtual PyObject* getDeviceCapability(DeviceIndex /*device*/) const {
  87. FAIL_MTIAHOOKS_FUNC(__func__);
  88. return nullptr;
  89. }
  90. virtual PyObject* getDeviceProperties(DeviceIndex device) const {
  91. FAIL_MTIAHOOKS_FUNC(__func__);
  92. return nullptr;
  93. }
  94. virtual void emptyCache() const {
  95. FAIL_MTIAHOOKS_FUNC(__func__);
  96. }
  97. virtual void recordMemoryHistory(
  98. const std::optional<std::string>& /*enabled*/,
  99. const std::string& /*stacks*/,
  100. size_t /*max_entries*/) const {
  101. FAIL_MTIAHOOKS_FUNC(__func__);
  102. }
  103. virtual PyObject* memorySnapshot(const std::optional<std::string>& local_path) const {
  104. FAIL_MTIAHOOKS_FUNC(__func__);
  105. return nullptr;
  106. }
  107. virtual DeviceIndex getDeviceCount() const {
  108. FAIL_MTIAHOOKS_FUNC(__func__);
  109. return 0;
  110. }
  111. virtual void resetPeakMemoryStats(DeviceIndex /*device*/) const {
  112. FAIL_MTIAHOOKS_FUNC(__func__);
  113. }
  114. virtual void attachOutOfMemoryObserver(PyObject* observer) const {
  115. FAIL_MTIAHOOKS_FUNC(__func__);
  116. return;
  117. }
  118. virtual bool isAvailable() const override;
  119. };
  120. struct TORCH_API MTIAHooksArgs {};
  121. TORCH_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
  122. #define REGISTER_MTIA_HOOKS(clsname) \
  123. C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
  124. namespace detail {
  125. TORCH_API const MTIAHooksInterface& getMTIAHooks();
  126. TORCH_API bool isMTIAHooksBuilt();
  127. } // namespace detail
  128. } // namespace at