benchmark_helper.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import csv
  7. import logging
  8. import os
  9. import random
  10. import sys
  11. import time
  12. import timeit
  13. from abc import ABC, abstractmethod
  14. from concurrent.futures import ThreadPoolExecutor
  15. from datetime import datetime
  16. from enum import Enum
  17. from time import sleep
  18. from typing import Any
  19. import coloredlogs
  20. import numpy
  21. import torch
  22. import transformers
  23. from packaging import version
  24. import onnxruntime
  25. logger = logging.getLogger(__name__)
  26. class Precision(Enum):
  27. FLOAT32 = "fp32"
  28. FLOAT16 = "fp16"
  29. INT8 = "int8"
  30. INT4 = "int4"
  31. def __str__(self):
  32. return self.value
  33. class OptimizerInfo(Enum):
  34. # no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as
  35. # graph optimization level is not 0 (disable all).
  36. NOOPT = "no_opt"
  37. BYORT = "by_ort"
  38. BYSCRIPT = "by_script"
  39. def __str__(self):
  40. return self.value
  41. class ConfigModifier:
  42. def __init__(self, num_layers):
  43. self.num_layers = num_layers
  44. def modify(self, config):
  45. if self.num_layers is None:
  46. return
  47. if hasattr(config, "num_hidden_layers"):
  48. config.num_hidden_layers = self.num_layers
  49. logger.info(f"Modifying pytorch model's number of hidden layers to: {self.num_layers}")
  50. if hasattr(config, "encoder_layers"):
  51. config.encoder_layers = self.num_layers
  52. logger.info(f"Modifying pytorch model's number of encoder layers to: {self.num_layers}")
  53. if hasattr(config, "decoder_layers "):
  54. config.decoder_layers = self.num_layers
  55. logger.info(f"Modifying pytorch model's number of decoder layers to: {self.num_layers}")
  56. def get_layer_num(self):
  57. return self.num_layers
  58. IO_BINDING_DATA_TYPE_MAP = {
  59. "float32": numpy.float32,
  60. # TODO: Add more.
  61. }
  62. def create_onnxruntime_session(
  63. onnx_model_path,
  64. use_gpu,
  65. provider=None,
  66. enable_all_optimization=True,
  67. num_threads=-1,
  68. enable_profiling=False,
  69. verbose=False,
  70. enable_mlas_gemm_fastmath_arm64_bfloat16=False,
  71. provider_options={}, # map execution provider name to its option # noqa: B006
  72. ):
  73. sess_options = onnxruntime.SessionOptions()
  74. if enable_all_optimization:
  75. sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  76. else:
  77. sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
  78. if enable_profiling:
  79. sess_options.enable_profiling = True
  80. if num_threads > 0:
  81. sess_options.intra_op_num_threads = num_threads
  82. logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}")
  83. if verbose:
  84. sess_options.log_severity_level = 0
  85. else:
  86. sess_options.log_severity_level = 4
  87. if provider in onnxruntime.get_available_providers():
  88. providers = [provider]
  89. elif use_gpu:
  90. if provider == "dml":
  91. providers = ["DmlExecutionProvider", "CPUExecutionProvider"]
  92. elif provider == "rocm":
  93. providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
  94. elif provider == "migraphx":
  95. providers = [
  96. "MIGraphXExecutionProvider",
  97. "ROCMExecutionProvider",
  98. "CPUExecutionProvider",
  99. ]
  100. elif provider == "cuda" or provider is None:
  101. providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
  102. elif provider == "tensorrt":
  103. providers = [
  104. "TensorrtExecutionProvider",
  105. "CUDAExecutionProvider",
  106. "CPUExecutionProvider",
  107. ]
  108. else:
  109. raise RuntimeError(f"The execution provider is not supported: {provider}")
  110. else:
  111. providers = ["CPUExecutionProvider"]
  112. if provider_options:
  113. providers = [(name, provider_options[name]) if name in provider_options else name for name in providers]
  114. if enable_mlas_gemm_fastmath_arm64_bfloat16:
  115. sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1")
  116. session = None
  117. try:
  118. session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
  119. except Exception:
  120. logger.exception(f"Failed to create session for {onnx_model_path} with providers={providers}")
  121. return session
  122. def setup_logger(verbose=True):
  123. if verbose:
  124. coloredlogs.install(
  125. level="DEBUG",
  126. fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
  127. )
  128. else:
  129. coloredlogs.install(fmt="%(message)s")
  130. logging.getLogger("transformers").setLevel(logging.WARNING)
  131. def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
  132. if cache_dir and not os.path.exists(cache_dir):
  133. os.makedirs(cache_dir)
  134. if output_dir and not os.path.exists(output_dir):
  135. os.makedirs(output_dir)
  136. if use_gpu:
  137. if provider == "dml":
  138. assert "DmlExecutionProvider" in onnxruntime.get_available_providers(), (
  139. "Please install onnxruntime-directml package to test GPU inference."
  140. )
  141. else:
  142. assert not set(onnxruntime.get_available_providers()).isdisjoint(
  143. ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"]
  144. ), "Please install onnxruntime-gpu package, or install ROCm support, to test GPU inference."
  145. logger.info(f"PyTorch Version:{torch.__version__}")
  146. logger.info(f"Transformers Version:{transformers.__version__}")
  147. logger.info(f"OnnxRuntime Version:{onnxruntime.__version__}")
  148. # Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
  149. assert version.parse(torch.__version__) >= version.parse("1.10.0")
  150. assert version.parse(transformers.__version__) >= version.parse("4.12.0")
  151. assert version.parse(onnxruntime.__version__) >= version.parse("1.10.0")
  152. def get_latency_result(latency_list, batch_size):
  153. latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
  154. latency_variance = numpy.var(latency_list, dtype=numpy.float64) * 1000.0
  155. throughput = batch_size * (1000.0 / latency_ms)
  156. return {
  157. "test_times": len(latency_list),
  158. "latency_variance": f"{latency_variance:.2f}",
  159. "latency_90_percentile": f"{numpy.percentile(latency_list, 90) * 1000.0:.2f}",
  160. "latency_95_percentile": f"{numpy.percentile(latency_list, 95) * 1000.0:.2f}",
  161. "latency_99_percentile": f"{numpy.percentile(latency_list, 99) * 1000.0:.2f}",
  162. "average_latency_ms": f"{latency_ms:.2f}",
  163. "QPS": f"{throughput:.2f}",
  164. }
  165. def output_details(results, csv_filename):
  166. with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
  167. column_names = [
  168. "engine",
  169. "version",
  170. "providers",
  171. "device",
  172. "precision",
  173. "optimizer",
  174. "io_binding",
  175. "model_name",
  176. "inputs",
  177. "threads",
  178. "batch_size",
  179. "sequence_length",
  180. "custom_layer_num",
  181. "datetime",
  182. "test_times",
  183. "QPS",
  184. "average_latency_ms",
  185. "latency_variance",
  186. "latency_90_percentile",
  187. "latency_95_percentile",
  188. "latency_99_percentile",
  189. ]
  190. csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
  191. csv_writer.writeheader()
  192. for result in results:
  193. csv_writer.writerow(result)
  194. logger.info(f"Detail results are saved to csv file: {csv_filename}")
  195. def output_summary(results, csv_filename, args):
  196. with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
  197. header_names = [
  198. "model_name",
  199. "inputs",
  200. "custom_layer_num",
  201. "engine",
  202. "version",
  203. "providers",
  204. "device",
  205. "precision",
  206. "optimizer",
  207. "io_binding",
  208. "threads",
  209. ]
  210. data_names = []
  211. for batch_size in args.batch_sizes:
  212. if args.sequence_lengths == [""]:
  213. data_names.append(f"b{batch_size}")
  214. else:
  215. for sequence_length in args.sequence_lengths:
  216. data_names.append(f"b{batch_size}_s{sequence_length}")
  217. csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
  218. csv_writer.writeheader()
  219. for model_name in args.models:
  220. for input_count in [1, 2, 3]:
  221. for engine_name in args.engines:
  222. for io_binding in [True, False, ""]:
  223. for threads in args.num_threads:
  224. row = {}
  225. for result in results:
  226. if (
  227. result["model_name"] == model_name
  228. and result["inputs"] == input_count
  229. and result["engine"] == engine_name
  230. and result["io_binding"] == io_binding
  231. and result["threads"] == threads
  232. ):
  233. headers = {k: v for k, v in result.items() if k in header_names}
  234. if not row:
  235. row.update(headers)
  236. row.update(dict.fromkeys(data_names, ""))
  237. else:
  238. for k in header_names:
  239. assert row[k] == headers[k]
  240. b = result["batch_size"]
  241. s = result["sequence_length"]
  242. if s:
  243. row[f"b{b}_s{s}"] = result["average_latency_ms"]
  244. else:
  245. row[f"b{b}"] = result["average_latency_ms"]
  246. if row:
  247. csv_writer.writerow(row)
  248. logger.info(f"Summary results are saved to csv file: {csv_filename}")
  249. def output_fusion_statistics(model_fusion_statistics, csv_filename):
  250. with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
  251. column_names = [
  252. "model_filename",
  253. "datetime",
  254. "transformers",
  255. "torch",
  256. *list(next(iter(model_fusion_statistics.values())).keys()),
  257. ]
  258. csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
  259. csv_writer.writeheader()
  260. for key in model_fusion_statistics:
  261. model_fusion_statistics[key]["datetime"] = str(datetime.now())
  262. model_fusion_statistics[key]["transformers"] = transformers.__version__
  263. model_fusion_statistics[key]["torch"] = torch.__version__
  264. model_fusion_statistics[key]["model_filename"] = key
  265. csv_writer.writerow(model_fusion_statistics[key])
  266. logger.info(f"Fusion statistics is saved to csv file: {csv_filename}")
  267. def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, warm_up_repeat=0):
  268. result = {}
  269. timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=warm_up_repeat) # Dry run
  270. latency_list = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times)
  271. result.update(result_template)
  272. result.update({"io_binding": False})
  273. result.update(get_latency_result(latency_list, batch_size))
  274. return result
  275. def inference_ort_with_io_binding(
  276. ort_session,
  277. ort_inputs,
  278. result_template,
  279. repeat_times,
  280. ort_output_names,
  281. ort_outputs,
  282. output_buffers,
  283. output_buffer_max_sizes,
  284. batch_size,
  285. device,
  286. data_type=numpy.longlong,
  287. warm_up_repeat=0,
  288. ):
  289. result = {}
  290. # Bind inputs and outputs to onnxruntime session
  291. io_binding = ort_session.io_binding()
  292. # Bind inputs to device
  293. for name in ort_inputs:
  294. np_input = torch.from_numpy(ort_inputs[name]).to(device)
  295. input_type = IO_BINDING_DATA_TYPE_MAP.get(str(ort_inputs[name].dtype), data_type)
  296. io_binding.bind_input(
  297. name,
  298. np_input.device.type,
  299. 0,
  300. input_type,
  301. np_input.shape,
  302. np_input.data_ptr(),
  303. )
  304. # Bind outputs buffers with the sizes needed if not allocated already
  305. if len(output_buffers) == 0:
  306. allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)
  307. for i, ort_output_name in enumerate(ort_output_names):
  308. io_binding.bind_output(
  309. ort_output_name,
  310. output_buffers[i].device.type,
  311. 0,
  312. numpy.float32,
  313. ort_outputs[i].shape,
  314. output_buffers[i].data_ptr(),
  315. )
  316. timeit.repeat(
  317. lambda: ort_session.run_with_iobinding(io_binding),
  318. number=1,
  319. repeat=warm_up_repeat,
  320. ) # Dry run
  321. latency_list = timeit.repeat(
  322. lambda: ort_session.run_with_iobinding(io_binding),
  323. number=1,
  324. repeat=repeat_times,
  325. )
  326. result.update(result_template)
  327. result.update({"io_binding": True})
  328. result.update(get_latency_result(latency_list, batch_size))
  329. return result
  330. def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device): # noqa: N802
  331. # Allocate output tensors with the largest test size needed. So the allocated memory can be reused
  332. # for each test run.
  333. for i in output_buffer_max_sizes:
  334. output_buffers.append(torch.empty(i, dtype=torch.float32, device=device))
  335. def set_random_seed(seed=123):
  336. """Set random seed manually to get deterministic results"""
  337. random.seed(seed)
  338. numpy.random.seed(seed)
  339. torch.manual_seed(seed)
  340. torch.cuda.manual_seed(seed)
  341. torch.cuda.manual_seed_all(seed)
  342. # torch.backends.cudnn.enabled = False
  343. # torch.backends.cudnn.benchmark = False
  344. # torch.backends.cudnn.deterministic = True
  345. def get_gpu_info() -> list[dict[str, Any]] | None:
  346. from py3nvml.py3nvml import ( # noqa: PLC0415
  347. NVMLError,
  348. nvmlDeviceGetCount,
  349. nvmlDeviceGetHandleByIndex,
  350. nvmlDeviceGetMemoryInfo,
  351. nvmlDeviceGetName,
  352. nvmlInit,
  353. nvmlShutdown,
  354. )
  355. try:
  356. nvmlInit()
  357. result = []
  358. device_count = nvmlDeviceGetCount()
  359. if not isinstance(device_count, int):
  360. return None
  361. for i in range(device_count):
  362. info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
  363. if isinstance(info, str):
  364. return None
  365. result.append(
  366. {
  367. "id": i,
  368. "name": nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)),
  369. "total": info.total,
  370. "free": info.free,
  371. "used": info.used,
  372. }
  373. )
  374. nvmlShutdown()
  375. return result
  376. except NVMLError as error:
  377. print("Error fetching GPU information using nvml: %s", error)
  378. return None
  379. class MemoryMonitor(ABC):
  380. def __init__(self, keep_measuring=True):
  381. self.keep_measuring = keep_measuring
  382. def measure_cpu_usage(self):
  383. import psutil # noqa: PLC0415
  384. max_usage = 0
  385. while True:
  386. max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
  387. sleep(0.005) # 5ms
  388. if not self.keep_measuring:
  389. break
  390. return max_usage
  391. @abstractmethod
  392. def measure_gpu_usage(self) -> list[dict[str, Any]] | None:
  393. raise NotImplementedError()
  394. class CudaMemoryMonitor(MemoryMonitor):
  395. def __init__(self, keep_measuring=True):
  396. super().__init__(keep_measuring)
  397. def measure_gpu_usage(self) -> list[dict[str, Any]] | None:
  398. from py3nvml.py3nvml import ( # noqa: PLC0415
  399. NVMLError,
  400. nvmlDeviceGetCount,
  401. nvmlDeviceGetHandleByIndex,
  402. nvmlDeviceGetMemoryInfo,
  403. nvmlDeviceGetName,
  404. nvmlInit,
  405. nvmlShutdown,
  406. )
  407. max_gpu_usage = []
  408. gpu_name = []
  409. try:
  410. nvmlInit()
  411. device_count = nvmlDeviceGetCount()
  412. if not isinstance(device_count, int):
  413. logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
  414. return None
  415. max_gpu_usage = [0 for i in range(device_count)]
  416. gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
  417. while True:
  418. for i in range(device_count):
  419. info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
  420. if isinstance(info, str):
  421. logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
  422. return None
  423. max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
  424. sleep(0.005) # 5ms
  425. if not self.keep_measuring:
  426. break
  427. nvmlShutdown()
  428. return [
  429. {
  430. "device_id": i,
  431. "name": gpu_name[i],
  432. "max_used_MB": max_gpu_usage[i],
  433. }
  434. for i in range(device_count)
  435. ]
  436. except NVMLError as error:
  437. logger.error("Error fetching GPU information using nvml: %s", error)
  438. return None
  439. class RocmMemoryMonitor(MemoryMonitor):
  440. def __init__(self, keep_measuring=True):
  441. super().__init__(keep_measuring)
  442. rocm_smi_path = "/opt/rocm/libexec/rocm_smi"
  443. if os.path.exists(rocm_smi_path):
  444. if rocm_smi_path not in sys.path:
  445. sys.path.append(rocm_smi_path)
  446. try:
  447. import rocm_smi # noqa: PLC0415
  448. self.rocm_smi = rocm_smi
  449. self.rocm_smi.initializeRsmi()
  450. except ImportError:
  451. self.rocm_smi = None
  452. def get_used_memory(self, dev):
  453. if self.rocm_smi is None:
  454. return -1
  455. return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024
  456. def measure_gpu_usage(self):
  457. if self.rocm_smi is None:
  458. return None
  459. device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0
  460. max_gpu_usage = [0 for i in range(device_count)]
  461. gpu_name = [f"GPU{i}" for i in range(device_count)]
  462. while True:
  463. for i in range(device_count):
  464. max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i))
  465. time.sleep(0.005) # 5ms
  466. if not self.keep_measuring:
  467. break
  468. return [
  469. {
  470. "device_id": i,
  471. "name": gpu_name[i],
  472. "max_used_MB": max_gpu_usage[i],
  473. }
  474. for i in range(device_count)
  475. ]
  476. def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
  477. memory_monitor_type = None
  478. if monitor_type == "rocm":
  479. memory_monitor_type = RocmMemoryMonitor
  480. else:
  481. memory_monitor_type = CudaMemoryMonitor
  482. monitor = memory_monitor_type(False)
  483. if is_gpu:
  484. if start_memory is not None:
  485. memory_before_test = start_memory
  486. else:
  487. memory_before_test = monitor.measure_gpu_usage()
  488. if memory_before_test is None:
  489. return None
  490. if func is None:
  491. return memory_before_test
  492. with ThreadPoolExecutor() as executor:
  493. monitor = memory_monitor_type()
  494. mem_thread = executor.submit(monitor.measure_gpu_usage)
  495. try:
  496. fn_thread = executor.submit(func)
  497. _ = fn_thread.result()
  498. finally:
  499. monitor.keep_measuring = False
  500. max_usage = mem_thread.result()
  501. if max_usage is None:
  502. return None
  503. logger.info(f"GPU memory usage: before={memory_before_test} peak={max_usage}")
  504. if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage):
  505. # When there are multiple GPUs, we will check the one with maximum usage.
  506. max_used = 0
  507. for i, memory_before in enumerate(memory_before_test):
  508. before = memory_before["max_used_MB"]
  509. after = max_usage[i]["max_used_MB"]
  510. used = after - before
  511. max_used = max(max_used, used)
  512. return max_used
  513. return None
  514. # CPU memory
  515. if start_memory is not None:
  516. memory_before_test = start_memory
  517. else:
  518. memory_before_test = monitor.measure_cpu_usage()
  519. if func is None:
  520. return memory_before_test
  521. with ThreadPoolExecutor() as executor:
  522. monitor = memory_monitor_type()
  523. mem_thread = executor.submit(monitor.measure_cpu_usage)
  524. try:
  525. fn_thread = executor.submit(func)
  526. _ = fn_thread.result()
  527. finally:
  528. monitor.keep_measuring = False
  529. max_usage = mem_thread.result()
  530. logger.info(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
  531. return max_usage - memory_before_test
  532. def get_ort_environment_variables():
  533. # Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
  534. env_names = [
  535. "ORT_DISABLE_FUSED_ATTENTION",
  536. "ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
  537. "ORT_DISABLE_FUSED_CROSS_ATTENTION",
  538. "ORT_DISABLE_TRT_FLASH_ATTENTION",
  539. "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
  540. "ORT_TRANSFORMER_OPTIONS",
  541. "ORT_CUDA_GEMM_OPTIONS",
  542. ]
  543. env = ""
  544. for name in env_names:
  545. value = os.getenv(name)
  546. if value is None:
  547. continue
  548. if env:
  549. env += ","
  550. env += f"{name}={value}"
  551. return env