copy_paste.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import cv2
  16. import random
  17. import numpy as np
  18. from PIL import Image
  19. from shapely.geometry import Polygon
  20. from ppocr.data.imaug.iaa_augment import IaaAugment
  21. from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
  22. from tools.infer.utility import get_rotate_crop_image
  23. class CopyPaste(object):
  24. def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
  25. self.ext_data_num = 1
  26. self.objects_paste_ratio = objects_paste_ratio
  27. self.limit_paste = limit_paste
  28. augmenter_args = [{"type": "Resize", "args": {"size": [0.5, 3]}}]
  29. self.aug = IaaAugment(augmenter_args)
  30. def __call__(self, data):
  31. point_num = data["polys"].shape[1]
  32. src_img = data["image"]
  33. src_polys = data["polys"].tolist()
  34. src_texts = data["texts"]
  35. src_ignores = data["ignore_tags"].tolist()
  36. ext_data = data["ext_data"][0]
  37. ext_image = ext_data["image"]
  38. ext_polys = ext_data["polys"]
  39. ext_texts = ext_data["texts"]
  40. ext_ignores = ext_data["ignore_tags"]
  41. indexes = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
  42. select_num = max(1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
  43. random.shuffle(indexes)
  44. select_idxs = indexes[:select_num]
  45. select_polys = ext_polys[select_idxs]
  46. select_ignores = ext_ignores[select_idxs]
  47. src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
  48. ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
  49. src_img = Image.fromarray(src_img).convert("RGBA")
  50. for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
  51. box_img = get_rotate_crop_image(ext_image, poly)
  52. src_img, box = self.paste_img(src_img, box_img, src_polys)
  53. if box is not None:
  54. box = box.tolist()
  55. for _ in range(len(box), point_num):
  56. box.append(box[-1])
  57. src_polys.append(box)
  58. src_texts.append(ext_texts[idx])
  59. src_ignores.append(tag)
  60. src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
  61. h, w = src_img.shape[:2]
  62. src_polys = np.array(src_polys)
  63. src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
  64. src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
  65. data["image"] = src_img
  66. data["polys"] = src_polys
  67. data["texts"] = src_texts
  68. data["ignore_tags"] = np.array(src_ignores)
  69. return data
  70. def paste_img(self, src_img, box_img, src_polys):
  71. box_img_pil = Image.fromarray(box_img).convert("RGBA")
  72. src_w, src_h = src_img.size
  73. box_w, box_h = box_img_pil.size
  74. angle = np.random.randint(0, 360)
  75. box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
  76. box = rotate_bbox(box_img, box, angle)[0]
  77. box_img_pil = box_img_pil.rotate(angle, expand=1)
  78. box_w, box_h = box_img_pil.width, box_img_pil.height
  79. if src_w - box_w < 0 or src_h - box_h < 0:
  80. return src_img, None
  81. paste_x, paste_y = self.select_coord(
  82. src_polys, box, src_w - box_w, src_h - box_h
  83. )
  84. if paste_x is None:
  85. return src_img, None
  86. box[:, 0] += paste_x
  87. box[:, 1] += paste_y
  88. r, g, b, A = box_img_pil.split()
  89. src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
  90. return src_img, box
  91. def select_coord(self, src_polys, box, endx, endy):
  92. if self.limit_paste:
  93. xmin, ymin, xmax, ymax = (
  94. box[:, 0].min(),
  95. box[:, 1].min(),
  96. box[:, 0].max(),
  97. box[:, 1].max(),
  98. )
  99. for _ in range(50):
  100. paste_x = random.randint(0, endx)
  101. paste_y = random.randint(0, endy)
  102. xmin1 = xmin + paste_x
  103. xmax1 = xmax + paste_x
  104. ymin1 = ymin + paste_y
  105. ymax1 = ymax + paste_y
  106. num_poly_in_rect = 0
  107. for poly in src_polys:
  108. if not is_poly_outside_rect(
  109. poly, xmin1, ymin1, xmax1 - xmin1, ymax1 - ymin1
  110. ):
  111. num_poly_in_rect += 1
  112. break
  113. if num_poly_in_rect == 0:
  114. return paste_x, paste_y
  115. return None, None
  116. else:
  117. paste_x = random.randint(0, endx)
  118. paste_y = random.randint(0, endy)
  119. return paste_x, paste_y
  120. def get_union(pD, pG):
  121. return Polygon(pD).union(Polygon(pG)).area
  122. def get_intersection_over_union(pD, pG):
  123. return get_intersection(pD, pG) / get_union(pD, pG)
  124. def get_intersection(pD, pG):
  125. return Polygon(pD).intersection(Polygon(pG)).area
  126. def rotate_bbox(img, text_polys, angle, scale=1):
  127. """
  128. from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
  129. Args:
  130. img: np.ndarray
  131. text_polys: np.ndarray N*4*2
  132. angle: int
  133. scale: int
  134. Returns:
  135. """
  136. w = img.shape[1]
  137. h = img.shape[0]
  138. rangle = np.deg2rad(angle)
  139. nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)
  140. nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)
  141. rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
  142. rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
  143. rot_mat[0, 2] += rot_move[0]
  144. rot_mat[1, 2] += rot_move[1]
  145. # ---------------------- rotate box ----------------------
  146. rot_text_polys = list()
  147. for bbox in text_polys:
  148. point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
  149. point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
  150. point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
  151. point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
  152. rot_text_polys.append([point1, point2, point3, point4])
  153. return np.array(rot_text_polys, dtype=np.float32)