io.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import paddle
  16. from paddle.base.framework import Program, static_only
  17. from paddle.framework import core, dygraph_not_support
  18. def _load_distributed_persistables(executor, dirname, main_program=None):
  19. """
  20. customized load_persistables for distributed training.
  21. it should be used on parameter server,
  22. Args:
  23. executor(Executor): The executor to run for saving parameters.
  24. dirname(str): The load directory path.
  25. main_program(Program): The program whose parameters will be
  26. loaded. the main_program must be the pserver_program
  27. get after transpiler.
  28. Returns:
  29. None
  30. Examples:
  31. .. code-block:: python
  32. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  33. >>> import paddle
  34. >>> import paddle.base as base
  35. >>> paddle.enable_static()
  36. >>> exe = base.Executor(base.CPUPlace())
  37. >>> param_path = "./my_paddle_model"
  38. >>> t = paddle.distributed.transpiler.DistributeTranspiler()
  39. >>> t.transpile(...)
  40. >>> pserver_prog = t.get_pserver_program(...)
  41. >>> _load_distributed_persistables(executor=exe, dirname=param_path, main_program=pserver_prog)
  42. """
  43. def __is_distributed_part_var(varname):
  44. trainer_idx = varname.find(".trainer_")
  45. block_idx = varname.find(".block")
  46. return trainer_idx or block_idx
  47. def __load_persistable_vars(executor, dirname, need_load_vars):
  48. load_prog = Program()
  49. load_block = load_prog.global_block()
  50. need_delete_vars = []
  51. for param in need_load_vars:
  52. origin_var = param.origin
  53. slice_var = param.slice
  54. is_slice = param.is_slice
  55. offset = param.offset
  56. if is_slice:
  57. slice = load_block.create_var(
  58. name=slice_var.name,
  59. type=slice_var.type,
  60. shape=slice_var.shape,
  61. dtype=slice_var.dtype,
  62. persistable=True,
  63. )
  64. load_block.append_op(
  65. type='load',
  66. inputs={},
  67. outputs={'Out': [slice]},
  68. attrs={
  69. 'file_path': os.path.join(dirname, origin_var.name),
  70. 'seek': offset,
  71. 'shape': slice.shape,
  72. },
  73. )
  74. else:
  75. origin = load_block.create_var(
  76. name=f"{origin_var.name}",
  77. type=origin_var.type,
  78. shape=origin_var.shape,
  79. dtype=origin_var.dtype,
  80. persistable=True,
  81. )
  82. load_block.append_op(
  83. type='load',
  84. inputs={},
  85. outputs={'Out': [origin]},
  86. attrs={'file_path': os.path.join(dirname, origin_var.name)},
  87. )
  88. load_block.append_op(
  89. type='delete_var',
  90. inputs={'X': need_delete_vars},
  91. )
  92. executor.run(load_prog)
  93. if not isinstance(main_program, Program):
  94. raise TypeError("'main_program' should be an instance of Program.")
  95. if not main_program._is_distributed:
  96. raise ValueError(
  97. "'_load_distributed_persistables' just be designed for distributed training."
  98. )
  99. if not main_program._ps_endpoint:
  100. raise ValueError(
  101. "'_load_distributed_persistables' need current_endpoint set in DistributeTranspiler.transpile"
  102. )
  103. need_load_vars = (
  104. main_program._parameters_on_pservers.get_distributed_vars_by_ep(
  105. main_program._ps_endpoint
  106. )
  107. )
  108. __load_persistable_vars(executor, dirname, need_load_vars)
  109. @dygraph_not_support
  110. def load_persistables(executor, dirname, main_program=None, filename=None):
  111. """
  112. :api_attr: Static Graph
  113. This API filters out all variables with ``persistable==True`` from the
  114. given ``main_program`` and then tries to load these variables from the
  115. directory ``dirname`` or the file ``filename``.
  116. Use the ``dirname`` to specify the directory where persistable variables
  117. (refer to :ref:`api_guide_model_save_reader_en`) were saved. If variables
  118. were saved in separate files, set ``filename`` as None; if all variables
  119. were saved in a single file, use ``filename`` to specify the file name.
  120. Args:
  121. executor(Executor): The executor used for loading persistable variables.
  122. See :ref:`api_guide_executor_en` for more details about it.
  123. dirname(str): The directory path.
  124. main_program(Program, optional): The program whose persistable variables will
  125. be loaded. If it is None, the ``default_main_program``
  126. will be used automatically. See :ref:`api_guide_Program_en`
  127. for more about ``Program``.
  128. Default: None.
  129. filename(str, optional): The file which saved all persistable variables. If variables
  130. were saved in separated files, set it to None.
  131. Default: None.
  132. Returns:
  133. None
  134. Examples:
  135. .. code-block:: python
  136. >>> import paddle
  137. >>> import paddle.base as base
  138. >>> paddle.enable_static()
  139. >>> exe = base.Executor(base.CPUPlace())
  140. >>> param_path = "./my_paddle_model"
  141. >>> prog = base.default_main_program()
  142. >>> paddle.distributed.io.load_persistables(executor=exe, dirname=param_path,
  143. ... main_program=None)
  144. """
  145. if main_program and main_program._is_distributed:
  146. _load_distributed_persistables(
  147. executor, dirname=dirname, main_program=main_program
  148. )
  149. else:
  150. paddle.static.io.load_vars(
  151. executor,
  152. dirname=dirname,
  153. main_program=main_program,
  154. predicate=is_persistable,
  155. filename=filename,
  156. )
  157. def _save_distributed_persistables(executor, dirname, main_program):
  158. """
  159. save_persistables for distributed training.
  160. the method will do things listed below:
  161. 1.save part of persistable variables on trainer.
  162. 2.receive "remote prefetch variables" from parameter servers and merge them.
  163. 3.save "distributed lookup table" on parameter servers.
  164. 4.receive "optimizer variables" from parameter servers and merge them.
  165. Args:
  166. executor(Executor): The executor to run for saving parameters.
  167. dirname(str): The saving directory path.
  168. main_program(Program): The program whose parameters will be
  169. saved. the main_program must be the trainer_program
  170. get after transpiler.
  171. Returns:
  172. None
  173. Examples:
  174. .. code-block:: python
  175. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  176. >>> import paddle
  177. >>> import paddle
  178. >>> paddle.enable_static()
  179. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  180. >>> param_path = "./my_paddle_model"
  181. >>> t = paddle.distributed.transpiler.DistributeTranspiler()
  182. >>> t.transpile(...)
  183. >>> train_program = t.get_trainer_program()
  184. >>> _save_distributed_persistables(executor=exe, dirname=param_path, main_program=train_program)
  185. """
  186. def __save_remote_params(executor, dirname, remote_params_map):
  187. """
  188. receive params on pserver through rpc.
  189. if the params are be sliced, will concat them to one, then save it.
  190. """
  191. if not remote_params_map:
  192. return
  193. prog = paddle.static.Program()
  194. block = prog.global_block()
  195. # recv optimize vars from pserver
  196. for name, remote_params in remote_params_map.items():
  197. origin = remote_params[0].origin
  198. is_slice = remote_params[0].is_slice
  199. slices = [None] * len(remote_params)
  200. slice_varnames = [None] * len(remote_params)
  201. remote_varnames = [None] * len(remote_params)
  202. endpoints = [None] * len(remote_params)
  203. for idx, optimizer in enumerate(remote_params):
  204. block_id = optimizer.block_id
  205. slice = optimizer.slice
  206. endpoint = optimizer.endpoint
  207. index = block_id if is_slice else idx
  208. slices[index] = slice
  209. slice_varnames[index] = f"{slice.name}.slice.{idx}"
  210. remote_varnames[index] = slice.name
  211. endpoints[index] = endpoint
  212. slice_shapes = []
  213. for slice in slices:
  214. tmp = [str(dim) for dim in slice.shape]
  215. slice_shapes.append(",".join(tmp))
  216. block.append_op(
  217. type='recv_save',
  218. attrs={
  219. "trainer_id": 0,
  220. "shape": origin.shape,
  221. "slice_shapes": slice_shapes,
  222. "slice_varnames": slice_varnames,
  223. "remote_varnames": remote_varnames,
  224. "endpoints": endpoints,
  225. "file_path": os.path.join(dirname, origin.name),
  226. },
  227. )
  228. executor.run(prog)
  229. def __save_distributed_lookup_tables(
  230. executor, dirname, distributed_lookup_table, endpoints
  231. ):
  232. """
  233. because the distributed lookup table may too huge to merge and save at one place,
  234. it will be saved at parameter server independent respectively.
  235. the save directory is dirname/"__lookup_table__".
  236. """
  237. prog = paddle.static.Program()
  238. block = prog.global_block()
  239. # if there is lookup table, the trainer 0 will notify all pserver to save.
  240. lookup_table_filename = os.path.join(dirname, "__lookup_table__")
  241. attrs = {}
  242. attrs['epmap'] = endpoints
  243. attrs['dir'] = lookup_table_filename
  244. attrs['lookup_table'] = distributed_lookup_table
  245. block.append_op(
  246. type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs
  247. )
  248. executor.run(prog)
  249. def __exclude_vars(exclude_var_names=[]):
  250. def is_valid(var):
  251. if var.name in exclude_var_names:
  252. return False
  253. if (
  254. var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
  255. or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
  256. or var.desc.type() == core.VarDesc.VarType.READER
  257. ):
  258. return False
  259. return var.persistable
  260. return is_valid
  261. if not isinstance(main_program, Program):
  262. raise TypeError("'main_program' should be an instance of Program.")
  263. if not main_program._is_distributed:
  264. raise ValueError(
  265. "'_save_distributed_persistables' just be designed for distributed training."
  266. )
  267. remote_params_map = (
  268. main_program._parameters_on_pservers.get_distributed_vars_by_vtypes(
  269. ["Optimizer", "RemotePrefetch"], groupby=True
  270. )
  271. )
  272. exclude_var_names = []
  273. if remote_params_map:
  274. exclude_var_names.extend(remote_params_map.keys())
  275. if main_program._distributed_lookup_table:
  276. if isinstance(main_program._distributed_lookup_table, list):
  277. exclude_var_names.extend(main_program._distributed_lookup_table)
  278. else:
  279. exclude_var_names.append(main_program._distributed_lookup_table)
  280. local_vars = list(
  281. filter(__exclude_vars(exclude_var_names), main_program.list_vars())
  282. )
  283. paddle.static.save_vars(
  284. executor, main_program=main_program, dirname=dirname, vars=local_vars
  285. )
  286. if main_program._is_chief:
  287. if remote_params_map:
  288. __save_remote_params(executor, dirname, remote_params_map)
  289. if main_program._distributed_lookup_table:
  290. __save_distributed_lookup_tables(
  291. executor,
  292. dirname,
  293. main_program._distributed_lookup_table,
  294. main_program._endpoints,
  295. )
  296. def is_persistable(var):
  297. """
  298. Check whether the given variable is persistable.
  299. Args:
  300. var(Variable): The variable to be checked.
  301. Returns:
  302. bool: True if the given `var` is persistable
  303. False if not.
  304. Examples:
  305. .. code-block:: python
  306. >>> import paddle
  307. >>> paddle.enable_static()
  308. >>> image = paddle.static.data(
  309. ... name='image', shape=[None, 28], dtype='float32')
  310. >>> bias_attr = paddle.ParamAttr('fc.b')
  311. >>> fc = paddle.static.nn.fc(image, size=10, bias_attr=bias_attr)
  312. >>> param = paddle.static.default_main_program().global_block().var('fc.b')
  313. >>> res = paddle.distributed.io.is_persistable(param)
  314. """
  315. if (
  316. var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
  317. or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
  318. or var.desc.type() == core.VarDesc.VarType.READER
  319. ):
  320. return False
  321. return var.persistable
  322. @dygraph_not_support
  323. def save_persistables(executor, dirname, main_program=None, filename=None):
  324. """
  325. Save all persistable variables from :code:`main_program` to
  326. the folder :code:`dirname` or file :code:`filename`. You can refer to
  327. :ref:`api_guide_model_save_reader_en` for more details. And then
  328. saves these persistables variables to the folder :code:`dirname` or file
  329. :code:`filename`.
  330. The :code:`dirname` is used to specify the folder where persistable variables
  331. are going to be saved. If you would like to save variables in separate
  332. files, set :code:`filename` None; if you would like to save all variables in a
  333. single file, use :code:`filename` to specify the file name.
  334. Args:
  335. executor(Executor): The executor to run for saving persistable variables.
  336. You can refer to :ref:`api_guide_executor_en` for
  337. more details.
  338. dirname(str, optional): The saving directory path.
  339. When you need to save the parameter to the memory, set it to None.
  340. main_program(Program, optional): The program whose persistable variables will
  341. be saved. You can refer to
  342. :ref:`api_guide_Program_en` for more details.
  343. If it is None, the default main program will
  344. be used.
  345. Default: None.
  346. filename(str, optional): The file to save all variables. If you prefer to
  347. save variables in different files, set it to None.
  348. Default: None.
  349. Returns:
  350. str: When saving parameters to a file, returns None.
  351. When saving parameters to memory, returns a binary string containing parameters.
  352. Examples:
  353. .. code-block:: python
  354. >>> import paddle
  355. >>> paddle.enable_static()
  356. >>> dir_path = "./my_paddle_model"
  357. >>> file_name = "persistables"
  358. >>> image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
  359. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  360. >>> feeder = paddle.base.DataFeeder(feed_list=[image, label], place=paddle.CPUPlace())
  361. >>> predict = paddle.static.nn.fc(x=image, size=10, activation='softmax')
  362. >>> loss = paddle.nn.functional.cross_entropy(input=predict, label=label)
  363. >>> avg_loss = paddle.mean(loss)
  364. >>> exe = paddle.static.Executor(paddle.CPUPlace())
  365. >>> exe.run(paddle.static.default_startup_program())
  366. >>> paddle.distributed.io.save_persistables(executor=exe, dirname=dir_path, filename=file_name)
  367. >>> # The persistables variables weights and bias in the fc layer of the network
  368. >>> # are going to be saved in the same file named "persistables" in the path
  369. >>> # "./my_paddle_model"
  370. """
  371. if main_program and main_program._is_distributed:
  372. return _save_distributed_persistables(
  373. executor, dirname=dirname, main_program=main_program
  374. )
  375. else:
  376. return paddle.static.save_vars(
  377. executor,
  378. dirname=dirname,
  379. main_program=main_program,
  380. vars=None,
  381. predicate=is_persistable,
  382. filename=filename,
  383. )
  384. @static_only
  385. def load_inference_model_distributed(
  386. dirname,
  387. executor,
  388. model_filename=None,
  389. params_filename=None,
  390. pserver_endpoints=None,
  391. ):
  392. """
  393. Load the inference model from a given directory. By this API, you can get the model
  394. structure(Inference Program) and model parameters.
  395. You can refer to :ref:`api_guide_model_save_reader_en` for more details.
  396. Args:
  397. dirname(str): One of the following:
  398. - The given directory path.
  399. - Set to None when reading the model from memory.
  400. executor(Executor): The executor to run for loading inference model.
  401. See :ref:`api_guide_executor_en` for more details about it.
  402. model_filename(str, optional): One of the following:
  403. - The name of file to load the inference program.
  404. - If it is None, the default filename ``__model__`` will be used.
  405. - When ``dirname`` is ``None``, it must be set to a string containing model.
  406. Default: ``None``.
  407. params_filename(str, optional): It is only used for the case that all
  408. parameters were saved in a single binary file. One of the following:
  409. - The name of file to load all parameters.
  410. - When ``dirname`` is ``None``, it must be set to a string containing all the parameters.
  411. - If parameters were saved in separate files, set it as ``None``.
  412. Default: ``None``.
  413. pserver_endpoints(list, optional): It is only needed by the distributed inference.
  414. If using a distributed look up table during the training,
  415. this table is also needed by the inference process. Its value is
  416. a list of pserver endpoints.
  417. Returns:
  418. list: The return of this API is a list with three elements:
  419. (program, feed_target_names, fetch_targets). The `program` is a
  420. ``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference.
  421. The `feed_target_names` is a list of ``str``, which contains names of variables
  422. that need to feed data in the inference program. The `fetch_targets` is a list of
  423. ``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which
  424. we can get inference results.
  425. Examples:
  426. .. code-block:: python
  427. >>> import paddle
  428. >>> import paddle.base as base
  429. >>> import numpy as np
  430. >>> paddle.enable_static()
  431. >>> # Build the model
  432. >>> main_prog = paddle.static.Program()
  433. >>> startup_prog = paddle.static.Program()
  434. >>> with paddle.static.program_guard(main_prog, startup_prog):
  435. ... data = paddle.static.data(name="img", shape=[64, 784], append_batch_size=False)
  436. ... w = paddle.create_parameter(shape=[784, 200], dtype='float32')
  437. ... b = paddle.create_parameter(shape=[200], dtype='float32')
  438. ... hidden_w = paddle.matmul(x=data, y=w)
  439. ... hidden_b = base.layers.elementwise_add(hidden_w, b)
  440. >>> place = base.CPUPlace()
  441. >>> exe = base.Executor(place)
  442. >>> exe.run(startup_prog)
  443. >>> # Save the inference model
  444. >>> path = "./infer_model"
  445. >>> base.io.save_inference_model(dirname=path, feeded_var_names=['img'],
  446. ... target_vars=[hidden_b], executor=exe, main_program=main_prog)
  447. ...
  448. >>> # Demo one. Not need to set the distributed look up table, because the
  449. >>> # training doesn't use a distributed look up table.
  450. >>> [inference_program, feed_target_names, fetch_targets] = (
  451. ... paddle.distributed.io.load_inference_model_distributed(dirname=path, executor=exe))
  452. >>> tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32)
  453. >>> results = exe.run(inference_program,
  454. ... feed={feed_target_names[0]: tensor_img},
  455. ... fetch_list=fetch_targets)
  456. ...
  457. >>> # Demo two. If the training uses a distributed look up table, the pserver
  458. >>> # endpoints list should be supported when loading the inference model.
  459. >>> # The below is just an example.
  460. >>> endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
  461. >>> [dist_inference_program, dist_feed_target_names, dist_fetch_targets] = (
  462. ... paddle.distributed.io.load_inference_model_distributed(dirname=path,
  463. ... executor=exe,
  464. ... pserver_endpoints=endpoints))
  465. ...
  466. >>> # In this example, the inference program was saved in the file
  467. >>> # "./infer_model/__model__" and parameters were saved in
  468. >>> # separate files under the directory "./infer_model".
  469. >>> # By the inference program, feed_target_names and
  470. >>> # fetch_targets, we can use an executor to run the inference
  471. >>> # program for getting the inference result.
  472. """
  473. load_from_memory = False
  474. if dirname is not None:
  475. load_dirname = os.path.normpath(dirname)
  476. if not os.path.isdir(load_dirname):
  477. raise ValueError("There is no directory named '%s'" % dirname)
  478. if model_filename is None:
  479. model_filename = '__model__'
  480. model_filename = os.path.join(
  481. load_dirname, os.path.basename(model_filename)
  482. )
  483. if params_filename is not None:
  484. params_filename = os.path.basename(params_filename)
  485. with open(model_filename, "rb") as f:
  486. program_desc_str = f.read()
  487. else:
  488. load_from_memory = True
  489. if params_filename is None:
  490. raise ValueError(
  491. "The path of params cannot be None when the directory path is None."
  492. )
  493. load_dirname = dirname
  494. program_desc_str = model_filename
  495. params_filename = params_filename
  496. program = Program.parse_from_string(program_desc_str)
  497. if not core._is_program_version_supported(program._version()):
  498. raise ValueError(
  499. "Unsupported program version: %d\n" % program._version()
  500. )
  501. # Binary data also need versioning.
  502. load_persistables(executor, load_dirname, program, params_filename)
  503. feed_target_names = program.desc.get_feed_target_names()
  504. fetch_target_names = program.desc.get_fetch_target_names()
  505. fetch_targets = [
  506. program.global_block().var(name) for name in fetch_target_names
  507. ]
  508. return [program, feed_target_names, fetch_targets]