pg_process.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import cv2
  16. import numpy as np
  17. from skimage.morphology._skeletonize import thin
  18. from ppocr.utils.e2e_utils.extract_textpoint_fast import (
  19. sort_and_expand_with_direction_v2,
  20. )
  21. __all__ = ["PGProcessTrain"]
  22. class PGProcessTrain(object):
  23. def __init__(
  24. self,
  25. character_dict_path,
  26. max_text_length,
  27. max_text_nums,
  28. tcl_len,
  29. batch_size=14,
  30. use_resize=True,
  31. use_random_crop=False,
  32. min_crop_size=24,
  33. min_text_size=4,
  34. max_text_size=512,
  35. point_gather_mode=None,
  36. **kwargs,
  37. ):
  38. self.tcl_len = tcl_len
  39. self.max_text_length = max_text_length
  40. self.max_text_nums = max_text_nums
  41. self.batch_size = batch_size
  42. if use_random_crop is True:
  43. self.min_crop_size = min_crop_size
  44. self.use_random_crop = use_random_crop
  45. self.min_text_size = min_text_size
  46. self.max_text_size = max_text_size
  47. self.use_resize = use_resize
  48. self.point_gather_mode = point_gather_mode
  49. self.Lexicon_Table = self.get_dict(character_dict_path)
  50. self.pad_num = len(self.Lexicon_Table)
  51. self.img_id = 0
  52. def get_dict(self, character_dict_path):
  53. character_str = ""
  54. with open(character_dict_path, "rb") as fin:
  55. lines = fin.readlines()
  56. for line in lines:
  57. line = line.decode("utf-8").strip("\n").strip("\r\n")
  58. character_str += line
  59. dict_character = list(character_str)
  60. return dict_character
  61. def quad_area(self, poly):
  62. """
  63. compute area of a polygon
  64. :param poly:
  65. :return:
  66. """
  67. edge = [
  68. (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
  69. (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
  70. (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
  71. (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
  72. ]
  73. return np.sum(edge) / 2.0
  74. def gen_quad_from_poly(self, poly):
  75. """
  76. Generate min area quad from poly.
  77. """
  78. point_num = poly.shape[0]
  79. min_area_quad = np.zeros((4, 2), dtype=np.float32)
  80. rect = cv2.minAreaRect(
  81. poly.astype(np.int32)
  82. ) # (center (x,y), (width, height), angle of rotation)
  83. box = np.array(cv2.boxPoints(rect))
  84. first_point_idx = 0
  85. min_dist = 1e4
  86. for i in range(4):
  87. dist = (
  88. np.linalg.norm(box[(i + 0) % 4] - poly[0])
  89. + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
  90. + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
  91. + np.linalg.norm(box[(i + 3) % 4] - poly[-1])
  92. )
  93. if dist < min_dist:
  94. min_dist = dist
  95. first_point_idx = i
  96. for i in range(4):
  97. min_area_quad[i] = box[(first_point_idx + i) % 4]
  98. return min_area_quad
  99. def check_and_validate_polys(self, polys, tags, im_size):
  100. """
  101. check so that the text poly is in the same direction,
  102. and also filter some invalid polygons
  103. :param polys:
  104. :param tags:
  105. :return:
  106. """
  107. (h, w) = im_size
  108. if polys.shape[0] == 0:
  109. return polys, np.array([]), np.array([])
  110. polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
  111. polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
  112. validated_polys = []
  113. validated_tags = []
  114. hv_tags = []
  115. for poly, tag in zip(polys, tags):
  116. quad = self.gen_quad_from_poly(poly)
  117. p_area = self.quad_area(quad)
  118. if abs(p_area) < 1:
  119. print("invalid poly")
  120. continue
  121. if p_area > 0:
  122. if tag == False:
  123. print("poly in wrong direction")
  124. tag = True # reversed cases should be ignore
  125. poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
  126. quad = quad[(0, 3, 2, 1), :]
  127. len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(
  128. quad[3] - quad[2]
  129. )
  130. len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(
  131. quad[1] - quad[2]
  132. )
  133. hv_tag = 1
  134. if len_w * 2.0 < len_h:
  135. hv_tag = 0
  136. validated_polys.append(poly)
  137. validated_tags.append(tag)
  138. hv_tags.append(hv_tag)
  139. return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
  140. def crop_area(
  141. self, im, polys, tags, hv_tags, txts, crop_background=False, max_tries=25
  142. ):
  143. """
  144. make random crop from the input image
  145. :param im:
  146. :param polys: [b,4,2]
  147. :param tags:
  148. :param crop_background:
  149. :param max_tries: 50 -> 25
  150. :return:
  151. """
  152. h, w, _ = im.shape
  153. pad_h = h // 10
  154. pad_w = w // 10
  155. h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
  156. w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
  157. for poly in polys:
  158. poly = np.round(poly, decimals=0).astype(np.int32)
  159. minx = np.min(poly[:, 0])
  160. maxx = np.max(poly[:, 0])
  161. w_array[minx + pad_w : maxx + pad_w] = 1
  162. miny = np.min(poly[:, 1])
  163. maxy = np.max(poly[:, 1])
  164. h_array[miny + pad_h : maxy + pad_h] = 1
  165. # ensure the cropped area not across a text
  166. h_axis = np.where(h_array == 0)[0]
  167. w_axis = np.where(w_array == 0)[0]
  168. if len(h_axis) == 0 or len(w_axis) == 0:
  169. return im, polys, tags, hv_tags, txts
  170. for i in range(max_tries):
  171. xx = np.random.choice(w_axis, size=2)
  172. xmin = np.min(xx) - pad_w
  173. xmax = np.max(xx) - pad_w
  174. xmin = np.clip(xmin, 0, w - 1)
  175. xmax = np.clip(xmax, 0, w - 1)
  176. yy = np.random.choice(h_axis, size=2)
  177. ymin = np.min(yy) - pad_h
  178. ymax = np.max(yy) - pad_h
  179. ymin = np.clip(ymin, 0, h - 1)
  180. ymax = np.clip(ymax, 0, h - 1)
  181. if xmax - xmin < self.min_crop_size or ymax - ymin < self.min_crop_size:
  182. continue
  183. if polys.shape[0] != 0:
  184. poly_axis_in_area = (
  185. (polys[:, :, 0] >= xmin)
  186. & (polys[:, :, 0] <= xmax)
  187. & (polys[:, :, 1] >= ymin)
  188. & (polys[:, :, 1] <= ymax)
  189. )
  190. selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
  191. else:
  192. selected_polys = []
  193. if len(selected_polys) == 0:
  194. # no text in this area
  195. if crop_background:
  196. txts_tmp = []
  197. for selected_poly in selected_polys:
  198. txts_tmp.append(txts[selected_poly])
  199. txts = txts_tmp
  200. return (
  201. im[ymin : ymax + 1, xmin : xmax + 1, :],
  202. polys[selected_polys],
  203. tags[selected_polys],
  204. hv_tags[selected_polys],
  205. txts,
  206. )
  207. else:
  208. continue
  209. im = im[ymin : ymax + 1, xmin : xmax + 1, :]
  210. polys = polys[selected_polys]
  211. tags = tags[selected_polys]
  212. hv_tags = hv_tags[selected_polys]
  213. txts_tmp = []
  214. for selected_poly in selected_polys:
  215. txts_tmp.append(txts[selected_poly])
  216. txts = txts_tmp
  217. polys[:, :, 0] -= xmin
  218. polys[:, :, 1] -= ymin
  219. return im, polys, tags, hv_tags, txts
  220. return im, polys, tags, hv_tags, txts
  221. def fit_and_gather_tcl_points_v2(
  222. self,
  223. min_area_quad,
  224. poly,
  225. max_h,
  226. max_w,
  227. fixed_point_num=64,
  228. img_id=0,
  229. reference_height=3,
  230. ):
  231. """
  232. Find the center point of poly as key_points, then fit and gather.
  233. """
  234. key_point_xys = []
  235. point_num = poly.shape[0]
  236. for idx in range(point_num // 2):
  237. center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
  238. key_point_xys.append(center_point)
  239. tmp_image = np.zeros(
  240. shape=(
  241. max_h,
  242. max_w,
  243. ),
  244. dtype="float32",
  245. )
  246. cv2.polylines(tmp_image, [np.array(key_point_xys).astype("int32")], False, 1.0)
  247. ys, xs = np.where(tmp_image > 0)
  248. xy_text = np.array(list(zip(xs, ys)), dtype="float32")
  249. left_center_pt = ((min_area_quad[0] + min_area_quad[3]) / 2.0).reshape(1, 2)
  250. right_center_pt = ((min_area_quad[1] + min_area_quad[2]) / 2.0).reshape(1, 2)
  251. proj_unit_vec = (right_center_pt - left_center_pt) / (
  252. np.linalg.norm(right_center_pt - left_center_pt) + 1e-6
  253. )
  254. proj_unit_vec_tile = np.tile(proj_unit_vec, (xy_text.shape[0], 1)) # (n, 2)
  255. left_center_pt_tile = np.tile(left_center_pt, (xy_text.shape[0], 1)) # (n, 2)
  256. xy_text_to_left_center = xy_text - left_center_pt_tile
  257. proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
  258. xy_text = xy_text[np.argsort(proj_value)]
  259. # convert to np and keep the num of point not greater then fixed_point_num
  260. pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
  261. point_num = len(pos_info)
  262. if point_num > fixed_point_num:
  263. keep_ids = [
  264. int((point_num * 1.0 / fixed_point_num) * x)
  265. for x in range(fixed_point_num)
  266. ]
  267. pos_info = pos_info[keep_ids, :]
  268. keep = int(min(len(pos_info), fixed_point_num))
  269. if np.random.rand() < 0.2 and reference_height >= 3:
  270. dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
  271. random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape([keep, 1])
  272. pos_info += random_float
  273. pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
  274. pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
  275. # padding to fixed length
  276. pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
  277. pos_l[:, 0] = np.ones((self.tcl_len,)) * img_id
  278. pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
  279. pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
  280. pos_m[:keep] = 1.0
  281. return pos_l, pos_m
  282. def fit_and_gather_tcl_points_v3(
  283. self,
  284. min_area_quad,
  285. poly,
  286. max_h,
  287. max_w,
  288. fixed_point_num=64,
  289. img_id=0,
  290. reference_height=3,
  291. ):
  292. """
  293. Find the center point of poly as key_points, then fit and gather.
  294. """
  295. det_mask = np.zeros(
  296. (int(max_h / self.ds_ratio), int(max_w / self.ds_ratio))
  297. ).astype(np.float32)
  298. # score_big_map
  299. cv2.fillPoly(det_mask, np.round(poly / self.ds_ratio).astype(np.int32), 1.0)
  300. det_mask = cv2.resize(det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio)
  301. det_mask = np.array(det_mask > 1e-3, dtype="float32")
  302. f_direction = self.f_direction
  303. skeleton_map = thin(det_mask.astype(np.uint8))
  304. instance_count, instance_label_map = cv2.connectedComponents(
  305. skeleton_map.astype(np.uint8), connectivity=8
  306. )
  307. ys, xs = np.where(instance_label_map == 1)
  308. pos_list = list(zip(ys, xs))
  309. if len(pos_list) < 3:
  310. return None
  311. pos_list_sorted = sort_and_expand_with_direction_v2(
  312. pos_list, f_direction, det_mask
  313. )
  314. pos_list_sorted = np.array(pos_list_sorted)
  315. length = len(pos_list_sorted) - 1
  316. insert_num = 0
  317. for index in range(length):
  318. stride_y = np.abs(
  319. pos_list_sorted[index + insert_num][0]
  320. - pos_list_sorted[index + 1 + insert_num][0]
  321. )
  322. stride_x = np.abs(
  323. pos_list_sorted[index + insert_num][1]
  324. - pos_list_sorted[index + 1 + insert_num][1]
  325. )
  326. max_points = int(max(stride_x, stride_y))
  327. stride = (
  328. pos_list_sorted[index + insert_num]
  329. - pos_list_sorted[index + 1 + insert_num]
  330. ) / (max_points)
  331. insert_num_temp = max_points - 1
  332. for i in range(int(insert_num_temp)):
  333. insert_value = pos_list_sorted[index + insert_num] - (i + 1) * stride
  334. insert_index = index + i + 1 + insert_num
  335. pos_list_sorted = np.insert(
  336. pos_list_sorted, insert_index, insert_value, axis=0
  337. )
  338. insert_num += insert_num_temp
  339. pos_info = (
  340. np.array(pos_list_sorted).reshape(-1, 2).astype(np.float32)
  341. ) # xy-> yx
  342. point_num = len(pos_info)
  343. if point_num > fixed_point_num:
  344. keep_ids = [
  345. int((point_num * 1.0 / fixed_point_num) * x)
  346. for x in range(fixed_point_num)
  347. ]
  348. pos_info = pos_info[keep_ids, :]
  349. keep = int(min(len(pos_info), fixed_point_num))
  350. reference_width = (
  351. np.abs(poly[0, 0, 0] - poly[-1, 1, 0])
  352. + np.abs(poly[0, 3, 0] - poly[-1, 2, 0])
  353. ) // 2
  354. if np.random.rand() < 1:
  355. dh = (np.random.rand(keep) - 0.5) * reference_height
  356. offset = np.random.rand() - 0.5
  357. dw = np.array([[0, offset * reference_width * 0.2]])
  358. random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape([keep, 1])
  359. random_float_w = dw.repeat(keep, axis=0)
  360. pos_info += random_float_h
  361. pos_info += random_float_w
  362. pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
  363. pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
  364. # padding to fixed length
  365. pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
  366. pos_l[:, 0] = np.ones((self.tcl_len,)) * img_id
  367. pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
  368. pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
  369. pos_m[:keep] = 1.0
  370. return pos_l, pos_m
  371. def generate_direction_map(self, poly_quads, n_char, direction_map):
  372. """ """
  373. width_list = []
  374. height_list = []
  375. for quad in poly_quads:
  376. quad_w = (
  377. np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
  378. ) / 2.0
  379. quad_h = (
  380. np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
  381. ) / 2.0
  382. width_list.append(quad_w)
  383. height_list.append(quad_h)
  384. norm_width = max(sum(width_list) / n_char, 1.0)
  385. average_height = max(sum(height_list) / len(height_list), 1.0)
  386. k = 1
  387. for quad in poly_quads:
  388. direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
  389. direct_vector = (
  390. direct_vector_full
  391. / (np.linalg.norm(direct_vector_full) + 1e-6)
  392. * norm_width
  393. )
  394. direction_label = tuple(
  395. map(float, [direct_vector[0], direct_vector[1], 1.0 / average_height])
  396. )
  397. cv2.fillPoly(
  398. direction_map,
  399. quad.round().astype(np.int32)[np.newaxis, :, :],
  400. direction_label,
  401. )
  402. k += 1
  403. return direction_map
  404. def calculate_average_height(self, poly_quads):
  405. """ """
  406. height_list = []
  407. for quad in poly_quads:
  408. quad_h = (
  409. np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])
  410. ) / 2.0
  411. height_list.append(quad_h)
  412. average_height = max(sum(height_list) / len(height_list), 1.0)
  413. return average_height
  414. def generate_tcl_ctc_label(
  415. self,
  416. h,
  417. w,
  418. polys,
  419. tags,
  420. text_strs,
  421. ds_ratio,
  422. tcl_ratio=0.3,
  423. shrink_ratio_of_width=0.15,
  424. ):
  425. """
  426. Generate polygon.
  427. """
  428. self.ds_ratio = ds_ratio
  429. score_map_big = np.zeros(
  430. (
  431. h,
  432. w,
  433. ),
  434. dtype=np.float32,
  435. )
  436. h, w = int(h * ds_ratio), int(w * ds_ratio)
  437. polys = polys * ds_ratio
  438. score_map = np.zeros(
  439. (
  440. h,
  441. w,
  442. ),
  443. dtype=np.float32,
  444. )
  445. score_label_map = np.zeros(
  446. (
  447. h,
  448. w,
  449. ),
  450. dtype=np.float32,
  451. )
  452. tbo_map = np.zeros((h, w, 5), dtype=np.float32)
  453. training_mask = np.ones(
  454. (
  455. h,
  456. w,
  457. ),
  458. dtype=np.float32,
  459. )
  460. direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
  461. [1, 1, 3]
  462. ).astype(np.float32)
  463. label_idx = 0
  464. score_label_map_text_label_list = []
  465. pos_list, pos_mask, label_list = [], [], []
  466. for poly_idx, poly_tag in enumerate(zip(polys, tags)):
  467. poly = poly_tag[0]
  468. tag = poly_tag[1]
  469. # generate min_area_quad
  470. min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
  471. min_area_quad_h = 0.5 * (
  472. np.linalg.norm(min_area_quad[0] - min_area_quad[3])
  473. + np.linalg.norm(min_area_quad[1] - min_area_quad[2])
  474. )
  475. min_area_quad_w = 0.5 * (
  476. np.linalg.norm(min_area_quad[0] - min_area_quad[1])
  477. + np.linalg.norm(min_area_quad[2] - min_area_quad[3])
  478. )
  479. if (
  480. min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio
  481. or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio
  482. ):
  483. continue
  484. if tag:
  485. cv2.fillPoly(
  486. training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15
  487. )
  488. else:
  489. text_label = text_strs[poly_idx]
  490. text_label = self.prepare_text_label(text_label, self.Lexicon_Table)
  491. text_label_index_list = [
  492. [self.Lexicon_Table.index(c_)]
  493. for c_ in text_label
  494. if c_ in self.Lexicon_Table
  495. ]
  496. if len(text_label_index_list) < 1:
  497. continue
  498. tcl_poly = self.poly2tcl(poly, tcl_ratio)
  499. tcl_quads = self.poly2quads(tcl_poly)
  500. poly_quads = self.poly2quads(poly)
  501. stcl_quads, quad_index = self.shrink_poly_along_width(
  502. tcl_quads,
  503. shrink_ratio_of_width=shrink_ratio_of_width,
  504. expand_height_ratio=1.0 / tcl_ratio,
  505. )
  506. cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
  507. cv2.fillPoly(
  508. score_map_big, np.round(stcl_quads / ds_ratio).astype(np.int32), 1.0
  509. )
  510. for idx, quad in enumerate(stcl_quads):
  511. quad_mask = np.zeros((h, w), dtype=np.float32)
  512. quad_mask = cv2.fillPoly(
  513. quad_mask,
  514. np.round(quad[np.newaxis, :, :]).astype(np.int32),
  515. 1.0,
  516. )
  517. tbo_map = self.gen_quad_tbo(
  518. poly_quads[quad_index[idx]], quad_mask, tbo_map
  519. )
  520. # score label map and score_label_map_text_label_list for refine
  521. if label_idx == 0:
  522. text_pos_list_ = [
  523. [len(self.Lexicon_Table)],
  524. ]
  525. score_label_map_text_label_list.append(text_pos_list_)
  526. label_idx += 1
  527. cv2.fillPoly(
  528. score_label_map, np.round(poly_quads).astype(np.int32), label_idx
  529. )
  530. score_label_map_text_label_list.append(text_label_index_list)
  531. # direction info, fix-me
  532. n_char = len(text_label_index_list)
  533. direction_map = self.generate_direction_map(
  534. poly_quads, n_char, direction_map
  535. )
  536. # pos info
  537. average_shrink_height = self.calculate_average_height(stcl_quads)
  538. if self.point_gather_mode == "align":
  539. self.f_direction = direction_map[:, :, :-1].copy()
  540. pos_res = self.fit_and_gather_tcl_points_v3(
  541. min_area_quad,
  542. stcl_quads,
  543. max_h=h,
  544. max_w=w,
  545. fixed_point_num=64,
  546. img_id=self.img_id,
  547. reference_height=average_shrink_height,
  548. )
  549. if pos_res is None:
  550. continue
  551. pos_l, pos_m = pos_res[0], pos_res[1]
  552. else:
  553. pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
  554. min_area_quad,
  555. poly,
  556. max_h=h,
  557. max_w=w,
  558. fixed_point_num=64,
  559. img_id=self.img_id,
  560. reference_height=average_shrink_height,
  561. )
  562. label_l = text_label_index_list
  563. if len(text_label_index_list) < 2:
  564. continue
  565. pos_list.append(pos_l)
  566. pos_mask.append(pos_m)
  567. label_list.append(label_l)
  568. # use big score_map for smooth tcl lines
  569. score_map_big_resized = cv2.resize(
  570. score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio
  571. )
  572. score_map = np.array(score_map_big_resized > 1e-3, dtype="float32")
  573. return (
  574. score_map,
  575. score_label_map,
  576. tbo_map,
  577. direction_map,
  578. training_mask,
  579. pos_list,
  580. pos_mask,
  581. label_list,
  582. score_label_map_text_label_list,
  583. )
  584. def adjust_point(self, poly):
  585. """
  586. adjust point order.
  587. """
  588. point_num = poly.shape[0]
  589. if point_num == 4:
  590. len_1 = np.linalg.norm(poly[0] - poly[1])
  591. len_2 = np.linalg.norm(poly[1] - poly[2])
  592. len_3 = np.linalg.norm(poly[2] - poly[3])
  593. len_4 = np.linalg.norm(poly[3] - poly[0])
  594. if (len_1 + len_3) * 1.5 < (len_2 + len_4):
  595. poly = poly[[1, 2, 3, 0], :]
  596. elif point_num > 4:
  597. vector_1 = poly[0] - poly[1]
  598. vector_2 = poly[1] - poly[2]
  599. cos_theta = np.dot(vector_1, vector_2) / (
  600. np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6
  601. )
  602. theta = np.arccos(np.round(cos_theta, decimals=4))
  603. if abs(theta) > (70 / 180 * math.pi):
  604. index = list(range(1, point_num)) + [0]
  605. poly = poly[np.array(index), :]
  606. return poly
  607. def gen_min_area_quad_from_poly(self, poly):
  608. """
  609. Generate min area quad from poly.
  610. """
  611. point_num = poly.shape[0]
  612. min_area_quad = np.zeros((4, 2), dtype=np.float32)
  613. if point_num == 4:
  614. min_area_quad = poly
  615. center_point = np.sum(poly, axis=0) / 4
  616. else:
  617. rect = cv2.minAreaRect(
  618. poly.astype(np.int32)
  619. ) # (center (x,y), (width, height), angle of rotation)
  620. center_point = rect[0]
  621. box = np.array(cv2.boxPoints(rect))
  622. first_point_idx = 0
  623. min_dist = 1e4
  624. for i in range(4):
  625. dist = (
  626. np.linalg.norm(box[(i + 0) % 4] - poly[0])
  627. + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1])
  628. + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2])
  629. + np.linalg.norm(box[(i + 3) % 4] - poly[-1])
  630. )
  631. if dist < min_dist:
  632. min_dist = dist
  633. first_point_idx = i
  634. for i in range(4):
  635. min_area_quad[i] = box[(first_point_idx + i) % 4]
  636. return min_area_quad, center_point
  637. def shrink_quad_along_width(self, quad, begin_width_ratio=0.0, end_width_ratio=1.0):
  638. """
  639. Generate shrink_quad_along_width.
  640. """
  641. ratio_pair = np.array(
  642. [[begin_width_ratio], [end_width_ratio]], dtype=np.float32
  643. )
  644. p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
  645. p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
  646. return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
  647. def shrink_poly_along_width(
  648. self, quads, shrink_ratio_of_width, expand_height_ratio=1.0
  649. ):
  650. """
  651. shrink poly with given length.
  652. """
  653. upper_edge_list = []
  654. def get_cut_info(edge_len_list, cut_len):
  655. for idx, edge_len in enumerate(edge_len_list):
  656. cut_len -= edge_len
  657. if cut_len <= 0.000001:
  658. ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
  659. return idx, ratio
  660. for quad in quads:
  661. upper_edge_len = np.linalg.norm(quad[0] - quad[1])
  662. upper_edge_list.append(upper_edge_len)
  663. # length of left edge and right edge.
  664. left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
  665. right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
  666. shrink_length = (
  667. min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
  668. )
  669. # shrinking length
  670. upper_len_left = shrink_length
  671. upper_len_right = sum(upper_edge_list) - shrink_length
  672. left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
  673. left_quad = self.shrink_quad_along_width(
  674. quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1
  675. )
  676. right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
  677. right_quad = self.shrink_quad_along_width(
  678. quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio
  679. )
  680. out_quad_list = []
  681. if left_idx == right_idx:
  682. out_quad_list.append(
  683. [left_quad[0], right_quad[1], right_quad[2], left_quad[3]]
  684. )
  685. else:
  686. out_quad_list.append(left_quad)
  687. for idx in range(left_idx + 1, right_idx):
  688. out_quad_list.append(quads[idx])
  689. out_quad_list.append(right_quad)
  690. return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
  691. def prepare_text_label(self, label_str, Lexicon_Table):
  692. """
  693. Prepare text label by given Lexicon_Table.
  694. """
  695. if len(Lexicon_Table) == 36:
  696. return label_str.lower()
  697. else:
  698. return label_str
  699. def vector_angle(self, A, B):
  700. """
  701. Calculate the angle between vector AB and x-axis positive direction.
  702. """
  703. AB = np.array([B[1] - A[1], B[0] - A[0]])
  704. return np.arctan2(*AB)
  705. def theta_line_cross_point(self, theta, point):
  706. """
  707. Calculate the line through given point and angle in ax + by + c =0 form.
  708. """
  709. x, y = point
  710. cos = np.cos(theta)
  711. sin = np.sin(theta)
  712. return [sin, -cos, cos * y - sin * x]
  713. def line_cross_two_point(self, A, B):
  714. """
  715. Calculate the line through given point A and B in ax + by + c =0 form.
  716. """
  717. angle = self.vector_angle(A, B)
  718. return self.theta_line_cross_point(angle, A)
  719. def average_angle(self, poly):
  720. """
  721. Calculate the average angle between left and right edge in given poly.
  722. """
  723. p0, p1, p2, p3 = poly
  724. angle30 = self.vector_angle(p3, p0)
  725. angle21 = self.vector_angle(p2, p1)
  726. return (angle30 + angle21) / 2
  727. def line_cross_point(self, line1, line2):
  728. """
  729. line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
  730. """
  731. a1, b1, c1 = line1
  732. a2, b2, c2 = line2
  733. d = a1 * b2 - a2 * b1
  734. if d == 0:
  735. print("Cross point does not exist")
  736. return np.array([0, 0], dtype=np.float32)
  737. else:
  738. x = (b1 * c2 - b2 * c1) / d
  739. y = (a2 * c1 - a1 * c2) / d
  740. return np.array([x, y], dtype=np.float32)
  741. def quad2tcl(self, poly, ratio):
  742. """
  743. Generate center line by poly clock-wise point. (4, 2)
  744. """
  745. ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
  746. p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
  747. p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
  748. return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
  749. def poly2tcl(self, poly, ratio):
  750. """
  751. Generate center line by poly clock-wise point.
  752. """
  753. ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
  754. tcl_poly = np.zeros_like(poly)
  755. point_num = poly.shape[0]
  756. for idx in range(point_num // 2):
  757. point_pair = (
  758. poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
  759. )
  760. tcl_poly[idx] = point_pair[0]
  761. tcl_poly[point_num - 1 - idx] = point_pair[1]
  762. return tcl_poly
  763. def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
  764. """
  765. Generate tbo_map for give quad.
  766. """
  767. # upper and lower line function: ax + by + c = 0;
  768. up_line = self.line_cross_two_point(quad[0], quad[1])
  769. lower_line = self.line_cross_two_point(quad[3], quad[2])
  770. quad_h = 0.5 * (
  771. np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
  772. )
  773. quad_w = 0.5 * (
  774. np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])
  775. )
  776. # average angle of left and right line.
  777. angle = self.average_angle(quad)
  778. xy_in_poly = np.argwhere(tcl_mask == 1)
  779. for y, x in xy_in_poly:
  780. point = (x, y)
  781. line = self.theta_line_cross_point(angle, point)
  782. cross_point_upper = self.line_cross_point(up_line, line)
  783. cross_point_lower = self.line_cross_point(lower_line, line)
  784. ##FIX, offset reverse
  785. upper_offset_x, upper_offset_y = cross_point_upper - point
  786. lower_offset_x, lower_offset_y = cross_point_lower - point
  787. tbo_map[y, x, 0] = upper_offset_y
  788. tbo_map[y, x, 1] = upper_offset_x
  789. tbo_map[y, x, 2] = lower_offset_y
  790. tbo_map[y, x, 3] = lower_offset_x
  791. tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
  792. return tbo_map
  793. def poly2quads(self, poly):
  794. """
  795. Split poly into quads.
  796. """
  797. quad_list = []
  798. point_num = poly.shape[0]
  799. # point pair
  800. point_pair_list = []
  801. for idx in range(point_num // 2):
  802. point_pair = [poly[idx], poly[point_num - 1 - idx]]
  803. point_pair_list.append(point_pair)
  804. quad_num = point_num // 2 - 1
  805. for idx in range(quad_num):
  806. # reshape and adjust to clock-wise
  807. quad_list.append(
  808. (np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]]
  809. )
  810. return np.array(quad_list)
  811. def rotate_im_poly(self, im, text_polys):
  812. """
  813. rotate image with 90 / 180 / 270 degre
  814. """
  815. im_w, im_h = im.shape[1], im.shape[0]
  816. dst_im = im.copy()
  817. dst_polys = []
  818. rand_degree_ratio = np.random.rand()
  819. rand_degree_cnt = 1
  820. if rand_degree_ratio > 0.5:
  821. rand_degree_cnt = 3
  822. for i in range(rand_degree_cnt):
  823. dst_im = np.rot90(dst_im)
  824. rot_degree = -90 * rand_degree_cnt
  825. rot_angle = rot_degree * math.pi / 180.0
  826. n_poly = text_polys.shape[0]
  827. cx, cy = 0.5 * im_w, 0.5 * im_h
  828. ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
  829. for i in range(n_poly):
  830. wordBB = text_polys[i]
  831. poly = []
  832. for j in range(4): # 16->4
  833. sx, sy = wordBB[j][0], wordBB[j][1]
  834. dx = (
  835. math.cos(rot_angle) * (sx - cx)
  836. - math.sin(rot_angle) * (sy - cy)
  837. + ncx
  838. )
  839. dy = (
  840. math.sin(rot_angle) * (sx - cx)
  841. + math.cos(rot_angle) * (sy - cy)
  842. + ncy
  843. )
  844. poly.append([dx, dy])
  845. dst_polys.append(poly)
  846. return dst_im, np.array(dst_polys, dtype=np.float32)
  847. def __call__(self, data):
  848. input_size = 512
  849. im = data["image"]
  850. text_polys = data["polys"]
  851. text_tags = data["ignore_tags"]
  852. text_strs = data["texts"]
  853. h, w, _ = im.shape
  854. text_polys, text_tags, hv_tags = self.check_and_validate_polys(
  855. text_polys, text_tags, (h, w)
  856. )
  857. if text_polys.shape[0] <= 0:
  858. return None
  859. # set aspect ratio and keep area fix
  860. asp_scales = np.arange(1.0, 1.55, 0.1)
  861. asp_scale = np.random.choice(asp_scales)
  862. if np.random.rand() < 0.5:
  863. asp_scale = 1.0 / asp_scale
  864. asp_scale = math.sqrt(asp_scale)
  865. asp_wx = asp_scale
  866. asp_hy = 1.0 / asp_scale
  867. im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
  868. text_polys[:, :, 0] *= asp_wx
  869. text_polys[:, :, 1] *= asp_hy
  870. if self.use_resize is True:
  871. ori_h, ori_w, _ = im.shape
  872. if max(ori_h, ori_w) < 200:
  873. ratio = 200 / max(ori_h, ori_w)
  874. im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
  875. text_polys[:, :, 0] *= ratio
  876. text_polys[:, :, 1] *= ratio
  877. if max(ori_h, ori_w) > 512:
  878. ratio = 512 / max(ori_h, ori_w)
  879. im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
  880. text_polys[:, :, 0] *= ratio
  881. text_polys[:, :, 1] *= ratio
  882. elif self.use_random_crop is True:
  883. h, w, _ = im.shape
  884. if max(h, w) > 2048:
  885. rd_scale = 2048.0 / max(h, w)
  886. im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
  887. text_polys *= rd_scale
  888. h, w, _ = im.shape
  889. if min(h, w) < 16:
  890. return None
  891. # no background
  892. im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
  893. im, text_polys, text_tags, hv_tags, text_strs, crop_background=False
  894. )
  895. if text_polys.shape[0] == 0:
  896. return None
  897. # continue for all ignore case
  898. if np.sum((text_tags * 1.0)) >= text_tags.size:
  899. return None
  900. new_h, new_w, _ = im.shape
  901. if (new_h is None) or (new_w is None):
  902. return None
  903. # resize image
  904. std_ratio = float(input_size) / max(new_w, new_h)
  905. rand_scales = np.array(
  906. [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]
  907. )
  908. rz_scale = std_ratio * np.random.choice(rand_scales)
  909. im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
  910. text_polys[:, :, 0] *= rz_scale
  911. text_polys[:, :, 1] *= rz_scale
  912. # add gaussian blur
  913. if np.random.rand() < 0.1 * 0.5:
  914. ks = np.random.permutation(5)[0] + 1
  915. ks = int(ks / 2) * 2 + 1
  916. im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
  917. # add brighter
  918. if np.random.rand() < 0.1 * 0.5:
  919. im = im * (1.0 + np.random.rand() * 0.5)
  920. im = np.clip(im, 0.0, 255.0)
  921. # add darker
  922. if np.random.rand() < 0.1 * 0.5:
  923. im = im * (1.0 - np.random.rand() * 0.5)
  924. im = np.clip(im, 0.0, 255.0)
  925. # Padding the im to [input_size, input_size]
  926. new_h, new_w, _ = im.shape
  927. if min(new_w, new_h) < input_size * 0.5:
  928. return None
  929. im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
  930. im_padded[:, :, 2] = 0.485 * 255
  931. im_padded[:, :, 1] = 0.456 * 255
  932. im_padded[:, :, 0] = 0.406 * 255
  933. # Random the start position
  934. del_h = input_size - new_h
  935. del_w = input_size - new_w
  936. sh, sw = 0, 0
  937. if del_h > 1:
  938. sh = int(np.random.rand() * del_h)
  939. if del_w > 1:
  940. sw = int(np.random.rand() * del_w)
  941. # Padding
  942. im_padded[sh : sh + new_h, sw : sw + new_w, :] = im.copy()
  943. text_polys[:, :, 0] += sw
  944. text_polys[:, :, 1] += sh
  945. (
  946. score_map,
  947. score_label_map,
  948. border_map,
  949. direction_map,
  950. training_mask,
  951. pos_list,
  952. pos_mask,
  953. label_list,
  954. score_label_map_text_label,
  955. ) = self.generate_tcl_ctc_label(
  956. input_size, input_size, text_polys, text_tags, text_strs, 0.25
  957. )
  958. if len(label_list) <= 0: # eliminate negative samples
  959. return None
  960. pos_list_temp = np.zeros([64, 3])
  961. pos_mask_temp = np.zeros([64, 1])
  962. label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
  963. for i, label in enumerate(label_list):
  964. n = len(label)
  965. if n > self.max_text_length:
  966. label_list[i] = label[: self.max_text_length]
  967. continue
  968. while n < self.max_text_length:
  969. label.append([self.pad_num])
  970. n += 1
  971. for i in range(len(label_list)):
  972. label_list[i] = np.array(label_list[i])
  973. if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
  974. return None
  975. for __ in range(self.max_text_nums - len(pos_list), 0, -1):
  976. pos_list.append(pos_list_temp)
  977. pos_mask.append(pos_mask_temp)
  978. label_list.append(label_list_temp)
  979. if self.img_id == self.batch_size - 1:
  980. self.img_id = 0
  981. else:
  982. self.img_id += 1
  983. im_padded[:, :, 2] -= 0.485 * 255
  984. im_padded[:, :, 1] -= 0.456 * 255
  985. im_padded[:, :, 0] -= 0.406 * 255
  986. im_padded[:, :, 2] /= 255.0 * 0.229
  987. im_padded[:, :, 1] /= 255.0 * 0.224
  988. im_padded[:, :, 0] /= 255.0 * 0.225
  989. im_padded = im_padded.transpose((2, 0, 1))
  990. images = im_padded[::-1, :, :]
  991. tcl_maps = score_map[np.newaxis, :, :]
  992. tcl_label_maps = score_label_map[np.newaxis, :, :]
  993. border_maps = border_map.transpose((2, 0, 1))
  994. direction_maps = direction_map.transpose((2, 0, 1))
  995. training_masks = training_mask[np.newaxis, :, :]
  996. pos_list = np.array(pos_list)
  997. pos_mask = np.array(pos_mask)
  998. label_list = np.array(label_list)
  999. data["images"] = images
  1000. data["tcl_maps"] = tcl_maps
  1001. data["tcl_label_maps"] = tcl_label_maps
  1002. data["border_maps"] = border_maps
  1003. data["direction_maps"] = direction_maps
  1004. data["training_masks"] = training_masks
  1005. data["label_list"] = label_list
  1006. data["pos_list"] = pos_list
  1007. data["pos_mask"] = pos_mask
  1008. return data