MPSHooks.h 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. // Copyright © 2022 Apple Inc.
  2. #pragma once
  3. #include <ATen/Generator.h>
  4. #include <ATen/detail/MPSHooksInterface.h>
  5. #include <ATen/mps/MPSEvent.h>
  6. #include <optional>
  7. namespace at::mps {
  8. // The real implementation of MPSHooksInterface
  9. struct MPSHooks : public at::MPSHooksInterface {
  10. MPSHooks(at::MPSHooksArgs) {}
  11. void init() const override;
  12. // MPSDevice interface
  13. bool hasMPS() const override;
  14. bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
  15. Device getDeviceFromPtr(void* data) const override;
  16. // MPSGeneratorImpl interface
  17. const Generator& getDefaultGenerator(
  18. DeviceIndex device_index = -1) const override;
  19. Generator getNewGenerator(DeviceIndex device_index = -1) const override;
  20. // MPSStream interface
  21. void deviceSynchronize() const override;
  22. void commitStream() const override;
  23. void* getCommandBuffer() const override;
  24. void* getDispatchQueue() const override;
  25. // MPSAllocator interface
  26. Allocator* getMPSDeviceAllocator() const override;
  27. void emptyCache() const override;
  28. size_t getCurrentAllocatedMemory() const override;
  29. size_t getDriverAllocatedMemory() const override;
  30. size_t getRecommendedMaxMemory() const override;
  31. void setMemoryFraction(double ratio) const override;
  32. bool isPinnedPtr(const void* data) const override;
  33. Allocator* getPinnedMemoryAllocator() const override;
  34. // MPSProfiler interface
  35. void profilerStartTrace(const std::string& mode, bool waitUntilCompleted)
  36. const override;
  37. void profilerStopTrace() const override;
  38. // MPSEvent interface
  39. uint32_t acquireEvent(bool enable_timing) const override;
  40. void releaseEvent(uint32_t event_id) const override;
  41. void recordEvent(uint32_t event_id) const override;
  42. void waitForEvent(uint32_t event_id) const override;
  43. void synchronizeEvent(uint32_t event_id) const override;
  44. bool queryEvent(uint32_t event_id) const override;
  45. double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
  46. const override;
  47. bool isBuilt() const override {
  48. return true;
  49. }
  50. bool isAvailable() const override {
  51. return hasMPS();
  52. }
  53. bool hasPrimaryContext(DeviceIndex device_index) const override {
  54. // When MPS is available, it is always in use for the one device.
  55. return true;
  56. }
  57. };
  58. } // namespace at::mps