profiler.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. """
  2. Dynamo profiling implementation.
  3. This module provides profiling functionality for Dynamo, including:
  4. - ProfileMetrics: Class for collecting and aggregating performance metrics like
  5. execution time, operator counts, and fusion statistics
  6. - ProfileResult: Class for analyzing and reporting profiling results
  7. - Utilities for tracking missed/uncaptured operations
  8. - Functions for instrumenting FX graphs with profiling capabilities
  9. The profiler helps measure and optimize the performance of Dynamo-compiled code
  10. by tracking both captured and total operations, timing, and graph statistics.
  11. """
  12. from __future__ import annotations
  13. import dataclasses
  14. import os
  15. from typing import Any
  16. from typing_extensions import Self
  17. import torch
  18. from .utils import print_once
  19. @dataclasses.dataclass
  20. class ProfileMetrics:
  21. microseconds: float = 0.0
  22. operators: int = 0
  23. fusions: int = 0
  24. graphs: int = 0
  25. def __iadd__(self, other: Self) -> Self:
  26. self.microseconds += other.microseconds
  27. self.operators += other.operators
  28. self.fusions += other.fusions
  29. return self
  30. def __add__(self, other: ProfileMetrics) -> ProfileMetrics:
  31. assert isinstance(other, ProfileMetrics)
  32. return ProfileMetrics(
  33. self.microseconds + other.microseconds,
  34. self.operators + other.operators,
  35. self.fusions + other.fusions,
  36. )
  37. def __truediv__(self, other: Any) -> ProfileMetrics:
  38. if isinstance(other, int):
  39. other = ProfileMetrics(other, other, other)
  40. return ProfileMetrics(
  41. self.microseconds / max(1, other.microseconds),
  42. self.operators / max(1, other.operators),
  43. self.fusions / max(1, other.fusions),
  44. )
  45. def __str__(self) -> str:
  46. return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
  47. def tocsv(self) -> list[float]:
  48. return [self.operators, self.microseconds]
  49. class ProfileResult:
  50. def __init__(
  51. self, captured: ProfileMetrics, total: ProfileMetrics, unique_graphs: int
  52. ) -> None:
  53. self.captured: ProfileMetrics = captured or ProfileMetrics()
  54. self.total: ProfileMetrics = total or ProfileMetrics()
  55. self.unique_graphs: int = unique_graphs
  56. def __iadd__(self, other: Self) -> Self:
  57. self.captured += other.captured
  58. self.total += other.total
  59. self.unique_graphs += other.unique_graphs
  60. return self
  61. def percent(self) -> ProfileMetrics:
  62. return self.captured / self.total
  63. def __str__(self) -> str:
  64. return (
  65. f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
  66. f"{self.captured.operators:4}/{self.total.operators:4} = "
  67. + str(self.percent())
  68. )
  69. def tocsv(self) -> list[Any]:
  70. return [
  71. self.unique_graphs,
  72. self.captured.graphs,
  73. self.captured.operators,
  74. self.total.operators,
  75. ] + self.percent().tocsv()
  76. def should_print_missing() -> bool:
  77. return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1"
  78. def print_missing(stack: list[str]) -> None:
  79. if any("/torch/autograd/profiler.py" in x for x in stack):
  80. return
  81. stack = [
  82. x for x in stack if ("<built-in" not in x and "site-packages/torch/" not in x)
  83. ]
  84. print_once("MISSING", " >> ".join(stack[-3:]))
  85. class Profiler:
  86. unique_graphs: int = 0
  87. def __init__(self) -> None:
  88. self.prof = torch.profiler.profile(
  89. activities=[torch.profiler.ProfilerActivity.CPU],
  90. with_stack=should_print_missing(),
  91. )
  92. def results(self) -> ProfileResult:
  93. captured_regions = 0
  94. captured_ops = 0
  95. captured_microseconds = 0
  96. total_ops = 0
  97. total_microseconds = 0
  98. last_op_end_time = -1
  99. captured_region_end_time = -1
  100. events = sorted(self.prof.events(), key=lambda x: x.time_range.start)
  101. for e in events:
  102. if e.name == "TORCHDYNAMO":
  103. captured_region_end_time = e.time_range.end
  104. captured_regions += 1
  105. # ignore `handle = torch.zeros(1)` in record_function.__init__()
  106. total_ops -= 1
  107. elif e.time_range.start >= last_op_end_time:
  108. last_op_end_time = e.time_range.end
  109. if e.time_range.end <= captured_region_end_time:
  110. captured_ops += 1
  111. captured_microseconds += e.time_range.elapsed_us()
  112. elif should_print_missing():
  113. print_missing(e.stack)
  114. total_ops += 1
  115. total_microseconds += e.time_range.elapsed_us()
  116. else:
  117. pass # ops recursively called from other ops (ignored)
  118. unique_graphs = Profiler.unique_graphs
  119. Profiler.unique_graphs = 0
  120. # we counted one extra op that is part of the profiler setup code
  121. total_ops -= 1
  122. return ProfileResult(
  123. captured=ProfileMetrics(
  124. microseconds=captured_microseconds,
  125. operators=captured_ops,
  126. fusions=captured_ops - captured_regions,
  127. graphs=captured_regions,
  128. ),
  129. total=ProfileMetrics(
  130. microseconds=total_microseconds,
  131. operators=total_ops,
  132. fusions=total_ops - 1,
  133. ),
  134. unique_graphs=unique_graphs,
  135. )
  136. def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: list[Any]) -> Any:
  137. def _wrapped(*args: Any) -> Any:
  138. with torch.profiler.record_function("TORCHDYNAMO"):
  139. return gm.forward(*args)
  140. Profiler.unique_graphs += 1
  141. return _wrapped