t5_decoder.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. import tempfile
  9. from pathlib import Path
  10. import numpy
  11. import onnx
  12. import torch
  13. from io_binding_helper import TypeHelper
  14. from onnx_model import OnnxModel
  15. from past_helper import PastKeyValuesHelper
  16. from t5_encoder import T5EncoderInputs
  17. from torch_onnx_export_helper import torch_onnx_export
  18. from transformers import MT5Config, T5Config
  19. from onnxruntime import InferenceSession
  20. logger = logging.getLogger(__name__)
  21. class T5DecoderInit(torch.nn.Module):
  22. """A T5 decoder with LM head to create initial past key values.
  23. This model is only called once during starting decoding.
  24. """
  25. def __init__(
  26. self,
  27. decoder: torch.nn.Module,
  28. lm_head: torch.nn.Module,
  29. config: T5Config | MT5Config,
  30. decoder_start_token_id: int | None = None,
  31. ):
  32. super().__init__()
  33. self.decoder = decoder
  34. self.lm_head = lm_head
  35. self.config = config
  36. self.decoder_start_token_id = (
  37. decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
  38. )
  39. self.tie_word_embeddings = (
  40. self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
  41. )
  42. def forward(
  43. self,
  44. decoder_input_ids: torch.Tensor,
  45. encoder_attention_mask: torch.Tensor,
  46. encoder_hidden_states: torch.FloatTensor,
  47. ):
  48. if decoder_input_ids is None:
  49. batch_size = encoder_attention_mask.shape[0]
  50. decoder_input_ids = (
  51. torch.ones(
  52. (batch_size, 1),
  53. dtype=torch.long,
  54. device=encoder_attention_mask.device,
  55. )
  56. * self.decoder_start_token_id
  57. )
  58. decoder_outputs = self.decoder(
  59. input_ids=decoder_input_ids,
  60. encoder_hidden_states=encoder_hidden_states,
  61. encoder_attention_mask=encoder_attention_mask,
  62. use_cache=True,
  63. return_dict=True,
  64. )
  65. sequence_output = decoder_outputs.last_hidden_state
  66. present_key_values = decoder_outputs.past_key_values
  67. if self.tie_word_embeddings:
  68. sequence_output = sequence_output * (self.config.d_model**-0.5)
  69. lm_logits = self.lm_head(sequence_output)
  70. past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
  71. return lm_logits, past_self, past_cross
  72. class T5Decoder(torch.nn.Module):
  73. """A T5 decoder with LM head and past key values"""
  74. def __init__(self, decoder, lm_head, config):
  75. super().__init__()
  76. self.decoder = decoder
  77. self.lm_head = lm_head
  78. self.config = config
  79. self.tie_word_embeddings = (
  80. self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
  81. )
  82. def forward(self, decoder_input_ids, encoder_attention_mask, *past):
  83. num_decoder_layers = self.config.num_decoder_layers
  84. past_key_values = PastKeyValuesHelper.group_by_layer(past, num_decoder_layers)
  85. # This is a hack since only the third dimension of encoder_hidden_states is used here
  86. dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2)
  87. decoder_outputs = self.decoder(
  88. input_ids=decoder_input_ids,
  89. past_key_values=past_key_values,
  90. encoder_hidden_states=dummy_encoder_hidden_states,
  91. encoder_attention_mask=encoder_attention_mask,
  92. use_cache=True,
  93. return_dict=True,
  94. )
  95. sequence_output = decoder_outputs.last_hidden_state
  96. present_key_values = decoder_outputs.past_key_values
  97. if self.tie_word_embeddings:
  98. sequence_output = sequence_output * (self.config.d_model**-0.5)
  99. lm_logits = self.lm_head(sequence_output)
  100. present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
  101. # Do not return present_cross since they are identical to corresponding past_cross input
  102. return lm_logits, present_self
  103. class T5DecoderInputs:
  104. def __init__(
  105. self,
  106. decoder_input_ids,
  107. encoder_attention_mask,
  108. past_key_values=None,
  109. ):
  110. self.decoder_input_ids: torch.LongTensor = decoder_input_ids
  111. self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
  112. self.past_key_values: list[torch.FloatTensor] | list[torch.HalfTensor] | None = past_key_values
  113. @staticmethod
  114. def create_dummy(
  115. config: T5Config | MT5Config,
  116. batch_size: int,
  117. encode_sequence_length: int,
  118. past_decode_sequence_length: int,
  119. device: torch.device,
  120. float16: bool = False,
  121. use_int32_inputs: bool = False,
  122. ): # -> T5DecoderInputs:
  123. """Create dummy inputs for T5Decoder.
  124. Args:
  125. decoder: decoder
  126. batch_size (int): batch size
  127. encode_sequence_length (int): sequence length of input_ids for encoder
  128. past_decode_sequence_length (int): past sequence length of input_ids for decoder
  129. device (torch.device): device of output tensors
  130. float16 (bool): whether the model uses float32 or float16 in input
  131. use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
  132. Returns:
  133. T5DecoderInputs: dummy inputs for decoder
  134. """
  135. num_attention_heads: int = config.num_heads
  136. num_layers: int = config.num_decoder_layers
  137. vocab_size: int = config.vocab_size
  138. # Do not use head_size = hidden_size / num_attention_heads here.
  139. # For example, mt5-small, d_model=512 and num_heads=6
  140. head_size: int = config.d_kv
  141. sequence_length: int = 1 # fixed for decoding
  142. decoder_input_ids = torch.randint(
  143. low=0,
  144. high=vocab_size - 1,
  145. size=(batch_size, sequence_length),
  146. dtype=(torch.int32 if use_int32_inputs else torch.int64),
  147. device=device,
  148. )
  149. encoder_inputs = T5EncoderInputs.create_dummy(
  150. batch_size,
  151. encode_sequence_length,
  152. vocab_size,
  153. device,
  154. use_int32_inputs=use_int32_inputs,
  155. )
  156. float_type = torch.float16 if float16 else torch.float32
  157. if past_decode_sequence_length > 0:
  158. self_attention_past_shape = [
  159. batch_size,
  160. num_attention_heads,
  161. past_decode_sequence_length,
  162. head_size,
  163. ]
  164. cross_attention_past_shape = [
  165. batch_size,
  166. num_attention_heads,
  167. encode_sequence_length,
  168. head_size,
  169. ]
  170. past = []
  171. for _ in range(2 * num_layers):
  172. past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
  173. for _ in range(2 * num_layers):
  174. past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
  175. else:
  176. past = None
  177. return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past)
  178. def to_list(self) -> list:
  179. input_list = [
  180. self.decoder_input_ids,
  181. self.encoder_attention_mask,
  182. ]
  183. if self.past_key_values:
  184. input_list.extend(self.past_key_values)
  185. return input_list
  186. def to_fp32(self):
  187. past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
  188. return T5DecoderInputs(
  189. self.decoder_input_ids.clone(),
  190. self.encoder_attention_mask.clone(),
  191. past,
  192. )
  193. class T5DecoderHelper:
  194. @staticmethod
  195. def export_onnx(
  196. decoder: T5Decoder | T5DecoderInit,
  197. device: torch.device,
  198. onnx_model_path: str,
  199. verbose: bool = True,
  200. use_external_data_format: bool = False,
  201. use_int32_inputs: bool = False,
  202. ):
  203. """Export decoder to ONNX
  204. Args:
  205. decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
  206. device (torch.device): device of decoder object
  207. onnx_model_path (str): onnx path
  208. verbose (bool, optional): print verbose information. Defaults to True.
  209. use_external_data_format (bool, optional): use external data format or not. Defaults to False.
  210. use_int32_inputs (bool, optional): use int32 inputs
  211. """
  212. assert isinstance(decoder, (T5Decoder, T5DecoderInit))
  213. inputs = T5DecoderInputs.create_dummy(
  214. decoder.config,
  215. batch_size=2,
  216. encode_sequence_length=3,
  217. past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
  218. device=device,
  219. use_int32_inputs=use_int32_inputs,
  220. )
  221. input_list = inputs.to_list()
  222. num_decoder_layers = decoder.config.num_decoder_layers
  223. past_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=False)
  224. present_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=True)
  225. present_self_names = present_names[: 2 * num_decoder_layers]
  226. input_past_names = past_names if isinstance(decoder, T5Decoder) else []
  227. output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
  228. output_names = ["logits", *output_present_names]
  229. # Shape of input tensors (sequence_length==1):
  230. # input_ids: (batch_size, sequence_length)
  231. # encoder_attention_mask: (batch_size, encode_sequence_length)
  232. # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
  233. # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
  234. # Shape of output tensors:
  235. # logits: (batch_size, sequence_length, vocab_size)
  236. # past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
  237. # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
  238. input_names = ["input_ids"]
  239. input_names.append("encoder_attention_mask")
  240. input_names.extend(input_past_names)
  241. dynamic_axes = {
  242. "input_ids": {
  243. 0: "batch_size",
  244. # 1: 'sequence_length'
  245. },
  246. "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
  247. "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"},
  248. "logits": {
  249. 0: "batch_size",
  250. # 1: 'sequence_length'
  251. },
  252. }
  253. for name in input_past_names:
  254. dynamic_axes[name] = {
  255. 0: "batch_size",
  256. 2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
  257. }
  258. for name in output_present_names:
  259. if "cross" in name:
  260. dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
  261. else: # self attention past state
  262. if isinstance(decoder, T5Decoder):
  263. dynamic_axes[name] = {
  264. 0: "batch_size",
  265. 2: "past_decode_sequence_length + 1",
  266. }
  267. else:
  268. dynamic_axes[name] = {
  269. 0: "batch_size",
  270. # 2: 'sequence_length'
  271. }
  272. Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  273. with tempfile.TemporaryDirectory() as tmp_dir_name:
  274. temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
  275. Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  276. torch_onnx_export(
  277. decoder,
  278. args=tuple(input_list),
  279. f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
  280. export_params=True,
  281. input_names=input_names,
  282. output_names=output_names,
  283. dynamic_axes=dynamic_axes,
  284. opset_version=12,
  285. do_constant_folding=True,
  286. use_external_data_format=use_external_data_format,
  287. verbose=verbose,
  288. )
  289. if use_external_data_format:
  290. model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
  291. OnnxModel.save(
  292. model,
  293. onnx_model_path,
  294. save_as_external_data=True,
  295. all_tensors_to_one_file=True,
  296. )
  297. @staticmethod
  298. def onnxruntime_inference(ort_session, inputs: T5DecoderInputs):
  299. """Run inference of ONNX model."""
  300. logger.debug("start onnxruntime_inference")
  301. ort_inputs = {
  302. "input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
  303. "encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
  304. }
  305. if inputs.past_key_values:
  306. assert len(inputs.past_key_values) % 4 == 0
  307. num_layers = int(len(inputs.past_key_values) / 4)
  308. past_names = PastKeyValuesHelper.get_past_names(num_layers)
  309. for i, past_tensor in enumerate(inputs.past_key_values):
  310. ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
  311. ort_outputs = ort_session.run(None, ort_inputs)
  312. return ort_outputs
  313. @staticmethod
  314. def verify_onnx(
  315. model: T5Decoder | T5DecoderInit,
  316. ort_session: InferenceSession,
  317. device: torch.device,
  318. use_int32_inputs: bool,
  319. max_cases: int = 4,
  320. ):
  321. """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
  322. float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
  323. test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
  324. test_cases_max_diff = []
  325. for (
  326. batch_size,
  327. encode_sequence_length,
  328. past_decode_sequence_length,
  329. ) in test_cases[:max_cases]:
  330. if isinstance(model, T5DecoderInit):
  331. past_decode_sequence_length = 0 # noqa: PLW2901
  332. inputs = T5DecoderInputs.create_dummy(
  333. model.config,
  334. batch_size,
  335. encode_sequence_length,
  336. past_decode_sequence_length,
  337. device=device,
  338. float16=float16,
  339. use_int32_inputs=use_int32_inputs,
  340. )
  341. # We use fp32 PyTroch model as baseline even when ONNX model is fp16
  342. input_list = inputs.to_fp32().to_list()
  343. # Run inference of PyTorch model
  344. with torch.no_grad():
  345. torch_outputs = model(*input_list)
  346. ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs)
  347. num_decoder_layers = model.config.num_decoder_layers
  348. max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
  349. max_diff_all = max_diff
  350. logger.debug(f"logits max_diff={max_diff}")
  351. for i in range(2 * num_decoder_layers):
  352. max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
  353. logger.debug(f"self attention past state {i} max_diff={max_diff}")
  354. max_diff_all = max(max_diff_all, max_diff)
  355. if isinstance(model, T5DecoderInit):
  356. for i in range(2 * num_decoder_layers):
  357. max_diff = numpy.amax(
  358. numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * num_decoder_layers + i])
  359. )
  360. logger.debug(f"cross attention past state {i} max_diff={max_diff}")
  361. max_diff_all = max(max_diff_all, max_diff)
  362. test_cases_max_diff.append(max_diff_all)
  363. logger.info(
  364. "batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
  365. batch_size,
  366. encode_sequence_length,
  367. past_decode_sequence_length,
  368. max_diff_all,
  369. )
  370. return max_diff_all