eval.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(os.path.abspath(__file__))
  20. sys.path.insert(0, __dir__)
  21. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
  22. import paddle
  23. from ppocr.data import build_dataloader, set_signal_handlers
  24. from ppocr.modeling.architectures import build_model
  25. from ppocr.postprocess import build_post_process
  26. from ppocr.metrics import build_metric
  27. from ppocr.utils.save_load import load_model
  28. import tools.program as program
  29. def main():
  30. global_config = config["Global"]
  31. # build dataloader
  32. set_signal_handlers()
  33. valid_dataloader = build_dataloader(config, "Eval", device, logger)
  34. # build post process
  35. post_process_class = build_post_process(config["PostProcess"], global_config)
  36. # build model
  37. # for rec algorithm
  38. if hasattr(post_process_class, "character"):
  39. char_num = len(getattr(post_process_class, "character"))
  40. if config["Architecture"]["algorithm"] in [
  41. "Distillation",
  42. ]: # distillation model
  43. for key in config["Architecture"]["Models"]:
  44. if (
  45. config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
  46. ): # for multi head
  47. out_channels_list = {}
  48. if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
  49. char_num = char_num - 2
  50. if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
  51. char_num = char_num - 3
  52. out_channels_list["CTCLabelDecode"] = char_num
  53. out_channels_list["SARLabelDecode"] = char_num + 2
  54. out_channels_list["NRTRLabelDecode"] = char_num + 3
  55. config["Architecture"]["Models"][key]["Head"][
  56. "out_channels_list"
  57. ] = out_channels_list
  58. else:
  59. config["Architecture"]["Models"][key]["Head"][
  60. "out_channels"
  61. ] = char_num
  62. elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
  63. out_channels_list = {}
  64. if config["PostProcess"]["name"] == "SARLabelDecode":
  65. char_num = char_num - 2
  66. if config["PostProcess"]["name"] == "NRTRLabelDecode":
  67. char_num = char_num - 3
  68. out_channels_list["CTCLabelDecode"] = char_num
  69. out_channels_list["SARLabelDecode"] = char_num + 2
  70. out_channels_list["NRTRLabelDecode"] = char_num + 3
  71. config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
  72. else: # base rec model
  73. config["Architecture"]["Head"]["out_channels"] = char_num
  74. model = build_model(config["Architecture"])
  75. extra_input_models = [
  76. "SRN",
  77. "NRTR",
  78. "SAR",
  79. "SEED",
  80. "SVTR",
  81. "SVTR_LCNet",
  82. "VisionLAN",
  83. "RobustScanner",
  84. "SVTR_HGNet",
  85. ]
  86. extra_input = False
  87. if config["Architecture"]["algorithm"] == "Distillation":
  88. for key in config["Architecture"]["Models"]:
  89. extra_input = (
  90. extra_input
  91. or config["Architecture"]["Models"][key]["algorithm"]
  92. in extra_input_models
  93. )
  94. else:
  95. extra_input = config["Architecture"]["algorithm"] in extra_input_models
  96. if "model_type" in config["Architecture"].keys():
  97. if config["Architecture"]["algorithm"] == "CAN":
  98. model_type = "can"
  99. elif config["Architecture"]["algorithm"] == "LaTeXOCR":
  100. model_type = "latexocr"
  101. config["Metric"]["cal_bleu_score"] = True
  102. elif config["Architecture"]["algorithm"] == "UniMERNet":
  103. model_type = "unimernet"
  104. config["Metric"]["cal_bleu_score"] = True
  105. elif config["Architecture"]["algorithm"] in [
  106. "PP-FormulaNet-S",
  107. "PP-FormulaNet-L",
  108. "PP-FormulaNet_plus-S",
  109. "PP-FormulaNet_plus-M",
  110. "PP-FormulaNet_plus-L",
  111. ]:
  112. model_type = "pp_formulanet"
  113. config["Metric"]["cal_bleu_score"] = True
  114. else:
  115. model_type = config["Architecture"]["model_type"]
  116. else:
  117. model_type = None
  118. # build metric
  119. eval_class = build_metric(config["Metric"])
  120. # amp
  121. use_amp = config["Global"].get("use_amp", False)
  122. amp_level = config["Global"].get("amp_level", "O2")
  123. amp_custom_black_list = config["Global"].get("amp_custom_black_list", [])
  124. if use_amp:
  125. AMP_RELATED_FLAGS_SETTING = {
  126. "FLAGS_cudnn_batchnorm_spatial_persistent": 1,
  127. }
  128. paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
  129. scale_loss = config["Global"].get("scale_loss", 1.0)
  130. use_dynamic_loss_scaling = config["Global"].get(
  131. "use_dynamic_loss_scaling", False
  132. )
  133. scaler = paddle.amp.GradScaler(
  134. init_loss_scaling=scale_loss,
  135. use_dynamic_loss_scaling=use_dynamic_loss_scaling,
  136. )
  137. if amp_level == "O2":
  138. model = paddle.amp.decorate(
  139. models=model, level=amp_level, master_weight=True
  140. )
  141. else:
  142. scaler = None
  143. best_model_dict = load_model(
  144. config, model, model_type=config["Architecture"]["model_type"]
  145. )
  146. if len(best_model_dict):
  147. logger.info("metric in ckpt ***************")
  148. for k, v in best_model_dict.items():
  149. logger.info("{}:{}".format(k, v))
  150. # start eval
  151. metric = program.eval(
  152. model,
  153. valid_dataloader,
  154. post_process_class,
  155. eval_class,
  156. model_type,
  157. extra_input,
  158. scaler,
  159. amp_level,
  160. amp_custom_black_list,
  161. )
  162. logger.info("metric eval ***************")
  163. for k, v in metric.items():
  164. logger.info("{}:{}".format(k, v))
  165. if __name__ == "__main__":
  166. config, device, logger, vdl_writer = program.preprocess()
  167. main()