test_ocr.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import time
  16. import os
  17. import sys
  18. import cv2
  19. import numpy as np
  20. import paddle
  21. import logging
  22. import numpy as np
  23. import argparse
  24. from tqdm import tqdm
  25. import paddle
  26. from paddleslim.common import load_config as load_slim_config
  27. from paddleslim.common import get_logger
  28. import sys
  29. sys.path.append("../../../")
  30. from ppocr.data import build_dataloader
  31. from ppocr.postprocess import build_post_process
  32. from ppocr.metrics import build_metric
  33. from paddle.inference import create_predictor, PrecisionType
  34. from paddle.inference import Config as PredictConfig
  35. logger = get_logger(__name__, level=logging.INFO)
  36. def find_images_with_bounding_size(dataset: paddle.io.Dataset):
  37. max_length_index = -1
  38. max_width_index = -1
  39. min_length_index = -1
  40. min_width_index = -1
  41. max_length = float("-inf")
  42. max_width = float("-inf")
  43. min_length = float("inf")
  44. min_width = float("inf")
  45. for idx, data in enumerate(dataset):
  46. image = np.array(data[0])
  47. h, w = image.shape[-2:]
  48. if h > max_length:
  49. max_length = h
  50. max_length_index = idx
  51. if w > max_width:
  52. max_width = w
  53. max_width_index = idx
  54. if h < min_length:
  55. min_length = h
  56. min_length_index = idx
  57. if w < min_width:
  58. min_width = w
  59. min_width_index = idx
  60. print(f"Found max image length: {max_length}, index: {max_length_index}")
  61. print(f"Found max image width: {max_width}, index: {max_width_index}")
  62. print(f"Found min image length: {min_length}, index: {min_length_index}")
  63. print(f"Found min image width: {min_width}, index: {min_width_index}")
  64. return paddle.io.Subset(
  65. dataset, [max_width_index, max_length_index, min_width_index, min_length_index]
  66. )
  67. def load_predictor(args):
  68. """
  69. load predictor func
  70. """
  71. rerun_flag = False
  72. model_file = os.path.join(args.model_path, args.model_filename)
  73. params_file = os.path.join(args.model_path, args.params_filename)
  74. pred_cfg = PredictConfig(model_file, params_file)
  75. pred_cfg.enable_memory_optim()
  76. pred_cfg.switch_ir_optim(True)
  77. if args.device == "GPU":
  78. pred_cfg.enable_use_gpu(100, 0)
  79. else:
  80. pred_cfg.disable_gpu()
  81. pred_cfg.set_cpu_math_library_num_threads(args.cpu_threads)
  82. if args.use_mkldnn:
  83. pred_cfg.enable_mkldnn()
  84. if args.precision == "int8":
  85. pred_cfg.enable_mkldnn_int8({"conv2d"})
  86. if global_config["model_type"] == "rec":
  87. # delete pass which influence the accuracy, please refer to https://github.com/PaddlePaddle/Paddle/issues/55290
  88. pred_cfg.delete_pass("fc_mkldnn_pass")
  89. pred_cfg.delete_pass("fc_act_mkldnn_fuse_pass")
  90. if args.use_trt:
  91. # To collect the dynamic shapes of inputs for TensorRT engine
  92. dynamic_shape_file = os.path.join(args.model_path, "dynamic_shape.txt")
  93. if os.path.exists(dynamic_shape_file):
  94. pred_cfg.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, True)
  95. print("trt set dynamic shape done!")
  96. precision_map = {
  97. "fp16": PrecisionType.Half,
  98. "fp32": PrecisionType.Float32,
  99. "int8": PrecisionType.Int8,
  100. }
  101. if (
  102. args.precision == "int8"
  103. and "ppocrv4_det_server_qat_dist.yaml" in args.config_path
  104. ):
  105. # Use the following settings only when the hardware is a Tesla V100. If you are using
  106. # a RTX 3090, use the settings in the else branch.
  107. pred_cfg.enable_tensorrt_engine(
  108. workspace_size=1 << 30,
  109. max_batch_size=1,
  110. min_subgraph_size=30,
  111. precision_mode=precision_map[args.precision],
  112. use_static=True,
  113. use_calib_mode=False,
  114. )
  115. pred_cfg.exp_disable_tensorrt_ops(["elementwise_add"])
  116. else:
  117. pred_cfg.enable_tensorrt_engine(
  118. workspace_size=1 << 30,
  119. max_batch_size=1,
  120. min_subgraph_size=4,
  121. precision_mode=precision_map[args.precision],
  122. use_static=True,
  123. use_calib_mode=False,
  124. )
  125. else:
  126. # pred_cfg.disable_gpu()
  127. # pred_cfg.set_cpu_math_library_num_threads(24)
  128. pred_cfg.collect_shape_range_info(dynamic_shape_file)
  129. print("Start collect dynamic shape...")
  130. rerun_flag = True
  131. predictor = create_predictor(pred_cfg)
  132. return predictor, rerun_flag
  133. def eval(args):
  134. """
  135. eval mIoU func
  136. """
  137. # DataLoader need run on cpu
  138. paddle.set_device("cpu")
  139. devices = paddle.device.get_device().split(":")[0]
  140. val_loader = build_dataloader(all_config, "Eval", devices, logger)
  141. post_process_class = build_post_process(all_config["PostProcess"], global_config)
  142. eval_class = build_metric(all_config["Metric"])
  143. model_type = global_config["model_type"]
  144. predictor, rerun_flag = load_predictor(args)
  145. if rerun_flag:
  146. eval_dataset = find_images_with_bounding_size(val_loader.dataset)
  147. batch_sampler = paddle.io.BatchSampler(
  148. eval_dataset, batch_size=1, shuffle=False, drop_last=False
  149. )
  150. val_loader = paddle.io.DataLoader(
  151. eval_dataset, batch_sampler=batch_sampler, num_workers=4, return_list=True
  152. )
  153. input_names = predictor.get_input_names()
  154. input_handle = predictor.get_input_handle(input_names[0])
  155. output_names = predictor.get_output_names()
  156. output_handle = predictor.get_output_handle(output_names[0])
  157. sample_nums = len(val_loader)
  158. predict_time = 0.0
  159. time_min = float("inf")
  160. time_max = float("-inf")
  161. print("Start evaluating ( total_iters: {}).".format(sample_nums))
  162. for batch_id, batch in enumerate(val_loader):
  163. images = np.array(batch[0])
  164. batch_numpy = []
  165. for item in batch:
  166. batch_numpy.append(np.array(item))
  167. # ori_shape = np.array(batch_numpy).shape[-2:]
  168. input_handle.reshape(images.shape)
  169. input_handle.copy_from_cpu(images)
  170. start_time = time.time()
  171. predictor.run()
  172. preds = output_handle.copy_to_cpu()
  173. end_time = time.time()
  174. timed = end_time - start_time
  175. time_min = min(time_min, timed)
  176. time_max = max(time_max, timed)
  177. predict_time += timed
  178. if model_type == "det":
  179. preds_map = {"maps": preds}
  180. post_result = post_process_class(preds_map, batch_numpy[1])
  181. eval_class(post_result, batch_numpy)
  182. elif model_type == "rec":
  183. post_result = post_process_class(preds, batch_numpy[1])
  184. eval_class(post_result, batch_numpy)
  185. if rerun_flag:
  186. if batch_id == 3:
  187. print(
  188. "***** Collect dynamic shape done, Please rerun the program to get correct results. *****"
  189. )
  190. return
  191. if batch_id % 100 == 0:
  192. print("Eval iter:", batch_id)
  193. sys.stdout.flush()
  194. metric = eval_class.get_metric()
  195. time_avg = predict_time / sample_nums
  196. print(
  197. "[Benchmark] Inference time(ms): min={}, max={}, avg={}".format(
  198. round(time_min * 1000, 2),
  199. round(time_max * 1000, 1),
  200. round(time_avg * 1000, 1),
  201. )
  202. )
  203. for k, v in metric.items():
  204. print("{}:{}".format(k, v))
  205. sys.stdout.flush()
  206. def main():
  207. global all_config, global_config
  208. all_config = load_slim_config(args.config_path)
  209. global_config = all_config["Global"]
  210. eval(args)
  211. if __name__ == "__main__":
  212. paddle.enable_static()
  213. parser = argparse.ArgumentParser()
  214. parser.add_argument("--model_path", type=str, help="inference model filepath")
  215. parser.add_argument(
  216. "--config_path",
  217. type=str,
  218. default="./configs/ppocrv3_det_qat_dist.yaml",
  219. help="path of compression strategy config.",
  220. )
  221. parser.add_argument(
  222. "--model_filename",
  223. type=str,
  224. default="inference.pdmodel",
  225. help="model file name",
  226. )
  227. parser.add_argument(
  228. "--params_filename",
  229. type=str,
  230. default="inference.pdiparams",
  231. help="params file name",
  232. )
  233. parser.add_argument(
  234. "--device",
  235. type=str,
  236. default="GPU",
  237. choices=["CPU", "GPU"],
  238. help="Choose the device you want to run, it can be: CPU/GPU, default is GPU",
  239. )
  240. parser.add_argument(
  241. "--precision",
  242. type=str,
  243. default="fp32",
  244. choices=["fp32", "fp16", "int8"],
  245. help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.",
  246. )
  247. parser.add_argument(
  248. "--use_trt",
  249. type=bool,
  250. default=False,
  251. help="Whether to use tensorrt engine or not.",
  252. )
  253. parser.add_argument(
  254. "--use_mkldnn", type=bool, default=False, help="Whether use mkldnn or not."
  255. )
  256. parser.add_argument(
  257. "--cpu_threads", type=int, default=10, help="Num of cpu threads."
  258. )
  259. args = parser.parse_args()
  260. main()