utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. import numpy as np
  2. import cv2
  3. import argparse
  4. import math
  5. from PIL import Image, ImageDraw, ImageFont
  6. from pathlib import Path
  7. # 获取当前文件所在的目录
  8. module_dir = Path(__file__).resolve().parent
  9. def get_rotate_crop_image(img, points):
  10. """
  11. img_height, img_width = img.shape[0:2]
  12. left = int(np.min(points[:, 0]))
  13. right = int(np.max(points[:, 0]))
  14. top = int(np.min(points[:, 1]))
  15. bottom = int(np.max(points[:, 1]))
  16. img_crop = img[top:bottom, left:right, :].copy()
  17. points[:, 0] = points[:, 0] - left
  18. points[:, 1] = points[:, 1] - top
  19. """
  20. assert len(points) == 4, "shape of points must be 4*2"
  21. img_crop_width = int(
  22. max(
  23. np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3])
  24. )
  25. )
  26. img_crop_height = int(
  27. max(
  28. np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2])
  29. )
  30. )
  31. pts_std = np.float32(
  32. [
  33. [0, 0],
  34. [img_crop_width, 0],
  35. [img_crop_width, img_crop_height],
  36. [0, img_crop_height],
  37. ]
  38. )
  39. M = cv2.getPerspectiveTransform(points, pts_std)
  40. dst_img = cv2.warpPerspective(
  41. img,
  42. M,
  43. (img_crop_width, img_crop_height),
  44. borderMode=cv2.BORDER_REPLICATE,
  45. flags=cv2.INTER_CUBIC,
  46. )
  47. dst_img_height, dst_img_width = dst_img.shape[0:2]
  48. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  49. dst_img = np.rot90(dst_img)
  50. return dst_img
  51. def get_minarea_rect_crop(img, points):
  52. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  53. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  54. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  55. if points[1][1] > points[0][1]:
  56. index_a = 0
  57. index_d = 1
  58. else:
  59. index_a = 1
  60. index_d = 0
  61. if points[3][1] > points[2][1]:
  62. index_b = 2
  63. index_c = 3
  64. else:
  65. index_b = 3
  66. index_c = 2
  67. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  68. crop_img = get_rotate_crop_image(img, np.array(box))
  69. return crop_img
  70. def resize_img(img, input_size=600):
  71. """
  72. resize img and limit the longest side of the image to input_size
  73. """
  74. img = np.array(img)
  75. im_shape = img.shape
  76. im_size_max = np.max(im_shape[0:2])
  77. im_scale = float(input_size) / float(im_size_max)
  78. img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
  79. return img
  80. def str_count(s):
  81. """
  82. Count the number of Chinese characters,
  83. a single English character and a single number
  84. equal to half the length of Chinese characters.
  85. args:
  86. s(string): the input of string
  87. return(int):
  88. the number of Chinese characters
  89. """
  90. import string
  91. count_zh = count_pu = 0
  92. s_len = len(str(s))
  93. en_dg_count = 0
  94. for c in str(s):
  95. if c in string.ascii_letters or c.isdigit() or c.isspace():
  96. en_dg_count += 1
  97. elif c.isalpha():
  98. count_zh += 1
  99. else:
  100. count_pu += 1
  101. return s_len - math.ceil(en_dg_count / 2)
  102. def text_visual(
  103. texts,
  104. scores,
  105. img_h=400,
  106. img_w=600,
  107. threshold=0.0,
  108. font_path=str(module_dir / "fonts/simfang.ttf"),
  109. ):
  110. """
  111. create new blank img and draw txt on it
  112. args:
  113. texts(list): the text will be draw
  114. scores(list|None): corresponding score of each txt
  115. img_h(int): the height of blank img
  116. img_w(int): the width of blank img
  117. font_path: the path of font which is used to draw text
  118. return(array):
  119. """
  120. if scores is not None:
  121. assert len(texts) == len(
  122. scores
  123. ), "The number of txts and corresponding scores must match"
  124. def create_blank_img():
  125. blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
  126. blank_img[:, img_w - 1 :] = 0
  127. blank_img = Image.fromarray(blank_img).convert("RGB")
  128. draw_txt = ImageDraw.Draw(blank_img)
  129. return blank_img, draw_txt
  130. blank_img, draw_txt = create_blank_img()
  131. font_size = 20
  132. txt_color = (0, 0, 0)
  133. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  134. gap = font_size + 5
  135. txt_img_list = []
  136. count, index = 1, 0
  137. for idx, txt in enumerate(texts):
  138. index += 1
  139. if scores[idx] < threshold or math.isnan(scores[idx]):
  140. index -= 1
  141. continue
  142. first_line = True
  143. while str_count(txt) >= img_w // font_size - 4:
  144. tmp = txt
  145. txt = tmp[: img_w // font_size - 4]
  146. if first_line:
  147. new_txt = str(index) + ": " + txt
  148. first_line = False
  149. else:
  150. new_txt = " " + txt
  151. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  152. txt = tmp[img_w // font_size - 4 :]
  153. if count >= img_h // gap - 1:
  154. txt_img_list.append(np.array(blank_img))
  155. blank_img, draw_txt = create_blank_img()
  156. count = 0
  157. count += 1
  158. if first_line:
  159. new_txt = str(index) + ": " + txt + " " + "%.3f" % (scores[idx])
  160. else:
  161. new_txt = " " + txt + " " + "%.3f" % (scores[idx])
  162. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  163. # whether add new blank img or not
  164. if count >= img_h // gap - 1 and idx + 1 < len(texts):
  165. txt_img_list.append(np.array(blank_img))
  166. blank_img, draw_txt = create_blank_img()
  167. count = 0
  168. count += 1
  169. txt_img_list.append(np.array(blank_img))
  170. if len(txt_img_list) == 1:
  171. blank_img = np.array(txt_img_list[0])
  172. else:
  173. blank_img = np.concatenate(txt_img_list, axis=1)
  174. return np.array(blank_img)
  175. def draw_ocr(
  176. image,
  177. boxes,
  178. txts=None,
  179. scores=None,
  180. drop_score=0.5,
  181. font_path=str(module_dir / "fonts/simfang.ttf"),
  182. ):
  183. """
  184. Visualize the results of OCR detection and recognition
  185. args:
  186. image(Image|array): RGB image
  187. boxes(list): boxes with shape(N, 4, 2)
  188. txts(list): the texts
  189. scores(list): txxs corresponding scores
  190. drop_score(float): only scores greater than drop_threshold will be visualized
  191. font_path: the path of font which is used to draw text
  192. return(array):
  193. the visualized img
  194. """
  195. if scores is None:
  196. scores = [1] * len(boxes)
  197. box_num = len(boxes)
  198. for i in range(box_num):
  199. if scores is not None and (scores[i] < drop_score or math.isnan(scores[i])):
  200. continue
  201. box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
  202. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  203. if txts is not None:
  204. img = np.array(resize_img(image, input_size=600))
  205. txt_img = text_visual(
  206. txts,
  207. scores,
  208. img_h=img.shape[0],
  209. img_w=600,
  210. threshold=drop_score,
  211. font_path=font_path,
  212. )
  213. img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
  214. return img
  215. return image
  216. def base64_to_cv2(b64str):
  217. import base64
  218. data = base64.b64decode(b64str.encode("utf8"))
  219. data = np.frombuffer(data, np.uint8)
  220. data = cv2.imdecode(data, cv2.IMREAD_COLOR)
  221. return data
  222. def str2bool(v):
  223. return v.lower() in ("true", "t", "1")
  224. def infer_args():
  225. parser = argparse.ArgumentParser()
  226. # params for prediction engine
  227. parser.add_argument("--use_gpu", type=str2bool, default=True)
  228. parser.add_argument("--use_xpu", type=str2bool, default=False)
  229. parser.add_argument("--use_npu", type=str2bool, default=False)
  230. parser.add_argument("--ir_optim", type=str2bool, default=True)
  231. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  232. parser.add_argument("--min_subgraph_size", type=int, default=15)
  233. parser.add_argument("--precision", type=str, default="fp32")
  234. parser.add_argument("--gpu_mem", type=int, default=500)
  235. parser.add_argument("--gpu_id", type=int, default=0)
  236. # params for text detector
  237. parser.add_argument("--image_dir", type=str)
  238. parser.add_argument("--page_num", type=int, default=0)
  239. parser.add_argument("--det_algorithm", type=str, default="DB")
  240. parser.add_argument(
  241. "--det_model_dir",
  242. type=str,
  243. default=str(module_dir / "models/ppocrv5/det/det.onnx"),
  244. )
  245. parser.add_argument("--det_limit_side_len", type=float, default=960)
  246. parser.add_argument("--det_limit_type", type=str, default="max")
  247. parser.add_argument("--det_box_type", type=str, default="quad")
  248. # DB parmas
  249. parser.add_argument("--det_db_thresh", type=float, default=0.3)
  250. parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
  251. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
  252. parser.add_argument("--max_batch_size", type=int, default=10)
  253. parser.add_argument("--use_dilation", type=str2bool, default=False)
  254. parser.add_argument("--det_db_score_mode", type=str, default="fast")
  255. # EAST parmas
  256. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  257. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  258. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  259. # SAST parmas
  260. parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
  261. parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
  262. # PSE parmas
  263. parser.add_argument("--det_pse_thresh", type=float, default=0)
  264. parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
  265. parser.add_argument("--det_pse_min_area", type=float, default=16)
  266. parser.add_argument("--det_pse_scale", type=int, default=1)
  267. # FCE parmas
  268. parser.add_argument("--scales", type=list, default=[8, 16, 32])
  269. parser.add_argument("--alpha", type=float, default=1.0)
  270. parser.add_argument("--beta", type=float, default=1.0)
  271. parser.add_argument("--fourier_degree", type=int, default=5)
  272. # params for text recognizer
  273. parser.add_argument("--rec_algorithm", type=str, default="SVTR_LCNet")
  274. parser.add_argument(
  275. "--rec_model_dir",
  276. type=str,
  277. default=str(module_dir / "models/ppocrv5/rec/rec.onnx"),
  278. )
  279. parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
  280. parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
  281. parser.add_argument("--rec_batch_num", type=int, default=6)
  282. parser.add_argument("--max_text_length", type=int, default=25)
  283. parser.add_argument(
  284. "--rec_char_dict_path",
  285. type=str,
  286. default=str(module_dir / "models/ppocrv5/ppocrv5_dict.txt"),
  287. )
  288. parser.add_argument("--use_space_char", type=str2bool, default=True)
  289. parser.add_argument(
  290. "--vis_font_path", type=str, default=str(module_dir / "fonts/simfang.ttf")
  291. )
  292. parser.add_argument("--drop_score", type=float, default=0.5)
  293. # params for e2e
  294. parser.add_argument("--e2e_algorithm", type=str, default="PGNet")
  295. parser.add_argument("--e2e_model_dir", type=str)
  296. parser.add_argument("--e2e_limit_side_len", type=float, default=768)
  297. parser.add_argument("--e2e_limit_type", type=str, default="max")
  298. # PGNet parmas
  299. parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
  300. parser.add_argument(
  301. "--e2e_char_dict_path",
  302. type=str,
  303. default=str(module_dir / "ppocr/utils/ic15_dict.txt"),
  304. )
  305. parser.add_argument("--e2e_pgnet_valid_set", type=str, default="totaltext")
  306. parser.add_argument("--e2e_pgnet_mode", type=str, default="fast")
  307. # params for text classifier
  308. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  309. parser.add_argument(
  310. "--cls_model_dir",
  311. type=str,
  312. default=str(module_dir / "models/ppocrv4/cls/cls.onnx"),
  313. )
  314. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  315. parser.add_argument("--label_list", type=list, default=["0", "180"])
  316. parser.add_argument("--cls_batch_num", type=int, default=6)
  317. parser.add_argument("--cls_thresh", type=float, default=0.9)
  318. parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
  319. parser.add_argument("--cpu_threads", type=int, default=10)
  320. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  321. parser.add_argument("--warmup", type=str2bool, default=False)
  322. # SR parmas
  323. parser.add_argument("--sr_model_dir", type=str)
  324. parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
  325. parser.add_argument("--sr_batch_num", type=int, default=1)
  326. #
  327. parser.add_argument(
  328. "--draw_img_save_dir", type=str, default=str(module_dir / "inference_results")
  329. )
  330. parser.add_argument("--save_crop_res", type=str2bool, default=False)
  331. parser.add_argument(
  332. "--crop_res_save_dir", type=str, default=str(module_dir / "output")
  333. )
  334. # multi-process
  335. parser.add_argument("--use_mp", type=str2bool, default=False)
  336. parser.add_argument("--total_process_num", type=int, default=1)
  337. parser.add_argument("--process_id", type=int, default=0)
  338. parser.add_argument("--benchmark", type=str2bool, default=False)
  339. parser.add_argument(
  340. "--save_log_path", type=str, default=str(module_dir / "log_output/")
  341. )
  342. parser.add_argument("--show_log", type=str2bool, default=True)
  343. parser.add_argument("--use_onnx", type=str2bool, default=False)
  344. return parser