det_drrg_head.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/dense_heads/drrg_head.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import warnings
  22. import cv2
  23. import numpy as np
  24. import paddle
  25. import paddle.nn as nn
  26. import paddle.nn.functional as F
  27. from .gcn import GCN
  28. from .local_graph import LocalGraphs
  29. from .proposal_local_graph import ProposalLocalGraphs
  30. class DRRGHead(nn.Layer):
  31. def __init__(
  32. self,
  33. in_channels,
  34. k_at_hops=(8, 4),
  35. num_adjacent_linkages=3,
  36. node_geo_feat_len=120,
  37. pooling_scale=1.0,
  38. pooling_output_size=(4, 3),
  39. nms_thr=0.3,
  40. min_width=8.0,
  41. max_width=24.0,
  42. comp_shrink_ratio=1.03,
  43. comp_ratio=0.4,
  44. comp_score_thr=0.3,
  45. text_region_thr=0.2,
  46. center_region_thr=0.2,
  47. center_region_area_thr=50,
  48. local_graph_thr=0.7,
  49. **kwargs,
  50. ):
  51. super().__init__()
  52. assert isinstance(in_channels, int)
  53. assert isinstance(k_at_hops, tuple)
  54. assert isinstance(num_adjacent_linkages, int)
  55. assert isinstance(node_geo_feat_len, int)
  56. assert isinstance(pooling_scale, float)
  57. assert isinstance(pooling_output_size, tuple)
  58. assert isinstance(comp_shrink_ratio, float)
  59. assert isinstance(nms_thr, float)
  60. assert isinstance(min_width, float)
  61. assert isinstance(max_width, float)
  62. assert isinstance(comp_ratio, float)
  63. assert isinstance(comp_score_thr, float)
  64. assert isinstance(text_region_thr, float)
  65. assert isinstance(center_region_thr, float)
  66. assert isinstance(center_region_area_thr, int)
  67. assert isinstance(local_graph_thr, float)
  68. self.in_channels = in_channels
  69. self.out_channels = 6
  70. self.downsample_ratio = 1.0
  71. self.k_at_hops = k_at_hops
  72. self.num_adjacent_linkages = num_adjacent_linkages
  73. self.node_geo_feat_len = node_geo_feat_len
  74. self.pooling_scale = pooling_scale
  75. self.pooling_output_size = pooling_output_size
  76. self.comp_shrink_ratio = comp_shrink_ratio
  77. self.nms_thr = nms_thr
  78. self.min_width = min_width
  79. self.max_width = max_width
  80. self.comp_ratio = comp_ratio
  81. self.comp_score_thr = comp_score_thr
  82. self.text_region_thr = text_region_thr
  83. self.center_region_thr = center_region_thr
  84. self.center_region_area_thr = center_region_area_thr
  85. self.local_graph_thr = local_graph_thr
  86. self.out_conv = nn.Conv2D(
  87. in_channels=self.in_channels,
  88. out_channels=self.out_channels,
  89. kernel_size=1,
  90. stride=1,
  91. padding=0,
  92. )
  93. self.graph_train = LocalGraphs(
  94. self.k_at_hops,
  95. self.num_adjacent_linkages,
  96. self.node_geo_feat_len,
  97. self.pooling_scale,
  98. self.pooling_output_size,
  99. self.local_graph_thr,
  100. )
  101. self.graph_test = ProposalLocalGraphs(
  102. self.k_at_hops,
  103. self.num_adjacent_linkages,
  104. self.node_geo_feat_len,
  105. self.pooling_scale,
  106. self.pooling_output_size,
  107. self.nms_thr,
  108. self.min_width,
  109. self.max_width,
  110. self.comp_shrink_ratio,
  111. self.comp_ratio,
  112. self.comp_score_thr,
  113. self.text_region_thr,
  114. self.center_region_thr,
  115. self.center_region_area_thr,
  116. )
  117. pool_w, pool_h = self.pooling_output_size
  118. node_feat_len = (pool_w * pool_h) * (
  119. self.in_channels + self.out_channels
  120. ) + self.node_geo_feat_len
  121. self.gcn = GCN(node_feat_len)
  122. def forward(self, inputs, targets=None):
  123. """
  124. Args:
  125. inputs (Tensor): Shape of :math:`(N, C, H, W)`.
  126. gt_comp_attribs (list[ndarray]): The padded text component
  127. attributes. Shape: (num_component, 8).
  128. Returns:
  129. tuple: Returns (pred_maps, (gcn_pred, gt_labels)).
  130. - | pred_maps (Tensor): Prediction map with shape
  131. :math:`(N, C_{out}, H, W)`.
  132. - | gcn_pred (Tensor): Prediction from GCN module, with
  133. shape :math:`(N, 2)`.
  134. - | gt_labels (Tensor): Ground-truth label with shape
  135. :math:`(N, 8)`.
  136. """
  137. if self.training:
  138. assert targets is not None
  139. gt_comp_attribs = targets[7]
  140. pred_maps = self.out_conv(inputs)
  141. feat_maps = paddle.concat([inputs, pred_maps], axis=1)
  142. node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
  143. feat_maps, np.stack(gt_comp_attribs)
  144. )
  145. gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
  146. return pred_maps, (gcn_pred, gt_labels)
  147. else:
  148. return self.single_test(inputs)
  149. def single_test(self, feat_maps):
  150. r"""
  151. Args:
  152. feat_maps (Tensor): Shape of :math:`(N, C, H, W)`.
  153. Returns:
  154. tuple: Returns (edge, score, text_comps).
  155. - | edge (ndarray): The edge array of shape :math:`(N, 2)`
  156. where each row is a pair of text component indices
  157. that makes up an edge in graph.
  158. - | score (ndarray): The score array of shape :math:`(N,)`,
  159. corresponding to the edge above.
  160. - | text_comps (ndarray): The text components of shape
  161. :math:`(N, 9)` where each row corresponds to one box and
  162. its score: (x1, y1, x2, y2, x3, y3, x4, y4, score).
  163. """
  164. pred_maps = self.out_conv(feat_maps)
  165. feat_maps = paddle.concat([feat_maps, pred_maps], axis=1)
  166. none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
  167. (
  168. local_graphs_node_feat,
  169. adjacent_matrices,
  170. pivots_knn_inds,
  171. pivot_local_graphs,
  172. text_comps,
  173. ) = graph_data
  174. if none_flag:
  175. return None, None, None
  176. gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds)
  177. pred_labels = F.softmax(gcn_pred, axis=1)
  178. edges = []
  179. scores = []
  180. pivot_local_graphs = pivot_local_graphs.squeeze().numpy()
  181. for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs):
  182. pivot = pivot_local_graph[0]
  183. for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]):
  184. neighbor = pivot_local_graph[neighbor_ind.item()]
  185. edges.append([pivot, neighbor])
  186. scores.append(
  187. pred_labels[pivot_ind * pivots_knn_inds.shape[1] + k_ind, 1].item()
  188. )
  189. edges = np.asarray(edges)
  190. scores = np.asarray(scores)
  191. return edges, scores, text_comps