MPSProfiler.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. // Copyright © 2022 Apple Inc.
  2. #pragma once
  3. #include <ATen/Tensor.h>
  4. #include <ATen/mps/MPSAllocatorInterface.h>
  5. #include <ATen/mps/MPSStream.h>
  6. #include <os/log.h>
  7. #include <os/signpost.h>
  8. #include <atomic>
  9. #include <ctime>
  10. #include <sstream>
  11. #include <string>
  12. #include <unordered_map>
  13. #include <utility>
  14. #ifndef __OBJC__
  15. typedef void* MTLCaptureManager;
  16. #endif
  17. namespace at::mps {
  18. namespace Profiler {
  19. struct BaseInfo {
  20. // profiling info types
  21. enum class Type {
  22. GRAPH,
  23. KERNEL,
  24. COPY,
  25. CPU_FALLBACK,
  26. };
  27. BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle)
  28. : type(infoType), profileId(Id), handle(Handle) {}
  29. virtual ~BaseInfo() = default;
  30. // type of profiling info
  31. Type type;
  32. // unique profile ID for execution instances of operations or copies
  33. uint64_t profileId;
  34. // ID generated by os_signpost
  35. // since it's possible to use event and interval-based signposts at the
  36. // same time, we need separate IDs for each.
  37. os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
  38. // accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime -
  39. // GPUStartTime")
  40. std::atomic<double> totalGpuTime{0.0};
  41. // accumulated Scheduling time in ms (obtained from CompletionHandler's
  42. // "KernelEndTime - KernelStartTime")
  43. std::atomic<double> totalSchedulingTime{0.0};
  44. // indicates if the operation or copy execution has completed
  45. std::atomic_bool completed{false};
  46. // handle used to identify the profile info's instance (usually the pointer)
  47. const uintptr_t handle;
  48. virtual const std::string toString(
  49. double gpuTime = 0,
  50. double schedulingTime = 0) const;
  51. // builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
  52. static std::string buildTensorString(
  53. const Tensor& tensor,
  54. bool includeBufferId = false);
  55. static uint64_t getTime() {
  56. return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
  57. }
  58. };
  59. struct OperationInfo : BaseInfo {
  60. OperationInfo(
  61. const void* Handle,
  62. bool IsGraph,
  63. uint64_t Id,
  64. const std::string& StrKey)
  65. : BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)),
  66. strKey(StrKey) {}
  67. uint64_t runCount = 0;
  68. std::string strKey;
  69. const std::string toString(double gpuTime = 0, double schedulingTime = 0)
  70. const override;
  71. // builds a string for a kernel
  72. static std::string buildKernelString(
  73. const std::string& kernelName,
  74. const TensorList& tensors,
  75. bool includeBufferId = false) {
  76. std::stringstream kernelStr;
  77. kernelStr << kernelName;
  78. for (const Tensor& tensor : tensors) {
  79. kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
  80. }
  81. return kernelStr.str();
  82. }
  83. };
  84. struct CpuFbInfo : BaseInfo {
  85. CpuFbInfo(uint64_t Id, const std::string& OpName)
  86. : BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) {}
  87. uint64_t runCount = 0;
  88. // the current and total overhead of copies in bytes required to convert the
  89. // Op's input tensors from MPS to CPU and then output from CPU back to MPS
  90. size_t currentCopyOverhead = 0;
  91. size_t totalCopyOverhead = 0;
  92. std::string opName;
  93. std::string strKey;
  94. uint64_t startTime = 0;
  95. const std::string toString(double gpuTime = 0, double schedulingTime = 0)
  96. const override;
  97. void updateCopyOverhead(const TensorList& tensors) {
  98. currentCopyOverhead = 0;
  99. for (const Tensor& tensor : tensors) {
  100. if (tensor.defined()) {
  101. currentCopyOverhead += tensor.nbytes();
  102. }
  103. }
  104. totalCopyOverhead += currentCopyOverhead;
  105. }
  106. };
  107. struct CopyInfo : BaseInfo {
  108. enum class Kind {
  109. MPS_TO_MPS,
  110. MPS_TO_CPU,
  111. CPU_TO_MPS,
  112. };
  113. CopyInfo(
  114. const void* Handle,
  115. size_t Length,
  116. uint64_t Id,
  117. bool IsNonBlocking,
  118. bool UsesBlitter)
  119. : BaseInfo(Type::COPY, Id, uintptr_t(Handle)),
  120. kind(Kind::MPS_TO_MPS),
  121. length(Length),
  122. isNonBlocking(IsNonBlocking),
  123. usesBlitter(UsesBlitter) {}
  124. Kind kind;
  125. size_t length;
  126. bool isNonBlocking;
  127. bool usesBlitter;
  128. std::string srcStrKey;
  129. std::string dstStrKey;
  130. // for copies that don't use blitters, we measure CPU time
  131. uint64_t startTime = 0;
  132. const std::string toString(double gpuTime = 0, double schedulingTime = 0)
  133. const override;
  134. static std::string buildTensorString(
  135. const void* buffer,
  136. const OptionalTensorRef tensor,
  137. bool includeBufferId = false);
  138. static bool isStorageOnMPS(
  139. const void* buffer,
  140. const OptionalTensorRef tensor) {
  141. if (tensor.has_value()) {
  142. return tensor->device().type() == at::kMPS;
  143. }
  144. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer);
  145. // getUnalignedBufferSize() returns -1 if input buffer is not on MPS device
  146. return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
  147. }
  148. static Kind getCopyKind(
  149. const void* srcBuffer,
  150. const void* dstBuffer,
  151. const OptionalTensorRef srcTensor,
  152. const OptionalTensorRef dstTensor) {
  153. const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
  154. const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
  155. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
  156. if (isSrcOnMPS && !isDstOnMPS) {
  157. return Kind::MPS_TO_CPU;
  158. } else if (!isSrcOnMPS && isDstOnMPS) {
  159. return Kind::CPU_TO_MPS;
  160. }
  161. return Kind::MPS_TO_MPS;
  162. }
  163. };
  164. struct CopyStat : CopyInfo {
  165. explicit CopyStat(std::string CopyKindStr)
  166. : CopyInfo(nullptr, 0, 0, false, false),
  167. kindStr(std::move(CopyKindStr)) {}
  168. // total number of copies
  169. size_t totalCount = 0;
  170. // number of Scalar copies (i.e., less than sizeof(int64))
  171. size_t scalarsCount = 0;
  172. // number of blocking copies (i.e., require syncing to GPU)
  173. size_t blockingCount = 0;
  174. // number of copies that used memcpy(), instead of Metal Blit Encoder
  175. size_t memcpyCount = 0;
  176. // accumulated GPU time in ms for the scalar copies
  177. std::atomic<double> scalarsGpuTime{0.0};
  178. // copy kind in string type
  179. std::string kindStr;
  180. };
  181. class MPSProfiler {
  182. public:
  183. // lower 16 bits used for profiler options
  184. enum ProfileOptions : uint32_t {
  185. OPTIONS_NONE = 0,
  186. // ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK,
  187. // etc.) (used for convenience to not compute bit flags by OR-ing manually)
  188. // trace all signpost types using events
  189. ALL_SIGNPOST_EVENTS = (1 << 0),
  190. // trace all signpost types using intervals
  191. ALL_SIGNPOST_INTERVALS = (1 << 1),
  192. // always wait for command buffer to finish executing after each commit
  193. WAIT_UNTIL_COMPLETED = (1 << 2),
  194. // for interval-based signposts, include the scheduling portion of
  195. // Graph/Kernel/Copy executions as well.
  196. // if flag is disable, only "GPU run time" is included in interval,
  197. // and not schedule time.
  198. INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
  199. // use these if you need to trace signposts types individually (rarely
  200. // required) trace signpost using intervals
  201. USE_INTERVALS = (1 << 4),
  202. // trace signpost by emitting events
  203. USE_EVENTS = (1 << 5),
  204. // used for sanity check (Change this when new option added)
  205. OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
  206. };
  207. // when adding new types, #define the type string in MPSProfiler.mm as well.
  208. // upper 16 bits used for event types
  209. enum SignpostTypes : uint32_t {
  210. SIGNPOST_NONE = 0,
  211. // trace signposts for PyTorch operation executions
  212. RUN_OPERATION = (1 << 16),
  213. // trace signposts for blitter copies
  214. BLIT_COPY = (1 << 17),
  215. // trace signposts for ops that fall back on CPU
  216. CPU_FALLBACK = (1 << 18),
  217. // used for sanity check (Change this when new type added)
  218. SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
  219. };
  220. enum LogOptions : uint32_t {
  221. LOG_NONE = 0,
  222. // Info logging options during execution
  223. // -------------------------------------
  224. // prints operation info (id/key/run_count) during execution
  225. OPERATION_INFO = (1 << 0),
  226. // prints copy info (src/dst tensors/buffers, size, etc.) during execution
  227. COPY_INFO = (1 << 1),
  228. // prints CPU Fallback info (id/runCount/opName/copyOverhead) during
  229. // execution
  230. CPU_FALLBACK_INFO = (1 << 2),
  231. // Profiling Statistics logging options when process terminates
  232. // ------------------------------------------------------------
  233. // prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before
  234. // process terminates this is convenient to not combine following stats bit
  235. // flags manually
  236. ALL_STATS = (1 << 3),
  237. // prints operation stats (GPU times, run count, etc.) before process
  238. // terminates
  239. OPERATION_STATS = (1 << 4),
  240. // prints copies stats (GPU times, copy kinds, sizes, etc.) before process
  241. // terminates
  242. COPY_STATS = (1 << 5),
  243. // prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
  244. // for tensors, etc.) before process terminates
  245. CPU_FALLBACK_STATS = (1 << 6),
  246. // Metadata format options when logging the info
  247. // ---------------------------------------------
  248. // if enabled, includes GPU run time in metadata (i.e.,
  249. // GPUEndTime-GPUStartTime from Metal Command Buffers) (e.g., [GPU=0.324
  250. // ms])
  251. INCLUDE_GPU_TIME = (1 << 7),
  252. // if enabled, includes GPU scheduling time in metadata separately
  253. // (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
  254. // e.g., [GPU=0.324 ms, KRNL=0.036 ms]
  255. INCLUDE_KERNEL_TIME = (1 << 8),
  256. // if enabled, includes the unique buffer ID in metadata for the storage
  257. // of a tensor that was allocated on MPSAllocator. This is useful (along
  258. // with the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are
  259. // involved with various operations.
  260. INCLUDE_BUFFER_ID = (1 << 9),
  261. // used for sanity check (Change this when new option added)
  262. LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
  263. };
  264. explicit MPSProfiler();
  265. ~MPSProfiler();
  266. // the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal
  267. // Kernels the beginProfile*() functions return a profileId which is unique
  268. // per graph/kernel/copy
  269. uint64_t beginProfileKernel(
  270. const void* handle,
  271. const std::string& strKey,
  272. bool isGraph);
  273. uint64_t beginProfileKernel(
  274. const void* handle,
  275. const std::string& kernelName,
  276. const TensorList& tensors);
  277. uint64_t beginProfileCopy(
  278. const void* srcBuffer,
  279. const void* dstBuffer,
  280. const OptionalTensorRef srcTensor,
  281. const OptionalTensorRef dstTensor,
  282. size_t length,
  283. bool isNonBlocking,
  284. bool usesBlitter = true);
  285. uint64_t beginProfileCPUFallback(
  286. const std::string& opName,
  287. const TensorList& tensors);
  288. void beginProfileGPUInterval(const void* handle);
  289. void endProfileCopy(uint64_t profileId, SyncType syncType);
  290. void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE);
  291. void endProfileCPUFallback(const std::string& opName);
  292. // these are used to hook into Python bindings for torch.mps.profiler module.
  293. // this enables generating OS Signpost traces from MPSProfiler on-demand
  294. // during runtime (instead of environment variables).
  295. // The "mode" could be either "interval", "event", or both "interval,event"
  296. // for interval-based and/or event-based signpost tracing.
  297. void StartTrace(const std::string& mode, bool waitUntilCompleted);
  298. void StopTrace();
  299. // Abstractions for GPU trace capturing
  300. bool isCaptureEnabled() const;
  301. bool isCapturing() const;
  302. void startCapture(const std::string& name, MPSStream* stream = nullptr);
  303. void stopCapture(MPSStream* stream = nullptr);
  304. // convenience functions to indicate whether signpost tracing or
  305. // logging are enabled for the SignpostTypes
  306. bool isOperationProfilingEnabled() const {
  307. return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
  308. (m_log_options &
  309. (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
  310. }
  311. bool isCopyProfilingEnabled() const {
  312. return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
  313. (m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
  314. }
  315. bool isCPUFallbackProfilingEnabled() const {
  316. return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
  317. (m_log_options &
  318. (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
  319. }
  320. bool isSignpostTracingEnabled() const {
  321. return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
  322. }
  323. private:
  324. // indicates what type of signpost types are enabled and traced by MPS
  325. // profiler.
  326. uint32_t m_signpost_types = 0;
  327. uint32_t m_profile_options = 0;
  328. uint32_t m_log_options = 0;
  329. uint64_t m_kernel_counter = 0;
  330. uint64_t m_graph_counter = 0;
  331. uint64_t m_cpu_fb_counter = 0;
  332. uint64_t m_copy_counter = 0;
  333. // technically, it's possible to trace both events and intervals at the same
  334. // time so we use separate os_log categories for them
  335. os_log_t m_os_log_events;
  336. os_log_t m_os_log_intervals;
  337. // stats logging could run either from destructor or signal handler
  338. // so this is used to check if logging has already started.
  339. std::atomic_bool hasLoggedStats{false};
  340. // indicates there are pending completionHandler callbacks that haven't been
  341. // called yet.
  342. std::atomic_bool hasPendingCompletionHandlers{false};
  343. // used to capture sigint signal to log profiling stats
  344. static struct sigaction currentSigint, previousSigint;
  345. // We use the following lists for two reasons:
  346. // 1- for interval-based signposts the "begin" point won't be in same function
  347. // as the "end" point where we need to be able to retrieve signpost's info
  348. // 2- if Operations info need to be logged when process ends using
  349. // LogOptions::OPERATION_INFO.
  350. // the pointer key for this map is either "MPSGraph*" or
  351. // "id<MTLComputePipelineState>" for Metal Kernels this list is retained and
  352. // could be logged along with aggregate profiling numbers when the process
  353. // ends.
  354. std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>>
  355. m_op_info_list{};
  356. // the string key for this map is the op name that we fall back to execute on
  357. // CPU this list is retained and could be logged along with aggregate
  358. // profiling numbers when the process ends.
  359. std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>>
  360. m_cpu_fb_info_list{};
  361. // this list contains the info for copies, and its key is the unique profileId
  362. // which is generated from m_copy_counter
  363. // The copyInfo list is not retained.
  364. std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
  365. // a short list that contains copy stats
  366. std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>>
  367. m_copy_stat_list{};
  368. mutable MTLCaptureManager* captureManager = nil;
  369. unsigned captureCount = 0;
  370. void initialize();
  371. void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
  372. void endProfileExecution(
  373. BaseInfo& info,
  374. os_signpost_id_t event_signpost_id,
  375. os_signpost_id_t interval_signpost_id,
  376. double gpuTime,
  377. double schedulingTime);
  378. void addProfilerScheduledHandler(BaseInfo& info);
  379. void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
  380. void emitSignpostEvent(
  381. SignpostTypes signpost_type,
  382. os_signpost_id_t signpost_id,
  383. const std::string& msg) const;
  384. void beginSignpostInterval(
  385. SignpostTypes signpost_type,
  386. os_signpost_id_t signpost_id,
  387. const std::string& msg) const;
  388. void endSignpostInterval(
  389. SignpostTypes signpost_type,
  390. os_signpost_id_t signpost_id) const;
  391. void updateCopyStats(
  392. const CopyInfo& copyInfo,
  393. double gpuTime,
  394. double schedulingTime);
  395. // returns true if logging the profiling info "during the execution" is
  396. // enabled
  397. bool isProfileInfoLoggingEnabled(
  398. BaseInfo::Type infoType,
  399. bool isExecutionEnded);
  400. // logs all the profiling stats that are enabled
  401. void logProfilingStats();
  402. // logs kernel profiling stats when the process ends.
  403. void logOperationsProfilingStats(std::FILE* f) const;
  404. // logs CPU Fallback profiling stats when the process ends.
  405. void logCPUFallbackProfilingStats(std::FILE* f) const;
  406. // logs copy profiling stats when the process ends.
  407. void logCopyProfilingStats(std::FILE* f) const;
  408. os_signpost_id_t generateSignpostId(
  409. os_signpost_type_t signpostType,
  410. const void* ptr = nullptr);
  411. static SignpostTypes getSignpostType(BaseInfo::Type infoType);
  412. static void handleIntSignal(int signal);
  413. };
  414. } // namespace Profiler
  415. Profiler::MPSProfiler& getMPSProfiler();
  416. } // namespace at::mps