test_int8_compilation.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python3
  2. """
  3. 测试 INT8 编译是否正常工作
  4. """
  5. import torch
  6. import sys
  7. print("Testing INT8 compilation...")
  8. print("="*60)
  9. # 检查CUDA
  10. if not torch.cuda.is_available():
  11. print("[ERROR] CUDA not available")
  12. sys.exit(1)
  13. print(f"CUDA available: {torch.cuda.is_available()}")
  14. print(f"GPU: {torch.cuda.get_device_name(0)}")
  15. # 检查TensorRT
  16. try:
  17. import torch_tensorrt
  18. print(f"torch-tensorrt version: {torch_tensorrt.__version__}")
  19. except ImportError:
  20. print("[ERROR] torch-tensorrt not installed")
  21. sys.exit(1)
  22. # 创建简单模型测试INT8编译
  23. print("\nCreating test model...")
  24. class SimpleModel(torch.nn.Module):
  25. def __init__(self):
  26. super().__init__()
  27. self.conv1 = torch.nn.Conv2d(1, 32, 3, padding=1)
  28. self.conv2 = torch.nn.Conv2d(32, 64, 3, padding=1)
  29. self.pool = torch.nn.MaxPool2d(2)
  30. def forward(self, x):
  31. x = torch.relu(self.conv1(x))
  32. x = self.pool(x)
  33. x = torch.relu(self.conv2(x))
  34. return x
  35. model = SimpleModel().eval().cuda()
  36. example_input = torch.randn(1, 1, 480, 640).cuda()
  37. print("Model created successfully")
  38. print(f"Input shape: {example_input.shape}")
  39. # 测试INT8编译
  40. print("\n" + "="*60)
  41. print("Attempting INT8 compilation...")
  42. print("WARNING: This may take 5-10 minutes")
  43. print("="*60)
  44. sys.stdout.flush()
  45. try:
  46. print("Starting compilation...")
  47. sys.stdout.flush()
  48. trt_model = torch_tensorrt.compile(
  49. model,
  50. inputs=[example_input],
  51. enabled_precisions={torch.float, torch.int8},
  52. workspace_size=1 << 30, # 1GB
  53. min_block_size=7,
  54. ir="torchscript",
  55. )
  56. print("\n[OK] INT8 compilation successful!")
  57. print("="*60)
  58. # 测试推理
  59. print("Testing inference...")
  60. with torch.no_grad():
  61. output = trt_model(example_input)
  62. print(f"Output shape: {output.shape}")
  63. print("[OK] Inference successful!")
  64. except Exception as e:
  65. print(f"\n[ERROR] INT8 compilation failed: {e}")
  66. print("="*60)
  67. print("Full traceback:")
  68. import traceback
  69. traceback.print_exc()
  70. print("\nRecommendation: Use FP16 instead of INT8")
  71. sys.exit(1)
  72. print("\n" + "="*60)
  73. print("INT8 compilation test PASSED!")
  74. print("You can use --tensorrt_precision int8 in your demo script")
  75. print("="*60)