base_trainer.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/8/23 21:50
  3. # @Author : zhoujun
  4. import os
  5. import pathlib
  6. import shutil
  7. from pprint import pformat
  8. import anyconfig
  9. import paddle
  10. import numpy as np
  11. import random
  12. from paddle.jit import to_static
  13. from paddle.static import InputSpec
  14. from utils import setup_logger
  15. class BaseTrainer:
  16. def __init__(
  17. self,
  18. config,
  19. model,
  20. criterion,
  21. train_loader,
  22. validate_loader,
  23. metric_cls,
  24. post_process=None,
  25. ):
  26. config["trainer"]["output_dir"] = os.path.join(
  27. str(pathlib.Path(os.path.abspath(__name__)).parent),
  28. config["trainer"]["output_dir"],
  29. )
  30. config["name"] = config["name"] + "_" + model.name
  31. self.save_dir = config["trainer"]["output_dir"]
  32. self.checkpoint_dir = os.path.join(self.save_dir, "checkpoint")
  33. os.makedirs(self.checkpoint_dir, exist_ok=True)
  34. self.global_step = 0
  35. self.start_epoch = 0
  36. self.config = config
  37. self.criterion = criterion
  38. # logger and tensorboard
  39. self.visualdl_enable = self.config["trainer"].get("visual_dl", False)
  40. self.epochs = self.config["trainer"]["epochs"]
  41. self.log_iter = self.config["trainer"]["log_iter"]
  42. if paddle.distributed.get_rank() == 0:
  43. anyconfig.dump(config, os.path.join(self.save_dir, "config.yaml"))
  44. self.logger = setup_logger(os.path.join(self.save_dir, "train.log"))
  45. self.logger_info(pformat(self.config))
  46. self.model = self.apply_to_static(model)
  47. # device
  48. if (
  49. paddle.device.cuda.device_count() > 0
  50. and paddle.device.is_compiled_with_cuda()
  51. ):
  52. self.with_cuda = True
  53. random.seed(self.config["trainer"]["seed"])
  54. np.random.seed(self.config["trainer"]["seed"])
  55. paddle.seed(self.config["trainer"]["seed"])
  56. else:
  57. self.with_cuda = False
  58. self.logger_info("train with and paddle {}".format(paddle.__version__))
  59. # metrics
  60. self.metrics = {
  61. "recall": 0,
  62. "precision": 0,
  63. "hmean": 0,
  64. "train_loss": float("inf"),
  65. "best_model_epoch": 0,
  66. }
  67. self.train_loader = train_loader
  68. if validate_loader is not None:
  69. assert post_process is not None and metric_cls is not None
  70. self.validate_loader = validate_loader
  71. self.post_process = post_process
  72. self.metric_cls = metric_cls
  73. self.train_loader_len = len(train_loader)
  74. if self.validate_loader is not None:
  75. self.logger_info(
  76. "train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader".format(
  77. len(self.train_loader.dataset),
  78. self.train_loader_len,
  79. len(self.validate_loader.dataset),
  80. len(self.validate_loader),
  81. )
  82. )
  83. else:
  84. self.logger_info(
  85. "train dataset has {} samples,{} in dataloader".format(
  86. len(self.train_loader.dataset), self.train_loader_len
  87. )
  88. )
  89. self._initialize_scheduler()
  90. self._initialize_optimizer()
  91. # resume or finetune
  92. if self.config["trainer"]["resume_checkpoint"] != "":
  93. self._load_checkpoint(
  94. self.config["trainer"]["resume_checkpoint"], resume=True
  95. )
  96. elif self.config["trainer"]["finetune_checkpoint"] != "":
  97. self._load_checkpoint(
  98. self.config["trainer"]["finetune_checkpoint"], resume=False
  99. )
  100. if self.visualdl_enable and paddle.distributed.get_rank() == 0:
  101. from visualdl import LogWriter
  102. self.writer = LogWriter(self.save_dir)
  103. # 混合精度训练
  104. self.amp = self.config.get("amp", None)
  105. if self.amp == "None":
  106. self.amp = None
  107. if self.amp:
  108. self.amp["scaler"] = paddle.amp.GradScaler(
  109. init_loss_scaling=self.amp.get("scale_loss", 1024),
  110. use_dynamic_loss_scaling=self.amp.get("use_dynamic_loss_scaling", True),
  111. )
  112. self.model, self.optimizer = paddle.amp.decorate(
  113. models=self.model,
  114. optimizers=self.optimizer,
  115. level=self.amp.get("amp_level", "O2"),
  116. )
  117. # 分布式训练
  118. if paddle.device.cuda.device_count() > 1:
  119. self.model = paddle.DataParallel(self.model)
  120. # make inverse Normalize
  121. self.UN_Normalize = False
  122. for t in self.config["dataset"]["train"]["dataset"]["args"]["transforms"]:
  123. if t["type"] == "Normalize":
  124. self.normalize_mean = t["args"]["mean"]
  125. self.normalize_std = t["args"]["std"]
  126. self.UN_Normalize = True
  127. def apply_to_static(self, model):
  128. support_to_static = self.config["trainer"].get("to_static", False)
  129. if support_to_static:
  130. specs = None
  131. print("static")
  132. specs = [InputSpec([None, 3, -1, -1])]
  133. model = to_static(model, input_spec=specs)
  134. self.logger_info(
  135. "Successfully to apply @to_static with specs: {}".format(specs)
  136. )
  137. return model
  138. def train(self):
  139. """
  140. Full training logic
  141. """
  142. for epoch in range(self.start_epoch + 1, self.epochs + 1):
  143. self.epoch_result = self._train_epoch(epoch)
  144. self._on_epoch_finish()
  145. if paddle.distributed.get_rank() == 0 and self.visualdl_enable:
  146. self.writer.close()
  147. self._on_train_finish()
  148. def _train_epoch(self, epoch):
  149. """
  150. Training logic for an epoch
  151. :param epoch: Current epoch number
  152. """
  153. raise NotImplementedError
  154. def _eval(self, epoch):
  155. """
  156. eval logic for an epoch
  157. :param epoch: Current epoch number
  158. """
  159. raise NotImplementedError
  160. def _on_epoch_finish(self):
  161. raise NotImplementedError
  162. def _on_train_finish(self):
  163. raise NotImplementedError
  164. def _save_checkpoint(self, epoch, file_name):
  165. """
  166. Saving checkpoints
  167. :param epoch: current epoch number
  168. :param log: logging information of the epoch
  169. :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
  170. """
  171. state_dict = self.model.state_dict()
  172. state = {
  173. "epoch": epoch,
  174. "global_step": self.global_step,
  175. "state_dict": state_dict,
  176. "optimizer": self.optimizer.state_dict(),
  177. "config": self.config,
  178. "metrics": self.metrics,
  179. }
  180. filename = os.path.join(self.checkpoint_dir, file_name)
  181. paddle.save(state, filename)
  182. def _load_checkpoint(self, checkpoint_path, resume):
  183. """
  184. Resume from saved checkpoints
  185. :param checkpoint_path: Checkpoint path to be resumed
  186. """
  187. self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path))
  188. checkpoint = paddle.load(checkpoint_path)
  189. self.model.set_state_dict(checkpoint["state_dict"])
  190. if resume:
  191. self.global_step = checkpoint["global_step"]
  192. self.start_epoch = checkpoint["epoch"]
  193. self.config["lr_scheduler"]["args"]["last_epoch"] = self.start_epoch
  194. # self.scheduler.load_state_dict(checkpoint['scheduler'])
  195. self.optimizer.set_state_dict(checkpoint["optimizer"])
  196. if "metrics" in checkpoint:
  197. self.metrics = checkpoint["metrics"]
  198. self.logger_info(
  199. "resume from checkpoint {} (epoch {})".format(
  200. checkpoint_path, self.start_epoch
  201. )
  202. )
  203. else:
  204. self.logger_info("finetune from checkpoint {}".format(checkpoint_path))
  205. def _initialize(self, name, module, *args, **kwargs):
  206. module_name = self.config[name]["type"]
  207. module_args = self.config[name].get("args", {})
  208. assert all(
  209. [k not in module_args for k in kwargs]
  210. ), "Overwriting kwargs given in config file is not allowed"
  211. module_args.update(kwargs)
  212. return getattr(module, module_name)(*args, **module_args)
  213. def _initialize_scheduler(self):
  214. self.lr_scheduler = self._initialize("lr_scheduler", paddle.optimizer.lr)
  215. def _initialize_optimizer(self):
  216. self.optimizer = self._initialize(
  217. "optimizer",
  218. paddle.optimizer,
  219. parameters=self.model.parameters(),
  220. learning_rate=self.lr_scheduler,
  221. )
  222. def inverse_normalize(self, batch_img):
  223. if self.UN_Normalize:
  224. batch_img[:, 0, :, :] = (
  225. batch_img[:, 0, :, :] * self.normalize_std[0] + self.normalize_mean[0]
  226. )
  227. batch_img[:, 1, :, :] = (
  228. batch_img[:, 1, :, :] * self.normalize_std[1] + self.normalize_mean[1]
  229. )
  230. batch_img[:, 2, :, :] = (
  231. batch_img[:, 2, :, :] * self.normalize_std[2] + self.normalize_mean[2]
  232. )
  233. def logger_info(self, s):
  234. if paddle.distributed.get_rank() == 0:
  235. self.logger.info(s)