llama_parity.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import argparse
  8. import logging
  9. import os
  10. import time
  11. import numpy as np
  12. import packaging.version as pv
  13. import torch
  14. from benchmark_helper import setup_logger
  15. from dist_settings import get_rank, get_size
  16. from llama_inputs import (
  17. add_io_bindings_as_ortvalues,
  18. convert_inputs_for_ort,
  19. get_merged_sample_with_past_kv_inputs,
  20. get_sample_inputs,
  21. get_sample_with_past_kv_inputs,
  22. verify_ort_inputs,
  23. )
  24. from llama_torch import setup_torch_model
  25. from models.torch_export_patches.cache_helper import make_dynamic_cache
  26. from transformers import AutoConfig
  27. from transformers import __version__ as transformers_version
  28. from transformers.cache_utils import DynamicCache
  29. import onnxruntime as ort
  30. logger = logging.getLogger("")
  31. def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
  32. past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
  33. max_sequence_length = config.max_position_embeddings
  34. return past_sequence_length, curr_sequence_length, max_sequence_length
  35. def get_inputs(args: argparse.Namespace, config: AutoConfig):
  36. # Dummy values for parity
  37. world_size = get_size()
  38. batch_size = 2
  39. past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config)
  40. if args.merged:
  41. inputs = get_merged_sample_with_past_kv_inputs(
  42. config,
  43. args.device,
  44. batch_size,
  45. seq_len=sequence_length,
  46. past_seq_len=past_sequence_length,
  47. max_seq_len=max_sequence_length,
  48. use_fp16=args.use_fp16,
  49. use_buffer_share=args.use_buffer_share,
  50. return_dict=True,
  51. world_size=world_size,
  52. )
  53. elif args.use_past_kv:
  54. inputs = get_sample_with_past_kv_inputs(
  55. config,
  56. args.device,
  57. batch_size,
  58. sequence_length,
  59. use_fp16=args.use_fp16,
  60. return_dict=True,
  61. world_size=world_size,
  62. )
  63. else:
  64. inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
  65. return inputs
  66. def torch_deepcopy(value):
  67. if isinstance(value, (int, float, str)):
  68. return value
  69. if isinstance(value, tuple):
  70. return tuple(torch_deepcopy(v) for v in value)
  71. if isinstance(value, list):
  72. return [torch_deepcopy(v) for v in value]
  73. if isinstance(value, set):
  74. return {torch_deepcopy(v) for v in value}
  75. if isinstance(value, dict):
  76. return {k: torch_deepcopy(v) for k, v in value.items()}
  77. if isinstance(value, np.ndarray):
  78. return value.copy()
  79. if hasattr(value, "clone"):
  80. return value.clone()
  81. if isinstance(value, DynamicCache):
  82. return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False))))
  83. # We should have a code using serialization, deserialization assuming a model
  84. # cannot be exported without them.
  85. raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
  86. def verify_parity(
  87. args: argparse.Namespace,
  88. location: str,
  89. use_auth_token: bool,
  90. kv_cache_ortvalues: dict,
  91. pytorch_model: None | torch.nn.Module = None,
  92. config: None | AutoConfig = None,
  93. ):
  94. # If it's running in a machine where GPU memory < 36GB, it should unload the model in GPU in time and free the GPU memory for ORT.
  95. py_model = pytorch_model
  96. if py_model is None:
  97. config, py_model = setup_torch_model(
  98. args,
  99. location,
  100. use_auth_token,
  101. torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
  102. device=args.device,
  103. )
  104. inputs = get_inputs(args, config)
  105. if "past_key_values" in inputs and pv.Version(transformers_version) >= pv.Version("4.45"):
  106. # Using DynamicCache
  107. inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"])
  108. # Run inference with PyTorch
  109. inputs_after_deepcopy = torch_deepcopy(inputs)
  110. if args.execution_provider != "cpu":
  111. torch.cuda.synchronize()
  112. start_time = time.time()
  113. # If there is a cache in the inputs, we need to make a copy as the model modifies them inplace.
  114. # DynamicCache inherits from torch.nn.Module in some version of transformers.
  115. # We need to make the copy manually.
  116. pt_outputs = py_model(**inputs_after_deepcopy).logits.detach().cpu().numpy()
  117. if args.execution_provider != "cpu":
  118. torch.cuda.synchronize()
  119. end_time = time.time()
  120. logger.info(f"PyTorch took {end_time - start_time} s")
  121. if args.small_gpu and py_model is not None:
  122. del py_model
  123. torch.cuda.empty_cache()
  124. # Run inference with ORT
  125. past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config)
  126. inputs = convert_inputs_for_ort(
  127. inputs,
  128. use_buffer_share=args.use_buffer_share,
  129. past_seq_len=past_sequence_length,
  130. max_seq_len=max_sequence_length,
  131. )
  132. ep = f"{args.execution_provider.upper()}ExecutionProvider"
  133. if ep == "CUDAExecutionProvider":
  134. ep = (ep, {"device_id": args.rank})
  135. ort_model = ort.InferenceSession(
  136. args.onnx_model_path,
  137. sess_options=ort.SessionOptions(),
  138. providers=[ep],
  139. )
  140. inputs = verify_ort_inputs(ort_model, inputs)
  141. # Add IO bindings for non-CPU execution providers
  142. if args.execution_provider != "cpu":
  143. io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
  144. ort_model,
  145. ort_inputs=inputs,
  146. device=args.execution_provider,
  147. device_id=int(args.rank),
  148. use_buffer_share=args.use_buffer_share,
  149. kv_cache_ortvalues=kv_cache_ortvalues,
  150. )
  151. io_binding.synchronize_inputs()
  152. start_time = time.time()
  153. ort_model.run_with_iobinding(io_binding)
  154. io_binding.synchronize_outputs()
  155. end_time = time.time()
  156. ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
  157. del ort_model
  158. else:
  159. start_time = time.time()
  160. ort_outputs = ort_model.run(None, inputs)
  161. end_time = time.time()
  162. ort_outputs = ort_outputs[0] # Get logits
  163. logger.info(f"ONNX Runtime took {end_time - start_time} s")
  164. # Compare PyTorch and ONNX Runtime accuracy
  165. tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1
  166. parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
  167. logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
  168. if not parity:
  169. logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
  170. return kv_cache_ortvalues
  171. def get_args(argv: list[str]):
  172. parser = argparse.ArgumentParser()
  173. parser.add_argument(
  174. "-m",
  175. "--model_name",
  176. required=False,
  177. help="Model name in Hugging Face",
  178. )
  179. parser.add_argument(
  180. "-t",
  181. "--torch_model_directory",
  182. required=False,
  183. default=os.path.join("."),
  184. help="Path to folder containing PyTorch model and associated files if saved on disk",
  185. )
  186. parser.add_argument(
  187. "-o",
  188. "--onnx_model_path",
  189. required=True,
  190. default=os.path.join("."),
  191. help="Path to ONNX model (with external data files saved in the same folder as the model)",
  192. )
  193. parser.add_argument(
  194. "-ep",
  195. "--execution_provider",
  196. required=False,
  197. default="cpu",
  198. choices=["cpu", "cuda", "rocm"],
  199. help="Execution provider to verify parity with",
  200. )
  201. parser.add_argument(
  202. "-v",
  203. "--verbose",
  204. action="store_true",
  205. help="Print verbose logs",
  206. )
  207. parser.set_defaults(verbose=False)
  208. parser.add_argument(
  209. "-p",
  210. "--use_past_kv",
  211. action="store_true",
  212. help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.",
  213. )
  214. parser.set_defaults(use_past_kv=False)
  215. parser.add_argument(
  216. "-g",
  217. "--use_buffer_share",
  218. action="store_true",
  219. help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing",
  220. )
  221. parser.set_defaults(use_buffer_share=False)
  222. parser.add_argument(
  223. "--merged",
  224. action="store_true",
  225. help="Use merged model (i.e. decoder_merged_model.onnx).",
  226. )
  227. parser.set_defaults(merged=False)
  228. parser.add_argument(
  229. "-fp",
  230. "--precision",
  231. required=True,
  232. choices=["int4", "int8", "fp16", "fp32"],
  233. help="Precision of model",
  234. )
  235. parser.add_argument(
  236. "--cache_dir",
  237. required=False,
  238. type=str,
  239. default="./model_cache",
  240. help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
  241. )
  242. # The argument is used for CI mainly, because the CI machine has 24G GPU memory at most.
  243. parser.add_argument(
  244. "--small_gpu",
  245. action="store_true",
  246. help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ",
  247. )
  248. args = parser.parse_args() if argv == [] else parser.parse_args(argv)
  249. # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
  250. args.precision = (
  251. "fp32"
  252. if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu")
  253. else "fp16"
  254. )
  255. return args
  256. def main(argv: list[str] = []): # noqa: B006
  257. args = get_args(argv)
  258. setup_logger(args.verbose)
  259. logger.info(f"Arguments: {args}")
  260. rank = get_rank()
  261. # Load model and config
  262. setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
  263. args.rank = rank
  264. setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
  265. setattr(args, "device", torch.device(args.device_name)) # noqa: B010
  266. use_auth_token = args.torch_model_directory == os.path.join(".")
  267. location = args.model_name if use_auth_token else args.torch_model_directory
  268. kv_cache_ortvalues = {}
  269. if not args.merged:
  270. verify_parity(args, location, use_auth_token, kv_cache_ortvalues)
  271. else:
  272. config = llama = None
  273. if not args.small_gpu:
  274. config, llama = setup_torch_model(
  275. args,
  276. location,
  277. use_auth_token,
  278. torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
  279. device=args.device,
  280. )
  281. # Verify prompt processing in merged model (decoder_model.onnx)
  282. args.use_past_kv = False
  283. kv_cache_ortvalues = verify_parity(
  284. args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config
  285. )
  286. # Verify token generation in merged model (decoder_with_past_model.onnx)
  287. args.use_past_kv = True
  288. verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config)
  289. if __name__ == "__main__":
  290. seed = 2
  291. np.random.seed(seed)
  292. torch.manual_seed(seed)
  293. main()