latex_ocr_aug.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/dataset/transforms.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. from __future__ import unicode_literals
  22. import os
  23. os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
  24. import math
  25. import cv2
  26. import numpy as np
  27. import albumentations as A
  28. from PIL import Image
  29. class LatexTrainTransform:
  30. def __init__(self, bitmap_prob=0.04, **kwargs):
  31. # your init code
  32. self.bitmap_prob = bitmap_prob
  33. self.train_transform = A.Compose(
  34. [
  35. A.Compose(
  36. [
  37. A.ShiftScaleRotate(
  38. shift_limit=0,
  39. scale_limit=(-0.15, 0),
  40. rotate_limit=1,
  41. border_mode=0,
  42. interpolation=3,
  43. value=[255, 255, 255],
  44. p=1,
  45. ),
  46. A.GridDistortion(
  47. distort_limit=0.1,
  48. border_mode=0,
  49. interpolation=3,
  50. value=[255, 255, 255],
  51. p=0.5,
  52. ),
  53. ],
  54. p=0.15,
  55. ),
  56. A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.3),
  57. A.GaussNoise(10, p=0.2),
  58. A.RandomBrightnessContrast(0.05, (-0.2, 0), True, p=0.2),
  59. A.ImageCompression(95, p=0.3),
  60. A.ToGray(always_apply=True),
  61. ]
  62. )
  63. def __call__(self, data):
  64. img = data["image"]
  65. if np.random.random() < self.bitmap_prob:
  66. img[img != 255] = 0
  67. img = self.train_transform(image=img)["image"]
  68. data["image"] = img
  69. return data
  70. class LatexTestTransform:
  71. def __init__(self, **kwargs):
  72. # your init code
  73. self.test_transform = A.Compose(
  74. [
  75. A.ToGray(always_apply=True),
  76. ]
  77. )
  78. def __call__(self, data):
  79. img = data["image"]
  80. img = self.test_transform(image=img)["image"]
  81. data["image"] = img
  82. return data
  83. class MinMaxResize:
  84. def __init__(self, min_dimensions=[32, 32], max_dimensions=[672, 192], **kwargs):
  85. # your init code
  86. self.min_dimensions = min_dimensions
  87. self.max_dimensions = max_dimensions
  88. # pass
  89. def pad_(self, img, divable=32):
  90. threshold = 128
  91. data = np.array(img.convert("LA"))
  92. if data[..., -1].var() == 0:
  93. data = (data[..., 0]).astype(np.uint8)
  94. else:
  95. data = (255 - data[..., -1]).astype(np.uint8)
  96. data = (data - data.min()) / (data.max() - data.min()) * 255
  97. if data.mean() > threshold:
  98. # To invert the text to white
  99. gray = 255 * (data < threshold).astype(np.uint8)
  100. else:
  101. gray = 255 * (data > threshold).astype(np.uint8)
  102. data = 255 - data
  103. coords = cv2.findNonZero(gray) # Find all non-zero points (text)
  104. a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
  105. rect = data[b : b + h, a : a + w]
  106. im = Image.fromarray(rect).convert("L")
  107. dims = []
  108. for x in [w, h]:
  109. div, mod = divmod(x, divable)
  110. dims.append(divable * (div + (1 if mod > 0 else 0)))
  111. padded = Image.new("L", dims, 255)
  112. padded.paste(im, (0, 0, im.size[0], im.size[1]))
  113. return padded
  114. def minmax_size_(self, img, max_dimensions, min_dimensions):
  115. if max_dimensions is not None:
  116. ratios = [a / b for a, b in zip(img.size, max_dimensions)]
  117. if any([r > 1 for r in ratios]):
  118. size = np.array(img.size) // max(ratios)
  119. img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
  120. if min_dimensions is not None:
  121. # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
  122. padded_size = [
  123. max(img_dim, min_dim)
  124. for img_dim, min_dim in zip(img.size, min_dimensions)
  125. ]
  126. if padded_size != list(img.size): # assert hypothesis
  127. padded_im = Image.new("L", padded_size, 255)
  128. padded_im.paste(img, img.getbbox())
  129. img = padded_im
  130. return img
  131. def __call__(self, data):
  132. img = data["image"]
  133. h, w = img.shape[:2]
  134. if (
  135. self.min_dimensions[0] <= w <= self.max_dimensions[0]
  136. and self.min_dimensions[1] <= h <= self.max_dimensions[1]
  137. ):
  138. return data
  139. else:
  140. im = Image.fromarray(np.uint8(img))
  141. im = self.minmax_size_(
  142. self.pad_(im), self.max_dimensions, self.min_dimensions
  143. )
  144. im = np.array(im)
  145. im = np.dstack((im, im, im))
  146. data["image"] = im
  147. return data
  148. class LatexImageFormat:
  149. def __init__(self, **kwargs):
  150. # your init code
  151. pass
  152. def __call__(self, data):
  153. img = data["image"]
  154. im_h, im_w = img.shape[:2]
  155. divide_h = math.ceil(im_h / 16) * 16
  156. divide_w = math.ceil(im_w / 16) * 16
  157. img = img[:, :, 0]
  158. img = np.pad(
  159. img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
  160. )
  161. img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
  162. data["image"] = img_expanded
  163. return data