gpt2_tester.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. # This script helps evaluation of GPT-2 model.
  7. import logging
  8. import math
  9. import os
  10. import statistics
  11. import timeit
  12. import numpy
  13. import torch
  14. from benchmark_helper import Precision
  15. from gpt2_helper import Gpt2Helper, Gpt2Inputs
  16. logger = logging.getLogger(__name__)
  17. class Gpt2Metric:
  18. def __init__(self, treatment_name, baseline_name="Torch", top_k=20):
  19. assert top_k > 1 and top_k <= 100
  20. self.baseline = baseline_name
  21. self.treatment = treatment_name
  22. self.name: str = f"{treatment_name} vs {baseline_name}"
  23. self.top_k = top_k
  24. self.top_1_error: int = 0
  25. self.top_k_error: int = 0
  26. self.total_samples: int = 0
  27. self.max_logits_diff: float = 0 # for non-empty past state
  28. self.max_logits_diff_no_past: float = 0 # for empty past state
  29. self.batch_top1_error: torch.FloatTensor = None # top 1 error for current batch
  30. self.batch_topk_error: torch.FloatTensor = None # top k error for current batch
  31. self.seq_len_latency = {}
  32. def print(self):
  33. if self.baseline != self.treatment:
  34. print("---")
  35. print(f"Metrics for {self.treatment} (baseline={self.baseline}):")
  36. if self.total_samples > 0:
  37. top_1_error_rate = 100.0 * self.top_1_error / self.total_samples
  38. top_k_error_rate = 100.0 * self.top_k_error / self.total_samples
  39. print(
  40. f"Total={self.total_samples} Top1Error={self.top_1_error} ({top_1_error_rate:.2f}%) Top{self.top_k}Error={self.top_k_error} ({top_k_error_rate:.2f}%)"
  41. )
  42. print("Max logits diffs:")
  43. print(f"\twith past = {self.max_logits_diff:.6f}")
  44. print(f"\tempty past = {self.max_logits_diff_no_past:.6f}")
  45. else:
  46. print(f"Metrics for {self.treatment} (baseline):")
  47. if self.seq_len_latency:
  48. print("Past sequence length range and average latency:")
  49. total = 0
  50. count = 0
  51. for key in sorted(self.seq_len_latency.keys()):
  52. average = statistics.mean(self.seq_len_latency[key]) * 1000.0
  53. if key == 0:
  54. print(f"\t{key}: \t{average:.2f} ms")
  55. else:
  56. print(f"\t[{2**key}, {2 ** (key + 1) - 1}]:\t{average:.2f} ms")
  57. total += average * len(self.seq_len_latency[key])
  58. count += len(self.seq_len_latency[key])
  59. print(f"Average Latency: {total / count:.2f} ms")
  60. def diff_logits(self, baseline_logits, treatment_logits, is_empty_past: bool):
  61. diff = (baseline_logits - treatment_logits).abs().max()
  62. if is_empty_past:
  63. self.max_logits_diff_no_past = max(self.max_logits_diff_no_past, diff)
  64. else:
  65. self.max_logits_diff = max(self.max_logits_diff, diff)
  66. return diff
  67. def start_batch(self, batch_size: int):
  68. self.total_samples += batch_size
  69. self.batch_top1_error = torch.zeros((batch_size, 1), dtype=torch.bool)
  70. self.batch_topk_error = torch.zeros((batch_size, 1), dtype=torch.bool)
  71. def eval_batch(self, baseline, treatment, past_seq_len, verbose=True):
  72. self._eval_topk(baseline.top_1_tokens, treatment.top_1_tokens, 1, verbose)
  73. self._eval_topk(baseline.top_k_tokens, treatment.top_k_tokens, self.top_k, verbose)
  74. max_diff = self.diff_logits(baseline.logits, treatment.logits, past_seq_len == 0)
  75. if verbose:
  76. print(f"Max logits diffs of {self.name}: {max_diff}")
  77. def _eval_topk(self, baseline_topk, treatment_topk, top_k, verbose=True):
  78. if not torch.all(torch.eq(baseline_topk, treatment_topk)):
  79. if top_k == 1:
  80. if verbose:
  81. print(f"Generated tokens not matched for {self.name}")
  82. self.batch_top1_error |= torch.eq(baseline_topk, treatment_topk).logical_not()
  83. else:
  84. if verbose:
  85. print(
  86. f"Top {top_k} tokens not matched for {self.name}. This will lead to wrong beam search results"
  87. )
  88. self.batch_topk_error |= (
  89. torch.eq(baseline_topk, treatment_topk).logical_not().sum(1).unsqueeze(dim=1) > 0
  90. )
  91. def end_batch(self):
  92. self.top_1_error += self.batch_top1_error.sum()
  93. self.top_k_error += self.batch_topk_error.sum()
  94. def add_latency(self, past_seq_len, latency):
  95. key = int(math.log2(past_seq_len)) + 1 if past_seq_len > 0 else 0
  96. if key not in self.seq_len_latency:
  97. self.seq_len_latency[key] = []
  98. self.seq_len_latency[key].append(latency)
  99. class Gpt2Tester:
  100. def __init__(
  101. self,
  102. input_ids,
  103. position_ids,
  104. attention_mask,
  105. num_attention_heads,
  106. hidden_size,
  107. num_layer,
  108. device,
  109. is_fp16=False,
  110. top_k=20,
  111. top_k_required_order=False,
  112. ):
  113. self.batch_size = input_ids.shape[0]
  114. self.input_length = input_ids.shape[1]
  115. self.n_layer = num_layer
  116. self.input_ids = input_ids
  117. self.position_ids = position_ids
  118. self.attention_mask = attention_mask
  119. self.has_position_ids = position_ids is not None
  120. self.has_attention_mask = attention_mask is not None
  121. # Empty past state for first inference
  122. self.past = []
  123. past_shape = [
  124. 2,
  125. self.batch_size,
  126. num_attention_heads,
  127. 0,
  128. hidden_size // num_attention_heads,
  129. ]
  130. for _i in range(num_layer):
  131. empty_past = torch.empty(past_shape).type(torch.float16 if is_fp16 else torch.float32)
  132. self.past.append(empty_past.to(device))
  133. self.logits = None
  134. self.top_1_tokens = None
  135. self.top_k_tokens = None
  136. self.top_k = top_k
  137. self.top_k_required_order = top_k_required_order
  138. def get_inputs(self) -> Gpt2Inputs:
  139. return Gpt2Inputs(self.input_ids, self.position_ids, self.attention_mask, self.past)
  140. def save_test_data(self, session, output, save_test_data_dir, test_case_id):
  141. from onnx import numpy_helper # noqa: PLC0415
  142. path = os.path.join(save_test_data_dir, "test_data_set_" + str(test_case_id))
  143. if os.path.exists(path):
  144. print(f"Directory {path} existed. Skip saving test data")
  145. return
  146. os.makedirs(path, exist_ok=True)
  147. def add_tensor(input_tensors, torch_tensor, name):
  148. input_tensors.append(numpy_helper.from_array(torch_tensor.clone().cpu().numpy(), name))
  149. input_tensors = []
  150. add_tensor(input_tensors, self.input_ids, "input_ids")
  151. if self.has_position_ids:
  152. add_tensor(input_tensors, self.position_ids, "position_ids")
  153. if self.has_attention_mask:
  154. add_tensor(input_tensors, self.attention_mask, "attention_mask")
  155. for i in range(self.n_layer):
  156. add_tensor(input_tensors, self.past[i], "past_" + str(i))
  157. for i, tensor in enumerate(input_tensors):
  158. with open(os.path.join(path, f"input_{i}.pb"), "wb") as f:
  159. f.write(tensor.SerializeToString())
  160. output_names = [output.name for output in session.get_outputs()]
  161. for i, _name in enumerate(output_names):
  162. tensor = numpy_helper.from_array(
  163. output[i] if isinstance(output[i], numpy.ndarray) else output[i].clone().cpu().numpy()
  164. )
  165. with open(os.path.join(path, f"output_{i}.pb"), "wb") as f:
  166. f.write(tensor.SerializeToString())
  167. print(f"Test data saved to directory {path}")
  168. def update(self, output, step, device):
  169. """
  170. Update the inputs for next inference.
  171. """
  172. self.logits = (
  173. torch.from_numpy(output[0]) if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu()
  174. )
  175. self.top_1_tokens = Gpt2Tester.predict_next_token(self.logits)
  176. self.top_k_tokens = Gpt2Tester.predict_next_token(self.logits, self.top_k, self.top_k_required_order)
  177. self.input_ids = self.top_1_tokens.clone().detach().reshape([self.batch_size, 1]).to(device)
  178. if self.has_position_ids:
  179. self.position_ids = (
  180. torch.tensor([self.input_length + step - 1]).unsqueeze(0).repeat(self.batch_size, 1).to(device)
  181. )
  182. if self.has_attention_mask:
  183. self.attention_mask = torch.cat(
  184. [
  185. self.attention_mask,
  186. torch.ones([self.batch_size, 1]).type_as(self.attention_mask),
  187. ],
  188. 1,
  189. ).to(device)
  190. self.past = []
  191. if isinstance(output[1], tuple): # past in torch output is tuple
  192. self.past = list(output[1])
  193. else:
  194. for i in range(self.n_layer):
  195. past_i = (
  196. torch.from_numpy(output[i + 1])
  197. if isinstance(output[i + 1], numpy.ndarray)
  198. else output[i + 1].clone().detach()
  199. )
  200. self.past.append(past_i.to(device))
  201. def diff(self, baseline):
  202. """
  203. Compare inputs and logits output.
  204. """
  205. print("start diff...")
  206. if self.logits is not None:
  207. max_io_diff = (self.logits - baseline.logits).abs().max()
  208. if max_io_diff > 1e-4:
  209. print(f"Max logits difference is too large: {max_io_diff}")
  210. if not torch.all(self.input_ids == baseline.input_ids):
  211. print("Input_ids is different", self.input_ids, baseline.input_ids)
  212. if self.has_position_ids:
  213. if not torch.all(self.position_ids == baseline.position_ids):
  214. print(
  215. "position_ids is different",
  216. self.position_ids,
  217. baseline.position_ids,
  218. )
  219. if self.has_attention_mask:
  220. if not torch.all(self.attention_mask == baseline.attention_mask):
  221. print(
  222. "attention_mask is different",
  223. self.attention_mask,
  224. baseline.attention_mask,
  225. )
  226. assert len(self.past) == len(baseline.past)
  227. for i, past_i in enumerate(self.past):
  228. assert past_i.shape == baseline.past[i].shape
  229. if past_i.nelement() > 0:
  230. max_past_diff = (past_i - baseline.past[i]).abs().max()
  231. if max_past_diff > 1e-4:
  232. print(f"max_past_diff[{i}]={max_past_diff}")
  233. @staticmethod
  234. def predict_next_token(logits, top_k=1, required_order=False):
  235. """
  236. Get top k topkens based on logits.
  237. """
  238. # logits has shape (batch_size, seq_len, vocab_size)
  239. # last token logits has shape (batch_size, vocab_size)
  240. lastTokenLogits = logits[:, -1] # noqa: N806
  241. if top_k == 1:
  242. generatedTokens = torch.argmax(lastTokenLogits, 1, True) # noqa: N806
  243. return generatedTokens
  244. else:
  245. topk = torch.argsort(lastTokenLogits, -1, descending=True)[:, :top_k]
  246. if not required_order:
  247. sorted_topk, _ = topk.sort()
  248. return sorted_topk
  249. return topk
  250. @staticmethod
  251. def diff_present(onnx_output, onnx_io_output, n_layer):
  252. """
  253. Compare the present outputs of two outputs from ONNX Runtime.
  254. """
  255. present_diff_max = []
  256. for i in range(n_layer):
  257. onnx_present_i = (
  258. torch.from_numpy(onnx_output[i + 1])
  259. if isinstance(onnx_output[i + 1], numpy.ndarray)
  260. else onnx_output[i + 1]
  261. )
  262. onnx_io_present_i = (
  263. torch.from_numpy(onnx_io_output[i + 1])
  264. if isinstance(onnx_io_output[i + 1], numpy.ndarray)
  265. else onnx_io_output[i + 1]
  266. )
  267. max_diff = (onnx_present_i - onnx_io_present_i).abs().max()
  268. present_diff_max.append(max_diff)
  269. print(f"present_diff_max={present_diff_max}")
  270. @staticmethod
  271. def is_quantized_onnx_model(onnx_model_path):
  272. """
  273. Returns True if the ONNX model is quantized.
  274. """
  275. from onnx import load # noqa: PLC0415
  276. model = load(onnx_model_path)
  277. from onnxruntime.quantization.quantize import __producer__ as quantize_producer # noqa: PLC0415
  278. return model.producer_name == quantize_producer
  279. @staticmethod
  280. def test_generation(
  281. session,
  282. model,
  283. device,
  284. test_inputs,
  285. precision=Precision.FLOAT32,
  286. model_class="Gpt2LMHeadModel",
  287. top_k=20,
  288. top_k_no_order=True,
  289. max_steps=24,
  290. max_inputs=0,
  291. verbose=False,
  292. save_test_data=0,
  293. save_test_data_dir=".",
  294. ):
  295. """
  296. Test Generation using greedy beam search (without sampling) to compare PyTorch and ONNX model.
  297. It will print top 1 and top k errors on the given test inputs.
  298. """
  299. print(
  300. f"start test generation: (top_k={top_k} top_k_no_order={top_k_no_order} max_steps={max_steps} test_inputs={len(test_inputs)} max_inputs={max_inputs})"
  301. )
  302. n_layer = model.config.n_layer
  303. n_head = model.config.n_head
  304. n_embd = model.config.n_embd
  305. eos_token_id = model.config.eos_token_id
  306. test_data_saved = 0
  307. is_float16 = precision == Precision.FLOAT16
  308. if is_float16:
  309. assert "float16" in session.get_outputs()[0].type
  310. # We will still use fp32 torch model as baseline when onnx model if fp16
  311. model.eval().to(device)
  312. # Allocate initial buffers for IO Binding of ONNX Runtimne. The buffer size will automatically increase later.
  313. init_output_shapes = Gpt2Helper.get_output_shapes(
  314. batch_size=4,
  315. past_sequence_length=128,
  316. sequence_length=32,
  317. config=model.config,
  318. model_class=model_class,
  319. )
  320. output_buffers = Gpt2Helper.get_output_buffers(init_output_shapes, device, is_float16=is_float16)
  321. baseline_name = "Torch"
  322. treatment_name = "Quantized Onnx" if precision == Precision.INT8 else "Onnx"
  323. torch_metric = Gpt2Metric(baseline_name, baseline_name, top_k)
  324. onnx_metric = Gpt2Metric(treatment_name, baseline_name, top_k)
  325. onnx_io_metric = Gpt2Metric(treatment_name + " with IO Binding", baseline_name, top_k)
  326. for i, inputs in enumerate(test_inputs):
  327. if max_inputs > 0 and i == max_inputs:
  328. break
  329. if i % 10 == 0:
  330. print(f"{i}")
  331. input_ids = inputs["input_ids"]
  332. position_ids = inputs.get("position_ids", None)
  333. attention_mask = inputs.get("attention_mask", None)
  334. onnx_runner = Gpt2Tester(
  335. input_ids,
  336. position_ids,
  337. attention_mask,
  338. n_head,
  339. n_embd,
  340. n_layer,
  341. device,
  342. is_float16,
  343. top_k,
  344. not top_k_no_order,
  345. )
  346. onnx_io_runner = Gpt2Tester(
  347. input_ids,
  348. position_ids,
  349. attention_mask,
  350. n_head,
  351. n_embd,
  352. n_layer,
  353. device,
  354. is_float16,
  355. top_k,
  356. not top_k_no_order,
  357. )
  358. torch_runner = Gpt2Tester(
  359. input_ids,
  360. position_ids,
  361. attention_mask,
  362. n_head,
  363. n_embd,
  364. n_layer,
  365. device,
  366. False,
  367. top_k,
  368. not top_k_no_order,
  369. ) # Torch model baseline is fp32
  370. batch_size = torch_runner.batch_size
  371. onnx_metric.start_batch(batch_size)
  372. onnx_io_metric.start_batch(batch_size)
  373. with torch.no_grad():
  374. done = torch.zeros(batch_size, dtype=torch.bool)
  375. for step in range(max_steps):
  376. seq_len = list(onnx_runner.input_ids.size())[1]
  377. past_seq_len = list(onnx_runner.past[0].size())[3]
  378. start_time = timeit.default_timer()
  379. pytorch_output = Gpt2Helper.pytorch_inference(model, torch_runner.get_inputs())
  380. torch_metric.add_latency(past_seq_len, timeit.default_timer() - start_time)
  381. torch_runner.update(pytorch_output, step, device)
  382. onnx_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference(
  383. session, onnx_runner.get_inputs(), total_runs=1
  384. )
  385. onnx_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
  386. onnx_runner.update(onnx_output, step, device)
  387. output_shapes = Gpt2Helper.get_output_shapes(
  388. batch_size,
  389. past_seq_len,
  390. seq_len,
  391. model.config,
  392. model_class=model_class,
  393. )
  394. Gpt2Helper.auto_increase_buffer_size(output_buffers, output_shapes)
  395. (
  396. onnx_io_output,
  397. avg_latency_ms,
  398. ) = Gpt2Helper.onnxruntime_inference_with_binded_io(
  399. session,
  400. onnx_io_runner.get_inputs(),
  401. output_buffers,
  402. output_shapes,
  403. total_runs=1,
  404. return_numpy=False,
  405. include_copy_output_latency=True,
  406. )
  407. onnx_io_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
  408. if test_data_saved < save_test_data:
  409. onnx_io_runner.save_test_data(session, onnx_io_output, save_test_data_dir, test_data_saved)
  410. test_data_saved += 1
  411. onnx_io_runner.update(onnx_io_output, step, device)
  412. if verbose:
  413. onnx_runner.diff(onnx_io_runner)
  414. Gpt2Tester.diff_present(onnx_output, onnx_io_output, n_layer)
  415. print("Top 1 tokens:")
  416. print("\tTorch", torch_runner.top_1_tokens)
  417. print("\tONNX", onnx_runner.top_1_tokens)
  418. print("\tONNX with IO binding", onnx_io_runner.top_1_tokens)
  419. onnx_metric.eval_batch(torch_runner, onnx_runner, past_seq_len, verbose=verbose)
  420. onnx_io_metric.eval_batch(torch_runner, onnx_io_runner, past_seq_len, verbose=verbose)
  421. done = done | (torch_runner.top_1_tokens == eos_token_id).any()
  422. if torch.all(done):
  423. break
  424. onnx_metric.end_batch()
  425. onnx_io_metric.end_batch()
  426. torch_metric.print()
  427. onnx_metric.print()
  428. onnx_io_metric.print()