export_model.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(os.path.abspath(__file__))
  17. sys.path.append(__dir__)
  18. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..", "..", "..")))
  19. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..", "..", "..", "tools")))
  20. import argparse
  21. import paddle
  22. from paddle.jit import to_static
  23. from ppocr.modeling.architectures import build_model
  24. from ppocr.postprocess import build_post_process
  25. from ppocr.utils.save_load import load_model
  26. from ppocr.utils.logging import get_logger
  27. from tools.program import load_config, merge_config, ArgsParser
  28. from ppocr.metrics import build_metric
  29. import tools.program as program
  30. from paddleslim.dygraph.quant import QAT
  31. from ppocr.data import build_dataloader, set_signal_handlers
  32. from ppocr.utils.export_model import export_single_model
  33. def main():
  34. ############################################################################################################
  35. # 1. quantization configs
  36. ############################################################################################################
  37. quant_config = {
  38. # weight preprocess type, default is None and no preprocessing is performed.
  39. "weight_preprocess_type": None,
  40. # activation preprocess type, default is None and no preprocessing is performed.
  41. "activation_preprocess_type": None,
  42. # weight quantize type, default is 'channel_wise_abs_max'
  43. "weight_quantize_type": "channel_wise_abs_max",
  44. # activation quantize type, default is 'moving_average_abs_max'
  45. "activation_quantize_type": "moving_average_abs_max",
  46. # weight quantize bit num, default is 8
  47. "weight_bits": 8,
  48. # activation quantize bit num, default is 8
  49. "activation_bits": 8,
  50. # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
  51. "dtype": "int8",
  52. # window size for 'range_abs_max' quantization. default is 10000
  53. "window_size": 10000,
  54. # The decay coefficient of moving average, default is 0.9
  55. "moving_rate": 0.9,
  56. # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
  57. "quantizable_layer_type": ["Conv2D", "Linear"],
  58. }
  59. FLAGS = ArgsParser().parse_args()
  60. config = load_config(FLAGS.config)
  61. config = merge_config(config, FLAGS.opt)
  62. logger = get_logger()
  63. # build post process
  64. post_process_class = build_post_process(config["PostProcess"], config["Global"])
  65. # build model
  66. if hasattr(post_process_class, "character"):
  67. char_num = len(getattr(post_process_class, "character"))
  68. if config["Architecture"]["algorithm"] in [
  69. "Distillation",
  70. ]: # distillation model
  71. for key in config["Architecture"]["Models"]:
  72. if (
  73. config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
  74. ): # for multi head
  75. if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
  76. char_num = char_num - 2
  77. # update SARLoss params
  78. assert (
  79. list(config["Loss"]["loss_config_list"][-1].keys())[0]
  80. == "DistillationSARLoss"
  81. )
  82. config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
  83. "ignore_index"
  84. ] = (char_num + 1)
  85. out_channels_list = {}
  86. out_channels_list["CTCLabelDecode"] = char_num
  87. out_channels_list["SARLabelDecode"] = char_num + 2
  88. config["Architecture"]["Models"][key]["Head"][
  89. "out_channels_list"
  90. ] = out_channels_list
  91. else:
  92. config["Architecture"]["Models"][key]["Head"][
  93. "out_channels"
  94. ] = char_num
  95. elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
  96. if config["PostProcess"]["name"] == "SARLabelDecode":
  97. char_num = char_num - 2
  98. # update SARLoss params
  99. assert list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss"
  100. if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
  101. config["Loss"]["loss_config_list"][1]["SARLoss"] = {
  102. "ignore_index": char_num + 1
  103. }
  104. else:
  105. config["Loss"]["loss_config_list"][1]["SARLoss"]["ignore_index"] = (
  106. char_num + 1
  107. )
  108. out_channels_list = {}
  109. out_channels_list["CTCLabelDecode"] = char_num
  110. out_channels_list["SARLabelDecode"] = char_num + 2
  111. config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
  112. else: # base rec model
  113. config["Architecture"]["Head"]["out_channels"] = char_num
  114. if config["PostProcess"]["name"] == "SARLabelDecode": # for SAR model
  115. config["Loss"]["ignore_index"] = char_num - 1
  116. model = build_model(config["Architecture"])
  117. # get QAT model
  118. quanter = QAT(config=quant_config)
  119. quanter.quantize(model)
  120. load_model(config, model)
  121. # build metric
  122. eval_class = build_metric(config["Metric"])
  123. # build dataloader
  124. set_signal_handlers()
  125. valid_dataloader = build_dataloader(config, "Eval", device, logger)
  126. use_srn = config["Architecture"]["algorithm"] == "SRN"
  127. model_type = config["Architecture"].get("model_type", None)
  128. # start eval
  129. metric = program.eval(
  130. model, valid_dataloader, post_process_class, eval_class, model_type, use_srn
  131. )
  132. model.eval()
  133. logger.info("metric eval ***************")
  134. for k, v in metric.items():
  135. logger.info("{}:{}".format(k, v))
  136. save_path = config["Global"]["save_inference_dir"]
  137. arch_config = config["Architecture"]
  138. if (
  139. arch_config["algorithm"] == "SVTR"
  140. and arch_config["Head"]["name"] != "MultiHead"
  141. ):
  142. input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
  143. "image_shape"
  144. ]
  145. else:
  146. input_shape = None
  147. if arch_config["algorithm"] in [
  148. "Distillation",
  149. ]: # distillation model
  150. archs = list(arch_config["Models"].values())
  151. for idx, name in enumerate(model.model_name_list):
  152. sub_model_save_path = os.path.join(save_path, name, "inference")
  153. export_single_model(
  154. model.model_list[idx],
  155. archs[idx],
  156. sub_model_save_path,
  157. logger,
  158. input_shape,
  159. quanter,
  160. )
  161. else:
  162. save_path = os.path.join(save_path, "inference")
  163. export_single_model(model, arch_config, save_path, logger, input_shape, quanter)
  164. if __name__ == "__main__":
  165. config, device, logger, vdl_writer = program.preprocess()
  166. main()