benchmark.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import argparse
  7. import ast
  8. import datetime
  9. import gc
  10. import logging
  11. import os
  12. import sys
  13. import time
  14. import numpy as np
  15. import psutil
  16. import torch
  17. import whisper
  18. from benchmark_helper import measure_memory, setup_logger
  19. from onnxruntime_extensions import get_library_path
  20. from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
  21. from torch.profiler import ProfilerActivity, profile, record_function
  22. from tqdm import trange
  23. from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
  24. import onnxruntime as ort
  25. logger = logging.getLogger(__name__)
  26. def get_inputs(args: argparse.Namespace):
  27. if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}:
  28. raise Exception("Unable to auto-detect inputs for provided model")
  29. def load_via_ffmpeg():
  30. audio = whisper.load_audio(args.audio_path)
  31. audio = whisper.pad_or_trim(audio)
  32. return audio
  33. def load_via_numpy():
  34. with open(args.audio_path, "rb") as f:
  35. audio = np.asarray(list(f.read()), dtype=np.uint8)
  36. audio = np.array([audio])
  37. return audio
  38. inputs = {
  39. "max_length": args.max_length,
  40. "min_length": args.min_length,
  41. "num_beams": args.num_beams,
  42. "num_return_sequences": args.num_return_sequences,
  43. "length_penalty": args.length_penalty,
  44. "repetition_penalty": args.repetition_penalty,
  45. }
  46. if args.benchmark_type == "ort":
  47. # convert_to_onnx export or ONNX E2E solution created by Olive
  48. for k, v in inputs.items():
  49. inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
  50. if args.has_decoder_input_ids:
  51. inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
  52. if args.has_logits_processor:
  53. inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
  54. if args.has_temperature:
  55. inputs["temperature"] = np.array([args.temperature], dtype=np.float32)
  56. # Measure time taken to load audio file
  57. logger.info(f"Load audio: {args.audio_path}")
  58. load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
  59. time_fn(args, load_audio_fn, args.has_audio_stream)
  60. audio_data = load_audio_fn(args.has_audio_stream)
  61. if args.has_audio_stream:
  62. # ONNX E2E solution created by Olive
  63. inputs["audio_stream"] = audio_data
  64. return inputs
  65. # Measure time taken to get input features
  66. logger.info("Feature extraction: ")
  67. return_type = "np" if args.benchmark_type == "ort" else "pt"
  68. processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
  69. [audio], return_tensors=return_type, sampling_rate=args.sampling_rate
  70. ).input_features
  71. time_fn(args, processor_fn, audio_data)
  72. input_features = processor_fn(audio_data)
  73. if args.benchmark_type == "ort":
  74. # convert_to_onnx export
  75. inputs["input_features"] = input_features
  76. return inputs
  77. inputs["inputs"] = input_features.to(
  78. dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
  79. )
  80. inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
  81. inputs["early_stopping"] = True
  82. inputs["use_cache"] = True
  83. if args.decoder_input_ids:
  84. inputs["forced_decoder_ids"] = args.decoder_input_ids
  85. return inputs
  86. def get_model(args: argparse.Namespace):
  87. model, sess_options = None, None
  88. start_time, end_time = None, None
  89. # There are multiple sources that the model could come from:
  90. # 1) Benchmark Whisper from Hugging Face
  91. # 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
  92. # 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
  93. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
  94. source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
  95. start_time = time.time()
  96. model = AutoModelForSpeechSeq2Seq.from_pretrained(
  97. source,
  98. torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
  99. use_cache=True,
  100. ).to(args.target_device)
  101. end_time = time.time()
  102. if args.benchmark_type == "hf-pt-compile":
  103. model = torch.compile(model)
  104. elif args.benchmark_type in {"hf-ort", "ort"}:
  105. sess_options = ort.SessionOptions()
  106. sess_options.enable_profiling = args.profile
  107. sess_options.register_custom_ops_library(get_library_path())
  108. if args.verbose:
  109. sess_options.log_verbosity_level = 1
  110. sess_options.log_severity_level = 1
  111. if args.tune:
  112. ort.set_default_logger_severity(0)
  113. ort.set_default_logger_verbosity(0)
  114. else:
  115. raise Exception(f"Cannot recognize {args.benchmark_type}")
  116. if args.benchmark_type == "hf-ort":
  117. # Optimum export
  118. provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
  119. provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
  120. start_time = time.time()
  121. model = ORTModelForSpeechSeq2Seq.from_pretrained(
  122. args.hf_ort_dir_path,
  123. provider=provider,
  124. provider_options=provider_options,
  125. session_options=sess_options,
  126. use_io_binding=True, # Avoid memory copy overhead
  127. )
  128. end_time = time.time()
  129. if args.benchmark_type == "ort":
  130. # convert_to_onnx.py export
  131. logger.info(f"Loading model from {args.ort_model_path}")
  132. start_time = time.time()
  133. model = ort.InferenceSession(
  134. args.ort_model_path,
  135. sess_options,
  136. providers=[args.execution_provider],
  137. )
  138. end_time = time.time()
  139. logger.info(f"Loaded model in {end_time - start_time} s")
  140. return model
  141. def time_fn(args, fn, inputs):
  142. warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
  143. benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
  144. torch_device = torch.device(args.target_device)
  145. # Warm up
  146. warmup_range = (
  147. range(args.warmup_runs)
  148. if args.benchmark_type == "ort"
  149. else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
  150. )
  151. if args.verbose:
  152. outputs = fn(warmup_inputs)
  153. logger.info(outputs)
  154. for _ in warmup_range:
  155. fn(warmup_inputs)
  156. # Benchmark
  157. if args.device != "cpu":
  158. torch.cuda.synchronize(torch_device)
  159. start_time = time.time()
  160. bench_range = (
  161. range(args.num_runs)
  162. if args.benchmark_type == "ort"
  163. else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
  164. )
  165. for _ in bench_range:
  166. fn(benchmark_inputs)
  167. if args.device != "cpu":
  168. torch.cuda.synchronize(torch_device)
  169. end_time = time.time()
  170. # Newline print after trange in order to print metrics on new lines without progress bar on same line
  171. if args.benchmark_type != "ort":
  172. logger.info("")
  173. batch_size = 1
  174. latency = (end_time - start_time) / args.num_runs
  175. throughput = batch_size / latency
  176. logger.info(f"Latency: {latency} s")
  177. logger.info(f"Throughput: {throughput} qps")
  178. return
  179. def profile_fn(args, fn, inputs, inputs_type):
  180. # Filename prefix format:
  181. # "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
  182. prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
  183. filename = None
  184. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
  185. # Profile PyTorch kernels
  186. with profile( # noqa: SIM117
  187. activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
  188. ) as prof:
  189. with record_function("model_inference"):
  190. fn(inputs)
  191. prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
  192. filename = os.path.join(args.log_folder, f"{prefix}.log")
  193. with open(filename, "w") as f:
  194. f.write(prof_data)
  195. else:
  196. # Profile ORT kernels
  197. fn(inputs)
  198. # Set new log name for ORT profile log generated
  199. filename = f"{prefix}.json"
  200. return filename
  201. def measure_fn(args, fn, inputs):
  202. # Measure CPU usage
  203. pid = os.getpid()
  204. process = psutil.Process(pid)
  205. process.cpu_percent(interval=0.1)
  206. fn(inputs)
  207. logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
  208. # Measure memory usage
  209. gc.collect()
  210. torch.cuda.empty_cache()
  211. measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
  212. # Flush output so memory usage is printed
  213. sys.stdout.flush()
  214. def run_hf_inference(args, inputs, model):
  215. # Inference steps to measure
  216. def get_pred_ids(inputs):
  217. # Inference pass with predicted token ids generation
  218. predicted_ids = model.generate(**inputs)
  219. return predicted_ids
  220. def gen_and_dec(inputs):
  221. # Inference pass with generation and decoding
  222. predicted_ids = get_pred_ids(inputs)
  223. transcription = []
  224. for _ in range(args.num_return_sequences):
  225. transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
  226. return predicted_ids, transcription
  227. # Examples of other inference steps that can be measured:
  228. # To use, uncomment the function and assign it to `generate_fn`
  229. # def get_logits(inputs):
  230. # # Inference pass without decoding
  231. # outputs = model(**inputs)
  232. # return outputs
  233. generate_fn = gen_and_dec
  234. if args.benchmark_type == "hf-pt-compile":
  235. # Run forward pass once with each set of inputs to process through Dynamo
  236. generate_fn(inputs)
  237. if args.profile:
  238. new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
  239. if args.benchmark_type == "hf-ort":
  240. # Rename log files per model component and turn profiling off to stop appending to log
  241. new_prefix = new_logname[: -len(".json")]
  242. old_logname = model.encoder.session.end_profiling()
  243. new_logname = new_prefix + "-encoder.json"
  244. if os.path.isfile(old_logname):
  245. logger.warning(f"Renaming {old_logname} to {new_logname}")
  246. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  247. old_logname = model.decoder.session.end_profiling()
  248. new_logname = new_prefix + "-decoder.json"
  249. if os.path.isfile(old_logname):
  250. logger.warning(f"Renaming {old_logname} to {new_logname}")
  251. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  252. old_logname = model.decoder_with_past.session.end_profiling()
  253. new_logname = new_prefix + "-decoder-with-past.json"
  254. if os.path.isfile(old_logname):
  255. logger.warning(f"Renaming {old_logname} to {new_logname}")
  256. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  257. return
  258. # PyTorch evaluations
  259. logger.info("\nEvaluating PyTorch...")
  260. time_fn(args, generate_fn, inputs)
  261. predicted_ids, transcription = generate_fn(inputs)
  262. logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
  263. logger.info(f"Transcription: {transcription[0]}")
  264. measure_fn(args, generate_fn, inputs)
  265. def run_ort_inference(args, inputs, model):
  266. def prepare_ort_inputs(inputs, warmup=False):
  267. # Check that all model inputs will be provided
  268. model_inputs = {model_input.name for model_input in model.get_inputs()}
  269. user_inputs = set(inputs.keys())
  270. missing_inputs = model_inputs - user_inputs
  271. if len(missing_inputs):
  272. logger.error(f"The following model inputs are missing: {missing_inputs}")
  273. raise Exception("There are missing inputs to the model. Please add them and try again.")
  274. if warmup and args.tune:
  275. inputs["min_length"] = inputs["max_length"]
  276. # Remove unnecessary inputs from model inputs
  277. unnecessary_inputs = user_inputs - model_inputs
  278. if len(unnecessary_inputs):
  279. for unnecessary_input in unnecessary_inputs:
  280. logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
  281. del inputs[unnecessary_input]
  282. # Add IO bindings for non-CPU execution providers
  283. if args.device != "cpu":
  284. io_binding = model.io_binding()
  285. for k, v in inputs.items():
  286. io_binding.bind_cpu_input(k, v)
  287. for output in model.get_outputs():
  288. io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id)
  289. return io_binding
  290. return inputs
  291. def with_io_binding(io_binding):
  292. # Inference pass with IO binding
  293. model.run_with_iobinding(io_binding)
  294. return io_binding
  295. def without_io_binding(inputs):
  296. # Inference pass without IO binding
  297. outputs = model.run(None, inputs)
  298. return outputs
  299. def handle_output(output):
  300. if args.eos_token_id in output:
  301. first_end = np.where(output == args.eos_token_id)[0][0]
  302. return output[: first_end + 1]
  303. return output
  304. generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
  305. ort_inputs = prepare_ort_inputs(inputs)
  306. if args.profile:
  307. new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
  308. # Turn profiling off to stop appending to log file
  309. old_logname = model.end_profiling()
  310. logger.warning(f"Renaming {old_logname} to {new_logname}")
  311. os.rename(old_logname, os.path.join(args.log_folder, new_logname))
  312. return
  313. # ORT evaluation
  314. logger.info("\nEvaluating ONNX Runtime...")
  315. ort_evaluate_inputs = ort_inputs
  316. if args.tune:
  317. ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True)
  318. ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs)
  319. time_fn(args, generate_fn, ort_evaluate_inputs)
  320. ort_outputs = generate_fn(ort_inputs)
  321. if args.device != "cpu":
  322. ort_outputs = ort_outputs.copy_outputs_to_cpu()
  323. ort_outputs = ort_outputs[0]
  324. if args.has_audio_stream:
  325. # ONNX E2E model from Olive produces transcribed output
  326. logger.info(f"Transcription: {ort_outputs[0][0]}")
  327. else:
  328. # convert_to_onnx model produces generated ids
  329. actual_output = handle_output(ort_outputs[0][0])
  330. logger.info(f"Generated token length: {len(actual_output)} tokens")
  331. transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
  332. # print to stdout as the output for comparison
  333. print(f"{transcription}")
  334. measure_fn(args, generate_fn, ort_inputs)
  335. def run_inference(args, inputs, model):
  336. if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
  337. run_hf_inference(args, inputs, model)
  338. elif args.benchmark_type == "ort":
  339. run_ort_inference(args, inputs, model)
  340. else:
  341. raise Exception(f"Cannot recognize {args.benchmark_type}")
  342. def parse_args():
  343. parser = argparse.ArgumentParser()
  344. parser.add_argument(
  345. "-bt",
  346. "--benchmark-type",
  347. type=str,
  348. required=True,
  349. choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"],
  350. )
  351. parser.add_argument(
  352. "-m",
  353. "--model-name",
  354. type=str,
  355. required=True,
  356. help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
  357. )
  358. parser.add_argument(
  359. "-p",
  360. "--precision",
  361. type=str,
  362. required=True,
  363. default="fp32",
  364. choices=["int8", "fp16", "fp32"],
  365. help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
  366. )
  367. parser.add_argument(
  368. "--hf-pt-model-path",
  369. type=str,
  370. default="",
  371. help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
  372. )
  373. parser.add_argument(
  374. "--hf-ort-dir-path",
  375. type=str,
  376. default="",
  377. help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
  378. )
  379. parser.add_argument(
  380. "--ort-model-path",
  381. type=str,
  382. default="",
  383. help="Path to ONNX model",
  384. )
  385. # Args for running and evaluating the model
  386. parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
  387. parser.add_argument(
  388. "-d",
  389. "--device",
  390. type=str,
  391. default="cuda" if torch.cuda.is_available() else "cpu",
  392. choices=["cpu", "cuda", "rocm"],
  393. )
  394. parser.add_argument("-id", "--device-id", type=int, default=0)
  395. parser.add_argument("-w", "--warmup-runs", type=int, default=5)
  396. parser.add_argument("-n", "--num-runs", type=int, default=10)
  397. parser.add_argument("--seed", type=int, default=2)
  398. # Optional args:
  399. parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
  400. # Args for decoding logic
  401. # Required args:
  402. parser.add_argument("--max-length", type=int, default=448)
  403. parser.add_argument("--min-length", type=int, default=0)
  404. parser.add_argument("--num-beams", type=int, default=1)
  405. parser.add_argument("--num-return-sequences", type=int, default=1)
  406. parser.add_argument("--length-penalty", type=float, default=1.0)
  407. parser.add_argument("--repetition-penalty", type=float, default=1.0)
  408. parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
  409. # Optional args for E2E solution:
  410. parser.add_argument(
  411. "--decoder-input-ids",
  412. type=str,
  413. default="[]",
  414. help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
  415. )
  416. parser.add_argument(
  417. "--logits-processor",
  418. type=int,
  419. default=1,
  420. help="Whether to use timestamps logits processor or not (0 for false, 1 for true).",
  421. )
  422. parser.add_argument(
  423. "--temperature",
  424. type=float,
  425. default=1.0,
  426. help="Temperature value for generation.",
  427. )
  428. # Args for accessing detailed info
  429. parser.add_argument("--profile", default=False, action="store_true")
  430. parser.add_argument(
  431. "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
  432. )
  433. parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
  434. parser.add_argument("--verbose", default=False, action="store_true")
  435. parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
  436. parser.add_argument(
  437. "--tune",
  438. default=False,
  439. action="store_true",
  440. help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel",
  441. )
  442. args = parser.parse_args()
  443. # Set seed properties
  444. np.random.seed(args.seed)
  445. torch.manual_seed(args.seed)
  446. args.monitor_type = args.device
  447. # Set runtime properties
  448. if "ort" in args.benchmark_type:
  449. args.execution_provider = f"{args.device.upper()}ExecutionProvider"
  450. if args.execution_provider == "CUDAExecutionProvider":
  451. args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
  452. elif args.execution_provider == "ROCMExecutionProvider":
  453. args.execution_provider = (
  454. args.execution_provider,
  455. {
  456. "device_id": args.device_id,
  457. "tunable_op_enable": 1,
  458. "tunable_op_tuning_enable": 1 if args.tune else 0,
  459. },
  460. )
  461. args.device = "cuda"
  462. # Check that model paths have been specified for any benchmarking with ORT
  463. if args.benchmark_type == "hf-ort":
  464. assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
  465. if args.benchmark_type == "ort":
  466. assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
  467. # Convert decoder_input_ids string to list of ids
  468. # (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
  469. args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
  470. return args
  471. def main():
  472. args = parse_args()
  473. setup_logger(args.verbose)
  474. logger.info(args.__dict__)
  475. torch.backends.cudnn.benchmark = True
  476. config = WhisperConfig.from_pretrained(args.model_name)
  477. processor = WhisperProcessor.from_pretrained(args.model_name)
  478. target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
  479. use_fp16 = args.precision == "fp16"
  480. setattr(args, "processor", processor) # noqa: B010
  481. setattr(args, "target_device", target_device) # noqa: B010
  482. setattr(args, "use_fp16", use_fp16) # noqa: B010
  483. setattr(args, "has_audio_stream", False) # noqa: B010
  484. setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
  485. logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
  486. # Measure cost to transcribe audio
  487. model = get_model(args)
  488. if args.benchmark_type == "ort":
  489. # Check for optional inputs that could have been added during export
  490. ort_model_inputs = {model_input.name for model_input in model.get_inputs()}
  491. args.has_audio_stream = "audio_stream" in ort_model_inputs
  492. setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
  493. setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
  494. setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010
  495. if args.decoder_input_ids == []:
  496. args.decoder_input_ids = [config.decoder_start_token_id]
  497. inputs = get_inputs(args)
  498. run_inference(args, inputs, model)
  499. if __name__ == "__main__":
  500. main()