# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- import time import numpy as np import torch from transformers import AutoTokenizer import onnxruntime as ort pt_to_np = { "torch.int32": np.int32, "torch.int64": np.int64, "torch.float32": np.float32, "torch.float16": np.float16, } def cuda_memcpy(dst, src): from cuda import cudart # noqa: PLC0415 cudart.cudaMemcpy( dst.data_ptr(), src.data_ptr(), src.element_size() * src.nelement(), cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, ) class ORTGenerator: def __init__(self, decoder_path): self.onnx_decoder_path = decoder_path self.num_heads = 32 self.head_size = 80 self.num_layers = 32 self.max_sequence_length = 2048 self.device_id = 0 self.use_cuda_graph = False self.use_traced_inputs = False self.static_inputs_map = {} def append_static_inputs(self, batch_size): # Only use this function with GQA and with use_cuda_graph=True if batch_size in self.static_inputs_map: return cpu_device = torch.device("cpu") cuda_device = torch.device("cuda", self.device_id) static_io = {} static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device) static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device) static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device) static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device) cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size) for i in range(self.num_layers): cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16) static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()}) static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device) self.static_inputs_map[batch_size] = static_io def get_initial_inputs_and_outputs(self, encodings_dict): self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32 input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32) attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32) batch_size, sequence_length = input_ids.shape self.use_traced_inputs = ( self.use_cuda_graph and (batch_size in self.static_inputs_map) and self.use_buffer_share and not self.packed_kv ) step = ( torch.tensor([0], device=self.device, dtype=torch.int64) if not self.use_traced_inputs else self.static_inputs_map[batch_size]["step"] ) seqlens_k = ( torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32) if not self.use_traced_inputs else self.static_inputs_map[batch_size]["seqlens_k"] ) cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32)) total_seq_length = ( torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32) if not self.use_traced_inputs else self.static_inputs_map[batch_size]["total_sequence_length"] ) total_seq_length[0] = sequence_length inputs = { "input_ids": input_ids.contiguous(), "attention_mask": attention_mask.contiguous(), } if self.use_step: inputs["step"] = step.contiguous() if self.use_cuda_graph: inputs["seqlens_k"] = seqlens_k.contiguous() inputs["total_sequence_length"] = total_seq_length.contiguous() del inputs["attention_mask"] past_seq_length = self.max_sequence_length if self.use_buffer_share else 0 past_shape = ( (2, batch_size, self.num_heads, past_seq_length, self.head_size) if self.packed_kv else (batch_size, self.num_heads, past_seq_length, self.head_size) ) if not self.use_traced_inputs: for i in range(self.num_layers): past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype) ( inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()}) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()}) ) else: for i in range(self.num_layers): inputs.update( { f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(), f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(), } ) logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype) outputs = {"logits": logits.contiguous()} if not self.use_buffer_share: present_shape = ( (2, batch_size, self.num_heads, sequence_length, self.head_size) if self.packed_kv else (batch_size, self.num_heads, sequence_length, self.head_size) ) for i in range(self.num_layers): present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) ( outputs.update( {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()} ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) ) return inputs, outputs def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict): io_binding = model.io_binding() device = None for k, v in inputs.items(): io_binding.bind_input( name=k, device_type=v.device.type, device_id=0 if v.device.type == "cpu" else v.device.index, element_type=pt_to_np[repr(v.dtype)], shape=tuple(v.shape), buffer_ptr=v.data_ptr(), ) device = v.device for output in model.get_outputs(): name = output.name if self.use_buffer_share and "present" in name: v = inputs[name.replace("present", "past")] io_binding.bind_output( name=name, device_type=v.device.type, device_id=v.device.index, element_type=(np.float16 if self.use_fp16 else np.float32), shape=tuple(v.shape), buffer_ptr=v.data_ptr(), ) else: v = outputs[name] io_binding.bind_output( name=name, device_type=device.type, device_id=0 if device.type == "cpu" else device.index, element_type=(np.float16 if self.use_fp16 else np.float32), shape=tuple(v.shape), buffer_ptr=v.data_ptr(), ) return io_binding def create_session( self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False ): self.device_id = device_id sess_options = ort.SessionOptions() sess_options.log_verbosity_level = 4 sess_options.log_severity_level = 4 self.use_cuda_graph = use_cuda_graph ep = ( ("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph}) if self.device_id >= 0 else "CPUExecutionProvider" ) self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep]) self.ro = ort.RunOptions() self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu") self.use_fp16 = use_fp16 self.use_buffer_share = use_buffer_share self.packed_kv = packed_kv self.use_step = use_step self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) self.tokenizer.pad_token = "[PAD]" def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False): inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict) all_token_ids = inputs["input_ids"].clone() batch_size, sequence_length = all_token_ids.shape current_length = sequence_length has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool) if benchmark: latency = [] prompt_run = True while current_length < max_length: io_binding = self.apply_io_binding(self.sess, inputs, outputs) if benchmark: start = time.time() io_binding.synchronize_inputs() if prompt_run: if self.use_cuda_graph: # Disable CUDA graph for the prompt run self.ro.add_run_config_entry("gpu_graph_id", "-1") self.sess.run_with_iobinding(io_binding, self.ro) if self.use_cuda_graph: # Enable CUDA graph for the decoding run self.ro.add_run_config_entry( "gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1" ) prompt_run = False else: self.sess.run_with_iobinding(io_binding, self.ro) io_binding.synchronize_outputs() if benchmark: end = time.time() latency.append(end - start) # Sample with argmax (greedy search) next_token_logits = outputs["logits"][:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1) # Check if we previously reached EOS token id or if generated token id is EOS token id has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id # Determine which new tokens to add to list of all token ids # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't) tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1]) all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) # Return early if all batch entries have reached EOS token id if torch.all(has_eos): break # Update inputs for next inference run current_length += 1 inputs["input_ids"] = tokens_to_add.to(torch.int32) if self.use_traced_inputs: cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"]) inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"] if self.use_step: inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64) if self.use_traced_inputs: cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"]) inputs["step"] = self.static_inputs_map[batch_size]["step"] if self.use_cuda_graph: previous_seqlens_k = inputs["seqlens_k"] inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32) inputs["total_sequence_length"][0] = current_length if self.use_traced_inputs: cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"]) inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"] self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0] inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"] else: inputs["attention_mask"] = torch.cat( [inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1 ).to(torch.int32) # Set logits to zeros for next inference run and re-use memory buffer if outputs["logits"].shape[1] != 1: outputs["logits"] = outputs["logits"][:, :1, :].contiguous() if self.use_traced_inputs: outputs["logits"] = self.static_inputs_map[batch_size]["logits"] outputs["logits"].zero_() if not self.use_buffer_share: for i in range(self.num_layers): if not self.packed_kv: inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"] inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"] else: inputs[f"past_{i}"] = outputs[f"present_{i}"] new_sequence_length = inputs["attention_mask"].shape[1] present_shape = ( (2, batch_size, self.num_heads, new_sequence_length, self.head_size) if self.packed_kv else (batch_size, self.num_heads, new_sequence_length, self.head_size) ) for i in range(self.num_layers): present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) ( outputs.update( { f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous(), } ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) ) if benchmark: print( f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}" ) print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms") return texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True) return texts def generate(self, prompt, max_length, cuda_graph_annotation): encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation) def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation): batch_size, sequence_length = prompt_shape max_length = sequence_length + token_num encodings_dict = {} encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist() encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist() # Warm up run self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False) # Benchmark run self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True) def run_phi2( onnx_model_path, use_buffer_share, device_id, packed_kv=False, use_fp16=True, use_step=False, use_cuda_graph=False, run_benchmark=False, ): generator = ORTGenerator(onnx_model_path) generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph) def simple_run(prompt): example_batch_size = len(prompt) if use_cuda_graph: generator.append_static_inputs(batch_size=example_batch_size) texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size) for i in range(len(texts)): print("Prompt: ", prompt[i]) print("Texts: ", texts[i]) prompt = [ '''```python def print_prime(n): """ Print all primes between 1 and n """''' ] if not run_benchmark: simple_run(prompt) # Run simple benchmark. Time the decoder only. if run_benchmark: token_num = 32 for batch_size in [1, 2, 4, 8]: generator.append_static_inputs(batch_size) for sequence_length in [16, 512]: prompt_shape = (batch_size, sequence_length) generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)