ocr_with_onnxocr_optimized.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import sys
  2. import os
  3. from pathlib import Path
  4. import cv2
  5. import numpy as np
  6. import json
  7. import time
  8. import argparse
  9. # 确保输出编码为UTF-8
  10. sys.stdout.reconfigure(encoding='utf-8')
  11. sys.stderr.reconfigure(encoding='utf-8')
  12. # 添加OnnxOCR路径
  13. project_root = Path(__file__).parent.parent.parent
  14. onnxocr_path = project_root / 'python' / 'OnnxOCR-main'
  15. if onnxocr_path.exists():
  16. sys.path.insert(0, str(onnxocr_path))
  17. print(f"[INFO] 使用本地OnnxOCR路径: {onnxocr_path}")
  18. else:
  19. print(f"[ERROR] 未找到本地OnnxOCR路径: {onnxocr_path}")
  20. sys.exit(1)
  21. try:
  22. from onnxocr.onnx_paddleocr import ONNXPaddleOcr
  23. ONNXOCR_AVAILABLE = True
  24. except ImportError as e:
  25. print(f"[ERROR] 无法导入OnnxOCR模块: {e}")
  26. ONNXOCR_AVAILABLE = False
  27. sys.exit(1)
  28. def ocr_with_onnxocr_modes(image_path, text_mask_path, output_dir, mode="full"):
  29. """
  30. 使用OnnxOCR进行OCR识别,支持不同模式
  31. Args:
  32. image_path: 输入图片路径
  33. text_mask_path: 文字遮罩路径(可以为空)
  34. output_dir: 输出目录
  35. mode: OCR模式 - "full"(完整), "detect"(仅检测), "fast"(快速检测)
  36. """
  37. if not ONNXOCR_AVAILABLE:
  38. print("[ERROR] OnnxOCR 不可用")
  39. return None
  40. # 创建输出目录
  41. output_dir = Path(output_dir)
  42. output_dir.mkdir(parents=True, exist_ok=True)
  43. print(f"[INFO] OCR模式: {mode}")
  44. print(f"[INFO] 输入图片: {image_path}")
  45. print(f"[INFO] 输出目录: {output_dir}")
  46. try:
  47. # 初始化OnnxOCR
  48. print("[INFO] 初始化OnnxOCR...")
  49. start_init = time.time()
  50. onnxocr_instance = ONNXPaddleOcr(use_angle_cls=True, use_gpu=False)
  51. print(f"[INFO] OnnxOCR 初始化完成 ({time.time()-start_init:.2f}秒)")
  52. # 读取图片
  53. print(f"[INFO] 读取图片: {image_path}")
  54. img_array = np.fromfile(str(image_path), dtype=np.uint8)
  55. img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
  56. if img is None:
  57. raise ValueError(f"无法读取图片: {image_path}")
  58. print(f"[INFO] 图片读取成功,尺寸: {img.shape}")
  59. # 根据模式执行不同的OCR操作
  60. start_ocr = time.time()
  61. if mode == "full":
  62. # 完整OCR模式:检测+识别+角度分类
  63. print("[INFO] 执行完整OCR识别(检测+识别+角度分类)...")
  64. ocr_result = onnxocr_instance.ocr(img, det=True, rec=True, cls=True)
  65. dialogues = []
  66. if ocr_result and ocr_result[0]:
  67. for detection in ocr_result[0]:
  68. bbox, (text, confidence) = detection
  69. dialogues.append({
  70. "bbox": bbox,
  71. "text": text,
  72. "confidence": float(confidence),
  73. "mode": "full_ocr"
  74. })
  75. elif mode == "detect":
  76. # 仅检测模式:只检测文字区域,不识别文字
  77. print("[INFO] 执行文字区域检测(仅坐标,不识别文字)...")
  78. detection_result = onnxocr_instance.ocr(img, det=True, rec=False, cls=False)
  79. dialogues = []
  80. if detection_result and detection_result[0]:
  81. for i, bbox in enumerate(detection_result[0]):
  82. dialogues.append({
  83. "bbox": bbox,
  84. "text": f"[区域{i+1}]", # 占位符文字
  85. "confidence": 1.0,
  86. "mode": "detection_only"
  87. })
  88. elif mode == "fast":
  89. # 快速检测模式:直接使用检测器
  90. print("[INFO] 执行快速文字检测(直接检测器)...")
  91. dt_boxes = onnxocr_instance.text_detector(img)
  92. dialogues = []
  93. if dt_boxes is not None and len(dt_boxes) > 0:
  94. for i, bbox in enumerate(dt_boxes):
  95. dialogues.append({
  96. "bbox": bbox.tolist(), # 转换numpy数组为列表
  97. "text": f"[快速检测{i+1}]", # 占位符文字
  98. "confidence": 1.0,
  99. "mode": "fast_detection"
  100. })
  101. else:
  102. raise ValueError(f"不支持的模式: {mode}")
  103. ocr_elapsed = time.time() - start_ocr
  104. print(f"[INFO] OCR处理完成 ({ocr_elapsed:.2f}秒)")
  105. print(f"[INFO] 检测到 {len(dialogues)} 个文字区域")
  106. # 保存结果到JSON文件
  107. image_name = Path(image_path).stem
  108. output_json_path = output_dir / f"{image_name}_dialogues_{mode}.json"
  109. result_data = {
  110. "dialogues": dialogues,
  111. "total_dialogues": len(dialogues),
  112. "image_path": str(image_path),
  113. "ocr_engine": "OnnxOCR",
  114. "ocr_mode": mode,
  115. "processing_time": {
  116. "initialization": f"{start_init:.2f}s",
  117. "ocr_processing": f"{ocr_elapsed:.2f}s",
  118. "total": f"{time.time()-start_init:.2f}s"
  119. },
  120. "performance_info": {
  121. "detected_regions": len(dialogues),
  122. "mode_description": {
  123. "full": "完整OCR:检测+识别+角度分类",
  124. "detect": "仅检测:只检测区域坐标,不识别文字",
  125. "fast": "快速检测:直接使用检测器,最快速度"
  126. }.get(mode, "未知模式")
  127. }
  128. }
  129. with open(output_json_path, 'w', encoding='utf-8') as f:
  130. json.dump(result_data, f, ensure_ascii=False, indent=2)
  131. print(f"[INFO] 结果已保存到: {output_json_path}")
  132. # 打印识别结果预览
  133. print("[INFO] 识别结果预览:")
  134. for i, d in enumerate(dialogues[:5]):
  135. if mode == "full":
  136. print(f" {i+1}. '{d['text']}' (置信度: {d['confidence']:.3f})")
  137. else:
  138. bbox = d['bbox']
  139. if isinstance(bbox[0], list): # 多边形格式
  140. bbox_array = np.array(bbox)
  141. center_x = np.mean(bbox_array[:, 0])
  142. center_y = np.mean(bbox_array[:, 1])
  143. width = np.max(bbox_array[:, 0]) - np.min(bbox_array[:, 0])
  144. height = np.max(bbox_array[:, 1]) - np.min(bbox_array[:, 1])
  145. print(f" {i+1}. 区域中心({center_x:.0f},{center_y:.0f}) 尺寸({width:.0f}x{height:.0f})")
  146. if len(dialogues) > 5:
  147. print(f" ... 还有 {len(dialogues) - 5} 个区域")
  148. print(f"[SUCCESS] OCR识别完成,共处理 {len(dialogues)} 个区域")
  149. return {
  150. "json_path": str(output_json_path),
  151. "total_count": len(dialogues),
  152. "mode": mode,
  153. "processing_time": ocr_elapsed
  154. }
  155. except Exception as e:
  156. print(f"[ERROR] OCR处理失败: {e}")
  157. import traceback
  158. traceback.print_exc()
  159. return None
  160. if __name__ == '__main__':
  161. parser = argparse.ArgumentParser(description='OnnxOCR多模式文字识别')
  162. parser.add_argument('image_path', help='输入图片路径')
  163. parser.add_argument('text_mask_path', nargs='?', default='', help='文字遮罩路径(可选)')
  164. parser.add_argument('output_dir', help='输出目录')
  165. parser.add_argument('--mode', choices=['full', 'detect', 'fast'], default='full',
  166. help='OCR模式:full(完整OCR), detect(仅检测), fast(快速检测)')
  167. args = parser.parse_args()
  168. print(f"[DEBUG] 开始OCR处理...")
  169. print(f"[DEBUG] 参数: 图片={args.image_path}, 模式={args.mode}")
  170. result = ocr_with_onnxocr_modes(args.image_path, args.text_mask_path, args.output_dir, args.mode)
  171. if result:
  172. print(f"[SUCCESS] 处理完成: {result}")
  173. else:
  174. print("[ERROR] 处理失败")
  175. sys.exit(1)