proposal_local_graph.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/proposal_local_graph.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import cv2
  22. import numpy as np
  23. import paddle
  24. import paddle.nn as nn
  25. import paddle.nn.functional as F
  26. from lanms import merge_quadrangle_n9 as la_nms
  27. from ppocr.ext_op import RoIAlignRotated
  28. from .local_graph import (
  29. euclidean_distance_matrix,
  30. feature_embedding,
  31. normalize_adjacent_matrix,
  32. )
  33. def fill_hole(input_mask):
  34. h, w = input_mask.shape
  35. canvas = np.zeros((h + 2, w + 2), np.uint8)
  36. canvas[1 : h + 1, 1 : w + 1] = input_mask.copy()
  37. mask = np.zeros((h + 4, w + 4), np.uint8)
  38. cv2.floodFill(canvas, mask, (0, 0), 1)
  39. canvas = canvas[1 : h + 1, 1 : w + 1].astype(np.bool_)
  40. return ~canvas | input_mask
  41. class ProposalLocalGraphs:
  42. def __init__(
  43. self,
  44. k_at_hops,
  45. num_adjacent_linkages,
  46. node_geo_feat_len,
  47. pooling_scale,
  48. pooling_output_size,
  49. nms_thr,
  50. min_width,
  51. max_width,
  52. comp_shrink_ratio,
  53. comp_w_h_ratio,
  54. comp_score_thr,
  55. text_region_thr,
  56. center_region_thr,
  57. center_region_area_thr,
  58. ):
  59. assert len(k_at_hops) == 2
  60. assert isinstance(k_at_hops, tuple)
  61. assert isinstance(num_adjacent_linkages, int)
  62. assert isinstance(node_geo_feat_len, int)
  63. assert isinstance(pooling_scale, float)
  64. assert isinstance(pooling_output_size, tuple)
  65. assert isinstance(nms_thr, float)
  66. assert isinstance(min_width, float)
  67. assert isinstance(max_width, float)
  68. assert isinstance(comp_shrink_ratio, float)
  69. assert isinstance(comp_w_h_ratio, float)
  70. assert isinstance(comp_score_thr, float)
  71. assert isinstance(text_region_thr, float)
  72. assert isinstance(center_region_thr, float)
  73. assert isinstance(center_region_area_thr, int)
  74. self.k_at_hops = k_at_hops
  75. self.active_connection = num_adjacent_linkages
  76. self.local_graph_depth = len(self.k_at_hops)
  77. self.node_geo_feat_dim = node_geo_feat_len
  78. self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
  79. self.nms_thr = nms_thr
  80. self.min_width = min_width
  81. self.max_width = max_width
  82. self.comp_shrink_ratio = comp_shrink_ratio
  83. self.comp_w_h_ratio = comp_w_h_ratio
  84. self.comp_score_thr = comp_score_thr
  85. self.text_region_thr = text_region_thr
  86. self.center_region_thr = center_region_thr
  87. self.center_region_area_thr = center_region_area_thr
  88. def propose_comps(
  89. self,
  90. score_map,
  91. top_height_map,
  92. bot_height_map,
  93. sin_map,
  94. cos_map,
  95. comp_score_thr,
  96. min_width,
  97. max_width,
  98. comp_shrink_ratio,
  99. comp_w_h_ratio,
  100. ):
  101. """Propose text components.
  102. Args:
  103. score_map (ndarray): The score map for NMS.
  104. top_height_map (ndarray): The predicted text height map from each
  105. pixel in text center region to top sideline.
  106. bot_height_map (ndarray): The predicted text height map from each
  107. pixel in text center region to bottom sideline.
  108. sin_map (ndarray): The predicted sin(theta) map.
  109. cos_map (ndarray): The predicted cos(theta) map.
  110. comp_score_thr (float): The score threshold of text component.
  111. min_width (float): The minimum width of text components.
  112. max_width (float): The maximum width of text components.
  113. comp_shrink_ratio (float): The shrink ratio of text components.
  114. comp_w_h_ratio (float): The width to height ratio of text
  115. components.
  116. Returns:
  117. text_comps (ndarray): The text components.
  118. """
  119. comp_centers = np.argwhere(score_map > comp_score_thr)
  120. comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
  121. y = comp_centers[:, 0]
  122. x = comp_centers[:, 1]
  123. top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
  124. bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
  125. sin = sin_map[y, x].reshape((-1, 1))
  126. cos = cos_map[y, x].reshape((-1, 1))
  127. top_mid_pts = comp_centers + np.hstack([top_height * sin, top_height * cos])
  128. bot_mid_pts = comp_centers - np.hstack([bot_height * sin, bot_height * cos])
  129. width = (top_height + bot_height) * comp_w_h_ratio
  130. width = np.clip(width, min_width, max_width)
  131. r = width / 2
  132. tl = top_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
  133. tr = top_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
  134. br = bot_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
  135. bl = bot_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
  136. text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
  137. score = score_map[y, x].reshape((-1, 1))
  138. text_comps = np.hstack([text_comps, score])
  139. return text_comps
  140. def propose_comps_and_attribs(
  141. self,
  142. text_region_map,
  143. center_region_map,
  144. top_height_map,
  145. bot_height_map,
  146. sin_map,
  147. cos_map,
  148. ):
  149. """Generate text components and attributes.
  150. Args:
  151. text_region_map (ndarray): The predicted text region probability
  152. map.
  153. center_region_map (ndarray): The predicted text center region
  154. probability map.
  155. top_height_map (ndarray): The predicted text height map from each
  156. pixel in text center region to top sideline.
  157. bot_height_map (ndarray): The predicted text height map from each
  158. pixel in text center region to bottom sideline.
  159. sin_map (ndarray): The predicted sin(theta) map.
  160. cos_map (ndarray): The predicted cos(theta) map.
  161. Returns:
  162. comp_attribs (ndarray): The text component attributes.
  163. text_comps (ndarray): The text components.
  164. """
  165. assert (
  166. text_region_map.shape
  167. == center_region_map.shape
  168. == top_height_map.shape
  169. == bot_height_map.shape
  170. == sin_map.shape
  171. == cos_map.shape
  172. )
  173. text_mask = text_region_map > self.text_region_thr
  174. center_region_mask = (center_region_map > self.center_region_thr) * text_mask
  175. scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2 + 1e-8))
  176. sin_map, cos_map = sin_map * scale, cos_map * scale
  177. center_region_mask = fill_hole(center_region_mask)
  178. center_region_contours, _ = cv2.findContours(
  179. center_region_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
  180. )
  181. mask_sz = center_region_map.shape
  182. comp_list = []
  183. for contour in center_region_contours:
  184. current_center_mask = np.zeros(mask_sz)
  185. cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
  186. if current_center_mask.sum() <= self.center_region_area_thr:
  187. continue
  188. score_map = text_region_map * current_center_mask
  189. text_comps = self.propose_comps(
  190. score_map,
  191. top_height_map,
  192. bot_height_map,
  193. sin_map,
  194. cos_map,
  195. self.comp_score_thr,
  196. self.min_width,
  197. self.max_width,
  198. self.comp_shrink_ratio,
  199. self.comp_w_h_ratio,
  200. )
  201. text_comps = la_nms(text_comps, self.nms_thr)
  202. text_comp_mask = np.zeros(mask_sz)
  203. text_comp_boxes = text_comps[:, :8].reshape((-1, 4, 2)).astype(np.int32)
  204. cv2.drawContours(text_comp_mask, text_comp_boxes, -1, 1, -1)
  205. if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
  206. continue
  207. if text_comps.shape[-1] > 0:
  208. comp_list.append(text_comps)
  209. if len(comp_list) <= 0:
  210. return None, None
  211. text_comps = np.vstack(comp_list)
  212. text_comp_boxes = text_comps[:, :8].reshape((-1, 4, 2))
  213. centers = np.mean(text_comp_boxes, axis=1).astype(np.int32)
  214. x = centers[:, 0]
  215. y = centers[:, 1]
  216. scores = []
  217. for text_comp_box in text_comp_boxes:
  218. text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0, mask_sz[1] - 1)
  219. text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0, mask_sz[0] - 1)
  220. min_coord = np.min(text_comp_box, axis=0).astype(np.int32)
  221. max_coord = np.max(text_comp_box, axis=0).astype(np.int32)
  222. text_comp_box = text_comp_box - min_coord
  223. box_sz = max_coord - min_coord + 1
  224. temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
  225. cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1)
  226. temp_region_patch = text_region_map[
  227. min_coord[1] : (max_coord[1] + 1), min_coord[0] : (max_coord[0] + 1)
  228. ]
  229. score = cv2.mean(temp_region_patch, temp_comp_mask)[0]
  230. scores.append(score)
  231. scores = np.array(scores).reshape((-1, 1))
  232. text_comps = np.hstack([text_comps[:, :-1], scores])
  233. h = top_height_map[y, x].reshape((-1, 1)) + bot_height_map[y, x].reshape(
  234. (-1, 1)
  235. )
  236. w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width)
  237. sin = sin_map[y, x].reshape((-1, 1))
  238. cos = cos_map[y, x].reshape((-1, 1))
  239. x = x.reshape((-1, 1))
  240. y = y.reshape((-1, 1))
  241. comp_attribs = np.hstack([x, y, h, w, cos, sin])
  242. return comp_attribs, text_comps
  243. def generate_local_graphs(self, sorted_dist_inds, node_feats):
  244. """Generate local graphs and graph convolution network input data.
  245. Args:
  246. sorted_dist_inds (ndarray): The node indices sorted according to
  247. the Euclidean distance.
  248. node_feats (tensor): The features of nodes in graph.
  249. Returns:
  250. local_graphs_node_feats (tensor): The features of nodes in local
  251. graphs.
  252. adjacent_matrices (tensor): The adjacent matrices.
  253. pivots_knn_inds (tensor): The k-nearest neighbor indices in
  254. local graphs.
  255. pivots_local_graphs (tensor): The indices of nodes in local
  256. graphs.
  257. """
  258. assert sorted_dist_inds.ndim == 2
  259. assert (
  260. sorted_dist_inds.shape[0]
  261. == sorted_dist_inds.shape[1]
  262. == node_feats.shape[0]
  263. )
  264. knn_graph = sorted_dist_inds[:, 1 : self.k_at_hops[0] + 1]
  265. pivot_local_graphs = []
  266. pivot_knns = []
  267. for pivot_ind, knn in enumerate(knn_graph):
  268. local_graph_neighbors = set(knn)
  269. for neighbor_ind in knn:
  270. local_graph_neighbors.update(
  271. set(sorted_dist_inds[neighbor_ind, 1 : self.k_at_hops[1] + 1])
  272. )
  273. local_graph_neighbors.discard(pivot_ind)
  274. pivot_local_graph = list(local_graph_neighbors)
  275. pivot_local_graph.insert(0, pivot_ind)
  276. pivot_knn = [pivot_ind] + list(knn)
  277. pivot_local_graphs.append(pivot_local_graph)
  278. pivot_knns.append(pivot_knn)
  279. num_max_nodes = max(
  280. [len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs]
  281. )
  282. local_graphs_node_feat = []
  283. adjacent_matrices = []
  284. pivots_knn_inds = []
  285. pivots_local_graphs = []
  286. for graph_ind, pivot_knn in enumerate(pivot_knns):
  287. pivot_local_graph = pivot_local_graphs[graph_ind]
  288. num_nodes = len(pivot_local_graph)
  289. pivot_ind = pivot_local_graph[0]
  290. node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
  291. knn_inds = paddle.cast(
  292. paddle.to_tensor([node2ind_map[i] for i in pivot_knn[1:]]), "int64"
  293. )
  294. pivot_feats = node_feats[pivot_ind]
  295. normalized_feats = (
  296. node_feats[paddle.to_tensor(pivot_local_graph)] - pivot_feats
  297. )
  298. adjacent_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
  299. for node in pivot_local_graph:
  300. neighbors = sorted_dist_inds[node, 1 : self.active_connection + 1]
  301. for neighbor in neighbors:
  302. if neighbor in pivot_local_graph:
  303. adjacent_matrix[node2ind_map[node], node2ind_map[neighbor]] = 1
  304. adjacent_matrix[node2ind_map[neighbor], node2ind_map[node]] = 1
  305. adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
  306. pad_adjacent_matrix = paddle.zeros(
  307. (num_max_nodes, num_max_nodes),
  308. )
  309. pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
  310. paddle.to_tensor(adjacent_matrix), "float32"
  311. )
  312. pad_normalized_feats = paddle.concat(
  313. [
  314. normalized_feats,
  315. paddle.zeros(
  316. (num_max_nodes - num_nodes, normalized_feats.shape[1]),
  317. ),
  318. ],
  319. axis=0,
  320. )
  321. local_graph_nodes = paddle.to_tensor(pivot_local_graph)
  322. local_graph_nodes = paddle.concat(
  323. [
  324. local_graph_nodes,
  325. paddle.zeros([num_max_nodes - num_nodes], dtype="int64"),
  326. ],
  327. axis=-1,
  328. )
  329. local_graphs_node_feat.append(pad_normalized_feats)
  330. adjacent_matrices.append(pad_adjacent_matrix)
  331. pivots_knn_inds.append(knn_inds)
  332. pivots_local_graphs.append(local_graph_nodes)
  333. local_graphs_node_feat = paddle.stack(local_graphs_node_feat, 0)
  334. adjacent_matrices = paddle.stack(adjacent_matrices, 0)
  335. pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
  336. pivots_local_graphs = paddle.stack(pivots_local_graphs, 0)
  337. return (
  338. local_graphs_node_feat,
  339. adjacent_matrices,
  340. pivots_knn_inds,
  341. pivots_local_graphs,
  342. )
  343. def __call__(self, preds, feat_maps):
  344. """Generate local graphs and graph convolutional network input data.
  345. Args:
  346. preds (tensor): The predicted maps.
  347. feat_maps (tensor): The feature maps to extract content feature of
  348. text components.
  349. Returns:
  350. none_flag (bool): The flag showing whether the number of proposed
  351. text components is 0.
  352. local_graphs_node_feats (tensor): The features of nodes in local
  353. graphs.
  354. adjacent_matrices (tensor): The adjacent matrices.
  355. pivots_knn_inds (tensor): The k-nearest neighbor indices in
  356. local graphs.
  357. pivots_local_graphs (tensor): The indices of nodes in local
  358. graphs.
  359. text_comps (ndarray): The predicted text components.
  360. """
  361. if preds.ndim == 4:
  362. assert preds.shape[0] == 1
  363. preds = paddle.squeeze(preds)
  364. pred_text_region = F.sigmoid(preds[0]).numpy()
  365. pred_center_region = F.sigmoid(preds[1]).numpy()
  366. pred_sin_map = preds[2].numpy()
  367. pred_cos_map = preds[3].numpy()
  368. pred_top_height_map = preds[4].numpy()
  369. pred_bot_height_map = preds[5].numpy()
  370. comp_attribs, text_comps = self.propose_comps_and_attribs(
  371. pred_text_region,
  372. pred_center_region,
  373. pred_top_height_map,
  374. pred_bot_height_map,
  375. pred_sin_map,
  376. pred_cos_map,
  377. )
  378. if comp_attribs is None or len(comp_attribs) < 2:
  379. none_flag = True
  380. return none_flag, (0, 0, 0, 0, 0)
  381. comp_centers = comp_attribs[:, 0:2]
  382. distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
  383. geo_feats = feature_embedding(comp_attribs, self.node_geo_feat_dim)
  384. geo_feats = paddle.to_tensor(geo_feats)
  385. batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
  386. comp_attribs = comp_attribs.astype(np.float32)
  387. angle = np.arccos(comp_attribs[:, -2]) * np.sign(comp_attribs[:, -1])
  388. angle = angle.reshape((-1, 1))
  389. rotated_rois = np.hstack([batch_id, comp_attribs[:, :-2], angle])
  390. rois = paddle.to_tensor(rotated_rois)
  391. content_feats = self.pooling(feat_maps, rois)
  392. content_feats = content_feats.reshape([content_feats.shape[0], -1])
  393. node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
  394. sorted_dist_inds = np.argsort(distance_matrix, axis=1)
  395. (
  396. local_graphs_node_feat,
  397. adjacent_matrices,
  398. pivots_knn_inds,
  399. pivots_local_graphs,
  400. ) = self.generate_local_graphs(sorted_dist_inds, node_feats)
  401. none_flag = False
  402. return none_flag, (
  403. local_graphs_node_feat,
  404. adjacent_matrices,
  405. pivots_knn_inds,
  406. pivots_local_graphs,
  407. text_comps,
  408. )