predict_kie_token_ser.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(os.path.abspath(__file__))
  17. sys.path.append(__dir__)
  18. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
  19. os.environ["FLAGS_allocator_strategy"] = "auto_growth"
  20. import cv2
  21. import json
  22. import numpy as np
  23. import time
  24. import tools.infer.utility as utility
  25. from ppocr.data import create_operators, transform
  26. from ppocr.postprocess import build_post_process
  27. from ppocr.utils.logging import get_logger
  28. from ppocr.utils.visual import draw_ser_results
  29. from ppocr.utils.utility import get_image_file_list, check_and_read
  30. from ppstructure.utility import parse_args
  31. from paddleocr import PaddleOCR
  32. logger = get_logger()
  33. class SerPredictor(object):
  34. def __init__(self, args):
  35. self.args = args
  36. self.ocr_engine = PaddleOCR(
  37. use_angle_cls=args.use_angle_cls,
  38. det_model_dir=args.det_model_dir,
  39. rec_model_dir=args.rec_model_dir,
  40. show_log=False,
  41. use_gpu=args.use_gpu,
  42. )
  43. pre_process_list = [
  44. {
  45. "VQATokenLabelEncode": {
  46. "algorithm": args.kie_algorithm,
  47. "class_path": args.ser_dict_path,
  48. "contains_re": False,
  49. "ocr_engine": self.ocr_engine,
  50. "order_method": args.ocr_order_method,
  51. }
  52. },
  53. {"VQATokenPad": {"max_seq_len": 512, "return_attention_mask": True}},
  54. {"VQASerTokenChunk": {"max_seq_len": 512, "return_attention_mask": True}},
  55. {"Resize": {"size": [224, 224]}},
  56. {
  57. "NormalizeImage": {
  58. "std": [58.395, 57.12, 57.375],
  59. "mean": [123.675, 116.28, 103.53],
  60. "scale": "1",
  61. "order": "hwc",
  62. }
  63. },
  64. {"ToCHWImage": None},
  65. {
  66. "KeepKeys": {
  67. "keep_keys": [
  68. "input_ids",
  69. "bbox",
  70. "attention_mask",
  71. "token_type_ids",
  72. "image",
  73. "labels",
  74. "segment_offset_id",
  75. "ocr_info",
  76. "entities",
  77. ]
  78. }
  79. },
  80. ]
  81. postprocess_params = {
  82. "name": "VQASerTokenLayoutLMPostProcess",
  83. "class_path": args.ser_dict_path,
  84. }
  85. self.preprocess_op = create_operators(pre_process_list, {"infer_mode": True})
  86. self.postprocess_op = build_post_process(postprocess_params)
  87. (
  88. self.predictor,
  89. self.input_tensor,
  90. self.output_tensors,
  91. self.config,
  92. ) = utility.create_predictor(args, "ser", logger)
  93. def __call__(self, img):
  94. ori_im = img.copy()
  95. data = {"image": img}
  96. data = transform(data, self.preprocess_op)
  97. if data[0] is None:
  98. return None, 0
  99. starttime = time.time()
  100. for idx in range(len(data)):
  101. if isinstance(data[idx], np.ndarray):
  102. data[idx] = np.expand_dims(data[idx], axis=0)
  103. else:
  104. data[idx] = [data[idx]]
  105. if self.args.use_onnx:
  106. input_tensor = {
  107. name: data[idx] for idx, name in enumerate(self.input_tensor)
  108. }
  109. self.output_tensors = self.predictor.run(None, input_tensor)
  110. else:
  111. for idx in range(len(self.input_tensor)):
  112. self.input_tensor[idx].copy_from_cpu(data[idx])
  113. self.predictor.run()
  114. outputs = []
  115. for output_tensor in self.output_tensors:
  116. output = (
  117. output_tensor if self.args.use_onnx else output_tensor.copy_to_cpu()
  118. )
  119. outputs.append(output)
  120. preds = outputs[0]
  121. post_result = self.postprocess_op(
  122. preds, segment_offset_ids=data[6], ocr_infos=data[7]
  123. )
  124. elapse = time.time() - starttime
  125. return post_result, data, elapse
  126. def main(args):
  127. image_file_list = get_image_file_list(args.image_dir)
  128. ser_predictor = SerPredictor(args)
  129. count = 0
  130. total_time = 0
  131. os.makedirs(args.output, exist_ok=True)
  132. with open(
  133. os.path.join(args.output, "infer.txt"), mode="w", encoding="utf-8"
  134. ) as f_w:
  135. for image_file in image_file_list:
  136. img, flag, _ = check_and_read(image_file)
  137. if not flag:
  138. img = cv2.imread(image_file)
  139. img = img[:, :, ::-1]
  140. if img is None:
  141. logger.info("error in loading image:{}".format(image_file))
  142. continue
  143. ser_res, _, elapse = ser_predictor(img)
  144. ser_res = ser_res[0]
  145. res_str = "{}\t{}\n".format(
  146. image_file,
  147. json.dumps(
  148. {
  149. "ocr_info": ser_res,
  150. },
  151. ensure_ascii=False,
  152. ),
  153. )
  154. f_w.write(res_str)
  155. img_res = draw_ser_results(
  156. image_file,
  157. ser_res,
  158. font_path=args.vis_font_path,
  159. )
  160. img_save_path = os.path.join(args.output, os.path.basename(image_file))
  161. cv2.imwrite(img_save_path, img_res)
  162. logger.info("save vis result to {}".format(img_save_path))
  163. if count > 0:
  164. total_time += elapse
  165. count += 1
  166. logger.info("Predict time of {}: {}".format(image_file, elapse))
  167. if __name__ == "__main__":
  168. main(parse_args())