llama_inputs.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import numpy as np
  8. import torch
  9. from transformers import AutoConfig, AutoTokenizer
  10. from transformers.cache_utils import DynamicCache
  11. from onnxruntime import InferenceSession, OrtValue
  12. # Get position_ids from attention_mask
  13. def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
  14. position_ids = attention_mask.long().cumsum(-1) - 1
  15. position_ids.masked_fill_(attention_mask == 0, 1)
  16. if use_past_kv:
  17. # Shape: (batch_size, 1)
  18. position_ids = position_ids[:, -1].unsqueeze(-1)
  19. # Shape: (batch_size, sequence_length)
  20. return position_ids
  21. # Inputs for first pass to get initial past_key_values
  22. # input_ids: (batch_size, sequence_length)
  23. # attention_mask: (batch_size, sequence_length)
  24. # position_ids: (batch_size, sequence_length)
  25. def get_sample_inputs(
  26. config: AutoConfig,
  27. device: torch.device,
  28. batch_size: int,
  29. seq_len: int,
  30. engine: str = "pt",
  31. return_dict: bool = False,
  32. ):
  33. input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
  34. attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
  35. position_ids = get_position_ids(attention_mask, use_past_kv=False)
  36. # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
  37. input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
  38. attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
  39. position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
  40. if not return_dict:
  41. # For export
  42. return (input_ids, attention_mask, position_ids)
  43. inputs = {
  44. "input_ids": input_ids,
  45. "attention_mask": attention_mask,
  46. "position_ids": position_ids,
  47. }
  48. return inputs
  49. # Inputs for subsequent passes with past_key_values
  50. # input_ids: (batch_size, 1)
  51. # attention_mask: (batch_size, past_sequence_length + 1)
  52. # position_ids: (batch_size, 1)
  53. # past_key: (batch_size, num_heads, past_sequence_length, head_size)
  54. # past_value: (batch_size, num_heads, past_sequence_length, head_size)
  55. def get_sample_with_past_kv_inputs(
  56. config: AutoConfig,
  57. device: torch.device,
  58. batch_size: int,
  59. past_seq_len: int,
  60. use_fp16: bool = False,
  61. engine: str = "pt",
  62. return_dict: bool = False,
  63. world_size: int = 1,
  64. ):
  65. input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
  66. attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
  67. # position_ids is of shape (batch_size, 1)
  68. position_ids = get_position_ids(attention_mask, use_past_kv=True)
  69. past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
  70. # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
  71. input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
  72. attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
  73. position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
  74. past_kv = (
  75. flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
  76. )
  77. if not return_dict:
  78. # For export
  79. assert isinstance(past_kv, list)
  80. return (input_ids, attention_mask, position_ids, past_kv)
  81. inputs = {
  82. "input_ids": input_ids,
  83. "attention_mask": attention_mask,
  84. "position_ids": position_ids,
  85. }
  86. if engine == "ort":
  87. assert isinstance(past_kv, dict)
  88. inputs.update(past_kv)
  89. else:
  90. assert isinstance(past_kv, list)
  91. inputs["past_key_values"] = past_kv
  92. return inputs
  93. # Inputs for all passes with past_key_values
  94. # input_ids: (batch_size, sequence_length)
  95. # attention_mask: (batch_size, past_sequence_length + sequence_length)
  96. # position_ids: (batch_size, sequence_length)
  97. # past_key: (batch_size, num_heads, kv_sequence_length, head_size)
  98. # For models with GQA, kv_sequence_length = max_sequence_length
  99. # For models without GQA, kv_sequence_length = past_sequence_length
  100. # past_value: (batch_size, num_heads, kv_sequence_length, head_size)
  101. # For models with GQA, kv_sequence_length = max_sequence_length
  102. # For models without GQA, kv_sequence_length = past_sequence_length
  103. def get_merged_sample_with_past_kv_inputs(
  104. config: AutoConfig,
  105. device: torch.device,
  106. batch_size: int,
  107. seq_len: int,
  108. past_seq_len: int,
  109. max_seq_len: int,
  110. use_fp16: bool = False,
  111. use_buffer_share: bool = False,
  112. engine: str = "pt",
  113. return_dict: bool = False,
  114. world_size: int = 1,
  115. ):
  116. input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
  117. attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
  118. # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
  119. position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
  120. past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
  121. # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
  122. input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
  123. attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
  124. position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
  125. past_kv = (
  126. flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
  127. )
  128. if not return_dict:
  129. # For export
  130. assert isinstance(past_kv, list)
  131. return (input_ids, attention_mask, position_ids, past_kv)
  132. inputs = {
  133. "input_ids": input_ids,
  134. "attention_mask": attention_mask,
  135. "position_ids": position_ids,
  136. }
  137. if engine == "ort":
  138. assert isinstance(past_kv, dict)
  139. inputs.update(past_kv)
  140. if use_buffer_share:
  141. inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
  142. else:
  143. assert isinstance(past_kv, list)
  144. inputs["past_key_values"] = past_kv
  145. return inputs
  146. # Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
  147. def get_msft_sample_inputs(
  148. config: AutoConfig,
  149. batch_size: int,
  150. past_seq_len: int,
  151. seq_len: int,
  152. max_seq_len: int,
  153. use_fp16: bool,
  154. use_buffer_share: bool,
  155. split_kv: bool,
  156. ):
  157. np_dtype = np.float16 if use_fp16 else np.float32
  158. head_size = config.hidden_size // config.num_attention_heads
  159. if not split_kv:
  160. ort_inputs = {
  161. "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
  162. "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
  163. "k_cache": np.random.rand(
  164. batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
  165. ).astype(np_dtype),
  166. "v_cache": np.random.rand(
  167. batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
  168. ).astype(np_dtype),
  169. "pos": np.array(past_seq_len, dtype=np.int64),
  170. }
  171. else:
  172. ort_inputs = {
  173. "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
  174. "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
  175. np.int32
  176. ),
  177. "pos": np.array(past_seq_len, dtype=np.int64),
  178. }
  179. for i in range(config.num_hidden_layers):
  180. ort_inputs.update(
  181. {
  182. f"k_{i}_cache": np.random.rand(
  183. batch_size, config.num_attention_heads, past_seq_len, head_size
  184. ).astype(np_dtype),
  185. f"v_{i}_cache": np.random.rand(
  186. batch_size, config.num_attention_heads, past_seq_len, head_size
  187. ).astype(np_dtype),
  188. }
  189. )
  190. if use_buffer_share:
  191. ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
  192. return ort_inputs
  193. # Create past_key_values
  194. # Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
  195. def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
  196. num_heads = config.num_key_value_heads // world_size
  197. head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
  198. torch_dtype = torch.float16 if use_fp16 else torch.float32
  199. past_kv = [
  200. (
  201. torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
  202. torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
  203. )
  204. for _ in range(config.num_hidden_layers)
  205. ]
  206. return past_kv
  207. # Convert list of past_key_values to dict of past_key and past_value
  208. def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
  209. past_kv = {}
  210. for i, (past_k, past_v) in enumerate(past_key_values):
  211. if isinstance(past_key_values, DynamicCache):
  212. past_kv[f"past_key_values_key_cache_{i}"] = past_k.detach().cpu().numpy()
  213. past_kv[f"past_key_values_value_cache_{i}"] = past_v.detach().cpu().numpy()
  214. else:
  215. past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
  216. past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
  217. return past_kv
  218. # Format PyTorch inputs to ONNX Runtime inputs
  219. def convert_inputs_for_ort(
  220. pt_inputs: dict,
  221. use_buffer_share: bool = False,
  222. past_seq_len: int = 0,
  223. max_seq_len: int = 2048,
  224. ):
  225. ort_inputs = {}
  226. for k, v in pt_inputs.items():
  227. if isinstance(v, np.ndarray):
  228. ort_inputs[k] = v
  229. elif k == "past_key_values":
  230. ort_inputs.update(flatten_past_kv_inputs(v))
  231. else:
  232. ort_inputs[k] = v.detach().cpu().numpy()
  233. # Reshape KV caches if using past-present-share-buffer
  234. if use_buffer_share:
  235. ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
  236. return ort_inputs
  237. # Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
  238. # (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
  239. def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
  240. for k, v in ort_inputs.items():
  241. # Allocate new buffers with max_sequence_length for GQA
  242. if "cache" in k or "past_key_values" in k:
  243. # Copy v (BxSxPxH) into new_v (BxSxMxH)
  244. batch_size, num_heads, _, head_size = v.shape
  245. new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
  246. new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
  247. ort_inputs[k] = new_v
  248. return ort_inputs
  249. # Verify ONNX Runtime inputs with model
  250. def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
  251. # Check that all model inputs will be provided
  252. model_inputs = {model_input.name for model_input in model.get_inputs()}
  253. user_inputs = set(ort_inputs.keys())
  254. missing_inputs = model_inputs - user_inputs
  255. if len(missing_inputs):
  256. print(f"The following model inputs are missing: {missing_inputs}")
  257. raise Exception("There are missing inputs to the model. Please add them and try again.")
  258. # Remove unnecessary inputs from model inputs
  259. unnecessary_inputs = user_inputs - model_inputs
  260. if len(unnecessary_inputs):
  261. for unnecessary_input in unnecessary_inputs:
  262. del ort_inputs[unnecessary_input]
  263. return ort_inputs
  264. # Add IO bindings for execution providers using OrtValue
  265. # Use when you need to run inference once or twice to save memory
  266. def add_io_bindings_as_ortvalues(
  267. model: InferenceSession,
  268. ort_inputs: dict,
  269. device: str,
  270. device_id: int,
  271. use_buffer_share: bool,
  272. kv_cache_ortvalues: dict,
  273. ):
  274. io_binding = model.io_binding()
  275. model_inputs = {i.name for i in model.get_inputs()}
  276. for k, v in ort_inputs.items():
  277. # Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
  278. # GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
  279. # but `position_ids` is used as a PyTorch model input
  280. if k not in model_inputs:
  281. continue
  282. # Bind OrtValue inputs to device
  283. if use_buffer_share and ("cache" in k or "past_key_values" in k):
  284. if k not in kv_cache_ortvalues:
  285. v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
  286. io_binding.bind_ortvalue_input(k, v_device)
  287. kv_cache_ortvalues[k] = v_device
  288. else:
  289. kv_cache_ortvalues[k].update_inplace(v)
  290. io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
  291. else:
  292. v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
  293. io_binding.bind_ortvalue_input(k, v_device)
  294. for output in model.get_outputs():
  295. name = output.name
  296. if use_buffer_share and ("out" in name or "present" in name):
  297. # Bind present KV cache outputs to past KV cache inputs in order to buffer share
  298. input_name = name.replace("out", "cache").replace("present", "past_key_values")
  299. io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
  300. else:
  301. io_binding.bind_output(name, device_type=device, device_id=device_id)
  302. return io_binding, kv_cache_ortvalues
  303. # Add IO bindings for execution providers using PyTorch tensors
  304. # Use when you need to run inference many times
  305. def add_io_bindings_as_tensors(
  306. model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
  307. ):
  308. # Verify model inputs
  309. inputs = verify_ort_inputs(model, inputs)
  310. device = None
  311. pt_to_np = {
  312. "torch.int32": np.int32,
  313. "torch.int64": np.int64,
  314. "torch.float16": np.float16,
  315. "torch.float32": np.float32,
  316. }
  317. # Bind inputs/outputs to IO binding
  318. io_binding = model.io_binding()
  319. for k, v in inputs.items():
  320. io_binding.bind_input(
  321. name=k,
  322. device_type=v.device.type,
  323. device_id=0 if v.device.type == "cpu" else v.device.index,
  324. element_type=pt_to_np[repr(v.dtype)],
  325. shape=tuple(v.shape),
  326. buffer_ptr=v.data_ptr(),
  327. )
  328. device = v.device
  329. for output in model.get_outputs():
  330. name = output.name
  331. # Bind KV cache outputs to KV cache inputs
  332. v = (
  333. inputs[name.replace("present", "past_key_values")]
  334. if use_buffer_share and "present" in name
  335. else outputs[name]
  336. )
  337. io_binding.bind_output(
  338. name=name,
  339. device_type=device.type,
  340. device_id=0 if device.type == "cpu" else device.index,
  341. element_type=(np.float16 if use_fp16 else np.float32),
  342. shape=tuple(v.shape),
  343. buffer_ptr=v.data_ptr(),
  344. )
  345. return io_binding
  346. # Get actual inputs when using real data (instead of sample data) and initialize outputs
  347. def get_initial_inputs_and_outputs(
  348. config: AutoConfig,
  349. tokenizer: AutoTokenizer,
  350. requested_length: int,
  351. prompt: list[str],
  352. device: torch.device,
  353. use_fp16: bool,
  354. use_buffer_share: bool,
  355. engine: str,
  356. ):
  357. tokenizer.pad_token = tokenizer.eos_token
  358. encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
  359. torch_dtype = torch.float16 if use_fp16 else torch.float32
  360. # input_ids: pad token id is 0
  361. # attention_mask: pad token id is 0
  362. # position_ids: pad token id is 1
  363. input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
  364. attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
  365. position_ids = get_position_ids(attention_mask, use_past_kv=False)
  366. # Check if tokenized prompt length matches the requested prompt length
  367. tokenized_length = input_ids.shape[-1]
  368. if tokenized_length > requested_length:
  369. # Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
  370. input_ids = input_ids[:, :requested_length]
  371. attention_mask = attention_mask[:, :requested_length]
  372. position_ids = get_position_ids(attention_mask, use_past_kv=False)
  373. elif tokenized_length < requested_length:
  374. # Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
  375. input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
  376. attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
  377. for _ in range(requested_length - tokenized_length):
  378. input_ids = torch.hstack((input_ids_first_col, input_ids))
  379. attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
  380. position_ids = get_position_ids(attention_mask, use_past_kv=False)
  381. tokenized_length = input_ids.shape[-1]
  382. assert tokenized_length == requested_length
  383. # Create inputs
  384. inputs = {
  385. "input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
  386. "attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
  387. "position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
  388. }
  389. if engine != "ort":
  390. inputs["past_key_values"] = []
  391. # Get shape of KV cache inputs
  392. batch_size, sequence_length = input_ids.shape
  393. max_sequence_length = config.max_position_embeddings
  394. num_heads = config.num_key_value_heads
  395. head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
  396. # Create KV cache inputs
  397. for i in range(config.num_hidden_layers):
  398. past_key = torch.zeros(
  399. batch_size,
  400. num_heads,
  401. max_sequence_length if use_buffer_share else 0,
  402. head_size,
  403. device=device,
  404. dtype=torch_dtype,
  405. )
  406. past_value = torch.zeros(
  407. batch_size,
  408. num_heads,
  409. max_sequence_length if use_buffer_share else 0,
  410. head_size,
  411. device=device,
  412. dtype=torch_dtype,
  413. )
  414. if engine == "ort":
  415. inputs.update(
  416. {
  417. f"past_key_values.{i}.key": past_key.contiguous(),
  418. f"past_key_values.{i}.value": past_value.contiguous(),
  419. }
  420. )
  421. else:
  422. inputs["past_key_values"].append((past_key, past_value))
  423. outputs = None
  424. if engine == "ort":
  425. # Create outputs
  426. logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
  427. outputs = {"logits": logits.contiguous()}
  428. if not use_buffer_share:
  429. for i in range(config.num_hidden_layers):
  430. present_key = torch.zeros(
  431. batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
  432. )
  433. present_value = torch.zeros(
  434. batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
  435. )
  436. outputs.update(
  437. {f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
  438. )
  439. return inputs, outputs