| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- #!/usr/bin/env python3
- """
- 测试 INT8 编译是否正常工作
- """
- import torch
- import sys
- print("Testing INT8 compilation...")
- print("="*60)
- # 检查CUDA
- if not torch.cuda.is_available():
- print("[ERROR] CUDA not available")
- sys.exit(1)
- print(f"CUDA available: {torch.cuda.is_available()}")
- print(f"GPU: {torch.cuda.get_device_name(0)}")
- # 检查TensorRT
- try:
- import torch_tensorrt
- print(f"torch-tensorrt version: {torch_tensorrt.__version__}")
- except ImportError:
- print("[ERROR] torch-tensorrt not installed")
- sys.exit(1)
- # 创建简单模型测试INT8编译
- print("\nCreating test model...")
- class SimpleModel(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.conv1 = torch.nn.Conv2d(1, 32, 3, padding=1)
- self.conv2 = torch.nn.Conv2d(32, 64, 3, padding=1)
- self.pool = torch.nn.MaxPool2d(2)
-
- def forward(self, x):
- x = torch.relu(self.conv1(x))
- x = self.pool(x)
- x = torch.relu(self.conv2(x))
- return x
- model = SimpleModel().eval().cuda()
- example_input = torch.randn(1, 1, 480, 640).cuda()
- print("Model created successfully")
- print(f"Input shape: {example_input.shape}")
- # 测试INT8编译
- print("\n" + "="*60)
- print("Attempting INT8 compilation...")
- print("WARNING: This may take 5-10 minutes")
- print("="*60)
- sys.stdout.flush()
- try:
- print("Starting compilation...")
- sys.stdout.flush()
-
- trt_model = torch_tensorrt.compile(
- model,
- inputs=[example_input],
- enabled_precisions={torch.float, torch.int8},
- workspace_size=1 << 30, # 1GB
- min_block_size=7,
- ir="torchscript",
- )
-
- print("\n[OK] INT8 compilation successful!")
- print("="*60)
-
- # 测试推理
- print("Testing inference...")
- with torch.no_grad():
- output = trt_model(example_input)
- print(f"Output shape: {output.shape}")
- print("[OK] Inference successful!")
-
- except Exception as e:
- print(f"\n[ERROR] INT8 compilation failed: {e}")
- print("="*60)
- print("Full traceback:")
- import traceback
- traceback.print_exc()
- print("\nRecommendation: Use FP16 instead of INT8")
- sys.exit(1)
- print("\n" + "="*60)
- print("INT8 compilation test PASSED!")
- print("You can use --tensorrt_precision int8 in your demo script")
- print("="*60)
|