benchmark_gpt2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. # This script benchmarks gpt2 model with past state.
  7. # For gpt2 model without past state, use benchmark.py to measure performance.
  8. import argparse
  9. import csv
  10. import logging
  11. import os
  12. from datetime import datetime
  13. import psutil
  14. import torch
  15. from benchmark_helper import (
  16. Precision,
  17. create_onnxruntime_session,
  18. get_ort_environment_variables,
  19. prepare_environment,
  20. setup_logger,
  21. )
  22. from gpt2_helper import DEFAULT_TOLERANCE, MODEL_CLASSES, PRETRAINED_GPT2_MODELS, Gpt2Helper
  23. from packaging import version
  24. from quantize_helper import QuantizeHelper
  25. from transformers import AutoConfig
  26. from transformers import __version__ as transformers_version
  27. logger = logging.getLogger("")
  28. def parse_arguments(argv=None):
  29. parser = argparse.ArgumentParser()
  30. parser.add_argument(
  31. "-m",
  32. "--model_name_or_path",
  33. required=True,
  34. type=str,
  35. help="Model path, or pretrained model name selected in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
  36. )
  37. parser.add_argument(
  38. "--model_class",
  39. required=False,
  40. type=str,
  41. default="GPT2LMHeadModel",
  42. choices=list(MODEL_CLASSES.keys()),
  43. help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
  44. )
  45. parser.add_argument(
  46. "--cache_dir",
  47. required=False,
  48. type=str,
  49. default=os.path.join(".", "cache_models"),
  50. help="Directory to cache pre-trained models",
  51. )
  52. parser.add_argument(
  53. "--onnx_dir",
  54. required=False,
  55. type=str,
  56. default=os.path.join(".", "onnx_models"),
  57. help="Directory to store onnx models",
  58. )
  59. parser.add_argument(
  60. "--test_times",
  61. required=False,
  62. default=100,
  63. type=int,
  64. help="Number of repeat times to get average inference latency.",
  65. )
  66. parser.add_argument(
  67. "-v",
  68. "--validate_onnx",
  69. required=False,
  70. action="store_true",
  71. help="Validate ONNX model",
  72. )
  73. parser.add_argument(
  74. "-o",
  75. "--optimize_onnx",
  76. required=False,
  77. action="store_true",
  78. help="Use optimizer.py to optimize onnx model",
  79. )
  80. parser.set_defaults(optimize_onnx=False)
  81. parser.add_argument(
  82. "--stage",
  83. type=int,
  84. default=0,
  85. required=False,
  86. choices=[0, 1, 2],
  87. help="Stage in generation: 1 (initial decoder), 2 (decoder), 0 (both). "
  88. "1 - decode the first token when past_sequence_length is zero; "
  89. "2 - decode the remaining tokens when past_sequence_length is not zero; "
  90. "0 - one onnx model for both stages 1 and 2. "
  91. "Note that we will optimize 1 and 2 differently for best performance.",
  92. )
  93. parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
  94. parser.set_defaults(use_gpu=False)
  95. parser.add_argument(
  96. "-p",
  97. "--precision",
  98. type=Precision,
  99. default=Precision.FLOAT32,
  100. choices=list(Precision),
  101. help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization",
  102. )
  103. parser.add_argument("--torchscript", required=False, action="store_true", help="use Torchscript")
  104. parser.set_defaults(torchscript=False)
  105. parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1], help="batch size")
  106. parser.add_argument(
  107. "--sequence_lengths",
  108. nargs="+",
  109. type=int,
  110. default=[1],
  111. help="sequence lengths (excluding past)",
  112. )
  113. parser.add_argument(
  114. "-s",
  115. "--past_sequence_lengths",
  116. nargs="+",
  117. type=int,
  118. default=[8, 16, 32, 64, 128, 256],
  119. help="past sequence lengths",
  120. )
  121. parser.add_argument(
  122. "-r",
  123. "--result_csv",
  124. required=False,
  125. default=None,
  126. help="CSV file for saving summary results.",
  127. )
  128. parser.add_argument("--thread_num", required=False, type=int, default=-1, help="Threads to use")
  129. parser.add_argument("--include_copy_output_latency", required=False, action="store_true")
  130. parser.set_defaults(include_copy_output_latency=False)
  131. parser.add_argument("--verbose", required=False, action="store_true")
  132. parser.set_defaults(verbose=False)
  133. parser.add_argument("--output_torch_latency", required=False, action="store_true")
  134. parser.set_defaults(output_torch_latency=False)
  135. parser.add_argument("--disable_io_binding", required=False, action="store_true")
  136. parser.set_defaults(disable_io_binding=False)
  137. args = parser.parse_args(argv)
  138. return args
  139. def main(args):
  140. if version.parse(transformers_version) < version.parse(
  141. "3.1.0"
  142. ): # past_key_values name does not exist in 3.0.2 or older
  143. raise RuntimeError("This tool requires transformers 3.1.0 or later.")
  144. logger.info(f"Arguments:{args}")
  145. if args.precision == Precision.FLOAT16:
  146. assert args.optimize_onnx and args.use_gpu, "fp16 requires --optimize_onnx --use_gpu"
  147. if args.precision == Precision.INT8:
  148. assert not args.use_gpu, "quantization only supports CPU"
  149. if args.stage == 1:
  150. assert args.past_sequence_lengths == [0], "past_sequence_lengths shall be 0 for stage==1 (init decoder)"
  151. torch.set_num_threads(psutil.cpu_count(logical=True) if args.thread_num <= 0 else args.thread_num)
  152. print(torch.__config__.parallel_info())
  153. cache_dir = args.cache_dir
  154. output_dir = args.onnx_dir
  155. prepare_environment(cache_dir, output_dir, args.use_gpu)
  156. model_class = MODEL_CLASSES[args.model_class][0]
  157. gpt2helper = Gpt2Helper
  158. config = AutoConfig.from_pretrained(args.model_name_or_path, torchscript=args.torchscript, cache_dir=cache_dir)
  159. model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir)
  160. # This script does not support float16 for PyTorch.
  161. # if args.float16:
  162. # model.half()
  163. device = torch.device("cuda:0" if args.use_gpu else "cpu")
  164. model.to(device)
  165. use_external_data_format = config.n_layer > 24 # TODO: find a way to check model size > 2GB
  166. onnx_model_paths = gpt2helper.get_onnx_paths(
  167. output_dir,
  168. args.model_name_or_path,
  169. args.model_class,
  170. has_past=True,
  171. new_folder=use_external_data_format,
  172. )
  173. onnx_model_path = onnx_model_paths["raw"]
  174. use_padding = MODEL_CLASSES[args.model_class][2]
  175. gpt2helper.export_onnx(
  176. model,
  177. device,
  178. onnx_model_path,
  179. args.verbose,
  180. use_external_data_format,
  181. has_position_ids=use_padding,
  182. has_attention_mask=use_padding,
  183. )
  184. if args.optimize_onnx or args.precision != Precision.FLOAT32:
  185. onnx_model_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"]
  186. gpt2helper.optimize_onnx(
  187. onnx_model_paths["raw"],
  188. onnx_model_path,
  189. args.precision == Precision.FLOAT16,
  190. model.config.num_attention_heads,
  191. model.config.hidden_size,
  192. use_external_data_format,
  193. auto_mixed_precision=True,
  194. stage=args.stage,
  195. )
  196. if args.precision == Precision.INT8:
  197. logger.info("quantizing model...")
  198. QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_paths["int8"], use_external_data_format)
  199. model = QuantizeHelper.quantize_torch_model(model)
  200. logger.info("finished quantizing model")
  201. onnx_model_path = onnx_model_paths["int8"]
  202. if args.torchscript:
  203. model = gpt2helper.torchscript(
  204. model,
  205. config,
  206. device,
  207. has_position_ids=use_padding,
  208. has_attention_mask=use_padding,
  209. )
  210. session = create_onnxruntime_session(
  211. onnx_model_path,
  212. args.use_gpu,
  213. enable_all_optimization=False,
  214. num_threads=args.thread_num,
  215. verbose=args.verbose,
  216. )
  217. if session is None:
  218. return
  219. # Allocate output buffers for IO Binding
  220. max_output_shapes = gpt2helper.get_output_shapes(
  221. max(args.batch_sizes),
  222. max(args.past_sequence_lengths),
  223. max(args.sequence_lengths),
  224. config,
  225. args.model_class,
  226. )
  227. output_buffers = gpt2helper.get_output_buffers(max_output_shapes, device, args.precision == Precision.FLOAT16)
  228. csv_filename = args.result_csv or "benchmark_result_{}.csv".format(datetime.now().strftime("%Y%m%d-%H%M%S"))
  229. with open(csv_filename, mode="a", newline="") as csv_file:
  230. column_names = [
  231. "model_name",
  232. "model_class",
  233. "stage",
  234. "environment_variables",
  235. "gpu",
  236. "precision",
  237. "optimizer",
  238. "torchscript",
  239. "batch_size",
  240. "sequence_length",
  241. "past_sequence_length",
  242. "disable_io_binding",
  243. "torch_latency",
  244. "onnxruntime_latency",
  245. ]
  246. csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
  247. csv_writer.writeheader()
  248. for batch_size in args.batch_sizes:
  249. for sequence_length in args.sequence_lengths:
  250. for past_sequence_length in args.past_sequence_lengths:
  251. assert batch_size > 0 and sequence_length > 0 and past_sequence_length >= 0
  252. logger.debug(
  253. "Running test for batch_size=%d sequence_length=%d past_sequence_length=%d ...",
  254. batch_size,
  255. sequence_length,
  256. past_sequence_length,
  257. )
  258. dummy_inputs = gpt2helper.get_dummy_inputs(
  259. batch_size,
  260. past_sequence_length,
  261. sequence_length,
  262. config.num_attention_heads,
  263. config.hidden_size,
  264. config.n_layer,
  265. config.vocab_size,
  266. device,
  267. float16=(args.precision == Precision.FLOAT16),
  268. has_position_ids=use_padding,
  269. has_attention_mask=use_padding,
  270. )
  271. output_shapes = gpt2helper.get_output_shapes(
  272. batch_size,
  273. past_sequence_length,
  274. sequence_length,
  275. config,
  276. args.model_class,
  277. )
  278. try:
  279. if args.validate_onnx or args.output_torch_latency:
  280. outputs, torch_latency = gpt2helper.pytorch_inference(model, dummy_inputs, args.test_times)
  281. # Dump Torch output shape
  282. for i, value in enumerate(outputs):
  283. if isinstance(value, tuple):
  284. logger.debug(
  285. f"torch output {i} is tuple of size {len(value)}, shape {value[0].shape}"
  286. )
  287. else:
  288. logger.debug(f"torch output {i} shape {value.shape}")
  289. else:
  290. outputs = None
  291. torch_latency = None
  292. if args.disable_io_binding:
  293. ort_outputs, ort_latency = gpt2helper.onnxruntime_inference(
  294. session, dummy_inputs, args.test_times
  295. )
  296. else:
  297. ort_outputs, ort_latency = gpt2helper.onnxruntime_inference_with_binded_io(
  298. session,
  299. dummy_inputs,
  300. output_buffers,
  301. output_shapes,
  302. args.test_times,
  303. return_numpy=False,
  304. include_copy_output_latency=args.include_copy_output_latency,
  305. )
  306. if args.validate_onnx:
  307. copy_outputs = ort_outputs
  308. if not args.disable_io_binding:
  309. # Results of IO binding might be in GPU. Copy outputs to CPU for comparison.
  310. copy_outputs = []
  311. for output in ort_outputs:
  312. copy_outputs.append(output.cpu().numpy())
  313. if gpt2helper.compare_outputs(
  314. outputs,
  315. copy_outputs,
  316. model_class=args.model_class,
  317. rtol=DEFAULT_TOLERANCE[args.precision],
  318. atol=DEFAULT_TOLERANCE[args.precision],
  319. ):
  320. logger.info(
  321. f"Pytorch and ONNX Runtime outputs are all close (tolerance={DEFAULT_TOLERANCE[args.precision]})."
  322. )
  323. logger.info(
  324. "batch_size=%d, sequence_length=%d, past_sequence_length=%d, onnxruntime_latency=%.2f %s %s",
  325. batch_size,
  326. sequence_length,
  327. past_sequence_length,
  328. ort_latency,
  329. "(disable_io_binding)" if args.disable_io_binding else "",
  330. ", torch_latency={torch_latency}" if torch_latency else "",
  331. )
  332. row = {
  333. "model_name": args.model_name_or_path,
  334. "model_class": args.model_class,
  335. "stage": args.stage,
  336. "environment_variables": get_ort_environment_variables(),
  337. "gpu": args.use_gpu,
  338. "precision": args.precision,
  339. "optimizer": args.optimize_onnx,
  340. "torchscript": args.torchscript,
  341. "batch_size": batch_size,
  342. "sequence_length": sequence_length,
  343. "past_sequence_length": past_sequence_length,
  344. "disable_io_binding": args.disable_io_binding,
  345. "torch_latency": f"{torch_latency:.2f}" if torch_latency else "None",
  346. "onnxruntime_latency": f"{ort_latency:.2f}",
  347. }
  348. csv_writer.writerow(row)
  349. except Exception:
  350. logger.error("Exception", exc_info=True) # noqa: G201
  351. return None
  352. logger.info(f"Results are saved to file {csv_filename}")
  353. return csv_filename
  354. if __name__ == "__main__":
  355. args = parse_arguments()
  356. setup_logger(args.verbose)
  357. main(args)