test_onnxocr_detection_modes.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import sys
  2. import os
  3. from pathlib import Path
  4. import cv2
  5. import numpy as np
  6. import json
  7. import time
  8. # 确保输出编码为UTF-8
  9. sys.stdout.reconfigure(encoding='utf-8')
  10. sys.stderr.reconfigure(encoding='utf-8')
  11. # 添加OnnxOCR路径
  12. project_root = Path(__file__).parent.parent.parent
  13. onnxocr_path = project_root / 'python' / 'OnnxOCR-main'
  14. if onnxocr_path.exists():
  15. sys.path.insert(0, str(onnxocr_path))
  16. print(f"[INFO] 使用本地OnnxOCR路径: {onnxocr_path}")
  17. else:
  18. print(f"[ERROR] 未找到本地OnnxOCR路径: {onnxocr_path}")
  19. sys.exit(1)
  20. try:
  21. from onnxocr.onnx_paddleocr import ONNXPaddleOcr
  22. print("[INFO] OnnxOCR 导入成功")
  23. except ImportError as e:
  24. print(f"[ERROR] 无法导入OnnxOCR模块: {e}")
  25. sys.exit(1)
  26. def test_different_ocr_modes(image_path):
  27. """测试OnnxOCR的不同模式"""
  28. print(f"\n🧪 测试图片: {image_path}")
  29. # 读取图片
  30. img_array = np.fromfile(str(image_path), dtype=np.uint8)
  31. img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
  32. if img is None:
  33. print(f"[ERROR] 无法读取图片: {image_path}")
  34. return
  35. print(f"[INFO] 图片尺寸: {img.shape[:2][::-1]} (宽x高)")
  36. # 初始化OnnxOCR
  37. print("\n[INFO] 初始化OnnxOCR...")
  38. ocr_model = ONNXPaddleOcr(use_angle_cls=True, use_gpu=False)
  39. print("\n" + "="*60)
  40. # 模式1: 完整OCR(检测+识别+角度分类)
  41. print("🔍 模式1: 完整OCR(检测+识别+角度分类)")
  42. start_time = time.time()
  43. full_result = ocr_model.ocr(img, det=True, rec=True, cls=True)
  44. elapsed = time.time() - start_time
  45. print(f"⏱️ 耗时: {elapsed:.2f}秒")
  46. if full_result and full_result[0]:
  47. print(f"📊 检测到 {len(full_result[0])} 个文字区域")
  48. for i, item in enumerate(full_result[0][:3]): # 只显示前3个
  49. bbox, (text, confidence) = item
  50. print(f" {i+1}. '{text}' (置信度: {confidence:.3f})")
  51. if len(full_result[0]) > 3:
  52. print(f" ... 还有 {len(full_result[0]) - 3} 个区域")
  53. print("\n" + "="*60)
  54. # 模式2: 只检测文字区域(不识别文字内容)
  55. print("🎯 模式2: 只检测文字区域(不识别文字内容)")
  56. start_time = time.time()
  57. detection_only = ocr_model.ocr(img, det=True, rec=False, cls=False)
  58. elapsed = time.time() - start_time
  59. print(f"⏱️ 耗时: {elapsed:.2f}秒")
  60. if detection_only and detection_only[0]:
  61. print(f"📍 检测到 {len(detection_only[0])} 个文字区域(仅坐标)")
  62. for i, bbox in enumerate(detection_only[0][:3]): # 只显示前3个
  63. # 计算区域中心和大小
  64. bbox_array = np.array(bbox)
  65. center_x = np.mean(bbox_array[:, 0])
  66. center_y = np.mean(bbox_array[:, 1])
  67. width = np.max(bbox_array[:, 0]) - np.min(bbox_array[:, 0])
  68. height = np.max(bbox_array[:, 1]) - np.min(bbox_array[:, 1])
  69. print(f" {i+1}. 中心({center_x:.0f},{center_y:.0f}) 尺寸({width:.0f}x{height:.0f})")
  70. if len(detection_only[0]) > 3:
  71. print(f" ... 还有 {len(detection_only[0]) - 3} 个区域")
  72. print("\n" + "="*60)
  73. # 模式3: 直接调用文字检测器
  74. print("🔧 模式3: 直接调用文字检测器")
  75. start_time = time.time()
  76. detector_result = ocr_model.text_detector(img)
  77. elapsed = time.time() - start_time
  78. print(f"⏱️ 耗时: {elapsed:.2f}秒")
  79. if detector_result is not None and len(detector_result) > 0:
  80. print(f"🎪 检测到 {len(detector_result)} 个文字区域(原始检测器输出)")
  81. for i, bbox in enumerate(detector_result[:3]): # 只显示前3个
  82. bbox_array = np.array(bbox)
  83. center_x = np.mean(bbox_array[:, 0])
  84. center_y = np.mean(bbox_array[:, 1])
  85. width = np.max(bbox_array[:, 0]) - np.min(bbox_array[:, 0])
  86. height = np.max(bbox_array[:, 1]) - np.min(bbox_array[:, 1])
  87. print(f" {i+1}. 中心({center_x:.0f},{center_y:.0f}) 尺寸({width:.0f}x{height:.0f})")
  88. if len(detector_result) > 3:
  89. print(f" ... 还有 {len(detector_result) - 3} 个区域")
  90. print("\n" + "="*60)
  91. # 性能对比总结
  92. print("📈 性能对比总结:")
  93. print(" 模式1 (完整OCR): 最慢,但提供完整的文字内容和坐标")
  94. print(" 模式2 (仅检测): 较快,只提供文字区域坐标")
  95. print(" 模式3 (检测器): 最快,提供原始检测结果")
  96. print("\n💡 推荐使用场景:")
  97. print(" - 需要文字内容: 使用模式1")
  98. print(" - 只需要区域位置: 使用模式2或3")
  99. print(" - 批量处理/实时应用: 使用模式2或3,然后选择性识别")
  100. if __name__ == '__main__':
  101. # 如果没有参数,使用默认的测试图片
  102. if len(sys.argv) < 2:
  103. # 尝试找到一个可用的测试图片
  104. project_root = Path(__file__).parent.parent.parent
  105. test_paths = [
  106. project_root / "static/漫画/image/鬼-巷第001卷/第一章/test/tmp/0004_鬼-巷第001卷_text_mask.png",
  107. project_root / "static/漫画/image/鬼-巷第001卷/第一章/0004_鬼-巷第001卷.jpeg",
  108. ]
  109. image_path = None
  110. for test_path in test_paths:
  111. if test_path.exists():
  112. image_path = str(test_path)
  113. print(f"[INFO] 使用默认测试图片: {test_path.name}")
  114. break
  115. if image_path is None:
  116. print("Usage: python test_onnxocr_detection_modes.py <image_path>")
  117. print("No test image found. Please provide an image path.")
  118. sys.exit(1)
  119. else:
  120. image_path = sys.argv[1]
  121. # 处理Windows编码问题
  122. try:
  123. # 尝试使用Path对象处理路径
  124. path_obj = Path(image_path)
  125. if not path_obj.exists():
  126. print(f"[ERROR] 图片文件不存在: {image_path}")
  127. # 尝试相对路径
  128. project_root = Path(__file__).parent.parent.parent
  129. path_obj = project_root / image_path
  130. if path_obj.exists():
  131. image_path = str(path_obj)
  132. print(f"[INFO] 使用相对路径: {path_obj}")
  133. else:
  134. sys.exit(1)
  135. else:
  136. image_path = str(path_obj)
  137. except Exception as e:
  138. print(f"[ERROR] 路径处理错误: {e}")
  139. sys.exit(1)
  140. test_different_ocr_modes(image_path)