predict_system.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import subprocess
  17. __dir__ = os.path.dirname(os.path.abspath(__file__))
  18. sys.path.append(__dir__)
  19. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
  20. os.environ["FLAGS_allocator_strategy"] = "auto_growth"
  21. import cv2
  22. import copy
  23. import numpy as np
  24. import json
  25. import time
  26. import logging
  27. from PIL import Image
  28. import tools.infer.utility as utility
  29. import tools.infer.predict_rec as predict_rec
  30. import tools.infer.predict_det as predict_det
  31. import tools.infer.predict_cls as predict_cls
  32. from ppocr.utils.utility import get_image_file_list, check_and_read
  33. from ppocr.utils.logging import get_logger
  34. from tools.infer.utility import (
  35. draw_ocr_box_txt,
  36. get_rotate_crop_image,
  37. get_minarea_rect_crop,
  38. slice_generator,
  39. merge_fragmented,
  40. )
  41. logger = get_logger()
  42. class TextSystem(object):
  43. def __init__(self, args):
  44. if not args.show_log:
  45. logger.setLevel(logging.INFO)
  46. self.text_detector = predict_det.TextDetector(args)
  47. self.text_recognizer = predict_rec.TextRecognizer(args)
  48. self.use_angle_cls = args.use_angle_cls
  49. self.drop_score = args.drop_score
  50. if self.use_angle_cls:
  51. self.text_classifier = predict_cls.TextClassifier(args)
  52. self.args = args
  53. self.crop_image_res_index = 0
  54. def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
  55. os.makedirs(output_dir, exist_ok=True)
  56. bbox_num = len(img_crop_list)
  57. for bno in range(bbox_num):
  58. cv2.imwrite(
  59. os.path.join(
  60. output_dir, f"mg_crop_{bno+self.crop_image_res_index}.jpg"
  61. ),
  62. img_crop_list[bno],
  63. )
  64. logger.debug(f"{bno}, {rec_res[bno]}")
  65. self.crop_image_res_index += bbox_num
  66. def __call__(self, img, cls=True, slice={}):
  67. time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
  68. if img is None:
  69. logger.debug("no valid image provided")
  70. return None, None, time_dict
  71. start = time.time()
  72. ori_im = img.copy()
  73. if slice:
  74. slice_gen = slice_generator(
  75. img,
  76. horizontal_stride=slice["horizontal_stride"],
  77. vertical_stride=slice["vertical_stride"],
  78. )
  79. elapsed = []
  80. dt_slice_boxes = []
  81. for slice_crop, v_start, h_start in slice_gen:
  82. dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
  83. if dt_boxes.size:
  84. dt_boxes[:, :, 0] += h_start
  85. dt_boxes[:, :, 1] += v_start
  86. dt_slice_boxes.append(dt_boxes)
  87. elapsed.append(elapse)
  88. dt_boxes = np.concatenate(dt_slice_boxes)
  89. dt_boxes = merge_fragmented(
  90. boxes=dt_boxes,
  91. x_threshold=slice["merge_x_thres"],
  92. y_threshold=slice["merge_y_thres"],
  93. )
  94. elapse = sum(elapsed)
  95. else:
  96. dt_boxes, elapse = self.text_detector(img)
  97. time_dict["det"] = elapse
  98. if dt_boxes is None:
  99. logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
  100. end = time.time()
  101. time_dict["all"] = end - start
  102. return None, None, time_dict
  103. else:
  104. logger.debug(
  105. "dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
  106. )
  107. img_crop_list = []
  108. dt_boxes = sorted_boxes(dt_boxes)
  109. for bno in range(len(dt_boxes)):
  110. tmp_box = copy.deepcopy(dt_boxes[bno])
  111. if self.args.det_box_type == "quad":
  112. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  113. else:
  114. img_crop = get_minarea_rect_crop(ori_im, tmp_box)
  115. img_crop_list.append(img_crop)
  116. if self.use_angle_cls and cls:
  117. img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
  118. time_dict["cls"] = elapse
  119. logger.debug(
  120. "cls num : {}, elapsed : {}".format(len(img_crop_list), elapse)
  121. )
  122. if len(img_crop_list) > 1000:
  123. logger.debug(
  124. f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
  125. )
  126. rec_res, elapse = self.text_recognizer(img_crop_list)
  127. time_dict["rec"] = elapse
  128. logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
  129. if self.args.save_crop_res:
  130. self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
  131. filter_boxes, filter_rec_res = [], []
  132. for box, rec_result in zip(dt_boxes, rec_res):
  133. text, score = rec_result[0], rec_result[1]
  134. if score >= self.drop_score:
  135. filter_boxes.append(box)
  136. filter_rec_res.append(rec_result)
  137. end = time.time()
  138. time_dict["all"] = end - start
  139. return filter_boxes, filter_rec_res, time_dict
  140. def sorted_boxes(dt_boxes):
  141. """
  142. Sort text boxes in order from top to bottom, left to right
  143. args:
  144. dt_boxes(array):detected text boxes with shape [4, 2]
  145. return:
  146. sorted boxes(array) with shape [4, 2]
  147. """
  148. num_boxes = dt_boxes.shape[0]
  149. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  150. _boxes = list(sorted_boxes)
  151. for i in range(num_boxes - 1):
  152. for j in range(i, -1, -1):
  153. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  154. _boxes[j + 1][0][0] < _boxes[j][0][0]
  155. ):
  156. tmp = _boxes[j]
  157. _boxes[j] = _boxes[j + 1]
  158. _boxes[j + 1] = tmp
  159. else:
  160. break
  161. return _boxes
  162. def main(args):
  163. image_file_list = get_image_file_list(args.image_dir)
  164. image_file_list = image_file_list[args.process_id :: args.total_process_num]
  165. text_sys = TextSystem(args)
  166. is_visualize = True
  167. font_path = args.vis_font_path
  168. drop_score = args.drop_score
  169. draw_img_save_dir = args.draw_img_save_dir
  170. os.makedirs(draw_img_save_dir, exist_ok=True)
  171. save_results = []
  172. logger.info(
  173. "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
  174. "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
  175. )
  176. # warm up 10 times
  177. if args.warmup:
  178. img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
  179. for i in range(10):
  180. res = text_sys(img)
  181. total_time = 0
  182. cpu_mem, gpu_mem, gpu_util = 0, 0, 0
  183. _st = time.time()
  184. count = 0
  185. for idx, image_file in enumerate(image_file_list):
  186. img, flag_gif, flag_pdf = check_and_read(image_file)
  187. if not flag_gif and not flag_pdf:
  188. img = cv2.imread(image_file)
  189. if not flag_pdf:
  190. if img is None:
  191. logger.debug("error in loading image:{}".format(image_file))
  192. continue
  193. imgs = [img]
  194. else:
  195. page_num = args.page_num
  196. if page_num > len(img) or page_num == 0:
  197. page_num = len(img)
  198. imgs = img[:page_num]
  199. for index, img in enumerate(imgs):
  200. starttime = time.time()
  201. dt_boxes, rec_res, time_dict = text_sys(img)
  202. elapse = time.time() - starttime
  203. total_time += elapse
  204. if len(imgs) > 1:
  205. logger.debug(
  206. str(idx)
  207. + "_"
  208. + str(index)
  209. + " Predict time of %s: %.3fs" % (image_file, elapse)
  210. )
  211. else:
  212. logger.debug(
  213. str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)
  214. )
  215. for text, score in rec_res:
  216. logger.debug("{}, {:.3f}".format(text, score))
  217. res = [
  218. {
  219. "transcription": rec_res[i][0],
  220. "points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
  221. }
  222. for i in range(len(dt_boxes))
  223. ]
  224. if len(imgs) > 1:
  225. save_pred = (
  226. os.path.basename(image_file)
  227. + "_"
  228. + str(index)
  229. + "\t"
  230. + json.dumps(res, ensure_ascii=False)
  231. + "\n"
  232. )
  233. else:
  234. save_pred = (
  235. os.path.basename(image_file)
  236. + "\t"
  237. + json.dumps(res, ensure_ascii=False)
  238. + "\n"
  239. )
  240. save_results.append(save_pred)
  241. if is_visualize:
  242. image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
  243. boxes = dt_boxes
  244. txts = [rec_res[i][0] for i in range(len(rec_res))]
  245. scores = [rec_res[i][1] for i in range(len(rec_res))]
  246. draw_img = draw_ocr_box_txt(
  247. image,
  248. boxes,
  249. txts,
  250. scores,
  251. drop_score=drop_score,
  252. font_path=font_path,
  253. )
  254. if flag_gif:
  255. save_file = image_file[:-3] + "png"
  256. elif flag_pdf:
  257. save_file = image_file.replace(".pdf", "_" + str(index) + ".png")
  258. else:
  259. save_file = image_file
  260. cv2.imwrite(
  261. os.path.join(draw_img_save_dir, os.path.basename(save_file)),
  262. draw_img[:, :, ::-1],
  263. )
  264. logger.debug(
  265. "The visualized image saved in {}".format(
  266. os.path.join(draw_img_save_dir, os.path.basename(save_file))
  267. )
  268. )
  269. logger.info("The predict total time is {}".format(time.time() - _st))
  270. if args.benchmark:
  271. text_sys.text_detector.autolog.report()
  272. text_sys.text_recognizer.autolog.report()
  273. with open(
  274. os.path.join(draw_img_save_dir, "system_results.txt"), "w", encoding="utf-8"
  275. ) as f:
  276. f.writelines(save_results)
  277. if __name__ == "__main__":
  278. args = utility.parse_args()
  279. if args.use_mp:
  280. p_list = []
  281. total_process_num = args.total_process_num
  282. for process_id in range(total_process_num):
  283. cmd = (
  284. [sys.executable, "-u"]
  285. + sys.argv
  286. + ["--process_id={}".format(process_id), "--use_mp={}".format(False)]
  287. )
  288. p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
  289. p_list.append(p)
  290. for p in p_list:
  291. p.wait()
  292. else:
  293. main(args)