| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494 |
- import paddle
- import numpy as np
- import os
- import paddle.nn as nn
- import paddle.distributed as dist
- dist.get_world_size()
- dist.init_parallel_env()
- from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
- from optimizer import create_optimizer
- from data_loader import build_dataloader
- from metric import create_metric
- from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
- from config import preprocess
- import time
- from paddleslim.dygraph.quant import QAT
- from slim.slim_quant import PACT, quant_config
- from slim.slim_fpgm import prune_model
- from utils import load_model
- def _mkdir_if_not_exist(path, logger):
- """
- mkdir if not exists, ignore the exception when multiprocess mkdir together
- """
- if not os.path.exists(path):
- try:
- os.makedirs(path)
- except OSError as e:
- if e.errno == errno.EEXIST and os.path.isdir(path):
- logger.warning(
- "be happy if some process has already created {}".format(path)
- )
- else:
- raise OSError("Failed to mkdir {}".format(path))
- def save_model(
- model, optimizer, model_path, logger, is_best=False, prefix="ppocr", **kwargs
- ):
- """
- save model to the target path
- """
- _mkdir_if_not_exist(model_path, logger)
- model_prefix = os.path.join(model_path, prefix)
- paddle.save(model.state_dict(), model_prefix + ".pdparams")
- if type(optimizer) is list:
- paddle.save(optimizer[0].state_dict(), model_prefix + ".pdopt")
- paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + ".pdopt")
- else:
- paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
- # # save metric and config
- # with open(model_prefix + '.states', 'wb') as f:
- # pickle.dump(kwargs, f, protocol=2)
- if is_best:
- logger.info("save best model is to {}".format(model_prefix))
- else:
- logger.info("save model in {}".format(model_prefix))
- def amp_scaler(config):
- if "AMP" in config and config["AMP"]["use_amp"] is True:
- AMP_RELATED_FLAGS_SETTING = {
- "FLAGS_cudnn_batchnorm_spatial_persistent": 1,
- }
- paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
- scale_loss = config["AMP"].get("scale_loss", 1.0)
- use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling", False)
- scaler = paddle.amp.GradScaler(
- init_loss_scaling=scale_loss,
- use_dynamic_loss_scaling=use_dynamic_loss_scaling,
- )
- return scaler
- else:
- return None
- def set_seed(seed):
- paddle.seed(seed)
- np.random.seed(seed)
- def train(config, scaler=None):
- EPOCH = config["epoch"]
- topk = config["topk"]
- batch_size = config["TRAIN"]["batch_size"]
- num_workers = config["TRAIN"]["num_workers"]
- train_loader = build_dataloader(
- "train", batch_size=batch_size, num_workers=num_workers
- )
- # build metric
- metric_func = create_metric
- # build model
- # model = MobileNetV3_large_x0_5(class_dim=100)
- model = build_model(config)
- # build_optimizer
- optimizer, lr_scheduler = create_optimizer(
- config, parameter_list=model.parameters()
- )
- # load model
- pre_best_model_dict = load_model(config, model, optimizer)
- if len(pre_best_model_dict) > 0:
- pre_str = "The metric of loaded metric as follows {}".format(
- ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
- )
- logger.info(pre_str)
- # about slim prune and quant
- if "quant_train" in config and config["quant_train"] is True:
- quanter = QAT(config=quant_config, act_preprocess=PACT)
- quanter.quantize(model)
- elif "prune_train" in config and config["prune_train"] is True:
- model = prune_model(model, [1, 3, 32, 32], 0.1)
- else:
- pass
- # distribution
- model.train()
- model = paddle.DataParallel(model)
- # build loss function
- loss_func = build_loss(config)
- data_num = len(train_loader)
- best_acc = {}
- for epoch in range(EPOCH):
- st = time.time()
- for idx, data in enumerate(train_loader):
- img_batch, label = data
- img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
- label = paddle.unsqueeze(label, -1)
- if scaler is not None:
- with paddle.amp.auto_cast():
- outs = model(img_batch)
- else:
- outs = model(img_batch)
- # cal metric
- acc = metric_func(outs, label)
- # cal loss
- avg_loss = loss_func(outs, label)
- if scaler is None:
- # backward
- avg_loss.backward()
- optimizer.step()
- optimizer.clear_grad()
- else:
- scaled_avg_loss = scaler.scale(avg_loss)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer, scaled_avg_loss)
- if not isinstance(lr_scheduler, float):
- lr_scheduler.step()
- if idx % 10 == 0:
- et = time.time()
- strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
- strs += f"loss: {float(avg_loss)}"
- strs += (
- f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
- )
- strs += f", batch_time: {round(et-st, 4)} s"
- logger.info(strs)
- st = time.time()
- if epoch % 10 == 0:
- acc = eval(config, model)
- if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
- best_acc = acc
- best_acc["epoch"] = epoch
- is_best = True
- else:
- is_best = False
- logger.info(
- f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
- )
- save_model(
- model,
- optimizer,
- config["save_model_dir"],
- logger,
- is_best,
- prefix="cls",
- )
- def train_distill(config, scaler=None):
- EPOCH = config["epoch"]
- topk = config["topk"]
- batch_size = config["TRAIN"]["batch_size"]
- num_workers = config["TRAIN"]["num_workers"]
- train_loader = build_dataloader(
- "train", batch_size=batch_size, num_workers=num_workers
- )
- # build metric
- metric_func = create_metric
- # model = distillmv3_large_x0_5(class_dim=100)
- model = build_model(config)
- # pact quant train
- if "quant_train" in config and config["quant_train"] is True:
- quanter = QAT(config=quant_config, act_preprocess=PACT)
- quanter.quantize(model)
- elif "prune_train" in config and config["prune_train"] is True:
- model = prune_model(model, [1, 3, 32, 32], 0.1)
- else:
- pass
- # build_optimizer
- optimizer, lr_scheduler = create_optimizer(
- config, parameter_list=model.parameters()
- )
- # load model
- pre_best_model_dict = load_model(config, model, optimizer)
- if len(pre_best_model_dict) > 0:
- pre_str = "The metric of loaded metric as follows {}".format(
- ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
- )
- logger.info(pre_str)
- model.train()
- model = paddle.DataParallel(model)
- # build loss function
- loss_func_distill = LossDistill(model_name_list=["student", "student1"])
- loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
- loss_func_js = KLJSLoss(mode="js")
- data_num = len(train_loader)
- best_acc = {}
- for epoch in range(EPOCH):
- st = time.time()
- for idx, data in enumerate(train_loader):
- img_batch, label = data
- img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
- label = paddle.unsqueeze(label, -1)
- if scaler is not None:
- with paddle.amp.auto_cast():
- outs = model(img_batch)
- else:
- outs = model(img_batch)
- # cal metric
- acc = metric_func(outs["student"], label)
- # cal loss
- avg_loss = (
- loss_func_distill(outs, label)["student"]
- + loss_func_distill(outs, label)["student1"]
- + loss_func_dml(outs, label)["student_student1"]
- )
- # backward
- if scaler is None:
- avg_loss.backward()
- optimizer.step()
- optimizer.clear_grad()
- else:
- scaled_avg_loss = scaler.scale(avg_loss)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer, scaled_avg_loss)
- if not isinstance(lr_scheduler, float):
- lr_scheduler.step()
- if idx % 10 == 0:
- et = time.time()
- strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
- strs += f"loss: {float(avg_loss)}"
- strs += (
- f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
- )
- strs += f", batch_time: {round(et-st, 4)} s"
- logger.info(strs)
- st = time.time()
- if epoch % 10 == 0:
- acc = eval(config, model._layers.student)
- if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
- best_acc = acc
- best_acc["epoch"] = epoch
- is_best = True
- else:
- is_best = False
- logger.info(
- f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
- )
- save_model(
- model,
- optimizer,
- config["save_model_dir"],
- logger,
- is_best,
- prefix="cls_distill",
- )
- def train_distill_multiopt(config, scaler=None):
- EPOCH = config["epoch"]
- topk = config["topk"]
- batch_size = config["TRAIN"]["batch_size"]
- num_workers = config["TRAIN"]["num_workers"]
- train_loader = build_dataloader(
- "train", batch_size=batch_size, num_workers=num_workers
- )
- # build metric
- metric_func = create_metric
- # model = distillmv3_large_x0_5(class_dim=100)
- model = build_model(config)
- # build_optimizer
- optimizer, lr_scheduler = create_optimizer(
- config, parameter_list=model.student.parameters()
- )
- optimizer1, lr_scheduler1 = create_optimizer(
- config, parameter_list=model.student1.parameters()
- )
- # load model
- pre_best_model_dict = load_model(config, model, optimizer)
- if len(pre_best_model_dict) > 0:
- pre_str = "The metric of loaded metric as follows {}".format(
- ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
- )
- logger.info(pre_str)
- # quant train
- if "quant_train" in config and config["quant_train"] is True:
- quanter = QAT(config=quant_config, act_preprocess=PACT)
- quanter.quantize(model)
- elif "prune_train" in config and config["prune_train"] is True:
- model = prune_model(model, [1, 3, 32, 32], 0.1)
- else:
- pass
- model.train()
- model = paddle.DataParallel(model)
- # build loss function
- loss_func_distill = LossDistill(model_name_list=["student", "student1"])
- loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
- loss_func_js = KLJSLoss(mode="js")
- data_num = len(train_loader)
- best_acc = {}
- for epoch in range(EPOCH):
- st = time.time()
- for idx, data in enumerate(train_loader):
- img_batch, label = data
- img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
- label = paddle.unsqueeze(label, -1)
- if scaler is not None:
- with paddle.amp.auto_cast():
- outs = model(img_batch)
- else:
- outs = model(img_batch)
- # cal metric
- acc = metric_func(outs["student"], label)
- # cal loss
- avg_loss = (
- loss_func_distill(outs, label)["student"]
- + loss_func_dml(outs, label)["student_student1"]
- )
- avg_loss1 = (
- loss_func_distill(outs, label)["student1"]
- + loss_func_dml(outs, label)["student_student1"]
- )
- if scaler is None:
- # backward
- avg_loss.backward(retain_graph=True)
- optimizer.step()
- optimizer.clear_grad()
- avg_loss1.backward()
- optimizer1.step()
- optimizer1.clear_grad()
- else:
- scaled_avg_loss = scaler.scale(avg_loss)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer, scaled_avg_loss)
- scaled_avg_loss = scaler.scale(avg_loss1)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer1, scaled_avg_loss)
- if not isinstance(lr_scheduler, float):
- lr_scheduler.step()
- if not isinstance(lr_scheduler1, float):
- lr_scheduler1.step()
- if idx % 10 == 0:
- et = time.time()
- strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
- strs += f"loss: {float(avg_loss)}, loss1: {float(avg_loss1)}"
- strs += (
- f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
- )
- strs += f", batch_time: {round(et-st, 4)} s"
- logger.info(strs)
- st = time.time()
- if epoch % 10 == 0:
- acc = eval(config, model._layers.student)
- if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
- best_acc = acc
- best_acc["epoch"] = epoch
- is_best = True
- else:
- is_best = False
- logger.info(
- f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
- )
- save_model(
- model,
- [optimizer, optimizer1],
- config["save_model_dir"],
- logger,
- is_best,
- prefix="cls_distill_multiopt",
- )
- def eval(config, model):
- batch_size = config["VALID"]["batch_size"]
- num_workers = config["VALID"]["num_workers"]
- valid_loader = build_dataloader(
- "test", batch_size=batch_size, num_workers=num_workers
- )
- # build metric
- metric_func = create_metric
- outs = []
- labels = []
- for idx, data in enumerate(valid_loader):
- img_batch, label = data
- img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
- label = paddle.unsqueeze(label, -1)
- out = model(img_batch)
- outs.append(out)
- labels.append(label)
- outs = paddle.concat(outs, axis=0)
- labels = paddle.concat(labels, axis=0)
- acc = metric_func(outs, labels)
- strs = f"The metric are as follows: acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
- logger.info(strs)
- return acc
- if __name__ == "__main__":
- config, logger = preprocess(is_train=False)
- # AMP scaler
- scaler = amp_scaler(config)
- model_type = config["model_type"]
- if model_type == "cls":
- train(config)
- elif model_type == "cls_distill":
- train_distill(config)
- elif model_type == "cls_distill_multiopt":
- train_distill_multiopt(config)
- else:
- raise ValueError("model_type should be one of ['']")
|