recompute.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import paddle
  16. from paddle.base import core, framework, unique_name
  17. from paddle.base.backward import append_backward
  18. from paddle.base.framework import Variable, in_dygraph_mode, program_guard
  19. from paddle.optimizer import Optimizer
  20. class RecomputeOptimizer(Optimizer):
  21. """
  22. :api_attr: Static Graph
  23. Recompute Optimizer Wrapper
  24. Normally, a training step contains three sub-steps: first, run forward
  25. Operators to calculate the loss; second, run backward Operators to
  26. calculate gradient of the parameters; third, apply optimization method
  27. to update the value of the parameters.
  28. In the forward computation process, all variables that are needed by
  29. backward computation process will be kept in memory, which occupy a great
  30. amount of memory when the network becomes very deep.
  31. Recompute split the network to k segments. In each segment, It will
  32. recompute the forward Operators, before running backward operators. It is
  33. very helpful for saving memory.
  34. The Variables that separate a network to segments are called as checkpoints,
  35. and users should set it manually. The usage is very simple:
  36. Args:
  37. optimizer (Optimizer): The optimizer that is applied to parameters.
  38. Examples:
  39. .. code-block:: python
  40. >>> import paddle
  41. >>> import numpy as np
  42. >>> paddle.enable_static()
  43. >>> def gen_data():
  44. ... return {"x": np.random.random(size=(32, 32)).astype('float32'),
  45. ... "y": np.random.randint(2, size=(32, 1)).astype('int64')}
  46. >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
  47. ... print(input_x)
  48. ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
  49. ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
  50. ... cost = paddle.nn.functional.cross_entropy(
  51. ... input=prediction, label=input_y,
  52. ... reduction='none', use_softmax=False
  53. ... )
  54. ... sum_cost = paddle.mean(cost)
  55. ... return sum_cost, fc_1, prediction
  56. >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
  57. >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
  58. >>> cost, fc_1, pred = mlp(input_x, input_y)
  59. >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
  60. >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
  61. >>> sgd._set_checkpoints([fc_1, pred])
  62. >>> sgd.minimize(cost)
  63. >>> print("Finished optimize")
  64. Finished optimize
  65. >>> place = paddle.CPUPlace()
  66. >>> exe = paddle.static.Executor(place)
  67. >>> exe.run(paddle.static.default_startup_program())
  68. >>> step = 10
  69. >>> for i in range(step):
  70. ... cost_val = exe.run(feed=gen_data(),
  71. ... program=paddle.static.default_main_program(),
  72. ... fetch_list=[cost.name])
  73. ... print("step=%d cost=%f" % (i, cost_val[0]))
  74. var x : LOD_TENSOR.shape(-1, 32).dtype(float32).stop_gradient(True)
  75. Finished optimize
  76. step=0 cost=0.737203
  77. step=1 cost=1.308077
  78. step=2 cost=0.768422
  79. step=3 cost=1.239475
  80. step=4 cost=0.882643
  81. step=5 cost=0.738027
  82. step=6 cost=0.819374
  83. step=7 cost=0.818534
  84. step=8 cost=0.753692
  85. step=9 cost=0.787448
  86. """
  87. def __init__(self, optimizer):
  88. if in_dygraph_mode():
  89. raise Exception("In dygraph, don't support RecomputeOptimizer.")
  90. self._optimizer = optimizer
  91. self._checkpoints = None
  92. self._learning_rate = self._optimizer._learning_rate
  93. self._learning_rate_map = self._optimizer._learning_rate_map
  94. self.enable_offload = False
  95. def _set_checkpoints(self, checkpoints):
  96. """
  97. Args:
  98. checkpoints (list): List of Variable or string
  99. """
  100. assert isinstance(
  101. checkpoints, list
  102. ), "_checkpoints should be a list of Variable or a list of String"
  103. for ckpt in checkpoints:
  104. assert isinstance(
  105. ckpt, (Variable, str)
  106. ), "_checkpoints should be a list of Variable or a list of String"
  107. self._checkpoints = checkpoints
  108. # should enable offload before calling backward
  109. def _enable_offload(self):
  110. self.enable_offload = True
  111. @framework.deprecate_stat_dict
  112. def load(self, state_dict):
  113. """
  114. :api_attr: Static Graph
  115. load function is not supported by Recompute Optimizer for now.
  116. :return: None
  117. Args:
  118. state_dict: the dict load by load_persistable method
  119. Examples:
  120. .. code-block:: python
  121. >>> import paddle
  122. >>> paddle.enable_static()
  123. >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
  124. ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
  125. ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
  126. ... cost = paddle.nn.functional.cross_entropy(
  127. ... input=prediction, label=input_y,
  128. ... reduction='none', use_softmax=False
  129. ... )
  130. ... sum_cost = paddle.mean(cost)
  131. ... return sum_cost, fc_1, prediction
  132. >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
  133. >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
  134. >>> cost, fc_1, pred = mlp(input_x, input_y)
  135. >>> print("Finished FF")
  136. Finished FF
  137. >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
  138. >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
  139. >>> sgd._set_checkpoints([fc_1, pred])
  140. >>> try:
  141. ... state_dict = {}
  142. ... sgd.load(state_dict)
  143. >>> except NotImplementedError as e:
  144. ... print(e)
  145. load function is not supported by Recompute Optimizer for now
  146. """
  147. raise NotImplementedError(
  148. "load function is not supported by Recompute Optimizer for now"
  149. )
  150. def apply_gradients(self, params_grads):
  151. """
  152. call apply_gradients function of self._optimizer.
  153. Args:
  154. params_grads (list): list of (param, grad) pair to do optimization.
  155. Returns:
  156. list: A list of operators appended to the current program.
  157. Examples:
  158. .. code-block:: python
  159. >>> import paddle
  160. >>> import paddle.base.framework as framework
  161. >>> paddle.enable_static()
  162. >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
  163. ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
  164. ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
  165. ... cost = paddle.nn.functional.cross_entropy(
  166. ... input=prediction, label=input_y,
  167. ... reduction='none', use_softmax=False
  168. ... )
  169. ... sum_cost = paddle.mean(cost)
  170. ... return sum_cost, fc_1, prediction
  171. >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
  172. >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
  173. >>> cost, fc_1, pred = mlp(input_x, input_y)
  174. >>> print("Finished FF")
  175. Finished FF
  176. >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
  177. >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
  178. >>> sgd._set_checkpoints([fc_1, pred])
  179. >>> params_grads = sgd.backward(
  180. ... cost,
  181. ... startup_program=None,
  182. ... parameter_list=None,
  183. ... no_grad_set=None)
  184. >>> program = cost.block.program
  185. >>> with framework.program_guard(program, None):
  186. ... optimize_ops = sgd.apply_gradients(params_grads)
  187. >>> print("Finished apply gradients")
  188. Finished apply gradients
  189. """
  190. return self._optimizer.apply_gradients(params_grads=params_grads)
  191. def _create_vars(self, varname):
  192. pinned_var_name = unique_name.generate(varname + "@Pinned")
  193. fetched_var_name = unique_name.generate(varname + "@Fetch")
  194. pinned_var = self._main_program.global_block().create_var(
  195. name=pinned_var_name,
  196. shape=self.checkpoint_shape,
  197. dtype=self._main_program.global_block().var(varname).dtype,
  198. persistable=False,
  199. stop_gradient=True,
  200. )
  201. fetch_var = self._main_program.global_block().create_var(
  202. name=fetched_var_name,
  203. shape=self.checkpoint_shape,
  204. dtype=self._main_program.global_block().var(varname).dtype,
  205. persistable=False,
  206. stop_gradient=False,
  207. )
  208. return pinned_var_name, fetched_var_name
  209. def _append_fill_constant_ops(self, startup_program):
  210. """
  211. add fill_constant_ops to the end of the prog
  212. we should fill the pinned vars before running the main_prog
  213. to instantiate their tensor hold_, which could tell us whether
  214. the host memory could hold all the checkpoints from all the
  215. GPU devices in this node.
  216. """
  217. op_role = 0
  218. block = startup_program.global_block()
  219. fill_constant_vars = self.checkpoint_name2pinned_name.values()
  220. OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
  221. for varname in fill_constant_vars:
  222. var = self._main_program.global_block().var(varname)
  223. # NOTE (JZ-LIANG) to pre-allocate the CUDAPinned MEM
  224. pinned_var = block.create_var(
  225. name=varname,
  226. shape=self.checkpoint_shape,
  227. dtype=self._main_program.global_block().var(var.name).dtype,
  228. persistable=False,
  229. stop_gradient=True,
  230. )
  231. block.append_op(
  232. type='fill_constant',
  233. outputs={'Out': varname},
  234. attrs={
  235. "shape": var.shape,
  236. "dtype": var.dtype,
  237. "value": 0.0,
  238. "place_type": 2,
  239. OP_ROLE_KEY: op_role,
  240. },
  241. )
  242. def _insert_async_memcpy_op(
  243. self, insert_idx, src_varname, dst_varname, op_role, dst_place_type
  244. ):
  245. OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
  246. self.block._insert_op_without_sync(
  247. insert_idx,
  248. type='memcpy',
  249. inputs={'X': [self._main_program.global_block().var(src_varname)]},
  250. outputs={
  251. 'Out': [self._main_program.global_block().var(dst_varname)]
  252. },
  253. attrs={"dst_place_type": int(dst_place_type), OP_ROLE_KEY: op_role},
  254. )
  255. def _insert_fetch_op(self, idx, varname):
  256. assert (
  257. varname in self.checkpoint_name2pinned_name
  258. ), f"Try to fetch {varname} from Pinned Memory, but it is NOT a checkpoint"
  259. pinned_varname = self.checkpoint_name2pinned_name[varname]
  260. fetch_varname = self.checkpoint_name2fetch_name[varname]
  261. self._insert_async_memcpy_op(idx, pinned_varname, fetch_varname, 1, 1)
  262. def _insert_offload_op(self, idx, varname):
  263. assert (
  264. varname in self.checkpoint_name2pinned_name
  265. ), f"Try to offload {varname} to Pinned Memory, but it is NOT a checkpoint"
  266. pinned_varname = self.checkpoint_name2pinned_name[varname]
  267. self._insert_async_memcpy_op(idx, varname, pinned_varname, 0, 2)
  268. def _insert_sync_op(self, op_idx, checkpoint_name):
  269. # single stream offload no need sync
  270. pass
  271. def _record_fetch_op(self, idx):
  272. assert (
  273. len(self.un_fetch_checkpoint_names) > 0
  274. ), "Could NOT found checkpoint to fetch"
  275. checkpoint_name = self.un_fetch_checkpoint_names.pop(-1)
  276. logging.debug(f"Record fetch [{checkpoint_name}]")
  277. self.idx2insertions[idx] = ("fetch", checkpoint_name)
  278. return checkpoint_name
  279. def _record_offload_op(self, idx, checkpoint_name):
  280. expected_checkpoint_name = self.un_offload_checkpoint_names.pop(0)
  281. assert (
  282. checkpoint_name == expected_checkpoint_name
  283. ), f"expected to offload [{expected_checkpoint_name}] but got [{checkpoint_name}]"
  284. logging.debug(f"Record offload [{checkpoint_name}]")
  285. self.idx2insertions[idx] = ("offload", checkpoint_name)
  286. def _record_sync_op(self, idx, checkpoint_name):
  287. assert (
  288. checkpoint_name not in self.synced_checkpoints
  289. ), f"Try to sync the checkpoint [{checkpoint_name}] twice"
  290. self.synced_checkpoints.add(checkpoint_name)
  291. logging.debug(f"Record offload sync [{checkpoint_name}]")
  292. self.idx2insertions[idx] = ("sync", checkpoint_name)
  293. def _parse_backward(self):
  294. self.idx2insertions = {}
  295. # don't offload the last checkpoints, to favor throughput
  296. self.un_fetch_checkpoint_names = self.sorted_checkpoint_names[:]
  297. self.un_fetch_checkpoint_names.pop(-1)
  298. need_fetch_checkpoint_names = self.un_fetch_checkpoint_names[:]
  299. self.checkpoint_usage_count = {}
  300. for checkpoint_name in self.un_fetch_checkpoint_names:
  301. self.checkpoint_usage_count[checkpoint_name] = 0
  302. self.bw_start_op_idx = len(self.block.ops)
  303. for idx, op in enumerate(self.block.ops):
  304. if int(op.desc.attr("op_role")) == 1:
  305. self.bw_start_op_idx = idx
  306. break
  307. assert self.bw_start_op_idx < len(
  308. self.block.ops
  309. ), "Could NOT found backward op in prog"
  310. # fetch second to last checkpoint at the beginning of BW
  311. fetched_checkpoint_varname = self._record_fetch_op(self.bw_start_op_idx)
  312. last_last_fetch_checkpoint = None
  313. for i, op in enumerate(self.block.ops[self.bw_start_op_idx :]):
  314. idx = self.bw_start_op_idx + i
  315. input_vars = op.desc.input_arg_names()
  316. for input_var in input_vars:
  317. if input_var in need_fetch_checkpoint_names:
  318. if input_var not in self.un_fetch_checkpoint_names:
  319. # fetch the offload checkpoint when the first usage of its previous one
  320. if self.checkpoint_usage_count[input_var] == 0:
  321. # TODO (JZ-LIANG) sync memcpy_stream if extra stream for memcpy
  322. second_to_last_fetch_checkpoint = (
  323. fetched_checkpoint_varname
  324. )
  325. # there is NO fetch ahead the first checkpoint
  326. if input_var != self.sorted_checkpoint_names[0]:
  327. fetched_checkpoint_varname = (
  328. self._record_fetch_op(idx)
  329. )
  330. # should check the current used checkpoint is ths last fetch one
  331. assert (
  332. second_to_last_fetch_checkpoint == input_var
  333. ), f"Current recompute segment should use [{second_to_last_fetch_checkpoint}] BUT got [{input_var}]"
  334. # rename
  335. self.block.ops[idx]._rename_input(
  336. input_var,
  337. self.checkpoint_name2fetch_name[input_var],
  338. )
  339. self.checkpoint_usage_count[input_var] += 1
  340. else:
  341. raise ValueError(
  342. f"use checkpoint [{input_var}] before fetch in BW"
  343. )
  344. assert (
  345. len(self.un_fetch_checkpoint_names) == 0
  346. ), f"{self.un_fetch_checkpoint_names} checkpoints have NOT been Recorded"
  347. def _update_backward(self):
  348. if len(self.idx2insertions) == 0:
  349. return
  350. total_op = len(self.block.ops)
  351. for op_idx in reversed(range(self.bw_start_op_idx, total_op)):
  352. if op_idx in self.idx2insertions:
  353. operation, checkpoint_name = self.idx2insertions[op_idx]
  354. if operation == "fetch":
  355. self._insert_fetch_op(op_idx, checkpoint_name)
  356. logging.debug(f"Insert [{checkpoint_name}] fetch op.")
  357. del self.idx2insertions[op_idx]
  358. elif operation == "sync":
  359. self._insert_sync_op(op_idx, checkpoint_name)
  360. logging.debug(f"Sync [{checkpoint_name}] fetch op.")
  361. self.block._sync_with_cpp()
  362. assert (
  363. len(self.idx2insertions) == 0
  364. ), f"{[ele[1] for ele in self.idx2insertions.values()]} checkpoints left un-Fetched"
  365. def _parse_forward(self):
  366. self.idx2insertions = {}
  367. # don't offload the last checkpoints, faster, less memory saving
  368. self.un_offload_checkpoint_names = self.sorted_checkpoint_names[:]
  369. last_checkpoint = self.un_offload_checkpoint_names.pop(-1)
  370. need_offload_checkpoint_names = self.un_offload_checkpoint_names[:]
  371. self.checkpoint_usage_count_and_idx = {}
  372. for checkpoint_name in self.un_offload_checkpoint_names:
  373. self.checkpoint_usage_count_and_idx[checkpoint_name] = {
  374. 'count': 0,
  375. 'idx': -1,
  376. }
  377. self.synced_checkpoints = set()
  378. self.fw_start_op_idx = len(self.block.ops)
  379. for idx, op in enumerate(self.block.ops):
  380. if int(op.desc.attr("op_role")) == 0:
  381. self.fw_start_op_idx = idx
  382. break
  383. assert self.fw_start_op_idx < len(
  384. self.block.ops
  385. ), "Could NOT found Forward op in prog"
  386. last_offload_checkpoint = None
  387. for i, op in enumerate(
  388. self.block.ops[self.fw_start_op_idx : self.bw_start_op_idx]
  389. ):
  390. idx = self.fw_start_op_idx + i
  391. output_vars = op.desc.output_arg_names()
  392. input_vars = op.desc.input_arg_names()
  393. for output_var in output_vars:
  394. if output_var in need_offload_checkpoint_names:
  395. assert (
  396. len(output_vars) == 1
  397. ), f"checkpoint should be the only Output of a certain op, but [{output_var}] is from [{op}]"
  398. if output_var in self.un_offload_checkpoint_names:
  399. # insert sync op if last checkpoint has not been sync
  400. if last_offload_checkpoint is not None:
  401. if (
  402. self.checkpoint_usage_count_and_idx[
  403. last_offload_checkpoint
  404. ]['count']
  405. == 0
  406. ):
  407. self._record_sync_op(
  408. idx, last_offload_checkpoint
  409. )
  410. else:
  411. last_usage_idx = (
  412. self.checkpoint_usage_count_and_idx[
  413. last_offload_checkpoint
  414. ]['idx']
  415. )
  416. assert (
  417. last_usage_idx > 0
  418. ), f"last_usage_idx of checkpoint [{last_offload_checkpoint}] should large than 0"
  419. self._record_sync_op(
  420. last_usage_idx + 1, last_offload_checkpoint
  421. )
  422. # insert offload op after the checkpoint's generation op
  423. self._record_offload_op(idx + 1, output_var)
  424. last_offload_checkpoint = output_var
  425. else:
  426. raise ValueError(
  427. f"There should be just ONE op that output checkpoint [{output_var}]"
  428. )
  429. # need to sync the last need to offload checkpoint before the last checkpoint as output op
  430. if output_var == last_checkpoint:
  431. assert (
  432. len(output_vars) == 1
  433. ), f"checkpoint should be the only Output of a certain op, but [{output_var}] is from [{op}]"
  434. assert (
  435. last_offload_checkpoint
  436. == self.sorted_checkpoint_names[-2]
  437. ), f"the last offload checkpoint before [{last_checkpoint}] is suppose to be [{self.sorted_checkpoint_names[-2]}], but got [{last_offload_checkpoint}]"
  438. # sync if last checkpoint has not been sync
  439. if (
  440. self.checkpoint_usage_count_and_idx[
  441. last_offload_checkpoint
  442. ]['idx']
  443. == 0
  444. ):
  445. self._record_sync_op(idx, last_offload_checkpoint)
  446. else:
  447. last_usage_idx = self.checkpoint_usage_count_and_idx[
  448. last_offload_checkpoint
  449. ]['idx']
  450. assert (
  451. last_usage_idx > 0
  452. ), f"last_usage_idx of checkpoint [{last_offload_checkpoint}] should large than 0"
  453. self._record_sync_op(
  454. last_usage_idx + 1, last_offload_checkpoint
  455. )
  456. # record checkpoint usage
  457. for input_var in input_vars:
  458. if input_var in need_offload_checkpoint_names:
  459. assert (
  460. input_var not in self.synced_checkpoints
  461. ), f"checkpoint [{input_var}] used after sync"
  462. self.checkpoint_usage_count_and_idx[input_var]['count'] += 1
  463. self.checkpoint_usage_count_and_idx[input_var]['idx'] = idx
  464. assert (
  465. len(self.un_offload_checkpoint_names) == 0
  466. ), f"{self.un_fetch_checkpoint_names} checkpoints have NOT been Recorded"
  467. assert len(self.synced_checkpoints) == len(
  468. need_offload_checkpoint_names
  469. ), f"{set(need_offload_checkpoint_names) - set(self.synced_checkpoints)} checkpoints have NOT been Recorded"
  470. def _update_forward(self):
  471. if len(self.idx2insertions) == 0:
  472. return
  473. for op_idx in reversed(
  474. range(self.fw_start_op_idx, self.bw_start_op_idx)
  475. ):
  476. if op_idx in self.idx2insertions:
  477. operation, checkpoint_name = self.idx2insertions[op_idx]
  478. if operation == "offload":
  479. self._insert_offload_op(op_idx, checkpoint_name)
  480. logging.debug(f"Insert [{checkpoint_name}] offload op.")
  481. del self.idx2insertions[op_idx]
  482. elif operation == "sync":
  483. self._insert_sync_op(op_idx, checkpoint_name)
  484. logging.debug(
  485. f"Insert [{checkpoint_name}] offload_sync op."
  486. )
  487. del self.idx2insertions[op_idx]
  488. self.block._sync_with_cpp()
  489. assert (
  490. len(self.idx2insertions) == 0
  491. ), f"{[ele[1] for ele in self.idx2insertions.values()]} checkpoints left un-Offloaded"
  492. def _check_offload_fetch(self):
  493. # TODO(JZ-LIANG) the single stream offload need no sync
  494. pass
  495. def _offload(self, loss, startup_program=None):
  496. """
  497. core steps for recompute offload
  498. 1. create pinned vars and temp vars
  499. 2. parse & update Forward pass: offload, sync
  500. 3. parse & update Backward pass: rename, fetch, sync
  501. 4. verify the correctness
  502. """
  503. self._main_program = loss.block.program
  504. self.block = loss.block
  505. if startup_program is None:
  506. startup_program = paddle.static.default_startup_program()
  507. with program_guard(self._main_program, startup_program):
  508. assert (
  509. len(self.checkpoint_shape) > 0
  510. ), f"checkpoints shape {self.checkpoint_shape} should be an non empty list like: [12, 512, 1024]"
  511. assert all(
  512. ele > 0 for ele in self.checkpoint_shape
  513. ), f"all ele in checkpoints shape {self.checkpoint_shape} should be a determined integer larger than 0"
  514. self.checkpoint_name2pinned_name = {}
  515. self.checkpoint_name2fetch_name = {}
  516. for checkpoint_varname in self.sorted_checkpoint_names:
  517. pinned_var_name, fetch_var_name = self._create_vars(
  518. checkpoint_varname
  519. )
  520. self.checkpoint_name2pinned_name[
  521. checkpoint_varname
  522. ] = pinned_var_name
  523. self.checkpoint_name2fetch_name[
  524. checkpoint_varname
  525. ] = fetch_var_name
  526. self._append_fill_constant_ops(startup_program)
  527. # TODO (JZ-LIANG) to provide two offload strategy in future
  528. # step 2. parse & update FW: rename, offload, sync
  529. self._parse_backward()
  530. self._update_backward()
  531. # step 3. parse & update BW: rename, offload, sync
  532. self._parse_forward()
  533. self._update_forward()
  534. # step 4. verify the correctness
  535. self._check_offload_fetch()
  536. def backward(
  537. self,
  538. loss,
  539. startup_program=None,
  540. parameter_list=None,
  541. no_grad_set=None,
  542. callbacks=None,
  543. ):
  544. """
  545. call append_backward with checkpoints.
  546. Args:
  547. loss (Variable): loss variable to run optimizations.
  548. startup_program (Program): startup_program for initializing parameters
  549. in `parameter_list`.
  550. parameter_list (list): list of Variables or Variable.names to update.
  551. no_grad_set (set|None): set of Variables or Variables.names should be ignored.
  552. callbacks (list|None): list of callables to run when appending backward
  553. operator for one parameter.
  554. checkpoints (list): list of Variables as checkpoints
  555. Examples:
  556. .. code-block:: python
  557. >>> import paddle
  558. >>> paddle.enable_static()
  559. >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
  560. ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
  561. ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
  562. ... cost = paddle.nn.functional.cross_entropy(
  563. ... input=prediction, label=input_y,
  564. ... reduction='none', use_softmax=False
  565. ... )
  566. ... sum_cost = paddle.mean(cost)
  567. ... return sum_cost, fc_1, prediction
  568. >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
  569. >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
  570. >>> cost, fc_1, pred = mlp(input_x, input_y)
  571. >>> print("Finished FF")
  572. Finished FF
  573. >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
  574. >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
  575. >>> sgd._set_checkpoints([fc_1, pred])
  576. >>> params_grads = sgd.backward(
  577. ... cost,
  578. ... startup_program=None,
  579. ... parameter_list=None,
  580. ... no_grad_set=None)
  581. >>> print("Finished backward")
  582. Finished backward
  583. """
  584. assert (
  585. self._checkpoints is not None
  586. ), "You should call _set_checkpoints first"
  587. if in_dygraph_mode():
  588. raise NotImplementedError(
  589. "DyGraph current does not support recompute"
  590. )
  591. self._dtype = loss.dtype
  592. program = loss.block.program
  593. with program_guard(program, startup_program):
  594. checkpoint_vars = []
  595. for ckpt in self._checkpoints:
  596. if isinstance(ckpt, Variable):
  597. checkpoint_vars.append(ckpt)
  598. else:
  599. checkpoint_vars.append(loss.block.var(ckpt))
  600. # allow return to non-recompute when checkpoints is empty
  601. if len(checkpoint_vars) > 0:
  602. params_grads, sorted_checkpoint_names = append_backward(
  603. loss,
  604. parameter_list,
  605. no_grad_set,
  606. checkpoints=checkpoint_vars,
  607. )
  608. else:
  609. params_grads = append_backward(
  610. loss,
  611. parameter_list,
  612. no_grad_set,
  613. checkpoints=checkpoint_vars,
  614. )
  615. if self.enable_offload:
  616. self.sorted_checkpoint_names = sorted_checkpoint_names
  617. self._offload(loss, startup_program=startup_program)
  618. return params_grads
  619. def apply_optimize(self, loss, startup_program, params_grads):
  620. """
  621. call the apply_optimize function of self._optimizer
  622. Args:
  623. loss (Variable): loss variable to run optimizations.
  624. startup_program (Program): startup_program for initializing parameters
  625. in `parameter_list`.
  626. params_grads (list): list of (param, grad) pair to do optimization.
  627. Examples:
  628. .. code-block:: python
  629. >>> import paddle
  630. >>> paddle.enable_static()
  631. >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
  632. ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
  633. ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
  634. ... cost = paddle.nn.functional.cross_entropy(
  635. ... input=prediction, label=input_y,
  636. ... reduction='none', use_softmax=False
  637. ... )
  638. ... sum_cost = paddle.mean(cost)
  639. ... return sum_cost, fc_1, prediction
  640. >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
  641. >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
  642. >>> cost, fc_1, pred = mlp(input_x, input_y)
  643. >>> print("Finished FF")
  644. Finished FF
  645. >>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
  646. >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd)
  647. >>> sgd._set_checkpoints([fc_1, pred])
  648. >>> params_grads = sgd.backward(
  649. ... cost,
  650. ... startup_program=None,
  651. ... parameter_list=None,
  652. ... no_grad_set=None)
  653. >>> optimize_ops = sgd.apply_optimize(
  654. ... cost, startup_program=None, params_grads=params_grads)
  655. >>> print("Finished apply_optimize")
  656. Finished apply_optimize
  657. """
  658. func = (
  659. self._optimizer.apply_optimize
  660. if hasattr(self._optimizer, 'apply_optimize')
  661. else self._optimizer._apply_optimize
  662. )
  663. return func(
  664. loss, startup_program=startup_program, params_grads=params_grads
  665. )
  666. def minimize(
  667. self, loss, startup_program=None, parameter_list=None, no_grad_set=None
  668. ):
  669. assert isinstance(loss, Variable), "The loss should be an Variable."
  670. assert (
  671. self._checkpoints is not None
  672. ), "You should call _set_checkpoints first"
  673. if in_dygraph_mode():
  674. raise NotImplementedError(
  675. "DyGraph current does not support recompute"
  676. )
  677. params_grads = self.backward(
  678. loss,
  679. startup_program=startup_program,
  680. parameter_list=parameter_list,
  681. no_grad_set=no_grad_set,
  682. )
  683. optimize_ops = self.apply_optimize(
  684. loss, startup_program=startup_program, params_grads=params_grads
  685. )
  686. return optimize_ops, params_grads