translated_layer.py 60 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pickle
  16. import numpy as np
  17. import paddle
  18. from paddle import _legacy_C_ops
  19. from paddle.base import backward, core, framework, unique_name
  20. from paddle.base.data_feeder import check_type
  21. from paddle.base.dygraph.base import switch_to_static_graph
  22. from paddle.base.framework import OpProtoHolder
  23. from paddle.framework import in_dynamic_mode
  24. from paddle.jit.dy2static.partial_program import (
  25. LazyInitialized,
  26. add_build_strategy_for,
  27. )
  28. from paddle.jit.dy2static.utils import construct_grad_names
  29. from paddle.nn.layer import layers
  30. __all__ = []
  31. INFER_MODEL_SUFFIX = ".pdmodel"
  32. INFER_PARAMS_SUFFIX = ".pdiparams"
  33. INFER_PARAMS_INFO_SUFFIX = ".pdiparams.info"
  34. INFER_PROPERTY_SUFFIX = '.meta'
  35. LOADED_VAR_SUFFIX = "load"
  36. PARAMETER_NAME_PREFIX = "param"
  37. BUFFER_NAME_PREFIX = "buffer"
  38. def _load_program_desc(model_file_path):
  39. # 1. parse program desc
  40. with open(model_file_path, "rb") as f:
  41. program_desc_str = f.read()
  42. program_desc = core.ProgramDesc(program_desc_str)
  43. if not core._is_program_version_supported(program_desc._version()):
  44. raise ValueError(
  45. "Unsupported program version: %d\n" % program_desc._version()
  46. )
  47. return program_desc
  48. def _is_persistable(var_desc):
  49. if (
  50. var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
  51. or var_desc.type() == core.VarDesc.VarType.FETCH_LIST
  52. or var_desc.type() == core.VarDesc.VarType.READER
  53. or var_desc.type() == core.VarDesc.VarType.RAW
  54. ):
  55. return False
  56. return var_desc.persistable()
  57. def _is_parameter(persistable_var_desc, program_desc):
  58. # 1. firstly, param should be input of op
  59. input_ops = [] # op can be repeated
  60. for block_idx in range(program_desc.num_blocks()):
  61. block = program_desc.block(block_idx)
  62. for op_idx in range(block.op_size()):
  63. op = block.op(op_idx)
  64. # NOTE: parameter is the input of a certain op
  65. if persistable_var_desc.name() in op.input_arg_names():
  66. input_ops.append(op)
  67. # 2. secondly, param should not be output of op or be same op's output
  68. for block_idx in range(program_desc.num_blocks()):
  69. block = program_desc.block(block_idx)
  70. for op_idx in range(block.op_size()):
  71. op = block.op(op_idx)
  72. if persistable_var_desc.name() in op.output_arg_names():
  73. # such as batch_norm_op
  74. if op in input_ops:
  75. continue
  76. else:
  77. return False
  78. return True
  79. def _get_persistable_vars(program_desc):
  80. persistable_vars = []
  81. for i in range(program_desc.num_blocks()):
  82. block = program_desc.block(i)
  83. persistable_vars.extend(list(filter(_is_persistable, block.all_vars())))
  84. return persistable_vars
  85. def _get_persistable_var_names(program_desc):
  86. """
  87. Get all persistable variable names in ProgramDesc.
  88. """
  89. var_names = []
  90. persistable_vars = _get_persistable_vars(program_desc)
  91. for var in persistable_vars:
  92. var_names.append(var.name())
  93. return var_names
  94. def _get_all_var_names(program_desc):
  95. all_var_names = set()
  96. for i in range(program_desc.num_blocks()):
  97. block = program_desc.block(i)
  98. for var in block.all_vars():
  99. all_var_names.add(var.name())
  100. return all_var_names
  101. @switch_to_static_graph
  102. def _append_loaded_suffix(name):
  103. """
  104. Append loaded suffix to the given variable name
  105. e.g. x ==> x.load_0, x.load_0 ==> x.load_0.load_0
  106. """
  107. suffix = LOADED_VAR_SUFFIX
  108. new_name = unique_name.generate_with_ignorable_key('.'.join((name, suffix)))
  109. return new_name
  110. @switch_to_static_graph
  111. def _generate_unique_var_name(prefix):
  112. return unique_name.generate_with_ignorable_key(prefix)
  113. def _append_loaded_suffix_to_var(program_desc):
  114. suffix_varname_dict = {}
  115. persistable_vars = _get_persistable_vars(program_desc)
  116. for var_desc in persistable_vars:
  117. old_name = var_desc.name()
  118. new_name = _append_loaded_suffix(var_desc.name())
  119. suffix_varname_dict[new_name] = old_name
  120. var_desc.set_name(new_name)
  121. for block_idx in range(program_desc.num_blocks()):
  122. block = program_desc.block(block_idx)
  123. block._rename_var(old_name.encode(), new_name.encode())
  124. for op_idx in range(block.op_size()):
  125. op = block.op(op_idx)
  126. op._rename_input(old_name, new_name)
  127. op._rename_output(old_name, new_name)
  128. return suffix_varname_dict
  129. @switch_to_static_graph
  130. def _generate_unique_var_name_sync_with_main_program(prefix):
  131. return unique_name.generate(prefix)
  132. def _get_loaded_var_new_old(program_desc, all_new_old_dict_all):
  133. new_old_dict = {}
  134. persistable_vars = _get_persistable_vars(program_desc)
  135. for var_desc in persistable_vars:
  136. name_new = var_desc.name()
  137. new_old_dict[name_new] = all_new_old_dict_all[name_new]
  138. return new_old_dict
  139. def _rename_var_program_desc(program_desc, include=None, exclude=None):
  140. """
  141. Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication.
  142. It is used when loading multiple program during inference.
  143. e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0. For double grad, x@GRAD ==> x_0@GRAD
  144. If 'include' is not `None`,variables in include and the corresponding
  145. double grad variables (if exist) are renamed.
  146. If 'exclude' is not `None`,variables that are in exclude and the
  147. corresponding double grad variables (if exist) are not renamed.
  148. Args:
  149. program_desc(ProgramDesc):the variables in it will be modified.
  150. include(List):list of names of variables.
  151. exclude(List):list of names of variables.
  152. Returns:
  153. tuple of (dict_rename_var_new_old, dict_rename_var_old_new)
  154. dict_rename_var_new_old is a dict mapping from new name to old name
  155. dict_rename_var_old_new is a dict mapping from old name to new name
  156. """
  157. dict_rename_var_old_new = {}
  158. dict_rename_var_new_old = {}
  159. old_names = []
  160. # Store all old names
  161. for b_idx in range(program_desc.num_blocks()):
  162. cur_block = program_desc.block(b_idx)
  163. for var in cur_block.all_vars():
  164. old_names.append(var.name())
  165. # Create dict_rename_var_new_old and dict_rename_var_old_new for non double
  166. # grad variables
  167. has_double_grad = False
  168. for b_idx in range(program_desc.num_blocks()):
  169. cur_block = program_desc.block(b_idx)
  170. for var_idx, var in enumerate(cur_block.all_vars()):
  171. name_old = var.name()
  172. is_double_grad_var = "@GRAD" in name_old
  173. has_double_grad = has_double_grad or is_double_grad_var
  174. should_rename = (
  175. (include is None or name_old in include)
  176. and (exclude is None or name_old not in exclude)
  177. and not is_double_grad_var
  178. )
  179. if should_rename:
  180. temp_name = name_old.split('_')
  181. if len(temp_name) > 1 and temp_name[-1].isnumeric():
  182. temp_name = "_".join(temp_name[:-1])
  183. else:
  184. temp_name = name_old
  185. while True:
  186. name_new = _generate_unique_var_name_sync_with_main_program(
  187. temp_name
  188. )
  189. if (
  190. name_new
  191. not in old_names[:var_idx] + old_names[var_idx + 1 :]
  192. ):
  193. break
  194. else:
  195. name_new = name_old
  196. if name_old != name_new:
  197. cur_block._rename_var(name_old.encode(), name_new.encode())
  198. if not is_double_grad_var:
  199. dict_rename_var_old_new[name_old] = name_new
  200. dict_rename_var_new_old[name_new] = name_old
  201. # Handle double grad names
  202. if has_double_grad:
  203. double_grad_rename_dict = {}
  204. for name_old in dict_rename_var_old_new:
  205. for b_idx in range(program_desc.num_blocks()):
  206. cur_block = program_desc.block(b_idx)
  207. for var_idx, var in enumerate(cur_block.all_vars()):
  208. var_name = var.name()
  209. if "@GRAD" in var_name and name_old in var_name:
  210. new_var_name = var_name.replace(
  211. name_old, dict_rename_var_old_new[name_old]
  212. )
  213. double_grad_rename_dict[var_name] = new_var_name
  214. for var_name in double_grad_rename_dict:
  215. dict_rename_var_old_new[var_name] = double_grad_rename_dict[
  216. var_name
  217. ]
  218. dict_rename_var_new_old[
  219. double_grad_rename_dict[var_name]
  220. ] = var_name
  221. # Rename on program desc
  222. for b_idx in range(program_desc.num_blocks()):
  223. cur_block = program_desc.block(b_idx)
  224. for op_idx in range(cur_block.op_size()):
  225. op = cur_block.op(op_idx)
  226. for input_arg_name in op.input_arg_names():
  227. if input_arg_name in dict_rename_var_old_new:
  228. if (
  229. input_arg_name
  230. != dict_rename_var_old_new[input_arg_name]
  231. ):
  232. op._rename_input(
  233. input_arg_name,
  234. dict_rename_var_old_new[input_arg_name],
  235. )
  236. if cur_block.has_var(input_arg_name.encode()):
  237. cur_block._rename_var(
  238. input_arg_name.encode(),
  239. dict_rename_var_old_new[
  240. input_arg_name
  241. ].encode(),
  242. )
  243. for output_arg_name in op.output_arg_names():
  244. if output_arg_name in dict_rename_var_old_new:
  245. if (
  246. output_arg_name
  247. != dict_rename_var_old_new[output_arg_name]
  248. ):
  249. op._rename_output(
  250. output_arg_name,
  251. dict_rename_var_old_new[output_arg_name],
  252. )
  253. if cur_block.has_var(output_arg_name.encode()):
  254. cur_block._rename_var(
  255. output_arg_name.encode(),
  256. dict_rename_var_old_new[
  257. output_arg_name
  258. ].encode(),
  259. )
  260. program_desc.flush()
  261. return dict_rename_var_new_old, dict_rename_var_old_new
  262. @switch_to_static_graph
  263. def _build_program_by_desc(program_desc):
  264. prog = framework.Program()
  265. prog.desc = program_desc
  266. prog.blocks = [
  267. framework.Block(prog, i) for i in range(prog.desc.num_blocks())
  268. ]
  269. prog._sync_with_cpp()
  270. return prog
  271. def _change_is_test_status(program_desc, is_test):
  272. # change all `is_test` attributes
  273. for i in range(program_desc.num_blocks()):
  274. block = program_desc.block(i)
  275. for j in range(block.op_size()):
  276. op = block.op(j)
  277. if op.has_attr('is_test'):
  278. op._set_attr('is_test', is_test)
  279. class _ProgramHolder:
  280. """
  281. Holds the execution information of a Program.
  282. _ProgramHolder is the execution unit of TranslatedLayer,
  283. if TranslatedLayer contains multiple _ProgramHolder,
  284. it can execute multiple methods
  285. _ProgramHolder is an internal concept.
  286. """
  287. def __init__(self, program_desc):
  288. super().__init__()
  289. # input, output, persistable, double_grads var info
  290. self._input_descs = []
  291. self._output_descs = []
  292. self._persistable_names = []
  293. self._grad_var_names = {}
  294. # execution scope
  295. self._inner_scope = core.Scope()
  296. # append suffix var name dict
  297. self._suffix_varname_dict = None
  298. # forward program
  299. self._infer_program_desc = self._preprocess(program_desc)
  300. # forward:
  301. @switch_to_static_graph
  302. def _create_forward_train_program(self):
  303. whole_program = _build_program_by_desc(self.train_program)
  304. end_op_index = self._infer_program_desc.block(0).op_size()
  305. if end_op_index > 0:
  306. return add_build_strategy_for(whole_program, 0, end_op_index)
  307. else:
  308. return whole_program
  309. @LazyInitialized
  310. def _forward_program_desc(self):
  311. return self._create_forward_train_program().desc
  312. # backward
  313. @switch_to_static_graph
  314. def _create_backward_train_program(self):
  315. whole_program = _build_program_by_desc(self.train_program)
  316. start_op_index = self._infer_program_desc.block(0).op_size() + len(
  317. self._output_descs
  318. )
  319. end_op_index = whole_program.desc.block(0).op_size()
  320. if start_op_index < end_op_index:
  321. return add_build_strategy_for(
  322. whole_program, start_op_index, end_op_index
  323. )
  324. else:
  325. return paddle.static.Program()
  326. @LazyInitialized
  327. def _backward_program_desc(self):
  328. return self._create_backward_train_program().desc
  329. @property
  330. def infer_program(self):
  331. return self._infer_program_desc
  332. @LazyInitialized
  333. def train_program(self):
  334. return self._append_backward_desc(self._infer_program_desc)
  335. @property
  336. def forward_program(self):
  337. return self._forward_program_desc
  338. @property
  339. def backward_program(self):
  340. return self._backward_program_desc
  341. @property
  342. def input_descs(self):
  343. return self._input_descs
  344. @property
  345. def output_descs(self):
  346. return self._output_descs
  347. @property
  348. def persistable_names(self):
  349. return self._persistable_names
  350. @property
  351. def scope(self):
  352. return self._inner_scope
  353. @property
  354. def grad_var_names(self):
  355. return self._grad_var_names
  356. def _preprocess(self, program_desc):
  357. # rename persistable variables of 'program_desc'
  358. list_persistable_var = _get_persistable_var_names(program_desc)
  359. rename_new_old_dict, _ = _rename_var_program_desc(
  360. program_desc, list_persistable_var
  361. )
  362. # 1. Prune original program
  363. # remove feed, fetch and scale-1 op, remove op_callstack attr
  364. ops_to_remove = []
  365. root_block = program_desc.block(0)
  366. for i in range(root_block.op_size()):
  367. op = root_block.op(i)
  368. if op.type() == 'feed':
  369. ops_to_remove.append(i)
  370. feed_var_name = op.input('X')[0].encode()
  371. root_block._remove_var(feed_var_name)
  372. self._input_descs.append(
  373. root_block.find_var(op.output('Out')[0].encode())
  374. )
  375. elif op.type() == 'scale' and op.output('Out')[0].startswith(
  376. 'save_infer_model/scale_'
  377. ):
  378. ops_to_remove.append(i)
  379. out_var_name = op.output('Out')[0].encode()
  380. root_block._remove_var(out_var_name)
  381. self._output_descs.append(
  382. root_block.find_var(op.input('X')[0].encode())
  383. )
  384. elif op.type() == 'fetch':
  385. ops_to_remove.append(i)
  386. fetch_var_name = op.output('Out')[0].encode()
  387. root_block._remove_var(fetch_var_name)
  388. # NOTE: some old pre-train models have no extra scale_op
  389. if not op.input('X')[0].startswith('save_infer_model/scale_'):
  390. self._output_descs.append(
  391. root_block.find_var(op.input('X')[0].encode())
  392. )
  393. else:
  394. if op.has_attr("op_callstack"):
  395. op.remove_attr("op_callstack")
  396. for op_idx in reversed(ops_to_remove):
  397. root_block._remove_op(op_idx, op_idx + 1)
  398. # 2. Input processing, reverse feed vars
  399. self._input_descs.reverse()
  400. # 3. Output processing, add scale for outputs
  401. tmp_program = _build_program_by_desc(program_desc)
  402. # NOTE: [why need append scale for outputs]
  403. # When dealing with some more complex pre-training models, there
  404. # will be situations where the pre-training model has multiple
  405. # fetch outputs. In the scenario of multiple fetch outputs,
  406. # there is a special case where multiple outputs of the model
  407. # may be on the same branch. According to the user's subsequent
  408. # use, multiple outputs may be associated with multiple branches.
  409. # These subsequent operations are added in TranslatedLayer is
  410. # agnostic during initialization, which results in subsequent
  411. # gradient accumulation operations that are required on the
  412. # output node in the middle of the branch will not be performed,
  413. # resulting in error, details see pull request:
  414. # [https://github.com/PaddlePaddle/Paddle/pull/24627]
  415. self._append_scale_to_output(tmp_program)
  416. # 4. Persistable vars processing
  417. # - append loaded suffix to persistable vars
  418. # NOTE: [why need to append suffix to persistable vars]
  419. # Dygraph and static graph mode use the same naming mechanism.
  420. # If users want to load the model fine-tune, it is possible
  421. # to add the existing Layer in the loaded model to enhance
  422. # the network. For example, the original saved model has linear,
  423. # and later after loading, a new linear is added. At this time,
  424. # there will be a problem of duplicate names, so here is unified
  425. # to add the LOADED suffix to the parameters of the model loaded
  426. self._suffix_varname_dict = _get_loaded_var_new_old(
  427. program_desc, rename_new_old_dict
  428. )
  429. # - get persistable var
  430. self._persistable_names = _get_persistable_var_names(program_desc)
  431. return program_desc
  432. @switch_to_static_graph
  433. def _append_scale_to_output(self, program):
  434. # 0. scale don't support bool output, we skip append scale for it
  435. for out_desc in self._output_descs:
  436. if out_desc.dtype() == paddle.bool:
  437. return
  438. # 1. append scale & save var
  439. scale_output_vars = []
  440. with framework.program_guard(program):
  441. for i, out in enumerate(self._output_descs):
  442. var = program.global_block().var(out.name())
  443. var = paddle.scale(var, 1.0, name=f"translated_layer/scale_{i}")
  444. scale_output_vars.append(var)
  445. # 2. update output names & descs
  446. for i, var in enumerate(scale_output_vars):
  447. self._output_descs[i] = var.desc
  448. @switch_to_static_graph
  449. def _get_train_forward_program(self, infer_program_desc):
  450. program_desc_copy = core.ProgramDesc(infer_program_desc)
  451. # 1. set all `is_test` attributes to False
  452. _change_is_test_status(program_desc_copy, False)
  453. # 2. prepare program and related var
  454. # NOTE: To reuse backward interfaces, build Program firstly.
  455. # Originally, there is no need to build a program, but need to almost
  456. # rewrite a series of methods for append_backward for program_desc.
  457. # Therefore, in order to reuse the method of backward.py, build the program here.
  458. program = _build_program_by_desc(program_desc_copy)
  459. # 3. Add the outputs which is only used for training and not saved in
  460. # inference program.
  461. for block_idx in range(program.num_blocks):
  462. block = program.block(block_idx)
  463. for op in block.ops:
  464. if op.type == "batch_norm":
  465. if (
  466. "ReserveSpace" not in op.output_names
  467. or len(op.output("ReserveSpace")) == 0
  468. ):
  469. reserve_space = block.create_var(
  470. name=unique_name.generate_with_ignorable_key(
  471. ".".join(["reserve_space", 'tmp'])
  472. ),
  473. dtype=block.var(op.input("X")[0]).dtype,
  474. type=core.VarDesc.VarType.LOD_TENSOR,
  475. persistable=False,
  476. stop_gradient=True,
  477. )
  478. op.desc.set_output("ReserveSpace", [reserve_space.name])
  479. continue
  480. # There are some situations that users will add backward op in Forward
  481. # function of Layer. And because backward op doesn't have proto. So, we
  482. # should skip it when we meet it.
  483. if not OpProtoHolder.instance().has_op_proto(op.type):
  484. continue
  485. proto = OpProtoHolder.instance().get_op_proto(op.type)
  486. has_create_intermediate_out = False
  487. for output_proto in proto.outputs:
  488. if output_proto.intermediate:
  489. intermediate_name = output_proto.name
  490. if intermediate_name not in op.output_names:
  491. has_create_intermediate_out = True
  492. intermediate_var = block.create_var(
  493. name=unique_name.generate_with_ignorable_key(
  494. ".".join(
  495. [
  496. op.type + '_' + intermediate_name,
  497. 'tmp',
  498. ]
  499. )
  500. ),
  501. type=core.VarDesc.VarType.LOD_TENSOR,
  502. persistable=False,
  503. stop_gradient=True,
  504. )
  505. op.desc.set_output(
  506. intermediate_name, [intermediate_var.name]
  507. )
  508. if has_create_intermediate_out:
  509. op.desc.infer_var_type(block.desc)
  510. op.desc.infer_shape(block.desc)
  511. return program
  512. @switch_to_static_graph
  513. def _append_backward_desc(self, infer_program_desc):
  514. program = self._get_train_forward_program(infer_program_desc)
  515. targets = []
  516. for out in self._output_descs:
  517. targets.append(program.global_block().var(out.name()))
  518. # 3. append backward
  519. check_type(
  520. targets,
  521. 'targets',
  522. (framework.Variable, list, tuple),
  523. 'paddle.static.gradients',
  524. )
  525. grad_info_map = backward.calc_gradient_helper(
  526. targets=targets, inputs=[]
  527. )
  528. x_vars = [
  529. program.block(0).var(desc.name()) for desc in self._input_descs
  530. ]
  531. param_vars = [
  532. program.block(0).var(name) for name in self._persistable_names
  533. ]
  534. out_vars = [
  535. program.block(0).var(desc.name()) for desc in self._output_descs
  536. ]
  537. self._grad_var_names = construct_grad_names(
  538. grad_info_map, x_vars, param_vars, out_vars
  539. )
  540. return program.desc
  541. # [ TranslatedLayer : Run program in imperative mode ]
  542. #
  543. # DESIGN IDEA: using an special operator `RunProgram`, execute program inside operator.
  544. #
  545. # Op's Inputs:
  546. # - the input variable of the user feed
  547. # - the necessary parameters of the network
  548. # Op's Outputs:
  549. # - the output variable of fetch
  550. #
  551. # This op receives a complete program desc, internally creates scope
  552. # and executor, executes this program. Key points:
  553. #
  554. # 1. Data Sharing:
  555. # The variable/parameter of the dynamic graph is not in the scope, so before the op
  556. # executes the program internally, create persistent variables with the
  557. # same name as feed, parameters, and fetch in the scope, and share the
  558. # LoDTensor of the op input.
  559. #
  560. # 2. Forward and Backward Separation:
  561. # Because the dynamic graph op performs the forward and backward separately,
  562. # in the forward op RunProgram, we only execute the forward part of whole program,
  563. # and in the backward op RunProgramGrad, we execute the backward part of program.
  564. # We can not separate the program into forward and backward part, which will
  565. # make some control flow execution logic wrong.
  566. # NOTE: [compatible] deal with model saved by save_inference_model,
  567. # which need get var info from program desc
  568. def _load_persistable_vars_by_program(
  569. model_path, program_holder, params_filename=None
  570. ):
  571. # make sure the path has been checked
  572. persistable_vars = _get_persistable_vars(program_holder.infer_program)
  573. load_var_dict = {}
  574. for each_var in persistable_vars:
  575. orig_each_name = program_holder._suffix_varname_dict[each_var.name()]
  576. if _is_parameter(each_var, program_holder.infer_program):
  577. # create output param
  578. new_var = framework.EagerParamBase(
  579. shape=each_var.shape(),
  580. dtype=each_var.dtype(),
  581. name=each_var.name(),
  582. type=each_var.type(),
  583. persistable=True,
  584. )
  585. else:
  586. new_var = framework._create_tensor(
  587. type=each_var.type(),
  588. name=each_var.name(),
  589. shape=each_var.shape(),
  590. dtype=each_var.dtype(),
  591. persistable=True,
  592. )
  593. if params_filename is None:
  594. framework._dygraph_tracer().trace_op(
  595. type='load',
  596. inputs={},
  597. outputs={'Out': new_var},
  598. attrs={'file_path': os.path.join(model_path, orig_each_name)},
  599. )
  600. new_var.stop_gradient = False
  601. load_var_dict[each_var.name()] = new_var
  602. if params_filename is not None:
  603. load_var_list = []
  604. dict_name_old_new = {
  605. v: k for k, v in program_holder._suffix_varname_dict.items()
  606. }
  607. for name in sorted(dict_name_old_new.keys()):
  608. load_var_list.append(load_var_dict[dict_name_old_new[name]])
  609. framework._dygraph_tracer().trace_op(
  610. type='load_combine',
  611. inputs={},
  612. outputs={'Out': load_var_list},
  613. attrs={'file_path': os.path.join(model_path, params_filename)},
  614. )
  615. for each_var in persistable_vars:
  616. if not _is_parameter(each_var, program_holder.infer_program):
  617. continue
  618. param = load_var_dict[each_var.name()]
  619. param.stop_gradient = False
  620. # NOTE: [Recovery stop gradient information based on the program]
  621. # After loading the model, the stop_gradient information
  622. # of the original variable is lost, but if a parameter does not
  623. # have a corresponding @GRAD variable in the backward program,
  624. # it can be said that it is also stop_gradient
  625. all_var_names = _get_all_var_names(program_holder.train_program)
  626. for var_name in load_var_dict:
  627. grad_var_name = var_name + core.grad_var_suffix()
  628. if grad_var_name not in all_var_names:
  629. load_var_dict[var_name].stop_gradient = True
  630. return load_var_dict
  631. def _load_persistable_vars(
  632. model_path, var_info_path, program_holder, params_filename
  633. ):
  634. # 1. load extra var info
  635. with open(var_info_path, 'rb') as f:
  636. extra_var_info = pickle.load(f)
  637. # 2. construct var dict
  638. load_var_dict = {}
  639. load_var_list = []
  640. inv_suffix_varname_dict = {
  641. value: key for key, value in program_holder._suffix_varname_dict.items()
  642. }
  643. # NOTE(chenweihang): we need load persistable vars based the program,
  644. # because the program may be pruned when `save_inference_model`, some
  645. # var in `extra_var_info` may have been pruned
  646. for name in sorted(inv_suffix_varname_dict):
  647. if name not in extra_var_info:
  648. raise RuntimeError(
  649. "The model to be loaded is not complete."
  650. "The variable `%s` of program cannot be found in loaded model.",
  651. name,
  652. )
  653. # get suffix var name, see [why need to append suffix to persistable vars]
  654. new_name = inv_suffix_varname_dict[name]
  655. # create output var or param
  656. if extra_var_info[name].get('trainable', None) is not None:
  657. # use default shape and dtype
  658. new_var = framework.EagerParamBase(
  659. shape=[1], # only to pass check, this shape is not meaningful
  660. dtype=core.VarDesc.VarType.FP32,
  661. name=new_name,
  662. persistable=True,
  663. )
  664. else:
  665. new_var = framework._create_tensor(name=new_name, persistable=True)
  666. new_var.stop_gradient = extra_var_info[name]['stop_gradient']
  667. load_var_dict[new_name] = new_var
  668. load_var_list.append(new_var)
  669. # 3. load all vars
  670. assert params_filename is not None, "params_filename should not be None."
  671. var_file_path = os.path.join(model_path, params_filename)
  672. if not os.path.exists(var_file_path):
  673. if len(extra_var_info) != 0:
  674. raise ValueError("The model to be loaded is incomplete.")
  675. else:
  676. framework._dygraph_tracer().trace_op(
  677. type='load_combine',
  678. inputs={},
  679. outputs={'Out': load_var_list},
  680. attrs={'file_path': var_file_path},
  681. )
  682. return load_var_dict
  683. # NOTE(chenweihang): to adapt paddle.load to get state_dict
  684. def _remove_varname_suffix(var_dict, program_holder):
  685. no_suffix_var_dict = {}
  686. for var_name in var_dict:
  687. no_suffix_name = program_holder._suffix_varname_dict[var_name]
  688. no_suffix_var_dict[no_suffix_name] = var_dict[var_name]
  689. return no_suffix_var_dict
  690. def _construct_program_holders(model_path, model_filename=None):
  691. # make sure the path has been checked
  692. program_holder_dict = {}
  693. if model_filename is not None:
  694. # [compatible] if assign model_filename, only can load one program as Layer.forward
  695. model_filename = os.path.basename(model_filename)
  696. model_file_path = os.path.join(model_path, model_filename)
  697. model_name = model_filename[: -len(INFER_MODEL_SUFFIX)]
  698. # Load every file that meets the requirements in the directory model_path.
  699. for filename in os.listdir(model_path):
  700. if model_filename == filename:
  701. func_name = 'forward'
  702. model_file_path = os.path.join(model_path, model_filename)
  703. elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith(
  704. model_name
  705. ):
  706. parsing_names = filename[
  707. len(model_name) : -len(INFER_MODEL_SUFFIX) + 1
  708. ].split('.')
  709. if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
  710. func_name = parsing_names[1]
  711. model_file_path = os.path.join(model_path, filename)
  712. else:
  713. continue
  714. else:
  715. continue
  716. program_holder_dict[func_name] = _ProgramHolder(
  717. _load_program_desc(model_file_path)
  718. )
  719. else:
  720. for _, _, file_names in os.walk(model_path):
  721. for name in file_names:
  722. if 'model' in name:
  723. model_file_path = os.path.join(model_path, name)
  724. method_name = name.strip('_')
  725. if method_name == 'model':
  726. method_name = 'forward'
  727. else:
  728. method_name.replace('model', '')
  729. program_holder_dict[method_name] = _ProgramHolder(
  730. _load_program_desc(model_file_path)
  731. )
  732. return program_holder_dict
  733. def _construct_params_and_buffers(
  734. model_path, programs, params_filename=None, append_suffix=True
  735. ):
  736. var_info_filename = str(params_filename) + ".info"
  737. var_info_path = os.path.join(model_path, var_info_filename)
  738. params_path = os.path.join(model_path, str(params_filename))
  739. if os.path.exists(var_info_path):
  740. var_dict = _load_persistable_vars(
  741. model_path, var_info_path, programs['forward'], params_filename
  742. )
  743. model_name = params_filename[: -len(INFER_PARAMS_SUFFIX)]
  744. # Load every file that meets the requirements in the directory model_path.
  745. for file_name in os.listdir(model_path):
  746. if file_name.startswith(model_name) and file_name.endswith(
  747. INFER_PARAMS_SUFFIX
  748. ):
  749. parsing_names = file_name[
  750. len(model_name) : -len(INFER_PARAMS_SUFFIX) + 1
  751. ].split('.')
  752. if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
  753. func_name = parsing_names[1]
  754. else:
  755. continue
  756. else:
  757. continue
  758. var_info_path = os.path.join(model_path, var_info_filename)
  759. var_dict.update(
  760. _load_persistable_vars(
  761. model_path, var_info_path, programs[func_name], file_name
  762. )
  763. )
  764. elif params_filename is not None and not os.path.exists(params_path):
  765. # When saving XX, there is only '*.pdmodel'
  766. return {}
  767. else:
  768. var_dict = _load_persistable_vars_by_program(
  769. model_path, programs['forward'], params_filename
  770. )
  771. if not append_suffix:
  772. var_dict = _remove_varname_suffix(var_dict, programs['forward'])
  773. return var_dict
  774. def _valid_vars(vars):
  775. return vars if vars else None
  776. def _run_dygraph(instance, input, program_holder):
  777. # 1. prepare inputs, outputs, attrs
  778. input_vars = []
  779. input_var_names = []
  780. for i, value in enumerate(input):
  781. if not isinstance(value, (np.ndarray, core.eager.Tensor)):
  782. raise TypeError(
  783. f"The type of input in TranslatedLayer must be numpy array or Variable(Tensor), but received {type(value)}."
  784. )
  785. # NOTE: In order to unify the API, firstly convert the input to Tensor
  786. if isinstance(value, np.ndarray):
  787. var = core.eager.Tensor(
  788. value=value,
  789. name=program_holder.input_descs[i].name(),
  790. persistable=False,
  791. place=framework._current_expected_place(),
  792. zero_copy=True,
  793. )
  794. else:
  795. var = value
  796. # NOTE: we changed var name here,
  797. # but it may be an important name set by user
  798. var.name = program_holder.input_descs[i].name()
  799. input_var_names.append(var.name)
  800. input_vars.append(var)
  801. if instance._input_args_names is None:
  802. instance._input_args_names = [
  803. ins.name() for ins in program_holder.input_descs
  804. ]
  805. persistable_vars = []
  806. for var_name in program_holder.persistable_names:
  807. dy_var_name = instance._persistable_var_name_dict[var_name]
  808. if dy_var_name in instance._parameters:
  809. persistable_vars.append(instance._parameters[dy_var_name])
  810. elif dy_var_name in instance._buffers:
  811. persistable_vars.append(instance._buffers[dy_var_name])
  812. else:
  813. raise ValueError(
  814. f"The persistable variable {var_name} does not exist in current TranslatedLayer."
  815. )
  816. output_vars = []
  817. for var_desc in program_holder.output_descs:
  818. var = core.eager.Tensor(
  819. dtype=var_desc.dtype(),
  820. dims=var_desc.shape(),
  821. name=var_desc.name(),
  822. type=var_desc.type(),
  823. persistable=False,
  824. )
  825. output_vars.append(var)
  826. # hold forward variables
  827. tmp_scope_vec = [program_holder.scope]
  828. # 2. run program by op
  829. trace_program = (
  830. program_holder.infer_program
  831. if instance._is_test
  832. else program_holder.train_program
  833. )
  834. forward_program = (
  835. program_holder._infer_program_desc
  836. if instance._is_test
  837. else program_holder.forward_program
  838. )
  839. end_op_index = program_holder.infer_program.block(0).op_size()
  840. attrs = [
  841. 'global_block',
  842. trace_program.block(0),
  843. 'start_op_index',
  844. 0,
  845. 'end_op_index',
  846. end_op_index,
  847. 'is_test',
  848. instance._is_test,
  849. 'program_id',
  850. paddle.utils._hash_with_id(trace_program, instance),
  851. 'x_names',
  852. input_var_names,
  853. ]
  854. if not instance._is_test:
  855. attrs.extend(
  856. (
  857. 'param_grad_names',
  858. program_holder.grad_var_names.get('param', []),
  859. 'out_grad_names',
  860. program_holder.grad_var_names.get('out', []),
  861. 'x_grad_names',
  862. program_holder.grad_var_names.get('x', []),
  863. )
  864. )
  865. use_interpretorcore = True
  866. attrs.extend(('use_interpretorcore', use_interpretorcore))
  867. if use_interpretorcore:
  868. attrs.extend(
  869. (
  870. 'forward_global_block',
  871. forward_program.block(0),
  872. )
  873. )
  874. if not instance._is_test:
  875. attrs.extend(
  876. (
  877. 'backward_global_block',
  878. program_holder.backward_program.block(0),
  879. )
  880. )
  881. _legacy_C_ops.run_program(
  882. _valid_vars(input_vars),
  883. _valid_vars(persistable_vars),
  884. _valid_vars(output_vars),
  885. tmp_scope_vec,
  886. None,
  887. *attrs,
  888. )
  889. # NOTE: [ why need set param's gradient type here ]
  890. # if user set sparse gradient mode, the param's gradient
  891. # will be SelectedRows, not LoDTensor. But tracer will just
  892. # set param grad Tensor by forward Tensor(LoDTensor)
  893. # If we don't change grad_var type here, RunProgramOp need
  894. # transform SelectedRows to LoDTensor forcibly, it may not
  895. # be user wanted result.
  896. for persistable_var in persistable_vars:
  897. grad_var_name = persistable_var.name + core.grad_var_suffix()
  898. grad_var = trace_program.block(0).find_var(grad_var_name.encode())
  899. # NOTE: cannot find var desc maybe not problem,
  900. # such as in batch_norm
  901. if grad_var is None:
  902. continue
  903. persistable_var._set_grad_type(grad_var.type())
  904. # 3. prepare output, keep same form with inputs
  905. outs = output_vars
  906. if len(output_vars) == 1:
  907. outs = output_vars[0]
  908. return outs
  909. def _run_static_graph(input, program_holder, trace_program):
  910. main_program = framework.default_main_program()
  911. param_var_names = _get_persistable_var_names(trace_program)
  912. _, dict_rename_var_old_new = _rename_var_program_desc(
  913. trace_program, exclude=param_var_names
  914. )
  915. trace_program.flush()
  916. # append blocks from 'trace_program'
  917. _append_block(
  918. main_program,
  919. trace_program,
  920. program_holder,
  921. input,
  922. dict_rename_var_old_new,
  923. )
  924. main_program._sync_with_cpp()
  925. outs = _get_output_from_program(
  926. main_program, program_holder, dict_rename_var_old_new
  927. )
  928. if len(outs) == 1:
  929. outs = outs[0]
  930. return outs
  931. def _collect_current_and_parent_var(program, block_idx):
  932. '''
  933. Get variables in current block and its parent block.
  934. Args:
  935. program(Program): The program containing the current block.
  936. block_idx(int): index of current block.
  937. Returns:
  938. List: list of variables.
  939. '''
  940. vars = []
  941. if block_idx < 0:
  942. return vars
  943. for var in program.block(block_idx).vars:
  944. vars.append(var)
  945. parent_idx = program.block(block_idx).parent_idx
  946. if parent_idx > -1:
  947. vars += _collect_current_and_parent_var(program, parent_idx)
  948. return vars
  949. def _append_block(
  950. dest_program,
  951. src_program_desc,
  952. program_holder,
  953. input_variables,
  954. dict_rename_var_old_new=None,
  955. ):
  956. '''
  957. Append Variables and Operators in 'src_program_desc' to dest_program.
  958. Args:
  959. dest_program(Program): Variables and Operators are appended to it.
  960. src_program_desc(ProgramDesc): Variables in it will be appended to 'dest_program'.
  961. program_holder(_ProgramHolder): program_holder of TranslatedLayer
  962. input_variables(list): list of input variables
  963. dict_rename_var_old_new(None|dict): When using '_rename_var_program_desc',
  964. use it to map the name of the variable before it was modified and the new name.
  965. '''
  966. origin_block_idx = dest_program.current_block_idx
  967. param_var_names = _collect_current_and_parent_var(
  968. dest_program, origin_block_idx
  969. )
  970. append_var_from_block_desc_static(
  971. dest_program.block(origin_block_idx),
  972. src_program_desc.block(0),
  973. exclude=param_var_names,
  974. )
  975. name_inp_desc = [inp.name() for inp in program_holder.input_descs]
  976. input_names = [inp.name for inp in input_variables]
  977. if len(name_inp_desc) != len(input_names):
  978. raise ValueError(
  979. f"The number of input is invalid, expected {len(name_inp_desc)}, but received {len(input_names)}."
  980. )
  981. for i, out_name in enumerate(name_inp_desc):
  982. if dict_rename_var_old_new:
  983. out_name = dict_rename_var_old_new[out_name]
  984. dest_program.block(origin_block_idx).append_op(
  985. type='assign',
  986. inputs={'X': [input_names[i]]},
  987. outputs={'Out': [out_name]},
  988. )
  989. append_ops = append_op_from_block_desc_static(
  990. dest_program.block(origin_block_idx), src_program_desc.block(0)
  991. )
  992. dest_program._sync_with_cpp()
  993. offset_block_idx = dest_program.num_blocks - 1
  994. parent_idx = 0
  995. if src_program_desc.num_blocks() > 1:
  996. for src_block_idx in range(1, src_program_desc.num_blocks()):
  997. src_block = src_program_desc.block(src_block_idx)
  998. src_parent_idx = src_block.parent
  999. if src_parent_idx > 0:
  1000. parent_idx = offset_block_idx + parent_idx
  1001. else:
  1002. parent_idx = origin_block_idx
  1003. dest_block = dest_program._create_block(parent_idx=parent_idx)
  1004. append_var_from_block_desc_static(
  1005. dest_block, src_block, exclude=param_var_names
  1006. )
  1007. append_ops += append_op_from_block_desc_static(
  1008. dest_block, src_block
  1009. )
  1010. dest_program._sync_with_cpp()
  1011. for op in append_ops:
  1012. if op.has_attr('sub_block'):
  1013. sub = op.attr('sub_block')
  1014. if isinstance(sub, framework.core.BlockDesc):
  1015. origin_id = sub.id
  1016. if isinstance(sub, framework.Block):
  1017. origin_id = sub.idx
  1018. op._set_attr(
  1019. 'sub_block', dest_program.block(offset_block_idx + origin_id)
  1020. )
  1021. dest_program._sync_with_cpp()
  1022. dest_program.current_block_idx = origin_block_idx
  1023. def _get_output_from_program(
  1024. program, program_holder, dict_rename_var_old_new=None
  1025. ):
  1026. """
  1027. Get output name of 'program' according to program_holder
  1028. """
  1029. outs = []
  1030. for var in program_holder.output_descs:
  1031. for idx in range(program.num_blocks):
  1032. vars = program.block(idx).vars
  1033. var_name = var.name()
  1034. if dict_rename_var_old_new:
  1035. var_name = dict_rename_var_old_new[var_name]
  1036. if var_name in vars:
  1037. out = vars[var_name]
  1038. if out not in outs:
  1039. outs.append(out)
  1040. return outs
  1041. def append_op_from_block_desc_static(block, src_block_desc):
  1042. """
  1043. Append Operators of 'src_block_desc' to current block.
  1044. Args:
  1045. block(Block): append OP of 'src_block_desc' to it.
  1046. src_block_desc(BlockDesc): append var of 'src_block_desc'
  1047. Returns:
  1048. List: list of the OP that are append to current block.
  1049. """
  1050. ops = []
  1051. for i in range(src_block_desc.op_size()):
  1052. ops.append(append_op_from_desc_static(block, src_block_desc.op(i)))
  1053. return ops
  1054. def append_op_from_desc_static(block, op_desc):
  1055. """
  1056. Append Operators to 'block' according to 'op_desc'.
  1057. Args:
  1058. block(Block): append OP of 'src_block_desc' to it.
  1059. op_desc(OpDesc): create OP according to it.
  1060. Returns:
  1061. Operator: OP appended to 'block'.
  1062. """
  1063. op_type = op_desc.type()
  1064. op_append = block.desc.append_op()
  1065. op_append.copy_from(op_desc)
  1066. op = framework.Operator(
  1067. block=block,
  1068. desc=op_append,
  1069. type=op_type,
  1070. inputs=None,
  1071. outputs=None,
  1072. attrs=None,
  1073. )
  1074. block.ops.append(op)
  1075. return op
  1076. def append_var_from_block_desc_static(
  1077. block, src_block_desc, include=None, exclude=None
  1078. ):
  1079. """
  1080. Append Variables of 'src_block_desc' to current block.
  1081. If 'include' is not `None`,variables that are not in include are not append.
  1082. If 'exclude' is not `None`,variables that are in exclude will are not append.
  1083. Args:
  1084. block(Block): append Variables of 'src_block_desc' to it.
  1085. src_block_desc(BlockDesc): append var of 'src_block_desc'
  1086. include(List):list of names of variables
  1087. exclude(List):list of names of variables
  1088. Returns:
  1089. List: list of the variables that are append to current block.
  1090. """
  1091. vars_append = []
  1092. for var_desc in src_block_desc.all_vars():
  1093. var_desc_name = var_desc.name()
  1094. should_append = (include is None or var_desc_name in include) and (
  1095. exclude is None or var_desc_name not in exclude
  1096. )
  1097. if not block.has_var(var_desc_name) and should_append:
  1098. var_type = var_desc.type()
  1099. if var_type in [
  1100. core.VarDesc.VarType.SELECTED_ROWS,
  1101. core.VarDesc.VarType.LOD_TENSOR,
  1102. core.VarDesc.VarType.LOD_TENSOR_ARRAY,
  1103. ]:
  1104. data_type = var_desc.dtype()
  1105. var_shape = var_desc.shape()
  1106. else:
  1107. data_type = None
  1108. var_shape = None
  1109. if var_type in [
  1110. core.VarDesc.VarType.LOD_TENSOR,
  1111. core.VarDesc.VarType.LOD_TENSOR_ARRAY,
  1112. ]:
  1113. lod_level = var_desc.lod_level()
  1114. else:
  1115. lod_level = None
  1116. if var_desc.persistable():
  1117. current_block = block.program.global_block()
  1118. else:
  1119. current_block = block
  1120. vars_append.append(
  1121. current_block.create_var(
  1122. name=var_desc.name(),
  1123. dtype=data_type,
  1124. type=var_type,
  1125. shape=var_shape,
  1126. lod_level=lod_level,
  1127. persistable=var_desc.persistable(),
  1128. set_need_check_feed=var_desc.need_check_feed(),
  1129. )
  1130. )
  1131. return vars_append
  1132. class TranslatedLayer(layers.Layer):
  1133. """
  1134. TranslatedLayer is a ``paddle.nn.Layer`` for holding the model
  1135. loaded by :ref:`api_paddle_jit_load` . It can be used like a
  1136. general Layer object in eval or train mode.
  1137. .. note:
  1138. The TranslatedLayer objects should not be created by constructor, it only can be loaded and constructed by :ref:`api_paddle_jit_load` .
  1139. Examples:
  1140. .. code-block:: python
  1141. >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest')
  1142. >>> import numpy as np
  1143. >>> import paddle
  1144. >>> import paddle.nn as nn
  1145. >>> import paddle.optimizer as opt
  1146. >>> BATCH_SIZE = 16
  1147. >>> BATCH_NUM = 4
  1148. >>> EPOCH_NUM = 4
  1149. >>> IMAGE_SIZE = 784
  1150. >>> CLASS_NUM = 10
  1151. >>> # define a random dataset
  1152. >>> class RandomDataset(paddle.io.Dataset):
  1153. ... def __init__(self, num_samples):
  1154. ... self.num_samples = num_samples
  1155. ...
  1156. ... def __getitem__(self, idx):
  1157. ... image = np.random.random([IMAGE_SIZE]).astype('float32')
  1158. ... label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
  1159. ... return image, label
  1160. ...
  1161. ... def __len__(self):
  1162. ... return self.num_samples
  1163. ...
  1164. >>> class LinearNet(nn.Layer):
  1165. ... def __init__(self):
  1166. ... super().__init__()
  1167. ... self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
  1168. ...
  1169. ... @paddle.jit.to_static
  1170. ... def forward(self, x):
  1171. ... return self._linear(x)
  1172. ...
  1173. >>> def train(layer, loader, loss_fn, opt):
  1174. ... for epoch_id in range(EPOCH_NUM):
  1175. ... for batch_id, (image, label) in enumerate(loader()):
  1176. ... out = layer(image)
  1177. ... loss = loss_fn(out, label)
  1178. ... loss.backward()
  1179. ... opt.step()
  1180. ... opt.clear_grad()
  1181. ... print("Epoch {} batch {}: loss = {}".format(
  1182. ... epoch_id, batch_id, np.mean(loss.numpy())))
  1183. ...
  1184. >>> # 1. train & save model.
  1185. >>> # create network
  1186. >>> layer = LinearNet()
  1187. >>> loss_fn = nn.CrossEntropyLoss()
  1188. >>> adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
  1189. >>> # create data loader
  1190. >>> dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
  1191. >>> loader = paddle.io.DataLoader(dataset,
  1192. ... batch_size=BATCH_SIZE,
  1193. ... shuffle=True,
  1194. ... drop_last=True,
  1195. ... num_workers=2
  1196. ... )
  1197. >>> # train
  1198. >>> train(layer, loader, loss_fn, adam)
  1199. >>> # save
  1200. >>> model_path = "linear.example.model"
  1201. >>> paddle.jit.save(layer, model_path)
  1202. >>> # 2. load model as TranslatedLayer
  1203. >>> # load
  1204. >>> translated_layer = paddle.jit.load(model_path)
  1205. >>> # inference
  1206. >>> translated_layer.eval()
  1207. >>> x = paddle.randn([1, IMAGE_SIZE], 'float32')
  1208. >>> pred = translated_layer(x)
  1209. >>> # fine-tune
  1210. >>> translated_layer.train()
  1211. >>> adam = opt.Adam(learning_rate=0.001, parameters=translated_layer.parameters())
  1212. >>> train(translated_layer, loader, loss_fn, adam)
  1213. """
  1214. def __init__(self, programs, persistable_vars):
  1215. super().__init__()
  1216. if not isinstance(programs, dict):
  1217. raise TypeError(
  1218. "TranslatedLayer need to use _ProgramHolder's dict for initialization."
  1219. )
  1220. if not isinstance(persistable_vars, dict):
  1221. raise TypeError(
  1222. "TranslatedLayer need to use persistable variable dict for initialization."
  1223. )
  1224. self._program_holder_dict = programs
  1225. # NOTE(chenweihang): [ why not use var name directly? ]
  1226. # When add parameter or buffer to Layer by follow apis,
  1227. # the variable name can't contain `.`, because which may cause
  1228. # AttributeError when access the newly added parameter or buffer
  1229. # in the form of `self.**.**``, but the EagerParamBase or BarBase
  1230. # name contains `.` originally, such as `linear_0.w_0`, so here
  1231. # need to generate new var name for each var
  1232. self._persistable_var_name_dict = {}
  1233. # the TranslatedLayer object held var names count started from 0
  1234. with unique_name.guard():
  1235. for name, var in persistable_vars.items():
  1236. if isinstance(var, framework.EagerParamBase):
  1237. dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
  1238. self._persistable_var_name_dict[name] = dy_name
  1239. self.add_parameter(dy_name, var)
  1240. elif isinstance(var, core.eager.Tensor):
  1241. dy_name = _generate_unique_var_name(BUFFER_NAME_PREFIX)
  1242. self._persistable_var_name_dict[name] = dy_name
  1243. self.register_buffer(dy_name, var)
  1244. else:
  1245. raise TypeError(
  1246. "Adding persistent variable which to layer is not supported now"
  1247. )
  1248. self._is_test = True
  1249. self._input_args_names = None
  1250. @staticmethod
  1251. @framework.dygraph_only
  1252. def _construct(model_path, configs=None):
  1253. # 0. dir and filename check
  1254. model_path = os.path.normpath(model_path)
  1255. if not os.path.isdir(model_path):
  1256. raise ValueError(f"There is no directory named '{model_path}'")
  1257. model_filename = None
  1258. params_filename = None
  1259. if configs is not None:
  1260. model_filename = configs.model_filename
  1261. params_filename = configs.params_filename
  1262. # 1. load program desc & construct _ProgramHolder
  1263. programs = _construct_program_holders(model_path, model_filename)
  1264. # 2. load layer parameters & buffers
  1265. persistable_vars = _construct_params_and_buffers(
  1266. model_path, programs, params_filename
  1267. )
  1268. # 3. construct TranslatedLayer object
  1269. translated_layer = TranslatedLayer(programs, persistable_vars)
  1270. # 4. create TranslatedLayer's execution method
  1271. for method_name, program_holder in programs.items():
  1272. if translated_layer._input_args_names is None:
  1273. translated_layer._input_args_names = [
  1274. ins.name() for ins in program_holder.input_descs
  1275. ]
  1276. setattr(
  1277. TranslatedLayer,
  1278. method_name,
  1279. TranslatedLayer._execution_method_creator(
  1280. method_name, program_holder
  1281. ),
  1282. )
  1283. # 5. set TranslatedLayer's default mode to eval
  1284. translated_layer.eval()
  1285. return translated_layer
  1286. @staticmethod
  1287. def _execution_method_creator(method_name, program_holder):
  1288. def __i_m_p_l__(self, *input):
  1289. program_holder = self._program_holder_dict[__i_m_p_l__.__name__]
  1290. # When using jit.save, it runs in static graph mode.
  1291. # Run in dynamic graph mode when the model is inferring.
  1292. if in_dynamic_mode():
  1293. return _run_dygraph(self, input, program_holder)
  1294. else:
  1295. # NOTE(weixin): [ why not use 'program_holder.infer_program' directly? ]
  1296. # When use '_run_static_graph(input, program_holder, program_holder.infer_program)',
  1297. # because '_run_static_graph' modifies 'ProgramDesc', 'OpDesc.op_size()' will return a very large wrong number.
  1298. # A Segmentation fault error may occur if used 'p=ProgramDesc(program_holder.infer_program)'.
  1299. p = framework.Program._construct_from_desc(
  1300. core.ProgramDesc(program_holder.infer_program)
  1301. )
  1302. return _run_static_graph(input, program_holder, p.desc)
  1303. __i_m_p_l__.__name__ = method_name
  1304. return __i_m_p_l__
  1305. def train(self):
  1306. self._is_test = False
  1307. self.training = True
  1308. def eval(self):
  1309. self._is_test = True
  1310. self.training = False
  1311. def program(self, method_name='forward'):
  1312. """
  1313. Gets translated program of specified method.
  1314. Args:
  1315. - method_name (string): method name corresponding to the program
  1316. to be obtained. Default: 'forward'.
  1317. Returns:
  1318. Program
  1319. Examples:
  1320. .. code-block:: python
  1321. >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest')
  1322. >>> import numpy as np
  1323. >>> import paddle
  1324. >>> from paddle import nn
  1325. >>> import paddle.optimizer as opt
  1326. >>> BATCH_SIZE = 16
  1327. >>> BATCH_NUM = 4
  1328. >>> EPOCH_NUM = 4
  1329. >>> IMAGE_SIZE = 784
  1330. >>> CLASS_NUM = 10
  1331. >>> # define a random dataset
  1332. >>> class RandomDataset(paddle.io.Dataset):
  1333. ... def __init__(self, num_samples):
  1334. ... self.num_samples = num_samples
  1335. ...
  1336. ... def __getitem__(self, idx):
  1337. ... image = np.random.random([IMAGE_SIZE]).astype('float32')
  1338. ... label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
  1339. ... return image, label
  1340. ...
  1341. ... def __len__(self):
  1342. ... return self.num_samples
  1343. ...
  1344. >>> class LinearNet(nn.Layer):
  1345. ... def __init__(self):
  1346. ... super().__init__()
  1347. ... self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
  1348. ...
  1349. ... @paddle.jit.to_static
  1350. ... def forward(self, x):
  1351. ... return self._linear(x)
  1352. ...
  1353. >>> def train(layer, loader, loss_fn, opt):
  1354. ... for epoch_id in range(EPOCH_NUM):
  1355. ... for batch_id, (image, label) in enumerate(loader()):
  1356. ... out = layer(image)
  1357. ... loss = loss_fn(out, label)
  1358. ... loss.backward()
  1359. ... opt.step()
  1360. ... opt.clear_grad()
  1361. ... print("Epoch {} batch {}: loss = {}".format(
  1362. ... epoch_id, batch_id, np.mean(loss.numpy())))
  1363. ...
  1364. >>> # create network
  1365. >>> layer = LinearNet()
  1366. >>> loss_fn = nn.CrossEntropyLoss()
  1367. >>> adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
  1368. >>> # create data loader
  1369. >>> dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
  1370. >>> loader = paddle.io.DataLoader(dataset,
  1371. ... batch_size=BATCH_SIZE,
  1372. ... shuffle=True,
  1373. ... drop_last=True,
  1374. ... num_workers=2
  1375. ... )
  1376. >>> # train
  1377. >>> train(layer, loader, loss_fn, adam)
  1378. >>> # save
  1379. >>> model_path = "linear.example.model"
  1380. >>> paddle.jit.save(layer, model_path)
  1381. >>> # load
  1382. >>> translated_layer = paddle.jit.load(model_path)
  1383. >>> # get program
  1384. >>> program = translated_layer.program()
  1385. """
  1386. # 1. get program holder
  1387. program_holder = self._get_program_holder(method_name)
  1388. # 2. get inference program desc
  1389. program_desc = program_holder.infer_program
  1390. # 3. construct program
  1391. program = _build_program_by_desc(program_desc)
  1392. return program
  1393. def _get_program_holder(self, method_name='forward'):
  1394. program_holder = self._program_holder_dict.get(method_name, None)
  1395. if program_holder is None:
  1396. raise ValueError(
  1397. f"The method `{method_name}` does not exist in loaded TranslatedLayer."
  1398. )
  1399. return program_holder
  1400. def _input_spec(self, method_name='forward'):
  1401. # 1. get program holder
  1402. program_holder = self._get_program_holder(method_name)
  1403. # 2. build input spec by input desc
  1404. input_spec = []
  1405. for var_desc in program_holder.input_descs:
  1406. spec = paddle.static.InputSpec(
  1407. shape=var_desc.shape(),
  1408. dtype=var_desc.dtype(),
  1409. name=var_desc.name(),
  1410. )
  1411. input_spec.append(spec)
  1412. return input_spec
  1413. def _output_spec(self, method_name='forward'):
  1414. # 1. get program holder
  1415. program_holder = self._get_program_holder(method_name)
  1416. # 2. build output spec by output desc
  1417. output_spec = []
  1418. for var_desc in program_holder.output_descs:
  1419. # NOTE(chenweihang): InputSpec describes a tensor, not just input.
  1420. # Maybe the name is not good enough. Here we use InputSpec to
  1421. # construct the description of Output tensor
  1422. spec = paddle.static.InputSpec(
  1423. shape=var_desc.shape(),
  1424. dtype=var_desc.dtype(),
  1425. name=var_desc.name(),
  1426. )
  1427. output_spec.append(spec)
  1428. return output_spec