make_pse_gt.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import cv2
  19. import numpy as np
  20. import pyclipper
  21. from shapely.geometry import Polygon
  22. __all__ = ["MakePseGt"]
  23. class MakePseGt(object):
  24. def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
  25. self.kernel_num = kernel_num
  26. self.min_shrink_ratio = min_shrink_ratio
  27. self.size = size
  28. def __call__(self, data):
  29. image = data["image"]
  30. text_polys = data["polys"]
  31. ignore_tags = data["ignore_tags"]
  32. h, w, _ = image.shape
  33. short_edge = min(h, w)
  34. if short_edge < self.size:
  35. # keep short_size >= self.size
  36. scale = self.size / short_edge
  37. image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
  38. text_polys *= scale
  39. gt_kernels = []
  40. for i in range(1, self.kernel_num + 1):
  41. # s1->sn, from big to small
  42. rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i
  43. text_kernel, ignore_tags = self.generate_kernel(
  44. image.shape[0:2], rate, text_polys, ignore_tags
  45. )
  46. gt_kernels.append(text_kernel)
  47. training_mask = np.ones(image.shape[0:2], dtype="uint8")
  48. for i in range(text_polys.shape[0]):
  49. if ignore_tags[i]:
  50. cv2.fillPoly(
  51. training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0
  52. )
  53. gt_kernels = np.array(gt_kernels)
  54. gt_kernels[gt_kernels > 0] = 1
  55. data["image"] = image
  56. data["polys"] = text_polys
  57. data["gt_kernels"] = gt_kernels[0:]
  58. data["gt_text"] = gt_kernels[0]
  59. data["mask"] = training_mask.astype("float32")
  60. return data
  61. def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None):
  62. """
  63. Refer to part of the code:
  64. https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
  65. """
  66. h, w = img_size
  67. text_kernel = np.zeros((h, w), dtype=np.float32)
  68. for i, poly in enumerate(text_polys):
  69. polygon = Polygon(poly)
  70. distance = (
  71. polygon.area
  72. * (1 - shrink_ratio * shrink_ratio)
  73. / (polygon.length + 1e-6)
  74. )
  75. subject = [tuple(l) for l in poly]
  76. pco = pyclipper.PyclipperOffset()
  77. pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  78. shrunk = np.array(pco.Execute(-distance))
  79. if len(shrunk) == 0 or shrunk.size == 0:
  80. if ignore_tags is not None:
  81. ignore_tags[i] = True
  82. continue
  83. try:
  84. shrunk = np.array(shrunk[0]).reshape(-1, 2)
  85. except:
  86. if ignore_tags is not None:
  87. ignore_tags[i] = True
  88. continue
  89. cv2.fillPoly(text_kernel, [shrunk.astype(np.int32)], i + 1)
  90. return text_kernel, ignore_tags