| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- #!/usr/bin/env python3
- """
- 将 LightGlue 和 SuperPoint 模型转换为 TensorRT 引擎
- 使用方法:
- python convert_to_tensorrt.py --max_keypoints 128 --precision fp16
- """
- import torch
- import torch.onnx
- import argparse
- import os
- from pathlib import Path
- from lightglue import LightGlue, SuperPoint
- def convert_superpoint_to_onnx(model, output_path, input_shape=(1, 1, 480, 640), device="cuda"):
- """转换 SuperPoint 模型到 ONNX"""
- print(f"Converting SuperPoint to ONNX...")
- model.eval()
-
- # 创建虚拟输入
- dummy_input = torch.randn(*input_shape).to(device)
-
- # 注意:SuperPoint 的输入是一个字典
- # 我们需要创建一个包装器来处理这个
- class SuperPointWrapper(torch.nn.Module):
- def __init__(self, model):
- super().__init__()
- self.model = model
-
- def forward(self, image):
- return self.model({"image": image})
-
- wrapped_model = SuperPointWrapper(model)
-
- try:
- torch.onnx.export(
- wrapped_model,
- dummy_input,
- output_path,
- input_names=["image"],
- output_names=["keypoints", "descriptors", "scores"],
- dynamic_axes={
- "image": {0: "batch_size", 2: "height", 3: "width"},
- },
- opset_version=13,
- do_constant_folding=True,
- verbose=False,
- )
- print(f"✓ SuperPoint ONNX model saved to {output_path}")
- return True
- except Exception as e:
- print(f"✗ Failed to convert SuperPoint: {e}")
- return False
- def convert_onnx_to_tensorrt(onnx_path, engine_path, precision="fp16", max_batch_size=1):
- """将 ONNX 模型转换为 TensorRT 引擎"""
- try:
- import tensorrt as trt
- except ImportError:
- print("✗ TensorRT not installed. Please install: pip install nvidia-tensorrt")
- return False
-
- print(f"Converting {onnx_path} to TensorRT engine...")
- TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
-
- builder = trt.Builder(TRT_LOGGER)
- network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
- parser = trt.OnnxParser(network, TRT_LOGGER)
-
- # 解析 ONNX 文件
- with open(onnx_path, 'rb') as model:
- if not parser.parse(model.read()):
- print("✗ Failed to parse ONNX file:")
- for error in range(parser.num_errors):
- print(f" {parser.get_error(error)}")
- return False
-
- # 配置构建器
- config = builder.create_builder_config()
- config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
-
- # 设置精度
- if precision == "fp16":
- if builder.platform_has_fast_fp16:
- config.set_flag(trt.BuilderFlag.FP16)
- print(" Using FP16 precision")
- else:
- print(" Warning: FP16 not supported on this platform, using FP32")
- elif precision == "int8":
- if builder.platform_has_fast_int8:
- config.set_flag(trt.BuilderFlag.INT8)
- print(" Using INT8 precision")
- else:
- print(" Warning: INT8 not supported on this platform, using FP32")
-
- # 构建引擎
- print(" Building TensorRT engine... This may take several minutes...")
- try:
- engine = builder.build_engine(network, config)
- except Exception as e:
- print(f"✗ Failed to build engine: {e}")
- return False
-
- if engine is None:
- print("✗ Failed to build engine")
- return False
-
- # 保存引擎
- with open(engine_path, 'wb') as f:
- f.write(engine.serialize())
-
- print(f"✓ TensorRT engine saved to {engine_path}")
- return True
- def main():
- parser = argparse.ArgumentParser(description="Convert LightGlue models to TensorRT")
- parser.add_argument("--max_keypoints", type=int, default=128, help="Maximum number of keypoints")
- parser.add_argument("--precision", type=str, default="fp16", choices=["fp32", "fp16", "int8"],
- help="TensorRT precision")
- parser.add_argument("--input_shape", type=int, nargs=2, default=[480, 640],
- help="Input image shape (height, width)")
- parser.add_argument("--output_dir", type=str, default="./models",
- help="Output directory for models")
- parser.add_argument("--skip_onnx", action="store_true",
- help="Skip ONNX conversion if ONNX file exists")
- parser.add_argument("--skip_trt", action="store_true",
- help="Skip TensorRT conversion if engine file exists")
-
- args = parser.parse_args()
-
- device = "cuda" if torch.cuda.is_available() else "cpu"
- if device == "cpu":
- print("Warning: CUDA not available, conversion may fail")
-
- # 创建输出目录
- output_dir = Path(args.output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- # 创建模型
- print("Loading SuperPoint model...")
- extractor = SuperPoint(
- max_num_keypoints=args.max_keypoints,
- detection_threshold=0.01,
- nms_radius=4,
- ).eval().to(device)
-
- # 转换 SuperPoint
- onnx_path = output_dir / "superpoint.onnx"
- engine_path = output_dir / f"superpoint_{args.precision}.engine"
-
- # 转换为 ONNX
- if not args.skip_onnx or not onnx_path.exists():
- input_shape = (1, 1, args.input_shape[0], args.input_shape[1])
- if not convert_superpoint_to_onnx(extractor, str(onnx_path), input_shape, device):
- print("Failed to convert SuperPoint to ONNX")
- return
- else:
- print(f"✓ ONNX file already exists: {onnx_path}")
-
- # 转换为 TensorRT
- if not args.skip_trt or not engine_path.exists():
- if not convert_onnx_to_tensorrt(str(onnx_path), str(engine_path), args.precision):
- print("Failed to convert ONNX to TensorRT")
- return
- else:
- print(f"✓ TensorRT engine already exists: {engine_path}")
-
- print("\n" + "="*60)
- print("Conversion completed!")
- print(f"ONNX model: {onnx_path}")
- print(f"TensorRT engine: {engine_path}")
- print("="*60)
- print("\nNext steps:")
- print("1. Test the TensorRT engine with: python test_tensorrt.py")
- print("2. Integrate into your application")
- if __name__ == "__main__":
- main()
|