profiler.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import tempfile
  4. import torch
  5. from . import check_error, cudart
  6. __all__ = ["init", "start", "stop", "profile"]
  7. DEFAULT_FLAGS = [
  8. "gpustarttimestamp",
  9. "gpuendtimestamp",
  10. "gridsize3d",
  11. "threadblocksize",
  12. "streamid",
  13. "enableonstart 0",
  14. "conckerneltrace",
  15. ]
  16. def init(output_file, flags=None, output_mode="key_value"):
  17. rt = cudart()
  18. if not hasattr(rt, "cudaOutputMode"):
  19. raise AssertionError("HIP does not support profiler initialization!")
  20. if (
  21. hasattr(torch.version, "cuda")
  22. and torch.version.cuda is not None
  23. and int(torch.version.cuda.split(".")[0]) >= 12
  24. ):
  25. # Check https://github.com/pytorch/pytorch/pull/91118
  26. # cudaProfilerInitialize is no longer needed after CUDA 12
  27. raise AssertionError("CUDA12+ does not need profiler initialization!")
  28. flags = DEFAULT_FLAGS if flags is None else flags
  29. if output_mode == "key_value":
  30. output_mode_enum = rt.cudaOutputMode.KeyValuePair
  31. elif output_mode == "csv":
  32. output_mode_enum = rt.cudaOutputMode.CSV
  33. else:
  34. raise RuntimeError(
  35. "supported CUDA profiler output modes are: key_value and csv"
  36. )
  37. with tempfile.NamedTemporaryFile(delete=True) as f:
  38. f.write(b"\n".join(f.encode("ascii") for f in flags))
  39. f.flush()
  40. check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum))
  41. def start():
  42. r"""Starts cuda profiler data collection.
  43. .. warning::
  44. Raises CudaError in case of it is unable to start the profiler.
  45. """
  46. check_error(cudart().cudaProfilerStart())
  47. def stop():
  48. r"""Stops cuda profiler data collection.
  49. .. warning::
  50. Raises CudaError in case of it is unable to stop the profiler.
  51. """
  52. check_error(cudart().cudaProfilerStop())
  53. @contextlib.contextmanager
  54. def profile():
  55. """
  56. Enable profiling.
  57. Context Manager to enabling profile collection by the active profiling tool from CUDA backend.
  58. Example:
  59. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  60. >>> import torch
  61. >>> model = torch.nn.Linear(20, 30).cuda()
  62. >>> inputs = torch.randn(128, 20).cuda()
  63. >>> with torch.cuda.profiler.profile() as prof:
  64. ... model(inputs)
  65. """
  66. try:
  67. start()
  68. yield
  69. finally:
  70. stop()