eval.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2018/6/11 15:54
  3. # @Author : zhoujun
  4. import os
  5. import sys
  6. import pathlib
  7. __dir__ = pathlib.Path(os.path.abspath(__file__))
  8. sys.path.append(str(__dir__))
  9. sys.path.append(str(__dir__.parent.parent))
  10. import argparse
  11. import time
  12. import paddle
  13. from tqdm.auto import tqdm
  14. class EVAL:
  15. def __init__(self, model_path, gpu_id=0):
  16. from models import build_model
  17. from data_loader import get_dataloader
  18. from post_processing import get_post_processing
  19. from utils import get_metric
  20. self.gpu_id = gpu_id
  21. if (
  22. self.gpu_id is not None
  23. and isinstance(self.gpu_id, int)
  24. and paddle.device.is_compiled_with_cuda()
  25. ):
  26. paddle.device.set_device("gpu:{}".format(self.gpu_id))
  27. else:
  28. paddle.device.set_device("cpu")
  29. checkpoint = paddle.load(model_path)
  30. config = checkpoint["config"]
  31. config["arch"]["backbone"]["pretrained"] = False
  32. self.validate_loader = get_dataloader(
  33. config["dataset"]["validate"], config["distributed"]
  34. )
  35. self.model = build_model(config["arch"])
  36. self.model.set_state_dict(checkpoint["state_dict"])
  37. self.post_process = get_post_processing(config["post_processing"])
  38. self.metric_cls = get_metric(config["metric"])
  39. def eval(self):
  40. self.model.eval()
  41. raw_metrics = []
  42. total_frame = 0.0
  43. total_time = 0.0
  44. for i, batch in tqdm(
  45. enumerate(self.validate_loader),
  46. total=len(self.validate_loader),
  47. desc="test model",
  48. ):
  49. with paddle.no_grad():
  50. start = time.time()
  51. preds = self.model(batch["img"])
  52. boxes, scores = self.post_process(
  53. batch, preds, is_output_polygon=self.metric_cls.is_output_polygon
  54. )
  55. total_frame += batch["img"].shape[0]
  56. total_time += time.time() - start
  57. raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
  58. raw_metrics.append(raw_metric)
  59. metrics = self.metric_cls.gather_measure(raw_metrics)
  60. print("FPS:{}".format(total_frame / total_time))
  61. return {
  62. "recall": metrics["recall"].avg,
  63. "precision": metrics["precision"].avg,
  64. "fmeasure": metrics["fmeasure"].avg,
  65. }
  66. def init_args():
  67. parser = argparse.ArgumentParser(description="DBNet.paddle")
  68. parser.add_argument(
  69. "--model_path",
  70. required=False,
  71. default="output/DBNet_resnet18_FPN_DBHead/checkpoint/1.pth",
  72. type=str,
  73. )
  74. args = parser.parse_args()
  75. return args
  76. if __name__ == "__main__":
  77. args = init_args()
  78. eval = EVAL(args.model_path)
  79. result = eval.eval()
  80. print(result)