predict.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/8/24 12:06
  3. # @Author : zhoujun
  4. import os
  5. import sys
  6. import pathlib
  7. __dir__ = pathlib.Path(os.path.abspath(__file__))
  8. sys.path.append(str(__dir__))
  9. sys.path.append(str(__dir__.parent.parent))
  10. import time
  11. import cv2
  12. import paddle
  13. from data_loader import get_transforms
  14. from models import build_model
  15. from post_processing import get_post_processing
  16. def resize_image(img, short_size):
  17. height, width, _ = img.shape
  18. if height < width:
  19. new_height = short_size
  20. new_width = new_height / height * width
  21. else:
  22. new_width = short_size
  23. new_height = new_width / width * height
  24. new_height = int(round(new_height / 32) * 32)
  25. new_width = int(round(new_width / 32) * 32)
  26. resized_img = cv2.resize(img, (new_width, new_height))
  27. return resized_img
  28. class PaddleModel:
  29. def __init__(self, model_path, post_p_thre=0.7, gpu_id=None):
  30. """
  31. 初始化模型
  32. :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
  33. :param gpu_id: 在哪一块gpu上运行
  34. """
  35. self.gpu_id = gpu_id
  36. if (
  37. self.gpu_id is not None
  38. and isinstance(self.gpu_id, int)
  39. and paddle.device.is_compiled_with_cuda()
  40. ):
  41. paddle.device.set_device("gpu:{}".format(self.gpu_id))
  42. else:
  43. paddle.device.set_device("cpu")
  44. checkpoint = paddle.load(model_path)
  45. config = checkpoint["config"]
  46. config["arch"]["backbone"]["pretrained"] = False
  47. self.model = build_model(config["arch"])
  48. self.post_process = get_post_processing(config["post_processing"])
  49. self.post_process.box_thresh = post_p_thre
  50. self.img_mode = config["dataset"]["train"]["dataset"]["args"]["img_mode"]
  51. self.model.set_state_dict(checkpoint["state_dict"])
  52. self.model.eval()
  53. self.transform = []
  54. for t in config["dataset"]["train"]["dataset"]["args"]["transforms"]:
  55. if t["type"] in ["ToTensor", "Normalize"]:
  56. self.transform.append(t)
  57. self.transform = get_transforms(self.transform)
  58. def predict(self, img_path: str, is_output_polygon=False, short_size: int = 1024):
  59. """
  60. 对传入的图像进行预测,支持图像地址,opencv 读取图片,偏慢
  61. :param img_path: 图像地址
  62. :param is_numpy:
  63. :return:
  64. """
  65. assert os.path.exists(img_path), "file is not exists"
  66. img = cv2.imread(img_path, 1 if self.img_mode != "GRAY" else 0)
  67. if self.img_mode == "RGB":
  68. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  69. h, w = img.shape[:2]
  70. img = resize_image(img, short_size)
  71. # 将图片由(w,h)变为(1,img_channel,h,w)
  72. tensor = self.transform(img)
  73. tensor = tensor.unsqueeze_(0)
  74. batch = {"shape": [(h, w)]}
  75. with paddle.no_grad():
  76. start = time.time()
  77. preds = self.model(tensor)
  78. box_list, score_list = self.post_process(
  79. batch, preds, is_output_polygon=is_output_polygon
  80. )
  81. box_list, score_list = box_list[0], score_list[0]
  82. if len(box_list) > 0:
  83. if is_output_polygon:
  84. idx = [x.sum() > 0 for x in box_list]
  85. box_list = [box_list[i] for i, v in enumerate(idx) if v]
  86. score_list = [score_list[i] for i, v in enumerate(idx) if v]
  87. else:
  88. idx = (
  89. box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0
  90. ) # 去掉全为0的框
  91. box_list, score_list = box_list[idx], score_list[idx]
  92. else:
  93. box_list, score_list = [], []
  94. t = time.time() - start
  95. return preds[0, 0, :, :].detach().cpu().numpy(), box_list, score_list, t
  96. def save_depoly(net, input, save_path):
  97. input_spec = [paddle.static.InputSpec(shape=[None, 3, None, None], dtype="float32")]
  98. net = paddle.jit.to_static(net, input_spec=input_spec)
  99. # save static model for inference directly
  100. paddle.jit.save(net, save_path)
  101. def init_args():
  102. import argparse
  103. parser = argparse.ArgumentParser(description="DBNet.paddle")
  104. parser.add_argument("--model_path", default=r"model_best.pth", type=str)
  105. parser.add_argument(
  106. "--input_folder", default="./test/input", type=str, help="img path for predict"
  107. )
  108. parser.add_argument(
  109. "--output_folder", default="./test/output", type=str, help="img path for output"
  110. )
  111. parser.add_argument("--gpu", default=0, type=int, help="gpu for inference")
  112. parser.add_argument(
  113. "--thre", default=0.3, type=float, help="the thresh of post_processing"
  114. )
  115. parser.add_argument("--polygon", action="store_true", help="output polygon or box")
  116. parser.add_argument("--show", action="store_true", help="show result")
  117. parser.add_argument(
  118. "--save_result", action="store_true", help="save box and score to txt file"
  119. )
  120. args = parser.parse_args()
  121. return args
  122. if __name__ == "__main__":
  123. import pathlib
  124. from tqdm import tqdm
  125. import matplotlib.pyplot as plt
  126. from utils.util import show_img, draw_bbox, save_result, get_image_file_list
  127. args = init_args()
  128. print(args)
  129. # 初始化网络
  130. model = PaddleModel(args.model_path, post_p_thre=args.thre, gpu_id=args.gpu)
  131. img_folder = pathlib.Path(args.input_folder)
  132. for img_path in tqdm(get_image_file_list(args.input_folder)):
  133. preds, boxes_list, score_list, t = model.predict(
  134. img_path, is_output_polygon=args.polygon
  135. )
  136. img = draw_bbox(cv2.imread(img_path)[:, :, ::-1], boxes_list)
  137. if args.show:
  138. show_img(preds)
  139. show_img(img, title=os.path.basename(img_path))
  140. plt.show()
  141. # 保存结果到路径
  142. os.makedirs(args.output_folder, exist_ok=True)
  143. img_path = pathlib.Path(img_path)
  144. output_path = os.path.join(args.output_folder, img_path.stem + "_result.jpg")
  145. pred_path = os.path.join(args.output_folder, img_path.stem + "_pred.jpg")
  146. cv2.imwrite(output_path, img[:, :, ::-1])
  147. cv2.imwrite(pred_path, preds * 255)
  148. save_result(
  149. output_path.replace("_result.jpg", ".txt"),
  150. boxes_list,
  151. score_list,
  152. args.polygon,
  153. )