#!/usr/bin/env python3 """ Convert exported ONNX graphs into TensorRT engines. Requires a local TensorRT installation with the Python bindings available. """ from __future__ import annotations import argparse from pathlib import Path import tensorrt as trt def build_engine(onnx_path: Path, engine_path: Path, fp16: bool = True, workspace_gb: int = 2): logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network_flags = trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) onnx_bytes = onnx_path.read_bytes() if not parser.parse(onnx_bytes): for idx in range(parser.num_errors): print(parser.get_error(idx)) raise RuntimeError(f"Failed to parse ONNX: {onnx_path}") config = builder.create_builder_config() config.max_workspace_size = workspace_gb * (1 << 30) if fp16 and builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile() for idx in range(network.num_inputs): tensor = network.get_input(idx) shape = tuple(tensor.shape) profile.set_shape(tensor.name, shape, shape, shape) config.add_optimization_profile(profile) engine = builder.build_engine(network, config) if engine is None: raise RuntimeError(f"Failed to build TensorRT engine for {onnx_path}") engine_path.parent.mkdir(parents=True, exist_ok=True) with engine_path.open("wb") as f: f.write(engine.serialize()) print(f"[OK] TensorRT engine saved to: {engine_path}") def main(): parser = argparse.ArgumentParser(description="Build TensorRT engines from ONNX.") parser.add_argument( "--onnx-dir", type=Path, default=Path("models"), help="Directory containing ONNX files.", ) parser.add_argument( "--output-dir", type=Path, default=Path("models"), help="Directory to store TensorRT engines.", ) parser.add_argument( "--fp16", action="store_true", help="Enable FP16 precision if supported by the device.", ) parser.add_argument( "--workspace-gb", type=int, default=2, help="TensorRT workspace size in GB.", ) args = parser.parse_args() onnx_files = [ args.onnx_dir / "superpoint.onnx", args.onnx_dir / "lightglue.onnx", ] for onnx_file in onnx_files: if not onnx_file.exists(): raise FileNotFoundError(f"Missing ONNX file: {onnx_file}") for onnx_file in onnx_files: engine_file = args.output_dir / (onnx_file.stem + ".plan") build_engine(onnx_file, engine_file, fp16=args.fp16, workspace_gb=args.workspace_gb) if __name__ == "__main__": main()