profiler.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """
  2. Dynamo profiling implementation.
  3. This module provides profiling functionality for Dynamo, including:
  4. - ProfileMetrics: Class for collecting and aggregating performance metrics like
  5. execution time, operator counts, and fusion statistics
  6. - ProfileResult: Class for analyzing and reporting profiling results
  7. - Utilities for tracking missed/uncaptured operations
  8. - Functions for instrumenting FX graphs with profiling capabilities
  9. The profiler helps measure and optimize the performance of Dynamo-compiled code
  10. by tracking both captured and total operations, timing, and graph statistics.
  11. """
  12. from __future__ import annotations
  13. import dataclasses
  14. import os
  15. from typing import Any
  16. from typing_extensions import Self
  17. import torch
  18. from .utils import print_once
  19. @dataclasses.dataclass
  20. class ProfileMetrics:
  21. microseconds: float = 0.0
  22. operators: int = 0
  23. fusions: int = 0
  24. graphs: int = 0
  25. def __iadd__(self, other: Self) -> Self:
  26. self.microseconds += other.microseconds
  27. self.operators += other.operators
  28. self.fusions += other.fusions
  29. return self
  30. def __add__(self, other: ProfileMetrics) -> ProfileMetrics:
  31. assert isinstance(other, ProfileMetrics)
  32. return ProfileMetrics(
  33. self.microseconds + other.microseconds,
  34. self.operators + other.operators,
  35. self.fusions + other.fusions,
  36. )
  37. def __truediv__(self, other: Any) -> ProfileMetrics:
  38. if isinstance(other, int):
  39. other = ProfileMetrics(other, other, other)
  40. return ProfileMetrics(
  41. # pyrefly: ignore [no-matching-overload]
  42. self.microseconds / max(1, other.microseconds),
  43. # pyrefly: ignore [bad-argument-type]
  44. self.operators / max(1, other.operators),
  45. # pyrefly: ignore [bad-argument-type]
  46. self.fusions / max(1, other.fusions),
  47. )
  48. def __str__(self) -> str:
  49. return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
  50. def tocsv(self) -> list[float]:
  51. return [self.operators, self.microseconds]
  52. class ProfileResult:
  53. def __init__(
  54. self, captured: ProfileMetrics, total: ProfileMetrics, unique_graphs: int
  55. ) -> None:
  56. self.captured: ProfileMetrics = captured or ProfileMetrics()
  57. self.total: ProfileMetrics = total or ProfileMetrics()
  58. self.unique_graphs: int = unique_graphs
  59. def __iadd__(self, other: Self) -> Self:
  60. self.captured += other.captured
  61. self.total += other.total
  62. self.unique_graphs += other.unique_graphs
  63. return self
  64. def percent(self) -> ProfileMetrics:
  65. return self.captured / self.total
  66. def __str__(self) -> str:
  67. return (
  68. f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
  69. f"{self.captured.operators:4}/{self.total.operators:4} = "
  70. + str(self.percent())
  71. )
  72. def tocsv(self) -> list[Any]:
  73. return [
  74. self.unique_graphs,
  75. self.captured.graphs,
  76. self.captured.operators,
  77. self.total.operators,
  78. ] + self.percent().tocsv()
  79. def should_print_missing() -> bool:
  80. return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1"
  81. def print_missing(stack: list[str]) -> None:
  82. if any("/torch/autograd/profiler.py" in x for x in stack):
  83. return
  84. stack = [
  85. x for x in stack if ("<built-in" not in x and "site-packages/torch/" not in x)
  86. ]
  87. print_once("MISSING", " >> ".join(stack[-3:]))
  88. class Profiler:
  89. unique_graphs: int = 0
  90. def __init__(self) -> None:
  91. self.prof = torch.profiler.profile(
  92. activities=[torch.profiler.ProfilerActivity.CPU],
  93. with_stack=should_print_missing(),
  94. )
  95. def results(self) -> ProfileResult:
  96. captured_regions = 0
  97. captured_ops = 0
  98. captured_microseconds = 0
  99. total_ops = 0
  100. total_microseconds = 0
  101. last_op_end_time = -1
  102. captured_region_end_time = -1
  103. events = sorted(self.prof.events(), key=lambda x: x.time_range.start)
  104. for e in events:
  105. if e.name == "TORCHDYNAMO":
  106. captured_region_end_time = e.time_range.end
  107. captured_regions += 1
  108. # ignore `handle = torch.zeros(1)` in record_function.__init__()
  109. total_ops -= 1
  110. elif e.time_range.start >= last_op_end_time:
  111. last_op_end_time = e.time_range.end
  112. if e.time_range.end <= captured_region_end_time:
  113. captured_ops += 1
  114. captured_microseconds += e.time_range.elapsed_us()
  115. elif should_print_missing():
  116. print_missing(e.stack)
  117. total_ops += 1
  118. total_microseconds += e.time_range.elapsed_us()
  119. else:
  120. pass # ops recursively called from other ops (ignored)
  121. unique_graphs = Profiler.unique_graphs
  122. Profiler.unique_graphs = 0
  123. # we counted one extra op that is part of the profiler setup code
  124. total_ops -= 1
  125. return ProfileResult(
  126. captured=ProfileMetrics(
  127. microseconds=captured_microseconds,
  128. operators=captured_ops,
  129. fusions=captured_ops - captured_regions,
  130. graphs=captured_regions,
  131. ),
  132. total=ProfileMetrics(
  133. microseconds=total_microseconds,
  134. operators=total_ops,
  135. fusions=total_ops - 1,
  136. ),
  137. unique_graphs=unique_graphs,
  138. )
  139. def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: list[Any]) -> Any:
  140. def _wrapped(*args: Any) -> Any:
  141. with torch.profiler.record_function("TORCHDYNAMO"):
  142. return gm.forward(*args)
  143. Profiler.unique_graphs += 1
  144. return _wrapped