extract_textpoint_fast.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Contains various CTC decoders."""
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import cv2
  19. import math
  20. import numpy as np
  21. from itertools import groupby
  22. from skimage.morphology._skeletonize import thin
  23. def get_dict(character_dict_path):
  24. character_str = ""
  25. with open(character_dict_path, "rb") as fin:
  26. lines = fin.readlines()
  27. for line in lines:
  28. line = line.decode("utf-8").strip("\n").strip("\r\n")
  29. character_str += line
  30. dict_character = list(character_str)
  31. return dict_character
  32. def softmax(logits):
  33. """
  34. logits: N x d
  35. """
  36. max_value = np.max(logits, axis=1, keepdims=True)
  37. exp = np.exp(logits - max_value)
  38. exp_sum = np.sum(exp, axis=1, keepdims=True)
  39. dist = exp / exp_sum
  40. return dist
  41. def get_keep_pos_idxs(labels, remove_blank=None):
  42. """
  43. Remove duplicate and get pos idxs of keep items.
  44. The value of keep_blank should be [None, 95].
  45. """
  46. duplicate_len_list = []
  47. keep_pos_idx_list = []
  48. keep_char_idx_list = []
  49. for k, v_ in groupby(labels):
  50. current_len = len(list(v_))
  51. if k != remove_blank:
  52. current_idx = int(sum(duplicate_len_list) + current_len // 2)
  53. keep_pos_idx_list.append(current_idx)
  54. keep_char_idx_list.append(k)
  55. duplicate_len_list.append(current_len)
  56. return keep_char_idx_list, keep_pos_idx_list
  57. def remove_blank(labels, blank=0):
  58. new_labels = [x for x in labels if x != blank]
  59. return new_labels
  60. def insert_blank(labels, blank=0):
  61. new_labels = [blank]
  62. for l in labels:
  63. new_labels += [l, blank]
  64. return new_labels
  65. def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
  66. """
  67. CTC greedy (best path) decoder.
  68. """
  69. raw_str = np.argmax(np.array(probs_seq), axis=1)
  70. remove_blank_in_pos = None if keep_blank_in_idxs else blank
  71. dedup_str, keep_idx_list = get_keep_pos_idxs(
  72. raw_str, remove_blank=remove_blank_in_pos
  73. )
  74. dst_str = remove_blank(dedup_str, blank=blank)
  75. return dst_str, keep_idx_list
  76. def instance_ctc_greedy_decoder(
  77. gather_info, logits_map, pts_num=4, point_gather_mode=None
  78. ):
  79. _, _, C = logits_map.shape
  80. if point_gather_mode == "align":
  81. insert_num = 0
  82. gather_info = np.array(gather_info)
  83. length = len(gather_info) - 1
  84. for index in range(length):
  85. stride_y = np.abs(
  86. gather_info[index + insert_num][0]
  87. - gather_info[index + 1 + insert_num][0]
  88. )
  89. stride_x = np.abs(
  90. gather_info[index + insert_num][1]
  91. - gather_info[index + 1 + insert_num][1]
  92. )
  93. max_points = int(max(stride_x, stride_y))
  94. stride = (
  95. gather_info[index + insert_num] - gather_info[index + 1 + insert_num]
  96. ) / (max_points)
  97. insert_num_temp = max_points - 1
  98. for i in range(int(insert_num_temp)):
  99. insert_value = gather_info[index + insert_num] - (i + 1) * stride
  100. insert_index = index + i + 1 + insert_num
  101. gather_info = np.insert(gather_info, insert_index, insert_value, axis=0)
  102. insert_num += insert_num_temp
  103. gather_info = gather_info.tolist()
  104. else:
  105. pass
  106. ys, xs = zip(*gather_info)
  107. logits_seq = logits_map[list(ys), list(xs)]
  108. probs_seq = logits_seq
  109. labels = np.argmax(probs_seq, axis=1)
  110. dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
  111. detal = len(gather_info) // (pts_num - 1)
  112. keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
  113. keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
  114. return dst_str, keep_gather_list
  115. def ctc_decoder_for_image(
  116. gather_info_list, logits_map, Lexicon_Table, pts_num=6, point_gather_mode=None
  117. ):
  118. """
  119. CTC decoder using multiple processes.
  120. """
  121. decoder_str = []
  122. decoder_xys = []
  123. for gather_info in gather_info_list:
  124. if len(gather_info) < pts_num:
  125. continue
  126. dst_str, xys_list = instance_ctc_greedy_decoder(
  127. gather_info,
  128. logits_map,
  129. pts_num=pts_num,
  130. point_gather_mode=point_gather_mode,
  131. )
  132. dst_str_readable = "".join([Lexicon_Table[idx] for idx in dst_str])
  133. if len(dst_str_readable) < 2:
  134. continue
  135. decoder_str.append(dst_str_readable)
  136. decoder_xys.append(xys_list)
  137. return decoder_str, decoder_xys
  138. def sort_with_direction(pos_list, f_direction):
  139. """
  140. f_direction: h x w x 2
  141. pos_list: [[y, x], [y, x], [y, x] ...]
  142. """
  143. def sort_part_with_direction(pos_list, point_direction):
  144. pos_list = np.array(pos_list).reshape(-1, 2)
  145. point_direction = np.array(point_direction).reshape(-1, 2)
  146. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  147. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  148. sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
  149. sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
  150. return sorted_list, sorted_direction
  151. pos_list = np.array(pos_list).reshape(-1, 2)
  152. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
  153. point_direction = point_direction[:, ::-1] # x, y -> y, x
  154. sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
  155. point_num = len(sorted_point)
  156. if point_num >= 16:
  157. middle_num = point_num // 2
  158. first_part_point = sorted_point[:middle_num]
  159. first_point_direction = sorted_direction[:middle_num]
  160. sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
  161. first_part_point, first_point_direction
  162. )
  163. last_part_point = sorted_point[middle_num:]
  164. last_point_direction = sorted_direction[middle_num:]
  165. sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
  166. last_part_point, last_point_direction
  167. )
  168. sorted_point = sorted_fist_part_point + sorted_last_part_point
  169. sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
  170. return sorted_point, np.array(sorted_direction)
  171. def add_id(pos_list, image_id=0):
  172. """
  173. Add id for gather feature, for inference.
  174. """
  175. new_list = []
  176. for item in pos_list:
  177. new_list.append((image_id, item[0], item[1]))
  178. return new_list
  179. def sort_and_expand_with_direction(pos_list, f_direction):
  180. """
  181. f_direction: h x w x 2
  182. pos_list: [[y, x], [y, x], [y, x] ...]
  183. """
  184. h, w, _ = f_direction.shape
  185. sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
  186. point_num = len(sorted_list)
  187. sub_direction_len = max(point_num // 3, 2)
  188. left_direction = point_direction[:sub_direction_len, :]
  189. right_dirction = point_direction[point_num - sub_direction_len :, :]
  190. left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
  191. left_average_len = np.linalg.norm(left_average_direction)
  192. left_start = np.array(sorted_list[0])
  193. left_step = left_average_direction / (left_average_len + 1e-6)
  194. right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
  195. right_average_len = np.linalg.norm(right_average_direction)
  196. right_step = right_average_direction / (right_average_len + 1e-6)
  197. right_start = np.array(sorted_list[-1])
  198. append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
  199. left_list = []
  200. right_list = []
  201. for i in range(append_num):
  202. ly, lx = (
  203. np.round(left_start + left_step * (i + 1))
  204. .flatten()
  205. .astype("int32")
  206. .tolist()
  207. )
  208. if ly < h and lx < w and (ly, lx) not in left_list:
  209. left_list.append((ly, lx))
  210. ry, rx = (
  211. np.round(right_start + right_step * (i + 1))
  212. .flatten()
  213. .astype("int32")
  214. .tolist()
  215. )
  216. if ry < h and rx < w and (ry, rx) not in right_list:
  217. right_list.append((ry, rx))
  218. all_list = left_list[::-1] + sorted_list + right_list
  219. return all_list
  220. def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
  221. """
  222. f_direction: h x w x 2
  223. pos_list: [[y, x], [y, x], [y, x] ...]
  224. binary_tcl_map: h x w
  225. """
  226. h, w, _ = f_direction.shape
  227. sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
  228. point_num = len(sorted_list)
  229. sub_direction_len = max(point_num // 3, 2)
  230. left_direction = point_direction[:sub_direction_len, :]
  231. right_dirction = point_direction[point_num - sub_direction_len :, :]
  232. left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
  233. left_average_len = np.linalg.norm(left_average_direction)
  234. left_start = np.array(sorted_list[0])
  235. left_step = left_average_direction / (left_average_len + 1e-6)
  236. right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
  237. right_average_len = np.linalg.norm(right_average_direction)
  238. right_step = right_average_direction / (right_average_len + 1e-6)
  239. right_start = np.array(sorted_list[-1])
  240. append_num = max(int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
  241. max_append_num = 2 * append_num
  242. left_list = []
  243. right_list = []
  244. for i in range(max_append_num):
  245. ly, lx = (
  246. np.round(left_start + left_step * (i + 1))
  247. .flatten()
  248. .astype("int32")
  249. .tolist()
  250. )
  251. if ly < h and lx < w and (ly, lx) not in left_list:
  252. if binary_tcl_map[ly, lx] > 0.5:
  253. left_list.append((ly, lx))
  254. else:
  255. break
  256. for i in range(max_append_num):
  257. ry, rx = (
  258. np.round(right_start + right_step * (i + 1))
  259. .flatten()
  260. .astype("int32")
  261. .tolist()
  262. )
  263. if ry < h and rx < w and (ry, rx) not in right_list:
  264. if binary_tcl_map[ry, rx] > 0.5:
  265. right_list.append((ry, rx))
  266. else:
  267. break
  268. all_list = left_list[::-1] + sorted_list + right_list
  269. return all_list
  270. def point_pair2poly(point_pair_list):
  271. """
  272. Transfer vertical point_pairs into poly point in clockwise.
  273. """
  274. point_num = len(point_pair_list) * 2
  275. point_list = [0] * point_num
  276. for idx, point_pair in enumerate(point_pair_list):
  277. point_list[idx] = point_pair[0]
  278. point_list[point_num - 1 - idx] = point_pair[1]
  279. return np.array(point_list).reshape(-1, 2)
  280. def shrink_quad_along_width(quad, begin_width_ratio=0.0, end_width_ratio=1.0):
  281. ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
  282. p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
  283. p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
  284. return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
  285. def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
  286. """
  287. expand poly along width.
  288. """
  289. point_num = poly.shape[0]
  290. left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
  291. left_ratio = (
  292. -shrink_ratio_of_width
  293. * np.linalg.norm(left_quad[0] - left_quad[3])
  294. / (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
  295. )
  296. left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
  297. right_quad = np.array(
  298. [
  299. poly[point_num // 2 - 2],
  300. poly[point_num // 2 - 1],
  301. poly[point_num // 2],
  302. poly[point_num // 2 + 1],
  303. ],
  304. dtype=np.float32,
  305. )
  306. right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(
  307. right_quad[0] - right_quad[3]
  308. ) / (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
  309. right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
  310. poly[0] = left_quad_expand[0]
  311. poly[-1] = left_quad_expand[-1]
  312. poly[point_num // 2 - 1] = right_quad_expand[1]
  313. poly[point_num // 2] = right_quad_expand[2]
  314. return poly
  315. def restore_poly(
  316. instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, src_h, valid_set
  317. ):
  318. poly_list = []
  319. keep_str_list = []
  320. for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
  321. if len(keep_str) < 2:
  322. print("--> too short, {}".format(keep_str))
  323. continue
  324. offset_expand = 1.0
  325. if valid_set == "totaltext":
  326. offset_expand = 1.2
  327. point_pair_list = []
  328. for y, x in yx_center_line:
  329. offset = p_border[:, y, x].reshape(2, 2) * offset_expand
  330. ori_yx = np.array([y, x], dtype=np.float32)
  331. point_pair = (
  332. (ori_yx + offset)[:, ::-1]
  333. * 4.0
  334. / np.array([ratio_w, ratio_h]).reshape(-1, 2)
  335. )
  336. point_pair_list.append(point_pair)
  337. detected_poly = point_pair2poly(point_pair_list)
  338. detected_poly = expand_poly_along_width(
  339. detected_poly, shrink_ratio_of_width=0.2
  340. )
  341. detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
  342. detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
  343. keep_str_list.append(keep_str)
  344. if valid_set == "partvgg":
  345. middle_point = len(detected_poly) // 2
  346. detected_poly = detected_poly[[0, middle_point - 1, middle_point, -1], :]
  347. poly_list.append(detected_poly)
  348. elif valid_set == "totaltext":
  349. poly_list.append(detected_poly)
  350. else:
  351. print("--> Not supported format.")
  352. exit(-1)
  353. return poly_list, keep_str_list
  354. def generate_pivot_list_fast(
  355. p_score,
  356. p_char_maps,
  357. f_direction,
  358. Lexicon_Table,
  359. score_thresh=0.5,
  360. point_gather_mode=None,
  361. ):
  362. """
  363. return center point and end point of TCL instance; filter with the char maps;
  364. """
  365. p_score = p_score[0]
  366. f_direction = f_direction.transpose(1, 2, 0)
  367. p_tcl_map = (p_score > score_thresh) * 1.0
  368. skeleton_map = thin(p_tcl_map.astype(np.uint8))
  369. instance_count, instance_label_map = cv2.connectedComponents(
  370. skeleton_map.astype(np.uint8), connectivity=8
  371. )
  372. # get TCL Instance
  373. all_pos_yxs = []
  374. if instance_count > 0:
  375. for instance_id in range(1, instance_count):
  376. pos_list = []
  377. ys, xs = np.where(instance_label_map == instance_id)
  378. pos_list = list(zip(ys, xs))
  379. if len(pos_list) < 3:
  380. continue
  381. pos_list_sorted = sort_and_expand_with_direction_v2(
  382. pos_list, f_direction, p_tcl_map
  383. )
  384. all_pos_yxs.append(pos_list_sorted)
  385. p_char_maps = p_char_maps.transpose([1, 2, 0])
  386. decoded_str, keep_yxs_list = ctc_decoder_for_image(
  387. all_pos_yxs,
  388. logits_map=p_char_maps,
  389. Lexicon_Table=Lexicon_Table,
  390. point_gather_mode=point_gather_mode,
  391. )
  392. return keep_yxs_list, decoded_str
  393. def extract_main_direction(pos_list, f_direction):
  394. """
  395. f_direction: h x w x 2
  396. pos_list: [[y, x], [y, x], [y, x] ...]
  397. """
  398. pos_list = np.array(pos_list)
  399. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
  400. point_direction = point_direction[:, ::-1] # x, y -> y, x
  401. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  402. average_direction = average_direction / (np.linalg.norm(average_direction) + 1e-6)
  403. return average_direction
  404. def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
  405. """
  406. f_direction: h x w x 2
  407. pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
  408. """
  409. pos_list_full = np.array(pos_list).reshape(-1, 3)
  410. pos_list = pos_list_full[:, 1:]
  411. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
  412. point_direction = point_direction[:, ::-1] # x, y -> y, x
  413. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  414. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  415. sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
  416. return sorted_list
  417. def sort_by_direction_with_image_id(pos_list, f_direction):
  418. """
  419. f_direction: h x w x 2
  420. pos_list: [[y, x], [y, x], [y, x] ...]
  421. """
  422. def sort_part_with_direction(pos_list_full, point_direction):
  423. pos_list_full = np.array(pos_list_full).reshape(-1, 3)
  424. pos_list = pos_list_full[:, 1:]
  425. point_direction = np.array(point_direction).reshape(-1, 2)
  426. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  427. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  428. sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
  429. sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
  430. return sorted_list, sorted_direction
  431. pos_list = np.array(pos_list).reshape(-1, 3)
  432. point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
  433. point_direction = point_direction[:, ::-1] # x, y -> y, x
  434. sorted_point, sorted_direction = sort_part_with_direction(pos_list, point_direction)
  435. point_num = len(sorted_point)
  436. if point_num >= 16:
  437. middle_num = point_num // 2
  438. first_part_point = sorted_point[:middle_num]
  439. first_point_direction = sorted_direction[:middle_num]
  440. sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
  441. first_part_point, first_point_direction
  442. )
  443. last_part_point = sorted_point[middle_num:]
  444. last_point_direction = sorted_direction[middle_num:]
  445. sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
  446. last_part_point, last_point_direction
  447. )
  448. sorted_point = sorted_fist_part_point + sorted_last_part_point
  449. sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
  450. return sorted_point