| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- #!/usr/bin/env python3
- """
- Convert exported ONNX graphs into TensorRT engines.
- Requires a local TensorRT installation with the Python bindings available.
- """
- from __future__ import annotations
- import argparse
- from pathlib import Path
- import tensorrt as trt
- def build_engine(onnx_path: Path, engine_path: Path, fp16: bool = True, workspace_gb: int = 2):
- logger = trt.Logger(trt.Logger.WARNING)
- builder = trt.Builder(logger)
- network_flags = trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
- network = builder.create_network(network_flags)
- parser = trt.OnnxParser(network, logger)
- onnx_bytes = onnx_path.read_bytes()
- if not parser.parse(onnx_bytes):
- for idx in range(parser.num_errors):
- print(parser.get_error(idx))
- raise RuntimeError(f"Failed to parse ONNX: {onnx_path}")
- config = builder.create_builder_config()
- config.max_workspace_size = workspace_gb * (1 << 30)
- if fp16 and builder.platform_has_fast_fp16:
- config.set_flag(trt.BuilderFlag.FP16)
- profile = builder.create_optimization_profile()
- for idx in range(network.num_inputs):
- tensor = network.get_input(idx)
- shape = tuple(tensor.shape)
- profile.set_shape(tensor.name, shape, shape, shape)
- config.add_optimization_profile(profile)
- engine = builder.build_engine(network, config)
- if engine is None:
- raise RuntimeError(f"Failed to build TensorRT engine for {onnx_path}")
- engine_path.parent.mkdir(parents=True, exist_ok=True)
- with engine_path.open("wb") as f:
- f.write(engine.serialize())
- print(f"[OK] TensorRT engine saved to: {engine_path}")
- def main():
- parser = argparse.ArgumentParser(description="Build TensorRT engines from ONNX.")
- parser.add_argument(
- "--onnx-dir",
- type=Path,
- default=Path("models"),
- help="Directory containing ONNX files.",
- )
- parser.add_argument(
- "--output-dir",
- type=Path,
- default=Path("models"),
- help="Directory to store TensorRT engines.",
- )
- parser.add_argument(
- "--fp16",
- action="store_true",
- help="Enable FP16 precision if supported by the device.",
- )
- parser.add_argument(
- "--workspace-gb",
- type=int,
- default=2,
- help="TensorRT workspace size in GB.",
- )
- args = parser.parse_args()
- onnx_files = [
- args.onnx_dir / "superpoint.onnx",
- args.onnx_dir / "lightglue.onnx",
- ]
- for onnx_file in onnx_files:
- if not onnx_file.exists():
- raise FileNotFoundError(f"Missing ONNX file: {onnx_file}")
- for onnx_file in onnx_files:
- engine_file = args.output_dir / (onnx_file.stem + ".plan")
- build_engine(onnx_file, engine_file, fp16=args.fp16, workspace_gb=args.workspace_gb)
- if __name__ == "__main__":
- main()
|