train.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. import paddle
  2. import numpy as np
  3. import os
  4. import paddle.nn as nn
  5. import paddle.distributed as dist
  6. dist.get_world_size()
  7. dist.init_parallel_env()
  8. from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
  9. from optimizer import create_optimizer
  10. from data_loader import build_dataloader
  11. from metric import create_metric
  12. from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
  13. from config import preprocess
  14. import time
  15. from paddleslim.dygraph.quant import QAT
  16. from slim.slim_quant import PACT, quant_config
  17. from slim.slim_fpgm import prune_model
  18. from utils import load_model
  19. def _mkdir_if_not_exist(path, logger):
  20. """
  21. mkdir if not exists, ignore the exception when multiprocess mkdir together
  22. """
  23. if not os.path.exists(path):
  24. try:
  25. os.makedirs(path)
  26. except OSError as e:
  27. if e.errno == errno.EEXIST and os.path.isdir(path):
  28. logger.warning(
  29. "be happy if some process has already created {}".format(path)
  30. )
  31. else:
  32. raise OSError("Failed to mkdir {}".format(path))
  33. def save_model(
  34. model, optimizer, model_path, logger, is_best=False, prefix="ppocr", **kwargs
  35. ):
  36. """
  37. save model to the target path
  38. """
  39. _mkdir_if_not_exist(model_path, logger)
  40. model_prefix = os.path.join(model_path, prefix)
  41. paddle.save(model.state_dict(), model_prefix + ".pdparams")
  42. if type(optimizer) is list:
  43. paddle.save(optimizer[0].state_dict(), model_prefix + ".pdopt")
  44. paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + ".pdopt")
  45. else:
  46. paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
  47. # # save metric and config
  48. # with open(model_prefix + '.states', 'wb') as f:
  49. # pickle.dump(kwargs, f, protocol=2)
  50. if is_best:
  51. logger.info("save best model is to {}".format(model_prefix))
  52. else:
  53. logger.info("save model in {}".format(model_prefix))
  54. def amp_scaler(config):
  55. if "AMP" in config and config["AMP"]["use_amp"] is True:
  56. AMP_RELATED_FLAGS_SETTING = {
  57. "FLAGS_cudnn_batchnorm_spatial_persistent": 1,
  58. }
  59. paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
  60. scale_loss = config["AMP"].get("scale_loss", 1.0)
  61. use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling", False)
  62. scaler = paddle.amp.GradScaler(
  63. init_loss_scaling=scale_loss,
  64. use_dynamic_loss_scaling=use_dynamic_loss_scaling,
  65. )
  66. return scaler
  67. else:
  68. return None
  69. def set_seed(seed):
  70. paddle.seed(seed)
  71. np.random.seed(seed)
  72. def train(config, scaler=None):
  73. EPOCH = config["epoch"]
  74. topk = config["topk"]
  75. batch_size = config["TRAIN"]["batch_size"]
  76. num_workers = config["TRAIN"]["num_workers"]
  77. train_loader = build_dataloader(
  78. "train", batch_size=batch_size, num_workers=num_workers
  79. )
  80. # build metric
  81. metric_func = create_metric
  82. # build model
  83. # model = MobileNetV3_large_x0_5(class_dim=100)
  84. model = build_model(config)
  85. # build_optimizer
  86. optimizer, lr_scheduler = create_optimizer(
  87. config, parameter_list=model.parameters()
  88. )
  89. # load model
  90. pre_best_model_dict = load_model(config, model, optimizer)
  91. if len(pre_best_model_dict) > 0:
  92. pre_str = "The metric of loaded metric as follows {}".format(
  93. ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
  94. )
  95. logger.info(pre_str)
  96. # about slim prune and quant
  97. if "quant_train" in config and config["quant_train"] is True:
  98. quanter = QAT(config=quant_config, act_preprocess=PACT)
  99. quanter.quantize(model)
  100. elif "prune_train" in config and config["prune_train"] is True:
  101. model = prune_model(model, [1, 3, 32, 32], 0.1)
  102. else:
  103. pass
  104. # distribution
  105. model.train()
  106. model = paddle.DataParallel(model)
  107. # build loss function
  108. loss_func = build_loss(config)
  109. data_num = len(train_loader)
  110. best_acc = {}
  111. for epoch in range(EPOCH):
  112. st = time.time()
  113. for idx, data in enumerate(train_loader):
  114. img_batch, label = data
  115. img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
  116. label = paddle.unsqueeze(label, -1)
  117. if scaler is not None:
  118. with paddle.amp.auto_cast():
  119. outs = model(img_batch)
  120. else:
  121. outs = model(img_batch)
  122. # cal metric
  123. acc = metric_func(outs, label)
  124. # cal loss
  125. avg_loss = loss_func(outs, label)
  126. if scaler is None:
  127. # backward
  128. avg_loss.backward()
  129. optimizer.step()
  130. optimizer.clear_grad()
  131. else:
  132. scaled_avg_loss = scaler.scale(avg_loss)
  133. scaled_avg_loss.backward()
  134. scaler.minimize(optimizer, scaled_avg_loss)
  135. if not isinstance(lr_scheduler, float):
  136. lr_scheduler.step()
  137. if idx % 10 == 0:
  138. et = time.time()
  139. strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
  140. strs += f"loss: {float(avg_loss)}"
  141. strs += (
  142. f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
  143. )
  144. strs += f", batch_time: {round(et-st, 4)} s"
  145. logger.info(strs)
  146. st = time.time()
  147. if epoch % 10 == 0:
  148. acc = eval(config, model)
  149. if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
  150. best_acc = acc
  151. best_acc["epoch"] = epoch
  152. is_best = True
  153. else:
  154. is_best = False
  155. logger.info(
  156. f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
  157. )
  158. save_model(
  159. model,
  160. optimizer,
  161. config["save_model_dir"],
  162. logger,
  163. is_best,
  164. prefix="cls",
  165. )
  166. def train_distill(config, scaler=None):
  167. EPOCH = config["epoch"]
  168. topk = config["topk"]
  169. batch_size = config["TRAIN"]["batch_size"]
  170. num_workers = config["TRAIN"]["num_workers"]
  171. train_loader = build_dataloader(
  172. "train", batch_size=batch_size, num_workers=num_workers
  173. )
  174. # build metric
  175. metric_func = create_metric
  176. # model = distillmv3_large_x0_5(class_dim=100)
  177. model = build_model(config)
  178. # pact quant train
  179. if "quant_train" in config and config["quant_train"] is True:
  180. quanter = QAT(config=quant_config, act_preprocess=PACT)
  181. quanter.quantize(model)
  182. elif "prune_train" in config and config["prune_train"] is True:
  183. model = prune_model(model, [1, 3, 32, 32], 0.1)
  184. else:
  185. pass
  186. # build_optimizer
  187. optimizer, lr_scheduler = create_optimizer(
  188. config, parameter_list=model.parameters()
  189. )
  190. # load model
  191. pre_best_model_dict = load_model(config, model, optimizer)
  192. if len(pre_best_model_dict) > 0:
  193. pre_str = "The metric of loaded metric as follows {}".format(
  194. ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
  195. )
  196. logger.info(pre_str)
  197. model.train()
  198. model = paddle.DataParallel(model)
  199. # build loss function
  200. loss_func_distill = LossDistill(model_name_list=["student", "student1"])
  201. loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
  202. loss_func_js = KLJSLoss(mode="js")
  203. data_num = len(train_loader)
  204. best_acc = {}
  205. for epoch in range(EPOCH):
  206. st = time.time()
  207. for idx, data in enumerate(train_loader):
  208. img_batch, label = data
  209. img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
  210. label = paddle.unsqueeze(label, -1)
  211. if scaler is not None:
  212. with paddle.amp.auto_cast():
  213. outs = model(img_batch)
  214. else:
  215. outs = model(img_batch)
  216. # cal metric
  217. acc = metric_func(outs["student"], label)
  218. # cal loss
  219. avg_loss = (
  220. loss_func_distill(outs, label)["student"]
  221. + loss_func_distill(outs, label)["student1"]
  222. + loss_func_dml(outs, label)["student_student1"]
  223. )
  224. # backward
  225. if scaler is None:
  226. avg_loss.backward()
  227. optimizer.step()
  228. optimizer.clear_grad()
  229. else:
  230. scaled_avg_loss = scaler.scale(avg_loss)
  231. scaled_avg_loss.backward()
  232. scaler.minimize(optimizer, scaled_avg_loss)
  233. if not isinstance(lr_scheduler, float):
  234. lr_scheduler.step()
  235. if idx % 10 == 0:
  236. et = time.time()
  237. strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
  238. strs += f"loss: {float(avg_loss)}"
  239. strs += (
  240. f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
  241. )
  242. strs += f", batch_time: {round(et-st, 4)} s"
  243. logger.info(strs)
  244. st = time.time()
  245. if epoch % 10 == 0:
  246. acc = eval(config, model._layers.student)
  247. if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
  248. best_acc = acc
  249. best_acc["epoch"] = epoch
  250. is_best = True
  251. else:
  252. is_best = False
  253. logger.info(
  254. f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
  255. )
  256. save_model(
  257. model,
  258. optimizer,
  259. config["save_model_dir"],
  260. logger,
  261. is_best,
  262. prefix="cls_distill",
  263. )
  264. def train_distill_multiopt(config, scaler=None):
  265. EPOCH = config["epoch"]
  266. topk = config["topk"]
  267. batch_size = config["TRAIN"]["batch_size"]
  268. num_workers = config["TRAIN"]["num_workers"]
  269. train_loader = build_dataloader(
  270. "train", batch_size=batch_size, num_workers=num_workers
  271. )
  272. # build metric
  273. metric_func = create_metric
  274. # model = distillmv3_large_x0_5(class_dim=100)
  275. model = build_model(config)
  276. # build_optimizer
  277. optimizer, lr_scheduler = create_optimizer(
  278. config, parameter_list=model.student.parameters()
  279. )
  280. optimizer1, lr_scheduler1 = create_optimizer(
  281. config, parameter_list=model.student1.parameters()
  282. )
  283. # load model
  284. pre_best_model_dict = load_model(config, model, optimizer)
  285. if len(pre_best_model_dict) > 0:
  286. pre_str = "The metric of loaded metric as follows {}".format(
  287. ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
  288. )
  289. logger.info(pre_str)
  290. # quant train
  291. if "quant_train" in config and config["quant_train"] is True:
  292. quanter = QAT(config=quant_config, act_preprocess=PACT)
  293. quanter.quantize(model)
  294. elif "prune_train" in config and config["prune_train"] is True:
  295. model = prune_model(model, [1, 3, 32, 32], 0.1)
  296. else:
  297. pass
  298. model.train()
  299. model = paddle.DataParallel(model)
  300. # build loss function
  301. loss_func_distill = LossDistill(model_name_list=["student", "student1"])
  302. loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
  303. loss_func_js = KLJSLoss(mode="js")
  304. data_num = len(train_loader)
  305. best_acc = {}
  306. for epoch in range(EPOCH):
  307. st = time.time()
  308. for idx, data in enumerate(train_loader):
  309. img_batch, label = data
  310. img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
  311. label = paddle.unsqueeze(label, -1)
  312. if scaler is not None:
  313. with paddle.amp.auto_cast():
  314. outs = model(img_batch)
  315. else:
  316. outs = model(img_batch)
  317. # cal metric
  318. acc = metric_func(outs["student"], label)
  319. # cal loss
  320. avg_loss = (
  321. loss_func_distill(outs, label)["student"]
  322. + loss_func_dml(outs, label)["student_student1"]
  323. )
  324. avg_loss1 = (
  325. loss_func_distill(outs, label)["student1"]
  326. + loss_func_dml(outs, label)["student_student1"]
  327. )
  328. if scaler is None:
  329. # backward
  330. avg_loss.backward(retain_graph=True)
  331. optimizer.step()
  332. optimizer.clear_grad()
  333. avg_loss1.backward()
  334. optimizer1.step()
  335. optimizer1.clear_grad()
  336. else:
  337. scaled_avg_loss = scaler.scale(avg_loss)
  338. scaled_avg_loss.backward()
  339. scaler.minimize(optimizer, scaled_avg_loss)
  340. scaled_avg_loss = scaler.scale(avg_loss1)
  341. scaled_avg_loss.backward()
  342. scaler.minimize(optimizer1, scaled_avg_loss)
  343. if not isinstance(lr_scheduler, float):
  344. lr_scheduler.step()
  345. if not isinstance(lr_scheduler1, float):
  346. lr_scheduler1.step()
  347. if idx % 10 == 0:
  348. et = time.time()
  349. strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
  350. strs += f"loss: {float(avg_loss)}, loss1: {float(avg_loss1)}"
  351. strs += (
  352. f", acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
  353. )
  354. strs += f", batch_time: {round(et-st, 4)} s"
  355. logger.info(strs)
  356. st = time.time()
  357. if epoch % 10 == 0:
  358. acc = eval(config, model._layers.student)
  359. if len(best_acc) < 1 or float(acc["top5"]) > best_acc["top5"]:
  360. best_acc = acc
  361. best_acc["epoch"] = epoch
  362. is_best = True
  363. else:
  364. is_best = False
  365. logger.info(
  366. f"The best acc: acc_topk1: {float(best_acc['top1'])}, acc_top5: {float(best_acc['top5'])}, best_epoch: {best_acc['epoch']}"
  367. )
  368. save_model(
  369. model,
  370. [optimizer, optimizer1],
  371. config["save_model_dir"],
  372. logger,
  373. is_best,
  374. prefix="cls_distill_multiopt",
  375. )
  376. def eval(config, model):
  377. batch_size = config["VALID"]["batch_size"]
  378. num_workers = config["VALID"]["num_workers"]
  379. valid_loader = build_dataloader(
  380. "test", batch_size=batch_size, num_workers=num_workers
  381. )
  382. # build metric
  383. metric_func = create_metric
  384. outs = []
  385. labels = []
  386. for idx, data in enumerate(valid_loader):
  387. img_batch, label = data
  388. img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
  389. label = paddle.unsqueeze(label, -1)
  390. out = model(img_batch)
  391. outs.append(out)
  392. labels.append(label)
  393. outs = paddle.concat(outs, axis=0)
  394. labels = paddle.concat(labels, axis=0)
  395. acc = metric_func(outs, labels)
  396. strs = f"The metric are as follows: acc_topk1: {float(acc['top1'])}, acc_top5: {float(acc['top5'])}"
  397. logger.info(strs)
  398. return acc
  399. if __name__ == "__main__":
  400. config, logger = preprocess(is_train=False)
  401. # AMP scaler
  402. scaler = amp_scaler(config)
  403. model_type = config["model_type"]
  404. if model_type == "cls":
  405. train(config)
  406. elif model_type == "cls_distill":
  407. train_distill(config)
  408. elif model_type == "cls_distill_multiopt":
  409. train_distill_multiopt(config)
  410. else:
  411. raise ValueError("model_type should be one of ['']")