benchmark.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import argparse
  6. import csv
  7. import os
  8. import statistics
  9. import sys
  10. import time
  11. from pathlib import Path
  12. import coloredlogs
  13. # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package.
  14. import torch
  15. from benchmark_helper import measure_memory
  16. SD_MODELS = {
  17. "1.5": "runwayml/stable-diffusion-v1-5",
  18. "2.0": "stabilityai/stable-diffusion-2",
  19. "2.1": "stabilityai/stable-diffusion-2-1",
  20. "xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0",
  21. "3.0M": "stabilityai/stable-diffusion-3-medium-diffusers",
  22. "3.5M": "stabilityai/stable-diffusion-3.5-medium",
  23. "3.5L": "stabilityai/stable-diffusion-3.5-large",
  24. "Flux.1S": "black-forest-labs/FLUX.1-schnell",
  25. "Flux.1D": "black-forest-labs/FLUX.1-dev",
  26. }
  27. PROVIDERS = {
  28. "cuda": "CUDAExecutionProvider",
  29. "rocm": "ROCMExecutionProvider",
  30. "migraphx": "MIGraphXExecutionProvider",
  31. "tensorrt": "TensorrtExecutionProvider",
  32. }
  33. def example_prompts():
  34. prompts = [
  35. "a photo of an astronaut riding a horse on mars",
  36. "cute grey cat with blue eyes, wearing a bowtie, acrylic painting",
  37. "a cute magical flying dog, fantasy art drawn by disney concept artists, highly detailed, digital painting",
  38. "an illustration of a house with large barn with many cute flower pots and beautiful blue sky scenery",
  39. "one apple sitting on a table, still life, reflective, full color photograph, centered, close-up product",
  40. "background texture of stones, masterpiece, artistic, stunning photo, award winner photo",
  41. "new international organic style house, tropical surroundings, architecture, 8k, hdr",
  42. "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation",
  43. "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic",
  44. "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k",
  45. ]
  46. negative_prompt = "bad composition, ugly, abnormal, malformed"
  47. return prompts, negative_prompt
  48. def warmup_prompts():
  49. return "warm up", "bad"
  50. def measure_gpu_memory(monitor_type, func, start_memory=None):
  51. return measure_memory(is_gpu=True, func=func, monitor_type=monitor_type, start_memory=start_memory)
  52. def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_checker: bool):
  53. from diffusers import DDIMScheduler, OnnxStableDiffusionPipeline # noqa: PLC0415
  54. import onnxruntime # noqa: PLC0415
  55. if directory is not None:
  56. assert os.path.exists(directory)
  57. session_options = onnxruntime.SessionOptions()
  58. pipe = OnnxStableDiffusionPipeline.from_pretrained(
  59. directory,
  60. provider=provider,
  61. sess_options=session_options,
  62. )
  63. else:
  64. pipe = OnnxStableDiffusionPipeline.from_pretrained(
  65. model_name,
  66. revision="onnx",
  67. provider=provider,
  68. use_auth_token=True,
  69. )
  70. pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
  71. pipe.set_progress_bar_config(disable=True)
  72. if disable_safety_checker:
  73. pipe.safety_checker = None
  74. pipe.feature_extractor = None
  75. return pipe
  76. def get_torch_pipeline(model_name: str, disable_safety_checker: bool, enable_torch_compile: bool, use_xformers: bool):
  77. if "FLUX" in model_name:
  78. from diffusers import FluxPipeline # noqa: PLC0415
  79. pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda")
  80. if enable_torch_compile:
  81. pipe.transformer.to(memory_format=torch.channels_last)
  82. pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
  83. return pipe
  84. if "stable-diffusion-3" in model_name:
  85. from diffusers import StableDiffusion3Pipeline # noqa: PLC0415
  86. pipe = StableDiffusion3Pipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda")
  87. if enable_torch_compile:
  88. pipe.transformer.to(memory_format=torch.channels_last)
  89. pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
  90. return pipe
  91. from diffusers import DDIMScheduler, StableDiffusionPipeline # noqa: PLC0415
  92. from torch import channels_last, float16 # noqa: PLC0415
  93. pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=float16).to("cuda")
  94. pipe.unet.to(memory_format=channels_last) # in-place operation
  95. if use_xformers:
  96. pipe.enable_xformers_memory_efficient_attention()
  97. if enable_torch_compile:
  98. pipe.unet = torch.compile(pipe.unet)
  99. pipe.vae = torch.compile(pipe.vae)
  100. pipe.text_encoder = torch.compile(pipe.text_encoder)
  101. print("Torch compiled unet, vae and text_encoder")
  102. pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
  103. pipe.set_progress_bar_config(disable=True)
  104. if disable_safety_checker:
  105. pipe.safety_checker = None
  106. pipe.feature_extractor = None
  107. return pipe
  108. def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, steps: int, disable_safety_checker: bool):
  109. short_model_name = model_name.split("/")[-1].replace("stable-diffusion-", "sd")
  110. return f"{engine}_{short_model_name}_b{batch_size}_s{steps}" + ("" if disable_safety_checker else "_safe")
  111. def run_ort_pipeline(
  112. pipe,
  113. batch_size: int,
  114. image_filename_prefix: str,
  115. height,
  116. width,
  117. steps,
  118. num_prompts,
  119. batch_count,
  120. start_memory,
  121. memory_monitor_type,
  122. skip_warmup: bool = False,
  123. ):
  124. from diffusers import OnnxStableDiffusionPipeline # noqa: PLC0415
  125. assert isinstance(pipe, OnnxStableDiffusionPipeline)
  126. prompts, negative_prompt = example_prompts()
  127. def warmup():
  128. if skip_warmup:
  129. return
  130. prompt, negative = warmup_prompts()
  131. pipe(
  132. prompt=[prompt] * batch_size,
  133. height=height,
  134. width=width,
  135. num_inference_steps=steps,
  136. negative_prompt=[negative] * batch_size,
  137. )
  138. # Run warm up, and measure GPU memory of two runs
  139. # cuDNN/MIOpen The first run has algo search so it might need more memory)
  140. first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  141. second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  142. warmup()
  143. latency_list = []
  144. for i, prompt in enumerate(prompts):
  145. if i >= num_prompts:
  146. break
  147. inference_start = time.time()
  148. images = pipe(
  149. prompt=[prompt] * batch_size,
  150. height=height,
  151. width=width,
  152. num_inference_steps=steps,
  153. negative_prompt=[negative_prompt] * batch_size,
  154. ).images
  155. inference_end = time.time()
  156. latency = inference_end - inference_start
  157. latency_list.append(latency)
  158. print(f"Inference took {latency:.3f} seconds")
  159. for k, image in enumerate(images):
  160. image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
  161. from onnxruntime import __version__ as ort_version # noqa: PLC0415
  162. return {
  163. "engine": "onnxruntime",
  164. "version": ort_version,
  165. "height": height,
  166. "width": width,
  167. "steps": steps,
  168. "batch_size": batch_size,
  169. "batch_count": batch_count,
  170. "num_prompts": num_prompts,
  171. "average_latency": sum(latency_list) / len(latency_list),
  172. "median_latency": statistics.median(latency_list),
  173. "first_run_memory_MB": first_run_memory,
  174. "second_run_memory_MB": second_run_memory,
  175. }
  176. def get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux, batch_size) -> dict:
  177. # Flux does not support negative prompt
  178. kwargs = (
  179. (
  180. {"negative_prompt": negative_prompt}
  181. if use_num_images_per_prompt
  182. else {"negative_prompt": [negative_prompt] * batch_size}
  183. )
  184. if not is_flux
  185. else {}
  186. )
  187. # Fix the random seed so that we can inspect the output quality easily.
  188. if torch.cuda.is_available():
  189. kwargs["generator"] = torch.Generator(device="cuda").manual_seed(123)
  190. return kwargs
  191. def run_torch_pipeline(
  192. pipe,
  193. batch_size: int,
  194. image_filename_prefix: str,
  195. height,
  196. width,
  197. steps,
  198. num_prompts,
  199. batch_count,
  200. start_memory,
  201. memory_monitor_type,
  202. skip_warmup=False,
  203. ):
  204. prompts, negative_prompt = example_prompts()
  205. import diffusers # noqa: PLC0415
  206. is_flux = isinstance(pipe, diffusers.FluxPipeline)
  207. def warmup():
  208. if skip_warmup:
  209. return
  210. prompt, negative = warmup_prompts()
  211. extra_kwargs = get_negative_prompt_kwargs(negative, False, is_flux, batch_size)
  212. pipe(prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs)
  213. # Run warm up, and measure GPU memory of two runs (The first run has cuDNN algo search so it might need more memory)
  214. first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  215. second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  216. warmup()
  217. torch.set_grad_enabled(False)
  218. latency_list = []
  219. for i, prompt in enumerate(prompts):
  220. if i >= num_prompts:
  221. break
  222. torch.cuda.synchronize()
  223. inference_start = time.time()
  224. extra_kwargs = get_negative_prompt_kwargs(negative_prompt, False, is_flux, batch_size)
  225. images = pipe(
  226. prompt=[prompt] * batch_size,
  227. height=height,
  228. width=width,
  229. num_inference_steps=steps,
  230. **extra_kwargs,
  231. ).images
  232. torch.cuda.synchronize()
  233. inference_end = time.time()
  234. latency = inference_end - inference_start
  235. latency_list.append(latency)
  236. print(f"Inference took {latency:.3f} seconds")
  237. for k, image in enumerate(images):
  238. image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
  239. return {
  240. "engine": "torch",
  241. "version": torch.__version__,
  242. "height": height,
  243. "width": width,
  244. "steps": steps,
  245. "batch_size": batch_size,
  246. "batch_count": batch_count,
  247. "num_prompts": num_prompts,
  248. "average_latency": sum(latency_list) / len(latency_list),
  249. "median_latency": statistics.median(latency_list),
  250. "first_run_memory_MB": first_run_memory,
  251. "second_run_memory_MB": second_run_memory,
  252. }
  253. def run_ort(
  254. model_name: str,
  255. directory: str,
  256. provider: str,
  257. batch_size: int,
  258. disable_safety_checker: bool,
  259. height: int,
  260. width: int,
  261. steps: int,
  262. num_prompts: int,
  263. batch_count: int,
  264. start_memory,
  265. memory_monitor_type,
  266. tuning: bool,
  267. skip_warmup: bool = False,
  268. ):
  269. provider_and_options = provider
  270. if tuning and provider in ["CUDAExecutionProvider", "ROCMExecutionProvider"]:
  271. provider_and_options = (provider, {"tunable_op_enable": 1, "tunable_op_tuning_enable": 1})
  272. load_start = time.time()
  273. pipe = get_ort_pipeline(model_name, directory, provider_and_options, disable_safety_checker)
  274. load_end = time.time()
  275. print(f"Model loading took {load_end - load_start} seconds")
  276. image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, steps, disable_safety_checker)
  277. result = run_ort_pipeline(
  278. pipe,
  279. batch_size,
  280. image_filename_prefix,
  281. height,
  282. width,
  283. steps,
  284. num_prompts,
  285. batch_count,
  286. start_memory,
  287. memory_monitor_type,
  288. skip_warmup=skip_warmup,
  289. )
  290. result.update(
  291. {
  292. "model_name": model_name,
  293. "directory": directory,
  294. "provider": provider.replace("ExecutionProvider", ""),
  295. "disable_safety_checker": disable_safety_checker,
  296. "enable_cuda_graph": False,
  297. }
  298. )
  299. return result
  300. def get_optimum_ort_pipeline(
  301. model_name: str,
  302. directory: str,
  303. provider="CUDAExecutionProvider",
  304. disable_safety_checker: bool = True,
  305. use_io_binding: bool = False,
  306. ):
  307. from optimum.onnxruntime import ORTPipelineForText2Image # noqa: PLC0415
  308. if directory is not None and os.path.exists(directory):
  309. pipeline = ORTPipelineForText2Image.from_pretrained(directory, provider=provider, use_io_binding=use_io_binding)
  310. else:
  311. pipeline = ORTPipelineForText2Image.from_pretrained(
  312. model_name,
  313. export=True,
  314. provider=provider,
  315. use_io_binding=use_io_binding,
  316. )
  317. pipeline.save_pretrained(directory)
  318. if disable_safety_checker:
  319. pipeline.safety_checker = None
  320. pipeline.feature_extractor = None
  321. return pipeline
  322. def run_optimum_ort_pipeline(
  323. pipe,
  324. batch_size: int,
  325. image_filename_prefix: str,
  326. height,
  327. width,
  328. steps,
  329. num_prompts,
  330. batch_count,
  331. start_memory,
  332. memory_monitor_type,
  333. use_num_images_per_prompt=False,
  334. skip_warmup=False,
  335. ):
  336. print("Pipeline type", type(pipe))
  337. from optimum.onnxruntime.modeling_diffusion import ORTFluxPipeline # noqa: PLC0415
  338. is_flux = isinstance(pipe, ORTFluxPipeline)
  339. prompts, negative_prompt = example_prompts()
  340. def warmup():
  341. if skip_warmup:
  342. return
  343. prompt, negative = warmup_prompts()
  344. extra_kwargs = get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux, batch_size)
  345. if use_num_images_per_prompt:
  346. pipe(
  347. prompt=prompt,
  348. height=height,
  349. width=width,
  350. num_inference_steps=steps,
  351. num_images_per_prompt=batch_count,
  352. **extra_kwargs,
  353. )
  354. else:
  355. pipe(prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs)
  356. # Run warm up, and measure GPU memory of two runs.
  357. # The first run has algo search for cuDNN/MIOpen, so it might need more memory.
  358. first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  359. second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  360. warmup()
  361. extra_kwargs = get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux, batch_size)
  362. latency_list = []
  363. for i, prompt in enumerate(prompts):
  364. if i >= num_prompts:
  365. break
  366. inference_start = time.time()
  367. if use_num_images_per_prompt:
  368. images = pipe(
  369. prompt=prompt,
  370. height=height,
  371. width=width,
  372. num_inference_steps=steps,
  373. num_images_per_prompt=batch_size,
  374. **extra_kwargs,
  375. ).images
  376. else:
  377. images = pipe(
  378. prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs
  379. ).images
  380. inference_end = time.time()
  381. latency = inference_end - inference_start
  382. latency_list.append(latency)
  383. print(f"Inference took {latency:.3f} seconds")
  384. for k, image in enumerate(images):
  385. image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
  386. from onnxruntime import __version__ as ort_version # noqa: PLC0415
  387. return {
  388. "engine": "optimum_ort",
  389. "version": ort_version,
  390. "height": height,
  391. "width": width,
  392. "steps": steps,
  393. "batch_size": batch_size,
  394. "batch_count": batch_count,
  395. "num_prompts": num_prompts,
  396. "average_latency": sum(latency_list) / len(latency_list),
  397. "median_latency": statistics.median(latency_list),
  398. "first_run_memory_MB": first_run_memory,
  399. "second_run_memory_MB": second_run_memory,
  400. }
  401. def run_optimum_ort(
  402. model_name: str,
  403. directory: str,
  404. provider: str,
  405. batch_size: int,
  406. disable_safety_checker: bool,
  407. height: int,
  408. width: int,
  409. steps: int,
  410. num_prompts: int,
  411. batch_count: int,
  412. start_memory,
  413. memory_monitor_type,
  414. use_io_binding: bool = False,
  415. skip_warmup: bool = False,
  416. ):
  417. load_start = time.time()
  418. pipe = get_optimum_ort_pipeline(
  419. model_name, directory, provider, disable_safety_checker, use_io_binding=use_io_binding
  420. )
  421. load_end = time.time()
  422. print(f"Model loading took {load_end - load_start} seconds")
  423. full_model_name = model_name + "_" + Path(directory).name if directory else model_name
  424. image_filename_prefix = get_image_filename_prefix(
  425. "optimum", full_model_name, batch_size, steps, disable_safety_checker
  426. )
  427. result = run_optimum_ort_pipeline(
  428. pipe,
  429. batch_size,
  430. image_filename_prefix,
  431. height,
  432. width,
  433. steps,
  434. num_prompts,
  435. batch_count,
  436. start_memory,
  437. memory_monitor_type,
  438. skip_warmup=skip_warmup,
  439. )
  440. result.update(
  441. {
  442. "model_name": model_name,
  443. "directory": directory,
  444. "provider": provider.replace("ExecutionProvider", ""),
  445. "disable_safety_checker": disable_safety_checker,
  446. "enable_cuda_graph": False,
  447. }
  448. )
  449. return result
  450. def run_ort_trt_static(
  451. work_dir: str,
  452. version: str,
  453. batch_size: int,
  454. disable_safety_checker: bool,
  455. height: int,
  456. width: int,
  457. steps: int,
  458. num_prompts: int,
  459. batch_count: int,
  460. start_memory,
  461. memory_monitor_type,
  462. max_batch_size: int,
  463. nvtx_profile: bool = False,
  464. use_cuda_graph: bool = True,
  465. ):
  466. print("[I] Initializing ORT TensorRT EP accelerated StableDiffusionXL txt2img pipeline (static input shape)")
  467. # Register TensorRT plugins
  468. from trt_utilities import init_trt_plugins # noqa: PLC0415
  469. init_trt_plugins()
  470. assert batch_size <= max_batch_size
  471. from diffusion_models import PipelineInfo # noqa: PLC0415
  472. pipeline_info = PipelineInfo(version)
  473. short_name = pipeline_info.short_name()
  474. from engine_builder import EngineType, get_engine_paths # noqa: PLC0415
  475. from pipeline_stable_diffusion import StableDiffusionPipeline # noqa: PLC0415
  476. engine_type = EngineType.ORT_TRT
  477. onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(work_dir, pipeline_info, engine_type)
  478. # Initialize pipeline
  479. pipeline = StableDiffusionPipeline(
  480. pipeline_info,
  481. scheduler="DDIM",
  482. output_dir=output_dir,
  483. verbose=False,
  484. nvtx_profile=nvtx_profile,
  485. max_batch_size=max_batch_size,
  486. use_cuda_graph=use_cuda_graph,
  487. framework_model_dir=framework_model_dir,
  488. engine_type=engine_type,
  489. )
  490. # Load TensorRT engines and pytorch modules
  491. pipeline.backend.build_engines(
  492. engine_dir,
  493. framework_model_dir,
  494. onnx_dir,
  495. 17,
  496. opt_image_height=height,
  497. opt_image_width=width,
  498. opt_batch_size=batch_size,
  499. static_batch=True,
  500. static_image_shape=True,
  501. max_workspace_size=0,
  502. device_id=torch.cuda.current_device(),
  503. )
  504. # Here we use static batch and image size, so the resource allocation only need done once.
  505. # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency.
  506. pipeline.load_resources(height, width, batch_size)
  507. def warmup():
  508. prompt, negative = warmup_prompts()
  509. pipeline.run([prompt] * batch_size, [negative] * batch_size, height, width, denoising_steps=steps)
  510. # Run warm up, and measure GPU memory of two runs
  511. # The first run has algo search so it might need more memory
  512. first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  513. second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  514. warmup()
  515. image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, steps, disable_safety_checker)
  516. latency_list = []
  517. prompts, negative_prompt = example_prompts()
  518. for i, prompt in enumerate(prompts):
  519. if i >= num_prompts:
  520. break
  521. inference_start = time.time()
  522. # Use warmup mode here since non-warmup mode will save image to disk.
  523. images, pipeline_time = pipeline.run(
  524. [prompt] * batch_size,
  525. [negative_prompt] * batch_size,
  526. height,
  527. width,
  528. denoising_steps=steps,
  529. guidance=7.5,
  530. seed=123,
  531. )
  532. inference_end = time.time()
  533. latency = inference_end - inference_start
  534. latency_list.append(latency)
  535. print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}")
  536. for k, image in enumerate(images):
  537. image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
  538. pipeline.teardown()
  539. from tensorrt import __version__ as trt_version # noqa: PLC0415
  540. from onnxruntime import __version__ as ort_version # noqa: PLC0415
  541. return {
  542. "model_name": pipeline_info.name(),
  543. "engine": "onnxruntime",
  544. "version": ort_version,
  545. "provider": f"tensorrt({trt_version})",
  546. "directory": engine_dir,
  547. "height": height,
  548. "width": width,
  549. "steps": steps,
  550. "batch_size": batch_size,
  551. "batch_count": batch_count,
  552. "num_prompts": num_prompts,
  553. "average_latency": sum(latency_list) / len(latency_list),
  554. "median_latency": statistics.median(latency_list),
  555. "first_run_memory_MB": first_run_memory,
  556. "second_run_memory_MB": second_run_memory,
  557. "disable_safety_checker": disable_safety_checker,
  558. "enable_cuda_graph": use_cuda_graph,
  559. }
  560. def run_tensorrt_static(
  561. work_dir: str,
  562. version: str,
  563. model_name: str,
  564. batch_size: int,
  565. disable_safety_checker: bool,
  566. height: int,
  567. width: int,
  568. steps: int,
  569. num_prompts: int,
  570. batch_count: int,
  571. start_memory,
  572. memory_monitor_type,
  573. max_batch_size: int,
  574. nvtx_profile: bool = False,
  575. use_cuda_graph: bool = True,
  576. skip_warmup: bool = False,
  577. ):
  578. print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)")
  579. from cuda import cudart # noqa: PLC0415
  580. # Register TensorRT plugins
  581. from trt_utilities import init_trt_plugins # noqa: PLC0415
  582. init_trt_plugins()
  583. assert batch_size <= max_batch_size
  584. from diffusion_models import PipelineInfo # noqa: PLC0415
  585. pipeline_info = PipelineInfo(version)
  586. from engine_builder import EngineType, get_engine_paths # noqa: PLC0415
  587. from pipeline_stable_diffusion import StableDiffusionPipeline # noqa: PLC0415
  588. engine_type = EngineType.TRT
  589. onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
  590. work_dir, pipeline_info, engine_type
  591. )
  592. # Initialize pipeline
  593. pipeline = StableDiffusionPipeline(
  594. pipeline_info,
  595. scheduler="DDIM",
  596. output_dir=output_dir,
  597. verbose=False,
  598. nvtx_profile=nvtx_profile,
  599. max_batch_size=max_batch_size,
  600. use_cuda_graph=True,
  601. engine_type=engine_type,
  602. )
  603. # Load TensorRT engines and pytorch modules
  604. pipeline.backend.load_engines(
  605. engine_dir=engine_dir,
  606. framework_model_dir=framework_model_dir,
  607. onnx_dir=onnx_dir,
  608. onnx_opset=17,
  609. opt_batch_size=batch_size,
  610. opt_image_height=height,
  611. opt_image_width=width,
  612. static_batch=True,
  613. static_shape=True,
  614. enable_all_tactics=False,
  615. timing_cache=timing_cache,
  616. )
  617. # activate engines
  618. max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory())
  619. _, shared_device_memory = cudart.cudaMalloc(max_device_memory)
  620. pipeline.backend.activate_engines(shared_device_memory)
  621. # Here we use static batch and image size, so the resource allocation only need done once.
  622. # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency.
  623. pipeline.load_resources(height, width, batch_size)
  624. def warmup():
  625. if skip_warmup:
  626. return
  627. prompt, negative = warmup_prompts()
  628. pipeline.run([prompt] * batch_size, [negative] * batch_size, height, width, denoising_steps=steps)
  629. # Run warm up, and measure GPU memory of two runs
  630. # The first run has algo search so it might need more memory
  631. first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  632. second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  633. warmup()
  634. image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, steps, disable_safety_checker)
  635. latency_list = []
  636. prompts, negative_prompt = example_prompts()
  637. for i, prompt in enumerate(prompts):
  638. if i >= num_prompts:
  639. break
  640. inference_start = time.time()
  641. # Use warmup mode here since non-warmup mode will save image to disk.
  642. images, pipeline_time = pipeline.run(
  643. [prompt] * batch_size,
  644. [negative_prompt] * batch_size,
  645. height,
  646. width,
  647. denoising_steps=steps,
  648. seed=123,
  649. )
  650. inference_end = time.time()
  651. latency = inference_end - inference_start
  652. latency_list.append(latency)
  653. print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}")
  654. for k, image in enumerate(images):
  655. image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
  656. pipeline.teardown()
  657. import tensorrt as trt # noqa: PLC0415
  658. return {
  659. "engine": "tensorrt",
  660. "version": trt.__version__,
  661. "provider": "default",
  662. "height": height,
  663. "width": width,
  664. "steps": steps,
  665. "batch_size": batch_size,
  666. "batch_count": batch_count,
  667. "num_prompts": num_prompts,
  668. "average_latency": sum(latency_list) / len(latency_list),
  669. "median_latency": statistics.median(latency_list),
  670. "first_run_memory_MB": first_run_memory,
  671. "second_run_memory_MB": second_run_memory,
  672. "enable_cuda_graph": use_cuda_graph,
  673. }
  674. def run_tensorrt_static_xl(
  675. work_dir: str,
  676. version: str,
  677. batch_size: int,
  678. disable_safety_checker: bool,
  679. height: int,
  680. width: int,
  681. steps: int,
  682. num_prompts: int,
  683. batch_count: int,
  684. start_memory,
  685. memory_monitor_type,
  686. max_batch_size: int,
  687. nvtx_profile: bool = False,
  688. use_cuda_graph=True,
  689. skip_warmup: bool = False,
  690. ):
  691. print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)")
  692. import tensorrt as trt # noqa: PLC0415
  693. from cuda import cudart # noqa: PLC0415
  694. from trt_utilities import init_trt_plugins # noqa: PLC0415
  695. # Validate image dimensions
  696. image_height = height
  697. image_width = width
  698. if image_height % 8 != 0 or image_width % 8 != 0:
  699. raise ValueError(
  700. f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}."
  701. )
  702. # Register TensorRT plugins
  703. init_trt_plugins()
  704. assert batch_size <= max_batch_size
  705. from diffusion_models import PipelineInfo # noqa: PLC0415
  706. from engine_builder import EngineType, get_engine_paths # noqa: PLC0415
  707. def init_pipeline(pipeline_class, pipeline_info):
  708. engine_type = EngineType.TRT
  709. onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
  710. work_dir, pipeline_info, engine_type
  711. )
  712. # Initialize pipeline
  713. pipeline = pipeline_class(
  714. pipeline_info,
  715. scheduler="DDIM",
  716. output_dir=output_dir,
  717. verbose=False,
  718. nvtx_profile=nvtx_profile,
  719. max_batch_size=max_batch_size,
  720. use_cuda_graph=use_cuda_graph,
  721. framework_model_dir=framework_model_dir,
  722. engine_type=engine_type,
  723. )
  724. pipeline.backend.load_engines(
  725. engine_dir=engine_dir,
  726. framework_model_dir=framework_model_dir,
  727. onnx_dir=onnx_dir,
  728. onnx_opset=17,
  729. opt_batch_size=batch_size,
  730. opt_image_height=height,
  731. opt_image_width=width,
  732. static_batch=True,
  733. static_shape=True,
  734. enable_all_tactics=False,
  735. timing_cache=timing_cache,
  736. )
  737. return pipeline
  738. from pipeline_stable_diffusion import StableDiffusionPipeline # noqa: PLC0415
  739. pipeline_info = PipelineInfo(version)
  740. pipeline = init_pipeline(StableDiffusionPipeline, pipeline_info)
  741. max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory())
  742. _, shared_device_memory = cudart.cudaMalloc(max_device_memory)
  743. pipeline.backend.activate_engines(shared_device_memory)
  744. # Here we use static batch and image size, so the resource allocation only need done once.
  745. # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency.
  746. pipeline.load_resources(image_height, image_width, batch_size)
  747. def run_sd_xl_inference(prompt, negative_prompt, seed=None):
  748. return pipeline.run(
  749. prompt,
  750. negative_prompt,
  751. image_height,
  752. image_width,
  753. denoising_steps=steps,
  754. guidance=5.0,
  755. seed=seed,
  756. )
  757. def warmup():
  758. if skip_warmup:
  759. return
  760. prompt, negative = warmup_prompts()
  761. run_sd_xl_inference([prompt] * batch_size, [negative] * batch_size)
  762. # Run warm up, and measure GPU memory of two runs
  763. # The first run has algo search so it might need more memory
  764. first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  765. second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  766. warmup()
  767. model_name = pipeline_info.name()
  768. image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, steps, disable_safety_checker)
  769. latency_list = []
  770. prompts, negative_prompt = example_prompts()
  771. for i, prompt in enumerate(prompts):
  772. if i >= num_prompts:
  773. break
  774. inference_start = time.time()
  775. # Use warmup mode here since non-warmup mode will save image to disk.
  776. images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123)
  777. inference_end = time.time()
  778. latency = inference_end - inference_start
  779. latency_list.append(latency)
  780. print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}")
  781. for k, image in enumerate(images):
  782. image.save(f"{image_filename_prefix}_{i}_{k}.png")
  783. pipeline.teardown()
  784. return {
  785. "model_name": model_name,
  786. "engine": "tensorrt",
  787. "version": trt.__version__,
  788. "provider": "default",
  789. "height": height,
  790. "width": width,
  791. "steps": steps,
  792. "batch_size": batch_size,
  793. "batch_count": batch_count,
  794. "num_prompts": num_prompts,
  795. "average_latency": sum(latency_list) / len(latency_list),
  796. "median_latency": statistics.median(latency_list),
  797. "first_run_memory_MB": first_run_memory,
  798. "second_run_memory_MB": second_run_memory,
  799. "enable_cuda_graph": use_cuda_graph,
  800. }
  801. def run_ort_trt_xl(
  802. work_dir: str,
  803. version: str,
  804. batch_size: int,
  805. disable_safety_checker: bool,
  806. height: int,
  807. width: int,
  808. steps: int,
  809. num_prompts: int,
  810. batch_count: int,
  811. start_memory,
  812. memory_monitor_type,
  813. max_batch_size: int,
  814. nvtx_profile: bool = False,
  815. use_cuda_graph=True,
  816. skip_warmup: bool = False,
  817. ):
  818. from demo_utils import initialize_pipeline # noqa: PLC0415
  819. from engine_builder import EngineType # noqa: PLC0415
  820. pipeline = initialize_pipeline(
  821. version=version,
  822. engine_type=EngineType.ORT_TRT,
  823. work_dir=work_dir,
  824. height=height,
  825. width=width,
  826. use_cuda_graph=use_cuda_graph,
  827. max_batch_size=max_batch_size,
  828. opt_batch_size=batch_size,
  829. )
  830. assert batch_size <= max_batch_size
  831. pipeline.load_resources(height, width, batch_size)
  832. def run_sd_xl_inference(prompt, negative_prompt, seed=None):
  833. return pipeline.run(
  834. prompt,
  835. negative_prompt,
  836. height,
  837. width,
  838. denoising_steps=steps,
  839. guidance=5.0,
  840. seed=seed,
  841. )
  842. def warmup():
  843. if skip_warmup:
  844. return
  845. prompt, negative = warmup_prompts()
  846. run_sd_xl_inference([prompt] * batch_size, [negative] * batch_size)
  847. # Run warm up, and measure GPU memory of two runs
  848. # The first run has algo search so it might need more memory
  849. first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  850. second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
  851. warmup()
  852. model_name = pipeline.pipeline_info.name()
  853. image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, steps, disable_safety_checker)
  854. latency_list = []
  855. prompts, negative_prompt = example_prompts()
  856. for i, prompt in enumerate(prompts):
  857. if i >= num_prompts:
  858. break
  859. inference_start = time.time()
  860. # Use warmup mode here since non-warmup mode will save image to disk.
  861. images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123)
  862. inference_end = time.time()
  863. latency = inference_end - inference_start
  864. latency_list.append(latency)
  865. print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}")
  866. for k, image in enumerate(images):
  867. filename = f"{image_filename_prefix}_{i}_{k}.png"
  868. image.save(filename)
  869. print("Image saved to", filename)
  870. pipeline.teardown()
  871. from tensorrt import __version__ as trt_version # noqa: PLC0415
  872. from onnxruntime import __version__ as ort_version # noqa: PLC0415
  873. return {
  874. "model_name": model_name,
  875. "engine": "onnxruntime",
  876. "version": ort_version,
  877. "provider": f"tensorrt{trt_version})",
  878. "height": height,
  879. "width": width,
  880. "steps": steps,
  881. "batch_size": batch_size,
  882. "batch_count": batch_count,
  883. "num_prompts": num_prompts,
  884. "average_latency": sum(latency_list) / len(latency_list),
  885. "median_latency": statistics.median(latency_list),
  886. "first_run_memory_MB": first_run_memory,
  887. "second_run_memory_MB": second_run_memory,
  888. "enable_cuda_graph": use_cuda_graph,
  889. }
  890. def run_torch(
  891. model_name: str,
  892. batch_size: int,
  893. disable_safety_checker: bool,
  894. enable_torch_compile: bool,
  895. use_xformers: bool,
  896. height: int,
  897. width: int,
  898. steps: int,
  899. num_prompts: int,
  900. batch_count: int,
  901. start_memory,
  902. memory_monitor_type,
  903. skip_warmup: bool = True,
  904. ):
  905. torch.backends.cudnn.enabled = True
  906. torch.backends.cudnn.benchmark = True
  907. torch.set_grad_enabled(False)
  908. load_start = time.time()
  909. pipe = get_torch_pipeline(model_name, disable_safety_checker, enable_torch_compile, use_xformers)
  910. load_end = time.time()
  911. print(f"Model loading took {load_end - load_start} seconds")
  912. image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, steps, disable_safety_checker)
  913. if not enable_torch_compile:
  914. with torch.inference_mode():
  915. result = run_torch_pipeline(
  916. pipe,
  917. batch_size,
  918. image_filename_prefix,
  919. height,
  920. width,
  921. steps,
  922. num_prompts,
  923. batch_count,
  924. start_memory,
  925. memory_monitor_type,
  926. skip_warmup=skip_warmup,
  927. )
  928. else:
  929. result = run_torch_pipeline(
  930. pipe,
  931. batch_size,
  932. image_filename_prefix,
  933. height,
  934. width,
  935. steps,
  936. num_prompts,
  937. batch_count,
  938. start_memory,
  939. memory_monitor_type,
  940. skip_warmup=skip_warmup,
  941. )
  942. result.update(
  943. {
  944. "model_name": model_name,
  945. "directory": None,
  946. "provider": "compile" if enable_torch_compile else "xformers" if use_xformers else "default",
  947. "disable_safety_checker": disable_safety_checker,
  948. "enable_cuda_graph": False,
  949. }
  950. )
  951. return result
  952. def parse_arguments():
  953. parser = argparse.ArgumentParser()
  954. parser.add_argument(
  955. "-e",
  956. "--engine",
  957. required=False,
  958. type=str,
  959. default="onnxruntime",
  960. choices=["onnxruntime", "optimum", "torch", "tensorrt"],
  961. help="Engines to benchmark. Default is onnxruntime.",
  962. )
  963. parser.add_argument(
  964. "-r",
  965. "--provider",
  966. required=False,
  967. type=str,
  968. default="cuda",
  969. choices=list(PROVIDERS.keys()),
  970. help="Provider to benchmark. Default is CUDAExecutionProvider.",
  971. )
  972. parser.add_argument(
  973. "-t",
  974. "--tuning",
  975. action="store_true",
  976. help="Enable TunableOp and tuning. "
  977. "This will incur longer warmup latency, and is mandatory for some operators of ROCm EP.",
  978. )
  979. parser.add_argument(
  980. "-v",
  981. "--version",
  982. required=False,
  983. type=str,
  984. choices=list(SD_MODELS.keys()),
  985. default="1.5",
  986. help="Stable diffusion version like 1.5, 2.0 or 2.1. Default is 1.5.",
  987. )
  988. parser.add_argument(
  989. "-p",
  990. "--pipeline",
  991. required=False,
  992. type=str,
  993. default=None,
  994. help="Directory of saved onnx pipeline. It could be the output directory of optimize_pipeline.py.",
  995. )
  996. parser.add_argument(
  997. "-w",
  998. "--work_dir",
  999. required=False,
  1000. type=str,
  1001. default=".",
  1002. help="Root directory to save exported onnx models, built engines etc.",
  1003. )
  1004. parser.add_argument(
  1005. "--enable_safety_checker",
  1006. required=False,
  1007. action="store_true",
  1008. help="Enable safety checker",
  1009. )
  1010. parser.set_defaults(enable_safety_checker=False)
  1011. parser.add_argument(
  1012. "--enable_torch_compile",
  1013. required=False,
  1014. action="store_true",
  1015. help="Enable compile unet for PyTorch 2.0",
  1016. )
  1017. parser.set_defaults(enable_torch_compile=False)
  1018. parser.add_argument(
  1019. "--use_xformers",
  1020. required=False,
  1021. action="store_true",
  1022. help="Use xformers for PyTorch",
  1023. )
  1024. parser.set_defaults(use_xformers=False)
  1025. parser.add_argument(
  1026. "--use_io_binding",
  1027. required=False,
  1028. action="store_true",
  1029. help="Use I/O Binding for Optimum.",
  1030. )
  1031. parser.set_defaults(use_io_binding=False)
  1032. parser.add_argument(
  1033. "--skip_warmup",
  1034. required=False,
  1035. action="store_true",
  1036. help="No warmup.",
  1037. )
  1038. parser.set_defaults(skip_warmup=False)
  1039. parser.add_argument(
  1040. "-b",
  1041. "--batch_size",
  1042. type=int,
  1043. default=1,
  1044. choices=[1, 2, 3, 4, 8, 10, 16, 32],
  1045. help="Number of images per batch. Default is 1.",
  1046. )
  1047. parser.add_argument(
  1048. "--height",
  1049. required=False,
  1050. type=int,
  1051. default=512,
  1052. help="Output image height. Default is 512.",
  1053. )
  1054. parser.add_argument(
  1055. "--width",
  1056. required=False,
  1057. type=int,
  1058. default=512,
  1059. help="Output image width. Default is 512.",
  1060. )
  1061. parser.add_argument(
  1062. "-s",
  1063. "--steps",
  1064. required=False,
  1065. type=int,
  1066. default=50,
  1067. help="Number of steps. Default is 50.",
  1068. )
  1069. parser.add_argument(
  1070. "-n",
  1071. "--num_prompts",
  1072. required=False,
  1073. type=int,
  1074. default=10,
  1075. help="Number of prompts. Default is 10.",
  1076. )
  1077. parser.add_argument(
  1078. "-c",
  1079. "--batch_count",
  1080. required=False,
  1081. type=int,
  1082. choices=range(1, 11),
  1083. default=5,
  1084. help="Number of batches to test. Default is 5.",
  1085. )
  1086. parser.add_argument(
  1087. "-m",
  1088. "--max_trt_batch_size",
  1089. required=False,
  1090. type=int,
  1091. choices=range(1, 16),
  1092. default=4,
  1093. help="Maximum batch size for TensorRT. Change the value may trigger TensorRT engine rebuild. Default is 4.",
  1094. )
  1095. parser.add_argument(
  1096. "-g",
  1097. "--enable_cuda_graph",
  1098. required=False,
  1099. action="store_true",
  1100. help="Enable Cuda Graph. Requires onnxruntime >= 1.16",
  1101. )
  1102. parser.set_defaults(enable_cuda_graph=False)
  1103. args = parser.parse_args()
  1104. return args
  1105. def print_loaded_libraries(cuda_related_only=True):
  1106. import psutil # noqa: PLC0415
  1107. p = psutil.Process(os.getpid())
  1108. for lib in p.memory_maps():
  1109. if (not cuda_related_only) or any(x in lib.path for x in ("libcu", "libnv", "tensorrt")):
  1110. print(lib.path)
  1111. def main():
  1112. args = parse_arguments()
  1113. print(args)
  1114. if args.engine == "onnxruntime":
  1115. if args.version in ["2.1"]:
  1116. # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model.
  1117. # The environment variables shall be set before the first run of Attention or MultiHeadAttention operator.
  1118. os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1"
  1119. from packaging import version # noqa: PLC0415
  1120. from onnxruntime import __version__ as ort_version # noqa: PLC0415
  1121. if version.parse(ort_version) == version.parse("1.16.0"):
  1122. # ORT 1.16 has a bug that might trigger Attention RuntimeError when latest fusion script is applied on clip model.
  1123. # The walkaround is to enable fused causal attention, or disable Attention fusion for clip model.
  1124. os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1"
  1125. if args.enable_cuda_graph:
  1126. if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None):
  1127. raise ValueError("The stable diffusion pipeline does not support CUDA graph.")
  1128. if version.parse(ort_version) < version.parse("1.16"):
  1129. raise ValueError("CUDA graph requires ONNX Runtime 1.16 or later")
  1130. coloredlogs.install(fmt="%(funcName)20s: %(message)s")
  1131. memory_monitor_type = "rocm" if args.provider == "rocm" else "cuda"
  1132. start_memory = measure_gpu_memory(memory_monitor_type, None)
  1133. print("GPU memory used before loading models:", start_memory)
  1134. sd_model = SD_MODELS[args.version]
  1135. provider = PROVIDERS[args.provider]
  1136. if args.engine == "onnxruntime" and args.provider == "tensorrt":
  1137. if "xl" in args.version:
  1138. print("Testing Txt2ImgXLPipeline with static input shape. Backend is ORT TensorRT EP.")
  1139. result = run_ort_trt_xl(
  1140. work_dir=args.work_dir,
  1141. version=args.version,
  1142. batch_size=args.batch_size,
  1143. disable_safety_checker=True,
  1144. height=args.height,
  1145. width=args.width,
  1146. steps=args.steps,
  1147. num_prompts=args.num_prompts,
  1148. batch_count=args.batch_count,
  1149. start_memory=start_memory,
  1150. memory_monitor_type=memory_monitor_type,
  1151. max_batch_size=args.max_trt_batch_size,
  1152. nvtx_profile=False,
  1153. use_cuda_graph=args.enable_cuda_graph,
  1154. skip_warmup=args.skip_warmup,
  1155. )
  1156. else:
  1157. print("Testing Txt2ImgPipeline with static input shape. Backend is ORT TensorRT EP.")
  1158. result = run_ort_trt_static(
  1159. work_dir=args.work_dir,
  1160. version=args.version,
  1161. batch_size=args.batch_size,
  1162. disable_safety_checker=not args.enable_safety_checker,
  1163. height=args.height,
  1164. width=args.width,
  1165. steps=args.steps,
  1166. num_prompts=args.num_prompts,
  1167. batch_count=args.batch_count,
  1168. start_memory=start_memory,
  1169. memory_monitor_type=memory_monitor_type,
  1170. max_batch_size=args.max_trt_batch_size,
  1171. nvtx_profile=False,
  1172. use_cuda_graph=args.enable_cuda_graph,
  1173. skip_warmup=args.skip_warmup,
  1174. )
  1175. elif args.engine == "optimum" and provider == "CUDAExecutionProvider":
  1176. if "xl" in args.version:
  1177. os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1"
  1178. result = run_optimum_ort(
  1179. model_name=sd_model,
  1180. directory=args.pipeline,
  1181. provider=provider,
  1182. batch_size=args.batch_size,
  1183. disable_safety_checker=not args.enable_safety_checker,
  1184. height=args.height,
  1185. width=args.width,
  1186. steps=args.steps,
  1187. num_prompts=args.num_prompts,
  1188. batch_count=args.batch_count,
  1189. start_memory=start_memory,
  1190. memory_monitor_type=memory_monitor_type,
  1191. use_io_binding=args.use_io_binding,
  1192. skip_warmup=args.skip_warmup,
  1193. )
  1194. elif args.engine == "onnxruntime":
  1195. assert args.pipeline and os.path.isdir(args.pipeline), (
  1196. "--pipeline should be specified for the directory of ONNX models"
  1197. )
  1198. print(f"Testing diffusers StableDiffusionPipeline with {provider} provider and tuning={args.tuning}")
  1199. result = run_ort(
  1200. model_name=sd_model,
  1201. directory=args.pipeline,
  1202. provider=provider,
  1203. batch_size=args.batch_size,
  1204. disable_safety_checker=not args.enable_safety_checker,
  1205. height=args.height,
  1206. width=args.width,
  1207. steps=args.steps,
  1208. num_prompts=args.num_prompts,
  1209. batch_count=args.batch_count,
  1210. start_memory=start_memory,
  1211. memory_monitor_type=memory_monitor_type,
  1212. tuning=args.tuning,
  1213. skip_warmup=args.skip_warmup,
  1214. )
  1215. elif args.engine == "tensorrt" and "xl" in args.version:
  1216. print("Testing Txt2ImgXLPipeline with static input shape. Backend is TensorRT.")
  1217. result = run_tensorrt_static_xl(
  1218. work_dir=args.work_dir,
  1219. version=args.version,
  1220. batch_size=args.batch_size,
  1221. disable_safety_checker=True,
  1222. height=args.height,
  1223. width=args.width,
  1224. steps=args.steps,
  1225. num_prompts=args.num_prompts,
  1226. batch_count=args.batch_count,
  1227. start_memory=start_memory,
  1228. memory_monitor_type=memory_monitor_type,
  1229. max_batch_size=args.max_trt_batch_size,
  1230. nvtx_profile=False,
  1231. use_cuda_graph=args.enable_cuda_graph,
  1232. skip_warmup=args.skip_warmup,
  1233. )
  1234. elif args.engine == "tensorrt":
  1235. print("Testing Txt2ImgPipeline with static input shape. Backend is TensorRT.")
  1236. result = run_tensorrt_static(
  1237. work_dir=args.work_dir,
  1238. version=args.version,
  1239. model_name=sd_model,
  1240. batch_size=args.batch_size,
  1241. disable_safety_checker=True,
  1242. height=args.height,
  1243. width=args.width,
  1244. steps=args.steps,
  1245. num_prompts=args.num_prompts,
  1246. batch_count=args.batch_count,
  1247. start_memory=start_memory,
  1248. memory_monitor_type=memory_monitor_type,
  1249. max_batch_size=args.max_trt_batch_size,
  1250. nvtx_profile=False,
  1251. use_cuda_graph=args.enable_cuda_graph,
  1252. skip_warmup=args.skip_warmup,
  1253. )
  1254. else:
  1255. print(
  1256. f"Testing Txt2ImgPipeline with dynamic input shape. Backend is PyTorch: compile={args.enable_torch_compile}, xformers={args.use_xformers}."
  1257. )
  1258. result = run_torch(
  1259. model_name=sd_model,
  1260. batch_size=args.batch_size,
  1261. disable_safety_checker=not args.enable_safety_checker,
  1262. enable_torch_compile=args.enable_torch_compile,
  1263. use_xformers=args.use_xformers,
  1264. height=args.height,
  1265. width=args.width,
  1266. steps=args.steps,
  1267. num_prompts=args.num_prompts,
  1268. batch_count=args.batch_count,
  1269. start_memory=start_memory,
  1270. memory_monitor_type=memory_monitor_type,
  1271. skip_warmup=args.skip_warmup,
  1272. )
  1273. print(result)
  1274. with open("benchmark_result.csv", mode="a", newline="") as csv_file:
  1275. column_names = [
  1276. "model_name",
  1277. "directory",
  1278. "engine",
  1279. "version",
  1280. "provider",
  1281. "disable_safety_checker",
  1282. "height",
  1283. "width",
  1284. "steps",
  1285. "batch_size",
  1286. "batch_count",
  1287. "num_prompts",
  1288. "average_latency",
  1289. "median_latency",
  1290. "first_run_memory_MB",
  1291. "second_run_memory_MB",
  1292. "enable_cuda_graph",
  1293. ]
  1294. csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
  1295. csv_writer.writeheader()
  1296. csv_writer.writerow(result)
  1297. # Show loaded DLLs when steps == 1 for debugging purpose.
  1298. if args.steps == 1:
  1299. print_loaded_libraries(args.provider in ["cuda", "tensorrt"])
  1300. if __name__ == "__main__":
  1301. import traceback
  1302. try:
  1303. main()
  1304. except Exception:
  1305. traceback.print_exception(*sys.exc_info())