local_graph.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/local_graph.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import numpy as np
  22. import paddle
  23. import paddle.nn as nn
  24. from ppocr.ext_op import RoIAlignRotated
  25. def normalize_adjacent_matrix(A):
  26. assert A.ndim == 2
  27. assert A.shape[0] == A.shape[1]
  28. A = A + np.eye(A.shape[0])
  29. d = np.sum(A, axis=0)
  30. d = np.clip(d, 0, None)
  31. d_inv = np.power(d, -0.5).flatten()
  32. d_inv[np.isinf(d_inv)] = 0.0
  33. d_inv = np.diag(d_inv)
  34. G = A.dot(d_inv).transpose().dot(d_inv)
  35. return G
  36. def euclidean_distance_matrix(A, B):
  37. """Calculate the Euclidean distance matrix.
  38. Args:
  39. A (ndarray): The point sequence.
  40. B (ndarray): The point sequence with the same dimensions as A.
  41. returns:
  42. D (ndarray): The Euclidean distance matrix.
  43. """
  44. assert A.ndim == 2
  45. assert B.ndim == 2
  46. assert A.shape[1] == B.shape[1]
  47. m = A.shape[0]
  48. n = B.shape[0]
  49. A_dots = (A * A).sum(axis=1).reshape((m, 1)) * np.ones(shape=(1, n))
  50. B_dots = (B * B).sum(axis=1) * np.ones(shape=(m, 1))
  51. D_squared = A_dots + B_dots - 2 * A.dot(B.T)
  52. zero_mask = np.less(D_squared, 0.0)
  53. D_squared[zero_mask] = 0.0
  54. D = np.sqrt(D_squared)
  55. return D
  56. def feature_embedding(input_feats, out_feat_len):
  57. """Embed features. This code was partially adapted from
  58. https://github.com/GXYM/DRRG licensed under the MIT license.
  59. Args:
  60. input_feats (ndarray): The input features of shape (N, d), where N is
  61. the number of nodes in graph, d is the input feature vector length.
  62. out_feat_len (int): The length of output feature vector.
  63. Returns:
  64. embedded_feats (ndarray): The embedded features.
  65. """
  66. assert input_feats.ndim == 2
  67. assert isinstance(out_feat_len, int)
  68. assert out_feat_len >= input_feats.shape[1]
  69. num_nodes = input_feats.shape[0]
  70. feat_dim = input_feats.shape[1]
  71. feat_repeat_times = out_feat_len // feat_dim
  72. residue_dim = out_feat_len % feat_dim
  73. if residue_dim > 0:
  74. embed_wave = np.array(
  75. [
  76. np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
  77. for j in range(feat_repeat_times + 1)
  78. ]
  79. ).reshape((feat_repeat_times + 1, 1, 1))
  80. repeat_feats = np.repeat(
  81. np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0
  82. )
  83. residue_feats = np.hstack(
  84. [
  85. input_feats[:, 0:residue_dim],
  86. np.zeros((num_nodes, feat_dim - residue_dim)),
  87. ]
  88. )
  89. residue_feats = np.expand_dims(residue_feats, axis=0)
  90. repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0)
  91. embedded_feats = repeat_feats / embed_wave
  92. embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
  93. embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
  94. embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
  95. (num_nodes, -1)
  96. )[:, 0:out_feat_len]
  97. else:
  98. embed_wave = np.array(
  99. [
  100. np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
  101. for j in range(feat_repeat_times)
  102. ]
  103. ).reshape((feat_repeat_times, 1, 1))
  104. repeat_feats = np.repeat(
  105. np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0
  106. )
  107. embedded_feats = repeat_feats / embed_wave
  108. embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
  109. embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
  110. embedded_feats = (
  111. np.transpose(embedded_feats, (1, 0, 2))
  112. .reshape((num_nodes, -1))
  113. .astype(np.float32)
  114. )
  115. return embedded_feats
  116. class LocalGraphs:
  117. def __init__(
  118. self,
  119. k_at_hops,
  120. num_adjacent_linkages,
  121. node_geo_feat_len,
  122. pooling_scale,
  123. pooling_output_size,
  124. local_graph_thr,
  125. ):
  126. assert len(k_at_hops) == 2
  127. assert all(isinstance(n, int) for n in k_at_hops)
  128. assert isinstance(num_adjacent_linkages, int)
  129. assert isinstance(node_geo_feat_len, int)
  130. assert isinstance(pooling_scale, float)
  131. assert all(isinstance(n, int) for n in pooling_output_size)
  132. assert isinstance(local_graph_thr, float)
  133. self.k_at_hops = k_at_hops
  134. self.num_adjacent_linkages = num_adjacent_linkages
  135. self.node_geo_feat_dim = node_geo_feat_len
  136. self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
  137. self.local_graph_thr = local_graph_thr
  138. def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
  139. """Generate local graphs for GCN to predict which instance a text
  140. component belongs to.
  141. Args:
  142. sorted_dist_inds (ndarray): The complete graph node indices, which
  143. is sorted according to the Euclidean distance.
  144. gt_comp_labels(ndarray): The ground truth labels define the
  145. instance to which the text components (nodes in graphs) belong.
  146. Returns:
  147. pivot_local_graphs(list[list[int]]): The list of local graph
  148. neighbor indices of pivots.
  149. pivot_knns(list[list[int]]): The list of k-nearest neighbor indices
  150. of pivots.
  151. """
  152. assert sorted_dist_inds.ndim == 2
  153. assert (
  154. sorted_dist_inds.shape[0]
  155. == sorted_dist_inds.shape[1]
  156. == gt_comp_labels.shape[0]
  157. )
  158. knn_graph = sorted_dist_inds[:, 1 : self.k_at_hops[0] + 1]
  159. pivot_local_graphs = []
  160. pivot_knns = []
  161. for pivot_ind, knn in enumerate(knn_graph):
  162. local_graph_neighbors = set(knn)
  163. for neighbor_ind in knn:
  164. local_graph_neighbors.update(
  165. set(sorted_dist_inds[neighbor_ind, 1 : self.k_at_hops[1] + 1])
  166. )
  167. local_graph_neighbors.discard(pivot_ind)
  168. pivot_local_graph = list(local_graph_neighbors)
  169. pivot_local_graph.insert(0, pivot_ind)
  170. pivot_knn = [pivot_ind] + list(knn)
  171. if pivot_ind < 1:
  172. pivot_local_graphs.append(pivot_local_graph)
  173. pivot_knns.append(pivot_knn)
  174. else:
  175. add_flag = True
  176. for graph_ind, added_knn in enumerate(pivot_knns):
  177. added_pivot_ind = added_knn[0]
  178. added_local_graph = pivot_local_graphs[graph_ind]
  179. union = len(
  180. set(pivot_local_graph[1:]).union(set(added_local_graph[1:]))
  181. )
  182. intersect = len(
  183. set(pivot_local_graph[1:]).intersection(
  184. set(added_local_graph[1:])
  185. )
  186. )
  187. local_graph_iou = intersect / (union + 1e-8)
  188. if (
  189. local_graph_iou > self.local_graph_thr
  190. and pivot_ind in added_knn
  191. and gt_comp_labels[added_pivot_ind] == gt_comp_labels[pivot_ind]
  192. and gt_comp_labels[pivot_ind] != 0
  193. ):
  194. add_flag = False
  195. break
  196. if add_flag:
  197. pivot_local_graphs.append(pivot_local_graph)
  198. pivot_knns.append(pivot_knn)
  199. return pivot_local_graphs, pivot_knns
  200. def generate_gcn_input(
  201. self,
  202. node_feat_batch,
  203. node_label_batch,
  204. local_graph_batch,
  205. knn_batch,
  206. sorted_dist_ind_batch,
  207. ):
  208. """Generate graph convolution network input data.
  209. Args:
  210. node_feat_batch (List[Tensor]): The batched graph node features.
  211. node_label_batch (List[ndarray]): The batched text component
  212. labels.
  213. local_graph_batch (List[List[list[int]]]): The local graph node
  214. indices of image batch.
  215. knn_batch (List[List[list[int]]]): The knn graph node indices of
  216. image batch.
  217. sorted_dist_ind_batch (list[ndarray]): The node indices sorted
  218. according to the Euclidean distance.
  219. Returns:
  220. local_graphs_node_feat (Tensor): The node features of graph.
  221. adjacent_matrices (Tensor): The adjacent matrices of local graphs.
  222. pivots_knn_inds (Tensor): The k-nearest neighbor indices in
  223. local graph.
  224. gt_linkage (Tensor): The surpervision signal of GCN for linkage
  225. prediction.
  226. """
  227. assert isinstance(node_feat_batch, list)
  228. assert isinstance(node_label_batch, list)
  229. assert isinstance(local_graph_batch, list)
  230. assert isinstance(knn_batch, list)
  231. assert isinstance(sorted_dist_ind_batch, list)
  232. num_max_nodes = max(
  233. [
  234. len(pivot_local_graph)
  235. for pivot_local_graphs in local_graph_batch
  236. for pivot_local_graph in pivot_local_graphs
  237. ]
  238. )
  239. local_graphs_node_feat = []
  240. adjacent_matrices = []
  241. pivots_knn_inds = []
  242. pivots_gt_linkage = []
  243. for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch):
  244. node_feats = node_feat_batch[batch_ind]
  245. pivot_local_graphs = local_graph_batch[batch_ind]
  246. pivot_knns = knn_batch[batch_ind]
  247. node_labels = node_label_batch[batch_ind]
  248. for graph_ind, pivot_knn in enumerate(pivot_knns):
  249. pivot_local_graph = pivot_local_graphs[graph_ind]
  250. num_nodes = len(pivot_local_graph)
  251. pivot_ind = pivot_local_graph[0]
  252. node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
  253. knn_inds = paddle.to_tensor([node2ind_map[i] for i in pivot_knn[1:]])
  254. pivot_feats = node_feats[pivot_ind]
  255. normalized_feats = (
  256. node_feats[paddle.to_tensor(pivot_local_graph)] - pivot_feats
  257. )
  258. adjacent_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
  259. for node in pivot_local_graph:
  260. neighbors = sorted_dist_inds[
  261. node, 1 : self.num_adjacent_linkages + 1
  262. ]
  263. for neighbor in neighbors:
  264. if neighbor in pivot_local_graph:
  265. adjacent_matrix[
  266. node2ind_map[node], node2ind_map[neighbor]
  267. ] = 1
  268. adjacent_matrix[
  269. node2ind_map[neighbor], node2ind_map[node]
  270. ] = 1
  271. adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
  272. pad_adjacent_matrix = paddle.zeros((num_max_nodes, num_max_nodes))
  273. pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
  274. paddle.to_tensor(adjacent_matrix), "float32"
  275. )
  276. pad_normalized_feats = paddle.concat(
  277. [
  278. normalized_feats,
  279. paddle.zeros(
  280. (num_max_nodes - num_nodes, normalized_feats.shape[1])
  281. ),
  282. ],
  283. axis=0,
  284. )
  285. local_graph_labels = node_labels[pivot_local_graph]
  286. knn_labels = local_graph_labels[knn_inds.numpy()]
  287. link_labels = (
  288. (node_labels[pivot_ind] == knn_labels)
  289. & (node_labels[pivot_ind] > 0)
  290. ).astype(np.int64)
  291. link_labels = paddle.to_tensor(link_labels)
  292. local_graphs_node_feat.append(pad_normalized_feats)
  293. adjacent_matrices.append(pad_adjacent_matrix)
  294. pivots_knn_inds.append(knn_inds)
  295. pivots_gt_linkage.append(link_labels)
  296. local_graphs_node_feat = paddle.stack(local_graphs_node_feat, 0)
  297. adjacent_matrices = paddle.stack(adjacent_matrices, 0)
  298. pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
  299. pivots_gt_linkage = paddle.stack(pivots_gt_linkage, 0)
  300. return (
  301. local_graphs_node_feat,
  302. adjacent_matrices,
  303. pivots_knn_inds,
  304. pivots_gt_linkage,
  305. )
  306. def __call__(self, feat_maps, comp_attribs):
  307. """Generate local graphs as GCN input.
  308. Args:
  309. feat_maps (Tensor): The feature maps to extract the content
  310. features of text components.
  311. comp_attribs (ndarray): The text component attributes.
  312. Returns:
  313. local_graphs_node_feat (Tensor): The node features of graph.
  314. adjacent_matrices (Tensor): The adjacent matrices of local graphs.
  315. pivots_knn_inds (Tensor): The k-nearest neighbor indices in local
  316. graph.
  317. gt_linkage (Tensor): The surpervision signal of GCN for linkage
  318. prediction.
  319. """
  320. assert isinstance(feat_maps, paddle.Tensor)
  321. assert comp_attribs.ndim == 3
  322. assert comp_attribs.shape[2] == 8
  323. sorted_dist_inds_batch = []
  324. local_graph_batch = []
  325. knn_batch = []
  326. node_feat_batch = []
  327. node_label_batch = []
  328. for batch_ind in range(comp_attribs.shape[0]):
  329. num_comps = int(comp_attribs[batch_ind, 0, 0])
  330. comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7]
  331. node_labels = comp_attribs[batch_ind, :num_comps, 7].astype(np.int32)
  332. comp_centers = comp_geo_attribs[:, 0:2]
  333. distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
  334. batch_id = (
  335. np.zeros((comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind
  336. )
  337. comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1)
  338. angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign(
  339. comp_geo_attribs[:, -1]
  340. )
  341. angle = angle.reshape((-1, 1))
  342. rotated_rois = np.hstack([batch_id, comp_geo_attribs[:, :-2], angle])
  343. rois = paddle.to_tensor(rotated_rois)
  344. content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0), rois)
  345. content_feats = content_feats.reshape([content_feats.shape[0], -1])
  346. geo_feats = feature_embedding(comp_geo_attribs, self.node_geo_feat_dim)
  347. geo_feats = paddle.to_tensor(geo_feats)
  348. node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
  349. sorted_dist_inds = np.argsort(distance_matrix, axis=1)
  350. pivot_local_graphs, pivot_knns = self.generate_local_graphs(
  351. sorted_dist_inds, node_labels
  352. )
  353. node_feat_batch.append(node_feats)
  354. node_label_batch.append(node_labels)
  355. local_graph_batch.append(pivot_local_graphs)
  356. knn_batch.append(pivot_knns)
  357. sorted_dist_inds_batch.append(sorted_dist_inds)
  358. (node_feats, adjacent_matrices, knn_inds, gt_linkage) = self.generate_gcn_input(
  359. node_feat_batch,
  360. node_label_batch,
  361. local_graph_batch,
  362. knn_batch,
  363. sorted_dist_inds_batch,
  364. )
  365. return node_feats, adjacent_matrices, knn_inds, gt_linkage