run.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from tqdm import tqdm
  16. import numpy as np
  17. import argparse
  18. import paddle
  19. from paddleslim.common import load_config as load_slim_config
  20. from paddleslim.common import get_logger
  21. from paddleslim.auto_compression import AutoCompression
  22. from paddleslim.common.dataloader import get_feed_vars
  23. import sys
  24. sys.path.append("../../../")
  25. from ppocr.data import build_dataloader
  26. from ppocr.postprocess import build_post_process
  27. from ppocr.metrics import build_metric
  28. logger = get_logger(__name__, level=logging.INFO)
  29. def argsparser():
  30. parser = argparse.ArgumentParser(description=__doc__)
  31. parser.add_argument(
  32. "--config_path",
  33. type=str,
  34. default=None,
  35. help="path of compression strategy config.",
  36. required=True,
  37. )
  38. parser.add_argument(
  39. "--save_dir",
  40. type=str,
  41. default="output",
  42. help="directory to save compressed model.",
  43. )
  44. parser.add_argument(
  45. "--devices", type=str, default="gpu", help="which device used to compress."
  46. )
  47. return parser
  48. def reader_wrapper(reader, input_name):
  49. if isinstance(input_name, list) and len(input_name) == 1:
  50. input_name = input_name[0]
  51. def gen(): # 形成一个字典输入
  52. for i, batch in enumerate(reader()):
  53. yield {input_name: batch[0]}
  54. return gen
  55. def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
  56. post_process_class = build_post_process(all_config["PostProcess"], global_config)
  57. eval_class = build_metric(all_config["Metric"])
  58. model_type = global_config["model_type"]
  59. with tqdm(
  60. total=len(val_loader),
  61. bar_format="Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}",
  62. ncols=80,
  63. ) as t:
  64. for batch_id, batch in enumerate(val_loader):
  65. images = batch[0]
  66. try:
  67. (preds,) = exe.run(
  68. compiled_test_program,
  69. feed={test_feed_names[0]: images},
  70. fetch_list=test_fetch_list,
  71. )
  72. except:
  73. preds, _ = exe.run(
  74. compiled_test_program,
  75. feed={test_feed_names[0]: images},
  76. fetch_list=test_fetch_list,
  77. )
  78. batch_numpy = []
  79. for item in batch:
  80. batch_numpy.append(np.array(item))
  81. if model_type == "det":
  82. preds_map = {"maps": preds}
  83. post_result = post_process_class(preds_map, batch_numpy[1])
  84. eval_class(post_result, batch_numpy)
  85. elif model_type == "rec":
  86. post_result = post_process_class(preds, batch_numpy[1])
  87. eval_class(post_result, batch_numpy)
  88. t.update()
  89. metric = eval_class.get_metric()
  90. logger.info("metric eval ***************")
  91. for k, v in metric.items():
  92. logger.info("{}:{}".format(k, v))
  93. if model_type == "det":
  94. return metric["hmean"]
  95. elif model_type == "rec":
  96. return metric["acc"]
  97. return metric
  98. def main():
  99. rank_id = paddle.distributed.get_rank()
  100. if args.devices == "gpu":
  101. place = paddle.CUDAPlace(rank_id)
  102. paddle.set_device("gpu")
  103. else:
  104. place = paddle.CPUPlace()
  105. paddle.set_device("cpu")
  106. global all_config, global_config
  107. all_config = load_slim_config(args.config_path)
  108. if "Global" not in all_config:
  109. raise KeyError(f"Key 'Global' not found in config file. \n{all_config}")
  110. global_config = all_config["Global"]
  111. gpu_num = paddle.distributed.get_world_size()
  112. train_dataloader = build_dataloader(all_config, "Train", args.devices, logger)
  113. global val_loader
  114. val_loader = build_dataloader(all_config, "Eval", args.devices, logger)
  115. if (
  116. isinstance(all_config["TrainConfig"]["learning_rate"], dict)
  117. and all_config["TrainConfig"]["learning_rate"]["type"] == "CosineAnnealingDecay"
  118. ):
  119. steps = len(train_dataloader) * all_config["TrainConfig"]["epochs"]
  120. all_config["TrainConfig"]["learning_rate"]["T_max"] = steps
  121. print("total training steps:", steps)
  122. global_config["input_name"] = get_feed_vars(
  123. global_config["model_dir"],
  124. global_config["model_filename"],
  125. global_config["params_filename"],
  126. )
  127. ac = AutoCompression(
  128. model_dir=global_config["model_dir"],
  129. model_filename=global_config["model_filename"],
  130. params_filename=global_config["params_filename"],
  131. save_dir=args.save_dir,
  132. config=all_config,
  133. train_dataloader=reader_wrapper(train_dataloader, global_config["input_name"]),
  134. eval_callback=eval_function if rank_id == 0 else None,
  135. eval_dataloader=reader_wrapper(val_loader, global_config["input_name"]),
  136. )
  137. ac.compress()
  138. if __name__ == "__main__":
  139. paddle.enable_static()
  140. parser = argsparser()
  141. args = parser.parse_args()
  142. main()