shape_inference.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # --------------------------------------------------------------------------
  2. # Copyright (c) Microsoft, Intel Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import tempfile
  8. import traceback
  9. from pathlib import Path
  10. import onnx
  11. import onnxruntime
  12. from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
  13. from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data
  14. from .fusions import ReplaceUpsampleWithResize
  15. from .onnx_model import ONNXModel
  16. from .quant_utils import add_pre_process_metadata, save_and_reload_model_with_shape_infer
  17. logger = logging.getLogger(__name__)
  18. def quant_pre_process(
  19. input_model: str | Path | onnx.ModelProto | None = None,
  20. output_model_path: str | Path | None = None,
  21. skip_optimization: bool = False,
  22. skip_onnx_shape: bool = False,
  23. skip_symbolic_shape: bool = False,
  24. auto_merge: bool = False,
  25. int_max: int = 2**31 - 1,
  26. guess_output_rank: bool = False,
  27. verbose: int = 0,
  28. save_as_external_data: bool = False,
  29. all_tensors_to_one_file: bool = False,
  30. external_data_location: str | None = None,
  31. external_data_size_threshold: int = 1024,
  32. **deprecated_kwargs,
  33. ) -> None:
  34. """Shape inference and model optimization, in preparation for quantization.
  35. Args:
  36. input_model: Path to the input model file or ModelProto
  37. output_model_path: Path to the output model file
  38. skip_optimization: Skip model optimization step if true. This may result in ONNX shape
  39. inference failure for some models.
  40. skip_onnx_shape: Skip ONNX shape inference. Symbolic shape inference is most effective
  41. with transformer based models. Skipping all shape inferences may
  42. reduce the effectiveness of quantization, as a tensor with unknown
  43. shape can not be quantized.
  44. skip_symbolic_shape: Skip symbolic shape inference. Symbolic shape inference is most
  45. effective with transformer based models. Skipping all shape
  46. inferences may reduce the effectiveness of quantization, as a tensor
  47. with unknown shape can not be quantized.
  48. auto_merge: For symbolic shape inference, automatically merge symbolic dims when
  49. conflict happens.
  50. int_max: For symbolic shape inference, specify the maximum value for integer to be
  51. treated as boundless for ops like slice
  52. guess_output_rank: Guess output rank to be the same as input 0 for unknown ops
  53. verbose: Logs detailed info of inference, 0: turn off, 1: warnings, 3: detailed
  54. save_as_external_data: Saving an ONNX model to external data
  55. all_tensors_to_one_file: Saving all the external data to one file
  56. external_data_location: The file location to save the external file
  57. external_data_size_threshold: The size threshold for external data
  58. """
  59. if input_model is None:
  60. input_model = deprecated_kwargs.pop("input_model_path", None)
  61. assert input_model is not None
  62. assert output_model_path is not None, "output_model_path is required."
  63. with tempfile.TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir:
  64. temp_path = Path(quant_tmp_dir)
  65. model = None
  66. if not skip_symbolic_shape:
  67. logger.info("Performing symbolic shape inference...")
  68. loaded_model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
  69. model = SymbolicShapeInference.infer_shapes(
  70. loaded_model,
  71. int_max,
  72. auto_merge,
  73. guess_output_rank,
  74. verbose,
  75. )
  76. # Since Upsample is deprecated after opset v10, and the model's opset will
  77. # be upgraded to at least v11 during quantization, we need to replace Upsample
  78. # with Resize first to avoid generating an invalid model.
  79. if model:
  80. ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
  81. if len(ai_onnx_domain) == 1:
  82. opset_version = ai_onnx_domain[0].version
  83. if opset_version < 10:
  84. ReplaceUpsampleWithResize(ONNXModel(model), opset_version).apply()
  85. model.opset_import.remove(ai_onnx_domain[0])
  86. opset_version = 11
  87. model.opset_import.extend([onnx.helper.make_opsetid("", opset_version)])
  88. model = onnx.version_converter.convert_version(model, opset_version)
  89. model = save_and_reload_model_with_shape_infer(model)
  90. if not skip_optimization:
  91. # Use ORT optimizers (native code) to optimize model
  92. if not skip_symbolic_shape:
  93. # Need to save the inferenced model to file so as to run the optimizer
  94. input_model = str(temp_path / "symbolic_shape_inferred.onnx")
  95. if save_as_external_data:
  96. onnx.save_model(
  97. model,
  98. input_model,
  99. save_as_external_data=True,
  100. all_tensors_to_one_file=all_tensors_to_one_file,
  101. size_threshold=external_data_size_threshold,
  102. convert_attribute=False,
  103. )
  104. else:
  105. onnx.save(model, input_model)
  106. model = None
  107. opt_model_path = str(temp_path / "optimized.onnx")
  108. try:
  109. sess_option = onnxruntime.SessionOptions()
  110. sess_option.optimized_model_filepath = opt_model_path
  111. sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
  112. # For large model, extract external data from model and add to session options
  113. if isinstance(input_model, onnx.ModelProto):
  114. if has_external_data(input_model):
  115. raise ValueError(
  116. "ModelProto has external data not loaded into memory, ORT cannot create session. "
  117. "Please load external data before calling this function. "
  118. "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
  119. )
  120. external_names, external_values = extract_raw_data_from_model(input_model)
  121. sess_option.add_external_initializers(list(external_names), list(external_values))
  122. input_model = input_model.SerializeToString()
  123. # the saved optimized model otherwise points to the original external data file name
  124. # which is not available relative to the optimized model file
  125. elif skip_symbolic_shape and save_as_external_data:
  126. sess_option.add_session_config_entry(
  127. "session.optimized_model_external_initializers_file_name", "optimized.onnx.data"
  128. )
  129. sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"])
  130. # Close the session to avoid the cleanup error on Windows for temp folders
  131. # https://github.com/microsoft/onnxruntime/issues/17627
  132. del sess
  133. except Exception:
  134. logger.error(
  135. "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'."
  136. )
  137. logger.error(traceback.format_exc())
  138. input_model = opt_model_path
  139. if not skip_onnx_shape:
  140. # ONNX shape inference.
  141. # According to docs, infer_shapes_path should be used for 2G+ models.
  142. # If the skip optimization is specified, we could be dealing with a
  143. # large model. So be on the safe side, save the model
  144. if model is not None:
  145. input_model = str(temp_path / "symbolic_shape_inferred.onnx")
  146. if save_as_external_data:
  147. onnx.save_model(
  148. model,
  149. input_model,
  150. save_as_external_data=True,
  151. all_tensors_to_one_file=all_tensors_to_one_file,
  152. size_threshold=external_data_size_threshold,
  153. convert_attribute=False,
  154. )
  155. else:
  156. onnx.save(model, input_model)
  157. model = None
  158. if isinstance(input_model, onnx.ModelProto):
  159. input_model = str(Path(quant_tmp_dir) / "model_input.onnx")
  160. onnx.save_model(
  161. model,
  162. input_model,
  163. save_as_external_data=True,
  164. all_tensors_to_one_file=all_tensors_to_one_file,
  165. size_threshold=external_data_size_threshold,
  166. convert_attribute=False,
  167. )
  168. inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
  169. onnx.shape_inference.infer_shapes_path(input_model, inferred_model_path)
  170. model = onnx.load(inferred_model_path)
  171. if model is None:
  172. model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
  173. add_pre_process_metadata(model)
  174. if save_as_external_data:
  175. onnx.save_model(
  176. model,
  177. output_model_path,
  178. save_as_external_data=True,
  179. all_tensors_to_one_file=all_tensors_to_one_file,
  180. location=external_data_location,
  181. size_threshold=external_data_size_threshold,
  182. convert_attribute=False,
  183. )
  184. else:
  185. onnx.save(model, output_model_path)