predict_det.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(os.path.abspath(__file__))
  17. sys.path.append(__dir__)
  18. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
  19. os.environ["FLAGS_allocator_strategy"] = "auto_growth"
  20. import cv2
  21. import numpy as np
  22. import time
  23. import sys
  24. import tools.infer.utility as utility
  25. from ppocr.utils.logging import get_logger
  26. from ppocr.utils.utility import get_image_file_list, check_and_read
  27. from ppocr.data import create_operators, transform
  28. from ppocr.postprocess import build_post_process
  29. import json
  30. class TextDetector(object):
  31. def __init__(self, args, logger=None):
  32. if os.path.exists(f"{args.det_model_dir}/inference.yml"):
  33. model_config = utility.load_config(f"{args.det_model_dir}/inference.yml")
  34. model_name = model_config.get("Global", {}).get("model_name", "")
  35. if model_name and model_name not in [
  36. "PP-OCRv5_mobile_det",
  37. "PP-OCRv5_server_det",
  38. ]:
  39. raise ValueError(
  40. f"{model_name} is not supported. Please check if the model is supported by the PaddleOCR wheel."
  41. )
  42. if logger is None:
  43. logger = get_logger()
  44. self.args = args
  45. self.det_algorithm = args.det_algorithm
  46. self.use_onnx = args.use_onnx
  47. pre_process_list = [
  48. {
  49. "DetResizeForTest": {
  50. "limit_side_len": args.det_limit_side_len,
  51. "limit_type": args.det_limit_type,
  52. }
  53. },
  54. {
  55. "NormalizeImage": {
  56. "std": [0.229, 0.224, 0.225],
  57. "mean": [0.485, 0.456, 0.406],
  58. "scale": "1./255.",
  59. "order": "hwc",
  60. }
  61. },
  62. {"ToCHWImage": None},
  63. {"KeepKeys": {"keep_keys": ["image", "shape"]}},
  64. ]
  65. postprocess_params = {}
  66. if self.det_algorithm == "DB":
  67. postprocess_params["name"] = "DBPostProcess"
  68. postprocess_params["thresh"] = args.det_db_thresh
  69. postprocess_params["box_thresh"] = args.det_db_box_thresh
  70. postprocess_params["max_candidates"] = 1000
  71. postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
  72. postprocess_params["use_dilation"] = args.use_dilation
  73. postprocess_params["score_mode"] = args.det_db_score_mode
  74. postprocess_params["box_type"] = args.det_box_type
  75. elif self.det_algorithm == "DB++":
  76. postprocess_params["name"] = "DBPostProcess"
  77. postprocess_params["thresh"] = args.det_db_thresh
  78. postprocess_params["box_thresh"] = args.det_db_box_thresh
  79. postprocess_params["max_candidates"] = 1000
  80. postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
  81. postprocess_params["use_dilation"] = args.use_dilation
  82. postprocess_params["score_mode"] = args.det_db_score_mode
  83. postprocess_params["box_type"] = args.det_box_type
  84. pre_process_list[1] = {
  85. "NormalizeImage": {
  86. "std": [1.0, 1.0, 1.0],
  87. "mean": [0.48109378172549, 0.45752457890196, 0.40787054090196],
  88. "scale": "1./255.",
  89. "order": "hwc",
  90. }
  91. }
  92. elif self.det_algorithm == "EAST":
  93. postprocess_params["name"] = "EASTPostProcess"
  94. postprocess_params["score_thresh"] = args.det_east_score_thresh
  95. postprocess_params["cover_thresh"] = args.det_east_cover_thresh
  96. postprocess_params["nms_thresh"] = args.det_east_nms_thresh
  97. elif self.det_algorithm == "SAST":
  98. pre_process_list[0] = {
  99. "DetResizeForTest": {"resize_long": args.det_limit_side_len}
  100. }
  101. postprocess_params["name"] = "SASTPostProcess"
  102. postprocess_params["score_thresh"] = args.det_sast_score_thresh
  103. postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
  104. if args.det_box_type == "poly":
  105. postprocess_params["sample_pts_num"] = 6
  106. postprocess_params["expand_scale"] = 1.2
  107. postprocess_params["shrink_ratio_of_width"] = 0.2
  108. else:
  109. postprocess_params["sample_pts_num"] = 2
  110. postprocess_params["expand_scale"] = 1.0
  111. postprocess_params["shrink_ratio_of_width"] = 0.3
  112. elif self.det_algorithm == "PSE":
  113. postprocess_params["name"] = "PSEPostProcess"
  114. postprocess_params["thresh"] = args.det_pse_thresh
  115. postprocess_params["box_thresh"] = args.det_pse_box_thresh
  116. postprocess_params["min_area"] = args.det_pse_min_area
  117. postprocess_params["box_type"] = args.det_box_type
  118. postprocess_params["scale"] = args.det_pse_scale
  119. elif self.det_algorithm == "FCE":
  120. pre_process_list[0] = {"DetResizeForTest": {"rescale_img": [1080, 736]}}
  121. postprocess_params["name"] = "FCEPostProcess"
  122. postprocess_params["scales"] = args.scales
  123. postprocess_params["alpha"] = args.alpha
  124. postprocess_params["beta"] = args.beta
  125. postprocess_params["fourier_degree"] = args.fourier_degree
  126. postprocess_params["box_type"] = args.det_box_type
  127. elif self.det_algorithm == "CT":
  128. pre_process_list[0] = {"ScaleAlignedShort": {"short_size": 640}}
  129. postprocess_params["name"] = "CTPostProcess"
  130. else:
  131. logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
  132. sys.exit(0)
  133. self.preprocess_op = create_operators(pre_process_list)
  134. self.postprocess_op = build_post_process(postprocess_params)
  135. (
  136. self.predictor,
  137. self.input_tensor,
  138. self.output_tensors,
  139. self.config,
  140. ) = utility.create_predictor(args, "det", logger)
  141. if self.use_onnx:
  142. img_h, img_w = self.input_tensor.shape[2:]
  143. if isinstance(img_h, str) or isinstance(img_w, str):
  144. pass
  145. elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
  146. pre_process_list[0] = {
  147. "DetResizeForTest": {"image_shape": [img_h, img_w]}
  148. }
  149. self.preprocess_op = create_operators(pre_process_list)
  150. if args.benchmark:
  151. import auto_log
  152. pid = os.getpid()
  153. gpu_id = utility.get_infer_gpuid()
  154. self.autolog = auto_log.AutoLogger(
  155. model_name="det",
  156. model_precision=args.precision,
  157. batch_size=1,
  158. data_shape="dynamic",
  159. save_path=None, # not used if logger is not None
  160. inference_config=self.config,
  161. pids=pid,
  162. process_name=None,
  163. gpu_ids=gpu_id if args.use_gpu else None,
  164. time_keys=["preprocess_time", "inference_time", "postprocess_time"],
  165. warmup=2,
  166. logger=logger,
  167. )
  168. def order_points_clockwise(self, pts):
  169. rect = np.zeros((4, 2), dtype="float32")
  170. s = pts.sum(axis=1)
  171. rect[0] = pts[np.argmin(s)]
  172. rect[2] = pts[np.argmax(s)]
  173. tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
  174. diff = np.diff(np.array(tmp), axis=1)
  175. rect[1] = tmp[np.argmin(diff)]
  176. rect[3] = tmp[np.argmax(diff)]
  177. return rect
  178. def pad_polygons(self, polygon, max_points):
  179. padding_size = max_points - len(polygon)
  180. if padding_size == 0:
  181. return polygon
  182. last_point = polygon[-1]
  183. padding = np.repeat([last_point], padding_size, axis=0)
  184. return np.vstack([polygon, padding])
  185. def clip_det_res(self, points, img_height, img_width):
  186. for pno in range(points.shape[0]):
  187. points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
  188. points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
  189. return points
  190. def filter_tag_det_res(self, dt_boxes, image_shape):
  191. img_height, img_width = image_shape[0:2]
  192. dt_boxes_new = []
  193. for box in dt_boxes:
  194. if type(box) is list:
  195. box = np.array(box)
  196. box = self.order_points_clockwise(box)
  197. box = self.clip_det_res(box, img_height, img_width)
  198. rect_width = int(np.linalg.norm(box[0] - box[1]))
  199. rect_height = int(np.linalg.norm(box[0] - box[3]))
  200. if rect_width <= 3 or rect_height <= 3:
  201. continue
  202. dt_boxes_new.append(box)
  203. dt_boxes = np.array(dt_boxes_new)
  204. return dt_boxes
  205. def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
  206. img_height, img_width = image_shape[0:2]
  207. dt_boxes_new = []
  208. for box in dt_boxes:
  209. if type(box) is list:
  210. box = np.array(box)
  211. box = self.clip_det_res(box, img_height, img_width)
  212. dt_boxes_new.append(box)
  213. if len(dt_boxes_new) > 0:
  214. max_points = max(len(polygon) for polygon in dt_boxes_new)
  215. dt_boxes_new = [
  216. self.pad_polygons(polygon, max_points) for polygon in dt_boxes_new
  217. ]
  218. dt_boxes = np.array(dt_boxes_new)
  219. return dt_boxes
  220. def predict(self, img):
  221. ori_im = img.copy()
  222. data = {"image": img}
  223. st = time.time()
  224. if self.args.benchmark:
  225. self.autolog.times.start()
  226. data = transform(data, self.preprocess_op)
  227. img, shape_list = data
  228. if img is None:
  229. return None, 0
  230. img = np.expand_dims(img, axis=0)
  231. shape_list = np.expand_dims(shape_list, axis=0)
  232. img = img.copy()
  233. if self.args.benchmark:
  234. self.autolog.times.stamp()
  235. if self.use_onnx:
  236. input_dict = {}
  237. input_dict[self.input_tensor.name] = img
  238. outputs = self.predictor.run(self.output_tensors, input_dict)
  239. else:
  240. self.input_tensor.copy_from_cpu(img)
  241. self.predictor.run()
  242. outputs = []
  243. for output_tensor in self.output_tensors:
  244. output = output_tensor.copy_to_cpu()
  245. outputs.append(output)
  246. if self.args.benchmark:
  247. self.autolog.times.stamp()
  248. preds = {}
  249. if self.det_algorithm == "EAST":
  250. preds["f_geo"] = outputs[0]
  251. preds["f_score"] = outputs[1]
  252. elif self.det_algorithm == "SAST":
  253. preds["f_border"] = outputs[0]
  254. preds["f_score"] = outputs[1]
  255. preds["f_tco"] = outputs[2]
  256. preds["f_tvo"] = outputs[3]
  257. elif self.det_algorithm in ["DB", "PSE", "DB++"]:
  258. preds["maps"] = outputs[0]
  259. elif self.det_algorithm == "FCE":
  260. for i, output in enumerate(outputs):
  261. preds["level_{}".format(i)] = output
  262. elif self.det_algorithm == "CT":
  263. preds["maps"] = outputs[0]
  264. preds["score"] = outputs[1]
  265. else:
  266. raise NotImplementedError
  267. post_result = self.postprocess_op(preds, shape_list)
  268. dt_boxes = post_result[0]["points"]
  269. if self.args.det_box_type == "poly":
  270. dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
  271. else:
  272. dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
  273. if self.args.benchmark:
  274. self.autolog.times.end(stamp=True)
  275. et = time.time()
  276. return dt_boxes, et - st
  277. def __call__(self, img, use_slice=False):
  278. # For image like poster with one side much greater than the other side,
  279. # splitting recursively and processing with overlap to enhance performance.
  280. MIN_BOUND_DISTANCE = 50
  281. dt_boxes = np.zeros((0, 4, 2), dtype=np.float32)
  282. elapse = 0
  283. if (
  284. img.shape[0] / img.shape[1] > 2
  285. and img.shape[0] > self.args.det_limit_side_len
  286. and use_slice
  287. ):
  288. start_h = 0
  289. end_h = 0
  290. while end_h <= img.shape[0]:
  291. end_h = start_h + img.shape[1] * 3 // 4
  292. subimg = img[start_h:end_h, :]
  293. if len(subimg) == 0:
  294. break
  295. sub_dt_boxes, sub_elapse = self.predict(subimg)
  296. offset = start_h
  297. # To prevent text blocks from being cut off, roll back a certain buffer area.
  298. if (
  299. len(sub_dt_boxes) == 0
  300. or img.shape[1] - max([x[-1][1] for x in sub_dt_boxes])
  301. > MIN_BOUND_DISTANCE
  302. ):
  303. start_h = end_h
  304. else:
  305. sorted_indices = np.argsort(sub_dt_boxes[:, 2, 1])
  306. sub_dt_boxes = sub_dt_boxes[sorted_indices]
  307. bottom_line = (
  308. 0
  309. if len(sub_dt_boxes) <= 1
  310. else int(np.max(sub_dt_boxes[:-1, 2, 1]))
  311. )
  312. if bottom_line > 0:
  313. start_h += bottom_line
  314. sub_dt_boxes = sub_dt_boxes[
  315. sub_dt_boxes[:, 2, 1] <= bottom_line
  316. ]
  317. else:
  318. start_h = end_h
  319. if len(sub_dt_boxes) > 0:
  320. if dt_boxes.shape[0] == 0:
  321. dt_boxes = sub_dt_boxes + np.array(
  322. [0, offset], dtype=np.float32
  323. )
  324. else:
  325. dt_boxes = np.append(
  326. dt_boxes,
  327. sub_dt_boxes + np.array([0, offset], dtype=np.float32),
  328. axis=0,
  329. )
  330. elapse += sub_elapse
  331. elif (
  332. img.shape[1] / img.shape[0] > 3
  333. and img.shape[1] > self.args.det_limit_side_len * 3
  334. and use_slice
  335. ):
  336. start_w = 0
  337. end_w = 0
  338. while end_w <= img.shape[1]:
  339. end_w = start_w + img.shape[0] * 3 // 4
  340. subimg = img[:, start_w:end_w]
  341. if len(subimg) == 0:
  342. break
  343. sub_dt_boxes, sub_elapse = self.predict(subimg)
  344. offset = start_w
  345. if (
  346. len(sub_dt_boxes) == 0
  347. or img.shape[0] - max([x[-1][0] for x in sub_dt_boxes])
  348. > MIN_BOUND_DISTANCE
  349. ):
  350. start_w = end_w
  351. else:
  352. sorted_indices = np.argsort(sub_dt_boxes[:, 2, 0])
  353. sub_dt_boxes = sub_dt_boxes[sorted_indices]
  354. right_line = (
  355. 0
  356. if len(sub_dt_boxes) <= 1
  357. else int(np.max(sub_dt_boxes[:-1, 1, 0]))
  358. )
  359. if right_line > 0:
  360. start_w += right_line
  361. sub_dt_boxes = sub_dt_boxes[sub_dt_boxes[:, 1, 0] <= right_line]
  362. else:
  363. start_w = end_w
  364. if len(sub_dt_boxes) > 0:
  365. if dt_boxes.shape[0] == 0:
  366. dt_boxes = sub_dt_boxes + np.array(
  367. [offset, 0], dtype=np.float32
  368. )
  369. else:
  370. dt_boxes = np.append(
  371. dt_boxes,
  372. sub_dt_boxes + np.array([offset, 0], dtype=np.float32),
  373. axis=0,
  374. )
  375. elapse += sub_elapse
  376. else:
  377. dt_boxes, elapse = self.predict(img)
  378. return dt_boxes, elapse
  379. if __name__ == "__main__":
  380. args = utility.parse_args()
  381. image_file_list = get_image_file_list(args.image_dir)
  382. total_time = 0
  383. draw_img_save_dir = args.draw_img_save_dir
  384. os.makedirs(draw_img_save_dir, exist_ok=True)
  385. # logger
  386. log_file = args.save_log_path
  387. if os.path.isdir(args.save_log_path) or (
  388. not os.path.exists(args.save_log_path) and args.save_log_path.endswith("/")
  389. ):
  390. log_file = os.path.join(log_file, "benchmark_detection.log")
  391. logger = get_logger(log_file=log_file)
  392. # create text detector
  393. text_detector = TextDetector(args, logger)
  394. if args.warmup:
  395. img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
  396. for i in range(2):
  397. res = text_detector(img)
  398. save_results = []
  399. for idx, image_file in enumerate(image_file_list):
  400. img, flag_gif, flag_pdf = check_and_read(image_file)
  401. if not flag_gif and not flag_pdf:
  402. img = cv2.imread(image_file)
  403. if not flag_pdf:
  404. if img is None:
  405. logger.debug("error in loading image:{}".format(image_file))
  406. continue
  407. imgs = [img]
  408. else:
  409. page_num = args.page_num
  410. if page_num > len(img) or page_num == 0:
  411. page_num = len(img)
  412. imgs = img[:page_num]
  413. for index, img in enumerate(imgs):
  414. st = time.time()
  415. dt_boxes, _ = text_detector(img)
  416. elapse = time.time() - st
  417. total_time += elapse
  418. if len(imgs) > 1:
  419. save_pred = (
  420. os.path.basename(image_file)
  421. + "_"
  422. + str(index)
  423. + "\t"
  424. + str(json.dumps([x.tolist() for x in dt_boxes]))
  425. + "\n"
  426. )
  427. else:
  428. save_pred = (
  429. os.path.basename(image_file)
  430. + "\t"
  431. + str(json.dumps([x.tolist() for x in dt_boxes]))
  432. + "\n"
  433. )
  434. save_results.append(save_pred)
  435. logger.info(save_pred)
  436. if len(imgs) > 1:
  437. logger.info(
  438. "{}_{} The predict time of {}: {}".format(
  439. idx, index, image_file, elapse
  440. )
  441. )
  442. else:
  443. logger.info(
  444. "{} The predict time of {}: {}".format(idx, image_file, elapse)
  445. )
  446. src_im = utility.draw_text_det_res(dt_boxes, img)
  447. if flag_gif:
  448. save_file = image_file[:-3] + "png"
  449. elif flag_pdf:
  450. save_file = image_file.replace(".pdf", "_" + str(index) + ".png")
  451. else:
  452. save_file = image_file
  453. img_path = os.path.join(
  454. draw_img_save_dir, "det_res_{}".format(os.path.basename(save_file))
  455. )
  456. cv2.imwrite(img_path, src_im)
  457. logger.info("The visualized image saved in {}".format(img_path))
  458. with open(os.path.join(draw_img_save_dir, "det_results.txt"), "w") as f:
  459. f.writelines(save_results)
  460. f.close()
  461. if args.benchmark:
  462. text_detector.autolog.report()