#!/usr/bin/env python3 """ 测试 TensorRT 集成是否正常工作 """ import torch # 测试导入 try: import torch_tensorrt print("✓ torch_tensorrt imported successfully") except ImportError as e: print(f"✗ Failed to import torch_tensorrt: {e}") exit(1) # 测试基本功能 try: # 创建一个简单的模型 class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(1, 64, 3, padding=1) def forward(self, x): return self.conv(x) model = SimpleModel().eval().cuda() example_input = torch.randn(1, 1, 480, 640).cuda() # 编译为 TensorRT print("Compiling model with TensorRT...") trt_model = torch_tensorrt.compile( model, inputs=[example_input], enabled_precisions={torch.float, torch.half}, workspace_size=1 << 30, ) # 测试推理 with torch.no_grad(): output = trt_model(example_input) print("✓ TensorRT compilation and inference successful!") print(f" Input shape: {example_input.shape}") print(f" Output shape: {output.shape}") except Exception as e: print(f"✗ TensorRT test failed: {e}") import traceback traceback.print_exc() exit(1) print("\n" + "="*60) print("TensorRT integration test passed!") print("You can now use --use_tensorrt flag in your demo script") print("="*60)