east_process.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is referred from:
  16. https://github.com/songdejia/EAST/blob/master/data_utils.py
  17. """
  18. import math
  19. import cv2
  20. import numpy as np
  21. import json
  22. import sys
  23. import os
  24. __all__ = ["EASTProcessTrain"]
  25. class EASTProcessTrain(object):
  26. def __init__(
  27. self,
  28. image_shape=[512, 512],
  29. background_ratio=0.125,
  30. min_crop_side_ratio=0.1,
  31. min_text_size=10,
  32. **kwargs,
  33. ):
  34. self.input_size = image_shape[1]
  35. self.random_scale = np.array([0.5, 1, 2.0, 3.0])
  36. self.background_ratio = background_ratio
  37. self.min_crop_side_ratio = min_crop_side_ratio
  38. self.min_text_size = min_text_size
  39. def preprocess(self, im):
  40. input_size = self.input_size
  41. im_shape = im.shape
  42. im_size_min = np.min(im_shape[0:2])
  43. im_size_max = np.max(im_shape[0:2])
  44. im_scale = float(input_size) / float(im_size_max)
  45. im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
  46. img_mean = [0.485, 0.456, 0.406]
  47. img_std = [0.229, 0.224, 0.225]
  48. # im = im[:, :, ::-1].astype(np.float32)
  49. im = im / 255
  50. im -= img_mean
  51. im /= img_std
  52. new_h, new_w, _ = im.shape
  53. im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
  54. im_padded[:new_h, :new_w, :] = im
  55. im_padded = im_padded.transpose((2, 0, 1))
  56. im_padded = im_padded[np.newaxis, :]
  57. return im_padded, im_scale
  58. def rotate_im_poly(self, im, text_polys):
  59. """
  60. rotate image with 90 / 180 / 270 degre
  61. """
  62. im_w, im_h = im.shape[1], im.shape[0]
  63. dst_im = im.copy()
  64. dst_polys = []
  65. rand_degree_ratio = np.random.rand()
  66. rand_degree_cnt = 1
  67. if 0.333 < rand_degree_ratio < 0.666:
  68. rand_degree_cnt = 2
  69. elif rand_degree_ratio > 0.666:
  70. rand_degree_cnt = 3
  71. for i in range(rand_degree_cnt):
  72. dst_im = np.rot90(dst_im)
  73. rot_degree = -90 * rand_degree_cnt
  74. rot_angle = rot_degree * math.pi / 180.0
  75. n_poly = text_polys.shape[0]
  76. cx, cy = 0.5 * im_w, 0.5 * im_h
  77. ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
  78. for i in range(n_poly):
  79. wordBB = text_polys[i]
  80. poly = []
  81. for j in range(4):
  82. sx, sy = wordBB[j][0], wordBB[j][1]
  83. dx = (
  84. math.cos(rot_angle) * (sx - cx)
  85. - math.sin(rot_angle) * (sy - cy)
  86. + ncx
  87. )
  88. dy = (
  89. math.sin(rot_angle) * (sx - cx)
  90. + math.cos(rot_angle) * (sy - cy)
  91. + ncy
  92. )
  93. poly.append([dx, dy])
  94. dst_polys.append(poly)
  95. dst_polys = np.array(dst_polys, dtype=np.float32)
  96. return dst_im, dst_polys
  97. def polygon_area(self, poly):
  98. """
  99. compute area of a polygon
  100. :param poly:
  101. :return:
  102. """
  103. edge = [
  104. (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
  105. (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
  106. (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
  107. (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]),
  108. ]
  109. return np.sum(edge) / 2.0
  110. def check_and_validate_polys(self, polys, tags, img_height, img_width):
  111. """
  112. check so that the text poly is in the same direction,
  113. and also filter some invalid polygons
  114. :param polys:
  115. :param tags:
  116. :return:
  117. """
  118. h, w = img_height, img_width
  119. if polys.shape[0] == 0:
  120. return polys
  121. polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
  122. polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
  123. validated_polys = []
  124. validated_tags = []
  125. for poly, tag in zip(polys, tags):
  126. p_area = self.polygon_area(poly)
  127. # invalid poly
  128. if abs(p_area) < 1:
  129. continue
  130. if p_area > 0:
  131. #'poly in wrong direction'
  132. if not tag:
  133. tag = True # reversed cases should be ignore
  134. poly = poly[(0, 3, 2, 1), :]
  135. validated_polys.append(poly)
  136. validated_tags.append(tag)
  137. return np.array(validated_polys), np.array(validated_tags)
  138. def draw_img_polys(self, img, polys):
  139. if len(img.shape) == 4:
  140. img = np.squeeze(img, axis=0)
  141. if img.shape[0] == 3:
  142. img = img.transpose((1, 2, 0))
  143. img[:, :, 2] += 123.68
  144. img[:, :, 1] += 116.78
  145. img[:, :, 0] += 103.94
  146. cv2.imwrite("tmp.jpg", img)
  147. img = cv2.imread("tmp.jpg")
  148. for box in polys:
  149. box = box.astype(np.int32).reshape((-1, 1, 2))
  150. cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
  151. import random
  152. ino = random.randint(0, 100)
  153. cv2.imwrite("tmp_%d.jpg" % ino, img)
  154. return
  155. def shrink_poly(self, poly, r):
  156. """
  157. fit a poly inside the origin poly, maybe bugs here...
  158. used for generate the score map
  159. :param poly: the text poly
  160. :param r: r in the paper
  161. :return: the shrunk poly
  162. """
  163. # shrink ratio
  164. R = 0.3
  165. # find the longer pair
  166. dist0 = np.linalg.norm(poly[0] - poly[1])
  167. dist1 = np.linalg.norm(poly[2] - poly[3])
  168. dist2 = np.linalg.norm(poly[0] - poly[3])
  169. dist3 = np.linalg.norm(poly[1] - poly[2])
  170. if dist0 + dist1 > dist2 + dist3:
  171. # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
  172. ## p0, p1
  173. theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
  174. poly[0][0] += R * r[0] * np.cos(theta)
  175. poly[0][1] += R * r[0] * np.sin(theta)
  176. poly[1][0] -= R * r[1] * np.cos(theta)
  177. poly[1][1] -= R * r[1] * np.sin(theta)
  178. ## p2, p3
  179. theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
  180. poly[3][0] += R * r[3] * np.cos(theta)
  181. poly[3][1] += R * r[3] * np.sin(theta)
  182. poly[2][0] -= R * r[2] * np.cos(theta)
  183. poly[2][1] -= R * r[2] * np.sin(theta)
  184. ## p0, p3
  185. theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
  186. poly[0][0] += R * r[0] * np.sin(theta)
  187. poly[0][1] += R * r[0] * np.cos(theta)
  188. poly[3][0] -= R * r[3] * np.sin(theta)
  189. poly[3][1] -= R * r[3] * np.cos(theta)
  190. ## p1, p2
  191. theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
  192. poly[1][0] += R * r[1] * np.sin(theta)
  193. poly[1][1] += R * r[1] * np.cos(theta)
  194. poly[2][0] -= R * r[2] * np.sin(theta)
  195. poly[2][1] -= R * r[2] * np.cos(theta)
  196. else:
  197. ## p0, p3
  198. # print poly
  199. theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
  200. poly[0][0] += R * r[0] * np.sin(theta)
  201. poly[0][1] += R * r[0] * np.cos(theta)
  202. poly[3][0] -= R * r[3] * np.sin(theta)
  203. poly[3][1] -= R * r[3] * np.cos(theta)
  204. ## p1, p2
  205. theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
  206. poly[1][0] += R * r[1] * np.sin(theta)
  207. poly[1][1] += R * r[1] * np.cos(theta)
  208. poly[2][0] -= R * r[2] * np.sin(theta)
  209. poly[2][1] -= R * r[2] * np.cos(theta)
  210. ## p0, p1
  211. theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
  212. poly[0][0] += R * r[0] * np.cos(theta)
  213. poly[0][1] += R * r[0] * np.sin(theta)
  214. poly[1][0] -= R * r[1] * np.cos(theta)
  215. poly[1][1] -= R * r[1] * np.sin(theta)
  216. ## p2, p3
  217. theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
  218. poly[3][0] += R * r[3] * np.cos(theta)
  219. poly[3][1] += R * r[3] * np.sin(theta)
  220. poly[2][0] -= R * r[2] * np.cos(theta)
  221. poly[2][1] -= R * r[2] * np.sin(theta)
  222. return poly
  223. def generate_quad(self, im_size, polys, tags):
  224. """
  225. Generate quadrangle.
  226. """
  227. h, w = im_size
  228. poly_mask = np.zeros((h, w), dtype=np.uint8)
  229. score_map = np.zeros((h, w), dtype=np.uint8)
  230. # (x1, y1, ..., x4, y4, short_edge_norm)
  231. geo_map = np.zeros((h, w, 9), dtype=np.float32)
  232. # mask used during training, to ignore some hard areas
  233. training_mask = np.ones((h, w), dtype=np.uint8)
  234. for poly_idx, poly_tag in enumerate(zip(polys, tags)):
  235. poly = poly_tag[0]
  236. tag = poly_tag[1]
  237. r = [None, None, None, None]
  238. for i in range(4):
  239. dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
  240. dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
  241. r[i] = min(dist1, dist2)
  242. # score map
  243. shrinked_poly = self.shrink_poly(poly.copy(), r).astype(np.int32)[
  244. np.newaxis, :, :
  245. ]
  246. cv2.fillPoly(score_map, shrinked_poly, 1)
  247. cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
  248. # if the poly is too small, then ignore it during training
  249. poly_h = min(
  250. np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])
  251. )
  252. poly_w = min(
  253. np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])
  254. )
  255. if min(poly_h, poly_w) < self.min_text_size:
  256. cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
  257. if tag:
  258. cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
  259. xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
  260. # geo map.
  261. y_in_poly = xy_in_poly[:, 0]
  262. x_in_poly = xy_in_poly[:, 1]
  263. poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
  264. poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
  265. for pno in range(4):
  266. geo_channel_beg = pno * 2
  267. geo_map[y_in_poly, x_in_poly, geo_channel_beg] = (
  268. x_in_poly - poly[pno, 0]
  269. )
  270. geo_map[y_in_poly, x_in_poly, geo_channel_beg + 1] = (
  271. y_in_poly - poly[pno, 1]
  272. )
  273. geo_map[y_in_poly, x_in_poly, 8] = 1.0 / max(min(poly_h, poly_w), 1.0)
  274. return score_map, geo_map, training_mask
  275. def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
  276. """
  277. make random crop from the input image
  278. :param im:
  279. :param polys:
  280. :param tags:
  281. :param crop_background:
  282. :param max_tries:
  283. :return:
  284. """
  285. h, w, _ = im.shape
  286. pad_h = h // 10
  287. pad_w = w // 10
  288. h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
  289. w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
  290. for poly in polys:
  291. poly = np.round(poly, decimals=0).astype(np.int32)
  292. minx = np.min(poly[:, 0])
  293. maxx = np.max(poly[:, 0])
  294. w_array[minx + pad_w : maxx + pad_w] = 1
  295. miny = np.min(poly[:, 1])
  296. maxy = np.max(poly[:, 1])
  297. h_array[miny + pad_h : maxy + pad_h] = 1
  298. # ensure the cropped area not across a text
  299. h_axis = np.where(h_array == 0)[0]
  300. w_axis = np.where(w_array == 0)[0]
  301. if len(h_axis) == 0 or len(w_axis) == 0:
  302. return im, polys, tags
  303. for i in range(max_tries):
  304. xx = np.random.choice(w_axis, size=2)
  305. xmin = np.min(xx) - pad_w
  306. xmax = np.max(xx) - pad_w
  307. xmin = np.clip(xmin, 0, w - 1)
  308. xmax = np.clip(xmax, 0, w - 1)
  309. yy = np.random.choice(h_axis, size=2)
  310. ymin = np.min(yy) - pad_h
  311. ymax = np.max(yy) - pad_h
  312. ymin = np.clip(ymin, 0, h - 1)
  313. ymax = np.clip(ymax, 0, h - 1)
  314. if (
  315. xmax - xmin < self.min_crop_side_ratio * w
  316. or ymax - ymin < self.min_crop_side_ratio * h
  317. ):
  318. # area too small
  319. continue
  320. if polys.shape[0] != 0:
  321. poly_axis_in_area = (
  322. (polys[:, :, 0] >= xmin)
  323. & (polys[:, :, 0] <= xmax)
  324. & (polys[:, :, 1] >= ymin)
  325. & (polys[:, :, 1] <= ymax)
  326. )
  327. selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
  328. else:
  329. selected_polys = []
  330. if len(selected_polys) == 0:
  331. # no text in this area
  332. if crop_background:
  333. im = im[ymin : ymax + 1, xmin : xmax + 1, :]
  334. polys = []
  335. tags = []
  336. return im, polys, tags
  337. else:
  338. continue
  339. im = im[ymin : ymax + 1, xmin : xmax + 1, :]
  340. polys = polys[selected_polys]
  341. tags = tags[selected_polys]
  342. polys[:, :, 0] -= xmin
  343. polys[:, :, 1] -= ymin
  344. return im, polys, tags
  345. return im, polys, tags
  346. def crop_background_infor(self, im, text_polys, text_tags):
  347. im, text_polys, text_tags = self.crop_area(
  348. im, text_polys, text_tags, crop_background=True
  349. )
  350. if len(text_polys) > 0:
  351. return None
  352. # pad and resize image
  353. input_size = self.input_size
  354. im, ratio = self.preprocess(im)
  355. score_map = np.zeros((input_size, input_size), dtype=np.float32)
  356. geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
  357. training_mask = np.ones((input_size, input_size), dtype=np.float32)
  358. return im, score_map, geo_map, training_mask
  359. def crop_foreground_infor(self, im, text_polys, text_tags):
  360. im, text_polys, text_tags = self.crop_area(
  361. im, text_polys, text_tags, crop_background=False
  362. )
  363. if text_polys.shape[0] == 0:
  364. return None
  365. # continue for all ignore case
  366. if np.sum((text_tags * 1.0)) >= text_tags.size:
  367. return None
  368. # pad and resize image
  369. input_size = self.input_size
  370. im, ratio = self.preprocess(im)
  371. text_polys[:, :, 0] *= ratio
  372. text_polys[:, :, 1] *= ratio
  373. _, _, new_h, new_w = im.shape
  374. # print(im.shape)
  375. # self.draw_img_polys(im, text_polys)
  376. score_map, geo_map, training_mask = self.generate_quad(
  377. (new_h, new_w), text_polys, text_tags
  378. )
  379. return im, score_map, geo_map, training_mask
  380. def __call__(self, data):
  381. im = data["image"]
  382. text_polys = data["polys"]
  383. text_tags = data["ignore_tags"]
  384. if im is None:
  385. return None
  386. if text_polys.shape[0] == 0:
  387. return None
  388. # add rotate cases
  389. if np.random.rand() < 0.5:
  390. im, text_polys = self.rotate_im_poly(im, text_polys)
  391. h, w, _ = im.shape
  392. text_polys, text_tags = self.check_and_validate_polys(
  393. text_polys, text_tags, h, w
  394. )
  395. if text_polys.shape[0] == 0:
  396. return None
  397. # random scale this image
  398. rd_scale = np.random.choice(self.random_scale)
  399. im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
  400. text_polys *= rd_scale
  401. if np.random.rand() < self.background_ratio:
  402. outs = self.crop_background_infor(im, text_polys, text_tags)
  403. else:
  404. outs = self.crop_foreground_infor(im, text_polys, text_tags)
  405. if outs is None:
  406. return None
  407. im, score_map, geo_map, training_mask = outs
  408. score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
  409. geo_map = np.swapaxes(geo_map, 1, 2)
  410. geo_map = np.swapaxes(geo_map, 1, 0)
  411. geo_map = geo_map[:, ::4, ::4].astype(np.float32)
  412. training_mask = training_mask[np.newaxis, ::4, ::4]
  413. training_mask = training_mask.astype(np.float32)
  414. data["image"] = im[0]
  415. data["score_map"] = score_map
  416. data["geo_map"] = geo_map
  417. data["training_mask"] = training_mask
  418. return data