| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- # -*- coding: utf-8 -*-
- # @Time : 2019/8/23 21:58
- # @Author : zhoujun
- import time
- import paddle
- from tqdm import tqdm
- from base import BaseTrainer
- from utils import runningScore, cal_text_score, Polynomial, profiler
- class Trainer(BaseTrainer):
- def __init__(
- self,
- config,
- model,
- criterion,
- train_loader,
- validate_loader,
- metric_cls,
- post_process=None,
- profiler_options=None,
- ):
- super(Trainer, self).__init__(
- config,
- model,
- criterion,
- train_loader,
- validate_loader,
- metric_cls,
- post_process,
- )
- self.profiler_options = profiler_options
- self.enable_eval = config["trainer"].get("enable_eval", True)
- def _train_epoch(self, epoch):
- self.model.train()
- total_samples = 0
- train_reader_cost = 0.0
- train_batch_cost = 0.0
- reader_start = time.time()
- epoch_start = time.time()
- train_loss = 0.0
- running_metric_text = runningScore(2)
- for i, batch in enumerate(self.train_loader):
- profiler.add_profiler_step(self.profiler_options)
- if i >= self.train_loader_len:
- break
- self.global_step += 1
- lr = self.optimizer.get_lr()
- cur_batch_size = batch["img"].shape[0]
- train_reader_cost += time.time() - reader_start
- if self.amp:
- with paddle.amp.auto_cast(
- enable="gpu" in paddle.device.get_device(),
- custom_white_list=self.amp.get("custom_white_list", []),
- custom_black_list=self.amp.get("custom_black_list", []),
- level=self.amp.get("level", "O2"),
- ):
- preds = self.model(batch["img"])
- loss_dict = self.criterion(preds.astype(paddle.float32), batch)
- scaled_loss = self.amp["scaler"].scale(loss_dict["loss"])
- scaled_loss.backward()
- self.amp["scaler"].minimize(self.optimizer, scaled_loss)
- else:
- preds = self.model(batch["img"])
- loss_dict = self.criterion(preds, batch)
- # backward
- loss_dict["loss"].backward()
- self.optimizer.step()
- self.lr_scheduler.step()
- self.optimizer.clear_grad()
- train_batch_time = time.time() - reader_start
- train_batch_cost += train_batch_time
- total_samples += cur_batch_size
- # acc iou
- score_shrink_map = cal_text_score(
- preds[:, 0, :, :],
- batch["shrink_map"],
- batch["shrink_mask"],
- running_metric_text,
- thred=self.config["post_processing"]["args"]["thresh"],
- )
- # loss 和 acc 记录到日志
- loss_str = "loss: {:.4f}, ".format(loss_dict["loss"].item())
- for idx, (key, value) in enumerate(loss_dict.items()):
- loss_dict[key] = value.item()
- if key == "loss":
- continue
- loss_str += "{}: {:.4f}".format(key, loss_dict[key])
- if idx < len(loss_dict) - 1:
- loss_str += ", "
- train_loss += loss_dict["loss"]
- acc = score_shrink_map["Mean Acc"]
- iou_shrink_map = score_shrink_map["Mean IoU"]
- if self.global_step % self.log_iter == 0:
- self.logger_info(
- "[{}/{}], [{}/{}], global_step: {}, ips: {:.1f} samples/sec, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}".format(
- epoch,
- self.epochs,
- i + 1,
- self.train_loader_len,
- self.global_step,
- total_samples / train_batch_cost,
- train_reader_cost / self.log_iter,
- train_batch_cost / self.log_iter,
- total_samples / self.log_iter,
- acc,
- iou_shrink_map,
- loss_str,
- lr,
- train_batch_cost,
- )
- )
- total_samples = 0
- train_reader_cost = 0.0
- train_batch_cost = 0.0
- if self.visualdl_enable and paddle.distributed.get_rank() == 0:
- # write tensorboard
- for key, value in loss_dict.items():
- self.writer.add_scalar(
- "TRAIN/LOSS/{}".format(key), value, self.global_step
- )
- self.writer.add_scalar("TRAIN/ACC_IOU/acc", acc, self.global_step)
- self.writer.add_scalar(
- "TRAIN/ACC_IOU/iou_shrink_map", iou_shrink_map, self.global_step
- )
- self.writer.add_scalar("TRAIN/lr", lr, self.global_step)
- reader_start = time.time()
- return {
- "train_loss": train_loss / self.train_loader_len,
- "lr": lr,
- "time": time.time() - epoch_start,
- "epoch": epoch,
- }
- def _eval(self, epoch):
- self.model.eval()
- raw_metrics = []
- total_frame = 0.0
- total_time = 0.0
- for i, batch in tqdm(
- enumerate(self.validate_loader),
- total=len(self.validate_loader),
- desc="test model",
- ):
- with paddle.no_grad():
- start = time.time()
- if self.amp:
- with paddle.amp.auto_cast(
- enable="gpu" in paddle.device.get_device(),
- custom_white_list=self.amp.get("custom_white_list", []),
- custom_black_list=self.amp.get("custom_black_list", []),
- level=self.amp.get("level", "O2"),
- ):
- preds = self.model(batch["img"])
- preds = preds.astype(paddle.float32)
- else:
- preds = self.model(batch["img"])
- boxes, scores = self.post_process(
- batch, preds, is_output_polygon=self.metric_cls.is_output_polygon
- )
- total_frame += batch["img"].shape[0]
- total_time += time.time() - start
- raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
- raw_metrics.append(raw_metric)
- metrics = self.metric_cls.gather_measure(raw_metrics)
- self.logger_info("FPS:{}".format(total_frame / total_time))
- return metrics["recall"].avg, metrics["precision"].avg, metrics["fmeasure"].avg
- def _on_epoch_finish(self):
- self.logger_info(
- "[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}".format(
- self.epoch_result["epoch"],
- self.epochs,
- self.epoch_result["train_loss"],
- self.epoch_result["time"],
- self.epoch_result["lr"],
- )
- )
- net_save_path = "{}/model_latest.pth".format(self.checkpoint_dir)
- net_save_path_best = "{}/model_best.pth".format(self.checkpoint_dir)
- if paddle.distributed.get_rank() == 0:
- self._save_checkpoint(self.epoch_result["epoch"], net_save_path)
- save_best = False
- if (
- self.validate_loader is not None
- and self.metric_cls is not None
- and self.enable_eval
- ): # 使用f1作为最优模型指标
- recall, precision, hmean = self._eval(self.epoch_result["epoch"])
- if self.visualdl_enable:
- self.writer.add_scalar("EVAL/recall", recall, self.global_step)
- self.writer.add_scalar(
- "EVAL/precision", precision, self.global_step
- )
- self.writer.add_scalar("EVAL/hmean", hmean, self.global_step)
- self.logger_info(
- "test: recall: {:.6f}, precision: {:.6f}, hmean: {:.6f}".format(
- recall, precision, hmean
- )
- )
- if hmean >= self.metrics["hmean"]:
- save_best = True
- self.metrics["train_loss"] = self.epoch_result["train_loss"]
- self.metrics["hmean"] = hmean
- self.metrics["precision"] = precision
- self.metrics["recall"] = recall
- self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
- else:
- if self.epoch_result["train_loss"] <= self.metrics["train_loss"]:
- save_best = True
- self.metrics["train_loss"] = self.epoch_result["train_loss"]
- self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
- best_str = "current best, "
- for k, v in self.metrics.items():
- best_str += "{}: {:.6f}, ".format(k, v)
- self.logger_info(best_str)
- if save_best:
- import shutil
- shutil.copy(net_save_path, net_save_path_best)
- self.logger_info("Saving current best: {}".format(net_save_path_best))
- else:
- self.logger_info("Saving checkpoint: {}".format(net_save_path))
- def _on_train_finish(self):
- if self.enable_eval:
- for k, v in self.metrics.items():
- self.logger_info("{}:{}".format(k, v))
- self.logger_info("finish train")
- def _initialize_scheduler(self):
- if self.config["lr_scheduler"]["type"] == "Polynomial":
- self.config["lr_scheduler"]["args"]["epochs"] = self.config["trainer"][
- "epochs"
- ]
- self.config["lr_scheduler"]["args"]["step_each_epoch"] = len(
- self.train_loader
- )
- self.lr_scheduler = Polynomial(**self.config["lr_scheduler"]["args"])()
- else:
- self.lr_scheduler = self._initialize("lr_scheduler", paddle.optimizer.lr)
|