table_postprocess.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from .rec_postprocess import AttnLabelDecode
  17. class TableLabelDecode(AttnLabelDecode):
  18. """ """
  19. def __init__(self, character_dict_path, merge_no_span_structure=False, **kwargs):
  20. dict_character = []
  21. with open(character_dict_path, "rb") as fin:
  22. lines = fin.readlines()
  23. for line in lines:
  24. line = line.decode("utf-8").strip("\n").strip("\r\n")
  25. dict_character.append(line)
  26. if merge_no_span_structure:
  27. if "<td></td>" not in dict_character:
  28. dict_character.append("<td></td>")
  29. if "<td>" in dict_character:
  30. dict_character.remove("<td>")
  31. dict_character = self.add_special_char(dict_character)
  32. self.dict = {}
  33. for i, char in enumerate(dict_character):
  34. self.dict[char] = i
  35. self.character = dict_character
  36. self.td_token = ["<td>", "<td", "<td></td>"]
  37. def __call__(self, preds, batch=None):
  38. structure_probs = preds["structure_probs"]
  39. bbox_preds = preds["loc_preds"]
  40. if isinstance(structure_probs, paddle.Tensor):
  41. structure_probs = structure_probs.numpy()
  42. if isinstance(bbox_preds, paddle.Tensor):
  43. bbox_preds = bbox_preds.numpy()
  44. shape_list = batch[-1]
  45. result = self.decode(structure_probs, bbox_preds, shape_list)
  46. if len(batch) == 1: # only contains shape
  47. return result
  48. label_decode_result = self.decode_label(batch)
  49. return result, label_decode_result
  50. def decode(self, structure_probs, bbox_preds, shape_list):
  51. """convert text-label into text-index."""
  52. ignored_tokens = self.get_ignored_tokens()
  53. end_idx = self.dict[self.end_str]
  54. structure_idx = structure_probs.argmax(axis=2)
  55. structure_probs = structure_probs.max(axis=2)
  56. structure_batch_list = []
  57. bbox_batch_list = []
  58. batch_size = len(structure_idx)
  59. for batch_idx in range(batch_size):
  60. structure_list = []
  61. bbox_list = []
  62. score_list = []
  63. for idx in range(len(structure_idx[batch_idx])):
  64. char_idx = int(structure_idx[batch_idx][idx])
  65. if idx > 0 and char_idx == end_idx:
  66. break
  67. if char_idx in ignored_tokens:
  68. continue
  69. text = self.character[char_idx]
  70. if text in self.td_token:
  71. bbox = bbox_preds[batch_idx, idx]
  72. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  73. bbox_list.append(bbox)
  74. structure_list.append(text)
  75. score_list.append(structure_probs[batch_idx, idx])
  76. structure_batch_list.append([structure_list, np.mean(score_list)])
  77. bbox_batch_list.append(np.array(bbox_list))
  78. result = {
  79. "bbox_batch_list": bbox_batch_list,
  80. "structure_batch_list": structure_batch_list,
  81. }
  82. return result
  83. def decode_label(self, batch):
  84. """convert text-label into text-index."""
  85. structure_idx = batch[1]
  86. gt_bbox_list = batch[2]
  87. shape_list = batch[-1]
  88. ignored_tokens = self.get_ignored_tokens()
  89. end_idx = self.dict[self.end_str]
  90. structure_batch_list = []
  91. bbox_batch_list = []
  92. batch_size = len(structure_idx)
  93. for batch_idx in range(batch_size):
  94. structure_list = []
  95. bbox_list = []
  96. for idx in range(len(structure_idx[batch_idx])):
  97. char_idx = int(structure_idx[batch_idx][idx])
  98. if idx > 0 and char_idx == end_idx:
  99. break
  100. if char_idx in ignored_tokens:
  101. continue
  102. structure_list.append(self.character[char_idx])
  103. bbox = gt_bbox_list[batch_idx][idx]
  104. if bbox.sum() != 0:
  105. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  106. bbox_list.append(bbox)
  107. structure_batch_list.append(structure_list)
  108. bbox_batch_list.append(bbox_list)
  109. result = {
  110. "bbox_batch_list": bbox_batch_list,
  111. "structure_batch_list": structure_batch_list,
  112. }
  113. return result
  114. def _bbox_decode(self, bbox, shape):
  115. h, w, ratio_h, ratio_w, pad_h, pad_w = shape
  116. h, w = pad_h, pad_w
  117. bbox[0::2] *= w
  118. bbox[1::2] *= h
  119. bbox[0::2] /= ratio_w
  120. bbox[1::2] /= ratio_h
  121. return bbox
  122. class TableMasterLabelDecode(TableLabelDecode):
  123. """ """
  124. def __init__(
  125. self,
  126. character_dict_path,
  127. box_shape="ori",
  128. merge_no_span_structure=True,
  129. **kwargs,
  130. ):
  131. super(TableMasterLabelDecode, self).__init__(
  132. character_dict_path, merge_no_span_structure
  133. )
  134. self.box_shape = box_shape
  135. assert box_shape in [
  136. "ori",
  137. "pad",
  138. ], "The shape used for box normalization must be ori or pad"
  139. def add_special_char(self, dict_character):
  140. self.beg_str = "<SOS>"
  141. self.end_str = "<EOS>"
  142. self.unknown_str = "<UKN>"
  143. self.pad_str = "<PAD>"
  144. dict_character = dict_character
  145. dict_character = dict_character + [
  146. self.unknown_str,
  147. self.beg_str,
  148. self.end_str,
  149. self.pad_str,
  150. ]
  151. return dict_character
  152. def get_ignored_tokens(self):
  153. pad_idx = self.dict[self.pad_str]
  154. start_idx = self.dict[self.beg_str]
  155. end_idx = self.dict[self.end_str]
  156. unknown_idx = self.dict[self.unknown_str]
  157. return [start_idx, end_idx, pad_idx, unknown_idx]
  158. def _bbox_decode(self, bbox, shape):
  159. h, w, ratio_h, ratio_w, pad_h, pad_w = shape
  160. if self.box_shape == "pad":
  161. h, w = pad_h, pad_w
  162. bbox[0::2] *= w
  163. bbox[1::2] *= h
  164. bbox[0::2] /= ratio_w
  165. bbox[1::2] /= ratio_h
  166. x, y, w, h = bbox
  167. x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
  168. bbox = np.array([x1, y1, x2, y2])
  169. return bbox