infer_rec.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import os
  19. import sys
  20. import json
  21. __dir__ = os.path.dirname(os.path.abspath(__file__))
  22. sys.path.append(__dir__)
  23. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
  24. os.environ["FLAGS_allocator_strategy"] = "auto_growth"
  25. import paddle
  26. from ppocr.data import create_operators, transform
  27. from ppocr.modeling.architectures import build_model
  28. from ppocr.postprocess import build_post_process
  29. from ppocr.utils.save_load import load_model
  30. from ppocr.utils.utility import get_image_file_list
  31. import tools.program as program
  32. def main():
  33. global_config = config["Global"]
  34. if config["Architecture"].get("algorithm") in [
  35. "UniMERNet",
  36. "PP-FormulaNet-S",
  37. "PP-FormulaNet-L",
  38. "PP-FormulaNet_plus-S",
  39. "PP-FormulaNet_plus-M",
  40. "PP-FormulaNet_plus-L",
  41. ]:
  42. config["PostProcess"]["is_infer"] = True
  43. # build post process
  44. post_process_class = build_post_process(config["PostProcess"], global_config)
  45. # build model
  46. if hasattr(post_process_class, "character"):
  47. char_num = len(getattr(post_process_class, "character"))
  48. if config["Architecture"]["algorithm"] in [
  49. "Distillation",
  50. ]: # distillation model
  51. for key in config["Architecture"]["Models"]:
  52. if (
  53. config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
  54. ): # multi head
  55. out_channels_list = {}
  56. if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
  57. char_num = char_num - 2
  58. if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
  59. char_num = char_num - 3
  60. out_channels_list["CTCLabelDecode"] = char_num
  61. out_channels_list["SARLabelDecode"] = char_num + 2
  62. out_channels_list["NRTRLabelDecode"] = char_num + 3
  63. config["Architecture"]["Models"][key]["Head"][
  64. "out_channels_list"
  65. ] = out_channels_list
  66. else:
  67. config["Architecture"]["Models"][key]["Head"][
  68. "out_channels"
  69. ] = char_num
  70. elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
  71. out_channels_list = {}
  72. char_num = len(getattr(post_process_class, "character"))
  73. if config["PostProcess"]["name"] == "SARLabelDecode":
  74. char_num = char_num - 2
  75. if config["PostProcess"]["name"] == "NRTRLabelDecode":
  76. char_num = char_num - 3
  77. out_channels_list["CTCLabelDecode"] = char_num
  78. out_channels_list["SARLabelDecode"] = char_num + 2
  79. out_channels_list["NRTRLabelDecode"] = char_num + 3
  80. config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
  81. else: # base rec model
  82. config["Architecture"]["Head"]["out_channels"] = char_num
  83. if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
  84. config["Architecture"]["Backbone"]["is_predict"] = True
  85. config["Architecture"]["Backbone"]["is_export"] = True
  86. config["Architecture"]["Head"]["is_export"] = True
  87. model = build_model(config["Architecture"])
  88. load_model(config, model)
  89. # create data ops
  90. transforms = []
  91. for op in config["Eval"]["dataset"]["transforms"]:
  92. op_name = list(op)[0]
  93. if "Label" in op_name:
  94. continue
  95. elif op_name in ["RecResizeImg"]:
  96. op[op_name]["infer_mode"] = True
  97. elif op_name == "KeepKeys":
  98. if config["Architecture"]["algorithm"] == "SRN":
  99. op[op_name]["keep_keys"] = [
  100. "image",
  101. "encoder_word_pos",
  102. "gsrm_word_pos",
  103. "gsrm_slf_attn_bias1",
  104. "gsrm_slf_attn_bias2",
  105. ]
  106. elif config["Architecture"]["algorithm"] == "SAR":
  107. op[op_name]["keep_keys"] = ["image", "valid_ratio"]
  108. elif config["Architecture"]["algorithm"] == "RobustScanner":
  109. op[op_name]["keep_keys"] = ["image", "valid_ratio", "word_positons"]
  110. else:
  111. op[op_name]["keep_keys"] = ["image"]
  112. transforms.append(op)
  113. global_config["infer_mode"] = True
  114. ops = create_operators(transforms, global_config)
  115. save_res_path = config["Global"].get(
  116. "save_res_path", "./output/rec/predicts_rec.txt"
  117. )
  118. if not os.path.exists(os.path.dirname(save_res_path)):
  119. os.makedirs(os.path.dirname(save_res_path))
  120. model.eval()
  121. infer_imgs = config["Global"]["infer_img"]
  122. infer_list = config["Global"].get("infer_list", None)
  123. with open(save_res_path, "w") as fout:
  124. for file in get_image_file_list(infer_imgs, infer_list=infer_list):
  125. logger.info("infer_img: {}".format(file))
  126. with open(file, "rb") as f:
  127. img = f.read()
  128. if config["Architecture"]["algorithm"] in [
  129. "UniMERNet",
  130. "PP-FormulaNet-S",
  131. "PP-FormulaNet-L",
  132. "PP-FormulaNet_plus-S",
  133. "PP-FormulaNet_plus-M",
  134. "PP-FormulaNet_plus-L",
  135. ]:
  136. data = {"image": img, "filename": file}
  137. else:
  138. data = {"image": img}
  139. batch = transform(data, ops)
  140. if config["Architecture"]["algorithm"] == "SRN":
  141. encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
  142. gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
  143. gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
  144. gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
  145. others = [
  146. paddle.to_tensor(encoder_word_pos_list),
  147. paddle.to_tensor(gsrm_word_pos_list),
  148. paddle.to_tensor(gsrm_slf_attn_bias1_list),
  149. paddle.to_tensor(gsrm_slf_attn_bias2_list),
  150. ]
  151. if config["Architecture"]["algorithm"] == "SAR":
  152. valid_ratio = np.expand_dims(batch[-1], axis=0)
  153. img_metas = [paddle.to_tensor(valid_ratio)]
  154. if config["Architecture"]["algorithm"] == "RobustScanner":
  155. valid_ratio = np.expand_dims(batch[1], axis=0)
  156. word_positons = np.expand_dims(batch[2], axis=0)
  157. img_metas = [
  158. paddle.to_tensor(valid_ratio),
  159. paddle.to_tensor(word_positons),
  160. ]
  161. if config["Architecture"]["algorithm"] == "CAN":
  162. image_mask = paddle.ones(
  163. (np.expand_dims(batch[0], axis=0).shape), dtype="float32"
  164. )
  165. label = paddle.ones((1, 36), dtype="int64")
  166. images = np.expand_dims(batch[0], axis=0)
  167. images = paddle.to_tensor(images)
  168. if config["Architecture"]["algorithm"] == "SRN":
  169. preds = model(images, others)
  170. elif config["Architecture"]["algorithm"] == "SAR":
  171. preds = model(images, img_metas)
  172. elif config["Architecture"]["algorithm"] == "RobustScanner":
  173. preds = model(images, img_metas)
  174. elif config["Architecture"]["algorithm"] == "CAN":
  175. preds = model([images, image_mask, label])
  176. else:
  177. preds = model(images)
  178. post_result = post_process_class(preds)
  179. info = None
  180. if isinstance(post_result, dict):
  181. rec_info = dict()
  182. for key in post_result:
  183. if len(post_result[key][0]) >= 2:
  184. rec_info[key] = {
  185. "label": post_result[key][0][0],
  186. "score": float(post_result[key][0][1]),
  187. }
  188. info = json.dumps(rec_info, ensure_ascii=False)
  189. elif isinstance(post_result, list) and isinstance(post_result[0], int):
  190. # for RFLearning CNT branch
  191. info = str(post_result[0])
  192. elif config["Architecture"]["algorithm"] in [
  193. "LaTeXOCR",
  194. "UniMERNet",
  195. "PP-FormulaNet-S",
  196. "PP-FormulaNet-L",
  197. "PP-FormulaNet_plus-S",
  198. "PP-FormulaNet_plus-M",
  199. "PP-FormulaNet_plus-L",
  200. ]:
  201. info = str(post_result[0])
  202. else:
  203. if len(post_result[0]) >= 2:
  204. info = post_result[0][0] + "\t" + str(post_result[0][1])
  205. if info is not None:
  206. logger.info("\t result: {}".format(info))
  207. fout.write(file + "\t" + info + "\n")
  208. logger.info("success!")
  209. if __name__ == "__main__":
  210. config, device, logger, vdl_writer = program.preprocess()
  211. main()