| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- #!/usr/bin/env python3
- """
- 测试 TensorRT 集成是否正常工作
- """
- import torch
- # 测试导入
- try:
- import torch_tensorrt
- print("✓ torch_tensorrt imported successfully")
- except ImportError as e:
- print(f"✗ Failed to import torch_tensorrt: {e}")
- exit(1)
- # 测试基本功能
- try:
- # 创建一个简单的模型
- class SimpleModel(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.conv = torch.nn.Conv2d(1, 64, 3, padding=1)
-
- def forward(self, x):
- return self.conv(x)
-
- model = SimpleModel().eval().cuda()
- example_input = torch.randn(1, 1, 480, 640).cuda()
-
- # 编译为 TensorRT
- print("Compiling model with TensorRT...")
- trt_model = torch_tensorrt.compile(
- model,
- inputs=[example_input],
- enabled_precisions={torch.float, torch.half},
- workspace_size=1 << 30,
- )
-
- # 测试推理
- with torch.no_grad():
- output = trt_model(example_input)
-
- print("✓ TensorRT compilation and inference successful!")
- print(f" Input shape: {example_input.shape}")
- print(f" Output shape: {output.shape}")
-
- except Exception as e:
- print(f"✗ TensorRT test failed: {e}")
- import traceback
- traceback.print_exc()
- exit(1)
- print("\n" + "="*60)
- print("TensorRT integration test passed!")
- print("You can now use --use_tensorrt flag in your demo script")
- print("="*60)
|