train_seg.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import torch
  2. from torch.optim import SGD, Adam, lr_scheduler
  3. from tqdm import tqdm
  4. import math
  5. from torch.cuda import amp
  6. import torch
  7. from utils.loss import BinaryDiceLoss
  8. import torch.nn as nn
  9. import yaml
  10. from basemodel import TextDetector
  11. import numpy as np
  12. from datetime import datetime
  13. from torchsummary import summary
  14. import numexpr
  15. import os
  16. import shutil
  17. os.environ['NUMEXPR_MAX_THREADS'] = str(numexpr.detect_number_of_cores())
  18. from seg_dataset import create_dataloader
  19. from utils.general import LOGGER, Loggers, CUDA, DEVICE
  20. import random
  21. torch.random.manual_seed(0)
  22. random.seed(0)
  23. np.random.seed(0)
  24. def one_cycle(y1=0.0, y2=1.0, steps=100):
  25. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  26. def eval_model(model: nn.Module, val_loader):
  27. global DEVICE
  28. loss_func = BinaryDiceLoss()
  29. pbar = enumerate(val_loader)
  30. nb = len(val_loader)
  31. model.eval()
  32. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  33. pr = tp = gt = m_loss = 0
  34. with torch.no_grad():
  35. for i, (imgs, masks) in pbar:
  36. imgs = imgs.to(DEVICE)
  37. masks = masks.to(DEVICE)
  38. pred = model(imgs)
  39. imgs.detach_()
  40. del imgs
  41. tp += torch.mul(pred, masks).sum().detach_()
  42. gt += masks.sum().detach_()
  43. pr += pred.sum().detach_()
  44. loss = loss_func(pred, masks)
  45. m_loss = (m_loss * i + loss.detach()) / (i + 1)
  46. masks.detach_()
  47. del masks
  48. recall = tp / gt
  49. precision = tp / pr
  50. return recall, precision, m_loss
  51. def train(hyp):
  52. with open(r'data/training_hyp.yaml', 'w', encoding='utf8') as f:
  53. yaml.safe_dump(hyp, f)
  54. start_epoch = 0
  55. hyp_train, hyp_data, hyp_model, hyp_logger, hyp_resume = hyp['train'], hyp['data'], hyp['model'], hyp['logger'], hyp['resume']
  56. epochs = hyp_train['epochs']
  57. batch_size = hyp_train['batch_size']
  58. model = TextDetector(**hyp_model)
  59. if CUDA:
  60. model.cuda()
  61. params = model.seg_net.parameters()
  62. if hyp_train['optimizer'] == 'adam':
  63. optimizer = Adam(params, lr=hyp_train['lr0'], betas=(hyp_train['momentum'], 0.999), weight_decay=hyp_train['weight_decay']) # adjust beta1 to momentum
  64. else:
  65. optimizer = SGD(params, lr=hyp_train['lr0'], momentum=hyp_train['momentum'], nesterov=True, weight_decay=hyp_train['weight_decay'])
  66. if hyp_train['linear_lr']:
  67. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp_train['lrf']) + hyp_train['lrf'] # linear
  68. else:
  69. lf = one_cycle(1, hyp_train['lrf'], epochs) # cosine 1->hyp['lrf']
  70. scaler = amp.GradScaler(enabled=CUDA)
  71. loss_func = BinaryDiceLoss()
  72. # Scheduler
  73. if hyp_train['linear_lr']:
  74. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp_train['lrf']) + hyp_train['lrf'] # linear
  75. else:
  76. lf = one_cycle(1, hyp_train['lrf'], epochs) # cosine 1->hyp['lrf']
  77. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  78. logger = None
  79. if hyp_resume['resume_training']:
  80. LOGGER.info(f'resume traning ... ')
  81. ckpt = torch.load(hyp_resume['ckpt'], map_location=DEVICE)
  82. model.seg_net.load_state_dict(ckpt['weights'])
  83. optimizer.load_state_dict(ckpt['optimizer'])
  84. scheduler.load_state_dict(ckpt['scheduler'])
  85. scheduler.step()
  86. start_epoch = ckpt['epoch'] + 1
  87. hyp_logger['run_id'] = ckpt['run_id']
  88. logger = Loggers(hyp)
  89. else:
  90. if hyp_logger['type'] == 'wandb':
  91. logger = Loggers(hyp)
  92. num_workers = 8
  93. train_img_dir, train_mask_dir, imgsz, augment, aug_param = hyp_data['train_img_dir'], hyp_data['train_mask_dir'], hyp_data['imgsz'], hyp_data['augment'], hyp_data['aug_param']
  94. val_img_dir, val_mask_dir = hyp_data['val_img_dir'], hyp_data['val_mask_dir']
  95. train_dataset, train_loader = create_dataloader(train_img_dir, train_mask_dir, imgsz, batch_size, augment, aug_param, shuffle=True, workers=num_workers, cache=hyp_data['cache'])
  96. val_dataset, val_loader = create_dataloader(val_img_dir, val_mask_dir, imgsz, 4, augment=False, shuffle=False, workers=num_workers, cache=hyp_data['cache'])
  97. nb = len(train_loader)
  98. nw = max(round(3 * nb), 700)
  99. LOGGER.info(f'num training imgs: {len(train_dataset)}, num val imgs: {len(val_dataset)}')
  100. eval_interval = hyp_train['eval_interval']
  101. best_f1 = -1
  102. best_val_loss = np.inf
  103. accumulation_steps = hyp_train['accumulation_steps']
  104. summary(model, (3, 640, 640), device=DEVICE)
  105. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  106. model.train_mask()
  107. train_dataset.initialize()
  108. pbar = enumerate(train_loader)
  109. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  110. m_loss = 0
  111. for i, (imgs, masks) in pbar:
  112. pbar.set_description(f' training size: {train_dataset.img_size}')
  113. # warm up
  114. ni = i + nb * epoch
  115. if ni <= nw:
  116. xi = [0, nw] # x interp
  117. for j, x in enumerate(optimizer.param_groups):
  118. x['lr'] = np.interp(ni, xi, [hyp_train['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  119. if 'momentum' in x:
  120. x['momentum'] = np.interp(ni, xi, [hyp_train['warmup_momentum'], hyp_train['momentum']])
  121. imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
  122. with amp.autocast():
  123. preds = model(imgs)
  124. imgs.detach_()
  125. del imgs
  126. loss = loss_func(preds, masks)
  127. masks.detach_()
  128. del masks
  129. scaler.scale(loss).backward()
  130. if i % accumulation_steps == 0:
  131. scaler.step(optimizer)
  132. scaler.update()
  133. optimizer.zero_grad()
  134. m_loss = (m_loss * i + loss.detach()) / (i + 1)
  135. if (epoch + 1) % eval_interval == 0:
  136. recall, precision, eval_m_loss = eval_model(model, val_loader)
  137. f1 = 2 * recall * precision / (recall + precision)
  138. last_ckpt = {'epoch': epoch,
  139. 'best_f1': best_f1,
  140. 'weights': model.seg_net.state_dict(),
  141. 'best_val_loss': best_val_loss,
  142. 'optimizer': optimizer.state_dict(),
  143. 'scheduler': scheduler.state_dict(),
  144. 'run_id': logger.wandb.id if logger is not None else None,
  145. 'date': datetime.now().isoformat(),
  146. 'hyp': hyp}
  147. torch.save(last_ckpt, 'data/unet_last.ckpt')
  148. if best_f1 < f1:
  149. best_f1 = f1
  150. LOGGER.info(f'saving model at epoch {epoch}, best val f1: {best_f1}')
  151. shutil.copy2('data/unet_last.ckpt', 'data/unet_best.ckpt')
  152. LOGGER.info(f'epoch {epoch}/{epochs-1} loss: {m_loss} precision: {precision} recall: {recall}')
  153. if logger is not None:
  154. log_dict = {}
  155. log_dict['train/lr'] = optimizer.param_groups[0]['lr']
  156. log_dict['train/loss'] = m_loss
  157. log_dict['eval/recall'] = recall
  158. log_dict['eval/precision'] = precision
  159. log_dict['eval/f1'] = f1
  160. log_dict['eval/eval_m_loss'] = eval_m_loss
  161. logger.on_train_epoch_end(epoch, log_dict)
  162. scheduler.step()
  163. pbar.close()
  164. if __name__ == '__main__':
  165. hyp_p = r'data/train_hyp.yaml'
  166. with open(hyp_p, 'r', encoding='utf8') as f:
  167. hyp = yaml.safe_load(f.read())
  168. hyp['data']['train_img_dir'] = [r'../datasets/codat_manga_v3/images/train', r'../datasets/ComicErased/processed']
  169. # hyp['data']['train_img_dir'] = [r'../datasets/codat_manga_v3/images/val']
  170. hyp['data']['val_img_dir'] = [r'../datasets/codat_manga_v3/images/val']
  171. hyp['data']['train_mask_dir'] = r'../datasets/ComicSegV2'
  172. hyp['data']['val_mask_dir'] = r'../datasets/ComicSegV2'
  173. hyp['data']['imgsz'] = 1024
  174. hyp['data']['cache'] = False
  175. hyp['data']['aug_param']['neg'] = 0.3
  176. hyp['data']['aug_param']['size_range'] = [0.85, 1.1]
  177. hyp['train']['lr0'] = 0.004
  178. hyp['train']['lrf'] = 0.005
  179. hyp['train']['weight_decay'] = 0.00002
  180. hyp['train']['epochs'] = 120
  181. hyp['train']['accumulation_steps'] = 4
  182. hyp['train']['batch_size'] = 4
  183. hyp['logger']['type'] = 'wandb'
  184. # hyp['resume']['resume_training'] = True
  185. # hyp['resume']['ckpt'] = 'data/unet_last.ckpt'
  186. train(hyp)