tensorrt_integration_example.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #!/usr/bin/env python3
  2. """
  3. TensorRT 集成示例
  4. 展示如何在 demo_lightglue_camera_position_async.py 中集成 TensorRT
  5. """
  6. # 在 demo_lightglue_camera_position_async.py 中添加以下代码:
  7. # ===== 1. 在文件开头添加导入 =====
  8. """
  9. try:
  10. from tensorrt_wrapper import create_tensorrt_models
  11. TENSORRT_AVAILABLE = True
  12. except ImportError:
  13. TENSORRT_AVAILABLE = False
  14. print("TensorRT not available, using PyTorch models")
  15. """
  16. # ===== 2. 在 argparse 中添加参数 =====
  17. """
  18. parser.add_argument(
  19. "--use_tensorrt",
  20. action="store_true",
  21. help="Use TensorRT optimized models (requires torch-tensorrt)"
  22. )
  23. parser.add_argument(
  24. "--tensorrt_precision",
  25. type=str,
  26. default="fp16",
  27. choices=["fp32", "fp16", "int8"],
  28. help="TensorRT precision mode"
  29. )
  30. """
  31. # ===== 3. 在模型加载后添加 TensorRT 优化 =====
  32. """
  33. # 原始代码:
  34. extractor = SuperPoint(...).eval().to(device)
  35. matcher = LightGlue(...).eval().to(device)
  36. # 添加 TensorRT 优化:
  37. if opt.use_tensorrt and TENSORRT_AVAILABLE and device == "cuda":
  38. try:
  39. print("Compiling models with TensorRT...")
  40. print(f"Precision: {opt.tensorrt_precision}")
  41. print("This may take several minutes on first run...")
  42. extractor, matcher = create_tensorrt_models(
  43. extractor,
  44. matcher,
  45. precision=opt.tensorrt_precision
  46. )
  47. print("✓ TensorRT models compiled successfully")
  48. print("Note: Compiled models are cached for faster startup next time")
  49. except Exception as e:
  50. print(f"✗ Failed to compile with TensorRT: {e}")
  51. print("Falling back to PyTorch models")
  52. import traceback
  53. traceback.print_exc()
  54. else:
  55. if opt.use_tensorrt:
  56. print("TensorRT requested but not available, using PyTorch models")
  57. """
  58. # ===== 4. 完整集成示例 =====
  59. def integrate_tensorrt_into_demo():
  60. """
  61. 完整的集成代码片段
  62. 将以下代码添加到 demo_lightglue_camera_position_async.py 的相应位置
  63. """
  64. integration_code = '''
  65. # ===== 在导入部分添加 =====
  66. try:
  67. from tensorrt_wrapper import create_tensorrt_models
  68. TENSORRT_AVAILABLE = True
  69. except ImportError:
  70. TENSORRT_AVAILABLE = False
  71. # ===== 在 argparse 部分添加 =====
  72. parser.add_argument(
  73. "--use_tensorrt",
  74. action="store_true",
  75. help="Use TensorRT optimized models (requires torch-tensorrt)"
  76. )
  77. parser.add_argument(
  78. "--tensorrt_precision",
  79. type=str,
  80. default="fp16",
  81. choices=["fp32", "fp16", "int8"],
  82. help="TensorRT precision mode"
  83. )
  84. # ===== 在模型加载后(约第338行)添加 =====
  85. print("Loaded SuperPoint and LightGlue models")
  86. # TensorRT 优化
  87. if opt.use_tensorrt and TENSORRT_AVAILABLE and device == "cuda":
  88. try:
  89. print("="*60)
  90. print("Compiling models with TensorRT...")
  91. print(f"Precision: {opt.tensorrt_precision}")
  92. print("This may take several minutes on first run...")
  93. print("="*60)
  94. extractor, matcher = create_tensorrt_models(
  95. extractor,
  96. matcher,
  97. precision=opt.tensorrt_precision
  98. )
  99. print("="*60)
  100. print("✓ TensorRT models compiled successfully")
  101. print("Note: Compiled models are cached for faster startup next time")
  102. print("="*60)
  103. except Exception as e:
  104. print(f"✗ Failed to compile with TensorRT: {e}")
  105. print("Falling back to PyTorch models")
  106. import traceback
  107. traceback.print_exc()
  108. elif opt.use_tensorrt:
  109. if not TENSORRT_AVAILABLE:
  110. print("Warning: TensorRT requested but torch-tensorrt not installed")
  111. print("Install with: pip install torch-tensorrt")
  112. elif device != "cuda":
  113. print("Warning: TensorRT requires CUDA, but running on CPU")
  114. '''
  115. return integration_code
  116. # ===== 使用方法 =====
  117. usage_instructions = """
  118. 使用方法:
  119. 1. 安装依赖:
  120. pip install torch-tensorrt
  121. 2. 运行程序(首次运行会编译模型,需要几分钟):
  122. python demo_lightglue_camera_position_async.py \\
  123. --input "udp://0.0.0.0:12346" \\
  124. --max_keypoints 128 \\
  125. --use_fp16 \\
  126. --use_tensorrt \\
  127. --tensorrt_precision fp16
  128. 3. 第二次运行会直接加载编译好的模型(很快)
  129. 4. 性能对比:
  130. - PyTorch FP16: ~22 FPS
  131. - TensorRT FP16: ~35-45 FPS (预期)
  132. - TensorRT INT8: ~50-60 FPS (预期,但精度可能下降)
  133. 注意事项:
  134. - 首次编译需要较长时间(5-15分钟)
  135. - 编译后的模型会保存在当前目录(superpoint_fp16.ts, lightglue_fp16.ts)
  136. - 如果模型结构改变,需要删除缓存文件重新编译
  137. - INT8 量化可能需要校准数据
  138. """
  139. if __name__ == "__main__":
  140. print("="*60)
  141. print("TensorRT 集成指南")
  142. print("="*60)
  143. print(integrate_tensorrt_into_demo())
  144. print("\n" + "="*60)
  145. print("使用说明")
  146. print("="*60)
  147. print(usage_instructions)