drrg_postprocess.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/postprocess/drrg_postprocessor.py
  17. """
  18. import functools
  19. import operator
  20. import numpy as np
  21. import paddle
  22. from numpy.linalg import norm
  23. import cv2
  24. class Node:
  25. def __init__(self, ind):
  26. self.__ind = ind
  27. self.__links = set()
  28. @property
  29. def ind(self):
  30. return self.__ind
  31. @property
  32. def links(self):
  33. return set(self.__links)
  34. def add_link(self, link_node):
  35. self.__links.add(link_node)
  36. link_node.__links.add(self)
  37. def graph_propagation(edges, scores, text_comps, edge_len_thr=50.0):
  38. assert edges.ndim == 2
  39. assert edges.shape[1] == 2
  40. assert edges.shape[0] == scores.shape[0]
  41. assert text_comps.ndim == 2
  42. assert isinstance(edge_len_thr, float)
  43. edges = np.sort(edges, axis=1)
  44. score_dict = {}
  45. for i, edge in enumerate(edges):
  46. if text_comps is not None:
  47. box1 = text_comps[edge[0], :8].reshape(4, 2)
  48. box2 = text_comps[edge[1], :8].reshape(4, 2)
  49. center1 = np.mean(box1, axis=0)
  50. center2 = np.mean(box2, axis=0)
  51. distance = norm(center1 - center2)
  52. if distance > edge_len_thr:
  53. scores[i] = 0
  54. if (edge[0], edge[1]) in score_dict:
  55. score_dict[edge[0], edge[1]] = 0.5 * (
  56. score_dict[edge[0], edge[1]] + scores[i]
  57. )
  58. else:
  59. score_dict[edge[0], edge[1]] = scores[i]
  60. nodes = np.sort(np.unique(edges.flatten()))
  61. mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int32)
  62. mapping[nodes] = np.arange(nodes.shape[0])
  63. order_inds = mapping[edges]
  64. vertices = [Node(node) for node in nodes]
  65. for ind in order_inds:
  66. vertices[ind[0]].add_link(vertices[ind[1]])
  67. return vertices, score_dict
  68. def connected_components(nodes, score_dict, link_thr):
  69. assert isinstance(nodes, list)
  70. assert all([isinstance(node, Node) for node in nodes])
  71. assert isinstance(score_dict, dict)
  72. assert isinstance(link_thr, float)
  73. clusters = []
  74. nodes = set(nodes)
  75. while nodes:
  76. node = nodes.pop()
  77. cluster = {node}
  78. node_queue = [node]
  79. while node_queue:
  80. node = node_queue.pop(0)
  81. neighbors = set(
  82. [
  83. neighbor
  84. for neighbor in node.links
  85. if score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr
  86. ]
  87. )
  88. neighbors.difference_update(cluster)
  89. nodes.difference_update(neighbors)
  90. cluster.update(neighbors)
  91. node_queue.extend(neighbors)
  92. clusters.append(list(cluster))
  93. return clusters
  94. def clusters2labels(clusters, num_nodes):
  95. assert isinstance(clusters, list)
  96. assert all([isinstance(cluster, list) for cluster in clusters])
  97. assert all([isinstance(node, Node) for cluster in clusters for node in cluster])
  98. assert isinstance(num_nodes, int)
  99. node_labels = np.zeros(num_nodes)
  100. for cluster_ind, cluster in enumerate(clusters):
  101. for node in cluster:
  102. node_labels[node.ind] = cluster_ind
  103. return node_labels
  104. def remove_single(text_comps, comp_pred_labels):
  105. assert text_comps.ndim == 2
  106. assert text_comps.shape[0] == comp_pred_labels.shape[0]
  107. single_flags = np.zeros_like(comp_pred_labels)
  108. pred_labels = np.unique(comp_pred_labels)
  109. for label in pred_labels:
  110. current_label_flag = comp_pred_labels == label
  111. if np.sum(current_label_flag) == 1:
  112. single_flags[np.where(current_label_flag)[0][0]] = 1
  113. keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]]
  114. filtered_text_comps = text_comps[keep_ind, :]
  115. filtered_labels = comp_pred_labels[keep_ind]
  116. return filtered_text_comps, filtered_labels
  117. def norm2(point1, point2):
  118. return ((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2) ** 0.5
  119. def min_connect_path(points):
  120. assert isinstance(points, list)
  121. assert all([isinstance(point, list) for point in points])
  122. assert all([isinstance(coord, int) for point in points for coord in point])
  123. points_queue = points.copy()
  124. shortest_path = []
  125. current_edge = [[], []]
  126. edge_dict0 = {}
  127. edge_dict1 = {}
  128. current_edge[0] = points_queue[0]
  129. current_edge[1] = points_queue[0]
  130. points_queue.remove(points_queue[0])
  131. while points_queue:
  132. for point in points_queue:
  133. length0 = norm2(point, current_edge[0])
  134. edge_dict0[length0] = [point, current_edge[0]]
  135. length1 = norm2(current_edge[1], point)
  136. edge_dict1[length1] = [current_edge[1], point]
  137. key0 = min(edge_dict0.keys())
  138. key1 = min(edge_dict1.keys())
  139. if key0 <= key1:
  140. start = edge_dict0[key0][0]
  141. end = edge_dict0[key0][1]
  142. shortest_path.insert(0, [points.index(start), points.index(end)])
  143. points_queue.remove(start)
  144. current_edge[0] = start
  145. else:
  146. start = edge_dict1[key1][0]
  147. end = edge_dict1[key1][1]
  148. shortest_path.append([points.index(start), points.index(end)])
  149. points_queue.remove(end)
  150. current_edge[1] = end
  151. edge_dict0 = {}
  152. edge_dict1 = {}
  153. shortest_path = functools.reduce(operator.concat, shortest_path)
  154. shortest_path = sorted(set(shortest_path), key=shortest_path.index)
  155. return shortest_path
  156. def in_contour(cont, point):
  157. x, y = point
  158. is_inner = cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0.5
  159. return is_inner
  160. def fix_corner(top_line, bot_line, start_box, end_box):
  161. assert isinstance(top_line, list)
  162. assert all(isinstance(point, list) for point in top_line)
  163. assert isinstance(bot_line, list)
  164. assert all(isinstance(point, list) for point in bot_line)
  165. assert start_box.shape == end_box.shape == (4, 2)
  166. contour = np.array(top_line + bot_line[::-1])
  167. start_left_mid = (start_box[0] + start_box[3]) / 2
  168. start_right_mid = (start_box[1] + start_box[2]) / 2
  169. end_left_mid = (end_box[0] + end_box[3]) / 2
  170. end_right_mid = (end_box[1] + end_box[2]) / 2
  171. if not in_contour(contour, start_left_mid):
  172. top_line.insert(0, start_box[0].tolist())
  173. bot_line.insert(0, start_box[3].tolist())
  174. elif not in_contour(contour, start_right_mid):
  175. top_line.insert(0, start_box[1].tolist())
  176. bot_line.insert(0, start_box[2].tolist())
  177. if not in_contour(contour, end_left_mid):
  178. top_line.append(end_box[0].tolist())
  179. bot_line.append(end_box[3].tolist())
  180. elif not in_contour(contour, end_right_mid):
  181. top_line.append(end_box[1].tolist())
  182. bot_line.append(end_box[2].tolist())
  183. return top_line, bot_line
  184. def comps2boundaries(text_comps, comp_pred_labels):
  185. assert text_comps.ndim == 2
  186. assert len(text_comps) == len(comp_pred_labels)
  187. boundaries = []
  188. if len(text_comps) < 1:
  189. return boundaries
  190. for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
  191. cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
  192. text_comp_boxes = (
  193. text_comps[cluster_comp_inds, :8].reshape((-1, 4, 2)).astype(np.int32)
  194. )
  195. score = np.mean(text_comps[cluster_comp_inds, -1])
  196. if text_comp_boxes.shape[0] < 1:
  197. continue
  198. elif text_comp_boxes.shape[0] > 1:
  199. centers = np.mean(text_comp_boxes, axis=1).astype(np.int32).tolist()
  200. shortest_path = min_connect_path(centers)
  201. text_comp_boxes = text_comp_boxes[shortest_path]
  202. top_line = (
  203. np.mean(text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
  204. )
  205. bot_line = (
  206. np.mean(text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
  207. )
  208. top_line, bot_line = fix_corner(
  209. top_line, bot_line, text_comp_boxes[0], text_comp_boxes[-1]
  210. )
  211. boundary_points = top_line + bot_line[::-1]
  212. else:
  213. top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
  214. bot_line = text_comp_boxes[0, 2:4:-1, :].astype(np.int32).tolist()
  215. boundary_points = top_line + bot_line
  216. boundary = [p for coord in boundary_points for p in coord] + [score]
  217. boundaries.append(boundary)
  218. return boundaries
  219. class DRRGPostprocess(object):
  220. """Merge text components and construct boundaries of text instances.
  221. Args:
  222. link_thr (float): The edge score threshold.
  223. """
  224. def __init__(self, link_thr, **kwargs):
  225. assert isinstance(link_thr, float)
  226. self.link_thr = link_thr
  227. def __call__(self, preds, shape_list):
  228. """
  229. Args:
  230. edges (ndarray): The edge array of shape N * 2, each row is a node
  231. index pair that makes up an edge in graph.
  232. scores (ndarray): The edge score array of shape (N,).
  233. text_comps (ndarray): The text components.
  234. Returns:
  235. List[list[float]]: The predicted boundaries of text instances.
  236. """
  237. edges, scores, text_comps = preds
  238. if edges is not None:
  239. if isinstance(edges, paddle.Tensor):
  240. edges = edges.numpy()
  241. if isinstance(scores, paddle.Tensor):
  242. scores = scores.numpy()
  243. if isinstance(text_comps, paddle.Tensor):
  244. text_comps = text_comps.numpy()
  245. assert len(edges) == len(scores)
  246. assert text_comps.ndim == 2
  247. assert text_comps.shape[1] == 9
  248. vertices, score_dict = graph_propagation(edges, scores, text_comps)
  249. clusters = connected_components(vertices, score_dict, self.link_thr)
  250. pred_labels = clusters2labels(clusters, text_comps.shape[0])
  251. text_comps, pred_labels = remove_single(text_comps, pred_labels)
  252. boundaries = comps2boundaries(text_comps, pred_labels)
  253. else:
  254. boundaries = []
  255. boundaries, scores = self.resize_boundary(
  256. boundaries, (1 / shape_list[0, 2:]).tolist()[::-1]
  257. )
  258. boxes_batch = [dict(points=boundaries, scores=scores)]
  259. return boxes_batch
  260. def resize_boundary(self, boundaries, scale_factor):
  261. """Rescale boundaries via scale_factor.
  262. Args:
  263. boundaries (list[list[float]]): The boundary list. Each boundary
  264. with size 2k+1 with k>=4.
  265. scale_factor(ndarray): The scale factor of size (4,).
  266. Returns:
  267. boundaries (list[list[float]]): The scaled boundaries.
  268. """
  269. boxes = []
  270. scores = []
  271. for b in boundaries:
  272. sz = len(b)
  273. scores.append(b[-1])
  274. b = (
  275. (
  276. np.array(b[: sz - 1])
  277. * (np.tile(scale_factor[:2], int((sz - 1) / 2)).reshape(1, sz - 1))
  278. )
  279. .flatten()
  280. .tolist()
  281. )
  282. boxes.append(np.array(b).reshape([-1, 2]))
  283. return boxes, scores