textmask.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. from os import stat
  2. from typing import List
  3. import cv2
  4. import numpy as np
  5. from .textblock import TextBlock
  6. from .imgproc_utils import draw_connected_labels, expand_textwindow, union_area
  7. WHITE = (255, 255, 255)
  8. BLACK = (0, 0, 0)
  9. LANG_ENG = 0
  10. LANG_JPN = 1
  11. REFINEMASK_INPAINT = 0
  12. REFINEMASK_ANNOTATION = 1
  13. def get_topk_color(color_list, bins, k=3, color_var=10, bin_tol=0.001):
  14. idx = np.argsort(bins * -1)
  15. color_list, bins = color_list[idx], bins[idx]
  16. top_colors = [color_list[0]]
  17. bin_tol = np.sum(bins) * bin_tol
  18. if len(color_list) > 1:
  19. for color, bin in zip(color_list[1:], bins[1:]):
  20. if np.abs(np.array(top_colors) - color).min() > color_var:
  21. top_colors.append(color)
  22. if len(top_colors) >= k or bin < bin_tol:
  23. break
  24. return top_colors
  25. def minxor_thresh(threshed, mask, dilate=False):
  26. neg_threshed = 255 - threshed
  27. e_size = 1
  28. if dilate:
  29. element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
  30. neg_threshed = cv2.dilate(neg_threshed, element, iterations=1)
  31. threshed = cv2.dilate(threshed, element, iterations=1)
  32. neg_xor_sum = cv2.bitwise_xor(neg_threshed, mask).sum()
  33. xor_sum = cv2.bitwise_xor(threshed, mask).sum()
  34. if neg_xor_sum < xor_sum:
  35. return neg_threshed, neg_xor_sum
  36. else:
  37. return threshed, xor_sum
  38. def get_otsuthresh_masklist(img, pred_mask, per_channel=False) -> List[np.ndarray]:
  39. channels = [img[..., 0], img[..., 1], img[..., 2]]
  40. mask_list = []
  41. for c in channels:
  42. _, threshed = cv2.threshold(c, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
  43. threshed, xor_sum = minxor_thresh(threshed, pred_mask, dilate=False)
  44. mask_list.append([threshed, xor_sum])
  45. mask_list.sort(key=lambda x: x[1])
  46. if per_channel:
  47. return mask_list
  48. else:
  49. return [mask_list[0]]
  50. def get_topk_masklist(im_grey, pred_mask):
  51. if len(im_grey.shape) == 3 and im_grey.shape[-1] == 3:
  52. im_grey = cv2.cvtColor(im_grey, cv2.COLOR_BGR2GRAY)
  53. msk = np.ascontiguousarray(pred_mask)
  54. candidate_grey_px = im_grey[np.where(cv2.erode(msk, np.ones((3,3), np.uint8), iterations=1) > 127)]
  55. bin, his = np.histogram(candidate_grey_px, bins=255)
  56. topk_color = get_topk_color(his, bin, color_var=10, k=3)
  57. color_range = 30
  58. mask_list = list()
  59. for ii, color in enumerate(topk_color):
  60. c_top = min(color+color_range, 255)
  61. c_bottom = c_top - 2 * color_range
  62. threshed = cv2.inRange(im_grey, c_bottom, c_top)
  63. threshed, xor_sum = minxor_thresh(threshed, msk)
  64. mask_list.append([threshed, xor_sum])
  65. return mask_list
  66. def merge_mask_list(mask_list, pred_mask, blk: TextBlock = None, pred_thresh=30, text_window=None, filter_with_lines=False, refine_mode=REFINEMASK_INPAINT):
  67. mask_list.sort(key=lambda x: x[1])
  68. linemask = None
  69. if blk is not None and filter_with_lines:
  70. linemask = np.zeros_like(pred_mask)
  71. lines = blk.lines_array(dtype=np.int64)
  72. for line in lines:
  73. line[..., 0] -= text_window[0]
  74. line[..., 1] -= text_window[1]
  75. cv2.fillPoly(linemask, [line], 255)
  76. linemask = cv2.dilate(linemask, np.ones((3, 3), np.uint8), iterations=3)
  77. if pred_thresh > 0:
  78. e_size = 1
  79. element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
  80. pred_mask = cv2.erode(pred_mask, element, iterations=1)
  81. _, pred_mask = cv2.threshold(pred_mask, 60, 255, cv2.THRESH_BINARY)
  82. connectivity = 8
  83. mask_merged = np.zeros_like(pred_mask)
  84. for ii, (candidate_mask, xor_sum) in enumerate(mask_list):
  85. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(candidate_mask, connectivity, cv2.CV_16U)
  86. for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
  87. if label_index != 0: # skip background label
  88. x, y, w, h, area = stat
  89. if w * h < 3:
  90. continue
  91. x1, y1, x2, y2 = x, y, x+w, y+h
  92. label_local = labels[y1: y2, x1: x2]
  93. label_coordinates = np.where(label_local==label_index)
  94. tmp_merged = np.zeros_like(label_local, np.uint8)
  95. tmp_merged[label_coordinates] = 255
  96. tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged)
  97. xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum()
  98. xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum()
  99. if xor_merged < xor_origin:
  100. mask_merged[y1: y2, x1: x2] = tmp_merged
  101. if refine_mode == REFINEMASK_INPAINT:
  102. mask_merged = cv2.dilate(mask_merged, np.ones((3, 3), np.uint8), iterations=1)
  103. # fill holes
  104. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(255-mask_merged, connectivity, cv2.CV_16U)
  105. sorted_area = np.sort(stats[:, -1])
  106. if len(sorted_area) > 1:
  107. area_thresh = sorted_area[-2]
  108. else:
  109. area_thresh = sorted_area[-1]
  110. for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
  111. x, y, w, h, area = stat
  112. if area < area_thresh:
  113. x1, y1, x2, y2 = x, y, x+w, y+h
  114. label_local = labels[y1: y2, x1: x2]
  115. label_coordinates = np.where(label_local==label_index)
  116. tmp_merged = np.zeros_like(label_local, np.uint8)
  117. tmp_merged[label_coordinates] = 255
  118. tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged)
  119. xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum()
  120. xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum()
  121. if xor_merged < xor_origin:
  122. mask_merged[y1: y2, x1: x2] = tmp_merged
  123. return mask_merged
  124. def refine_undetected_mask(img: np.ndarray, mask_pred: np.ndarray, mask_refined: np.ndarray, blk_list: List[TextBlock], refine_mode=REFINEMASK_INPAINT):
  125. mask_pred[np.where(mask_refined > 30)] = 0
  126. _, pred_mask_t = cv2.threshold(mask_pred, 30, 255, cv2.THRESH_BINARY)
  127. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(pred_mask_t, 4, cv2.CV_16U)
  128. valid_labels = np.where(stats[:, -1] > 50)[0]
  129. seg_blk_list = []
  130. if len(valid_labels) > 0:
  131. for lab_index in valid_labels[1:]:
  132. x, y, w, h, area = stats[lab_index]
  133. bx1, by1 = x, y
  134. bx2, by2 = x+w, y+h
  135. bbox = [bx1, by1, bx2, by2]
  136. bbox_score = -1
  137. for blk in blk_list:
  138. bbox_s = union_area(blk.xyxy, bbox)
  139. if bbox_s > bbox_score:
  140. bbox_score = bbox_s
  141. if bbox_score / w / h < 0.5:
  142. seg_blk_list.append(TextBlock(bbox))
  143. if len(seg_blk_list) > 0:
  144. mask_refined = cv2.bitwise_or(mask_refined, refine_mask(img, mask_pred, seg_blk_list, refine_mode=refine_mode))
  145. return mask_refined
  146. def refine_mask(img: np.ndarray, pred_mask: np.ndarray, blk_list: List[TextBlock], refine_mode: int = REFINEMASK_INPAINT) -> np.ndarray:
  147. mask_refined = np.zeros_like(pred_mask)
  148. for blk in blk_list:
  149. bx1, by1, bx2, by2 = expand_textwindow(img.shape, blk.xyxy, expand_r=16)
  150. im = np.ascontiguousarray(img[by1: by2, bx1: bx2])
  151. msk = np.ascontiguousarray(pred_mask[by1: by2, bx1: bx2])
  152. mask_list = get_topk_masklist(im, msk)
  153. mask_list += get_otsuthresh_masklist(im, msk, per_channel=False)
  154. mask_merged = merge_mask_list(mask_list, msk, blk=blk, text_window=[bx1, by1, bx2, by2], refine_mode=refine_mode)
  155. mask_refined[by1: by2, bx1: bx2] = cv2.bitwise_or(mask_refined[by1: by2, bx1: bx2], mask_merged)
  156. return mask_refined
  157. # def extract_textballoon(img, pred_textmsk=None, global_mask=None):
  158. # if len(img.shape) > 2 and img.shape[2] == 3:
  159. # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  160. # im_h, im_w = img.shape[0], img.shape[1]
  161. # hyp_textmsk = np.zeros((im_h, im_w), np.uint8)
  162. # thresh_val, threshed = cv2.threshold(img, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
  163. # xormap_sum = cv2.bitwise_xor(threshed, pred_textmsk).sum()
  164. # neg_threshed = 255 - threshed
  165. # neg_xormap_sum = cv2.bitwise_xor(neg_threshed, pred_textmsk).sum()
  166. # neg_thresh = neg_xormap_sum < xormap_sum
  167. # if neg_thresh:
  168. # threshed = neg_threshed
  169. # thresh_info = {'thresh_val': thresh_val,'neg_thresh': neg_thresh}
  170. # connectivity = 8
  171. # num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(threshed, connectivity, cv2.CV_16U)
  172. # label_unchanged = np.copy(labels)
  173. # if global_mask is not None:
  174. # labels[np.where(global_mask==0)] = 0
  175. # text_labels = []
  176. # if pred_textmsk is not None:
  177. # text_score_thresh = 0.5
  178. # textbbox_map = np.zeros_like(pred_textmsk)
  179. # for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
  180. # if label_index != 0: # skip background label
  181. # x, y, w, h, area = stat
  182. # area *= 255
  183. # x1, y1, x2, y2 = x, y, x+w, y+h
  184. # label_local = labels[y1: y2, x1: x2]
  185. # label_coordinates = np.where(label_local==label_index)
  186. # tmp_merged = np.zeros((h, w), np.uint8)
  187. # tmp_merged[label_coordinates] = 255
  188. # andmap = cv2.bitwise_and(tmp_merged, pred_textmsk[y1: y2, x1: x2])
  189. # text_score = andmap.sum() / area
  190. # if text_score > text_score_thresh:
  191. # text_labels.append(label_index)
  192. # hyp_textmsk[y1: y2, x1: x2][label_coordinates] = 255
  193. # labels = label_unchanged
  194. # bubble_msk = np.zeros((img.shape[0], img.shape[1]), np.uint8)
  195. # bubble_msk[np.where(labels==0)] = 255
  196. # # if lang == LANG_JPN:
  197. # bubble_msk = cv2.erode(bubble_msk, (3, 3), iterations=1)
  198. # line_thickness = 2
  199. # cv2.rectangle(bubble_msk, (0, 0), (im_w, im_h), BLACK, line_thickness, cv2.LINE_8)
  200. # contours, hiers = cv2.findContours(bubble_msk, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
  201. # brect_area_thresh = im_h * im_w * 0.4
  202. # min_brect_area = np.inf
  203. # ballon_index = -1
  204. # maximum_pixsum = -1
  205. # for ii, contour in enumerate(contours):
  206. # brect = cv2.boundingRect(contours[ii])
  207. # brect_area = brect[2] * brect[3]
  208. # if brect_area > brect_area_thresh and brect_area < min_brect_area:
  209. # tmp_ballonmsk = np.zeros_like(bubble_msk)
  210. # tmp_ballonmsk = cv2.drawContours(tmp_ballonmsk, contours, ii, WHITE, cv2.FILLED)
  211. # andmap_sum = cv2.bitwise_and(tmp_ballonmsk, hyp_textmsk).sum()
  212. # if andmap_sum > maximum_pixsum:
  213. # maximum_pixsum = andmap_sum
  214. # min_brect_area = brect_area
  215. # ballon_index = ii
  216. # if ballon_index != -1:
  217. # bubble_msk = np.zeros_like(bubble_msk)
  218. # bubble_msk = cv2.drawContours(bubble_msk, contours, ballon_index, WHITE, cv2.FILLED)
  219. # hyp_textmsk = cv2.bitwise_and(hyp_textmsk, bubble_msk)
  220. # return hyp_textmsk, bubble_msk, thresh_info, (num_labels, label_unchanged, stats, centroids, text_labels)
  221. # def extract_textballoon_channelwise(img, pred_textmsk, test_grey=True, global_mask=None):
  222. # c_list = [img[:, :, i] for i in range(3)]
  223. # if test_grey:
  224. # c_list.append(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY))
  225. # best_xorpix_sum = np.inf
  226. # best_cindex = best_hyptextmsk = best_bubblemsk = best_thresh_info = best_component_stats = None
  227. # for c_index, channel in enumerate(c_list):
  228. # hyp_textmsk, bubble_msk, thresh_info, component_stats = extract_textballoon(channel, pred_textmsk, global_mask=global_mask)
  229. # pixor_sum = cv2.bitwise_xor(hyp_textmsk, pred_textmsk).sum()
  230. # if pixor_sum < best_xorpix_sum:
  231. # best_xorpix_sum = pixor_sum
  232. # best_cindex = c_index
  233. # best_hyptextmsk, best_bubblemsk, best_thresh_info, best_component_stats = hyp_textmsk, bubble_msk, thresh_info, component_stats
  234. # return best_hyptextmsk, best_bubblemsk, best_component_stats
  235. # def refine_textmask(img, pred_mask, channel_wise=True, find_leaveouts=True, global_mask=None):
  236. # hyp_textmsk, bubble_msk, component_stats = extract_textballoon_channelwise(img, pred_mask, global_mask=global_mask)
  237. # num_labels, labels, stats, centroids, text_labels = component_stats
  238. # stats = np.array(stats)
  239. # text_stats = stats[text_labels]
  240. # if find_leaveouts and len(text_stats) > 0:
  241. # median_h = np.median(text_stats[:, 3])
  242. # for label, label_h in zip(range(num_labels), stats[:, 3]):
  243. # if label == 0 or label in text_labels:
  244. # continue
  245. # if label_h > 0.5 * median_h and label_h < 1.5 * median_h:
  246. # hyp_textmsk[np.where(labels==label)] = 255
  247. # hyp_textmsk = cv2.bitwise_and(hyp_textmsk, bubble_msk)
  248. # if global_mask is not None:
  249. # hyp_textmsk = cv2.bitwise_and(hyp_textmsk, global_mask)
  250. # return hyp_textmsk, bubble_msk