sast_postprocess.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(__file__)
  20. sys.path.append(__dir__)
  21. sys.path.append(os.path.join(__dir__, ".."))
  22. import numpy as np
  23. from .locality_aware_nms import nms_locality
  24. import paddle
  25. import cv2
  26. import time
  27. class SASTPostProcess(object):
  28. """
  29. The post process for SAST.
  30. """
  31. def __init__(
  32. self,
  33. score_thresh=0.5,
  34. nms_thresh=0.2,
  35. sample_pts_num=2,
  36. shrink_ratio_of_width=0.3,
  37. expand_scale=1.0,
  38. tcl_map_thresh=0.5,
  39. **kwargs,
  40. ):
  41. self.score_thresh = score_thresh
  42. self.nms_thresh = nms_thresh
  43. self.sample_pts_num = sample_pts_num
  44. self.shrink_ratio_of_width = shrink_ratio_of_width
  45. self.expand_scale = expand_scale
  46. self.tcl_map_thresh = tcl_map_thresh
  47. # c++ la-nms is faster, but only support python 3.5
  48. self.is_python35 = False
  49. if sys.version_info.major == 3 and sys.version_info.minor == 5:
  50. self.is_python35 = True
  51. def point_pair2poly(self, point_pair_list):
  52. """
  53. Transfer vertical point_pairs into poly point in clockwise.
  54. """
  55. # construct poly
  56. point_num = len(point_pair_list) * 2
  57. point_list = [0] * point_num
  58. for idx, point_pair in enumerate(point_pair_list):
  59. point_list[idx] = point_pair[0]
  60. point_list[point_num - 1 - idx] = point_pair[1]
  61. return np.array(point_list).reshape(-1, 2)
  62. def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
  63. """
  64. Generate shrink_quad_along_width.
  65. """
  66. ratio_pair = np.array(
  67. [[begin_width_ratio], [end_width_ratio]], dtype=np.float32
  68. )
  69. p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
  70. p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
  71. return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
  72. def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
  73. """
  74. expand poly along width.
  75. """
  76. point_num = poly.shape[0]
  77. left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
  78. left_ratio = (
  79. -shrink_ratio_of_width
  80. * np.linalg.norm(left_quad[0] - left_quad[3])
  81. / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
  82. )
  83. left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
  84. right_quad = np.array(
  85. [
  86. poly[point_num // 2 - 2],
  87. poly[point_num // 2 - 1],
  88. poly[point_num // 2],
  89. poly[point_num // 2 + 1],
  90. ],
  91. dtype=np.float32,
  92. )
  93. right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
  94. right_quad[0] - right_quad[3]
  95. ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
  96. right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
  97. poly[0] = left_quad_expand[0]
  98. poly[-1] = left_quad_expand[-1]
  99. poly[point_num // 2 - 1] = right_quad_expand[1]
  100. poly[point_num // 2] = right_quad_expand[2]
  101. return poly
  102. def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
  103. """Restore quad."""
  104. xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
  105. xy_text = xy_text[:, ::-1] # (n, 2)
  106. # Sort the text boxes via the y axis
  107. xy_text = xy_text[np.argsort(xy_text[:, 1])]
  108. scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
  109. scores = scores[:, np.newaxis]
  110. # Restore
  111. point_num = int(tvo_map.shape[-1] / 2)
  112. assert point_num == 4
  113. tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
  114. xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
  115. quads = xy_text_tile - tvo_map
  116. return scores, quads, xy_text
  117. def quad_area(self, quad):
  118. """
  119. compute area of a quad.
  120. """
  121. edge = [
  122. (quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
  123. (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
  124. (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
  125. (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1]),
  126. ]
  127. return np.sum(edge) / 2.0
  128. def nms(self, dets):
  129. if self.is_python35:
  130. from ppocr.utils.utility import check_install
  131. check_install("lanms", "lanms-nova")
  132. import lanms
  133. dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
  134. else:
  135. dets = nms_locality(dets, self.nms_thresh)
  136. return dets
  137. def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
  138. """
  139. Cluster pixels in tcl_map based on quads.
  140. """
  141. instance_count = quads.shape[0] + 1 # contain background
  142. instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
  143. if instance_count == 1:
  144. return instance_count, instance_label_map
  145. # predict text center
  146. xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
  147. n = xy_text.shape[0]
  148. xy_text = xy_text[:, ::-1] # (n, 2)
  149. tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
  150. pred_tc = xy_text - tco
  151. # get gt text center
  152. m = quads.shape[0]
  153. gt_tc = np.mean(quads, axis=1) # (m, 2)
  154. pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
  155. gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
  156. dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
  157. xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
  158. instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
  159. return instance_count, instance_label_map
  160. def estimate_sample_pts_num(self, quad, xy_text):
  161. """
  162. Estimate sample points number.
  163. """
  164. eh = (
  165. np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
  166. ) / 2.0
  167. ew = (
  168. np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
  169. ) / 2.0
  170. dense_sample_pts_num = max(2, int(ew))
  171. dense_xy_center_line = xy_text[
  172. np.linspace(
  173. 0,
  174. xy_text.shape[0] - 1,
  175. dense_sample_pts_num,
  176. endpoint=True,
  177. dtype=np.float32,
  178. ).astype(np.int32)
  179. ]
  180. dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
  181. estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
  182. sample_pts_num = max(2, int(estimate_arc_len / eh))
  183. return sample_pts_num
  184. def detect_sast(
  185. self,
  186. tcl_map,
  187. tvo_map,
  188. tbo_map,
  189. tco_map,
  190. ratio_w,
  191. ratio_h,
  192. src_w,
  193. src_h,
  194. shrink_ratio_of_width=0.3,
  195. tcl_map_thresh=0.5,
  196. offset_expand=1.0,
  197. out_strid=4.0,
  198. ):
  199. """
  200. first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
  201. """
  202. # restore quad
  203. scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
  204. dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
  205. dets = self.nms(dets)
  206. if dets.shape[0] == 0:
  207. return []
  208. quads = dets[:, :-1].reshape(-1, 4, 2)
  209. # Compute quad area
  210. quad_areas = []
  211. for quad in quads:
  212. quad_areas.append(-self.quad_area(quad))
  213. # instance segmentation
  214. # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
  215. instance_count, instance_label_map = self.cluster_by_quads_tco(
  216. tcl_map, tcl_map_thresh, quads, tco_map
  217. )
  218. # restore single poly with tcl instance.
  219. poly_list = []
  220. for instance_idx in range(1, instance_count):
  221. xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
  222. quad = quads[instance_idx - 1]
  223. q_area = quad_areas[instance_idx - 1]
  224. if q_area < 5:
  225. continue
  226. #
  227. len1 = float(np.linalg.norm(quad[0] - quad[1]))
  228. len2 = float(np.linalg.norm(quad[1] - quad[2]))
  229. min_len = min(len1, len2)
  230. if min_len < 3:
  231. continue
  232. # filter small CC
  233. if xy_text.shape[0] <= 0:
  234. continue
  235. # filter low confidence instance
  236. xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
  237. if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
  238. # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
  239. continue
  240. # sort xy_text
  241. left_center_pt = np.array(
  242. [[(quad[0, 0] + quad[-1, 0]) / 2.0, (quad[0, 1] + quad[-1, 1]) / 2.0]]
  243. ) # (1, 2)
  244. right_center_pt = np.array(
  245. [[(quad[1, 0] + quad[2, 0]) / 2.0, (quad[1, 1] + quad[2, 1]) / 2.0]]
  246. ) # (1, 2)
  247. proj_unit_vec = (right_center_pt - left_center_pt) / (
  248. np.linalg.norm(right_center_pt - left_center_pt) + 1e-6
  249. )
  250. proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
  251. xy_text = xy_text[np.argsort(proj_value)]
  252. # Sample pts in tcl map
  253. if self.sample_pts_num == 0:
  254. sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
  255. else:
  256. sample_pts_num = self.sample_pts_num
  257. xy_center_line = xy_text[
  258. np.linspace(
  259. 0,
  260. xy_text.shape[0] - 1,
  261. sample_pts_num,
  262. endpoint=True,
  263. dtype=np.float32,
  264. ).astype(np.int32)
  265. ]
  266. point_pair_list = []
  267. for x, y in xy_center_line:
  268. # get corresponding offset
  269. offset = tbo_map[y, x, :].reshape(2, 2)
  270. if offset_expand != 1.0:
  271. offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
  272. expand_length = np.clip(
  273. offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0
  274. )
  275. offset_detal = offset / offset_length * expand_length
  276. offset = offset + offset_detal
  277. # original point
  278. ori_yx = np.array([y, x], dtype=np.float32)
  279. point_pair = (
  280. (ori_yx + offset)[:, ::-1]
  281. * out_strid
  282. / np.array([ratio_w, ratio_h]).reshape(-1, 2)
  283. )
  284. point_pair_list.append(point_pair)
  285. # ndarry: (x, 2), expand poly along width
  286. detected_poly = self.point_pair2poly(point_pair_list)
  287. detected_poly = self.expand_poly_along_width(
  288. detected_poly, shrink_ratio_of_width
  289. )
  290. detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
  291. detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
  292. poly_list.append(detected_poly)
  293. return poly_list
  294. def __call__(self, outs_dict, shape_list):
  295. score_list = outs_dict["f_score"]
  296. border_list = outs_dict["f_border"]
  297. tvo_list = outs_dict["f_tvo"]
  298. tco_list = outs_dict["f_tco"]
  299. if isinstance(score_list, paddle.Tensor):
  300. score_list = score_list.numpy()
  301. border_list = border_list.numpy()
  302. tvo_list = tvo_list.numpy()
  303. tco_list = tco_list.numpy()
  304. img_num = len(shape_list)
  305. poly_lists = []
  306. for ino in range(img_num):
  307. p_score = score_list[ino].transpose((1, 2, 0))
  308. p_border = border_list[ino].transpose((1, 2, 0))
  309. p_tvo = tvo_list[ino].transpose((1, 2, 0))
  310. p_tco = tco_list[ino].transpose((1, 2, 0))
  311. src_h, src_w, ratio_h, ratio_w = shape_list[ino]
  312. poly_list = self.detect_sast(
  313. p_score,
  314. p_tvo,
  315. p_border,
  316. p_tco,
  317. ratio_w,
  318. ratio_h,
  319. src_w,
  320. src_h,
  321. shrink_ratio_of_width=self.shrink_ratio_of_width,
  322. tcl_map_thresh=self.tcl_map_thresh,
  323. offset_expand=self.expand_scale,
  324. )
  325. poly_lists.append({"points": np.array(poly_list)})
  326. return poly_lists