decomp.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import typing
  16. import warnings
  17. import paddle
  18. from paddle import pir
  19. from paddle.autograd import ir_backward
  20. from paddle.autograd.backward_utils import ValueDict, ValueSet
  21. from paddle.base.core import (
  22. call_decomp,
  23. call_decomp_vjp,
  24. decomp_ops_contain_unused_output,
  25. has_decomp,
  26. has_decomp_vjp,
  27. )
  28. from paddle.base.libpaddle.pir import Block, Operation
  29. from paddle.base.wrapped_decorator import signature_safe_contextmanager
  30. from paddle.framework import core
  31. from . import register
  32. logger = logging.getLogger(__name__)
  33. @signature_safe_contextmanager
  34. def prim_guard():
  35. prim_state = core._is_all_prim_enabled()
  36. try:
  37. if not prim_state:
  38. core._set_prim_all_enabled(True)
  39. yield
  40. finally:
  41. if not prim_state:
  42. core._set_prim_all_enabled(False)
  43. def _build_tensor_tuple(xs):
  44. if isinstance(xs, pir.Value):
  45. return (xs,)
  46. elif isinstance(xs, typing.Sequence):
  47. return tuple(xs)
  48. return TypeError(f"Type {type(xs)} is not supported.")
  49. def _analyse_decomp_results(orig_outs, decomp_outs, op):
  50. assert len(orig_outs) == len(decomp_outs)
  51. res = []
  52. for idx, value in enumerate(decomp_outs):
  53. if isinstance(orig_outs[idx], pir.Value):
  54. if (
  55. op.name() in decomp_ops_contain_unused_output.keys()
  56. and idx in decomp_ops_contain_unused_output[op.name()]
  57. ):
  58. assert value[0] is None
  59. else:
  60. assert len(value) == 1 and isinstance(value[0], pir.Value)
  61. res.append(value[0])
  62. else:
  63. res.append(value)
  64. return res
  65. def _prepare_python_api_arguments(op):
  66. """
  67. For standard api of operator, its inputs should keep consistent with organization of its inputs and attrs.
  68. Args:
  69. op (Operator): The target operator.
  70. """
  71. combine_op_name = "builtin.combine"
  72. inputs = []
  73. for x in op.operands():
  74. input = x.source()
  75. if input.initialized():
  76. prev_op = input.get_defining_op()
  77. if (
  78. isinstance(prev_op, Operation)
  79. and prev_op.name() == combine_op_name
  80. ):
  81. input = [item.source() for item in prev_op.operands()]
  82. inputs.append(input)
  83. else:
  84. # for optional input, such as scale for layer_norm op,
  85. # if it is not set, there will be an empty Value which is not initialized in ops.operands
  86. # therefore append None for it.
  87. inputs.append(None)
  88. # The inputs of Pir op builtin.combine will be restored as list of tensor.
  89. if op.name() == combine_op_name:
  90. return (inputs,)
  91. api_arguments = inputs + [op.attrs()[x] for x in op.get_attr_names()]
  92. return tuple(api_arguments)
  93. def _check_prim_dynamic(op):
  94. combine_op_name = "builtin.combine"
  95. inputs = []
  96. for x in op.operands():
  97. input = x.source()
  98. if input.initialized():
  99. prev_op = input.get_defining_op()
  100. if (
  101. isinstance(prev_op, Operation)
  102. and prev_op.name() == combine_op_name
  103. ):
  104. for item in prev_op.operands():
  105. shape = item.source().shape
  106. if -1 in shape:
  107. warnings.warn(
  108. f"Decomp op does not support dynamic shape -1, but got shape {item.source().shape} in inputs of op {op.name()} "
  109. )
  110. return True
  111. else:
  112. shape = input.shape
  113. if -1 in shape:
  114. warnings.warn(
  115. f"Decomp op does not support dynamic shape -1, but got shape {input.shape} in op {op.name()} "
  116. )
  117. return True
  118. def _check_op_results(
  119. op_name, orig_outs, new_outs, orig_vars=None, dst_vars=None
  120. ):
  121. """
  122. Check whether the replaced outputs are consistent with origin outputs.
  123. Args:
  124. op_name (str): The name of operator.
  125. orig_outs (tuple): The outputs of original operator.
  126. new_outs (tuple): The outputs of replaced operator.
  127. orig_vars (dict): Origin variables of original block.
  128. dst_vars (list): Corresponding replaced variables of Origin variables.
  129. """
  130. assert len(orig_outs) == len(new_outs), (
  131. f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, '
  132. f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}'
  133. )
  134. for orig_out, new_out in zip(
  135. orig_outs,
  136. new_outs,
  137. ):
  138. if (orig_out is None or new_out is None) and (
  139. op_name not in core.ops_contain_none
  140. ):
  141. raise ValueError(
  142. f"op {op_name} should not contain any None value. original outs={orig_outs} and its composite rule outs={new_outs}"
  143. )
  144. if orig_out is None:
  145. # to keep same as phi op definition, orig_out may receive None
  146. continue
  147. elif new_out is not None:
  148. if orig_vars is not None and dst_vars is not None:
  149. if orig_out in orig_vars:
  150. dst_vars[orig_vars[orig_out]] = new_out
  151. orig_dtype = orig_out.dtype
  152. new_dtype = new_out.dtype
  153. orig_shape = orig_out.shape
  154. new_shape = new_out.shape
  155. assert orig_dtype == new_dtype, (
  156. f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, '
  157. f'but orig_out dtype={orig_dtype} and new_out dtype={new_dtype}'
  158. )
  159. assert (
  160. -1 not in new_shape
  161. ), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
  162. assert orig_shape == new_shape, (
  163. f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
  164. f'but orig_out shape={orig_shape} and new_out shape={new_shape}'
  165. )
  166. assert not (orig_out is None) ^ (
  167. new_out is None
  168. ), "orig_out and new_out should match."
  169. return
  170. def decompose(
  171. program,
  172. src_vars,
  173. blacklist=frozenset(),
  174. whitelist=frozenset(),
  175. start_index=0,
  176. end_index=-1,
  177. ):
  178. """
  179. Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
  180. The operators in blacklist will be excluded from program when decomposed into primitives, and only the
  181. operators in whitelist will be decomposed. The priority of blacklist is higher than whitelist, it means
  182. an operator both in blacklist and whitelist will not be decomposed.
  183. The finally set that will be decomposed is:
  184. (block.ops & ops have decomposite rule & whitelist) - blacklist
  185. Note:
  186. All variables must be contained inside the given program.
  187. Args:
  188. program (Program): The program to be processed.
  189. src_vars (list[Value]): In program, once some operator is decomposed, its vars will be replaced by new ones. This argument means some vars will be used later and corresponding vars will be returned for later usage.
  190. blacklist (frozenset): The Operators that will be exclude when decomposed into primitives.
  191. whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives.
  192. start_index (int): The start index of decomposed operator in global block, default 0;
  193. end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. start_index and end_index follow the principle of left closed and right open, that is [start_index, end_index).
  194. Returns:
  195. dst_vars (list): A list contains all vars which replace origin ones in src_vars.
  196. """
  197. blacklist = core.prim_config["forward_blacklist"] | blacklist
  198. assert isinstance(start_index, int)
  199. assert isinstance(end_index, int)
  200. return core.sinking_decomp(
  201. program, src_vars, blacklist, whitelist, start_index, end_index
  202. )
  203. def _check_combine_inputs(input1, input2):
  204. '''check whether the inputs of two builtins.combine ops are the same'''
  205. builtin_combine_op1 = input1.get_defining_op()
  206. builtin_combine_op2 = input2.get_defining_op()
  207. if builtin_combine_op1.num_operands() != builtin_combine_op2.num_operands():
  208. return False
  209. else:
  210. for i in range(builtin_combine_op1.num_operands()):
  211. if not (
  212. builtin_combine_op1.operand_source(i).is_same(
  213. builtin_combine_op2.operand_source(i)
  214. )
  215. ):
  216. return False
  217. return True
  218. def _check_op(
  219. fwd_op: pir.Operation,
  220. bwd_op: pir.Operation,
  221. ):
  222. '''check whether the bwd_op is corresponding to fwd_op'''
  223. if fwd_op is None or fwd_op.name() + "_grad" != bwd_op.name():
  224. return False
  225. bwd_op_input_names = bwd_op.get_input_names()
  226. bwd_inputs = [x.source() for x in bwd_op.operands()]
  227. assert len(bwd_op_input_names) == len(
  228. bwd_inputs
  229. ), "backward op names do not match backward op inputs"
  230. fwd_op_related_inputs_outputs = []
  231. for idx, name in enumerate(bwd_op_input_names):
  232. if "_grad" not in name:
  233. fwd_op_related_inputs_outputs.append(bwd_inputs[idx])
  234. fwd_inputs = [x.source() for x in fwd_op.operands()]
  235. fwd_outputs = fwd_op.results()
  236. fwd_vec_inputs = [
  237. x.source()
  238. for x in fwd_op.operands()
  239. if x.source().initialized()
  240. and x.source().get_defining_op().name() == "builtin.combine"
  241. ]
  242. inserted_op_name_list = ["pd_op.full_int_array", "pd_op.full"]
  243. for operand in fwd_op_related_inputs_outputs:
  244. if (
  245. operand.initialized()
  246. and operand.get_defining_op().name() == "builtin.combine"
  247. ): # for pir::VectorType<paddle::dialect::DenseTensorType>
  248. in_fwd = False
  249. for vec_input in fwd_vec_inputs:
  250. if _check_combine_inputs(operand, vec_input):
  251. in_fwd = True
  252. break
  253. if not in_fwd:
  254. return False
  255. else: # for pir::VectorType<paddle::dialect::DenseTensorType>
  256. if not (
  257. operand in ValueSet(fwd_inputs)
  258. or operand in ValueSet(fwd_outputs)
  259. or operand.get_defining_op().name() in inserted_op_name_list
  260. ):
  261. return False
  262. return True
  263. def _get_fwd_op(bwd_op, grad_var_to_var):
  264. bwd_op_input_names = bwd_op.get_input_names()
  265. out_grad_name = ["out_grad", "Out_grad", "loss_grad"]
  266. for idx, input_name in enumerate(bwd_op_input_names):
  267. if input_name in out_grad_name:
  268. out_grad = bwd_op.operand(idx).source()
  269. if out_grad in grad_var_to_var:
  270. out = grad_var_to_var[out_grad]
  271. fwd_op = out.get_defining_op()
  272. return fwd_op
  273. return None
  274. def _decomp_fwd_op(
  275. block: Block, fwd_op: pir.Operation, grad_var_to_var: dict, prev_op=None
  276. ) -> tuple:
  277. '''
  278. Decompose the forward op into a list of primitive ops.
  279. Args:
  280. block (Block): the block to which the forward op belongs.
  281. fwd_op (pir.Operation): the forward op to be decomposed.
  282. grad_var_to_var (dict): a dict obtained from distributed processing,
  283. which maps the backward grad variable to its corresponding forward variable.
  284. prev_op (pir.Operation): the previous op of fwd_op in the block. If prev_op is builtin.combine, insertion point when decomposing fwd_op will be set to prev_op.
  285. Returns:
  286. new_outputs (tuple(Value)): the new outputs after decomposing.
  287. has_decomposed: whether the forward op has been successfully decomposed.
  288. '''
  289. with pir.core.program_guard(block.program):
  290. op_name = fwd_op.name()
  291. orig_outs = fwd_op.results()
  292. decom_rule = register.get_decomp_rule(op_name)
  293. has_sink_decomp_rule = has_decomp(fwd_op)
  294. lower = decom_rule or has_sink_decomp_rule
  295. if lower:
  296. # step1: check dynamic shape, currently not supported
  297. if _check_prim_dynamic(fwd_op):
  298. return None, False
  299. # step2: check insertion point, if prev_op is builtin.combine (such as concat op), insertion point will be set to prev_op
  300. if prev_op is not None:
  301. pir.set_insertion_point(prev_op)
  302. else:
  303. pir.set_insertion_point(fwd_op)
  304. # step3: decompose op, and get new outputs
  305. input_args = _prepare_python_api_arguments(fwd_op)
  306. if has_sink_decomp_rule:
  307. decomp_outs = call_decomp(fwd_op)
  308. new_outs = _analyse_decomp_results(
  309. orig_outs, decomp_outs, fwd_op
  310. )
  311. else:
  312. new_outs = _build_tensor_tuple(decom_rule(*input_args))
  313. _check_op_results(op_name, orig_outs, new_outs)
  314. # step4: upgrade grad_var_to_var with new outputs
  315. _upgrade_grad_var_to_var(
  316. grad_var_to_var, orig_outs=orig_outs, new_outs=new_outs
  317. )
  318. # step5: replace original op with new ops, replace original output with new outputs
  319. if fwd_op.name() in decomp_ops_contain_unused_output.keys():
  320. for idx in range(len(orig_outs)):
  321. if (
  322. idx
  323. not in decomp_ops_contain_unused_output[fwd_op.name()]
  324. ):
  325. orig_outs[idx].replace_all_uses_with(new_outs[idx])
  326. else:
  327. if fwd_op.name() in decomp_ops_contain_unused_output.keys():
  328. orig_outs[0].replace_all_uses_with(new_outs[0])
  329. else:
  330. fwd_op.replace_all_uses_with(new_outs)
  331. block.remove_op(fwd_op)
  332. # step6: remove redundant prev_op (builtin.combine)
  333. if prev_op is not None:
  334. remove_op = True
  335. for item in prev_op.results():
  336. if item.has_one_use():
  337. remove_op = False
  338. break
  339. if remove_op:
  340. block.remove_op(prev_op)
  341. prev_op = None
  342. return new_outs, True
  343. else:
  344. return tuple(orig_outs), False
  345. def _prepare_inputs(fwd_op):
  346. new_inputs = []
  347. for input in fwd_op.operands():
  348. if (
  349. input.source().initialized()
  350. and input.source().get_defining_op().name() == "builtin.combine"
  351. ): # for pir::VectorType<paddle::dialect::DenseTensorType>
  352. builtin_combine_op = input.source().get_defining_op()
  353. new_input = [
  354. builtin_combine_op.operand_source(i)
  355. for i in range(0, builtin_combine_op.num_operands())
  356. ]
  357. new_inputs.append(new_input)
  358. else:
  359. new_inputs.append([input.source()]) # for DenseTensorType
  360. return new_inputs
  361. def _prepare_grad_outputs(fwd_op, bwd_op):
  362. # check forward outputs and backward inputs
  363. fwd_outputs = fwd_op.results()
  364. fwd_output_names = fwd_op.get_output_names()
  365. assert len(fwd_output_names) == len(
  366. fwd_outputs
  367. ), "forward op output names do not match forward op outputs"
  368. bwd_inputs = [x.source() for x in bwd_op.operands()]
  369. bwd_input_names = bwd_op.get_input_names()
  370. assert len(bwd_input_names) == len(
  371. bwd_inputs
  372. ), "backward op input names do not match backward op inputs"
  373. # cut gradients from backward op's inputs
  374. fwd_inputs = [x.source() for x in fwd_op.operands()]
  375. fwd_vec_inputs = [
  376. x.source()
  377. for x in fwd_op.operands()
  378. if x.source().initialized()
  379. and x.source().get_defining_op().name() == "builtin.combine"
  380. ]
  381. grad_outputs = []
  382. grad_output_names = []
  383. for i, bwd_input in enumerate(bwd_inputs):
  384. if (
  385. bwd_input.initialized()
  386. and bwd_input.get_defining_op().name() == "builtin.combine"
  387. ): # for pir::VectorType<paddle::dialect::DenseTensorType>
  388. in_fwd = False
  389. for vec_input in fwd_vec_inputs:
  390. if _check_combine_inputs(bwd_input, vec_input):
  391. in_fwd = True
  392. break
  393. if not in_fwd:
  394. grad_outputs.append([bwd_input])
  395. grad_output_names.append(bwd_input_names[i])
  396. else:
  397. if not (
  398. bwd_input in ValueSet(fwd_inputs)
  399. or bwd_input in ValueSet(fwd_outputs)
  400. ): # for paddle::dialect::DenseTensorType
  401. grad_outputs.append([bwd_input])
  402. grad_output_names.append(bwd_input_names[i])
  403. # add fake grads for forward op's outputs which are not used in backward op
  404. # this is necessary for the call_vjp(), which ensures that len(out_grads) must be equal to len(outputs)
  405. new_grad_outputs = []
  406. index = 0
  407. for fwd_output_name in fwd_output_names:
  408. if (fwd_output_name + "_grad") in grad_output_names:
  409. new_grad_outputs.append(grad_outputs[index])
  410. index += 1
  411. else:
  412. new_grad_outputs.append([pir.fake_value()])
  413. return new_grad_outputs
  414. def _prepare_stop_gradients(fwd_inputs, bwd_outputs):
  415. stop_gradients = []
  416. for idx, bwd_output in enumerate(bwd_outputs):
  417. if bwd_output.initialized():
  418. stop_gradient = [False] * len(fwd_inputs[idx])
  419. else:
  420. stop_gradient = [True] * len(fwd_inputs[idx])
  421. stop_gradients.append(stop_gradient)
  422. return stop_gradients
  423. def _upgrade_grad_var_to_var(
  424. grad_var_to_var,
  425. orig_grads=None,
  426. new_grads=None,
  427. orig_outs=None,
  428. new_outs=None,
  429. ):
  430. assert grad_var_to_var is not None, "grad_var_to_var should not be None"
  431. if orig_grads is not None and new_grads is not None:
  432. for idx, grad_input in enumerate(orig_grads):
  433. if grad_input in grad_var_to_var:
  434. grad_var_to_var[new_grads[idx]] = grad_var_to_var.pop(
  435. grad_input
  436. )
  437. if orig_outs is not None and new_outs is not None:
  438. for grad_var, var in grad_var_to_var.items():
  439. for i, orin_var in enumerate(orig_outs):
  440. if var.is_same(orin_var):
  441. grad_var_to_var[grad_var] = new_outs[i]
  442. def _decomp_bwd_with_vjp(
  443. block: Block,
  444. fwd_op: pir.Operation,
  445. bwd_op: pir.Operation,
  446. grad_var_to_var: dict,
  447. ) -> tuple:
  448. '''
  449. Decompose the backward op into a list of primitive ops.
  450. If forward op has composite vjp rules (including custom vjp), call call_vjp() to get a list of primitive operators in backward graph, then replace backward op.
  451. '''
  452. # step1: prepare arguments for call_vjp()
  453. fwd_inputs_ = _prepare_inputs(fwd_op)
  454. fwd_outputs_ = [[fwd_output] for fwd_output in fwd_op.results()]
  455. grad_outputs_ = _prepare_grad_outputs(fwd_op, bwd_op)
  456. stop_gradients_ = _prepare_stop_gradients(fwd_inputs_, bwd_op.results())
  457. # step2: call call_vjp() to get a list of primitive operators which has the same meaning as the backward op
  458. bwd_op_idx = block.ops.index(bwd_op)
  459. before_num_ops = len(block.ops)
  460. new_grad_inputs = core.call_vjp(
  461. fwd_op, fwd_inputs_, fwd_outputs_, grad_outputs_, stop_gradients_
  462. )
  463. after_num_ops = len(block.ops)
  464. num_appended_ops = after_num_ops - before_num_ops
  465. # if forward op has no composite vjp rules, call_vjp() appends the same op as original backward op, skip decomposing, return False
  466. if num_appended_ops == 1 and block.ops[-1].name() == bwd_op.name():
  467. block.remove_op(block.ops[-1])
  468. return None, False
  469. else:
  470. # step3: record new outputs of the decomposed backward op
  471. if block.ops[-1].name() == "builtin.split":
  472. new_grad_inputs = [[block.ops[-1].operand(0).source()]]
  473. res = []
  474. for grad_input in new_grad_inputs:
  475. if grad_input[0] is not None and grad_input[0].initialized():
  476. res.append(grad_input[0])
  477. else:
  478. res.append(pir.fake_value())
  479. assert len(res) == len(
  480. bwd_op.results()
  481. ), "results of original backward op do not match results of decomposed backward op"
  482. # step4: upgrade grad_var_to_var
  483. _upgrade_grad_var_to_var(
  484. grad_var_to_var, orig_grads=bwd_op.results(), new_grads=res
  485. )
  486. # step5: replace original backward op with new primitive ops
  487. insert_idx = bwd_op_idx
  488. for i in range(before_num_ops, after_num_ops):
  489. block.move_op(block.ops[i], insert_idx)
  490. insert_idx += 1
  491. bwd_op.replace_all_uses_with(res)
  492. block.remove_op(bwd_op)
  493. return tuple(res), True
  494. def _decomp_bwd_without_vjp(
  495. block: Block,
  496. bwd_op: pir.Operation,
  497. grad_var_to_var: dict,
  498. fwd_inputs: list,
  499. fwd_outputs_after_decompose: tuple,
  500. ) -> tuple:
  501. '''
  502. Decompose the backward op into a list of primitive ops.
  503. If forward op has no composite vjp rules, and forward op has been decomposed to a list of primitive operators in forward graph previously,
  504. call grad() for the decomposed forward subgraph to get a list of primitive operators in backward graph, then replace backward op.
  505. '''
  506. if fwd_outputs_after_decompose is None:
  507. raise RuntimeError(
  508. "To decompose backward op, please decompose forward op firstly"
  509. )
  510. # step1: prepare arguments for grad()
  511. bwd_inputs = [x.source() for x in bwd_op.operands()]
  512. grad_inputs = bwd_op.results()
  513. grad_outputs = tuple(
  514. bwd_input
  515. for bwd_input in bwd_inputs
  516. if not (
  517. bwd_input in ValueSet(fwd_inputs)
  518. or bwd_input in ValueSet(fwd_outputs_after_decompose)
  519. )
  520. )
  521. fwd_outputs_ = tuple(
  522. grad_var_to_var[grad_output] for grad_output in grad_outputs
  523. )
  524. fwd_inputs_ = tuple(
  525. grad_var_to_var[grad_input]
  526. for grad_input in grad_inputs
  527. if grad_input.initialized()
  528. )
  529. # step2: call grad() to get a list of primitive operators which has the same meaning as the backward op
  530. bwd_op_idx = block.ops.index(bwd_op)
  531. before_num_ops = len(block.ops)
  532. new_grad_inputs = ir_backward.grad(fwd_outputs_, fwd_inputs_, grad_outputs)
  533. after_num_ops = len(block.ops)
  534. # step3: record new outputs of the decomposed backward op
  535. res = []
  536. input_grads_idx = 0
  537. for idx, grad_input in enumerate(grad_inputs):
  538. if grad_input.initialized():
  539. res.append(new_grad_inputs[input_grads_idx])
  540. input_grads_idx += 1
  541. else:
  542. res.append(pir.fake_value())
  543. # step4: upgrade grad_var_to_var
  544. _upgrade_grad_var_to_var(
  545. grad_var_to_var, orig_grads=grad_inputs, new_grads=res
  546. )
  547. # step5: replace original backward op with new primitive ops
  548. insert_idx = bwd_op_idx
  549. for i in range(before_num_ops, after_num_ops):
  550. block.move_op(block.ops[i], insert_idx)
  551. insert_idx += 1
  552. bwd_op.replace_all_uses_with(res)
  553. block.remove_op(bwd_op)
  554. has_decomposed = True
  555. return tuple(res), has_decomposed
  556. def _decomp_bwd_op(
  557. block: Block,
  558. bwd_op: pir.Operation,
  559. grad_var_to_var: dict,
  560. ):
  561. '''
  562. Decompose a backward op in pir program.
  563. Get the corresponding forward op according to grad_var_to_var firstly, then
  564. (1) try to decompose backward op by calling _decompose_bwd_with_vjp, if forward op has composite vjp rules (including custom vjp),
  565. _decompose_bwd_with_vjp will call call_vjp() to get a list of primitive operators in backward graph, then replace backward op successfully and return True;
  566. (2) when _decompose_bwd_with_vjp return False, means there is no composite vjp rules,
  567. try to decompose forward op firstly by calling _decomp_fwd_op firstly and get corresponding primitive operators in backward graph by calling _decompose_bwd_without_vjp secondly, then replace backward op successfully and return True;
  568. (3) if the backward op is still not decomposed by the above two steps, returns False.
  569. Args:
  570. block (Block): the block to which the backward op belongs.
  571. bwd_op (pir.Operation): the backward op to be decomposed.
  572. grad_var_to_var (dict): a dict obtained from distributed processing,
  573. which maps the backward grad variable to its corresponding forward variable.
  574. Return:
  575. new_input_grads (tuple(Value)): new results of backward op after decomposing.
  576. has_decomposed: whether the backward op has been successfully decomposed.
  577. '''
  578. # get the corresponding forward op according to grad_var_to_var
  579. # check and ensure: bwd_inputs = out_grads + fwd_inputs[optional] + fwd_outputs[optional]
  580. fwd_op = _get_fwd_op(bwd_op, grad_var_to_var)
  581. if not _check_op(fwd_op, bwd_op):
  582. logger.debug(
  583. f'{bwd_op.name()} can not be decomposed due to the mismatch between forward op and backward op'
  584. )
  585. return None, False
  586. if _check_prim_dynamic(fwd_op) or _check_prim_dynamic(bwd_op):
  587. return None, False
  588. # try to decompose backward op directly
  589. (
  590. new_grads,
  591. bwd_has_decomposed,
  592. ) = _decomp_bwd_with_vjp(
  593. block,
  594. fwd_op,
  595. bwd_op,
  596. grad_var_to_var,
  597. )
  598. if not bwd_has_decomposed:
  599. # try to decompose the forward op
  600. fwd_inputs = [x.source() for x in fwd_op.operands()]
  601. (
  602. new_fwd_outputs,
  603. fwd_has_decomposed,
  604. ) = _decomp_fwd_op(
  605. block,
  606. fwd_op,
  607. grad_var_to_var,
  608. )
  609. if fwd_has_decomposed:
  610. # try to decompose the backward op
  611. (
  612. new_grads,
  613. bwd_has_decomposed,
  614. ) = _decomp_bwd_without_vjp(
  615. block,
  616. bwd_op,
  617. grad_var_to_var,
  618. fwd_inputs,
  619. new_fwd_outputs,
  620. )
  621. return new_grads, bwd_has_decomposed
  622. def _get_all_bwd_ops(pir_program):
  623. bwd_ops = []
  624. global_block = pir_program.global_block()
  625. for op in global_block.ops:
  626. if (
  627. op.name().endswith("_grad") or op.name().endswith("_grad_")
  628. ) and op.name() not in bwd_ops:
  629. bwd_ops.append(op.name())
  630. return bwd_ops
  631. def _set_prim_state():
  632. state = []
  633. prev_fwd_prim_state = core._is_fwd_prim_enabled()
  634. prev_bwd_prim_state = core._is_bwd_prim_enabled()
  635. state.append(prev_fwd_prim_state)
  636. state.append(prev_bwd_prim_state)
  637. core._set_prim_forward_enabled(True)
  638. core._set_prim_backward_enabled(True)
  639. prev_pir_api_flag = paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
  640. "FLAGS_enable_pir_api"
  641. ]
  642. paddle.framework.set_flags(
  643. {"FLAGS_enable_pir_api": True}
  644. ) # set in pir mode for operator overloading
  645. paddle.base.framework.global_var._use_pir_api_ = True
  646. state.append(prev_pir_api_flag)
  647. return state
  648. def _reset_prim_state(state):
  649. assert (
  650. len(state) == 3
  651. ), "state should contain fwd_prim_state, bwd_prim_state and pir_api_state"
  652. core._set_prim_forward_enabled(state[0])
  653. core._set_prim_backward_enabled(state[1])
  654. paddle.framework.set_flags({"FLAGS_enable_pir_api": state[2]})
  655. paddle.base.framework.global_var._use_pir_api_ = state[2]
  656. def _translate_gradvartovar_to_pir(param_mapping, grad_var_to_var):
  657. '''translate grad_var_to_var (mapping VarDesc->VarDesc) to pir_grad_var_to_var (mapping Value->Value)'''
  658. pir_grad_var_to_var = ValueDict()
  659. for grad_var, var in grad_var_to_var.items():
  660. if grad_var in param_mapping.keys() and var in param_mapping.keys():
  661. if (
  662. len(param_mapping[grad_var]) == 1
  663. and len(param_mapping[var]) == 1
  664. ):
  665. new_grad_var = param_mapping[grad_var][0]
  666. new_var = param_mapping[var][0]
  667. pir_grad_var_to_var[new_grad_var] = new_var
  668. else:
  669. new_grad_vars = []
  670. new_vars = []
  671. if len(param_mapping[grad_var]) == 1:
  672. new_grad_vars.append(param_mapping[grad_var][0])
  673. elif (
  674. len(param_mapping[grad_var]) == 2
  675. and param_mapping[grad_var][1].get_defining_op().name()
  676. == "builtin.slice"
  677. ):
  678. new_grad_vars.append(param_mapping[grad_var][1])
  679. else:
  680. for i in range(0, len(param_mapping[grad_var])):
  681. new_grad_vars.append(param_mapping[grad_var][i])
  682. if len(param_mapping[var]) == 1:
  683. new_vars.append(param_mapping[var][0])
  684. elif (
  685. len(param_mapping[var]) == 2
  686. and param_mapping[var][1].get_defining_op().name()
  687. == "builtin.slice"
  688. ):
  689. new_vars.append(param_mapping[var][1])
  690. else:
  691. last_op = param_mapping[var][-1].get_defining_op()
  692. if last_op.name().endswith("_"):
  693. new_vars.append(param_mapping[var][0])
  694. assert len(new_vars) == 1, "translate pir_grad_var_to_var error"
  695. for i in range(0, len(new_grad_vars)):
  696. pir_grad_var_to_var[new_grad_vars[i]] = new_vars[0]
  697. return pir_grad_var_to_var
  698. def _decomp_bwd_program(pir_program, pir_grad_var_to_var):
  699. '''Traverse and decompose all backward OPs in program'''
  700. with paddle.pir.core.program_guard(pir_program):
  701. bwd_ops = _get_all_bwd_ops(pir_program)
  702. undecomposed_bwd_ops = []
  703. ops = pir_program.global_block().ops
  704. for op in ops:
  705. bwd_op_name = op.name()
  706. if op.name() in bwd_ops:
  707. _, bwd_has_decomposed = _decomp_bwd_op(
  708. pir_program.global_block(), op, pir_grad_var_to_var
  709. )
  710. if (
  711. not bwd_has_decomposed
  712. and bwd_op_name not in undecomposed_bwd_ops
  713. ):
  714. undecomposed_bwd_ops.append(bwd_op_name)
  715. logger.debug(
  716. f'Following backward ops can not be decomposed: {undecomposed_bwd_ops}'
  717. )
  718. def _decomp_fwd_program(pir_program, pir_grad_var_to_var):
  719. '''Traverse and decompose all forward OPs in program'''
  720. with paddle.pir.core.program_guard(pir_program):
  721. ops = pir_program.global_block().ops
  722. bwd_ops = _get_all_bwd_ops(pir_program)
  723. # ops including compile-time infermeta, causing mismatched input shape and output shape, which is unsupported when decomposing.
  724. black_fwd_ops = ["pd_op.stack", "pd_op.squeeze"]
  725. undecomposed_fwd_ops = []
  726. prev_op = None
  727. for op in ops:
  728. fwd_op_name = op.name()
  729. if op.name() not in bwd_ops:
  730. if op.name() not in black_fwd_ops:
  731. _, fwd_has_decomposed = _decomp_fwd_op(
  732. pir_program.global_block(),
  733. op,
  734. pir_grad_var_to_var,
  735. prev_op,
  736. )
  737. if (
  738. not fwd_has_decomposed
  739. and fwd_op_name not in undecomposed_fwd_ops
  740. ):
  741. undecomposed_fwd_ops.append(fwd_op_name)
  742. else:
  743. if fwd_op_name not in undecomposed_fwd_ops:
  744. undecomposed_fwd_ops.append(fwd_op_name)
  745. prev_op = op if op.name() == "builtin.combine" else None
  746. logger.debug(
  747. f'Following forward ops can not be decomposed: {undecomposed_fwd_ops}'
  748. )
  749. def decompose_dist_program(pir_program):
  750. '''
  751. Decompose all non-primitive ops into primitive ops in a pir program. It may contain forward ops and backward ops.
  752. '''
  753. # decomp forward composite ops
  754. decompose(pir_program, [])
  755. # decomp backward ops
  756. blacklist = core.prim_config["backward_blacklist"]
  757. block = pir_program.global_block()
  758. pre_combine_op = None
  759. with paddle.pir.core.program_guard(pir_program):
  760. ops = pir_program.global_block().ops
  761. for op in ops:
  762. bwd_op_name = op.name()
  763. if bwd_op_name.split(".")[-1] in blacklist:
  764. continue
  765. skip_decomp = False
  766. if has_decomp_vjp(op):
  767. if (
  768. not core._enable_prim_dynamic_shape()
  769. ) and _check_prim_dynamic(op):
  770. skip_decomp = True
  771. if not skip_decomp:
  772. pir.set_insertion_point(op)
  773. orig_outs = op.results()
  774. is_next_split = False
  775. decomp_outs = call_decomp_vjp(op)
  776. for i in range(len(orig_outs)):
  777. if orig_outs[i].has_one_use():
  778. next_op = orig_outs[i].first_use().owner()
  779. if next_op.name() == "builtin.split":
  780. is_next_split = True
  781. _check_op_results(
  782. next_op.name(),
  783. next_op.results(),
  784. decomp_outs[i],
  785. )
  786. next_op.replace_all_uses_with(decomp_outs[i])
  787. block.remove_op(next_op)
  788. if not is_next_split:
  789. new_outs = _analyse_decomp_results(
  790. orig_outs, decomp_outs, op
  791. )
  792. _check_op_results(op.name(), orig_outs, new_outs)
  793. op.replace_all_uses_with(new_outs)
  794. block.remove_op(op)
  795. if op.name() == "builtin.combine":
  796. pre_combine_op = op
  797. if pre_combine_op is not None:
  798. remove_op = True
  799. for item in pre_combine_op.results():
  800. if item.has_one_use():
  801. remove_op = False
  802. break
  803. if remove_op:
  804. block.remove_op(pre_combine_op)
  805. pre_combine_op = None
  806. paddle.pir.set_insertion_point_to_block_end(block)
  807. def decompose_pir_program(pir_program, param_mapping, grad_var_to_var):
  808. '''
  809. Decompose all PHI ops into prim ops in a pir program.
  810. Args:
  811. pir_program (Program): the program to be decomposed
  812. param_mapping (dict): a map of program variables to pir program values
  813. grad_var_to_var (dict): a dict obtained from distributed processing,
  814. which maps the backward grad variable to its corresponding forward variable.
  815. '''
  816. # set prim flags and pir_api flags
  817. state = _set_prim_state()
  818. # translate grad_var_to_var to pir
  819. pir_grad_var_to_var = _translate_gradvartovar_to_pir(
  820. param_mapping, grad_var_to_var
  821. )
  822. # decompose
  823. _decomp_bwd_program(pir_program, pir_grad_var_to_var)
  824. _decomp_fwd_program(pir_program, pir_grad_var_to_var)
  825. # reset prim flags and pir_api flags
  826. _reset_prim_state(state)