t5_helper.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # -------------------------------------------------------------------------
  5. import logging
  6. import os
  7. from pathlib import Path
  8. import torch
  9. from float16 import float_to_float16_max_diff
  10. from onnx_model import OnnxModel
  11. from optimizer import optimize_model
  12. from t5_decoder import T5Decoder, T5DecoderHelper
  13. from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper
  14. from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration
  15. from onnxruntime import InferenceSession
  16. logger = logging.getLogger(__name__)
  17. PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
  18. PRETRAINED_MT5_MODELS = [
  19. "google/mt5-small",
  20. "google/mt5-base",
  21. "google/mt5-large",
  22. "google/mt5-xl",
  23. "google/mt5-xxl",
  24. ]
  25. class T5Helper:
  26. @staticmethod
  27. def get_onnx_path(
  28. output_dir: str,
  29. model_name_or_path: str,
  30. suffix: str = "",
  31. new_folder: bool = False,
  32. ) -> str:
  33. """Build onnx path
  34. Args:
  35. output_dir (str): output directory
  36. model_name_or_path (str): pretrained model name, or path to the model checkpoint
  37. suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
  38. new_folder (bool, optional): create a new directory for the model. Defaults to False.
  39. Returns:
  40. str: path of onnx model
  41. """
  42. model_name = model_name_or_path
  43. if os.path.isdir(model_name_or_path):
  44. model_name = Path(model_name_or_path).parts[-1]
  45. else:
  46. model_name.split("/")[-1]
  47. model_name += suffix
  48. directory = os.path.join(output_dir, model_name) if new_folder else output_dir
  49. return os.path.join(directory, model_name + ".onnx")
  50. @staticmethod
  51. def load_model(
  52. model_name_or_path: str,
  53. cache_dir: str,
  54. device: torch.device,
  55. model_type: str = "t5",
  56. state_dict_path: str = "",
  57. encoder_decoder_init: bool = False,
  58. ) -> dict[str, T5EncoderDecoderInit | T5Decoder]:
  59. """Load model given a pretrained name or path, then build models for ONNX conversion.
  60. Args:
  61. model_name_or_path (str): pretrained model name or path
  62. cache_dir (str): cache directory
  63. device (torch.device): device to run the model
  64. model_type (str, optional): model type "t5" or "mt5"
  65. state_dict_path(str, optional): state dictionary path
  66. encoder_decoder_init (bool, optional): combine encoder and decoder kv cache initialization into one model.
  67. Returns:
  68. Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
  69. """
  70. if model_type == "t5":
  71. model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
  72. elif model_type == "mt5":
  73. model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
  74. else:
  75. raise ValueError("only support mode_type=t5 or mt5")
  76. if state_dict_path:
  77. model.load_state_dict(torch.load(state_dict_path))
  78. decoder = T5Decoder(model.decoder, model.lm_head, model.config)
  79. decoder.eval().to(device)
  80. encoder = T5EncoderDecoderInit(
  81. model.encoder,
  82. model.decoder,
  83. model.lm_head,
  84. model.config,
  85. decoder_start_token_id=None,
  86. output_cross_only=not encoder_decoder_init,
  87. )
  88. encoder_name = "encoder_decoder_init" if encoder_decoder_init else "encoder"
  89. return {encoder_name: encoder, "decoder": decoder}
  90. @staticmethod
  91. def export_onnx(
  92. model: T5Decoder | T5EncoderDecoderInit,
  93. device: torch.device,
  94. onnx_model_path: str,
  95. verbose: bool = True,
  96. use_external_data_format: bool = False,
  97. use_decoder_input_ids: bool = True,
  98. use_int32_inputs: bool = False,
  99. ):
  100. if isinstance(model, T5EncoderDecoderInit):
  101. T5EncoderDecoderInitHelper.export_onnx(
  102. model,
  103. device,
  104. onnx_model_path,
  105. use_decoder_input_ids,
  106. verbose,
  107. use_external_data_format,
  108. use_int32_inputs,
  109. )
  110. else:
  111. T5DecoderHelper.export_onnx(
  112. model,
  113. device,
  114. onnx_model_path,
  115. verbose,
  116. use_external_data_format,
  117. use_int32_inputs,
  118. )
  119. @staticmethod
  120. def auto_mixed_precision(
  121. onnx_model: OnnxModel,
  122. op_block_list: list[str] | None = None,
  123. force_fp16_logits: bool = False,
  124. use_symbolic_shape_infer: bool = True,
  125. ):
  126. """Convert model to mixed precision.
  127. It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
  128. Args:
  129. onnx_model (OnnxModel): optimized ONNX model
  130. op_block_list (List[str], optional): operators need to run in fp32.
  131. force_fp16_logits (bool, optional): force logits and last MatMul node to be in float16. Defaults to False.
  132. use_symbolic_shape_infer (bool, optional): use symbolic shape inference to convert float to float16. Defaults to True.
  133. Returns:
  134. parameters(dict): a dictionary of parameters used in float16 conversion
  135. """
  136. if op_block_list is None:
  137. op_block_list = [
  138. "SimplifiedLayerNormalization",
  139. "SkipSimplifiedLayerNormalization",
  140. "Relu",
  141. "Add",
  142. ]
  143. op_full_set = {node.op_type for node in onnx_model.nodes()}
  144. fp32_op_set = set(op_block_list)
  145. fp16_op_set = op_full_set.difference(fp32_op_set)
  146. logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
  147. # logits is the first output
  148. logits_output_name = onnx_model.graph().output[0].name
  149. # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
  150. is_weight_fp16_precision = False
  151. output_name_to_node = onnx_model.output_name_to_node()
  152. assert logits_output_name in output_name_to_node
  153. node = output_name_to_node[logits_output_name]
  154. last_matmul_node = None
  155. if node.op_type == "MatMul":
  156. last_matmul_node = node
  157. logger.info(f"Found last MatMul node for logits: {node.name}")
  158. initializer = None
  159. for input in node.input:
  160. initializer = onnx_model.get_initializer(input)
  161. if initializer is not None:
  162. break
  163. # when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
  164. # we can deduce that the weights are stored in float16 precision.
  165. max_diff = float_to_float16_max_diff(initializer)
  166. logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
  167. is_weight_fp16_precision = max_diff < 1e-6
  168. else:
  169. logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
  170. keep_io_types = []
  171. node_block_list = []
  172. if (not is_weight_fp16_precision) and (last_matmul_node is not None) and not force_fp16_logits:
  173. # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
  174. keep_io_types = [logits_output_name]
  175. node_block_list = [last_matmul_node.name]
  176. if "Add" not in op_block_list:
  177. input_name_to_nodes = onnx_model.input_name_to_nodes()
  178. fp32_add = 0
  179. changed = True
  180. add_nodes = onnx_model.get_nodes_by_op_type("Add")
  181. while changed:
  182. changed = False
  183. for node in add_nodes:
  184. if node.name not in node_block_list:
  185. parents = onnx_model.get_parents(node, output_name_to_node)
  186. children = onnx_model.get_children(node, input_name_to_nodes)
  187. blocked_children = [
  188. child for child in children if child.op_type in op_block_list or child in node_block_list
  189. ]
  190. blocked_parents = [
  191. parent for parent in parents if parent.op_type in op_block_list or parent in node_block_list
  192. ]
  193. # If any child or parent is in fp32, we place the Add node to fp32.
  194. if (len(blocked_children) + len(blocked_parents)) > 0:
  195. node_block_list.append(node.name)
  196. fp32_add += 1
  197. changed = True
  198. fp16_add = len(add_nodes) - fp32_add
  199. logger.info(f"node counter of Add operator: fp32={fp32_add} fp16={fp16_add}")
  200. logger.info(f"node_block_list: {node_block_list}")
  201. parameters = {
  202. "keep_io_types": keep_io_types,
  203. "op_block_list": op_block_list,
  204. "node_block_list": node_block_list,
  205. "force_fp16_initializers": is_weight_fp16_precision,
  206. }
  207. logger.info(f"auto_mixed_precision parameters: {parameters}")
  208. if use_symbolic_shape_infer:
  209. onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
  210. else:
  211. # Workaround when symbolic shape inference fails.
  212. # Need enable shape_infer_before_optimization in convert_to_onnx.py as well.
  213. from float16 import convert_float_to_float16 # noqa: PLC0415
  214. convert_float_to_float16(
  215. onnx_model.model,
  216. disable_shape_infer=True,
  217. **parameters,
  218. )
  219. return parameters
  220. @staticmethod
  221. def optimize_onnx(
  222. onnx_model_path: str,
  223. optimized_model_path: str,
  224. is_float16: bool,
  225. num_attention_heads: int,
  226. hidden_size: int,
  227. use_external_data_format: bool = False,
  228. auto_mixed_precision: bool = True,
  229. use_gpu: bool = False,
  230. force_fp16_io: bool = False,
  231. ):
  232. """Optimize ONNX model with an option to convert it to use mixed precision."""
  233. from fusion_options import FusionOptions # noqa: PLC0415
  234. optimization_options = None
  235. if is_float16:
  236. optimization_options = FusionOptions("t5")
  237. # SkipLayerNormalization is faster but might bring accuracy drop since it uses fp16 accumulation.
  238. optimization_options.enable_skip_layer_norm = not auto_mixed_precision
  239. m = optimize_model(
  240. onnx_model_path,
  241. model_type="t5",
  242. num_heads=num_attention_heads,
  243. hidden_size=hidden_size,
  244. opt_level=0,
  245. optimization_options=optimization_options,
  246. use_gpu=use_gpu,
  247. )
  248. if is_float16:
  249. if auto_mixed_precision:
  250. T5Helper.auto_mixed_precision(m, force_fp16_logits=force_fp16_io)
  251. else:
  252. m.convert_model_float32_to_float16(cast_input_output=force_fp16_io)
  253. m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
  254. @staticmethod
  255. def verify_onnx(
  256. model: T5Decoder | T5EncoderDecoderInit,
  257. ort_session: InferenceSession,
  258. device: torch.device,
  259. use_int32_inputs: bool,
  260. ):
  261. """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
  262. if isinstance(model, T5EncoderDecoderInit):
  263. return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
  264. return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)