predict_system.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import os
  2. import cv2
  3. import copy
  4. from . import predict_det
  5. from . import predict_cls
  6. from . import predict_rec
  7. from .utils import get_rotate_crop_image, get_minarea_rect_crop
  8. class TextSystem(object):
  9. def __init__(self, args):
  10. self.text_detector = predict_det.TextDetector(args)
  11. self.text_recognizer = predict_rec.TextRecognizer(args)
  12. self.use_angle_cls = args.use_angle_cls
  13. self.drop_score = args.drop_score
  14. if self.use_angle_cls:
  15. self.text_classifier = predict_cls.TextClassifier(args)
  16. self.args = args
  17. self.crop_image_res_index = 0
  18. def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
  19. os.makedirs(output_dir, exist_ok=True)
  20. bbox_num = len(img_crop_list)
  21. for bno in range(bbox_num):
  22. cv2.imwrite(
  23. os.path.join(
  24. output_dir, f"mg_crop_{bno+self.crop_image_res_index}.jpg"
  25. ),
  26. img_crop_list[bno],
  27. )
  28. self.crop_image_res_index += bbox_num
  29. def __call__(self, img, cls=True):
  30. ori_im = img.copy()
  31. # 文字检测
  32. dt_boxes = self.text_detector(img)
  33. if dt_boxes is None:
  34. return None, None
  35. img_crop_list = []
  36. dt_boxes = sorted_boxes(dt_boxes)
  37. # 图片裁剪
  38. for bno in range(len(dt_boxes)):
  39. tmp_box = copy.deepcopy(dt_boxes[bno])
  40. if self.args.det_box_type == "quad":
  41. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  42. else:
  43. img_crop = get_minarea_rect_crop(ori_im, tmp_box)
  44. img_crop_list.append(img_crop)
  45. # 方向分类
  46. if self.use_angle_cls and cls:
  47. img_crop_list, angle_list = self.text_classifier(img_crop_list)
  48. # 图像识别
  49. rec_res = self.text_recognizer(img_crop_list)
  50. if self.args.save_crop_res:
  51. self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
  52. filter_boxes, filter_rec_res = [], []
  53. for box, rec_result in zip(dt_boxes, rec_res):
  54. text, score = rec_result
  55. if score >= self.drop_score:
  56. filter_boxes.append(box)
  57. filter_rec_res.append(rec_result)
  58. return filter_boxes, filter_rec_res
  59. def sorted_boxes(dt_boxes):
  60. """
  61. Sort text boxes in order from top to bottom, left to right
  62. args:
  63. dt_boxes(array):detected text boxes with shape [4, 2]
  64. return:
  65. sorted boxes(array) with shape [4, 2]
  66. """
  67. num_boxes = dt_boxes.shape[0]
  68. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  69. _boxes = list(sorted_boxes)
  70. for i in range(num_boxes - 1):
  71. for j in range(i, -1, -1):
  72. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  73. _boxes[j + 1][0][0] < _boxes[j][0][0]
  74. ):
  75. tmp = _boxes[j]
  76. _boxes[j] = _boxes[j + 1]
  77. _boxes[j + 1] = tmp
  78. else:
  79. break
  80. return _boxes