profiler_legacy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. import warnings
  4. from typing_extensions import deprecated
  5. import torch
  6. import torch.cuda
  7. from torch.autograd import (
  8. _disable_profiler_legacy,
  9. _enable_profiler_legacy,
  10. DeviceType,
  11. ProfilerConfig,
  12. ProfilerState,
  13. )
  14. from torch.autograd.profiler_util import (
  15. _filter_name,
  16. _filter_stack_entry,
  17. _rewrite_name,
  18. EventList,
  19. FunctionEvent,
  20. MEMORY_EVENT_NAME,
  21. )
  22. __all__ = ["profile"]
  23. @deprecated(
  24. "`torch.autograd.profiler_legacy.profile` is deprecated and will be removed in a future release. "
  25. "Please use `torch.profiler` instead.",
  26. category=None, # TODO: change to `FutureWarning`
  27. )
  28. class profile:
  29. """DEPRECATED: use torch.profiler instead."""
  30. def __init__(
  31. self,
  32. enabled=True,
  33. *,
  34. use_cuda=False,
  35. record_shapes=False,
  36. with_flops=False,
  37. profile_memory=False,
  38. with_stack=False,
  39. with_modules=False,
  40. ):
  41. self.enabled: bool = enabled
  42. if not self.enabled:
  43. return
  44. self.use_cuda = use_cuda
  45. self.function_events = None
  46. self.entered = False
  47. self.record_shapes = record_shapes
  48. self.with_flops = with_flops
  49. self.record_shapes |= self.with_flops
  50. self.profile_memory = profile_memory
  51. self.with_stack = with_stack
  52. self.with_modules = with_modules
  53. if self.use_cuda and not torch.cuda.is_available():
  54. warnings.warn(
  55. "CUDA is not available, disabling CUDA profiling",
  56. stacklevel=2,
  57. )
  58. self.use_cuda = False
  59. if self.use_cuda:
  60. self.profiler_kind = ProfilerState.CUDA
  61. else:
  62. self.profiler_kind = ProfilerState.CPU
  63. def config(self):
  64. return ProfilerConfig(
  65. self.profiler_kind,
  66. self.record_shapes,
  67. self.profile_memory,
  68. self.with_stack,
  69. self.with_flops,
  70. self.with_modules,
  71. # avoid exposing _ExperimentalConfig this in legacy public API
  72. torch._C._profiler._ExperimentalConfig(),
  73. )
  74. def __enter__(self):
  75. if not self.enabled:
  76. return
  77. if self.entered:
  78. raise RuntimeError("Profiler context manager is not reentrant")
  79. self.entered = True
  80. self._start_trace()
  81. return self
  82. def _start_trace(self):
  83. _enable_profiler_legacy(self.config())
  84. def __exit__(self, exc_type, exc_val, exc_tb):
  85. if not self.enabled:
  86. return
  87. if self.use_cuda:
  88. torch.cuda.synchronize()
  89. records = _disable_profiler_legacy()
  90. parsed_results = _parse_legacy_records(records)
  91. self.function_events = EventList(
  92. parsed_results,
  93. use_device="cuda" if self.use_cuda else None,
  94. profile_memory=self.profile_memory,
  95. with_flops=self.with_flops,
  96. )
  97. self.function_events._build_tree()
  98. return False
  99. def __repr__(self):
  100. if self.function_events is None:
  101. return "<unfinished profiler_legacy.profile>"
  102. return repr(self.function_events)
  103. def __str__(self):
  104. if self.function_events is None:
  105. return "<unfinished profile.profiler_legacy.profile>"
  106. return str(self.function_events)
  107. def _check_finish(self):
  108. if self.function_events is None:
  109. raise RuntimeError("Profiler didn't finish running")
  110. def table(
  111. self,
  112. sort_by=None,
  113. row_limit=100,
  114. max_src_column_width=75,
  115. max_name_column_width=55,
  116. max_shapes_column_width=80,
  117. header=None,
  118. top_level_events_only=False,
  119. ):
  120. self._check_finish()
  121. assert self.function_events is not None
  122. return self.function_events.table(
  123. sort_by=sort_by,
  124. row_limit=row_limit,
  125. max_src_column_width=max_src_column_width,
  126. max_name_column_width=max_name_column_width,
  127. max_shapes_column_width=max_shapes_column_width,
  128. header=header,
  129. top_level_events_only=top_level_events_only,
  130. )
  131. table.__doc__ = EventList.table.__doc__
  132. def export_chrome_trace(self, path):
  133. self._check_finish()
  134. assert self.function_events is not None
  135. return self.function_events.export_chrome_trace(path)
  136. export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
  137. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  138. self._check_finish()
  139. assert self.function_events is not None, "Expected profiling results"
  140. assert self.with_stack, "export_stacks() requires with_stack=True"
  141. return self.function_events.export_stacks(path, metric)
  142. def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
  143. self._check_finish()
  144. assert self.function_events is not None, "Expected profiling results"
  145. return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
  146. key_averages.__doc__ = EventList.key_averages.__doc__
  147. def total_average(self):
  148. self._check_finish()
  149. assert self.function_events is not None, "Expected profiling results"
  150. return self.function_events.total_average()
  151. total_average.__doc__ = EventList.total_average.__doc__
  152. @property
  153. def self_cpu_time_total(self):
  154. """Return CPU time as the sum of self times across all events."""
  155. self._check_finish()
  156. assert self.function_events is not None
  157. return self.function_events.self_cpu_time_total
  158. def _parse_legacy_records(thread_records):
  159. def _get_record_key(record):
  160. """Return a tuple for correlating start and end records in `_parse_legacy_records`."""
  161. return (record.handle(), record.node_id())
  162. start_record = None
  163. functions = []
  164. # '__start_profile' is not guaranteed to be first, so we must find it here
  165. for record in itertools.chain.from_iterable(thread_records):
  166. name = record.name()
  167. if start_record is None and name == "__start_profile":
  168. start_record = record
  169. assert start_record is not None and not start_record.is_remote()
  170. for thread_record_list in thread_records:
  171. # accumulated memory allocations per handle
  172. cpu_memory_allocs = {}
  173. cuda_memory_allocs = {}
  174. # ranges per handle
  175. range_starts = {}
  176. filtered_handles = set()
  177. prev_record = None
  178. for record in thread_record_list:
  179. record_key = _get_record_key(record)
  180. if _filter_name(record.name()) or record_key in filtered_handles:
  181. filtered_handles.add(record_key)
  182. continue
  183. if record.kind() == "push":
  184. # workaround to reduce double logging from operator
  185. # wrappers and redispatch
  186. if prev_record is not None:
  187. duplicate = (
  188. prev_record.name() == record.name()
  189. and prev_record.kind() == record.kind()
  190. and prev_record.node_id() == record.node_id()
  191. )
  192. if duplicate:
  193. filtered_handles.add(record_key)
  194. continue
  195. range_starts[record_key] = record
  196. cpu_memory_allocs[record_key] = 0
  197. cuda_memory_allocs[record_key] = 0
  198. elif record.kind() == "pop":
  199. assert (
  200. record_key in range_starts
  201. ), f"""Expected record with key {record_key} to exist in range_starts.
  202. This means that the pop event did not have a corresponding push."""
  203. start = range_starts[record_key]
  204. cpu_memory_usage = cpu_memory_allocs[record_key]
  205. cuda_memory_usage = cuda_memory_allocs[record_key]
  206. is_async = start.is_async() or (start.thread_id() != record.thread_id())
  207. is_remote_event = record.is_remote()
  208. start_flops = start.flops()
  209. fe = FunctionEvent(
  210. id=record.handle(),
  211. node_id=record.node_id(),
  212. name=_rewrite_name(name=start.name(), with_wildcard=True),
  213. trace_name=_rewrite_name(name=start.name(), with_wildcard=False),
  214. thread=start.thread_id(),
  215. start_us=start_record.cpu_elapsed_us(start),
  216. end_us=start_record.cpu_elapsed_us(record),
  217. fwd_thread=start.fwd_thread_id(),
  218. input_shapes=start.shapes(),
  219. stack=[
  220. entry for entry in start.stack() if _filter_stack_entry(entry)
  221. ],
  222. scope=start.scope(),
  223. use_device="cuda" if start.has_cuda() else None,
  224. cpu_memory_usage=cpu_memory_usage,
  225. device_memory_usage=cuda_memory_usage,
  226. is_async=is_async,
  227. is_remote=is_remote_event,
  228. sequence_nr=start.sequence_nr(),
  229. device_type=DeviceType.CPU,
  230. is_legacy=True,
  231. flops=start_flops,
  232. )
  233. # note: async events have only cpu total time
  234. if not is_async and start.has_cuda():
  235. duration = start.cuda_elapsed_us(record)
  236. if duration > 0:
  237. fe.append_kernel(start.name(), start.device(), duration)
  238. functions.append(fe)
  239. del range_starts[record_key]
  240. del cpu_memory_allocs[record_key]
  241. del cuda_memory_allocs[record_key]
  242. elif record.kind() == "memory_alloc":
  243. num_open_handles_cpu = len(cpu_memory_allocs)
  244. num_open_handles_cuda = len(cuda_memory_allocs)
  245. assert num_open_handles_cpu == num_open_handles_cuda
  246. for handle in cpu_memory_allocs.keys():
  247. cpu_memory_allocs[handle] += record.cpu_memory_usage()
  248. for handle in cuda_memory_allocs.keys():
  249. cuda_memory_allocs[handle] += record.cuda_memory_usage()
  250. if num_open_handles_cpu == 0:
  251. # output event as a top-level memory event
  252. fe = FunctionEvent(
  253. id=0,
  254. name=MEMORY_EVENT_NAME,
  255. trace_name=None,
  256. thread=0,
  257. start_us=0,
  258. end_us=0,
  259. stack=[],
  260. cpu_memory_usage=record.cpu_memory_usage(),
  261. device_memory_usage=record.cuda_memory_usage(),
  262. is_legacy=True,
  263. )
  264. functions.append(fe)
  265. prev_record = record
  266. # Sort functions by start time then by end time ascending.
  267. # This ensures that--in the case of nested events which
  268. # have the same start time (which may happen due to the
  269. # granularity of the given clock tick)--we always show
  270. # the outermost nested call first. This adds stability
  271. # in how FunctionEvents appear
  272. functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
  273. return functions