inference_example.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import time
  6. import numpy as np
  7. import torch
  8. from transformers import AutoTokenizer
  9. import onnxruntime as ort
  10. pt_to_np = {
  11. "torch.int32": np.int32,
  12. "torch.int64": np.int64,
  13. "torch.float32": np.float32,
  14. "torch.float16": np.float16,
  15. }
  16. def cuda_memcpy(dst, src):
  17. from cuda import cudart # noqa: PLC0415
  18. cudart.cudaMemcpy(
  19. dst.data_ptr(),
  20. src.data_ptr(),
  21. src.element_size() * src.nelement(),
  22. cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
  23. )
  24. class ORTGenerator:
  25. def __init__(self, decoder_path):
  26. self.onnx_decoder_path = decoder_path
  27. self.num_heads = 32
  28. self.head_size = 80
  29. self.num_layers = 32
  30. self.max_sequence_length = 2048
  31. self.device_id = 0
  32. self.use_cuda_graph = False
  33. self.use_traced_inputs = False
  34. self.static_inputs_map = {}
  35. def append_static_inputs(self, batch_size):
  36. # Only use this function with GQA and with use_cuda_graph=True
  37. if batch_size in self.static_inputs_map:
  38. return
  39. cpu_device = torch.device("cpu")
  40. cuda_device = torch.device("cuda", self.device_id)
  41. static_io = {}
  42. static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device)
  43. static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device)
  44. static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device)
  45. static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device)
  46. cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size)
  47. for i in range(self.num_layers):
  48. cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16)
  49. static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()})
  50. static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device)
  51. self.static_inputs_map[batch_size] = static_io
  52. def get_initial_inputs_and_outputs(self, encodings_dict):
  53. self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
  54. input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
  55. attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
  56. batch_size, sequence_length = input_ids.shape
  57. self.use_traced_inputs = (
  58. self.use_cuda_graph
  59. and (batch_size in self.static_inputs_map)
  60. and self.use_buffer_share
  61. and not self.packed_kv
  62. )
  63. step = (
  64. torch.tensor([0], device=self.device, dtype=torch.int64)
  65. if not self.use_traced_inputs
  66. else self.static_inputs_map[batch_size]["step"]
  67. )
  68. seqlens_k = (
  69. torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
  70. if not self.use_traced_inputs
  71. else self.static_inputs_map[batch_size]["seqlens_k"]
  72. )
  73. cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))
  74. total_seq_length = (
  75. torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
  76. if not self.use_traced_inputs
  77. else self.static_inputs_map[batch_size]["total_sequence_length"]
  78. )
  79. total_seq_length[0] = sequence_length
  80. inputs = {
  81. "input_ids": input_ids.contiguous(),
  82. "attention_mask": attention_mask.contiguous(),
  83. }
  84. if self.use_step:
  85. inputs["step"] = step.contiguous()
  86. if self.use_cuda_graph:
  87. inputs["seqlens_k"] = seqlens_k.contiguous()
  88. inputs["total_sequence_length"] = total_seq_length.contiguous()
  89. del inputs["attention_mask"]
  90. past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
  91. past_shape = (
  92. (2, batch_size, self.num_heads, past_seq_length, self.head_size)
  93. if self.packed_kv
  94. else (batch_size, self.num_heads, past_seq_length, self.head_size)
  95. )
  96. if not self.use_traced_inputs:
  97. for i in range(self.num_layers):
  98. past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
  99. (
  100. inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()})
  101. if not self.packed_kv
  102. else inputs.update({f"past_{i}": past.contiguous()})
  103. )
  104. else:
  105. for i in range(self.num_layers):
  106. inputs.update(
  107. {
  108. f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(),
  109. f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(),
  110. }
  111. )
  112. logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
  113. outputs = {"logits": logits.contiguous()}
  114. if not self.use_buffer_share:
  115. present_shape = (
  116. (2, batch_size, self.num_heads, sequence_length, self.head_size)
  117. if self.packed_kv
  118. else (batch_size, self.num_heads, sequence_length, self.head_size)
  119. )
  120. for i in range(self.num_layers):
  121. present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
  122. (
  123. outputs.update(
  124. {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
  125. )
  126. if not self.packed_kv
  127. else outputs.update({f"present_{i}": present.contiguous()})
  128. )
  129. return inputs, outputs
  130. def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
  131. io_binding = model.io_binding()
  132. device = None
  133. for k, v in inputs.items():
  134. io_binding.bind_input(
  135. name=k,
  136. device_type=v.device.type,
  137. device_id=0 if v.device.type == "cpu" else v.device.index,
  138. element_type=pt_to_np[repr(v.dtype)],
  139. shape=tuple(v.shape),
  140. buffer_ptr=v.data_ptr(),
  141. )
  142. device = v.device
  143. for output in model.get_outputs():
  144. name = output.name
  145. if self.use_buffer_share and "present" in name:
  146. v = inputs[name.replace("present", "past")]
  147. io_binding.bind_output(
  148. name=name,
  149. device_type=v.device.type,
  150. device_id=v.device.index,
  151. element_type=(np.float16 if self.use_fp16 else np.float32),
  152. shape=tuple(v.shape),
  153. buffer_ptr=v.data_ptr(),
  154. )
  155. else:
  156. v = outputs[name]
  157. io_binding.bind_output(
  158. name=name,
  159. device_type=device.type,
  160. device_id=0 if device.type == "cpu" else device.index,
  161. element_type=(np.float16 if self.use_fp16 else np.float32),
  162. shape=tuple(v.shape),
  163. buffer_ptr=v.data_ptr(),
  164. )
  165. return io_binding
  166. def create_session(
  167. self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False
  168. ):
  169. self.device_id = device_id
  170. sess_options = ort.SessionOptions()
  171. sess_options.log_verbosity_level = 4
  172. sess_options.log_severity_level = 4
  173. self.use_cuda_graph = use_cuda_graph
  174. ep = (
  175. ("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
  176. if self.device_id >= 0
  177. else "CPUExecutionProvider"
  178. )
  179. self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
  180. self.ro = ort.RunOptions()
  181. self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu")
  182. self.use_fp16 = use_fp16
  183. self.use_buffer_share = use_buffer_share
  184. self.packed_kv = packed_kv
  185. self.use_step = use_step
  186. self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
  187. self.tokenizer.pad_token = "[PAD]"
  188. def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
  189. inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
  190. all_token_ids = inputs["input_ids"].clone()
  191. batch_size, sequence_length = all_token_ids.shape
  192. current_length = sequence_length
  193. has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
  194. if benchmark:
  195. latency = []
  196. prompt_run = True
  197. while current_length < max_length:
  198. io_binding = self.apply_io_binding(self.sess, inputs, outputs)
  199. if benchmark:
  200. start = time.time()
  201. io_binding.synchronize_inputs()
  202. if prompt_run:
  203. if self.use_cuda_graph:
  204. # Disable CUDA graph for the prompt run
  205. self.ro.add_run_config_entry("gpu_graph_id", "-1")
  206. self.sess.run_with_iobinding(io_binding, self.ro)
  207. if self.use_cuda_graph:
  208. # Enable CUDA graph for the decoding run
  209. self.ro.add_run_config_entry(
  210. "gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1"
  211. )
  212. prompt_run = False
  213. else:
  214. self.sess.run_with_iobinding(io_binding, self.ro)
  215. io_binding.synchronize_outputs()
  216. if benchmark:
  217. end = time.time()
  218. latency.append(end - start)
  219. # Sample with argmax (greedy search)
  220. next_token_logits = outputs["logits"][:, -1, :]
  221. next_tokens = torch.argmax(next_token_logits, dim=-1)
  222. # Check if we previously reached EOS token id or if generated token id is EOS token id
  223. has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
  224. # Determine which new tokens to add to list of all token ids
  225. # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
  226. tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
  227. all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
  228. # Return early if all batch entries have reached EOS token id
  229. if torch.all(has_eos):
  230. break
  231. # Update inputs for next inference run
  232. current_length += 1
  233. inputs["input_ids"] = tokens_to_add.to(torch.int32)
  234. if self.use_traced_inputs:
  235. cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"])
  236. inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"]
  237. if self.use_step:
  238. inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
  239. if self.use_traced_inputs:
  240. cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"])
  241. inputs["step"] = self.static_inputs_map[batch_size]["step"]
  242. if self.use_cuda_graph:
  243. previous_seqlens_k = inputs["seqlens_k"]
  244. inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
  245. inputs["total_sequence_length"][0] = current_length
  246. if self.use_traced_inputs:
  247. cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
  248. inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"]
  249. self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0]
  250. inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"]
  251. else:
  252. inputs["attention_mask"] = torch.cat(
  253. [inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1
  254. ).to(torch.int32)
  255. # Set logits to zeros for next inference run and re-use memory buffer
  256. if outputs["logits"].shape[1] != 1:
  257. outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
  258. if self.use_traced_inputs:
  259. outputs["logits"] = self.static_inputs_map[batch_size]["logits"]
  260. outputs["logits"].zero_()
  261. if not self.use_buffer_share:
  262. for i in range(self.num_layers):
  263. if not self.packed_kv:
  264. inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
  265. inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
  266. else:
  267. inputs[f"past_{i}"] = outputs[f"present_{i}"]
  268. new_sequence_length = inputs["attention_mask"].shape[1]
  269. present_shape = (
  270. (2, batch_size, self.num_heads, new_sequence_length, self.head_size)
  271. if self.packed_kv
  272. else (batch_size, self.num_heads, new_sequence_length, self.head_size)
  273. )
  274. for i in range(self.num_layers):
  275. present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
  276. (
  277. outputs.update(
  278. {
  279. f"present_key_{i}": present.contiguous(),
  280. f"present_value_{i}": present.clone().contiguous(),
  281. }
  282. )
  283. if not self.packed_kv
  284. else outputs.update({f"present_{i}": present.contiguous()})
  285. )
  286. if benchmark:
  287. print(
  288. f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
  289. )
  290. print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
  291. return
  292. texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
  293. return texts
  294. def generate(self, prompt, max_length, cuda_graph_annotation):
  295. encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
  296. return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
  297. def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
  298. batch_size, sequence_length = prompt_shape
  299. max_length = sequence_length + token_num
  300. encodings_dict = {}
  301. encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
  302. encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()
  303. # Warm up run
  304. self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)
  305. # Benchmark run
  306. self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)
  307. def run_phi2(
  308. onnx_model_path,
  309. use_buffer_share,
  310. device_id,
  311. packed_kv=False,
  312. use_fp16=True,
  313. use_step=False,
  314. use_cuda_graph=False,
  315. run_benchmark=False,
  316. ):
  317. generator = ORTGenerator(onnx_model_path)
  318. generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph)
  319. def simple_run(prompt):
  320. example_batch_size = len(prompt)
  321. if use_cuda_graph:
  322. generator.append_static_inputs(batch_size=example_batch_size)
  323. texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)
  324. for i in range(len(texts)):
  325. print("Prompt: ", prompt[i])
  326. print("Texts: ", texts[i])
  327. prompt = [
  328. '''```python
  329. def print_prime(n):
  330. """
  331. Print all primes between 1 and n
  332. """'''
  333. ]
  334. if not run_benchmark:
  335. simple_run(prompt)
  336. # Run simple benchmark. Time the decoder only.
  337. if run_benchmark:
  338. token_num = 32
  339. for batch_size in [1, 2, 4, 8]:
  340. generator.append_static_inputs(batch_size)
  341. for sequence_length in [16, 512]:
  342. prompt_shape = (batch_size, sequence_length)
  343. generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)