math_txt2pkl.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import pickle
  15. from tqdm import tqdm
  16. import os
  17. import math
  18. from paddle.utils import try_import
  19. from collections import defaultdict
  20. import glob
  21. from os.path import join
  22. import argparse
  23. def txt2pickle(images, equations, save_dir):
  24. imagesize = try_import("imagesize")
  25. save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(images.split("/")[-1]))
  26. min_dimensions = (32, 32)
  27. max_dimensions = (672, 192)
  28. max_length = 512
  29. data = defaultdict(lambda: [])
  30. if images is not None and equations is not None:
  31. images_list = [
  32. path.replace("\\", "/") for path in glob.glob(join(images, "*.png"))
  33. ]
  34. indices = [int(os.path.basename(img).split(".")[0]) for img in images_list]
  35. eqs = open(equations, "r").read().split("\n")
  36. for i, im in tqdm(enumerate(images_list), total=len(images_list)):
  37. width, height = imagesize.get(im)
  38. if (
  39. min_dimensions[0] <= width <= max_dimensions[0]
  40. and min_dimensions[1] <= height <= max_dimensions[1]
  41. ):
  42. divide_h = math.ceil(height / 16) * 16
  43. divide_w = math.ceil(width / 16) * 16
  44. im = os.path.basename(im)
  45. data[(divide_w, divide_h)].append((eqs[indices[i]], im))
  46. data = dict(data)
  47. with open(save_p, "wb") as file:
  48. pickle.dump(data, file)
  49. if __name__ == "__main__":
  50. parser = argparse.ArgumentParser()
  51. parser.add_argument(
  52. "--image_dir",
  53. type=str,
  54. default=".",
  55. help="Input_label or input path to be converted",
  56. )
  57. parser.add_argument(
  58. "--mathtxt_path",
  59. type=str,
  60. default=".",
  61. help="Input_label or input path to be converted",
  62. )
  63. parser.add_argument(
  64. "--output_dir", type=str, default="out_label.txt", help="Output file name"
  65. )
  66. args = parser.parse_args()
  67. txt2pickle(args.image_dir, args.mathtxt_path, args.output_dir)