# -*- coding: utf-8 -*- """ 使用 comic-text-detector-master 检测漫画页面的文字和位置 并绘制每个对话气泡的范围框 """ import sys import os import json from pathlib import Path import cv2 import numpy as np import warnings # 抑制 pkg_resources 的弃用警告 warnings.filterwarnings('ignore', category=DeprecationWarning, message='.*pkg_resources.*') # Windows编码修复 if sys.platform == 'win32': import io sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace') # 添加comic-text-detector路径 project_root = Path(__file__).parent.parent # 尝试两个可能的路径 comic_detector_path1 = project_root / 'comic-text-detector-master' / 'comic-text-detector-master' comic_detector_path2 = project_root / 'comic-text-detector-master' if comic_detector_path1.exists() and (comic_detector_path1 / 'inference.py').exists(): comic_detector_path = comic_detector_path1 elif comic_detector_path2.exists() and (comic_detector_path2 / 'inference.py').exists(): comic_detector_path = comic_detector_path2 else: comic_detector_path = comic_detector_path1 # 默认路径 sys.path.insert(0, str(comic_detector_path)) # 处理 wandb 可选依赖(comic-text-detector 需要但推理时不需要) try: import wandb except ImportError: # 创建一个假的 wandb 模块,避免导入错误 class FakeWandb: @staticmethod def init(*args, **kwargs): pass @staticmethod def log(*args, **kwargs): pass @staticmethod def log_model(*args, **kwargs): pass sys.modules['wandb'] = FakeWandb() try: from inference import TextDetector, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION from utils.textblock import TextBlock from utils.io_utils import imread, imwrite except ImportError as e: print(f"[ERROR] 无法导入comic-text-detector模块: {e}") print(f"[INFO] 请确保已安装依赖: pip install torch torchvision opencv-python numpy tqdm") sys.exit(1) def detect_comic_text_with_boxes(image_path, output_dir, project_root=None, input_size=1536, conf_thresh=0.4, nms_thresh=0.35, mask_thresh=0.3, act='leaky', refine_mode=0, keep_undetected_mask=False, erode_iterations=0, invert_mask=False): """ 使用comic-text-detector检测漫画图片中的文字和位置 只生成文字区域坐标JSON文件,不生成mask图片 参数: image_path: 图片路径 output_dir: 输出目录 project_root: 项目根目录(可选) input_size: 输入尺寸(默认1536,越大精度越高但速度越慢,建议1024/1536/2048) conf_thresh: 置信度阈值(默认0.4,0-1,越高越严格,过滤低质量检测) nms_thresh: NMS阈值(默认0.35,0-1,越高保留越多重叠框) mask_thresh: Mask阈值(默认0.3,0-1,用于分割网络,影响mask精度) act: 激活函数(默认'leaky',可选'leaky'或'relu') refine_mode: 精炼模式(默认0=INPAINT填充模式,1=ANNOTATION标注模式) keep_undetected_mask: 是否保留未检测区域(默认False,True会保留更多区域但可能有噪声) erode_iterations: 腐蚀迭代次数(已废弃,不再使用) invert_mask: 是否反转mask(已废弃,不再使用) 返回: 检测结果字典,包含文字块的位置信息(不包含mask图片) """ if project_root is None: project_root = Path(__file__).parent.parent else: project_root = Path(project_root) image_path = Path(image_path) output_dir = Path(output_dir) if not image_path.exists(): raise FileNotFoundError(f"图片文件不存在: {image_path}") # 确保输出目录存在 output_dir.mkdir(parents=True, exist_ok=True) print(f"📖 正在检测图片中的文字区域: {image_path.name}") # 设置模型路径 possible_paths = [ comic_detector_path / 'data' / 'comictextdetector.pt', comic_detector_path / 'data' / 'comictextdetector.pt.onnx', project_root / 'models' / 'comictextdetector.pt', project_root / 'models' / 'comictextdetector.pt.onnx', ] model_path = None for path in possible_paths: if path.exists(): model_path = path break if model_path is None: raise FileNotFoundError( f"未找到模型文件。请下载模型并放到以下位置之一:\n" + "\n".join([f" - {p}" for p in possible_paths]) ) # 初始化检测器 device = 'cuda' if __import__('torch').cuda.is_available() else 'cpu' print(f"[INFO] 使用设备: {device}") try: detector = TextDetector( model_path=str(model_path), input_size=input_size, # 可配置的输入尺寸 device=device, conf_thresh=conf_thresh, # 置信度阈值 nms_thresh=nms_thresh, # NMS阈值 mask_thresh=mask_thresh, # Mask阈值 act=act # 激活函数 ) except Exception as e: print(f"[ERROR] 初始化检测器失败: {e}") raise # 读取图片(处理中文路径) img_array = np.fromfile(str(image_path), dtype=np.uint8) img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) if img is None: raise ValueError(f"无法读取图片文件: {image_path}") im_h, im_w = img.shape[:2] print(f"[INFO] 图片尺寸: {im_w}x{im_h}") # 执行检测 print("[INFO] 正在检测文字区域...") print(f"[INFO] 检测参数: input_size={input_size}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, refine_mode={refine_mode}") try: # 根据refine_mode选择模式 refine_mode_enum = REFINEMASK_INPAINT if refine_mode == 0 else REFINEMASK_ANNOTATION mask, mask_refined, blk_list = detector( img, refine_mode=refine_mode_enum, # 可配置的精炼模式 keep_undetected_mask=keep_undetected_mask # 可配置的未检测区域保留 ) except Exception as e: print(f"[ERROR] 检测失败: {e}") raise # 提取文字块信息(不绘制轮廓框,只提取数据) text_blocks = [] for i, blk in enumerate(blk_list): # 获取边界框坐标 x1, y1, x2, y2 = blk.xyxy x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) # 获取文字行坐标(多边形轮廓) lines = blk.lines if hasattr(blk, 'lines') else [] # 计算中心点 center_x = (x1 + x2) / 2 center_y = (y1 + y2) / 2 # 计算宽度和高度 width = x2 - x1 height = y2 - y1 # 处理lines,确保所有坐标都是可序列化的 lines_data = [] if lines: for line in lines: if isinstance(line, (list, tuple)): line_points = [] for p in line: if isinstance(p, (list, tuple)) and len(p) >= 2: line_points.append([int(float(p[0])), int(float(p[1]))]) if line_points: lines_data.append(line_points) text_block = { 'index': i + 1, 'bbox': { 'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2, 'width': int(width), 'height': int(height), 'center_x': int(center_x), 'center_y': int(center_y) }, 'lines': lines_data, 'language': str(getattr(blk, 'language', 'unknown')), 'vertical': bool(getattr(blk, 'vertical', False)) } text_blocks.append(text_block) print(f"[OK] 检测到 {len(text_blocks)} 个文字区域") # 保存结果 image_name = image_path.stem result = { 'image_file': image_path.name, 'image_size': { 'width': im_w, 'height': im_h }, 'text_blocks': text_blocks, 'total_count': len(text_blocks) } # 保存文字区域坐标JSON文件 json_path = output_dir / f"{image_name}_text_regions.json" with open(json_path, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=2) print(f"[OK] 已保存文字区域坐标: {json_path}") # 同时保存简化版坐标文件(兼容OCR格式) ocr_compatible_result = { 'dialogues': [] } for block in text_blocks: # 转换为OCR兼容的格式 bbox = block['bbox'] # 构造四个角点坐标(左上、右上、右下、左下) bbox_points = [ [bbox['x1'], bbox['y1']], # 左上 [bbox['x2'], bbox['y1']], # 右上 [bbox['x2'], bbox['y2']], # 右下 [bbox['x1'], bbox['y2']] # 左下 ] ocr_compatible_result['dialogues'].append({ 'bbox': bbox_points, 'text': f'[文字区域{block["index"]}]', # 占位符文字 'confidence': 0.95, # 高置信度,因为是专门的检测器 'source': 'comic-text-detector', 'region_info': { 'width': bbox['width'], 'height': bbox['height'], 'center_x': bbox['center_x'], 'center_y': bbox['center_y'], 'vertical': block['vertical'], 'language': block['language'] } }) # 不再生成mask图片,只生成JSON文件 print(f"[INFO] OCR兼容数据将由Node.js处理,不生成中间文件") # 添加文件路径到返回结果(只包含JSON文件) result['output_files'] = { 'text_regions_json': str(json_path) } # 添加OCR兼容数据供Node.js使用 result['ocr_compatible_data'] = ocr_compatible_result return result if __name__ == '__main__': if len(sys.argv) < 3: print("用法: python detect_comic_text_with_boxes.py <图片路径> <输出目录> [项目根目录] [input_size] [conf_thresh] [nms_thresh] [mask_thresh] [act] [refine_mode] [keep_undetected_mask] [erode_iterations]") print("可选参数:") print(" input_size: 输入尺寸(默认1536)") print(" conf_thresh: 置信度阈值(默认0.4)") print(" nms_thresh: NMS阈值(默认0.35)") print(" mask_thresh: Mask阈值(默认0.3)") print(" act: 激活函数(默认'leaky')") print(" refine_mode: 精炼模式(默认0=INPAINT,1=ANNOTATION)") print(" keep_undetected_mask: 是否保留未检测区域(默认False,传入1为True)") print(" erode_iterations: 腐蚀迭代次数(默认0,值越大文字越细,建议0-3)") print(" invert_mask: 是否反转mask(默认False,传入1为True,输出黑字白底)") sys.exit(1) image_path = sys.argv[1] output_dir = sys.argv[2] project_root = sys.argv[3] if len(sys.argv) > 3 else None input_size = int(sys.argv[4]) if len(sys.argv) > 4 else 1536 conf_thresh = float(sys.argv[5]) if len(sys.argv) > 5 else 0.4 nms_thresh = float(sys.argv[6]) if len(sys.argv) > 6 else 0.35 mask_thresh = float(sys.argv[7]) if len(sys.argv) > 7 else 0.3 act = sys.argv[8] if len(sys.argv) > 8 else 'leaky' refine_mode = int(sys.argv[9]) if len(sys.argv) > 9 else 0 keep_undetected_mask = bool(int(sys.argv[10])) if len(sys.argv) > 10 else False erode_iterations = int(sys.argv[11]) if len(sys.argv) > 11 else 0 invert_mask = bool(int(sys.argv[12])) if len(sys.argv) > 12 else False try: detect_comic_text_with_boxes( image_path, output_dir, project_root, input_size, conf_thresh, nms_thresh, mask_thresh, act, refine_mode, keep_undetected_mask, erode_iterations, invert_mask ) except Exception as e: print(f"[ERROR] 处理失败: {e}") import traceback traceback.print_exc() sys.exit(1)