| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- import torch
- from torch.optim import SGD, Adam, lr_scheduler
- from tqdm import tqdm
- import math
- from torch.cuda import amp
- import torch
- from utils.loss import BinaryDiceLoss
- import torch.nn as nn
- import yaml
- from basemodel import TextDetector
- import numpy as np
- from datetime import datetime
- from torchsummary import summary
- import numexpr
- import os
- import shutil
- os.environ['NUMEXPR_MAX_THREADS'] = str(numexpr.detect_number_of_cores())
- from seg_dataset import create_dataloader
- from utils.general import LOGGER, Loggers, CUDA, DEVICE
- import random
- torch.random.manual_seed(0)
- random.seed(0)
- np.random.seed(0)
- def one_cycle(y1=0.0, y2=1.0, steps=100):
- return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
- def eval_model(model: nn.Module, val_loader):
- global DEVICE
- loss_func = BinaryDiceLoss()
- pbar = enumerate(val_loader)
- nb = len(val_loader)
- model.eval()
- pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
- pr = tp = gt = m_loss = 0
- with torch.no_grad():
- for i, (imgs, masks) in pbar:
- imgs = imgs.to(DEVICE)
- masks = masks.to(DEVICE)
- pred = model(imgs)
- imgs.detach_()
- del imgs
- tp += torch.mul(pred, masks).sum().detach_()
- gt += masks.sum().detach_()
- pr += pred.sum().detach_()
- loss = loss_func(pred, masks)
- m_loss = (m_loss * i + loss.detach()) / (i + 1)
- masks.detach_()
- del masks
- recall = tp / gt
- precision = tp / pr
- return recall, precision, m_loss
- def train(hyp):
- with open(r'data/training_hyp.yaml', 'w', encoding='utf8') as f:
- yaml.safe_dump(hyp, f)
- start_epoch = 0
- hyp_train, hyp_data, hyp_model, hyp_logger, hyp_resume = hyp['train'], hyp['data'], hyp['model'], hyp['logger'], hyp['resume']
- epochs = hyp_train['epochs']
- batch_size = hyp_train['batch_size']
- model = TextDetector(**hyp_model)
- if CUDA:
- model.cuda()
- params = model.seg_net.parameters()
-
- if hyp_train['optimizer'] == 'adam':
- optimizer = Adam(params, lr=hyp_train['lr0'], betas=(hyp_train['momentum'], 0.999), weight_decay=hyp_train['weight_decay']) # adjust beta1 to momentum
- else:
- optimizer = SGD(params, lr=hyp_train['lr0'], momentum=hyp_train['momentum'], nesterov=True, weight_decay=hyp_train['weight_decay'])
-
- if hyp_train['linear_lr']:
- lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp_train['lrf']) + hyp_train['lrf'] # linear
- else:
- lf = one_cycle(1, hyp_train['lrf'], epochs) # cosine 1->hyp['lrf']
- scaler = amp.GradScaler(enabled=CUDA)
- loss_func = BinaryDiceLoss()
- # Scheduler
- if hyp_train['linear_lr']:
- lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp_train['lrf']) + hyp_train['lrf'] # linear
- else:
- lf = one_cycle(1, hyp_train['lrf'], epochs) # cosine 1->hyp['lrf']
- scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
-
- logger = None
- if hyp_resume['resume_training']:
- LOGGER.info(f'resume traning ... ')
- ckpt = torch.load(hyp_resume['ckpt'], map_location=DEVICE)
- model.seg_net.load_state_dict(ckpt['weights'])
- optimizer.load_state_dict(ckpt['optimizer'])
- scheduler.load_state_dict(ckpt['scheduler'])
- scheduler.step()
- start_epoch = ckpt['epoch'] + 1
- hyp_logger['run_id'] = ckpt['run_id']
- logger = Loggers(hyp)
- else:
- if hyp_logger['type'] == 'wandb':
- logger = Loggers(hyp)
- num_workers = 8
- train_img_dir, train_mask_dir, imgsz, augment, aug_param = hyp_data['train_img_dir'], hyp_data['train_mask_dir'], hyp_data['imgsz'], hyp_data['augment'], hyp_data['aug_param']
- val_img_dir, val_mask_dir = hyp_data['val_img_dir'], hyp_data['val_mask_dir']
- train_dataset, train_loader = create_dataloader(train_img_dir, train_mask_dir, imgsz, batch_size, augment, aug_param, shuffle=True, workers=num_workers, cache=hyp_data['cache'])
- val_dataset, val_loader = create_dataloader(val_img_dir, val_mask_dir, imgsz, 4, augment=False, shuffle=False, workers=num_workers, cache=hyp_data['cache'])
- nb = len(train_loader)
- nw = max(round(3 * nb), 700)
- LOGGER.info(f'num training imgs: {len(train_dataset)}, num val imgs: {len(val_dataset)}')
- eval_interval = hyp_train['eval_interval']
- best_f1 = -1
- best_val_loss = np.inf
- accumulation_steps = hyp_train['accumulation_steps']
- summary(model, (3, 640, 640), device=DEVICE)
- for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
-
- model.train_mask()
- train_dataset.initialize()
- pbar = enumerate(train_loader)
- pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
-
- m_loss = 0
- for i, (imgs, masks) in pbar:
-
- pbar.set_description(f' training size: {train_dataset.img_size}')
- # warm up
- ni = i + nb * epoch
- if ni <= nw:
- xi = [0, nw] # x interp
- for j, x in enumerate(optimizer.param_groups):
- x['lr'] = np.interp(ni, xi, [hyp_train['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
- if 'momentum' in x:
- x['momentum'] = np.interp(ni, xi, [hyp_train['warmup_momentum'], hyp_train['momentum']])
- imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
- with amp.autocast():
- preds = model(imgs)
- imgs.detach_()
- del imgs
- loss = loss_func(preds, masks)
- masks.detach_()
- del masks
- scaler.scale(loss).backward()
- if i % accumulation_steps == 0:
- scaler.step(optimizer)
- scaler.update()
- optimizer.zero_grad()
- m_loss = (m_loss * i + loss.detach()) / (i + 1)
-
- if (epoch + 1) % eval_interval == 0:
- recall, precision, eval_m_loss = eval_model(model, val_loader)
- f1 = 2 * recall * precision / (recall + precision)
- last_ckpt = {'epoch': epoch,
- 'best_f1': best_f1,
- 'weights': model.seg_net.state_dict(),
- 'best_val_loss': best_val_loss,
- 'optimizer': optimizer.state_dict(),
- 'scheduler': scheduler.state_dict(),
- 'run_id': logger.wandb.id if logger is not None else None,
- 'date': datetime.now().isoformat(),
- 'hyp': hyp}
- torch.save(last_ckpt, 'data/unet_last.ckpt')
- if best_f1 < f1:
- best_f1 = f1
- LOGGER.info(f'saving model at epoch {epoch}, best val f1: {best_f1}')
- shutil.copy2('data/unet_last.ckpt', 'data/unet_best.ckpt')
- LOGGER.info(f'epoch {epoch}/{epochs-1} loss: {m_loss} precision: {precision} recall: {recall}')
- if logger is not None:
- log_dict = {}
- log_dict['train/lr'] = optimizer.param_groups[0]['lr']
- log_dict['train/loss'] = m_loss
- log_dict['eval/recall'] = recall
- log_dict['eval/precision'] = precision
- log_dict['eval/f1'] = f1
- log_dict['eval/eval_m_loss'] = eval_m_loss
- logger.on_train_epoch_end(epoch, log_dict)
- scheduler.step()
- pbar.close()
- if __name__ == '__main__':
- hyp_p = r'data/train_hyp.yaml'
- with open(hyp_p, 'r', encoding='utf8') as f:
- hyp = yaml.safe_load(f.read())
- hyp['data']['train_img_dir'] = [r'../datasets/codat_manga_v3/images/train', r'../datasets/ComicErased/processed']
- # hyp['data']['train_img_dir'] = [r'../datasets/codat_manga_v3/images/val']
- hyp['data']['val_img_dir'] = [r'../datasets/codat_manga_v3/images/val']
- hyp['data']['train_mask_dir'] = r'../datasets/ComicSegV2'
- hyp['data']['val_mask_dir'] = r'../datasets/ComicSegV2'
- hyp['data']['imgsz'] = 1024
- hyp['data']['cache'] = False
- hyp['data']['aug_param']['neg'] = 0.3
- hyp['data']['aug_param']['size_range'] = [0.85, 1.1]
- hyp['train']['lr0'] = 0.004
- hyp['train']['lrf'] = 0.005
- hyp['train']['weight_decay'] = 0.00002
- hyp['train']['epochs'] = 120
- hyp['train']['accumulation_steps'] = 4
- hyp['train']['batch_size'] = 4
- hyp['logger']['type'] = 'wandb'
- # hyp['resume']['resume_training'] = True
- # hyp['resume']['ckpt'] = 'data/unet_last.ckpt'
- train(hyp)
|