optimize_qdq_model.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. #!/usr/bin/env python3
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. import argparse
  5. import os
  6. import pathlib
  7. import onnx
  8. def optimize_qdq_model():
  9. parser = argparse.ArgumentParser(
  10. os.path.basename(__file__),
  11. description="Update a QDQ format ONNX model to ensure optimal performance when executed using ONNX Runtime.",
  12. )
  13. parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
  14. parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
  15. args = parser.parse_args()
  16. model = onnx.load(str(args.input_model.resolve(strict=True)))
  17. # run QDQ model optimizations here
  18. # Originally, the fixing up of DQ nodes with multiple consumers was implemented as one such optimization.
  19. # That was moved to an ORT graph transformer.
  20. print("As of ORT 1.15, the fixing up of DQ nodes with multiple consumers is done by an ORT graph transformer.")
  21. # There are no optimizations being run currently but we expect that there may be in the future.
  22. onnx.save(model, str(args.output_model.resolve()))
  23. if __name__ == "__main__":
  24. optimize_qdq_model()