preprocess_image.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # -*- coding: utf-8 -*-
  2. """
  3. 图像预处理:提高OCR准确率
  4. 包括:对比度增强、去噪、锐化、二值化
  5. """
  6. import sys
  7. import cv2
  8. import numpy as np
  9. from pathlib import Path
  10. # Windows编码修复
  11. if sys.platform == 'win32':
  12. import io
  13. sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
  14. sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
  15. def preprocess_image_for_ocr(input_path, output_path):
  16. """
  17. 对图像进行预处理以提高OCR准确率(针对黑底白字漫画优化)
  18. 步骤:
  19. 1. 检测背景类型(黑底白字 or 白底黑字)
  20. 2. 颜色反转(如果是黑底白字,转换为白底黑字,OCR模型通常训练在白底黑字上)
  21. 3. 提高对比度:使用 CLAHE 自适应直方图均衡化
  22. 4. 去噪:使用 cv2.fastNlMeansDenoising
  23. 5. 锐化:使用锐化核增强文字边缘
  24. 6. 二值化:使用 OTSU 或自适应阈值,确保文字清晰
  25. 7. 形态学操作:去除小噪点,填充空洞
  26. 参数:
  27. input_path: 输入图片路径
  28. output_path: 输出图片路径
  29. 返回:
  30. 处理后的图片(numpy数组)
  31. """
  32. # 读取图片(处理中文路径)
  33. img_array = np.fromfile(str(input_path), dtype=np.uint8)
  34. img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
  35. if img is None:
  36. raise ValueError(f"无法读取图片: {input_path}")
  37. print(f"[INFO] 读取图片: {Path(input_path).name}")
  38. print(f"[INFO] 原始图片尺寸: {img.shape[1]}x{img.shape[0]}")
  39. # 转换为灰度图(如果还不是)
  40. if len(img.shape) == 3:
  41. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  42. else:
  43. gray = img.copy()
  44. # 步骤1: 检测背景类型(黑底白字 or 白底黑字)
  45. mean_brightness = np.mean(gray)
  46. is_dark_background = mean_brightness < 127 # 平均亮度小于127认为是黑底
  47. print(f"[INFO] 步骤1: 检测背景类型...")
  48. print(f" 平均亮度: {mean_brightness:.1f} ({'黑底白字' if is_dark_background else '白底黑字'})")
  49. # 步骤2: 颜色反转(如果是黑底白字,转换为白底黑字)
  50. # OCR模型通常训练在白底黑字上,所以需要反转
  51. if is_dark_background:
  52. print("[INFO] 步骤2: 颜色反转(黑底白字 -> 白底黑字)...")
  53. gray = cv2.bitwise_not(gray)
  54. else:
  55. print("[INFO] 步骤2: 跳过反转(已是白底黑字)")
  56. # 步骤3: 提高对比度 - 使用 CLAHE 自适应直方图均衡化
  57. # 对于黑底白字反转后的图片,适度增强对比度
  58. print("[INFO] 步骤3: 提高对比度(CLAHE)...")
  59. clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  60. enhanced = clahe.apply(gray)
  61. # 步骤4: 去噪 - 使用 fastNlMeansDenoising
  62. # 对于反转后的图片,去噪参数可以稍微降低
  63. print("[INFO] 步骤4: 去噪...")
  64. denoised = cv2.fastNlMeansDenoising(enhanced, None, h=10, templateWindowSize=7, searchWindowSize=21)
  65. # 步骤5: 锐化 - 使用锐化核增强文字边缘
  66. print("[INFO] 步骤5: 锐化...")
  67. # 创建锐化核
  68. sharpen_kernel = np.array([
  69. [0, -1, 0],
  70. [-1, 5, -1],
  71. [0, -1, 0]
  72. ])
  73. sharpened = cv2.filter2D(denoised, -1, sharpen_kernel)
  74. # 步骤6: 二值化 - 使用 OTSU 阈值,然后进行更严格的二值化
  75. print("[INFO] 步骤6: 二值化(OTSU + 自适应阈值)...")
  76. # 使用OTSU自动阈值(反转后应该是白底黑字,OTSU效果会很好)
  77. otsu_thresh, binary_otsu = cv2.threshold(sharpened, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  78. # 如果OTSU阈值太低,使用更严格的阈值确保文字区域更干净
  79. # 对于白底黑字,我们希望文字(黑色)更纯,背景(白色)更干净
  80. if otsu_thresh < 127:
  81. # 使用更严格的阈值,确保文字区域更黑(更干净)
  82. _, binary_strict = cv2.threshold(sharpened, otsu_thresh + 10, 255, cv2.THRESH_BINARY)
  83. binary_otsu = binary_strict
  84. # 步骤7: 形态学操作,去除小噪点,填充文字内部空洞,清理文字内部
  85. print("[INFO] 步骤7: 形态学操作(去噪点、填充空洞、清理文字内部)...")
  86. # 先开运算:去除小的噪点(在文字外部)
  87. kernel_open_small = np.ones((2, 2), np.uint8)
  88. cleaned = cv2.morphologyEx(binary_otsu, cv2.MORPH_OPEN, kernel_open_small, iterations=1)
  89. # 闭运算:填充文字内部的小空洞
  90. kernel_close = np.ones((3, 3), np.uint8)
  91. filled = cv2.morphologyEx(cleaned, cv2.MORPH_CLOSE, kernel_close, iterations=2)
  92. # 再次开运算:清理文字边缘的小突起
  93. kernel_open_edge = np.ones((2, 2), np.uint8)
  94. result = cv2.morphologyEx(filled, cv2.MORPH_OPEN, kernel_open_edge, iterations=1)
  95. # 额外的清理步骤:使用连通域分析去除小的噪点
  96. # 找到所有连通域
  97. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(255 - result, connectivity=8)
  98. # 计算平均连通域面积(排除背景)
  99. if num_labels > 1:
  100. areas = stats[1:, cv2.CC_STAT_AREA]
  101. if len(areas) > 0:
  102. # 计算中位数面积,作为阈值
  103. median_area = np.median(areas)
  104. min_area = max(10, median_area * 0.1) # 至少保留10像素,或中位数的10%
  105. # 创建清理后的mask
  106. cleaned_mask = np.zeros_like(result)
  107. for i in range(1, num_labels):
  108. if stats[i, cv2.CC_STAT_AREA] >= min_area:
  109. # 保留这个连通域
  110. cleaned_mask[labels == i] = 255
  111. # 反转回来(因为连通域分析是在反转图像上做的)
  112. result = 255 - cleaned_mask
  113. # 保存结果(处理中文路径)
  114. success, encoded_img = cv2.imencode('.png', result)
  115. if success:
  116. encoded_img.tofile(str(output_path))
  117. print(f"[OK] 预处理完成,已保存: {Path(output_path).name}")
  118. else:
  119. raise ValueError(f"保存图片失败: {output_path}")
  120. return result
  121. if __name__ == '__main__':
  122. if len(sys.argv) < 3:
  123. print("用法: python preprocess_image.py <输入图片路径> <输出图片路径>")
  124. sys.exit(1)
  125. input_path = sys.argv[1]
  126. output_path = sys.argv[2]
  127. try:
  128. preprocess_image_for_ocr(input_path, output_path)
  129. except Exception as e:
  130. print(f"[ERROR] 预处理失败: {e}")
  131. import traceback
  132. traceback.print_exc()
  133. sys.exit(1)