db_dataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import numpy as np
  2. import yaml
  3. import torch
  4. import glob
  5. import os
  6. import os.path as osp
  7. import random
  8. from itertools import repeat
  9. from multiprocessing.pool import Pool, ThreadPool
  10. from pathlib import Path
  11. from threading import Thread
  12. import cv2
  13. from torch.utils.data import Dataset
  14. from tqdm import tqdm
  15. from pathlib import Path
  16. from torchvision import transforms
  17. from torch.utils.data import DataLoader, Dataset, dataloader
  18. from utils.general import LOGGER, Loggers, CUDA, DEVICE
  19. from utils.db_utils import MakeBorderMap, MakeShrinkMap
  20. from seg_dataset import augment_hsv
  21. from utils.imgproc_utils import rotate_polygons, letterbox, resize_keepasp
  22. from PIL import Image
  23. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # DPP
  24. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of multiprocessing threads
  25. IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg']
  26. def db_val_collate_fn(batchs):
  27. cat_list = ['text_polys', 'ignore_tags']
  28. ret_batchs = {}
  29. for key in batchs[0].keys():
  30. ret_batchs[key] = []
  31. for batch in batchs:
  32. if isinstance(batch[key], np.ndarray):
  33. batch[key] = torch.from_numpy(batch[key])
  34. ret_batchs[key].append(batch[key])
  35. if key in cat_list:
  36. pass
  37. else:
  38. ret_batchs[key] = torch.stack(ret_batchs[key], 0)
  39. return ret_batchs
  40. class LoadImageAndAnnotations(Dataset):
  41. def __init__(self, img_dir, ann_dir=None, img_size=640, augment=False, aug_param=None, cache=False, stride=128, cache_ann_only=True, with_ann=False):
  42. if isinstance(img_dir, str):
  43. self.img_dir = [img_dir]
  44. elif isinstance(img_dir, list):
  45. self.img_dir = img_dir
  46. else:
  47. raise Exception('unknown img_dir format')
  48. if ann_dir is None or ann_dir == '':
  49. self.ann_dir = self.img_dir
  50. else:
  51. if isinstance(ann_dir, str):
  52. self.ann_dir = [ann_dir]
  53. elif isinstance(ann_dir, list):
  54. self.ann_dir = ann_dir
  55. self.with_ann = with_ann
  56. self.make_border_map = MakeBorderMap(shrink_ratio=0.4)
  57. self.make_shrink_map = MakeShrinkMap(shrink_ratio=0.4)
  58. self.img_ann_list = []
  59. self.img_size = (img_size, img_size)
  60. self.stride = stride
  61. self._augment = augment
  62. if self._augment:
  63. self._mini_mosaic = aug_param['mini_mosaic']
  64. self._augment_hsv = aug_param['hsv']
  65. self._flip_lr = aug_param['flip_lr']
  66. self._neg = aug_param['neg']
  67. self._rotate = aug_param['rotate']
  68. self.rotate_range = aug_param['rotate_range']
  69. size_range = aug_param['size_range']
  70. if isinstance(size_range, list) and size_range[0] > 0:
  71. min_size = round(img_size * size_range[0] / stride ) * stride
  72. max_size = round(img_size * size_range[1] / stride ) * stride
  73. self.valid_size = np.arange(min_size, max_size+1, stride)
  74. self.multi_size = True
  75. else:
  76. self.valid_size = None
  77. self.multi_size = False
  78. for img_dir in self.img_dir:
  79. for filep in glob.glob(osp.join(img_dir, "*")):
  80. filename = osp.basename(filep)
  81. file_suffix = Path(filename).suffix
  82. if file_suffix not in IMG_EXT:
  83. continue
  84. annname = 'line-' + filename.replace(file_suffix, '.txt')
  85. for ann_dir in self.ann_dir:
  86. annp = osp.join(ann_dir, annname)
  87. if osp.exists(annp):
  88. self.img_ann_list.append((filep, annp))
  89. self._img_transform = transforms.Compose([transforms.ToTensor()])
  90. n = len(self.img_ann_list)
  91. self.imgs, self.anns = [None] * n, [None] * n
  92. gb = 0
  93. if cache:
  94. results = ThreadPool(NUM_THREADS).imap(lambda x: load_image_annotations(*x, max_size=img_size), zip(repeat(self), range(n)))
  95. pbar = tqdm(enumerate(results), total=n)
  96. for i, x in pbar:
  97. im, self.anns[i] = x # im, hw_orig, hw_resized = load_image_ann(self, i)
  98. if not cache_ann_only:
  99. self.imgs[i] = im
  100. gb += self.imgs[i].nbytes
  101. gb += self.anns[i].nbytes
  102. if gb / 1E9 > 7:
  103. break
  104. pbar.desc = f'Caching images ({gb / 1E9:.1f}GB )'
  105. pbar.close()
  106. def initialize(self):
  107. if self.augment:
  108. if self.multi_size:
  109. self.img_size = random.choice(self.valid_size)
  110. def transform(self, img):
  111. cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
  112. img = img.astype(np.float32) / 255
  113. img = self._img_transform(img)
  114. return img
  115. def mini_mosaic(self, img, ann):
  116. im_h, im_w = img.shape[:2]
  117. idx = random.randint(0, len(self)-1)
  118. img2, ann2 = load_image_annotations(self, idx, self.img_size)
  119. img2_h, img2_w = img2.shape[:2]
  120. if img2_h > img2_w:
  121. imm_h = max(im_h, img2_h)
  122. imm_w = im_w + img2_w
  123. im_tmp = np.zeros((imm_h, imm_w, 3), np.uint8)
  124. im_tmp[:im_h, :im_w] = img
  125. im_tmp[:img2_h, im_w:] = img2
  126. ann[:, :, 0] = ann[:, :, 0] * im_w / imm_w
  127. ann[:, :, 1] = ann[:, :, 1] * im_h / imm_h
  128. if ann2.shape[1] > 0:
  129. ann2[:, :, 0] = ann2[:, :, 0] * img2_w / imm_w + im_w / imm_w
  130. ann2[:, :, 1] = ann2[:, :, 1] * img2_h / imm_h
  131. ann = np.concatenate((ann, ann2))
  132. img = im_tmp
  133. return img, ann
  134. else:
  135. return img, ann
  136. def augment(self, img, ann):
  137. im_h, im_w = img.shape[0], img.shape[1]
  138. if im_h > im_w and random.random() < self._mini_mosaic:
  139. # imp2, annp2 = random.choice(self.img_ann_list)
  140. img, ann = self.mini_mosaic(img, ann)
  141. if random.random() < self._augment_hsv:
  142. augment_hsv(img)
  143. if random.random() < self._flip_lr:
  144. cv2.flip(img, 1, img)
  145. ann[:, :, 0] = 1 - ann[:, :, 0]
  146. if random.random() < self._neg:
  147. img = 255 - img
  148. if random.random() < self._rotate:
  149. degrees = random.uniform(self.rotate_range[0], self.rotate_range[1])
  150. if abs(degrees) > 15:
  151. img = Image.fromarray(img)
  152. center = (img.width/2, img.height/2)
  153. ann[:, :, 0] *= img.width
  154. ann[:, :, 1] *= img.height
  155. ann = ann.reshape(len(ann), -1)
  156. img = img.rotate(degrees, resample=Image.BILINEAR, expand=1)
  157. new_center = (img.width/2, img.height/2)
  158. ann = rotate_polygons(center, ann, degrees, new_center, to_int=False)
  159. ann = ann.reshape(len(ann), -1, 2)
  160. ann[:, :, 0] /= img.width
  161. ann[:, :, 1] /= img.height
  162. img = np.asarray(img)
  163. return img, ann
  164. def inverse_transform(self, img: torch.Tensor, scale=255, to_uint8=True):
  165. img = img.permute(1, 2, 0)
  166. img = img * scale
  167. img = img.cpu().numpy()
  168. if to_uint8:
  169. img = np.ascontiguousarray(img, np.uint8)
  170. return img
  171. def __len__(self):
  172. return len(self.img_ann_list)
  173. def __getitem__(self, idx):
  174. img, ann = load_image_annotations(self, idx, self.img_size)
  175. in_h, in_w = img.shape[:2]
  176. if self._augment:
  177. img, ann = self.augment(img, ann)
  178. ignore_tags = [False] * ann.shape[0]
  179. img, ratio, (dw, dh) = letterbox(img, new_shape=self.img_size, auto=False)
  180. im_h, im_w = img.shape[:2]
  181. if ann is not None:
  182. ann[:, :, 0] *= (im_w - dw)
  183. ann[:, :, 1] *= (im_h - dh)
  184. ann = ann.astype(np.int64)
  185. data_dict = {'imgs': img, 'text_polys': ann, 'ignore_tags': ignore_tags}
  186. shrink_map = self.make_shrink_map(data_dict)
  187. thresh_map = self.make_border_map(data_dict)
  188. tp = thresh_map.pop('text_polys')
  189. it = thresh_map.pop('ignore_tags')
  190. if self.with_ann:
  191. thresh_map['text_polys'] = torch.from_numpy(np.array(tp))
  192. thresh_map['ignore_tags'] = torch.from_numpy(np.array(it))
  193. thresh_map['imgs'] = self.transform(thresh_map['imgs'])
  194. return thresh_map
  195. def load_image_annotations(self, i, max_size=None, ann_abs2rel=True):
  196. # loads 1 image from dataset index 'i', returns im, original hw, resized hw
  197. img, ann = self.imgs[i], self.anns[i]
  198. imp, ann_path = self.img_ann_list[i]
  199. if img is None:
  200. img = cv2.imread(imp)
  201. im_h, im_w = img.shape[:2]
  202. if ann is None:
  203. ann = np.loadtxt(ann_path)
  204. if len(ann.shape) == 1:
  205. ann = np.array([ann])
  206. if ann_abs2rel:
  207. ann[:, ::2] /= im_w
  208. ann[:, 1::2] /= im_h
  209. ann = ann.reshape(len(ann), -1, 2)
  210. else:
  211. ann = np.copy(ann)
  212. if max_size is not None:
  213. if isinstance(max_size, tuple):
  214. max_size = max_size[0]
  215. img = resize_keepasp(img, max_size)
  216. return img, ann
  217. def create_dataloader(img_dir, ann_dir, imgsz, batch_size, augment=False, aug_param=None, cache=False, workers=8, shuffle=False, with_ann=False):
  218. dataset = LoadImageAndAnnotations(img_dir, ann_dir, imgsz, augment, aug_param, cache, with_ann=with_ann)
  219. batch_size = min(batch_size, len(dataset))
  220. nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
  221. if with_ann:
  222. collate_fn = db_val_collate_fn
  223. else:
  224. collate_fn = None
  225. loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=nw, collate_fn=collate_fn)
  226. return dataset, loader
  227. if __name__ == '__main__':
  228. img_dir = 'data/dataset/db_sub'
  229. hyp_p = r'data/train_db_hyp.yaml'
  230. with open(hyp_p, 'r', encoding='utf8') as f:
  231. hyp = yaml.safe_load(f.read())
  232. hyp['data']['train_img_dir'] = img_dir
  233. hyp['data']['cache'] = False
  234. hyp_train, hyp_data, hyp_model, hyp_logger, hyp_resume = hyp['train'], hyp['data'], hyp['model'], hyp['logger'], hyp['resume']
  235. batch_size = hyp_train['batch_size']
  236. batch_size = 1
  237. num_workers = 0
  238. train_img_dir, train_mask_dir, imgsz, augment, aug_param = hyp_data['train_img_dir'], hyp_data['train_mask_dir'], hyp_data['imgsz'], hyp_data['augment'], hyp_data['aug_param']
  239. train_dataset, train_loader = create_dataloader(train_img_dir, train_mask_dir, imgsz, batch_size, augment, aug_param, shuffle=True, workers=num_workers, cache=hyp_data['cache'], with_ann=True)
  240. for ii in range(10):
  241. for batchs in train_loader:
  242. train_dataset.initialize()
  243. print(train_dataset.img_size)
  244. img = batchs['imgs'][0]
  245. img = train_dataset.inverse_transform(img)
  246. threshold_map = batchs['threshold_map'][0]
  247. threshold_mask = batchs['threshold_mask'][0]
  248. shrink_map = batchs['shrink_map'][0]
  249. shrink_mask = batchs['shrink_mask'][0]
  250. polys = batchs['text_polys'][0].numpy().astype(np.int32)
  251. for p in polys:
  252. cv2.polylines(img,[p],True,(255, 0, 0), thickness=2)
  253. cv2.imshow('imgs', img)
  254. cv2.imshow('threshold_map', threshold_map.numpy())
  255. cv2.imshow('threshold_mask', threshold_mask.numpy())
  256. cv2.imshow('shrink_map', shrink_map.numpy())
  257. cv2.imshow('shrink_mask', shrink_mask.numpy())
  258. cv2.waitKey(0)