picodet_postprocess.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from scipy.special import softmax
  16. def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
  17. """
  18. Args:
  19. box_scores (N, 5): boxes in corner-form and probabilities.
  20. iou_threshold: intersection over union threshold.
  21. top_k: keep top_k results. If k <= 0, keep all the results.
  22. candidate_size: only consider the candidates with the highest scores.
  23. Returns:
  24. picked: a list of indexes of the kept boxes
  25. """
  26. scores = box_scores[:, -1]
  27. boxes = box_scores[:, :-1]
  28. picked = []
  29. indexes = np.argsort(scores)
  30. indexes = indexes[-candidate_size:]
  31. while len(indexes) > 0:
  32. current = indexes[-1]
  33. picked.append(current)
  34. if 0 < top_k == len(picked) or len(indexes) == 1:
  35. break
  36. current_box = boxes[current, :]
  37. indexes = indexes[:-1]
  38. rest_boxes = boxes[indexes, :]
  39. iou = iou_of(
  40. rest_boxes,
  41. np.expand_dims(current_box, axis=0),
  42. )
  43. indexes = indexes[iou <= iou_threshold]
  44. return box_scores[picked, :]
  45. def iou_of(boxes0, boxes1, eps=1e-5):
  46. """Return intersection-over-union (Jaccard index) of boxes.
  47. Args:
  48. boxes0 (N, 4): ground truth boxes.
  49. boxes1 (N or 1, 4): predicted boxes.
  50. eps: a small number to avoid 0 as denominator.
  51. Returns:
  52. iou (N): IoU values.
  53. """
  54. overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
  55. overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
  56. overlap_area = area_of(overlap_left_top, overlap_right_bottom)
  57. area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
  58. area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
  59. return overlap_area / (area0 + area1 - overlap_area + eps)
  60. def area_of(left_top, right_bottom):
  61. """Compute the areas of rectangles given two corners.
  62. Args:
  63. left_top (N, 2): left top corner.
  64. right_bottom (N, 2): right bottom corner.
  65. Returns:
  66. area (N): return the area.
  67. """
  68. hw = np.clip(right_bottom - left_top, 0.0, None)
  69. return hw[..., 0] * hw[..., 1]
  70. def calculate_containment(boxes0, boxes1):
  71. """
  72. Calculate the containment of the boxes.
  73. Args:
  74. boxes0 (N, 4): ground truth boxes.
  75. boxes1 (N or 1, 4): predicted boxes.
  76. Returns:
  77. containment (N): containment values.
  78. """
  79. overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
  80. overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
  81. overlap_area = area_of(overlap_left_top, overlap_right_bottom)
  82. area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
  83. area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
  84. return overlap_area / np.minimum(area0, np.expand_dims(area1, axis=0))
  85. class PicoDetPostProcess(object):
  86. """
  87. Args:
  88. input_shape (int): network input image size
  89. ori_shape (int): ori image shape of before padding
  90. scale_factor (float): scale factor of ori image
  91. enable_mkldnn (bool): whether to open MKLDNN
  92. """
  93. def __init__(
  94. self,
  95. layout_dict_path,
  96. strides=[8, 16, 32, 64],
  97. score_threshold=0.4,
  98. nms_threshold=0.5,
  99. nms_top_k=1000,
  100. keep_top_k=100,
  101. ):
  102. self.labels = self.load_layout_dict(layout_dict_path)
  103. self.strides = strides
  104. self.score_threshold = score_threshold
  105. self.nms_threshold = nms_threshold
  106. self.nms_top_k = nms_top_k
  107. self.keep_top_k = keep_top_k
  108. def load_layout_dict(self, layout_dict_path):
  109. with open(layout_dict_path, "r", encoding="utf-8") as fp:
  110. labels = fp.readlines()
  111. return [label.strip("\n") for label in labels]
  112. def warp_boxes(self, boxes, ori_shape):
  113. """Apply transform to boxes"""
  114. width, height = ori_shape[1], ori_shape[0]
  115. n = len(boxes)
  116. if n:
  117. # warp points
  118. xy = np.ones((n * 4, 3))
  119. xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
  120. n * 4, 2
  121. ) # x1y1, x2y2, x1y2, x2y1
  122. # xy = xy @ M.T # transform
  123. xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
  124. # create new boxes
  125. x = xy[:, [0, 2, 4, 6]]
  126. y = xy[:, [1, 3, 5, 7]]
  127. xy = (
  128. np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  129. )
  130. # clip boxes
  131. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  132. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  133. return xy.astype(np.float32)
  134. else:
  135. return boxes
  136. def img_info(self, ori_img, img):
  137. origin_shape = ori_img.shape
  138. resize_shape = img.shape
  139. im_scale_y = resize_shape[2] / float(origin_shape[0])
  140. im_scale_x = resize_shape[3] / float(origin_shape[1])
  141. scale_factor = np.array([im_scale_y, im_scale_x], dtype=np.float32)
  142. img_shape = np.array(img.shape[2:], dtype=np.float32)
  143. input_shape = np.array(img).astype("float32").shape[2:]
  144. ori_shape = np.array((img_shape,)).astype("float32")
  145. scale_factor = np.array((scale_factor,)).astype("float32")
  146. return ori_shape, input_shape, scale_factor
  147. def __call__(self, ori_img, img, preds):
  148. scores, raw_boxes = preds["boxes"], preds["boxes_num"]
  149. batch_size = raw_boxes[0].shape[0]
  150. reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
  151. out_boxes_num = []
  152. out_boxes_list = []
  153. results = []
  154. ori_shape, input_shape, scale_factor = self.img_info(ori_img, img)
  155. for batch_id in range(batch_size):
  156. # generate centers
  157. decode_boxes = []
  158. select_scores = []
  159. for stride, box_distribute, score in zip(self.strides, raw_boxes, scores):
  160. box_distribute = box_distribute[batch_id]
  161. score = score[batch_id]
  162. # centers
  163. fm_h = input_shape[0] / stride
  164. fm_w = input_shape[1] / stride
  165. h_range = np.arange(fm_h)
  166. w_range = np.arange(fm_w)
  167. ww, hh = np.meshgrid(w_range, h_range)
  168. ct_row = (hh.flatten() + 0.5) * stride
  169. ct_col = (ww.flatten() + 0.5) * stride
  170. center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
  171. # box distribution to distance
  172. reg_range = np.arange(reg_max + 1)
  173. box_distance = box_distribute.reshape((-1, reg_max + 1))
  174. box_distance = softmax(box_distance, axis=1)
  175. box_distance = box_distance * np.expand_dims(reg_range, axis=0)
  176. box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
  177. box_distance = box_distance * stride
  178. # top K candidate
  179. topk_idx = np.argsort(score.max(axis=1))[::-1]
  180. topk_idx = topk_idx[: self.nms_top_k]
  181. center = center[topk_idx]
  182. score = score[topk_idx]
  183. box_distance = box_distance[topk_idx]
  184. # decode box
  185. decode_box = center + [-1, -1, 1, 1] * box_distance
  186. select_scores.append(score)
  187. decode_boxes.append(decode_box)
  188. # nms
  189. bboxes = np.concatenate(decode_boxes, axis=0)
  190. confidences = np.concatenate(select_scores, axis=0)
  191. picked_box_probs = []
  192. picked_labels = []
  193. for class_index in range(0, confidences.shape[1]):
  194. probs = confidences[:, class_index]
  195. mask = probs > self.score_threshold
  196. probs = probs[mask]
  197. if probs.shape[0] == 0:
  198. continue
  199. subset_boxes = bboxes[mask, :]
  200. box_probs = np.concatenate([subset_boxes, probs.reshape(-1, 1)], axis=1)
  201. box_probs = hard_nms(
  202. box_probs,
  203. iou_threshold=self.nms_threshold,
  204. top_k=self.keep_top_k,
  205. )
  206. picked_box_probs.append(box_probs)
  207. picked_labels.extend([class_index] * box_probs.shape[0])
  208. if len(picked_box_probs) == 0:
  209. out_boxes_list.append(np.empty((0, 4)))
  210. out_boxes_num.append(0)
  211. else:
  212. picked_box_probs = np.concatenate(picked_box_probs)
  213. # resize output boxes
  214. picked_box_probs[:, :4] = self.warp_boxes(
  215. picked_box_probs[:, :4], ori_shape[batch_id]
  216. )
  217. im_scale = np.concatenate(
  218. [scale_factor[batch_id][::-1], scale_factor[batch_id][::-1]]
  219. )
  220. picked_box_probs[:, :4] /= im_scale
  221. # clas score box
  222. out_boxes_list.append(
  223. np.concatenate(
  224. [
  225. np.expand_dims(np.array(picked_labels), axis=-1),
  226. np.expand_dims(picked_box_probs[:, 4], axis=-1),
  227. picked_box_probs[:, :4],
  228. ],
  229. axis=1,
  230. )
  231. )
  232. out_boxes_num.append(len(picked_labels))
  233. out_boxes_list = np.concatenate(out_boxes_list, axis=0)
  234. out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
  235. for dt in out_boxes_list:
  236. clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
  237. label = self.labels[clsid]
  238. result = {"bbox": bbox, "label": label, "score": score}
  239. results.append(result)
  240. # Handle conflict where a box is simultaneously recognized as multiple labels.
  241. # Use IoU to find similar boxes. Prioritize labels as table, text, and others when deduplicate similar boxes.
  242. bboxes = np.array([x["bbox"] for x in results])
  243. duplicate_idx = list()
  244. for i in range(len(results)):
  245. if i in duplicate_idx:
  246. continue
  247. containments = calculate_containment(bboxes, bboxes[i, ...])
  248. overlaps = np.where(containments > 0.5)[0]
  249. if len(overlaps) > 1:
  250. table_box = [x for x in overlaps if results[x]["label"] == "table"]
  251. if len(table_box) > 0:
  252. keep = sorted(
  253. [(x, results[x]) for x in table_box],
  254. key=lambda x: x[1]["score"],
  255. reverse=True,
  256. )[0][0]
  257. else:
  258. keep = sorted(
  259. [(x, results[x]) for x in overlaps],
  260. key=lambda x: x[1]["score"],
  261. reverse=True,
  262. )[0][0]
  263. duplicate_idx.extend([x for x in overlaps if x != keep])
  264. results = [x for i, x in enumerate(results) if i not in duplicate_idx]
  265. return results