inference.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import json
  2. from basemodel import TextDetBase, TextDetBaseDNN
  3. import os.path as osp
  4. from tqdm import tqdm
  5. import numpy as np
  6. import cv2
  7. import torch
  8. from pathlib import Path
  9. import torch
  10. from utils.yolov5_utils import non_max_suppression
  11. from utils.db_utils import SegDetectorRepresenter
  12. from utils.io_utils import imread, imwrite, find_all_imgs, NumpyEncoder
  13. from utils.imgproc_utils import letterbox, xyxy2yolo, get_yololabel_strings
  14. from utils.textblock import TextBlock, group_output, visualize_textblocks
  15. from utils.textmask import refine_mask, refine_undetected_mask, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION
  16. from pathlib import Path
  17. from typing import Union
  18. def model2annotations(model_path, img_dir_list, save_dir, save_json=False):
  19. if isinstance(img_dir_list, str):
  20. img_dir_list = [img_dir_list]
  21. cuda = torch.cuda.is_available()
  22. device = 'cuda' if cuda else 'cpu'
  23. model = TextDetector(model_path=model_path, input_size=1024, device=device, act='leaky')
  24. imglist = []
  25. for img_dir in img_dir_list:
  26. imglist += find_all_imgs(img_dir, abs_path=True)
  27. for img_path in tqdm(imglist):
  28. imgname = osp.basename(img_path)
  29. img = imread(img_path)
  30. im_h, im_w = img.shape[:2]
  31. imname = imgname.replace(Path(imgname).suffix, '')
  32. maskname = 'mask-'+imname+'.png'
  33. poly_save_path = osp.join(save_dir, 'line-' + imname + '.txt')
  34. mask, mask_refined, blk_list = model(img, refine_mode=REFINEMASK_ANNOTATION, keep_undetected_mask=True)
  35. polys = []
  36. blk_xyxy = []
  37. blk_dict_list = []
  38. for blk in blk_list:
  39. polys += blk.lines
  40. blk_xyxy.append(blk.xyxy)
  41. blk_dict_list.append(blk.to_dict())
  42. blk_xyxy = xyxy2yolo(blk_xyxy, im_w, im_h)
  43. if blk_xyxy is not None:
  44. cls_list = [1] * len(blk_xyxy)
  45. yolo_label = get_yololabel_strings(cls_list, blk_xyxy)
  46. else:
  47. yolo_label = ''
  48. with open(osp.join(save_dir, imname+'.txt'), 'w', encoding='utf8') as f:
  49. f.write(yolo_label)
  50. # num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask)
  51. # _, mask = cv2.threshold(mask, 50, 255, cv2.THRESH_BINARY)
  52. # draw_connected_labels(num_labels, labels, stats, centroids)
  53. # visualize_textblocks(img, blk_list)
  54. # cv2.imshow('rst', img)
  55. # cv2.imshow('mask', mask)
  56. # cv2.imshow('mask_refined', mask_refined)
  57. # cv2.waitKey(0)
  58. if len(polys) != 0:
  59. if isinstance(polys, list):
  60. polys = np.array(polys)
  61. polys = polys.reshape(-1, 8)
  62. np.savetxt(poly_save_path, polys, fmt='%d')
  63. if save_json:
  64. with open(osp.join(save_dir, imname+'.json'), 'w', encoding='utf8') as f:
  65. f.write(json.dumps(blk_dict_list, ensure_ascii=False, cls=NumpyEncoder))
  66. imwrite(osp.join(save_dir, imgname), img)
  67. imwrite(osp.join(save_dir, maskname), mask_refined)
  68. def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True):
  69. if bgr2rgb:
  70. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  71. img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64)
  72. if to_tensor:
  73. img_in = img_in.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  74. img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255
  75. if to_tensor:
  76. img_in = torch.from_numpy(img_in).to(device)
  77. if half:
  78. img_in = img_in.half()
  79. return img_in, ratio, int(dw), int(dh)
  80. def postprocess_mask(img: Union[torch.Tensor, np.ndarray], thresh=None):
  81. # img = img.permute(1, 2, 0)
  82. if isinstance(img, torch.Tensor):
  83. img = img.squeeze_()
  84. if img.device != 'cpu':
  85. img = img.detach_().cpu()
  86. img = img.numpy()
  87. else:
  88. img = img.squeeze()
  89. if thresh is not None:
  90. img = img > thresh
  91. img = img * 255
  92. # if isinstance(img, torch.Tensor):
  93. return img.astype(np.uint8)
  94. def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None):
  95. det = non_max_suppression(det, conf_thresh, nms_thresh)[0]
  96. # bbox = det[..., 0:4]
  97. if det.device != 'cpu':
  98. det = det.detach_().cpu().numpy()
  99. det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0]
  100. det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1]
  101. if sort_func is not None:
  102. det = sort_func(det)
  103. blines = det[..., 0:4].astype(np.int32)
  104. confs = np.round(det[..., 4], 3)
  105. cls = det[..., 5].astype(np.int32)
  106. return blines, cls, confs
  107. class TextDetector:
  108. lang_list = ['eng', 'ja', 'unknown']
  109. langcls2idx = {'eng': 0, 'ja': 1, 'unknown': 2}
  110. def __init__(self, model_path, input_size=1024, device='cpu', half=False, nms_thresh=0.35, conf_thresh=0.4, mask_thresh=0.3, act='leaky'):
  111. super(TextDetector, self).__init__()
  112. cuda = device == 'cuda'
  113. if Path(model_path).suffix == '.onnx':
  114. self.model = cv2.dnn.readNetFromONNX(model_path)
  115. self.net = TextDetBaseDNN(input_size, model_path)
  116. self.backend = 'opencv'
  117. else:
  118. self.net = TextDetBase(model_path, device=device, act=act)
  119. self.backend = 'torch'
  120. if isinstance(input_size, int):
  121. input_size = (input_size, input_size)
  122. self.input_size = input_size
  123. self.device = device
  124. self.half = half
  125. self.conf_thresh = conf_thresh
  126. self.nms_thresh = nms_thresh
  127. self.seg_rep = SegDetectorRepresenter(thresh=0.3)
  128. @torch.no_grad()
  129. def __call__(self, img, refine_mode=REFINEMASK_INPAINT, keep_undetected_mask=False):
  130. img_in, ratio, dw, dh = preprocess_img(img, input_size=self.input_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')
  131. im_h, im_w = img.shape[:2]
  132. blks, mask, lines_map = self.net(img_in)
  133. resize_ratio = (im_w / (self.input_size[0] - dw), im_h / (self.input_size[1] - dh))
  134. blks = postprocess_yolo(blks, self.conf_thresh, self.nms_thresh, resize_ratio)
  135. if self.backend == 'opencv':
  136. if mask.shape[1] == 2: # some version of opencv spit out reversed result
  137. tmp = mask
  138. mask = lines_map
  139. lines_map = tmp
  140. mask = postprocess_mask(mask)
  141. lines, scores = self.seg_rep(self.input_size, lines_map)
  142. box_thresh = 0.6
  143. idx = np.where(scores[0] > box_thresh)
  144. lines, scores = lines[0][idx], scores[0][idx]
  145. # map output to input img
  146. mask = mask[: mask.shape[0]-dh, : mask.shape[1]-dw]
  147. mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
  148. if lines.size == 0 :
  149. lines = []
  150. else :
  151. lines = lines.astype(np.float64)
  152. lines[..., 0] *= resize_ratio[0]
  153. lines[..., 1] *= resize_ratio[1]
  154. lines = lines.astype(np.int32)
  155. blk_list = group_output(blks, lines, im_w, im_h, mask)
  156. mask_refined = refine_mask(img, mask, blk_list, refine_mode=refine_mode)
  157. if keep_undetected_mask:
  158. mask_refined = refine_undetected_mask(img, mask, mask_refined, blk_list, refine_mode=refine_mode)
  159. return mask, mask_refined, blk_list
  160. def traverse_by_dict(img_dir_list, dict_dir):
  161. if isinstance(img_dir_list, str):
  162. img_dir_list = [img_dir_list]
  163. imglist = []
  164. for img_dir in img_dir_list:
  165. imglist += find_all_imgs(img_dir, abs_path=True)
  166. for img_path in tqdm(imglist):
  167. imgname = osp.basename(img_path)
  168. imname = imgname.replace(Path(imgname).suffix, '')
  169. mask_path = osp.join(dict_dir, 'mask-'+imname+'.png')
  170. with open(osp.join(dict_dir, imname+'.json'), 'r', encoding='utf8') as f:
  171. blk_dict_list = json.loads(f.read())
  172. blk_list = [TextBlock(**blk_dict) for blk_dict in blk_dict_list]
  173. img = cv2.imread(img_path)
  174. mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
  175. mask = refine_mask(img, mask, blk_list)
  176. visualize_textblocks(img, blk_list)
  177. cv2.imshow('im', img)
  178. cv2.imshow('mask', mask)
  179. cv2.waitKey(0)
  180. if __name__ == '__main__':
  181. device = 'cpu'
  182. model_path = 'data/comictextdetector.pt'
  183. model_path = 'data/comictextdetector.pt.onnx'
  184. img_dir = r'data/examples'
  185. save_dir = r'data/backup'
  186. model2annotations(model_path, img_dir, save_dir, save_json=True)
  187. traverse_by_dict(img_dir, save_dir)