benchmark.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945
  1. # Copyright (c) Microsoft Corporation. All rights reserved.
  2. # Copyright 2018 The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Benchmarking the inference of pretrained transformer models.
  17. PyTorch/TorchScript benchmark is based on https://github.com/huggingface/transformers/blob/master/examples/benchmarks.py.
  18. One difference is that random input_ids is generated in this benchmark.
  19. For onnxruntime, this script will convert a pretrained model to ONNX, and optimize it when -o parameter is used.
  20. Example commands:
  21. Export all models to ONNX, optimize and validate them:
  22. python benchmark.py -b 0 -o -v -i 1 2 3
  23. Run OnnxRuntime on GPU for all models:
  24. python benchmark.py -g
  25. Run OnnxRuntime on GPU for all models with fp32 optimization:
  26. python benchmark.py -g -o
  27. Run OnnxRuntime on GPU with fp16 optimization:
  28. python benchmark.py -g -o -p "fp16"
  29. Run TorchScript on GPU for all models:
  30. python benchmark.py -e torchscript -g
  31. Run TorchScript on GPU for all models with fp16:
  32. python benchmark.py -e torchscript -g -p "fp16"
  33. Run ONNXRuntime and TorchScript on CPU for all models with quantization:
  34. python benchmark.py -e torchscript onnxruntime -p "int8" -o
  35. Run OnnxRuntime with the ROCM provider and graph optimization script:
  36. python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm
  37. Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support:
  38. python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm
  39. It is recommended to use run_benchmark.sh to launch benchmark.
  40. """
  41. import argparse
  42. import logging
  43. import os
  44. import timeit
  45. from datetime import datetime
  46. import numpy
  47. import psutil
  48. from benchmark_helper import (
  49. ConfigModifier,
  50. OptimizerInfo,
  51. Precision,
  52. create_onnxruntime_session,
  53. get_latency_result,
  54. inference_ort,
  55. inference_ort_with_io_binding,
  56. output_details,
  57. output_fusion_statistics,
  58. output_summary,
  59. setup_logger,
  60. )
  61. from fusion_options import FusionOptions
  62. from huggingface_models import MODEL_CLASSES, MODELS
  63. from onnx_exporter import (
  64. create_onnxruntime_input,
  65. export_onnx_model_from_pt,
  66. export_onnx_model_from_tf,
  67. load_pretrained_model,
  68. )
  69. from packaging import version
  70. from quantize_helper import QuantizeHelper
  71. logger = logging.getLogger("")
  72. cpu_count = psutil.cpu_count(logical=False)
  73. # Set OMP environment variable before importing onnxruntime or torch.
  74. if "OMP_NUM_THREADS" not in os.environ:
  75. os.environ["OMP_NUM_THREADS"] = str(cpu_count)
  76. import torch # noqa: E402
  77. from transformers import AutoConfig, AutoTokenizer, LxmertConfig # noqa: E402
  78. def run_onnxruntime(
  79. use_gpu,
  80. provider,
  81. model_names,
  82. model_class,
  83. config_modifier,
  84. precision,
  85. num_threads,
  86. batch_sizes,
  87. sequence_lengths,
  88. repeat_times,
  89. input_counts,
  90. optimizer_info,
  91. validate_onnx,
  92. cache_dir,
  93. onnx_dir,
  94. verbose,
  95. overwrite,
  96. disable_ort_io_binding,
  97. use_raw_attention_mask,
  98. model_fusion_statistics,
  99. model_source,
  100. enable_arm64_bfloat16_fastmath_mlas_gemm,
  101. args,
  102. ):
  103. import onnxruntime # noqa: PLC0415
  104. results = []
  105. if (
  106. use_gpu
  107. and ("CUDAExecutionProvider" not in onnxruntime.get_available_providers())
  108. and ("MIGraphXExecutionProvider" not in onnxruntime.get_available_providers())
  109. and ("ROCMExecutionProvider" not in onnxruntime.get_available_providers())
  110. and ("DmlExecutionProvider" not in onnxruntime.get_available_providers())
  111. ):
  112. logger.error(
  113. "Please install onnxruntime-gpu or onnxruntime-directml package instead of onnxruntime, and use a machine with GPU for testing gpu performance."
  114. )
  115. return results
  116. warm_up_repeat = 0
  117. if provider == "tensorrt":
  118. optimizer_info = OptimizerInfo.NOOPT
  119. warm_up_repeat = 5
  120. if "TensorrtExecutionProvider" not in onnxruntime.get_available_providers():
  121. logger.error(
  122. "Please install onnxruntime-gpu-tensorrt package, and use a machine with GPU for testing gpu performance."
  123. )
  124. return results
  125. if optimizer_info == OptimizerInfo.NOOPT:
  126. logger.warning(
  127. f"OptimizerInfo is set to {optimizer_info}, graph optimizations specified in FusionOptions are not applied."
  128. )
  129. for model_name in model_names:
  130. all_input_names = MODELS[model_name][0]
  131. for num_inputs in input_counts:
  132. if num_inputs > len(all_input_names):
  133. break
  134. input_names = all_input_names[:num_inputs]
  135. args.model_type = MODELS[model_name][3]
  136. fusion_options = FusionOptions.parse(args)
  137. if "pt" in model_source:
  138. with torch.no_grad():
  139. (
  140. onnx_model_file,
  141. is_valid_onnx_model,
  142. vocab_size,
  143. max_sequence_length,
  144. ) = export_onnx_model_from_pt(
  145. model_name,
  146. MODELS[model_name][1],
  147. MODELS[model_name][2],
  148. MODELS[model_name][3],
  149. model_class,
  150. config_modifier,
  151. cache_dir,
  152. onnx_dir,
  153. input_names,
  154. use_gpu,
  155. precision,
  156. optimizer_info,
  157. validate_onnx,
  158. use_raw_attention_mask,
  159. overwrite,
  160. model_fusion_statistics,
  161. fusion_options,
  162. )
  163. if "tf" in model_source:
  164. (
  165. onnx_model_file,
  166. is_valid_onnx_model,
  167. vocab_size,
  168. max_sequence_length,
  169. ) = export_onnx_model_from_tf(
  170. model_name,
  171. MODELS[model_name][1],
  172. MODELS[model_name][2],
  173. MODELS[model_name][3],
  174. model_class,
  175. config_modifier,
  176. cache_dir,
  177. onnx_dir,
  178. input_names,
  179. use_gpu,
  180. precision,
  181. optimizer_info,
  182. validate_onnx,
  183. use_raw_attention_mask,
  184. overwrite,
  185. model_fusion_statistics,
  186. fusion_options,
  187. )
  188. if not is_valid_onnx_model:
  189. continue
  190. ort_session = create_onnxruntime_session(
  191. onnx_model_file,
  192. use_gpu,
  193. provider,
  194. enable_all_optimization=True,
  195. num_threads=num_threads,
  196. verbose=verbose,
  197. enable_mlas_gemm_fastmath_arm64_bfloat16=enable_arm64_bfloat16_fastmath_mlas_gemm,
  198. )
  199. if ort_session is None:
  200. continue
  201. ort_output_names = [node_arg.name for node_arg in ort_session.get_outputs()]
  202. output_buffers = []
  203. device = "cuda" if use_gpu else "cpu"
  204. config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
  205. max_last_state_size = numpy.prod(
  206. [
  207. max(batch_sizes),
  208. max(sequence_lengths),
  209. max(vocab_size, config.hidden_size),
  210. ]
  211. )
  212. max_pooler_size = numpy.prod([max(batch_sizes), config.hidden_size])
  213. for batch_size in batch_sizes:
  214. if batch_size <= 0:
  215. continue
  216. for sequence_length in sequence_lengths:
  217. if max_sequence_length is not None and sequence_length > max_sequence_length:
  218. continue
  219. input_value_type = numpy.int64 if "pt" in model_source else numpy.int32
  220. ort_inputs = create_onnxruntime_input(
  221. vocab_size,
  222. batch_size,
  223. sequence_length,
  224. input_names,
  225. config,
  226. input_value_type,
  227. )
  228. result_template = {
  229. "engine": "onnxruntime",
  230. "version": onnxruntime.__version__,
  231. "providers": provider,
  232. "device": device,
  233. "optimizer": optimizer_info,
  234. "precision": precision,
  235. "io_binding": not disable_ort_io_binding,
  236. "model_name": model_name,
  237. "inputs": num_inputs,
  238. "threads": num_threads,
  239. "batch_size": batch_size,
  240. "sequence_length": sequence_length,
  241. "custom_layer_num": config_modifier.get_layer_num(),
  242. "datetime": str(datetime.now()),
  243. }
  244. if config.model_type in ["vit", "swin"]:
  245. logger.info(
  246. f"Run onnxruntime on {model_name} with input shape {[batch_size, 3, config.image_size, config.image_size]}"
  247. )
  248. else:
  249. logger.info(f"Run onnxruntime on {model_name} with input shape {[batch_size, sequence_length]}")
  250. if disable_ort_io_binding:
  251. result = inference_ort(
  252. ort_session,
  253. ort_inputs,
  254. result_template,
  255. repeat_times,
  256. batch_size,
  257. warm_up_repeat,
  258. )
  259. else:
  260. # Get output sizes from a dummy ort run
  261. ort_outputs = ort_session.run(ort_output_names, ort_inputs)
  262. output_buffer_max_sizes = [max_last_state_size]
  263. for i in range(len(ort_outputs)):
  264. if i == 2 and MODELS[model_name][3] == "gpt":
  265. # past state output max size
  266. output_buffer_max_sizes.append(max_pooler_size)
  267. else:
  268. output_buffer_max_sizes.append(max_last_state_size)
  269. data_type = numpy.longlong if "pt" in model_source else numpy.intc
  270. result = inference_ort_with_io_binding(
  271. ort_session,
  272. ort_inputs,
  273. result_template,
  274. repeat_times,
  275. ort_output_names,
  276. ort_outputs,
  277. output_buffers,
  278. output_buffer_max_sizes,
  279. batch_size,
  280. device,
  281. data_type,
  282. warm_up_repeat,
  283. )
  284. logger.info(result)
  285. results.append(result)
  286. return results
  287. def run_pytorch(
  288. use_gpu,
  289. model_names,
  290. model_class,
  291. config_modifier,
  292. precision,
  293. num_threads,
  294. batch_sizes,
  295. sequence_lengths,
  296. repeat_times,
  297. torchscript,
  298. torch2,
  299. cache_dir,
  300. verbose,
  301. ):
  302. results = []
  303. if use_gpu and not torch.cuda.is_available():
  304. logger.error("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
  305. return results
  306. torch.set_grad_enabled(False)
  307. for model_name in model_names:
  308. config = AutoConfig.from_pretrained(model_name, torchscript=torchscript, cache_dir=cache_dir)
  309. config_modifier.modify(config)
  310. model = load_pretrained_model(
  311. model_name,
  312. config=config,
  313. cache_dir=cache_dir,
  314. custom_model_class=model_class,
  315. )
  316. if config.model_type in ["vit", "swin"]:
  317. # These models don't use sequence lengths, so just pick the first sequence length so that the summary still works
  318. sequence_lengths = [sequence_lengths[0]]
  319. else:
  320. tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
  321. max_input_size = tokenizer.model_max_length
  322. logger.debug(f"Model {model}")
  323. logger.debug(f"Number of parameters {model.num_parameters()}")
  324. if precision == Precision.FLOAT16:
  325. model.half()
  326. device = torch.device("cuda:0" if use_gpu else "cpu")
  327. model.to(device)
  328. if precision == Precision.INT8:
  329. model = QuantizeHelper.quantize_torch_model(model)
  330. for batch_size in batch_sizes:
  331. if batch_size <= 0:
  332. continue
  333. for sequence_length in sequence_lengths:
  334. if config.model_type in ["vit", "swin"]:
  335. logger.info(
  336. f"Run PyTorch on {model_name} with input shape {[batch_size, 3, config.image_size, config.image_size]}"
  337. )
  338. input_ids = torch.randn(
  339. size=(batch_size, 3, config.image_size, config.image_size),
  340. dtype=torch.float16 if precision == Precision.FLOAT16 else torch.float32,
  341. device=device,
  342. )
  343. else:
  344. if max_input_size is not None and sequence_length > max_input_size:
  345. continue
  346. logger.info(f"Run PyTorch on {model_name} with input shape {[batch_size, sequence_length]}")
  347. input_ids = torch.randint(
  348. low=0,
  349. high=config.vocab_size - 1,
  350. size=(batch_size, sequence_length),
  351. dtype=torch.long,
  352. device=device,
  353. )
  354. try:
  355. inference = (
  356. torch.jit.trace(model, input_ids) if torchscript else torch.compile(model) if torch2 else model
  357. )
  358. inference(input_ids)
  359. runtimes = timeit.repeat(lambda: inference(input_ids), repeat=repeat_times, number=1) # noqa: B023
  360. result = {
  361. "engine": "torchscript" if torchscript else "torch2" if torch2 else "torch",
  362. "version": torch.__version__,
  363. "providers": "NA",
  364. "device": "cuda" if use_gpu else "cpu",
  365. "optimizer": "",
  366. "precision": precision,
  367. "io_binding": "",
  368. "model_name": model_name,
  369. "inputs": 1,
  370. "threads": num_threads,
  371. "batch_size": batch_size,
  372. "sequence_length": sequence_length,
  373. "custom_layer_num": config_modifier.get_layer_num(),
  374. "datetime": str(datetime.now()),
  375. }
  376. result.update(get_latency_result(runtimes, batch_size))
  377. logger.info(result)
  378. results.append(result)
  379. except RuntimeError as e:
  380. logger.exception(e)
  381. torch.cuda.empty_cache()
  382. return results
  383. def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
  384. from functools import wraps # noqa: PLC0415
  385. import tensorflow as tf # noqa: PLC0415
  386. def run_func(func):
  387. @wraps(func)
  388. def run_in_eager_mode(*args, **kwargs):
  389. return func(*args, **kwargs)
  390. @wraps(func)
  391. @tf.function(experimental_compile=use_xla)
  392. def run_in_graph_mode(*args, **kwargs):
  393. return func(*args, **kwargs)
  394. if do_eager_mode is True:
  395. assert use_xla is False, (
  396. "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
  397. )
  398. return run_in_eager_mode
  399. else:
  400. return run_in_graph_mode
  401. return run_func
  402. def run_tensorflow(
  403. use_gpu,
  404. model_names,
  405. model_class,
  406. config_modifier,
  407. precision,
  408. num_threads,
  409. batch_sizes,
  410. sequence_lengths,
  411. repeat_times,
  412. cache_dir,
  413. verbose,
  414. ):
  415. results = []
  416. import tensorflow as tf # noqa: PLC0415
  417. tf.config.threading.set_intra_op_parallelism_threads(num_threads)
  418. if not use_gpu:
  419. tf.config.set_visible_devices([], "GPU")
  420. if use_gpu and not tf.test.is_built_with_cuda():
  421. logger.error("Please install Tensorflow-gpu, and use a machine with GPU for testing gpu performance.")
  422. return results
  423. if use_gpu: # Restrict TensorFlow to only use the first GPU
  424. physical_devices = tf.config.list_physical_devices("GPU")
  425. try:
  426. tf.config.set_visible_devices(physical_devices[0], "GPU")
  427. tf.config.experimental.set_memory_growth(physical_devices[0], True)
  428. tf.distribute.OneDeviceStrategy(device="/gpu:0")
  429. except RuntimeError as e:
  430. logger.exception(e)
  431. if precision == Precision.FLOAT16 or precision == Precision.INT8:
  432. raise NotImplementedError("Mixed precision is currently not supported.")
  433. for model_name in model_names:
  434. config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
  435. config_modifier.modify(config)
  436. model = load_pretrained_model(
  437. model_name,
  438. config=config,
  439. cache_dir=cache_dir,
  440. custom_model_class=model_class,
  441. is_tf_model=True,
  442. )
  443. tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
  444. max_input_size = tokenizer.model_max_length
  445. for batch_size in batch_sizes:
  446. if batch_size <= 0:
  447. continue
  448. for sequence_length in sequence_lengths:
  449. if max_input_size is not None and sequence_length > max_input_size:
  450. continue
  451. logger.info(f"Run Tensorflow on {model_name} with input shape {[batch_size, sequence_length]}")
  452. import random # noqa: PLC0415
  453. rng = random.Random()
  454. values = [rng.randint(0, config.vocab_size - 1) for i in range(batch_size * sequence_length)]
  455. input_ids = tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
  456. try:
  457. # Disable both for better inference perf
  458. @run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
  459. def encoder_forward():
  460. return model(input_ids, training=False) # noqa: B023
  461. @run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
  462. def encoder_decoder_forward():
  463. return model(input_ids, decoder_input_ids=input_ids, training=False) # noqa: B023
  464. @run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
  465. def lxmert_forward():
  466. feats = tf.random.normal([1, 1, config.visual_feat_dim]) # noqa: B023
  467. pos = tf.random.normal([1, 1, config.visual_pos_dim]) # noqa: B023
  468. return model( # noqa: B023
  469. input_ids, # noqa: B023
  470. visual_feats=feats,
  471. visual_pos=pos,
  472. training=False,
  473. )
  474. inference = encoder_forward
  475. if config.is_encoder_decoder:
  476. inference = encoder_decoder_forward
  477. elif isinstance(config, LxmertConfig):
  478. inference = lxmert_forward
  479. inference()
  480. runtimes = timeit.repeat(lambda: inference(), repeat=repeat_times, number=1) # noqa: B023
  481. result = {
  482. "engine": "tensorflow",
  483. "version": tf.__version__,
  484. "providers": "NA",
  485. "device": "cuda" if use_gpu else "cpu",
  486. "optimizer": "",
  487. "precision": precision,
  488. "io_binding": "",
  489. "model_name": model_name,
  490. "inputs": 1,
  491. "threads": num_threads,
  492. "batch_size": batch_size,
  493. "sequence_length": sequence_length,
  494. "custom_layer_num": config_modifier.get_layer_num(),
  495. "datetime": str(datetime.now()),
  496. }
  497. result.update(get_latency_result(runtimes, batch_size))
  498. logger.info(result)
  499. results.append(result)
  500. except RuntimeError as e:
  501. logger.exception(e)
  502. from numba import cuda # noqa: PLC0415
  503. device = cuda.get_current_device()
  504. device.reset()
  505. return results
  506. def parse_arguments():
  507. parser = argparse.ArgumentParser()
  508. parser.add_argument(
  509. "-m",
  510. "--models",
  511. required=False,
  512. nargs="+",
  513. type=str,
  514. default=["bert-base-cased", "roberta-base", "gpt2"],
  515. choices=list(MODELS.keys()),
  516. help="Pre-trained models in the list: " + ", ".join(MODELS.keys()),
  517. )
  518. parser.add_argument(
  519. "--model_source",
  520. required=False,
  521. nargs=1,
  522. type=str,
  523. default="pt",
  524. choices=["pt", "tf"],
  525. help="Export onnx from pt or tf",
  526. )
  527. parser.add_argument(
  528. "--model_class",
  529. required=False,
  530. type=str,
  531. default=None,
  532. choices=list(MODEL_CLASSES),
  533. help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
  534. )
  535. parser.add_argument(
  536. "-e",
  537. "--engines",
  538. required=False,
  539. nargs="+",
  540. type=str,
  541. default=["onnxruntime"],
  542. choices=["onnxruntime", "torch", "torch2", "torchscript", "tensorflow"],
  543. help="Engines to benchmark",
  544. )
  545. parser.add_argument(
  546. "-c",
  547. "--cache_dir",
  548. required=False,
  549. type=str,
  550. default=os.path.join(".", "cache_models"),
  551. help="Directory to cache pre-trained models",
  552. )
  553. parser.add_argument(
  554. "--onnx_dir",
  555. required=False,
  556. type=str,
  557. default=os.path.join(".", "onnx_models"),
  558. help="Directory to store onnx models",
  559. )
  560. parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="Run on gpu device")
  561. parser.add_argument(
  562. "--provider",
  563. required=False,
  564. type=str,
  565. default=None,
  566. help="Execution provider to use",
  567. )
  568. parser.add_argument(
  569. "-p",
  570. "--precision",
  571. type=Precision,
  572. default=Precision.FLOAT32,
  573. choices=list(Precision),
  574. help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization",
  575. )
  576. parser.add_argument("--verbose", required=False, action="store_true", help="Print more information")
  577. parser.add_argument(
  578. "--overwrite",
  579. required=False,
  580. action="store_true",
  581. help="Overwrite existing models",
  582. )
  583. parser.add_argument(
  584. "-o",
  585. "--optimizer_info",
  586. type=OptimizerInfo,
  587. default=OptimizerInfo.BYSCRIPT,
  588. choices=list(OptimizerInfo),
  589. help="Optimizer info: Use optimizer.py to optimize onnx model as default. Can also choose from by_ort and no_opt",
  590. )
  591. parser.add_argument(
  592. "-v",
  593. "--validate_onnx",
  594. required=False,
  595. action="store_true",
  596. help="Validate ONNX model",
  597. )
  598. parser.add_argument(
  599. "-f",
  600. "--fusion_csv",
  601. required=False,
  602. default=None,
  603. help="CSV file for saving summary results of graph optimization.",
  604. )
  605. parser.add_argument(
  606. "-d",
  607. "--detail_csv",
  608. required=False,
  609. default=None,
  610. help="CSV file for saving detail results.",
  611. )
  612. parser.add_argument(
  613. "-r",
  614. "--result_csv",
  615. required=False,
  616. default=None,
  617. help="CSV file for saving summary results.",
  618. )
  619. parser.add_argument(
  620. "-i",
  621. "--input_counts",
  622. required=False,
  623. nargs="+",
  624. default=[1],
  625. type=int,
  626. choices=[1, 2, 3],
  627. help="Number of ONNX model inputs. Please use 1 for fair comparison with Torch or TorchScript.",
  628. )
  629. parser.add_argument(
  630. "-t",
  631. "--test_times",
  632. required=False,
  633. default=100,
  634. type=int,
  635. help="Number of repeat times to get average inference latency.",
  636. )
  637. parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1])
  638. parser.add_argument(
  639. "-s",
  640. "--sequence_lengths",
  641. nargs="+",
  642. type=int,
  643. default=[4, 8, 16, 32, 64, 128, 256],
  644. )
  645. parser.add_argument(
  646. "--disable_ort_io_binding",
  647. required=False,
  648. action="store_true",
  649. help="Disable running ONNX Runtime with binded inputs and outputs. ",
  650. )
  651. parser.set_defaults(disable_ort_io_binding=False)
  652. parser.add_argument(
  653. "-n",
  654. "--num_threads",
  655. required=False,
  656. nargs="+",
  657. type=int,
  658. default=[0],
  659. help="Threads to use",
  660. )
  661. parser.add_argument(
  662. "--force_num_layers",
  663. required=False,
  664. type=int,
  665. default=None,
  666. help="Manually set the model's layer number",
  667. )
  668. parser.add_argument(
  669. "--enable_arm64_bfloat16_fastmath_mlas_gemm",
  670. required=False,
  671. action="store_true",
  672. help="Enable bfloat16 mlas gemm kernels on aarch64. Supported only for CPU EP ",
  673. )
  674. parser.set_defaults(enable_arm64_bfloat16_fastmath_mlas_gemm=False)
  675. FusionOptions.add_arguments(parser)
  676. args = parser.parse_args()
  677. return args
  678. def main():
  679. args = parse_arguments()
  680. setup_logger(args.verbose)
  681. if args.precision == Precision.FLOAT16 and not args.use_gpu:
  682. logger.error("fp16 is for GPU only")
  683. return
  684. if args.precision == Precision.INT8 and args.use_gpu and args.provider not in ["migraphx", "rocm"]:
  685. logger.error("int8 is for CPU only")
  686. return
  687. if len(args.models) == 1 and MODELS[args.models[0]][3] in ["vit", "swim"]:
  688. args.sequence_lengths = [""]
  689. args.num_threads = sorted({cpu_count if x <= 0 else x for x in args.num_threads})
  690. logger.info(f"Arguments: {args}")
  691. if not os.path.exists(args.cache_dir):
  692. try:
  693. os.mkdir(args.cache_dir)
  694. except OSError:
  695. logger.error("Creation of the directory %s failed", args.cache_dir)
  696. enable_torch = "torch" in args.engines
  697. enable_torch2 = "torch2" in args.engines
  698. enable_torchscript = "torchscript" in args.engines
  699. enable_onnxruntime = "onnxruntime" in args.engines
  700. enable_tensorflow = "tensorflow" in args.engines
  701. if enable_torch2 and version.parse(torch.__version__) < version.parse("2.0.0"):
  702. logger.error(f"PyTorch version must be >=2.0.0 and you are using {torch.__version__}")
  703. return
  704. config_modifier = ConfigModifier(args.force_num_layers)
  705. results = []
  706. for num_threads in args.num_threads:
  707. torch.set_num_threads(num_threads)
  708. logger.debug(torch.__config__.parallel_info())
  709. if enable_torch or enable_torch2 or enable_torchscript:
  710. if args.input_counts != [1]:
  711. logger.warning("--input_counts is not implemented for torch or torchscript engine.")
  712. if enable_torchscript:
  713. results += run_pytorch(
  714. args.use_gpu,
  715. args.models,
  716. args.model_class,
  717. config_modifier,
  718. args.precision,
  719. num_threads,
  720. args.batch_sizes,
  721. args.sequence_lengths,
  722. args.test_times,
  723. True,
  724. False,
  725. args.cache_dir,
  726. args.verbose,
  727. )
  728. if enable_torch:
  729. results += run_pytorch(
  730. args.use_gpu,
  731. args.models,
  732. args.model_class,
  733. config_modifier,
  734. args.precision,
  735. num_threads,
  736. args.batch_sizes,
  737. args.sequence_lengths,
  738. args.test_times,
  739. False,
  740. False,
  741. args.cache_dir,
  742. args.verbose,
  743. )
  744. if enable_torch2:
  745. results += run_pytorch(
  746. args.use_gpu,
  747. args.models,
  748. args.model_class,
  749. config_modifier,
  750. args.precision,
  751. num_threads,
  752. args.batch_sizes,
  753. args.sequence_lengths,
  754. args.test_times,
  755. False,
  756. True,
  757. args.cache_dir,
  758. args.verbose,
  759. )
  760. if enable_tensorflow:
  761. results += run_tensorflow(
  762. args.use_gpu,
  763. args.models,
  764. args.model_class,
  765. config_modifier,
  766. args.precision,
  767. num_threads,
  768. args.batch_sizes,
  769. args.sequence_lengths,
  770. args.test_times,
  771. args.cache_dir,
  772. args.verbose,
  773. )
  774. model_fusion_statistics = {}
  775. if enable_onnxruntime:
  776. try:
  777. use_raw_attention_mask = not args.use_mask_index
  778. results += run_onnxruntime(
  779. args.use_gpu,
  780. args.provider,
  781. args.models,
  782. args.model_class,
  783. config_modifier,
  784. args.precision,
  785. num_threads,
  786. args.batch_sizes,
  787. args.sequence_lengths,
  788. args.test_times,
  789. args.input_counts,
  790. args.optimizer_info,
  791. args.validate_onnx,
  792. args.cache_dir,
  793. args.onnx_dir,
  794. args.verbose,
  795. args.overwrite,
  796. args.disable_ort_io_binding,
  797. use_raw_attention_mask,
  798. model_fusion_statistics,
  799. args.model_source,
  800. args.enable_arm64_bfloat16_fastmath_mlas_gemm,
  801. args,
  802. )
  803. except Exception:
  804. logger.exception("Exception")
  805. time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
  806. if model_fusion_statistics:
  807. csv_filename = args.fusion_csv or f"benchmark_fusion_{time_stamp}.csv"
  808. output_fusion_statistics(model_fusion_statistics, csv_filename)
  809. if len(results) == 0:
  810. if args.batch_sizes != [0]:
  811. logger.warning("No any result available.")
  812. return
  813. csv_filename = args.detail_csv or f"benchmark_detail_{time_stamp}.csv"
  814. output_details(results, csv_filename)
  815. csv_filename = args.result_csv or f"benchmark_summary_{time_stamp}.csv"
  816. output_summary(results, csv_filename, args)
  817. if __name__ == "__main__":
  818. main()