save_load.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import errno
  18. import os
  19. import pickle
  20. import json
  21. from packaging import version
  22. import paddle
  23. from ppocr.utils.logging import get_logger
  24. from ppocr.utils.network import maybe_download_params
  25. try:
  26. import encryption # Attempt to import the encryption module for AIStudio's encryption model
  27. encrypted = encryption.is_encryption_needed()
  28. except ImportError:
  29. print("Skipping import of the encryption module.")
  30. encrypted = False # Encryption is not needed if the module cannot be imported
  31. __all__ = ["load_model"]
  32. # just to determine the inference model file format
  33. def get_FLAGS_json_format_model():
  34. # json format by default
  35. return os.environ.get("FLAGS_json_format_model", "1").lower() in ("1", "true", "t")
  36. FLAGS_json_format_model = get_FLAGS_json_format_model()
  37. def _mkdir_if_not_exist(path, logger):
  38. """
  39. mkdir if not exists, ignore the exception when multiprocess mkdir together
  40. """
  41. if not os.path.exists(path):
  42. try:
  43. os.makedirs(path)
  44. except OSError as e:
  45. if e.errno == errno.EEXIST and os.path.isdir(path):
  46. logger.warning(
  47. "be happy if some process has already created {}".format(path)
  48. )
  49. else:
  50. raise OSError("Failed to mkdir {}".format(path))
  51. def load_model(config, model, optimizer=None, model_type="det"):
  52. """
  53. load model from checkpoint or pretrained_model
  54. """
  55. logger = get_logger()
  56. global_config = config["Global"]
  57. checkpoints = global_config.get("checkpoints")
  58. pretrained_model = global_config.get("pretrained_model")
  59. best_model_dict = {}
  60. is_float16 = False
  61. is_nlp_model = model_type == "kie" and config["Architecture"]["algorithm"] not in [
  62. "SDMGR"
  63. ]
  64. if is_nlp_model is True:
  65. # NOTE: for kie model dsitillation, resume training is not supported now
  66. if config["Architecture"]["algorithm"] in ["Distillation"]:
  67. return best_model_dict
  68. checkpoints = config["Architecture"]["Backbone"]["checkpoints"]
  69. # load kie method metric
  70. if checkpoints:
  71. if os.path.exists(os.path.join(checkpoints, "metric.states")):
  72. with open(os.path.join(checkpoints, "metric.states"), "rb") as f:
  73. states_dict = pickle.load(f, encoding="latin1")
  74. best_model_dict = states_dict.get("best_model_dict", {})
  75. if "epoch" in states_dict:
  76. best_model_dict["start_epoch"] = states_dict["epoch"] + 1
  77. logger.info("resume from {}".format(checkpoints))
  78. if optimizer is not None:
  79. if checkpoints[-1] in ["/", "\\"]:
  80. checkpoints = checkpoints[:-1]
  81. if os.path.exists(checkpoints + ".pdopt"):
  82. optim_dict = paddle.load(checkpoints + ".pdopt")
  83. optimizer.set_state_dict(optim_dict)
  84. else:
  85. logger.warning(
  86. "{}.pdopt is not exists, params of optimizer is not loaded".format(
  87. checkpoints
  88. )
  89. )
  90. return best_model_dict
  91. if checkpoints:
  92. if checkpoints.endswith(".pdparams"):
  93. checkpoints = checkpoints.replace(".pdparams", "")
  94. assert os.path.exists(
  95. checkpoints + ".pdparams"
  96. ), "The {}.pdparams does not exists!".format(checkpoints)
  97. # load params from trained model
  98. params = paddle.load(checkpoints + ".pdparams")
  99. state_dict = model.state_dict()
  100. new_state_dict = {}
  101. for key, value in state_dict.items():
  102. if key not in params:
  103. logger.warning(
  104. "{} not in loaded params {} !".format(key, params.keys())
  105. )
  106. continue
  107. pre_value = params[key]
  108. if pre_value.dtype == paddle.float16:
  109. is_float16 = True
  110. if pre_value.dtype != value.dtype:
  111. pre_value = pre_value.astype(value.dtype)
  112. if list(value.shape) == list(pre_value.shape):
  113. new_state_dict[key] = pre_value
  114. else:
  115. logger.warning(
  116. "The shape of model params {} {} not matched with loaded params shape {} !".format(
  117. key, value.shape, pre_value.shape
  118. )
  119. )
  120. model.set_state_dict(new_state_dict)
  121. if is_float16:
  122. logger.info(
  123. "The parameter type is float16, which is converted to float32 when loading"
  124. )
  125. if optimizer is not None:
  126. if os.path.exists(checkpoints + ".pdopt"):
  127. optim_dict = paddle.load(checkpoints + ".pdopt")
  128. optimizer.set_state_dict(optim_dict)
  129. else:
  130. logger.warning(
  131. "{}.pdopt is not exists, params of optimizer is not loaded".format(
  132. checkpoints
  133. )
  134. )
  135. if os.path.exists(checkpoints + ".states"):
  136. with open(checkpoints + ".states", "rb") as f:
  137. states_dict = pickle.load(f, encoding="latin1")
  138. best_model_dict = states_dict.get("best_model_dict", {})
  139. best_model_dict["acc"] = 0.0
  140. if "epoch" in states_dict:
  141. best_model_dict["start_epoch"] = states_dict["epoch"] + 1
  142. logger.info("resume from {}".format(checkpoints))
  143. elif pretrained_model:
  144. is_float16 = load_pretrained_params(model, pretrained_model)
  145. else:
  146. logger.info("train from scratch")
  147. best_model_dict["is_float16"] = is_float16
  148. return best_model_dict
  149. def load_pretrained_params(model, path):
  150. logger = get_logger()
  151. path = maybe_download_params(path)
  152. if path.endswith(".pdparams"):
  153. path = path.replace(".pdparams", "")
  154. assert os.path.exists(
  155. path + ".pdparams"
  156. ), "The {}.pdparams does not exists!".format(path)
  157. params = paddle.load(path + ".pdparams")
  158. state_dict = model.state_dict()
  159. new_state_dict = {}
  160. is_float16 = False
  161. for k1 in params.keys():
  162. if k1 not in state_dict.keys():
  163. logger.warning("The pretrained params {} not in model".format(k1))
  164. else:
  165. if params[k1].dtype == paddle.float16:
  166. is_float16 = True
  167. if params[k1].dtype != state_dict[k1].dtype:
  168. params[k1] = params[k1].astype(state_dict[k1].dtype)
  169. if list(state_dict[k1].shape) == list(params[k1].shape):
  170. new_state_dict[k1] = params[k1]
  171. else:
  172. logger.warning(
  173. "The shape of model params {} {} not matched with loaded params {} {} !".format(
  174. k1, state_dict[k1].shape, k1, params[k1].shape
  175. )
  176. )
  177. model.set_state_dict(new_state_dict)
  178. if is_float16:
  179. logger.info(
  180. "The parameter type is float16, which is converted to float32 when loading"
  181. )
  182. logger.info("load pretrain successful from {}".format(path))
  183. return is_float16
  184. def save_model(
  185. model,
  186. optimizer,
  187. model_path,
  188. logger,
  189. config,
  190. is_best=False,
  191. prefix="ppocr",
  192. **kwargs,
  193. ):
  194. """
  195. save model to the target path
  196. """
  197. _mkdir_if_not_exist(model_path, logger)
  198. model_prefix = os.path.join(model_path, prefix)
  199. if prefix == "best_accuracy":
  200. best_model_path = os.path.join(model_path, "best_model")
  201. _mkdir_if_not_exist(best_model_path, logger)
  202. paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
  203. if prefix == "best_accuracy":
  204. paddle.save(
  205. optimizer.state_dict(), os.path.join(best_model_path, "model.pdopt")
  206. )
  207. is_nlp_model = config["Architecture"]["model_type"] == "kie" and config[
  208. "Architecture"
  209. ]["algorithm"] not in ["SDMGR"]
  210. if is_nlp_model is not True:
  211. paddle.save(model.state_dict(), model_prefix + ".pdparams")
  212. metric_prefix = model_prefix
  213. if prefix == "best_accuracy":
  214. paddle.save(
  215. model.state_dict(), os.path.join(best_model_path, "model.pdparams")
  216. )
  217. else: # for kie system, we follow the save/load rules in NLP
  218. if config["Global"]["distributed"]:
  219. arch = model._layers
  220. else:
  221. arch = model
  222. if config["Architecture"]["algorithm"] in ["Distillation"]:
  223. arch = arch.Student
  224. arch.backbone.model.save_pretrained(model_prefix)
  225. metric_prefix = os.path.join(model_prefix, "metric")
  226. if prefix == "best_accuracy":
  227. arch.backbone.model.save_pretrained(best_model_path)
  228. save_model_info = kwargs.pop("save_model_info", False)
  229. if save_model_info:
  230. with open(os.path.join(model_path, f"{prefix}.info.json"), "w") as f:
  231. json.dump(kwargs, f)
  232. logger.info("Already save model info in {}".format(model_path))
  233. if prefix != "latest":
  234. done_flag = kwargs.pop("done_flag", False)
  235. update_train_results(config, prefix, save_model_info, done_flag=done_flag)
  236. # save metric and config
  237. with open(metric_prefix + ".states", "wb") as f:
  238. pickle.dump(kwargs, f, protocol=2)
  239. if is_best:
  240. logger.info("save best model is to {}".format(model_prefix))
  241. else:
  242. logger.info("save model in {}".format(model_prefix))
  243. def update_train_results(config, prefix, metric_info, done_flag=False, last_num=5):
  244. if paddle.distributed.get_rank() != 0:
  245. return
  246. assert last_num >= 1
  247. train_results_path = os.path.join(
  248. config["Global"]["save_model_dir"], "train_result.json"
  249. )
  250. save_model_tag = ["pdparams", "pdopt", "pdstates"]
  251. paddle_version = version.parse(paddle.__version__)
  252. if FLAGS_json_format_model or paddle_version >= version.parse("3.0.0"):
  253. save_inference_files = {
  254. "inference_config": "inference.yml",
  255. "pdmodel": "inference.json",
  256. "pdiparams": "inference.pdiparams",
  257. }
  258. else:
  259. save_inference_files = {
  260. "inference_config": "inference.yml",
  261. "pdmodel": "inference.pdmodel",
  262. "pdiparams": "inference.pdiparams",
  263. "pdiparams.info": "inference.pdiparams.info",
  264. }
  265. if os.path.exists(train_results_path):
  266. with open(train_results_path, "r") as fp:
  267. train_results = json.load(fp)
  268. else:
  269. train_results = {}
  270. train_results["model_name"] = config["Global"]["model_name"]
  271. label_dict_path = config["Global"].get("character_dict_path", "")
  272. if label_dict_path != "":
  273. label_dict_path = os.path.abspath(label_dict_path)
  274. if not os.path.exists(label_dict_path):
  275. label_dict_path = ""
  276. train_results["label_dict"] = label_dict_path
  277. train_results["train_log"] = "train.log"
  278. train_results["visualdl_log"] = ""
  279. train_results["config"] = "config.yaml"
  280. train_results["models"] = {}
  281. for i in range(1, last_num + 1):
  282. train_results["models"][f"last_{i}"] = {}
  283. train_results["models"]["best"] = {}
  284. train_results["done_flag"] = done_flag
  285. if "best" in prefix:
  286. if "acc" in metric_info["metric"]:
  287. metric_score = metric_info["metric"]["acc"]
  288. elif "precision" in metric_info["metric"]:
  289. metric_score = metric_info["metric"]["precision"]
  290. elif "exp_rate" in metric_info["metric"]:
  291. metric_score = metric_info["metric"]["exp_rate"]
  292. else:
  293. raise ValueError("No metric score found.")
  294. train_results["models"]["best"]["score"] = metric_score
  295. for tag in save_model_tag:
  296. if tag == "pdparams" and encrypted:
  297. train_results["models"]["best"][tag] = os.path.join(
  298. prefix,
  299. (
  300. f"{prefix}.encrypted.{tag}"
  301. if tag != "pdstates"
  302. else f"{prefix}.states"
  303. ),
  304. )
  305. else:
  306. train_results["models"]["best"][tag] = os.path.join(
  307. prefix,
  308. f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states",
  309. )
  310. for key in save_inference_files:
  311. train_results["models"]["best"][key] = os.path.join(
  312. prefix, "inference", save_inference_files[key]
  313. )
  314. else:
  315. for i in range(last_num - 1, 0, -1):
  316. train_results["models"][f"last_{i + 1}"] = train_results["models"][
  317. f"last_{i}"
  318. ].copy()
  319. if "acc" in metric_info["metric"]:
  320. metric_score = metric_info["metric"]["acc"]
  321. elif "precision" in metric_info["metric"]:
  322. metric_score = metric_info["metric"]["precision"]
  323. elif "exp_rate" in metric_info["metric"]:
  324. metric_score = metric_info["metric"]["exp_rate"]
  325. else:
  326. metric_score = 0
  327. train_results["models"][f"last_{1}"]["score"] = metric_score
  328. for tag in save_model_tag:
  329. if tag == "pdparams" and encrypted:
  330. train_results["models"][f"last_{1}"][tag] = os.path.join(
  331. prefix,
  332. (
  333. f"{prefix}.encrypted.{tag}"
  334. if tag != "pdstates"
  335. else f"{prefix}.states"
  336. ),
  337. )
  338. else:
  339. train_results["models"][f"last_{1}"][tag] = os.path.join(
  340. prefix,
  341. f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states",
  342. )
  343. for key in save_inference_files:
  344. train_results["models"][f"last_{1}"][key] = os.path.join(
  345. prefix, "inference", save_inference_files[key]
  346. )
  347. with open(train_results_path, "w") as fp:
  348. json.dump(train_results, fp)