convert_npz_to_onnx_adapter.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Microsoft Corporation. All rights reserved.
  2. # Licensed under the MIT License.
  3. # This script helps converting .npz files to .onnx_adapter files
  4. import argparse
  5. import os
  6. import sys
  7. import numpy as np
  8. import onnxruntime as ort
  9. def get_args() -> argparse:
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument("--npz_file_path", type=str, required=True)
  12. parser.add_argument("--output_file_path", type=str, required=True)
  13. parser.add_argument("--adapter_version", type=int, required=True)
  14. parser.add_argument("--model_version", type=int, required=True)
  15. return parser.parse_args()
  16. def export_lora_parameters(
  17. npz_file_path: os.PathLike, adapter_version: int, model_version: int, output_file_path: os.PathLike
  18. ):
  19. """The function converts lora parameters in npz to onnx_adapter format"""
  20. adapter_format = ort.AdapterFormat()
  21. adapter_format.set_adapter_version(adapter_version)
  22. adapter_format.set_model_version(model_version)
  23. name_to_ort_value = {}
  24. with np.load(npz_file_path) as data:
  25. for name, np_arr in data.items():
  26. ort_value = ort.OrtValue.ortvalue_from_numpy(np_arr)
  27. name_to_ort_value[name] = ort_value
  28. adapter_format.set_parameters(name_to_ort_value)
  29. adapter_format.export_adapter(output_file_path)
  30. def main() -> int:
  31. args = get_args()
  32. export_lora_parameters(args.npz_file_path, args.adapter_version, args.model_version, args.output_file_path)
  33. return 0
  34. if __name__ == "__main__":
  35. sys.exit(main())