sast_process.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This part code is referred from:
  16. https://github.com/songdejia/EAST/blob/master/data_utils.py
  17. """
  18. import math
  19. import cv2
  20. import numpy as np
  21. import json
  22. import sys
  23. import os
  24. __all__ = ["SASTProcessTrain"]
  25. class SASTProcessTrain(object):
  26. def __init__(
  27. self,
  28. image_shape=[512, 512],
  29. min_crop_size=24,
  30. min_crop_side_ratio=0.3,
  31. min_text_size=10,
  32. max_text_size=512,
  33. **kwargs,
  34. ):
  35. self.input_size = image_shape[1]
  36. self.min_crop_size = min_crop_size
  37. self.min_crop_side_ratio = min_crop_side_ratio
  38. self.min_text_size = min_text_size
  39. self.max_text_size = max_text_size
  40. def quad_area(self, poly):
  41. """
  42. compute area of a polygon
  43. :param poly:
  44. :return:
  45. """
  46. edge = [
  47. (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
  48. (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
  49. (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
  50. (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
  51. ]
  52. return np.sum(edge) / 2.0
  53. def gen_quad_from_poly(self, poly):
  54. """
  55. Generate min area quad from poly.
  56. """
  57. point_num = poly.shape[0]
  58. min_area_quad = np.zeros((4, 2), dtype=np.float32)
  59. if True:
  60. rect = cv2.minAreaRect(
  61. poly.astype(np.int32)
  62. ) # (center (x,y), (width, height), angle of rotation)
  63. center_point = rect[0]
  64. box = np.array(cv2.boxPoints(rect))
  65. first_point_idx = 0
  66. min_dist = 1e4
  67. for i in range(4):
  68. dist = (
  69. np.linalg.norm(box[(i + 0) % 4] - poly[0])
  70. + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
  71. + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
  72. + np.linalg.norm(box[(i + 3) % 4] - poly[-1])
  73. )
  74. if dist < min_dist:
  75. min_dist = dist
  76. first_point_idx = i
  77. for i in range(4):
  78. min_area_quad[i] = box[(first_point_idx + i) % 4]
  79. return min_area_quad
  80. def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
  81. """
  82. check so that the text poly is in the same direction,
  83. and also filter some invalid polygons
  84. :param polys:
  85. :param tags:
  86. :return:
  87. """
  88. (h, w) = xxx_todo_changeme
  89. if polys.shape[0] == 0:
  90. return polys, np.array([]), np.array([])
  91. polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
  92. polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
  93. validated_polys = []
  94. validated_tags = []
  95. hv_tags = []
  96. for poly, tag in zip(polys, tags):
  97. quad = self.gen_quad_from_poly(poly)
  98. p_area = self.quad_area(quad)
  99. if abs(p_area) < 1:
  100. print("invalid poly")
  101. continue
  102. if p_area > 0:
  103. if tag == False:
  104. print("poly in wrong direction")
  105. tag = True # reversed cases should be ignore
  106. poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
  107. quad = quad[(0, 3, 2, 1), :]
  108. len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(
  109. quad[3] - quad[2]
  110. )
  111. len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(
  112. quad[1] - quad[2]
  113. )
  114. hv_tag = 1
  115. if len_w * 2.0 < len_h:
  116. hv_tag = 0
  117. validated_polys.append(poly)
  118. validated_tags.append(tag)
  119. hv_tags.append(hv_tag)
  120. return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
  121. def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25):
  122. """
  123. make random crop from the input image
  124. :param im:
  125. :param polys:
  126. :param tags:
  127. :param crop_background:
  128. :param max_tries: 50 -> 25
  129. :return:
  130. """
  131. h, w, _ = im.shape
  132. pad_h = h // 10
  133. pad_w = w // 10
  134. h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
  135. w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
  136. for poly in polys:
  137. poly = np.round(poly, decimals=0).astype(np.int32)
  138. minx = np.min(poly[:, 0])
  139. maxx = np.max(poly[:, 0])
  140. w_array[minx + pad_w : maxx + pad_w] = 1
  141. miny = np.min(poly[:, 1])
  142. maxy = np.max(poly[:, 1])
  143. h_array[miny + pad_h : maxy + pad_h] = 1
  144. # ensure the cropped area not across a text
  145. h_axis = np.where(h_array == 0)[0]
  146. w_axis = np.where(w_array == 0)[0]
  147. if len(h_axis) == 0 or len(w_axis) == 0:
  148. return im, polys, tags, hv_tags
  149. for i in range(max_tries):
  150. xx = np.random.choice(w_axis, size=2)
  151. xmin = np.min(xx) - pad_w
  152. xmax = np.max(xx) - pad_w
  153. xmin = np.clip(xmin, 0, w - 1)
  154. xmax = np.clip(xmax, 0, w - 1)
  155. yy = np.random.choice(h_axis, size=2)
  156. ymin = np.min(yy) - pad_h
  157. ymax = np.max(yy) - pad_h
  158. ymin = np.clip(ymin, 0, h - 1)
  159. ymax = np.clip(ymax, 0, h - 1)
  160. # if xmax - xmin < ARGS.min_crop_side_ratio * w or \
  161. # ymax - ymin < ARGS.min_crop_side_ratio * h:
  162. if xmax - xmin < self.min_crop_size or ymax - ymin < self.min_crop_size:
  163. # area too small
  164. continue
  165. if polys.shape[0] != 0:
  166. poly_axis_in_area = (
  167. (polys[:, :, 0] >= xmin)
  168. & (polys[:, :, 0] <= xmax)
  169. & (polys[:, :, 1] >= ymin)
  170. & (polys[:, :, 1] <= ymax)
  171. )
  172. selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
  173. else:
  174. selected_polys = []
  175. if len(selected_polys) == 0:
  176. # no text in this area
  177. if crop_background:
  178. return (
  179. im[ymin : ymax + 1, xmin : xmax + 1, :],
  180. polys[selected_polys],
  181. tags[selected_polys],
  182. hv_tags[selected_polys],
  183. )
  184. else:
  185. continue
  186. im = im[ymin : ymax + 1, xmin : xmax + 1, :]
  187. polys = polys[selected_polys]
  188. tags = tags[selected_polys]
  189. hv_tags = hv_tags[selected_polys]
  190. polys[:, :, 0] -= xmin
  191. polys[:, :, 1] -= ymin
  192. return im, polys, tags, hv_tags
  193. return im, polys, tags, hv_tags
  194. def generate_direction_map(self, poly_quads, direction_map):
  195. """ """
  196. width_list = []
  197. height_list = []
  198. for quad in poly_quads:
  199. quad_w = (
  200. np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
  201. ) / 2.0
  202. quad_h = (
  203. np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
  204. ) / 2.0
  205. width_list.append(quad_w)
  206. height_list.append(quad_h)
  207. norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
  208. average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
  209. for quad in poly_quads:
  210. direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
  211. direct_vector = (
  212. direct_vector_full
  213. / (np.linalg.norm(direct_vector_full) + 1e-6)
  214. * norm_width
  215. )
  216. direction_label = tuple(
  217. map(
  218. float,
  219. [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)],
  220. )
  221. )
  222. cv2.fillPoly(
  223. direction_map,
  224. quad.round().astype(np.int32)[np.newaxis, :, :],
  225. direction_label,
  226. )
  227. return direction_map
  228. def calculate_average_height(self, poly_quads):
  229. """ """
  230. height_list = []
  231. for quad in poly_quads:
  232. quad_h = (
  233. np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
  234. ) / 2.0
  235. height_list.append(quad_h)
  236. average_height = max(sum(height_list) / len(height_list), 1.0)
  237. return average_height
  238. def generate_tcl_label(
  239. self, hw, polys, tags, ds_ratio, tcl_ratio=0.3, shrink_ratio_of_width=0.15
  240. ):
  241. """
  242. Generate polygon.
  243. """
  244. h, w = hw
  245. h, w = int(h * ds_ratio), int(w * ds_ratio)
  246. polys = polys * ds_ratio
  247. score_map = np.zeros(
  248. (
  249. h,
  250. w,
  251. ),
  252. dtype=np.float32,
  253. )
  254. tbo_map = np.zeros((h, w, 5), dtype=np.float32)
  255. training_mask = np.ones(
  256. (
  257. h,
  258. w,
  259. ),
  260. dtype=np.float32,
  261. )
  262. direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
  263. [1, 1, 3]
  264. ).astype(np.float32)
  265. for poly_idx, poly_tag in enumerate(zip(polys, tags)):
  266. poly = poly_tag[0]
  267. tag = poly_tag[1]
  268. # generate min_area_quad
  269. min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
  270. min_area_quad_h = 0.5 * (
  271. np.linalg.norm(min_area_quad[0] - min_area_quad[3])
  272. + np.linalg.norm(min_area_quad[1] - min_area_quad[2])
  273. )
  274. min_area_quad_w = 0.5 * (
  275. np.linalg.norm(min_area_quad[0] - min_area_quad[1])
  276. + np.linalg.norm(min_area_quad[2] - min_area_quad[3])
  277. )
  278. if (
  279. min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio
  280. or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio
  281. ):
  282. continue
  283. if tag:
  284. # continue
  285. cv2.fillPoly(
  286. training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15
  287. )
  288. else:
  289. tcl_poly = self.poly2tcl(poly, tcl_ratio)
  290. tcl_quads = self.poly2quads(tcl_poly)
  291. poly_quads = self.poly2quads(poly)
  292. # stcl map
  293. stcl_quads, quad_index = self.shrink_poly_along_width(
  294. tcl_quads,
  295. shrink_ratio_of_width=shrink_ratio_of_width,
  296. expand_height_ratio=1.0 / tcl_ratio,
  297. )
  298. # generate tcl map
  299. cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
  300. # generate tbo map
  301. for idx, quad in enumerate(stcl_quads):
  302. quad_mask = np.zeros((h, w), dtype=np.float32)
  303. quad_mask = cv2.fillPoly(
  304. quad_mask,
  305. np.round(quad[np.newaxis, :, :]).astype(np.int32),
  306. 1.0,
  307. )
  308. tbo_map = self.gen_quad_tbo(
  309. poly_quads[quad_index[idx]], quad_mask, tbo_map
  310. )
  311. return score_map, tbo_map, training_mask
  312. def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25):
  313. """
  314. Generate tcl map, tvo map and tbo map.
  315. """
  316. h, w = hw
  317. h, w = int(h * ds_ratio), int(w * ds_ratio)
  318. polys = polys * ds_ratio
  319. poly_mask = np.zeros((h, w), dtype=np.float32)
  320. tvo_map = np.ones((9, h, w), dtype=np.float32)
  321. tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
  322. tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
  323. poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
  324. # tco map
  325. tco_map = np.ones((3, h, w), dtype=np.float32)
  326. tco_map[0] = np.tile(np.arange(0, w), (h, 1))
  327. tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
  328. poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
  329. poly_short_edge_map = np.ones((h, w), dtype=np.float32)
  330. for poly, poly_tag in zip(polys, tags):
  331. if poly_tag == True:
  332. continue
  333. # adjust point order for vertical poly
  334. poly = self.adjust_point(poly)
  335. # generate min_area_quad
  336. min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
  337. min_area_quad_h = 0.5 * (
  338. np.linalg.norm(min_area_quad[0] - min_area_quad[3])
  339. + np.linalg.norm(min_area_quad[1] - min_area_quad[2])
  340. )
  341. min_area_quad_w = 0.5 * (
  342. np.linalg.norm(min_area_quad[0] - min_area_quad[1])
  343. + np.linalg.norm(min_area_quad[2] - min_area_quad[3])
  344. )
  345. # generate tcl map and text, 128 * 128
  346. tcl_poly = self.poly2tcl(poly, tcl_ratio)
  347. # generate poly_tv_xy_map
  348. for idx in range(4):
  349. cv2.fillPoly(
  350. poly_tv_xy_map[2 * idx],
  351. np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
  352. float(min(max(min_area_quad[idx, 0], 0), w)),
  353. )
  354. cv2.fillPoly(
  355. poly_tv_xy_map[2 * idx + 1],
  356. np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
  357. float(min(max(min_area_quad[idx, 1], 0), h)),
  358. )
  359. # generate poly_tc_xy_map
  360. for idx in range(2):
  361. cv2.fillPoly(
  362. poly_tc_xy_map[idx],
  363. np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
  364. float(center_point[idx]),
  365. )
  366. # generate poly_short_edge_map
  367. cv2.fillPoly(
  368. poly_short_edge_map,
  369. np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
  370. float(max(min(min_area_quad_h, min_area_quad_w), 1.0)),
  371. )
  372. # generate poly_mask and training_mask
  373. cv2.fillPoly(
  374. poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1
  375. )
  376. tvo_map *= poly_mask
  377. tvo_map[:8] -= poly_tv_xy_map
  378. tvo_map[-1] /= poly_short_edge_map
  379. tvo_map = tvo_map.transpose((1, 2, 0))
  380. tco_map *= poly_mask
  381. tco_map[:2] -= poly_tc_xy_map
  382. tco_map[-1] /= poly_short_edge_map
  383. tco_map = tco_map.transpose((1, 2, 0))
  384. return tvo_map, tco_map
  385. def adjust_point(self, poly):
  386. """
  387. adjust point order.
  388. """
  389. point_num = poly.shape[0]
  390. if point_num == 4:
  391. len_1 = np.linalg.norm(poly[0] - poly[1])
  392. len_2 = np.linalg.norm(poly[1] - poly[2])
  393. len_3 = np.linalg.norm(poly[2] - poly[3])
  394. len_4 = np.linalg.norm(poly[3] - poly[0])
  395. if (len_1 + len_3) * 1.5 < (len_2 + len_4):
  396. poly = poly[[1, 2, 3, 0], :]
  397. elif point_num > 4:
  398. vector_1 = poly[0] - poly[1]
  399. vector_2 = poly[1] - poly[2]
  400. cos_theta = np.dot(vector_1, vector_2) / (
  401. np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6
  402. )
  403. theta = np.arccos(np.round(cos_theta, decimals=4))
  404. if abs(theta) > (70 / 180 * math.pi):
  405. index = list(range(1, point_num)) + [0]
  406. poly = poly[np.array(index), :]
  407. return poly
  408. def gen_min_area_quad_from_poly(self, poly):
  409. """
  410. Generate min area quad from poly.
  411. """
  412. point_num = poly.shape[0]
  413. min_area_quad = np.zeros((4, 2), dtype=np.float32)
  414. if point_num == 4:
  415. min_area_quad = poly
  416. center_point = np.sum(poly, axis=0) / 4
  417. else:
  418. rect = cv2.minAreaRect(
  419. poly.astype(np.int32)
  420. ) # (center (x,y), (width, height), angle of rotation)
  421. center_point = rect[0]
  422. box = np.array(cv2.boxPoints(rect))
  423. first_point_idx = 0
  424. min_dist = 1e4
  425. for i in range(4):
  426. dist = (
  427. np.linalg.norm(box[(i + 0) % 4] - poly[0])
  428. + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
  429. + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
  430. + np.linalg.norm(box[(i + 3) % 4] - poly[-1])
  431. )
  432. if dist < min_dist:
  433. min_dist = dist
  434. first_point_idx = i
  435. for i in range(4):
  436. min_area_quad[i] = box[(first_point_idx + i) % 4]
  437. return min_area_quad, center_point
  438. def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
  439. """
  440. Generate shrink_quad_along_width.
  441. """
  442. ratio_pair = np.array(
  443. [[begin_width_ratio], [end_width_ratio]], dtype=np.float32
  444. )
  445. p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
  446. p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
  447. return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
  448. def shrink_poly_along_width(
  449. self, quads, shrink_ratio_of_width, expand_height_ratio=1.0
  450. ):
  451. """
  452. shrink poly with given length.
  453. """
  454. upper_edge_list = []
  455. def get_cut_info(edge_len_list, cut_len):
  456. for idx, edge_len in enumerate(edge_len_list):
  457. cut_len -= edge_len
  458. if cut_len <= 0.000001:
  459. ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
  460. return idx, ratio
  461. for quad in quads:
  462. upper_edge_len = np.linalg.norm(quad[0] - quad[1])
  463. upper_edge_list.append(upper_edge_len)
  464. # length of left edge and right edge.
  465. left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
  466. right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
  467. shrink_length = (
  468. min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
  469. )
  470. # shrinking length
  471. upper_len_left = shrink_length
  472. upper_len_right = sum(upper_edge_list) - shrink_length
  473. left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
  474. left_quad = self.shrink_quad_along_width(
  475. quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1
  476. )
  477. right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
  478. right_quad = self.shrink_quad_along_width(
  479. quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio
  480. )
  481. out_quad_list = []
  482. if left_idx == right_idx:
  483. out_quad_list.append(
  484. [left_quad[0], right_quad[1], right_quad[2], left_quad[3]]
  485. )
  486. else:
  487. out_quad_list.append(left_quad)
  488. for idx in range(left_idx + 1, right_idx):
  489. out_quad_list.append(quads[idx])
  490. out_quad_list.append(right_quad)
  491. return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
  492. def vector_angle(self, A, B):
  493. """
  494. Calculate the angle between vector AB and x-axis positive direction.
  495. """
  496. AB = np.array([B[1] - A[1], B[0] - A[0]])
  497. return np.arctan2(*AB)
  498. def theta_line_cross_point(self, theta, point):
  499. """
  500. Calculate the line through given point and angle in ax + by + c =0 form.
  501. """
  502. x, y = point
  503. cos = np.cos(theta)
  504. sin = np.sin(theta)
  505. return [sin, -cos, cos * y - sin * x]
  506. def line_cross_two_point(self, A, B):
  507. """
  508. Calculate the line through given point A and B in ax + by + c =0 form.
  509. """
  510. angle = self.vector_angle(A, B)
  511. return self.theta_line_cross_point(angle, A)
  512. def average_angle(self, poly):
  513. """
  514. Calculate the average angle between left and right edge in given poly.
  515. """
  516. p0, p1, p2, p3 = poly
  517. angle30 = self.vector_angle(p3, p0)
  518. angle21 = self.vector_angle(p2, p1)
  519. return (angle30 + angle21) / 2
  520. def line_cross_point(self, line1, line2):
  521. """
  522. line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
  523. """
  524. a1, b1, c1 = line1
  525. a2, b2, c2 = line2
  526. d = a1 * b2 - a2 * b1
  527. if d == 0:
  528. # print("line1", line1)
  529. # print("line2", line2)
  530. print("Cross point does not exist")
  531. return np.array([0, 0], dtype=np.float32)
  532. else:
  533. x = (b1 * c2 - b2 * c1) / d
  534. y = (a2 * c1 - a1 * c2) / d
  535. return np.array([x, y], dtype=np.float32)
  536. def quad2tcl(self, poly, ratio):
  537. """
  538. Generate center line by poly clock-wise point. (4, 2)
  539. """
  540. ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
  541. p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
  542. p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
  543. return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
  544. def poly2tcl(self, poly, ratio):
  545. """
  546. Generate center line by poly clock-wise point.
  547. """
  548. ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
  549. tcl_poly = np.zeros_like(poly)
  550. point_num = poly.shape[0]
  551. for idx in range(point_num // 2):
  552. point_pair = (
  553. poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
  554. )
  555. tcl_poly[idx] = point_pair[0]
  556. tcl_poly[point_num - 1 - idx] = point_pair[1]
  557. return tcl_poly
  558. def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
  559. """
  560. Generate tbo_map for give quad.
  561. """
  562. # upper and lower line function: ax + by + c = 0;
  563. up_line = self.line_cross_two_point(quad[0], quad[1])
  564. lower_line = self.line_cross_two_point(quad[3], quad[2])
  565. quad_h = 0.5 * (
  566. np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
  567. )
  568. quad_w = 0.5 * (
  569. np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
  570. )
  571. # average angle of left and right line.
  572. angle = self.average_angle(quad)
  573. xy_in_poly = np.argwhere(tcl_mask == 1)
  574. for y, x in xy_in_poly:
  575. point = (x, y)
  576. line = self.theta_line_cross_point(angle, point)
  577. cross_point_upper = self.line_cross_point(up_line, line)
  578. cross_point_lower = self.line_cross_point(lower_line, line)
  579. ##FIX, offset reverse
  580. upper_offset_x, upper_offset_y = cross_point_upper - point
  581. lower_offset_x, lower_offset_y = cross_point_lower - point
  582. tbo_map[y, x, 0] = upper_offset_y
  583. tbo_map[y, x, 1] = upper_offset_x
  584. tbo_map[y, x, 2] = lower_offset_y
  585. tbo_map[y, x, 3] = lower_offset_x
  586. tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
  587. return tbo_map
  588. def poly2quads(self, poly):
  589. """
  590. Split poly into quads.
  591. """
  592. quad_list = []
  593. point_num = poly.shape[0]
  594. # point pair
  595. point_pair_list = []
  596. for idx in range(point_num // 2):
  597. point_pair = [poly[idx], poly[point_num - 1 - idx]]
  598. point_pair_list.append(point_pair)
  599. quad_num = point_num // 2 - 1
  600. for idx in range(quad_num):
  601. # reshape and adjust to clock-wise
  602. quad_list.append(
  603. (np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]]
  604. )
  605. return np.array(quad_list)
  606. def __call__(self, data):
  607. im = data["image"]
  608. text_polys = data["polys"]
  609. text_tags = data["ignore_tags"]
  610. if im is None:
  611. return None
  612. if text_polys.shape[0] == 0:
  613. return None
  614. h, w, _ = im.shape
  615. text_polys, text_tags, hv_tags = self.check_and_validate_polys(
  616. text_polys, text_tags, (h, w)
  617. )
  618. if text_polys.shape[0] == 0:
  619. return None
  620. # set aspect ratio and keep area fix
  621. asp_scales = np.arange(1.0, 1.55, 0.1)
  622. asp_scale = np.random.choice(asp_scales)
  623. if np.random.rand() < 0.5:
  624. asp_scale = 1.0 / asp_scale
  625. asp_scale = math.sqrt(asp_scale)
  626. asp_wx = asp_scale
  627. asp_hy = 1.0 / asp_scale
  628. im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
  629. text_polys[:, :, 0] *= asp_wx
  630. text_polys[:, :, 1] *= asp_hy
  631. h, w, _ = im.shape
  632. if max(h, w) > 2048:
  633. rd_scale = 2048.0 / max(h, w)
  634. im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
  635. text_polys *= rd_scale
  636. h, w, _ = im.shape
  637. if min(h, w) < 16:
  638. return None
  639. # no background
  640. im, text_polys, text_tags, hv_tags = self.crop_area(
  641. im, text_polys, text_tags, hv_tags, crop_background=False
  642. )
  643. if text_polys.shape[0] == 0:
  644. return None
  645. # continue for all ignore case
  646. if np.sum((text_tags * 1.0)) >= text_tags.size:
  647. return None
  648. new_h, new_w, _ = im.shape
  649. if (new_h is None) or (new_w is None):
  650. return None
  651. # resize image
  652. std_ratio = float(self.input_size) / max(new_w, new_h)
  653. rand_scales = np.array(
  654. [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]
  655. )
  656. rz_scale = std_ratio * np.random.choice(rand_scales)
  657. im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
  658. text_polys[:, :, 0] *= rz_scale
  659. text_polys[:, :, 1] *= rz_scale
  660. # add gaussian blur
  661. if np.random.rand() < 0.1 * 0.5:
  662. ks = np.random.permutation(5)[0] + 1
  663. ks = int(ks / 2) * 2 + 1
  664. im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
  665. # add brighter
  666. if np.random.rand() < 0.1 * 0.5:
  667. im = im * (1.0 + np.random.rand() * 0.5)
  668. im = np.clip(im, 0.0, 255.0)
  669. # add darker
  670. if np.random.rand() < 0.1 * 0.5:
  671. im = im * (1.0 - np.random.rand() * 0.5)
  672. im = np.clip(im, 0.0, 255.0)
  673. # Padding the im to [input_size, input_size]
  674. new_h, new_w, _ = im.shape
  675. if min(new_w, new_h) < self.input_size * 0.5:
  676. return None
  677. im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32)
  678. im_padded[:, :, 2] = 0.485 * 255
  679. im_padded[:, :, 1] = 0.456 * 255
  680. im_padded[:, :, 0] = 0.406 * 255
  681. # Random the start position
  682. del_h = self.input_size - new_h
  683. del_w = self.input_size - new_w
  684. sh, sw = 0, 0
  685. if del_h > 1:
  686. sh = int(np.random.rand() * del_h)
  687. if del_w > 1:
  688. sw = int(np.random.rand() * del_w)
  689. # Padding
  690. im_padded[sh : sh + new_h, sw : sw + new_w, :] = im.copy()
  691. text_polys[:, :, 0] += sw
  692. text_polys[:, :, 1] += sh
  693. score_map, border_map, training_mask = self.generate_tcl_label(
  694. (self.input_size, self.input_size), text_polys, text_tags, 0.25
  695. )
  696. # SAST head
  697. tvo_map, tco_map = self.generate_tvo_and_tco(
  698. (self.input_size, self.input_size),
  699. text_polys,
  700. text_tags,
  701. tcl_ratio=0.3,
  702. ds_ratio=0.25,
  703. )
  704. # print("test--------tvo_map shape:", tvo_map.shape)
  705. im_padded[:, :, 2] -= 0.485 * 255
  706. im_padded[:, :, 1] -= 0.456 * 255
  707. im_padded[:, :, 0] -= 0.406 * 255
  708. im_padded[:, :, 2] /= 255.0 * 0.229
  709. im_padded[:, :, 1] /= 255.0 * 0.224
  710. im_padded[:, :, 0] /= 255.0 * 0.225
  711. im_padded = im_padded.transpose((2, 0, 1))
  712. data["image"] = im_padded[::-1, :, :]
  713. data["score_map"] = score_map[np.newaxis, :, :]
  714. data["border_map"] = border_map.transpose((2, 0, 1))
  715. data["training_mask"] = training_mask[np.newaxis, :, :]
  716. data["tvo_map"] = tvo_map.transpose((2, 0, 1))
  717. data["tco_map"] = tco_map.transpose((2, 0, 1))
  718. return data