pgnet_pp_utils.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import os
  19. import sys
  20. __dir__ = os.path.dirname(__file__)
  21. sys.path.append(__dir__)
  22. sys.path.append(os.path.join(__dir__, ".."))
  23. from extract_textpoint_slow import *
  24. from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
  25. class PGNet_PostProcess(object):
  26. # two different post-process
  27. def __init__(
  28. self,
  29. character_dict_path,
  30. valid_set,
  31. score_thresh,
  32. outs_dict,
  33. shape_list,
  34. point_gather_mode=None,
  35. ):
  36. self.Lexicon_Table = get_dict(character_dict_path)
  37. self.valid_set = valid_set
  38. self.score_thresh = score_thresh
  39. self.outs_dict = outs_dict
  40. self.shape_list = shape_list
  41. self.point_gather_mode = point_gather_mode
  42. def pg_postprocess_fast(self):
  43. p_score = self.outs_dict["f_score"]
  44. p_border = self.outs_dict["f_border"]
  45. p_char = self.outs_dict["f_char"]
  46. p_direction = self.outs_dict["f_direction"]
  47. if isinstance(p_score, paddle.Tensor):
  48. p_score = p_score[0].numpy()
  49. p_border = p_border[0].numpy()
  50. p_direction = p_direction[0].numpy()
  51. p_char = p_char[0].numpy()
  52. else:
  53. p_score = p_score[0]
  54. p_border = p_border[0]
  55. p_direction = p_direction[0]
  56. p_char = p_char[0]
  57. src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
  58. instance_yxs_list, seq_strs = generate_pivot_list_fast(
  59. p_score,
  60. p_char,
  61. p_direction,
  62. self.Lexicon_Table,
  63. score_thresh=self.score_thresh,
  64. point_gather_mode=self.point_gather_mode,
  65. )
  66. poly_list, keep_str_list = restore_poly(
  67. instance_yxs_list,
  68. seq_strs,
  69. p_border,
  70. ratio_w,
  71. ratio_h,
  72. src_w,
  73. src_h,
  74. self.valid_set,
  75. )
  76. data = {
  77. "points": poly_list,
  78. "texts": keep_str_list,
  79. }
  80. return data
  81. def pg_postprocess_slow(self):
  82. p_score = self.outs_dict["f_score"]
  83. p_border = self.outs_dict["f_border"]
  84. p_char = self.outs_dict["f_char"]
  85. p_direction = self.outs_dict["f_direction"]
  86. if isinstance(p_score, paddle.Tensor):
  87. p_score = p_score[0].numpy()
  88. p_border = p_border[0].numpy()
  89. p_direction = p_direction[0].numpy()
  90. p_char = p_char[0].numpy()
  91. else:
  92. p_score = p_score[0]
  93. p_border = p_border[0]
  94. p_direction = p_direction[0]
  95. p_char = p_char[0]
  96. src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
  97. is_curved = self.valid_set == "totaltext"
  98. char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
  99. p_score,
  100. p_char,
  101. p_direction,
  102. score_thresh=self.score_thresh,
  103. is_backbone=True,
  104. is_curved=is_curved,
  105. )
  106. seq_strs = []
  107. for char_idx_set in char_seq_idx_set:
  108. pr_str = "".join([self.Lexicon_Table[pos] for pos in char_idx_set])
  109. seq_strs.append(pr_str)
  110. poly_list = []
  111. keep_str_list = []
  112. all_point_list = []
  113. all_point_pair_list = []
  114. for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
  115. if len(yx_center_line) == 1:
  116. yx_center_line.append(yx_center_line[-1])
  117. offset_expand = 1.0
  118. if self.valid_set == "totaltext":
  119. offset_expand = 1.2
  120. point_pair_list = []
  121. for batch_id, y, x in yx_center_line:
  122. offset = p_border[:, y, x].reshape(2, 2)
  123. if offset_expand != 1.0:
  124. offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
  125. expand_length = np.clip(
  126. offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
  127. )
  128. offset_detal = offset / offset_length * expand_length
  129. offset = offset + offset_detal
  130. ori_yx = np.array([y, x], dtype=np.float32)
  131. point_pair = (
  132. (ori_yx + offset)[:, ::-1]
  133. * 4.0
  134. / np.array([ratio_w, ratio_h]).reshape(-1, 2)
  135. )
  136. point_pair_list.append(point_pair)
  137. all_point_list.append(
  138. [int(round(x * 4.0 / ratio_w)), int(round(y * 4.0 / ratio_h))]
  139. )
  140. all_point_pair_list.append(point_pair.round().astype(np.int32).tolist())
  141. detected_poly, pair_length_info = point_pair2poly(point_pair_list)
  142. detected_poly = expand_poly_along_width(
  143. detected_poly, shrink_ratio_of_width=0.2
  144. )
  145. detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
  146. detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
  147. if len(keep_str) < 2:
  148. continue
  149. keep_str_list.append(keep_str)
  150. detected_poly = np.round(detected_poly).astype("int32")
  151. if self.valid_set == "partvgg":
  152. middle_point = len(detected_poly) // 2
  153. detected_poly = detected_poly[
  154. [0, middle_point - 1, middle_point, -1], :
  155. ]
  156. poly_list.append(detected_poly)
  157. elif self.valid_set == "totaltext":
  158. poly_list.append(detected_poly)
  159. else:
  160. print("--> Not supported format.")
  161. exit(-1)
  162. data = {
  163. "points": poly_list,
  164. "texts": keep_str_list,
  165. }
  166. return data