ocr_with_paddleocr.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # -*- coding: utf-8 -*-
  2. """
  3. 使用PaddleOCR识别图片中的文字
  4. """
  5. import sys
  6. import json
  7. import cv2
  8. import numpy as np
  9. import os
  10. from pathlib import Path
  11. # ========== 必须在所有导入之前设置环境变量 ==========
  12. # 跳过模型源检查,加快启动速度(必须在导入 PaddleOCR 之前设置)
  13. # 注意:正确的环境变量名是 PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK
  14. # 设置为 'True' 会跳过连接检查,设置为 'False' 或不设置会进行连接检查
  15. # 由于我们已经直接指定了本地模型路径,可以禁用这个检查
  16. os.environ['PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK'] = 'True'
  17. # 禁用 oneDNN 以避免 NotImplementedError(PaddlePaddle 3.3.0 的已知问题)
  18. # 必须在导入 PaddlePaddle 之前设置
  19. os.environ['FLAGS_onednn'] = '0'
  20. os.environ['FLAGS_use_mkldnn'] = '0'
  21. os.environ['FLAGS_enable_onednn_layout_fusion'] = '0'
  22. os.environ['FLAGS_use_mkldnn'] = 'false'
  23. os.environ['FLAGS_onednn'] = 'false'
  24. # 禁用 oneDNN 的更多选项
  25. os.environ['FLAGS_use_mkldnn'] = 'OFF'
  26. os.environ['FLAGS_onednn'] = 'OFF'
  27. # 设置日志级别,减少不必要的日志输出
  28. # 注意:必须在导入 logging 相关模块之前设置
  29. import logging
  30. import warnings
  31. # 设置 paddlex 的日志级别为 WARNING,减少不必要的日志输出
  32. logging.getLogger('paddlex').setLevel(logging.WARNING)
  33. logging.getLogger('paddlex.inference').setLevel(logging.WARNING)
  34. logging.getLogger('paddlex.inference.utils').setLevel(logging.WARNING)
  35. logging.getLogger('paddlex.inference.utils.official_models').setLevel(logging.WARNING)
  36. # 抑制 pkg_resources 的弃用警告
  37. warnings.filterwarnings('ignore', category=UserWarning, message='.*pkg_resources.*')
  38. warnings.filterwarnings('ignore', category=DeprecationWarning, message='.*pkg_resources.*')
  39. # 抑制 ccache 警告(这是 PaddlePaddle 的警告,不影响功能)
  40. warnings.filterwarnings('ignore', message='.*ccache.*')
  41. # ==================================================
  42. # Windows编码修复
  43. if sys.platform == 'win32':
  44. import io
  45. sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
  46. sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
  47. # 添加PaddleOCR路径
  48. project_root = Path(__file__).parent.parent
  49. # 尝试多个可能的路径
  50. paddleocr_paths = [
  51. project_root / 'PaddleOCR-main', # 直接路径
  52. project_root / 'PaddleOCR-main' / 'PaddleOCR-main', # 嵌套路径
  53. ]
  54. paddleocr_path = None
  55. for path in paddleocr_paths:
  56. if path.exists() and (path / 'paddleocr').exists():
  57. paddleocr_path = path
  58. break
  59. if paddleocr_path:
  60. sys.path.insert(0, str(paddleocr_path))
  61. print(f"[INFO] 使用本地PaddleOCR路径: {paddleocr_path}")
  62. else:
  63. print(f"[WARN] 未找到本地PaddleOCR,尝试使用pip安装的版本")
  64. try:
  65. from paddleocr import PaddleOCR
  66. PADDLEOCR_AVAILABLE = True
  67. except ImportError as e:
  68. print(f"[ERROR] 无法导入PaddleOCR模块: {e}")
  69. print("[ERROR] PaddleOCR 是必需的,请确保已正确安装")
  70. PADDLEOCR_AVAILABLE = False
  71. sys.exit(1)
  72. def ocr_with_paddleocr(image_path, text_mask_path, output_dir):
  73. """
  74. 使用PaddleOCR识别图片中的文字
  75. 参数:
  76. image_path: 原始图片路径
  77. text_mask_path: 文字遮罩图路径(用于参考,可选,当前未使用)
  78. output_dir: 输出目录
  79. """
  80. image_path = Path(image_path)
  81. # 处理空字符串的情况
  82. text_mask_path = Path(text_mask_path) if text_mask_path and text_mask_path.strip() else None
  83. output_dir = Path(output_dir)
  84. # 使用 Path.mkdir 处理中文路径,比 os.makedirs 更可靠
  85. output_dir.mkdir(parents=True, exist_ok=True)
  86. print(f"📖 读取原始图片: {image_path.name}")
  87. # 读取原始图片(处理中文路径)
  88. img_array = np.fromfile(str(image_path), dtype=np.uint8)
  89. img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
  90. if img is None:
  91. raise ValueError(f"无法读取图片: {image_path}")
  92. # 确保是RGB格式(3通道)
  93. if len(img.shape) == 2:
  94. # 如果是灰度图,转换为RGB
  95. img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
  96. elif img.shape[2] == 4:
  97. # 如果是RGBA,转换为RGB
  98. img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
  99. elif img.shape[2] == 3:
  100. # BGR转RGB
  101. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  102. img_height, img_width = img.shape[:2]
  103. print(f"[INFO] 图片尺寸: {img_width}x{img_height}")
  104. # 初始化PaddleOCR
  105. print("[INFO] 初始化PaddleOCR...")
  106. try:
  107. # 使用简单的初始化方式,让 PaddleOCR 自动使用本地已下载的模型
  108. # 这样更稳定,避免直接指定模型路径可能导致的初始化问题
  109. paddleocr_instance = PaddleOCR(
  110. lang='ch', # 中文
  111. enable_mkldnn=False # 明确禁用 MKL-DNN/oneDNN
  112. )
  113. print("[INFO] PaddleOCR 初始化成功")
  114. except Exception as e:
  115. print(f"[ERROR] PaddleOCR初始化失败: {e}")
  116. import traceback
  117. traceback.print_exc()
  118. raise RuntimeError(f"PaddleOCR 初始化失败: {e}")
  119. # 执行OCR识别
  120. print("[INFO] 正在识别文字...")
  121. print(f"[DEBUG] 图片数组信息: shape={img.shape}, dtype={img.dtype}, min={img.min()}, max={img.max()}")
  122. try:
  123. # 使用已读取的图片数组,传递给 PaddleOCR(避免中文路径问题)
  124. print(f"[DEBUG] 准备调用 paddleocr_instance.predict...")
  125. import sys
  126. sys.stdout.flush() # 确保输出被刷新
  127. ocr_result = paddleocr_instance.predict(img)
  128. print(f"[DEBUG] OCR结果类型: {type(ocr_result)}, 长度: {len(ocr_result) if ocr_result else 0}")
  129. sys.stdout.flush()
  130. except Exception as e:
  131. print(f"[ERROR] OCR识别失败: {e}")
  132. import traceback
  133. traceback.print_exc()
  134. sys.stdout.flush()
  135. raise
  136. if not ocr_result or len(ocr_result) == 0:
  137. print("[WARN] 未识别到任何文字")
  138. dialogues = []
  139. else:
  140. # 解析PaddleOCR结果
  141. result_item = ocr_result[0]
  142. # PaddleOCR 3.x 返回的是 OCRResult 对象,通过 .json 属性获取数据
  143. try:
  144. result_json = result_item.json
  145. res_data = result_json.get('res', {}) if isinstance(result_json, dict) else {}
  146. # 提取文本、置信度、坐标
  147. rec_texts = res_data.get('rec_texts', [])
  148. rec_scores = res_data.get('rec_scores', [])
  149. rec_polys = res_data.get('rec_polys', []) # 多边形坐标 [[[x1,y1],[x2,y2],[x3,y3],[x4,y4]], ...]
  150. rec_boxes = res_data.get('rec_boxes', []) # 边界框 [[x1,y1,x2,y2], ...]
  151. print(f"[OK] 识别到 {len(rec_texts)} 个文本区域")
  152. # 提取对话文本
  153. dialogues = []
  154. for idx, text in enumerate(rec_texts):
  155. if not text or not text.strip():
  156. continue
  157. # 获取置信度
  158. confidence = float(rec_scores[idx]) if idx < len(rec_scores) else 0.9
  159. # 获取坐标(优先使用多边形坐标,如果没有则使用边界框)
  160. if idx < len(rec_polys) and rec_polys[idx]:
  161. bbox_coords = rec_polys[idx] # [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
  162. elif idx < len(rec_boxes) and rec_boxes[idx]:
  163. # 将边界框转换为多边形格式
  164. box = rec_boxes[idx] # [x1, y1, x2, y2]
  165. bbox_coords = [
  166. [box[0], box[1]], # 左上
  167. [box[2], box[1]], # 右上
  168. [box[2], box[3]], # 右下
  169. [box[0], box[3]] # 左下
  170. ]
  171. else:
  172. print(f" [WARN] 第 {idx} 个文本没有坐标信息,跳过")
  173. continue
  174. # 计算边界框
  175. if not isinstance(bbox_coords, (list, tuple)) or len(bbox_coords) < 4:
  176. print(f" [WARN] 第 {idx} 个文本坐标格式不正确,跳过")
  177. continue
  178. try:
  179. x_coords = []
  180. y_coords = []
  181. for coord in bbox_coords:
  182. if isinstance(coord, (list, tuple)) and len(coord) >= 2:
  183. x_coords.append(coord[0])
  184. y_coords.append(coord[1])
  185. if not x_coords or not y_coords or len(x_coords) < 4:
  186. print(f" [WARN] 第 {idx} 个文本无法提取足够的坐标点,跳过")
  187. continue
  188. x1 = int(min(x_coords))
  189. y1 = int(min(y_coords))
  190. x2 = int(max(x_coords))
  191. y2 = int(max(y_coords))
  192. dialogues.append({
  193. 'order': len(dialogues) + 1,
  194. 'text': text.strip(),
  195. 'bbox': {
  196. 'x1': x1,
  197. 'y1': y1,
  198. 'x2': x2,
  199. 'y2': y2,
  200. 'width': x2 - x1,
  201. 'height': y2 - y1,
  202. 'center_x': float((x1 + x2) / 2),
  203. 'center_y': float((y1 + y2) / 2)
  204. },
  205. 'confidence': confidence
  206. })
  207. print(f" [{len(dialogues)}/{len(rec_texts)}] {text[:50]}...")
  208. except (TypeError, IndexError, ValueError) as e:
  209. print(f" [WARN] 第 {idx} 个文本解析坐标失败: {e},跳过")
  210. continue
  211. except Exception as e:
  212. print(f"[ERROR] 解析PaddleOCR结果失败: {e}")
  213. import traceback
  214. traceback.print_exc()
  215. dialogues = []
  216. # 保存结果
  217. image_name = image_path.stem
  218. output_json = {
  219. 'image_file': f"{image_name}{image_path.suffix}",
  220. 'reading_order': '从右到左、从上到下(日式漫画阅读顺序)',
  221. 'dialogues': dialogues,
  222. 'total_count': len(dialogues)
  223. }
  224. output_file = output_dir / f"{image_name}_dialogues.json"
  225. with open(output_file, 'w', encoding='utf-8') as f:
  226. json.dump(output_json, f, ensure_ascii=False, indent=2)
  227. print(f"\n✅ 结果已保存到: {output_file}")
  228. return output_file
  229. if __name__ == '__main__':
  230. try:
  231. print(f"[DEBUG] sys.argv: {sys.argv}")
  232. print(f"[DEBUG] sys.argv长度: {len(sys.argv)}")
  233. # 至少需要3个参数:脚本名、图片路径、输出目录
  234. # text_mask_path 是可选的,可以为空(空字符串会被shell忽略)
  235. if len(sys.argv) < 3:
  236. print("用法: python ocr_with_paddleocr.py <原始图片路径> [文字遮罩图路径] <输出目录>")
  237. sys.exit(1)
  238. image_path = sys.argv[1]
  239. # 如果只有3个参数,说明没有 text_mask_path(空字符串被忽略),output_dir 是第二个参数
  240. if len(sys.argv) == 3:
  241. text_mask_path = ""
  242. output_dir = sys.argv[2]
  243. elif len(sys.argv) >= 4:
  244. # 有4个或更多参数:脚本名、图片路径、text_mask_path、输出目录
  245. text_mask_path = sys.argv[2]
  246. output_dir = sys.argv[3]
  247. else:
  248. # 不应该到这里
  249. raise ValueError("参数数量不正确")
  250. print(f"[DEBUG] 参数: image_path={image_path}, text_mask_path={text_mask_path}, output_dir={output_dir}")
  251. # 验证图片路径是否存在
  252. if not Path(image_path).exists():
  253. raise FileNotFoundError(f"图片文件不存在: {image_path}")
  254. ocr_with_paddleocr(image_path, text_mask_path, output_dir)
  255. except KeyboardInterrupt:
  256. print("[INFO] 用户中断")
  257. sys.exit(1)
  258. except Exception as e:
  259. print(f"[ERROR] OCR识别失败: {e}")
  260. import traceback
  261. traceback.print_exc()
  262. sys.exit(1)