| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- from __future__ import annotations
- import numpy as np
- import torch
- from transformers import AutoConfig, AutoTokenizer
- from transformers.cache_utils import DynamicCache
- from onnxruntime import InferenceSession, OrtValue
- # Get position_ids from attention_mask
- def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if use_past_kv:
- # Shape: (batch_size, 1)
- position_ids = position_ids[:, -1].unsqueeze(-1)
- # Shape: (batch_size, sequence_length)
- return position_ids
- # Inputs for first pass to get initial past_key_values
- # input_ids: (batch_size, sequence_length)
- # attention_mask: (batch_size, sequence_length)
- # position_ids: (batch_size, sequence_length)
- def get_sample_inputs(
- config: AutoConfig,
- device: torch.device,
- batch_size: int,
- seq_len: int,
- engine: str = "pt",
- return_dict: bool = False,
- ):
- input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
- attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
- position_ids = get_position_ids(attention_mask, use_past_kv=False)
- # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
- input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
- attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
- position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
- if not return_dict:
- # For export
- return (input_ids, attention_mask, position_ids)
- inputs = {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "position_ids": position_ids,
- }
- return inputs
- # Inputs for subsequent passes with past_key_values
- # input_ids: (batch_size, 1)
- # attention_mask: (batch_size, past_sequence_length + 1)
- # position_ids: (batch_size, 1)
- # past_key: (batch_size, num_heads, past_sequence_length, head_size)
- # past_value: (batch_size, num_heads, past_sequence_length, head_size)
- def get_sample_with_past_kv_inputs(
- config: AutoConfig,
- device: torch.device,
- batch_size: int,
- past_seq_len: int,
- use_fp16: bool = False,
- engine: str = "pt",
- return_dict: bool = False,
- world_size: int = 1,
- ):
- input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
- attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
- # position_ids is of shape (batch_size, 1)
- position_ids = get_position_ids(attention_mask, use_past_kv=True)
- past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
- # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
- input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
- attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
- position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
- past_kv = (
- flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
- )
- if not return_dict:
- # For export
- assert isinstance(past_kv, list)
- return (input_ids, attention_mask, position_ids, past_kv)
- inputs = {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "position_ids": position_ids,
- }
- if engine == "ort":
- assert isinstance(past_kv, dict)
- inputs.update(past_kv)
- else:
- assert isinstance(past_kv, list)
- inputs["past_key_values"] = past_kv
- return inputs
- # Inputs for all passes with past_key_values
- # input_ids: (batch_size, sequence_length)
- # attention_mask: (batch_size, past_sequence_length + sequence_length)
- # position_ids: (batch_size, sequence_length)
- # past_key: (batch_size, num_heads, kv_sequence_length, head_size)
- # For models with GQA, kv_sequence_length = max_sequence_length
- # For models without GQA, kv_sequence_length = past_sequence_length
- # past_value: (batch_size, num_heads, kv_sequence_length, head_size)
- # For models with GQA, kv_sequence_length = max_sequence_length
- # For models without GQA, kv_sequence_length = past_sequence_length
- def get_merged_sample_with_past_kv_inputs(
- config: AutoConfig,
- device: torch.device,
- batch_size: int,
- seq_len: int,
- past_seq_len: int,
- max_seq_len: int,
- use_fp16: bool = False,
- use_buffer_share: bool = False,
- engine: str = "pt",
- return_dict: bool = False,
- world_size: int = 1,
- ):
- input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
- attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
- # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
- position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
- past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
- # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
- input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
- attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
- position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
- past_kv = (
- flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
- )
- if not return_dict:
- # For export
- assert isinstance(past_kv, list)
- return (input_ids, attention_mask, position_ids, past_kv)
- inputs = {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "position_ids": position_ids,
- }
- if engine == "ort":
- assert isinstance(past_kv, dict)
- inputs.update(past_kv)
- if use_buffer_share:
- inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
- else:
- assert isinstance(past_kv, list)
- inputs["past_key_values"] = past_kv
- return inputs
- # Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
- def get_msft_sample_inputs(
- config: AutoConfig,
- batch_size: int,
- past_seq_len: int,
- seq_len: int,
- max_seq_len: int,
- use_fp16: bool,
- use_buffer_share: bool,
- split_kv: bool,
- ):
- np_dtype = np.float16 if use_fp16 else np.float32
- head_size = config.hidden_size // config.num_attention_heads
- if not split_kv:
- ort_inputs = {
- "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
- "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
- "k_cache": np.random.rand(
- batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
- ).astype(np_dtype),
- "v_cache": np.random.rand(
- batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
- ).astype(np_dtype),
- "pos": np.array(past_seq_len, dtype=np.int64),
- }
- else:
- ort_inputs = {
- "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
- "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
- np.int32
- ),
- "pos": np.array(past_seq_len, dtype=np.int64),
- }
- for i in range(config.num_hidden_layers):
- ort_inputs.update(
- {
- f"k_{i}_cache": np.random.rand(
- batch_size, config.num_attention_heads, past_seq_len, head_size
- ).astype(np_dtype),
- f"v_{i}_cache": np.random.rand(
- batch_size, config.num_attention_heads, past_seq_len, head_size
- ).astype(np_dtype),
- }
- )
- if use_buffer_share:
- ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
- return ort_inputs
- # Create past_key_values
- # Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
- def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
- num_heads = config.num_key_value_heads // world_size
- head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
- torch_dtype = torch.float16 if use_fp16 else torch.float32
- past_kv = [
- (
- torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
- torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
- )
- for _ in range(config.num_hidden_layers)
- ]
- return past_kv
- # Convert list of past_key_values to dict of past_key and past_value
- def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
- past_kv = {}
- for i, (past_k, past_v) in enumerate(past_key_values):
- if isinstance(past_key_values, DynamicCache):
- past_kv[f"past_key_values_key_cache_{i}"] = past_k.detach().cpu().numpy()
- past_kv[f"past_key_values_value_cache_{i}"] = past_v.detach().cpu().numpy()
- else:
- past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
- past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
- return past_kv
- # Format PyTorch inputs to ONNX Runtime inputs
- def convert_inputs_for_ort(
- pt_inputs: dict,
- use_buffer_share: bool = False,
- past_seq_len: int = 0,
- max_seq_len: int = 2048,
- ):
- ort_inputs = {}
- for k, v in pt_inputs.items():
- if isinstance(v, np.ndarray):
- ort_inputs[k] = v
- elif k == "past_key_values":
- ort_inputs.update(flatten_past_kv_inputs(v))
- else:
- ort_inputs[k] = v.detach().cpu().numpy()
- # Reshape KV caches if using past-present-share-buffer
- if use_buffer_share:
- ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
- return ort_inputs
- # Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
- # (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
- def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
- for k, v in ort_inputs.items():
- # Allocate new buffers with max_sequence_length for GQA
- if "cache" in k or "past_key_values" in k:
- # Copy v (BxSxPxH) into new_v (BxSxMxH)
- batch_size, num_heads, _, head_size = v.shape
- new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
- new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
- ort_inputs[k] = new_v
- return ort_inputs
- # Verify ONNX Runtime inputs with model
- def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
- # Check that all model inputs will be provided
- model_inputs = {model_input.name for model_input in model.get_inputs()}
- user_inputs = set(ort_inputs.keys())
- missing_inputs = model_inputs - user_inputs
- if len(missing_inputs):
- print(f"The following model inputs are missing: {missing_inputs}")
- raise Exception("There are missing inputs to the model. Please add them and try again.")
- # Remove unnecessary inputs from model inputs
- unnecessary_inputs = user_inputs - model_inputs
- if len(unnecessary_inputs):
- for unnecessary_input in unnecessary_inputs:
- del ort_inputs[unnecessary_input]
- return ort_inputs
- # Add IO bindings for execution providers using OrtValue
- # Use when you need to run inference once or twice to save memory
- def add_io_bindings_as_ortvalues(
- model: InferenceSession,
- ort_inputs: dict,
- device: str,
- device_id: int,
- use_buffer_share: bool,
- kv_cache_ortvalues: dict,
- ):
- io_binding = model.io_binding()
- model_inputs = {i.name for i in model.get_inputs()}
- for k, v in ort_inputs.items():
- # Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
- # GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
- # but `position_ids` is used as a PyTorch model input
- if k not in model_inputs:
- continue
- # Bind OrtValue inputs to device
- if use_buffer_share and ("cache" in k or "past_key_values" in k):
- if k not in kv_cache_ortvalues:
- v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
- io_binding.bind_ortvalue_input(k, v_device)
- kv_cache_ortvalues[k] = v_device
- else:
- kv_cache_ortvalues[k].update_inplace(v)
- io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
- else:
- v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
- io_binding.bind_ortvalue_input(k, v_device)
- for output in model.get_outputs():
- name = output.name
- if use_buffer_share and ("out" in name or "present" in name):
- # Bind present KV cache outputs to past KV cache inputs in order to buffer share
- input_name = name.replace("out", "cache").replace("present", "past_key_values")
- io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
- else:
- io_binding.bind_output(name, device_type=device, device_id=device_id)
- return io_binding, kv_cache_ortvalues
- # Add IO bindings for execution providers using PyTorch tensors
- # Use when you need to run inference many times
- def add_io_bindings_as_tensors(
- model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
- ):
- # Verify model inputs
- inputs = verify_ort_inputs(model, inputs)
- device = None
- pt_to_np = {
- "torch.int32": np.int32,
- "torch.int64": np.int64,
- "torch.float16": np.float16,
- "torch.float32": np.float32,
- }
- # Bind inputs/outputs to IO binding
- io_binding = model.io_binding()
- for k, v in inputs.items():
- io_binding.bind_input(
- name=k,
- device_type=v.device.type,
- device_id=0 if v.device.type == "cpu" else v.device.index,
- element_type=pt_to_np[repr(v.dtype)],
- shape=tuple(v.shape),
- buffer_ptr=v.data_ptr(),
- )
- device = v.device
- for output in model.get_outputs():
- name = output.name
- # Bind KV cache outputs to KV cache inputs
- v = (
- inputs[name.replace("present", "past_key_values")]
- if use_buffer_share and "present" in name
- else outputs[name]
- )
- io_binding.bind_output(
- name=name,
- device_type=device.type,
- device_id=0 if device.type == "cpu" else device.index,
- element_type=(np.float16 if use_fp16 else np.float32),
- shape=tuple(v.shape),
- buffer_ptr=v.data_ptr(),
- )
- return io_binding
- # Get actual inputs when using real data (instead of sample data) and initialize outputs
- def get_initial_inputs_and_outputs(
- config: AutoConfig,
- tokenizer: AutoTokenizer,
- requested_length: int,
- prompt: list[str],
- device: torch.device,
- use_fp16: bool,
- use_buffer_share: bool,
- engine: str,
- ):
- tokenizer.pad_token = tokenizer.eos_token
- encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
- torch_dtype = torch.float16 if use_fp16 else torch.float32
- # input_ids: pad token id is 0
- # attention_mask: pad token id is 0
- # position_ids: pad token id is 1
- input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
- attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
- position_ids = get_position_ids(attention_mask, use_past_kv=False)
- # Check if tokenized prompt length matches the requested prompt length
- tokenized_length = input_ids.shape[-1]
- if tokenized_length > requested_length:
- # Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
- input_ids = input_ids[:, :requested_length]
- attention_mask = attention_mask[:, :requested_length]
- position_ids = get_position_ids(attention_mask, use_past_kv=False)
- elif tokenized_length < requested_length:
- # Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
- input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
- attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
- for _ in range(requested_length - tokenized_length):
- input_ids = torch.hstack((input_ids_first_col, input_ids))
- attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
- position_ids = get_position_ids(attention_mask, use_past_kv=False)
- tokenized_length = input_ids.shape[-1]
- assert tokenized_length == requested_length
- # Create inputs
- inputs = {
- "input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
- "attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
- "position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
- }
- if engine != "ort":
- inputs["past_key_values"] = []
- # Get shape of KV cache inputs
- batch_size, sequence_length = input_ids.shape
- max_sequence_length = config.max_position_embeddings
- num_heads = config.num_key_value_heads
- head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
- # Create KV cache inputs
- for i in range(config.num_hidden_layers):
- past_key = torch.zeros(
- batch_size,
- num_heads,
- max_sequence_length if use_buffer_share else 0,
- head_size,
- device=device,
- dtype=torch_dtype,
- )
- past_value = torch.zeros(
- batch_size,
- num_heads,
- max_sequence_length if use_buffer_share else 0,
- head_size,
- device=device,
- dtype=torch_dtype,
- )
- if engine == "ort":
- inputs.update(
- {
- f"past_key_values.{i}.key": past_key.contiguous(),
- f"past_key_values.{i}.value": past_value.contiguous(),
- }
- )
- else:
- inputs["past_key_values"].append((past_key, past_value))
- outputs = None
- if engine == "ort":
- # Create outputs
- logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
- outputs = {"logits": logits.contiguous()}
- if not use_buffer_share:
- for i in range(config.num_hidden_layers):
- present_key = torch.zeros(
- batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
- )
- present_value = torch.zeros(
- batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
- )
- outputs.update(
- {f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
- )
- return inputs, outputs
|