| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- import time
- import numpy as np
- import torch
- from transformers import AutoTokenizer
- import onnxruntime as ort
- pt_to_np = {
- "torch.int32": np.int32,
- "torch.int64": np.int64,
- "torch.float32": np.float32,
- "torch.float16": np.float16,
- }
- def cuda_memcpy(dst, src):
- from cuda import cudart # noqa: PLC0415
- cudart.cudaMemcpy(
- dst.data_ptr(),
- src.data_ptr(),
- src.element_size() * src.nelement(),
- cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
- )
- class ORTGenerator:
- def __init__(self, decoder_path):
- self.onnx_decoder_path = decoder_path
- self.num_heads = 32
- self.head_size = 80
- self.num_layers = 32
- self.max_sequence_length = 2048
- self.device_id = 0
- self.use_cuda_graph = False
- self.use_traced_inputs = False
- self.static_inputs_map = {}
- def append_static_inputs(self, batch_size):
- # Only use this function with GQA and with use_cuda_graph=True
- if batch_size in self.static_inputs_map:
- return
- cpu_device = torch.device("cpu")
- cuda_device = torch.device("cuda", self.device_id)
- static_io = {}
- static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device)
- static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device)
- static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device)
- static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device)
- cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size)
- for i in range(self.num_layers):
- cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16)
- static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()})
- static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device)
- self.static_inputs_map[batch_size] = static_io
- def get_initial_inputs_and_outputs(self, encodings_dict):
- self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
- input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
- attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
- batch_size, sequence_length = input_ids.shape
- self.use_traced_inputs = (
- self.use_cuda_graph
- and (batch_size in self.static_inputs_map)
- and self.use_buffer_share
- and not self.packed_kv
- )
- step = (
- torch.tensor([0], device=self.device, dtype=torch.int64)
- if not self.use_traced_inputs
- else self.static_inputs_map[batch_size]["step"]
- )
- seqlens_k = (
- torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
- if not self.use_traced_inputs
- else self.static_inputs_map[batch_size]["seqlens_k"]
- )
- cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))
- total_seq_length = (
- torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
- if not self.use_traced_inputs
- else self.static_inputs_map[batch_size]["total_sequence_length"]
- )
- total_seq_length[0] = sequence_length
- inputs = {
- "input_ids": input_ids.contiguous(),
- "attention_mask": attention_mask.contiguous(),
- }
- if self.use_step:
- inputs["step"] = step.contiguous()
- if self.use_cuda_graph:
- inputs["seqlens_k"] = seqlens_k.contiguous()
- inputs["total_sequence_length"] = total_seq_length.contiguous()
- del inputs["attention_mask"]
- past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
- past_shape = (
- (2, batch_size, self.num_heads, past_seq_length, self.head_size)
- if self.packed_kv
- else (batch_size, self.num_heads, past_seq_length, self.head_size)
- )
- if not self.use_traced_inputs:
- for i in range(self.num_layers):
- past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
- (
- inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()})
- if not self.packed_kv
- else inputs.update({f"past_{i}": past.contiguous()})
- )
- else:
- for i in range(self.num_layers):
- inputs.update(
- {
- f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(),
- f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(),
- }
- )
- logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
- outputs = {"logits": logits.contiguous()}
- if not self.use_buffer_share:
- present_shape = (
- (2, batch_size, self.num_heads, sequence_length, self.head_size)
- if self.packed_kv
- else (batch_size, self.num_heads, sequence_length, self.head_size)
- )
- for i in range(self.num_layers):
- present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
- (
- outputs.update(
- {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
- )
- if not self.packed_kv
- else outputs.update({f"present_{i}": present.contiguous()})
- )
- return inputs, outputs
- def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
- io_binding = model.io_binding()
- device = None
- for k, v in inputs.items():
- io_binding.bind_input(
- name=k,
- device_type=v.device.type,
- device_id=0 if v.device.type == "cpu" else v.device.index,
- element_type=pt_to_np[repr(v.dtype)],
- shape=tuple(v.shape),
- buffer_ptr=v.data_ptr(),
- )
- device = v.device
- for output in model.get_outputs():
- name = output.name
- if self.use_buffer_share and "present" in name:
- v = inputs[name.replace("present", "past")]
- io_binding.bind_output(
- name=name,
- device_type=v.device.type,
- device_id=v.device.index,
- element_type=(np.float16 if self.use_fp16 else np.float32),
- shape=tuple(v.shape),
- buffer_ptr=v.data_ptr(),
- )
- else:
- v = outputs[name]
- io_binding.bind_output(
- name=name,
- device_type=device.type,
- device_id=0 if device.type == "cpu" else device.index,
- element_type=(np.float16 if self.use_fp16 else np.float32),
- shape=tuple(v.shape),
- buffer_ptr=v.data_ptr(),
- )
- return io_binding
- def create_session(
- self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False
- ):
- self.device_id = device_id
- sess_options = ort.SessionOptions()
- sess_options.log_verbosity_level = 4
- sess_options.log_severity_level = 4
- self.use_cuda_graph = use_cuda_graph
- ep = (
- ("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
- if self.device_id >= 0
- else "CPUExecutionProvider"
- )
- self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
- self.ro = ort.RunOptions()
- self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu")
- self.use_fp16 = use_fp16
- self.use_buffer_share = use_buffer_share
- self.packed_kv = packed_kv
- self.use_step = use_step
- self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
- self.tokenizer.pad_token = "[PAD]"
- def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
- inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
- all_token_ids = inputs["input_ids"].clone()
- batch_size, sequence_length = all_token_ids.shape
- current_length = sequence_length
- has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
- if benchmark:
- latency = []
- prompt_run = True
- while current_length < max_length:
- io_binding = self.apply_io_binding(self.sess, inputs, outputs)
- if benchmark:
- start = time.time()
- io_binding.synchronize_inputs()
- if prompt_run:
- if self.use_cuda_graph:
- # Disable CUDA graph for the prompt run
- self.ro.add_run_config_entry("gpu_graph_id", "-1")
- self.sess.run_with_iobinding(io_binding, self.ro)
- if self.use_cuda_graph:
- # Enable CUDA graph for the decoding run
- self.ro.add_run_config_entry(
- "gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1"
- )
- prompt_run = False
- else:
- self.sess.run_with_iobinding(io_binding, self.ro)
- io_binding.synchronize_outputs()
- if benchmark:
- end = time.time()
- latency.append(end - start)
- # Sample with argmax (greedy search)
- next_token_logits = outputs["logits"][:, -1, :]
- next_tokens = torch.argmax(next_token_logits, dim=-1)
- # Check if we previously reached EOS token id or if generated token id is EOS token id
- has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
- # Determine which new tokens to add to list of all token ids
- # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
- tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
- all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
- # Return early if all batch entries have reached EOS token id
- if torch.all(has_eos):
- break
- # Update inputs for next inference run
- current_length += 1
- inputs["input_ids"] = tokens_to_add.to(torch.int32)
- if self.use_traced_inputs:
- cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"])
- inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"]
- if self.use_step:
- inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
- if self.use_traced_inputs:
- cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"])
- inputs["step"] = self.static_inputs_map[batch_size]["step"]
- if self.use_cuda_graph:
- previous_seqlens_k = inputs["seqlens_k"]
- inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
- inputs["total_sequence_length"][0] = current_length
- if self.use_traced_inputs:
- cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
- inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"]
- self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0]
- inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"]
- else:
- inputs["attention_mask"] = torch.cat(
- [inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1
- ).to(torch.int32)
- # Set logits to zeros for next inference run and re-use memory buffer
- if outputs["logits"].shape[1] != 1:
- outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
- if self.use_traced_inputs:
- outputs["logits"] = self.static_inputs_map[batch_size]["logits"]
- outputs["logits"].zero_()
- if not self.use_buffer_share:
- for i in range(self.num_layers):
- if not self.packed_kv:
- inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
- inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
- else:
- inputs[f"past_{i}"] = outputs[f"present_{i}"]
- new_sequence_length = inputs["attention_mask"].shape[1]
- present_shape = (
- (2, batch_size, self.num_heads, new_sequence_length, self.head_size)
- if self.packed_kv
- else (batch_size, self.num_heads, new_sequence_length, self.head_size)
- )
- for i in range(self.num_layers):
- present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
- (
- outputs.update(
- {
- f"present_key_{i}": present.contiguous(),
- f"present_value_{i}": present.clone().contiguous(),
- }
- )
- if not self.packed_kv
- else outputs.update({f"present_{i}": present.contiguous()})
- )
- if benchmark:
- print(
- f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
- )
- print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
- return
- texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
- return texts
- def generate(self, prompt, max_length, cuda_graph_annotation):
- encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
- return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
- def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
- batch_size, sequence_length = prompt_shape
- max_length = sequence_length + token_num
- encodings_dict = {}
- encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
- encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()
- # Warm up run
- self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)
- # Benchmark run
- self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)
- def run_phi2(
- onnx_model_path,
- use_buffer_share,
- device_id,
- packed_kv=False,
- use_fp16=True,
- use_step=False,
- use_cuda_graph=False,
- run_benchmark=False,
- ):
- generator = ORTGenerator(onnx_model_path)
- generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph)
- def simple_run(prompt):
- example_batch_size = len(prompt)
- if use_cuda_graph:
- generator.append_static_inputs(batch_size=example_batch_size)
- texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)
- for i in range(len(texts)):
- print("Prompt: ", prompt[i])
- print("Texts: ", texts[i])
- prompt = [
- '''```python
- def print_prime(n):
- """
- Print all primes between 1 and n
- """'''
- ]
- if not run_benchmark:
- simple_run(prompt)
- # Run simple benchmark. Time the decoder only.
- if run_benchmark:
- token_num = 32
- for batch_size in [1, 2, 4, 8]:
- generator.append_static_inputs(batch_size)
- for sequence_length in [16, 512]:
- prompt_shape = (batch_size, sequence_length)
- generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)
|