| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- import glob
- import os
- import os.path as osp
- import random
- from itertools import repeat
- from multiprocessing.pool import Pool, ThreadPool
- from pathlib import Path
- from threading import Thread
- from zipfile import ZipFile
- import cv2
- import numpy as np
- from numpy.lib.npyio import load
- from numpy.random import rand
- import torch
- import torch.nn.functional as F
- from torch.utils import data
- from torchvision.transforms.transforms import Compose
- from torch.utils.data import Dataset
- from tqdm import tqdm
- from pathlib import Path
- from tqdm import tqdm
- from torchvision import transforms
- import random
- from torch.utils.data import DataLoader, Dataset
- from utils.general import LOGGER, Loggers, CUDA, DEVICE
- from utils.imgproc_utils import resize_keepasp, letterbox
- from utils.io_utils import imread, imwrite
- WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # DPP
- NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of multiprocessing threads
- IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg']
- def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
- # HSV color-space augmentation
- if hgain or sgain or vgain:
- r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
- hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
- dtype = im.dtype # uint8
- x = np.arange(0, 256, dtype=r.dtype)
- lut_hue = ((x * r[0]) % 180).astype(dtype)
- lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
- lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
- im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
- cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
- def load_image_mask(self, i, max_size=None):
- # loads 1 image from dataset index 'i', returns im, original hw, resized hw
- img, mask = self.imgs[i], self.masks[i]
- imp, maskp = self.img_mask_list[i]
- if img is None:
- img = cv2.imread(imp)
- if mask is None:
- mask = cv2.imread(maskp, cv2.IMREAD_GRAYSCALE)
- if max_size is not None:
- if isinstance(max_size, tuple):
- max_size = max_size[0]
- try:
- img = resize_keepasp(img, max_size)
- mask = resize_keepasp(mask, max_size, interpolation=cv2.INTER_AREA)
- except:
- pass
- return img, mask
- def mini_mosaic(self, img, mask):
- im_h, im_w = img.shape[0], img.shape[1]
- idx = random.randint(0, len(self)-1)
- img2, mask2 = load_image_mask(self, idx, self.img_size)
- img2_h, img2_w = img2.shape[0], img2.shape[1]
- ratio = img2_h / im_h
- if img2_h > img2_w and ratio > 0.4 and ratio < 1.6:
- im_h = max(im_h, img2_h)
- im_w = im_w + img2_w
- im_tmp = np.zeros((im_h, im_w, 3), np.uint8)
- im_tmp[:img.shape[0], :img.shape[1]] = img
- im_tmp[:img2_h, img.shape[1]:] = img2
- mask_tmp = np.zeros((im_h, im_w), np.uint8)
- mask_tmp[:img.shape[0], :img.shape[1]] = mask
- mask_tmp[:img2_h, img.shape[1]:] = mask2
- img = np.ascontiguousarray(im_tmp)
- mask = np.ascontiguousarray(mask_tmp)
- return img, mask
- class LoadImageAndMask(Dataset):
- def __init__(self, img_dir, mask_dir=None, img_size=640, augment=False, aug_param=None, cache=False, stride=128, cache_mask_only=True):
- if isinstance(img_dir, str):
- self.img_dir = [img_dir]
- elif isinstance(img_dir, list):
- self.img_dir = img_dir
- else:
- raise Exception('unknown img_dir format')
-
- if mask_dir is None or mask_dir == '':
- self.mask_dir = self.img_dir
- else:
- if isinstance(mask_dir, str):
- self.mask_dir = [mask_dir]
- elif isinstance(mask_dir, list):
- self.mask_dir = mask_dir
-
- self.img_mask_list = []
- self.img_size = (img_size, img_size)
- self.stride = stride
- self._augment = augment
- if self._augment:
- self._mini_mosaic = aug_param['mini_mosaic']
- self._augment_hsv = aug_param['hsv']
- self._flip_lr = aug_param['flip_lr']
- self._neg = aug_param['neg']
- size_range = aug_param['size_range']
- if size_range[0] != -1:
- min_size = round(img_size * size_range[0] / stride ) * stride
- max_size = round(img_size * size_range[1] / stride ) * stride
- self.valid_size = np.arange(min_size, max_size+1, stride)
- self.multi_size = True
- else:
- self.valid_size = None
- self.multi_size = False
- for img_dir in self.img_dir:
- for filep in glob.glob(osp.join(img_dir, "*")):
- filename = osp.basename(filep)
- file_suffix = Path(filename).suffix
- if file_suffix.lower() not in IMG_EXT:
- continue
- maskname = 'mask-' + filename.replace(file_suffix, '.png')
- for mask_dir in self.mask_dir:
- maskp = osp.join(mask_dir, maskname)
- if osp.exists(maskp):
- self.img_mask_list.append((filep, maskp))
- self._img_transform = transforms.Compose([transforms.ToTensor()])
- self._mask_transform = transforms.Compose([transforms.ToTensor()])
- n = len(self.img_mask_list)
- self.imgs, self.masks = [None] * n, [None] * n
- gb = 0
- if cache:
- results = ThreadPool(NUM_THREADS).imap(lambda x: load_image_mask(*x, max_size=img_size), zip(repeat(self), range(n)))
- pbar = tqdm(enumerate(results), total=n)
- for i, x in pbar:
- im, self.masks[i] = x # im, hw_orig, hw_resized = load_image_mask(self, i)
- if not cache_mask_only:
- self.imgs[i] = im
- gb += self.imgs[i].nbytes
- gb += self.masks[i].nbytes
- if gb / 1E9 > 7:
- break
- pbar.desc = f'Caching images ({gb / 1E9:.1f}GB )'
- pbar.close()
-
- def initialize(self):
- if self.augment:
- if self.multi_size:
- self.img_size = random.choice(self.valid_size)
-
- def transform(self, img, mask):
- cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
- img = img.astype(np.float32) / 255
- mask = (mask > 30).astype(np.float32)
- # mask = mask / 255
- img = self._img_transform(img)
- mask = self._mask_transform(mask)
- return img, mask
- def augment(self, img, mask):
- im_h, im_w = img.shape[0], img.shape[1]
- if im_h > im_w and random.random() < self._mini_mosaic:
- # imp2, maskp2 = random.choice(self.img_mask_list)
- img, mask = mini_mosaic(self, img, mask)
- img, ratio, (dw, dh) = letterbox(img, new_shape=self.img_size, auto=False)
- mask, ratio, (dw, dh) = letterbox(mask, new_shape=self.img_size, auto=False)
- if random.random() < self._augment_hsv:
- augment_hsv(img)
- if random.random() < self._flip_lr:
- cv2.flip(img, 1, img)
- cv2.flip(mask, 1, mask)
- if random.random() < self._neg:
- img = 255 - img
- return img, mask
- def inverse_transform(self, img: torch.Tensor):
- img = img.permute(1, 2, 0)
- img = img * 255
- img = img.cpu().numpy().astype(np.uint8)
- return img
- def __len__(self):
- return len(self.img_mask_list)
- def __getitem__(self, idx):
- img, mask = load_image_mask(self, idx, self.img_size)
- if self._augment:
- img, mask = self.augment(img, mask)
- else:
- img, ratio, (dw, dh) = letterbox(img, new_shape=self.img_size, auto=False)
- mask, ratio, (dw, dh) = letterbox(mask, new_shape=self.img_size, auto=False)
- return self.transform(img, mask)
- def create_dataloader(img_dir, mask_dir, imgsz, batch_size, augment=False, aug_param=None, cache=False, workers=8, shuffle=False):
- dataset = LoadImageAndMask(img_dir, mask_dir, imgsz, augment, aug_param, cache)
- batch_size = min(batch_size, len(dataset))
- nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=nw)
- return dataset, loader
- if __name__ == '__main__':
- random.seed(42)
- torch.random.manual_seed(42)
- np.random.seed(42)
- import yaml
- hyp_p = r'data/train_hyp.yaml'
- with open(hyp_p, 'r', encoding='utf8') as f:
- hyp = yaml.safe_load(f.read())
- hyp['data']['train_img_dir'] = [r'D:/neonbub/datasets/codat_manga_v3/images/train', r'D:/neonbub/datasets/ComicErased/processed']
- hyp['data']['val_img_dir'] = [r'D:/neonbub/datasets/codat_manga_v3/images/val']
- hyp['data']['train_mask_dir'] = r'D:/neonbub/datasets/ComicSegV2'
- hyp['data']['val_mask_dir'] = r'D:/neonbub/datasets/ComicSegV2'
- hyp['data']['cache'] = False
- hyp_train, hyp_data, hyp_model, hyp_logger, hyp_resume = hyp['train'], hyp['data'], hyp['model'], hyp['logger'], hyp['resume']
-
- batch_size = hyp_train['batch_size']
- batch_size = 4
- num_workers = 2
- train_img_dir, train_mask_dir, imgsz, augment, aug_param = hyp_data['train_img_dir'], hyp_data['train_mask_dir'], hyp_data['imgsz'], hyp_data['augment'], hyp_data['aug_param']
- val_img_dir, val_mask_dir = hyp_data['val_img_dir'], hyp_data['val_mask_dir']
- train_dataset, train_loader = create_dataloader(train_img_dir, train_mask_dir, imgsz, batch_size, augment, aug_param, shuffle=True, workers=num_workers, cache=hyp_data['cache'])
- val_dataset, val_loader = create_dataloader(val_img_dir, val_mask_dir, imgsz, batch_size, augment=False, shuffle=False, workers=num_workers, cache=hyp_data['cache'])
- LOGGER.info(f'num training imgs: {len(train_dataset)}, num val imgs: {len(val_dataset)}')
-
- for epoch in range(0, 4): # epoch ------------------------------------------------------------------
- train_dataset.initialize()
- pbar = enumerate(train_loader)
- pbar = tqdm(pbar, total=len(train_loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
- pbar.set_description(f' training size: {train_dataset.img_size}')
- for i, (imgs, masks) in pbar:
- img, mask = imgs[0], masks[0]
- imgs = imgs
- masks = masks
- img = train_dataset.inverse_transform(img)
- mask = train_dataset.inverse_transform(mask)
- cv2.imshow('img', img)
- cv2.imshow('mask', mask)
- cv2.waitKey(0)
- pbar.close()
|