| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- """
- Benchmark performance of SAM2 encoder with ORT or PyTorch. See benchmark_sam2.sh for usage.
- """
- import argparse
- import csv
- import statistics
- import time
- from collections.abc import Mapping
- from datetime import datetime
- import torch
- from image_decoder import SAM2ImageDecoder
- from image_encoder import SAM2ImageEncoder
- from sam2_utils import decoder_shape_dict, encoder_shape_dict, load_sam2_model
- from onnxruntime import InferenceSession, SessionOptions, get_available_providers
- from onnxruntime.transformers.io_binding_helper import CudaSession
- class TestConfig:
- def __init__(
- self,
- model_type: str,
- onnx_path: str,
- sam2_dir: str,
- device: torch.device,
- component: str = "image_encoder",
- provider="CPUExecutionProvider",
- torch_compile_mode="max-autotune",
- batch_size: int = 1,
- height: int = 1024,
- width: int = 1024,
- num_labels: int = 1,
- num_points: int = 1,
- num_masks: int = 1,
- multi_mask_output: bool = False,
- use_tf32: bool = True,
- enable_cuda_graph: bool = False,
- dtype=torch.float32,
- prefer_nhwc: bool = False,
- warm_up: int = 5,
- enable_nvtx_profile: bool = False,
- enable_ort_profile: bool = False,
- enable_torch_profile: bool = False,
- repeats: int = 1000,
- verbose: bool = False,
- ):
- assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
- assert height >= 160 and height <= 4096
- assert width >= 160 and width <= 4096
- self.model_type = model_type
- self.onnx_path = onnx_path
- self.sam2_dir = sam2_dir
- self.component = component
- self.provider = provider
- self.torch_compile_mode = torch_compile_mode
- self.batch_size = batch_size
- self.height = height
- self.width = width
- self.num_labels = num_labels
- self.num_points = num_points
- self.num_masks = num_masks
- self.multi_mask_output = multi_mask_output
- self.device = device
- self.use_tf32 = use_tf32
- self.enable_cuda_graph = enable_cuda_graph
- self.dtype = dtype
- self.prefer_nhwc = prefer_nhwc
- self.warm_up = warm_up
- self.enable_nvtx_profile = enable_nvtx_profile
- self.enable_ort_profile = enable_ort_profile
- self.enable_torch_profile = enable_torch_profile
- self.repeats = repeats
- self.verbose = verbose
- if self.component == "image_encoder":
- assert self.height == 1024 and self.width == 1024, "Only image size 1024x1024 is allowed for image encoder."
- def __repr__(self):
- return f"{vars(self)}"
- def shape_dict(self) -> Mapping[str, list[int]]:
- if self.component == "image_encoder":
- return encoder_shape_dict(self.batch_size, self.height, self.width)
- else:
- return decoder_shape_dict(self.height, self.width, self.num_labels, self.num_points, self.num_masks)
- def random_inputs(self) -> Mapping[str, torch.Tensor]:
- dtype = self.dtype
- if self.component == "image_encoder":
- return {"image": torch.randn(self.batch_size, 3, self.height, self.width, dtype=dtype, device=self.device)}
- else:
- return {
- "image_features_0": torch.rand(1, 32, 256, 256, dtype=dtype, device=self.device),
- "image_features_1": torch.rand(1, 64, 128, 128, dtype=dtype, device=self.device),
- "image_embeddings": torch.rand(1, 256, 64, 64, dtype=dtype, device=self.device),
- "point_coords": torch.randint(
- 0, 1024, (self.num_labels, self.num_points, 2), dtype=dtype, device=self.device
- ),
- "point_labels": torch.randint(
- 0, 1, (self.num_labels, self.num_points), dtype=torch.int32, device=self.device
- ),
- "input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=dtype, device=self.device),
- "has_input_masks": torch.ones(self.num_labels, dtype=dtype, device=self.device),
- "original_image_size": torch.tensor([self.height, self.width], dtype=torch.int32, device=self.device),
- }
- def create_ort_session(config: TestConfig, session_options=None) -> InferenceSession:
- if config.verbose:
- print(f"create session for {vars(config)}")
- if config.provider == "CUDAExecutionProvider":
- device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
- provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph)
- provider_options["use_tf32"] = int(config.use_tf32)
- if config.prefer_nhwc:
- provider_options["prefer_nhwc"] = 1
- providers = [(config.provider, provider_options), "CPUExecutionProvider"]
- else:
- providers = ["CPUExecutionProvider"]
- ort_session = InferenceSession(config.onnx_path, session_options, providers=providers)
- return ort_session
- def create_session(config: TestConfig, session_options=None) -> CudaSession:
- ort_session = create_ort_session(config, session_options)
- cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph)
- cuda_session.allocate_buffers(config.shape_dict())
- return cuda_session
- class OrtTestSession:
- """A wrapper of ORT session to test relevance and performance."""
- def __init__(self, config: TestConfig, session_options=None):
- self.ort_session = create_session(config, session_options)
- self.feed_dict = config.random_inputs()
- def infer(self):
- return self.ort_session.infer(self.feed_dict)
- def measure_latency(cuda_session: CudaSession, input_dict):
- start = time.time()
- _ = cuda_session.infer(input_dict)
- end = time.time()
- return end - start
- def run_torch(config: TestConfig):
- device_type = config.device.type
- is_cuda = device_type == "cuda"
- # Turn on TF32 for Ampere GPUs which could help when data type is float32.
- if is_cuda and torch.cuda.get_device_properties(0).major >= 8 and config.use_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
- enabled_auto_cast = is_cuda and config.dtype != torch.float32
- ort_inputs = config.random_inputs()
- with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast):
- sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device)
- if config.component == "image_encoder":
- if is_cuda and config.torch_compile_mode != "none":
- sam2_model.image_encoder.forward = torch.compile(
- sam2_model.image_encoder.forward,
- mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run.
- fullgraph=True,
- dynamic=False,
- )
- image_shape = config.shape_dict()["image"]
- img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype)
- sam2_encoder = SAM2ImageEncoder(sam2_model)
- if is_cuda and config.torch_compile_mode != "none":
- print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.")
- for _ in range(config.warm_up):
- _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
- if is_cuda and config.enable_nvtx_profile:
- import nvtx # noqa: PLC0415
- from cuda import cudart # noqa: PLC0415
- cudart.cudaProfilerStart()
- print("Start nvtx profiling on encoder ...")
- with nvtx.annotate("one_run"):
- sam2_encoder(img, enable_nvtx_profile=True)
- cudart.cudaProfilerStop()
- if is_cuda and config.enable_torch_profile:
- with torch.profiler.profile(
- activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
- record_shapes=True,
- ) as prof:
- print("Start torch profiling on encoder ...")
- with torch.profiler.record_function("encoder"):
- sam2_encoder(img)
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
- prof.export_chrome_trace("torch_image_encoder.json")
- if config.repeats == 0:
- return
- print(f"Start {config.repeats} runs of performance tests...")
- start = time.time()
- for _ in range(config.repeats):
- _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
- if is_cuda:
- torch.cuda.synchronize()
- else:
- torch_inputs = (
- ort_inputs["image_features_0"],
- ort_inputs["image_features_1"],
- ort_inputs["image_embeddings"],
- ort_inputs["point_coords"],
- ort_inputs["point_labels"],
- ort_inputs["input_masks"],
- ort_inputs["has_input_masks"],
- ort_inputs["original_image_size"],
- )
- sam2_decoder = SAM2ImageDecoder(
- sam2_model,
- multimask_output=config.multi_mask_output,
- )
- if is_cuda and config.torch_compile_mode != "none":
- sam2_decoder.forward = torch.compile(
- sam2_decoder.forward,
- mode=config.torch_compile_mode,
- fullgraph=True,
- dynamic=False,
- )
- # warm up
- for _ in range(config.warm_up):
- _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
- if is_cuda and config.enable_nvtx_profile:
- import nvtx # noqa: PLC0415
- from cuda import cudart # noqa: PLC0415
- cudart.cudaProfilerStart()
- print("Start nvtx profiling on decoder...")
- with nvtx.annotate("one_run"):
- sam2_decoder(*torch_inputs, enable_nvtx_profile=True)
- cudart.cudaProfilerStop()
- if is_cuda and config.enable_torch_profile:
- with torch.profiler.profile(
- activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
- record_shapes=True,
- ) as prof:
- print("Start torch profiling on decoder ...")
- with torch.profiler.record_function("decoder"):
- sam2_decoder(*torch_inputs)
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
- prof.export_chrome_trace("torch_image_decoder.json")
- if config.repeats == 0:
- return
- print(f"Start {config.repeats} runs of performance tests...")
- start = time.time()
- for _ in range(config.repeats):
- _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
- if is_cuda:
- torch.cuda.synchronize()
- end = time.time()
- return (end - start) / config.repeats
- def run_test(
- args: argparse.Namespace,
- csv_writer: csv.DictWriter | None = None,
- ):
- use_gpu: bool = args.use_gpu
- enable_cuda_graph: bool = args.use_cuda_graph
- repeats: int = args.repeats
- if use_gpu:
- device_id = torch.cuda.current_device()
- device = torch.device("cuda", device_id)
- provider = "CUDAExecutionProvider"
- else:
- device_id = 0
- device = torch.device("cpu")
- enable_cuda_graph = False
- provider = "CPUExecutionProvider"
- dtypes = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
- config = TestConfig(
- model_type=args.model_type,
- onnx_path=args.onnx_path,
- sam2_dir=args.sam2_dir,
- component=args.component,
- provider=provider,
- batch_size=args.batch_size,
- height=args.height,
- width=args.width,
- device=device,
- use_tf32=True,
- enable_cuda_graph=enable_cuda_graph,
- dtype=dtypes[args.dtype],
- prefer_nhwc=args.prefer_nhwc,
- repeats=args.repeats,
- warm_up=args.warm_up,
- enable_nvtx_profile=args.enable_nvtx_profile,
- enable_ort_profile=args.enable_ort_profile,
- enable_torch_profile=args.enable_torch_profile,
- torch_compile_mode=args.torch_compile_mode,
- verbose=False,
- )
- if args.engine == "ort":
- sess_options = SessionOptions()
- sess_options.intra_op_num_threads = args.intra_op_num_threads
- if config.enable_ort_profile:
- sess_options.enable_profiling = True
- sess_options.log_severity_level = 4
- sess_options.log_verbosity_level = 0
- session = create_session(config, sess_options)
- input_dict = config.random_inputs()
- # warm up session
- try:
- for _ in range(config.warm_up):
- _ = measure_latency(session, input_dict)
- except Exception as e:
- print(f"Failed to run {config=}. Exception: {e}")
- return
- if config.enable_nvtx_profile:
- import nvtx # noqa: PLC0415
- from cuda import cudart # noqa: PLC0415
- cudart.cudaProfilerStart()
- with nvtx.annotate("one_run"):
- _ = session.infer(input_dict)
- cudart.cudaProfilerStop()
- if config.enable_ort_profile:
- session.ort_session.end_profiling()
- if repeats == 0:
- return
- latency_list = []
- for _ in range(repeats):
- latency = measure_latency(session, input_dict)
- latency_list.append(latency)
- average_latency = statistics.mean(latency_list)
- del session
- else: # torch
- with torch.no_grad():
- try:
- average_latency = run_torch(config)
- except Exception as e:
- print(f"Failed to run {config=}. Exception: {e}")
- return
- if repeats == 0:
- return
- engine = args.engine + ":" + ("cuda" if use_gpu else "cpu")
- row = {
- "model_type": args.model_type,
- "component": args.component,
- "dtype": args.dtype,
- "use_gpu": use_gpu,
- "enable_cuda_graph": enable_cuda_graph,
- "prefer_nhwc": config.prefer_nhwc,
- "use_tf32": config.use_tf32,
- "batch_size": args.batch_size,
- "height": args.height,
- "width": args.width,
- "multi_mask_output": args.multimask_output,
- "num_labels": config.num_labels,
- "num_points": config.num_points,
- "num_masks": config.num_masks,
- "intra_op_num_threads": args.intra_op_num_threads,
- "warm_up": config.warm_up,
- "repeats": repeats,
- "enable_nvtx_profile": args.enable_nvtx_profile,
- "torch_compile_mode": args.torch_compile_mode,
- "engine": engine,
- "average_latency": average_latency,
- }
- if csv_writer is not None:
- csv_writer.writerow(row)
- print(f"{vars(config)}")
- print(f"{row}")
- def run_perf_test(args):
- features = "gpu" if args.use_gpu else "cpu"
- csv_filename = "benchmark_sam_{}_{}_{}.csv".format(
- features,
- args.engine,
- datetime.now().strftime("%Y%m%d-%H%M%S"),
- )
- with open(csv_filename, mode="a", newline="") as csv_file:
- column_names = [
- "model_type",
- "component",
- "dtype",
- "use_gpu",
- "enable_cuda_graph",
- "prefer_nhwc",
- "use_tf32",
- "batch_size",
- "height",
- "width",
- "multi_mask_output",
- "num_labels",
- "num_points",
- "num_masks",
- "intra_op_num_threads",
- "warm_up",
- "repeats",
- "enable_nvtx_profile",
- "torch_compile_mode",
- "engine",
- "average_latency",
- ]
- csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
- csv_writer.writeheader()
- run_test(args, csv_writer)
- def _parse_arguments():
- parser = argparse.ArgumentParser(description="Benchmark SMA2 for ONNX Runtime and PyTorch.")
- parser.add_argument(
- "--component",
- required=False,
- choices=["image_encoder", "image_decoder"],
- default="image_encoder",
- help="component to benchmark. Choices are image_encoder and image_decoder.",
- )
- parser.add_argument(
- "--dtype", required=False, choices=["fp32", "fp16", "bf16"], default="fp32", help="Data type for inference."
- )
- parser.add_argument(
- "--use_gpu",
- required=False,
- action="store_true",
- help="Use GPU for inference.",
- )
- parser.set_defaults(use_gpu=False)
- parser.add_argument(
- "--use_cuda_graph",
- required=False,
- action="store_true",
- help="Use cuda graph in onnxruntime.",
- )
- parser.set_defaults(use_cuda_graph=False)
- parser.add_argument(
- "--intra_op_num_threads",
- required=False,
- type=int,
- choices=[0, 1, 2, 4, 8, 16],
- default=0,
- help="intra_op_num_threads for onnxruntime. ",
- )
- parser.add_argument(
- "--batch_size",
- required=False,
- type=int,
- default=1,
- help="batch size",
- )
- parser.add_argument(
- "--height",
- required=False,
- type=int,
- default=1024,
- help="image height",
- )
- parser.add_argument(
- "--width",
- required=False,
- type=int,
- default=1024,
- help="image width",
- )
- parser.add_argument(
- "--repeats",
- required=False,
- type=int,
- default=1000,
- help="number of repeats for performance test. Default is 1000.",
- )
- parser.add_argument(
- "--warm_up",
- required=False,
- type=int,
- default=5,
- help="number of runs for warm up. Default is 5.",
- )
- parser.add_argument(
- "--engine",
- required=False,
- type=str,
- default="ort",
- choices=["ort", "torch"],
- help="engine for inference",
- )
- parser.add_argument(
- "--multimask_output",
- required=False,
- default=False,
- action="store_true",
- help="Export mask_decoder or image_decoder with multimask_output",
- )
- parser.add_argument(
- "--prefer_nhwc",
- required=False,
- default=False,
- action="store_true",
- help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider",
- )
- parser.add_argument(
- "--enable_nvtx_profile",
- required=False,
- default=False,
- action="store_true",
- help="Enable nvtx profiling. It will add an extra run for profiling before performance test.",
- )
- parser.add_argument(
- "--enable_ort_profile",
- required=False,
- default=False,
- action="store_true",
- help="Enable ORT profiling.",
- )
- parser.add_argument(
- "--enable_torch_profile",
- required=False,
- default=False,
- action="store_true",
- help="Enable PyTorch profiling. It will add an extra run for profiling before performance test.",
- )
- parser.add_argument(
- "--model_type",
- required=False,
- type=str,
- default="sam2_hiera_large",
- choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
- help="sam2 model name",
- )
- parser.add_argument(
- "--sam2_dir",
- required=False,
- type=str,
- default="./segment-anything-2",
- help="The directory of segment-anything-2 git root directory",
- )
- parser.add_argument(
- "--onnx_path",
- required=False,
- type=str,
- default="./sam2_onnx_models/sam2_hiera_large_image_encoder.onnx",
- help="path of onnx model",
- )
- parser.add_argument(
- "--torch_compile_mode",
- required=False,
- type=str,
- default=None,
- choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"],
- help="torch compile mode. none will disable torch compile.",
- )
- args = parser.parse_args()
- return args
- if __name__ == "__main__":
- args = _parse_arguments()
- print(f"arguments:{args}")
- if args.torch_compile_mode is None:
- # image decoder will fail with compile modes other than "none".
- args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none"
- if args.use_gpu:
- assert torch.cuda.is_available()
- if args.engine == "ort":
- assert "CUDAExecutionProvider" in get_available_providers()
- args.enable_torch_profile = False
- else:
- # Only support cuda profiling for now.
- assert not args.enable_nvtx_profile
- assert not args.enable_torch_profile
- if args.enable_nvtx_profile or args.enable_torch_profile:
- run_test(args)
- else:
- run_perf_test(args)
|