east_postprocess.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. from .locality_aware_nms import nms_locality
  19. import cv2
  20. import paddle
  21. import os
  22. from ppocr.utils.utility import check_install
  23. import sys
  24. class EASTPostProcess(object):
  25. """
  26. The post process for EAST.
  27. """
  28. def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs):
  29. self.score_thresh = score_thresh
  30. self.cover_thresh = cover_thresh
  31. self.nms_thresh = nms_thresh
  32. def restore_rectangle_quad(self, origin, geometry):
  33. """
  34. Restore rectangle from quadrangle.
  35. """
  36. # quad
  37. origin_concat = np.concatenate(
  38. (origin, origin, origin, origin), axis=1
  39. ) # (n, 8)
  40. pred_quads = origin_concat - geometry
  41. pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
  42. return pred_quads
  43. def detect(
  44. self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2
  45. ):
  46. """
  47. restore text boxes from score map and geo map
  48. """
  49. score_map = score_map[0]
  50. geo_map = np.swapaxes(geo_map, 1, 0)
  51. geo_map = np.swapaxes(geo_map, 1, 2)
  52. # filter the score map
  53. xy_text = np.argwhere(score_map > score_thresh)
  54. if len(xy_text) == 0:
  55. return []
  56. # sort the text boxes via the y axis
  57. xy_text = xy_text[np.argsort(xy_text[:, 0])]
  58. # restore quad proposals
  59. text_box_restored = self.restore_rectangle_quad(
  60. xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :]
  61. )
  62. boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
  63. boxes[:, :8] = text_box_restored.reshape((-1, 8))
  64. boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
  65. try:
  66. check_install("lanms", "lanms-nova")
  67. import lanms
  68. boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
  69. except:
  70. print(
  71. "You should install lanms by pip3 install lanms-nova to speed up nms_locality"
  72. )
  73. boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
  74. if boxes.shape[0] == 0:
  75. return []
  76. # Here we filter some low score boxes by the average score map,
  77. # this is different from the original paper.
  78. for i, box in enumerate(boxes):
  79. mask = np.zeros_like(score_map, dtype=np.uint8)
  80. cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
  81. boxes[i, 8] = cv2.mean(score_map, mask)[0]
  82. boxes = boxes[boxes[:, 8] > cover_thresh]
  83. return boxes
  84. def sort_poly(self, p):
  85. """
  86. Sort polygons.
  87. """
  88. min_axis = np.argmin(np.sum(p, axis=1))
  89. p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
  90. if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
  91. return p
  92. else:
  93. return p[[0, 3, 2, 1]]
  94. def __call__(self, outs_dict, shape_list):
  95. score_list = outs_dict["f_score"]
  96. geo_list = outs_dict["f_geo"]
  97. if isinstance(score_list, paddle.Tensor):
  98. score_list = score_list.numpy()
  99. geo_list = geo_list.numpy()
  100. img_num = len(shape_list)
  101. dt_boxes_list = []
  102. for ino in range(img_num):
  103. score = score_list[ino]
  104. geo = geo_list[ino]
  105. boxes = self.detect(
  106. score_map=score,
  107. geo_map=geo,
  108. score_thresh=self.score_thresh,
  109. cover_thresh=self.cover_thresh,
  110. nms_thresh=self.nms_thresh,
  111. )
  112. boxes_norm = []
  113. if len(boxes) > 0:
  114. h, w = score.shape[1:]
  115. src_h, src_w, ratio_h, ratio_w = shape_list[ino]
  116. boxes = boxes[:, :8].reshape((-1, 4, 2))
  117. boxes[:, :, 0] /= ratio_w
  118. boxes[:, :, 1] /= ratio_h
  119. for i_box, box in enumerate(boxes):
  120. box = self.sort_poly(box.astype(np.int32))
  121. if (
  122. np.linalg.norm(box[0] - box[1]) < 5
  123. or np.linalg.norm(box[3] - box[0]) < 5
  124. ):
  125. continue
  126. boxes_norm.append(box)
  127. dt_boxes_list.append({"points": np.array(boxes_norm)})
  128. return dt_boxes_list