seg_dataset.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import glob
  2. import os
  3. import os.path as osp
  4. import random
  5. from itertools import repeat
  6. from multiprocessing.pool import Pool, ThreadPool
  7. from pathlib import Path
  8. from threading import Thread
  9. from zipfile import ZipFile
  10. import cv2
  11. import numpy as np
  12. from numpy.lib.npyio import load
  13. from numpy.random import rand
  14. import torch
  15. import torch.nn.functional as F
  16. from torch.utils import data
  17. from torchvision.transforms.transforms import Compose
  18. from torch.utils.data import Dataset
  19. from tqdm import tqdm
  20. from pathlib import Path
  21. from tqdm import tqdm
  22. from torchvision import transforms
  23. import random
  24. from torch.utils.data import DataLoader, Dataset
  25. from utils.general import LOGGER, Loggers, CUDA, DEVICE
  26. from utils.imgproc_utils import resize_keepasp, letterbox
  27. from utils.io_utils import imread, imwrite
  28. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # DPP
  29. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of multiprocessing threads
  30. IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg']
  31. def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
  32. # HSV color-space augmentation
  33. if hgain or sgain or vgain:
  34. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  35. hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
  36. dtype = im.dtype # uint8
  37. x = np.arange(0, 256, dtype=r.dtype)
  38. lut_hue = ((x * r[0]) % 180).astype(dtype)
  39. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  40. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  41. im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
  42. cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
  43. def load_image_mask(self, i, max_size=None):
  44. # loads 1 image from dataset index 'i', returns im, original hw, resized hw
  45. img, mask = self.imgs[i], self.masks[i]
  46. imp, maskp = self.img_mask_list[i]
  47. if img is None:
  48. img = cv2.imread(imp)
  49. if mask is None:
  50. mask = cv2.imread(maskp, cv2.IMREAD_GRAYSCALE)
  51. if max_size is not None:
  52. if isinstance(max_size, tuple):
  53. max_size = max_size[0]
  54. try:
  55. img = resize_keepasp(img, max_size)
  56. mask = resize_keepasp(mask, max_size, interpolation=cv2.INTER_AREA)
  57. except:
  58. pass
  59. return img, mask
  60. def mini_mosaic(self, img, mask):
  61. im_h, im_w = img.shape[0], img.shape[1]
  62. idx = random.randint(0, len(self)-1)
  63. img2, mask2 = load_image_mask(self, idx, self.img_size)
  64. img2_h, img2_w = img2.shape[0], img2.shape[1]
  65. ratio = img2_h / im_h
  66. if img2_h > img2_w and ratio > 0.4 and ratio < 1.6:
  67. im_h = max(im_h, img2_h)
  68. im_w = im_w + img2_w
  69. im_tmp = np.zeros((im_h, im_w, 3), np.uint8)
  70. im_tmp[:img.shape[0], :img.shape[1]] = img
  71. im_tmp[:img2_h, img.shape[1]:] = img2
  72. mask_tmp = np.zeros((im_h, im_w), np.uint8)
  73. mask_tmp[:img.shape[0], :img.shape[1]] = mask
  74. mask_tmp[:img2_h, img.shape[1]:] = mask2
  75. img = np.ascontiguousarray(im_tmp)
  76. mask = np.ascontiguousarray(mask_tmp)
  77. return img, mask
  78. class LoadImageAndMask(Dataset):
  79. def __init__(self, img_dir, mask_dir=None, img_size=640, augment=False, aug_param=None, cache=False, stride=128, cache_mask_only=True):
  80. if isinstance(img_dir, str):
  81. self.img_dir = [img_dir]
  82. elif isinstance(img_dir, list):
  83. self.img_dir = img_dir
  84. else:
  85. raise Exception('unknown img_dir format')
  86. if mask_dir is None or mask_dir == '':
  87. self.mask_dir = self.img_dir
  88. else:
  89. if isinstance(mask_dir, str):
  90. self.mask_dir = [mask_dir]
  91. elif isinstance(mask_dir, list):
  92. self.mask_dir = mask_dir
  93. self.img_mask_list = []
  94. self.img_size = (img_size, img_size)
  95. self.stride = stride
  96. self._augment = augment
  97. if self._augment:
  98. self._mini_mosaic = aug_param['mini_mosaic']
  99. self._augment_hsv = aug_param['hsv']
  100. self._flip_lr = aug_param['flip_lr']
  101. self._neg = aug_param['neg']
  102. size_range = aug_param['size_range']
  103. if size_range[0] != -1:
  104. min_size = round(img_size * size_range[0] / stride ) * stride
  105. max_size = round(img_size * size_range[1] / stride ) * stride
  106. self.valid_size = np.arange(min_size, max_size+1, stride)
  107. self.multi_size = True
  108. else:
  109. self.valid_size = None
  110. self.multi_size = False
  111. for img_dir in self.img_dir:
  112. for filep in glob.glob(osp.join(img_dir, "*")):
  113. filename = osp.basename(filep)
  114. file_suffix = Path(filename).suffix
  115. if file_suffix.lower() not in IMG_EXT:
  116. continue
  117. maskname = 'mask-' + filename.replace(file_suffix, '.png')
  118. for mask_dir in self.mask_dir:
  119. maskp = osp.join(mask_dir, maskname)
  120. if osp.exists(maskp):
  121. self.img_mask_list.append((filep, maskp))
  122. self._img_transform = transforms.Compose([transforms.ToTensor()])
  123. self._mask_transform = transforms.Compose([transforms.ToTensor()])
  124. n = len(self.img_mask_list)
  125. self.imgs, self.masks = [None] * n, [None] * n
  126. gb = 0
  127. if cache:
  128. results = ThreadPool(NUM_THREADS).imap(lambda x: load_image_mask(*x, max_size=img_size), zip(repeat(self), range(n)))
  129. pbar = tqdm(enumerate(results), total=n)
  130. for i, x in pbar:
  131. im, self.masks[i] = x # im, hw_orig, hw_resized = load_image_mask(self, i)
  132. if not cache_mask_only:
  133. self.imgs[i] = im
  134. gb += self.imgs[i].nbytes
  135. gb += self.masks[i].nbytes
  136. if gb / 1E9 > 7:
  137. break
  138. pbar.desc = f'Caching images ({gb / 1E9:.1f}GB )'
  139. pbar.close()
  140. def initialize(self):
  141. if self.augment:
  142. if self.multi_size:
  143. self.img_size = random.choice(self.valid_size)
  144. def transform(self, img, mask):
  145. cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
  146. img = img.astype(np.float32) / 255
  147. mask = (mask > 30).astype(np.float32)
  148. # mask = mask / 255
  149. img = self._img_transform(img)
  150. mask = self._mask_transform(mask)
  151. return img, mask
  152. def augment(self, img, mask):
  153. im_h, im_w = img.shape[0], img.shape[1]
  154. if im_h > im_w and random.random() < self._mini_mosaic:
  155. # imp2, maskp2 = random.choice(self.img_mask_list)
  156. img, mask = mini_mosaic(self, img, mask)
  157. img, ratio, (dw, dh) = letterbox(img, new_shape=self.img_size, auto=False)
  158. mask, ratio, (dw, dh) = letterbox(mask, new_shape=self.img_size, auto=False)
  159. if random.random() < self._augment_hsv:
  160. augment_hsv(img)
  161. if random.random() < self._flip_lr:
  162. cv2.flip(img, 1, img)
  163. cv2.flip(mask, 1, mask)
  164. if random.random() < self._neg:
  165. img = 255 - img
  166. return img, mask
  167. def inverse_transform(self, img: torch.Tensor):
  168. img = img.permute(1, 2, 0)
  169. img = img * 255
  170. img = img.cpu().numpy().astype(np.uint8)
  171. return img
  172. def __len__(self):
  173. return len(self.img_mask_list)
  174. def __getitem__(self, idx):
  175. img, mask = load_image_mask(self, idx, self.img_size)
  176. if self._augment:
  177. img, mask = self.augment(img, mask)
  178. else:
  179. img, ratio, (dw, dh) = letterbox(img, new_shape=self.img_size, auto=False)
  180. mask, ratio, (dw, dh) = letterbox(mask, new_shape=self.img_size, auto=False)
  181. return self.transform(img, mask)
  182. def create_dataloader(img_dir, mask_dir, imgsz, batch_size, augment=False, aug_param=None, cache=False, workers=8, shuffle=False):
  183. dataset = LoadImageAndMask(img_dir, mask_dir, imgsz, augment, aug_param, cache)
  184. batch_size = min(batch_size, len(dataset))
  185. nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
  186. loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=nw)
  187. return dataset, loader
  188. if __name__ == '__main__':
  189. random.seed(42)
  190. torch.random.manual_seed(42)
  191. np.random.seed(42)
  192. import yaml
  193. hyp_p = r'data/train_hyp.yaml'
  194. with open(hyp_p, 'r', encoding='utf8') as f:
  195. hyp = yaml.safe_load(f.read())
  196. hyp['data']['train_img_dir'] = [r'D:/neonbub/datasets/codat_manga_v3/images/train', r'D:/neonbub/datasets/ComicErased/processed']
  197. hyp['data']['val_img_dir'] = [r'D:/neonbub/datasets/codat_manga_v3/images/val']
  198. hyp['data']['train_mask_dir'] = r'D:/neonbub/datasets/ComicSegV2'
  199. hyp['data']['val_mask_dir'] = r'D:/neonbub/datasets/ComicSegV2'
  200. hyp['data']['cache'] = False
  201. hyp_train, hyp_data, hyp_model, hyp_logger, hyp_resume = hyp['train'], hyp['data'], hyp['model'], hyp['logger'], hyp['resume']
  202. batch_size = hyp_train['batch_size']
  203. batch_size = 4
  204. num_workers = 2
  205. train_img_dir, train_mask_dir, imgsz, augment, aug_param = hyp_data['train_img_dir'], hyp_data['train_mask_dir'], hyp_data['imgsz'], hyp_data['augment'], hyp_data['aug_param']
  206. val_img_dir, val_mask_dir = hyp_data['val_img_dir'], hyp_data['val_mask_dir']
  207. train_dataset, train_loader = create_dataloader(train_img_dir, train_mask_dir, imgsz, batch_size, augment, aug_param, shuffle=True, workers=num_workers, cache=hyp_data['cache'])
  208. val_dataset, val_loader = create_dataloader(val_img_dir, val_mask_dir, imgsz, batch_size, augment=False, shuffle=False, workers=num_workers, cache=hyp_data['cache'])
  209. LOGGER.info(f'num training imgs: {len(train_dataset)}, num val imgs: {len(val_dataset)}')
  210. for epoch in range(0, 4): # epoch ------------------------------------------------------------------
  211. train_dataset.initialize()
  212. pbar = enumerate(train_loader)
  213. pbar = tqdm(pbar, total=len(train_loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  214. pbar.set_description(f' training size: {train_dataset.img_size}')
  215. for i, (imgs, masks) in pbar:
  216. img, mask = imgs[0], masks[0]
  217. imgs = imgs
  218. masks = masks
  219. img = train_dataset.inverse_transform(img)
  220. mask = train_dataset.inverse_transform(mask)
  221. cv2.imshow('img', img)
  222. cv2.imshow('mask', mask)
  223. cv2.waitKey(0)
  224. pbar.close()