trainer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/8/23 21:58
  3. # @Author : zhoujun
  4. import time
  5. import paddle
  6. from tqdm import tqdm
  7. from base import BaseTrainer
  8. from utils import runningScore, cal_text_score, Polynomial, profiler
  9. class Trainer(BaseTrainer):
  10. def __init__(
  11. self,
  12. config,
  13. model,
  14. criterion,
  15. train_loader,
  16. validate_loader,
  17. metric_cls,
  18. post_process=None,
  19. profiler_options=None,
  20. ):
  21. super(Trainer, self).__init__(
  22. config,
  23. model,
  24. criterion,
  25. train_loader,
  26. validate_loader,
  27. metric_cls,
  28. post_process,
  29. )
  30. self.profiler_options = profiler_options
  31. self.enable_eval = config["trainer"].get("enable_eval", True)
  32. def _train_epoch(self, epoch):
  33. self.model.train()
  34. total_samples = 0
  35. train_reader_cost = 0.0
  36. train_batch_cost = 0.0
  37. reader_start = time.time()
  38. epoch_start = time.time()
  39. train_loss = 0.0
  40. running_metric_text = runningScore(2)
  41. for i, batch in enumerate(self.train_loader):
  42. profiler.add_profiler_step(self.profiler_options)
  43. if i >= self.train_loader_len:
  44. break
  45. self.global_step += 1
  46. lr = self.optimizer.get_lr()
  47. cur_batch_size = batch["img"].shape[0]
  48. train_reader_cost += time.time() - reader_start
  49. if self.amp:
  50. with paddle.amp.auto_cast(
  51. enable="gpu" in paddle.device.get_device(),
  52. custom_white_list=self.amp.get("custom_white_list", []),
  53. custom_black_list=self.amp.get("custom_black_list", []),
  54. level=self.amp.get("level", "O2"),
  55. ):
  56. preds = self.model(batch["img"])
  57. loss_dict = self.criterion(preds.astype(paddle.float32), batch)
  58. scaled_loss = self.amp["scaler"].scale(loss_dict["loss"])
  59. scaled_loss.backward()
  60. self.amp["scaler"].minimize(self.optimizer, scaled_loss)
  61. else:
  62. preds = self.model(batch["img"])
  63. loss_dict = self.criterion(preds, batch)
  64. # backward
  65. loss_dict["loss"].backward()
  66. self.optimizer.step()
  67. self.lr_scheduler.step()
  68. self.optimizer.clear_grad()
  69. train_batch_time = time.time() - reader_start
  70. train_batch_cost += train_batch_time
  71. total_samples += cur_batch_size
  72. # acc iou
  73. score_shrink_map = cal_text_score(
  74. preds[:, 0, :, :],
  75. batch["shrink_map"],
  76. batch["shrink_mask"],
  77. running_metric_text,
  78. thred=self.config["post_processing"]["args"]["thresh"],
  79. )
  80. # loss 和 acc 记录到日志
  81. loss_str = "loss: {:.4f}, ".format(loss_dict["loss"].item())
  82. for idx, (key, value) in enumerate(loss_dict.items()):
  83. loss_dict[key] = value.item()
  84. if key == "loss":
  85. continue
  86. loss_str += "{}: {:.4f}".format(key, loss_dict[key])
  87. if idx < len(loss_dict) - 1:
  88. loss_str += ", "
  89. train_loss += loss_dict["loss"]
  90. acc = score_shrink_map["Mean Acc"]
  91. iou_shrink_map = score_shrink_map["Mean IoU"]
  92. if self.global_step % self.log_iter == 0:
  93. self.logger_info(
  94. "[{}/{}], [{}/{}], global_step: {}, ips: {:.1f} samples/sec, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}".format(
  95. epoch,
  96. self.epochs,
  97. i + 1,
  98. self.train_loader_len,
  99. self.global_step,
  100. total_samples / train_batch_cost,
  101. train_reader_cost / self.log_iter,
  102. train_batch_cost / self.log_iter,
  103. total_samples / self.log_iter,
  104. acc,
  105. iou_shrink_map,
  106. loss_str,
  107. lr,
  108. train_batch_cost,
  109. )
  110. )
  111. total_samples = 0
  112. train_reader_cost = 0.0
  113. train_batch_cost = 0.0
  114. if self.visualdl_enable and paddle.distributed.get_rank() == 0:
  115. # write tensorboard
  116. for key, value in loss_dict.items():
  117. self.writer.add_scalar(
  118. "TRAIN/LOSS/{}".format(key), value, self.global_step
  119. )
  120. self.writer.add_scalar("TRAIN/ACC_IOU/acc", acc, self.global_step)
  121. self.writer.add_scalar(
  122. "TRAIN/ACC_IOU/iou_shrink_map", iou_shrink_map, self.global_step
  123. )
  124. self.writer.add_scalar("TRAIN/lr", lr, self.global_step)
  125. reader_start = time.time()
  126. return {
  127. "train_loss": train_loss / self.train_loader_len,
  128. "lr": lr,
  129. "time": time.time() - epoch_start,
  130. "epoch": epoch,
  131. }
  132. def _eval(self, epoch):
  133. self.model.eval()
  134. raw_metrics = []
  135. total_frame = 0.0
  136. total_time = 0.0
  137. for i, batch in tqdm(
  138. enumerate(self.validate_loader),
  139. total=len(self.validate_loader),
  140. desc="test model",
  141. ):
  142. with paddle.no_grad():
  143. start = time.time()
  144. if self.amp:
  145. with paddle.amp.auto_cast(
  146. enable="gpu" in paddle.device.get_device(),
  147. custom_white_list=self.amp.get("custom_white_list", []),
  148. custom_black_list=self.amp.get("custom_black_list", []),
  149. level=self.amp.get("level", "O2"),
  150. ):
  151. preds = self.model(batch["img"])
  152. preds = preds.astype(paddle.float32)
  153. else:
  154. preds = self.model(batch["img"])
  155. boxes, scores = self.post_process(
  156. batch, preds, is_output_polygon=self.metric_cls.is_output_polygon
  157. )
  158. total_frame += batch["img"].shape[0]
  159. total_time += time.time() - start
  160. raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
  161. raw_metrics.append(raw_metric)
  162. metrics = self.metric_cls.gather_measure(raw_metrics)
  163. self.logger_info("FPS:{}".format(total_frame / total_time))
  164. return metrics["recall"].avg, metrics["precision"].avg, metrics["fmeasure"].avg
  165. def _on_epoch_finish(self):
  166. self.logger_info(
  167. "[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}".format(
  168. self.epoch_result["epoch"],
  169. self.epochs,
  170. self.epoch_result["train_loss"],
  171. self.epoch_result["time"],
  172. self.epoch_result["lr"],
  173. )
  174. )
  175. net_save_path = "{}/model_latest.pth".format(self.checkpoint_dir)
  176. net_save_path_best = "{}/model_best.pth".format(self.checkpoint_dir)
  177. if paddle.distributed.get_rank() == 0:
  178. self._save_checkpoint(self.epoch_result["epoch"], net_save_path)
  179. save_best = False
  180. if (
  181. self.validate_loader is not None
  182. and self.metric_cls is not None
  183. and self.enable_eval
  184. ): # 使用f1作为最优模型指标
  185. recall, precision, hmean = self._eval(self.epoch_result["epoch"])
  186. if self.visualdl_enable:
  187. self.writer.add_scalar("EVAL/recall", recall, self.global_step)
  188. self.writer.add_scalar(
  189. "EVAL/precision", precision, self.global_step
  190. )
  191. self.writer.add_scalar("EVAL/hmean", hmean, self.global_step)
  192. self.logger_info(
  193. "test: recall: {:.6f}, precision: {:.6f}, hmean: {:.6f}".format(
  194. recall, precision, hmean
  195. )
  196. )
  197. if hmean >= self.metrics["hmean"]:
  198. save_best = True
  199. self.metrics["train_loss"] = self.epoch_result["train_loss"]
  200. self.metrics["hmean"] = hmean
  201. self.metrics["precision"] = precision
  202. self.metrics["recall"] = recall
  203. self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
  204. else:
  205. if self.epoch_result["train_loss"] <= self.metrics["train_loss"]:
  206. save_best = True
  207. self.metrics["train_loss"] = self.epoch_result["train_loss"]
  208. self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
  209. best_str = "current best, "
  210. for k, v in self.metrics.items():
  211. best_str += "{}: {:.6f}, ".format(k, v)
  212. self.logger_info(best_str)
  213. if save_best:
  214. import shutil
  215. shutil.copy(net_save_path, net_save_path_best)
  216. self.logger_info("Saving current best: {}".format(net_save_path_best))
  217. else:
  218. self.logger_info("Saving checkpoint: {}".format(net_save_path))
  219. def _on_train_finish(self):
  220. if self.enable_eval:
  221. for k, v in self.metrics.items():
  222. self.logger_info("{}:{}".format(k, v))
  223. self.logger_info("finish train")
  224. def _initialize_scheduler(self):
  225. if self.config["lr_scheduler"]["type"] == "Polynomial":
  226. self.config["lr_scheduler"]["args"]["epochs"] = self.config["trainer"][
  227. "epochs"
  228. ]
  229. self.config["lr_scheduler"]["args"]["step_each_epoch"] = len(
  230. self.train_loader
  231. )
  232. self.lr_scheduler = Polynomial(**self.config["lr_scheduler"]["args"])()
  233. else:
  234. self.lr_scheduler = self._initialize("lr_scheduler", paddle.optimizer.lr)