io.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import copyreg
  16. import os
  17. import pickle
  18. import sys
  19. import threading
  20. import warnings
  21. from collections.abc import Iterable
  22. import numpy as np
  23. import paddle
  24. # deprecated module import
  25. from paddle import base
  26. from paddle.base import core
  27. from paddle.base.framework import (
  28. EagerParamBase,
  29. Program,
  30. Variable,
  31. _create_tensor,
  32. _current_expected_place,
  33. _current_expected_place_,
  34. _dygraph_tracer,
  35. in_dygraph_mode,
  36. in_pir_mode,
  37. )
  38. from .io_utils import (
  39. _is_file_path,
  40. _is_memory_buffer,
  41. _legacy_static_save,
  42. _open_file_buffer,
  43. _pack_loaded_dict,
  44. _pickle_loads_mac,
  45. _unpack_saved_dict,
  46. )
  47. __all__ = []
  48. async_save_queue = []
  49. def clear_async_save_task_queue():
  50. '''
  51. wait until all async save task to be done.
  52. '''
  53. while len(async_save_queue) > 0:
  54. task = async_save_queue.pop()
  55. if task and task.is_alive():
  56. task.join()
  57. def async_save(obj, path, protocol=4, sync_other_task=False, **configs):
  58. '''
  59. async version of paddle.save.
  60. Note:
  61. currently only support dygraph mode.
  62. Note:
  63. any argument passed through configs will be overridden by default setting.
  64. Args:
  65. obj(Object) : The object to be saved.
  66. path(str|BytesIO) : The path/buffer of the object to be saved.
  67. If saved in the current directory, the input path string will be used as the file name.
  68. protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
  69. Default: 4
  70. sync_other_task(bool) : Determine whether to wait other async save task to be finished before this one be put in queue.
  71. **configs(dict, optional): compatible argument to paddle.save, but will be overridden by default setting.
  72. Examples:
  73. .. code-block:: python
  74. :name: code-example-1
  75. import paddle
  76. emb = paddle.nn.Embedding(10, 10)
  77. layer_state_dict = emb.state_dict()
  78. # call paddle.async_save with the same style of paddle.save
  79. paddle.async_save(layer_state_dict, "emb.pdparams")
  80. for i in range(10):
  81. # do some calculations here
  82. # wait if any async_save task has not been done
  83. paddle.clear_async_task_queue()
  84. '''
  85. if not in_dygraph_mode():
  86. raise ValueError(
  87. "async_save currently is not supported in static mode."
  88. )
  89. if len(configs) > 0:
  90. warnings.warn(
  91. "configs are not supported in async mode, will be overridden by default settings."
  92. )
  93. # TODO: make this part async
  94. def move_state_dict_to_cpu(sd):
  95. for k, v in sd.items():
  96. if isinstance(v, dict):
  97. move_state_dict_to_cpu(v)
  98. elif isinstance(v, core.eager.Tensor):
  99. sd[k] = v.pin_memory() if core.is_compiled_with_cuda() else v
  100. if isinstance(obj, dict):
  101. move_state_dict_to_cpu(obj)
  102. elif isinstance(obj, core.eager.Tensor):
  103. obj = obj.pin_memory() if core.is_compiled_with_cuda() else obj
  104. else:
  105. # other types are currently not supported
  106. raise TypeError(
  107. f"currently async_save does not support this type: {type(obj)}"
  108. )
  109. if sync_other_task:
  110. clear_async_save_task_queue()
  111. t = threading.Thread(target=save, args=(obj, path, protocol))
  112. t.start()
  113. async_save_queue.append(t)
  114. def _build_saved_state_dict(state_dict):
  115. save_dict = {}
  116. name_table = {}
  117. for key, value in state_dict.items():
  118. if isinstance(value, (Variable, core.eager.Tensor)):
  119. if value.type == core.VarDesc.VarType.VOCAB:
  120. save_dict[key] = value.value().get_map_tensor()
  121. else:
  122. if not value.value().get_tensor()._is_initialized():
  123. raise ValueError(
  124. "The saved tensor is not initialized. If you used group sharded, please use save_group_sharded_model."
  125. )
  126. if value.is_dense() and value.place.is_custom_place():
  127. value = paddle._C_ops.npu_identity(value, -1)
  128. save_dict[key] = np.array(value.cpu())
  129. name_table[key] = value.name
  130. else:
  131. save_dict[key] = value
  132. save_dict["StructuredToParameterName@@"] = name_table
  133. return save_dict
  134. def _load_state_dict_from_save_inference_model(model_path, config):
  135. # 1. load program desc & construct _ProgramHolder
  136. # TODO(GGBond8488):From a long-term perspective, it is inappropriate for the framework to
  137. # rely on jit. It is necessary to migrate the dependency from jit to the framework in the future
  138. if in_pir_mode():
  139. from paddle.jit.pir_translated_layer import (
  140. _construct_params_and_buffers,
  141. _construct_program_holders,
  142. )
  143. programs = _construct_program_holders(model_path, config.model_filename)
  144. else:
  145. from paddle.jit.translated_layer import (
  146. _construct_params_and_buffers,
  147. _construct_program_holders,
  148. )
  149. programs = _construct_program_holders(model_path, config.model_filename)
  150. # 2. load layer parameters & buffers
  151. with base.dygraph.guard():
  152. persistable_var_dict = _construct_params_and_buffers(
  153. model_path, programs, config.params_filename
  154. )
  155. # 3. construct state_dict
  156. load_param_dict = {}
  157. for var_name in persistable_var_dict:
  158. tmp_var = persistable_var_dict[var_name]
  159. if tmp_var.is_dense() and tmp_var.place.is_custom_place():
  160. load_param_dict[var_name] = np.array(
  161. paddle._C_ops.npu_identity(tmp_var, -1).cpu()
  162. )
  163. else:
  164. load_param_dict[var_name] = np.array(tmp_var.cpu())
  165. # if *.info exists, we can recover structured_name
  166. var_info_filename = str(config.params_filename) + ".info"
  167. var_info_path = os.path.join(model_path, var_info_filename)
  168. if os.path.exists(var_info_path):
  169. with open(var_info_path, 'rb') as f:
  170. extra_var_info = pickle.load(f)
  171. structured_para_dict = {}
  172. for var_name in load_param_dict:
  173. structured_name = extra_var_info[var_name].get(
  174. 'structured_name', None
  175. )
  176. assert structured_name is not None, (
  177. "Cannot find saved variable (%s)'s structured name in saved model."
  178. % var_name
  179. )
  180. structured_para_dict[structured_name] = load_param_dict[
  181. var_name
  182. ]
  183. load_param_dict = structured_para_dict
  184. return load_param_dict
  185. def _load_state_dict_from_save_params(model_path):
  186. # Try to load all the files in the directory in Tensor format,
  187. # the file name is used as the name of Tensor
  188. load_var_list = []
  189. # 1. load file names
  190. var_name_list = []
  191. for root, _, files in os.walk(model_path):
  192. for filename in files:
  193. file_path = os.path.join(root, filename)
  194. tmp_var_name = os.path.relpath(file_path, model_path)
  195. var_name = tmp_var_name.replace("\\", "/")
  196. var_name_list.append(var_name)
  197. # 2. create and load Tensor
  198. with base.dygraph.guard():
  199. for name in var_name_list:
  200. new_var = _create_tensor(name=name, persistable=True)
  201. _dygraph_tracer().trace_op(
  202. type='load',
  203. inputs={},
  204. outputs={'Out': new_var},
  205. attrs={'file_path': os.path.join(model_path, name)},
  206. )
  207. load_var_list.append(new_var)
  208. # 3. construct state_dict
  209. load_param_dict = {}
  210. for var in load_var_list:
  211. if var.is_dense() and var.place.is_custom_place():
  212. var = paddle._C_ops.npu_identity(var, -1)
  213. load_param_dict[var.name] = np.array(var.cpu())
  214. return load_param_dict
  215. # NOTE(chenweihang): [ Handling of use cases of API paddle.load ]
  216. # `paddle.load` may be used to load saved results of:
  217. # 1. Expected cases:
  218. # - need [full filename] when loading
  219. # - paddle.save
  220. # - paddle.static.save
  221. # - need [prefix] when loading [compatible for paddle 2.x]
  222. # - paddle.jit.save
  223. # - paddle.static.save_inference_model
  224. # - need [directory] when loading [compatible for paddle 1.x]
  225. # - paddle.base.io.save_inference_model
  226. # - paddle.base.io.save_params/save_persistable
  227. # 2. Error cases:
  228. # - no error case
  229. def _build_load_path_and_config(path, config):
  230. # NOTE(chenweihang): If both [prefix save format] and [directory save format] exist,
  231. # raise error, avoid confusing behavior
  232. # TODO(GGBond8488):From a long-term perspective, it is inappropriate for the framework to
  233. # rely on jit. It is necessary to migrate the dependency from jit to the framework in the future
  234. from paddle.jit.pir_translated_layer import (
  235. PIR_INFER_MODEL_SUFFIX,
  236. )
  237. from paddle.jit.translated_layer import (
  238. INFER_MODEL_SUFFIX,
  239. INFER_PARAMS_SUFFIX,
  240. )
  241. if in_pir_mode():
  242. prefix_format_path = path + PIR_INFER_MODEL_SUFFIX
  243. else:
  244. prefix_format_path = path + INFER_MODEL_SUFFIX
  245. prefix_format_exist = os.path.exists(prefix_format_path)
  246. directory_format_exist = os.path.isdir(path)
  247. if prefix_format_exist and directory_format_exist:
  248. raise ValueError(
  249. f"The {path}.pdmodel and {path} directory exist at the same time, "
  250. "don't know which one to load, please make sure that the specified target "
  251. "of ``path`` is unique."
  252. )
  253. elif not prefix_format_exist and not directory_format_exist:
  254. error_msg = "The ``path`` (%s) to load model not exists."
  255. # if current path is a prefix, and the path.pdparams or path.pdopt
  256. # is exist, users may want use `paddle.load` load the result of
  257. # `base.save_dygraph`, we raise error here for users
  258. params_file_path = path + ".pdparams"
  259. opti_file_path = path + ".pdopt"
  260. if os.path.exists(params_file_path) or os.path.exists(opti_file_path):
  261. error_msg += (
  262. "please specify the full file name, not just the file name prefix. For "
  263. "example, it should be written as `paddle.load('model.pdparams')` instead of "
  264. "`paddle.load('model')`."
  265. )
  266. raise ValueError(error_msg % path)
  267. else:
  268. if prefix_format_exist:
  269. file_prefix = os.path.basename(path)
  270. model_path = os.path.dirname(path)
  271. if config.model_filename is not None:
  272. warnings.warn(
  273. "When loading the result saved with the "
  274. "specified file prefix, the ``model_filename`` config does "
  275. "not take effect."
  276. )
  277. if in_pir_mode():
  278. config.model_filename = file_prefix + PIR_INFER_MODEL_SUFFIX
  279. else:
  280. config.model_filename = file_prefix + INFER_MODEL_SUFFIX
  281. if config.params_filename is not None:
  282. warnings.warn(
  283. "When loading the result saved with the "
  284. "specified file prefix, the ``params_filename`` config does "
  285. "not take effect."
  286. )
  287. config.params_filename = file_prefix + INFER_PARAMS_SUFFIX
  288. else:
  289. # Compatible with the old save_inference_model format
  290. model_path = path
  291. return model_path, config
  292. def _parse_load_config(configs):
  293. supported_configs = [
  294. 'model_filename',
  295. 'params_filename',
  296. 'keep_name_table',
  297. 'return_numpy',
  298. ]
  299. # input check
  300. for key in configs:
  301. if key not in supported_configs:
  302. raise ValueError(
  303. "The additional config (%s) of `paddle.load` is not supported."
  304. % key
  305. )
  306. # construct inner config
  307. # TODO(GGBond8488):From a long-term perspective, it is inappropriate for the framework to
  308. # rely on jit. It is necessary to migrate the dependency from jit to the framework in the future
  309. from paddle.jit.api import _SaveLoadConfig
  310. inner_config = _SaveLoadConfig()
  311. inner_config.model_filename = configs.get('model_filename', None)
  312. inner_config.params_filename = configs.get('params_filename', None)
  313. inner_config.keep_name_table = configs.get('keep_name_table', None)
  314. inner_config.return_numpy = configs.get('return_numpy', False)
  315. return inner_config
  316. def _parse_save_config(configs):
  317. supported_configs = ['use_binary_format', 'pickle_protocol']
  318. # input check
  319. for key in configs:
  320. if key not in supported_configs:
  321. raise ValueError(
  322. "The additional config (%s) of `paddle.save` is not supported."
  323. % key
  324. )
  325. # construct inner config
  326. # TODO(GGBond8488):From a long-term perspective, it is inappropriate for the framework to
  327. # rely on jit. It is necessary to migrate the dependency from jit to the framework in the future
  328. from paddle.jit.api import _SaveLoadConfig
  329. inner_config = _SaveLoadConfig()
  330. inner_config.use_binary_format = configs.get('use_binary_format', False)
  331. inner_config.pickle_protocol = configs.get('pickle_protocol', None)
  332. return inner_config
  333. def _pickle_save(obj, f, protocol):
  334. # TODO(weixin):add support for BytesIO.
  335. if not isinstance(protocol, int):
  336. raise ValueError(
  337. f"The 'protocol' MUST be `int`, but received {type(protocol)}"
  338. )
  339. if protocol < 2 or protocol > 4:
  340. raise ValueError(
  341. f"Expected 1<'protocol'<5, but received protocol={protocol}"
  342. )
  343. def reduce_varbase(self):
  344. if self.is_dense() and self.place.is_custom_place():
  345. data = np.array(paddle._C_ops.npu_identity(self, -1).cpu())
  346. else:
  347. data = np.array(self.cpu())
  348. name = self.name
  349. return (tuple, ((name, data),))
  350. def reduce_LoDTensor(self):
  351. p = core.Place()
  352. p.set_place(paddle.CPUPlace())
  353. if self._place().is_custom_place():
  354. data = np.array(paddle._C_ops.npu_identity(self, -1)._copy(p))
  355. else:
  356. data = np.array(self._copy(p))
  357. return (eval, ('data', {'data': data}))
  358. def reduce_Layer(self):
  359. raise ValueError(
  360. "paddle do not support saving `paddle.nn.Layer` object."
  361. )
  362. dispatch_table_layer = {}
  363. def create_layer_dispatch_table(layer):
  364. dispatch_table_layer[layer.__class__] = reduce_Layer
  365. return layer
  366. _parse_every_object(
  367. obj,
  368. lambda v: isinstance(v, paddle.nn.Layer),
  369. create_layer_dispatch_table,
  370. )
  371. def add_dispatch_table():
  372. # This is not a good method, because the pickle module has been modified.
  373. pickle.dispatch_table[core.eager.Tensor] = reduce_varbase
  374. pickle.dispatch_table[EagerParamBase] = reduce_varbase
  375. pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor
  376. pickle.dispatch_table.update(dispatch_table_layer)
  377. def pop_dispatch_table():
  378. pickle.dispatch_table.pop(core.LoDTensor)
  379. pickle.dispatch_table.pop(core.eager.Tensor)
  380. pickle.dispatch_table.pop(EagerParamBase)
  381. for k in dispatch_table_layer:
  382. pickle.dispatch_table.pop(k)
  383. # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
  384. if sys.platform == 'darwin' and sys.version_info.major == 3:
  385. add_dispatch_table()
  386. pickle_bytes = pickle.dumps(obj)
  387. pop_dispatch_table()
  388. max_bytes = 2**30
  389. for i in range(0, len(pickle_bytes), max_bytes):
  390. f.write(pickle_bytes[i : i + max_bytes])
  391. else:
  392. pickler = pickle.Pickler(f, protocol)
  393. pickler.dispatch_table = copyreg.dispatch_table.copy()
  394. pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor
  395. pickler.dispatch_table[core.eager.Tensor] = reduce_varbase
  396. pickler.dispatch_table[EagerParamBase] = reduce_varbase
  397. pickler.dispatch_table.update(dispatch_table_layer)
  398. pickler.dump(obj)
  399. def _contain_x(obj, condition_func):
  400. if isinstance(obj, core.SelectedRows):
  401. raise NotImplementedError(
  402. "`paddle.save` do not support saving 'SelectedRows'."
  403. )
  404. if condition_func(obj):
  405. return True
  406. elif type(obj) in (dict, collections.OrderedDict, list, tuple):
  407. if type(obj) in (dict, collections.OrderedDict):
  408. keys = list(obj.keys())
  409. else:
  410. keys = range(len(obj))
  411. flag = False
  412. for key in keys:
  413. flag |= _contain_x(obj[key], condition_func)
  414. if flag:
  415. return True
  416. return flag
  417. else:
  418. return False
  419. def _is_state_dict(obj):
  420. if isinstance(obj, dict):
  421. def condition(obj):
  422. return isinstance(
  423. obj,
  424. (
  425. paddle.nn.Layer,
  426. Program,
  427. core.eager.Tensor,
  428. core.LoDTensor,
  429. core.SelectedRows,
  430. ),
  431. )
  432. # If the value of a dict is a core.Tensor/LoDTensor or a dict
  433. # that does not contain a paddle type(Layer, Program, Tensor, LoDTensor, SelectedRows),
  434. # the dict is considered to be a state_ dict.
  435. for key, value in obj.items():
  436. if isinstance(value, dict):
  437. for k, v in value.items():
  438. if _contain_x(v, condition):
  439. return False
  440. elif not isinstance(value, (core.eager.Tensor, core.LoDTensor)):
  441. return False
  442. return True
  443. return False
  444. def _transformed_from_varbase(obj):
  445. # In paddle2.1 version, Tensor is saved as tuple(tensor.name, tensor.numpy()).
  446. # When executing paddle.load, use this function to determine whether to restore to Tensor/LoDTensor.
  447. if isinstance(obj, tuple) and len(obj) == 2:
  448. name_types = str
  449. if isinstance(obj[0], name_types) and isinstance(obj[1], np.ndarray):
  450. return True
  451. return False
  452. def _transformed_from_lodtensor(obj):
  453. # In paddle2.1 version, LoDTensor is saved as np.array(tensor).
  454. # When executing paddle.load, use this function to determine whether to restore to Tensor/LoDTensor.
  455. if isinstance(obj, np.ndarray):
  456. return True
  457. return False
  458. def _to_LodTensor(ndarray):
  459. if not isinstance(ndarray, np.ndarray):
  460. raise TypeError(
  461. f'Type of `ndarray` should be numpy.ndarray, but received {type(ndarray)}.'
  462. )
  463. t = core.LoDTensor()
  464. place = _current_expected_place_()
  465. t.set(ndarray, place)
  466. return t
  467. def _tuple_to_tensor(obj, return_numpy):
  468. if return_numpy:
  469. return obj[1]
  470. if in_dygraph_mode():
  471. t = paddle.to_tensor(obj[1])
  472. # This function does modify the name of return value.
  473. # Loading the same variable multiple times may cause the same name.
  474. t.name = obj[0]
  475. return t
  476. else:
  477. return _to_LodTensor(obj[1])
  478. def _ndarray_to_tensor(obj, return_numpy):
  479. if return_numpy:
  480. return obj
  481. if in_dygraph_mode():
  482. return paddle.to_tensor(obj)
  483. else:
  484. return _to_LodTensor(obj)
  485. def _lod_tensor2varbase(tensor):
  486. return_var = _create_tensor()
  487. return_var.value().get_tensor().set(tensor, _current_expected_place())
  488. return return_var
  489. def _parse_every_object(obj, condition_func, convert_func):
  490. if condition_func(obj):
  491. return convert_func(obj)
  492. elif type(obj) in (dict, collections.OrderedDict, list):
  493. if type(obj) == list:
  494. keys = range(len(obj))
  495. else:
  496. keys = list(obj.keys())
  497. for key in keys:
  498. if condition_func(obj[key]):
  499. obj[key] = convert_func(obj[key])
  500. else:
  501. obj[key] = _parse_every_object(
  502. obj[key], condition_func, convert_func
  503. )
  504. return obj
  505. elif type(obj) == tuple:
  506. return tuple(
  507. _parse_every_object(list(obj), condition_func, convert_func)
  508. )
  509. elif type(obj) == set:
  510. return set(_parse_every_object(list(obj), condition_func, convert_func))
  511. else:
  512. if isinstance(obj, Iterable) and not isinstance(
  513. obj,
  514. (str, np.ndarray, core.eager.Tensor, core.LoDTensor),
  515. ):
  516. raise NotImplementedError(
  517. f"The iterable objects supported are tuple, list, dict, OrderedDict, string. But received {type(obj)}."
  518. )
  519. return obj
  520. def _parse_load_result(obj, return_numpy):
  521. def is_layer(obj):
  522. return isinstance(obj, paddle.nn.Layer)
  523. def parse_layer(obj):
  524. temp_dict = _parse_load_result(obj.__dict__, False)
  525. obj.__dict__.update(temp_dict)
  526. return obj
  527. if _contain_x(obj, is_layer):
  528. if not in_dygraph_mode():
  529. raise ValueError(
  530. "Layer can only be loaded in dynamic graph mode, but now in static graph mode."
  531. )
  532. _parse_every_object(obj, is_layer, parse_layer)
  533. def tuple_to_tensor(obj):
  534. return _tuple_to_tensor(obj, return_numpy=return_numpy)
  535. def ndarray_to_tensor(obj):
  536. return _ndarray_to_tensor(obj, return_numpy=return_numpy)
  537. # tuple(name, ndarray) was converted from varbase of paddle2.1,
  538. # and all tuple(name, ndarray) are converted to tensor.
  539. if _contain_x(obj, _transformed_from_varbase):
  540. return _parse_every_object(
  541. obj, _transformed_from_varbase, tuple_to_tensor
  542. )
  543. # If there is no tuple(name, ndarray), it is considered to be saved by paddle2.0
  544. # or converted from LoDTensor, and all ndarrays are converted to tensor.
  545. else:
  546. return _parse_every_object(
  547. obj, _transformed_from_lodtensor, ndarray_to_tensor
  548. )
  549. def _save_lod_tensor(tensor, file_name):
  550. if not tensor._is_initialized():
  551. raise ValueError(
  552. "The saved tensor is not initialized. If you used group sharded, please use save_group_sharded_model firstly."
  553. )
  554. if _is_file_path(file_name):
  555. _seek = core.save_lod_tensor(tensor, file_name)
  556. # '_seek' is the end position of this tensor in the file.
  557. elif _is_memory_buffer(file_name):
  558. tensor_bytes = core.save_lod_tensor_to_memory(tensor)
  559. with _open_file_buffer(file_name, 'wb') as f:
  560. f.write(tensor_bytes)
  561. _seek = f.tell()
  562. else:
  563. raise NotImplementedError(
  564. f'Only supports saving objects to file or BytesIO, but received {type(file_name)}'
  565. )
  566. return _seek
  567. def _load_lod_tensor(file_name):
  568. temp_t = paddle.base.core.LoDTensor()
  569. if _is_file_path(file_name):
  570. # '_seek' is the end position of this tensor in the file.
  571. _seek = paddle.base.core.load_lod_tensor(temp_t, file_name)
  572. elif _is_memory_buffer(file_name):
  573. with _open_file_buffer(file_name, 'rb') as f:
  574. tensor_bytes = f.read()
  575. paddle.base.core.load_lod_tensor_from_memory(temp_t, tensor_bytes)
  576. _seek = f.tell()
  577. else:
  578. raise NotImplementedError(
  579. f'Only supports load objects from file or BytesIO, but received {type(file_name)}'
  580. )
  581. return temp_t, _seek
  582. def _save_selected_rows(selected_rows, file_name):
  583. if not selected_rows.get_tensor()._is_initialized():
  584. raise ValueError("The saved tensor is not initialized.")
  585. if _is_file_path(file_name):
  586. # '_seek' is the end position of this SelectedRows in the file.
  587. _seek = core.save_selected_rows(selected_rows, file_name)
  588. elif _is_memory_buffer(file_name):
  589. selected_rows_bytes = core.save_selected_rows_to_memory(selected_rows)
  590. with _open_file_buffer(file_name, 'wb') as f:
  591. f.write(selected_rows_bytes)
  592. _seek = f.tell()
  593. else:
  594. raise NotImplementedError(
  595. f'Only supports saving objects to file or BytesIO, but received {type(file_name)}'
  596. )
  597. return _seek
  598. def _load_selected_rows(file_name):
  599. temp_sr = core.SelectedRows()
  600. if _is_file_path(file_name):
  601. # '_seek' is the end position of this SelectedRows in the file.
  602. _seek = core.load_selected_rows(temp_sr, file_name)
  603. elif _is_memory_buffer(file_name):
  604. with _open_file_buffer(file_name, 'rb') as f:
  605. selected_rows_bytes = f.read()
  606. paddle.base.core.load_selected_rows_from_memory(
  607. temp_sr, selected_rows_bytes
  608. )
  609. _seek = f.tell()
  610. else:
  611. raise NotImplementedError(
  612. f'Only supports load objects from file or BytesIO, but received {type(file_name)}'
  613. )
  614. return temp_sr, _seek
  615. def _save_binary_var(obj, path):
  616. if isinstance(obj, core.LoDTensor):
  617. _save_lod_tensor(obj, path)
  618. elif isinstance(obj, core.SelectedRows):
  619. _save_selected_rows(obj, path)
  620. elif isinstance(obj, core.eager.Tensor):
  621. _save_lod_tensor(obj.value().get_tensor(), path)
  622. else:
  623. # Since the concept of 'Tensor' is only exposed to users, the error message can only contain tensor instead of 'LoDTensor' or 'SelectedRows'
  624. raise NotImplementedError(
  625. f"When use_binary_format = True, `paddle.save` expected Tensor, but received {type(obj)}."
  626. )
  627. def save(obj, path, protocol=4, **configs):
  628. '''
  629. Save an object to the specified path.
  630. Note:
  631. Now supports saving ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program.
  632. Note:
  633. Different from ``paddle.jit.save``, since the save result of ``paddle.save`` is a single file,
  634. there is no need to distinguish multiple saved files by adding a suffix. The argument ``path``
  635. of ``paddle.save`` will be directly used as the saved file name instead of a prefix.
  636. In order to unify the saved file name format, we recommend using the paddle standard suffix:
  637. 1. for ``Layer.state_dict`` , recommend to use ``.pdparams`` ;
  638. 2. for ``Optimizer.state_dict`` , recommend to use ``.pdopt`` .
  639. For specific examples, please refer to API code examples.
  640. Args:
  641. obj(Object) : The object to be saved.
  642. path(str|BytesIO) : The path/buffer of the object to be saved.
  643. If saved in the current directory, the input path string will be used as the file name.
  644. protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
  645. Default: 4
  646. **configs(dict, optional): optional keyword arguments. The following options are currently supported:
  647. use_binary_format(bool): When the saved object is static graph variable, you can specify ``use_binary_for_var``.
  648. If True, save the file in the c++ binary format when saving a single static graph variable; otherwise, save it in pickle format.
  649. Default: False
  650. Returns:
  651. None
  652. Examples:
  653. .. code-block:: python
  654. :name: code-example-1
  655. >>> # example 1: dynamic graph
  656. >>> import paddle
  657. >>> emb = paddle.nn.Embedding(10, 10)
  658. >>> layer_state_dict = emb.state_dict()
  659. >>> # save state_dict of emb
  660. >>> paddle.save(layer_state_dict, "emb.pdparams")
  661. >>> scheduler = paddle.optimizer.lr.NoamDecay(
  662. ... d_model=0.01, warmup_steps=100, verbose=True)
  663. >>> adam = paddle.optimizer.Adam(
  664. ... learning_rate=scheduler,
  665. ... parameters=emb.parameters())
  666. >>> opt_state_dict = adam.state_dict()
  667. >>> # save state_dict of optimizer
  668. >>> paddle.save(opt_state_dict, "adam.pdopt")
  669. >>> # save weight of emb
  670. >>> paddle.save(emb.weight, "emb.weight.pdtensor")
  671. .. code-block:: python
  672. :name: code-example-2
  673. >>> # example 2: Save multiple state_dict at the same time
  674. >>> import paddle
  675. >>> from paddle import nn
  676. >>> from paddle.optimizer import Adam
  677. >>> layer = paddle.nn.Linear(3, 4)
  678. >>> adam = Adam(learning_rate=0.001, parameters=layer.parameters())
  679. >>> obj = {'model': layer.state_dict(), 'opt': adam.state_dict(), 'epoch': 100}
  680. >>> path = 'example/model.pdparams'
  681. >>> paddle.save(obj, path)
  682. .. code-block:: python
  683. :name: code-example-3
  684. >>> # example 3: static graph
  685. >>> import paddle
  686. >>> import paddle.static as static
  687. >>> paddle.enable_static()
  688. >>> # create network
  689. >>> x = paddle.static.data(name="x", shape=[None, 224], dtype='float32')
  690. >>> z = paddle.static.nn.fc(x, 10)
  691. >>> place = paddle.CPUPlace()
  692. >>> exe = paddle.static.Executor(place)
  693. >>> exe.run(paddle.static.default_startup_program())
  694. >>> prog = paddle.static.default_main_program()
  695. >>> for var in prog.list_vars():
  696. ... if list(var.shape) == [224, 10]:
  697. ... tensor = var.get_value()
  698. ... break
  699. >>> # save/load tensor
  700. >>> path_tensor = 'temp/tensor.pdtensor'
  701. >>> paddle.save(tensor, path_tensor)
  702. >>> # save/load state_dict
  703. >>> path_state_dict = 'temp/model.pdparams'
  704. >>> paddle.save(prog.state_dict("param"), path_tensor)
  705. .. code-block:: python
  706. :name: code-example-4
  707. >>> # example 4: save program
  708. >>> import paddle
  709. >>> paddle.enable_static()
  710. >>> data = paddle.static.data(
  711. ... name='x_static_save', shape=(None, 224), dtype='float32')
  712. >>> y_static = z = paddle.static.nn.fc(data, 10)
  713. >>> main_program = paddle.static.default_main_program()
  714. >>> path = "example/main_program.pdmodel"
  715. >>> paddle.save(main_program, path)
  716. .. code-block:: python
  717. :name: code-example-5
  718. >>> # example 5: save object to memory
  719. >>> from io import BytesIO
  720. >>> import paddle
  721. >>> from paddle.nn import Linear
  722. >>> paddle.disable_static()
  723. >>> linear = Linear(5, 10)
  724. >>> state_dict = linear.state_dict()
  725. >>> byio = BytesIO()
  726. >>> paddle.save(state_dict, byio)
  727. >>> paddle.seed(2023)
  728. >>> tensor = paddle.randn([2, 3], dtype='float32')
  729. >>> paddle.save(tensor, byio)
  730. '''
  731. if _is_file_path(path):
  732. # 1. input check
  733. filename = os.path.basename(path)
  734. if filename == "":
  735. raise ValueError(
  736. "The input path MUST be format of dirname/filename "
  737. "[dirname\\filename in Windows system], but received "
  738. "filename is empty string."
  739. )
  740. # 2. save object
  741. dirname = os.path.dirname(path)
  742. if dirname and not os.path.exists(dirname):
  743. os.makedirs(dirname, exist_ok=True)
  744. elif not _is_memory_buffer(path):
  745. raise ValueError(
  746. f"only supports saving objects to file and `BytesIO`, but got {type(path)}"
  747. )
  748. config = _parse_save_config(configs)
  749. if not isinstance(config.use_binary_format, bool):
  750. raise TypeError(
  751. f"Type of `use_binary_format` should be bool, but received {type(config.use_binary_format)}."
  752. )
  753. if config.use_binary_format:
  754. _save_binary_var(obj, path)
  755. else:
  756. # `protocol` need to be used, `pickle_protocol` is a deprecated arg.
  757. if config.pickle_protocol is not None:
  758. protocol = config.pickle_protocol
  759. warnings.warn(
  760. "'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
  761. )
  762. if isinstance(obj, paddle.static.Program):
  763. if in_pir_mode():
  764. paddle.core.serialize_pir_program(
  765. obj, path, 1, True, False, True
  766. )
  767. else:
  768. obj.desc.flush()
  769. with _open_file_buffer(path, "wb") as f:
  770. f.write(obj.desc.serialize_to_string())
  771. elif _is_state_dict(obj):
  772. if in_dygraph_mode():
  773. _legacy_save(obj, path, protocol)
  774. else:
  775. _legacy_static_save(obj, path, protocol)
  776. else:
  777. with _open_file_buffer(path, 'wb') as f:
  778. _pickle_save(obj, f, protocol)
  779. def _legacy_save(obj, path, protocol=2):
  780. # 1. input check
  781. if not isinstance(obj, dict):
  782. raise NotImplementedError(
  783. "Now only supports save state_dict of Layer or Optimizer, "
  784. "expect dict, but received %s." % type(obj)
  785. )
  786. if len(obj) == 0:
  787. warnings.warn("The input state dict is empty, no need to save.")
  788. if not isinstance(protocol, int):
  789. raise ValueError(
  790. f"The 'protocol' MUST be `int`, but received {type(protocol)}"
  791. )
  792. if protocol < 2 or protocol > 4:
  793. raise ValueError(
  794. f"Expected 1<'protocol'<5, but received protocol={protocol}"
  795. )
  796. if _is_file_path(path):
  797. filename = os.path.basename(path)
  798. if filename == "":
  799. raise ValueError(
  800. "The input path MUST be format of dirname/filename "
  801. "[dirname\\filename in Windows system], but received "
  802. "filename is empty string."
  803. )
  804. # 2. save object
  805. dirname = os.path.dirname(path)
  806. if dirname and not os.path.exists(dirname):
  807. os.makedirs(dirname, exist_ok=True)
  808. if isinstance(obj, dict):
  809. saved_obj = _build_saved_state_dict(obj)
  810. saved_obj = _unpack_saved_dict(saved_obj, protocol)
  811. # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
  812. if (
  813. _is_file_path(path)
  814. and sys.platform == 'darwin'
  815. and sys.version_info.major == 3
  816. ):
  817. pickle_bytes = pickle.dumps(saved_obj, protocol=protocol)
  818. with open(path, 'wb') as f:
  819. max_bytes = 2**30
  820. for i in range(0, len(pickle_bytes), max_bytes):
  821. f.write(pickle_bytes[i : i + max_bytes])
  822. else:
  823. with _open_file_buffer(path, 'wb') as f:
  824. pickle.dump(saved_obj, f, protocol=protocol)
  825. def load(path, **configs):
  826. '''
  827. Load an object can be used in paddle from specified path.
  828. Note:
  829. Now supports loading ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program.
  830. Note:
  831. In order to use the model parameters saved by paddle more efficiently,
  832. ``paddle.load`` supports loading ``state_dict`` of Layer from the result of
  833. other save APIs except ``paddle.save`` , but the argument ``path`` format is
  834. different:
  835. 1. loading from ``paddle.static.save`` or ``paddle.Model().save(training=True)`` ,
  836. ``path`` needs to be a complete file name, such as ``model.pdparams`` or
  837. ``model.pdopt`` ;
  838. 2. loading from ``paddle.jit.save`` or ``paddle.static.save_inference_model``
  839. or ``paddle.Model().save(training=False)`` , ``path`` need to be a file prefix,
  840. such as ``model/mnist``, and ``paddle.load`` will get information from
  841. ``mnist.pdmodel`` and ``mnist.pdiparams`` ;
  842. 3. loading from paddle 1.x APIs ``paddle.base.io.save_inference_model`` or
  843. ``paddle.base.io.save_params/save_persistables`` , ``path`` need to be a
  844. directory, such as ``model`` and model is a directory.
  845. Note:
  846. If you load ``state_dict`` from the saved result of static graph mode API such as
  847. ``paddle.static.save`` or ``paddle.static.save_inference_model`` ,
  848. the structured variable name in dynamic mode will cannot be restored.
  849. You need to set the argument ``use_structured_name=False`` when using
  850. ``Layer.set_state_dict`` later.
  851. Args:
  852. path(str|BytesIO) : The path/buffer to load the target object. Generally, the path is the target
  853. file path. When loading state_dict from the saved result of the API used to save
  854. the inference model, the path may be a file prefix or directory.
  855. **configs (dict, optional): other load configuration options for compatibility. We do not
  856. recommend using these configurations, they may be removed in the future. If not necessary,
  857. DO NOT use them. Default None.
  858. The following options are currently supported:
  859. (1) model_filename (str): The inference model file name of the paddle 1.x
  860. ``save_inference_model`` save format. Default file name is :code:`__model__` .
  861. (2) params_filename (str): The persistable variables file name of the paddle 1.x
  862. ``save_inference_model`` save format. No default file name, save variables separately
  863. by default.
  864. (3) return_numpy(bool): If specified as True, return tensor as numpy.ndarray, otherwise return tensor as paddle.Tensor.
  865. Default False.
  866. Returns:
  867. Object(Object): a target object can be used in paddle
  868. Examples:
  869. .. code-block:: python
  870. :name: code-example-1
  871. >>> # example 1: dynamic graph
  872. >>> import paddle
  873. >>> emb = paddle.nn.Embedding(10, 10)
  874. >>> layer_state_dict = emb.state_dict()
  875. >>> # save state_dict of emb
  876. >>> paddle.save(layer_state_dict, "emb.pdparams")
  877. >>> scheduler = paddle.optimizer.lr.NoamDecay(
  878. ... d_model=0.01, warmup_steps=100, verbose=True)
  879. >>> adam = paddle.optimizer.Adam(
  880. ... learning_rate=scheduler,
  881. ... parameters=emb.parameters())
  882. >>> opt_state_dict = adam.state_dict()
  883. >>> # save state_dict of optimizer
  884. >>> paddle.save(opt_state_dict, "adam.pdopt")
  885. >>> # save weight of emb
  886. >>> paddle.save(emb.weight, "emb.weight.pdtensor")
  887. >>> # load state_dict of emb
  888. >>> load_layer_state_dict = paddle.load("emb.pdparams")
  889. >>> # load state_dict of optimizer
  890. >>> load_opt_state_dict = paddle.load("adam.pdopt")
  891. >>> # load weight of emb
  892. >>> load_weight = paddle.load("emb.weight.pdtensor")
  893. .. code-block:: python
  894. :name: code-example-2
  895. >>> # example 2: Load multiple state_dict at the same time
  896. >>> import paddle
  897. >>> from paddle import nn
  898. >>> from paddle.optimizer import Adam
  899. >>> layer = paddle.nn.Linear(3, 4)
  900. >>> adam = Adam(learning_rate=0.001, parameters=layer.parameters())
  901. >>> obj = {'model': layer.state_dict(), 'opt': adam.state_dict(), 'epoch': 100}
  902. >>> path = 'example/model.pdparams'
  903. >>> paddle.save(obj, path)
  904. >>> obj_load = paddle.load(path)
  905. .. code-block:: python
  906. :name: code-example-3
  907. >>> # example 3: static graph
  908. >>> import paddle
  909. >>> import paddle.static as static
  910. >>> paddle.enable_static()
  911. >>> # create network
  912. >>> x = paddle.static.data(name="x", shape=[None, 224], dtype='float32')
  913. >>> z = paddle.static.nn.fc(x, 10)
  914. >>> place = paddle.CPUPlace()
  915. >>> exe = paddle.static.Executor(place)
  916. >>> exe.run(paddle.static.default_startup_program())
  917. >>> prog = paddle.static.default_main_program()
  918. >>> for var in prog.list_vars():
  919. ... if list(var.shape) == [224, 10]:
  920. ... tensor = var.get_value()
  921. ... break
  922. >>> # save/load tensor
  923. >>> path_tensor = 'temp/tensor.pdtensor'
  924. >>> paddle.save(tensor, path_tensor)
  925. >>> load_tensor = paddle.load(path_tensor)
  926. >>> # save/load state_dict
  927. >>> path_state_dict = 'temp/model.pdparams'
  928. >>> paddle.save(prog.state_dict("param"), path_tensor)
  929. >>> load_state_dict = paddle.load(path_tensor)
  930. .. code-block:: python
  931. :name: code-example-4
  932. >>> # example 4: load program
  933. >>> import paddle
  934. >>> paddle.enable_static()
  935. >>> data = paddle.static.data(
  936. ... name='x_static_save', shape=(None, 224), dtype='float32')
  937. >>> y_static = z = paddle.static.nn.fc(data, 10)
  938. >>> main_program = paddle.static.default_main_program()
  939. >>> path = "example/main_program.pdmodel"
  940. >>> paddle.save(main_program, path)
  941. >>> load_main = paddle.load(path)
  942. .. code-block:: python
  943. :name: code-example-5
  944. >>> # example 5: save object to memory
  945. >>> from io import BytesIO
  946. >>> import paddle
  947. >>> from paddle.nn import Linear
  948. >>> paddle.disable_static()
  949. >>> linear = Linear(5, 10)
  950. >>> state_dict = linear.state_dict()
  951. >>> byio = BytesIO()
  952. >>> paddle.save(state_dict, byio)
  953. >>> paddle.seed(2023)
  954. >>> tensor = paddle.randn([2, 3], dtype='float32')
  955. >>> paddle.save(tensor, byio)
  956. >>> byio.seek(0)
  957. >>> # load state_dict
  958. >>> dict_load = paddle.load(byio)
  959. '''
  960. if _is_memory_buffer(path) or os.path.isfile(path):
  961. config = _parse_load_config(configs)
  962. exception_type = pickle.UnpicklingError
  963. try:
  964. with _open_file_buffer(path, 'rb') as f:
  965. # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
  966. if (
  967. _is_file_path(path)
  968. and sys.platform == 'darwin'
  969. and sys.version_info.major == 3
  970. ):
  971. load_result = _pickle_loads_mac(path, f)
  972. else:
  973. load_result = pickle.load(f, encoding='latin1')
  974. # TODO(weixin):If `obj` is any object, the judgment condition should be more precise.
  975. if isinstance(load_result, dict):
  976. load_result = _pack_loaded_dict(load_result)
  977. # paddle2.0: paddle.save/load
  978. if "StructuredToParameterName@@" in load_result:
  979. for key, name in load_result[
  980. "StructuredToParameterName@@"
  981. ].items():
  982. if isinstance(load_result[key], np.ndarray):
  983. load_result[key] = _ndarray_to_tensor(
  984. load_result[key], config.return_numpy
  985. )
  986. # default name is "generatedxxx" which is set in Tensor init, if not set
  987. if not config.return_numpy and getattr(
  988. load_result[key], "name", ""
  989. ):
  990. load_result[key].name = name
  991. if (
  992. not config.keep_name_table
  993. and "StructuredToParameterName@@" in load_result
  994. ):
  995. del load_result["StructuredToParameterName@@"]
  996. else:
  997. # paddle2.1 static.save/load
  998. load_result = _parse_load_result(
  999. load_result, config.return_numpy
  1000. )
  1001. else:
  1002. load_result = _parse_load_result(
  1003. load_result, config.return_numpy
  1004. )
  1005. except exception_type as msg_pickle:
  1006. try:
  1007. tensor, _ = _load_selected_rows(path)
  1008. return tensor
  1009. except:
  1010. try:
  1011. tensor, _ = _load_lod_tensor(path)
  1012. if config.return_numpy:
  1013. p = core.Place()
  1014. p.set_place(paddle.CPUPlace())
  1015. if tensor._place().is_custom_place():
  1016. return np.array(
  1017. paddle._C_ops.npu_identity(tensor, -1)._copy(p)
  1018. )
  1019. else:
  1020. return np.array(tensor._copy(p))
  1021. else:
  1022. if in_dygraph_mode():
  1023. return _lod_tensor2varbase(tensor)
  1024. return tensor
  1025. except:
  1026. try:
  1027. if in_pir_mode():
  1028. program = paddle.static.Program()
  1029. paddle.core.deserialize_pir_program(
  1030. path, program, 1
  1031. )
  1032. return program
  1033. with _open_file_buffer(path, "rb") as f:
  1034. program_desc_str = f.read()
  1035. program = Program.parse_from_string(
  1036. program_desc_str
  1037. )
  1038. return program
  1039. except:
  1040. raise ValueError(
  1041. f"`paddle.load` can not parse the file:{path}."
  1042. )
  1043. else:
  1044. load_result = _legacy_load(path, **configs)
  1045. return load_result
  1046. def _legacy_load(path, **configs):
  1047. load_result = None
  1048. config = _parse_load_config(configs)
  1049. if os.path.isfile(path) or _is_memory_buffer(path):
  1050. # we think path is file means this file is created by paddle.save
  1051. with _open_file_buffer(path, 'rb') as f:
  1052. load_result = pickle.load(f, encoding='latin1')
  1053. load_result = _pack_loaded_dict(load_result)
  1054. if (
  1055. not config.keep_name_table
  1056. and "StructuredToParameterName@@" in load_result
  1057. ):
  1058. del load_result["StructuredToParameterName@@"]
  1059. else:
  1060. # file prefix and directory are compatible cases
  1061. model_path, config = _build_load_path_and_config(path, config)
  1062. # check whether model file exists
  1063. if config.model_filename is None:
  1064. model_filename = '__model__'
  1065. else:
  1066. model_filename = config.model_filename
  1067. model_file_path = os.path.join(model_path, model_filename)
  1068. if os.path.exists(model_file_path):
  1069. # Load state dict by `jit.save/io.save_inference_model` save format
  1070. # NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
  1071. # The model saved by `save_inference_model` does not completely correspond to
  1072. # the information required by the `state_dict` under the dygraph.
  1073. # `save_inference_model` not save structured name, we need to remind
  1074. # the user to configure the `use_structured_name` argument when `set_state_dict`
  1075. # NOTE(chenweihang): `jit.save` doesn't save optimizer state
  1076. load_result = _load_state_dict_from_save_inference_model(
  1077. model_path, config
  1078. )
  1079. else:
  1080. # load state dict by `io.save_params/persistables` save format
  1081. # TODO(chenweihang): [ Now only supports loading parameters separately ]
  1082. # If users save all parameters as one file, the [ variable.name -> variable ]
  1083. # mapping info will lost, so users need to give variable list, but users build
  1084. # variable list in dygraph mode is difficult, we recommend users to use
  1085. # paddle.static.load_program_state in this case
  1086. load_result = _load_state_dict_from_save_params(model_path)
  1087. return load_result