#!/usr/bin/env python3 """ 将 LightGlue 和 SuperPoint 模型转换为 TensorRT 引擎 使用方法: python convert_to_tensorrt.py --max_keypoints 128 --precision fp16 """ import torch import torch.onnx import argparse import os from pathlib import Path from lightglue import LightGlue, SuperPoint def convert_superpoint_to_onnx(model, output_path, input_shape=(1, 1, 480, 640), device="cuda"): """转换 SuperPoint 模型到 ONNX""" print(f"Converting SuperPoint to ONNX...") model.eval() # 创建虚拟输入 dummy_input = torch.randn(*input_shape).to(device) # 注意:SuperPoint 的输入是一个字典 # 我们需要创建一个包装器来处理这个 class SuperPointWrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, image): return self.model({"image": image}) wrapped_model = SuperPointWrapper(model) try: torch.onnx.export( wrapped_model, dummy_input, output_path, input_names=["image"], output_names=["keypoints", "descriptors", "scores"], dynamic_axes={ "image": {0: "batch_size", 2: "height", 3: "width"}, }, opset_version=13, do_constant_folding=True, verbose=False, ) print(f"✓ SuperPoint ONNX model saved to {output_path}") return True except Exception as e: print(f"✗ Failed to convert SuperPoint: {e}") return False def convert_onnx_to_tensorrt(onnx_path, engine_path, precision="fp16", max_batch_size=1): """将 ONNX 模型转换为 TensorRT 引擎""" try: import tensorrt as trt except ImportError: print("✗ TensorRT not installed. Please install: pip install nvidia-tensorrt") return False print(f"Converting {onnx_path} to TensorRT engine...") TRT_LOGGER = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(TRT_LOGGER) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, TRT_LOGGER) # 解析 ONNX 文件 with open(onnx_path, 'rb') as model: if not parser.parse(model.read()): print("✗ Failed to parse ONNX file:") for error in range(parser.num_errors): print(f" {parser.get_error(error)}") return False # 配置构建器 config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB # 设置精度 if precision == "fp16": if builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) print(" Using FP16 precision") else: print(" Warning: FP16 not supported on this platform, using FP32") elif precision == "int8": if builder.platform_has_fast_int8: config.set_flag(trt.BuilderFlag.INT8) print(" Using INT8 precision") else: print(" Warning: INT8 not supported on this platform, using FP32") # 构建引擎 print(" Building TensorRT engine... This may take several minutes...") try: engine = builder.build_engine(network, config) except Exception as e: print(f"✗ Failed to build engine: {e}") return False if engine is None: print("✗ Failed to build engine") return False # 保存引擎 with open(engine_path, 'wb') as f: f.write(engine.serialize()) print(f"✓ TensorRT engine saved to {engine_path}") return True def main(): parser = argparse.ArgumentParser(description="Convert LightGlue models to TensorRT") parser.add_argument("--max_keypoints", type=int, default=128, help="Maximum number of keypoints") parser.add_argument("--precision", type=str, default="fp16", choices=["fp32", "fp16", "int8"], help="TensorRT precision") parser.add_argument("--input_shape", type=int, nargs=2, default=[480, 640], help="Input image shape (height, width)") parser.add_argument("--output_dir", type=str, default="./models", help="Output directory for models") parser.add_argument("--skip_onnx", action="store_true", help="Skip ONNX conversion if ONNX file exists") parser.add_argument("--skip_trt", action="store_true", help="Skip TensorRT conversion if engine file exists") args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cpu": print("Warning: CUDA not available, conversion may fail") # 创建输出目录 output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # 创建模型 print("Loading SuperPoint model...") extractor = SuperPoint( max_num_keypoints=args.max_keypoints, detection_threshold=0.01, nms_radius=4, ).eval().to(device) # 转换 SuperPoint onnx_path = output_dir / "superpoint.onnx" engine_path = output_dir / f"superpoint_{args.precision}.engine" # 转换为 ONNX if not args.skip_onnx or not onnx_path.exists(): input_shape = (1, 1, args.input_shape[0], args.input_shape[1]) if not convert_superpoint_to_onnx(extractor, str(onnx_path), input_shape, device): print("Failed to convert SuperPoint to ONNX") return else: print(f"✓ ONNX file already exists: {onnx_path}") # 转换为 TensorRT if not args.skip_trt or not engine_path.exists(): if not convert_onnx_to_tensorrt(str(onnx_path), str(engine_path), args.precision): print("Failed to convert ONNX to TensorRT") return else: print(f"✓ TensorRT engine already exists: {engine_path}") print("\n" + "="*60) print("Conversion completed!") print(f"ONNX model: {onnx_path}") print(f"TensorRT engine: {engine_path}") print("="*60) print("\nNext steps:") print("1. Test the TensorRT engine with: python test_tensorrt.py") print("2. Integrate into your application") if __name__ == "__main__": main()