convert_ppocr_label.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import json
  16. import os
  17. def poly_to_string(poly):
  18. if len(poly.shape) > 1:
  19. poly = np.array(poly).flatten()
  20. string = "\t".join(str(i) for i in poly)
  21. return string
  22. def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
  23. if not os.path.exists(label_dir):
  24. raise ValueError(f"The file {label_dir} does not exist!")
  25. assert label_dir != save_dir, "hahahhaha"
  26. label_file = open(label_dir, "r")
  27. data = label_file.readlines()
  28. gt_dict = {}
  29. for line in data:
  30. try:
  31. tmp = line.split("\t")
  32. assert len(tmp) == 2, ""
  33. except:
  34. tmp = line.strip().split(" ")
  35. gt_lists = []
  36. if tmp[0].split("/")[0] is not None:
  37. img_path = tmp[0]
  38. anno = json.loads(tmp[1])
  39. gt_collect = []
  40. for dic in anno:
  41. # txt = dic['transcription'].replace(' ', '') # ignore blank
  42. txt = dic["transcription"]
  43. if "score" in dic and float(dic["score"]) < 0.5:
  44. continue
  45. if "\u3000" in txt:
  46. txt = txt.replace("\u3000", " ")
  47. # while ' ' in txt:
  48. # txt = txt.replace(' ', '')
  49. poly = np.array(dic["points"]).flatten()
  50. if txt == "###":
  51. txt_tag = 1 ## ignore 1
  52. else:
  53. txt_tag = 0
  54. if mode == "gt":
  55. gt_label = (
  56. poly_to_string(poly) + "\t" + str(txt_tag) + "\t" + txt + "\n"
  57. )
  58. else:
  59. gt_label = poly_to_string(poly) + "\t" + txt + "\n"
  60. gt_lists.append(gt_label)
  61. gt_dict[img_path] = gt_lists
  62. else:
  63. continue
  64. if not os.path.exists(save_dir):
  65. os.makedirs(save_dir)
  66. for img_name in gt_dict.keys():
  67. save_name = img_name.split("/")[-1]
  68. save_file = os.path.join(save_dir, save_name + ".txt")
  69. with open(save_file, "w") as f:
  70. f.writelines(gt_dict[img_name])
  71. print("The convert label saved in {}".format(save_dir))
  72. def parse_args():
  73. import argparse
  74. parser = argparse.ArgumentParser(description="args")
  75. parser.add_argument("--label_path", type=str, required=True)
  76. parser.add_argument("--save_folder", type=str, required=True)
  77. parser.add_argument("--mode", type=str, default=False)
  78. args = parser.parse_args()
  79. return args
  80. if __name__ == "__main__":
  81. args = parse_args()
  82. convert_label(args.label_path, args.mode, args.save_folder)