infer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import pathlib
  17. __dir__ = pathlib.Path(os.path.abspath(__file__))
  18. sys.path.append(str(__dir__))
  19. sys.path.append(str(__dir__.parent.parent))
  20. import cv2
  21. import paddle
  22. from paddle import inference
  23. import numpy as np
  24. from PIL import Image
  25. from paddle.vision import transforms
  26. from tools.predict import resize_image
  27. from post_processing import get_post_processing
  28. from utils.util import draw_bbox, save_result
  29. class InferenceEngine(object):
  30. """InferenceEngine
  31. Inference engine class which contains preprocess, run, postprocess
  32. """
  33. def __init__(self, args):
  34. """
  35. Args:
  36. args: Parameters generated using argparser.
  37. Returns: None
  38. """
  39. super().__init__()
  40. self.args = args
  41. # init inference engine
  42. (
  43. self.predictor,
  44. self.config,
  45. self.input_tensor,
  46. self.output_tensor,
  47. ) = self.load_predictor(
  48. os.path.join(args.model_dir, "inference.pdmodel"),
  49. os.path.join(args.model_dir, "inference.pdiparams"),
  50. )
  51. # build transforms
  52. self.transforms = transforms.Compose(
  53. [
  54. transforms.ToTensor(),
  55. transforms.Normalize(
  56. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  57. ),
  58. ]
  59. )
  60. # wamrup
  61. if self.args.warmup > 0:
  62. for idx in range(args.warmup):
  63. print(idx)
  64. x = np.random.rand(
  65. 1, 3, self.args.crop_size, self.args.crop_size
  66. ).astype("float32")
  67. self.input_tensor.copy_from_cpu(x)
  68. self.predictor.run()
  69. self.output_tensor.copy_to_cpu()
  70. self.post_process = get_post_processing(
  71. {
  72. "type": "SegDetectorRepresenter",
  73. "args": {
  74. "thresh": 0.3,
  75. "box_thresh": 0.7,
  76. "max_candidates": 1000,
  77. "unclip_ratio": 1.5,
  78. },
  79. }
  80. )
  81. def load_predictor(self, model_file_path, params_file_path):
  82. """load_predictor
  83. initialize the inference engine
  84. Args:
  85. model_file_path: inference model path (*.pdmodel)
  86. model_file_path: inference parameter path (*.pdiparams)
  87. Return:
  88. predictor: Predictor created using Paddle Inference.
  89. config: Configuration of the predictor.
  90. input_tensor: Input tensor of the predictor.
  91. output_tensor: Output tensor of the predictor.
  92. """
  93. args = self.args
  94. config = inference.Config(model_file_path, params_file_path)
  95. if args.use_gpu:
  96. config.enable_use_gpu(1000, 0)
  97. if args.use_tensorrt:
  98. config.enable_tensorrt_engine(
  99. workspace_size=1 << 30,
  100. precision_mode=precision,
  101. max_batch_size=args.max_batch_size,
  102. min_subgraph_size=args.min_subgraph_size, # skip the minimum trt subgraph
  103. use_calib_mode=False,
  104. )
  105. # collect shape
  106. trt_shape_f = os.path.join(model_dir, "_trt_dynamic_shape.txt")
  107. if not os.path.exists(trt_shape_f):
  108. config.collect_shape_range_info(trt_shape_f)
  109. logger.info(f"collect dynamic shape info into : {trt_shape_f}")
  110. try:
  111. config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True)
  112. except Exception as E:
  113. logger.info(E)
  114. logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!")
  115. else:
  116. config.disable_gpu()
  117. # The thread num should not be greater than the number of cores in the CPU.
  118. if args.enable_mkldnn:
  119. # cache 10 different shapes for mkldnn to avoid memory leak
  120. config.set_mkldnn_cache_capacity(10)
  121. config.enable_mkldnn()
  122. if args.precision == "fp16":
  123. config.enable_mkldnn_bfloat16()
  124. if hasattr(args, "cpu_threads"):
  125. config.set_cpu_math_library_num_threads(args.cpu_threads)
  126. else:
  127. # default cpu threads as 10
  128. config.set_cpu_math_library_num_threads(10)
  129. # enable memory optim
  130. config.enable_memory_optim()
  131. config.disable_glog_info()
  132. config.switch_use_feed_fetch_ops(False)
  133. config.switch_ir_optim(True)
  134. # create predictor
  135. predictor = inference.create_predictor(config)
  136. # get input and output tensor property
  137. input_names = predictor.get_input_names()
  138. input_tensor = predictor.get_input_handle(input_names[0])
  139. output_names = predictor.get_output_names()
  140. output_tensor = predictor.get_output_handle(output_names[0])
  141. return predictor, config, input_tensor, output_tensor
  142. def preprocess(self, img_path, short_size):
  143. """preprocess
  144. Preprocess to the input.
  145. Args:
  146. img_path: Image path.
  147. Returns: Input data after preprocess.
  148. """
  149. img = cv2.imread(img_path, 1)
  150. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  151. h, w = img.shape[:2]
  152. img = resize_image(img, short_size)
  153. img = self.transforms(img)
  154. img = np.expand_dims(img, axis=0)
  155. shape_info = {"shape": [(h, w)]}
  156. return img, shape_info
  157. def postprocess(self, x, shape_info, is_output_polygon):
  158. """postprocess
  159. Postprocess to the inference engine output.
  160. Args:
  161. x: Inference engine output.
  162. Returns: Output data after argmax.
  163. """
  164. box_list, score_list = self.post_process(
  165. shape_info, x, is_output_polygon=is_output_polygon
  166. )
  167. box_list, score_list = box_list[0], score_list[0]
  168. if len(box_list) > 0:
  169. if is_output_polygon:
  170. idx = [x.sum() > 0 for x in box_list]
  171. box_list = [box_list[i] for i, v in enumerate(idx) if v]
  172. score_list = [score_list[i] for i, v in enumerate(idx) if v]
  173. else:
  174. idx = (
  175. box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0
  176. ) # 去掉全为0的框
  177. box_list, score_list = box_list[idx], score_list[idx]
  178. else:
  179. box_list, score_list = [], []
  180. return box_list, score_list
  181. def run(self, x):
  182. """run
  183. Inference process using inference engine.
  184. Args:
  185. x: Input data after preprocess.
  186. Returns: Inference engine output
  187. """
  188. self.input_tensor.copy_from_cpu(x)
  189. self.predictor.run()
  190. output = self.output_tensor.copy_to_cpu()
  191. return output
  192. def get_args(add_help=True):
  193. """
  194. parse args
  195. """
  196. import argparse
  197. def str2bool(v):
  198. return v.lower() in ("true", "t", "1")
  199. parser = argparse.ArgumentParser(
  200. description="PaddlePaddle Classification Training", add_help=add_help
  201. )
  202. parser.add_argument("--model_dir", default=None, help="inference model dir")
  203. parser.add_argument("--batch_size", type=int, default=1)
  204. parser.add_argument("--short_size", default=1024, type=int, help="short size")
  205. parser.add_argument("--img_path", default="./images/demo.jpg")
  206. parser.add_argument("--benchmark", default=False, type=str2bool, help="benchmark")
  207. parser.add_argument("--warmup", default=0, type=int, help="warmup iter")
  208. parser.add_argument("--polygon", action="store_true", help="output polygon or box")
  209. parser.add_argument("--use_gpu", type=str2bool, default=True)
  210. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  211. parser.add_argument("--precision", type=str, default="fp32")
  212. parser.add_argument("--gpu_mem", type=int, default=500)
  213. parser.add_argument("--gpu_id", type=int, default=0)
  214. parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
  215. parser.add_argument("--cpu_threads", type=int, default=10)
  216. args = parser.parse_args()
  217. return args
  218. def main(args):
  219. """
  220. Main inference function.
  221. Args:
  222. args: Parameters generated using argparser.
  223. Returns:
  224. class_id: Class index of the input.
  225. prob: : Probability of the input.
  226. """
  227. inference_engine = InferenceEngine(args)
  228. # init benchmark
  229. if args.benchmark:
  230. import auto_log
  231. autolog = auto_log.AutoLogger(
  232. model_name="db",
  233. batch_size=args.batch_size,
  234. inference_config=inference_engine.config,
  235. gpu_ids="auto" if args.use_gpu else None,
  236. )
  237. # enable benchmark
  238. if args.benchmark:
  239. autolog.times.start()
  240. # preprocess
  241. img, shape_info = inference_engine.preprocess(args.img_path, args.short_size)
  242. if args.benchmark:
  243. autolog.times.stamp()
  244. output = inference_engine.run(img)
  245. if args.benchmark:
  246. autolog.times.stamp()
  247. # postprocess
  248. box_list, score_list = inference_engine.postprocess(
  249. output, shape_info, args.polygon
  250. )
  251. if args.benchmark:
  252. autolog.times.stamp()
  253. autolog.times.end(stamp=True)
  254. autolog.report()
  255. img = draw_bbox(cv2.imread(args.img_path)[:, :, ::-1], box_list)
  256. # 保存结果到路径
  257. os.makedirs("output", exist_ok=True)
  258. img_path = pathlib.Path(args.img_path)
  259. output_path = os.path.join("output", img_path.stem + "_infer_result.jpg")
  260. cv2.imwrite(output_path, img[:, :, ::-1])
  261. save_result(
  262. output_path.replace("_infer_result.jpg", ".txt"),
  263. box_list,
  264. score_list,
  265. args.polygon,
  266. )
  267. if __name__ == "__main__":
  268. args = get_args()
  269. main(args)