backward_utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,tes
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import collections
  16. import logging
  17. import warnings
  18. from collections.abc import Sequence
  19. from functools import lru_cache
  20. from typing import Any
  21. from paddle import pir
  22. from paddle.base import core
  23. from paddle.base.libpaddle.pir import (
  24. get_used_external_value,
  25. )
  26. from paddle.base.wrapped_decorator import signature_safe_contextmanager
  27. # TODO: Consider a better way to mark these ops has no grad op.
  28. # Such as use a new trait to mark these ops.
  29. # Please keep them as alphabetical order.
  30. ALLOW_NO_GRAD_OPS = [
  31. # Compare ops
  32. "pd_op.equal",
  33. "pd_op.equal_",
  34. "pd_op.greater_than",
  35. "pd_op.greater_than_",
  36. "pd_op.greater_equal",
  37. "pd_op.greater_equal_",
  38. "pd_op.less_than",
  39. "pd_op.less_than_",
  40. "pd_op.less_equal",
  41. "pd_op.less_equal_",
  42. "pd_op.not_equal",
  43. "pd_op.not_equal_",
  44. # Logical ops
  45. "pd_op.logical_and",
  46. "pd_op.logical_and_",
  47. "pd_op.logical_not",
  48. "pd_op.logical_not_",
  49. "pd_op.logical_or",
  50. "pd_op.logical_or_",
  51. "pd_op.logical_xor",
  52. "pd_op.logical_xor_",
  53. # Bitwise ops
  54. "pd_op.bitwise_and",
  55. "pd_op.bitwise_and_",
  56. "pd_op.bitwise_left_shift",
  57. "pd_op.bitwise_left_shift_",
  58. "pd_op.bitwise_not",
  59. "pd_op.bitwise_not_",
  60. "pd_op.bitwise_or",
  61. "pd_op.bitwise_or_",
  62. "pd_op.bitwise_right_shift",
  63. "pd_op.bitwise_right_shift_",
  64. "pd_op.bitwise_xor",
  65. "pd_op.bitwise_xor_",
  66. # Array ops
  67. "pd_op.assign_array",
  68. "pd_op.assign_array_",
  69. "pd_op.array_length",
  70. "pd_op.array_pop",
  71. "pd_op.array_read",
  72. "pd_op.array_write_",
  73. "pd_op.create_array",
  74. "pd_op.create_array_like",
  75. "pd_op.slice_array",
  76. "pd_op.slice_array_dense",
  77. # Others
  78. "pd_op.accuracy",
  79. "pd_op.all",
  80. "pd_op.any",
  81. "pd_op.argmax",
  82. "pd_op.assign_value_",
  83. "pd_op.bernoulli",
  84. "pd_op.distribute_fpn_proposals",
  85. "pd_op.floor_divide",
  86. "pd_op.full_like",
  87. "pd_op.full_with_tensor",
  88. "pd_op.gaussian",
  89. "pd_op.isnan",
  90. "pd_op.isinf",
  91. "pd_op.nextafter",
  92. "pd_op.nonzero",
  93. "pd_op.one_hot",
  94. "pd_op.print",
  95. "pd_op.prior_box",
  96. "pd_op.randint",
  97. "pd_op.remainder",
  98. "pd_op.shape",
  99. "pd_op.share_data_",
  100. "pd_op.uniform",
  101. ]
  102. # TODO(CZ): to be removed when we support dynamic shape by default.
  103. ALLOW_DYNAMIC_SHAPE_VJP_OPS = [
  104. "pd_op.abs",
  105. "pd_op.assign",
  106. "pd_op.sin",
  107. "pd_op.cos",
  108. "pd_op.tanh",
  109. "pd_op.cast",
  110. "pd_op.log",
  111. "pd_op.exp",
  112. "pd_op.sqrt",
  113. "pd_op.rsqrt",
  114. "pd_op.sigmoid",
  115. "pd_op.silu",
  116. ]
  117. class ValueWrapper:
  118. def __init__(self, value) -> None:
  119. if isinstance(value, ValueWrapper):
  120. assert isinstance(value._value, (type(None), pir.Value))
  121. else:
  122. assert isinstance(value, (type(None), pir.Value))
  123. self._value = value._value if isinstance(value, ValueWrapper) else value
  124. def __hash__(self) -> int:
  125. if isinstance(self._value, pir.Value):
  126. return self._value.hash()
  127. else:
  128. return hash(self._value)
  129. def __eq__(self, other) -> bool:
  130. if not isinstance(other, ValueWrapper):
  131. warnings.warn(
  132. f'In ValueWrapper.__eq__ expected type of `other` is ValueWrapper but received {other.__class__}.'
  133. )
  134. return False
  135. if self._value is None or other._value is None:
  136. return self._value is None and other._value is None
  137. return self._value.is_same(other._value)
  138. class ValueDict:
  139. def __init__(
  140. self,
  141. iter=None,
  142. *,
  143. default_factory=None,
  144. ):
  145. self._items: dict[ValueWrapper] = {}
  146. self._default_factory = default_factory
  147. if iter is not None:
  148. for key, val in iter.items():
  149. self[key] = val
  150. def copy(self):
  151. ret = ValueDict()
  152. ret._items = self._items.copy()
  153. ret._default_factory = self._default_factory
  154. return ret
  155. def update(self, other_dict):
  156. for key, val in other_dict.items():
  157. self[key] = val
  158. def keys(self):
  159. for key in self._items.keys():
  160. yield key._value
  161. def values(self):
  162. return self._items.values()
  163. def items(self):
  164. for key, val in self._items.items():
  165. yield key._value, val
  166. def get(self, key, default=None):
  167. if not self.__contains__(key):
  168. return default
  169. return self._items[ValueWrapper(key)]
  170. def pop(self, key):
  171. if not self.__contains__(key):
  172. raise KeyError(f'{key} is not in ValueDict')
  173. return self._items.pop(ValueWrapper(key))
  174. def setdefault(self, key, default=None):
  175. if not self.__contains__(key):
  176. self[key] = default
  177. return self[key]
  178. def __setitem__(self, key, val: Any):
  179. self._items[ValueWrapper(key)] = val
  180. def __getitem__(self, key):
  181. if not self.__contains__(key):
  182. if self._default_factory is not None:
  183. self[key] = self._default_factory()
  184. else:
  185. raise KeyError(f'{key} is not in ValueDict')
  186. return self._items[ValueWrapper(key)]
  187. def __bool__(self):
  188. return bool(self._items)
  189. def __len__(self):
  190. return len(self._items)
  191. def __iter__(self):
  192. return self.keys()
  193. def __contains__(self, key):
  194. return ValueWrapper(key) in self._items
  195. def __repr__(self) -> str:
  196. items_str = ", ".join(f"{key}: {val}" for key, val in self.items())
  197. return f'ValueDict({items_str})'
  198. class ValueSet:
  199. def __init__(
  200. self, iter: Sequence[ValueWrapper] | set[ValueWrapper] | None = None
  201. ):
  202. self._set: set[ValueWrapper] = set()
  203. if iter is not None:
  204. for val in iter:
  205. self.add(val)
  206. def copy(self):
  207. ret = ValueSet()
  208. ret._set = self._set.copy()
  209. return ret
  210. def add(self, val):
  211. if not self.__contains__(val):
  212. self._set.add(ValueWrapper(val))
  213. def update(self, other: set):
  214. for val in other:
  215. self.add(val)
  216. def pop(self):
  217. return self._set.pop()._value
  218. def __and__(self, other: ValueSet):
  219. return ValueSet(self._set & other._set)
  220. def __or__(self, other: ValueSet):
  221. return ValueSet(self._set | other._set)
  222. def __bool__(self):
  223. return bool(self._set)
  224. def __len__(self):
  225. return len(self._set)
  226. def __iter__(self):
  227. for val in self._set:
  228. yield val._value
  229. def __contains__(self, val):
  230. return ValueWrapper(val) in self._set
  231. def __repr__(self) -> str:
  232. items_str = ", ".join(repr(item) for item in self)
  233. return f'ValueSet({items_str})'
  234. class State:
  235. """
  236. record relationship of forward op/value and backward op/value
  237. one state must be binding with a block, if block has parent block,
  238. state will include parent block info.
  239. """
  240. def __init__(self, block):
  241. self.block = block
  242. # value -> list(list(value))
  243. self.value_to_valuegrad = ValueDict(default_factory=list)
  244. self.value_to_sumvaluegrad = ValueDict(default_factory=list)
  245. # operation -> list(operation)
  246. self.op_to_opgrad = collections.defaultdict(list)
  247. # value -> list(value)
  248. self.valuegrad_to_value = ValueDict(default_factory=list)
  249. self.sumvaluegrad_to_value = ValueDict(default_factory=list)
  250. # operation -> list(operation)
  251. self.opgrad_to_op = collections.defaultdict(list)
  252. # only for controlflow
  253. # inside_value is sub block value, which will yield to parent block,
  254. # parent block value is outside_value
  255. self.inside_value_to_outside_value_map = ValueDict()
  256. def turn_map(self) -> None:
  257. self.valuegrad_to_value = ValueDict(default_factory=list)
  258. self.sumvaluegrad_to_value = ValueDict(default_factory=list)
  259. self.opgrad_to_op = collections.defaultdict(list)
  260. for k, v in self.value_to_valuegrad.items():
  261. if v != []:
  262. for value in v[0]:
  263. self.valuegrad_to_value[value] = [k]
  264. for k, v in self.value_to_sumvaluegrad.items():
  265. if v != []:
  266. for value in v[0]:
  267. self.sumvaluegrad_to_value[value] = [k]
  268. for k, v in self.op_to_opgrad.items():
  269. if v != []:
  270. self.opgrad_to_op[v[0]] = [k]
  271. def copy(self, new_block):
  272. state = State(new_block)
  273. state.value_to_valuegrad = self.value_to_valuegrad.copy()
  274. state.value_to_sumvaluegrad = self.value_to_sumvaluegrad.copy()
  275. # operation -> list(operation)
  276. state.op_to_opgrad = self.op_to_opgrad.copy()
  277. # value -> list(value)
  278. state.valuegrad_to_value = self.valuegrad_to_value.copy()
  279. state.sumvaluegrad_to_value = self.sumvaluegrad_to_value.copy()
  280. # operation -> list(operation)
  281. state.opgrad_to_op = self.opgrad_to_op.copy()
  282. # only for controlflow
  283. state.inside_value_to_outside_value_map = (
  284. self.inside_value_to_outside_value_map.copy()
  285. )
  286. return state
  287. def _check_vjp_dynamic_shape(op, inputs):
  288. for items in inputs:
  289. for item in items:
  290. if item.initialized() and -1 in item.shape:
  291. warnings.warn(
  292. f"[Prim] Decomp op does not support dynamic shape -1, but got shape {item.shape} in inputs of op {op.name()} . Prim will skip its vjp op."
  293. )
  294. return True
  295. # Prim currently does not support dynamic shape, when dynamic shape exits in shape of op inputs, prim will be skipped its vjp op.
  296. @signature_safe_contextmanager
  297. def dynamic_shape_prim_vjp_guard(op, inputs):
  298. origin_prim = core._is_bwd_prim_enabled()
  299. if op.name() == "cf.tuple_push":
  300. skip_prim = True
  301. else:
  302. skip_prim = (
  303. origin_prim
  304. and core._enable_prim_skip_dynamic_shape()
  305. and _check_vjp_dynamic_shape(op, inputs)
  306. and op.name() not in ALLOW_DYNAMIC_SHAPE_VJP_OPS
  307. )
  308. try:
  309. if origin_prim and skip_prim:
  310. core._set_prim_backward_enabled(False)
  311. yield
  312. finally:
  313. if origin_prim:
  314. core._set_prim_backward_enabled(True)
  315. def check_type(input, input_name, expected_type, op_name, extra_message=''):
  316. if not isinstance(input, expected_type):
  317. raise TypeError(
  318. f"The type of '{input_name}' in {op_name} must be {expected_type}, but received {type(input)}. {extra_message}"
  319. )
  320. def _as_list(x):
  321. if x is None:
  322. return []
  323. return list(x) if isinstance(x, Sequence) else [x]
  324. def some_in_set(value_list, value_set):
  325. return any(v in value_set for v in value_list)
  326. def is_control_flow(op):
  327. return op.name() == "pd_op.if" or op.name() == "pd_op.while"
  328. def is_builtin_op(op):
  329. dialect_name, opname = op.name().split(".")
  330. return dialect_name == "builtin"
  331. def update_no_grad_set_by_stopgradient(block, no_grad_set):
  332. for op in block.ops:
  333. if is_control_flow(op):
  334. for sub_block in op.blocks():
  335. update_no_grad_set_by_stopgradient(sub_block, no_grad_set)
  336. for value in op.results():
  337. if value.stop_gradient and value not in no_grad_set:
  338. no_grad_set.add(value)
  339. def get_real_op_inputs(op):
  340. if op.name() == "pd_op.if":
  341. return get_used_external_value(op)
  342. elif op.name() == "pd_op.while":
  343. return op.operands_source() + get_used_external_value(
  344. op.as_while_op().body()
  345. )
  346. elif op.name() == "pd_op.pylayer":
  347. return get_used_external_value(op)
  348. else:
  349. return op.operands_source()
  350. def inverse_sort_op(ops):
  351. '''
  352. if topo graph is op1 -> op2 -> op3
  353. return [op3, op2, op1]
  354. '''
  355. # init pending_count[op] which describes number of
  356. # pending edges for its grad_op
  357. pending_count = collections.defaultdict(int)
  358. ops_set = set(ops)
  359. sorted_list = []
  360. for op in ops:
  361. for x in get_real_op_inputs(op):
  362. if not pir.is_fake_value(x) and x.get_defining_op() in ops_set:
  363. pending_count[x.get_defining_op()] += 1
  364. queue = collections.deque()
  365. for op in ops:
  366. if pending_count[op] == 0:
  367. queue.append(op)
  368. while queue:
  369. op = queue.popleft()
  370. sorted_list.append(op)
  371. for x in get_real_op_inputs(op):
  372. x_op = x.get_defining_op()
  373. pending_count[x_op] -= 1
  374. if pending_count[x_op] == 0:
  375. queue.append(x_op)
  376. if len(sorted_list) != len(ops):
  377. raise ValueError(
  378. "inverse_sort_op wrong, sorted_list size is not equal to origin_list size"
  379. )
  380. change_list = []
  381. # true %0 = op1, 1% = increment(0%), 3% = op2(0%), tuple_push(%0, 1%, 3%),
  382. # no one use 1% so increment be the first op, actually op2 use 1% ,
  383. # sorted_list = [increment, op2, op1] should be [op2, increment, op1],
  384. # tuple_push(0%) must be forward last op, backward first op, so skip it.
  385. for op in reversed(sorted_list):
  386. if op.name() == 'pd_op.increment_':
  387. idx_1 = sorted_list.index(op)
  388. idx_2 = sorted_list.index(op)
  389. for op_in in reversed(sorted_list[: sorted_list.index(op)]):
  390. if (
  391. some_in_set(
  392. op.operands_source(),
  393. ValueSet(get_real_op_inputs(op_in)),
  394. )
  395. and op_in.name() != "cf.tuple_push"
  396. ):
  397. idx_2 = sorted_list.index(op_in)
  398. if idx_1 != idx_2:
  399. change_list.append((idx_1, idx_2))
  400. for idx_1, idx_2 in change_list:
  401. sorted_list[idx_1], sorted_list[idx_2] = (
  402. sorted_list[idx_2],
  403. sorted_list[idx_1],
  404. )
  405. return sorted_list
  406. def is_inplace_net(op_list):
  407. '''
  408. when program has inplace op , it's difficult to find the actual pending_count.
  409. '''
  410. for op in op_list:
  411. if op.name() in ["pd_op.array_write_", "pd_op.assign_out_"]:
  412. return True
  413. if is_control_flow(op):
  414. for block in op.blocks():
  415. if is_inplace_net(block.ops):
  416. return True
  417. return False
  418. def remove_op(block, op, state):
  419. '''
  420. remove op from block
  421. '''
  422. if state.opgrad_to_op[op] != []:
  423. fwd_op = state.opgrad_to_op[op][0]
  424. state.op_to_opgrad[fwd_op].remove(op)
  425. for valuegrad in op.results():
  426. if state.valuegrad_to_value[valuegrad] != []:
  427. value = state.valuegrad_to_value[valuegrad][0]
  428. state.value_to_valuegrad[value] = []
  429. if value in state.sumvaluegrad_to_value:
  430. raise ValueError(
  431. 'input_grad in [%s] is value which need to sum ', op.name()
  432. )
  433. # NOTE(SigureMo): Ensure access to the op's results before removing it.
  434. # Otherwise, the op will be deconstructed and access the num_results
  435. # will be undefined behavior, it always cause hanging on the macOS.
  436. block.remove_op(op)
  437. def while_prune_check(while_tuple_ops):
  438. if len(while_tuple_ops) != 0:
  439. for opresult in while_tuple_ops[0].results():
  440. if not opresult.use_empty():
  441. return False
  442. return True
  443. return False
  444. def remove_useless_full_like_ops(block, ops, state):
  445. '''
  446. remove ops which are not in use recursively,
  447. '''
  448. remove_ops = []
  449. inverse_ops = inverse_sort_op(list(ops))
  450. # from output to input
  451. for op in inverse_ops:
  452. if op.name() == "pd_op.full_like":
  453. if op.result(0).use_empty():
  454. full_op = op.operand_source(1).get_defining_op()
  455. remove_ops.append(op)
  456. remove_ops.append(full_op)
  457. elif is_control_flow(op):
  458. for sub_block in op.blocks():
  459. remove_useless_full_like_ops(sub_block, sub_block.ops, state)
  460. for op in remove_ops:
  461. remove_op(block, op, state)
  462. def all_stop_gradient_true(block):
  463. for op in block.ops:
  464. for value in op.results():
  465. if value.stop_gradient is False:
  466. return False
  467. return True
  468. def all_input_stop_gradient_true(list_of_list):
  469. for list_ in list_of_list:
  470. for stop_gradient in list_:
  471. if stop_gradient is False:
  472. return False
  473. return True
  474. def all_output_grad_none(list_of_list):
  475. for list_ in list_of_list:
  476. for value in list_:
  477. if value is not None:
  478. return False
  479. return True
  480. def op_has_vjp(op):
  481. # NOTE(MarioLulab): In PIR mode, even though the `PyLayer` op does
  482. # not have a vjp interface, we still need to generate the backward
  483. # block based on its registered backward function. To achieve this,
  484. # we add more handling logic for `PyLayer` Op in the `call_vjp` function
  485. return core.has_vjp(op) or op.name() == "pd_op.pylayer"
  486. def parent_total_ops(block):
  487. '''
  488. when block is sub_block, forward op should include its parent block ops
  489. (sub block nest should Add on demand to avoid block copy)
  490. '''
  491. total_ops = []
  492. if block.parent_block is not None:
  493. if block.parent_block.parent_block:
  494. total_ops += block.parent_block.parent_block.ops
  495. total_ops += block.parent_block.ops
  496. total_ops += block.ops
  497. return total_ops
  498. # only for control_flow to find corresponding value or value_list
  499. def return_map_value(value, map):
  500. output = value
  501. while output in map:
  502. output = map[output]
  503. return output
  504. def return_map_value_list(value, map):
  505. output = []
  506. for i in range(len(value)):
  507. if value[i] in map:
  508. output.append(return_map_value(value[i], map))
  509. else:
  510. output.append(value[i])
  511. return output
  512. def argument_to_value(while_op):
  513. '''
  514. return while op's relationship of (block_argument to input value) and (input value to block_argument).
  515. '''
  516. if while_op.name() != "pd_op.while":
  517. return ValueDict(), ValueDict()
  518. assert len(while_op.as_while_op().block_arguments()) + 1 == len(
  519. while_op.operands_source()
  520. ), "while op's block_arguments size + 1 should same to while op's operands_source size"
  521. arg_to_value_map = ValueDict()
  522. value_to_arg_map = ValueDict()
  523. for arg, value in zip(
  524. while_op.as_while_op().block_arguments(),
  525. while_op.operands_source()[1:],
  526. ):
  527. arg_to_value_map[arg] = value
  528. value_to_arg_map[value] = arg
  529. return arg_to_value_map, value_to_arg_map
  530. def get_grad_semantic_info(op):
  531. '''
  532. return whether op's inputs has grad, usually handled from yaml.
  533. some op has uncertain inputs need special handling.
  534. '''
  535. if op.name() in [
  536. "builtin.combine",
  537. "pd_op.if",
  538. "pd_op.while",
  539. "pd_op.pylayer",
  540. "cf.tuple_push",
  541. ]:
  542. grad_semantic_info = [True for _ in range(len(get_real_op_inputs(op)))]
  543. if op.name() == "pd_op.if":
  544. grad_semantic_info[0] = False
  545. else:
  546. grad_semantic_info = op.get_input_grad_semantics()
  547. return grad_semantic_info
  548. def get_split_op(value):
  549. for op in value.all_used_ops():
  550. if op.name() == "builtin.split":
  551. return op
  552. return None
  553. @lru_cache
  554. def warning_once(message: str):
  555. logging.warning(message)