MPSHooksInterface.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. // Copyright © 2022 Apple Inc.
  2. #pragma once
  3. #include <ATen/detail/AcceleratorHooksInterface.h>
  4. #include <c10/core/Allocator.h>
  5. #include <c10/util/Exception.h>
  6. #include <c10/util/Registry.h>
  7. #include <cstddef>
  8. C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
  9. namespace at {
  10. struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
  11. // this fails the implementation if MPSHooks functions are called, but
  12. // MPS backend is not present.
  13. #define FAIL_MPSHOOKS_FUNC(func) \
  14. TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend.");
  15. ~MPSHooksInterface() override = default;
  16. // Initialize the MPS library state
  17. void init() const override {
  18. FAIL_MPSHOOKS_FUNC(__func__);
  19. }
  20. virtual bool hasMPS() const {
  21. return false;
  22. }
  23. virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const {
  24. FAIL_MPSHOOKS_FUNC(__func__);
  25. }
  26. const Generator& getDefaultGenerator(
  27. [[maybe_unused]] DeviceIndex device_index = -1) const override {
  28. FAIL_MPSHOOKS_FUNC(__func__);
  29. }
  30. Generator getNewGenerator(
  31. [[maybe_unused]] DeviceIndex device_index) const override {
  32. FAIL_MPSHOOKS_FUNC(__func__);
  33. }
  34. virtual Allocator* getMPSDeviceAllocator() const {
  35. FAIL_MPSHOOKS_FUNC(__func__);
  36. }
  37. virtual void deviceSynchronize() const {
  38. FAIL_MPSHOOKS_FUNC(__func__);
  39. }
  40. virtual void commitStream() const {
  41. FAIL_MPSHOOKS_FUNC(__func__);
  42. }
  43. virtual void* getCommandBuffer() const {
  44. FAIL_MPSHOOKS_FUNC(__func__);
  45. }
  46. virtual void* getDispatchQueue() const {
  47. FAIL_MPSHOOKS_FUNC(__func__);
  48. }
  49. virtual void emptyCache() const {
  50. FAIL_MPSHOOKS_FUNC(__func__);
  51. }
  52. virtual size_t getCurrentAllocatedMemory() const {
  53. FAIL_MPSHOOKS_FUNC(__func__);
  54. }
  55. virtual size_t getDriverAllocatedMemory() const {
  56. FAIL_MPSHOOKS_FUNC(__func__);
  57. }
  58. virtual size_t getRecommendedMaxMemory() const {
  59. FAIL_MPSHOOKS_FUNC(__func__);
  60. }
  61. virtual void setMemoryFraction(double /*ratio*/) const {
  62. FAIL_MPSHOOKS_FUNC(__func__);
  63. }
  64. virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
  65. FAIL_MPSHOOKS_FUNC(__func__);
  66. }
  67. virtual void profilerStopTrace() const {
  68. FAIL_MPSHOOKS_FUNC(__func__);
  69. }
  70. virtual uint32_t acquireEvent(bool enable_timing) const {
  71. FAIL_MPSHOOKS_FUNC(__func__);
  72. }
  73. Device getDeviceFromPtr(void* data) const override {
  74. TORCH_CHECK(false, "Cannot get device of pointer on MPS without ATen_mps library. ");
  75. }
  76. virtual void releaseEvent(uint32_t event_id) const {
  77. FAIL_MPSHOOKS_FUNC(__func__);
  78. }
  79. virtual void recordEvent(uint32_t event_id) const {
  80. FAIL_MPSHOOKS_FUNC(__func__);
  81. }
  82. virtual void waitForEvent(uint32_t event_id) const {
  83. FAIL_MPSHOOKS_FUNC(__func__);
  84. }
  85. virtual void synchronizeEvent(uint32_t event_id) const {
  86. FAIL_MPSHOOKS_FUNC(__func__);
  87. }
  88. virtual bool queryEvent(uint32_t event_id) const {
  89. FAIL_MPSHOOKS_FUNC(__func__);
  90. }
  91. virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
  92. FAIL_MPSHOOKS_FUNC(__func__);
  93. }
  94. bool hasPrimaryContext(DeviceIndex device_index) const override {
  95. FAIL_MPSHOOKS_FUNC(__func__);
  96. }
  97. bool isPinnedPtr(const void* data) const override {
  98. return false;
  99. }
  100. Allocator* getPinnedMemoryAllocator() const override {
  101. FAIL_MPSHOOKS_FUNC(__func__);
  102. }
  103. #undef FAIL_MPSHOOKS_FUNC
  104. };
  105. struct TORCH_API MPSHooksArgs {};
  106. TORCH_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs);
  107. #define REGISTER_MPS_HOOKS(clsname) \
  108. C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname)
  109. namespace detail {
  110. TORCH_API const MPSHooksInterface& getMPSHooks();
  111. } // namespace detail
  112. } // namespace at
  113. C10_DIAGNOSTIC_POP()