make_border_map.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. from __future__ import unicode_literals
  22. import numpy as np
  23. import cv2
  24. np.seterr(divide="ignore", invalid="ignore")
  25. import pyclipper
  26. from shapely.geometry import Polygon
  27. import sys
  28. import warnings
  29. warnings.simplefilter("ignore")
  30. __all__ = ["MakeBorderMap"]
  31. class MakeBorderMap(object):
  32. def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7, **kwargs):
  33. self.shrink_ratio = shrink_ratio
  34. self.thresh_min = thresh_min
  35. self.thresh_max = thresh_max
  36. if "total_epoch" in kwargs and "epoch" in kwargs and kwargs["epoch"] != "None":
  37. self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs["epoch"] / float(
  38. kwargs["total_epoch"]
  39. )
  40. def __call__(self, data):
  41. img = data["image"]
  42. text_polys = data["polys"]
  43. ignore_tags = data["ignore_tags"]
  44. canvas = np.zeros(img.shape[:2], dtype=np.float32)
  45. mask = np.zeros(img.shape[:2], dtype=np.float32)
  46. for i in range(len(text_polys)):
  47. if ignore_tags[i]:
  48. continue
  49. self.draw_border_map(text_polys[i], canvas, mask=mask)
  50. canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
  51. data["threshold_map"] = canvas
  52. data["threshold_mask"] = mask
  53. return data
  54. def draw_border_map(self, polygon, canvas, mask):
  55. polygon = np.array(polygon)
  56. assert polygon.ndim == 2
  57. assert polygon.shape[1] == 2
  58. polygon_shape = Polygon(polygon)
  59. if polygon_shape.area <= 0:
  60. return
  61. distance = (
  62. polygon_shape.area
  63. * (1 - np.power(self.shrink_ratio, 2))
  64. / polygon_shape.length
  65. )
  66. subject = [tuple(l) for l in polygon]
  67. padding = pyclipper.PyclipperOffset()
  68. padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  69. padded_polygon = np.array(padding.Execute(distance)[0])
  70. cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
  71. xmin = padded_polygon[:, 0].min()
  72. xmax = padded_polygon[:, 0].max()
  73. ymin = padded_polygon[:, 1].min()
  74. ymax = padded_polygon[:, 1].max()
  75. width = xmax - xmin + 1
  76. height = ymax - ymin + 1
  77. polygon[:, 0] = polygon[:, 0] - xmin
  78. polygon[:, 1] = polygon[:, 1] - ymin
  79. xs = np.broadcast_to(
  80. np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)
  81. )
  82. ys = np.broadcast_to(
  83. np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)
  84. )
  85. distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)
  86. for i in range(polygon.shape[0]):
  87. j = (i + 1) % polygon.shape[0]
  88. absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
  89. distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
  90. distance_map = distance_map.min(axis=0)
  91. xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
  92. xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
  93. ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
  94. ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
  95. canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
  96. 1
  97. - distance_map[
  98. ymin_valid - ymin : ymax_valid - ymax + height,
  99. xmin_valid - xmin : xmax_valid - xmax + width,
  100. ],
  101. canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
  102. )
  103. def _distance(self, xs, ys, point_1, point_2):
  104. """
  105. compute the distance from point to a line
  106. ys: coordinates in the first axis
  107. xs: coordinates in the second axis
  108. point_1, point_2: (x, y), the end of the line
  109. """
  110. height, width = xs.shape[:2]
  111. square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
  112. square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
  113. square_distance = np.square(point_1[0] - point_2[0]) + np.square(
  114. point_1[1] - point_2[1]
  115. )
  116. cosin = (square_distance - square_distance_1 - square_distance_2) / (
  117. 2 * np.sqrt(square_distance_1 * square_distance_2)
  118. )
  119. square_sin = 1 - np.square(cosin)
  120. square_sin = np.nan_to_num(square_sin)
  121. result = np.sqrt(
  122. square_distance_1 * square_distance_2 * square_sin / square_distance
  123. )
  124. result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[
  125. cosin < 0
  126. ]
  127. # self.extend_line(point_1, point_2, result)
  128. return result
  129. def extend_line(self, point_1, point_2, result, shrink_ratio):
  130. ex_point_1 = (
  131. int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
  132. int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + shrink_ratio))),
  133. )
  134. cv2.line(
  135. result,
  136. tuple(ex_point_1),
  137. tuple(point_1),
  138. 4096.0,
  139. 1,
  140. lineType=cv2.LINE_AA,
  141. shift=0,
  142. )
  143. ex_point_2 = (
  144. int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
  145. int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + shrink_ratio))),
  146. )
  147. cv2.line(
  148. result,
  149. tuple(ex_point_2),
  150. tuple(point_2),
  151. 4096.0,
  152. 1,
  153. lineType=cv2.LINE_AA,
  154. shift=0,
  155. )
  156. return ex_point_1, ex_point_2