api.py 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. # Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import annotations
  16. import inspect
  17. import os
  18. import pickle
  19. import sys
  20. import threading
  21. import types
  22. import warnings
  23. from collections import OrderedDict
  24. from contextlib import contextmanager
  25. from typing import Any
  26. import paddle
  27. from paddle.base import core, dygraph
  28. from paddle.base.compiler import (
  29. BuildStrategy,
  30. )
  31. from paddle.base.dygraph.base import (
  32. switch_to_static_graph,
  33. )
  34. from paddle.base.executor import Executor, scope_guard
  35. from paddle.base.framework import (
  36. EagerParamBase,
  37. Parameter,
  38. Variable,
  39. _current_expected_place,
  40. dygraph_only,
  41. )
  42. from paddle.base.wrapped_decorator import wrap_decorator
  43. from paddle.framework import use_pir_api
  44. from paddle.nn import Layer
  45. from paddle.static.io import save_inference_model
  46. from paddle.utils.environments import (
  47. BooleanEnvironmentVariable,
  48. EnvironmentVariableGuard,
  49. )
  50. from .dy2static import logging_utils
  51. from .dy2static.convert_call_func import ConversionOptions, add_ignore_module
  52. from .dy2static.program_translator import (
  53. ASTStaticFunction,
  54. ProgramTranslator,
  55. StaticFunction,
  56. SymbolicStaticFunction,
  57. unwrap_decorators,
  58. )
  59. from .pir_translated_layer import PIR_INFER_MODEL_SUFFIX, PirTranslatedLayer
  60. from .translated_layer import (
  61. INFER_MODEL_SUFFIX,
  62. INFER_PARAMS_INFO_SUFFIX,
  63. INFER_PARAMS_SUFFIX,
  64. INFER_PROPERTY_SUFFIX,
  65. TranslatedLayer,
  66. )
  67. ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True)
  68. @contextmanager
  69. def sot_mode_guard(value: bool):
  70. with EnvironmentVariableGuard(ENV_ENABLE_SOT, value):
  71. yield
  72. def copy_decorator_attrs(original_func, decorated_obj):
  73. """
  74. Copies some necessary attributes from original function into decorated function.
  75. Args:
  76. original_func(callable): the original decorated function.
  77. decorated_obj(StaticFunction): the target decorated StaticFunction object.
  78. """
  79. decorator_name = "to_static"
  80. decorated_obj.__name__ = original_func.__name__
  81. decorated_obj._decorator_name = decorator_name
  82. decorated_obj.__wrapped__ = original_func
  83. decorated_obj.__doc__ = original_func.__doc__
  84. if hasattr(original_func, "__module__"):
  85. decorated_obj.__module__ = original_func.__module__
  86. return decorated_obj
  87. def ignore_module(modules: list[Any]):
  88. """
  89. Adds modules that ignore transcription.
  90. Builtin modules that have been ignored are collections, pdb, copy, inspect, re, numpy, logging, six
  91. Args:
  92. modules (List[Any]): Ignored modules that you want to add
  93. Examples:
  94. .. code-block:: python
  95. >>> import scipy
  96. >>> import astor
  97. >>> import paddle
  98. >>> from paddle.jit import ignore_module
  99. >>> modules = [
  100. ... scipy,
  101. ... astor,
  102. ... ]
  103. >>> ignore_module(modules)
  104. """
  105. add_ignore_module(modules)
  106. def _check_and_set_backend(backend, build_strategy):
  107. if backend not in ['CINN', None]:
  108. raise ValueError(
  109. f"The backend of to_static should be 'CINN' or None, but received {backend}."
  110. )
  111. if backend == 'CINN':
  112. build_strategy.build_cinn_pass = True
  113. def to_static(
  114. function=None,
  115. input_spec=None,
  116. build_strategy=None,
  117. backend=None,
  118. **kwargs,
  119. ):
  120. """
  121. Converts dynamic graph APIs into static graph function APIs. Decorator
  122. @to_static handles the Program and Executor of static graph mode and returns
  123. the result as dynamic graph Tensor(s). Users could use the returned dynamic
  124. graph Tensor(s) to do dynamic graph training, inference, or other operations.
  125. If the decorated function calls other dynamic graph function, the called one
  126. will be converted into static graph function as well.
  127. Args:
  128. function (callable): Callable dynamic graph function. If it used as a
  129. decorator, the decorated function will be parsed as this parameter.
  130. input_spec (list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to
  131. specific the shape/dtype/name information of each input Tensor.
  132. build_strategy (BuildStrategy|None): This argument is used to compile the
  133. converted program with the specified options, such as operators' fusion
  134. in the computational graph and memory optimization during the execution
  135. of the computational graph. For more information about build_strategy,
  136. please refer to :code:`paddle.static.BuildStrategy`. The default is None.
  137. backend(str, Optional): Specifies compilation backend, which can be `CINN` or
  138. None. When backend is `CINN`, CINN compiler will be used to speed up
  139. training and inference.
  140. kwargs: Support keys including `property`, set `property` to True if the function
  141. is python property.
  142. Returns:
  143. Tensor(s): containing the numerical result.
  144. Examples:
  145. .. code-block:: python
  146. >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest')
  147. >>> import paddle
  148. >>> from paddle.jit import to_static
  149. >>> @to_static
  150. >>> def func(x):
  151. ... if paddle.mean(x) < 0:
  152. ... x_v = x - 1
  153. ... else:
  154. ... x_v = x + 1
  155. ... return x_v
  156. ...
  157. >>> x = paddle.ones([1, 2], dtype='float32')
  158. >>> x_v = func(x)
  159. >>> print(x_v)
  160. Tensor(shape=[1, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
  161. [[2., 2.]])
  162. """
  163. property = kwargs.get("property", False)
  164. full_graph = kwargs.get("full_graph", None)
  165. def decorated(python_func):
  166. """
  167. Decorates a python function into a ASTStaticFunction object.
  168. """
  169. nonlocal full_graph
  170. if full_graph is None:
  171. flag = ENV_ENABLE_SOT.get()
  172. full_graph = not flag
  173. if sys.version_info >= (3, 13) and not full_graph:
  174. warnings.warn(
  175. "full_graph=False is not supported in Python 3.13+. Set full_graph=True automatically"
  176. )
  177. full_graph = True
  178. StaticClass = {
  179. False: SymbolicStaticFunction,
  180. True: ASTStaticFunction,
  181. }[full_graph]
  182. # Step 1. unwrap the function if it is already decorated.
  183. _, python_func = unwrap_decorators(python_func)
  184. # Step 2. copy some attributes from original python function.
  185. static_layer = copy_decorator_attrs(
  186. original_func=python_func,
  187. decorated_obj=StaticClass(
  188. function=python_func,
  189. input_spec=input_spec,
  190. build_strategy=build_strategy,
  191. property=property,
  192. backend=backend,
  193. ),
  194. )
  195. return static_layer
  196. build_strategy = build_strategy or BuildStrategy()
  197. if not isinstance(build_strategy, BuildStrategy):
  198. raise TypeError(
  199. f"Required type(build_strategy) shall be `paddle.static.BuildStrategy`, but received {type(build_strategy).__name__}"
  200. )
  201. _check_and_set_backend(backend, build_strategy)
  202. # for usage: `to_static(foo, ...)`
  203. if function is not None:
  204. if isinstance(function, Layer):
  205. if isinstance(function.forward, StaticFunction):
  206. class_name = function.__class__.__name__
  207. logging_utils.warn(
  208. f"`{class_name}.forward` has already been decorated somewhere. It will be redecorated to replace previous one."
  209. )
  210. function.forward = decorated(function.forward)
  211. return function
  212. else:
  213. return decorated(function)
  214. # for usage: `@to_static`
  215. return decorated
  216. def not_to_static(func=None):
  217. """
  218. A Decorator to suppresses the convention of a function.
  219. Args:
  220. func(callable): The function to decorate.
  221. Returns:
  222. callable: A function which won't be converted in Dynamic-to-Static.
  223. Examples:
  224. .. code-block:: python
  225. >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest')
  226. >>> import paddle
  227. >>> @paddle.jit.not_to_static
  228. ... def func_not_to_static(x):
  229. ... res = x - 1
  230. ... return res
  231. >>> @paddle.jit.to_static
  232. ... def func(x):
  233. ... if paddle.mean(x) < 0:
  234. ... out = func_not_to_static(x)
  235. ... else:
  236. ... out = x + 1
  237. ... return out
  238. ...
  239. >>> x = paddle.ones([1, 2], dtype='float32')
  240. >>> out = func(x)
  241. >>> print(out)
  242. Tensor(shape=[1, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
  243. [[2., 2.]])
  244. """
  245. if func is None:
  246. return not_to_static
  247. options = ConversionOptions(not_convert=True)
  248. options.attach(func)
  249. return func
  250. class _SaveLoadConfig:
  251. def __init__(self):
  252. self._output_spec = None
  253. self._model_filename = None
  254. self._params_filename = None
  255. self._separate_params = False
  256. # used for `paddle.load`
  257. self._keep_name_table = False
  258. # NOTE: Users rarely use following configs, so these configs are not open to users,
  259. # reducing user learning costs, but we retain the configuration capabilities
  260. # If True, programs are modified to only support direct inference deployment.
  261. # Otherwise,more information will be stored for flexible optimization and re-training.
  262. # Currently, only True is supported
  263. self._export_for_deployment = True
  264. # If True, It will save inference program only, and do not save params of Program
  265. self._program_only = False
  266. self.with_hook = False
  267. # if True, multi `StaticFunction` will share params in one file.
  268. self.combine_params = False
  269. # when need to save a prune model, use input_names_after_prune to specify the inputs left after pruning
  270. self.input_names_after_prune = None
  271. # in the scene of llm-inference, pruning program can cause unexpectable result, an option to skip prune is necessary
  272. self.skip_prune_program = False
  273. @property
  274. def output_spec(self):
  275. return self._output_spec
  276. @output_spec.setter
  277. def output_spec(self, spec):
  278. if spec is None:
  279. return
  280. if not isinstance(spec, list):
  281. raise TypeError(
  282. f"The config `output_spec` should be 'list', but received input type is {type(input)}."
  283. )
  284. for var in spec:
  285. if not isinstance(var, core.eager.Tensor):
  286. raise TypeError(
  287. f"The element in config `output_spec` list should be 'Variable', but received element's type is {type(var)}."
  288. )
  289. self._output_spec = spec
  290. @property
  291. def model_filename(self):
  292. return self._model_filename
  293. @model_filename.setter
  294. def model_filename(self, filename):
  295. if filename is None:
  296. return
  297. if not isinstance(filename, str):
  298. raise TypeError(
  299. f"The config `model_filename` should be str, but received input's type is {type(filename)}."
  300. )
  301. if len(filename) == 0:
  302. raise ValueError("The config `model_filename` is empty string.")
  303. self._model_filename = filename
  304. @property
  305. def params_filename(self):
  306. return self._params_filename
  307. @params_filename.setter
  308. def params_filename(self, filename):
  309. if filename is None:
  310. return
  311. if not isinstance(filename, str):
  312. raise TypeError(
  313. f"The config `params_filename` should be str, but received input's type is {type(filename)}."
  314. )
  315. if len(filename) == 0:
  316. raise ValueError("The config `params_filename` is empty string.")
  317. self._params_filename = filename
  318. @property
  319. def keep_name_table(self):
  320. return self._keep_name_table
  321. @keep_name_table.setter
  322. def keep_name_table(self, value):
  323. if value is None:
  324. return
  325. if not isinstance(value, bool):
  326. raise TypeError(
  327. f"The config `keep_name_table` should be bool value, but received input's type is {type(value)}."
  328. )
  329. self._keep_name_table = value
  330. def _parse_save_configs(configs):
  331. supported_configs = [
  332. "output_spec",
  333. "with_hook",
  334. "combine_params",
  335. "clip_extra",
  336. "skip_forward",
  337. "input_names_after_prune",
  338. "skip_prune_program",
  339. ]
  340. # input check
  341. for key in configs:
  342. if key not in supported_configs:
  343. raise ValueError(
  344. f"The additional config ({key}) of `paddle.jit.save` is not supported."
  345. )
  346. # construct inner config
  347. inner_config = _SaveLoadConfig()
  348. inner_config.output_spec = configs.get("output_spec", None)
  349. inner_config.with_hook = configs.get("with_hook", False)
  350. inner_config.combine_params = configs.get("combine_params", False)
  351. inner_config.clip_extra = configs.get("clip_extra", True)
  352. inner_config.skip_forward = configs.get("skip_forward", False)
  353. inner_config.input_names_after_prune = configs.get(
  354. "input_names_after_prune", None
  355. )
  356. inner_config.skip_prune_program = configs.get("skip_prune_program", False)
  357. return inner_config
  358. def _parse_load_config(configs):
  359. supported_configs = ['model_filename', 'params_filename']
  360. # input check
  361. for key in configs:
  362. if key not in supported_configs:
  363. raise ValueError(
  364. f"The additional config ({key}) of `paddle.jit.load` is not supported."
  365. )
  366. # construct inner config
  367. inner_config = _SaveLoadConfig()
  368. inner_config.model_filename = configs.get('model_filename', None)
  369. inner_config.params_filename = configs.get('params_filename', None)
  370. return inner_config
  371. def _get_input_var_and_names(inputs, input_spec, input_names_after_prune):
  372. name_none_error = (
  373. "The %s's name is None. "
  374. "When using jit.save, please set InputSpec's name in "
  375. "to_static(input_spec=[]) and jit.save(input_spec=[]) "
  376. "and make sure they are consistent."
  377. )
  378. name_no_exists_error = (
  379. "The tensor `%s` does not exists. "
  380. "Please make sure the name of InputSpec or example Tensor "
  381. "in input_spec is the same as the name of InputSpec in "
  382. "`to_static` decorated on the Layer.forward method."
  383. )
  384. if input_names_after_prune is not None:
  385. input_spec = [
  386. x
  387. for x in input_spec
  388. if isinstance(x, paddle.static.InputSpec)
  389. and x.name in input_names_after_prune
  390. ]
  391. input_vars = [
  392. var
  393. for var in paddle.utils.flatten(inputs)
  394. if isinstance(var, (Variable, paddle.pir.Value))
  395. ]
  396. input_var_names = [
  397. var.name
  398. for var in paddle.utils.flatten(inputs)
  399. if isinstance(var, (Variable, paddle.pir.Value))
  400. ]
  401. if input_spec is None:
  402. # no prune
  403. return input_vars, input_var_names
  404. else:
  405. # filter out non-tensor type spec infos.
  406. input_spec = [
  407. spec
  408. for spec in input_spec
  409. if isinstance(spec, paddle.static.InputSpec)
  410. ]
  411. result_var_list = []
  412. result_name_list = []
  413. if len(input_spec) == len(input_var_names):
  414. # no prune
  415. result_var_list = input_vars
  416. result_name_list = input_var_names
  417. # if input spec name not in input_var_names, only raise warning
  418. for spec in input_spec:
  419. if spec.name is None:
  420. warnings.warn(name_none_error % spec)
  421. elif spec.name not in input_var_names:
  422. warnings.warn(name_no_exists_error % spec.name)
  423. else:
  424. # do nothing
  425. pass
  426. else:
  427. # prune
  428. for spec in input_spec:
  429. if spec.name is None:
  430. # name is None, the input_spec only can be InputSpec
  431. raise ValueError(name_none_error % spec)
  432. elif spec.name not in input_var_names:
  433. # the input_spec can be `InputSpec` or `Tensor`
  434. raise ValueError(name_no_exists_error % spec.name)
  435. else:
  436. result_var_list.append(spec)
  437. result_name_list.append(spec.name)
  438. return result_var_list, result_name_list
  439. def _get_output_vars(outputs, output_spec, with_hook=False):
  440. name_no_exists_error = (
  441. "The tensor `%s` does not exists. "
  442. "Please make sure the name of example Tensor "
  443. "in configs.output_spec is the output tensor of "
  444. "Layer.forward method."
  445. )
  446. output_spec_is_not_value_error = (
  447. "tensor `%s` is not support in pir mode, "
  448. "because pir value has no name sometimes, especially as ouptut,"
  449. "so we can't check tensor's name with output var name, please"
  450. "change as pir.value(to_static layer's output)"
  451. "or int(the position of to_static layer's output)"
  452. )
  453. if output_spec and with_hook:
  454. raise RuntimeError(
  455. "Currently not support specify output_spec while founding pre/post hooks in your outermost layer."
  456. )
  457. result_list = []
  458. if use_pir_api():
  459. from paddle.autograd.backward_utils import ValueSet
  460. for var in paddle.utils.flatten(outputs):
  461. if isinstance(var, paddle.pir.Value):
  462. result_list.append(var)
  463. if output_spec is not None:
  464. output_size = len(result_list)
  465. if len(output_spec) == output_size:
  466. for var in output_spec:
  467. if not isinstance(var, (paddle.pir.Value, int)):
  468. warnings.warn(output_spec_is_not_value_error % var.name)
  469. else:
  470. if var not in ValueSet(result_list):
  471. warnings.warn(name_no_exists_error % var.name)
  472. else:
  473. result_set = ValueSet(result_list)
  474. part_result_list = []
  475. for var in output_spec:
  476. if isinstance(var, paddle.pir.Value):
  477. if var not in result_set:
  478. raise ValueError(name_no_exists_error % var.name)
  479. else:
  480. part_result_list.append(var)
  481. elif isinstance(var, int):
  482. if var >= output_size:
  483. raise ValueError(
  484. "position %d should smaller than output's size % d",
  485. var,
  486. output_size,
  487. )
  488. else:
  489. part_result_list.append(result_list[var])
  490. else:
  491. raise ValueError(
  492. output_spec_is_not_value_error % var.name
  493. )
  494. return part_result_list
  495. else:
  496. output_vars_dict = OrderedDict()
  497. for var in paddle.utils.flatten(outputs):
  498. if isinstance(var, (Variable)):
  499. output_vars_dict[var.name] = var
  500. if output_spec is None:
  501. result_list = list(output_vars_dict.values())
  502. elif output_spec is not None and len(output_spec) == len(
  503. output_vars_dict
  504. ):
  505. result_list = list(output_vars_dict.values())
  506. for var in output_spec:
  507. if var.name not in output_vars_dict:
  508. warnings.warn(name_no_exists_error % var.name)
  509. else:
  510. for var in output_spec:
  511. if var.name not in output_vars_dict:
  512. raise ValueError(name_no_exists_error % var.name)
  513. else:
  514. result_list.append(output_vars_dict[var.name])
  515. return result_list
  516. # NOTE(chenweihang): [ Handling of use cases of API paddle.jit.load ]
  517. # `paddle.jit.load` may be used to load saved results of:
  518. # 1. Expected cases:
  519. # - paddle.jit.save
  520. # - paddle.static.save_inference_model
  521. # 2. Error cases:
  522. # - paddle.save: no .pdmodel for prefix
  523. # - paddle.static.save: no .pdiparams but .pdparams exists
  524. # - paddle.base.io.save_params/save_persistables: no __model__
  525. # TODO(chenweihang): polish error message in above error cases
  526. def _build_load_path_and_config(path, config):
  527. # NOTE(chenweihang): If both [prefix save format] and [directory save format] exist,
  528. # raise error, avoid confusing behavior
  529. if use_pir_api():
  530. model_suffix = PIR_INFER_MODEL_SUFFIX
  531. else:
  532. model_suffix = INFER_MODEL_SUFFIX
  533. prefix_format_path = path + model_suffix
  534. prefix_format_exist = os.path.exists(prefix_format_path)
  535. directory_format_exist = os.path.isdir(path)
  536. if prefix_format_exist and directory_format_exist:
  537. raise ValueError(
  538. f"The {path}.pdmodel and {path} directory exist at the same time, "
  539. "don't know which one to load, please make sure that the specified target "
  540. "of ``path`` is unique."
  541. )
  542. elif not prefix_format_exist and not directory_format_exist:
  543. raise ValueError(
  544. f"The ``path`` ({path}) to load model not exists. "
  545. "Please make sure that *.pdmodel exists or "
  546. "don't using ``skip_forward=True`` to jit.save."
  547. )
  548. else:
  549. if prefix_format_exist:
  550. file_prefix = os.path.basename(path)
  551. model_path = os.path.dirname(path)
  552. if config.model_filename is not None:
  553. warnings.warn(
  554. "When loading the result saved with the "
  555. "specified file prefix, the ``model_filename`` config does "
  556. "not take effect."
  557. )
  558. config.model_filename = file_prefix + model_suffix
  559. if config.params_filename is not None:
  560. warnings.warn(
  561. "When loading the result saved with the "
  562. "specified file prefix, the ``params_filename`` config does "
  563. "not take effect."
  564. )
  565. config.params_filename = file_prefix + INFER_PARAMS_SUFFIX
  566. else:
  567. # Compatible with the old save_inference_model format
  568. model_path = path
  569. return model_path, config
  570. _save_pre_hooks_lock = threading.Lock()
  571. _save_pre_hooks = []
  572. class HookRemoveHelper:
  573. """A HookRemoveHelper that can be used to remove hook."""
  574. def __init__(self, hook):
  575. self._hook = hook
  576. def remove(self):
  577. _remove_save_pre_hook(self._hook)
  578. def _register_save_pre_hook(hook):
  579. """
  580. Register a save pre-hook for `paddle.jit.save`.
  581. This hook will be executed before `save` function has been invoked.
  582. hook(layer, input_spec, configs) -> None
  583. - layer (Layer|function): This argument is corresponding to `layer` in `paddle.jit.save`.
  584. - input_spec (list or tuple[InputSpec|Tensor|Python built-in variable]): This argument is corresponding to `input_spec` in `paddle.jit.save`.
  585. - configs (dict): This argument is corresponding to `configs` in `paddle.jit.save`.
  586. Args:
  587. hook(function): a function registered as a save pre-hook
  588. Returns:
  589. HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()`.
  590. Examples:
  591. .. code-block:: python
  592. >>> # doctest: +SKIP('`paddle.jit.api.to_static` can not run in xdoctest')
  593. >>> import numpy as np
  594. >>> import paddle
  595. >>> IMAGE_SIZE = 256
  596. >>> CLASS_NUM = 10
  597. >>> class LinearNet(paddle.nn.Layer):
  598. ... def __init__(self):
  599. ... super().__init__()
  600. ... self._linear = paddle.nn.Linear(IMAGE_SIZE, CLASS_NUM)
  601. ...
  602. ... def forward(self, x):
  603. ... return self._linear(x)
  604. ...
  605. >>> saving_count = 0
  606. >>> def save_pre_hook(layer, input_spec, configs):
  607. ... global saving_count
  608. ... saving_count += 1
  609. ...
  610. >>> remove_handler = paddle.jit.api._register_save_pre_hook(save_pre_hook)
  611. >>> layer = LinearNet()
  612. >>> paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])])
  613. >>> print(saving_count)
  614. 1
  615. >>> remove_handler.remove()
  616. >>> paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])])
  617. >>> print(saving_count)
  618. 1
  619. """
  620. global _save_pre_hooks_lock
  621. global _save_pre_hooks
  622. _save_pre_hooks_lock.acquire()
  623. if hook not in _save_pre_hooks:
  624. _save_pre_hooks.append(hook)
  625. _save_pre_hooks_lock.release()
  626. return HookRemoveHelper(hook)
  627. def _clear_save_pre_hooks():
  628. global _save_pre_hooks_lock
  629. global _save_pre_hooks
  630. _save_pre_hooks_lock.acquire()
  631. _save_pre_hooks.clear()
  632. _save_pre_hooks_lock.release()
  633. def _remove_save_pre_hook(hook):
  634. global _save_pre_hooks_lock
  635. global _save_pre_hooks
  636. _save_pre_hooks_lock.acquire()
  637. if hook in _save_pre_hooks:
  638. _save_pre_hooks.remove(hook)
  639. _save_pre_hooks_lock.release()
  640. @wrap_decorator
  641. def _run_save_pre_hooks(func):
  642. def wrapper(layer, path, input_spec=None, **configs):
  643. global _save_pre_hooks
  644. for hook in _save_pre_hooks:
  645. hook(layer, input_spec, configs)
  646. func(layer, path, input_spec, **configs)
  647. return wrapper
  648. def _save_property(filename: str, property_vals: list[tuple[Any, str]]):
  649. """class property serialization.
  650. Args:
  651. filename (str): *.meta
  652. property_vals (list[tuple[Any, str]]): class property.
  653. """
  654. def set_property(meta, key, val):
  655. if isinstance(val, float):
  656. meta.set_float(key, val)
  657. elif isinstance(val, int):
  658. meta.set_int(key, val)
  659. elif isinstance(val, str):
  660. meta.set_string(key, val)
  661. elif isinstance(val, (tuple, list)):
  662. if isinstance(val[0], float):
  663. meta.set_floats(key, val)
  664. elif isinstance(val[0], int):
  665. meta.set_ints(key, val)
  666. elif isinstance(val[0], str):
  667. meta.set_strings(key, val)
  668. else:
  669. raise ValueError(f"Note support val type: {type(val)}")
  670. with open(filename, 'wb') as f:
  671. meta = paddle.framework.core.Property()
  672. for item in property_vals:
  673. val, key = item[0], item[1]
  674. set_property(meta, key, val)
  675. f.write(meta.serialize_to_string())
  676. @_run_save_pre_hooks
  677. @switch_to_static_graph
  678. def save(layer, path, input_spec=None, **configs):
  679. """
  680. Saves input Layer or function as ``paddle.jit.TranslatedLayer``
  681. format model, which can be used for inference or fine-tuning after loading.
  682. It will save the translated program and all related persistable
  683. variables of input Layer to given ``path`` .
  684. ``path`` is the prefix of saved objects, and the saved translated program file
  685. suffix is ``.pdmodel`` , the saved persistable variables file suffix is ``.pdiparams`` ,
  686. and here also saved some additional variable description information to a file,
  687. its suffix is ``.pdiparams.info``, these additional information is used in fine-tuning.
  688. The saved model can be loaded by follow APIs:
  689. - ``paddle.jit.load``
  690. - ``paddle.static.load_inference_model``
  691. - Other C++ inference APIs
  692. .. note::
  693. When using ``paddle.jit.save`` to save a function, parameters will not be saved. If you have to
  694. save the parameter, please pass the Layer containing function and parameter to ``paddle.jit.save``.
  695. Args:
  696. layer (Layer|function): The Layer or function to be saved.
  697. path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
  698. input_spec (list or tuple[InputSpec|Tensor|Python built-in variable], optional): Describes the input of the saved model's forward
  699. method, which can be described by InputSpec or example Tensor. Moreover, we support to specify non-tensor type argument,
  700. such as int, float, string, or list/dict of them.If None, all input variables of
  701. the original Layer's forward method would be the inputs of the saved model. Default None.
  702. **configs (dict, optional): Other save configuration options for compatibility. We do not
  703. recommend using these configurations, they may be removed in the future. If not necessary,
  704. DO NOT use them. Default None.
  705. The following options are currently supported:
  706. (1) output_spec (list[Tensor|Value|int]): Selects the output targets of the saved model,
  707. By default, all return variables of original Layer's forward method are kept as the
  708. output of the saved model. If the provided ``output_spec`` list is not all output variables,
  709. the saved model will be pruned according to the given ``output_spec`` list.
  710. in pir mode, Tensor is not supported, because value has no name in most cases,
  711. which can't be used to judge which tensor corresponds to which value; the value can't be found
  712. if the saved program is not the same as the program that includes output_spec, so we need to
  713. use the position of the output.
  714. Returns:
  715. None
  716. Examples:
  717. .. code-block:: python
  718. >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest')
  719. >>> # example 1: save layer
  720. >>> import numpy as np
  721. >>> import paddle
  722. >>> import paddle.nn as nn
  723. >>> import paddle.optimizer as opt
  724. >>> BATCH_SIZE = 16
  725. >>> BATCH_NUM = 4
  726. >>> EPOCH_NUM = 4
  727. >>> IMAGE_SIZE = 784
  728. >>> CLASS_NUM = 10
  729. >>> # define a random dataset
  730. >>> class RandomDataset(paddle.io.Dataset):
  731. ... def __init__(self, num_samples):
  732. ... self.num_samples = num_samples
  733. ...
  734. ... def __getitem__(self, idx):
  735. ... image = np.random.random([IMAGE_SIZE]).astype('float32')
  736. ... label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
  737. ... return image, label
  738. ...
  739. ... def __len__(self):
  740. ... return self.num_samples
  741. >>> class LinearNet(nn.Layer):
  742. ... def __init__(self):
  743. ... super().__init__()
  744. ... self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
  745. ...
  746. ... @paddle.jit.to_static
  747. ... def forward(self, x):
  748. ... return self._linear(x)
  749. >>> def train(layer, loader, loss_fn, opt):
  750. ... for epoch_id in range(EPOCH_NUM):
  751. ... for batch_id, (image, label) in enumerate(loader()):
  752. ... out = layer(image)
  753. ... loss = loss_fn(out, label)
  754. ... loss.backward()
  755. ... opt.step()
  756. ... opt.clear_grad()
  757. ... print("Epoch {} batch {}: loss = {}".format(
  758. ... epoch_id, batch_id, np.mean(loss.numpy())))
  759. >>> # 1. train & save model.
  760. >>> # create network
  761. >>> layer = LinearNet()
  762. >>> loss_fn = nn.CrossEntropyLoss()
  763. >>> adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
  764. >>> # create data loader
  765. >>> dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
  766. >>> loader = paddle.io.DataLoader(dataset,
  767. ... batch_size=BATCH_SIZE,
  768. ... shuffle=True,
  769. ... drop_last=True,
  770. ... num_workers=2
  771. ... )
  772. >>> # train
  773. >>> train(layer, loader, loss_fn, adam)
  774. >>> # save
  775. >>> path = "example_model/linear"
  776. >>> paddle.jit.save(layer, path)
  777. >>> # example 2: save function
  778. >>> import paddle
  779. >>> from paddle.static import InputSpec
  780. >>> def save_function():
  781. ... @paddle.jit.to_static
  782. ... def fun(inputs):
  783. ... return paddle.tanh(inputs)
  784. ...
  785. ... path = 'test_jit_save_load_function_1/func'
  786. ... inps = paddle.rand([3, 6])
  787. ... origin = fun(inps)
  788. ...
  789. ... paddle.jit.save(fun, path)
  790. ... load_func = paddle.jit.load(path)
  791. ...
  792. ... load_result = load_func(inps)
  793. ... print((load_result - origin).abs().max() < 1e-10)
  794. >>> save_function()
  795. """
  796. # 1. input build & check
  797. prog_translator = ProgramTranslator()
  798. is_prim_infer = core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled()
  799. if not prog_translator.enable_to_static:
  800. raise RuntimeError(
  801. "The paddle.jit.save doesn't work when setting 'paddle.jit.enable_to_static' to False."
  802. )
  803. if not (
  804. isinstance(layer, (Layer, StaticFunction)) or inspect.isfunction(layer)
  805. ):
  806. raise TypeError(
  807. f"The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is {type(layer)}."
  808. )
  809. elif inspect.isfunction(layer) or isinstance(layer, StaticFunction):
  810. warnings.warn(
  811. 'What you save is a function, and `jit.save` will generate the name of the model file according to `path` you specify. When loading these files with `jit.load`, you get a `TranslatedLayer` whose inference result is the same as the inference result of the function you saved.'
  812. )
  813. # NOTE(chenweihang): If the input layer be wrapped by DataParallel,
  814. # the args and kwargs of forward method will can't be parsed by
  815. # function_spec, so here we save DataParallel._layers instead
  816. # DataParallel it self
  817. # NOTE(chenweihang): using inner_layer, do not change input layer
  818. if isinstance(layer, paddle.DataParallel):
  819. inner_layer = layer._layers
  820. else:
  821. inner_layer = layer
  822. # path check
  823. file_prefix = os.path.basename(path)
  824. if file_prefix == "":
  825. raise ValueError(
  826. "The input path MUST be format of dirname/file_prefix "
  827. "[dirname\\file_prefix in Windows system], but received "
  828. "file_prefix is empty string."
  829. )
  830. dirname = os.path.dirname(path)
  831. if dirname and not os.path.exists(dirname):
  832. os.makedirs(dirname)
  833. # avoid change user given input_spec
  834. inner_input_spec = None
  835. if input_spec is not None:
  836. if isinstance(layer, Layer):
  837. for attr_func in dir(inner_layer):
  838. static_func = getattr(inner_layer, attr_func, None)
  839. if (
  840. isinstance(static_func, StaticFunction)
  841. and 'forward' != attr_func
  842. ):
  843. raise ValueError(
  844. f"If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is {type(input_spec)}."
  845. )
  846. if not isinstance(input_spec, (list, tuple)):
  847. raise TypeError(
  848. f"The input input_spec should be 'list', but received input_spec's type is {type(input_spec)}."
  849. )
  850. inner_input_spec = []
  851. for var in paddle.utils.flatten(input_spec):
  852. if isinstance(var, paddle.static.InputSpec):
  853. inner_input_spec.append(var)
  854. elif isinstance(
  855. var, (core.eager.Tensor, Variable, paddle.pir.Value)
  856. ):
  857. inner_input_spec.append(
  858. paddle.static.InputSpec.from_tensor(var)
  859. )
  860. else:
  861. # NOTE(Aurelius84): Support non-Tensor type in `input_spec`.
  862. inner_input_spec.append(var)
  863. # parse configs
  864. configs = _parse_save_configs(configs)
  865. # whether outermost layer has pre/post hook, if does, we need also save
  866. # these operators in program.
  867. with_hook = configs.with_hook
  868. combine_params = configs.combine_params
  869. if combine_params:
  870. configs._program_only = True
  871. scope = core.Scope()
  872. extra_var_info = {}
  873. if isinstance(layer, Layer):
  874. functions = list(set(dir(inner_layer)))
  875. functions = sorted(functions)
  876. if inner_layer._forward_pre_hooks or inner_layer._forward_post_hooks:
  877. with_hook = True
  878. else:
  879. # layer is function
  880. functions = [
  881. layer,
  882. ]
  883. combine_vars = {}
  884. combine_program = []
  885. property_vals = [] # (value, key)
  886. concrete_program = None
  887. for attr_func in functions:
  888. if isinstance(layer, Layer):
  889. static_func = get_ast_static_function(
  890. getattr(inner_layer, attr_func, None)
  891. )
  892. if isinstance(static_func, StaticFunction):
  893. if static_func.is_property:
  894. # property method to be exported
  895. immediate_val = static_func()
  896. property_vals.append(
  897. (
  898. immediate_val,
  899. layer.__class__.__name__ + '.' + attr_func,
  900. )
  901. )
  902. continue
  903. concrete_program = (
  904. static_func.concrete_program_specify_input_spec(
  905. inner_input_spec,
  906. with_hook=with_hook,
  907. is_prim_infer=is_prim_infer,
  908. )
  909. )
  910. elif 'forward' == attr_func:
  911. if configs.skip_forward:
  912. # do not jit.save forward function
  913. continue
  914. # transform in jit.save, if input_spec is incomplete, declarative will throw error
  915. # inner_input_spec is list[InputSpec], it should be packed with same structure
  916. # as original input_spec here.
  917. if inner_input_spec:
  918. inner_input_spec = paddle.utils.pack_sequence_as(
  919. input_spec, inner_input_spec
  920. )
  921. static_forward = to_static(
  922. inner_layer.forward,
  923. input_spec=inner_input_spec,
  924. full_graph=True,
  925. )
  926. concrete_program = (
  927. static_forward.concrete_program_specify_input_spec(
  928. with_hook=with_hook, is_prim_infer=is_prim_infer
  929. )
  930. )
  931. # the input_spec has been used in declarative, which is equal to
  932. # @to_static with input_spec and jit.save without input_spec,
  933. # avoid needless warning
  934. inner_input_spec = None
  935. else:
  936. continue
  937. else:
  938. # When layer is a function
  939. if isinstance(attr_func, StaticFunction):
  940. static_func = get_ast_static_function(attr_func)
  941. if static_func.is_property:
  942. # property method to be exported
  943. immediate_val = static_func()
  944. property_vals.append((immediate_val, static_func))
  945. continue
  946. concrete_program = (
  947. static_func.concrete_program_specify_input_spec(
  948. inner_input_spec, is_prim_infer=is_prim_infer
  949. )
  950. )
  951. else:
  952. static_func = get_ast_static_function(attr_func)
  953. if inner_input_spec:
  954. inner_input_spec = paddle.utils.pack_sequence_as(
  955. input_spec, inner_input_spec
  956. )
  957. static_function = to_static(
  958. static_func,
  959. input_spec=inner_input_spec,
  960. full_graph=True,
  961. )
  962. concrete_program = static_function.concrete_program
  963. if static_function._class_instance is None:
  964. warnings.warn(
  965. f'`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {layer} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`'
  966. )
  967. # when save multi `StaticFunction`, all `StaticFunction` share params.
  968. dygraph_state_dict = None
  969. if isinstance(inner_layer, Layer):
  970. dygraph_state_dict = inner_layer.to_static_state_dict()
  971. elif isinstance(attr_func, StaticFunction):
  972. if static_func._class_instance:
  973. dygraph_state_dict = (
  974. static_func._class_instance.to_static_state_dict()
  975. )
  976. if dygraph_state_dict:
  977. # NOTE(chenweihang): we maintain the mapping of variable name to
  978. # structured name, the buffer variable (non-persistable)
  979. # saved to inference program may not need by dygraph Layer,
  980. # we only record the state_dict variable's structured name
  981. state_names_dict = {}
  982. state_var_dict = {}
  983. for structured_name, var in dygraph_state_dict.items():
  984. state_names_dict[var.name] = structured_name
  985. state_var_dict[var.name] = var
  986. # 3. share parameters from Layer to scope & record var info
  987. with dygraph.guard():
  988. if use_pir_api():
  989. for tensor, value in zip(*concrete_program.parameters):
  990. if not value.persistable:
  991. continue
  992. param_or_buffer_tensor = scope.var(value.name).get_tensor()
  993. src_tensor = (
  994. state_var_dict[tensor.name].value().get_tensor()
  995. )
  996. param_or_buffer_tensor._share_data_with(src_tensor)
  997. else:
  998. for param_or_buffer in concrete_program.parameters:
  999. # share to scope
  1000. if param_or_buffer.type == core.VarDesc.VarType.VOCAB:
  1001. scr_tensor = param_or_buffer.value().get_map_tensor()
  1002. tgt_var = scope.var(param_or_buffer.name)
  1003. tgt_var.set_vocab(scr_tensor)
  1004. else:
  1005. param_or_buffer_tensor = scope.var(
  1006. param_or_buffer.name
  1007. ).get_tensor()
  1008. # src_tensor = param_or_buffer.value().get_tensor()
  1009. src_tensor = (
  1010. state_var_dict[param_or_buffer.name]
  1011. .value()
  1012. .get_tensor()
  1013. )
  1014. param_or_buffer_tensor._share_data_with(src_tensor)
  1015. # record var info
  1016. if param_or_buffer.name not in extra_var_info:
  1017. extra_info_dict = {}
  1018. if param_or_buffer.name in state_names_dict:
  1019. extra_info_dict[
  1020. 'structured_name'
  1021. ] = state_names_dict[param_or_buffer.name]
  1022. extra_info_dict[
  1023. 'stop_gradient'
  1024. ] = param_or_buffer.stop_gradient
  1025. if isinstance(param_or_buffer, EagerParamBase):
  1026. extra_info_dict[
  1027. 'trainable'
  1028. ] = param_or_buffer.trainable
  1029. extra_var_info[param_or_buffer.name] = extra_info_dict
  1030. # 4. build input & output of save_inference_model
  1031. # NOTE(chenweihang): [ Get input variables name ]
  1032. # There are two cases, whether to prune the inputs or not
  1033. # - not prune inputs (recommend):
  1034. # - the len(input_spec) == len((concrete_program.inputs) - 1
  1035. # - here can use concrete_program.inputs directly
  1036. # - prune inputs:
  1037. # - the input_spec length < len((concrete_program.inputs) - 1
  1038. # - the input_spec's name should be in concrete_program.inputs
  1039. input_vars, input_var_names = _get_input_var_and_names(
  1040. concrete_program.inputs,
  1041. inner_input_spec,
  1042. configs.input_names_after_prune,
  1043. )
  1044. # NOTE(chenweihang): [ Get output variables ]
  1045. # the rule is like [ Get input variables name ]. For output var,
  1046. # we only support Tensor spec, and actually, we only need the
  1047. # var name of output, and we don't recommended to use output_spec
  1048. # NOTE(Ruting): in pir mode, Tensor is not supported, because value has no name in most cases,
  1049. # which can't be used to judge which tensor corresponds to which value; the value can't be found
  1050. # if the saved program is not the same as the program that includes output_spec, so we need to
  1051. # use the position of the output.
  1052. output_vars = _get_output_vars(
  1053. concrete_program.outputs, configs.output_spec, with_hook
  1054. )
  1055. # 5. save inference model
  1056. # construct new save_inference_model arguments
  1057. model_path = dirname
  1058. # NOTE(chenweihang): because prefix contains model and params filename,
  1059. # so we don't support set model_filename & params_filename
  1060. if 'forward' == attr_func or not isinstance(layer, Layer):
  1061. model_filename = file_prefix + INFER_MODEL_SUFFIX
  1062. params_filename = file_prefix + INFER_PARAMS_SUFFIX
  1063. path_prefix = file_prefix
  1064. else:
  1065. model_filename = file_prefix + '.' + attr_func + INFER_MODEL_SUFFIX
  1066. params_filename = (
  1067. file_prefix + '.' + attr_func + INFER_PARAMS_SUFFIX
  1068. )
  1069. path_prefix = file_prefix + '.' + attr_func
  1070. file_path = os.path.join(model_path, path_prefix)
  1071. with scope_guard(scope):
  1072. if use_pir_api():
  1073. value_map = paddle.pir.IrMapping()
  1074. clone_program = concrete_program.main_program.clone(value_map)
  1075. clone_input_vars = []
  1076. for v in input_vars:
  1077. if type(v) is paddle.static.InputSpec:
  1078. name = v.name
  1079. for op in clone_program.global_block().ops:
  1080. if (
  1081. op.name() == 'pd_op.data'
  1082. and op.attrs()["name"] == name
  1083. ):
  1084. clone_input_vars.append(op.result(0))
  1085. else:
  1086. clone_input_vars.append(value_map.look_up(v))
  1087. clone_output_vars = [value_map.look_up(v) for v in output_vars]
  1088. else:
  1089. input_vars = [
  1090. concrete_program.main_program.global_block().var(name)
  1091. for name in input_var_names
  1092. ]
  1093. clone_program = concrete_program.main_program.clone()
  1094. clone_input_vars = input_vars
  1095. clone_output_vars = output_vars
  1096. save_inference_model(
  1097. path_prefix=file_path,
  1098. feed_vars=clone_input_vars,
  1099. fetch_vars=clone_output_vars,
  1100. executor=Executor(_current_expected_place()),
  1101. program=clone_program,
  1102. clip_extra=configs.clip_extra,
  1103. skip_prune_program=configs.skip_prune_program,
  1104. )
  1105. if combine_params:
  1106. if use_pir_api():
  1107. # NOTE(Ruting): concrete_program has been pruned when init partialProgramLayer,
  1108. # so we do not neet to prune again.
  1109. for var in concrete_program.main_program.list_vars():
  1110. if var.persistable:
  1111. combine_vars[var.name] = var
  1112. # NOTE(Ruting): concrete_program will delete after this loop item,
  1113. # value delete at the same time, so we use list to Extend its lifecycle
  1114. combine_program.append(concrete_program.main_program)
  1115. else:
  1116. clone_main_program = concrete_program.main_program.clone()
  1117. clone_main_program = clone_main_program._prune_with_input(
  1118. input_var_names, output_vars
  1119. )
  1120. for block in clone_main_program.blocks:
  1121. combine_vars.update(block.vars)
  1122. # save shared params
  1123. if combine_params:
  1124. # sort vars by name
  1125. combine_vars = sorted(combine_vars.items(), key=lambda item: item[0])
  1126. ordered_vars = []
  1127. for name, var in combine_vars:
  1128. ordered_vars.append(var)
  1129. params_filename = file_prefix + INFER_PARAMS_SUFFIX
  1130. with scope_guard(scope):
  1131. if use_pir_api():
  1132. paddle.static.save_vars(
  1133. Executor(_current_expected_place()),
  1134. dirname=model_path,
  1135. vars=ordered_vars,
  1136. filename=params_filename,
  1137. )
  1138. else:
  1139. paddle.static.save_vars(
  1140. Executor(_current_expected_place()),
  1141. dirname=model_path,
  1142. vars=list(
  1143. filter(
  1144. paddle.framework.io_utils.is_persistable,
  1145. ordered_vars,
  1146. )
  1147. ),
  1148. filename=params_filename,
  1149. )
  1150. # save property
  1151. property_save_path = os.path.join(
  1152. os.path.normpath(model_path), file_prefix + INFER_PROPERTY_SUFFIX
  1153. )
  1154. _save_property(property_save_path, property_vals)
  1155. # NOTE(chenweihang): [ Save extra variable info ]
  1156. # save_inference_model will lose some important variable information, including:
  1157. # - Variable name and correspondence (when saved variables as one file)
  1158. # - Variable.stop_gradient information
  1159. # - Which persistent variable are parameter and which are not
  1160. # - Parameter.trainable information
  1161. #
  1162. # The lost information cannot be recovered when it is loaded again,
  1163. # so if we want to perform fine-tune after loading, we may need to
  1164. # configure redundant information to proceed.
  1165. #
  1166. # Due to compatibility issues, we cannot change the original storage structure,
  1167. # but we can save these information in `jit.save` without changing the original
  1168. # storage to improve user experience. So we save extra information into
  1169. # file `***.pdiparams.info`
  1170. # "layer" can only be Layer or function or StaticFunction.
  1171. contain_parameter = False
  1172. if concrete_program is not None:
  1173. for var in concrete_program.main_program.list_vars():
  1174. if use_pir_api():
  1175. is_persistable = (
  1176. var.get_defining_op().has_attr("persistable")
  1177. and var.get_defining_op().attrs()["persistable"] is True
  1178. )
  1179. contain_parameter |= is_persistable
  1180. else:
  1181. contain_parameter |= isinstance(var, Parameter)
  1182. if (isinstance(layer, Layer) or contain_parameter) and extra_var_info:
  1183. with scope_guard(scope):
  1184. extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
  1185. with open(extra_var_info_path, 'wb') as f:
  1186. pickle.dump(extra_var_info, f, protocol=2)
  1187. @dygraph_only
  1188. def load(path, **configs):
  1189. """
  1190. :api_attr: imperative
  1191. Load model saved by ``paddle.jit.save`` or ``paddle.static.save_inference_model`` or
  1192. paddle 1.x API ``paddle.static.save_inference_model`` as ``paddle.jit.TranslatedLayer``,
  1193. then performing inference or fine-tune training.
  1194. .. note::
  1195. If you load model saved by ``paddle.static.save_inference_model`` ,
  1196. there will be the following limitations when using it in fine-tuning:
  1197. 1. Imperative mode do not support LoDTensor. All original model's feed targets or parameters that depend on LoD are temporarily unavailable.
  1198. 2. All saved model's feed targets need to be passed into TranslatedLayer's forward function.
  1199. 3. The variable's ``stop_gradient`` information is lost and can not be recovered.
  1200. 4. The parameter's ``trainable`` information is lost and can not be recovered.
  1201. Args:
  1202. path (str): The path prefix to load model. The format is ``dirname/file_prefix`` or ``file_prefix`` .
  1203. **configs (dict, optional): Other load configuration options for compatibility. We do not
  1204. recommend using these configurations, they may be removed in the future. If not necessary,
  1205. DO NOT use them. Default None.
  1206. The following options are currently supported:
  1207. (1) model_filename (str): The inference model file name of the paddle 1.x
  1208. ``save_inference_model`` save format. Default file name is :code:`__model__` .
  1209. (2) params_filename (str): The persistable variables file name of the paddle 1.x
  1210. ``save_inference_model`` save format. No default file name, save variables separately
  1211. by default.
  1212. Returns:
  1213. TranslatedLayer: A Layer object can run saved translated model.
  1214. Examples:
  1215. 1. Load model saved by ``paddle.jit.save`` then performing inference and fine-tune training.
  1216. .. code-block:: python
  1217. :name: code-example1
  1218. >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest')
  1219. >>> import numpy as np
  1220. >>> import paddle
  1221. >>> import paddle.nn as nn
  1222. >>> import paddle.optimizer as opt
  1223. >>> BATCH_SIZE = 16
  1224. >>> BATCH_NUM = 4
  1225. >>> EPOCH_NUM = 4
  1226. >>> IMAGE_SIZE = 784
  1227. >>> CLASS_NUM = 10
  1228. >>> # define a random dataset
  1229. >>> class RandomDataset(paddle.io.Dataset):
  1230. ... def __init__(self, num_samples):
  1231. ... self.num_samples = num_samples
  1232. ...
  1233. ... def __getitem__(self, idx):
  1234. ... image = np.random.random([IMAGE_SIZE]).astype('float32')
  1235. ... label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
  1236. ... return image, label
  1237. ...
  1238. ... def __len__(self):
  1239. ... return self.num_samples
  1240. >>> class LinearNet(nn.Layer):
  1241. ... def __init__(self):
  1242. ... super().__init__()
  1243. ... self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
  1244. ...
  1245. ... @paddle.jit.to_static
  1246. ... def forward(self, x):
  1247. ... return self._linear(x)
  1248. ...
  1249. >>> def train(layer, loader, loss_fn, opt):
  1250. ... for epoch_id in range(EPOCH_NUM):
  1251. ... for batch_id, (image, label) in enumerate(loader()):
  1252. ... out = layer(image)
  1253. ... loss = loss_fn(out, label)
  1254. ... loss.backward()
  1255. ... opt.step()
  1256. ... opt.clear_grad()
  1257. ... print("Epoch {} batch {}: loss = {}".format(
  1258. ... epoch_id, batch_id, np.mean(loss.numpy())))
  1259. >>> # 1. train & save model.
  1260. >>> # create network
  1261. >>> layer = LinearNet()
  1262. >>> loss_fn = nn.CrossEntropyLoss()
  1263. >>> adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
  1264. >>> # create data loader
  1265. >>> dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
  1266. >>> loader = paddle.io.DataLoader(
  1267. ... dataset,
  1268. ... batch_size=BATCH_SIZE,
  1269. ... shuffle=True,
  1270. ... drop_last=True,
  1271. ... num_workers=2
  1272. ... )
  1273. >>> # train
  1274. >>> train(layer, loader, loss_fn, adam)
  1275. >>> # save
  1276. >>> path = "example_model/linear"
  1277. >>> paddle.jit.save(layer, path)
  1278. >>> # 2. load model
  1279. >>> # load
  1280. >>> loaded_layer = paddle.jit.load(path)
  1281. >>> # inference
  1282. >>> loaded_layer.eval()
  1283. >>> x = paddle.randn([1, IMAGE_SIZE], 'float32')
  1284. >>> pred = loaded_layer(x)
  1285. >>> # fine-tune
  1286. >>> loaded_layer.train()
  1287. >>> adam = opt.Adam(learning_rate=0.001, parameters=loaded_layer.parameters())
  1288. >>> train(loaded_layer, loader, loss_fn, adam)
  1289. 2. Load model saved by ``paddle.static.save_inference_model`` then performing and fine-tune training.
  1290. .. code-block:: python
  1291. :name: code-example2
  1292. >>> # doctest: +SOLO('can not use multiprocessing testing `DataLoader`')
  1293. >>> import numpy as np
  1294. >>> import paddle
  1295. >>> import paddle.static as static
  1296. >>> import paddle.nn as nn
  1297. >>> import paddle.optimizer as opt
  1298. >>> import paddle.nn.functional as F
  1299. >>> BATCH_SIZE = 16
  1300. >>> BATCH_NUM = 4
  1301. >>> EPOCH_NUM = 4
  1302. >>> IMAGE_SIZE = 784
  1303. >>> CLASS_NUM = 10
  1304. >>> # define a random dataset
  1305. >>> class RandomDataset(paddle.io.Dataset):
  1306. ... def __init__(self, num_samples):
  1307. ... self.num_samples = num_samples
  1308. ...
  1309. ... def __getitem__(self, idx):
  1310. ... image = np.random.random([IMAGE_SIZE]).astype('float32')
  1311. ... label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
  1312. ... return image, label
  1313. ...
  1314. ... def __len__(self):
  1315. ... return self.num_samples
  1316. >>> paddle.enable_static()
  1317. >>> image = static.data(name='image', shape=[None, 784], dtype='float32')
  1318. >>> label = static.data(name='label', shape=[None, 1], dtype='int64')
  1319. >>> pred = static.nn.fc(x=image, size=10, activation='softmax')
  1320. >>> loss = F.cross_entropy(input=pred, label=label)
  1321. >>> avg_loss = paddle.mean(loss)
  1322. >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001)
  1323. >>> optimizer.minimize(avg_loss)
  1324. >>> place = paddle.CPUPlace()
  1325. >>> exe = static.Executor(place)
  1326. >>> exe.run(static.default_startup_program())
  1327. >>> # create data loader
  1328. >>> dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
  1329. >>> loader = paddle.io.DataLoader(dataset,
  1330. ... feed_list=[image, label],
  1331. ... places=place,
  1332. ... batch_size=BATCH_SIZE,
  1333. ... shuffle=True,
  1334. ... drop_last=True,
  1335. ... return_list=False,
  1336. ... num_workers=2
  1337. ... )
  1338. >>> # 1. train and save inference model
  1339. >>> for data in loader():
  1340. ... exe.run(
  1341. ... static.default_main_program(),
  1342. ... feed=data,
  1343. ... fetch_list=[avg_loss]
  1344. ... )
  1345. >>> model_path = "fc.example.model"
  1346. >>> paddle.static.save_inference_model(
  1347. ... model_path,
  1348. ... [image],
  1349. ... [pred],
  1350. ... exe
  1351. ... )
  1352. >>> # 2. load model
  1353. >>> # enable dygraph mode
  1354. >>> paddle.disable_static(place)
  1355. >>> # load
  1356. >>> fc = paddle.jit.load(model_path)
  1357. >>> # inference
  1358. >>> fc.eval()
  1359. >>> x = paddle.randn([1, IMAGE_SIZE], 'float32')
  1360. >>> pred = fc(x)
  1361. >>> # fine-tune
  1362. >>> fc.train()
  1363. >>> loss_fn = nn.CrossEntropyLoss()
  1364. >>> adam = opt.Adam(learning_rate=0.001, parameters=fc.parameters())
  1365. >>> loader = paddle.io.DataLoader(dataset,
  1366. ... places=place,
  1367. ... batch_size=BATCH_SIZE,
  1368. ... shuffle=True,
  1369. ... drop_last=True,
  1370. ... num_workers=2
  1371. ... )
  1372. >>> for epoch_id in range(EPOCH_NUM):
  1373. ... for batch_id, (image, label) in enumerate(loader()):
  1374. ... out = fc(image)
  1375. ... loss = loss_fn(out, label)
  1376. ... loss.backward()
  1377. ... adam.step()
  1378. ... adam.clear_grad()
  1379. ... print("Epoch {} batch {}: loss = {}".format(
  1380. ... epoch_id, batch_id, np.mean(loss.numpy())))
  1381. """
  1382. # 1. construct correct config
  1383. config = _parse_load_config(configs)
  1384. model_path, config = _build_load_path_and_config(path, config)
  1385. if use_pir_api():
  1386. return PirTranslatedLayer._construct(model_path, config)
  1387. else:
  1388. return TranslatedLayer._construct(model_path, config)
  1389. def set_dynamic_shape(variable, shape_list):
  1390. if paddle.base.dygraph.base.in_to_static_mode():
  1391. if isinstance(variable, paddle.base.framework.Variable):
  1392. variable.desc.set_shape(shape_list)
  1393. elif isinstance(variable, paddle.pir.Value):
  1394. variable.set_shape(shape_list)
  1395. else:
  1396. raise TypeError(
  1397. "In to_static mode, variable must be a Variable or Value"
  1398. )
  1399. else:
  1400. # in dygraph mode, dynamic shape is not needed, just do nothing.
  1401. return
  1402. def get_ast_static_function(function):
  1403. if isinstance(function, SymbolicStaticFunction):
  1404. if function._class_instance:
  1405. dygraph_function = types.MethodType(
  1406. function._dygraph_function, function._class_instance
  1407. )
  1408. else:
  1409. dygraph_function = function._dygraph_function
  1410. if function._function_spec._input_spec is None:
  1411. ast_static_function = ASTStaticFunction(
  1412. dygraph_function,
  1413. function.last_call_input_spec,
  1414. **function._kwargs,
  1415. )
  1416. return ast_static_function
  1417. else:
  1418. ast_static_function = ASTStaticFunction(
  1419. dygraph_function,
  1420. function._function_spec._input_spec,
  1421. **function._kwargs,
  1422. )
  1423. return ast_static_function
  1424. return function