export_model.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import os
  2. import sys
  3. __dir__ = os.path.dirname(os.path.abspath(__file__))
  4. sys.path.append(__dir__)
  5. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
  6. import argparse
  7. import paddle
  8. from paddle.jit import to_static
  9. from models import build_model
  10. from utils import Config, ArgsParser
  11. def init_args():
  12. parser = ArgsParser()
  13. args = parser.parse_args()
  14. return args
  15. def load_checkpoint(model, checkpoint_path):
  16. """
  17. load checkpoints
  18. :param checkpoint_path: Checkpoint path to be loaded
  19. """
  20. checkpoint = paddle.load(checkpoint_path)
  21. model.set_state_dict(checkpoint["state_dict"])
  22. print("load checkpoint from {}".format(checkpoint_path))
  23. def main(config):
  24. model = build_model(config["arch"])
  25. load_checkpoint(model, config["trainer"]["resume_checkpoint"])
  26. model.eval()
  27. save_path = config["trainer"]["output_dir"]
  28. save_path = os.path.join(save_path, "inference")
  29. infer_shape = [3, -1, -1]
  30. model = to_static(
  31. model,
  32. input_spec=[
  33. paddle.static.InputSpec(shape=[None] + infer_shape, dtype="float32")
  34. ],
  35. )
  36. paddle.jit.save(model, save_path)
  37. print("inference model is saved to {}".format(save_path))
  38. if __name__ == "__main__":
  39. args = init_args()
  40. assert os.path.exists(args.config_file)
  41. config = Config(args.config_file)
  42. config.merge_dict(args.opt)
  43. main(config.cfg)