convert_to_tensorrt.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #!/usr/bin/env python3
  2. """
  3. 将 LightGlue 和 SuperPoint 模型转换为 TensorRT 引擎
  4. 使用方法:
  5. python convert_to_tensorrt.py --max_keypoints 128 --precision fp16
  6. """
  7. import torch
  8. import torch.onnx
  9. import argparse
  10. import os
  11. from pathlib import Path
  12. from lightglue import LightGlue, SuperPoint
  13. def convert_superpoint_to_onnx(model, output_path, input_shape=(1, 1, 480, 640), device="cuda"):
  14. """转换 SuperPoint 模型到 ONNX"""
  15. print(f"Converting SuperPoint to ONNX...")
  16. model.eval()
  17. # 创建虚拟输入
  18. dummy_input = torch.randn(*input_shape).to(device)
  19. # 注意:SuperPoint 的输入是一个字典
  20. # 我们需要创建一个包装器来处理这个
  21. class SuperPointWrapper(torch.nn.Module):
  22. def __init__(self, model):
  23. super().__init__()
  24. self.model = model
  25. def forward(self, image):
  26. return self.model({"image": image})
  27. wrapped_model = SuperPointWrapper(model)
  28. try:
  29. torch.onnx.export(
  30. wrapped_model,
  31. dummy_input,
  32. output_path,
  33. input_names=["image"],
  34. output_names=["keypoints", "descriptors", "scores"],
  35. dynamic_axes={
  36. "image": {0: "batch_size", 2: "height", 3: "width"},
  37. },
  38. opset_version=13,
  39. do_constant_folding=True,
  40. verbose=False,
  41. )
  42. print(f"✓ SuperPoint ONNX model saved to {output_path}")
  43. return True
  44. except Exception as e:
  45. print(f"✗ Failed to convert SuperPoint: {e}")
  46. return False
  47. def convert_onnx_to_tensorrt(onnx_path, engine_path, precision="fp16", max_batch_size=1):
  48. """将 ONNX 模型转换为 TensorRT 引擎"""
  49. try:
  50. import tensorrt as trt
  51. except ImportError:
  52. print("✗ TensorRT not installed. Please install: pip install nvidia-tensorrt")
  53. return False
  54. print(f"Converting {onnx_path} to TensorRT engine...")
  55. TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
  56. builder = trt.Builder(TRT_LOGGER)
  57. network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
  58. parser = trt.OnnxParser(network, TRT_LOGGER)
  59. # 解析 ONNX 文件
  60. with open(onnx_path, 'rb') as model:
  61. if not parser.parse(model.read()):
  62. print("✗ Failed to parse ONNX file:")
  63. for error in range(parser.num_errors):
  64. print(f" {parser.get_error(error)}")
  65. return False
  66. # 配置构建器
  67. config = builder.create_builder_config()
  68. config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
  69. # 设置精度
  70. if precision == "fp16":
  71. if builder.platform_has_fast_fp16:
  72. config.set_flag(trt.BuilderFlag.FP16)
  73. print(" Using FP16 precision")
  74. else:
  75. print(" Warning: FP16 not supported on this platform, using FP32")
  76. elif precision == "int8":
  77. if builder.platform_has_fast_int8:
  78. config.set_flag(trt.BuilderFlag.INT8)
  79. print(" Using INT8 precision")
  80. else:
  81. print(" Warning: INT8 not supported on this platform, using FP32")
  82. # 构建引擎
  83. print(" Building TensorRT engine... This may take several minutes...")
  84. try:
  85. engine = builder.build_engine(network, config)
  86. except Exception as e:
  87. print(f"✗ Failed to build engine: {e}")
  88. return False
  89. if engine is None:
  90. print("✗ Failed to build engine")
  91. return False
  92. # 保存引擎
  93. with open(engine_path, 'wb') as f:
  94. f.write(engine.serialize())
  95. print(f"✓ TensorRT engine saved to {engine_path}")
  96. return True
  97. def main():
  98. parser = argparse.ArgumentParser(description="Convert LightGlue models to TensorRT")
  99. parser.add_argument("--max_keypoints", type=int, default=128, help="Maximum number of keypoints")
  100. parser.add_argument("--precision", type=str, default="fp16", choices=["fp32", "fp16", "int8"],
  101. help="TensorRT precision")
  102. parser.add_argument("--input_shape", type=int, nargs=2, default=[480, 640],
  103. help="Input image shape (height, width)")
  104. parser.add_argument("--output_dir", type=str, default="./models",
  105. help="Output directory for models")
  106. parser.add_argument("--skip_onnx", action="store_true",
  107. help="Skip ONNX conversion if ONNX file exists")
  108. parser.add_argument("--skip_trt", action="store_true",
  109. help="Skip TensorRT conversion if engine file exists")
  110. args = parser.parse_args()
  111. device = "cuda" if torch.cuda.is_available() else "cpu"
  112. if device == "cpu":
  113. print("Warning: CUDA not available, conversion may fail")
  114. # 创建输出目录
  115. output_dir = Path(args.output_dir)
  116. output_dir.mkdir(parents=True, exist_ok=True)
  117. # 创建模型
  118. print("Loading SuperPoint model...")
  119. extractor = SuperPoint(
  120. max_num_keypoints=args.max_keypoints,
  121. detection_threshold=0.01,
  122. nms_radius=4,
  123. ).eval().to(device)
  124. # 转换 SuperPoint
  125. onnx_path = output_dir / "superpoint.onnx"
  126. engine_path = output_dir / f"superpoint_{args.precision}.engine"
  127. # 转换为 ONNX
  128. if not args.skip_onnx or not onnx_path.exists():
  129. input_shape = (1, 1, args.input_shape[0], args.input_shape[1])
  130. if not convert_superpoint_to_onnx(extractor, str(onnx_path), input_shape, device):
  131. print("Failed to convert SuperPoint to ONNX")
  132. return
  133. else:
  134. print(f"✓ ONNX file already exists: {onnx_path}")
  135. # 转换为 TensorRT
  136. if not args.skip_trt or not engine_path.exists():
  137. if not convert_onnx_to_tensorrt(str(onnx_path), str(engine_path), args.precision):
  138. print("Failed to convert ONNX to TensorRT")
  139. return
  140. else:
  141. print(f"✓ TensorRT engine already exists: {engine_path}")
  142. print("\n" + "="*60)
  143. print("Conversion completed!")
  144. print(f"ONNX model: {onnx_path}")
  145. print(f"TensorRT engine: {engine_path}")
  146. print("="*60)
  147. print("\nNext steps:")
  148. print("1. Test the TensorRT engine with: python test_tensorrt.py")
  149. print("2. Integrate into your application")
  150. if __name__ == "__main__":
  151. main()