MPSStream.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. // Copyright © 2022 Apple Inc.
  2. #pragma once
  3. #include <cstdint>
  4. #include <utility>
  5. #include <ATen/mps/MPSDevice.h>
  6. #include <c10/core/DeviceGuard.h>
  7. #include <c10/core/Stream.h>
  8. #include <c10/util/Exception.h>
  9. #ifdef __OBJC__
  10. #include <Foundation/Foundation.h>
  11. #include <Metal/Metal.h>
  12. #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
  13. #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
  14. typedef MPSCommandBuffer* MPSCommandBuffer_t;
  15. typedef id<MTLCommandQueue> MTLCommandQueue_t;
  16. typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
  17. typedef id<MTLSharedEvent> MTLSharedEvent_t;
  18. typedef id<MTLDevice> MTLDevice_t;
  19. typedef id<MTLBuffer> MTLBuffer_t;
  20. #else
  21. #include <dispatch/dispatch.h>
  22. typedef void* MPSCommandBuffer_t;
  23. typedef void* MPSGraph;
  24. typedef void* MPSGraphExecutionDescriptor;
  25. typedef void* MPSGraphCompilationDescriptor;
  26. typedef void* MTLCommandQueue_t;
  27. typedef void* MTLComputeCommandEncoder_t;
  28. typedef void* MTLSharedEvent_t;
  29. typedef void* MTLDevice_t;
  30. typedef void* MTLBuffer_t;
  31. typedef void* MTLCommandBufferHandler;
  32. typedef void* NSDictionary;
  33. #define nil NULL
  34. #endif
  35. namespace at::mps {
  36. //-----------------------------------------------------------------
  37. // MPSStream
  38. //-----------------------------------------------------------------
  39. enum class SyncType {
  40. NONE, // no commit to command buffer
  41. COMMIT, // commit and flush the command buffer
  42. COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
  43. COMMIT_AND_CONTINUE, // commit and continue with a new underlying command buffer
  44. COMMIT_ADAPTIVE, // commit adaptively based on available memory
  45. };
  46. class TORCH_API MPSStream {
  47. public:
  48. enum Unchecked { UNCHECKED };
  49. /// Construct a MPSStream from a Stream. This construction is checked,
  50. /// and will raise an error if the Stream is not, in fact, a MPS stream.
  51. explicit MPSStream(Stream stream);
  52. ~MPSStream();
  53. MTLCommandQueue_t commandQueue() const {
  54. return _commandQueue;
  55. }
  56. dispatch_queue_t queue() const {
  57. return _serialQueue;
  58. }
  59. MPSCommandBuffer_t commandBuffer();
  60. MTLComputeCommandEncoder_t commandEncoder();
  61. void endKernelCoalescing();
  62. void synchronize(SyncType syncType);
  63. void fill(MTLBuffer_t buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
  64. void copy(MTLBuffer_t srcBuffer,
  65. MTLBuffer_t dstBuffer,
  66. size_t length,
  67. size_t srcOffset,
  68. size_t dstOffset,
  69. uint64_t profileId,
  70. SyncType syncType = SyncType::NONE);
  71. void copy_and_sync(MTLBuffer_t srcBuffer,
  72. MTLBuffer_t dstBuffer,
  73. size_t length,
  74. size_t srcOffset,
  75. size_t dstOffset,
  76. bool non_blocking,
  77. uint64_t profileId);
  78. void executeMPSGraph(MPSGraph* mpsGraph,
  79. NSDictionary* feeds,
  80. NSDictionary* results,
  81. SyncType syncType = SyncType::NONE);
  82. void addCompletedHandler(MTLCommandBufferHandler block);
  83. /// Get the MPS device index that this stream is associated with.
  84. c10::DeviceIndex device_index() const {
  85. return _stream.device_index();
  86. }
  87. MTLCommandQueue_t stream() const {
  88. return _commandQueue;
  89. }
  90. MTLDevice_t device() const;
  91. /// Explicit conversion to Stream.
  92. Stream unwrap() const {
  93. return _stream;
  94. }
  95. private:
  96. Stream _stream;
  97. MTLCommandQueue_t _commandQueue = nil;
  98. MPSCommandBuffer_t _commandBuffer = nil;
  99. MPSCommandBuffer_t _prevCommandBuffer = nil;
  100. MTLComputeCommandEncoder_t _commandEncoder = nil;
  101. MPSGraphExecutionDescriptor* _executionDescriptor = nil;
  102. MPSGraphCompilationDescriptor* _compilationDescriptor = nil;
  103. dispatch_queue_t _serialQueue = nullptr;
  104. // CommitAndContinue is enabled by default
  105. bool _enableCommitAndContinue = true;
  106. // use synchronize() to access any of these commit functions outside MPSStream
  107. void commit();
  108. void commitAndWait();
  109. void commitAndContinue();
  110. void flush();
  111. };
  112. /**
  113. * Get the current MPS stream
  114. */
  115. TORCH_API MPSStream* getCurrentMPSStream();
  116. /**
  117. * Get the default MPS stream
  118. */
  119. TORCH_API MPSStream* getDefaultMPSStream();
  120. //-----------------------------------------------------------------
  121. // MPSStreamImpl
  122. //-----------------------------------------------------------------
  123. class TORCH_API MPSStreamImpl {
  124. public:
  125. /**
  126. * Gets single instance of the MPSStream.
  127. */
  128. static MPSStream* getInstance();
  129. private:
  130. static MPSStream* _stream;
  131. MPSStreamImpl();
  132. };
  133. } // namespace at::mps