gpt2_helper.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. # This script helps onnx conversion and validation for GPT2 model with past state.
  7. import logging
  8. import os
  9. import pickle
  10. import random
  11. import shutil
  12. import tempfile
  13. import time
  14. from pathlib import Path
  15. import numpy
  16. import onnx
  17. import torch
  18. from benchmark_helper import Precision
  19. from float16 import float_to_float16_max_diff
  20. from fusion_options import FusionOptions
  21. from io_binding_helper import IOBindingHelper
  22. from onnx_model import OnnxModel
  23. from optimizer import optimize_model
  24. from torch_onnx_export_helper import torch_onnx_export
  25. from transformers import GPT2Config, GPT2LMHeadModel, GPT2Model, TFGPT2Model
  26. logger = logging.getLogger(__name__)
  27. PRETRAINED_GPT2_MODELS = ["distilgpt2", "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"]
  28. DEFAULT_TOLERANCE = {
  29. Precision.FLOAT32: 0.0005,
  30. Precision.FLOAT16: 0.2,
  31. Precision.INT8: 3.0,
  32. }
  33. class GPT2ModelNoPastState(GPT2Model):
  34. """Here we wrap a class to disable past state output."""
  35. def __init__(self, config):
  36. super().__init__(config)
  37. def forward(self, input_ids):
  38. return super().forward(input_ids, use_cache=False, return_dict=False)
  39. class TFGPT2ModelNoPastState(TFGPT2Model):
  40. """Here we wrap a class to disable past state output."""
  41. def __init__(self, config):
  42. config.use_cache = False
  43. super().__init__(config)
  44. def forward(self, input_ids):
  45. return super().call(input_ids, use_cache=False)
  46. class MyGPT2Model(GPT2Model):
  47. """Here we wrap a class for Onnx model conversion for GPT2Model with past state."""
  48. def __init__(self, config):
  49. super().__init__(config)
  50. @staticmethod
  51. def post_process(result, num_layer):
  52. if isinstance(result[1][0], (tuple, list)):
  53. assert len(result[1]) == num_layer and len(result[1][0]) == 2
  54. # assert len(result[1][0][0].shape) == 4 and result[1][0][0].shape == result[1][0][1].shape
  55. present = []
  56. for i in range(num_layer):
  57. # Since transformers v4.*, past key and values are separated outputs.
  58. # Here we concate them into one tensor to be compatible with Attention operator.
  59. present.append(
  60. torch.cat(
  61. (result[1][i][0].unsqueeze(0), result[1][i][1].unsqueeze(0)),
  62. dim=0,
  63. )
  64. )
  65. return (result[0], tuple(present))
  66. return result
  67. def forward(self, input_ids, position_ids, attention_mask, *past):
  68. result = super().forward(
  69. input_ids,
  70. position_ids=position_ids,
  71. attention_mask=attention_mask,
  72. past_key_values=past,
  73. return_dict=False,
  74. )
  75. return MyGPT2Model.post_process(result, self.config.n_layer)
  76. class MyGPT2LMHeadModel(GPT2LMHeadModel):
  77. """Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state."""
  78. def __init__(self, config):
  79. super().__init__(config)
  80. def forward(self, input_ids, position_ids, attention_mask, *past):
  81. result = super().forward(
  82. input_ids,
  83. position_ids=position_ids,
  84. attention_mask=attention_mask,
  85. past_key_values=past,
  86. return_dict=False,
  87. )
  88. return MyGPT2Model.post_process(result, self.config.n_layer)
  89. class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel): # noqa: N801
  90. """Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and no padding.
  91. When you always use batch_size=1 in inference, there is no padding in inputs. In such case, position_ids
  92. and attention_mask need no be in inputs.
  93. """
  94. def __init__(self, config):
  95. super().__init__(config)
  96. def forward(self, input_ids, *past):
  97. result = super().forward(input_ids, past_key_values=past, return_dict=False)
  98. return MyGPT2Model.post_process(result, self.config.n_layer)
  99. # Maps model class name to a tuple of model class, name of first output and use padding or not
  100. MODEL_CLASSES = {
  101. "GPT2LMHeadModel": (MyGPT2LMHeadModel, "logits", True),
  102. "GPT2LMHeadModel_NoPadding": (MyGPT2LMHeadModel_NoPadding, "logits", False),
  103. "GPT2Model": (MyGPT2Model, "last_state", True),
  104. }
  105. class Gpt2Inputs:
  106. def __init__(self, input_ids, position_ids, attention_mask, past):
  107. self.input_ids: torch.LongTensor = input_ids
  108. self.position_ids: torch.LongTensor = position_ids
  109. self.attention_mask: torch.LongTensor | torch.FloatTensor | torch.HalfTensor = attention_mask
  110. self.past: list[torch.FloatTensor] | list[torch.HalfTensor] = past
  111. def to_list(self) -> list:
  112. input_list = [v for v in [self.input_ids, self.position_ids, self.attention_mask] if v is not None]
  113. if self.past:
  114. input_list.extend(self.past)
  115. return input_list
  116. def to_tuple(self) -> tuple:
  117. return tuple(v for v in [self.input_ids, self.position_ids, self.attention_mask, self.past] if v is not None)
  118. def to_fp32(self):
  119. # For attention mask, only convert fp16 to fp32, and keep the original type if it is integer.
  120. attention_mask = None
  121. if self.attention_mask is not None:
  122. attention_mask = (
  123. self.attention_mask.to(dtype=torch.float32)
  124. if (self.attention_mask.dtype == torch.float16)
  125. else self.attention_mask
  126. )
  127. past = [p.to(dtype=torch.float32) for p in self.past]
  128. return Gpt2Inputs(self.input_ids, self.position_ids, attention_mask, past)
  129. class Gpt2Helper:
  130. """A helper class for Gpt2 model conversion, inference and verification."""
  131. @staticmethod
  132. def get_dummy_inputs(
  133. batch_size: int,
  134. past_sequence_length: int,
  135. sequence_length: int,
  136. num_attention_heads: int,
  137. hidden_size: int,
  138. num_layer: int,
  139. vocab_size: int,
  140. device: torch.device,
  141. float16: bool = False,
  142. has_position_ids: bool = True,
  143. has_attention_mask: bool = True,
  144. input_ids_dtype: torch.dtype = torch.int32,
  145. position_ids_dtype: torch.dtype = torch.int32,
  146. attention_mask_dtype: torch.dtype = torch.int32,
  147. left_side_padding: bool = True,
  148. ) -> Gpt2Inputs:
  149. """Create random inputs for GPT2 model.
  150. Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors.
  151. """
  152. float_type = torch.float16 if float16 else torch.float32
  153. past_shape = [
  154. 2,
  155. batch_size,
  156. num_attention_heads,
  157. past_sequence_length,
  158. int(hidden_size / num_attention_heads),
  159. ]
  160. past = [(torch.rand(past_shape, dtype=float_type, device=device) * 2.0 - 1.0) for _ in range(num_layer)]
  161. input_ids = torch.randint(
  162. low=0,
  163. high=vocab_size - 1,
  164. size=(batch_size, sequence_length),
  165. dtype=input_ids_dtype,
  166. device=device,
  167. )
  168. attention_mask = None
  169. if has_attention_mask:
  170. total_sequence_length = past_sequence_length + sequence_length
  171. attention_mask = torch.ones(
  172. [batch_size, total_sequence_length],
  173. dtype=attention_mask_dtype,
  174. device=device,
  175. )
  176. if total_sequence_length >= 2:
  177. for i in range(batch_size):
  178. padding_length = random.randint(0, total_sequence_length - 1)
  179. if left_side_padding:
  180. attention_mask[i, :padding_length] = 0
  181. else: # right side padding
  182. attention_mask[i, total_sequence_length - padding_length :] = 0
  183. # Deduce position_ids from attention mask
  184. position_ids = None
  185. if has_position_ids:
  186. position_ids = attention_mask.long().cumsum(-1) - 1
  187. position_ids.masked_fill_(position_ids < 0, 0)
  188. position_ids = position_ids[:, past_sequence_length:].to(position_ids_dtype)
  189. return Gpt2Inputs(input_ids, position_ids, attention_mask, past)
  190. @staticmethod
  191. def get_output_shapes(
  192. batch_size: int,
  193. past_sequence_length: int,
  194. sequence_length: int,
  195. config: GPT2Config,
  196. model_class: str = "GPT2LMHeadModel",
  197. ) -> dict[str, list[int]]:
  198. """Returns a dictionary with output name as key, and shape as value."""
  199. num_attention_heads = config.num_attention_heads
  200. hidden_size = config.hidden_size
  201. num_layer = config.num_hidden_layers
  202. vocab_size = config.vocab_size
  203. output_name = MODEL_CLASSES[model_class][1]
  204. last_state_shape = [
  205. batch_size,
  206. sequence_length,
  207. vocab_size if output_name == "logits" else hidden_size,
  208. ]
  209. present_state_shape = [
  210. 2,
  211. batch_size,
  212. num_attention_heads,
  213. past_sequence_length + sequence_length,
  214. int(hidden_size / num_attention_heads),
  215. ]
  216. output_shapes = {output_name: last_state_shape}
  217. for i in range(num_layer):
  218. output_shapes["present_" + str(i)] = present_state_shape
  219. return output_shapes
  220. @staticmethod
  221. def auto_increase_buffer_size(output_buffers, output_shapes):
  222. for key in output_shapes:
  223. assert key in output_buffers
  224. buffer = output_buffers[key]
  225. if numpy.prod(output_shapes[key]) > buffer.nelement():
  226. output_buffers[key] = torch.empty(
  227. numpy.prod(output_shapes[key]),
  228. dtype=buffer.dtype,
  229. device=buffer.device,
  230. )
  231. @staticmethod
  232. def get_output_buffers(output_shapes, device, is_float16=False):
  233. """Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
  234. data_type = torch.float16 if is_float16 else torch.float32
  235. output_buffers = {}
  236. for name, shape in output_shapes.items():
  237. output_buffers[name] = torch.empty(numpy.prod(shape), dtype=data_type, device=device)
  238. return output_buffers
  239. @staticmethod
  240. def diff_outputs(torch_outputs, ort_outputs, relative=False):
  241. """Returns the maximum difference between PyTorch and OnnxRuntime outputs."""
  242. expected_outputs = torch_outputs[0].cpu().numpy()
  243. diff = numpy.abs(expected_outputs - ort_outputs[0])
  244. if relative:
  245. return numpy.amax(diff / (numpy.abs(expected_outputs) + 1e-6))
  246. else:
  247. return numpy.amax(diff)
  248. @staticmethod
  249. def compare_outputs(torch_outputs, ort_outputs, rtol=1e-03, atol=1e-03, **kwargs):
  250. """Returns True if torch and ORT outputs are close for given thresholds, and False otherwise.
  251. Note: need kwargs since Gpt2BeamSearchHelper.compare_outputs has an extra parameter model_class
  252. """
  253. is_close = numpy.allclose(ort_outputs[0], torch_outputs[0].cpu().numpy(), rtol=rtol, atol=atol)
  254. logger.debug(f"PyTorch and OnnxRuntime output 0 (last_state) are close: {is_close}")
  255. is_all_close = is_close
  256. num_layers = len(ort_outputs) - 1
  257. for layer in range(num_layers):
  258. is_close = numpy.allclose(
  259. ort_outputs[1 + layer],
  260. torch_outputs[1][layer].cpu().numpy(),
  261. rtol=rtol,
  262. atol=atol,
  263. )
  264. logger.debug(f"PyTorch and OnnxRuntime layer {layer} state (present_{layer}) are close:{is_close}")
  265. is_all_close = is_all_close and is_close
  266. if not is_all_close:
  267. max_abs_diff = Gpt2Helper.diff_outputs(torch_outputs, ort_outputs)
  268. logger.info(f"PyTorch and OnnxRuntime results are not all close: max_abs_diff={max_abs_diff:.5f}")
  269. return is_all_close
  270. @staticmethod
  271. def compare_outputs_v2(torch_outputs, ort_outputs, atol=1e-06):
  272. """Compare outputs from PyTorch and OnnxRuntime
  273. Args:
  274. torch_outputs (Tuple[Torch.Tensor]): PyTorch model output
  275. ort_outputs (List[numpy.ndarray]): OnnxRuntime output
  276. atol (float, optional): Absolute tollerance. Defaults to 1e-06.
  277. Returns:
  278. is_all_close(bool): whether all elements are close.
  279. max_abs_diff(float): maximum absolute difference.
  280. messages(str): a list of debug message for each output
  281. """
  282. is_all_close = True
  283. is_top1_matched = False
  284. max_diffs = []
  285. messages = []
  286. for i in range(len(ort_outputs)):
  287. ort_output = ort_outputs[i]
  288. torch_output = (torch_outputs[0] if i == 0 else torch_outputs[1][i - 1]).cpu().numpy()
  289. is_close = numpy.allclose(ort_output, torch_output, atol=atol, rtol=0)
  290. max_diffs.append(numpy.amax(numpy.abs(torch_output - ort_output)))
  291. is_all_close = is_all_close and is_close
  292. if numpy.isnan(torch_output).any():
  293. logger.debug(f"PyTorch output {i} has nan")
  294. if numpy.isinf(torch_output).any():
  295. logger.debug(f"PyTorch output {i} has inf")
  296. if numpy.isnan(ort_output).any():
  297. logger.debug(f"ORT output {i} has nan")
  298. if numpy.isinf(ort_output).any():
  299. logger.debug(f"ORT output {i} has inf")
  300. diff = numpy.fabs(ort_output - torch_output)
  301. idx = numpy.unravel_index(diff.argmax(), diff.shape)
  302. messages.append(
  303. f"diff={diff[idx]:.9f} index={idx} ort={ort_output[idx]:.9f} torch={float(torch_output[idx]):.9f}"
  304. )
  305. if i == 0: # logits
  306. ort_max_index = numpy.unravel_index(numpy.argmax(ort_output, axis=None), ort_output.shape)
  307. torch_max_index = numpy.unravel_index(numpy.argmax(torch_output, axis=None), torch_output.shape)
  308. is_top1_matched = numpy.array_equal(ort_max_index, torch_max_index)
  309. max_diff_output_index = max_diffs.index(max(max_diffs))
  310. return (
  311. is_all_close,
  312. max(max_diffs),
  313. max_diff_output_index,
  314. messages,
  315. is_top1_matched,
  316. )
  317. @staticmethod
  318. def export_onnx(
  319. model,
  320. device,
  321. onnx_model_path: str,
  322. verbose: bool = False,
  323. use_external_data_format: bool = False,
  324. has_position_ids: bool = True,
  325. has_attention_mask: bool = True,
  326. input_ids_dtype: torch.dtype = torch.int32,
  327. position_ids_dtype: torch.dtype = torch.int32,
  328. attention_mask_dtype: torch.dtype = torch.int32,
  329. ):
  330. """Export GPT-2 model with past state to ONNX model."""
  331. config: GPT2Config = model.config
  332. num_layer = config.n_layer
  333. dummy_inputs = Gpt2Helper.get_dummy_inputs(
  334. batch_size=1,
  335. past_sequence_length=1,
  336. sequence_length=1,
  337. num_attention_heads=config.num_attention_heads,
  338. hidden_size=config.hidden_size,
  339. num_layer=num_layer,
  340. vocab_size=config.vocab_size,
  341. device=device,
  342. float16=False,
  343. has_position_ids=has_position_ids,
  344. has_attention_mask=has_attention_mask,
  345. input_ids_dtype=input_ids_dtype,
  346. position_ids_dtype=position_ids_dtype,
  347. attention_mask_dtype=attention_mask_dtype,
  348. )
  349. input_list = dummy_inputs.to_list()
  350. with torch.no_grad():
  351. outputs = model(*input_list)
  352. past_names = [f"past_{i}" for i in range(num_layer)]
  353. present_names = [f"present_{i}" for i in range(num_layer)]
  354. # GPT2Model outputs last_state; GPT2LMHeadModel outputs logits (prediction_scores)
  355. assert outputs[0].shape[2] == config.vocab_size or outputs[0].shape[2] == config.hidden_size
  356. output_names = ["logits" if outputs[0].shape[2] == config.vocab_size else "last_state", *present_names]
  357. # Shape of input tensors:
  358. # input_ids: (batch_size, seq_len)
  359. # past_{i}: (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads)
  360. # attention_mask: (batch_size, past_seq_len + seq_len)
  361. # Shape of output tensors:
  362. # last_state: (batch_size, seq_len, hidden_size)
  363. # or logits: (batch_size, seq_len, vocab_size)
  364. # present_{i}: (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads)
  365. dynamic_axes = {
  366. "input_ids": {0: "batch_size", 1: "seq_len"},
  367. output_names[0]: {0: "batch_size", 1: "seq_len"},
  368. }
  369. for name in past_names:
  370. dynamic_axes[name] = {1: "batch_size", 3: "past_seq_len"}
  371. for name in present_names:
  372. dynamic_axes[name] = {1: "batch_size", 3: "total_seq_len"}
  373. input_names = ["input_ids"]
  374. if has_position_ids:
  375. dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
  376. input_names.append("position_ids")
  377. if has_attention_mask:
  378. dynamic_axes["attention_mask"] = {0: "batch_size", 1: "total_seq_len"}
  379. input_names.append("attention_mask")
  380. input_names.extend(past_names)
  381. assert len(outputs) == 2 and len(outputs[1]) == num_layer
  382. logger.info(
  383. f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}"
  384. )
  385. Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  386. if use_external_data_format:
  387. # We let PyTorch export onnx to a temp directory first, then convert external data to one file.
  388. with tempfile.TemporaryDirectory() as tmp_dir_name:
  389. temp_onnx_model_path = os.path.join(tmp_dir_name, "gpt2.onnx")
  390. Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  391. torch_onnx_export(
  392. model,
  393. args=tuple(input_list),
  394. f=temp_onnx_model_path,
  395. export_params=True,
  396. input_names=input_names,
  397. output_names=output_names,
  398. dynamic_axes=dynamic_axes,
  399. opset_version=11,
  400. do_constant_folding=True,
  401. use_external_data_format=True,
  402. verbose=verbose,
  403. )
  404. model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
  405. OnnxModel.save(
  406. model,
  407. onnx_model_path,
  408. save_as_external_data=True,
  409. all_tensors_to_one_file=True,
  410. )
  411. else:
  412. torch_onnx_export(
  413. model,
  414. args=tuple(input_list),
  415. f=onnx_model_path,
  416. export_params=True,
  417. input_names=input_names,
  418. output_names=output_names,
  419. dynamic_axes=dynamic_axes,
  420. opset_version=11,
  421. do_constant_folding=True,
  422. use_external_data_format=False,
  423. verbose=verbose,
  424. )
  425. @staticmethod
  426. def optimize_onnx(
  427. onnx_model_path,
  428. optimized_model_path,
  429. is_float16,
  430. num_attention_heads,
  431. hidden_size,
  432. use_external_data_format=False,
  433. auto_mixed_precision=False,
  434. stage=0,
  435. **kwargs,
  436. ):
  437. """Optimize ONNX model with an option to convert it to use mixed precision."""
  438. optimization_options = FusionOptions("gpt2")
  439. m = optimize_model(
  440. onnx_model_path,
  441. model_type="gpt2",
  442. num_heads=num_attention_heads,
  443. hidden_size=hidden_size,
  444. opt_level=0,
  445. optimization_options=optimization_options,
  446. use_gpu=False,
  447. )
  448. if is_float16:
  449. if auto_mixed_precision:
  450. Gpt2Helper.auto_mixed_precision(m)
  451. else:
  452. if "keep_io_types" not in kwargs:
  453. kwargs["keep_io_types"] = False
  454. m.convert_float_to_float16(use_symbolic_shape_infer=True, **kwargs)
  455. m.save_model_to_file(optimized_model_path, use_external_data_format)
  456. return m
  457. @staticmethod
  458. def auto_mixed_precision(
  459. onnx_model: OnnxModel,
  460. op_block_list: list[str] = [ # noqa: B006
  461. "Add",
  462. "LayerNormalization",
  463. "SkipLayerNormalization",
  464. "FastGelu",
  465. "EmbedLayerNormalization",
  466. ],
  467. ):
  468. """Convert GPT-2 model to mixed precision.
  469. It detects whether original model has fp16 weights, and set parameters for float16 conversion automatically.
  470. Args:
  471. onnx_model (OnnxModel): optimized ONNX model
  472. op_block_list (List[str], optional): operators to compute in fp32. Defaults to ["Add", "LayerNormalization",
  473. "SkipLayerNormalization", "FastGelu", "EmbedLayerNormalization"]
  474. Returns:
  475. parameters(dict): a dictionary of parameters used in float16 conversion
  476. """
  477. op_full_set = {node.op_type for node in onnx_model.nodes()}
  478. fp32_op_set = set(op_block_list)
  479. fp16_op_set = op_full_set.difference(fp32_op_set)
  480. logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
  481. # logits is the first output
  482. logits_output_name = onnx_model.graph().output[0].name
  483. # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
  484. is_weight_fp16_precision = False
  485. output_name_to_node = onnx_model.output_name_to_node()
  486. assert logits_output_name in output_name_to_node
  487. node = output_name_to_node[logits_output_name]
  488. last_matmul_node = None
  489. if node.op_type == "MatMul":
  490. last_matmul_node = node
  491. logger.info(f"Found last MatMul node for logits: {node.name}")
  492. initializer = None
  493. for input in node.input:
  494. initializer = onnx_model.get_initializer(input)
  495. if initializer is not None:
  496. break
  497. # when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
  498. # we can deduce that the weights are stored in float16 precision.
  499. max_diff = float_to_float16_max_diff(initializer)
  500. logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
  501. is_weight_fp16_precision = max_diff < 1e-6
  502. else:
  503. logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
  504. keep_io_types = []
  505. node_block_list = []
  506. if (not is_weight_fp16_precision) and (last_matmul_node is not None):
  507. # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
  508. keep_io_types = [logits_output_name]
  509. node_block_list = [last_matmul_node.name]
  510. parameters = {
  511. "keep_io_types": keep_io_types,
  512. "op_block_list": op_block_list,
  513. "node_block_list": node_block_list,
  514. "force_fp16_initializers": is_weight_fp16_precision,
  515. }
  516. logger.info(f"auto_mixed_precision parameters: {parameters}")
  517. onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
  518. return parameters
  519. @staticmethod
  520. def pytorch_inference(model, inputs: Gpt2Inputs, total_runs: int = 0):
  521. """Run inference of PyTorch model, and returns average latency in ms when total_runs > 0 besides outputs."""
  522. logger.debug("start pytorch_inference")
  523. # Convert it to fp32 as the PyTroch model cannot deal with half input.
  524. input_list = inputs.to_fp32().to_list()
  525. with torch.no_grad():
  526. outputs = model(*input_list)
  527. if total_runs == 0:
  528. return outputs
  529. latency = []
  530. with torch.no_grad():
  531. for _ in range(total_runs):
  532. start = time.time()
  533. outputs = model(*input_list)
  534. latency.append(time.time() - start)
  535. average_latency = sum(latency) * 1000 / len(latency)
  536. logger.debug("PyTorch inference time = {} ms".format(format(average_latency, ".2f"))) # noqa: G001
  537. return outputs, average_latency
  538. @staticmethod
  539. def onnxruntime_inference(ort_session, inputs: Gpt2Inputs, total_runs: int = 0):
  540. """Run inference of ONNX model, and returns average latency in ms when total_runs > 0 besides outputs."""
  541. logger.debug("start onnxruntime_inference")
  542. ort_inputs = {"input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy())}
  543. if inputs.past is not None:
  544. for i, past_i in enumerate(inputs.past):
  545. ort_inputs[f"past_{i}"] = numpy.ascontiguousarray(past_i.cpu().numpy())
  546. if inputs.attention_mask is not None:
  547. ort_inputs["attention_mask"] = numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy())
  548. if inputs.position_ids is not None:
  549. ort_inputs["position_ids"] = numpy.ascontiguousarray(inputs.position_ids.cpu().numpy())
  550. ort_outputs = ort_session.run(None, ort_inputs)
  551. if total_runs == 0:
  552. return ort_outputs
  553. latency = []
  554. for _ in range(total_runs):
  555. start = time.time()
  556. ort_outputs = ort_session.run(None, ort_inputs)
  557. latency.append(time.time() - start)
  558. average_latency = sum(latency) * 1000 / len(latency)
  559. logger.debug("OnnxRuntime Inference time = {} ms".format(format(average_latency, ".2f"))) # noqa: G001
  560. return ort_outputs, average_latency
  561. @staticmethod
  562. def prepare_io_binding(
  563. ort_session,
  564. input_ids,
  565. position_ids,
  566. attention_mask,
  567. past,
  568. output_buffers,
  569. output_shapes,
  570. ):
  571. """Returnas IO binding object for a session."""
  572. return IOBindingHelper.prepare_io_binding(
  573. ort_session,
  574. input_ids,
  575. position_ids,
  576. attention_mask,
  577. past,
  578. output_buffers,
  579. output_shapes,
  580. )
  581. @staticmethod
  582. def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
  583. """Copy results to cpu. Returns a list of numpy array."""
  584. return IOBindingHelper.get_outputs_from_io_binding_buffer(
  585. ort_session, output_buffers, output_shapes, return_numpy
  586. )
  587. @staticmethod
  588. def onnxruntime_inference_with_binded_io(
  589. ort_session,
  590. inputs: Gpt2Inputs,
  591. output_buffers: dict[str, torch.Tensor],
  592. output_shapes: dict[str, list[int]],
  593. total_runs: int = 0,
  594. return_numpy: bool = True,
  595. include_copy_output_latency: bool = False,
  596. ):
  597. """Inference with IO binding. Returns outputs, and optional latency when total_runs > 0."""
  598. logger.debug("start onnxruntime_inference_with_binded_io")
  599. # Bind inputs and outputs to onnxruntime session
  600. io_binding = Gpt2Helper.prepare_io_binding(
  601. ort_session,
  602. inputs.input_ids,
  603. inputs.position_ids,
  604. inputs.attention_mask,
  605. inputs.past,
  606. output_buffers,
  607. output_shapes,
  608. )
  609. # Run onnxruntime with io binding
  610. ort_session.run_with_iobinding(io_binding)
  611. # Copy results to cpu for verification
  612. ort_outputs = Gpt2Helper.get_outputs_from_io_binding_buffer(
  613. ort_session, output_buffers, output_shapes, return_numpy
  614. )
  615. if total_runs == 0:
  616. return ort_outputs
  617. latency = []
  618. for _ in range(total_runs):
  619. start = time.time()
  620. # Run onnxruntime with io binding
  621. ort_session.run_with_iobinding(io_binding)
  622. if include_copy_output_latency:
  623. _ = Gpt2Helper.get_outputs_from_io_binding_buffer(
  624. ort_session, output_buffers, output_shapes, return_numpy
  625. )
  626. latency.append(time.time() - start)
  627. average_latency = sum(latency) * 1000 / len(latency)
  628. logger.debug("OnnxRuntime with IO binding inference time = %.2f ms", average_latency)
  629. return ort_outputs, average_latency
  630. @staticmethod
  631. def save_outputs(i, ort_outputs, torch_outputs):
  632. with open(f"ort_outputs_{i}.pickle", "wb") as f:
  633. pickle.dump(ort_outputs, f)
  634. logger.info(f"ORT output are saved to ort_outputs_{i}.pickle")
  635. with open(f"torch_outputs_{i}.pickle", "wb") as f:
  636. pickle.dump(torch_outputs, f)
  637. logger.info(f"Torch output are saved to torch_outputs_{i}.pickle")
  638. @staticmethod
  639. def save_inputs(i, dummy_inputs, ort_outputs, torch_outputs):
  640. with open(f"dummy_inputs_{i}.pickle", "wb") as f:
  641. pickle.dump(dummy_inputs, f)
  642. logger.info(f"inputs are saved to dummy_inputs_{i}.pickle")
  643. @staticmethod
  644. def test_parity(
  645. ort_session,
  646. model,
  647. device,
  648. is_float16=False,
  649. rtol=5e-4,
  650. atol=5e-4,
  651. test_cases_per_run=10000,
  652. total_runs=1,
  653. use_io_binding=True,
  654. model_class="GPT2LMHeadModel",
  655. has_position_ids=True,
  656. has_attention_mask=True,
  657. input_ids_dtype=torch.int32,
  658. position_ids_dtype=torch.int32,
  659. attention_mask_dtype=torch.int32,
  660. stage=0,
  661. verbose=False,
  662. enable_pickle_output=False,
  663. ):
  664. """Generate random inputs and compare the results of PyTorch and Onnx Runtime."""
  665. config: GPT2Config = model.config
  666. logger.info(
  667. f"Running parity test (atol={atol}, test_cases={test_cases_per_run}, runs={total_runs}, use_io_binding={use_io_binding}, model_class={model_class}, is_float16={is_float16}) ..."
  668. )
  669. max_batch_size = 8
  670. max_past_seq_len = 4 # Do not use large number here for higher chance of hitting empty past (past_seq_len=0)
  671. max_seq_len = 2
  672. output_buffers = None
  673. if use_io_binding:
  674. max_output_shapes = Gpt2Helper.get_output_shapes(
  675. max_batch_size, max_past_seq_len, max_seq_len, config, model_class
  676. )
  677. output_buffers = Gpt2Helper.get_output_buffers(max_output_shapes, device, is_float16)
  678. passed_test_cases = 0
  679. top1_matched_cases = 0
  680. max_abs_diff_list = []
  681. top1_matched_cases_per_run = [0] * total_runs
  682. total_test_cases = test_cases_per_run * total_runs
  683. for i in range(total_test_cases):
  684. run_id = int(i / test_cases_per_run)
  685. sequence_length = random.randint(1, max_seq_len)
  686. past_sequence_length = 0 if (stage == 1) else random.randint(0, max_past_seq_len)
  687. batch_size = random.randint(1, max_batch_size)
  688. logger.debug(
  689. f"Running parity test for batch_size={batch_size} past_sequence_length={past_sequence_length}..."
  690. )
  691. dummy_inputs = Gpt2Helper.get_dummy_inputs(
  692. batch_size,
  693. past_sequence_length,
  694. sequence_length,
  695. config.num_attention_heads,
  696. config.hidden_size,
  697. config.n_layer,
  698. config.vocab_size,
  699. device,
  700. is_float16,
  701. has_position_ids,
  702. has_attention_mask,
  703. input_ids_dtype=input_ids_dtype,
  704. position_ids_dtype=position_ids_dtype,
  705. attention_mask_dtype=attention_mask_dtype,
  706. left_side_padding=True,
  707. )
  708. outputs = Gpt2Helper.pytorch_inference(model, dummy_inputs)
  709. if use_io_binding:
  710. ort_outputs = Gpt2Helper.onnxruntime_inference(ort_session, dummy_inputs)
  711. else:
  712. output_shapes = Gpt2Helper.get_output_shapes(
  713. batch_size,
  714. past_sequence_length,
  715. sequence_length,
  716. config,
  717. model_class,
  718. )
  719. ort_outputs = Gpt2Helper.onnxruntime_inference_with_binded_io(
  720. ort_session, dummy_inputs, output_buffers, output_shapes
  721. )
  722. (
  723. is_all_close,
  724. max_abs_diff,
  725. max_diff_output_index,
  726. messages,
  727. is_top1_matched,
  728. ) = Gpt2Helper.compare_outputs_v2(outputs, ort_outputs, atol=atol)
  729. if not numpy.isnan(max_abs_diff):
  730. max_abs_diff_list.append(max_abs_diff)
  731. if is_all_close:
  732. passed_test_cases += 1
  733. if is_top1_matched:
  734. top1_matched_cases += 1
  735. top1_matched_cases_per_run[run_id] += 1
  736. if verbose and not is_all_close:
  737. logger.info(
  738. f"test_case={i} batch_size={batch_size} past_sequence_length={past_sequence_length} sequence_length={sequence_length} MaxDiff={max_abs_diff}"
  739. )
  740. for i, message in enumerate(messages): # noqa: PLW2901
  741. logger.info(f"\t{i}: Name={ort_session.get_outputs()[i].name}, {message}")
  742. # Collect data for debugging
  743. if enable_pickle_output and (numpy.isnan(max_abs_diff) or max_abs_diff > 100 * atol):
  744. Gpt2Helper.save_inputs(i, dummy_inputs)
  745. Gpt2Helper.save_outputs(i, ort_outputs, outputs)
  746. if max_abs_diff_list:
  747. result = {
  748. f"max_diff_percentile_{p}": f"{numpy.percentile(max_abs_diff_list, p):.5f}" for p in [50, 90, 95, 99]
  749. }
  750. else:
  751. result = {f"max_diff_percentile_{p}": "nan" for p in [50, 90, 95, 99]}
  752. result["top1_match_rate"] = top1_matched_cases * 1.0 / total_test_cases
  753. result["top1_match_rate_per_run"] = [x * 1.0 / test_cases_per_run for x in top1_matched_cases_per_run]
  754. result["diff_pass_rate"] = passed_test_cases * 1.0 / total_test_cases
  755. result["nan_rate"] = (total_test_cases - len(max_abs_diff_list)) * 1.0 / total_test_cases
  756. logger.info(
  757. f"Parity Test Cases={total_test_cases}; Passed={passed_test_cases}; Nan={total_test_cases - len(max_abs_diff_list)}; Top1_Matched={top1_matched_cases}"
  758. )
  759. if passed_test_cases > 0.95 * total_test_cases:
  760. logger.info(f"Parity is good: passed rate={int(passed_test_cases * 100 / total_test_cases):.0f}%")
  761. return result
  762. @staticmethod
  763. def test_performance(
  764. ort_session,
  765. model,
  766. device,
  767. is_float16=False,
  768. total_runs=100,
  769. use_io_binding=True,
  770. model_class="GPT2LMHeadModel",
  771. has_position_ids=True,
  772. has_attention_mask=True,
  773. input_ids_dtype=torch.int32,
  774. position_ids_dtype=torch.int32,
  775. attention_mask_dtype=torch.int32,
  776. batch_size=8,
  777. sequence_length=1,
  778. past_sequence_length=32,
  779. ):
  780. """Generate random inputs and measure average latency of Onnx Runtime."""
  781. config: GPT2Config = model.config
  782. output_buffers = None
  783. if use_io_binding:
  784. output_shapes = Gpt2Helper.get_output_shapes(
  785. batch_size, past_sequence_length, sequence_length, config, model_class
  786. )
  787. output_buffers = Gpt2Helper.get_output_buffers(output_shapes, device, is_float16)
  788. dummy_inputs = Gpt2Helper.get_dummy_inputs(
  789. batch_size,
  790. past_sequence_length,
  791. sequence_length,
  792. config.num_attention_heads,
  793. config.hidden_size,
  794. config.n_layer,
  795. config.vocab_size,
  796. device,
  797. is_float16,
  798. has_position_ids,
  799. has_attention_mask,
  800. input_ids_dtype=input_ids_dtype,
  801. position_ids_dtype=position_ids_dtype,
  802. attention_mask_dtype=attention_mask_dtype,
  803. )
  804. if use_io_binding:
  805. _, latency = Gpt2Helper.onnxruntime_inference(ort_session, dummy_inputs, total_runs)
  806. else:
  807. _, latency = Gpt2Helper.onnxruntime_inference_with_binded_io(
  808. ort_session, dummy_inputs, output_buffers, output_shapes, total_runs
  809. )
  810. return latency
  811. @staticmethod
  812. def torchscript(model, config, device, has_position_ids=True, has_attention_mask=True):
  813. """JIT trace for TorchScript."""
  814. input_list = Gpt2Helper.get_dummy_inputs(
  815. batch_size=1,
  816. past_sequence_length=1,
  817. sequence_length=1,
  818. num_attention_heads=config.num_attention_heads,
  819. hidden_size=config.hidden_size,
  820. num_layer=config.n_layer,
  821. vocab_size=config.vocab_size,
  822. device=device,
  823. float16=False,
  824. has_position_ids=has_position_ids,
  825. has_attention_mask=has_attention_mask,
  826. ).to_list()
  827. return torch.jit.trace(model, input_list)
  828. @staticmethod
  829. def get_onnx_paths(
  830. output_dir,
  831. model_name_or_path,
  832. model_class: str = "GPT2LMHeadModel",
  833. has_past=True,
  834. new_folder=False,
  835. remove_existing=["raw", "fp32", "fp16", "int8"], # noqa: B006
  836. ):
  837. """Build a path name for given model based on given attributes."""
  838. model_name = model_name_or_path
  839. if os.path.isdir(model_name_or_path):
  840. model_name = Path(model_name_or_path).parts[-1]
  841. else:
  842. model_name.split("/")[-1]
  843. if model_class != "GPT2LMHeadModel":
  844. model_name += "_" + model_class
  845. if has_past:
  846. model_name += "_past"
  847. if new_folder:
  848. suffix = {"raw": "", "fp32": "_fp32", "fp16": "_fp16", "int8": "_int8"}
  849. # Remove the directories if existed.
  850. for model_type in ["raw", "fp32", "fp16", "int8"]:
  851. new_dir = os.path.join(output_dir, model_name + suffix[model_type])
  852. if os.path.exists(new_dir):
  853. if model_type in remove_existing:
  854. try:
  855. shutil.rmtree(new_dir)
  856. logger.info(f"Removed the existed directory: {new_dir}")
  857. except OSError as e:
  858. logger.info(f"Failed to remove the directory {new_dir}: {e.strerror}")
  859. else:
  860. logger.info(f"Directory for {model_type} existed: {new_dir}")
  861. # store each model to its own directory (for external data format).
  862. return {
  863. "raw": os.path.join(os.path.join(output_dir, model_name), model_name + ".onnx"),
  864. "fp32": os.path.join(
  865. os.path.join(output_dir, model_name + "_fp32"),
  866. model_name + "_fp32.onnx",
  867. ),
  868. "fp16": os.path.join(
  869. os.path.join(output_dir, model_name + "_fp16"),
  870. model_name + "_fp16.onnx",
  871. ),
  872. "int8": os.path.join(
  873. os.path.join(output_dir, model_name + "_int8"),
  874. model_name + "_int8.onnx",
  875. ),
  876. }
  877. return {
  878. "raw": os.path.join(output_dir, model_name + ".onnx"),
  879. "fp32": os.path.join(output_dir, model_name + "_fp32.onnx"),
  880. "fp16": os.path.join(output_dir, model_name + "_fp16.onnx"),
  881. "int8": os.path.join(output_dir, model_name + "_int8.onnx"),
  882. }