| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- #!/usr/bin/env python3
- """
- TensorRT 推理包装器
- 支持两种方式:
- 1. Torch-TensorRT(推荐,更简单)
- 2. 原生 TensorRT(需要 ONNX 转换)
- """
- import torch
- import torch.nn as nn
- from typing import Dict, Optional
- import os
- class TensorRTWrapper:
- """TensorRT 推理包装器"""
-
- def __init__(self, model, model_name="model", precision="fp16", use_torch_tensorrt=True):
- """
- 初始化 TensorRT 包装器
-
- Args:
- model: PyTorch 模型
- model_name: 模型名称(用于保存/加载)
- precision: 精度 ("fp32", "fp16", "int8")
- use_torch_tensorrt: 是否使用 Torch-TensorRT(推荐)
- """
- self.model = model
- self.model_name = model_name
- self.precision = precision
- self.use_torch_tensorrt = use_torch_tensorrt
- self.trt_model = None
- self.engine_path = f"{model_name}_{precision}.ts"
-
- if use_torch_tensorrt:
- self._compile_with_torch_tensorrt()
- else:
- self._compile_with_onnx_tensorrt()
-
- def _compile_with_torch_tensorrt(self):
- """使用 Torch-TensorRT 编译模型"""
- try:
- import torch_tensorrt
- except ImportError:
- print(f"Warning: torch_tensorrt not installed. Install with: pip install torch-tensorrt")
- print("Falling back to PyTorch model")
- return
-
- # 检查是否已有编译好的模型
- if os.path.exists(self.engine_path):
- print(f"Loading compiled TensorRT model from {self.engine_path}")
- try:
- self.trt_model = torch.jit.load(self.engine_path)
- self.trt_model.eval()
- print("✓ TensorRT model loaded successfully")
- return
- except Exception as e:
- print(f"Failed to load TensorRT model: {e}")
- print("Will recompile...")
-
- print(f"Compiling {self.model_name} with Torch-TensorRT ({self.precision})...")
- print("This may take several minutes...")
-
- try:
- # 创建示例输入
- # 注意:需要根据实际模型调整输入形状
- if "superpoint" in self.model_name.lower():
- example_input = torch.randn(1, 1, 480, 640).cuda()
- else:
- # LightGlue 需要两个输入
- example_input = [
- torch.randn(1, 128, 2).cuda(), # keypoints0
- torch.randn(1, 128, 256).cuda(), # descriptors0
- torch.randn(1, 128, 2).cuda(), # keypoints1
- torch.randn(1, 128, 256).cuda(), # descriptors1
- ]
-
- # 设置精度
- enabled_precisions = {torch.float}
- if self.precision == "fp16":
- enabled_precisions.add(torch.half)
- elif self.precision == "int8":
- enabled_precisions.add(torch.int8)
-
- # 编译模型
- self.trt_model = torch_tensorrt.compile(
- self.model,
- inputs=example_input if isinstance(example_input, torch.Tensor) else example_input,
- enabled_precisions=enabled_precisions,
- workspace_size=1 << 30, # 1GB
- min_block_size=7,
- torch_executed_ops=[],
- )
-
- # 保存编译后的模型
- torch.jit.save(self.trt_model, self.engine_path)
- print(f"✓ TensorRT model compiled and saved to {self.engine_path}")
-
- except Exception as e:
- print(f"✗ Failed to compile with Torch-TensorRT: {e}")
- print("Falling back to PyTorch model")
- import traceback
- traceback.print_exc()
-
- def _compile_with_onnx_tensorrt(self):
- """通过 ONNX 转换为 TensorRT(更复杂,但更稳定)"""
- print("ONNX → TensorRT conversion not implemented in this wrapper")
- print("Please use convert_to_tensorrt.py script instead")
-
- def __call__(self, *args, **kwargs):
- """调用模型"""
- if self.trt_model is not None:
- with torch.no_grad():
- return self.trt_model(*args, **kwargs)
- else:
- # 回退到原始 PyTorch 模型
- return self.model(*args, **kwargs)
-
- def eval(self):
- """设置为评估模式"""
- if self.trt_model is not None:
- self.trt_model.eval()
- else:
- self.model.eval()
- return self
- class SuperPointTensorRT(nn.Module):
- """SuperPoint TensorRT 包装器"""
-
- def __init__(self, model, precision="fp16"):
- super().__init__()
- self.wrapper = TensorRTWrapper(
- model,
- model_name="superpoint",
- precision=precision,
- use_torch_tensorrt=True
- )
-
- def forward(self, inputs):
- """前向传播"""
- if isinstance(inputs, dict):
- image = inputs["image"]
- else:
- image = inputs
-
- result = self.wrapper(image)
-
- # 转换为字典格式(如果返回的是元组)
- if isinstance(result, (list, tuple)):
- return {
- "keypoints": result[0],
- "descriptors": result[1],
- "scores": result[2] if len(result) > 2 else None,
- }
- return result
- class LightGlueTensorRT(nn.Module):
- """LightGlue TensorRT 包装器"""
-
- def __init__(self, model, precision="fp16"):
- super().__init__()
- self.wrapper = TensorRTWrapper(
- model,
- model_name="lightglue",
- precision=precision,
- use_torch_tensorrt=True
- )
-
- def forward(self, inputs):
- """前向传播"""
- if isinstance(inputs, dict):
- image0 = inputs["image0"]
- image1 = inputs["image1"]
-
- # 提取特征
- kpts0 = image0["keypoints"]
- desc0 = image0["descriptors"]
- kpts1 = image1["keypoints"]
- desc1 = image1["descriptors"]
-
- # 调用 TensorRT 模型
- result = self.wrapper(kpts0, desc0, kpts1, desc1)
-
- # 转换为字典格式
- if isinstance(result, (list, tuple)):
- return {
- "matches0": result[0],
- "matches1": result[1] if len(result) > 1 else None,
- "matching_scores0": result[2] if len(result) > 2 else None,
- }
- return result
- return self.wrapper(inputs)
- def create_tensorrt_models(extractor, matcher, precision="fp16"):
- """
- 创建 TensorRT 优化的模型
-
- Args:
- extractor: SuperPoint 模型
- matcher: LightGlue 模型
- precision: 精度 ("fp32", "fp16", "int8")
-
- Returns:
- (extractor_trt, matcher_trt): TensorRT 优化的模型
- """
- print("Creating TensorRT optimized models...")
-
- extractor_trt = SuperPointTensorRT(extractor, precision=precision)
- matcher_trt = LightGlueTensorRT(matcher, precision=precision)
-
- return extractor_trt, matcher_trt
|