io.py 74 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import errno
  15. import inspect
  16. import logging
  17. import os
  18. import pickle
  19. import sys
  20. import warnings
  21. import numpy as np
  22. import paddle
  23. from paddle.base import (
  24. Program,
  25. Variable,
  26. core,
  27. default_main_program,
  28. program_guard,
  29. unique_name,
  30. )
  31. from paddle.base.executor import Executor, global_scope
  32. from paddle.base.framework import (
  33. Parameter,
  34. dygraph_not_support,
  35. in_pir_mode,
  36. process_type_promotion,
  37. static_only,
  38. )
  39. from paddle.base.log_helper import get_logger
  40. from paddle.framework.io_utils import (
  41. _clone_var_in_block_,
  42. _load_program_scope,
  43. _pack_loaded_dict,
  44. _pickle_loads_mac,
  45. _unpack_saved_dict,
  46. is_belong_to_optimizer,
  47. is_parameter,
  48. is_persistable,
  49. )
  50. from .io_utils import (
  51. _check_args,
  52. _check_vars,
  53. _get_valid_program,
  54. _normalize_path_prefix,
  55. _safe_load_pickle,
  56. )
  57. from .pir_io import (
  58. get_pir_parameters,
  59. load_pir,
  60. load_pir_inference_model,
  61. load_vars_pir,
  62. normalize_pir_program,
  63. save_pir,
  64. save_pir_inference_model,
  65. save_vars_pir,
  66. )
  67. __all__ = []
  68. _logger = get_logger(
  69. __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
  70. )
  71. def _clone_var_in_block(block, var):
  72. assert isinstance(var, Variable)
  73. if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
  74. return block.create_var(
  75. name=var.name,
  76. shape=var.shape,
  77. dtype=var.dtype,
  78. type=var.type,
  79. lod_level=var.lod_level,
  80. persistable=True,
  81. )
  82. else:
  83. return block.create_var(
  84. name=var.name,
  85. shape=var.shape,
  86. dtype=var.dtype,
  87. type=var.type,
  88. persistable=True,
  89. )
  90. def prepend_feed_ops(
  91. inference_program, feed_target_names, feed_holder_name='feed'
  92. ):
  93. if len(feed_target_names) == 0:
  94. return
  95. global_block = inference_program.global_block()
  96. feed_var = global_block.create_var(
  97. name=feed_holder_name,
  98. type=core.VarDesc.VarType.FEED_MINIBATCH,
  99. persistable=True,
  100. )
  101. for i, name in enumerate(feed_target_names):
  102. if not global_block.has_var(name):
  103. raise ValueError(
  104. f"The feeded_var_names[{i}]: '{name}' doesn't exist in pruned inference program. "
  105. f"Please check whether '{name}' is a valid feed_var name, or remove it from feeded_var_names "
  106. f"if '{name}' is not involved in the target_vars calculation."
  107. )
  108. out = global_block.var(name)
  109. global_block._prepend_op(
  110. type='feed',
  111. inputs={'X': [feed_var]},
  112. outputs={'Out': [out]},
  113. attrs={'col': i},
  114. )
  115. def append_fetch_ops(
  116. inference_program, fetch_target_names, fetch_holder_name='fetch'
  117. ):
  118. global_block = inference_program.global_block()
  119. fetch_var = global_block.create_var(
  120. name=fetch_holder_name,
  121. type=core.VarDesc.VarType.FETCH_LIST,
  122. persistable=True,
  123. )
  124. for i, name in enumerate(fetch_target_names):
  125. global_block.append_op(
  126. type='fetch',
  127. inputs={'X': [name]},
  128. outputs={'Out': [fetch_var]},
  129. attrs={'col': i},
  130. )
  131. def normalize_program(program, feed_vars, fetch_vars, **kwargs):
  132. """
  133. Normalize/Optimize a program according to feed_vars and fetch_vars.
  134. Args:
  135. program(Program): Specify a program you want to optimize.
  136. feed_vars(Tensor | list[Tensor]): Variables needed by inference.
  137. fetch_vars(Tensor | list[Tensor]): Variables returned by inference.
  138. kwargs: Supported keys including ``skip_prune_program``.
  139. - skip_prune_program(bool): whether to skip pruning program. Defaults to False.
  140. Returns:
  141. Program: Normalized/Optimized program.
  142. Examples:
  143. .. code-block:: python
  144. >>> import paddle
  145. >>> paddle.enable_static()
  146. >>> path_prefix = "./infer_model"
  147. # User defined network, here a softmax regression example
  148. >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
  149. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  150. >>> predict = paddle.static.nn.fc(image, 10, activation='softmax')
  151. >>> loss = paddle.nn.functional.cross_entropy(predict, label)
  152. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  153. >>> exe.run(paddle.static.default_startup_program())
  154. # normalize main program.
  155. >>> program = paddle.static.default_main_program()
  156. >>> normalized_program = paddle.static.normalize_program(program, [image], [predict])
  157. """
  158. if in_pir_mode():
  159. return normalize_pir_program(program, feed_vars, fetch_vars, **kwargs)
  160. if not isinstance(program, Program):
  161. raise TypeError(
  162. "program type must be `base.Program`, but received `%s`"
  163. % type(program)
  164. )
  165. if not isinstance(feed_vars, list):
  166. feed_vars = [feed_vars]
  167. if not all(isinstance(v, Variable) for v in feed_vars):
  168. raise TypeError(
  169. "feed_vars type must be a Variable or a list of Variable."
  170. )
  171. if not isinstance(fetch_vars, list):
  172. fetch_vars = [fetch_vars]
  173. if not all(isinstance(v, Variable) for v in fetch_vars):
  174. raise TypeError(
  175. "fetch_vars type must be a Variable or a list of Variable."
  176. )
  177. # remind users to set auc_states to 0 if auc op were found.
  178. for op in program.global_block().ops:
  179. # clear device of Op
  180. device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
  181. op._set_attr(device_attr_name, "")
  182. if op.type == 'auc':
  183. warnings.warn(
  184. "Be sure that you have set auc states to 0 before saving inference model."
  185. )
  186. break
  187. # fix the bug that the activation op's output as target will be pruned.
  188. # will affect the inference performance.
  189. # TODO(Superjomn) add an IR pass to remove 1-scale op.
  190. with program_guard(program):
  191. uniq_fetch_vars = []
  192. for i, var in enumerate(fetch_vars):
  193. if var.dtype != paddle.bool:
  194. var = paddle.scale(var, 1.0, name=f"save_infer_model/scale_{i}")
  195. uniq_fetch_vars.append(var)
  196. fetch_vars = uniq_fetch_vars
  197. # serialize program
  198. copy_program = program.clone()
  199. global_block = copy_program.global_block()
  200. remove_op_idx = []
  201. for i, op in enumerate(global_block.ops):
  202. op.desc.set_is_target(False)
  203. if op.type == "feed" or op.type == "fetch":
  204. remove_op_idx.append(i)
  205. if op.type == "pylayer":
  206. sub_blocks_ids = op._blocks_attr_ids("blocks")
  207. if len(sub_blocks_ids) > 1:
  208. # pylayer op ``blocks`` attr contains forward block id and backward block id
  209. backward_block_id = sub_blocks_ids[-1]
  210. # remove backward block
  211. copy_program.blocks.pop(backward_block_id)
  212. # update attrs ``blocks``
  213. reserved_blocks = []
  214. for block_id in sub_blocks_ids[:-1]:
  215. reserved_blocks.append(copy_program.block(block_id))
  216. op._update_desc_attr("blocks", reserved_blocks)
  217. for idx in remove_op_idx[::-1]:
  218. global_block._remove_op(idx)
  219. copy_program.desc.flush()
  220. feed_var_names = [var.name for var in feed_vars]
  221. skip_prune_program = kwargs.get('skip_prune_program', False)
  222. if not skip_prune_program:
  223. copy_program = copy_program._prune_with_input(
  224. feeded_var_names=feed_var_names, targets=fetch_vars
  225. )
  226. copy_program = copy_program._inference_optimize(prune_read_op=True)
  227. fetch_var_names = [var.name for var in fetch_vars]
  228. prepend_feed_ops(copy_program, feed_var_names)
  229. append_fetch_ops(copy_program, fetch_var_names)
  230. copy_program.desc._set_version()
  231. return copy_program
  232. @static_only
  233. def serialize_program(feed_vars, fetch_vars, **kwargs):
  234. """
  235. Serialize default main program according to feed_vars and fetch_vars.
  236. Args:
  237. feed_vars(Tensor | list[Tensor]): Tensor needed by inference.
  238. fetch_vars(Tensor | list[Tensor]): Tensor returned by inference.
  239. kwargs: Supported keys including ``program``. Attention please, kwargs is used for backward compatibility mainly.
  240. - program(Program): specify a program if you don't want to use default main program.
  241. - legacy_format(bool): whether to save inference program in legacy format. Defaults to False.
  242. Returns:
  243. bytes: serialized program.
  244. Examples:
  245. .. code-block:: python
  246. >>> import paddle
  247. >>> paddle.enable_static()
  248. >>> path_prefix = "./infer_model"
  249. # User defined network, here a softmax regression example
  250. >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
  251. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  252. >>> predict = paddle.static.nn.fc(image, 10, activation='softmax')
  253. >>> loss = paddle.nn.functional.cross_entropy(predict, label)
  254. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  255. >>> exe.run(paddle.static.default_startup_program())
  256. # serialize the default main program to bytes.
  257. >>> serialized_program = paddle.static.serialize_program([image], [predict])
  258. # deserialize bytes to program
  259. >>> deserialized_program = paddle.static.deserialize_program(serialized_program)
  260. """
  261. # verify feed_vars
  262. _check_vars('feed_vars', feed_vars)
  263. # verify fetch_vars
  264. _check_vars('fetch_vars', fetch_vars)
  265. program = _get_valid_program(kwargs.get('program', None))
  266. program = normalize_program(program, feed_vars, fetch_vars)
  267. legacy_format = kwargs.get('legacy_format', False)
  268. return _serialize_program(program, legacy_format=legacy_format)
  269. def _serialize_program(program, legacy_format=False):
  270. """
  271. serialize given program to bytes.
  272. """
  273. return program.desc.serialize_to_string(legacy_format=legacy_format)
  274. @static_only
  275. def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs):
  276. """
  277. Serialize parameters using given executor and default main program according to feed_vars and fetch_vars.
  278. Args:
  279. feed_vars(Tensor | list[Tensor]): Tensor needed by inference.
  280. fetch_vars(Tensor | list[Tensor]): Tensor returned by inference.
  281. kwargs: Supported keys including ``program``. Attention please, kwargs is used for backward compatibility mainly.
  282. - program(Program): specify a program if you don't want to use default main program.
  283. Returns:
  284. bytes: serialized program.
  285. Examples:
  286. .. code-block:: python
  287. >>> import paddle
  288. >>> paddle.enable_static()
  289. >>> path_prefix = "./infer_model"
  290. # User defined network, here a softmax regression example
  291. >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
  292. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  293. >>> predict = paddle.static.nn.fc(image, 10, activation='softmax')
  294. >>> loss = paddle.nn.functional.cross_entropy(predict, label)
  295. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  296. >>> exe.run(paddle.static.default_startup_program())
  297. # serialize parameters to bytes.
  298. >>> serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
  299. # deserialize bytes to parameters.
  300. >>> main_program = paddle.static.default_main_program()
  301. >>> deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe)
  302. """
  303. # verify feed_vars
  304. _check_vars('feed_vars', feed_vars)
  305. # verify fetch_vars
  306. _check_vars('fetch_vars', fetch_vars)
  307. program = _get_valid_program(kwargs.get('program', None))
  308. program = normalize_program(program, feed_vars, fetch_vars)
  309. return _serialize_persistables(program, executor)
  310. def _serialize_persistables(program, executor):
  311. """
  312. Serialize parameters using given program and executor.
  313. """
  314. vars_ = list(filter(is_persistable, program.list_vars()))
  315. # warn if no variable found in model
  316. if len(vars_) == 0:
  317. warnings.warn(
  318. "no variable in your model, please ensure there are any "
  319. "variables in your model to save"
  320. )
  321. return None
  322. # create a new program and clone persistable vars to it
  323. save_program = Program()
  324. save_block = save_program.global_block()
  325. save_var_map = {}
  326. for var in vars_:
  327. if var.type != core.VarDesc.VarType.RAW:
  328. var_copy = _clone_var_in_block(save_block, var)
  329. save_var_map[var_copy.name] = var
  330. # create in_vars and out_var, then append a save_combine op to save_program
  331. in_vars = []
  332. for name in sorted(save_var_map.keys()):
  333. in_vars.append(save_var_map[name])
  334. out_var_name = unique_name.generate("out_var")
  335. out_var = save_block.create_var(
  336. type=core.VarDesc.VarType.RAW, name=out_var_name
  337. )
  338. out_var.desc.set_persistable(True)
  339. save_block.append_op(
  340. type='save_combine',
  341. inputs={'X': in_vars},
  342. outputs={'Y': out_var},
  343. attrs={'file_path': '', 'save_to_memory': True},
  344. )
  345. # run save_program to save vars
  346. # NOTE(zhiqiu): save op will add variable kLookupTablePath to save_program.desc,
  347. # which leads to diff between save_program and its desc. Call _sync_with_cpp
  348. # to keep consistency.
  349. save_program._sync_with_cpp()
  350. executor.run(save_program)
  351. # return serialized bytes in out_var
  352. return global_scope().find_var(out_var_name).get_bytes()
  353. def save_to_file(path, content):
  354. """
  355. Save content to given path.
  356. Args:
  357. path(str): Path to write content to.
  358. content(bytes): Content to write.
  359. Returns:
  360. None
  361. Examples:
  362. .. code-block:: python
  363. >>> import paddle
  364. >>> paddle.enable_static()
  365. >>> path_prefix = "./infer_model"
  366. # 用户自定义网络,此处用 softmax 回归为例。
  367. >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
  368. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  369. >>> predict = paddle.static.nn.fc(image, 10, activation='softmax')
  370. >>> loss = paddle.nn.functional.cross_entropy(predict, label)
  371. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  372. >>> exe.run(paddle.static.default_startup_program())
  373. # 序列化参数
  374. >>> serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
  375. # 将序列化之后的参数保存到文件
  376. >>> params_path = path_prefix + ".params"
  377. >>> paddle.static.save_to_file(params_path, serialized_params)
  378. """
  379. if not isinstance(content, bytes):
  380. raise ValueError("'content' type should be bytes.")
  381. with open(path, "wb") as f:
  382. f.write(content)
  383. @static_only
  384. def save_inference_model(
  385. path_prefix, feed_vars, fetch_vars, executor, **kwargs
  386. ):
  387. """
  388. Save current model and its parameters to given path. i.e.
  389. Given ``path_prefix = "PATH/modelname"``, after invoking
  390. ``save_inference_model(path_prefix, feed_vars, fetch_vars, executor)``,
  391. you will find two files named ``modelname.pdmodel`` and ``modelname.pdiparams``
  392. under ``PATH``, which represent your model and parameters respectively.
  393. Args:
  394. path_prefix(str): Directory path to save model + model name without suffix.
  395. feed_vars(Tensor | list[Tensor]): Variables needed by inference.
  396. fetch_vars(Tensor | list[Tensor]): Variables returned by inference.
  397. executor(Executor): The executor that saves the inference model. You can refer
  398. to :ref:`api_guide_executor_en` for more details.
  399. kwargs: Supported keys including 'program' and "clip_extra". Attention please, kwargs is used for backward compatibility mainly.
  400. - program(Program): specify a program if you don't want to use default main program.
  401. - clip_extra(bool): the flag indicating whether to clip extra information for every operator. Default: True.
  402. - legacy_format(bool): whether to save inference model in legacy format. Default: False.
  403. Returns:
  404. None
  405. Examples:
  406. .. code-block:: python
  407. >>> import paddle
  408. >>> paddle.enable_static()
  409. >>> path_prefix = "./infer_model"
  410. # User defined network, here a softmax regression example
  411. >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
  412. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  413. >>> predict = paddle.static.nn.fc(image, 10, activation='softmax')
  414. >>> loss = paddle.nn.functional.cross_entropy(predict, label)
  415. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  416. >>> exe.run(paddle.static.default_startup_program())
  417. # Feed data and train process
  418. # Save inference model. Note we don't save label and loss in this example
  419. >>> paddle.static.save_inference_model(path_prefix, [image], [predict], exe)
  420. # In this example, the save_inference_mode inference will prune the default
  421. # main program according to the network's input node (img) and output node(predict).
  422. # The pruned inference program is going to be saved in file "./infer_model.pdmodel"
  423. # and parameters are going to be saved in file "./infer_model.pdiparams".
  424. """
  425. if in_pir_mode():
  426. save_pir_inference_model(
  427. path_prefix, feed_vars, fetch_vars, executor, **kwargs
  428. )
  429. return
  430. # check path_prefix, set model_path and params_path
  431. path_prefix = _normalize_path_prefix(path_prefix)
  432. try:
  433. # mkdir may conflict if pserver and trainer are running on the same machine
  434. dirname = os.path.dirname(path_prefix)
  435. os.makedirs(dirname)
  436. except OSError as e:
  437. if e.errno != errno.EEXIST:
  438. raise
  439. model_path = path_prefix + ".pdmodel"
  440. params_path = path_prefix + ".pdiparams"
  441. if os.path.isdir(model_path):
  442. raise ValueError(f"'{model_path}' is an existing directory.")
  443. if os.path.isdir(params_path):
  444. raise ValueError(f"'{params_path}' is an existing directory.")
  445. # verify feed_vars
  446. _check_vars('feed_vars', feed_vars)
  447. # verify fetch_vars
  448. _check_vars('fetch_vars', fetch_vars)
  449. program = _get_valid_program(kwargs.get('program', None))
  450. # do type promotion
  451. program = process_type_promotion(program)
  452. clip_extra = kwargs.get('clip_extra', True)
  453. # serialize and save program
  454. program = normalize_program(
  455. program,
  456. feed_vars,
  457. fetch_vars,
  458. skip_prune_program=kwargs.get('skip_prune_program', False),
  459. )
  460. legacy_format = kwargs.get('legacy_format', False)
  461. program_bytes = _serialize_program(
  462. program._remove_training_info(clip_extra=clip_extra),
  463. legacy_format=legacy_format,
  464. )
  465. save_to_file(model_path, program_bytes)
  466. vars = list(filter(is_persistable, program.list_vars()))
  467. if len(list(vars)) == 0:
  468. warnings.warn(
  469. "no variable in your model, please ensure there are any variables in your model to save"
  470. )
  471. if len(vars) > 0:
  472. save_dirname = os.path.dirname(params_path)
  473. params_filename = os.path.basename(params_path)
  474. save_vars(
  475. executor,
  476. dirname=save_dirname,
  477. main_program=program,
  478. predicate=is_persistable,
  479. filename=params_filename,
  480. )
  481. @static_only
  482. def deserialize_program(data):
  483. """
  484. Deserialize given data to a program.
  485. Args:
  486. data(bytes): serialized program.
  487. Returns:
  488. Program: deserialized program.
  489. Examples:
  490. .. code-block:: python
  491. >>> import paddle
  492. >>> paddle.enable_static()
  493. >>> path_prefix = "./infer_model"
  494. # User defined network, here a softmax regression example
  495. >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
  496. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  497. >>> predict = paddle.static.nn.fc(image, 10, activation='softmax')
  498. >>> loss = paddle.nn.functional.cross_entropy(predict, label)
  499. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  500. >>> exe.run(paddle.static.default_startup_program())
  501. # serialize the default main program to bytes.
  502. >>> serialized_program = paddle.static.serialize_program([image], [predict])
  503. # deserialize bytes to program
  504. >>> deserialized_program = paddle.static.deserialize_program(serialized_program)
  505. """
  506. program = Program.parse_from_string(data)
  507. if not core._is_program_version_supported(program._version()):
  508. raise ValueError(
  509. "Unsupported program version: %d\n" % program._version()
  510. )
  511. return program
  512. # NOTE(liuyuanle): Due to load from memory, deserialize_persistables does not support loading weights with file sizes exceeding 2GB.
  513. @static_only
  514. def deserialize_persistables(program, data, executor):
  515. """
  516. Deserialize given data to parameters according to given program and executor.
  517. Args:
  518. program(Program): program that contains parameter names (to deserialize).
  519. data(bytes): serialized parameters.
  520. executor(Executor): executor used to run load op.
  521. Returns:
  522. Program: deserialized program.
  523. Examples:
  524. .. code-block:: python
  525. >>> import paddle
  526. >>> paddle.enable_static()
  527. >>> path_prefix = "./infer_model"
  528. # User defined network, here a softmax regression example
  529. >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
  530. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  531. >>> predict = paddle.static.nn.fc(image, 10, activation='softmax')
  532. >>> loss = paddle.nn.functional.cross_entropy(predict, label)
  533. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  534. >>> exe.run(paddle.static.default_startup_program())
  535. # serialize parameters to bytes.
  536. >>> serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
  537. # deserialize bytes to parameters.
  538. >>> main_program = paddle.static.default_main_program()
  539. >>> deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe)
  540. """
  541. if not isinstance(program, Program):
  542. raise TypeError(
  543. "program type must be `base.Program`, but received `%s`"
  544. % type(program)
  545. )
  546. # load params to a tmp program
  547. load_program = Program()
  548. load_block = load_program.global_block()
  549. vars_ = list(filter(is_persistable, program.list_vars()))
  550. origin_shape_map = {}
  551. load_var_map = {}
  552. check_vars = []
  553. sparse_vars = []
  554. for var in vars_:
  555. assert isinstance(var, Variable)
  556. if var.type == core.VarDesc.VarType.RAW:
  557. continue
  558. if isinstance(var, Parameter):
  559. origin_shape_map[var.name] = tuple(var.desc.get_shape())
  560. if var.type == core.VarDesc.VarType.SELECTED_ROWS:
  561. sparse_vars.append(var)
  562. continue
  563. var_copy = _clone_var_in_block(load_block, var)
  564. check_vars.append(var)
  565. load_var_map[var_copy.name] = var_copy
  566. if data is None:
  567. assert (
  568. len(origin_shape_map) == 0
  569. ), "Required 'data' shall be not None if program contains parameter, but received 'data' is None."
  570. return
  571. # append load_combine op to load parameters,
  572. load_var_list = []
  573. for name in sorted(load_var_map.keys()):
  574. load_var_list.append(load_var_map[name])
  575. load_block.append_op(
  576. type='load_combine',
  577. inputs={},
  578. outputs={"Out": load_var_list},
  579. # if load from memory, file_path is data
  580. attrs={'file_path': data, 'model_from_memory': True},
  581. )
  582. executor.run(load_program)
  583. # check var shape
  584. for var in check_vars:
  585. if not isinstance(var, Parameter):
  586. continue
  587. var_tmp = paddle.base.global_scope().find_var(var.name)
  588. assert var_tmp is not None, "can't not find var: " + var.name
  589. new_shape = (np.array(var_tmp.get_tensor())).shape
  590. assert var.name in origin_shape_map, var.name + " MUST in var list."
  591. origin_shape = origin_shape_map.get(var.name)
  592. if new_shape != origin_shape:
  593. raise RuntimeError(
  594. f"Shape mismatch, program needs a parameter with shape ({origin_shape}), "
  595. f"but the loaded parameter ('{var.name}') has a shape of ({new_shape})."
  596. )
  597. def load_from_file(path):
  598. """
  599. Load file in binary mode.
  600. Args:
  601. path(str): Path of an existed file.
  602. Returns:
  603. bytes: Content of file.
  604. Examples:
  605. .. code-block:: python
  606. >>> import paddle
  607. >>> paddle.enable_static()
  608. >>> path_prefix = "./infer_model"
  609. # 用户自定义网络,此处用 softmax 回归为例。
  610. >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
  611. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  612. >>> predict = paddle.static.nn.fc(image, 10, activation='softmax')
  613. >>> loss = paddle.nn.functional.cross_entropy(predict, label)
  614. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  615. >>> exe.run(paddle.static.default_startup_program())
  616. # 序列化参数
  617. >>> serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
  618. # 将序列化之后的参数保存到文件
  619. >>> params_path = path_prefix + ".params"
  620. >>> paddle.static.save_to_file(params_path, serialized_params)
  621. # 从文件加载序列化之后的参数
  622. >>> serialized_params_copy = paddle.static.load_from_file(params_path)
  623. """
  624. with open(path, 'rb') as f:
  625. data = f.read()
  626. return data
  627. @static_only
  628. def load_inference_model(path_prefix, executor, **kwargs):
  629. """
  630. Load inference model from a given path. By this API, you can get the model
  631. structure(Inference Program) and model parameters.
  632. Args:
  633. path_prefix(str | None): One of the following:
  634. - Directory path to save model + model name without suffix.
  635. - Set to None when reading the model from memory.
  636. executor(Executor): The executor to run for loading inference model.
  637. See :ref:`api_guide_executor_en` for more details about it.
  638. kwargs: Supported keys including 'model_filename', 'params_filename'. Attention please, kwargs is used for backward compatibility mainly.
  639. - model_filename(str): specify model_filename if you don't want to use default name.
  640. - params_filename(str): specify params_filename if you don't want to use default name.
  641. Returns:
  642. list: The return of this API is a list with three elements:
  643. (program, feed_target_names, fetch_targets). The `program` is a
  644. ``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference.
  645. The `feed_target_names` is a list of ``str``, which contains names of variables
  646. that need to feed data in the inference program. The `fetch_targets` is a list of
  647. ``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which
  648. we can get inference results.
  649. Examples:
  650. .. code-block:: python
  651. >>> import paddle
  652. >>> import numpy as np
  653. >>> paddle.enable_static()
  654. # Build the model
  655. >>> startup_prog = paddle.static.default_startup_program()
  656. >>> main_prog = paddle.static.default_main_program()
  657. >>> with paddle.static.program_guard(main_prog, startup_prog):
  658. ... image = paddle.static.data(name="img", shape=[64, 784])
  659. ... w = paddle.create_parameter(shape=[784, 200], dtype='float32')
  660. ... b = paddle.create_parameter(shape=[200], dtype='float32')
  661. ... hidden_w = paddle.matmul(x=image, y=w)
  662. ... hidden_b = paddle.add(hidden_w, b)
  663. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  664. >>> exe.run(startup_prog)
  665. # Save the inference model
  666. >>> path_prefix = "./infer_model"
  667. >>> paddle.static.save_inference_model(path_prefix, [image], [hidden_b], exe)
  668. >>> [inference_program, feed_target_names, fetch_targets] = (
  669. ... paddle.static.load_inference_model(path_prefix, exe))
  670. >>> tensor_img = np.array(np.random.random((64, 784)), dtype=np.float32)
  671. >>> results = exe.run(inference_program,
  672. ... feed={feed_target_names[0]: tensor_img},
  673. ... fetch_list=fetch_targets)
  674. # In this example, the inference program was saved in file
  675. # "./infer_model.pdmodel" and parameters were saved in file
  676. # " ./infer_model.pdiparams".
  677. # By the inference program, feed_target_names and
  678. # fetch_targets, we can use an executor to run the inference
  679. # program to get the inference result.
  680. """
  681. if in_pir_mode():
  682. return load_pir_inference_model(path_prefix, executor, **kwargs)
  683. # check kwargs
  684. supported_args = ('model_filename', 'params_filename')
  685. deprecated_args = ('pserver_endpoints',)
  686. caller = inspect.currentframe().f_code.co_name
  687. _check_args(caller, kwargs, supported_args, deprecated_args)
  688. # load from memory
  689. if path_prefix is None:
  690. _logger.warning(
  691. "Load inference model from memory is deprecated. Please specify path_prefix."
  692. )
  693. model_filename = kwargs.get('model_filename', None)
  694. params_filename = kwargs.get('params_filename', None)
  695. if params_filename is None:
  696. raise ValueError(
  697. "params_filename cannot be None when path_prefix is None."
  698. )
  699. program_bytes = model_filename
  700. # deserialize bytes to program
  701. program = deserialize_program(program_bytes)
  702. # do type promotion
  703. program = process_type_promotion(program)
  704. vars = list(filter(is_persistable, program.list_vars()))
  705. if len(vars) > 0:
  706. load_vars(
  707. executor,
  708. # load from memory, dirname is None
  709. dirname=None,
  710. main_program=program,
  711. predicate=is_persistable,
  712. filename=params_filename,
  713. )
  714. # load from file
  715. else:
  716. # check and norm path_prefix
  717. path_prefix = _normalize_path_prefix(path_prefix)
  718. dir_path = os.path.dirname(path_prefix)
  719. if not os.path.isdir(dir_path):
  720. raise ValueError(f"There is no directory named {dir_path}")
  721. # set model_path and params_path in new way,
  722. # path_prefix represents a file path without suffix in this case.
  723. if not kwargs:
  724. model_path = path_prefix + ".pdmodel"
  725. params_path = path_prefix + ".pdiparams"
  726. # set model_path and params_path in old way for compatible,
  727. # path_prefix represents a directory path.
  728. else:
  729. model_filename = kwargs.get('model_filename', None)
  730. params_filename = kwargs.get('params_filename', None)
  731. # set model_path
  732. if model_filename is None:
  733. model_path = os.path.join(path_prefix, "__model__")
  734. else:
  735. model_path = os.path.join(
  736. path_prefix, model_filename + ".pdmodel"
  737. )
  738. if not os.path.exists(model_path):
  739. model_path = os.path.join(path_prefix, model_filename)
  740. # set params_path
  741. if params_filename is None:
  742. params_path = os.path.join(path_prefix, "")
  743. else:
  744. params_path = os.path.join(
  745. path_prefix, params_filename + ".pdiparams"
  746. )
  747. if not os.path.exists(params_path):
  748. params_path = os.path.join(path_prefix, params_filename)
  749. _logger.warning(
  750. "The old way to load inference model is deprecated. Please specify path_prefix."
  751. f" model path: {model_path}, params path: {params_path}"
  752. )
  753. program_bytes = load_from_file(model_path)
  754. # deserialize bytes to program
  755. program = deserialize_program(program_bytes)
  756. # do type promotion
  757. program = process_type_promotion(program)
  758. vars = list(filter(is_persistable, program.list_vars()))
  759. if len(vars) > 0:
  760. load_dirname = os.path.dirname(params_path)
  761. params_filename = os.path.basename(params_path)
  762. load_vars(
  763. executor,
  764. dirname=load_dirname,
  765. main_program=program,
  766. predicate=is_persistable,
  767. filename=params_filename,
  768. )
  769. feed_target_names = program.desc.get_feed_target_names()
  770. fetch_target_names = program.desc.get_fetch_target_names()
  771. fetch_targets = [
  772. program.global_block().var(name) for name in fetch_target_names
  773. ]
  774. return [program, feed_target_names, fetch_targets]
  775. @dygraph_not_support
  776. def save_vars(
  777. executor,
  778. dirname,
  779. main_program=None,
  780. vars=None,
  781. predicate=None,
  782. filename=None,
  783. ):
  784. """
  785. Save specific variables in the `Program` to files.
  786. There are two ways to specify the variables to be saved: set variables in
  787. a list and assign it to the `vars`, or use the `predicate` function to select
  788. variables that make `predicate(variable) == True`. The first way has a higher priority.
  789. The `dirname` is used to specify the folder where to save variables.
  790. If you prefer to save variables in separate files in the `dirname` folder,
  791. do not set `filename`. If you prefer to save all variables in a single file,
  792. use `filename` to specify it.
  793. Args:
  794. executor(Executor): The executor to run for saving variables.
  795. dirname(str, optional): The folder where to save variables.
  796. When you need to save the parameter to the memory, set it to None.
  797. main_program(Program, optional): The program whose variables will be saved.
  798. If it is None, the default main program will
  799. be used automatically.
  800. Default: None
  801. vars(list[Variable], optional): The list contains all variables to be saved.
  802. Default: None
  803. predicate(function, optional): The function selects the variables that make
  804. `predicate(variable) == True`.
  805. Default: None
  806. filename(str, optional): If you prefer to save all variables in a single file,
  807. use `filename` to specify it. Otherwise, let `filename` be None.
  808. Default: None
  809. Returns:
  810. str: When saving parameters to a file, returns None.
  811. When saving parameters to memory, returns a binary string containing parameters.
  812. Raises:
  813. TypeError: If `main_program` is not an instance of Program nor None.
  814. Examples:
  815. .. code-block:: python
  816. >>> import paddle
  817. >>> import paddle.static as static
  818. >>> paddle.enable_static()
  819. >>> main_prog = static.Program()
  820. >>> startup_prog = static.Program()
  821. >>> with static.program_guard(main_prog, startup_prog):
  822. ... data = paddle.static.data(name="img", shape=[64, 784])
  823. ... w = paddle.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
  824. ... b = paddle.create_parameter(shape=[200], dtype='float32', name='fc_b')
  825. ... hidden_w = paddle.matmul(x=data, y=w)
  826. ... hidden_b = paddle.add(hidden_w, b)
  827. >>> place = static.CPUPlace()
  828. >>> exe = static.Executor(place)
  829. >>> exe.run(startup_prog)
  830. # The first usage: use `vars` to set the saved variables.
  831. >>> var_list = [w, b]
  832. >>> path = "./my_paddle_vars"
  833. # w and b will be save in a file named "var_file".
  834. >>> paddle.static.io.save_vars(executor=exe, dirname=path, vars=var_list,
  835. ... filename="vars_file")
  836. # The second usage: use `predicate` to select the saved variable.
  837. >>> def name_has_fc(var):
  838. ... res = "fc" in var.name
  839. ... return res
  840. >>> param_path = "./my_paddle_model"
  841. # all variables whose names contain "fc " are saved.
  842. >>> paddle.static.io.save_vars(executor=exe, dirname=param_path, main_program=main_prog, vars=None, predicate = name_has_fc)
  843. """
  844. if in_pir_mode():
  845. return save_vars_pir(dirname, main_program, vars, filename)
  846. save_to_memory = False
  847. if dirname is None and filename is None:
  848. save_to_memory = True
  849. main_program = _get_valid_program(main_program)
  850. if vars is None:
  851. return save_vars(
  852. executor,
  853. main_program=main_program,
  854. dirname=dirname,
  855. vars=list(filter(predicate, main_program.list_vars())),
  856. filename=filename,
  857. )
  858. else:
  859. params_var_name = "saved_params"
  860. # give warning when there is no var in model
  861. if len(list(vars)) == 0:
  862. warnings.warn(
  863. "no variable in your model, please ensure there are any variables in your model to save"
  864. )
  865. return None
  866. save_program = Program()
  867. save_block = save_program.global_block()
  868. save_var_map = {}
  869. for each_var in vars:
  870. # NOTE: don't save the variable which type is RAW
  871. if each_var.type == core.VarDesc.VarType.RAW:
  872. continue
  873. new_var = _clone_var_in_block_(save_block, each_var)
  874. if filename is None and save_to_memory is False:
  875. save_file_path = os.path.join(
  876. os.path.normpath(dirname), new_var.name
  877. )
  878. save_block.append_op(
  879. type='save',
  880. inputs={'X': [new_var]},
  881. outputs={},
  882. attrs={'file_path': os.path.normpath(save_file_path)},
  883. )
  884. else:
  885. save_var_map[new_var.name] = new_var
  886. if filename is not None or save_to_memory:
  887. save_var_list = []
  888. for name in sorted(save_var_map.keys()):
  889. save_var_list.append(save_var_map[name])
  890. save_path = ''
  891. if save_to_memory is False:
  892. save_path = os.path.join(os.path.normpath(dirname), filename)
  893. saved_params = save_block.create_var(
  894. type=core.VarDesc.VarType.RAW, name=params_var_name
  895. )
  896. saved_params.desc.set_persistable(True)
  897. save_block.append_op(
  898. type='save_combine',
  899. inputs={'X': save_var_list},
  900. outputs={'Y': saved_params},
  901. attrs={
  902. 'file_path': save_path,
  903. 'save_to_memory': save_to_memory,
  904. },
  905. )
  906. # NOTE(zhiqiu): save op will add variable kLookupTablePath in save_program.desc,
  907. # which leads to diff on save_program and its desc. Call _sync_with_cpp
  908. # to keep consistency.
  909. save_program._sync_with_cpp()
  910. # flush to root_scope
  911. executor.flush()
  912. executor.run(save_program)
  913. if save_to_memory:
  914. return global_scope().find_var(params_var_name).get_bytes()
  915. def load_vars(
  916. executor,
  917. dirname,
  918. main_program=None,
  919. vars=None,
  920. predicate=None,
  921. filename=None,
  922. ):
  923. """
  924. :api_attr: Static Graph
  925. This API loads variables from files by executor.
  926. There are two ways to specify the variables to be loaded: the first way, set
  927. variables in a list and assign it to the `vars`; the second way, use the
  928. `predicate` function to select variables that make `predicate(variable) == True`.
  929. The first way has a higher priority.
  930. The `dirname` is used to specify the folder where to load variables.
  931. If variables were saved in separate files in the folder `dirname`,
  932. set `filename` None. If all variables were saved in a single file,
  933. use `filename` to specify it.
  934. Args:
  935. executor(Executor): The executor to run for loading variables.
  936. dirname(str): The folder where to load the variables.
  937. main_program(Program, optional): The program whose variables will be loaded.
  938. If it is None, the default main program will
  939. be used automatically.
  940. Default: None
  941. vars(list[Variable], optional): The list that contains all variables to be loaded.
  942. Default: None
  943. predicate(function, optional): The function selects variables that make
  944. `predicate(variable) == True`.
  945. Default: None
  946. filename(str, optional): The file which saved all required variables. If variables
  947. were saved in separate files, set it to be None.
  948. Default: None
  949. Returns:
  950. None
  951. Examples:
  952. .. code-block:: python
  953. >>> import paddle
  954. >>> import paddle.static as static
  955. >>> paddle.enable_static()
  956. >>> main_prog = static.Program()
  957. >>> startup_prog = static.Program()
  958. >>> with static.program_guard(main_prog, startup_prog):
  959. ... data = paddle.static.data(name="img", shape=[64, 784])
  960. ... w = paddle.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
  961. ... b = paddle.create_parameter(shape=[200], dtype='float32', name='fc_b')
  962. ... hidden_w = paddle.matmul(x=data, y=w)
  963. ... hidden_b = paddle.add(hidden_w, b)
  964. >>> place = paddle.CPUPlace()
  965. >>> exe = static.Executor(place)
  966. >>> exe.run(startup_prog)
  967. # The first usage: using `vars` to specify the variables.
  968. >>> path = "./my_paddle_vars"
  969. >>> var_list = [w, b]
  970. >>> paddle.static.io.save_vars(executor=exe, dirname=path, vars=var_list,
  971. ... filename="vars_file")
  972. >>> paddle.static.io.load_vars(executor=exe, dirname=path, vars=var_list,
  973. ... filename="vars_file")
  974. # w and b will be loaded, and they are supposed to
  975. # be saved in the same file named 'var_file' in the path "./my_paddle_vars".
  976. # The second usage: using the `predicate` function to select variables
  977. >>> param_path = "./my_paddle_model"
  978. >>> def name_has_fc(var):
  979. ... res = "fc" in var.name
  980. ... return res
  981. >>> paddle.static.io.save_vars(executor=exe, dirname=param_path, main_program=main_prog,
  982. ... vars=None, predicate=name_has_fc)
  983. >>> paddle.static.io.load_vars(executor=exe, dirname=param_path, main_program=main_prog,
  984. ... vars=None, predicate=name_has_fc)
  985. # Load All variables in the `main_program` whose name includes "fc".
  986. # And all the variables are supposed to be saved in separate files.
  987. """
  988. if in_pir_mode():
  989. return load_vars_pir(executor, dirname, main_program, vars, filename)
  990. vars_from_memory = False
  991. if dirname is not None:
  992. dirname = os.path.normpath(dirname)
  993. else:
  994. vars_from_memory = True
  995. if filename == '':
  996. filename = None
  997. if vars is None:
  998. if main_program is None:
  999. main_program = default_main_program()
  1000. if not isinstance(main_program, Program):
  1001. raise TypeError(
  1002. "The type of input main_program is invalid, expected type is base.Program, but received %s"
  1003. % type(main_program)
  1004. )
  1005. load_vars(
  1006. executor,
  1007. dirname=dirname,
  1008. main_program=main_program,
  1009. vars=list(filter(predicate, main_program.list_vars())),
  1010. filename=filename,
  1011. )
  1012. else:
  1013. load_prog = Program()
  1014. load_block = load_prog.global_block()
  1015. if main_program is None:
  1016. main_program = default_main_program()
  1017. if not isinstance(main_program, Program):
  1018. raise TypeError(
  1019. "The type of input main_program is invalid, expected type is base.Program, but received %s"
  1020. % type(main_program)
  1021. )
  1022. # save origin param shape
  1023. orig_para_shape = {}
  1024. load_var_map = {}
  1025. check_vars = []
  1026. sparse_vars = []
  1027. for each_var in vars:
  1028. assert isinstance(each_var, Variable)
  1029. if each_var.type == core.VarDesc.VarType.RAW:
  1030. continue
  1031. if isinstance(each_var, Parameter):
  1032. orig_para_shape[each_var.name] = tuple(
  1033. each_var.desc.get_shape()
  1034. )
  1035. if each_var.type == core.VarDesc.VarType.SELECTED_ROWS:
  1036. sparse_vars.append(each_var)
  1037. continue
  1038. new_var = _clone_var_in_block_(load_block, each_var)
  1039. check_vars.append(each_var)
  1040. if filename is None:
  1041. if dirname is None:
  1042. raise ValueError(
  1043. "The directory path and params cannot be None at the same time."
  1044. )
  1045. load_block.append_op(
  1046. type='load',
  1047. inputs={},
  1048. outputs={'Out': [new_var]},
  1049. attrs={'file_path': os.path.join(dirname, new_var.name)},
  1050. )
  1051. else:
  1052. load_var_map[new_var.name] = new_var
  1053. for each_var in sparse_vars:
  1054. assert isinstance(each_var, Variable)
  1055. if filename is not None:
  1056. raise ValueError(
  1057. "SelectedRows can not be load with load_combine"
  1058. )
  1059. new_var = _clone_var_in_block_(load_block, each_var)
  1060. var_path = os.path.join(dirname, new_var.name)
  1061. if not os.path.exists(var_path):
  1062. raise ValueError(
  1063. f"SelectedRows var {new_var.name} can not find at {var_path}"
  1064. )
  1065. if os.path.isfile(var_path):
  1066. load_block.append_op(
  1067. type='load',
  1068. inputs={},
  1069. outputs={'Out': [new_var]},
  1070. attrs={'file_path': os.path.join(dirname, new_var.name)},
  1071. )
  1072. else:
  1073. blocks = []
  1074. block_paths = os.listdir(var_path)
  1075. for block in block_paths:
  1076. if block.startswith(new_var.name):
  1077. blocks.append(block)
  1078. slices = []
  1079. for block in blocks:
  1080. slice = load_block.create_var(
  1081. name=block,
  1082. type=new_var.type,
  1083. shape=new_var.shape,
  1084. dtype=new_var.dtype,
  1085. persistable=False,
  1086. )
  1087. slices.append(slice)
  1088. file_path = os.path.join(var_path, block, "Param")
  1089. load_block.append_op(
  1090. type='load',
  1091. inputs={},
  1092. outputs={'Out': [slice]},
  1093. attrs={'file_path': file_path},
  1094. )
  1095. load_block.append_op(
  1096. type='lookup_sparse_table_merge',
  1097. inputs={'X': slices},
  1098. outputs={'Out': new_var},
  1099. attrs={},
  1100. )
  1101. if filename is not None:
  1102. load_var_list = []
  1103. for name in sorted(load_var_map.keys()):
  1104. load_var_list.append(load_var_map[name])
  1105. if vars_from_memory is False:
  1106. filename = os.path.join(dirname, filename)
  1107. load_block.append_op(
  1108. type='load_combine',
  1109. inputs={},
  1110. outputs={"Out": load_var_list},
  1111. attrs={
  1112. 'file_path': filename,
  1113. 'model_from_memory': vars_from_memory,
  1114. },
  1115. )
  1116. executor.run(load_prog)
  1117. # check var shape
  1118. for each_var in check_vars:
  1119. if not isinstance(each_var, Parameter):
  1120. continue
  1121. var_temp = paddle.base.global_scope().find_var(each_var.name)
  1122. assert var_temp is not None, "can't not find var: " + each_var.name
  1123. new_shape = (np.array(var_temp.get_tensor())).shape
  1124. assert each_var.name in orig_para_shape, (
  1125. each_var.name + "MUST in var list"
  1126. )
  1127. orig_shape = orig_para_shape.get(each_var.name)
  1128. if new_shape != orig_shape:
  1129. raise RuntimeError(
  1130. f"Variable's shape does not match, the Program requires a parameter with the shape of ({orig_shape}), "
  1131. f"while the loaded parameter (namely [ {each_var.name} ]) has a shape of ({new_shape})."
  1132. )
  1133. @static_only
  1134. def save(program, model_path, protocol=4, **configs):
  1135. """
  1136. This function save parameters, optimizer information and network description to model_path.
  1137. The parameters contains all the trainable Tensor, will save to a file with suffix ".pdparams".
  1138. The optimizer information contains all the Tensor used by optimizer. For Adam optimizer, contains beta1, beta2, momentum etc. All the information will save to a file with suffix ".pdopt". (If the optimizer have no Tensor need to save (like SGD), the fill will not generated).
  1139. The network description is the description of the program. It's only used for deployment. The description will save to a file with a suffix ".pdmodel".
  1140. Args:
  1141. program(Program) : The program to saved.
  1142. model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
  1143. protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
  1144. Default: 4
  1145. configs(dict, optional) : optional keyword arguments.
  1146. Returns:
  1147. None
  1148. Examples:
  1149. .. code-block:: python
  1150. >>> import paddle
  1151. >>> import paddle.static as static
  1152. >>> paddle.enable_static()
  1153. >>> x = static.data(name="x", shape=[10, 10], dtype='float32')
  1154. >>> y = static.nn.fc(x, 10)
  1155. >>> z = static.nn.fc(y, 10)
  1156. >>> place = paddle.CPUPlace()
  1157. >>> exe = static.Executor(place)
  1158. >>> exe.run(static.default_startup_program())
  1159. >>> prog = static.default_main_program()
  1160. >>> static.save(prog, "./temp")
  1161. """
  1162. if in_pir_mode():
  1163. return save_pir(program, model_path, protocol, **configs)
  1164. base_name = os.path.basename(model_path)
  1165. assert (
  1166. base_name != ""
  1167. ), "The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."
  1168. if 'pickle_protocol' in configs:
  1169. protocol = configs['pickle_protocol']
  1170. warnings.warn(
  1171. "'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
  1172. )
  1173. if not isinstance(protocol, int):
  1174. raise ValueError(
  1175. f"The 'protocol' MUST be `int`, but received {type(protocol)}"
  1176. )
  1177. if protocol < 2 or protocol > 4:
  1178. raise ValueError(
  1179. f"Expected 1<'protocol'<5, but received protocol={protocol}"
  1180. )
  1181. dir_name = os.path.dirname(model_path)
  1182. if dir_name and not os.path.exists(dir_name):
  1183. os.makedirs(dir_name)
  1184. def get_tensor(var):
  1185. t = global_scope().find_var(var.name).get_tensor()
  1186. return np.array(t)
  1187. parameter_list = list(filter(is_parameter, program.list_vars()))
  1188. param_dict = {p.name: get_tensor(p) for p in parameter_list}
  1189. param_dict = _unpack_saved_dict(param_dict, protocol)
  1190. # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
  1191. if sys.platform == 'darwin' and sys.version_info.major == 3:
  1192. pickle_bytes = pickle.dumps(param_dict, protocol=protocol)
  1193. with open(model_path + ".pdparams", 'wb') as f:
  1194. max_bytes = 2**30
  1195. for i in range(0, len(pickle_bytes), max_bytes):
  1196. f.write(pickle_bytes[i : i + max_bytes])
  1197. else:
  1198. with open(model_path + ".pdparams", 'wb') as f:
  1199. pickle.dump(param_dict, f, protocol=protocol)
  1200. optimizer_var_list = list(
  1201. filter(is_belong_to_optimizer, program.list_vars())
  1202. )
  1203. opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
  1204. with open(model_path + ".pdopt", 'wb') as f:
  1205. pickle.dump(opt_dict, f, protocol=protocol)
  1206. main_program = program.clone()
  1207. program.desc.flush()
  1208. with open(model_path + ".pdmodel", "wb") as f:
  1209. f.write(program.desc.serialize_to_string())
  1210. @static_only
  1211. def load(program, model_path, executor=None, var_list=None):
  1212. """
  1213. :api_attr: Static Graph
  1214. This function get parameters and optimizer information from program, and then get corresponding value from file.
  1215. An exception will throw if shape or dtype of the parameters is not match.
  1216. This function can also load model file saved with [ save_params, save_persistables, save_vars ].
  1217. var_list can not be None when load single model file
  1218. ( filename is not None When save_params, save_persistables or save_vars is called ).
  1219. Args:
  1220. program(Program): The program will be loaded
  1221. model_path(str): The file prefix store the program
  1222. executor(Executor, optional): The executor used for initialize the parameter
  1223. When startup program is not run.
  1224. var_list(list|tuple, optional): The Tensor list/tuple to load single model file saved with
  1225. [ save_params, save_persistables, save_vars ].
  1226. Default: None
  1227. Returns:
  1228. None
  1229. Examples:
  1230. .. code-block:: python
  1231. >>> import paddle
  1232. >>> import paddle.static as static
  1233. >>> paddle.enable_static()
  1234. >>> x = static.data(name="x", shape=[10, 10], dtype='float32')
  1235. >>> y = static.nn.fc(x, 10)
  1236. >>> z = static.nn.fc(y, 10)
  1237. >>> place = paddle.CPUPlace()
  1238. >>> exe = static.Executor(place)
  1239. >>> exe.run(static.default_startup_program())
  1240. >>> prog = static.default_main_program()
  1241. >>> static.save(prog, "./temp")
  1242. >>> static.load(prog, "./temp")
  1243. """
  1244. if in_pir_mode():
  1245. return load_pir(program, model_path, executor, var_list)
  1246. assert executor is None or isinstance(executor, Executor)
  1247. model_prefix = model_path
  1248. if model_prefix.endswith(".pdparams"):
  1249. model_prefix = model_prefix[:-9]
  1250. elif model_prefix.endswith(".pdopt"):
  1251. model_prefix = model_prefix[:-6]
  1252. elif model_prefix.endswith(".pdmodel"):
  1253. model_prefix = model_prefix[:-8]
  1254. parameter_file_name = model_prefix + ".pdparams"
  1255. if not os.path.exists(parameter_file_name):
  1256. # model file save by base.save not found, try to load model file saved with
  1257. # [save_vars, save_params, save_persistables]
  1258. _logger.debug(
  1259. f"{parameter_file_name} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]"
  1260. )
  1261. if executor is None:
  1262. raise ValueError(
  1263. "executor is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
  1264. )
  1265. if var_list is not None:
  1266. var_list_names = [var.name for var in var_list]
  1267. else:
  1268. var_list_names = None
  1269. if os.path.isdir(model_path):
  1270. binary_file_set = set()
  1271. for root, dirs, files in os.walk(model_path, topdown=False):
  1272. for f in files:
  1273. binary_file_set.add(
  1274. os.path.join(root, f).replace("\\", "/")
  1275. )
  1276. program_var_list = list(program.list_vars())
  1277. loaded_var_list = []
  1278. for var in program_var_list:
  1279. var_path = os.path.join(model_path, var.name).replace("\\", "/")
  1280. load_condition = (
  1281. var_list_names is None or var.name in var_list_names
  1282. )
  1283. if var_path in binary_file_set and load_condition:
  1284. loaded_var_list.append(var)
  1285. binary_file_set.remove(var_path)
  1286. if len(binary_file_set) > 0:
  1287. unused_var_list = " ".join(list(binary_file_set))
  1288. _logger.warning(
  1289. "variable file [ %s ] not used"
  1290. % (" ".join(list(binary_file_set)))
  1291. )
  1292. try:
  1293. load_vars(
  1294. executor=executor, dirname=model_path, vars=loaded_var_list
  1295. )
  1296. except RuntimeError as e:
  1297. _logger.error(e)
  1298. raise e
  1299. except:
  1300. raise RuntimeError(
  1301. "Failed to load model file, please make sure model file is saved with the "
  1302. "following APIs: save_params, save_persistables, save_vars"
  1303. )
  1304. return
  1305. elif os.path.isfile(model_path):
  1306. if var_list is None:
  1307. raise ValueError(
  1308. "var_list is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
  1309. )
  1310. program_var_list = program.list_vars()
  1311. program_var_name_set = {var.name for var in program_var_list}
  1312. # check all the variable included in program
  1313. for var in var_list:
  1314. if var.name not in program_var_name_set:
  1315. raise LookupError(
  1316. "loaded var [{}] is not in program variable list"
  1317. )
  1318. dir_name, file_name = os.path.split(model_path)
  1319. try:
  1320. load_vars(
  1321. executor=executor,
  1322. dirname=dir_name,
  1323. vars=var_list,
  1324. filename=file_name,
  1325. )
  1326. except RuntimeError as e:
  1327. _logger.error(e)
  1328. raise e
  1329. except:
  1330. raise RuntimeError(
  1331. "Failed to load model file , please make sure model file is saved with the "
  1332. "the following APIs: [ save_params, save_persistables, save_vars ]. "
  1333. "When these API called, filename CANNOT be None"
  1334. )
  1335. return
  1336. def set_var(var, ndarray):
  1337. t = global_scope().find_var(var.name).get_tensor()
  1338. p = t._place()
  1339. if p.is_cpu_place():
  1340. place = paddle.base.CPUPlace()
  1341. elif p.is_cuda_pinned_place():
  1342. place = paddle.base.CUDAPinnedPlace()
  1343. elif p.is_xpu_place():
  1344. p = paddle.base.core.Place()
  1345. p.set_place(t._place())
  1346. place = paddle.base.XPUPlace(p.xpu_device_id())
  1347. elif p.is_custom_place():
  1348. p = paddle.base.core.Place()
  1349. p.set_place(t._place())
  1350. place = paddle.base.CustomPlace(
  1351. paddle.device.get_device().split(':')[0], p.custom_device_id()
  1352. )
  1353. else:
  1354. p = paddle.base.core.Place()
  1355. p.set_place(t._place())
  1356. place = paddle.base.CUDAPlace(p.gpu_device_id())
  1357. t.set(ndarray, place)
  1358. parameter_list = list(filter(is_parameter, program.list_vars()))
  1359. if executor:
  1360. paddle.base.core._create_loaded_parameter(
  1361. parameter_list, global_scope(), executor._default_executor
  1362. )
  1363. with open(parameter_file_name, 'rb') as f:
  1364. # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
  1365. if sys.platform == 'darwin' and sys.version_info.major == 3:
  1366. load_dict = _pickle_loads_mac(parameter_file_name, f)
  1367. else:
  1368. load_dict = _safe_load_pickle(f, encoding='latin1')
  1369. load_dict = _pack_loaded_dict(load_dict)
  1370. for v in parameter_list:
  1371. assert (
  1372. v.name in load_dict
  1373. ), f"Can not find [{v.name}] in model file [{parameter_file_name}]"
  1374. set_var(v, load_dict[v.name])
  1375. optimizer_var_list = list(
  1376. filter(is_belong_to_optimizer, program.list_vars())
  1377. )
  1378. if len(optimizer_var_list) > 0:
  1379. opt_file_name = model_prefix + ".pdopt"
  1380. assert os.path.exists(
  1381. opt_file_name
  1382. ), f"Optimizer file [{opt_file_name}] not exits"
  1383. if executor:
  1384. paddle.base.core._create_loaded_parameter(
  1385. optimizer_var_list, global_scope(), executor._default_executor
  1386. )
  1387. with open(opt_file_name, 'rb') as f:
  1388. load_dict = _safe_load_pickle(f, encoding='latin1')
  1389. for v in optimizer_var_list:
  1390. assert (
  1391. v.name in load_dict
  1392. ), f"Can not find [{v.name}] in model file [{opt_file_name}]"
  1393. set_var(v, load_dict[v.name])
  1394. @static_only
  1395. def set_program_state(program, state_dict):
  1396. """
  1397. Set program parameter from state_dict
  1398. An exception will throw if shape or dtype of the parameters is not match.
  1399. NOTICE: This function MUST called after run start_up_program
  1400. Args:
  1401. program(Program): The program to be set
  1402. state_dict(dict): the dict store Parameter and optimizer information
  1403. Returns:
  1404. None
  1405. Examples:
  1406. .. code-block:: python
  1407. >>> import paddle
  1408. >>> import paddle.static as static
  1409. >>> paddle.enable_static()
  1410. >>> x = static.data(name="x", shape=[10, 10], dtype='float32')
  1411. >>> y = static.nn.fc(x, 10)
  1412. >>> z = static.nn.fc(y, 10)
  1413. >>> place = paddle.CPUPlace()
  1414. >>> exe = static.Executor(place)
  1415. >>> exe.run(static.default_startup_program())
  1416. >>> prog = static.default_main_program()
  1417. >>> static.save(prog, "./temp")
  1418. >>> program_state = static.load_program_state("./temp")
  1419. >>> static.set_program_state(prog, program_state)
  1420. """
  1421. state_dict = _pack_loaded_dict(state_dict)
  1422. if in_pir_mode():
  1423. params, opts = get_pir_parameters(program)
  1424. parameter_list = params + opts
  1425. parameter_list = [var for var in parameter_list if var.persistable]
  1426. else:
  1427. parameter_list = list(filter(is_persistable, program.list_vars()))
  1428. used_para_list = {}
  1429. for para in parameter_list:
  1430. var_temp = paddle.base.global_scope().find_var(para.name)
  1431. assert (
  1432. var_temp is not None
  1433. ), f"Variable [ {para.name} ] Not found, Please make sure run startup program"
  1434. if para.name in state_dict:
  1435. # set value from state dict
  1436. orig_para_np = np.array(var_temp.get_tensor())
  1437. new_para_np = state_dict[para.name]
  1438. assert orig_para_np.shape == new_para_np.shape, (
  1439. f"Parameter's shape does not match, the Program requires a parameter with the shape of ({orig_para_np.shape}), "
  1440. f"while the loaded parameter (namely [ {para.name} ]) has a shape of ({new_para_np.shape})."
  1441. )
  1442. assert orig_para_np.dtype == new_para_np.dtype, (
  1443. f"Parameter's data type does not match, the Program requires a parameter with a dtype of ({orig_para_np.dtype}), "
  1444. f"while the loaded parameter (namely [ {para.name} ]) has a dtype of ({new_para_np.dtype})."
  1445. )
  1446. ten = var_temp.get_tensor()
  1447. ten_place = ten._place()
  1448. # assert ten_place.is_gpu_place() or ten_place.is_cpu_place(), \
  1449. # "Place not support, only support CPUPlace and GPUPlace, now is {}".format(str(ten_place))
  1450. py_place = paddle.base.CPUPlace()
  1451. if ten_place.is_cuda_pinned_place():
  1452. place = paddle.base.CUDAPinnedPlace()
  1453. elif ten_place.is_gpu_place():
  1454. p = paddle.base.core.Place()
  1455. p.set_place(ten_place)
  1456. py_place = paddle.base.CUDAPlace(p.gpu_device_id())
  1457. elif ten_place.is_xpu_place():
  1458. p = paddle.base.core.Place()
  1459. p.set_place(ten_place)
  1460. py_place = paddle.base.XPUPlace(p.xpu_device_id())
  1461. ten.set(new_para_np, py_place)
  1462. used_para_list[para.name] = 1
  1463. unused_para_list = []
  1464. for k, v in state_dict.items():
  1465. if k not in used_para_list:
  1466. unused_para_list.append(k)
  1467. if len(unused_para_list) > 0:
  1468. warnings.warn(
  1469. "This list is not set, Because of Parameter not found in program. There are: {}".format(
  1470. " ".join(unused_para_list)
  1471. )
  1472. )
  1473. @dygraph_not_support
  1474. def get_program_persistable_vars(program):
  1475. """
  1476. Get all the persistable vars from Program.
  1477. Args:
  1478. var(Program): The Program to get persistable vars
  1479. Returns:
  1480. list: The list contains all persistable vars in the program
  1481. Examples:
  1482. .. code-block:: python
  1483. >>> import paddle
  1484. >>> import paddle.static.io as io
  1485. >>> paddle.enable_static()
  1486. >>> data = paddle.static.data(name="img", shape=[64, 784])
  1487. >>> w = paddle.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
  1488. >>> b = paddle.create_parameter(shape=[200], dtype='float32', name='fc_b')
  1489. >>> list_para = io.get_program_persistable_vars( paddle.static.default_main_program() )
  1490. """
  1491. return list(filter(is_persistable, program.list_vars()))
  1492. def load_program_state(model_path, var_list=None):
  1493. """
  1494. Load program state from local file
  1495. Args:
  1496. model_path(str): The file prefix store the program
  1497. var_list(list|tuple, optional): The Tensor list/tuple to load saved with
  1498. [ save_params, save_persistables, save_vars ].
  1499. Default: None.
  1500. The var_list is only used to get name,
  1501. will not be modified.
  1502. Returns:
  1503. state_dict(dict): the dict store Parameter and optimizer information
  1504. Examples:
  1505. .. code-block:: python
  1506. >>> import paddle
  1507. >>> import paddle.static as static
  1508. >>> paddle.enable_static()
  1509. >>> x = static.data(name="x", shape=[10, 10], dtype='float32')
  1510. >>> y = static.nn.fc(x, 10)
  1511. >>> z = static.nn.fc(y, 10)
  1512. >>> place = paddle.CPUPlace()
  1513. >>> exe = static.Executor(place)
  1514. >>> exe.run(static.default_startup_program())
  1515. >>> prog = static.default_main_program()
  1516. >>> static.save(prog, "./temp")
  1517. >>> program_state = static.load_program_state("./temp")
  1518. """
  1519. model_prefix = model_path
  1520. if model_prefix.endswith(".pdparams"):
  1521. model_prefix = model_prefix[:-9]
  1522. elif model_prefix.endswith(".pdopt"):
  1523. model_prefix = model_prefix[:-6]
  1524. elif model_prefix.endswith(".pdmodel"):
  1525. model_prefix = model_prefix[:-8]
  1526. parameter_file_name = model_prefix + ".pdparams"
  1527. if not os.path.exists(parameter_file_name):
  1528. # model file saved with base.save is not found, try to load model file saved with
  1529. # [save_vars, save_params, save_persistables]
  1530. _logger.debug(
  1531. f"{parameter_file_name} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]"
  1532. )
  1533. var_name_list = []
  1534. if var_list is None and os.path.isfile(model_path):
  1535. raise ValueError(
  1536. "var_list can not be None when model_path is a file type"
  1537. )
  1538. for root, dirs, files in os.walk(model_path, topdown=False):
  1539. for f in files:
  1540. file_path = os.path.join(root, f)
  1541. var_temp_name = os.path.relpath(file_path, model_path)
  1542. var_temp_name = var_temp_name.replace("\\", "/")
  1543. var_name_list.append(var_temp_name)
  1544. with _load_program_scope():
  1545. load_prog = Program()
  1546. load_block = load_prog.global_block()
  1547. def clone_var_to_block(block, var):
  1548. if not isinstance(var, Variable):
  1549. raise TypeError("value in var_list must be variable")
  1550. return block.create_var(
  1551. name=var.name,
  1552. shape=var.shape,
  1553. dtype=var.dtype,
  1554. type=var.type,
  1555. lod_level=var.lod_level
  1556. if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR
  1557. else None,
  1558. persistable=True,
  1559. )
  1560. def _load_vars_with_try_catch(
  1561. exe, dirname, vars, filename, raise_error=True
  1562. ):
  1563. try:
  1564. load_vars(
  1565. executor=exe,
  1566. dirname=dirname,
  1567. vars=vars,
  1568. filename=filename,
  1569. )
  1570. return True
  1571. except:
  1572. error_str = (
  1573. "Failed to load model/variables `%s`, please make sure "
  1574. "model/variables file is saved with the following APIs: "
  1575. "save_params, save_persistables, save_vars."
  1576. )
  1577. filenames = (
  1578. [var.name for var in vars]
  1579. if filename is None
  1580. else filename
  1581. )
  1582. if raise_error:
  1583. raise RuntimeError(error_str % filenames)
  1584. else:
  1585. warnings.warn(error_str % filenames, RuntimeWarning)
  1586. return False
  1587. place = paddle.base.CPUPlace()
  1588. exe = paddle.base.Executor(place)
  1589. loaded_var_list = []
  1590. if os.path.isfile(model_path):
  1591. # when model_path is file, var_list cannot be None
  1592. dir_name, file_name = os.path.split(model_path)
  1593. for var in var_list:
  1594. loaded_var_list.append(clone_var_to_block(load_block, var))
  1595. _load_vars_with_try_catch(
  1596. exe, dir_name, loaded_var_list, file_name
  1597. )
  1598. else:
  1599. # var_list can be None or not None
  1600. if var_list is not None:
  1601. for var in var_list:
  1602. loaded_var_list.append(
  1603. clone_var_to_block(load_block, var)
  1604. )
  1605. _load_vars_with_try_catch(
  1606. exe, model_path, loaded_var_list, None
  1607. )
  1608. else:
  1609. for var_name in var_name_list:
  1610. # NOTE(chenweihang): If identify which files the user wants
  1611. # to load from the disk, we load these variables one by one.
  1612. # If a file does not exist, we only warn the user that the
  1613. # file may be an irrelevant file, but does not throw an error
  1614. # to ensure that other legal variables can be loaded.
  1615. temp_var = load_block.create_var(
  1616. name=var_name, persistable=True
  1617. )
  1618. if _load_vars_with_try_catch(
  1619. exe, model_path, [temp_var], None, False
  1620. ):
  1621. loaded_var_list.append(temp_var)
  1622. res_dict = {}
  1623. for var in loaded_var_list:
  1624. res_dict[var.name] = np.asarray(
  1625. paddle.base.global_scope().find_var(var.name).get_tensor()
  1626. )
  1627. return res_dict
  1628. assert os.path.exists(
  1629. parameter_file_name
  1630. ), f"Parameter file [{parameter_file_name}] not exits"
  1631. with open(parameter_file_name, 'rb') as f:
  1632. # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
  1633. if sys.platform == 'darwin' and sys.version_info.major == 3:
  1634. para_dict = _pickle_loads_mac(parameter_file_name, f)
  1635. else:
  1636. para_dict = _safe_load_pickle(f, encoding='latin1')
  1637. para_dict = _pack_loaded_dict(para_dict)
  1638. opt_file_name = model_prefix + ".pdopt"
  1639. if os.path.exists(opt_file_name):
  1640. with open(opt_file_name, 'rb') as f:
  1641. opti_dict = _safe_load_pickle(f, encoding='latin1')
  1642. para_dict.update(opti_dict)
  1643. return para_dict