| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import os
- import cv2
- import copy
- from . import predict_det
- from . import predict_cls
- from . import predict_rec
- from .utils import get_rotate_crop_image, get_minarea_rect_crop
- class TextSystem(object):
- def __init__(self, args):
- self.text_detector = predict_det.TextDetector(args)
- self.text_recognizer = predict_rec.TextRecognizer(args)
- self.use_angle_cls = args.use_angle_cls
- self.drop_score = args.drop_score
- if self.use_angle_cls:
- self.text_classifier = predict_cls.TextClassifier(args)
- self.args = args
- self.crop_image_res_index = 0
- def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
- os.makedirs(output_dir, exist_ok=True)
- bbox_num = len(img_crop_list)
- for bno in range(bbox_num):
- cv2.imwrite(
- os.path.join(
- output_dir, f"mg_crop_{bno+self.crop_image_res_index}.jpg"
- ),
- img_crop_list[bno],
- )
- self.crop_image_res_index += bbox_num
- def __call__(self, img, cls=True):
- ori_im = img.copy()
- # 文字检测
- dt_boxes = self.text_detector(img)
- if dt_boxes is None:
- return None, None
- img_crop_list = []
- dt_boxes = sorted_boxes(dt_boxes)
- # 图片裁剪
- for bno in range(len(dt_boxes)):
- tmp_box = copy.deepcopy(dt_boxes[bno])
- if self.args.det_box_type == "quad":
- img_crop = get_rotate_crop_image(ori_im, tmp_box)
- else:
- img_crop = get_minarea_rect_crop(ori_im, tmp_box)
- img_crop_list.append(img_crop)
- # 方向分类
- if self.use_angle_cls and cls:
- img_crop_list, angle_list = self.text_classifier(img_crop_list)
- # 图像识别
- rec_res = self.text_recognizer(img_crop_list)
- if self.args.save_crop_res:
- self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
- filter_boxes, filter_rec_res = [], []
- for box, rec_result in zip(dt_boxes, rec_res):
- text, score = rec_result
- if score >= self.drop_score:
- filter_boxes.append(box)
- filter_rec_res.append(rec_result)
- return filter_boxes, filter_rec_res
- def sorted_boxes(dt_boxes):
- """
- Sort text boxes in order from top to bottom, left to right
- args:
- dt_boxes(array):detected text boxes with shape [4, 2]
- return:
- sorted boxes(array) with shape [4, 2]
- """
- num_boxes = dt_boxes.shape[0]
- sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
- _boxes = list(sorted_boxes)
- for i in range(num_boxes - 1):
- for j in range(i, -1, -1):
- if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
- _boxes[j + 1][0][0] < _boxes[j][0][0]
- ):
- tmp = _boxes[j]
- _boxes[j] = _boxes[j + 1]
- _boxes[j + 1] = tmp
- else:
- break
- return _boxes
|