detect_comic_text_with_boxes.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # -*- coding: utf-8 -*-
  2. """
  3. 使用 comic-text-detector-master 检测漫画页面的文字和位置
  4. 并绘制每个对话气泡的范围框
  5. """
  6. import sys
  7. import os
  8. import json
  9. from pathlib import Path
  10. import cv2
  11. import numpy as np
  12. import warnings
  13. # 抑制 pkg_resources 的弃用警告
  14. warnings.filterwarnings('ignore', category=DeprecationWarning, message='.*pkg_resources.*')
  15. # Windows编码修复
  16. if sys.platform == 'win32':
  17. import io
  18. sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
  19. sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
  20. # 添加comic-text-detector路径
  21. project_root = Path(__file__).parent.parent
  22. # 尝试两个可能的路径
  23. comic_detector_path1 = project_root / 'comic-text-detector-master' / 'comic-text-detector-master'
  24. comic_detector_path2 = project_root / 'comic-text-detector-master'
  25. if comic_detector_path1.exists() and (comic_detector_path1 / 'inference.py').exists():
  26. comic_detector_path = comic_detector_path1
  27. elif comic_detector_path2.exists() and (comic_detector_path2 / 'inference.py').exists():
  28. comic_detector_path = comic_detector_path2
  29. else:
  30. comic_detector_path = comic_detector_path1 # 默认路径
  31. sys.path.insert(0, str(comic_detector_path))
  32. # 处理 wandb 可选依赖(comic-text-detector 需要但推理时不需要)
  33. try:
  34. import wandb
  35. except ImportError:
  36. # 创建一个假的 wandb 模块,避免导入错误
  37. class FakeWandb:
  38. @staticmethod
  39. def init(*args, **kwargs):
  40. pass
  41. @staticmethod
  42. def log(*args, **kwargs):
  43. pass
  44. @staticmethod
  45. def log_model(*args, **kwargs):
  46. pass
  47. sys.modules['wandb'] = FakeWandb()
  48. try:
  49. from inference import TextDetector, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION
  50. from utils.textblock import TextBlock
  51. from utils.io_utils import imread, imwrite
  52. except ImportError as e:
  53. print(f"[ERROR] 无法导入comic-text-detector模块: {e}")
  54. print(f"[INFO] 请确保已安装依赖: pip install torch torchvision opencv-python numpy tqdm")
  55. sys.exit(1)
  56. def detect_comic_text_with_boxes(image_path, output_dir, project_root=None,
  57. input_size=1536, conf_thresh=0.4, nms_thresh=0.35,
  58. mask_thresh=0.3, act='leaky', refine_mode=0,
  59. keep_undetected_mask=False, erode_iterations=0,
  60. invert_mask=False):
  61. """
  62. 使用comic-text-detector检测漫画图片中的文字和位置
  63. 只生成文字区域坐标JSON文件,不生成mask图片
  64. 参数:
  65. image_path: 图片路径
  66. output_dir: 输出目录
  67. project_root: 项目根目录(可选)
  68. input_size: 输入尺寸(默认1536,越大精度越高但速度越慢,建议1024/1536/2048)
  69. conf_thresh: 置信度阈值(默认0.4,0-1,越高越严格,过滤低质量检测)
  70. nms_thresh: NMS阈值(默认0.35,0-1,越高保留越多重叠框)
  71. mask_thresh: Mask阈值(默认0.3,0-1,用于分割网络,影响mask精度)
  72. act: 激活函数(默认'leaky',可选'leaky'或'relu')
  73. refine_mode: 精炼模式(默认0=INPAINT填充模式,1=ANNOTATION标注模式)
  74. keep_undetected_mask: 是否保留未检测区域(默认False,True会保留更多区域但可能有噪声)
  75. erode_iterations: 腐蚀迭代次数(已废弃,不再使用)
  76. invert_mask: 是否反转mask(已废弃,不再使用)
  77. 返回:
  78. 检测结果字典,包含文字块的位置信息(不包含mask图片)
  79. """
  80. if project_root is None:
  81. project_root = Path(__file__).parent.parent
  82. else:
  83. project_root = Path(project_root)
  84. image_path = Path(image_path)
  85. output_dir = Path(output_dir)
  86. if not image_path.exists():
  87. raise FileNotFoundError(f"图片文件不存在: {image_path}")
  88. # 确保输出目录存在
  89. output_dir.mkdir(parents=True, exist_ok=True)
  90. print(f"📖 正在检测图片中的文字区域: {image_path.name}")
  91. # 设置模型路径
  92. possible_paths = [
  93. comic_detector_path / 'data' / 'comictextdetector.pt',
  94. comic_detector_path / 'data' / 'comictextdetector.pt.onnx',
  95. project_root / 'models' / 'comictextdetector.pt',
  96. project_root / 'models' / 'comictextdetector.pt.onnx',
  97. ]
  98. model_path = None
  99. for path in possible_paths:
  100. if path.exists():
  101. model_path = path
  102. break
  103. if model_path is None:
  104. raise FileNotFoundError(
  105. f"未找到模型文件。请下载模型并放到以下位置之一:\n" +
  106. "\n".join([f" - {p}" for p in possible_paths])
  107. )
  108. # 初始化检测器
  109. device = 'cuda' if __import__('torch').cuda.is_available() else 'cpu'
  110. print(f"[INFO] 使用设备: {device}")
  111. try:
  112. detector = TextDetector(
  113. model_path=str(model_path),
  114. input_size=input_size, # 可配置的输入尺寸
  115. device=device,
  116. conf_thresh=conf_thresh, # 置信度阈值
  117. nms_thresh=nms_thresh, # NMS阈值
  118. mask_thresh=mask_thresh, # Mask阈值
  119. act=act # 激活函数
  120. )
  121. except Exception as e:
  122. print(f"[ERROR] 初始化检测器失败: {e}")
  123. raise
  124. # 读取图片(处理中文路径)
  125. img_array = np.fromfile(str(image_path), dtype=np.uint8)
  126. img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
  127. if img is None:
  128. raise ValueError(f"无法读取图片文件: {image_path}")
  129. im_h, im_w = img.shape[:2]
  130. print(f"[INFO] 图片尺寸: {im_w}x{im_h}")
  131. # 执行检测
  132. print("[INFO] 正在检测文字区域...")
  133. print(f"[INFO] 检测参数: input_size={input_size}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, refine_mode={refine_mode}")
  134. try:
  135. # 根据refine_mode选择模式
  136. refine_mode_enum = REFINEMASK_INPAINT if refine_mode == 0 else REFINEMASK_ANNOTATION
  137. mask, mask_refined, blk_list = detector(
  138. img,
  139. refine_mode=refine_mode_enum, # 可配置的精炼模式
  140. keep_undetected_mask=keep_undetected_mask # 可配置的未检测区域保留
  141. )
  142. except Exception as e:
  143. print(f"[ERROR] 检测失败: {e}")
  144. raise
  145. # 提取文字块信息(不绘制轮廓框,只提取数据)
  146. text_blocks = []
  147. for i, blk in enumerate(blk_list):
  148. # 获取边界框坐标
  149. x1, y1, x2, y2 = blk.xyxy
  150. x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
  151. # 获取文字行坐标(多边形轮廓)
  152. lines = blk.lines if hasattr(blk, 'lines') else []
  153. # 计算中心点
  154. center_x = (x1 + x2) / 2
  155. center_y = (y1 + y2) / 2
  156. # 计算宽度和高度
  157. width = x2 - x1
  158. height = y2 - y1
  159. # 处理lines,确保所有坐标都是可序列化的
  160. lines_data = []
  161. if lines:
  162. for line in lines:
  163. if isinstance(line, (list, tuple)):
  164. line_points = []
  165. for p in line:
  166. if isinstance(p, (list, tuple)) and len(p) >= 2:
  167. line_points.append([int(float(p[0])), int(float(p[1]))])
  168. if line_points:
  169. lines_data.append(line_points)
  170. text_block = {
  171. 'index': i + 1,
  172. 'bbox': {
  173. 'x1': x1,
  174. 'y1': y1,
  175. 'x2': x2,
  176. 'y2': y2,
  177. 'width': int(width),
  178. 'height': int(height),
  179. 'center_x': int(center_x),
  180. 'center_y': int(center_y)
  181. },
  182. 'lines': lines_data,
  183. 'language': str(getattr(blk, 'language', 'unknown')),
  184. 'vertical': bool(getattr(blk, 'vertical', False))
  185. }
  186. text_blocks.append(text_block)
  187. print(f"[OK] 检测到 {len(text_blocks)} 个文字区域")
  188. # 保存结果
  189. image_name = image_path.stem
  190. result = {
  191. 'image_file': image_path.name,
  192. 'image_size': {
  193. 'width': im_w,
  194. 'height': im_h
  195. },
  196. 'text_blocks': text_blocks,
  197. 'total_count': len(text_blocks)
  198. }
  199. # 保存文字区域坐标JSON文件
  200. json_path = output_dir / f"{image_name}_text_regions.json"
  201. with open(json_path, 'w', encoding='utf-8') as f:
  202. json.dump(result, f, ensure_ascii=False, indent=2)
  203. print(f"[OK] 已保存文字区域坐标: {json_path}")
  204. # 同时保存简化版坐标文件(兼容OCR格式)
  205. ocr_compatible_result = {
  206. 'dialogues': []
  207. }
  208. for block in text_blocks:
  209. # 转换为OCR兼容的格式
  210. bbox = block['bbox']
  211. # 构造四个角点坐标(左上、右上、右下、左下)
  212. bbox_points = [
  213. [bbox['x1'], bbox['y1']], # 左上
  214. [bbox['x2'], bbox['y1']], # 右上
  215. [bbox['x2'], bbox['y2']], # 右下
  216. [bbox['x1'], bbox['y2']] # 左下
  217. ]
  218. ocr_compatible_result['dialogues'].append({
  219. 'bbox': bbox_points,
  220. 'text': f'[文字区域{block["index"]}]', # 占位符文字
  221. 'confidence': 0.95, # 高置信度,因为是专门的检测器
  222. 'source': 'comic-text-detector',
  223. 'region_info': {
  224. 'width': bbox['width'],
  225. 'height': bbox['height'],
  226. 'center_x': bbox['center_x'],
  227. 'center_y': bbox['center_y'],
  228. 'vertical': block['vertical'],
  229. 'language': block['language']
  230. }
  231. })
  232. # 不再生成mask图片,只生成JSON文件
  233. print(f"[INFO] OCR兼容数据将由Node.js处理,不生成中间文件")
  234. # 添加文件路径到返回结果(只包含JSON文件)
  235. result['output_files'] = {
  236. 'text_regions_json': str(json_path)
  237. }
  238. # 添加OCR兼容数据供Node.js使用
  239. result['ocr_compatible_data'] = ocr_compatible_result
  240. return result
  241. if __name__ == '__main__':
  242. if len(sys.argv) < 3:
  243. print("用法: python detect_comic_text_with_boxes.py <图片路径> <输出目录> [项目根目录] [input_size] [conf_thresh] [nms_thresh] [mask_thresh] [act] [refine_mode] [keep_undetected_mask] [erode_iterations]")
  244. print("可选参数:")
  245. print(" input_size: 输入尺寸(默认1536)")
  246. print(" conf_thresh: 置信度阈值(默认0.4)")
  247. print(" nms_thresh: NMS阈值(默认0.35)")
  248. print(" mask_thresh: Mask阈值(默认0.3)")
  249. print(" act: 激活函数(默认'leaky')")
  250. print(" refine_mode: 精炼模式(默认0=INPAINT,1=ANNOTATION)")
  251. print(" keep_undetected_mask: 是否保留未检测区域(默认False,传入1为True)")
  252. print(" erode_iterations: 腐蚀迭代次数(默认0,值越大文字越细,建议0-3)")
  253. print(" invert_mask: 是否反转mask(默认False,传入1为True,输出黑字白底)")
  254. sys.exit(1)
  255. image_path = sys.argv[1]
  256. output_dir = sys.argv[2]
  257. project_root = sys.argv[3] if len(sys.argv) > 3 else None
  258. input_size = int(sys.argv[4]) if len(sys.argv) > 4 else 1536
  259. conf_thresh = float(sys.argv[5]) if len(sys.argv) > 5 else 0.4
  260. nms_thresh = float(sys.argv[6]) if len(sys.argv) > 6 else 0.35
  261. mask_thresh = float(sys.argv[7]) if len(sys.argv) > 7 else 0.3
  262. act = sys.argv[8] if len(sys.argv) > 8 else 'leaky'
  263. refine_mode = int(sys.argv[9]) if len(sys.argv) > 9 else 0
  264. keep_undetected_mask = bool(int(sys.argv[10])) if len(sys.argv) > 10 else False
  265. erode_iterations = int(sys.argv[11]) if len(sys.argv) > 11 else 0
  266. invert_mask = bool(int(sys.argv[12])) if len(sys.argv) > 12 else False
  267. try:
  268. detect_comic_text_with_boxes(
  269. image_path, output_dir, project_root,
  270. input_size, conf_thresh, nms_thresh, mask_thresh,
  271. act, refine_mode, keep_undetected_mask, erode_iterations, invert_mask
  272. )
  273. except Exception as e:
  274. print(f"[ERROR] 处理失败: {e}")
  275. import traceback
  276. traceback.print_exc()
  277. sys.exit(1)