nvtx_helper.py 1.2 KB

123456789101112131415161718192021222324252627282930313233
  1. # -------------------------------------------------------------------------
  2. # Copyright (R) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import nvtx
  6. from cuda import cudart
  7. class NvtxHelper:
  8. def __init__(self, stages):
  9. self.stages = stages
  10. self.events = {}
  11. for stage in stages:
  12. for marker in ["start", "stop"]:
  13. self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
  14. self.markers = {}
  15. def start_profile(self, stage, color="blue"):
  16. self.markers[stage] = nvtx.start_range(message=stage, color=color)
  17. event_name = stage + "-start"
  18. if event_name in self.events:
  19. cudart.cudaEventRecord(self.events[event_name], 0)
  20. def stop_profile(self, stage):
  21. event_name = stage + "-stop"
  22. if event_name in self.events:
  23. cudart.cudaEventRecord(self.events[event_name], 0)
  24. nvtx.end_range(self.markers[stage])
  25. def print_latency(self):
  26. for stage in self.stages:
  27. latency = cudart.cudaEventElapsedTime(self.events[f"{stage}-start"], self.events[f"{stage}-stop"])[1]
  28. print(f"{stage}: {latency:.2f} ms")