| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- import numpy as np
- import yaml
- import torch
- import glob
- import os
- import os.path as osp
- import random
- from itertools import repeat
- from multiprocessing.pool import Pool, ThreadPool
- from pathlib import Path
- from threading import Thread
- import cv2
- from torch.utils.data import Dataset
- from tqdm import tqdm
- from pathlib import Path
- from torchvision import transforms
- from torch.utils.data import DataLoader, Dataset, dataloader
- from utils.general import LOGGER, Loggers, CUDA, DEVICE
- from utils.db_utils import MakeBorderMap, MakeShrinkMap
- from seg_dataset import augment_hsv
- from utils.imgproc_utils import rotate_polygons, letterbox, resize_keepasp
- from PIL import Image
- WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # DPP
- NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of multiprocessing threads
- IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg']
- def db_val_collate_fn(batchs):
- cat_list = ['text_polys', 'ignore_tags']
- ret_batchs = {}
- for key in batchs[0].keys():
- ret_batchs[key] = []
- for batch in batchs:
- if isinstance(batch[key], np.ndarray):
- batch[key] = torch.from_numpy(batch[key])
- ret_batchs[key].append(batch[key])
- if key in cat_list:
- pass
- else:
- ret_batchs[key] = torch.stack(ret_batchs[key], 0)
- return ret_batchs
- class LoadImageAndAnnotations(Dataset):
- def __init__(self, img_dir, ann_dir=None, img_size=640, augment=False, aug_param=None, cache=False, stride=128, cache_ann_only=True, with_ann=False):
- if isinstance(img_dir, str):
- self.img_dir = [img_dir]
- elif isinstance(img_dir, list):
- self.img_dir = img_dir
- else:
- raise Exception('unknown img_dir format')
-
- if ann_dir is None or ann_dir == '':
- self.ann_dir = self.img_dir
- else:
- if isinstance(ann_dir, str):
- self.ann_dir = [ann_dir]
- elif isinstance(ann_dir, list):
- self.ann_dir = ann_dir
- self.with_ann = with_ann
- self.make_border_map = MakeBorderMap(shrink_ratio=0.4)
- self.make_shrink_map = MakeShrinkMap(shrink_ratio=0.4)
- self.img_ann_list = []
- self.img_size = (img_size, img_size)
- self.stride = stride
- self._augment = augment
- if self._augment:
- self._mini_mosaic = aug_param['mini_mosaic']
- self._augment_hsv = aug_param['hsv']
- self._flip_lr = aug_param['flip_lr']
- self._neg = aug_param['neg']
- self._rotate = aug_param['rotate']
- self.rotate_range = aug_param['rotate_range']
- size_range = aug_param['size_range']
- if isinstance(size_range, list) and size_range[0] > 0:
- min_size = round(img_size * size_range[0] / stride ) * stride
- max_size = round(img_size * size_range[1] / stride ) * stride
- self.valid_size = np.arange(min_size, max_size+1, stride)
- self.multi_size = True
- else:
- self.valid_size = None
- self.multi_size = False
- for img_dir in self.img_dir:
- for filep in glob.glob(osp.join(img_dir, "*")):
- filename = osp.basename(filep)
- file_suffix = Path(filename).suffix
- if file_suffix not in IMG_EXT:
- continue
- annname = 'line-' + filename.replace(file_suffix, '.txt')
- for ann_dir in self.ann_dir:
- annp = osp.join(ann_dir, annname)
- if osp.exists(annp):
- self.img_ann_list.append((filep, annp))
- self._img_transform = transforms.Compose([transforms.ToTensor()])
- n = len(self.img_ann_list)
- self.imgs, self.anns = [None] * n, [None] * n
- gb = 0
- if cache:
- results = ThreadPool(NUM_THREADS).imap(lambda x: load_image_annotations(*x, max_size=img_size), zip(repeat(self), range(n)))
- pbar = tqdm(enumerate(results), total=n)
- for i, x in pbar:
- im, self.anns[i] = x # im, hw_orig, hw_resized = load_image_ann(self, i)
- if not cache_ann_only:
- self.imgs[i] = im
- gb += self.imgs[i].nbytes
- gb += self.anns[i].nbytes
- if gb / 1E9 > 7:
- break
- pbar.desc = f'Caching images ({gb / 1E9:.1f}GB )'
- pbar.close()
-
- def initialize(self):
- if self.augment:
- if self.multi_size:
- self.img_size = random.choice(self.valid_size)
-
- def transform(self, img):
- cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
- img = img.astype(np.float32) / 255
- img = self._img_transform(img)
- return img
- def mini_mosaic(self, img, ann):
- im_h, im_w = img.shape[:2]
- idx = random.randint(0, len(self)-1)
- img2, ann2 = load_image_annotations(self, idx, self.img_size)
- img2_h, img2_w = img2.shape[:2]
-
- if img2_h > img2_w:
- imm_h = max(im_h, img2_h)
- imm_w = im_w + img2_w
- im_tmp = np.zeros((imm_h, imm_w, 3), np.uint8)
- im_tmp[:im_h, :im_w] = img
- im_tmp[:img2_h, im_w:] = img2
- ann[:, :, 0] = ann[:, :, 0] * im_w / imm_w
- ann[:, :, 1] = ann[:, :, 1] * im_h / imm_h
- if ann2.shape[1] > 0:
- ann2[:, :, 0] = ann2[:, :, 0] * img2_w / imm_w + im_w / imm_w
- ann2[:, :, 1] = ann2[:, :, 1] * img2_h / imm_h
- ann = np.concatenate((ann, ann2))
- img = im_tmp
- return img, ann
-
- else:
- return img, ann
- def augment(self, img, ann):
- im_h, im_w = img.shape[0], img.shape[1]
- if im_h > im_w and random.random() < self._mini_mosaic:
- # imp2, annp2 = random.choice(self.img_ann_list)
- img, ann = self.mini_mosaic(img, ann)
- if random.random() < self._augment_hsv:
- augment_hsv(img)
- if random.random() < self._flip_lr:
- cv2.flip(img, 1, img)
- ann[:, :, 0] = 1 - ann[:, :, 0]
- if random.random() < self._neg:
- img = 255 - img
- if random.random() < self._rotate:
- degrees = random.uniform(self.rotate_range[0], self.rotate_range[1])
- if abs(degrees) > 15:
- img = Image.fromarray(img)
- center = (img.width/2, img.height/2)
- ann[:, :, 0] *= img.width
- ann[:, :, 1] *= img.height
- ann = ann.reshape(len(ann), -1)
- img = img.rotate(degrees, resample=Image.BILINEAR, expand=1)
- new_center = (img.width/2, img.height/2)
- ann = rotate_polygons(center, ann, degrees, new_center, to_int=False)
- ann = ann.reshape(len(ann), -1, 2)
- ann[:, :, 0] /= img.width
- ann[:, :, 1] /= img.height
- img = np.asarray(img)
- return img, ann
- def inverse_transform(self, img: torch.Tensor, scale=255, to_uint8=True):
- img = img.permute(1, 2, 0)
- img = img * scale
- img = img.cpu().numpy()
- if to_uint8:
- img = np.ascontiguousarray(img, np.uint8)
- return img
- def __len__(self):
- return len(self.img_ann_list)
- def __getitem__(self, idx):
- img, ann = load_image_annotations(self, idx, self.img_size)
- in_h, in_w = img.shape[:2]
- if self._augment:
- img, ann = self.augment(img, ann)
- ignore_tags = [False] * ann.shape[0]
- img, ratio, (dw, dh) = letterbox(img, new_shape=self.img_size, auto=False)
- im_h, im_w = img.shape[:2]
- if ann is not None:
- ann[:, :, 0] *= (im_w - dw)
- ann[:, :, 1] *= (im_h - dh)
- ann = ann.astype(np.int64)
- data_dict = {'imgs': img, 'text_polys': ann, 'ignore_tags': ignore_tags}
- shrink_map = self.make_shrink_map(data_dict)
- thresh_map = self.make_border_map(data_dict)
- tp = thresh_map.pop('text_polys')
- it = thresh_map.pop('ignore_tags')
- if self.with_ann:
- thresh_map['text_polys'] = torch.from_numpy(np.array(tp))
- thresh_map['ignore_tags'] = torch.from_numpy(np.array(it))
- thresh_map['imgs'] = self.transform(thresh_map['imgs'])
- return thresh_map
- def load_image_annotations(self, i, max_size=None, ann_abs2rel=True):
- # loads 1 image from dataset index 'i', returns im, original hw, resized hw
- img, ann = self.imgs[i], self.anns[i]
- imp, ann_path = self.img_ann_list[i]
- if img is None:
- img = cv2.imread(imp)
- im_h, im_w = img.shape[:2]
- if ann is None:
- ann = np.loadtxt(ann_path)
- if len(ann.shape) == 1:
- ann = np.array([ann])
- if ann_abs2rel:
- ann[:, ::2] /= im_w
- ann[:, 1::2] /= im_h
- ann = ann.reshape(len(ann), -1, 2)
- else:
- ann = np.copy(ann)
- if max_size is not None:
- if isinstance(max_size, tuple):
- max_size = max_size[0]
- img = resize_keepasp(img, max_size)
- return img, ann
- def create_dataloader(img_dir, ann_dir, imgsz, batch_size, augment=False, aug_param=None, cache=False, workers=8, shuffle=False, with_ann=False):
- dataset = LoadImageAndAnnotations(img_dir, ann_dir, imgsz, augment, aug_param, cache, with_ann=with_ann)
- batch_size = min(batch_size, len(dataset))
- nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
- if with_ann:
- collate_fn = db_val_collate_fn
- else:
- collate_fn = None
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, num_workers=nw, collate_fn=collate_fn)
- return dataset, loader
- if __name__ == '__main__':
- img_dir = 'data/dataset/db_sub'
- hyp_p = r'data/train_db_hyp.yaml'
- with open(hyp_p, 'r', encoding='utf8') as f:
- hyp = yaml.safe_load(f.read())
- hyp['data']['train_img_dir'] = img_dir
- hyp['data']['cache'] = False
- hyp_train, hyp_data, hyp_model, hyp_logger, hyp_resume = hyp['train'], hyp['data'], hyp['model'], hyp['logger'], hyp['resume']
- batch_size = hyp_train['batch_size']
- batch_size = 1
- num_workers = 0
- train_img_dir, train_mask_dir, imgsz, augment, aug_param = hyp_data['train_img_dir'], hyp_data['train_mask_dir'], hyp_data['imgsz'], hyp_data['augment'], hyp_data['aug_param']
- train_dataset, train_loader = create_dataloader(train_img_dir, train_mask_dir, imgsz, batch_size, augment, aug_param, shuffle=True, workers=num_workers, cache=hyp_data['cache'], with_ann=True)
- for ii in range(10):
-
- for batchs in train_loader:
- train_dataset.initialize()
- print(train_dataset.img_size)
- img = batchs['imgs'][0]
-
- img = train_dataset.inverse_transform(img)
- threshold_map = batchs['threshold_map'][0]
- threshold_mask = batchs['threshold_mask'][0]
- shrink_map = batchs['shrink_map'][0]
- shrink_mask = batchs['shrink_mask'][0]
- polys = batchs['text_polys'][0].numpy().astype(np.int32)
- for p in polys:
- cv2.polylines(img,[p],True,(255, 0, 0), thickness=2)
- cv2.imshow('imgs', img)
- cv2.imshow('threshold_map', threshold_map.numpy())
- cv2.imshow('threshold_mask', threshold_mask.numpy())
- cv2.imshow('shrink_map', shrink_map.numpy())
- cv2.imshow('shrink_mask', shrink_mask.numpy())
- cv2.waitKey(0)
|