drrg_targets.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py
  17. """
  18. import cv2
  19. import numpy as np
  20. from ppocr.utils.utility import check_install
  21. from numpy.linalg import norm
  22. class DRRGTargets(object):
  23. def __init__(
  24. self,
  25. orientation_thr=2.0,
  26. resample_step=8.0,
  27. num_min_comps=9,
  28. num_max_comps=600,
  29. min_width=8.0,
  30. max_width=24.0,
  31. center_region_shrink_ratio=0.3,
  32. comp_shrink_ratio=1.0,
  33. comp_w_h_ratio=0.3,
  34. text_comp_nms_thr=0.25,
  35. min_rand_half_height=8.0,
  36. max_rand_half_height=24.0,
  37. jitter_level=0.2,
  38. **kwargs,
  39. ):
  40. super().__init__()
  41. self.orientation_thr = orientation_thr
  42. self.resample_step = resample_step
  43. self.num_max_comps = num_max_comps
  44. self.num_min_comps = num_min_comps
  45. self.min_width = min_width
  46. self.max_width = max_width
  47. self.center_region_shrink_ratio = center_region_shrink_ratio
  48. self.comp_shrink_ratio = comp_shrink_ratio
  49. self.comp_w_h_ratio = comp_w_h_ratio
  50. self.text_comp_nms_thr = text_comp_nms_thr
  51. self.min_rand_half_height = min_rand_half_height
  52. self.max_rand_half_height = max_rand_half_height
  53. self.jitter_level = jitter_level
  54. self.eps = 1e-8
  55. def vector_angle(self, vec1, vec2):
  56. if vec1.ndim > 1:
  57. unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps).reshape((-1, 1))
  58. else:
  59. unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps)
  60. if vec2.ndim > 1:
  61. unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps).reshape((-1, 1))
  62. else:
  63. unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps)
  64. return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
  65. def vector_slope(self, vec):
  66. assert len(vec) == 2
  67. return abs(vec[1] / (vec[0] + self.eps))
  68. def vector_sin(self, vec):
  69. assert len(vec) == 2
  70. return vec[1] / (norm(vec) + self.eps)
  71. def vector_cos(self, vec):
  72. assert len(vec) == 2
  73. return vec[0] / (norm(vec) + self.eps)
  74. def find_head_tail(self, points, orientation_thr):
  75. assert points.ndim == 2
  76. assert points.shape[0] >= 4
  77. assert points.shape[1] == 2
  78. assert isinstance(orientation_thr, float)
  79. if len(points) > 4:
  80. pad_points = np.vstack([points, points[0]])
  81. edge_vec = pad_points[1:] - pad_points[:-1]
  82. theta_sum = []
  83. adjacent_vec_theta = []
  84. for i, edge_vec1 in enumerate(edge_vec):
  85. adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
  86. adjacent_edge_vec = edge_vec[adjacent_ind]
  87. temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
  88. temp_adjacent_theta = self.vector_angle(
  89. adjacent_edge_vec[0], adjacent_edge_vec[1]
  90. )
  91. theta_sum.append(temp_theta_sum)
  92. adjacent_vec_theta.append(temp_adjacent_theta)
  93. theta_sum_score = np.array(theta_sum) / np.pi
  94. adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
  95. poly_center = np.mean(points, axis=0)
  96. edge_dist = np.maximum(
  97. norm(pad_points[1:] - poly_center, axis=-1),
  98. norm(pad_points[:-1] - poly_center, axis=-1),
  99. )
  100. dist_score = edge_dist / (np.max(edge_dist) + self.eps)
  101. position_score = np.zeros(len(edge_vec))
  102. score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
  103. score += 0.35 * dist_score
  104. if len(points) % 2 == 0:
  105. position_score[(len(score) // 2 - 1)] += 1
  106. position_score[-1] += 1
  107. score += 0.1 * position_score
  108. pad_score = np.concatenate([score, score])
  109. score_matrix = np.zeros((len(score), len(score) - 3))
  110. x = np.arange(len(score) - 3) / float(len(score) - 4)
  111. gaussian = (
  112. 1.0
  113. / (np.sqrt(2.0 * np.pi) * 0.5)
  114. * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
  115. )
  116. gaussian = gaussian / np.max(gaussian)
  117. for i in range(len(score)):
  118. score_matrix[i, :] = (
  119. score[i]
  120. + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
  121. )
  122. head_start, tail_increment = np.unravel_index(
  123. score_matrix.argmax(), score_matrix.shape
  124. )
  125. tail_start = (head_start + tail_increment + 2) % len(points)
  126. head_end = (head_start + 1) % len(points)
  127. tail_end = (tail_start + 1) % len(points)
  128. if head_end > tail_end:
  129. head_start, tail_start = tail_start, head_start
  130. head_end, tail_end = tail_end, head_end
  131. head_inds = [head_start, head_end]
  132. tail_inds = [tail_start, tail_end]
  133. else:
  134. if self.vector_slope(points[1] - points[0]) + self.vector_slope(
  135. points[3] - points[2]
  136. ) < self.vector_slope(points[2] - points[1]) + self.vector_slope(
  137. points[0] - points[3]
  138. ):
  139. horizontal_edge_inds = [[0, 1], [2, 3]]
  140. vertical_edge_inds = [[3, 0], [1, 2]]
  141. else:
  142. horizontal_edge_inds = [[3, 0], [1, 2]]
  143. vertical_edge_inds = [[0, 1], [2, 3]]
  144. vertical_len_sum = norm(
  145. points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
  146. ) + norm(
  147. points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
  148. )
  149. horizontal_len_sum = norm(
  150. points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
  151. ) + norm(
  152. points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
  153. )
  154. if vertical_len_sum > horizontal_len_sum * orientation_thr:
  155. head_inds = horizontal_edge_inds[0]
  156. tail_inds = horizontal_edge_inds[1]
  157. else:
  158. head_inds = vertical_edge_inds[0]
  159. tail_inds = vertical_edge_inds[1]
  160. return head_inds, tail_inds
  161. def reorder_poly_edge(self, points):
  162. assert points.ndim == 2
  163. assert points.shape[0] >= 4
  164. assert points.shape[1] == 2
  165. head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
  166. head_edge, tail_edge = points[head_inds], points[tail_inds]
  167. pad_points = np.vstack([points, points])
  168. if tail_inds[1] < 1:
  169. tail_inds[1] = len(points)
  170. sideline1 = pad_points[head_inds[1] : tail_inds[1]]
  171. sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
  172. sideline_mean_shift = np.mean(sideline1, axis=0) - np.mean(sideline2, axis=0)
  173. if sideline_mean_shift[1] > 0:
  174. top_sideline, bot_sideline = sideline2, sideline1
  175. else:
  176. top_sideline, bot_sideline = sideline1, sideline2
  177. return head_edge, tail_edge, top_sideline, bot_sideline
  178. def cal_curve_length(self, line):
  179. assert line.ndim == 2
  180. assert len(line) >= 2
  181. edges_length = np.sqrt(
  182. (line[1:, 0] - line[:-1, 0]) ** 2 + (line[1:, 1] - line[:-1, 1]) ** 2
  183. )
  184. total_length = np.sum(edges_length)
  185. return edges_length, total_length
  186. def resample_line(self, line, n):
  187. assert line.ndim == 2
  188. assert line.shape[0] >= 2
  189. assert line.shape[1] == 2
  190. assert isinstance(n, int)
  191. assert n > 2
  192. edges_length, total_length = self.cal_curve_length(line)
  193. t_org = np.insert(np.cumsum(edges_length), 0, 0)
  194. unit_t = total_length / (n - 1)
  195. t_equidistant = np.arange(1, n - 1, dtype=np.float32) * unit_t
  196. edge_ind = 0
  197. points = [line[0]]
  198. for t in t_equidistant:
  199. while edge_ind < len(edges_length) - 1 and t > t_org[edge_ind + 1]:
  200. edge_ind += 1
  201. t_l, t_r = t_org[edge_ind], t_org[edge_ind + 1]
  202. weight = np.array([t_r - t, t - t_l], dtype=np.float32) / (
  203. t_r - t_l + self.eps
  204. )
  205. p_coords = np.dot(weight, line[[edge_ind, edge_ind + 1]])
  206. points.append(p_coords)
  207. points.append(line[-1])
  208. resampled_line = np.vstack(points)
  209. return resampled_line
  210. def resample_sidelines(self, sideline1, sideline2, resample_step):
  211. assert sideline1.ndim == sideline2.ndim == 2
  212. assert sideline1.shape[1] == sideline2.shape[1] == 2
  213. assert sideline1.shape[0] >= 2
  214. assert sideline2.shape[0] >= 2
  215. assert isinstance(resample_step, float)
  216. _, length1 = self.cal_curve_length(sideline1)
  217. _, length2 = self.cal_curve_length(sideline2)
  218. avg_length = (length1 + length2) / 2
  219. resample_point_num = max(int(float(avg_length) / resample_step) + 1, 3)
  220. resampled_line1 = self.resample_line(sideline1, resample_point_num)
  221. resampled_line2 = self.resample_line(sideline2, resample_point_num)
  222. return resampled_line1, resampled_line2
  223. def dist_point2line(self, point, line):
  224. assert isinstance(line, tuple)
  225. point1, point2 = line
  226. d = abs(np.cross(point2 - point1, point - point1)) / (
  227. norm(point2 - point1) + 1e-8
  228. )
  229. return d
  230. def draw_center_region_maps(
  231. self,
  232. top_line,
  233. bot_line,
  234. center_line,
  235. center_region_mask,
  236. top_height_map,
  237. bot_height_map,
  238. sin_map,
  239. cos_map,
  240. region_shrink_ratio,
  241. ):
  242. assert top_line.shape == bot_line.shape == center_line.shape
  243. assert (
  244. center_region_mask.shape
  245. == top_height_map.shape
  246. == bot_height_map.shape
  247. == sin_map.shape
  248. == cos_map.shape
  249. )
  250. assert isinstance(region_shrink_ratio, float)
  251. h, w = center_region_mask.shape
  252. for i in range(0, len(center_line) - 1):
  253. top_mid_point = (top_line[i] + top_line[i + 1]) / 2
  254. bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
  255. sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
  256. cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
  257. tl = center_line[i] + (top_line[i] - center_line[i]) * region_shrink_ratio
  258. tr = (
  259. center_line[i + 1]
  260. + (top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
  261. )
  262. br = (
  263. center_line[i + 1]
  264. + (bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
  265. )
  266. bl = center_line[i] + (bot_line[i] - center_line[i]) * region_shrink_ratio
  267. current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
  268. cv2.fillPoly(center_region_mask, [current_center_box], color=1)
  269. cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
  270. cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
  271. current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0, w - 1)
  272. current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0, h - 1)
  273. min_coord = np.min(current_center_box, axis=0).astype(np.int32)
  274. max_coord = np.max(current_center_box, axis=0).astype(np.int32)
  275. current_center_box = current_center_box - min_coord
  276. box_sz = max_coord - min_coord + 1
  277. center_box_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
  278. cv2.fillPoly(center_box_mask, [current_center_box], color=1)
  279. inds = np.argwhere(center_box_mask > 0)
  280. inds = inds + (min_coord[1], min_coord[0])
  281. inds_xy = np.fliplr(inds)
  282. top_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
  283. inds_xy, (top_line[i], top_line[i + 1])
  284. )
  285. bot_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
  286. inds_xy, (bot_line[i], bot_line[i + 1])
  287. )
  288. def generate_center_mask_attrib_maps(self, img_size, text_polys):
  289. assert isinstance(img_size, tuple)
  290. h, w = img_size
  291. center_lines = []
  292. center_region_mask = np.zeros((h, w), np.uint8)
  293. top_height_map = np.zeros((h, w), dtype=np.float32)
  294. bot_height_map = np.zeros((h, w), dtype=np.float32)
  295. sin_map = np.zeros((h, w), dtype=np.float32)
  296. cos_map = np.zeros((h, w), dtype=np.float32)
  297. for poly in text_polys:
  298. polygon_points = poly
  299. _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
  300. resampled_top_line, resampled_bot_line = self.resample_sidelines(
  301. top_line, bot_line, self.resample_step
  302. )
  303. resampled_bot_line = resampled_bot_line[::-1]
  304. center_line = (resampled_top_line + resampled_bot_line) / 2
  305. if self.vector_slope(center_line[-1] - center_line[0]) > 2:
  306. if (center_line[-1] - center_line[0])[1] < 0:
  307. center_line = center_line[::-1]
  308. resampled_top_line = resampled_top_line[::-1]
  309. resampled_bot_line = resampled_bot_line[::-1]
  310. else:
  311. if (center_line[-1] - center_line[0])[0] < 0:
  312. center_line = center_line[::-1]
  313. resampled_top_line = resampled_top_line[::-1]
  314. resampled_bot_line = resampled_bot_line[::-1]
  315. line_head_shrink_len = (
  316. np.clip(
  317. (norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
  318. self.min_width,
  319. self.max_width,
  320. )
  321. / 2
  322. )
  323. line_tail_shrink_len = (
  324. np.clip(
  325. (norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
  326. self.min_width,
  327. self.max_width,
  328. )
  329. / 2
  330. )
  331. num_head_shrink = int(line_head_shrink_len // self.resample_step)
  332. num_tail_shrink = int(line_tail_shrink_len // self.resample_step)
  333. if len(center_line) > num_head_shrink + num_tail_shrink + 2:
  334. center_line = center_line[
  335. num_head_shrink : len(center_line) - num_tail_shrink
  336. ]
  337. resampled_top_line = resampled_top_line[
  338. num_head_shrink : len(resampled_top_line) - num_tail_shrink
  339. ]
  340. resampled_bot_line = resampled_bot_line[
  341. num_head_shrink : len(resampled_bot_line) - num_tail_shrink
  342. ]
  343. center_lines.append(center_line.astype(np.int32))
  344. self.draw_center_region_maps(
  345. resampled_top_line,
  346. resampled_bot_line,
  347. center_line,
  348. center_region_mask,
  349. top_height_map,
  350. bot_height_map,
  351. sin_map,
  352. cos_map,
  353. self.center_region_shrink_ratio,
  354. )
  355. return (
  356. center_lines,
  357. center_region_mask,
  358. top_height_map,
  359. bot_height_map,
  360. sin_map,
  361. cos_map,
  362. )
  363. def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
  364. assert isinstance(num_rand_comps, int)
  365. assert num_rand_comps > 0
  366. assert center_sample_mask.ndim == 2
  367. h, w = center_sample_mask.shape
  368. max_rand_half_height = self.max_rand_half_height
  369. min_rand_half_height = self.min_rand_half_height
  370. max_rand_height = max_rand_half_height * 2
  371. max_rand_width = np.clip(
  372. max_rand_height * self.comp_w_h_ratio, self.min_width, self.max_width
  373. )
  374. margin = (
  375. int(np.sqrt((max_rand_height / 2) ** 2 + (max_rand_width / 2) ** 2)) + 1
  376. )
  377. if 2 * margin + 1 > min(h, w):
  378. assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
  379. max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
  380. min_rand_half_height = max(max_rand_half_height / 4, self.min_width / 2)
  381. max_rand_height = max_rand_half_height * 2
  382. max_rand_width = np.clip(
  383. max_rand_height * self.comp_w_h_ratio, self.min_width, self.max_width
  384. )
  385. margin = (
  386. int(np.sqrt((max_rand_height / 2) ** 2 + (max_rand_width / 2) ** 2)) + 1
  387. )
  388. inner_center_sample_mask = np.zeros_like(center_sample_mask)
  389. inner_center_sample_mask[margin : h - margin, margin : w - margin] = (
  390. center_sample_mask[margin : h - margin, margin : w - margin]
  391. )
  392. kernel_size = int(np.clip(max_rand_half_height, 7, 21))
  393. inner_center_sample_mask = cv2.erode(
  394. inner_center_sample_mask, np.ones((kernel_size, kernel_size), np.uint8)
  395. )
  396. center_candidates = np.argwhere(inner_center_sample_mask > 0)
  397. num_center_candidates = len(center_candidates)
  398. sample_inds = np.random.choice(num_center_candidates, num_rand_comps)
  399. rand_centers = center_candidates[sample_inds]
  400. rand_top_height = np.random.randint(
  401. min_rand_half_height, max_rand_half_height, size=(len(rand_centers), 1)
  402. )
  403. rand_bot_height = np.random.randint(
  404. min_rand_half_height, max_rand_half_height, size=(len(rand_centers), 1)
  405. )
  406. rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
  407. rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
  408. scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
  409. rand_cos = rand_cos * scale
  410. rand_sin = rand_sin * scale
  411. height = rand_top_height + rand_bot_height
  412. width = np.clip(height * self.comp_w_h_ratio, self.min_width, self.max_width)
  413. rand_comp_attribs = np.hstack(
  414. [
  415. rand_centers[:, ::-1],
  416. height,
  417. width,
  418. rand_cos,
  419. rand_sin,
  420. np.zeros_like(rand_sin),
  421. ]
  422. ).astype(np.float32)
  423. return rand_comp_attribs
  424. def jitter_comp_attribs(self, comp_attribs, jitter_level):
  425. """Jitter text components attributes.
  426. Args:
  427. comp_attribs (ndarray): The text component attributes.
  428. jitter_level (float): The jitter level of text components
  429. attributes.
  430. Returns:
  431. jittered_comp_attribs (ndarray): The jittered text component
  432. attributes (x, y, h, w, cos, sin, comp_label).
  433. """
  434. assert comp_attribs.shape[1] == 7
  435. assert comp_attribs.shape[0] > 0
  436. assert isinstance(jitter_level, float)
  437. x = comp_attribs[:, 0].reshape((-1, 1))
  438. y = comp_attribs[:, 1].reshape((-1, 1))
  439. h = comp_attribs[:, 2].reshape((-1, 1))
  440. w = comp_attribs[:, 3].reshape((-1, 1))
  441. cos = comp_attribs[:, 4].reshape((-1, 1))
  442. sin = comp_attribs[:, 5].reshape((-1, 1))
  443. comp_labels = comp_attribs[:, 6].reshape((-1, 1))
  444. x += (
  445. (np.random.random(size=(len(comp_attribs), 1)) - 0.5)
  446. * (h * np.abs(cos) + w * np.abs(sin))
  447. * jitter_level
  448. )
  449. y += (
  450. (np.random.random(size=(len(comp_attribs), 1)) - 0.5)
  451. * (h * np.abs(sin) + w * np.abs(cos))
  452. * jitter_level
  453. )
  454. h += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * h * jitter_level
  455. w += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * w * jitter_level
  456. cos += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * 2 * jitter_level
  457. sin += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * 2 * jitter_level
  458. scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
  459. cos = cos * scale
  460. sin = sin * scale
  461. jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, comp_labels])
  462. return jittered_comp_attribs
  463. def generate_comp_attribs(
  464. self,
  465. center_lines,
  466. text_mask,
  467. center_region_mask,
  468. top_height_map,
  469. bot_height_map,
  470. sin_map,
  471. cos_map,
  472. ):
  473. """Generate text component attributes.
  474. Args:
  475. center_lines (list[ndarray]): The list of text center lines .
  476. text_mask (ndarray): The text region mask.
  477. center_region_mask (ndarray): The text center region mask.
  478. top_height_map (ndarray): The map on which the distance from points
  479. to top side lines will be drawn for each pixel in text center
  480. regions.
  481. bot_height_map (ndarray): The map on which the distance from points
  482. to bottom side lines will be drawn for each pixel in text
  483. center regions.
  484. sin_map (ndarray): The sin(theta) map where theta is the angle
  485. between vector (top point - bottom point) and vector (1, 0).
  486. cos_map (ndarray): The cos(theta) map where theta is the angle
  487. between vector (top point - bottom point) and vector (1, 0).
  488. Returns:
  489. pad_comp_attribs (ndarray): The padded text component attributes
  490. of a fixed size.
  491. """
  492. assert isinstance(center_lines, list)
  493. assert (
  494. text_mask.shape
  495. == center_region_mask.shape
  496. == top_height_map.shape
  497. == bot_height_map.shape
  498. == sin_map.shape
  499. == cos_map.shape
  500. )
  501. center_lines_mask = np.zeros_like(center_region_mask)
  502. cv2.polylines(center_lines_mask, center_lines, 0, 1, 1)
  503. center_lines_mask = center_lines_mask * center_region_mask
  504. comp_centers = np.argwhere(center_lines_mask > 0)
  505. y = comp_centers[:, 0]
  506. x = comp_centers[:, 1]
  507. top_height = top_height_map[y, x].reshape((-1, 1)) * self.comp_shrink_ratio
  508. bot_height = bot_height_map[y, x].reshape((-1, 1)) * self.comp_shrink_ratio
  509. sin = sin_map[y, x].reshape((-1, 1))
  510. cos = cos_map[y, x].reshape((-1, 1))
  511. top_mid_points = comp_centers + np.hstack([top_height * sin, top_height * cos])
  512. bot_mid_points = comp_centers - np.hstack([bot_height * sin, bot_height * cos])
  513. width = (top_height + bot_height) * self.comp_w_h_ratio
  514. width = np.clip(width, self.min_width, self.max_width)
  515. r = width / 2
  516. tl = top_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
  517. tr = top_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
  518. br = bot_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
  519. bl = bot_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
  520. text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
  521. score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
  522. text_comps = np.hstack([text_comps, score])
  523. check_install("lanms", "lanms-neo")
  524. from lanms import merge_quadrangle_n9 as la_nms
  525. text_comps = la_nms(text_comps, self.text_comp_nms_thr)
  526. if text_comps.shape[0] >= 1:
  527. img_h, img_w = center_region_mask.shape
  528. text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
  529. text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
  530. comp_centers = np.mean(
  531. text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1
  532. ).astype(np.int32)
  533. x = comp_centers[:, 0]
  534. y = comp_centers[:, 1]
  535. height = (top_height_map[y, x] + bot_height_map[y, x]).reshape((-1, 1))
  536. width = np.clip(
  537. height * self.comp_w_h_ratio, self.min_width, self.max_width
  538. )
  539. cos = cos_map[y, x].reshape((-1, 1))
  540. sin = sin_map[y, x].reshape((-1, 1))
  541. _, comp_label_mask = cv2.connectedComponents(
  542. center_region_mask, connectivity=8
  543. )
  544. comp_labels = comp_label_mask[y, x].reshape((-1, 1)).astype(np.float32)
  545. x = x.reshape((-1, 1)).astype(np.float32)
  546. y = y.reshape((-1, 1)).astype(np.float32)
  547. comp_attribs = np.hstack([x, y, height, width, cos, sin, comp_labels])
  548. comp_attribs = self.jitter_comp_attribs(comp_attribs, self.jitter_level)
  549. if comp_attribs.shape[0] < self.num_min_comps:
  550. num_rand_comps = self.num_min_comps - comp_attribs.shape[0]
  551. rand_comp_attribs = self.generate_rand_comp_attribs(
  552. num_rand_comps, 1 - text_mask
  553. )
  554. comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
  555. else:
  556. comp_attribs = self.generate_rand_comp_attribs(
  557. self.num_min_comps, 1 - text_mask
  558. )
  559. num_comps = (
  560. np.ones((comp_attribs.shape[0], 1), dtype=np.float32)
  561. * comp_attribs.shape[0]
  562. )
  563. comp_attribs = np.hstack([num_comps, comp_attribs])
  564. if comp_attribs.shape[0] > self.num_max_comps:
  565. comp_attribs = comp_attribs[: self.num_max_comps, :]
  566. comp_attribs[:, 0] = self.num_max_comps
  567. pad_comp_attribs = np.zeros(
  568. (self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32
  569. )
  570. pad_comp_attribs[: comp_attribs.shape[0], :] = comp_attribs
  571. return pad_comp_attribs
  572. def generate_text_region_mask(self, img_size, text_polys):
  573. """Generate text center region mask and geometry attribute maps.
  574. Args:
  575. img_size (tuple): The image size (height, width).
  576. text_polys (list[list[ndarray]]): The list of text polygons.
  577. Returns:
  578. text_region_mask (ndarray): The text region mask.
  579. """
  580. assert isinstance(img_size, tuple)
  581. h, w = img_size
  582. text_region_mask = np.zeros((h, w), dtype=np.uint8)
  583. for poly in text_polys:
  584. polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
  585. cv2.fillPoly(text_region_mask, polygon, 1)
  586. return text_region_mask
  587. def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
  588. """Generate effective mask by setting the ineffective regions to 0 and
  589. effective regions to 1.
  590. Args:
  591. mask_size (tuple): The mask size.
  592. polygons_ignore (list[[ndarray]]: The list of ignored text
  593. polygons.
  594. Returns:
  595. mask (ndarray): The effective mask of (height, width).
  596. """
  597. mask = np.ones(mask_size, dtype=np.uint8)
  598. for poly in polygons_ignore:
  599. instance = poly.astype(np.int32).reshape(1, -1, 2)
  600. cv2.fillPoly(mask, instance, 0)
  601. return mask
  602. def generate_targets(self, data):
  603. """Generate the gt targets for DRRG.
  604. Args:
  605. data (dict): The input result dictionary.
  606. Returns:
  607. data (dict): The output result dictionary.
  608. """
  609. assert isinstance(data, dict)
  610. image = data["image"]
  611. polygons = data["polys"]
  612. ignore_tags = data["ignore_tags"]
  613. h, w, _ = image.shape
  614. polygon_masks = []
  615. polygon_masks_ignore = []
  616. for tag, polygon in zip(ignore_tags, polygons):
  617. if tag is True:
  618. polygon_masks_ignore.append(polygon)
  619. else:
  620. polygon_masks.append(polygon)
  621. gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
  622. gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
  623. (
  624. center_lines,
  625. gt_center_region_mask,
  626. gt_top_height_map,
  627. gt_bot_height_map,
  628. gt_sin_map,
  629. gt_cos_map,
  630. ) = self.generate_center_mask_attrib_maps((h, w), polygon_masks)
  631. gt_comp_attribs = self.generate_comp_attribs(
  632. center_lines,
  633. gt_text_mask,
  634. gt_center_region_mask,
  635. gt_top_height_map,
  636. gt_bot_height_map,
  637. gt_sin_map,
  638. gt_cos_map,
  639. )
  640. mapping = {
  641. "gt_text_mask": gt_text_mask,
  642. "gt_center_region_mask": gt_center_region_mask,
  643. "gt_mask": gt_mask,
  644. "gt_top_height_map": gt_top_height_map,
  645. "gt_bot_height_map": gt_bot_height_map,
  646. "gt_sin_map": gt_sin_map,
  647. "gt_cos_map": gt_cos_map,
  648. }
  649. data.update(mapping)
  650. data["gt_comp_attribs"] = gt_comp_attribs
  651. return data
  652. def __call__(self, data):
  653. data = self.generate_targets(data)
  654. return data