tensorrt_wrapper.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #!/usr/bin/env python3
  2. """
  3. TensorRT 推理包装器
  4. 支持两种方式:
  5. 1. Torch-TensorRT(推荐,更简单)
  6. 2. 原生 TensorRT(需要 ONNX 转换)
  7. """
  8. import torch
  9. import torch.nn as nn
  10. from typing import Dict, Optional
  11. import os
  12. class TensorRTWrapper:
  13. """TensorRT 推理包装器"""
  14. def __init__(self, model, model_name="model", precision="fp16", use_torch_tensorrt=True):
  15. """
  16. 初始化 TensorRT 包装器
  17. Args:
  18. model: PyTorch 模型
  19. model_name: 模型名称(用于保存/加载)
  20. precision: 精度 ("fp32", "fp16", "int8")
  21. use_torch_tensorrt: 是否使用 Torch-TensorRT(推荐)
  22. """
  23. self.model = model
  24. self.model_name = model_name
  25. self.precision = precision
  26. self.use_torch_tensorrt = use_torch_tensorrt
  27. self.trt_model = None
  28. self.engine_path = f"{model_name}_{precision}.ts"
  29. if use_torch_tensorrt:
  30. self._compile_with_torch_tensorrt()
  31. else:
  32. self._compile_with_onnx_tensorrt()
  33. def _compile_with_torch_tensorrt(self):
  34. """使用 Torch-TensorRT 编译模型"""
  35. try:
  36. import torch_tensorrt
  37. except ImportError:
  38. print(f"Warning: torch_tensorrt not installed. Install with: pip install torch-tensorrt")
  39. print("Falling back to PyTorch model")
  40. return
  41. # 检查是否已有编译好的模型
  42. if os.path.exists(self.engine_path):
  43. print(f"Loading compiled TensorRT model from {self.engine_path}")
  44. try:
  45. self.trt_model = torch.jit.load(self.engine_path)
  46. self.trt_model.eval()
  47. print("✓ TensorRT model loaded successfully")
  48. return
  49. except Exception as e:
  50. print(f"Failed to load TensorRT model: {e}")
  51. print("Will recompile...")
  52. print(f"Compiling {self.model_name} with Torch-TensorRT ({self.precision})...")
  53. print("This may take several minutes...")
  54. try:
  55. # 创建示例输入
  56. # 注意:需要根据实际模型调整输入形状
  57. if "superpoint" in self.model_name.lower():
  58. example_input = torch.randn(1, 1, 480, 640).cuda()
  59. else:
  60. # LightGlue 需要两个输入
  61. example_input = [
  62. torch.randn(1, 128, 2).cuda(), # keypoints0
  63. torch.randn(1, 128, 256).cuda(), # descriptors0
  64. torch.randn(1, 128, 2).cuda(), # keypoints1
  65. torch.randn(1, 128, 256).cuda(), # descriptors1
  66. ]
  67. # 设置精度
  68. enabled_precisions = {torch.float}
  69. if self.precision == "fp16":
  70. enabled_precisions.add(torch.half)
  71. elif self.precision == "int8":
  72. enabled_precisions.add(torch.int8)
  73. # 编译模型
  74. self.trt_model = torch_tensorrt.compile(
  75. self.model,
  76. inputs=example_input if isinstance(example_input, torch.Tensor) else example_input,
  77. enabled_precisions=enabled_precisions,
  78. workspace_size=1 << 30, # 1GB
  79. min_block_size=7,
  80. torch_executed_ops=[],
  81. )
  82. # 保存编译后的模型
  83. torch.jit.save(self.trt_model, self.engine_path)
  84. print(f"✓ TensorRT model compiled and saved to {self.engine_path}")
  85. except Exception as e:
  86. print(f"✗ Failed to compile with Torch-TensorRT: {e}")
  87. print("Falling back to PyTorch model")
  88. import traceback
  89. traceback.print_exc()
  90. def _compile_with_onnx_tensorrt(self):
  91. """通过 ONNX 转换为 TensorRT(更复杂,但更稳定)"""
  92. print("ONNX → TensorRT conversion not implemented in this wrapper")
  93. print("Please use convert_to_tensorrt.py script instead")
  94. def __call__(self, *args, **kwargs):
  95. """调用模型"""
  96. if self.trt_model is not None:
  97. with torch.no_grad():
  98. return self.trt_model(*args, **kwargs)
  99. else:
  100. # 回退到原始 PyTorch 模型
  101. return self.model(*args, **kwargs)
  102. def eval(self):
  103. """设置为评估模式"""
  104. if self.trt_model is not None:
  105. self.trt_model.eval()
  106. else:
  107. self.model.eval()
  108. return self
  109. class SuperPointTensorRT(nn.Module):
  110. """SuperPoint TensorRT 包装器"""
  111. def __init__(self, model, precision="fp16"):
  112. super().__init__()
  113. self.wrapper = TensorRTWrapper(
  114. model,
  115. model_name="superpoint",
  116. precision=precision,
  117. use_torch_tensorrt=True
  118. )
  119. def forward(self, inputs):
  120. """前向传播"""
  121. if isinstance(inputs, dict):
  122. image = inputs["image"]
  123. else:
  124. image = inputs
  125. result = self.wrapper(image)
  126. # 转换为字典格式(如果返回的是元组)
  127. if isinstance(result, (list, tuple)):
  128. return {
  129. "keypoints": result[0],
  130. "descriptors": result[1],
  131. "scores": result[2] if len(result) > 2 else None,
  132. }
  133. return result
  134. class LightGlueTensorRT(nn.Module):
  135. """LightGlue TensorRT 包装器"""
  136. def __init__(self, model, precision="fp16"):
  137. super().__init__()
  138. self.wrapper = TensorRTWrapper(
  139. model,
  140. model_name="lightglue",
  141. precision=precision,
  142. use_torch_tensorrt=True
  143. )
  144. def forward(self, inputs):
  145. """前向传播"""
  146. if isinstance(inputs, dict):
  147. image0 = inputs["image0"]
  148. image1 = inputs["image1"]
  149. # 提取特征
  150. kpts0 = image0["keypoints"]
  151. desc0 = image0["descriptors"]
  152. kpts1 = image1["keypoints"]
  153. desc1 = image1["descriptors"]
  154. # 调用 TensorRT 模型
  155. result = self.wrapper(kpts0, desc0, kpts1, desc1)
  156. # 转换为字典格式
  157. if isinstance(result, (list, tuple)):
  158. return {
  159. "matches0": result[0],
  160. "matches1": result[1] if len(result) > 1 else None,
  161. "matching_scores0": result[2] if len(result) > 2 else None,
  162. }
  163. return result
  164. return self.wrapper(inputs)
  165. def create_tensorrt_models(extractor, matcher, precision="fp16"):
  166. """
  167. 创建 TensorRT 优化的模型
  168. Args:
  169. extractor: SuperPoint 模型
  170. matcher: LightGlue 模型
  171. precision: 精度 ("fp32", "fp16", "int8")
  172. Returns:
  173. (extractor_trt, matcher_trt): TensorRT 优化的模型
  174. """
  175. print("Creating TensorRT optimized models...")
  176. extractor_trt = SuperPointTensorRT(extractor, precision=precision)
  177. matcher_trt = LightGlueTensorRT(matcher, precision=precision)
  178. return extractor_trt, matcher_trt