unimernet_aug.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import os
  19. os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
  20. import cv2
  21. import math
  22. import numpy as np
  23. from io import BytesIO
  24. import albumentations as A
  25. from PIL import Image, ImageOps, ImageDraw
  26. from scipy.ndimage import zoom as scizoom
  27. class Erosion(A.ImageOnlyTransform):
  28. def __init__(self, scale, always_apply=False, p=0.5):
  29. super().__init__(always_apply=always_apply, p=p)
  30. if type(scale) is tuple or type(scale) is list:
  31. assert len(scale) == 2
  32. self.scale = scale
  33. else:
  34. self.scale = (scale, scale)
  35. def apply(self, img, **params):
  36. kernel = cv2.getStructuringElement(
  37. cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
  38. )
  39. img = cv2.erode(img, kernel, iterations=1)
  40. return img
  41. class Dilation(A.ImageOnlyTransform):
  42. def __init__(self, scale, always_apply=False, p=0.5):
  43. super().__init__(always_apply=always_apply, p=p)
  44. if type(scale) is tuple or type(scale) is list:
  45. assert len(scale) == 2
  46. self.scale = scale
  47. else:
  48. self.scale = (scale, scale)
  49. def apply(self, img, **params):
  50. kernel = cv2.getStructuringElement(
  51. cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
  52. )
  53. img = cv2.dilate(img, kernel, iterations=1)
  54. return img
  55. class Bitmap(A.ImageOnlyTransform):
  56. def __init__(self, value=0, lower=200, always_apply=False, p=0.5):
  57. super().__init__(always_apply=always_apply, p=p)
  58. self.lower = lower
  59. self.value = value
  60. def apply(self, img, **params):
  61. img = img.copy()
  62. img[img < self.lower] = self.value
  63. return img
  64. def clipped_zoom(img, zoom_factor):
  65. h = img.shape[1]
  66. ch = int(np.ceil(h / float(zoom_factor)))
  67. top = (h - ch) // 2
  68. img = scizoom(
  69. img[top : top + ch, top : top + ch], (zoom_factor, zoom_factor, 1), order=1
  70. )
  71. trim_top = (img.shape[0] - h) // 2
  72. return img[trim_top : trim_top + h, trim_top : trim_top + h]
  73. def disk(radius, alias_blur=0.1, dtype=np.float32):
  74. if radius <= 8:
  75. coords = np.arange(-8, 8 + 1)
  76. ksize = (3, 3)
  77. else:
  78. coords = np.arange(-radius, radius + 1)
  79. ksize = (5, 5)
  80. x, y = np.meshgrid(coords, coords)
  81. aliased_disk = np.asarray((x**2 + y**2) <= radius**2, dtype=dtype)
  82. aliased_disk /= np.sum(aliased_disk)
  83. return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
  84. def plasma_fractal(mapsize=256, wibbledecay=3, rng=None):
  85. """
  86. Generate a heightmap using diamond-square algorithm.
  87. Return square 2d array, side length 'mapsize', of floats in range 0-255.
  88. 'mapsize' must be a power of two.
  89. """
  90. assert mapsize & (mapsize - 1) == 0
  91. maparray = np.empty((mapsize, mapsize), dtype=np.float64)
  92. maparray[0, 0] = 0
  93. stepsize = mapsize
  94. wibble = 100
  95. if rng is None:
  96. rng = np.random.default_rng()
  97. def wibbledmean(array):
  98. return array / 4 + wibble * rng.uniform(-wibble, wibble, array.shape)
  99. def fillsquares():
  100. """For each square of points stepsize apart,
  101. calculate middle value as mean of points + wibble"""
  102. cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
  103. squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
  104. squareaccum += np.roll(squareaccum, shift=-1, axis=1)
  105. maparray[
  106. stepsize // 2 : mapsize : stepsize, stepsize // 2 : mapsize : stepsize
  107. ] = wibbledmean(squareaccum)
  108. def filldiamonds():
  109. """For each diamond of points stepsize apart,
  110. calculate middle value as mean of points + wibble"""
  111. drgrid = maparray[
  112. stepsize // 2 : mapsize : stepsize, stepsize // 2 : mapsize : stepsize
  113. ]
  114. ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
  115. ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
  116. lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
  117. ltsum = ldrsum + lulsum
  118. maparray[0:mapsize:stepsize, stepsize // 2 : mapsize : stepsize] = wibbledmean(
  119. ltsum
  120. )
  121. tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
  122. tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
  123. ttsum = tdrsum + tulsum
  124. maparray[stepsize // 2 : mapsize : stepsize, 0:mapsize:stepsize] = wibbledmean(
  125. ttsum
  126. )
  127. while stepsize >= 2:
  128. fillsquares()
  129. filldiamonds()
  130. stepsize //= 2
  131. wibble /= wibbledecay
  132. maparray -= maparray.min()
  133. return maparray / maparray.max()
  134. class Fog(A.ImageOnlyTransform):
  135. def __init__(self, mag=-1, always_apply=False, p=1.0):
  136. super().__init__(always_apply=always_apply, p=p)
  137. self.rng = np.random.default_rng()
  138. self.mag = mag
  139. def apply(self, img, **params):
  140. img = Image.fromarray(img.astype(np.uint8))
  141. w, h = img.size
  142. c = [(1.5, 2), (2.0, 2), (2.5, 1.7)]
  143. if self.mag < 0 or self.mag >= len(c):
  144. index = self.rng.integers(0, len(c))
  145. else:
  146. index = self.mag
  147. c = c[index]
  148. n_channels = len(img.getbands())
  149. isgray = n_channels == 1
  150. img = np.asarray(img) / 255.0
  151. max_val = img.max()
  152. max_size = 2 ** math.ceil(math.log2(max(w, h)) + 1)
  153. fog = (
  154. c[0]
  155. * plasma_fractal(mapsize=max_size, wibbledecay=c[1], rng=self.rng)[:h, :w][
  156. ..., np.newaxis
  157. ]
  158. )
  159. if isgray:
  160. fog = np.squeeze(fog)
  161. else:
  162. fog = np.repeat(fog, 3, axis=2)
  163. img += fog
  164. img = np.clip(img * max_val / (max_val + c[0]), 0, 1) * 255
  165. return img.astype(np.uint8)
  166. class Frost(A.ImageOnlyTransform):
  167. def __init__(self, mag=-1, always_apply=False, p=1.0):
  168. super().__init__(always_apply=always_apply, p=p)
  169. self.rng = np.random.default_rng()
  170. self.mag = mag
  171. def apply(self, img, **params):
  172. img = Image.fromarray(img.astype(np.uint8))
  173. w, h = img.size
  174. c = [(0.78, 0.22), (0.64, 0.36), (0.5, 0.5)]
  175. if self.mag < 0 or self.mag >= len(c):
  176. index = self.rng.integers(0, len(c))
  177. else:
  178. index = self.mag
  179. c = c[index]
  180. file_dir = os.path.dirname(__file__)
  181. filename = [
  182. os.path.join(file_dir, "frost_img", "frost1.jpg"),
  183. os.path.join(file_dir, "frost_img", "frost2.png"),
  184. os.path.join(file_dir, "frost_img", "frost3.png"),
  185. os.path.join(file_dir, "frost_img", "frost4.jpg"),
  186. os.path.join(file_dir, "frost_img", "frost5.jpg"),
  187. os.path.join(file_dir, "frost_img", "frost6.jpg"),
  188. ]
  189. index = self.rng.integers(0, len(filename))
  190. filename = filename[index]
  191. frost = Image.open(filename).convert("RGB")
  192. f_w, f_h = frost.size
  193. if w / h > f_w / f_h:
  194. f_h = round(f_h * w / f_w)
  195. f_w = w
  196. else:
  197. f_w = round(f_w * h / f_h)
  198. f_h = h
  199. frost = np.asarray(frost.resize((f_w, f_h)))
  200. # randomly crop
  201. y_start, x_start = self.rng.integers(0, f_h - h + 1), self.rng.integers(
  202. 0, f_w - w + 1
  203. )
  204. frost = frost[y_start : y_start + h, x_start : x_start + w]
  205. n_channels = len(img.getbands())
  206. isgray = n_channels == 1
  207. img = np.asarray(img)
  208. if isgray:
  209. img = np.expand_dims(img, axis=2)
  210. img = np.repeat(img, 3, axis=2)
  211. img = np.clip(np.round(c[0] * img + c[1] * frost), 0, 255)
  212. img = img.astype(np.uint8)
  213. if isgray:
  214. img = np.squeeze(img)
  215. return img
  216. class Snow(A.ImageOnlyTransform):
  217. def __init__(self, mag=-1, always_apply=False, p=1.0):
  218. super().__init__(always_apply=always_apply, p=p)
  219. self.rng = np.random.default_rng()
  220. self.mag = mag
  221. def apply(self, img, **params):
  222. from wand.image import Image as WandImage
  223. img = Image.fromarray(img.astype(np.uint8))
  224. w, h = img.size
  225. c = [
  226. (0.1, 0.3, 3, 0.5, 10, 4, 0.8),
  227. (0.2, 0.3, 2, 0.5, 12, 4, 0.7),
  228. (0.55, 0.3, 4, 0.9, 12, 8, 0.7),
  229. ]
  230. if self.mag < 0 or self.mag >= len(c):
  231. index = self.rng.integers(0, len(c))
  232. else:
  233. index = self.mag
  234. c = c[index]
  235. n_channels = len(img.getbands())
  236. isgray = n_channels == 1
  237. img = np.asarray(img, dtype=np.float32) / 255.0
  238. if isgray:
  239. img = np.expand_dims(img, axis=2)
  240. img = np.repeat(img, 3, axis=2)
  241. snow_layer = self.rng.normal(size=img.shape[:2], loc=c[0], scale=c[1])
  242. snow_layer[snow_layer < c[3]] = 0
  243. snow_layer = Image.fromarray(
  244. (np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode="L"
  245. )
  246. output = BytesIO()
  247. snow_layer.save(output, format="PNG")
  248. snow_layer = WandImage(blob=output.getvalue())
  249. snow_layer.motion_blur(
  250. radius=c[4], sigma=c[5], angle=self.rng.uniform(-135, -45)
  251. )
  252. snow_layer = (
  253. cv2.imdecode(
  254. np.frombuffer(snow_layer.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED
  255. )
  256. / 255.0
  257. )
  258. snow_layer = snow_layer[..., np.newaxis]
  259. img = c[6] * img
  260. gray_img = (1 - c[6]) * np.maximum(
  261. img, cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).reshape(h, w, 1) * 1.5 + 0.5
  262. )
  263. img += gray_img
  264. img = np.clip(img + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255
  265. img = img.astype(np.uint8)
  266. if isgray:
  267. img = np.squeeze(img)
  268. return img
  269. class Rain(A.ImageOnlyTransform):
  270. def __init__(self, mag=-1, always_apply=False, p=1.0):
  271. super().__init__(always_apply=always_apply, p=p)
  272. self.rng = np.random.default_rng()
  273. self.mag = mag
  274. def apply(self, img, **params):
  275. img = Image.fromarray(img.astype(np.uint8))
  276. img = img.copy()
  277. w, h = img.size
  278. n_channels = len(img.getbands())
  279. isgray = n_channels == 1
  280. line_width = self.rng.integers(1, 2)
  281. c = [50, 70, 90]
  282. if self.mag < 0 or self.mag >= len(c):
  283. index = 0
  284. else:
  285. index = self.mag
  286. c = c[index]
  287. n_rains = self.rng.integers(c, c + 20)
  288. slant = self.rng.integers(-60, 60)
  289. fillcolor = 200 if isgray else (200, 200, 200)
  290. draw = ImageDraw.Draw(img)
  291. max_length = min(w, h, 10)
  292. for i in range(1, n_rains):
  293. length = self.rng.integers(5, max_length)
  294. x1 = self.rng.integers(0, w - length)
  295. y1 = self.rng.integers(0, h - length)
  296. x2 = x1 + length * math.sin(slant * math.pi / 180.0)
  297. y2 = y1 + length * math.cos(slant * math.pi / 180.0)
  298. x2 = int(x2)
  299. y2 = int(y2)
  300. draw.line([(x1, y1), (x2, y2)], width=line_width, fill=fillcolor)
  301. img = np.asarray(img).astype(np.uint8)
  302. return img
  303. class Shadow(A.ImageOnlyTransform):
  304. def __init__(self, mag=-1, always_apply=False, p=1.0):
  305. super().__init__(always_apply=always_apply, p=p)
  306. self.rng = np.random.default_rng()
  307. self.mag = mag
  308. def apply(self, img, **params):
  309. img = Image.fromarray(img.astype(np.uint8))
  310. w, h = img.size
  311. n_channels = len(img.getbands())
  312. isgray = n_channels == 1
  313. c = [64, 96, 128]
  314. if self.mag < 0 or self.mag >= len(c):
  315. index = 0
  316. else:
  317. index = self.mag
  318. c = c[index]
  319. img = img.convert("RGBA")
  320. overlay = Image.new("RGBA", img.size, (255, 255, 255, 0))
  321. draw = ImageDraw.Draw(overlay)
  322. transparency = self.rng.integers(c, c + 32)
  323. x1 = self.rng.integers(0, w // 2)
  324. y1 = 0
  325. x2 = self.rng.integers(w // 2, w)
  326. y2 = 0
  327. x3 = self.rng.integers(w // 2, w)
  328. y3 = h - 1
  329. x4 = self.rng.integers(0, w // 2)
  330. y4 = h - 1
  331. draw.polygon(
  332. [(x1, y1), (x2, y2), (x3, y3), (x4, y4)], fill=(0, 0, 0, transparency)
  333. )
  334. img = Image.alpha_composite(img, overlay)
  335. img = img.convert("RGB")
  336. if isgray:
  337. img = ImageOps.grayscale(img)
  338. img = np.asarray(img).astype(np.uint8)
  339. return img
  340. class UniMERNetTrainTransform:
  341. def __init__(self, bitmap_prob=0.04, **kwargs):
  342. self.bitmap_prob = bitmap_prob
  343. if tuple(map(int, A.__version__.split("."))) >= (2, 0, 0):
  344. new_val = (0, (10 / 255) ** 0.5)
  345. GaussNoise = A.GaussNoise(new_val, p=0.2)
  346. ImageCompression = A.ImageCompression(quality_range=(95, 100), p=0.3)
  347. else:
  348. GaussNoise = A.GaussNoise(10, p=0.2)
  349. ImageCompression = A.ImageCompression(95, p=0.3)
  350. self.train_transform = A.Compose(
  351. [
  352. A.Compose(
  353. [
  354. Bitmap(p=0.05),
  355. A.OneOf([Fog(), Frost(), Snow(), Rain(), Shadow()], p=0.2),
  356. A.OneOf([Erosion((2, 3)), Dilation((2, 3))], p=0.2),
  357. A.ShiftScaleRotate(
  358. shift_limit=0,
  359. scale_limit=(-0.15, 0),
  360. rotate_limit=1,
  361. border_mode=0,
  362. interpolation=3,
  363. value=[255, 255, 255],
  364. p=1,
  365. ),
  366. A.GridDistortion(
  367. distort_limit=0.1,
  368. border_mode=0,
  369. interpolation=3,
  370. value=[255, 255, 255],
  371. p=0.5,
  372. ),
  373. ],
  374. p=0.15,
  375. ),
  376. A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.3),
  377. GaussNoise,
  378. A.RandomBrightnessContrast(0.05, (-0.2, 0), True, p=0.2),
  379. ImageCompression,
  380. A.ToGray(always_apply=True),
  381. A.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
  382. ]
  383. )
  384. def __call__(self, data):
  385. img = data["image"]
  386. if np.random.random() < self.bitmap_prob:
  387. img[img != 255] = 0
  388. img = self.train_transform(image=img)["image"]
  389. data["image"] = img
  390. return data
  391. class UniMERNetTestTransform:
  392. def __init__(self, **kwargs):
  393. self.test_transform = A.Compose(
  394. [
  395. A.ToGray(always_apply=True),
  396. A.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
  397. ]
  398. )
  399. def __call__(self, data):
  400. img = data["image"]
  401. img = self.test_transform(image=img)["image"]
  402. data["image"] = img
  403. return data
  404. class GoTImgDecode:
  405. def __init__(self, input_size, random_padding=False, **kwargs):
  406. self.input_size = input_size
  407. self.random_padding = random_padding
  408. def crop_margin(self, img):
  409. data = np.array(img.convert("L"))
  410. data = data.astype(np.uint8)
  411. max_val = data.max()
  412. min_val = data.min()
  413. if max_val == min_val:
  414. return img
  415. data = (data - min_val) / (max_val - min_val) * 255
  416. gray = 255 * (data < 200).astype(np.uint8)
  417. coords = cv2.findNonZero(gray) # Find all non-zero points (text)
  418. a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
  419. return img.crop((a, b, w + a, h + b))
  420. def get_dimensions(self, img):
  421. if hasattr(img, "getbands"):
  422. channels = len(img.getbands())
  423. else:
  424. channels = img.channels
  425. width, height = img.size
  426. return [channels, height, width]
  427. def _compute_resized_output_size(self, image_size, size, max_size=None):
  428. if len(size) == 1: # specified size only for the smallest edge
  429. h, w = image_size
  430. short, long = (w, h) if w <= h else (h, w)
  431. requested_new_short = size if isinstance(size, int) else size[0]
  432. new_short, new_long = requested_new_short, int(
  433. requested_new_short * long / short
  434. )
  435. if max_size is not None:
  436. if max_size <= requested_new_short:
  437. raise ValueError(
  438. f"max_size = {max_size} must be strictly greater than the requested "
  439. f"size for the smaller edge size = {size}"
  440. )
  441. if new_long > max_size:
  442. new_short, new_long = int(max_size * new_short / new_long), max_size
  443. new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
  444. else: # specified both h and w
  445. new_w, new_h = size[1], size[0]
  446. return [new_h, new_w]
  447. def resize(self, img, size):
  448. _, image_height, image_width = self.get_dimensions(img)
  449. if isinstance(size, int):
  450. size = [size]
  451. max_size = None
  452. output_size = self._compute_resized_output_size(
  453. (image_height, image_width), size, max_size
  454. )
  455. img = img.resize(tuple(output_size[::-1]), resample=2)
  456. return img
  457. def __call__(self, data):
  458. filename = data["filename"]
  459. img = Image.open(filename)
  460. try:
  461. img = self.crop_margin(img.convert("RGB"))
  462. except OSError:
  463. return
  464. if img.height == 0 or img.width == 0:
  465. return
  466. img = self.resize(img, min(self.input_size))
  467. img.thumbnail((self.input_size[1], self.input_size[0]))
  468. delta_width = self.input_size[1] - img.width
  469. delta_height = self.input_size[0] - img.height
  470. if self.random_padding:
  471. pad_width = np.random.randint(low=0, high=delta_width + 1)
  472. pad_height = np.random.randint(low=0, high=delta_height + 1)
  473. else:
  474. pad_width = delta_width // 2
  475. pad_height = delta_height // 2
  476. padding = (
  477. pad_width,
  478. pad_height,
  479. delta_width - pad_width,
  480. delta_height - pad_height,
  481. )
  482. data["image"] = np.array(ImageOps.expand(img, padding))
  483. return data
  484. class UniMERNetImgDecode:
  485. def __init__(
  486. self,
  487. input_size,
  488. random_padding=False,
  489. random_resize=False,
  490. random_crop=False,
  491. **kwargs,
  492. ):
  493. self.input_size = input_size
  494. self.is_random_padding = random_padding
  495. self.is_random_resize = random_resize
  496. self.is_random_crop = random_crop
  497. def crop_margin(self, img):
  498. data = np.array(img.convert("L"))
  499. data = data.astype(np.uint8)
  500. max_val = data.max()
  501. min_val = data.min()
  502. if max_val == min_val:
  503. return img
  504. data = (data - min_val) / (max_val - min_val) * 255
  505. gray = 255 * (data < 200).astype(np.uint8)
  506. coords = cv2.findNonZero(gray) # Find all non-zero points (text)
  507. a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
  508. return img.crop((a, b, w + a, h + b))
  509. def get_dimensions(self, img):
  510. if hasattr(img, "getbands"):
  511. channels = len(img.getbands())
  512. else:
  513. channels = img.channels
  514. width, height = img.size
  515. return [channels, height, width]
  516. def _compute_resized_output_size(self, image_size, size, max_size=None):
  517. if len(size) == 1: # specified size only for the smallest edge
  518. h, w = image_size
  519. short, long = (w, h) if w <= h else (h, w)
  520. requested_new_short = size if isinstance(size, int) else size[0]
  521. new_short, new_long = requested_new_short, int(
  522. requested_new_short * long / short
  523. )
  524. if max_size is not None:
  525. if max_size <= requested_new_short:
  526. raise ValueError(
  527. f"max_size = {max_size} must be strictly greater than the requested "
  528. f"size for the smaller edge size = {size}"
  529. )
  530. if new_long > max_size:
  531. new_short, new_long = int(max_size * new_short / new_long), max_size
  532. new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
  533. else: # specified both h and w
  534. new_w, new_h = size[1], size[0]
  535. return [new_h, new_w]
  536. def resize(self, img, size):
  537. _, image_height, image_width = self.get_dimensions(img)
  538. if isinstance(size, int):
  539. size = [size]
  540. max_size = None
  541. output_size = self._compute_resized_output_size(
  542. (image_height, image_width), size, max_size
  543. )
  544. img = img.resize(tuple(output_size[::-1]), resample=2)
  545. return img
  546. def random_resize(self, img):
  547. scale = np.random.uniform(0.5, 1)
  548. img = img.resize([int(scale * s) for s in img.size])
  549. return img
  550. def random_crop(self, img, crop_ratio):
  551. width, height = img.width, img.height
  552. max_crop_pixel = min(width, height) * crop_ratio
  553. crop_left = np.random.uniform(0, max_crop_pixel)
  554. crop_right = np.random.uniform(0, max_crop_pixel)
  555. crop_top = np.random.uniform(0, max_crop_pixel)
  556. crop_bottom = np.random.uniform(0, max_crop_pixel)
  557. # 计算裁剪后的边界
  558. left = crop_left
  559. top = crop_top
  560. right = width - crop_right
  561. bottom = height - crop_bottom
  562. # 裁剪图像
  563. img = img.crop((left, top, right, bottom))
  564. return img
  565. def __call__(self, data):
  566. filename = data["filename"]
  567. img = Image.open(filename)
  568. try:
  569. if self.is_random_resize:
  570. img = self.random_resize(img)
  571. img = self.crop_margin(img.convert("RGB"))
  572. if "label" in data and self.is_random_crop:
  573. label = data["label"]
  574. equation_length = len(label)
  575. if equation_length < 256:
  576. img = self.random_crop(img, crop_ratio=0.1)
  577. elif 256 < equation_length <= 512:
  578. img = self.random_crop(img, crop_ratio=0.05)
  579. else:
  580. img = self.random_crop(img, crop_ratio=0.03)
  581. except OSError:
  582. return
  583. if img.height == 0 or img.width == 0:
  584. return
  585. img = self.resize(img, min(self.input_size))
  586. img.thumbnail((self.input_size[1], self.input_size[0]))
  587. delta_width = self.input_size[1] - img.width
  588. delta_height = self.input_size[0] - img.height
  589. if self.is_random_padding:
  590. pad_width = np.random.randint(low=0, high=delta_width + 1)
  591. pad_height = np.random.randint(low=0, high=delta_height + 1)
  592. else:
  593. pad_width = delta_width // 2
  594. pad_height = delta_height // 2
  595. padding = (
  596. pad_width,
  597. pad_height,
  598. delta_width - pad_width,
  599. delta_height - pad_height,
  600. )
  601. data["image"] = np.array(ImageOps.expand(img, padding))
  602. return data
  603. class UniMERNetResize:
  604. def __init__(self, input_size, random_padding=False, **kwargs):
  605. self.input_size = input_size
  606. self.random_padding = random_padding
  607. def crop_margin(self, img):
  608. data = np.array(img.convert("L"))
  609. data = data.astype(np.uint8)
  610. max_val = data.max()
  611. min_val = data.min()
  612. if max_val == min_val:
  613. return img
  614. data = (data - min_val) / (max_val - min_val) * 255
  615. gray = 255 * (data < 200).astype(np.uint8)
  616. coords = cv2.findNonZero(gray) # Find all non-zero points (text)
  617. a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
  618. return img.crop((a, b, w + a, h + b))
  619. def get_dimensions(self, img):
  620. if hasattr(img, "getbands"):
  621. channels = len(img.getbands())
  622. else:
  623. channels = img.channels
  624. width, height = img.size
  625. return [channels, height, width]
  626. def _compute_resized_output_size(self, image_size, size, max_size=None):
  627. if len(size) == 1: # specified size only for the smallest edge
  628. h, w = image_size
  629. short, long = (w, h) if w <= h else (h, w)
  630. requested_new_short = size if isinstance(size, int) else size[0]
  631. new_short, new_long = requested_new_short, int(
  632. requested_new_short * long / short
  633. )
  634. if max_size is not None:
  635. if max_size <= requested_new_short:
  636. raise ValueError(
  637. f"max_size = {max_size} must be strictly greater than the requested "
  638. f"size for the smaller edge size = {size}"
  639. )
  640. if new_long > max_size:
  641. new_short, new_long = int(max_size * new_short / new_long), max_size
  642. new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
  643. else: # specified both h and w
  644. new_w, new_h = size[1], size[0]
  645. return [new_h, new_w]
  646. def resize(self, img, size):
  647. _, image_height, image_width = self.get_dimensions(img)
  648. if isinstance(size, int):
  649. size = [size]
  650. max_size = None
  651. output_size = self._compute_resized_output_size(
  652. (image_height, image_width), size, max_size
  653. )
  654. img.resize(tuple(output_size[::-1]), resample=2)
  655. return img
  656. def __call__(self, data):
  657. img = data["image"]
  658. img = Image.fromarray(img)
  659. try:
  660. img = self.crop_margin(img)
  661. except OSError:
  662. return
  663. if img.height == 0 or img.width == 0:
  664. return
  665. img = self.resize(img, min(self.input_size))
  666. img.thumbnail((self.input_size[1], self.input_size[0]))
  667. delta_width = self.input_size[1] - img.width
  668. delta_height = self.input_size[0] - img.height
  669. if self.random_padding:
  670. pad_width = np.random.randint(low=0, high=delta_width + 1)
  671. pad_height = np.random.randint(low=0, high=delta_height + 1)
  672. else:
  673. pad_width = delta_width // 2
  674. pad_height = delta_height // 2
  675. padding = (
  676. pad_width,
  677. pad_height,
  678. delta_width - pad_width,
  679. delta_height - pad_height,
  680. )
  681. data["image"] = np.array(ImageOps.expand(img, padding))
  682. return data
  683. class UniMERNetImageFormat:
  684. def __init__(self, **kwargs):
  685. pass
  686. def __call__(self, data):
  687. img = data["image"]
  688. im_h, im_w = img.shape[:2]
  689. divide_h = math.ceil(im_h / 32) * 32
  690. divide_w = math.ceil(im_w / 32) * 32
  691. img = img[:, :, 0]
  692. img = np.pad(
  693. img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
  694. )
  695. img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
  696. data["image"] = img_expanded
  697. return data