train_db.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from torch.autograd.grad_mode import F
  2. from torch.nn.functional import sigmoid
  3. from torch.nn.modules.loss import CrossEntropyLoss
  4. from torch.optim import SGD, Adam, lr_scheduler
  5. from tqdm import tqdm
  6. import math
  7. from torch.cuda import amp
  8. import torch
  9. from utils.loss import DBLoss
  10. import torch.nn as nn
  11. import yaml
  12. from basemodel import TextDetector
  13. from utils.db_utils import SegDetectorRepresenter, QuadMetric
  14. import numpy as np
  15. from datetime import datetime
  16. from torchsummary import summary
  17. import numexpr
  18. import os
  19. import shutil
  20. os.environ['NUMEXPR_MAX_THREADS'] = str(numexpr.detect_number_of_cores())
  21. from db_dataset import create_dataloader
  22. from utils.general import LOGGER, Loggers, CUDA, DEVICE
  23. import time
  24. import random
  25. torch.random.manual_seed(0)
  26. random.seed(0)
  27. np.random.seed(0)
  28. def one_cycle(y1=0.0, y2=1.0, steps=100):
  29. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  30. def eval_model(model: nn.Module, val_loader, post_process, metric_cls):
  31. # global DEVICE
  32. raw_metrics = []
  33. total_frame = 0.0
  34. total_time = 0.0
  35. model.eval()
  36. for i, batch in tqdm(enumerate(val_loader), total=len(val_loader), desc='test model'):
  37. with torch.no_grad():
  38. # 数据进行转换和丢到gpu
  39. for key, value in batch.items():
  40. if value is not None:
  41. if isinstance(value, torch.Tensor):
  42. batch[key] = value.to(DEVICE)
  43. start = time.time()
  44. with amp.autocast():
  45. preds = model(batch['imgs'])
  46. boxes, scores = post_process(batch, preds,is_output_polygon=False)
  47. total_frame += batch['imgs'].size()[0]
  48. total_time += time.time() - start
  49. raw_metric = metric_cls.validate_measure(batch, (boxes, scores))
  50. raw_metrics.append(raw_metric)
  51. metrics = metric_cls.gather_measure(raw_metrics)
  52. LOGGER.info('FPS:{}'.format(total_frame / total_time))
  53. return metrics['recall'].avg, metrics['precision'].avg, metrics['fmeasure'].avg
  54. def train(hyp):
  55. start_epoch = 0
  56. hyp_train, hyp_data, hyp_model, hyp_logger, hyp_resume = hyp['train'], hyp['data'], hyp['model'], hyp['logger'], hyp['resume']
  57. epochs = hyp_train['epochs']
  58. batch_size = hyp_train['batch_size']
  59. scaler = amp.GradScaler(enabled=CUDA)
  60. criterion = DBLoss()
  61. use_bce = False
  62. if hyp_train['loss'] == 'bce':
  63. use_bce = True
  64. shrink_with_sigmoid = not use_bce
  65. model = TextDetector(hyp_model['weights'], map_location='cpu', act=hyp_model['act'])
  66. model.initialize_db(hyp_model['unet_weights'])
  67. model.dbnet.shrink_with_sigmoid = shrink_with_sigmoid
  68. model.train_db()
  69. model.to(DEVICE)
  70. if hyp_model['db_weights'] != '':
  71. model.dbnet.load_state_dict(torch.load(hyp_model['db_weights'])['weights'])
  72. if hyp_train['optimizer'] == 'adam':
  73. optimizer = Adam(model.dbnet.parameters(), lr=hyp_train['lr0'], betas=(0.937, 0.999), weight_decay=0.00002) # adjust beta1 to momentum
  74. else:
  75. optimizer = SGD(model.dbnet.parameters(), lr=hyp_train['lr0'], momentum=hyp_train['momentum'], nesterov=True, weight_decay=hyp_train['weight_decay'])
  76. if hyp_train['linear_lr']:
  77. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp_train['lrf']) + hyp_train['lrf'] # linear
  78. else:
  79. lf = one_cycle(1, hyp_train['lrf'], epochs) # cosine 1->hyp['lrf']
  80. if hyp_train['linear_lr']:
  81. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp_train['lrf']) + hyp_train['lrf'] # linear
  82. else:
  83. lf = one_cycle(1, hyp_train['lrf'], epochs) # cosine 1->hyp['lrf']
  84. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  85. logger = None
  86. if hyp_resume['resume_training']:
  87. LOGGER.info(f'resume traning ... ')
  88. ckpt = torch.load(hyp_resume['ckpt'], map_location=DEVICE)
  89. model.dbnet.load_state_dict(ckpt['weights'])
  90. optimizer.load_state_dict(ckpt['optimizer'])
  91. scheduler.load_state_dict(ckpt['scheduler'])
  92. scheduler.step()
  93. start_epoch = ckpt['epoch'] + 1
  94. hyp_logger['run_id'] = ckpt['run_id']
  95. logger = Loggers(hyp)
  96. else:
  97. # if hyp_logger['type'] == 'wandb':
  98. logger = Loggers(hyp)
  99. train_img_dir, train_mask_dir, imgsz, augment, aug_param = hyp_data['train_img_dir'], hyp_data['train_mask_dir'], hyp_data['imgsz'], hyp_data['augment'], hyp_data['aug_param']
  100. val_img_dir, val_mask_dir = hyp_data['val_img_dir'], hyp_data['val_mask_dir']
  101. train_dataset, train_loader = create_dataloader(train_img_dir, train_mask_dir, imgsz, batch_size, augment, aug_param, shuffle=True, workers=hyp_data['num_workers'], cache=hyp_data['cache'])
  102. val_dataset, val_loader = create_dataloader(val_img_dir, val_mask_dir, imgsz, batch_size, augment=False, shuffle=False, workers=hyp_data['num_workers'], cache=hyp_data['cache'], with_ann=True)
  103. nb = len(train_loader)
  104. nw = max(round(3 * nb), 700)
  105. LOGGER.info(f'num training imgs: {len(train_dataset)}, num val imgs: {len(val_dataset)}')
  106. eval_interval = hyp_train['eval_interval']
  107. best_f1 = best_epoch = -1
  108. best_val_loss = np.inf
  109. accumulation_steps = hyp_train['accumulation_steps']
  110. summary(model, (3, 640, 640), device=DEVICE)
  111. metric_cls = QuadMetric()
  112. post_process = SegDetectorRepresenter(thresh=0.5)
  113. best_f1 = -1
  114. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  115. model.train_db()
  116. pbar = enumerate(train_loader)
  117. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  118. m_loss = 0
  119. m_loss_s = 0
  120. m_loss_t = 0
  121. m_loss_b = 0
  122. for i, batchs in pbar:
  123. if (i+2) % 256 == 0:
  124. train_dataset.initialize()
  125. pbar.set_description(f' training size: {train_dataset.img_size}')
  126. # warm up
  127. if hyp_train['warm_up']:
  128. ni = i + nb * epoch
  129. if ni <= nw:
  130. xi = [0, nw] # x interp
  131. for j, x in enumerate(optimizer.param_groups):
  132. x['lr'] = np.interp(ni, xi, [hyp_train['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  133. if 'momentum' in x:
  134. x['momentum'] = np.interp(ni, xi, [hyp_train['warmup_momentum'], hyp_train['momentum']])
  135. with amp.autocast():
  136. for key in batchs.keys():
  137. batchs[key] = batchs[key].cuda()
  138. preds = model(batchs['imgs'])
  139. metric = criterion(preds, batchs, use_bce)
  140. loss = metric['loss'] / accumulation_steps
  141. scaler.scale(loss).backward()
  142. if (i+1) % accumulation_steps == 0:
  143. scaler.step(optimizer)
  144. scaler.update()
  145. optimizer.zero_grad()
  146. m_loss = (m_loss * i + metric['loss'].detach()) / (i + 1)
  147. m_loss_s = (m_loss_s * i + metric['loss_shrink_maps'].detach()) / (i + 1)
  148. m_loss_t = (m_loss_t * i + metric['loss_threshold_maps'].detach()) / (i + 1)
  149. m_loss_b = (m_loss_b * i + metric['loss_binary_maps'].detach()) / (i + 1)
  150. if i % eval_interval == 0:
  151. recall, precision, fmeasure = eval_model(model, val_loader, post_process, metric_cls)
  152. log_dict = {}
  153. log_dict['train/lr'] = optimizer.param_groups[0]['lr']
  154. log_dict['train/loss'] = m_loss
  155. log_dict['train/loss_shrink'] = m_loss_s
  156. log_dict['train/loss_threshold'] = m_loss_t
  157. log_dict['train/loss_binary_maps'] = m_loss_b
  158. log_dict['eval/recall'] = recall
  159. log_dict['eval/precision'] = precision
  160. log_dict['eval/f1'] = fmeasure
  161. save_best = best_f1 < fmeasure
  162. if save_best:
  163. best_f1 = fmeasure
  164. last_ckpt = {'epoch': epoch,
  165. 'best_f1': best_f1,
  166. 'weights': model.dbnet.state_dict(),
  167. 'best_val_loss': best_val_loss,
  168. 'optimizer': optimizer.state_dict(),
  169. 'scheduler': scheduler.state_dict(),
  170. 'run_id': logger.wandb.id if logger.wandb is not None else None,
  171. 'date': datetime.now().isoformat(),
  172. 'hyp': hyp}
  173. torch.save(last_ckpt, 'data/db_last.ckpt')
  174. if save_best:
  175. shutil.copy('data/db_last.ckpt', 'data/db_best.ckpt')
  176. if logger is not None:
  177. logger.on_train_epoch_end(epoch, log_dict)
  178. scheduler.step()
  179. pbar.close()
  180. if __name__ == '__main__':
  181. hyp_p = r'data/train_db_hyp.yaml'
  182. with open(hyp_p, 'r', encoding='utf8') as f:
  183. hyp = yaml.safe_load(f.read())
  184. # hyp['data']['train_img_dir'] = r'../datasets/pixanimegirls/processed'
  185. hyp['data']['train_img_dir'] = [r'../datasets/codat_manga_v3/images/train', r'../datasets/codat_manga_v3/images/val', r'../datasets/pixanimegirls/processed']
  186. hyp['data']['train_mask_dir'] = r'../datasets/TextLines'
  187. # hyp['data']['train_img_dir'] = r'data/dataset/db_sub'
  188. hyp['data']['val_img_dir'] = r'data/dataset/db_sub'
  189. hyp['data']['cache'] = False
  190. # hyp['data']['aug_param']['size_range'] = [-1]
  191. hyp['train']['lr0'] = 0.01
  192. hyp['train']['lrf'] = 0.002
  193. hyp['train']['weight_decay'] = 0.00002
  194. hyp['train']['batch_size'] = 4
  195. hyp['train']['epochs'] = 160
  196. # hyp['train']['optimizer'] = 'sgd'
  197. hyp['train']['loss'] = 'bce'
  198. hyp['logger']['type'] = 'wandb'
  199. # hyp['resume']['resume_training'] = True
  200. # hyp['resume']['ckpt'] = 'data/db_last_bk.ckpt'
  201. # hyp['model']['db_weights'] = r'data/db_last.ckpt'
  202. train(hyp)