profiler.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128
  1. # mypy: allow-untyped-defs
  2. import gzip
  3. import json
  4. import os
  5. import shutil
  6. import tempfile
  7. from abc import ABC, abstractmethod
  8. from collections.abc import Iterable
  9. from enum import Enum
  10. from functools import partial
  11. from typing import Any, Callable, Optional
  12. from typing_extensions import Self
  13. from warnings import warn
  14. import torch
  15. import torch.autograd.profiler as prof
  16. from torch._C import _get_privateuse1_backend_name
  17. from torch._C._profiler import (
  18. _add_execution_trace_observer,
  19. _disable_execution_trace_observer,
  20. _enable_execution_trace_observer,
  21. _ExperimentalConfig,
  22. _remove_execution_trace_observer,
  23. )
  24. from torch._environment import is_fbcode
  25. from torch._utils_internal import profiler_allow_cudagraph_cupti_lazy_reinit_cuda12
  26. from torch.autograd import kineto_available, ProfilerActivity
  27. from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
  28. __all__ = [
  29. "supported_activities",
  30. "ProfilerAction",
  31. "schedule",
  32. "tensorboard_trace_handler",
  33. "profile",
  34. "ExecutionTraceObserver",
  35. ]
  36. PROFILER_STEP_NAME = "ProfilerStep"
  37. class _NumpyEncoder(json.JSONEncoder):
  38. """
  39. Json encoder for numpy types (np.int, np.float, np.array etc.)
  40. Returns default encoder if numpy is not available
  41. """
  42. def default(self, obj):
  43. """Encode NumPy types to JSON"""
  44. try:
  45. import numpy as np
  46. except ImportError:
  47. return json.JSONEncoder.default(self, obj)
  48. if isinstance(obj, np.integer):
  49. return int(obj)
  50. elif isinstance(obj, np.floating):
  51. return float(obj)
  52. elif isinstance(obj, np.ndarray):
  53. return obj.tolist()
  54. else:
  55. return json.JSONEncoder.default(self, obj)
  56. def supported_activities():
  57. """
  58. Returns a set of supported profiler tracing activities.
  59. Note: profiler uses CUPTI library to trace on-device CUDA kernels.
  60. In case when CUDA is enabled but CUPTI is not available, passing
  61. ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA
  62. profiling code (same as in the legacy ``torch.autograd.profiler``).
  63. This, in turn, results in including CUDA time in the profiler table output,
  64. but not in the JSON trace.
  65. """
  66. return torch.autograd._supported_activities()
  67. class _ITraceObserver(ABC):
  68. """Abstract interface for a Trace observer.
  69. This satisfies 3 methods: start, stop and cleanup"""
  70. @abstractmethod
  71. def start(self):
  72. pass
  73. @abstractmethod
  74. def stop(self):
  75. pass
  76. @abstractmethod
  77. def cleanup(self):
  78. pass
  79. class _KinetoProfile:
  80. """Low-level profiler wrap the autograd profile
  81. Args:
  82. activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
  83. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
  84. ``torch.profiler.ProfilerActivity.XPU``.
  85. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
  86. or (when available) ProfilerActivity.XPU.
  87. record_shapes (bool): save information about operator's input shapes.
  88. profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline``
  89. for more details).
  90. with_stack (bool): record source information (file and line number) for the ops.
  91. with_flops (bool): use formula to estimate the FLOPS of specific operators
  92. (matrix multiplication and 2D convolution).
  93. with_modules (bool): record module hierarchy (including function names)
  94. corresponding to the callstack of the op. e.g. If module A's forward call's
  95. module B's forward which contains an aten::add op,
  96. then aten::add's module hierarchy is A.B
  97. Note that this support exist, at the moment, only for TorchScript models
  98. and not eager mode models.
  99. experimental_config (_ExperimentalConfig) : A set of experimental options
  100. used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
  101. execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
  102. `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
  103. representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
  104. When this argument is included the observer start() and stop() will be called for the
  105. same time window as PyTorch profiler.
  106. acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
  107. .. note::
  108. This API is experimental and subject to change in the future.
  109. Enabling shape and stack tracing results in additional overhead.
  110. When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
  111. that may further prevent certain optimizations that depend on the reference count and introduce
  112. extra tensor copies.
  113. """
  114. def __init__(
  115. self,
  116. *,
  117. activities: Optional[Iterable[ProfilerActivity]] = None,
  118. record_shapes: bool = False,
  119. profile_memory: bool = False,
  120. with_stack: bool = False,
  121. with_flops: bool = False,
  122. with_modules: bool = False,
  123. experimental_config: Optional[_ExperimentalConfig] = None,
  124. execution_trace_observer: Optional[_ITraceObserver] = None,
  125. acc_events: bool = False,
  126. custom_trace_id_callback: Optional[Callable[[], str]] = None,
  127. ) -> None:
  128. self.activities = set(activities) if activities else supported_activities()
  129. self.record_shapes = record_shapes
  130. self.with_flops = with_flops
  131. self.profile_memory = profile_memory
  132. self.with_stack = with_stack
  133. self.with_modules = with_modules
  134. self.experimental_config = experimental_config
  135. self.execution_trace_observer = execution_trace_observer
  136. self.acc_events = acc_events
  137. self.custom_trace_id_callback = custom_trace_id_callback
  138. self.profiler: Optional[prof.profile] = None
  139. self.has_cudagraphs = False
  140. self.mem_tl: Optional[MemoryProfileTimeline] = None
  141. self.use_device = None
  142. if ProfilerActivity.CUDA in self.activities:
  143. self.use_device = "cuda"
  144. elif ProfilerActivity.XPU in self.activities:
  145. self.use_device = "xpu"
  146. elif ProfilerActivity.MTIA in self.activities:
  147. self.use_device = "mtia"
  148. elif ProfilerActivity.HPU in self.activities:
  149. self.use_device = "hpu"
  150. elif ProfilerActivity.PrivateUse1 in self.activities:
  151. self.use_device = _get_privateuse1_backend_name()
  152. # user-defined metadata to be amended to the trace
  153. self.preset_metadata: dict[str, str] = {}
  154. def start(self) -> None:
  155. self.prepare_trace()
  156. self.start_trace()
  157. def stop(self) -> None:
  158. self.stop_trace()
  159. def prepare_trace(self) -> None:
  160. if hasattr(torch, "_inductor"):
  161. import torch._inductor.config as inductor_config
  162. self.has_cudagraphs = inductor_config.triton.cudagraphs
  163. if (self.profiler is None) or (not self.acc_events):
  164. self.profiler = prof.profile(
  165. use_cpu=(ProfilerActivity.CPU in self.activities),
  166. use_device=self.use_device,
  167. record_shapes=self.record_shapes,
  168. with_flops=self.with_flops,
  169. profile_memory=self.profile_memory,
  170. with_stack=self.with_stack,
  171. with_modules=self.with_modules,
  172. use_kineto=True,
  173. experimental_config=self.experimental_config,
  174. acc_events=self.acc_events,
  175. custom_trace_id_callback=self.custom_trace_id_callback,
  176. )
  177. self.profiler._prepare_trace()
  178. def start_trace(self) -> None:
  179. if self.execution_trace_observer:
  180. self.execution_trace_observer.start()
  181. assert self.profiler is not None
  182. self.profiler._start_trace()
  183. if self.profile_memory:
  184. self.add_metadata_json("profile_memory", "1")
  185. if self.with_stack:
  186. self.add_metadata_json("with_stack", "1")
  187. if self.record_shapes:
  188. self.add_metadata_json("record_shapes", "1")
  189. if self.with_modules:
  190. self.add_metadata_json("with_modules", "1")
  191. if self.with_flops:
  192. self.add_metadata_json("with_flops", "1")
  193. if kineto_available():
  194. dist_info = self._get_distributed_info()
  195. if dist_info:
  196. self.add_metadata_json(
  197. "distributedInfo", json.dumps(dist_info, cls=_NumpyEncoder)
  198. )
  199. cuda_version = None
  200. if hasattr(torch, "version"):
  201. from torch.torch_version import TorchVersion
  202. cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0"))
  203. if self.has_cudagraphs and (
  204. (cuda_version and cuda_version < "12.6")
  205. or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
  206. ):
  207. os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
  208. self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1")
  209. # FIXME: CUDA Graph does not work well with CUPTI teardown.
  210. # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
  211. # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
  212. # Workaround: turn off CUPTI teardown when using CUDA Graphs.
  213. os.environ["TEARDOWN_CUPTI"] = "0"
  214. # Insert the preset user metadata to the trace
  215. for k, v in self.preset_metadata.items():
  216. self.add_metadata_json(k, v)
  217. def stop_trace(self) -> None:
  218. if self.execution_trace_observer:
  219. self.execution_trace_observer.stop()
  220. assert self.profiler is not None
  221. self.profiler.__exit__(None, None, None)
  222. def export_chrome_trace(self, path: str):
  223. """
  224. Exports the collected trace in Chrome JSON format. If kineto is enabled, only
  225. last cycle in schedule is exported.
  226. """
  227. assert self.profiler
  228. if path.endswith(".gz"):
  229. fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
  230. fp.close()
  231. retvalue = self.profiler.export_chrome_trace(fp.name)
  232. with open(fp.name, "rb") as fin:
  233. with gzip.open(path, "wb") as fout:
  234. fout.writelines(fin)
  235. os.remove(fp.name)
  236. return retvalue
  237. else:
  238. return self.profiler.export_chrome_trace(path)
  239. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  240. """Save stack traces to a file
  241. Args:
  242. path (str): save stacks file to this location;
  243. metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
  244. """
  245. assert self.profiler
  246. return self.profiler.export_stacks(path, metric)
  247. def toggle_collection_dynamic(
  248. self, enable: bool, activities: Iterable[ProfilerActivity]
  249. ) -> None:
  250. """Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops
  251. (CPU) and CUDA activity supported in Kineto
  252. Args:
  253. activities (iterable): list of activity groups to use in profiling, supported values:
  254. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``
  255. Examples:
  256. .. code-block:: python
  257. with torch.profiler.profile(
  258. activities=[
  259. torch.profiler.ProfilerActivity.CPU,
  260. torch.profiler.ProfilerActivity.CUDA,
  261. ]
  262. ) as p:
  263. code_to_profile_0()
  264. // turn off collection of all CUDA activity
  265. p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA])
  266. code_to_profile_1()
  267. // turn on collection of all CUDA activity
  268. p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA])
  269. code_to_profile_2()
  270. print(p.key_averages().table(
  271. sort_by="self_cuda_time_total", row_limit=-1))
  272. """
  273. if not self.profiler:
  274. return
  275. self.profiler.toggle_collection_dynamic(enable, activities)
  276. def key_averages(
  277. self,
  278. group_by_input_shape: bool = False,
  279. group_by_stack_n: int = 0,
  280. group_by_overload_name: bool = False,
  281. ):
  282. """Averages events, grouping them by operator name and (optionally) input shapes, stack
  283. and overload name.
  284. .. note::
  285. To use shape/stack functionality make sure to set record_shapes/with_stack
  286. when creating profiler context manager.
  287. """
  288. assert self.profiler
  289. return self.profiler.key_averages(
  290. group_by_input_shape, group_by_stack_n, group_by_overload_name
  291. )
  292. def events(self):
  293. """
  294. Returns the list of unaggregated profiler events,
  295. to be used in the trace callback or after the profiling is finished
  296. """
  297. assert self.profiler
  298. return self.profiler.function_events
  299. def add_metadata(self, key: str, value: str) -> None:
  300. """
  301. Adds a user defined metadata with a string key and a string value
  302. into the trace file
  303. """
  304. wrapped_value = '"' + value.replace('"', '\\"') + '"'
  305. torch.autograd._add_metadata_json(key, wrapped_value)
  306. def add_metadata_json(self, key: str, value: str) -> None:
  307. """
  308. Adds a user defined metadata with a string key and a valid json value
  309. into the trace file
  310. """
  311. torch.autograd._add_metadata_json(key, value)
  312. def preset_metadata_json(self, key: str, value: str) -> None:
  313. """
  314. Preset a user defined metadata when the profiler is not started
  315. and added into the trace file later.
  316. Metadata is in the format of a string key and a valid json value
  317. """
  318. self.preset_metadata[key] = value
  319. def _get_distributed_info(self):
  320. import torch.distributed as dist
  321. if not dist.is_available() or not dist.is_initialized():
  322. return None
  323. backend = dist.get_backend()
  324. dist_info = {
  325. "backend": backend,
  326. "rank": dist.get_rank(),
  327. "world_size": dist.get_world_size(),
  328. "pg_count": dist.get_pg_count(),
  329. "pg_config": dist.distributed_c10d._get_all_pg_configs(),
  330. }
  331. if backend == "nccl":
  332. nccl_version = torch.cuda.nccl.version()
  333. dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
  334. return dist_info
  335. def _memory_profile(self) -> MemoryProfile:
  336. required = ("record_shapes", "profile_memory", "with_stack")
  337. missing = [f"{i}=True" for i in required if not getattr(self, i)]
  338. if missing:
  339. raise ValueError(f"{', '.join(missing)} required for memory profiling.")
  340. assert self.profiler is not None and self.profiler.kineto_results is not None
  341. return MemoryProfile(self.profiler.kineto_results)
  342. def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
  343. """Export memory event information from the profiler collected
  344. tree for a given device, and export a timeline plot. There are 3
  345. exportable files using ``export_memory_timeline``, each controlled by the
  346. ``path``'s suffix.
  347. - For an HTML compatible plot, use the suffix ``.html``, and a memory timeline
  348. plot will be embedded as a PNG file in the HTML file.
  349. - For plot points consisting of ``[times, [sizes by category]]``, where
  350. ``times`` are timestamps and ``sizes`` are memory usage for each category.
  351. The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON
  352. (``.json.gz``) depending on the suffix.
  353. - For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory
  354. event will consist of ``(timestamp, action, numbytes, category)``, where
  355. ``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``,
  356. and ``category`` is one of the enums from
  357. ``torch.profiler._memory_profiler.Category``.
  358. Output: Memory timeline written as gzipped JSON, JSON, or HTML.
  359. """
  360. # Default to device 0, if unset. Fallback on cpu.
  361. if device is None:
  362. if self.use_device and self.use_device != "cuda":
  363. device = self.use_device + ":0"
  364. else:
  365. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  366. # Construct the memory timeline plot data
  367. self.mem_tl = MemoryProfileTimeline(self._memory_profile())
  368. # Depending on the file suffix, save the data as json.gz or json.
  369. # For html, we can embed the image into an HTML file.
  370. if path.endswith(".html"):
  371. self.mem_tl.export_memory_timeline_html(path, device)
  372. elif path.endswith(".gz"):
  373. fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
  374. fp.close()
  375. if path.endswith("raw.json.gz"):
  376. self.mem_tl.export_memory_timeline_raw(fp.name, device)
  377. else:
  378. self.mem_tl.export_memory_timeline(fp.name, device)
  379. with open(fp.name) as fin:
  380. with gzip.open(path, "wt") as fout:
  381. fout.writelines(fin)
  382. os.remove(fp.name)
  383. else:
  384. self.mem_tl.export_memory_timeline(path, device)
  385. class ProfilerAction(Enum):
  386. """
  387. Profiler actions that can be taken at the specified intervals
  388. """
  389. NONE = 0
  390. WARMUP = 1
  391. RECORD = 2
  392. RECORD_AND_SAVE = 3
  393. def schedule(
  394. *,
  395. wait: int,
  396. warmup: int,
  397. active: int,
  398. repeat: int = 0,
  399. skip_first: int = 0,
  400. skip_first_wait: int = 0,
  401. ) -> Callable:
  402. """
  403. Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip
  404. the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps,
  405. then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps.
  406. The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that
  407. the cycles will continue until the profiling is finished.
  408. The ``skip_first_wait`` parameter controls whether the first ``wait`` stage should be skipped.
  409. This can be useful if a user wants to wait longer than ``skip_first`` between cycles, but not
  410. for the first profile. For example, if ``skip_first`` is 10 and ``wait`` is 20, the first cycle will
  411. wait 10 + 20 = 30 steps before warmup if ``skip_first_wait`` is zero, but will wait only 10
  412. steps if ``skip_first_wait`` is non-zero. All subsequent cycles will then wait 20 steps between the
  413. last active and warmup.
  414. """
  415. def schedule_fn(step: int) -> ProfilerAction:
  416. assert step >= 0
  417. if step < skip_first:
  418. return ProfilerAction.NONE
  419. else:
  420. step -= skip_first
  421. # If wait >> skip_first and we want to grab profiling early, shift left by wait if skip_first_wait is True
  422. if skip_first_wait != 0:
  423. step += wait
  424. num_steps = wait + warmup + active
  425. if repeat > 0 and step / num_steps >= repeat:
  426. return ProfilerAction.NONE
  427. mod_step = step % num_steps
  428. if mod_step < wait:
  429. return ProfilerAction.NONE
  430. elif mod_step < wait + warmup:
  431. return ProfilerAction.WARMUP
  432. else:
  433. return (
  434. ProfilerAction.RECORD
  435. if mod_step < num_steps - 1
  436. else ProfilerAction.RECORD_AND_SAVE
  437. )
  438. assert (
  439. wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
  440. ), "Invalid profiler schedule arguments"
  441. if warmup == 0:
  442. warn("Profiler won't be using warmup, this can skew profiler results")
  443. return schedule_fn
  444. def _default_schedule_fn(_: int) -> ProfilerAction:
  445. """
  446. Default profiler behavior - immediately starts recording the events,
  447. keeps doing it on every profiler step.
  448. """
  449. return ProfilerAction.RECORD
  450. def tensorboard_trace_handler(
  451. dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False
  452. ):
  453. """
  454. Outputs tracing files to directory of ``dir_name``, then that directory can be
  455. directly delivered to tensorboard as logdir.
  456. ``worker_name`` should be unique for each worker in distributed scenario,
  457. it will be set to '[hostname]_[pid]' by default.
  458. """
  459. import socket
  460. import time
  461. def handler_fn(prof) -> None:
  462. nonlocal worker_name
  463. if not os.path.isdir(dir_name):
  464. try:
  465. os.makedirs(dir_name, exist_ok=True)
  466. except Exception as e:
  467. raise RuntimeError("Can't create directory: " + dir_name) from e
  468. if not worker_name:
  469. worker_name = f"{socket.gethostname()}_{os.getpid()}"
  470. # Use nanosecond here to avoid naming clash when exporting the trace
  471. file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json"
  472. if use_gzip:
  473. file_name = file_name + ".gz"
  474. prof.export_chrome_trace(os.path.join(dir_name, file_name))
  475. return handler_fn
  476. class profile(_KinetoProfile):
  477. """Profiler context manager.
  478. Args:
  479. activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
  480. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
  481. ``torch.profiler.ProfilerActivity.XPU``.
  482. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
  483. or (when available) ProfilerActivity.XPU.
  484. schedule (Callable): callable that takes step (int) as a single parameter and returns
  485. ``ProfilerAction`` value that specifies the profiler action to perform at each step.
  486. on_trace_ready (Callable): callable that is called at each step when ``schedule``
  487. returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling.
  488. record_shapes (bool): save information about operator's input shapes.
  489. profile_memory (bool): track tensor memory allocation/deallocation.
  490. with_stack (bool): record source information (file and line number) for the ops.
  491. with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
  492. (matrix multiplication and 2D convolution).
  493. with_modules (bool): record module hierarchy (including function names)
  494. corresponding to the callstack of the op. e.g. If module A's forward call's
  495. module B's forward which contains an aten::add op,
  496. then aten::add's module hierarchy is A.B
  497. Note that this support exist, at the moment, only for TorchScript models
  498. and not eager mode models.
  499. experimental_config (_ExperimentalConfig) : A set of experimental options
  500. used for Kineto library features. Note, backward compatibility is not guaranteed.
  501. execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
  502. `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
  503. representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
  504. When this argument is included the observer start() and stop() will be called for the
  505. same time window as PyTorch profiler. See the examples section below for a code sample.
  506. acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
  507. use_cuda (bool):
  508. .. deprecated:: 1.8.1
  509. use ``activities`` instead.
  510. .. note::
  511. Use :func:`~torch.profiler.schedule` to generate the callable schedule.
  512. Non-default schedules are useful when profiling long training jobs
  513. and allow the user to obtain multiple traces at the different iterations
  514. of the training process.
  515. The default schedule simply records all the events continuously for the
  516. duration of the context manager.
  517. .. note::
  518. Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard:
  519. ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)``
  520. After profiling, result files can be found in the specified directory. Use the command:
  521. ``tensorboard --logdir dir_name``
  522. to see the results in TensorBoard.
  523. For more information, see
  524. `PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__
  525. .. note::
  526. Enabling shape and stack tracing results in additional overhead.
  527. When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
  528. that may further prevent certain optimizations that depend on the reference count and introduce
  529. extra tensor copies.
  530. Examples:
  531. .. code-block:: python
  532. with torch.profiler.profile(
  533. activities=[
  534. torch.profiler.ProfilerActivity.CPU,
  535. torch.profiler.ProfilerActivity.CUDA,
  536. ]
  537. ) as p:
  538. code_to_profile()
  539. print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
  540. Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
  541. .. code-block:: python
  542. # Non-default profiler schedule allows user to turn profiler on and off
  543. # on different iterations of the training loop;
  544. # trace_handler is called every time a new trace becomes available
  545. def trace_handler(prof):
  546. print(
  547. prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)
  548. )
  549. # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
  550. with torch.profiler.profile(
  551. activities=[
  552. torch.profiler.ProfilerActivity.CPU,
  553. torch.profiler.ProfilerActivity.CUDA,
  554. ],
  555. # In this example with wait=1, warmup=1, active=2, repeat=1,
  556. # profiler will skip the first step/iteration,
  557. # start warming up on the second, record
  558. # the third and the forth iterations,
  559. # after which the trace will become available
  560. # and on_trace_ready (when set) is called;
  561. # the cycle repeats starting with the next step
  562. schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
  563. on_trace_ready=trace_handler,
  564. # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
  565. # used when outputting for tensorboard
  566. ) as p:
  567. for iter in range(N):
  568. code_iteration_to_profile(iter)
  569. # send a signal to the profiler that the next iteration has started
  570. p.step()
  571. The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)
  572. .. code-block:: python
  573. with torch.profiler.profile(
  574. ...
  575. execution_trace_observer=(
  576. ExecutionTraceObserver().register_callback("./execution_trace.json")
  577. ),
  578. ) as p:
  579. for iter in range(N):
  580. code_iteration_to_profile(iter)
  581. p.step()
  582. You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py.
  583. Note: One can also pass any object satisfying the _ITraceObserver interface.
  584. """
  585. def __init__(
  586. self,
  587. *,
  588. activities: Optional[Iterable[ProfilerActivity]] = None,
  589. schedule: Optional[Callable[[int], ProfilerAction]] = None,
  590. on_trace_ready: Optional[Callable[..., Any]] = None,
  591. record_shapes: bool = False,
  592. profile_memory: bool = False,
  593. with_stack: bool = False,
  594. with_flops: bool = False,
  595. with_modules: bool = False,
  596. experimental_config: Optional[_ExperimentalConfig] = None,
  597. execution_trace_observer: Optional[_ITraceObserver] = None,
  598. acc_events: bool = False,
  599. # deprecated:
  600. use_cuda: Optional[bool] = None,
  601. custom_trace_id_callback: Optional[Callable[[], str]] = None,
  602. ) -> None:
  603. activities_set = set(activities) if activities else supported_activities()
  604. if use_cuda is not None:
  605. warn(
  606. "`use_cuda` is deprecated, use `activities` argument instead",
  607. FutureWarning,
  608. stacklevel=2,
  609. )
  610. if use_cuda:
  611. activities_set.add(ProfilerActivity.CUDA)
  612. elif ProfilerActivity.CUDA in activities_set:
  613. activities_set.remove(ProfilerActivity.CUDA)
  614. assert len(activities_set) > 0, "No valid profiler activities found"
  615. super().__init__(
  616. activities=activities,
  617. record_shapes=record_shapes,
  618. profile_memory=profile_memory,
  619. with_stack=with_stack,
  620. with_flops=with_flops,
  621. with_modules=with_modules,
  622. experimental_config=experimental_config,
  623. execution_trace_observer=execution_trace_observer
  624. if execution_trace_observer
  625. else ExecutionTraceObserver.build_execution_trace_obs_from_env(),
  626. acc_events=acc_events,
  627. custom_trace_id_callback=custom_trace_id_callback,
  628. )
  629. if schedule:
  630. self.schedule = schedule
  631. # add step markers into the trace and table view
  632. self.record_steps = True
  633. else:
  634. self.schedule = _default_schedule_fn
  635. self.record_steps = False
  636. self.on_trace_ready = on_trace_ready
  637. self.step_num = 0
  638. self.current_action = self.schedule(self.step_num)
  639. self.step_rec_fn: Optional[prof.record_function] = None
  640. self.action_map: dict[
  641. tuple[ProfilerAction, Optional[ProfilerAction]], list[Any]
  642. ] = {
  643. # key is (prev_action, current_action), value is action list corresponding to the state pair.
  644. (ProfilerAction.NONE, ProfilerAction.NONE): [],
  645. (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace],
  646. (ProfilerAction.NONE, ProfilerAction.RECORD): [
  647. self.prepare_trace,
  648. self.start_trace,
  649. ],
  650. (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [
  651. self.prepare_trace,
  652. self.start_trace,
  653. ],
  654. (ProfilerAction.WARMUP, ProfilerAction.NONE): [
  655. partial(warn, "Incorrect schedule: WARMUP followed by NONE"),
  656. self.start_trace,
  657. self.stop_trace,
  658. ],
  659. (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [],
  660. (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace],
  661. (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace],
  662. (ProfilerAction.RECORD, ProfilerAction.NONE): [
  663. partial(warn, "Incorrect schedule: RECORD followed by NONE"),
  664. self.stop_trace,
  665. ],
  666. (ProfilerAction.RECORD, ProfilerAction.WARMUP): [
  667. partial(warn, "Incorrect schedule: RECORD followed by WARMUP"),
  668. self.stop_trace,
  669. ],
  670. (ProfilerAction.RECORD, ProfilerAction.RECORD): [],
  671. (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [],
  672. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [
  673. self.stop_trace,
  674. self._trace_ready,
  675. ],
  676. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [
  677. self.stop_trace,
  678. self._trace_ready,
  679. self.prepare_trace,
  680. ],
  681. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [
  682. self.stop_trace,
  683. self._trace_ready,
  684. self.prepare_trace,
  685. self.start_trace,
  686. ],
  687. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [
  688. self.stop_trace,
  689. self._trace_ready,
  690. self.prepare_trace,
  691. self.start_trace,
  692. ],
  693. # used for exit action
  694. (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace],
  695. (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready],
  696. (ProfilerAction.RECORD_AND_SAVE, None): [
  697. self.stop_trace,
  698. self._trace_ready,
  699. ],
  700. }
  701. # Start tracking increments to profiler step, this will be used
  702. # by Kineto
  703. prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME)
  704. def __enter__(self):
  705. self.start()
  706. return self
  707. def __exit__(self, exc_type, exc_val, exc_tb):
  708. self.stop()
  709. prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME)
  710. if self.execution_trace_observer:
  711. self.execution_trace_observer.cleanup()
  712. def start(self) -> None:
  713. self._transit_action(ProfilerAction.NONE, self.current_action)
  714. if self.record_steps:
  715. self.step_rec_fn = prof.record_function(
  716. "ProfilerStep#" + str(self.step_num)
  717. )
  718. self.step_rec_fn.__enter__()
  719. def stop(self) -> None:
  720. if self.record_steps and self.step_rec_fn:
  721. self.step_rec_fn.__exit__(None, None, None)
  722. self._transit_action(self.current_action, None)
  723. def step(self) -> None:
  724. """
  725. Signals the profiler that the next profiling step has started.
  726. """
  727. if self.record_steps and self.step_rec_fn:
  728. self.step_rec_fn.__exit__(None, None, None)
  729. prev_action = self.current_action
  730. self.step_num += 1
  731. self.current_action = self.schedule(self.step_num)
  732. self._transit_action(prev_action, self.current_action)
  733. if os.environ.get("KINETO_USE_DAEMON", "") or (
  734. is_fbcode() and os.environ.get("KINETO_FORCE_STEP_HOOK", "")
  735. ):
  736. prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME)
  737. if self.record_steps:
  738. self.step_rec_fn = prof.record_function(
  739. "ProfilerStep#" + str(self.step_num)
  740. )
  741. self.step_rec_fn.__enter__()
  742. def set_custom_trace_id_callback(self, callback) -> None:
  743. """
  744. Sets a callback to be called when a new trace ID is generated.
  745. """
  746. self.custom_trace_id_callback = callback
  747. def get_trace_id(self):
  748. """
  749. Returns the current trace ID.
  750. """
  751. if self.profiler is None:
  752. return None
  753. return self.profiler.trace_id
  754. def _trace_ready(self) -> None:
  755. if self.on_trace_ready:
  756. self.on_trace_ready(self)
  757. def _transit_action(self, prev_action, current_action) -> None:
  758. action_list = self.action_map.get((prev_action, current_action))
  759. if action_list:
  760. for action in action_list:
  761. action()
  762. def _stats(self) -> Optional[prof._ProfilerStats]:
  763. if self.profiler is None:
  764. return None
  765. return self.profiler._stats
  766. class ExecutionTraceObserver(_ITraceObserver):
  767. """Execution Trace Observer
  768. Each process can have a single ExecutionTraceObserver instance. The observer
  769. can be added to record function callbacks via calling register_callback()
  770. explicitly. Without calling unregister_callback(), repeated calls to
  771. register_callback() will not add additional observers to record function
  772. callbacks. Once an ExecutionTraceObserver is created, the start() and stop()
  773. methods control when the event data is recorded.
  774. Deleting or calling unregister_callback() will remove the observer from the
  775. record function callbacks, finalize the output file, and will stop
  776. incurring any overheads.
  777. """
  778. def __init__(self) -> None:
  779. """
  780. Initializes the default states.
  781. """
  782. self._registered = False
  783. self._execution_trace_running = False
  784. self.extra_resources_collection = False
  785. self.resources_dir: str = ""
  786. self.output_file_path: str = ""
  787. self.output_file_path_observer: str = ""
  788. def __del__(self) -> None:
  789. """
  790. Calls unregister_callback() to make sure to finalize outputs.
  791. """
  792. self.unregister_callback()
  793. @staticmethod
  794. def build_execution_trace_obs_from_env() -> Optional["ExecutionTraceObserver"]:
  795. """
  796. Returns an ExecutionTraceObserver instance if the environment variable
  797. ENABLE_PYTORCH_EXECUTION_TRACE is set to 1, otherwise returns None.
  798. Configures the observer to also collect extra resources if the environment variable
  799. ``ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS=1``. These are resources such as generated kernels,
  800. index tensor data etc. that are required to make the Execution Trace replayable.
  801. """
  802. if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE", "0") == "1":
  803. try:
  804. fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
  805. except Exception as e:
  806. warn(
  807. f"Execution trace will not be recorded. Exception on creating default temporary file: {e}"
  808. )
  809. return None
  810. fp.close()
  811. et = ExecutionTraceObserver()
  812. et.register_callback(fp.name)
  813. # additionally, check if the env requires us to collect extra resources
  814. if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS", "0") == "1":
  815. et.set_extra_resource_collection(True)
  816. else:
  817. et.set_extra_resource_collection(False)
  818. return et
  819. return None
  820. def set_extra_resource_collection(self, val) -> None:
  821. """
  822. Collects extra resources such as generated kernels, index tensor data, and any other
  823. metadata that is required to complete the Execution Trace content.
  824. The caller should call this method with val=True after calling register_callback() if they want
  825. to collect the extra resources.
  826. """
  827. self.extra_resources_collection = val
  828. if self.extra_resources_collection:
  829. self.get_resources_dir(can_create=True)
  830. return
  831. def register_callback(self, output_file_path: str) -> Self:
  832. """
  833. Adds ET observer to record function callbacks. The data will be
  834. written to output_file_path.
  835. """
  836. def get_temp_uncompressed_file() -> str:
  837. fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
  838. fp.close()
  839. return fp.name
  840. if not self._registered:
  841. self.output_file_path = output_file_path
  842. if output_file_path.endswith(".gz"):
  843. output_file_path = get_temp_uncompressed_file()
  844. self.output_file_path_observer = output_file_path
  845. self._registered = _add_execution_trace_observer(output_file_path)
  846. return self
  847. def get_resources_dir(self, can_create=False) -> Optional[str]:
  848. """
  849. Generates the resources directory for the generated kernels,
  850. or index tensor data or any other metadata that is required
  851. to complete the Execution Trace content.
  852. The directory is created right where the ET file is being output.
  853. Only works if the observer has called set_extra_resource_collection(val=True).
  854. Returns None if the observer is not configured with extra resource collection.
  855. """
  856. if not self.extra_resources_collection:
  857. return None
  858. if self.resources_dir:
  859. # already created
  860. return self.resources_dir
  861. generated_path = ExecutionTraceObserver.get_resources_dir_for_et_path(
  862. self.output_file_path, create_dir=can_create
  863. )
  864. if not generated_path:
  865. # could not find of create the resources dir
  866. return None
  867. self.resources_dir = generated_path
  868. return self.resources_dir
  869. @staticmethod
  870. def get_resources_dir_for_et_path(
  871. trace_path, create_dir: bool = False
  872. ) -> Optional[str]:
  873. work_dir, file_name = os.path.split(trace_path)
  874. resource_dir = os.path.join(
  875. work_dir, os.path.splitext(file_name)[0] + "_resources"
  876. )
  877. if not os.path.exists(resource_dir):
  878. if create_dir:
  879. try:
  880. os.mkdir(resource_dir)
  881. except Exception:
  882. warn(f"Execution trace exception when creating {resource_dir}")
  883. return None
  884. else:
  885. return None
  886. return resource_dir
  887. def unregister_callback(self) -> None:
  888. """
  889. Removes ET observer from record function callbacks.
  890. """
  891. def _save_triton_kernels() -> None:
  892. try:
  893. resource_dir = self.get_resources_dir()
  894. except Exception as e:
  895. warn(
  896. f"Execution trace exception when generating resource directory: {e}"
  897. )
  898. return
  899. if not resource_dir:
  900. return
  901. # Save the kernel paths for the generated kernels
  902. from torch._inductor.codecache import PyCodeCache as PyCodeCache
  903. kernel_files = [
  904. v.__file__
  905. for v in PyCodeCache.modules
  906. if getattr(v, "__file__", None) is not None
  907. ]
  908. for kernel_file in kernel_files:
  909. if kernel_file is None:
  910. continue
  911. name = os.path.basename(kernel_file)
  912. dst = os.path.join(resource_dir, name)
  913. shutil.copyfile(kernel_file, dst)
  914. def _save_gz_file(uncompressed_file: str, output_file: str) -> None:
  915. print(f"Execution Trace: compressing {uncompressed_file} to {output_file}")
  916. with open(uncompressed_file, "rb") as fin:
  917. with gzip.open(output_file, "wb") as fout:
  918. fout.writelines(fin)
  919. os.remove(uncompressed_file)
  920. if self._registered:
  921. self.stop()
  922. try:
  923. _save_triton_kernels()
  924. except Exception as e:
  925. warn(f"Execution trace failed to save kernels: {e}")
  926. _remove_execution_trace_observer()
  927. if self.output_file_path.endswith("gz"):
  928. _save_gz_file(self.output_file_path_observer, self.output_file_path)
  929. self._registered = False
  930. @property
  931. def is_registered(self):
  932. """
  933. Returns True if the execution trace observer is registered, otherwise False.
  934. """
  935. return self._registered
  936. def is_running(self):
  937. """
  938. Returns True if the observer is running, otherwise False.
  939. """
  940. return self._execution_trace_running
  941. def start(self) -> None:
  942. """
  943. Starts to capture.
  944. """
  945. if self._registered and not self._execution_trace_running:
  946. _enable_execution_trace_observer()
  947. self._execution_trace_running = True
  948. self._record_pg_config()
  949. def stop(self) -> None:
  950. """
  951. Stops to capture.
  952. """
  953. if self._execution_trace_running:
  954. _disable_execution_trace_observer()
  955. self._execution_trace_running = False
  956. def cleanup(self) -> None:
  957. """
  958. Calls unregister_callback() to make sure to finalize outputs.
  959. """
  960. self.unregister_callback()
  961. def get_output_file_path(self) -> Optional[str]:
  962. """
  963. Returns the output file name or None.
  964. """
  965. if self.output_file_path:
  966. return self.output_file_path
  967. else:
  968. return None
  969. def _record_pg_config(self) -> None:
  970. # Records the PG config info to the trace as node:
  971. # ## process_group:init ##
  972. if (
  973. self.is_registered
  974. and torch.distributed.is_available()
  975. and torch.distributed.is_initialized()
  976. ):
  977. pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info
  978. torch.autograd._record_function_with_args_enter(
  979. "## process_group:init ##",
  980. json.dumps(pg_config_info, cls=_NumpyEncoder),
  981. )