table_ops.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. """
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. from __future__ import unicode_literals
  20. import sys
  21. import cv2
  22. import numpy as np
  23. class GenTableMask(object):
  24. """gen table mask"""
  25. def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
  26. self.shrink_h_max = 5
  27. self.shrink_w_max = 5
  28. self.mask_type = mask_type
  29. def projection(self, erosion, h, w, spilt_threshold=0):
  30. # 水平投影
  31. projection_map = np.ones_like(erosion)
  32. project_val_array = [0 for _ in range(0, h)]
  33. for j in range(0, h):
  34. for i in range(0, w):
  35. if erosion[j, i] == 255:
  36. project_val_array[j] += 1
  37. # 根据数组,获取切割点
  38. start_idx = 0 # 记录进入字符区的索引
  39. end_idx = 0 # 记录进入空白区域的索引
  40. in_text = False # 是否遍历到了字符区内
  41. box_list = []
  42. for i in range(len(project_val_array)):
  43. if (
  44. in_text == False and project_val_array[i] > spilt_threshold
  45. ): # 进入字符区了
  46. in_text = True
  47. start_idx = i
  48. elif (
  49. project_val_array[i] <= spilt_threshold and in_text == True
  50. ): # 进入空白区了
  51. end_idx = i
  52. in_text = False
  53. if end_idx - start_idx <= 2:
  54. continue
  55. box_list.append((start_idx, end_idx + 1))
  56. if in_text:
  57. box_list.append((start_idx, h - 1))
  58. # 绘制投影直方图
  59. for j in range(0, h):
  60. for i in range(0, project_val_array[j]):
  61. projection_map[j, i] = 0
  62. return box_list, projection_map
  63. def projection_cx(self, box_img):
  64. box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
  65. h, w = box_gray_img.shape
  66. # 灰度图片进行二值化处理
  67. ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
  68. # 纵向腐蚀
  69. if h < w:
  70. kernel = np.ones((2, 1), np.uint8)
  71. erode = cv2.erode(thresh1, kernel, iterations=1)
  72. else:
  73. erode = thresh1
  74. # 水平膨胀
  75. kernel = np.ones((1, 5), np.uint8)
  76. erosion = cv2.dilate(erode, kernel, iterations=1)
  77. # 水平投影
  78. projection_map = np.ones_like(erosion)
  79. project_val_array = [0 for _ in range(0, h)]
  80. for j in range(0, h):
  81. for i in range(0, w):
  82. if erosion[j, i] == 255:
  83. project_val_array[j] += 1
  84. # 根据数组,获取切割点
  85. start_idx = 0 # 记录进入字符区的索引
  86. end_idx = 0 # 记录进入空白区域的索引
  87. in_text = False # 是否遍历到了字符区内
  88. box_list = []
  89. spilt_threshold = 0
  90. for i in range(len(project_val_array)):
  91. if (
  92. in_text == False and project_val_array[i] > spilt_threshold
  93. ): # 进入字符区了
  94. in_text = True
  95. start_idx = i
  96. elif (
  97. project_val_array[i] <= spilt_threshold and in_text == True
  98. ): # 进入空白区了
  99. end_idx = i
  100. in_text = False
  101. if end_idx - start_idx <= 2:
  102. continue
  103. box_list.append((start_idx, end_idx + 1))
  104. if in_text:
  105. box_list.append((start_idx, h - 1))
  106. # 绘制投影直方图
  107. for j in range(0, h):
  108. for i in range(0, project_val_array[j]):
  109. projection_map[j, i] = 0
  110. split_bbox_list = []
  111. if len(box_list) > 1:
  112. for i, (h_start, h_end) in enumerate(box_list):
  113. if i == 0:
  114. h_start = 0
  115. if i == len(box_list):
  116. h_end = h
  117. word_img = erosion[h_start : h_end + 1, :]
  118. word_h, word_w = word_img.shape
  119. w_split_list, w_projection_map = self.projection(
  120. word_img.T, word_w, word_h
  121. )
  122. w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
  123. if h_start > 0:
  124. h_start -= 1
  125. h_end += 1
  126. word_img = box_img[h_start : h_end + 1 :, w_start : w_end + 1, :]
  127. split_bbox_list.append([w_start, h_start, w_end, h_end])
  128. else:
  129. split_bbox_list.append([0, 0, w, h])
  130. return split_bbox_list
  131. def shrink_bbox(self, bbox):
  132. left, top, right, bottom = bbox
  133. sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
  134. sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
  135. left_new = left + sh_w
  136. right_new = right - sh_w
  137. top_new = top + sh_h
  138. bottom_new = bottom - sh_h
  139. if left_new >= right_new:
  140. left_new = left
  141. right_new = right
  142. if top_new >= bottom_new:
  143. top_new = top
  144. bottom_new = bottom
  145. return [left_new, top_new, right_new, bottom_new]
  146. def __call__(self, data):
  147. img = data["image"]
  148. cells = data["cells"]
  149. height, width = img.shape[0:2]
  150. if self.mask_type == 1:
  151. mask_img = np.zeros((height, width), dtype=np.float32)
  152. else:
  153. mask_img = np.zeros((height, width, 3), dtype=np.float32)
  154. cell_num = len(cells)
  155. for cno in range(cell_num):
  156. if "bbox" in cells[cno]:
  157. bbox = cells[cno]["bbox"]
  158. left, top, right, bottom = bbox
  159. box_img = img[top:bottom, left:right, :].copy()
  160. split_bbox_list = self.projection_cx(box_img)
  161. for sno in range(len(split_bbox_list)):
  162. split_bbox_list[sno][0] += left
  163. split_bbox_list[sno][1] += top
  164. split_bbox_list[sno][2] += left
  165. split_bbox_list[sno][3] += top
  166. for sno in range(len(split_bbox_list)):
  167. left, top, right, bottom = split_bbox_list[sno]
  168. left, top, right, bottom = self.shrink_bbox(
  169. [left, top, right, bottom]
  170. )
  171. if self.mask_type == 1:
  172. mask_img[top:bottom, left:right] = 1.0
  173. data["mask_img"] = mask_img
  174. else:
  175. mask_img[top:bottom, left:right, :] = (255, 255, 255)
  176. data["image"] = mask_img
  177. return data
  178. class ResizeTableImage(object):
  179. def __init__(self, max_len, resize_bboxes=False, infer_mode=False, **kwargs):
  180. super(ResizeTableImage, self).__init__()
  181. self.max_len = max_len
  182. self.resize_bboxes = resize_bboxes
  183. self.infer_mode = infer_mode
  184. def __call__(self, data):
  185. img = data["image"]
  186. height, width = img.shape[0:2]
  187. ratio = self.max_len / (max(height, width) * 1.0)
  188. resize_h = int(height * ratio)
  189. resize_w = int(width * ratio)
  190. resize_img = cv2.resize(img, (resize_w, resize_h))
  191. if self.resize_bboxes and not self.infer_mode:
  192. data["bboxes"] = data["bboxes"] * ratio
  193. data["image"] = resize_img
  194. data["src_img"] = img
  195. data["shape"] = np.array([height, width, ratio, ratio])
  196. data["max_len"] = self.max_len
  197. return data
  198. class PaddingTableImage(object):
  199. def __init__(self, size, **kwargs):
  200. super(PaddingTableImage, self).__init__()
  201. self.size = size
  202. def __call__(self, data):
  203. img = data["image"]
  204. pad_h, pad_w = self.size
  205. padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
  206. height, width = img.shape[0:2]
  207. padding_img[0:height, 0:width, :] = img.copy()
  208. data["image"] = padding_img
  209. shape = data["shape"].tolist()
  210. shape.extend([pad_h, pad_w])
  211. data["shape"] = np.array(shape)
  212. return data