| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- # -*- coding: utf-8 -*-
- # @Time : 2019/8/23 21:50
- # @Author : zhoujun
- import os
- import pathlib
- import shutil
- from pprint import pformat
- import anyconfig
- import paddle
- import numpy as np
- import random
- from paddle.jit import to_static
- from paddle.static import InputSpec
- from utils import setup_logger
- class BaseTrainer:
- def __init__(
- self,
- config,
- model,
- criterion,
- train_loader,
- validate_loader,
- metric_cls,
- post_process=None,
- ):
- config["trainer"]["output_dir"] = os.path.join(
- str(pathlib.Path(os.path.abspath(__name__)).parent),
- config["trainer"]["output_dir"],
- )
- config["name"] = config["name"] + "_" + model.name
- self.save_dir = config["trainer"]["output_dir"]
- self.checkpoint_dir = os.path.join(self.save_dir, "checkpoint")
- os.makedirs(self.checkpoint_dir, exist_ok=True)
- self.global_step = 0
- self.start_epoch = 0
- self.config = config
- self.criterion = criterion
- # logger and tensorboard
- self.visualdl_enable = self.config["trainer"].get("visual_dl", False)
- self.epochs = self.config["trainer"]["epochs"]
- self.log_iter = self.config["trainer"]["log_iter"]
- if paddle.distributed.get_rank() == 0:
- anyconfig.dump(config, os.path.join(self.save_dir, "config.yaml"))
- self.logger = setup_logger(os.path.join(self.save_dir, "train.log"))
- self.logger_info(pformat(self.config))
- self.model = self.apply_to_static(model)
- # device
- if (
- paddle.device.cuda.device_count() > 0
- and paddle.device.is_compiled_with_cuda()
- ):
- self.with_cuda = True
- random.seed(self.config["trainer"]["seed"])
- np.random.seed(self.config["trainer"]["seed"])
- paddle.seed(self.config["trainer"]["seed"])
- else:
- self.with_cuda = False
- self.logger_info("train with and paddle {}".format(paddle.__version__))
- # metrics
- self.metrics = {
- "recall": 0,
- "precision": 0,
- "hmean": 0,
- "train_loss": float("inf"),
- "best_model_epoch": 0,
- }
- self.train_loader = train_loader
- if validate_loader is not None:
- assert post_process is not None and metric_cls is not None
- self.validate_loader = validate_loader
- self.post_process = post_process
- self.metric_cls = metric_cls
- self.train_loader_len = len(train_loader)
- if self.validate_loader is not None:
- self.logger_info(
- "train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader".format(
- len(self.train_loader.dataset),
- self.train_loader_len,
- len(self.validate_loader.dataset),
- len(self.validate_loader),
- )
- )
- else:
- self.logger_info(
- "train dataset has {} samples,{} in dataloader".format(
- len(self.train_loader.dataset), self.train_loader_len
- )
- )
- self._initialize_scheduler()
- self._initialize_optimizer()
- # resume or finetune
- if self.config["trainer"]["resume_checkpoint"] != "":
- self._load_checkpoint(
- self.config["trainer"]["resume_checkpoint"], resume=True
- )
- elif self.config["trainer"]["finetune_checkpoint"] != "":
- self._load_checkpoint(
- self.config["trainer"]["finetune_checkpoint"], resume=False
- )
- if self.visualdl_enable and paddle.distributed.get_rank() == 0:
- from visualdl import LogWriter
- self.writer = LogWriter(self.save_dir)
- # 混合精度训练
- self.amp = self.config.get("amp", None)
- if self.amp == "None":
- self.amp = None
- if self.amp:
- self.amp["scaler"] = paddle.amp.GradScaler(
- init_loss_scaling=self.amp.get("scale_loss", 1024),
- use_dynamic_loss_scaling=self.amp.get("use_dynamic_loss_scaling", True),
- )
- self.model, self.optimizer = paddle.amp.decorate(
- models=self.model,
- optimizers=self.optimizer,
- level=self.amp.get("amp_level", "O2"),
- )
- # 分布式训练
- if paddle.device.cuda.device_count() > 1:
- self.model = paddle.DataParallel(self.model)
- # make inverse Normalize
- self.UN_Normalize = False
- for t in self.config["dataset"]["train"]["dataset"]["args"]["transforms"]:
- if t["type"] == "Normalize":
- self.normalize_mean = t["args"]["mean"]
- self.normalize_std = t["args"]["std"]
- self.UN_Normalize = True
- def apply_to_static(self, model):
- support_to_static = self.config["trainer"].get("to_static", False)
- if support_to_static:
- specs = None
- print("static")
- specs = [InputSpec([None, 3, -1, -1])]
- model = to_static(model, input_spec=specs)
- self.logger_info(
- "Successfully to apply @to_static with specs: {}".format(specs)
- )
- return model
- def train(self):
- """
- Full training logic
- """
- for epoch in range(self.start_epoch + 1, self.epochs + 1):
- self.epoch_result = self._train_epoch(epoch)
- self._on_epoch_finish()
- if paddle.distributed.get_rank() == 0 and self.visualdl_enable:
- self.writer.close()
- self._on_train_finish()
- def _train_epoch(self, epoch):
- """
- Training logic for an epoch
- :param epoch: Current epoch number
- """
- raise NotImplementedError
- def _eval(self, epoch):
- """
- eval logic for an epoch
- :param epoch: Current epoch number
- """
- raise NotImplementedError
- def _on_epoch_finish(self):
- raise NotImplementedError
- def _on_train_finish(self):
- raise NotImplementedError
- def _save_checkpoint(self, epoch, file_name):
- """
- Saving checkpoints
- :param epoch: current epoch number
- :param log: logging information of the epoch
- :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
- """
- state_dict = self.model.state_dict()
- state = {
- "epoch": epoch,
- "global_step": self.global_step,
- "state_dict": state_dict,
- "optimizer": self.optimizer.state_dict(),
- "config": self.config,
- "metrics": self.metrics,
- }
- filename = os.path.join(self.checkpoint_dir, file_name)
- paddle.save(state, filename)
- def _load_checkpoint(self, checkpoint_path, resume):
- """
- Resume from saved checkpoints
- :param checkpoint_path: Checkpoint path to be resumed
- """
- self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path))
- checkpoint = paddle.load(checkpoint_path)
- self.model.set_state_dict(checkpoint["state_dict"])
- if resume:
- self.global_step = checkpoint["global_step"]
- self.start_epoch = checkpoint["epoch"]
- self.config["lr_scheduler"]["args"]["last_epoch"] = self.start_epoch
- # self.scheduler.load_state_dict(checkpoint['scheduler'])
- self.optimizer.set_state_dict(checkpoint["optimizer"])
- if "metrics" in checkpoint:
- self.metrics = checkpoint["metrics"]
- self.logger_info(
- "resume from checkpoint {} (epoch {})".format(
- checkpoint_path, self.start_epoch
- )
- )
- else:
- self.logger_info("finetune from checkpoint {}".format(checkpoint_path))
- def _initialize(self, name, module, *args, **kwargs):
- module_name = self.config[name]["type"]
- module_args = self.config[name].get("args", {})
- assert all(
- [k not in module_args for k in kwargs]
- ), "Overwriting kwargs given in config file is not allowed"
- module_args.update(kwargs)
- return getattr(module, module_name)(*args, **module_args)
- def _initialize_scheduler(self):
- self.lr_scheduler = self._initialize("lr_scheduler", paddle.optimizer.lr)
- def _initialize_optimizer(self):
- self.optimizer = self._initialize(
- "optimizer",
- paddle.optimizer,
- parameters=self.model.parameters(),
- learning_rate=self.lr_scheduler,
- )
- def inverse_normalize(self, batch_img):
- if self.UN_Normalize:
- batch_img[:, 0, :, :] = (
- batch_img[:, 0, :, :] * self.normalize_std[0] + self.normalize_mean[0]
- )
- batch_img[:, 1, :, :] = (
- batch_img[:, 1, :, :] * self.normalize_std[1] + self.normalize_mean[1]
- )
- batch_img[:, 2, :, :] = (
- batch_img[:, 2, :, :] * self.normalize_std[2] + self.normalize_mean[2]
- )
- def logger_info(self, s):
- if paddle.distributed.get_rank() == 0:
- self.logger.info(s)
|