build_tensorrt.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #!/usr/bin/env python3
  2. """
  3. Convert exported ONNX graphs into TensorRT engines.
  4. Requires a local TensorRT installation with the Python bindings available.
  5. """
  6. from __future__ import annotations
  7. import argparse
  8. from pathlib import Path
  9. import tensorrt as trt
  10. def build_engine(onnx_path: Path, engine_path: Path, fp16: bool = True, workspace_gb: int = 2):
  11. logger = trt.Logger(trt.Logger.WARNING)
  12. builder = trt.Builder(logger)
  13. network_flags = trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
  14. network = builder.create_network(network_flags)
  15. parser = trt.OnnxParser(network, logger)
  16. onnx_bytes = onnx_path.read_bytes()
  17. if not parser.parse(onnx_bytes):
  18. for idx in range(parser.num_errors):
  19. print(parser.get_error(idx))
  20. raise RuntimeError(f"Failed to parse ONNX: {onnx_path}")
  21. config = builder.create_builder_config()
  22. config.max_workspace_size = workspace_gb * (1 << 30)
  23. if fp16 and builder.platform_has_fast_fp16:
  24. config.set_flag(trt.BuilderFlag.FP16)
  25. profile = builder.create_optimization_profile()
  26. for idx in range(network.num_inputs):
  27. tensor = network.get_input(idx)
  28. shape = tuple(tensor.shape)
  29. profile.set_shape(tensor.name, shape, shape, shape)
  30. config.add_optimization_profile(profile)
  31. engine = builder.build_engine(network, config)
  32. if engine is None:
  33. raise RuntimeError(f"Failed to build TensorRT engine for {onnx_path}")
  34. engine_path.parent.mkdir(parents=True, exist_ok=True)
  35. with engine_path.open("wb") as f:
  36. f.write(engine.serialize())
  37. print(f"[OK] TensorRT engine saved to: {engine_path}")
  38. def main():
  39. parser = argparse.ArgumentParser(description="Build TensorRT engines from ONNX.")
  40. parser.add_argument(
  41. "--onnx-dir",
  42. type=Path,
  43. default=Path("models"),
  44. help="Directory containing ONNX files.",
  45. )
  46. parser.add_argument(
  47. "--output-dir",
  48. type=Path,
  49. default=Path("models"),
  50. help="Directory to store TensorRT engines.",
  51. )
  52. parser.add_argument(
  53. "--fp16",
  54. action="store_true",
  55. help="Enable FP16 precision if supported by the device.",
  56. )
  57. parser.add_argument(
  58. "--workspace-gb",
  59. type=int,
  60. default=2,
  61. help="TensorRT workspace size in GB.",
  62. )
  63. args = parser.parse_args()
  64. onnx_files = [
  65. args.onnx_dir / "superpoint.onnx",
  66. args.onnx_dir / "lightglue.onnx",
  67. ]
  68. for onnx_file in onnx_files:
  69. if not onnx_file.exists():
  70. raise FileNotFoundError(f"Missing ONNX file: {onnx_file}")
  71. for onnx_file in onnx_files:
  72. engine_file = args.output_dir / (onnx_file.stem + ".plan")
  73. build_engine(onnx_file, engine_file, fp16=args.fp16, workspace_gb=args.workspace_gb)
  74. if __name__ == "__main__":
  75. main()