| 123456789101112131415161718192021222324252627282930313233 |
- # -------------------------------------------------------------------------
- # Copyright (R) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- import nvtx
- from cuda import cudart
- class NvtxHelper:
- def __init__(self, stages):
- self.stages = stages
- self.events = {}
- for stage in stages:
- for marker in ["start", "stop"]:
- self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
- self.markers = {}
- def start_profile(self, stage, color="blue"):
- self.markers[stage] = nvtx.start_range(message=stage, color=color)
- event_name = stage + "-start"
- if event_name in self.events:
- cudart.cudaEventRecord(self.events[event_name], 0)
- def stop_profile(self, stage):
- event_name = stage + "-stop"
- if event_name in self.events:
- cudart.cudaEventRecord(self.events[event_name], 0)
- nvtx.end_range(self.markers[stage])
- def print_latency(self):
- for stage in self.stages:
- latency = cudart.cudaEventElapsedTime(self.events[f"{stage}-start"], self.events[f"{stage}-stop"])[1]
- print(f"{stage}: {latency:.2f} ms")
|