recompute.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import List, Sequence, Tuple
  16. import paddle
  17. from paddle import pir
  18. from paddle.autograd import backward_utils
  19. from paddle.base import core
  20. _PADDLE_DTYPE_2_NBYTES = {
  21. core.DataType.BOOL: 1,
  22. core.DataType.FLOAT16: 2,
  23. core.DataType.BFLOAT16: 2,
  24. core.DataType.FLOAT32: 4,
  25. core.DataType.FLOAT64: 8,
  26. core.DataType.INT8: 1,
  27. core.DataType.INT16: 2,
  28. core.DataType.INT32: 4,
  29. core.DataType.INT64: 8,
  30. core.DataType.UINT8: 1,
  31. core.DataType.COMPLEX64: 8,
  32. core.DataType.COMPLEX128: 16,
  33. }
  34. # define the default recompute ops that can be fused between pairs
  35. DEFAULT_RECOMPUTABLE_OPS: List[str] = [
  36. "pd_op.full_int_array",
  37. "pd_op.full",
  38. "pd_op.sum",
  39. "pd_op.divide",
  40. "pd_op.subtract",
  41. "pd_op.add",
  42. "pd_op.multiply",
  43. "pd_op.elementwise_pow",
  44. "pd_op.rsqrt",
  45. "pd_op.reshape",
  46. "pd_op.full_like",
  47. "pd_op.assign",
  48. "pd_op.expand",
  49. "pd_op.scale",
  50. "pd_op.exp",
  51. "pd_op.equal",
  52. "pd_op.where",
  53. "pd_op.sin",
  54. "pd_op.cos",
  55. "pd_op.add_n",
  56. "pd_op.any",
  57. "pd_op.bitwise_and",
  58. "pd_op.cast",
  59. "pd_op.concat",
  60. "pd_op.full_with_tensor",
  61. "pd_op.gather_nd",
  62. "pd_op.greater_than",
  63. "pd_op.less_than",
  64. "pd_op.logical_and",
  65. "pd_op.logical_not",
  66. "pd_op.not_equal",
  67. "pd_op.pow",
  68. "pd_op.shape",
  69. "pd_op.slice",
  70. "pd_op.squeeze",
  71. "pd_op.unsqueeze",
  72. "pd_op.transpose",
  73. "pd_op.where",
  74. "pd_op.prod",
  75. "pd_op.log",
  76. "pd_op.log1p",
  77. "pd_op.logit",
  78. "pd_op.max",
  79. "pd_op.expand_as",
  80. "pd_op.split",
  81. "pd_op.arange",
  82. "pd_op.put_along_axis",
  83. "pd_op.tanh",
  84. "pd_op.atan",
  85. "pd_op.atanh",
  86. "pd_op.sinh",
  87. "pd_op.asin",
  88. "pd_op.asinh",
  89. "pd_op.cosh",
  90. "pd_op.acos",
  91. "pd_op.acosh",
  92. "pd_op.abs",
  93. "pd_op.sign",
  94. "pd_op.expm1",
  95. "pd_op.erf",
  96. "pd_op.erfinv",
  97. "pd_op.ceil",
  98. "pd_op.floor",
  99. "pd_op.frac",
  100. "pd_op.round",
  101. "pd_op.trunc",
  102. "pd_op.equal",
  103. "pd_op.angle",
  104. "pd_op.as_complex",
  105. "pd_op.as_real",
  106. "pd_op.complex",
  107. "pd_op.real",
  108. "pd_op.imag",
  109. "pd_op.conj",
  110. "pd_op.not_equal",
  111. "pd_op.greater_equal",
  112. "pd_op.greater_than",
  113. "pd_op.less_equal",
  114. "pd_op.less_than",
  115. "pd_op.bitwise_and",
  116. "pd_op.bitwise_not",
  117. "pd_op.bitwise_or",
  118. "pd_op.bitwise_xor",
  119. "pd_op.isinf",
  120. "pd_op.isnan",
  121. ]
  122. VIEW_OPS: List[str] = []
  123. RANDOM_OPS: List[str] = ["pd_op.randint", "pd_op.uniform", "pd_op.dropout"]
  124. COMPUTE_INTENSIVE_OPS: List[str] = [
  125. "pd_op.matmul",
  126. "pd_op.conv2d",
  127. "pd_op.layer_norm",
  128. "pd_op.batchnorm",
  129. "pd_op.softmax",
  130. ]
  131. AGGRESSIVE_RECOMPUTATION = False
  132. # Restricts the amount of computation recompute can do.
  133. MAX_DIST_FROM_BW = 3
  134. def auto_recompute(
  135. program: paddle.static.Program,
  136. inputs: Sequence[pir.Value],
  137. outputs: Sequence[pir.Value],
  138. grad_outputs: Sequence[pir.Value],
  139. fwd_op_end_idx: int,
  140. backward_op_start_idx: int,
  141. recomputable_ops: Sequence[str] = None,
  142. ) -> Tuple[paddle.static.Program, int]:
  143. '''
  144. Considering the compiler fuse strategy, we model the pir graph.
  145. Convert the pir calculation graph into a networkx calculation
  146. graph. Find the cut point through the min-cut algorithm,
  147. which is the value to be saved in pir forward calculation graph.
  148. Recompute the forward computation graph to replace intermediate
  149. variables in the forward graph held by the backward graph.
  150. .. warning::
  151. This API is experimental and likely to change.
  152. Args:
  153. program (Program): The program to be recomputed.
  154. inputs:(list[Value]|tuple(Value)): The input Values
  155. of the forward graph.
  156. outputs:(list[Value]|tuple(Value)): The out Values
  157. of the forward graph.
  158. grad_outputs:(list[Value]|tuple(Value)): initial gradient values
  159. of `outputs` .
  160. forward_op_end_idx(int): The index of the last forward op.
  161. backward_op_start_idx(int): The index of the start backward op.
  162. recomputable_ops(list[str]|tuple(str)|None): The op names that can
  163. be recomputed. If 'recompute_ops' is None, we will use the
  164. default recomputable_ops. Default None.
  165. Returns:
  166. recomputed_program(Program): The recomputed program.
  167. fwd_op_end_idx(int): The index of the last forward op in recomputed program.
  168. Examples:
  169. .. code-block:: python
  170. >>> import numpy as np
  171. >>> import paddle
  172. >>> from paddle.autograd.ir_backward import grad as ir_grad
  173. >>> from paddle.base import core
  174. >>> from paddle.decomposition import decompose
  175. >>> def forward(x):
  176. ... y = paddle.sin(x)
  177. ... z = paddle.cos(y)
  178. ... return z
  179. >>> np_x = np.random.random(size=[4096, 4096]).astype("float32")
  180. >>> paddle.enable_static()
  181. >>> core._set_prim_all_enabled(True)
  182. >>> main_program = paddle.static.Program()
  183. >>> with paddle.static.program_guard(main_program):
  184. >>> x = paddle.static.data(
  185. >>> name="x", shape=[4096, 4096], dtype="float32"
  186. >>> )
  187. >>> x.stop_gradient = False
  188. >>> out = forward(x)
  189. >>> out_grad = paddle.full(
  190. >>> shape=out.shape, fill_value=3, dtype="float32"
  191. >>> )
  192. >>> [out] = decompose(main_program, [out])
  193. >>> [dx] = ir_grad(out, [x], out_grad)
  194. >>> main_program, _ = paddle.decomposition.auto_recompute(
  195. >>> main_program,
  196. >>> [x],
  197. >>> [out],
  198. >>> grad_outputs=[out_grad],
  199. >>> fwd_op_end_idx=2,
  200. >>> backward_op_start_idx=4
  201. >>> )
  202. >>> exe = paddle.static.Executor(paddle.CUDAPlace(0))
  203. >>> res = exe.run(
  204. >>> feed={'x': np_x},
  205. >>> fetch_list=[dx],
  206. >>> )
  207. >>> print(main_program)
  208. {
  209. (%0) = "pd_op.data" () {dtype:(pd_op.DataType)float32,name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4096,4096],stop_gradient:[false]} : () -> pd_op.tensor<4096x4096xf32>
  210. (%1) = "pd_op.sin" (%0) {stop_gradient:[false]} : (pd_op.tensor<4096x4096xf32>) -> pd_op.tensor<4096x4096xf32>
  211. (%2) = "pd_op.cos" (%1) {stop_gradient:[false]} : (pd_op.tensor<4096x4096xf32>) -> pd_op.tensor<4096x4096xf32>
  212. (%3) = "pd_op.full" () {dtype:(pd_op.DataType)float32,place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4096,4096],stop_gradient:[true],value:(Float)3} : () -> pd_op.tensor<4096x4096xf32>
  213. (%4) = "pd_op.sin" (%0) {stop_gradient:[false]} : (pd_op.tensor<4096x4096xf32>) -> pd_op.tensor<4096x4096xf32>
  214. (%5) = "pd_op.sin" (%4) {stop_gradient:[false]} : (pd_op.tensor<4096x4096xf32>) -> pd_op.tensor<4096x4096xf32>
  215. (%6) = "pd_op.full" () {dtype:(pd_op.DataType)float32,place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[1],stop_gradient:[true],value:(Float)-1} : () -> pd_op.tensor<1xf32>
  216. (%7) = "pd_op.scale" (%5, %6) {bias:(Float)0,bias_after_scale:true,stop_gradient:[false]} : (pd_op.tensor<4096x4096xf32>, pd_op.tensor<1xf32>) -> pd_op.tensor<4096x4096xf32>
  217. (%8) = "pd_op.multiply" (%7, %3) {stop_gradient:[false]} : (pd_op.tensor<4096x4096xf32>, pd_op.tensor<4096x4096xf32>) -> pd_op.tensor<4096x4096xf32>
  218. (%9) = "pd_op.cos" (%0) {stop_gradient:[false]} : (pd_op.tensor<4096x4096xf32>) -> pd_op.tensor<4096x4096xf32>
  219. (%10) = "pd_op.multiply" (%9, %8) {stop_gradient:[false]} : (pd_op.tensor<4096x4096xf32>, pd_op.tensor<4096x4096xf32>) -> pd_op.tensor<4096x4096xf32>
  220. (%11) = "pd_op.fetch" (%10) {col:(Int32)0,is_persistable:[true],name:"fetch0",stop_gradient:[false]} : (pd_op.tensor<4096x4096xf32>) -> pd_op.tensor<4096x4096xf32>
  221. }
  222. '''
  223. # 1. find smart recompute needed saved values by min-cut algorithm
  224. # 1.1 classify value nodes
  225. import networkx as nx
  226. # model value as graph's node, op as graph's edge
  227. (
  228. required_fw_value_nodes,
  229. required_bw_value_nodes,
  230. unclaimed_value_nodes,
  231. ) = classify_value_node(program, grad_outputs, fwd_op_end_idx)
  232. if len(required_bw_value_nodes) == 0:
  233. return program, fwd_op_end_idx
  234. all_ops = program.global_block().ops
  235. # 1.2 cal value nodes dist to backward
  236. dist_from_bw = cal_value_nodes_dist_to_backward(
  237. all_ops, required_fw_value_nodes
  238. )
  239. # 1.3 classify ops
  240. default_recomputable_ops = DEFAULT_RECOMPUTABLE_OPS
  241. view_ops = VIEW_OPS
  242. default_recomputable_ops += view_ops
  243. recomputable_ops = (
  244. set(recomputable_ops)
  245. if recomputable_ops is not None
  246. else set(default_recomputable_ops)
  247. )
  248. random_ops = RANDOM_OPS
  249. compute_intensive_ops = COMPUTE_INTENSIVE_OPS
  250. unrecomputable_ops = random_ops + compute_intensive_ops
  251. fusible_ops = recomputable_ops | set(random_ops)
  252. def _is_fusible(value_node1, value_node2):
  253. return (
  254. value_node1.get_defining_op().name() in fusible_ops
  255. and value_node2.get_defining_op().name() in fusible_ops
  256. )
  257. def _is_materialized_backwards(value_node):
  258. cur_value_nodes = backward_utils.ValueSet()
  259. cur_value_nodes.add(value_node)
  260. while len(cur_value_nodes) > 0:
  261. cur_value_node = cur_value_nodes.pop()
  262. users = find_value_node_users(cur_value_node)
  263. for user in users:
  264. if user not in required_fw_value_nodes and not _is_fusible(
  265. cur_value_node, user
  266. ):
  267. return True
  268. if (
  269. user not in required_fw_value_nodes
  270. and get_real_define_op_name(user) in view_ops
  271. ):
  272. cur_value_nodes.add(user)
  273. return False
  274. def _is_materialized(value_node, placeholder_value_nodes):
  275. if value_node in placeholder_value_nodes:
  276. return True
  277. users = find_value_node_users(value_node)
  278. return not all(_is_fusible(value_node, user) for user in users)
  279. def _get_node_weight(value_node, placeholder_value_nodes):
  280. mem_sz = cal_value_node_size(value_node)
  281. # Heuristic to bias towards nodes closer to the backwards pass
  282. mem_sz = int(
  283. mem_sz * (1.1 ** max(min(dist_from_bw[value_node], 100), 1))
  284. )
  285. if _is_materialized(value_node, placeholder_value_nodes):
  286. return mem_sz
  287. else:
  288. return mem_sz * 2
  289. def _ban_recomputation(value_node):
  290. if AGGRESSIVE_RECOMPUTATION:
  291. return value_node.get_defining_op().name() in unrecomputable_ops
  292. else:
  293. if value_node.get_defining_op().name() not in recomputable_ops:
  294. return True
  295. # If a node *must* be materialized in the backwards pass, then we
  296. # should never recompute it. This is a pretty subtle point. In
  297. # general, the assumption we make is that recomputing a node in the
  298. # backwards pass is "free". However, if a node must be materialized
  299. # in the backwards pass, then recomputing it is never free.
  300. if _is_materialized_backwards(value_node):
  301. return True
  302. if dist_from_bw[value_node] > MAX_DIST_FROM_BW:
  303. return True
  304. # If the output of an op is 4x smaller (arbitrary choice),
  305. # then we don't allow recomputation.
  306. output_size = cal_value_node_size(value_node)
  307. inputs = get_real_input_nodes(value_node)
  308. inputs_size = sum(cal_value_node_size(i) for i in inputs)
  309. return output_size * 4 < inputs_size
  310. # 1.4 Model pir graph. Convert the pir calculation graph into a networkx calculation graph.
  311. outputs = backward_utils.ValueSet(outputs)
  312. inputs = backward_utils.ValueSet(inputs)
  313. value_id_dict = {}
  314. nx_graph = nx.DiGraph()
  315. for value_node in (
  316. required_fw_value_nodes
  317. | required_bw_value_nodes
  318. | unclaimed_value_nodes
  319. ):
  320. if value_node in outputs or not value_node.initialized():
  321. continue
  322. if value_node.get_defining_op().name() == "builtin.combine":
  323. continue
  324. if (
  325. len(value_node.all_used_ops()) == 1
  326. and value_node.all_used_ops()[0].name() == "builtin.split"
  327. ):
  328. continue
  329. if value_node in required_bw_value_nodes:
  330. nx_graph.add_edge(value_node.id + "_in", "sink", capacity=math.inf)
  331. value_id_dict[value_node.id] = value_node
  332. continue
  333. if value_node in inputs:
  334. nx_graph.add_edge(
  335. "source", value_node.id + "_in", capacity=math.inf
  336. )
  337. value_id_dict[value_node.id] = value_node
  338. # If a node can't be recomputed (too expensive or involves randomness),
  339. # we prevent it from being recomputed by adding an inf edge to the source
  340. # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.
  341. if (
  342. _ban_recomputation(value_node)
  343. and value_node in required_fw_value_nodes
  344. ):
  345. nx_graph.add_edge(
  346. "source", value_node.id + "_in", capacity=math.inf
  347. )
  348. value_id_dict[value_node.id] = value_node
  349. # todo(wanghao107) hack for dynamic shape
  350. if is_dynamic_value_node(value_node):
  351. weight = 1
  352. else:
  353. weight = _get_node_weight(
  354. value_node, placeholder_value_nodes=inputs | outputs
  355. )
  356. # Creates the weights on the "node" edge
  357. nx_graph.add_edge(
  358. value_node.id + "_in", value_node.id + "_out", capacity=weight
  359. )
  360. value_id_dict[value_node.id] = value_node
  361. users = find_value_node_users(value_node)
  362. for user in users:
  363. nx_graph.add_edge(
  364. value_node.id + "_out", user.id + "_in", capacity=math.inf
  365. )
  366. # 1.5 find saved values by minimum cut.
  367. _, partition = nx.minimum_cut(nx_graph, "source", "sink")
  368. reachable, non_reachable = partition
  369. cutset = set()
  370. for u, nbrs in ((n, nx_graph[n]) for n in reachable):
  371. cutset.update((u, v) for v in nbrs if v in non_reachable)
  372. cut_value_nodes = backward_utils.ValueSet()
  373. for value_node_in, value_node_out in cutset:
  374. assert value_node_in[:-3] == value_node_out[:-4]
  375. value_node = value_id_dict[value_node_in[:-3]]
  376. cut_value_nodes.add(value_node)
  377. saved_values = cut_value_nodes
  378. # (TODO: wanghao107): remove it and fix model
  379. saved_values = cut_value_nodes | inputs
  380. # 2.patition the joint graph by saved values.
  381. (
  382. program_after_recompute,
  383. fwd_op_end_idx_after_recompute,
  384. ) = partition_joint_graph(
  385. program,
  386. saved_values,
  387. inputs,
  388. outputs,
  389. fwd_op_end_idx,
  390. backward_op_start_idx,
  391. )
  392. return program_after_recompute, fwd_op_end_idx_after_recompute
  393. def partition_joint_graph(
  394. program: paddle.static.Program,
  395. saved_values: List[pir.Value],
  396. inputs: List[pir.Value],
  397. outputs: List[pir.Value],
  398. fwd_op_end_idx: int,
  399. backward_op_start_idx: int,
  400. ) -> Tuple[paddle.static.Program, int]:
  401. """
  402. Partition the joint graph, recompute the intermediate values
  403. by saved values to save memory.
  404. Args:
  405. program(Program): The program to be recomputed.
  406. saved_values(list[valueiable]): The saved values
  407. of forward graph which used by backward graph.
  408. inputs:(list[Value]|tuple(Value)): The input Values
  409. of the forward graph.
  410. outputs(list[valueiable]): The out values
  411. of the forward graph.
  412. forward_op_end_idx(int): The index of the last forward op.
  413. backward_op_start_idx(int): The index of the start backward op.
  414. Returns:
  415. recomputed_program(Program): The recomputed program.
  416. fwd_op_end_idx(int): The index of the last forward op in
  417. recomputed program.
  418. """
  419. saved_values = backward_utils.ValueSet(saved_values)
  420. outputs = backward_utils.ValueSet(outputs)
  421. # 1. Analyze the program, get all forward porgram mid hold values
  422. mid_hold_values = analyze_mid_hold_values(
  423. program,
  424. saved_values,
  425. inputs,
  426. outputs,
  427. fwd_op_end_idx,
  428. backward_op_start_idx,
  429. )
  430. # 2. Extract the recompute subgraph and replace forward mid hold values with recompute subgraph's outputs
  431. program, fwd_op_end_idx = replace_mid_values_with_forward_subgraph(
  432. program,
  433. saved_values,
  434. mid_hold_values,
  435. fwd_op_end_idx,
  436. backward_op_start_idx,
  437. )
  438. return program, fwd_op_end_idx
  439. def replace_mid_values_with_forward_subgraph(
  440. program, saved_values, mid_values, fwd_op_end_idx, backward_op_start_idx
  441. ):
  442. def _extract_forward_recompute_subgraph_for_backward(
  443. saved_values, mid_values
  444. ):
  445. def _find_recompute_ops(
  446. recompute_value,
  447. saved_values,
  448. marked_recompute_ops,
  449. needed_saved_values,
  450. ):
  451. define_op = recompute_value.get_defining_op()
  452. if define_op in marked_recompute_ops:
  453. return
  454. op_inputs = define_op.operands_source()
  455. if len(op_inputs) == 0 and define_op.name() not in [
  456. "pd_op.full",
  457. "pd_op.full_int_array",
  458. ]:
  459. raise Exception(
  460. f"Every path to recompute value {recompute_value} must have saved value or starting point of the path is one of op in [pd_op.full, pd_op.full_int_array], but find {define_op.name()} op"
  461. )
  462. for op_input in op_inputs:
  463. if op_input in saved_values:
  464. if op_input not in needed_saved_values:
  465. needed_saved_values.add(op_input)
  466. continue
  467. _find_recompute_ops(
  468. op_input,
  469. saved_values,
  470. marked_recompute_ops,
  471. needed_saved_values,
  472. )
  473. marked_recompute_ops.add(define_op)
  474. return
  475. # {inputs:[...], ops: [...], needed_outputs: [...]}
  476. recompute_subgraph_ops = set()
  477. recompute_subgraph_inputs = backward_utils.ValueSet()
  478. recompute_subgraph_outputs_backward_needed = mid_values
  479. for recompute_value in mid_values:
  480. _find_recompute_ops(
  481. recompute_value,
  482. saved_values,
  483. recompute_subgraph_ops,
  484. recompute_subgraph_inputs,
  485. )
  486. recompute_subgraph = {
  487. "inputs": recompute_subgraph_inputs,
  488. "recompute_ops": recompute_subgraph_ops,
  489. "outputs": recompute_subgraph_outputs_backward_needed,
  490. }
  491. return recompute_subgraph
  492. forward_ops = set(program.global_block().ops[: fwd_op_end_idx + 1])
  493. backward_ops = set(program.global_block().ops[backward_op_start_idx:])
  494. first_backward_op = program.global_block().ops[backward_op_start_idx]
  495. # 1. find forward subgraph to recompute mid values that backward need to hold.
  496. recompute_forward_subgraph = (
  497. _extract_forward_recompute_subgraph_for_backward(
  498. saved_values, mid_values
  499. )
  500. )
  501. # 2. clone subgraph which need to be recomputed
  502. origin_ops = recompute_forward_subgraph["recompute_ops"]
  503. origin_subgraph_inputs = recompute_forward_subgraph["inputs"]
  504. origin_subgraph_outputs = recompute_forward_subgraph["outputs"]
  505. cloned_ops, value_map = clone_graph(
  506. program, origin_ops, origin_subgraph_inputs, first_backward_op
  507. )
  508. # 3. replace mid values that backward need to hold with recompute subgraph's outputs
  509. cloned_subgraph_outputs = backward_utils.ValueSet()
  510. for origin_value in origin_subgraph_outputs:
  511. cloned_value = value_map.look_up(origin_value)
  512. origin_value.replace_grad_users_with(cloned_value, backward_ops)
  513. cloned_subgraph_outputs.add(cloned_value)
  514. # 4. reset recomputed ops location in program
  515. reseted_ops = set()
  516. backward_ops_list = program.global_block().ops[backward_op_start_idx:]
  517. for op in backward_ops_list:
  518. op_inputs = op.operands_source()
  519. for op_input in op_inputs:
  520. if op_input in cloned_subgraph_outputs:
  521. parent_ops = find_parent_ops(op_input)
  522. for cloned_op in cloned_ops:
  523. if cloned_op in parent_ops and cloned_op not in reseted_ops:
  524. cloned_op.move_before(op)
  525. reseted_ops.add(cloned_op)
  526. return program, fwd_op_end_idx
  527. def classify_value_node(program, grad_outputs, fwd_op_end_idx):
  528. all_ops = program.global_block().ops
  529. required_fw_value_nodes = backward_utils.ValueSet()
  530. required_fw_ops = set(all_ops[: fwd_op_end_idx + 1])
  531. for required_fw_op in required_fw_ops:
  532. fw_op_outputs = required_fw_op.results()
  533. required_fw_value_nodes = (
  534. required_fw_value_nodes | backward_utils.ValueSet(fw_op_outputs)
  535. )
  536. required_bw_value_nodes = backward_utils.ValueSet()
  537. required_bw_ops = set()
  538. for grad_output in grad_outputs:
  539. required_bw_ops = required_bw_ops | find_child_ops(grad_output)
  540. required_bw_ops.add(grad_output.get_defining_op())
  541. for required_bw_op in required_bw_ops:
  542. bw_op_outputs = required_bw_op.results()
  543. required_bw_value_nodes = (
  544. required_bw_value_nodes | backward_utils.ValueSet(bw_op_outputs)
  545. )
  546. unclaimed_value_nodes = backward_utils.ValueSet()
  547. unclaimed_ops = {
  548. op
  549. for op in all_ops
  550. if op not in required_fw_ops and op not in required_bw_ops
  551. }
  552. for unclaimed_op in unclaimed_ops:
  553. unclaimed_op_outputs = unclaimed_op.results()
  554. unclaimed_value_nodes = unclaimed_value_nodes | backward_utils.ValueSet(
  555. unclaimed_op_outputs
  556. )
  557. return (
  558. required_fw_value_nodes,
  559. required_bw_value_nodes,
  560. unclaimed_value_nodes,
  561. )
  562. def find_value_node_users(value_node):
  563. '''
  564. Find all the value nodes which use the same value node to be computed.
  565. '''
  566. users = backward_utils.ValueSet()
  567. for op in value_node.all_used_ops():
  568. if op.name() == "builtin.combine":
  569. combine_result = op.results()[0]
  570. for combine_res_used_op in combine_result.all_used_ops():
  571. results = combine_res_used_op.results()
  572. for result in results:
  573. if (
  574. len(result.all_used_ops()) == 1
  575. and result.all_used_ops()[0].name() == "builtin.split"
  576. ):
  577. split_results = result.all_used_ops()[0].results()
  578. users |= backward_utils.ValueSet(split_results)
  579. else:
  580. users.add(result)
  581. else:
  582. results = op.results()
  583. for result in results:
  584. if (
  585. len(result.all_used_ops()) == 1
  586. and result.all_used_ops()[0].name() == "builtin.split"
  587. ):
  588. split_results = result.all_used_ops()[0].results()
  589. users |= backward_utils.ValueSet(split_results)
  590. else:
  591. users.add(result)
  592. return users
  593. def get_real_input_nodes(output_value_node):
  594. real_input_nodes = backward_utils.ValueSet()
  595. define_op = output_value_node.get_defining_op()
  596. if define_op.name() == "builtin.split":
  597. op_input = define_op.operands_source()[0]
  598. real_define_op = op_input.get_defining_op()
  599. input_value_nodes = real_define_op.operands_source()
  600. else:
  601. input_value_nodes = define_op.operands_source()
  602. for input_value_node in input_value_nodes:
  603. if input_value_node.get_defining_op().name() == "builtin.combine":
  604. real_input_nodes |= backward_utils.ValueSet(
  605. input_value_node.get_defining_op().operands_source()
  606. )
  607. else:
  608. real_input_nodes.add(input_value_node)
  609. return real_input_nodes
  610. def get_real_define_op_name(value_node):
  611. define_op = value_node.get_defining_op()
  612. if define_op.name() == "builtin.split":
  613. op_input = define_op.operands_source()[0]
  614. return op_input.get_defining_op().name()
  615. else:
  616. return define_op.name()
  617. def is_dynamic_value_node(value_node):
  618. return -1 in value_node.shape
  619. def cal_value_node_size(value_node):
  620. # todo(wanghao107) hack for dynamic shape
  621. if is_dynamic_value_node(value_node):
  622. return 1
  623. return value_node.numel() * _PADDLE_DTYPE_2_NBYTES[value_node.dtype]
  624. def cal_value_nodes_dist_to_backward(all_ops, required_fw_value_nodes):
  625. dist_from_bw = backward_utils.ValueDict()
  626. # caculate value node the shortest dist to backward graph
  627. for op in reversed(all_ops):
  628. if op.name() == "builtin.combine":
  629. continue
  630. op_results = op.results()
  631. for op_result in op_results:
  632. used_ops = op_result.all_used_ops()
  633. if len(used_ops) == 1 and used_ops[0].name() == "builtin.split":
  634. continue
  635. real_users = find_value_node_users(op_result)
  636. if op_result not in required_fw_value_nodes:
  637. dist_from_bw[op_result] = 0
  638. else:
  639. dist_from_bw[op_result] = int(1e9)
  640. for user in real_users:
  641. dist_from_bw[op_result] = min(
  642. dist_from_bw[op_result], dist_from_bw[user] + 1
  643. )
  644. return dist_from_bw
  645. def analyze_mid_hold_values(
  646. program,
  647. saved_values,
  648. inputs,
  649. outputs,
  650. fwd_op_end_idx,
  651. backward_op_start_idx,
  652. ):
  653. forward_ops = set(program.global_block().ops[: fwd_op_end_idx + 1])
  654. backward_ops = set(program.global_block().ops[backward_op_start_idx:])
  655. mid_hold_values = backward_utils.ValueSet()
  656. for op in forward_ops:
  657. for result in op.results():
  658. all_used_ops = result.all_used_ops()
  659. if (
  660. any(op in backward_ops for op in all_used_ops)
  661. and result not in saved_values
  662. and result not in outputs
  663. and result not in inputs
  664. ):
  665. mid_hold_values.add(result)
  666. return mid_hold_values
  667. def clone_graph(program, origin_ops, graph_inputs, clone_insertion_op):
  668. pir.set_insertion_point(clone_insertion_op)
  669. all_ops = program.global_block().ops
  670. value_map = paddle.pir.IrMapping()
  671. origin_ops = set(origin_ops)
  672. cloned_ops = []
  673. for input_value in graph_inputs:
  674. value_map.add(input_value, input_value)
  675. for op in all_ops:
  676. if op in origin_ops:
  677. cloned_ops.append(
  678. op.clone(value_map, paddle.pir.CloneOptions(False, True, True))
  679. )
  680. pir.set_insertion_point_to_block_end(program.global_block())
  681. return cloned_ops, value_map
  682. def find_parent_ops(value):
  683. visited = backward_utils.ValueSet()
  684. def _find_parent_ops(value):
  685. parent_ops = set()
  686. if value in visited:
  687. return parent_ops
  688. visited.add(value)
  689. parent_op = value.get_defining_op()
  690. parent_ops.add(parent_op)
  691. op_inputs = parent_op.operands_source()
  692. for op_input in op_inputs:
  693. parent_ops = parent_ops | _find_parent_ops(op_input)
  694. return parent_ops
  695. return _find_parent_ops(value)
  696. def find_child_ops(value):
  697. visited = backward_utils.ValueSet()
  698. def _find_child_ops(value):
  699. child_ops = set()
  700. if value in visited:
  701. return child_ops
  702. visited.add(value)
  703. used_ops = value.all_used_ops()
  704. child_ops |= set(used_ops)
  705. op_results = backward_utils.ValueSet()
  706. for used_op in used_ops:
  707. op_results = op_results | backward_utils.ValueSet(used_op.results())
  708. for op_result in op_results:
  709. child_ops = child_ops | _find_child_ops(op_result)
  710. return child_ops
  711. return _find_child_ops(value)