| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- # -*- coding: utf-8 -*-
- """
- 使用 comic-text-detector-master 检测漫画页面的文字和位置
- 并绘制每个对话气泡的范围框
- """
- import sys
- import os
- import json
- from pathlib import Path
- import cv2
- import numpy as np
- import warnings
- # 抑制 pkg_resources 的弃用警告
- warnings.filterwarnings('ignore', category=DeprecationWarning, message='.*pkg_resources.*')
- # Windows编码修复
- if sys.platform == 'win32':
- import io
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
- sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
- # 添加comic-text-detector路径
- project_root = Path(__file__).parent.parent
- # 尝试两个可能的路径
- comic_detector_path1 = project_root / 'comic-text-detector-master' / 'comic-text-detector-master'
- comic_detector_path2 = project_root / 'comic-text-detector-master'
- if comic_detector_path1.exists() and (comic_detector_path1 / 'inference.py').exists():
- comic_detector_path = comic_detector_path1
- elif comic_detector_path2.exists() and (comic_detector_path2 / 'inference.py').exists():
- comic_detector_path = comic_detector_path2
- else:
- comic_detector_path = comic_detector_path1 # 默认路径
- sys.path.insert(0, str(comic_detector_path))
- # 处理 wandb 可选依赖(comic-text-detector 需要但推理时不需要)
- try:
- import wandb
- except ImportError:
- # 创建一个假的 wandb 模块,避免导入错误
- class FakeWandb:
- @staticmethod
- def init(*args, **kwargs):
- pass
- @staticmethod
- def log(*args, **kwargs):
- pass
- @staticmethod
- def log_model(*args, **kwargs):
- pass
- sys.modules['wandb'] = FakeWandb()
- try:
- from inference import TextDetector, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION
- from utils.textblock import TextBlock
- from utils.io_utils import imread, imwrite
- except ImportError as e:
- print(f"[ERROR] 无法导入comic-text-detector模块: {e}")
- print(f"[INFO] 请确保已安装依赖: pip install torch torchvision opencv-python numpy tqdm")
- sys.exit(1)
- def detect_comic_text_with_boxes(image_path, output_dir, project_root=None,
- input_size=1536, conf_thresh=0.4, nms_thresh=0.35,
- mask_thresh=0.3, act='leaky', refine_mode=0,
- keep_undetected_mask=False, erode_iterations=0,
- invert_mask=False):
- """
- 使用comic-text-detector检测漫画图片中的文字和位置
- 只生成文字区域坐标JSON文件,不生成mask图片
-
- 参数:
- image_path: 图片路径
- output_dir: 输出目录
- project_root: 项目根目录(可选)
- input_size: 输入尺寸(默认1536,越大精度越高但速度越慢,建议1024/1536/2048)
- conf_thresh: 置信度阈值(默认0.4,0-1,越高越严格,过滤低质量检测)
- nms_thresh: NMS阈值(默认0.35,0-1,越高保留越多重叠框)
- mask_thresh: Mask阈值(默认0.3,0-1,用于分割网络,影响mask精度)
- act: 激活函数(默认'leaky',可选'leaky'或'relu')
- refine_mode: 精炼模式(默认0=INPAINT填充模式,1=ANNOTATION标注模式)
- keep_undetected_mask: 是否保留未检测区域(默认False,True会保留更多区域但可能有噪声)
- erode_iterations: 腐蚀迭代次数(已废弃,不再使用)
- invert_mask: 是否反转mask(已废弃,不再使用)
-
- 返回:
- 检测结果字典,包含文字块的位置信息(不包含mask图片)
- """
- if project_root is None:
- project_root = Path(__file__).parent.parent
- else:
- project_root = Path(project_root)
-
- image_path = Path(image_path)
- output_dir = Path(output_dir)
-
- if not image_path.exists():
- raise FileNotFoundError(f"图片文件不存在: {image_path}")
-
- # 确保输出目录存在
- output_dir.mkdir(parents=True, exist_ok=True)
-
- print(f"📖 正在检测图片中的文字区域: {image_path.name}")
-
- # 设置模型路径
- possible_paths = [
- comic_detector_path / 'data' / 'comictextdetector.pt',
- comic_detector_path / 'data' / 'comictextdetector.pt.onnx',
- project_root / 'models' / 'comictextdetector.pt',
- project_root / 'models' / 'comictextdetector.pt.onnx',
- ]
- model_path = None
- for path in possible_paths:
- if path.exists():
- model_path = path
- break
-
- if model_path is None:
- raise FileNotFoundError(
- f"未找到模型文件。请下载模型并放到以下位置之一:\n" +
- "\n".join([f" - {p}" for p in possible_paths])
- )
-
- # 初始化检测器
- device = 'cuda' if __import__('torch').cuda.is_available() else 'cpu'
- print(f"[INFO] 使用设备: {device}")
-
- try:
- detector = TextDetector(
- model_path=str(model_path),
- input_size=input_size, # 可配置的输入尺寸
- device=device,
- conf_thresh=conf_thresh, # 置信度阈值
- nms_thresh=nms_thresh, # NMS阈值
- mask_thresh=mask_thresh, # Mask阈值
- act=act # 激活函数
- )
- except Exception as e:
- print(f"[ERROR] 初始化检测器失败: {e}")
- raise
-
- # 读取图片(处理中文路径)
- img_array = np.fromfile(str(image_path), dtype=np.uint8)
- img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
-
- if img is None:
- raise ValueError(f"无法读取图片文件: {image_path}")
-
- im_h, im_w = img.shape[:2]
- print(f"[INFO] 图片尺寸: {im_w}x{im_h}")
-
- # 执行检测
- print("[INFO] 正在检测文字区域...")
- print(f"[INFO] 检测参数: input_size={input_size}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, refine_mode={refine_mode}")
- try:
- # 根据refine_mode选择模式
- refine_mode_enum = REFINEMASK_INPAINT if refine_mode == 0 else REFINEMASK_ANNOTATION
- mask, mask_refined, blk_list = detector(
- img,
- refine_mode=refine_mode_enum, # 可配置的精炼模式
- keep_undetected_mask=keep_undetected_mask # 可配置的未检测区域保留
- )
- except Exception as e:
- print(f"[ERROR] 检测失败: {e}")
- raise
-
- # 提取文字块信息(不绘制轮廓框,只提取数据)
- text_blocks = []
-
- for i, blk in enumerate(blk_list):
- # 获取边界框坐标
- x1, y1, x2, y2 = blk.xyxy
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
-
- # 获取文字行坐标(多边形轮廓)
- lines = blk.lines if hasattr(blk, 'lines') else []
-
- # 计算中心点
- center_x = (x1 + x2) / 2
- center_y = (y1 + y2) / 2
-
- # 计算宽度和高度
- width = x2 - x1
- height = y2 - y1
-
- # 处理lines,确保所有坐标都是可序列化的
- lines_data = []
- if lines:
- for line in lines:
- if isinstance(line, (list, tuple)):
- line_points = []
- for p in line:
- if isinstance(p, (list, tuple)) and len(p) >= 2:
- line_points.append([int(float(p[0])), int(float(p[1]))])
- if line_points:
- lines_data.append(line_points)
-
- text_block = {
- 'index': i + 1,
- 'bbox': {
- 'x1': x1,
- 'y1': y1,
- 'x2': x2,
- 'y2': y2,
- 'width': int(width),
- 'height': int(height),
- 'center_x': int(center_x),
- 'center_y': int(center_y)
- },
- 'lines': lines_data,
- 'language': str(getattr(blk, 'language', 'unknown')),
- 'vertical': bool(getattr(blk, 'vertical', False))
- }
- text_blocks.append(text_block)
-
- print(f"[OK] 检测到 {len(text_blocks)} 个文字区域")
-
- # 保存结果
- image_name = image_path.stem
- result = {
- 'image_file': image_path.name,
- 'image_size': {
- 'width': im_w,
- 'height': im_h
- },
- 'text_blocks': text_blocks,
- 'total_count': len(text_blocks)
- }
-
- # 保存文字区域坐标JSON文件
- json_path = output_dir / f"{image_name}_text_regions.json"
- with open(json_path, 'w', encoding='utf-8') as f:
- json.dump(result, f, ensure_ascii=False, indent=2)
- print(f"[OK] 已保存文字区域坐标: {json_path}")
-
- # 同时保存简化版坐标文件(兼容OCR格式)
- ocr_compatible_result = {
- 'dialogues': []
- }
-
- for block in text_blocks:
- # 转换为OCR兼容的格式
- bbox = block['bbox']
- # 构造四个角点坐标(左上、右上、右下、左下)
- bbox_points = [
- [bbox['x1'], bbox['y1']], # 左上
- [bbox['x2'], bbox['y1']], # 右上
- [bbox['x2'], bbox['y2']], # 右下
- [bbox['x1'], bbox['y2']] # 左下
- ]
-
- ocr_compatible_result['dialogues'].append({
- 'bbox': bbox_points,
- 'text': f'[文字区域{block["index"]}]', # 占位符文字
- 'confidence': 0.95, # 高置信度,因为是专门的检测器
- 'source': 'comic-text-detector',
- 'region_info': {
- 'width': bbox['width'],
- 'height': bbox['height'],
- 'center_x': bbox['center_x'],
- 'center_y': bbox['center_y'],
- 'vertical': block['vertical'],
- 'language': block['language']
- }
- })
-
- # 不再生成mask图片,只生成JSON文件
- print(f"[INFO] OCR兼容数据将由Node.js处理,不生成中间文件")
-
- # 添加文件路径到返回结果(只包含JSON文件)
- result['output_files'] = {
- 'text_regions_json': str(json_path)
- }
-
- # 添加OCR兼容数据供Node.js使用
- result['ocr_compatible_data'] = ocr_compatible_result
-
- return result
- if __name__ == '__main__':
- if len(sys.argv) < 3:
- print("用法: python detect_comic_text_with_boxes.py <图片路径> <输出目录> [项目根目录] [input_size] [conf_thresh] [nms_thresh] [mask_thresh] [act] [refine_mode] [keep_undetected_mask] [erode_iterations]")
- print("可选参数:")
- print(" input_size: 输入尺寸(默认1536)")
- print(" conf_thresh: 置信度阈值(默认0.4)")
- print(" nms_thresh: NMS阈值(默认0.35)")
- print(" mask_thresh: Mask阈值(默认0.3)")
- print(" act: 激活函数(默认'leaky')")
- print(" refine_mode: 精炼模式(默认0=INPAINT,1=ANNOTATION)")
- print(" keep_undetected_mask: 是否保留未检测区域(默认False,传入1为True)")
- print(" erode_iterations: 腐蚀迭代次数(默认0,值越大文字越细,建议0-3)")
- print(" invert_mask: 是否反转mask(默认False,传入1为True,输出黑字白底)")
- sys.exit(1)
-
- image_path = sys.argv[1]
- output_dir = sys.argv[2]
- project_root = sys.argv[3] if len(sys.argv) > 3 else None
- input_size = int(sys.argv[4]) if len(sys.argv) > 4 else 1536
- conf_thresh = float(sys.argv[5]) if len(sys.argv) > 5 else 0.4
- nms_thresh = float(sys.argv[6]) if len(sys.argv) > 6 else 0.35
- mask_thresh = float(sys.argv[7]) if len(sys.argv) > 7 else 0.3
- act = sys.argv[8] if len(sys.argv) > 8 else 'leaky'
- refine_mode = int(sys.argv[9]) if len(sys.argv) > 9 else 0
- keep_undetected_mask = bool(int(sys.argv[10])) if len(sys.argv) > 10 else False
- erode_iterations = int(sys.argv[11]) if len(sys.argv) > 11 else 0
- invert_mask = bool(int(sys.argv[12])) if len(sys.argv) > 12 else False
-
- try:
- detect_comic_text_with_boxes(
- image_path, output_dir, project_root,
- input_size, conf_thresh, nms_thresh, mask_thresh,
- act, refine_mode, keep_undetected_mask, erode_iterations, invert_mask
- )
- except Exception as e:
- print(f"[ERROR] 处理失败: {e}")
- import traceback
- traceback.print_exc()
- sys.exit(1)
|