#!/usr/bin/env python3 """ TensorRT 推理包装器 支持两种方式: 1. Torch-TensorRT(推荐,更简单) 2. 原生 TensorRT(需要 ONNX 转换) """ import torch import torch.nn as nn from typing import Dict, Optional import os class TensorRTWrapper: """TensorRT 推理包装器""" def __init__(self, model, model_name="model", precision="fp16", use_torch_tensorrt=True): """ 初始化 TensorRT 包装器 Args: model: PyTorch 模型 model_name: 模型名称(用于保存/加载) precision: 精度 ("fp32", "fp16", "int8") use_torch_tensorrt: 是否使用 Torch-TensorRT(推荐) """ self.model = model self.model_name = model_name self.precision = precision self.use_torch_tensorrt = use_torch_tensorrt self.trt_model = None self.engine_path = f"{model_name}_{precision}.ts" if use_torch_tensorrt: self._compile_with_torch_tensorrt() else: self._compile_with_onnx_tensorrt() def _compile_with_torch_tensorrt(self): """使用 Torch-TensorRT 编译模型""" try: import torch_tensorrt except ImportError: print(f"Warning: torch_tensorrt not installed. Install with: pip install torch-tensorrt") print("Falling back to PyTorch model") return # 检查是否已有编译好的模型 if os.path.exists(self.engine_path): print(f"Loading compiled TensorRT model from {self.engine_path}") try: self.trt_model = torch.jit.load(self.engine_path) self.trt_model.eval() print("✓ TensorRT model loaded successfully") return except Exception as e: print(f"Failed to load TensorRT model: {e}") print("Will recompile...") print(f"Compiling {self.model_name} with Torch-TensorRT ({self.precision})...") print("This may take several minutes...") try: # 创建示例输入 # 注意:需要根据实际模型调整输入形状 if "superpoint" in self.model_name.lower(): example_input = torch.randn(1, 1, 480, 640).cuda() else: # LightGlue 需要两个输入 example_input = [ torch.randn(1, 128, 2).cuda(), # keypoints0 torch.randn(1, 128, 256).cuda(), # descriptors0 torch.randn(1, 128, 2).cuda(), # keypoints1 torch.randn(1, 128, 256).cuda(), # descriptors1 ] # 设置精度 enabled_precisions = {torch.float} if self.precision == "fp16": enabled_precisions.add(torch.half) elif self.precision == "int8": enabled_precisions.add(torch.int8) # 编译模型 self.trt_model = torch_tensorrt.compile( self.model, inputs=example_input if isinstance(example_input, torch.Tensor) else example_input, enabled_precisions=enabled_precisions, workspace_size=1 << 30, # 1GB min_block_size=7, torch_executed_ops=[], ) # 保存编译后的模型 torch.jit.save(self.trt_model, self.engine_path) print(f"✓ TensorRT model compiled and saved to {self.engine_path}") except Exception as e: print(f"✗ Failed to compile with Torch-TensorRT: {e}") print("Falling back to PyTorch model") import traceback traceback.print_exc() def _compile_with_onnx_tensorrt(self): """通过 ONNX 转换为 TensorRT(更复杂,但更稳定)""" print("ONNX → TensorRT conversion not implemented in this wrapper") print("Please use convert_to_tensorrt.py script instead") def __call__(self, *args, **kwargs): """调用模型""" if self.trt_model is not None: with torch.no_grad(): return self.trt_model(*args, **kwargs) else: # 回退到原始 PyTorch 模型 return self.model(*args, **kwargs) def eval(self): """设置为评估模式""" if self.trt_model is not None: self.trt_model.eval() else: self.model.eval() return self class SuperPointTensorRT(nn.Module): """SuperPoint TensorRT 包装器""" def __init__(self, model, precision="fp16"): super().__init__() self.wrapper = TensorRTWrapper( model, model_name="superpoint", precision=precision, use_torch_tensorrt=True ) def forward(self, inputs): """前向传播""" if isinstance(inputs, dict): image = inputs["image"] else: image = inputs result = self.wrapper(image) # 转换为字典格式(如果返回的是元组) if isinstance(result, (list, tuple)): return { "keypoints": result[0], "descriptors": result[1], "scores": result[2] if len(result) > 2 else None, } return result class LightGlueTensorRT(nn.Module): """LightGlue TensorRT 包装器""" def __init__(self, model, precision="fp16"): super().__init__() self.wrapper = TensorRTWrapper( model, model_name="lightglue", precision=precision, use_torch_tensorrt=True ) def forward(self, inputs): """前向传播""" if isinstance(inputs, dict): image0 = inputs["image0"] image1 = inputs["image1"] # 提取特征 kpts0 = image0["keypoints"] desc0 = image0["descriptors"] kpts1 = image1["keypoints"] desc1 = image1["descriptors"] # 调用 TensorRT 模型 result = self.wrapper(kpts0, desc0, kpts1, desc1) # 转换为字典格式 if isinstance(result, (list, tuple)): return { "matches0": result[0], "matches1": result[1] if len(result) > 1 else None, "matching_scores0": result[2] if len(result) > 2 else None, } return result return self.wrapper(inputs) def create_tensorrt_models(extractor, matcher, precision="fp16"): """ 创建 TensorRT 优化的模型 Args: extractor: SuperPoint 模型 matcher: LightGlue 模型 precision: 精度 ("fp32", "fp16", "int8") Returns: (extractor_trt, matcher_trt): TensorRT 优化的模型 """ print("Creating TensorRT optimized models...") extractor_trt = SuperPointTensorRT(extractor, precision=precision) matcher_trt = LightGlueTensorRT(matcher, precision=precision) return extractor_trt, matcher_trt