utility.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. import cv2
  18. import numpy as np
  19. import paddle
  20. import PIL
  21. from PIL import Image, ImageDraw, ImageFont
  22. import math
  23. from paddle import inference
  24. import random
  25. import yaml
  26. from ppocr.utils.logging import get_logger
  27. def str2bool(v):
  28. return v.lower() in ("true", "yes", "t", "y", "1")
  29. def str2int_tuple(v):
  30. return tuple([int(i.strip()) for i in v.split(",")])
  31. def init_args():
  32. parser = argparse.ArgumentParser()
  33. # params for prediction engine
  34. parser.add_argument("--use_gpu", type=str2bool, default=True)
  35. parser.add_argument("--use_xpu", type=str2bool, default=False)
  36. parser.add_argument("--use_npu", type=str2bool, default=False)
  37. parser.add_argument("--use_mlu", type=str2bool, default=False)
  38. parser.add_argument("--use_metax_gpu", type=str2bool, default=False)
  39. parser.add_argument(
  40. "--use_gcu",
  41. type=str2bool,
  42. default=False,
  43. help="Use Enflame GCU(General Compute Unit)",
  44. )
  45. parser.add_argument("--ir_optim", type=str2bool, default=True)
  46. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  47. parser.add_argument("--min_subgraph_size", type=int, default=15)
  48. parser.add_argument("--precision", type=str, default="fp32")
  49. parser.add_argument("--gpu_mem", type=int, default=500)
  50. parser.add_argument("--gpu_id", type=int, default=0)
  51. # params for text detector
  52. parser.add_argument("--image_dir", type=str)
  53. parser.add_argument("--page_num", type=int, default=0)
  54. parser.add_argument("--det_algorithm", type=str, default="DB")
  55. parser.add_argument("--det_model_dir", type=str)
  56. parser.add_argument("--det_limit_side_len", type=float, default=960)
  57. parser.add_argument("--det_limit_type", type=str, default="max")
  58. parser.add_argument("--det_box_type", type=str, default="quad")
  59. # DB params
  60. parser.add_argument("--det_db_thresh", type=float, default=0.3)
  61. parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
  62. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
  63. parser.add_argument("--max_batch_size", type=int, default=10)
  64. parser.add_argument("--use_dilation", type=str2bool, default=False)
  65. parser.add_argument("--det_db_score_mode", type=str, default="fast")
  66. # EAST params
  67. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  68. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  69. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  70. # SAST params
  71. parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
  72. parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
  73. # PSE params
  74. parser.add_argument("--det_pse_thresh", type=float, default=0)
  75. parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
  76. parser.add_argument("--det_pse_min_area", type=float, default=16)
  77. parser.add_argument("--det_pse_scale", type=int, default=1)
  78. # FCE params
  79. parser.add_argument("--scales", type=list, default=[8, 16, 32])
  80. parser.add_argument("--alpha", type=float, default=1.0)
  81. parser.add_argument("--beta", type=float, default=1.0)
  82. parser.add_argument("--fourier_degree", type=int, default=5)
  83. # params for text recognizer
  84. parser.add_argument("--rec_algorithm", type=str, default="SVTR_LCNet")
  85. parser.add_argument("--rec_model_dir", type=str)
  86. parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
  87. parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
  88. parser.add_argument("--rec_batch_num", type=int, default=6)
  89. parser.add_argument("--max_text_length", type=int, default=25)
  90. parser.add_argument(
  91. "--rec_char_dict_path", type=str, default="./ppocr/utils/ppocr_keys_v1.txt"
  92. )
  93. parser.add_argument("--use_space_char", type=str2bool, default=True)
  94. parser.add_argument("--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
  95. parser.add_argument("--drop_score", type=float, default=0.5)
  96. # params for e2e
  97. parser.add_argument("--e2e_algorithm", type=str, default="PGNet")
  98. parser.add_argument("--e2e_model_dir", type=str)
  99. parser.add_argument("--e2e_limit_side_len", type=float, default=768)
  100. parser.add_argument("--e2e_limit_type", type=str, default="max")
  101. # PGNet params
  102. parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
  103. parser.add_argument(
  104. "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt"
  105. )
  106. parser.add_argument("--e2e_pgnet_valid_set", type=str, default="totaltext")
  107. parser.add_argument("--e2e_pgnet_mode", type=str, default="fast")
  108. # params for text classifier
  109. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  110. parser.add_argument("--cls_model_dir", type=str)
  111. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  112. parser.add_argument("--label_list", type=list, default=["0", "180"])
  113. parser.add_argument("--cls_batch_num", type=int, default=6)
  114. parser.add_argument("--cls_thresh", type=float, default=0.9)
  115. parser.add_argument("--enable_mkldnn", type=str2bool, default=None)
  116. parser.add_argument("--cpu_threads", type=int, default=10)
  117. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  118. parser.add_argument("--warmup", type=str2bool, default=False)
  119. # SR params
  120. parser.add_argument("--sr_model_dir", type=str)
  121. parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
  122. parser.add_argument("--sr_batch_num", type=int, default=1)
  123. #
  124. parser.add_argument("--draw_img_save_dir", type=str, default="./inference_results")
  125. parser.add_argument("--save_crop_res", type=str2bool, default=False)
  126. parser.add_argument("--crop_res_save_dir", type=str, default="./output")
  127. # multi-process
  128. parser.add_argument("--use_mp", type=str2bool, default=False)
  129. parser.add_argument("--total_process_num", type=int, default=1)
  130. parser.add_argument("--process_id", type=int, default=0)
  131. parser.add_argument("--benchmark", type=str2bool, default=False)
  132. parser.add_argument("--save_log_path", type=str, default="./log_output/")
  133. parser.add_argument("--show_log", type=str2bool, default=True)
  134. parser.add_argument("--use_onnx", type=str2bool, default=False)
  135. parser.add_argument("--onnx_providers", nargs="+", type=str, default=False)
  136. parser.add_argument("--onnx_sess_options", type=list, default=False)
  137. # extended function
  138. parser.add_argument(
  139. "--return_word_box",
  140. type=str2bool,
  141. default=False,
  142. help="Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery",
  143. )
  144. return parser
  145. def parse_args():
  146. parser = init_args()
  147. return parser.parse_args()
  148. def create_predictor(args, mode, logger):
  149. if mode == "det":
  150. model_dir = args.det_model_dir
  151. elif mode == "cls":
  152. model_dir = args.cls_model_dir
  153. elif mode == "rec":
  154. model_dir = args.rec_model_dir
  155. elif mode == "table":
  156. model_dir = args.table_model_dir
  157. elif mode == "ser":
  158. model_dir = args.ser_model_dir
  159. elif mode == "re":
  160. model_dir = args.re_model_dir
  161. elif mode == "sr":
  162. model_dir = args.sr_model_dir
  163. elif mode == "layout":
  164. model_dir = args.layout_model_dir
  165. else:
  166. model_dir = args.e2e_model_dir
  167. if model_dir is None:
  168. logger.info("not find {} model file path {}".format(mode, model_dir))
  169. sys.exit(0)
  170. if args.use_onnx:
  171. import onnxruntime as ort
  172. model_file_path = model_dir
  173. if not os.path.exists(model_file_path):
  174. raise ValueError("not find model file path {}".format(model_file_path))
  175. sess_options = args.onnx_sess_options or None
  176. if args.onnx_providers and len(args.onnx_providers) > 0:
  177. sess = ort.InferenceSession(
  178. model_file_path,
  179. providers=args.onnx_providers,
  180. sess_options=sess_options,
  181. )
  182. elif args.use_gpu:
  183. sess = ort.InferenceSession(
  184. model_file_path,
  185. providers=[
  186. (
  187. "CUDAExecutionProvider",
  188. {"device_id": args.gpu_id, "cudnn_conv_algo_search": "DEFAULT"},
  189. )
  190. ],
  191. sess_options=sess_options,
  192. )
  193. else:
  194. sess = ort.InferenceSession(
  195. model_file_path,
  196. providers=["CPUExecutionProvider"],
  197. sess_options=sess_options,
  198. )
  199. inputs = sess.get_inputs()
  200. return (
  201. sess,
  202. inputs[0] if len(inputs) == 1 else [vo.name for vo in inputs],
  203. None,
  204. None,
  205. )
  206. else:
  207. file_names = ["model", "inference"]
  208. for file_name in file_names:
  209. params_file_path = f"{model_dir}/{file_name}.pdiparams"
  210. if os.path.exists(params_file_path):
  211. break
  212. if not os.path.exists(params_file_path):
  213. raise ValueError(f"not find {file_name}.pdiparams in {model_dir}")
  214. if not (
  215. os.path.exists(f"{model_dir}/{file_name}.pdmodel")
  216. or os.path.exists(f"{model_dir}/{file_name}.json")
  217. ):
  218. raise ValueError(
  219. f"neither {file_name}.json nor {file_name}.pdmodel was found in {model_dir}."
  220. )
  221. if os.path.exists(f"{model_dir}/{file_name}.json"):
  222. model_file_path = f"{model_dir}/{file_name}.json"
  223. else:
  224. model_file_path = f"{model_dir}/{file_name}.pdmodel"
  225. config = inference.Config(model_file_path, params_file_path)
  226. if hasattr(args, "precision"):
  227. if args.precision == "fp16" and args.use_tensorrt:
  228. precision = inference.PrecisionType.Half
  229. elif args.precision == "int8":
  230. precision = inference.PrecisionType.Int8
  231. else:
  232. precision = inference.PrecisionType.Float32
  233. else:
  234. precision = inference.PrecisionType.Float32
  235. if args.use_gpu:
  236. gpu_id = get_infer_gpuid()
  237. if gpu_id is None:
  238. logger.warning(
  239. "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
  240. )
  241. config.enable_use_gpu(args.gpu_mem, args.gpu_id)
  242. if args.use_tensorrt:
  243. if ".json" in model_file_path:
  244. trt_dynamic_shapes = {}
  245. trt_dynamic_shape_input_data = {}
  246. if os.path.exists(f"{model_dir}/inference.yml"):
  247. model_config = load_config(f"{model_dir}/inference.yml")
  248. trt_dynamic_shapes = (
  249. model_config.get("Hpi", {})
  250. .get("backend_configs", {})
  251. .get("paddle_infer", {})
  252. .get("trt_dynamic_shapes", {})
  253. )
  254. trt_dynamic_shape_input_data = (
  255. model_config.get("Hpi", {})
  256. .get("backend_configs", {})
  257. .get("paddle_infer", {})
  258. .get("trt_dynamic_shapes_input_data", {})
  259. )
  260. if not trt_dynamic_shapes:
  261. raise RuntimeError(
  262. "Configuration Error: 'trt_dynamic_shapes' must be defined in 'inference.yml' for Paddle Inference TensorRT."
  263. )
  264. trt_save_path = f"{model_dir}/.cache/trt/{file_name}"
  265. trt_model_file_path = trt_save_path + ".json"
  266. trt_params_file_path = trt_save_path + ".pdiparams"
  267. if not os.path.exists(trt_model_file_path) or not os.path.exists(
  268. trt_params_file_path
  269. ):
  270. _convert_trt(
  271. {},
  272. model_file_path,
  273. params_file_path,
  274. trt_save_path,
  275. args.gpu_id,
  276. trt_dynamic_shapes,
  277. trt_dynamic_shape_input_data,
  278. )
  279. config = inference.Config(model_file_path, params_file_path)
  280. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  281. config.enable_use_gpu(args.gpu_mem, args.gpu_id)
  282. else:
  283. config.enable_tensorrt_engine(
  284. workspace_size=1 << 30,
  285. precision_mode=precision,
  286. max_batch_size=args.max_batch_size,
  287. min_subgraph_size=args.min_subgraph_size, # skip the minimum trt subgraph
  288. use_calib_mode=False,
  289. )
  290. # collect shape
  291. trt_shape_f = os.path.join(
  292. model_dir, f"{mode}_trt_dynamic_shape.txt"
  293. )
  294. if not os.path.exists(trt_shape_f):
  295. config.collect_shape_range_info(trt_shape_f)
  296. logger.info(f"collect dynamic shape info into : {trt_shape_f}")
  297. try:
  298. config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True)
  299. except Exception as E:
  300. logger.info(E)
  301. logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!")
  302. elif args.use_npu:
  303. config.enable_custom_device("npu")
  304. elif args.use_mlu:
  305. config.enable_custom_device("mlu")
  306. elif args.use_metax_gpu:
  307. if args.precision == "fp16":
  308. config.enable_custom_device(
  309. "metax_gpu", 0, paddle.inference.PrecisionType.Half
  310. )
  311. else:
  312. config.enable_custom_device("metax_gpu")
  313. elif args.use_xpu:
  314. config.enable_xpu(10 * 1024 * 1024)
  315. elif args.use_gcu: # for Enflame GCU(General Compute Unit)
  316. assert paddle.device.is_compiled_with_custom_device("gcu"), (
  317. "Args use_gcu cannot be set as True while your paddle "
  318. "is not compiled with gcu! \nPlease try: \n"
  319. "\t1. Install paddle-custom-gcu to run model on GCU. \n"
  320. "\t2. Set use_gcu as False in args to run model on CPU."
  321. )
  322. import paddle_custom_device.gcu.passes as gcu_passes
  323. gcu_passes.setUp()
  324. if args.precision == "fp16":
  325. config.enable_custom_device(
  326. "gcu", 0, paddle.inference.PrecisionType.Half
  327. )
  328. gcu_passes.set_exp_enable_mixed_precision_ops(config)
  329. else:
  330. config.enable_custom_device("gcu")
  331. if paddle.framework.use_pir_api():
  332. config.enable_new_ir(True)
  333. config.enable_new_executor(True)
  334. else:
  335. pass_builder = config.pass_builder()
  336. gcu_passes.append_passes_for_legacy_ir(pass_builder, "PaddleOCR")
  337. else:
  338. config.disable_gpu()
  339. if args.enable_mkldnn is not None:
  340. if args.enable_mkldnn:
  341. # cache 10 different shapes for mkldnn to avoid memory leak
  342. config.set_mkldnn_cache_capacity(10)
  343. config.enable_mkldnn()
  344. if args.precision == "fp16":
  345. config.enable_mkldnn_bfloat16()
  346. else:
  347. if hasattr(config, "disable_mkldnn"):
  348. config.disable_mkldnn()
  349. if hasattr(args, "cpu_threads"):
  350. config.set_cpu_math_library_num_threads(args.cpu_threads)
  351. else:
  352. # default cpu threads as 10
  353. config.set_cpu_math_library_num_threads(10)
  354. if hasattr(config, "enable_new_ir"):
  355. config.enable_new_ir()
  356. if hasattr(config, "enable_new_executor"):
  357. config.enable_new_executor()
  358. # enable memory optim
  359. config.enable_memory_optim()
  360. config.disable_glog_info()
  361. if not args.use_gcu: # for Enflame GCU(General Compute Unit)
  362. config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
  363. config.delete_pass("matmul_transpose_reshape_fuse_pass")
  364. if mode == "rec" and args.rec_algorithm == "SRN":
  365. config.delete_pass("gpu_cpu_map_matmul_v2_to_matmul_pass")
  366. if mode == "re":
  367. config.delete_pass("simplify_with_basic_ops_pass")
  368. if mode == "table":
  369. config.delete_pass("fc_fuse_pass") # not supported for table
  370. config.switch_use_feed_fetch_ops(False)
  371. config.switch_ir_optim(True)
  372. # create predictor
  373. predictor = inference.create_predictor(config)
  374. input_names = predictor.get_input_names()
  375. if mode in ["ser", "re"]:
  376. input_tensor = []
  377. for name in input_names:
  378. input_tensor.append(predictor.get_input_handle(name))
  379. else:
  380. for name in input_names:
  381. input_tensor = predictor.get_input_handle(name)
  382. output_tensors = get_output_tensors(args, mode, predictor)
  383. return predictor, input_tensor, output_tensors, config
  384. def _convert_trt(
  385. trt_cfg_setting,
  386. pp_model_file,
  387. pp_params_file,
  388. trt_save_path,
  389. device_id,
  390. dynamic_shapes,
  391. dynamic_shape_input_data,
  392. ):
  393. from paddle.tensorrt.export import Input, TensorRTConfig, convert
  394. def _set_trt_config():
  395. for attr_name in trt_cfg_setting:
  396. assert hasattr(
  397. trt_config, attr_name
  398. ), f"The `{type(trt_config)}` don't have the attribute `{attr_name}`!"
  399. setattr(trt_config, attr_name, trt_cfg_setting[attr_name])
  400. def _get_predictor(model_file, params_file):
  401. # HACK
  402. config = inference.Config(str(model_file), str(params_file))
  403. config.enable_use_gpu(100, device_id)
  404. # NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
  405. config.disable_mkldnn()
  406. config.disable_glog_info()
  407. return inference.create_predictor(config)
  408. dynamic_shape_input_data = dynamic_shape_input_data or {}
  409. predictor = _get_predictor(pp_model_file, pp_params_file)
  410. input_names = predictor.get_input_names()
  411. for name in dynamic_shapes:
  412. if name not in input_names:
  413. raise ValueError(
  414. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  415. )
  416. for name in input_names:
  417. if name not in dynamic_shapes:
  418. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  419. for name in dynamic_shape_input_data:
  420. if name not in input_names:
  421. raise ValueError(
  422. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  423. )
  424. trt_inputs = []
  425. for name, candidate_shapes in dynamic_shapes.items():
  426. # XXX: Currently we have no way to get the data type of the tensor
  427. # without creating an input handle.
  428. handle = predictor.get_input_handle(name)
  429. dtype = _pd_dtype_to_np_dtype(handle.type())
  430. min_shape, opt_shape, max_shape = candidate_shapes
  431. if name in dynamic_shape_input_data:
  432. min_arr = np.array(dynamic_shape_input_data[name][0], dtype=dtype).reshape(
  433. min_shape
  434. )
  435. opt_arr = np.array(dynamic_shape_input_data[name][1], dtype=dtype).reshape(
  436. opt_shape
  437. )
  438. max_arr = np.array(dynamic_shape_input_data[name][2], dtype=dtype).reshape(
  439. max_shape
  440. )
  441. else:
  442. min_arr = np.ones(min_shape, dtype=dtype)
  443. opt_arr = np.ones(opt_shape, dtype=dtype)
  444. max_arr = np.ones(max_shape, dtype=dtype)
  445. # refer to: https://github.com/PolaKuma/Paddle/blob/3347f225bc09f2ec09802a2090432dd5cb5b6739/test/tensorrt/test_converter_model_resnet50.py
  446. trt_input = Input((min_arr, opt_arr, max_arr))
  447. trt_inputs.append(trt_input)
  448. # Create TensorRTConfig
  449. trt_config = TensorRTConfig(inputs=trt_inputs)
  450. _set_trt_config()
  451. trt_config.save_model_dir = trt_save_path
  452. pp_model_path = pp_model_file.split(".")[0]
  453. convert(pp_model_path, trt_config)
  454. def _pd_dtype_to_np_dtype(pd_dtype):
  455. if pd_dtype == inference.DataType.FLOAT64:
  456. return np.float64
  457. elif pd_dtype == inference.DataType.FLOAT32:
  458. return np.float32
  459. elif pd_dtype == inference.DataType.INT64:
  460. return np.int64
  461. elif pd_dtype == inference.DataType.INT32:
  462. return np.int32
  463. elif pd_dtype == inference.DataType.UINT8:
  464. return np.uint8
  465. elif pd_dtype == inference.DataType.INT8:
  466. return np.int8
  467. else:
  468. raise TypeError(f"Unsupported data type: {pd_dtype}")
  469. def load_config(file_path):
  470. _, ext = os.path.splitext(file_path)
  471. if ext not in [".yml", ".yaml"]:
  472. raise ValueError(f"only support yaml files for now, got {file_path}")
  473. with open(file_path, "rb") as file:
  474. config = yaml.load(file, Loader=yaml.SafeLoader)
  475. return config
  476. def get_output_tensors(args, mode, predictor):
  477. output_names = predictor.get_output_names()
  478. output_tensors = []
  479. if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet", "SVTR_HGNet"]:
  480. output_name = "softmax_0.tmp_0"
  481. if output_name in output_names:
  482. return [predictor.get_output_handle(output_name)]
  483. else:
  484. for output_name in output_names:
  485. output_tensor = predictor.get_output_handle(output_name)
  486. output_tensors.append(output_tensor)
  487. else:
  488. for output_name in output_names:
  489. output_tensor = predictor.get_output_handle(output_name)
  490. output_tensors.append(output_tensor)
  491. return output_tensors
  492. def get_infer_gpuid():
  493. """
  494. Get the GPU ID to be used for inference.
  495. Returns:
  496. int: The GPU ID to be used for inference.
  497. """
  498. logger = get_logger()
  499. if not paddle.device.is_compiled_with_rocm:
  500. gpu_id_str = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
  501. else:
  502. gpu_id_str = os.environ.get("HIP_VISIBLE_DEVICES", "0")
  503. gpu_ids = gpu_id_str.split(",")
  504. logger.warning(
  505. "The first GPU is used for inference by default, GPU ID: {}".format(gpu_ids[0])
  506. )
  507. return int(gpu_ids[0])
  508. def draw_e2e_res(dt_boxes, strs, img_path):
  509. src_im = cv2.imread(img_path)
  510. for box, str in zip(dt_boxes, strs):
  511. box = box.astype(np.int32).reshape((-1, 1, 2))
  512. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  513. cv2.putText(
  514. src_im,
  515. str,
  516. org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
  517. fontFace=cv2.FONT_HERSHEY_COMPLEX,
  518. fontScale=0.7,
  519. color=(0, 255, 0),
  520. thickness=1,
  521. )
  522. return src_im
  523. def draw_text_det_res(dt_boxes, img):
  524. for box in dt_boxes:
  525. box = np.array(box).astype(np.int32).reshape(-1, 2)
  526. cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
  527. return img
  528. def resize_img(img, input_size=600):
  529. """
  530. resize img and limit the longest side of the image to input_size
  531. """
  532. img = np.array(img)
  533. im_shape = img.shape
  534. im_size_max = np.max(im_shape[0:2])
  535. im_scale = float(input_size) / float(im_size_max)
  536. img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
  537. return img
  538. def draw_ocr(
  539. image,
  540. boxes,
  541. txts=None,
  542. scores=None,
  543. drop_score=0.5,
  544. font_path="./doc/fonts/simfang.ttf",
  545. ):
  546. """
  547. Visualize the results of OCR detection and recognition
  548. args:
  549. image(Image|array): RGB image
  550. boxes(list): boxes with shape(N, 4, 2)
  551. txts(list): the texts
  552. scores(list): txxs corresponding scores
  553. drop_score(float): only scores greater than drop_threshold will be visualized
  554. font_path: the path of font which is used to draw text
  555. return(array):
  556. the visualized img
  557. """
  558. if scores is None:
  559. scores = [1] * len(boxes)
  560. box_num = len(boxes)
  561. for i in range(box_num):
  562. if scores is not None and (scores[i] < drop_score or math.isnan(scores[i])):
  563. continue
  564. box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
  565. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  566. if txts is not None:
  567. img = np.array(resize_img(image, input_size=600))
  568. txt_img = text_visual(
  569. txts,
  570. scores,
  571. img_h=img.shape[0],
  572. img_w=600,
  573. threshold=drop_score,
  574. font_path=font_path,
  575. )
  576. img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
  577. return img
  578. return image
  579. def draw_ocr_box_txt(
  580. image,
  581. boxes,
  582. txts=None,
  583. scores=None,
  584. drop_score=0.5,
  585. font_path="./doc/fonts/simfang.ttf",
  586. ):
  587. h, w = image.height, image.width
  588. img_left = image.copy()
  589. img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
  590. random.seed(0)
  591. draw_left = ImageDraw.Draw(img_left)
  592. if txts is None or len(txts) != len(boxes):
  593. txts = [None] * len(boxes)
  594. for idx, (box, txt) in enumerate(zip(boxes, txts)):
  595. if scores is not None and scores[idx] < drop_score:
  596. continue
  597. color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
  598. draw_left.polygon(box, fill=color)
  599. img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
  600. pts = np.array(box, np.int32).reshape((-1, 1, 2))
  601. cv2.polylines(img_right_text, [pts], True, color, 1)
  602. img_right = cv2.bitwise_and(img_right, img_right_text)
  603. img_left = Image.blend(image, img_left, 0.5)
  604. img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
  605. img_show.paste(img_left, (0, 0, w, h))
  606. img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
  607. return np.array(img_show)
  608. def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
  609. box_height = int(
  610. math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
  611. )
  612. box_width = int(
  613. math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
  614. )
  615. if box_height > 2 * box_width and box_height > 30:
  616. img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
  617. draw_text = ImageDraw.Draw(img_text)
  618. if txt:
  619. font = create_font(txt, (box_height, box_width), font_path)
  620. draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
  621. img_text = img_text.transpose(Image.ROTATE_270)
  622. else:
  623. img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
  624. draw_text = ImageDraw.Draw(img_text)
  625. if txt:
  626. font = create_font(txt, (box_width, box_height), font_path)
  627. draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
  628. pts1 = np.float32(
  629. [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
  630. )
  631. pts2 = np.array(box, dtype=np.float32)
  632. M = cv2.getPerspectiveTransform(pts1, pts2)
  633. img_text = np.array(img_text, dtype=np.uint8)
  634. img_right_text = cv2.warpPerspective(
  635. img_text,
  636. M,
  637. img_size,
  638. flags=cv2.INTER_NEAREST,
  639. borderMode=cv2.BORDER_CONSTANT,
  640. borderValue=(255, 255, 255),
  641. )
  642. return img_right_text
  643. def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
  644. font_size = int(sz[1] * 0.99)
  645. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  646. if int(PIL.__version__.split(".")[0]) < 10:
  647. length = font.getsize(txt)[0]
  648. else:
  649. length = font.getlength(txt)
  650. if length > sz[0]:
  651. font_size = int(font_size * sz[0] / length)
  652. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  653. return font
  654. def str_count(s):
  655. """
  656. Count the number of Chinese characters,
  657. a single English character and a single number
  658. equal to half the length of Chinese characters.
  659. args:
  660. s(string): the input of string
  661. return(int):
  662. the number of Chinese characters
  663. """
  664. import string
  665. count_zh = count_pu = 0
  666. s_len = len(s)
  667. en_dg_count = 0
  668. for c in s:
  669. if c in string.ascii_letters or c.isdigit() or c.isspace():
  670. en_dg_count += 1
  671. elif c.isalpha():
  672. count_zh += 1
  673. else:
  674. count_pu += 1
  675. return s_len - math.ceil(en_dg_count / 2)
  676. def text_visual(
  677. texts, scores, img_h=400, img_w=600, threshold=0.0, font_path="./doc/simfang.ttf"
  678. ):
  679. """
  680. create new blank img and draw txt on it
  681. args:
  682. texts(list): the text will be draw
  683. scores(list|None): corresponding score of each txt
  684. img_h(int): the height of blank img
  685. img_w(int): the width of blank img
  686. font_path: the path of font which is used to draw text
  687. return(array):
  688. """
  689. if scores is not None:
  690. assert len(texts) == len(
  691. scores
  692. ), "The number of txts and corresponding scores must match"
  693. def create_blank_img():
  694. blank_img = np.ones(shape=[img_h, img_w], dtype=np.uint8) * 255
  695. blank_img[:, img_w - 1 :] = 0
  696. blank_img = Image.fromarray(blank_img).convert("RGB")
  697. draw_txt = ImageDraw.Draw(blank_img)
  698. return blank_img, draw_txt
  699. blank_img, draw_txt = create_blank_img()
  700. font_size = 20
  701. txt_color = (0, 0, 0)
  702. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  703. gap = font_size + 5
  704. txt_img_list = []
  705. count, index = 1, 0
  706. for idx, txt in enumerate(texts):
  707. index += 1
  708. if scores[idx] < threshold or math.isnan(scores[idx]):
  709. index -= 1
  710. continue
  711. first_line = True
  712. while str_count(txt) >= img_w // font_size - 4:
  713. tmp = txt
  714. txt = tmp[: img_w // font_size - 4]
  715. if first_line:
  716. new_txt = str(index) + ": " + txt
  717. first_line = False
  718. else:
  719. new_txt = " " + txt
  720. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  721. txt = tmp[img_w // font_size - 4 :]
  722. if count >= img_h // gap - 1:
  723. txt_img_list.append(np.array(blank_img))
  724. blank_img, draw_txt = create_blank_img()
  725. count = 0
  726. count += 1
  727. if first_line:
  728. new_txt = str(index) + ": " + txt + " " + "%.3f" % (scores[idx])
  729. else:
  730. new_txt = " " + txt + " " + "%.3f" % (scores[idx])
  731. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  732. # whether add new blank img or not
  733. if count >= img_h // gap - 1 and idx + 1 < len(texts):
  734. txt_img_list.append(np.array(blank_img))
  735. blank_img, draw_txt = create_blank_img()
  736. count = 0
  737. count += 1
  738. txt_img_list.append(np.array(blank_img))
  739. if len(txt_img_list) == 1:
  740. blank_img = np.array(txt_img_list[0])
  741. else:
  742. blank_img = np.concatenate(txt_img_list, axis=1)
  743. return np.array(blank_img)
  744. def base64_to_cv2(b64str):
  745. import base64
  746. data = base64.b64decode(b64str.encode("utf8"))
  747. data = np.frombuffer(data, np.uint8)
  748. data = cv2.imdecode(data, cv2.IMREAD_COLOR)
  749. return data
  750. def draw_boxes(image, boxes, scores=None, drop_score=0.5):
  751. if scores is None:
  752. scores = [1] * len(boxes)
  753. for box, score in zip(boxes, scores):
  754. if score < drop_score:
  755. continue
  756. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  757. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  758. return image
  759. def get_rotate_crop_image(img, points):
  760. """
  761. img_height, img_width = img.shape[0:2]
  762. left = int(np.min(points[:, 0]))
  763. right = int(np.max(points[:, 0]))
  764. top = int(np.min(points[:, 1]))
  765. bottom = int(np.max(points[:, 1]))
  766. img_crop = img[top:bottom, left:right, :].copy()
  767. points[:, 0] = points[:, 0] - left
  768. points[:, 1] = points[:, 1] - top
  769. """
  770. assert len(points) == 4, "shape of points must be 4*2"
  771. img_crop_width = int(
  772. max(
  773. np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3])
  774. )
  775. )
  776. img_crop_height = int(
  777. max(
  778. np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2])
  779. )
  780. )
  781. pts_std = np.float32(
  782. [
  783. [0, 0],
  784. [img_crop_width, 0],
  785. [img_crop_width, img_crop_height],
  786. [0, img_crop_height],
  787. ]
  788. )
  789. M = cv2.getPerspectiveTransform(points, pts_std)
  790. dst_img = cv2.warpPerspective(
  791. img,
  792. M,
  793. (img_crop_width, img_crop_height),
  794. borderMode=cv2.BORDER_REPLICATE,
  795. flags=cv2.INTER_CUBIC,
  796. )
  797. dst_img_height, dst_img_width = dst_img.shape[0:2]
  798. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  799. dst_img = np.rot90(dst_img)
  800. return dst_img
  801. def get_minarea_rect_crop(img, points):
  802. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  803. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  804. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  805. if points[1][1] > points[0][1]:
  806. index_a = 0
  807. index_d = 1
  808. else:
  809. index_a = 1
  810. index_d = 0
  811. if points[3][1] > points[2][1]:
  812. index_b = 2
  813. index_c = 3
  814. else:
  815. index_b = 3
  816. index_c = 2
  817. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  818. crop_img = get_rotate_crop_image(img, np.array(box))
  819. return crop_img
  820. def slice_generator(image, horizontal_stride, vertical_stride, maximum_slices=500):
  821. if not isinstance(image, np.ndarray):
  822. image = np.array(image)
  823. image_h, image_w = image.shape[:2]
  824. vertical_num_slices = (image_h + vertical_stride - 1) // vertical_stride
  825. horizontal_num_slices = (image_w + horizontal_stride - 1) // horizontal_stride
  826. assert (
  827. vertical_num_slices > 0
  828. ), f"Invalid number ({vertical_num_slices}) of vertical slices"
  829. assert (
  830. horizontal_num_slices > 0
  831. ), f"Invalid number ({horizontal_num_slices}) of horizontal slices"
  832. if vertical_num_slices >= maximum_slices:
  833. recommended_vertical_stride = max(1, image_h // maximum_slices) + 1
  834. assert (
  835. False
  836. ), f"Too computationally expensive with {vertical_num_slices} slices, try a higher vertical stride (recommended minimum: {recommended_vertical_stride})"
  837. if horizontal_num_slices >= maximum_slices:
  838. recommended_horizontal_stride = max(1, image_w // maximum_slices) + 1
  839. assert (
  840. False
  841. ), f"Too computationally expensive with {horizontal_num_slices} slices, try a higher horizontal stride (recommended minimum: {recommended_horizontal_stride})"
  842. for v_slice_idx in range(vertical_num_slices):
  843. v_start = max(0, (v_slice_idx * vertical_stride))
  844. v_end = min(((v_slice_idx + 1) * vertical_stride), image_h)
  845. vertical_slice = image[v_start:v_end, :]
  846. for h_slice_idx in range(horizontal_num_slices):
  847. h_start = max(0, (h_slice_idx * horizontal_stride))
  848. h_end = min(((h_slice_idx + 1) * horizontal_stride), image_w)
  849. horizontal_slice = vertical_slice[:, h_start:h_end]
  850. yield (horizontal_slice, v_start, h_start)
  851. def calculate_box_extents(box):
  852. min_x = box[0][0]
  853. max_x = box[1][0]
  854. min_y = box[0][1]
  855. max_y = box[2][1]
  856. return min_x, max_x, min_y, max_y
  857. def merge_boxes(box1, box2, x_threshold, y_threshold):
  858. min_x1, max_x1, min_y1, max_y1 = calculate_box_extents(box1)
  859. min_x2, max_x2, min_y2, max_y2 = calculate_box_extents(box2)
  860. if (
  861. abs(min_y1 - min_y2) <= y_threshold
  862. and abs(max_y1 - max_y2) <= y_threshold
  863. and abs(max_x1 - min_x2) <= x_threshold
  864. ):
  865. new_xmin = min(min_x1, min_x2)
  866. new_xmax = max(max_x1, max_x2)
  867. new_ymin = min(min_y1, min_y2)
  868. new_ymax = max(max_y1, max_y2)
  869. return [
  870. [new_xmin, new_ymin],
  871. [new_xmax, new_ymin],
  872. [new_xmax, new_ymax],
  873. [new_xmin, new_ymax],
  874. ]
  875. else:
  876. return None
  877. def merge_fragmented(boxes, x_threshold=10, y_threshold=10):
  878. merged_boxes = []
  879. visited = set()
  880. for i, box1 in enumerate(boxes):
  881. if i in visited:
  882. continue
  883. merged_box = [point[:] for point in box1]
  884. for j, box2 in enumerate(boxes[i + 1 :], start=i + 1):
  885. if j not in visited:
  886. merged_result = merge_boxes(
  887. merged_box, box2, x_threshold=x_threshold, y_threshold=y_threshold
  888. )
  889. if merged_result:
  890. merged_box = merged_result
  891. visited.add(j)
  892. merged_boxes.append(merged_box)
  893. if len(merged_boxes) == len(boxes):
  894. return np.array(merged_boxes)
  895. else:
  896. return merge_fragmented(merged_boxes, x_threshold, y_threshold)
  897. def check_gpu(use_gpu):
  898. if use_gpu and (
  899. not paddle.is_compiled_with_cuda() or paddle.device.get_device() == "cpu"
  900. ):
  901. use_gpu = False
  902. return use_gpu
  903. if __name__ == "__main__":
  904. pass