test_tensorrt_integration.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #!/usr/bin/env python3
  2. """
  3. 测试 TensorRT 集成是否正常工作
  4. """
  5. import torch
  6. # 测试导入
  7. try:
  8. import torch_tensorrt
  9. print("✓ torch_tensorrt imported successfully")
  10. except ImportError as e:
  11. print(f"✗ Failed to import torch_tensorrt: {e}")
  12. exit(1)
  13. # 测试基本功能
  14. try:
  15. # 创建一个简单的模型
  16. class SimpleModel(torch.nn.Module):
  17. def __init__(self):
  18. super().__init__()
  19. self.conv = torch.nn.Conv2d(1, 64, 3, padding=1)
  20. def forward(self, x):
  21. return self.conv(x)
  22. model = SimpleModel().eval().cuda()
  23. example_input = torch.randn(1, 1, 480, 640).cuda()
  24. # 编译为 TensorRT
  25. print("Compiling model with TensorRT...")
  26. trt_model = torch_tensorrt.compile(
  27. model,
  28. inputs=[example_input],
  29. enabled_precisions={torch.float, torch.half},
  30. workspace_size=1 << 30,
  31. )
  32. # 测试推理
  33. with torch.no_grad():
  34. output = trt_model(example_input)
  35. print("✓ TensorRT compilation and inference successful!")
  36. print(f" Input shape: {example_input.shape}")
  37. print(f" Output shape: {output.shape}")
  38. except Exception as e:
  39. print(f"✗ TensorRT test failed: {e}")
  40. import traceback
  41. traceback.print_exc()
  42. exit(1)
  43. print("\n" + "="*60)
  44. print("TensorRT integration test passed!")
  45. print("You can now use --use_tensorrt flag in your demo script")
  46. print("="*60)