import numpy as np import yaml import torch import glob import os import os.path as osp import random from itertools import repeat from multiprocessing.pool import Pool, ThreadPool from pathlib import Path from threading import Thread import cv2 from torch.utils.data import Dataset from tqdm import tqdm from pathlib import Path from torchvision import transforms from torch.utils.data import DataLoader, Dataset, dataloader from utils.general import LOGGER, Loggers, CUDA, DEVICE from utils.db_utils import MakeBorderMap, MakeShrinkMap from seg_dataset import augment_hsv from utils.imgproc_utils import rotate_polygons, letterbox, resize_keepasp from PIL import Image WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # DPP NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of multiprocessing threads IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg'] def db_val_collate_fn(batchs): cat_list = ['text_polys', 'ignore_tags'] ret_batchs = {} for key in batchs[0].keys(): ret_batchs[key] = [] for batch in batchs: if isinstance(batch[key], np.ndarray): batch[key] = torch.from_numpy(batch[key]) ret_batchs[key].append(batch[key]) if key in cat_list: pass else: ret_batchs[key] = torch.stack(ret_batchs[key], 0) return ret_batchs class LoadImageAndAnnotations(Dataset): def __init__(self, img_dir, ann_dir=None, img_size=640, augment=False, aug_param=None, cache=False, stride=128, cache_ann_only=True, with_ann=False): if isinstance(img_dir, str): self.img_dir = [img_dir] elif isinstance(img_dir, list): self.img_dir = img_dir else: raise Exception('unknown img_dir format') if ann_dir is None or ann_dir == '': self.ann_dir = self.img_dir else: if isinstance(ann_dir, str): self.ann_dir = [ann_dir] elif isinstance(ann_dir, list): self.ann_dir = ann_dir self.with_ann = with_ann self.make_border_map = MakeBorderMap(shrink_ratio=0.4) self.make_shrink_map = MakeShrinkMap(shrink_ratio=0.4) self.img_ann_list = [] self.img_size = (img_size, img_size) self.stride = stride self._augment = augment if self._augment: self._mini_mosaic = aug_param['mini_mosaic'] self._augment_hsv = aug_param['hsv'] self._flip_lr = aug_param['flip_lr'] self._neg = aug_param['neg'] self._rotate = aug_param['rotate'] self.rotate_range = aug_param['rotate_range'] size_range = aug_param['size_range'] if isinstance(size_range, list) and size_range[0] > 0: min_size = round(img_size * size_range[0] / stride ) * stride max_size = round(img_size * size_range[1] / stride ) * stride self.valid_size = np.arange(min_size, max_size+1, stride) self.multi_size = True else: self.valid_size = None self.multi_size = False for img_dir in self.img_dir: for filep in glob.glob(osp.join(img_dir, "*")): filename = osp.basename(filep) file_suffix = Path(filename).suffix if file_suffix not in IMG_EXT: continue annname = 'line-' + filename.replace(file_suffix, '.txt') for ann_dir in self.ann_dir: annp = osp.join(ann_dir, annname) if osp.exists(annp): self.img_ann_list.append((filep, annp)) self._img_transform = transforms.Compose([transforms.ToTensor()]) n = len(self.img_ann_list) self.imgs, self.anns = [None] * n, [None] * n gb = 0 if cache: results = ThreadPool(NUM_THREADS).imap(lambda x: load_image_annotations(*x, max_size=img_size), zip(repeat(self), range(n))) pbar = tqdm(enumerate(results), total=n) for i, x in pbar: im, self.anns[i] = x # im, hw_orig, hw_resized = load_image_ann(self, i) if not cache_ann_only: self.imgs[i] = im gb += self.imgs[i].nbytes gb += self.anns[i].nbytes if gb / 1E9 > 7: break pbar.desc = f'Caching images ({gb / 1E9:.1f}GB )' pbar.close() def initialize(self): if self.augment: if self.multi_size: self.img_size = random.choice(self.valid_size) def transform(self, img): cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) img = img.astype(np.float32) / 255 img = self._img_transform(img) return img def mini_mosaic(self, img, ann): im_h, im_w = img.shape[:2] idx = random.randint(0, len(self)-1) img2, ann2 = load_image_annotations(self, idx, self.img_size) img2_h, img2_w = img2.shape[:2] if img2_h > img2_w: imm_h = max(im_h, img2_h) imm_w = im_w + img2_w im_tmp = np.zeros((imm_h, imm_w, 3), np.uint8) im_tmp[:im_h, :im_w] = img im_tmp[:img2_h, im_w:] = img2 ann[:, :, 0] = ann[:, :, 0] * im_w / imm_w ann[:, :, 1] = ann[:, :, 1] * im_h / imm_h if ann2.shape[1] > 0: ann2[:, :, 0] = ann2[:, :, 0] * img2_w / imm_w + im_w / imm_w ann2[:, :, 1] = ann2[:, :, 1] * img2_h / imm_h ann = np.concatenate((ann, ann2)) img = im_tmp return img, ann else: return img, ann def augment(self, img, ann): im_h, im_w = img.shape[0], img.shape[1] if im_h > im_w and random.random() < self._mini_mosaic: # imp2, annp2 = random.choice(self.img_ann_list) img, ann = self.mini_mosaic(img, ann) if random.random() < self._augment_hsv: augment_hsv(img) if random.random() < self._flip_lr: cv2.flip(img, 1, img) ann[:, :, 0] = 1 - ann[:, :, 0] if random.random() < self._neg: img = 255 - img if random.random() < self._rotate: degrees = random.uniform(self.rotate_range[0], self.rotate_range[1]) if abs(degrees) > 15: img = Image.fromarray(img) center = (img.width/2, img.height/2) ann[:, :, 0] *= img.width ann[:, :, 1] *= img.height ann = ann.reshape(len(ann), -1) img = img.rotate(degrees, resample=Image.BILINEAR, expand=1) new_center = (img.width/2, img.height/2) ann = rotate_polygons(center, ann, degrees, new_center, to_int=False) ann = ann.reshape(len(ann), -1, 2) ann[:, :, 0] /= img.width ann[:, :, 1] /= img.height img = np.asarray(img) return img, ann def inverse_transform(self, img: torch.Tensor, scale=255, to_uint8=True): img = img.permute(1, 2, 0) img = img * scale img = img.cpu().numpy() if to_uint8: img = np.ascontiguousarray(img, np.uint8) return img def __len__(self): return len(self.img_ann_list) def __getitem__(self, idx): img, ann = load_image_annotations(self, idx, self.img_size) in_h, in_w = img.shape[:2] if self._augment: img, ann = self.augment(img, ann) ignore_tags = [False] * ann.shape[0] img, ratio, (dw, dh) = letterbox(img, new_shape=self.img_size, auto=False) im_h, im_w = img.shape[:2] if ann is not None: ann[:, :, 0] *= (im_w - dw) ann[:, :, 1] *= (im_h - dh) ann = ann.astype(np.int64) data_dict = {'imgs': img, 'text_polys': ann, 'ignore_tags': ignore_tags} shrink_map = self.make_shrink_map(data_dict) thresh_map = self.make_border_map(data_dict) tp = thresh_map.pop('text_polys') it = thresh_map.pop('ignore_tags') if self.with_ann: thresh_map['text_polys'] = torch.from_numpy(np.array(tp)) thresh_map['ignore_tags'] = torch.from_numpy(np.array(it)) thresh_map['imgs'] = self.transform(thresh_map['imgs']) return thresh_map def load_image_annotations(self, i, max_size=None, ann_abs2rel=True): # loads 1 image from dataset index 'i', returns im, original hw, resized hw img, ann = self.imgs[i], self.anns[i] imp, ann_path = self.img_ann_list[i] if img is None: img = cv2.imread(imp) im_h, im_w = img.shape[:2] if ann is None: ann = np.loadtxt(ann_path) if len(ann.shape) == 1: ann = np.array([ann]) if ann_abs2rel: ann[:, ::2] /= im_w ann[:, 1::2] /= im_h ann = ann.reshape(len(ann), -1, 2) else: ann = np.copy(ann) if max_size is not None: if isinstance(max_size, tuple): max_size = max_size[0] img = resize_keepasp(img, max_size) return img, ann def create_dataloader(img_dir, ann_dir, imgsz, batch_size, augment=False, aug_param=None, cache=False, workers=8, shuffle=False, with_ann=False): dataset = LoadImageAndAnnotations(img_dir, ann_dir, imgsz, augment, aug_param, cache, with_ann=with_ann) batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers if with_ann: collate_fn = db_val_collate_fn else: collate_fn = None loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=nw, collate_fn=collate_fn) return dataset, loader if __name__ == '__main__': img_dir = 'data/dataset/db_sub' hyp_p = r'data/train_db_hyp.yaml' with open(hyp_p, 'r', encoding='utf8') as f: hyp = yaml.safe_load(f.read()) hyp['data']['train_img_dir'] = img_dir hyp['data']['cache'] = False hyp_train, hyp_data, hyp_model, hyp_logger, hyp_resume = hyp['train'], hyp['data'], hyp['model'], hyp['logger'], hyp['resume'] batch_size = hyp_train['batch_size'] batch_size = 1 num_workers = 0 train_img_dir, train_mask_dir, imgsz, augment, aug_param = hyp_data['train_img_dir'], hyp_data['train_mask_dir'], hyp_data['imgsz'], hyp_data['augment'], hyp_data['aug_param'] train_dataset, train_loader = create_dataloader(train_img_dir, train_mask_dir, imgsz, batch_size, augment, aug_param, shuffle=True, workers=num_workers, cache=hyp_data['cache'], with_ann=True) for ii in range(10): for batchs in train_loader: train_dataset.initialize() print(train_dataset.img_size) img = batchs['imgs'][0] img = train_dataset.inverse_transform(img) threshold_map = batchs['threshold_map'][0] threshold_mask = batchs['threshold_mask'][0] shrink_map = batchs['shrink_map'][0] shrink_mask = batchs['shrink_mask'][0] polys = batchs['text_polys'][0].numpy().astype(np.int32) for p in polys: cv2.polylines(img,[p],True,(255, 0, 0), thickness=2) cv2.imshow('imgs', img) cv2.imshow('threshold_map', threshold_map.numpy()) cv2.imshow('threshold_mask', threshold_mask.numpy()) cv2.imshow('shrink_map', shrink_map.numpy()) cv2.imshow('shrink_mask', shrink_mask.numpy()) cv2.waitKey(0)