benchmark_sam2.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. """
  6. Benchmark performance of SAM2 encoder with ORT or PyTorch. See benchmark_sam2.sh for usage.
  7. """
  8. import argparse
  9. import csv
  10. import statistics
  11. import time
  12. from collections.abc import Mapping
  13. from datetime import datetime
  14. import torch
  15. from image_decoder import SAM2ImageDecoder
  16. from image_encoder import SAM2ImageEncoder
  17. from sam2_utils import decoder_shape_dict, encoder_shape_dict, load_sam2_model
  18. from onnxruntime import InferenceSession, SessionOptions, get_available_providers
  19. from onnxruntime.transformers.io_binding_helper import CudaSession
  20. class TestConfig:
  21. def __init__(
  22. self,
  23. model_type: str,
  24. onnx_path: str,
  25. sam2_dir: str,
  26. device: torch.device,
  27. component: str = "image_encoder",
  28. provider="CPUExecutionProvider",
  29. torch_compile_mode="max-autotune",
  30. batch_size: int = 1,
  31. height: int = 1024,
  32. width: int = 1024,
  33. num_labels: int = 1,
  34. num_points: int = 1,
  35. num_masks: int = 1,
  36. multi_mask_output: bool = False,
  37. use_tf32: bool = True,
  38. enable_cuda_graph: bool = False,
  39. dtype=torch.float32,
  40. prefer_nhwc: bool = False,
  41. warm_up: int = 5,
  42. enable_nvtx_profile: bool = False,
  43. enable_ort_profile: bool = False,
  44. enable_torch_profile: bool = False,
  45. repeats: int = 1000,
  46. verbose: bool = False,
  47. ):
  48. assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
  49. assert height >= 160 and height <= 4096
  50. assert width >= 160 and width <= 4096
  51. self.model_type = model_type
  52. self.onnx_path = onnx_path
  53. self.sam2_dir = sam2_dir
  54. self.component = component
  55. self.provider = provider
  56. self.torch_compile_mode = torch_compile_mode
  57. self.batch_size = batch_size
  58. self.height = height
  59. self.width = width
  60. self.num_labels = num_labels
  61. self.num_points = num_points
  62. self.num_masks = num_masks
  63. self.multi_mask_output = multi_mask_output
  64. self.device = device
  65. self.use_tf32 = use_tf32
  66. self.enable_cuda_graph = enable_cuda_graph
  67. self.dtype = dtype
  68. self.prefer_nhwc = prefer_nhwc
  69. self.warm_up = warm_up
  70. self.enable_nvtx_profile = enable_nvtx_profile
  71. self.enable_ort_profile = enable_ort_profile
  72. self.enable_torch_profile = enable_torch_profile
  73. self.repeats = repeats
  74. self.verbose = verbose
  75. if self.component == "image_encoder":
  76. assert self.height == 1024 and self.width == 1024, "Only image size 1024x1024 is allowed for image encoder."
  77. def __repr__(self):
  78. return f"{vars(self)}"
  79. def shape_dict(self) -> Mapping[str, list[int]]:
  80. if self.component == "image_encoder":
  81. return encoder_shape_dict(self.batch_size, self.height, self.width)
  82. else:
  83. return decoder_shape_dict(self.height, self.width, self.num_labels, self.num_points, self.num_masks)
  84. def random_inputs(self) -> Mapping[str, torch.Tensor]:
  85. dtype = self.dtype
  86. if self.component == "image_encoder":
  87. return {"image": torch.randn(self.batch_size, 3, self.height, self.width, dtype=dtype, device=self.device)}
  88. else:
  89. return {
  90. "image_features_0": torch.rand(1, 32, 256, 256, dtype=dtype, device=self.device),
  91. "image_features_1": torch.rand(1, 64, 128, 128, dtype=dtype, device=self.device),
  92. "image_embeddings": torch.rand(1, 256, 64, 64, dtype=dtype, device=self.device),
  93. "point_coords": torch.randint(
  94. 0, 1024, (self.num_labels, self.num_points, 2), dtype=dtype, device=self.device
  95. ),
  96. "point_labels": torch.randint(
  97. 0, 1, (self.num_labels, self.num_points), dtype=torch.int32, device=self.device
  98. ),
  99. "input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=dtype, device=self.device),
  100. "has_input_masks": torch.ones(self.num_labels, dtype=dtype, device=self.device),
  101. "original_image_size": torch.tensor([self.height, self.width], dtype=torch.int32, device=self.device),
  102. }
  103. def create_ort_session(config: TestConfig, session_options=None) -> InferenceSession:
  104. if config.verbose:
  105. print(f"create session for {vars(config)}")
  106. if config.provider == "CUDAExecutionProvider":
  107. device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
  108. provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph)
  109. provider_options["use_tf32"] = int(config.use_tf32)
  110. if config.prefer_nhwc:
  111. provider_options["prefer_nhwc"] = 1
  112. providers = [(config.provider, provider_options), "CPUExecutionProvider"]
  113. else:
  114. providers = ["CPUExecutionProvider"]
  115. ort_session = InferenceSession(config.onnx_path, session_options, providers=providers)
  116. return ort_session
  117. def create_session(config: TestConfig, session_options=None) -> CudaSession:
  118. ort_session = create_ort_session(config, session_options)
  119. cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph)
  120. cuda_session.allocate_buffers(config.shape_dict())
  121. return cuda_session
  122. class OrtTestSession:
  123. """A wrapper of ORT session to test relevance and performance."""
  124. def __init__(self, config: TestConfig, session_options=None):
  125. self.ort_session = create_session(config, session_options)
  126. self.feed_dict = config.random_inputs()
  127. def infer(self):
  128. return self.ort_session.infer(self.feed_dict)
  129. def measure_latency(cuda_session: CudaSession, input_dict):
  130. start = time.time()
  131. _ = cuda_session.infer(input_dict)
  132. end = time.time()
  133. return end - start
  134. def run_torch(config: TestConfig):
  135. device_type = config.device.type
  136. is_cuda = device_type == "cuda"
  137. # Turn on TF32 for Ampere GPUs which could help when data type is float32.
  138. if is_cuda and torch.cuda.get_device_properties(0).major >= 8 and config.use_tf32:
  139. torch.backends.cuda.matmul.allow_tf32 = True
  140. torch.backends.cudnn.allow_tf32 = True
  141. enabled_auto_cast = is_cuda and config.dtype != torch.float32
  142. ort_inputs = config.random_inputs()
  143. with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast):
  144. sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device)
  145. if config.component == "image_encoder":
  146. if is_cuda and config.torch_compile_mode != "none":
  147. sam2_model.image_encoder.forward = torch.compile(
  148. sam2_model.image_encoder.forward,
  149. mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run.
  150. fullgraph=True,
  151. dynamic=False,
  152. )
  153. image_shape = config.shape_dict()["image"]
  154. img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype)
  155. sam2_encoder = SAM2ImageEncoder(sam2_model)
  156. if is_cuda and config.torch_compile_mode != "none":
  157. print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.")
  158. for _ in range(config.warm_up):
  159. _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
  160. if is_cuda and config.enable_nvtx_profile:
  161. import nvtx # noqa: PLC0415
  162. from cuda import cudart # noqa: PLC0415
  163. cudart.cudaProfilerStart()
  164. print("Start nvtx profiling on encoder ...")
  165. with nvtx.annotate("one_run"):
  166. sam2_encoder(img, enable_nvtx_profile=True)
  167. cudart.cudaProfilerStop()
  168. if is_cuda and config.enable_torch_profile:
  169. with torch.profiler.profile(
  170. activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
  171. record_shapes=True,
  172. ) as prof:
  173. print("Start torch profiling on encoder ...")
  174. with torch.profiler.record_function("encoder"):
  175. sam2_encoder(img)
  176. print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
  177. prof.export_chrome_trace("torch_image_encoder.json")
  178. if config.repeats == 0:
  179. return
  180. print(f"Start {config.repeats} runs of performance tests...")
  181. start = time.time()
  182. for _ in range(config.repeats):
  183. _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
  184. if is_cuda:
  185. torch.cuda.synchronize()
  186. else:
  187. torch_inputs = (
  188. ort_inputs["image_features_0"],
  189. ort_inputs["image_features_1"],
  190. ort_inputs["image_embeddings"],
  191. ort_inputs["point_coords"],
  192. ort_inputs["point_labels"],
  193. ort_inputs["input_masks"],
  194. ort_inputs["has_input_masks"],
  195. ort_inputs["original_image_size"],
  196. )
  197. sam2_decoder = SAM2ImageDecoder(
  198. sam2_model,
  199. multimask_output=config.multi_mask_output,
  200. )
  201. if is_cuda and config.torch_compile_mode != "none":
  202. sam2_decoder.forward = torch.compile(
  203. sam2_decoder.forward,
  204. mode=config.torch_compile_mode,
  205. fullgraph=True,
  206. dynamic=False,
  207. )
  208. # warm up
  209. for _ in range(config.warm_up):
  210. _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
  211. if is_cuda and config.enable_nvtx_profile:
  212. import nvtx # noqa: PLC0415
  213. from cuda import cudart # noqa: PLC0415
  214. cudart.cudaProfilerStart()
  215. print("Start nvtx profiling on decoder...")
  216. with nvtx.annotate("one_run"):
  217. sam2_decoder(*torch_inputs, enable_nvtx_profile=True)
  218. cudart.cudaProfilerStop()
  219. if is_cuda and config.enable_torch_profile:
  220. with torch.profiler.profile(
  221. activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
  222. record_shapes=True,
  223. ) as prof:
  224. print("Start torch profiling on decoder ...")
  225. with torch.profiler.record_function("decoder"):
  226. sam2_decoder(*torch_inputs)
  227. print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
  228. prof.export_chrome_trace("torch_image_decoder.json")
  229. if config.repeats == 0:
  230. return
  231. print(f"Start {config.repeats} runs of performance tests...")
  232. start = time.time()
  233. for _ in range(config.repeats):
  234. _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
  235. if is_cuda:
  236. torch.cuda.synchronize()
  237. end = time.time()
  238. return (end - start) / config.repeats
  239. def run_test(
  240. args: argparse.Namespace,
  241. csv_writer: csv.DictWriter | None = None,
  242. ):
  243. use_gpu: bool = args.use_gpu
  244. enable_cuda_graph: bool = args.use_cuda_graph
  245. repeats: int = args.repeats
  246. if use_gpu:
  247. device_id = torch.cuda.current_device()
  248. device = torch.device("cuda", device_id)
  249. provider = "CUDAExecutionProvider"
  250. else:
  251. device_id = 0
  252. device = torch.device("cpu")
  253. enable_cuda_graph = False
  254. provider = "CPUExecutionProvider"
  255. dtypes = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
  256. config = TestConfig(
  257. model_type=args.model_type,
  258. onnx_path=args.onnx_path,
  259. sam2_dir=args.sam2_dir,
  260. component=args.component,
  261. provider=provider,
  262. batch_size=args.batch_size,
  263. height=args.height,
  264. width=args.width,
  265. device=device,
  266. use_tf32=True,
  267. enable_cuda_graph=enable_cuda_graph,
  268. dtype=dtypes[args.dtype],
  269. prefer_nhwc=args.prefer_nhwc,
  270. repeats=args.repeats,
  271. warm_up=args.warm_up,
  272. enable_nvtx_profile=args.enable_nvtx_profile,
  273. enable_ort_profile=args.enable_ort_profile,
  274. enable_torch_profile=args.enable_torch_profile,
  275. torch_compile_mode=args.torch_compile_mode,
  276. verbose=False,
  277. )
  278. if args.engine == "ort":
  279. sess_options = SessionOptions()
  280. sess_options.intra_op_num_threads = args.intra_op_num_threads
  281. if config.enable_ort_profile:
  282. sess_options.enable_profiling = True
  283. sess_options.log_severity_level = 4
  284. sess_options.log_verbosity_level = 0
  285. session = create_session(config, sess_options)
  286. input_dict = config.random_inputs()
  287. # warm up session
  288. try:
  289. for _ in range(config.warm_up):
  290. _ = measure_latency(session, input_dict)
  291. except Exception as e:
  292. print(f"Failed to run {config=}. Exception: {e}")
  293. return
  294. if config.enable_nvtx_profile:
  295. import nvtx # noqa: PLC0415
  296. from cuda import cudart # noqa: PLC0415
  297. cudart.cudaProfilerStart()
  298. with nvtx.annotate("one_run"):
  299. _ = session.infer(input_dict)
  300. cudart.cudaProfilerStop()
  301. if config.enable_ort_profile:
  302. session.ort_session.end_profiling()
  303. if repeats == 0:
  304. return
  305. latency_list = []
  306. for _ in range(repeats):
  307. latency = measure_latency(session, input_dict)
  308. latency_list.append(latency)
  309. average_latency = statistics.mean(latency_list)
  310. del session
  311. else: # torch
  312. with torch.no_grad():
  313. try:
  314. average_latency = run_torch(config)
  315. except Exception as e:
  316. print(f"Failed to run {config=}. Exception: {e}")
  317. return
  318. if repeats == 0:
  319. return
  320. engine = args.engine + ":" + ("cuda" if use_gpu else "cpu")
  321. row = {
  322. "model_type": args.model_type,
  323. "component": args.component,
  324. "dtype": args.dtype,
  325. "use_gpu": use_gpu,
  326. "enable_cuda_graph": enable_cuda_graph,
  327. "prefer_nhwc": config.prefer_nhwc,
  328. "use_tf32": config.use_tf32,
  329. "batch_size": args.batch_size,
  330. "height": args.height,
  331. "width": args.width,
  332. "multi_mask_output": args.multimask_output,
  333. "num_labels": config.num_labels,
  334. "num_points": config.num_points,
  335. "num_masks": config.num_masks,
  336. "intra_op_num_threads": args.intra_op_num_threads,
  337. "warm_up": config.warm_up,
  338. "repeats": repeats,
  339. "enable_nvtx_profile": args.enable_nvtx_profile,
  340. "torch_compile_mode": args.torch_compile_mode,
  341. "engine": engine,
  342. "average_latency": average_latency,
  343. }
  344. if csv_writer is not None:
  345. csv_writer.writerow(row)
  346. print(f"{vars(config)}")
  347. print(f"{row}")
  348. def run_perf_test(args):
  349. features = "gpu" if args.use_gpu else "cpu"
  350. csv_filename = "benchmark_sam_{}_{}_{}.csv".format(
  351. features,
  352. args.engine,
  353. datetime.now().strftime("%Y%m%d-%H%M%S"),
  354. )
  355. with open(csv_filename, mode="a", newline="") as csv_file:
  356. column_names = [
  357. "model_type",
  358. "component",
  359. "dtype",
  360. "use_gpu",
  361. "enable_cuda_graph",
  362. "prefer_nhwc",
  363. "use_tf32",
  364. "batch_size",
  365. "height",
  366. "width",
  367. "multi_mask_output",
  368. "num_labels",
  369. "num_points",
  370. "num_masks",
  371. "intra_op_num_threads",
  372. "warm_up",
  373. "repeats",
  374. "enable_nvtx_profile",
  375. "torch_compile_mode",
  376. "engine",
  377. "average_latency",
  378. ]
  379. csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
  380. csv_writer.writeheader()
  381. run_test(args, csv_writer)
  382. def _parse_arguments():
  383. parser = argparse.ArgumentParser(description="Benchmark SMA2 for ONNX Runtime and PyTorch.")
  384. parser.add_argument(
  385. "--component",
  386. required=False,
  387. choices=["image_encoder", "image_decoder"],
  388. default="image_encoder",
  389. help="component to benchmark. Choices are image_encoder and image_decoder.",
  390. )
  391. parser.add_argument(
  392. "--dtype", required=False, choices=["fp32", "fp16", "bf16"], default="fp32", help="Data type for inference."
  393. )
  394. parser.add_argument(
  395. "--use_gpu",
  396. required=False,
  397. action="store_true",
  398. help="Use GPU for inference.",
  399. )
  400. parser.set_defaults(use_gpu=False)
  401. parser.add_argument(
  402. "--use_cuda_graph",
  403. required=False,
  404. action="store_true",
  405. help="Use cuda graph in onnxruntime.",
  406. )
  407. parser.set_defaults(use_cuda_graph=False)
  408. parser.add_argument(
  409. "--intra_op_num_threads",
  410. required=False,
  411. type=int,
  412. choices=[0, 1, 2, 4, 8, 16],
  413. default=0,
  414. help="intra_op_num_threads for onnxruntime. ",
  415. )
  416. parser.add_argument(
  417. "--batch_size",
  418. required=False,
  419. type=int,
  420. default=1,
  421. help="batch size",
  422. )
  423. parser.add_argument(
  424. "--height",
  425. required=False,
  426. type=int,
  427. default=1024,
  428. help="image height",
  429. )
  430. parser.add_argument(
  431. "--width",
  432. required=False,
  433. type=int,
  434. default=1024,
  435. help="image width",
  436. )
  437. parser.add_argument(
  438. "--repeats",
  439. required=False,
  440. type=int,
  441. default=1000,
  442. help="number of repeats for performance test. Default is 1000.",
  443. )
  444. parser.add_argument(
  445. "--warm_up",
  446. required=False,
  447. type=int,
  448. default=5,
  449. help="number of runs for warm up. Default is 5.",
  450. )
  451. parser.add_argument(
  452. "--engine",
  453. required=False,
  454. type=str,
  455. default="ort",
  456. choices=["ort", "torch"],
  457. help="engine for inference",
  458. )
  459. parser.add_argument(
  460. "--multimask_output",
  461. required=False,
  462. default=False,
  463. action="store_true",
  464. help="Export mask_decoder or image_decoder with multimask_output",
  465. )
  466. parser.add_argument(
  467. "--prefer_nhwc",
  468. required=False,
  469. default=False,
  470. action="store_true",
  471. help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider",
  472. )
  473. parser.add_argument(
  474. "--enable_nvtx_profile",
  475. required=False,
  476. default=False,
  477. action="store_true",
  478. help="Enable nvtx profiling. It will add an extra run for profiling before performance test.",
  479. )
  480. parser.add_argument(
  481. "--enable_ort_profile",
  482. required=False,
  483. default=False,
  484. action="store_true",
  485. help="Enable ORT profiling.",
  486. )
  487. parser.add_argument(
  488. "--enable_torch_profile",
  489. required=False,
  490. default=False,
  491. action="store_true",
  492. help="Enable PyTorch profiling. It will add an extra run for profiling before performance test.",
  493. )
  494. parser.add_argument(
  495. "--model_type",
  496. required=False,
  497. type=str,
  498. default="sam2_hiera_large",
  499. choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
  500. help="sam2 model name",
  501. )
  502. parser.add_argument(
  503. "--sam2_dir",
  504. required=False,
  505. type=str,
  506. default="./segment-anything-2",
  507. help="The directory of segment-anything-2 git root directory",
  508. )
  509. parser.add_argument(
  510. "--onnx_path",
  511. required=False,
  512. type=str,
  513. default="./sam2_onnx_models/sam2_hiera_large_image_encoder.onnx",
  514. help="path of onnx model",
  515. )
  516. parser.add_argument(
  517. "--torch_compile_mode",
  518. required=False,
  519. type=str,
  520. default=None,
  521. choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"],
  522. help="torch compile mode. none will disable torch compile.",
  523. )
  524. args = parser.parse_args()
  525. return args
  526. if __name__ == "__main__":
  527. args = _parse_arguments()
  528. print(f"arguments:{args}")
  529. if args.torch_compile_mode is None:
  530. # image decoder will fail with compile modes other than "none".
  531. args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none"
  532. if args.use_gpu:
  533. assert torch.cuda.is_available()
  534. if args.engine == "ort":
  535. assert "CUDAExecutionProvider" in get_available_providers()
  536. args.enable_torch_profile = False
  537. else:
  538. # Only support cuda profiling for now.
  539. assert not args.enable_nvtx_profile
  540. assert not args.enable_torch_profile
  541. if args.enable_nvtx_profile or args.enable_torch_profile:
  542. run_test(args)
  543. else:
  544. run_perf_test(args)