# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- from __future__ import annotations import argparse import logging import os import time import numpy as np import packaging.version as pv import torch from benchmark_helper import setup_logger from dist_settings import get_rank, get_size from llama_inputs import ( add_io_bindings_as_ortvalues, convert_inputs_for_ort, get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs, verify_ort_inputs, ) from llama_torch import setup_torch_model from models.torch_export_patches.cache_helper import make_dynamic_cache from transformers import AutoConfig from transformers import __version__ as transformers_version from transformers.cache_utils import DynamicCache import onnxruntime as ort logger = logging.getLogger("") def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig): past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) max_sequence_length = config.max_position_embeddings return past_sequence_length, curr_sequence_length, max_sequence_length def get_inputs(args: argparse.Namespace, config: AutoConfig): # Dummy values for parity world_size = get_size() batch_size = 2 past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config) if args.merged: inputs = get_merged_sample_with_past_kv_inputs( config, args.device, batch_size, seq_len=sequence_length, past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, use_fp16=args.use_fp16, use_buffer_share=args.use_buffer_share, return_dict=True, world_size=world_size, ) elif args.use_past_kv: inputs = get_sample_with_past_kv_inputs( config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True, world_size=world_size, ) else: inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) return inputs def torch_deepcopy(value): if isinstance(value, (int, float, str)): return value if isinstance(value, tuple): return tuple(torch_deepcopy(v) for v in value) if isinstance(value, list): return [torch_deepcopy(v) for v in value] if isinstance(value, set): return {torch_deepcopy(v) for v in value} if isinstance(value, dict): return {k: torch_deepcopy(v) for k, v in value.items()} if isinstance(value, np.ndarray): return value.copy() if hasattr(value, "clone"): return value.clone() if isinstance(value, DynamicCache): return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False)))) # We should have a code using serialization, deserialization assuming a model # cannot be exported without them. raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}") def verify_parity( args: argparse.Namespace, location: str, use_auth_token: bool, kv_cache_ortvalues: dict, pytorch_model: None | torch.nn.Module = None, config: None | AutoConfig = None, ): # If it's running in a machine where GPU memory < 36GB, it should unload the model in GPU in time and free the GPU memory for ORT. py_model = pytorch_model if py_model is None: config, py_model = setup_torch_model( args, location, use_auth_token, torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), device=args.device, ) inputs = get_inputs(args, config) if "past_key_values" in inputs and pv.Version(transformers_version) >= pv.Version("4.45"): # Using DynamicCache inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"]) # Run inference with PyTorch inputs_after_deepcopy = torch_deepcopy(inputs) if args.execution_provider != "cpu": torch.cuda.synchronize() start_time = time.time() # If there is a cache in the inputs, we need to make a copy as the model modifies them inplace. # DynamicCache inherits from torch.nn.Module in some version of transformers. # We need to make the copy manually. pt_outputs = py_model(**inputs_after_deepcopy).logits.detach().cpu().numpy() if args.execution_provider != "cpu": torch.cuda.synchronize() end_time = time.time() logger.info(f"PyTorch took {end_time - start_time} s") if args.small_gpu and py_model is not None: del py_model torch.cuda.empty_cache() # Run inference with ORT past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config) inputs = convert_inputs_for_ort( inputs, use_buffer_share=args.use_buffer_share, past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, ) ep = f"{args.execution_provider.upper()}ExecutionProvider" if ep == "CUDAExecutionProvider": ep = (ep, {"device_id": args.rank}) ort_model = ort.InferenceSession( args.onnx_model_path, sess_options=ort.SessionOptions(), providers=[ep], ) inputs = verify_ort_inputs(ort_model, inputs) # Add IO bindings for non-CPU execution providers if args.execution_provider != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues( ort_model, ort_inputs=inputs, device=args.execution_provider, device_id=int(args.rank), use_buffer_share=args.use_buffer_share, kv_cache_ortvalues=kv_cache_ortvalues, ) io_binding.synchronize_inputs() start_time = time.time() ort_model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() end_time = time.time() ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits del ort_model else: start_time = time.time() ort_outputs = ort_model.run(None, inputs) end_time = time.time() ort_outputs = ort_outputs[0] # Get logits logger.info(f"ONNX Runtime took {end_time - start_time} s") # Compare PyTorch and ONNX Runtime accuracy tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1 parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol) logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}") if not parity: logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}") return kv_cache_ortvalues def get_args(argv: list[str]): parser = argparse.ArgumentParser() parser.add_argument( "-m", "--model_name", required=False, help="Model name in Hugging Face", ) parser.add_argument( "-t", "--torch_model_directory", required=False, default=os.path.join("."), help="Path to folder containing PyTorch model and associated files if saved on disk", ) parser.add_argument( "-o", "--onnx_model_path", required=True, default=os.path.join("."), help="Path to ONNX model (with external data files saved in the same folder as the model)", ) parser.add_argument( "-ep", "--execution_provider", required=False, default="cpu", choices=["cpu", "cuda", "rocm"], help="Execution provider to verify parity with", ) parser.add_argument( "-v", "--verbose", action="store_true", help="Print verbose logs", ) parser.set_defaults(verbose=False) parser.add_argument( "-p", "--use_past_kv", action="store_true", help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.", ) parser.set_defaults(use_past_kv=False) parser.add_argument( "-g", "--use_buffer_share", action="store_true", help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing", ) parser.set_defaults(use_buffer_share=False) parser.add_argument( "--merged", action="store_true", help="Use merged model (i.e. decoder_merged_model.onnx).", ) parser.set_defaults(merged=False) parser.add_argument( "-fp", "--precision", required=True, choices=["int4", "int8", "fp16", "fp32"], help="Precision of model", ) parser.add_argument( "--cache_dir", required=False, type=str, default="./model_cache", help="model cache dir to override default HF cache dir to avoid overflood the /home dir", ) # The argument is used for CI mainly, because the CI machine has 24G GPU memory at most. parser.add_argument( "--small_gpu", action="store_true", help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ", ) args = parser.parse_args() if argv == [] else parser.parse_args(argv) # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models args.precision = ( "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu") else "fp16" ) return args def main(argv: list[str] = []): # noqa: B006 args = get_args(argv) setup_logger(args.verbose) logger.info(f"Arguments: {args}") rank = get_rank() # Load model and config setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 args.rank = rank setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 setattr(args, "device", torch.device(args.device_name)) # noqa: B010 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory kv_cache_ortvalues = {} if not args.merged: verify_parity(args, location, use_auth_token, kv_cache_ortvalues) else: config = llama = None if not args.small_gpu: config, llama = setup_torch_model( args, location, use_auth_token, torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), device=args.device, ) # Verify prompt processing in merged model (decoder_model.onnx) args.use_past_kv = False kv_cache_ortvalues = verify_parity( args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config ) # Verify token generation in merged model (decoder_with_past_model.onnx) args.use_past_kv = True verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config) if __name__ == "__main__": seed = 2 np.random.seed(seed) torch.manual_seed(seed) main()