backward.py 110 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import copy
  16. import logging
  17. import os
  18. import re
  19. import warnings
  20. from collections.abc import Sequence
  21. import paddle.base
  22. from . import core, framework, log_helper, unique_name
  23. from .data_feeder import check_type
  24. from .framework import program_guard
  25. from .proto import framework_pb2
  26. __all__ = []
  27. _logger = log_helper.get_logger(
  28. __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
  29. )
  30. class ProgramStats:
  31. def __init__(self, block, ops):
  32. self.block = block
  33. self.ops = ops
  34. self.op_deps = {} # op-> in_ops, out_ops
  35. self.var_op_deps = {} # var as input op, var as output op
  36. def get_input_nodes(self):
  37. input_names = []
  38. for name in self.var_op_deps:
  39. if (
  40. len(self.var_op_deps[name]["var_as_output_ops"]) == 0
  41. and len(self.var_op_deps[name]["var_as_input_ops"]) > 0
  42. ):
  43. if self.block.var(name).persistable:
  44. continue
  45. input_names.append(name)
  46. for op in self.ops:
  47. if op.desc.type() == "read":
  48. input_names.extend(op.desc.output_arg_names())
  49. return input_names
  50. def get_reserved_vars(self):
  51. var_name = []
  52. for op in self.ops:
  53. if op.desc.type() == "seed":
  54. var_name.extend(op.desc.output_arg_names())
  55. return var_name
  56. def get_out_of_subgraph_vars(self, begin_op_idx, end_op_idx):
  57. var_name = []
  58. for i in range(begin_op_idx, end_op_idx, 1):
  59. for name in self.ops[i].desc.output_arg_names():
  60. if name in self.var_op_deps:
  61. for idx in self.var_op_deps[name]["var_as_input_ops"]:
  62. if idx >= end_op_idx:
  63. var_name.append(name)
  64. for name in self.ops[i].desc.input_arg_names():
  65. if name in self.var_op_deps:
  66. for idx in self.var_op_deps[name]["var_as_output_ops"]:
  67. if idx < begin_op_idx:
  68. var_name.append(name)
  69. return var_name
  70. def is_subgraph(self, var_group1, var_group2):
  71. # should traverse from var_group1 to var_group2
  72. # max op idx in var_group2
  73. # min op idx in var_group1
  74. min_op_idx = len(self.ops)
  75. max_op_idx = -1
  76. for name in var_group1:
  77. if name not in self.var_op_deps:
  78. return False, min_op_idx, max_op_idx
  79. for name in var_group2:
  80. if name not in self.var_op_deps:
  81. return False, min_op_idx, max_op_idx
  82. for name in var_group1:
  83. op_idx = self.var_op_deps[name]["var_as_input_ops"]
  84. for idx in op_idx:
  85. min_op_idx = min(min_op_idx, idx)
  86. for name in var_group2:
  87. op_idx = self.var_op_deps[name]["var_as_output_ops"]
  88. for idx in op_idx:
  89. max_op_idx = max(max_op_idx, idx)
  90. if min_op_idx >= max_op_idx:
  91. return False, min_op_idx, max_op_idx
  92. return True, min_op_idx, max_op_idx
  93. def _update_segment_start(self, min_idx, pre_segment_end_idx):
  94. """
  95. persist vars of amp-related cast should be included in recompute segment
  96. """
  97. def is_amp_cast(op):
  98. return (
  99. op.desc.type() == 'cast'
  100. and self.block.var(op.desc.input_arg_names()[0]).persistable
  101. )
  102. idx_ = min_idx - 1
  103. updated_min_idx = min_idx
  104. while idx_ > pre_segment_end_idx:
  105. if is_amp_cast(self.ops[idx_]):
  106. _logger.info(
  107. f"found amp-cast op: {self.ops[idx_].desc.type()}, : {self.ops[idx_].desc.input_arg_names()[0]}"
  108. )
  109. updated_min_idx = idx_
  110. idx_ -= 1
  111. else:
  112. break
  113. return updated_min_idx
  114. def build_stats(self):
  115. for i, op in enumerate(self.ops):
  116. self.op_deps[i] = {"in_ops": [], "out_ops": []}
  117. for j, name in enumerate(op.desc.input_arg_names()):
  118. if name in self.var_op_deps:
  119. self.op_deps[i]["in_ops"].extend(
  120. self.var_op_deps[name]["var_as_output_ops"]
  121. )
  122. for j, name in enumerate(op.desc.input_arg_names()):
  123. if name in self.var_op_deps:
  124. self.var_op_deps[name]["var_as_input_ops"].extend([i])
  125. else:
  126. self.var_op_deps[name] = {}
  127. self.var_op_deps[name]["var_as_input_ops"] = [i]
  128. self.var_op_deps[name]["var_as_output_ops"] = []
  129. for j, name in enumerate(op.desc.output_arg_names()):
  130. if name in self.var_op_deps:
  131. self.var_op_deps[name]["var_as_output_ops"].extend([i])
  132. else:
  133. self.var_op_deps[name] = {}
  134. self.var_op_deps[name]["var_as_input_ops"] = []
  135. self.var_op_deps[name]["var_as_output_ops"] = [i]
  136. for op_idx in self.op_deps[i]["in_ops"]:
  137. self.op_deps[op_idx]["out_ops"].extend([i])
  138. def sort_checkpoints(self, checkpoints_name):
  139. sorted_checkpoints = []
  140. for name in checkpoints_name:
  141. if name not in self.var_op_deps:
  142. _logger.info(
  143. "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
  144. % name
  145. )
  146. elif self.var_op_deps[name]["var_as_output_ops"] == []:
  147. # input nodes
  148. sorted_checkpoints.append((name, -1))
  149. else:
  150. sorted_checkpoints.append(
  151. (name, max(self.var_op_deps[name]["var_as_output_ops"]))
  152. )
  153. sorted_checkpoints = sorted(sorted_checkpoints, key=lambda x: x[1])
  154. return [x[0] for x in sorted_checkpoints]
  155. def modify_forward_desc_for_recompute(self):
  156. op_types = [op.desc.type() for op in self.ops]
  157. if "dropout" not in op_types:
  158. return
  159. op_idx = 0
  160. while op_idx < len(self.ops):
  161. op = self.ops[op_idx]
  162. if op.desc.type() != "dropout":
  163. op_idx += 1
  164. continue
  165. # already insert seed op before dropout
  166. if op.input('Seed') is not None and len(op.input('Seed')) == 1:
  167. op_idx += 1
  168. continue
  169. # add a seed op so that the two dropout op can generate same output
  170. op_unique_name = unique_name.generate("seed")
  171. var_unique_name = unique_name.generate_with_ignorable_key(
  172. ".".join([op_unique_name, 'tmp'])
  173. )
  174. added_var = self.block.create_var(
  175. name=var_unique_name,
  176. dtype='int32',
  177. type=core.VarDesc.VarType.LOD_TENSOR,
  178. persistable=False,
  179. stop_gradient=False,
  180. )
  181. seed = 0 if op.attr("fix_seed") is False else int(op.attr("seed"))
  182. op_device_attr_name = (
  183. core.op_proto_and_checker_maker.kOpDeviceAttrName()
  184. )
  185. op_device = ""
  186. if op.desc.has_attr(op_device_attr_name):
  187. op_device = op.desc.attr(op_device_attr_name)
  188. # Setting the force_cpu of seed to true will make the output of seed in cpu memory,
  189. # reduce the synchronous copy from GPU to CPU in dropout, and reduce the communication hang
  190. added_op = self.block._insert_op(
  191. index=op.idx,
  192. type='seed',
  193. inputs={},
  194. outputs={'Out': [added_var]},
  195. attrs={'seed': seed, 'op_device': op_device, 'force_cpu': True},
  196. )
  197. self.ops.insert(op_idx, added_op)
  198. # modify dropout op desc so that it accept a seed var as input
  199. op.desc.set_input("Seed", [var_unique_name])
  200. op.desc.remove_attr("fix_seed")
  201. op.desc.remove_attr("seed")
  202. self.block._sync_with_cpp()
  203. op_idx += 2
  204. def _pretty_op_desc_(op_desc, prefix):
  205. out_s = "{}\tname:[{}]\n{} \tinputs:[{}]\n{} \toutputs:[{}]".format(
  206. prefix + "_op",
  207. str(op_desc.type()),
  208. prefix + "_input",
  209. " ".join(op_desc.input_arg_names()),
  210. prefix + "_output",
  211. " ".join(op_desc.output_arg_names()),
  212. )
  213. return out_s
  214. def _add_needed_descs_to_block(
  215. descs, block, main_block, in_memory_vars, grad_op_id_to_fwd_op=None
  216. ):
  217. if len(descs) == 0:
  218. return []
  219. result_descs = []
  220. op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
  221. backward = core.op_proto_and_checker_maker.OpRole.Backward
  222. for desc in descs:
  223. origin_desc = desc
  224. origin_is_operator = False
  225. if isinstance(desc, framework.Operator):
  226. desc = desc.desc
  227. origin_is_operator = True
  228. if isinstance(desc, tuple):
  229. desc = desc[0]
  230. is_needed = False
  231. for name in desc.output_arg_names():
  232. if main_block.has_var(name) and main_block.var(name).persistable:
  233. continue
  234. if name not in in_memory_vars:
  235. is_needed = True
  236. if is_needed:
  237. if origin_is_operator and grad_op_id_to_fwd_op is not None:
  238. grad_op_id_to_fwd_op[desc.original_id()] = origin_desc
  239. new_op_desc = block.desc.append_op()
  240. new_op_desc.copy_from(desc)
  241. new_op_desc._set_attr(op_role_attr_name, backward)
  242. if desc.has_attr('op_device'):
  243. new_op_desc._set_attr('op_device', desc.attr('op_device'))
  244. result_descs.append(new_op_desc)
  245. return result_descs
  246. def _add_descs_to_block(descs, block, grad_op_id_to_fwd_op=None):
  247. if len(descs) == 0:
  248. return []
  249. result_descs = []
  250. op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
  251. backward = core.op_proto_and_checker_maker.OpRole.Backward
  252. for desc in descs:
  253. if isinstance(desc, framework.Operator):
  254. # for recompute, should record recompute ops
  255. if grad_op_id_to_fwd_op is not None:
  256. grad_op_id_to_fwd_op[desc.desc.original_id()] = desc
  257. desc = desc.desc
  258. if isinstance(desc, tuple):
  259. desc = desc[0]
  260. new_op_desc = block.desc.append_op()
  261. new_op_desc.copy_from(desc)
  262. new_op_desc._set_attr(op_role_attr_name, backward)
  263. if desc.has_attr('op_device'):
  264. new_op_desc._set_attr('op_device', desc.attr('op_device'))
  265. result_descs.append(new_op_desc)
  266. return result_descs
  267. def _find_loss_op_(loss):
  268. for op in reversed(loss.block.ops):
  269. assert isinstance(op, framework.Operator)
  270. if (
  271. len(op.output_arg_names) == 1
  272. and op.output_arg_names[0] == loss.name
  273. ):
  274. loss.op = op
  275. break
  276. if loss.op is None:
  277. raise ValueError("loss.op is None. Should not happen")
  278. def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
  279. """
  280. Traverse all ops in op_descs[begin_idx : end_idx],
  281. if any op has inputs/outputs named "old_name", rename it as 'new_name'
  282. """
  283. if begin_idx is None:
  284. begin_idx = 0
  285. if end_idx is None:
  286. end_idx = len(op_descs)
  287. if isinstance(op_descs, (list, tuple)):
  288. for i in range(begin_idx, end_idx):
  289. op_desc = op_descs[i]
  290. if isinstance(op_desc, tuple):
  291. op_desc = op_desc[0]
  292. op_desc._rename_input(old_name, new_name)
  293. op_desc._rename_output(old_name, new_name)
  294. if isinstance(op_descs, collections.OrderedDict):
  295. for key, value in op_descs.items():
  296. if isinstance(value, (list, tuple)):
  297. for op_desc in value:
  298. op_desc._rename_input(old_name, new_name)
  299. op_desc._rename_output(old_name, new_name)
  300. def _create_op_desc_(op_type, inputs, outputs, attrs):
  301. """
  302. Create a C++ OpDesc object with specified inputs, outputs and attributes.
  303. """
  304. op_desc = core.OpDesc()
  305. op_desc.set_type(op_type)
  306. for para, args in inputs.items():
  307. op_desc.set_input(
  308. para,
  309. [arg.decode() if isinstance(arg, bytes) else arg for arg in args],
  310. )
  311. for para, args in outputs.items():
  312. op_desc.set_output(
  313. para,
  314. [arg.decode() if isinstance(arg, bytes) else arg for arg in args],
  315. )
  316. op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
  317. op_device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
  318. if op_role_attr_name not in attrs:
  319. attrs[
  320. op_role_attr_name
  321. ] = core.op_proto_and_checker_maker.OpRole.Backward
  322. if op_device_attr_name not in attrs:
  323. attrs[op_device_attr_name] = ""
  324. for name, val in attrs.items():
  325. if isinstance(val, framework.Block):
  326. op_desc.set_block_attr(name, val.desc)
  327. else:
  328. op_desc._set_attr(name, val)
  329. return op_desc
  330. def _create_loss_op_desc_(loss):
  331. # 0-D Tensor or 0-Size Tensor
  332. if len(loss.shape) == 0 or 0 in loss.shape:
  333. create_shape = loss.shape
  334. else:
  335. create_shape = [1]
  336. op_desc = _create_op_desc_(
  337. "fill_constant",
  338. {},
  339. {"Out": [_append_grad_suffix_(loss.name)]},
  340. {
  341. "shape": create_shape,
  342. "value": 1.0,
  343. "dtype": loss.dtype,
  344. "force_cpu": False,
  345. core.op_proto_and_checker_maker.kOpRoleAttrName(): int(
  346. core.op_proto_and_checker_maker.OpRole.Backward
  347. )
  348. | int(core.op_proto_and_checker_maker.OpRole.Loss),
  349. core.op_proto_and_checker_maker.kOpDeviceAttrName(): loss.op.attr(
  350. core.op_proto_and_checker_maker.kOpDeviceAttrName()
  351. ),
  352. },
  353. )
  354. return op_desc
  355. def _infer_var_data_type_shape_(grad_var_name, block):
  356. """
  357. Infer the data type and shape of given grad variable
  358. """
  359. grad_var = block.desc.find_var(grad_var_name.encode())
  360. fwd_name = _strip_grad_suffix_(grad_var_name)
  361. if block.desc.has_var_recursive(fwd_name.encode()):
  362. fwd_var = block.desc.find_var_recursive(fwd_name.encode())
  363. grad_var.set_dtype(fwd_var.dtype())
  364. grad_var.set_shape(fwd_var.shape())
  365. else:
  366. # TODO(jiabin): Maybe we should not to this to cause some unexpected error on dtype
  367. warnings.warn(
  368. f"Set grad var: {grad_var_name} dtype to default FP32, since we can't find its related forward var"
  369. )
  370. grad_var.set_dtype(core.VarDesc.VarType.FP32)
  371. def _all_in_set_(cands, s):
  372. """
  373. Test if all elements of 'cands' are in set 's'
  374. """
  375. if len(cands) == 0:
  376. return False
  377. for c in cands:
  378. if c not in s:
  379. return False
  380. return True
  381. def _some_in_set_(cands, s):
  382. """
  383. Test if some elements of 'cands' are in set 's'
  384. """
  385. if len(cands) == 0:
  386. return False
  387. for c in cands:
  388. if c in s:
  389. return True
  390. return False
  391. def _strip_grad_suffix_(name):
  392. """
  393. Strip the grad suffix from the given variable name
  394. e.g. x@GRAD ==> x
  395. x@GRAD@GRAD ==> x
  396. y@GRAD@RENAME@1 ==> y
  397. z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0
  398. grad/grad/z@GRAD@RENAME@block0@1@GRAD ==> z
  399. """
  400. pos = re.search(f'{core.grad_var_suffix()}+@', name) or re.search(
  401. f'{core.grad_var_suffix()}$', name
  402. )
  403. new_name = name[: pos.start()] if pos is not None else name
  404. new_pos = name.rfind('grad/')
  405. return new_name[new_pos + 5 :] if new_pos != -1 else new_name
  406. def _append_grad_suffix_(name):
  407. """
  408. Append grad suffix to the given variable name
  409. e.g. x ==> x@GRAD
  410. """
  411. return name + core.grad_var_suffix()
  412. def _accumulate_gradients_by_sum_op_(
  413. var_name, renamed_vars, pending_sum_ops, op_idx, op_device=""
  414. ):
  415. """
  416. Use sum op to accumulate_gradients, the gradients are stored in renamed_vars.
  417. """
  418. if op_idx not in pending_sum_ops.keys():
  419. pending_sum_ops[op_idx] = []
  420. pending_sum_ops[op_idx].append(
  421. _create_op_desc_(
  422. "sum",
  423. {"X": renamed_vars[var_name]},
  424. {"Out": [var_name]},
  425. {"op_device": op_device},
  426. )
  427. )
  428. renamed_vars[var_name] = [var_name]
  429. def _accumulate_gradients_by_add_ops_(
  430. var_name,
  431. renamed_vars,
  432. pending_sum_ops,
  433. op_idx,
  434. op_device="",
  435. grad_var_to_var=None,
  436. ):
  437. """
  438. Use several inplace add op to accumulate_gradients, the gradients are stored in renamed_vars.
  439. """
  440. if op_idx not in pending_sum_ops.keys():
  441. pending_sum_ops[op_idx] = []
  442. out_name = renamed_vars[var_name][0]
  443. for i in range(1, len(renamed_vars[var_name])):
  444. x_name = out_name
  445. y_name = renamed_vars[var_name][i]
  446. if i != len(renamed_vars[var_name]) - 1:
  447. out_name = var_name + '@ADD@' + str(i)
  448. else:
  449. out_name = var_name
  450. pending_sum_ops[op_idx].append(
  451. _create_op_desc_(
  452. "grad_add",
  453. {"X": [x_name], "Y": [y_name]},
  454. {"Out": [out_name]},
  455. {"op_device": op_device},
  456. )
  457. )
  458. # record mapping between out grad var name and fwd var name (only for auto parallel)
  459. if grad_var_to_var is not None:
  460. if var_name in grad_var_to_var:
  461. grad_var_to_var[out_name] = grad_var_to_var[var_name]
  462. else:
  463. grad_var_to_var[out_name] = var_name
  464. renamed_vars[var_name] = [var_name]
  465. def _addup_repetitive_outputs_(
  466. op_descs,
  467. block_idx,
  468. grad_var_to_var=None,
  469. grad_op_id_to_fwd_op=None,
  470. topo_order_for_backward=None,
  471. ):
  472. """
  473. In backward part, an variable may be the output of more than one ops.
  474. And one op may yield its multiple outputs to the same variable.
  475. In these cases, the variable should be the accumulation of all the outputs.
  476. `sum_op`s are added to implement the accumulate.
  477. Args:
  478. grad_var_to_var(dict): used to build the mapping between grad var name and forward var name.
  479. Only for auto parallel.
  480. """
  481. _MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add']
  482. topo_order_for_grad_name = {}
  483. # pending_sum_ops = []
  484. pending_sum_ops = collections.OrderedDict()
  485. var_rename_count = collections.defaultdict(int)
  486. renamed_vars = collections.defaultdict(list)
  487. renamed_var_start_idx = collections.defaultdict(list)
  488. var_device = collections.defaultdict(str)
  489. def _change_order_by_topo_order(var_name):
  490. if topo_order_for_backward is None:
  491. return
  492. origin_names = renamed_vars[var_name]
  493. origin_names.sort(key=lambda x: topo_order_for_grad_name[x])
  494. for idx, op_desc in enumerate(op_descs):
  495. op_device_attr_name = (
  496. core.op_proto_and_checker_maker.kOpDeviceAttrName()
  497. )
  498. op_device = ""
  499. if op_desc.has_attr(op_device_attr_name):
  500. op_device = op_desc.attr(op_device_attr_name)
  501. for var_name in op_desc.input_arg_names():
  502. if "@GRAD" not in var_name:
  503. continue
  504. if len(renamed_vars[var_name]) > 1:
  505. if len(renamed_vars[var_name]) > _MAX_ADD_NUM_:
  506. _change_order_by_topo_order(var_name)
  507. _accumulate_gradients_by_sum_op_(
  508. var_name,
  509. renamed_vars,
  510. pending_sum_ops,
  511. idx,
  512. var_device[var_name],
  513. )
  514. else:
  515. _change_order_by_topo_order(var_name)
  516. _accumulate_gradients_by_add_ops_(
  517. var_name,
  518. renamed_vars,
  519. pending_sum_ops,
  520. idx,
  521. var_device[var_name],
  522. grad_var_to_var,
  523. )
  524. for param_idx, param_name in enumerate(op_desc.output_names()):
  525. arg_names = op_desc.output(param_name)
  526. for arg_idx, var_name in enumerate(arg_names):
  527. if "@GRAD" not in var_name:
  528. continue
  529. # if "@RENAME@" in var_name:
  530. # continue
  531. if (
  532. var_name == core.empty_var_name()
  533. or var_name in op_desc.input_arg_names()
  534. ):
  535. # empty variable or inplace op
  536. continue
  537. if len(renamed_vars[var_name]) == 0:
  538. # it's the first time we get the variable
  539. renamed_vars[var_name] = [var_name]
  540. renamed_var_start_idx[var_name] = idx
  541. topo_order_for_grad_name[var_name] = (
  542. topo_order_for_backward[op_desc]
  543. if topo_order_for_backward
  544. and op_desc in topo_order_for_backward
  545. else 1
  546. )
  547. else:
  548. if len(renamed_vars[var_name]) == 1:
  549. new_name = (
  550. var_name
  551. + "@RENAME@block"
  552. + str(block_idx)
  553. + "@"
  554. + str(var_rename_count[var_name])
  555. )
  556. var_rename_count[var_name] += 1
  557. # Build the mapping between the new_name and var_name (Only for auto parallel)
  558. if grad_var_to_var is not None:
  559. if var_name in grad_var_to_var:
  560. grad_var_to_var[new_name] = grad_var_to_var[
  561. var_name
  562. ]
  563. else:
  564. grad_var_to_var[new_name] = var_name
  565. # rename original var_name
  566. topo_order_for_grad_name[
  567. new_name
  568. ] = topo_order_for_grad_name[var_name]
  569. renamed_vars[var_name][0] = new_name
  570. # before change: _rename_arg_(op_descs, var_name,
  571. # new_name, 0, idx)
  572. # rename arg from idx of the first appearance
  573. # in backward, not always from 0
  574. _rename_arg_(
  575. op_descs,
  576. var_name,
  577. new_name,
  578. renamed_var_start_idx[var_name],
  579. idx,
  580. )
  581. _rename_arg_(pending_sum_ops, var_name, new_name)
  582. for p in op_desc.output_names()[:param_idx]:
  583. p_arg_names = op_desc.output(p)
  584. if var_name in p_arg_names:
  585. op_desc.set_output(
  586. p,
  587. [
  588. new_name if x == var_name else x
  589. for x in p_arg_names
  590. ],
  591. )
  592. arg_names = [
  593. new_name if x == var_name else x
  594. for x in arg_names[:arg_idx]
  595. ] + arg_names[arg_idx:]
  596. new_name = (
  597. var_name
  598. + "@RENAME@block"
  599. + str(block_idx)
  600. + "@"
  601. + str(var_rename_count[var_name])
  602. )
  603. var_rename_count[var_name] += 1
  604. # Build the mapping between the new_name and var_name (Only for auto parallel)
  605. if grad_var_to_var is not None:
  606. if var_name in grad_var_to_var:
  607. grad_var_to_var[new_name] = grad_var_to_var[
  608. var_name
  609. ]
  610. else:
  611. grad_var_to_var[new_name] = var_name
  612. arg_names[arg_idx] = new_name
  613. op_desc.set_output(param_name, arg_names)
  614. renamed_vars[var_name].append(new_name)
  615. # record the latest device
  616. var_device[var_name] = op_device
  617. topo_order_for_grad_name[new_name] = (
  618. topo_order_for_backward[op_desc]
  619. if topo_order_for_backward
  620. and op_desc in topo_order_for_backward
  621. else 1
  622. )
  623. for var_name, inputs in renamed_vars.items():
  624. if len(renamed_vars[var_name]) > 1:
  625. if len(renamed_vars[var_name]) > _MAX_ADD_NUM_:
  626. _change_order_by_topo_order(var_name)
  627. _accumulate_gradients_by_sum_op_(
  628. var_name,
  629. renamed_vars,
  630. pending_sum_ops,
  631. len(op_descs),
  632. var_device[var_name],
  633. )
  634. else:
  635. _change_order_by_topo_order(var_name)
  636. _accumulate_gradients_by_add_ops_(
  637. var_name,
  638. renamed_vars,
  639. pending_sum_ops,
  640. len(op_descs),
  641. var_device[var_name],
  642. )
  643. op_descs_len = len(op_descs)
  644. # sum_op descs are sorted according to their insert position
  645. for key, value in collections.OrderedDict(
  646. reversed(list(pending_sum_ops.items()))
  647. ).items():
  648. # NOTE(zhiqiu): Since reversed, the idx of op_descs to be inserted will remains correct.
  649. # For example, [0, 1, 2], and we want to insert 'a' at idx 1, 'b' at idx 2, and the expected result is [0, 1, 'a', 2, 'b'].
  650. # If reversed, we first insert 'b' at idx 2, it becomes [0, 1, 2, 'b'], and then insert 'a' at idx 1, it becomes [0, 1, 'a', 2, 'b'].
  651. # If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2].
  652. idx = key
  653. for i, op in enumerate(value):
  654. # update the mapping between fwd and bwd
  655. target_idx = idx - 1 if idx == op_descs_len else idx + i
  656. if (
  657. grad_op_id_to_fwd_op is not None
  658. and grad_op_id_to_fwd_op.get(
  659. op_descs[target_idx].original_id(), None
  660. )
  661. is not None
  662. ):
  663. grad_op_id_to_fwd_op[op.original_id()] = grad_op_id_to_fwd_op[
  664. op_descs[target_idx].original_id()
  665. ]
  666. op_descs.insert(idx + i, op)
  667. return op_descs
  668. def _remove_no_grad_branch_(
  669. op_descs, no_grad_set, grad_op_id_to_fwd_op=None, target_vars=[]
  670. ):
  671. """
  672. Remove unnecessary grad ops
  673. A grad op can be removed in two cases:
  674. 1. all outputs of the grad op are in 'no_grad_set'
  675. 2. all grad inputs of the grad op are in 'no_grad_set'
  676. NOTE: we will skip target_vars's grad name.
  677. """
  678. def _op_can_be_removed_(op_desc, no_grad_set):
  679. out_arg_names = op_desc.output_arg_names()
  680. if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set):
  681. return True
  682. if _all_in_set_(
  683. [
  684. name
  685. for name in op_desc.input_arg_names()
  686. if name.find(core.grad_var_suffix()) != -1
  687. ],
  688. no_grad_set,
  689. ):
  690. no_grad_set.update(set(out_arg_names) - target_grad_var_names)
  691. return True
  692. return False
  693. # Remove ops whose outputs are all in no_grad_dict
  694. target_grad_var_names = {
  695. var.name + core.grad_var_suffix() for var in target_vars
  696. }
  697. op_descs = [
  698. op_desc
  699. for op_desc in op_descs
  700. if not _op_can_be_removed_(op_desc, no_grad_set)
  701. ]
  702. # Insert fill_any_like_op with value 0
  703. to_insert = []
  704. if not core._is_bwd_prim_enabled():
  705. for idx, op_desc in enumerate(op_descs):
  706. for arg in op_desc.input_arg_names():
  707. # arg is a gradient var name and arg should not have gradient
  708. if core.grad_var_suffix() in arg and arg in no_grad_set:
  709. x_in = _strip_grad_suffix_(arg)
  710. # the reason should be: arg can be input of another grad op
  711. # and the op is a not-to-remove op
  712. new_op_desc = _create_op_desc_(
  713. "fill_any_like",
  714. {"X": [x_in]},
  715. {"Out": [arg]},
  716. {'value': 0, 'dtype': -1},
  717. )
  718. # update the mapping between fwd and bwd
  719. if (
  720. grad_op_id_to_fwd_op is not None
  721. and grad_op_id_to_fwd_op.get(
  722. op_desc.original_id(), None
  723. )
  724. is not None
  725. ):
  726. grad_op_id_to_fwd_op[
  727. new_op_desc.original_id()
  728. ] = grad_op_id_to_fwd_op[op_desc.original_id()]
  729. to_insert.append((new_op_desc, idx))
  730. [op_descs.insert(p[1], p[0]) for p in reversed(to_insert)]
  731. return op_descs
  732. def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
  733. """
  734. Pruning Program with Structural Analysis Method of Computational Graph.
  735. The nodes of the computational graph composed of backward OPS should be
  736. interconnected. If there are unconnected sub-graphs in the computational graph,
  737. these sub-graphs should be cut off.
  738. Args:
  739. grad_op_descs(list[core.OpDesc]): The candidate backward OpDescs.
  740. forward_ops(list[Operator]): The forward ops.
  741. input_grad_names_set(set): this set is used to store the gradients' name
  742. which is generated by backward ops, and input_grad_names_set can help
  743. to prune the unnecessary backward ops.
  744. Return:
  745. (set[core.OpDesc]): A set of OpDescs which should be pruned.
  746. """
  747. class Var:
  748. def __init__(self, var_name):
  749. self.var_name = var_name
  750. self.gen_op = None
  751. self.pending_ops = []
  752. def set_gen_op(self, gen_op):
  753. assert isinstance(gen_op, Op)
  754. assert self.gen_op is None
  755. self.gen_op = gen_op
  756. def add_pending_op(self, op):
  757. assert isinstance(op, Op)
  758. self.pending_ops.append(op)
  759. class Op:
  760. def __init__(self, op_desc):
  761. self.op_desc = op_desc
  762. self.inputs = []
  763. self.outputs = []
  764. def insert_input(self, var):
  765. assert isinstance(var, Var)
  766. self.inputs.append(var)
  767. def insert_output(self, var):
  768. assert isinstance(var, Var)
  769. self.outputs.append(var)
  770. var_versions = {}
  771. def _create_node(name):
  772. if name not in var_versions.keys():
  773. var_versions[name] = [Var(name)]
  774. else:
  775. var_versions[name].append(Var(name))
  776. return var_versions[name][-1]
  777. def _create_or_get_last_version_node(name):
  778. if name not in var_versions.keys():
  779. var_versions[name] = [Var(name)]
  780. return var_versions[name][-1]
  781. def _create_op_node(op_desc):
  782. op_node = Op(op_desc)
  783. for input in op_desc.input_arg_names():
  784. var = _create_or_get_last_version_node(name=input)
  785. var.add_pending_op(op_node)
  786. op_node.insert_input(var)
  787. for output in op_desc.output_arg_names():
  788. var = _create_node(name=output)
  789. var.set_gen_op(op_node)
  790. op_node.insert_output(var)
  791. return op_node
  792. # Record the forward vars
  793. forward_vars_set = (
  794. set() if input_grad_names_set is None else set(input_grad_names_set)
  795. )
  796. for op in forward_ops:
  797. forward_vars_set.update(op.desc.input_arg_names())
  798. forward_vars_set.update(op.desc.output_arg_names())
  799. # Record the vars which are created during backward and is not generated by op.
  800. backward_vars_set = set()
  801. # special_op_nodes is the candidate sub-graph head node.
  802. special_op_nodes = set()
  803. for op_desc in grad_op_descs:
  804. input_set = set(op_desc.input_arg_names())
  805. # The new_vars are created during backward and is not generated by op.
  806. new_vars = input_set - forward_vars_set - backward_vars_set
  807. backward_vars_set.update(op_desc.output_arg_names())
  808. op_node = _create_op_node(op_desc)
  809. if len(new_vars) == len(input_set):
  810. special_op_nodes.add(op_node)
  811. not_need_op_descs = []
  812. # Start traversing all candidate sub-graph headers to check whether
  813. # they are connected to backward computational graphs, and if they are
  814. # not, list them in not_need_op_descs
  815. for special_op_node in special_op_nodes:
  816. op_list = [special_op_node]
  817. ready_vars = set(special_op_node.inputs)
  818. remove_ops = True
  819. candidate_ops = [special_op_node]
  820. while len(candidate_ops) > 0:
  821. op_node = candidate_ops.pop(0)
  822. if _all_in_set_(op_node.inputs, ready_vars):
  823. for out_var in op_node.outputs:
  824. candidate_ops.extend(out_var.pending_ops)
  825. op_list.extend(out_var.pending_ops)
  826. ready_vars.update(op_node.outputs)
  827. else:
  828. remove_ops = False
  829. break
  830. if remove_ops:
  831. not_need_op_descs.extend([node.op_desc for node in op_list])
  832. not_need_op_descs_set = set(not_need_op_descs)
  833. grad_op_descs_set = set(grad_op_descs)
  834. # If a backward computational graph is simply one sub-graph header, the
  835. # not_need_op_descs will be whole graph, this IF clause avoids it.
  836. if grad_op_descs_set == not_need_op_descs_set:
  837. return set()
  838. return not_need_op_descs_set
  839. def serialize_op_decs(op_desc):
  840. protostr = op_desc.serialize_to_string()
  841. proto = framework_pb2.OpDesc.FromString(bytes(protostr))
  842. return proto.__str__()
  843. def _append_backward_ops_with_checkpoints_(
  844. block,
  845. ops,
  846. target_vars,
  847. target_block,
  848. no_grad_dict,
  849. grad_to_var,
  850. checkpoints,
  851. grad_op_id_to_fwd_op=None,
  852. ):
  853. """
  854. Create grad ops with forward ops, and insert them into given block
  855. Args:
  856. block(Block): the block where forward ops are
  857. ops(Op): the forward operators whose forward recomputation backward ops need to be added
  858. target_vars(list[Tensor]): the loss vars we want to calculate gradient.
  859. target_block(Block): the block which is going to hold new generated grad ops
  860. no_grad_dict(dict):
  861. key(int) block index
  862. val(str): corresponding forward variable name
  863. checkpoints: variables that a user defined as checkpoint for forward recomputation
  864. Algorithms:
  865. 0) deal with forward recomputing program descs
  866. 1) find ops between checkpoints, i.e. recompute_segments
  867. 2) go through all forward ops and induct all variables that will be hold in memory
  868. a. variables that are used across segments will be held in memory
  869. b. output of dropout op will be held in memory
  870. c. input variables will be held in memory
  871. 3) go through each recompute_segments, add backward ops with forward recomputation
  872. a. add ops in current recompute_segment as forward recomputation ops
  873. b. rename all non-checkpoint variables in recomputation ops
  874. c. add backward ops of current recomputation ops
  875. d. add sum op for repetitive_outputs
  876. 4) remove no grad branch as it is in _remove_no_grad_branch_
  877. 5) Note1: all appended ops' OpRole are Backward
  878. 6) Note2: all variables with new name should be returned so that _append_backward_vars_ can be called
  879. 7) Note3: current forward recomputation backpropagation does not handle programs with subblock
  880. """
  881. checkpoints_name = [x.name for x in checkpoints]
  882. checkpoints_name = list(set(checkpoints_name))
  883. local_block = block.program._create_block()
  884. buffer_block = block.program._create_block()
  885. # 0) deal with forward recomputing program descs
  886. program_stat = ProgramStats(block, ops)
  887. program_stat.modify_forward_desc_for_recompute()
  888. program_stat.build_stats()
  889. # 1) find ops between checkpoints, i.e. recompute_segments
  890. checkpoints_name = program_stat.sort_checkpoints(checkpoints_name)
  891. segments = []
  892. if len(checkpoints_name) == 1:
  893. # only one checkpoint
  894. max_op_idx = -1
  895. var_group = [checkpoints_name[0]]
  896. for name in var_group:
  897. if name not in program_stat.var_op_deps:
  898. break
  899. op_idx = program_stat.var_op_deps[name]["var_as_output_ops"]
  900. # only count the last generate op
  901. for idx in op_idx:
  902. max_op_idx = max(max_op_idx, idx)
  903. if max_op_idx > 0:
  904. segments.append([0, max_op_idx + 1])
  905. else:
  906. start_idx = 0
  907. pre_segment_end_idx = -1
  908. while True:
  909. if start_idx >= len(checkpoints_name) - 1:
  910. break
  911. # min_idx: checkpoint_1' s input op
  912. # max_idx: checkpoint_2' s output op
  913. flag, min_idx, max_idx = program_stat.is_subgraph(
  914. [checkpoints_name[start_idx]], [checkpoints_name[start_idx + 1]]
  915. )
  916. if flag:
  917. # max_idx + 1 since the exact and used segment end idx is max_idx
  918. min_idx = program_stat._update_segment_start(
  919. min_idx, pre_segment_end_idx
  920. )
  921. segments.append([min_idx, max_idx + 1])
  922. else:
  923. _logger.info(
  924. f"Could not recompute op range [{min_idx}] - [{max_idx + 1}] "
  925. )
  926. start_idx += 1
  927. if segments != [] and segments[0][0] != 0:
  928. recompute_segments = [[0, segments[0][0]]] + segments
  929. else:
  930. recompute_segments = segments
  931. for i, (idx1, idx2) in enumerate(recompute_segments):
  932. _logger.info(f"recompute segment[{i}]")
  933. _logger.info(
  934. f"segment start op: [{ops[idx1].desc.type()}]: [{ops[idx1].desc.input_arg_names()}]"
  935. )
  936. _logger.info(
  937. f"segment end op: [{ops[idx2 - 1].desc.type()}]: [{ops[idx2 - 1].desc.input_arg_names()}]"
  938. )
  939. _logger.info(f"recompute segment[{i}]")
  940. _logger.info(
  941. f"segment start op: [{ops[idx1].desc.type()}]: [{ops[idx1].desc.input_arg_names()}]"
  942. )
  943. _logger.info(
  944. f"segment end op: [{ops[idx2 - 1].desc.type()}]: [{ops[idx2 - 1].desc.input_arg_names()}]"
  945. )
  946. # 2) go through all forward ops and induct all variables that will be hold in memory
  947. vars_should_be_hold = []
  948. # a. variables that are used across segments will be held in memory
  949. for segment in recompute_segments:
  950. vars_should_be_hold.extend(
  951. program_stat.get_out_of_subgraph_vars(segment[0], segment[1])
  952. )
  953. cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
  954. _logger.info(
  955. f"found [{len(cross_vars)}] vars which cross recompute segment: [{cross_vars}], better checkpoints might be set to reduce those vars"
  956. )
  957. # b. output of seed op should be kept in memory
  958. vars_should_be_hold.extend(program_stat.get_reserved_vars())
  959. # c. input variables are checkpoints
  960. vars_should_be_hold.extend(program_stat.get_input_nodes())
  961. vars_should_be_hold = list(set(vars_should_be_hold))
  962. # 3) go through each recompute_segments, add backward ops with forward recomputation
  963. grad_op_descs = []
  964. var_name_dict = {}
  965. vars_in_memory = vars_should_be_hold + checkpoints_name
  966. max_calculated_op_position = len(ops)
  967. device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
  968. if recompute_segments == []:
  969. gap_ops = ops[0:max_calculated_op_position]
  970. for op in reversed(gap_ops):
  971. if op.has_attr("sub_block"):
  972. raise Exception(
  973. "Recompute don't support ops with sub_block"
  974. "invoke op: %s"
  975. % _pretty_op_desc_(op.desc, "with_sub_block")
  976. )
  977. grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
  978. op.desc, no_grad_dict[block.idx], []
  979. )
  980. # record the mapping between fwd and bwd
  981. if grad_op_id_to_fwd_op is not None:
  982. for op_desc in grad_op_desc:
  983. grad_op_id_to_fwd_op[op_desc.original_id()] = op
  984. # Set device for grad_op according to forward Op
  985. if op.desc.has_attr(device_attr_name):
  986. op_device = op.desc.attr(device_attr_name)
  987. for op_desc in grad_op_desc:
  988. op_desc._set_attr(device_attr_name, op_device)
  989. added_descs = _add_descs_to_block(
  990. grad_op_desc, local_block, grad_op_id_to_fwd_op
  991. )
  992. grad_op_descs.extend(added_descs)
  993. grad_to_var.update(op_grad_to_var)
  994. for i, segment in enumerate(recompute_segments[::-1]):
  995. gap_ops = ops[segment[1] : max_calculated_op_position]
  996. max_calculated_op_position = segment[0]
  997. for op in reversed(gap_ops):
  998. if op.has_attr("sub_block"):
  999. raise Exception(
  1000. "Recompute don't support ops with sub_block"
  1001. "invoke op: %s"
  1002. % _pretty_op_desc_(op.desc, "with_sub_block")
  1003. )
  1004. grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
  1005. op.desc, no_grad_dict[block.idx], []
  1006. )
  1007. # record the mapping between fwd and bwd
  1008. if grad_op_id_to_fwd_op is not None:
  1009. for op_desc in grad_op_desc:
  1010. grad_op_id_to_fwd_op[op_desc.original_id()] = op
  1011. # Set device for grad_op according to forward Op
  1012. if op.desc.has_attr(device_attr_name):
  1013. op_device = op.desc.attr(device_attr_name)
  1014. for op_desc in grad_op_desc:
  1015. op_desc._set_attr(device_attr_name, op_device)
  1016. added_descs = _add_descs_to_block(
  1017. grad_op_desc, local_block, grad_op_id_to_fwd_op
  1018. )
  1019. grad_op_descs.extend(added_descs)
  1020. grad_to_var.update(op_grad_to_var)
  1021. ff_ops = ops[segment[0] : segment[1]]
  1022. var_suffix = ".subprog_%d" % i
  1023. for op in ff_ops:
  1024. if op.has_attr("sub_block"):
  1025. raise Exception(
  1026. "Recompute don't support ops with sub_block"
  1027. "invoke op: %s"
  1028. % _pretty_op_desc_(op.desc, "with_sub_block")
  1029. )
  1030. input_and_output_names = []
  1031. input_and_output_names.extend(op.desc.input_arg_names())
  1032. input_and_output_names.extend(op.desc.output_arg_names())
  1033. for name in input_and_output_names:
  1034. if block.var(name).persistable or name in checkpoints_name:
  1035. continue
  1036. if name in vars_should_be_hold:
  1037. continue
  1038. if name not in var_name_dict:
  1039. var_name_dict[name] = name + var_suffix
  1040. # we should create the rename var in subprog, otherwise its VarType will be BOOL
  1041. ref_var = block.program.global_block().var(name)
  1042. block.create_var(
  1043. name=var_name_dict[name],
  1044. shape=ref_var.shape,
  1045. dtype=ref_var.dtype,
  1046. type=ref_var.type,
  1047. persistable=ref_var.persistable,
  1048. stop_gradient=ref_var.stop_gradient,
  1049. )
  1050. # 3.a. add ops in current recompute_segment as forward recomputation ops
  1051. buffer_descs = _add_needed_descs_to_block(
  1052. ff_ops, buffer_block, block, vars_in_memory, grad_op_id_to_fwd_op
  1053. )
  1054. added_descs = _add_descs_to_block(
  1055. ff_ops, local_block, grad_op_id_to_fwd_op
  1056. )
  1057. # 3.b. rename all non-checkpoint variables in recomputation ops
  1058. for key in var_name_dict:
  1059. _rename_arg_(buffer_descs, key, var_name_dict[key])
  1060. # added_descs should be in grad_op_descs because it is backward op desc
  1061. grad_op_descs.extend(buffer_descs)
  1062. # 3.c. add backward ops for all ops in current segment
  1063. for op_desc in reversed(added_descs):
  1064. grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
  1065. op_desc, no_grad_dict[block.idx], []
  1066. )
  1067. # record the mapping between fwd and bwd
  1068. if grad_op_id_to_fwd_op is not None:
  1069. for g_op_desc in grad_op_desc:
  1070. grad_op_id_to_fwd_op[
  1071. g_op_desc.original_id()
  1072. ] = grad_op_id_to_fwd_op[op_desc.original_id()]
  1073. # Set device for grad_op according to forward Op
  1074. if op_desc.has_attr(device_attr_name):
  1075. op_device = op_desc.attr(device_attr_name)
  1076. for g_op_desc in grad_op_desc:
  1077. g_op_desc._set_attr(device_attr_name, op_device)
  1078. for key in var_name_dict:
  1079. _rename_arg_(grad_op_desc, key, var_name_dict[key])
  1080. grad_op_descs.extend(grad_op_desc)
  1081. grad_to_var.update(op_grad_to_var)
  1082. # 3.d. add sum op for repetitive_outputs
  1083. grad_op_descs = _addup_repetitive_outputs_(
  1084. grad_op_descs, block.idx, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op
  1085. )
  1086. # 4) remove no grad branch as it is in _remove_no_grad_branch_
  1087. grad_op_descs = _remove_no_grad_branch_(
  1088. grad_op_descs,
  1089. no_grad_dict[block.idx],
  1090. grad_op_id_to_fwd_op,
  1091. target_vars,
  1092. )
  1093. added_descs = _add_descs_to_block(
  1094. grad_op_descs, target_block, grad_op_id_to_fwd_op
  1095. )
  1096. return (
  1097. program_stat,
  1098. checkpoints_name,
  1099. vars_should_be_hold,
  1100. recompute_segments,
  1101. )
  1102. def _get_sub_block_path(
  1103. sub_block,
  1104. sub_block_op_desc,
  1105. no_grad_set,
  1106. op_path_dict,
  1107. sub_block_target_names=None,
  1108. ):
  1109. """
  1110. Get output vars in subblock which will be assigned to parent block.
  1111. It is used to find the grad path in subblock.
  1112. Args:
  1113. sub_block(Block): The sub-block in which to get op path.
  1114. sub_block_op_desc: The op desc of the sub-block op such as 'while', 'conditional_block' and 'recurrent'.
  1115. no_grad_set(set): The set of no grad var name. no_grad_set will be changed.
  1116. op_path_dict(dict): op_path_dict will be changed.
  1117. key(int) block index
  1118. val(list) the op path of block(index)
  1119. sub_block_target_names(set): Target var names of sub-block.
  1120. Return:
  1121. The forward op path of sub-block corresponding to backward op.
  1122. """
  1123. assert sub_block_op_desc.has_attr(
  1124. "sub_block"
  1125. ) and sub_block.idx == sub_block_op_desc._block_attr_id("sub_block")
  1126. assert isinstance(sub_block_target_names, (set, type(None)))
  1127. if sub_block_target_names is None:
  1128. sub_block_target_names = sub_block_op_desc.output_arg_names
  1129. # TODO(huihuangzheng): add support for recurrent op.
  1130. if sub_block_op_desc.type in ["conditional_block", "while"]:
  1131. # Step1: get the output vars in sub-block
  1132. sub_outputs = [
  1133. sub_block._var_recursive(var) for var in sub_block_target_names
  1134. ]
  1135. for var in sub_block_target_names:
  1136. for op_desc in sub_block.ops:
  1137. if var in op_desc.output_arg_names:
  1138. for name in op_desc.input_arg_names:
  1139. sub_outputs.append(sub_block._var_recursive(name))
  1140. # Step2: find op path of sub-block
  1141. is_while = sub_block_op_desc.type in ["while"]
  1142. sub_block_op_path = _find_op_path_(
  1143. sub_block, sub_outputs, [], no_grad_set, op_path_dict, is_while
  1144. )
  1145. return sub_block_op_path
  1146. return sub_block.ops
  1147. def _is_grad_op_(op):
  1148. op_maker = core.op_proto_and_checker_maker
  1149. backward = core.op_proto_and_checker_maker.OpRole.Backward
  1150. if op_maker.kOpRoleVarAttrName() in op.attr_names and int(
  1151. op.all_attrs()[op_maker.kOpRoleAttrName()]
  1152. ) == int(backward):
  1153. return True
  1154. return False
  1155. def _rename_grad_name_(name, grad_order):
  1156. return 'grad/' * grad_order + name
  1157. def _topo_order_map(block, target_vars):
  1158. """Analysis forward block and build a mapping from:
  1159. OpDesc -> Int
  1160. """
  1161. get_defined_op = {} # mapping from String -> OpDesc (defined op)
  1162. for op in block.ops:
  1163. for out_name in op.output_arg_names:
  1164. get_defined_op[out_name] = op
  1165. topo_order_map = {} # mapping from OpDesc -> Topologic Order
  1166. queue = [var.name for var in target_vars]
  1167. visited = {var.name for var in target_vars}
  1168. topo_order_counter = 0
  1169. while len(queue) > 0:
  1170. cur_var_name = queue.pop(0)
  1171. if cur_var_name not in get_defined_op:
  1172. continue
  1173. cur_op = get_defined_op[cur_var_name]
  1174. topo_order_map[cur_op] = topo_order_counter
  1175. topo_order_counter += 1
  1176. for inp in cur_op.input_arg_names:
  1177. if inp in get_defined_op and inp not in visited:
  1178. queue.append(inp)
  1179. visited.add(inp)
  1180. return topo_order_map
  1181. def _topo_bwd_order_map(topo_fwd_map, backward_op_map):
  1182. topo_bwd_map = {}
  1183. topo_fwd_map = {op.desc: order for op, order in topo_fwd_map.items()}
  1184. for fwd_op, bwd_ops in backward_op_map.items():
  1185. if fwd_op not in topo_fwd_map:
  1186. continue
  1187. for bwd_op in bwd_ops:
  1188. topo_bwd_map[bwd_op] = topo_fwd_map[fwd_op]
  1189. return topo_bwd_map
  1190. def _append_backward_ops_(
  1191. block,
  1192. ops,
  1193. target_vars,
  1194. target_block,
  1195. no_grad_dict,
  1196. grad_to_var,
  1197. callbacks=None,
  1198. input_grad_names_set=None,
  1199. op_path_dict=None,
  1200. distop_context=None,
  1201. rename_var_map=None,
  1202. grad_op_id_to_fwd_op=None,
  1203. ):
  1204. """
  1205. Create all grad ops, and insert them into given block
  1206. Args:
  1207. block(Block): the block where forward ops are
  1208. ops(Op): the forward operators whose backward ops need to be added
  1209. target_vars(list[Tensor]): the loss vars we want to calculate gradient.
  1210. target_block(Block): the block which is going to hold new generated grad ops
  1211. no_grad_dict(dict):
  1212. key(int) block index
  1213. val(set) a set of variable names. These variables have no gradient
  1214. grad_to_var(dict)(output argument):
  1215. key(str): grad variable name
  1216. val(str): corresponding forward variable name
  1217. callbacks(callable object): a callable object used to decorate new generated grad ops
  1218. input_grad_names_set(set): this set is used to store the gradients' name which is
  1219. generated by backward ops, and input_grad_names_set can help to prune the unnecessary
  1220. backward ops.
  1221. op_path_dict(dict): op_path_dict will be changed.
  1222. key(int) block index
  1223. val(list) the op path of block(index)
  1224. rename_var_map(dict): used to associate target_grad var name with first grad_op input name.
  1225. Only used in for high order gradient.
  1226. """
  1227. # Build the mapping between the forward op and backward op (Only for auto parallel)
  1228. def update_distop_context(
  1229. distop_context, op_grad_to_var, appending_grad_times
  1230. ):
  1231. distop_context.grad_var_to_var[appending_grad_times].update(
  1232. op_grad_to_var
  1233. )
  1234. for op_desc in grad_op_desc:
  1235. assert (
  1236. op_desc.original_id() not in distop_context.grad_op_id_to_op_id
  1237. )
  1238. distop_context.grad_op_id_to_op_id[
  1239. op_desc.original_id()
  1240. ] = op.desc.original_id()
  1241. if callbacks is not None:
  1242. assert isinstance(callbacks, (list, tuple))
  1243. for cb in callbacks:
  1244. if not callable(cb):
  1245. raise ValueError("'callback' must be a callable object.")
  1246. # grad_op_descs holds created grad_op, and will be appended to target_block
  1247. grad_op_descs = []
  1248. program = block.program
  1249. get_backward_op_desc = {} # for topo order map
  1250. if rename_var_map is None:
  1251. rename_var_map = {}
  1252. assert isinstance(rename_var_map, dict)
  1253. if core._is_bwd_prim_enabled():
  1254. composite_block = program.clone().current_block()
  1255. # Create output and infer shape for operators whose output haven't
  1256. # been created.
  1257. for op in composite_block.ops:
  1258. for name in op.output_arg_names:
  1259. if not (
  1260. composite_block.desc.has_var_recursive(name.encode())
  1261. or name == core.empty_var_name()
  1262. ):
  1263. composite_block.create_var(name=name)
  1264. op.desc.infer_var_type(composite_block.desc)
  1265. op.desc.infer_shape(composite_block.desc)
  1266. # add grad_op_desc by reversed ops
  1267. for op in reversed(ops):
  1268. grad_sub_block_list = []
  1269. # If the op has its own sub-block, deal with the sub-block first
  1270. if op.has_attr("sub_block"):
  1271. sub_block = program.block(op._block_attr_id("sub_block"))
  1272. grad_sub_block = program._create_block()
  1273. grad_sub_block._set_forward_block_idx(sub_block.idx)
  1274. # see following comments for why set None here.
  1275. pre_input_grad_names_set = copy.copy(input_grad_names_set)
  1276. input_grad_names_set = None
  1277. sub_block_path = op_path_dict[op._block_attr_id("sub_block")]
  1278. _append_backward_ops_(
  1279. sub_block,
  1280. sub_block_path,
  1281. target_vars,
  1282. grad_sub_block,
  1283. no_grad_dict,
  1284. grad_to_var,
  1285. callbacks,
  1286. input_grad_names_set,
  1287. op_path_dict,
  1288. grad_op_id_to_fwd_op=grad_op_id_to_fwd_op,
  1289. )
  1290. input_grad_names_set = pre_input_grad_names_set
  1291. program._rollback()
  1292. grad_sub_block_list.append(grad_sub_block.desc)
  1293. # In primitive mode, raw phi GradOp will be split into multiple small
  1294. # primitive operators, and the split rules are defined in c++ level,
  1295. # see details: paddle/base/prim/api/manual/backward/composite_backward_api.h
  1296. # It means that the output's shape and dtype of previous operators which
  1297. # maybe used as the input of next operators must be known. Therefore,
  1298. # we infer shape and dtype in a sandbox block(named composite_block) for
  1299. # used in c++ level.
  1300. # For example:
  1301. # forward:
  1302. # z = multiply(x, y) //maybe broadcast in kernel
  1303. # backward:
  1304. # x_grad_unreduce = z_grad * y // maybe unreduce
  1305. # reduced_axes = get_reduced_axes(x_grad.shape, x.shape) // need known shape
  1306. # x_grad = reduce_sum(x_grad_unreduce)
  1307. grad_op_desc = []
  1308. op_grad_to_var = {}
  1309. if core._is_bwd_prim_enabled():
  1310. def find_op_index(block_desc, cur_op_desc):
  1311. for idx in range(block_desc.op_size()):
  1312. if cur_op_desc == block_desc.op(idx):
  1313. return idx
  1314. return -1
  1315. grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
  1316. composite_block.desc.op(find_op_index(block.desc, op.desc)),
  1317. no_grad_dict[composite_block.idx],
  1318. grad_sub_block_list,
  1319. )
  1320. for desc in grad_op_desc:
  1321. infershape_for_composite(composite_block, desc)
  1322. else:
  1323. # Getting op's corresponding grad_op
  1324. grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
  1325. op.desc, no_grad_dict[block.idx], grad_sub_block_list
  1326. )
  1327. # record the mapping between fwd and bwd
  1328. get_backward_op_desc[op.desc] = grad_op_desc
  1329. if grad_op_id_to_fwd_op is not None:
  1330. for op_desc in grad_op_desc:
  1331. grad_op_id_to_fwd_op[op_desc.original_id()] = op
  1332. # Build the mapping between the forward op and backward op (Only for auto parallel)
  1333. if distop_context is not None:
  1334. update_distop_context(
  1335. distop_context, op_grad_to_var, program._appending_grad_times
  1336. )
  1337. else:
  1338. default_ctx = getattr(
  1339. paddle.distributed.auto_parallel.static.dist_context,
  1340. '_g_default_distributed_context',
  1341. None,
  1342. )
  1343. if default_ctx is not None:
  1344. distop_context = default_ctx.dist_op_context
  1345. update_distop_context(
  1346. distop_context,
  1347. op_grad_to_var,
  1348. program._appending_grad_times,
  1349. )
  1350. # Set device for grad_op according to forward Op
  1351. device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
  1352. if op.desc.has_attr(device_attr_name):
  1353. op_device = op.desc.attr(device_attr_name)
  1354. for op_desc in grad_op_desc:
  1355. op_desc._set_attr(device_attr_name, op_device)
  1356. # Rename internal gradient variables in multiple backward
  1357. # so that they have different names with previous backward.
  1358. # For example:
  1359. # y = x * x, grad = base.gradients(base.gradients(y, x) + y * y, x)
  1360. # In second-time backward, gradient variable names of partial
  1361. # forward network (y * y) may be have same names with first-time
  1362. # base.gradients(y, x).
  1363. # So rename here before _addup_repetitive_outputs_.
  1364. if program._appending_grad_times > 1:
  1365. for op_desc in grad_op_desc:
  1366. forward_op_inputs = op.desc.input_arg_names()
  1367. for name in op_desc.input_arg_names():
  1368. if name in rename_var_map and name not in forward_op_inputs:
  1369. op_desc._rename_input(name, rename_var_map[name])
  1370. for name in op_desc.output_arg_names():
  1371. if "@GRAD" not in name:
  1372. continue
  1373. if block.desc.find_var(name.encode("ascii")):
  1374. new_name = _rename_grad_name_(
  1375. name, program._appending_grad_times
  1376. )
  1377. op_desc._rename_output(name, new_name)
  1378. rename_var_map[name] = new_name
  1379. if name in op_grad_to_var:
  1380. # Build the mapping between the grad var name and var name (Only for auto parallel)
  1381. if distop_context is not None:
  1382. distop_context.grad_var_to_var[
  1383. program._appending_grad_times
  1384. ][new_name] = op_grad_to_var[name]
  1385. op_grad_to_var[new_name] = op_grad_to_var[name]
  1386. op_grad_to_var.pop(name)
  1387. # If input_grad_names_set is not None, extend grad_op_descs only when
  1388. # any input grad in outputs of previous grad ops.
  1389. # But this strategy is not suited for while op for some control flow,
  1390. # for example, for while op, the grads maybe generated in next loop.
  1391. if input_grad_names_set is not None:
  1392. is_grad_name = (
  1393. lambda name: name.find(core.grad_var_suffix()) != -1
  1394. or name in input_grad_names_set
  1395. )
  1396. is_append_grad = False
  1397. # NOTE: In primitive mode, the intermediate variable generated by
  1398. # decompositing raw grad op are not satisfied the rule of 'XX@GRAD',
  1399. # which will cause it be pruned according to current pruning logic.
  1400. # For simplicity, we treat all primitive operators as one raw
  1401. # operator, and keep the pruning logic consistent with currently
  1402. # logic. The drawback of this solution is may lead to some primitive
  1403. # operators are not pruned, which is needed to fixed.
  1404. # FIXME: Optimize pruning logic from the perspective of whole graph.
  1405. input_grad_names = []
  1406. for op_desc in grad_op_desc:
  1407. input_grad_names += [
  1408. name
  1409. for name in op_desc.input_arg_names()
  1410. if is_grad_name(name)
  1411. ]
  1412. # some code of gradient ops, like increment, are not very
  1413. # standard, there is no @GRAD in these ops' inputs.
  1414. if len(input_grad_names) == 0:
  1415. is_append_grad = True
  1416. continue
  1417. if _some_in_set_(input_grad_names, input_grad_names_set):
  1418. is_append_grad = True
  1419. for op_desc in grad_op_desc:
  1420. grad_op_descs.append(op_desc)
  1421. for name in op_desc.output_arg_names():
  1422. input_grad_names_set.add(name)
  1423. if is_append_grad:
  1424. grad_to_var.update(op_grad_to_var)
  1425. else:
  1426. grad_op_descs.extend(grad_op_desc)
  1427. grad_to_var.update(op_grad_to_var)
  1428. # record mapping between grad var name and var name (Only for auto parallel)
  1429. grad_var_to_var = None
  1430. if distop_context is not None:
  1431. grad_var_to_var = distop_context.grad_var_to_var[
  1432. program._appending_grad_times
  1433. ]
  1434. # sum parameter's gradients' var given multiple var gradient
  1435. if os.environ.get("FLAGS_program_topo_reorder", "False") in [
  1436. 'True',
  1437. '1',
  1438. 'true',
  1439. ]:
  1440. topo_order = _topo_order_map(block, target_vars)
  1441. topo_order_for_backward = _topo_bwd_order_map(
  1442. topo_order, get_backward_op_desc
  1443. )
  1444. else:
  1445. topo_order_for_backward = None
  1446. grad_op_descs = _addup_repetitive_outputs_(
  1447. grad_op_descs,
  1448. block.idx,
  1449. grad_var_to_var,
  1450. grad_op_id_to_fwd_op=grad_op_id_to_fwd_op,
  1451. topo_order_for_backward=topo_order_for_backward,
  1452. )
  1453. # if all outputs of the grad op are in no_grad_set, then just remove and fill zero
  1454. # if all inputs of the grad op are in no_grad_set, just remove this op
  1455. grad_op_descs = _remove_no_grad_branch_(
  1456. grad_op_descs,
  1457. no_grad_dict[block.idx],
  1458. grad_op_id_to_fwd_op,
  1459. target_vars,
  1460. )
  1461. # remove some backward ops
  1462. # TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem
  1463. if not core._is_bwd_prim_enabled():
  1464. not_need_ops = _find_not_need_ops(
  1465. grad_op_descs, ops, input_grad_names_set
  1466. )
  1467. grad_op_descs = [
  1468. op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
  1469. ]
  1470. else:
  1471. logging.debug(
  1472. "Running backward composite and disable find_not_need_ops"
  1473. )
  1474. # append op_desc in grad_op_descs to target_block
  1475. op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
  1476. backward = core.op_proto_and_checker_maker.OpRole.Backward
  1477. for op_desc in grad_op_descs:
  1478. new_op_desc = target_block.desc.append_op()
  1479. new_op_desc.copy_from(op_desc)
  1480. new_op_desc._set_attr(op_role_attr_name, backward)
  1481. grad_to_var["__current_op_desc__"] = new_op_desc
  1482. if callbacks is not None:
  1483. assert isinstance(callbacks, (list, tuple))
  1484. for cb in callbacks:
  1485. cb(block=target_block, context=grad_to_var)
  1486. def _is_grad_var_(var_name):
  1487. return core.grad_var_suffix() in var_name
  1488. # Find the op who holds the sub_block as its "sub_block" attr
  1489. def _find_parent_op_(sub_block):
  1490. sub_block_id = sub_block.idx
  1491. if sub_block_id == 0:
  1492. return None
  1493. program = sub_block.program
  1494. for block_id in range(program.num_blocks):
  1495. block_desc = program.block(block_id).desc
  1496. for op_idx in range(block_desc.op_size()):
  1497. op = block_desc.op(op_idx)
  1498. if (
  1499. op.has_attr("sub_block")
  1500. and op._block_attr_id("sub_block") == sub_block_id
  1501. ):
  1502. return op
  1503. # NOTE(paddle-dev): When optimizer is added in conditional block,
  1504. # sub_block may not be found.
  1505. return None
  1506. def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
  1507. """
  1508. Create new variables required by backward pass.
  1509. Args:
  1510. block(Block): the block where new variables will be created
  1511. start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created
  1512. grad_to_var(dict):
  1513. key(str): grad variable name
  1514. val(str): corresponding forward variable name
  1515. In most cases, this dict is generated by _append_backward_ops_()
  1516. grad_info_map(dict)(output argument):
  1517. key(str): forward variable name
  1518. val(tuple): a tuple of (str, Block), str is the corresponding grad name, Block is the block containing grad variable
  1519. """
  1520. ops_to_remove = []
  1521. '''
  1522. NOTE(paddle-dev): while_grad op may hold some inputs which are not found
  1523. in the parent/forward block, and they are also the outputs of while_grad
  1524. op. These kinds of inputs are the recursive outputs inside while_grad op.
  1525. They should be considered as "already created" when scanning the inner
  1526. ops of while_grad ops.
  1527. '''
  1528. parent_op = _find_parent_op_(block)
  1529. parent_op_vars = []
  1530. if parent_op is not None:
  1531. input_args = parent_op.input_arg_names()
  1532. output_args = parent_op.output_arg_names()
  1533. for in_arg in input_args:
  1534. if in_arg in output_args:
  1535. parent_op_vars.append(in_arg)
  1536. for op_idx in range(start_op_idx, block.desc.op_size()):
  1537. op_desc = block.desc.op(op_idx)
  1538. if op_desc.has_attr("sub_block"):
  1539. sub_block = block.program.block(op_desc._block_attr_id("sub_block"))
  1540. _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
  1541. grad_var_ins = [
  1542. var for var in op_desc.input_arg_names() if _is_grad_var_(var)
  1543. ]
  1544. grad_var_outs = [
  1545. var for var in op_desc.output_arg_names() if _is_grad_var_(var)
  1546. ]
  1547. inputs = [
  1548. var
  1549. for var in op_desc.input_arg_names()
  1550. if var != core.empty_var_name()
  1551. ]
  1552. outputs = [
  1553. var
  1554. for var in op_desc.output_arg_names()
  1555. if var != core.empty_var_name()
  1556. ]
  1557. # If the outputs of grad op is empty, just remove it
  1558. if not outputs:
  1559. ops_to_remove.append(op_idx)
  1560. continue
  1561. else:
  1562. '''
  1563. If the output is not empty and there is any grad input, find
  1564. whether there is any existing input. If not, just remove it.
  1565. '''
  1566. if grad_var_ins:
  1567. existing_grad_var_ins = [
  1568. var
  1569. for var in grad_var_ins
  1570. if block.desc.has_var_recursive(var.encode())
  1571. or var in parent_op_vars
  1572. ]
  1573. if not existing_grad_var_ins:
  1574. ops_to_remove.append(op_idx)
  1575. continue
  1576. # sum may create invalid variable, here to deal with it.
  1577. if op_desc.type() == 'sum':
  1578. new_inputs = []
  1579. for grad_var_name in op_desc.input_arg_names():
  1580. if block.desc.has_var_recursive(grad_var_name.encode()):
  1581. # meet invalid sum variables, remove the invalid operand.
  1582. new_inputs.append(grad_var_name)
  1583. assert (
  1584. len(new_inputs) > 0
  1585. ), "After remove invalid variables, sum op have no inputs."
  1586. op_desc.set_input("X", new_inputs)
  1587. new_vars = set()
  1588. # create new gradient variables
  1589. for grad_var_name in op_desc.output_arg_names():
  1590. if (
  1591. block.desc.has_var_recursive(grad_var_name.encode())
  1592. or grad_var_name == core.empty_var_name()
  1593. ):
  1594. continue
  1595. block.desc.var(grad_var_name.encode())
  1596. new_vars.add(grad_var_name)
  1597. if grad_var_name not in grad_to_var:
  1598. continue
  1599. grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
  1600. # infer_shape and infer_type
  1601. op_desc.check_attrs()
  1602. op_desc.infer_var_type(block.desc)
  1603. op_desc.infer_shape(block.desc)
  1604. for arg in op_desc.output_arg_names():
  1605. if arg in new_vars:
  1606. _infer_var_data_type_shape_(arg, block)
  1607. for op_idx in reversed(ops_to_remove):
  1608. block.desc._remove_op(op_idx, op_idx + 1)
  1609. def infershape_for_composite(block, grad_op_desc):
  1610. # NOTE: why pruning the operator with empty output here ?
  1611. # Some backward operator will output empty var, which will cause infer
  1612. # shape error, such assign with input's stop_gradient=True
  1613. if len(grad_op_desc.output_arg_names()) == 0:
  1614. return
  1615. # create output variable
  1616. new_vars = set()
  1617. for grad_var_name in grad_op_desc.output_arg_names():
  1618. if not (
  1619. block.desc.has_var_recursive(grad_var_name.encode())
  1620. or grad_var_name == core.empty_var_name()
  1621. ):
  1622. # NOTE: stop_gradient will be set in append_op
  1623. desc = block.desc.var(grad_var_name.encode())
  1624. block.create_var(name=grad_var_name, desc=desc, type=desc.type())
  1625. new_vars.add(grad_var_name)
  1626. # NOTE For the primitive operator generated by decompositing phi grad kernel,
  1627. # we Operator to reconstruct the op_desc for reusing some complex logic, such
  1628. # as processing dispensable input, intermediate output, extra attrs, etc...
  1629. if framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()):
  1630. op = block.append_op(
  1631. type=grad_op_desc.type(),
  1632. inputs={
  1633. name: [block._find_var_recursive(arg) for arg in args]
  1634. for name, args in grad_op_desc.inputs().items()
  1635. },
  1636. outputs={
  1637. name: [block._find_var_recursive(arg) for arg in args]
  1638. for name, args in grad_op_desc.outputs().items()
  1639. },
  1640. # NOTE Runtime attr will be ignore as the c++ GetRuntimeAttr
  1641. # interface cann't be exported to python. Please note the WARNING
  1642. # message logged in RuntimeAttrs of composite_grad_desc_maker.h
  1643. attrs=grad_op_desc.get_attr_map(),
  1644. )
  1645. op.desc._set_attr(
  1646. core.op_proto_and_checker_maker.kOpRoleAttrName(),
  1647. core.op_proto_and_checker_maker.OpRole.Backward,
  1648. )
  1649. grad_op_desc.copy_from(op.desc)
  1650. # For the backward operator, we reuse the logic of _append_backward_var
  1651. else:
  1652. op_desc = block.desc.append_op()
  1653. op_desc.copy_from(grad_op_desc)
  1654. op_desc._set_attr(
  1655. core.op_proto_and_checker_maker.kOpRoleAttrName(),
  1656. core.op_proto_and_checker_maker.OpRole.Backward,
  1657. )
  1658. op_desc.check_attrs()
  1659. op_desc.infer_var_type(block.desc)
  1660. op_desc.infer_shape(block.desc)
  1661. grad_op_desc.copy_from(op_desc)
  1662. if not framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()):
  1663. # NOTE: Some raw base grad operators which hadn't been decomposed may not
  1664. # implement InferVarType method, such as elementwise_xx_grad, and it will
  1665. # cause the dtype or shape of corresponding cotangent incorrect. This
  1666. # patch set the cotangent dtype and shape same with corresponding
  1667. # forward variable. For primitive operators, we have ensure all
  1668. # InferVarType method to be executed correctly in PR#52818, we skip
  1669. # this patch for primitive operators.
  1670. for arg in grad_op_desc.output_arg_names():
  1671. if arg in new_vars:
  1672. _infer_var_data_type_shape_(arg, block)
  1673. def _rename_grad_(
  1674. block, start_op_idx, grad_to_var, target_grad_map, skip_rename_var_list
  1675. ):
  1676. var_map = copy.copy(target_grad_map)
  1677. for op_idx in range(start_op_idx, block.desc.op_size()):
  1678. op_desc = block.desc.op(op_idx)
  1679. for name in op_desc.input_arg_names():
  1680. if name in var_map:
  1681. op_desc._rename_input(name, var_map[name])
  1682. for name in op_desc.output_arg_names():
  1683. if "@GRAD" not in name:
  1684. continue
  1685. if block.desc.find_var(name.encode("ascii")):
  1686. if name in skip_rename_var_list:
  1687. continue
  1688. new_name = unique_name.generate(name)
  1689. op_desc._rename_output(name, new_name)
  1690. var_map[name] = new_name
  1691. for g, ng in var_map.items():
  1692. if g in grad_to_var:
  1693. grad_to_var[ng] = grad_to_var[g]
  1694. grad_to_var.pop(g)
  1695. def _get_stop_gradients_(program):
  1696. no_grad_dict = {}
  1697. assert isinstance(program, framework.Program)
  1698. for block in program.blocks:
  1699. assert isinstance(block, framework.Block)
  1700. block_no_grad_set = set()
  1701. for var in list(block.vars.values()):
  1702. assert isinstance(var, framework.Variable)
  1703. if var.stop_gradient:
  1704. block_no_grad_set.add(_append_grad_suffix_(var.name))
  1705. no_grad_dict[block.idx] = block_no_grad_set
  1706. return no_grad_dict
  1707. def _get_son_parent_block_idx_dict(program, current_block_idx):
  1708. son_parent_block_idx_dict = collections.OrderedDict()
  1709. while current_block_idx >= 0:
  1710. parent_block_idx = program.block(current_block_idx).parent_idx
  1711. son_parent_block_idx_dict[current_block_idx] = parent_block_idx
  1712. current_block_idx = parent_block_idx
  1713. return son_parent_block_idx_dict
  1714. def _get_no_grad_set_name(no_grad_set):
  1715. no_grad_set_name = set()
  1716. if no_grad_set is not None:
  1717. if isinstance(no_grad_set, (set, list, tuple)):
  1718. for i, no_grad_var in enumerate(no_grad_set):
  1719. if isinstance(no_grad_var, framework.Variable):
  1720. no_grad_set_name.add(no_grad_var.name)
  1721. elif isinstance(no_grad_var, str):
  1722. no_grad_set_name.add(no_grad_var)
  1723. else:
  1724. raise TypeError(
  1725. "The type of no_grad_set's member must be paddle.base.Variable or str, but received %s."
  1726. % (type(no_grad_var))
  1727. )
  1728. else:
  1729. raise TypeError(
  1730. f"The type of no_grad_set should be set or list or tuple, but received {type(no_grad_set)}"
  1731. )
  1732. return no_grad_set_name
  1733. def _get_no_grad_set_value(no_grad_set):
  1734. no_grad_set_value = paddle.autograd.backward_utils.ValueSet()
  1735. if no_grad_set is not None:
  1736. if isinstance(no_grad_set, (set, list, tuple)):
  1737. for i, no_grad_value in enumerate(no_grad_set):
  1738. if isinstance(no_grad_value, paddle.pir.Value):
  1739. no_grad_set_value.add(no_grad_value)
  1740. else:
  1741. raise TypeError(
  1742. "The type of no_grad_set's member must be paddle.pir.Value, but received %s."
  1743. % (type(no_grad_value))
  1744. )
  1745. else:
  1746. raise TypeError(
  1747. f"The type of no_grad_set should be set or list or tuple, but received {type(no_grad_set)}"
  1748. )
  1749. return no_grad_set_value
  1750. @framework.static_only
  1751. def append_backward(
  1752. loss,
  1753. parameter_list=None,
  1754. no_grad_set=None,
  1755. callbacks=None,
  1756. checkpoints=None,
  1757. distop_context=None,
  1758. ):
  1759. """
  1760. :api_attr: Static Graph
  1761. This function appends backward part to main_program.
  1762. A complete neural network training is made up of forward and backward
  1763. propagation. However, when we configure a network, we only need to
  1764. specify its forward part. This function uses the chain rule to automatically
  1765. generate the backward part according to the forward part.
  1766. In most cases, users do not need to invoke this function manually.
  1767. It will be automatically invoked by the optimizer's `minimize` function.
  1768. Parameters:
  1769. loss(Tensor): The loss Tensor of the network.
  1770. parameter_list(list[Tensor|str]|tuple[Tensor|str], optional): List/Tuple of Parameters or Parameter.names
  1771. that need to be updated by optimizers.
  1772. If it is None, all parameters
  1773. will be updated.
  1774. Default: None.
  1775. no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients
  1776. should be ignored. All Tensors with
  1777. `stop_gradient=True` from all blocks will
  1778. be automatically added into this set.
  1779. If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set.
  1780. Default: None.
  1781. callbacks(list[callable object]|tuple[callable object], optional): List/Tuple of callback functions.
  1782. The callbacks are used for
  1783. doing some custom jobs during
  1784. backward part building. All
  1785. callable objects in it will
  1786. be invoked once each time a
  1787. new gradient operator is added
  1788. into the program. The callable
  1789. object must have two input
  1790. parameters: ``block`` and ``context`` .
  1791. The ``block`` is the :ref:`api_guide_Block_en` which
  1792. the new gradient operator will
  1793. be added to. The ``context`` is a
  1794. map, whose keys are gradient
  1795. Tensor names and values are
  1796. corresponding original :ref:`api_guide_tensor_en` .
  1797. In addition to this, the ``context``
  1798. has another special key-value pair:
  1799. the key is string ``__current_op_desc__``
  1800. and the value is the op_desc of the
  1801. gradient operator who has just
  1802. triggered the callable object.
  1803. Default: None.
  1804. Returns:
  1805. list of tuple ( :ref:`api_guide_tensor_en` , :ref:`api_guide_tensor_en` ): Pairs of parameter and its corresponding gradients.
  1806. The key is the parameter and the value is gradient Tensor.
  1807. Raises:
  1808. AssertionError: If ``loss`` is not an instance of Tensor.
  1809. Examples:
  1810. .. code-block:: python
  1811. >>> import paddle
  1812. >>> import paddle.nn.functional as F
  1813. >>> paddle.enable_static()
  1814. >>> x = paddle.static.data(name='x', shape=[None, 13], dtype='int64')
  1815. >>> y = paddle.static.data(name='y', shape=[None, 1], dtype='float32')
  1816. >>> x_emb = paddle.static.nn.embedding(x, size=[100, 256])
  1817. >>> y_predict = paddle.static.nn.fc(x=x_emb, size=1, activation=None, name='my_fc')
  1818. >>> loss = F.square_error_cost(input=y_predict, label=y)
  1819. >>> avg_loss = paddle.mean(loss)
  1820. >>> # Get all weights in main_program, not include bias.
  1821. >>> all_weights = [param for param in paddle.static.default_main_program().block(0).all_parameters() if 'w_' in param.name]
  1822. >>> all_weights_name = [w.name for w in all_weights]
  1823. >>> # return all param_grads needed to be updated if parameter_list set default None.
  1824. >>> p_g_list1 = paddle.static.append_backward(loss=avg_loss)
  1825. >>> # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)]
  1826. >>> # return the param_grads corresponding to parameter_list that can be list of param (Tensor).
  1827. >>> p_g_list2 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights)
  1828. >>> # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)]
  1829. >>> # parameter_list can be list of param.name (str).
  1830. >>> p_g_list3 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights_name)
  1831. >>> # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)]
  1832. >>> # no_grad_set can be set of Tensors that means grad will be cut off from these Tensors.
  1833. >>> p_g_list4 = paddle.static.append_backward(loss=avg_loss, no_grad_set=set([x_emb]))
  1834. >>> # output: [(my_fc.w_0, my_fc.w_0@GRAD), (my_fc.b_0, my_fc.b_0@GRAD)]
  1835. >>> # no_grad_set can be set of Tensor.name when the Tensor is created inside layers and can't be specified explicitly.
  1836. >>> p_g_list5 = paddle.static.append_backward(loss=avg_loss, no_grad_set=set(['my_fc.b_0']))
  1837. >>> # output: [(embedding_0.w_0, embedding_0.w_0@GRAD), (my_fc.w_0, my_fc.w_0@GRAD)]
  1838. >>> # return [] because all param_grads are filtered by no_grad_set.
  1839. >>> p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights))
  1840. """
  1841. if framework.in_pir_mode():
  1842. return paddle.autograd.ir_backward.append_backward(
  1843. loss, parameter_list, no_grad_set
  1844. )
  1845. grad_op_id_to_fwd_op = (
  1846. {}
  1847. ) # for cuda graph usage, recording the mapping between grad op original id to fwd op
  1848. check_type(
  1849. loss, 'loss', framework.Variable, 'paddle.static.append_backward'
  1850. )
  1851. if loss.op is None:
  1852. # the loss is from a cloned program. Find loss op manually.
  1853. _find_loss_op_(loss)
  1854. loss.op._set_attr(
  1855. core.op_proto_and_checker_maker.kOpRoleAttrName(),
  1856. int(core.op_proto_and_checker_maker.OpRole.Forward)
  1857. | int(core.op_proto_and_checker_maker.OpRole.Loss),
  1858. )
  1859. if callbacks is not None:
  1860. check_type(
  1861. callbacks,
  1862. 'callbacks',
  1863. (list, tuple),
  1864. 'paddle.static.append_backward',
  1865. )
  1866. program = loss.block.program
  1867. root_block = program.block(0)
  1868. current_block_idx = program.current_block_idx
  1869. current_block = program.block(current_block_idx)
  1870. is_in_control_flow = current_block_idx != 0
  1871. # Double grad is not supported in sub-block (control flow)
  1872. if not is_in_control_flow:
  1873. # _appending_grad_times used for double grad
  1874. program._appending_grad_times += 1
  1875. if no_grad_set is None:
  1876. no_grad_set = set()
  1877. else:
  1878. no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set))
  1879. no_grad_dict = _get_stop_gradients_(program)
  1880. # no_grad_set only contains vars in block 0
  1881. # Todo(liym27): support vars in sub block
  1882. no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set)))
  1883. # Currently it is only to support the optimizer.minimize
  1884. # in a switch branch, which can append_backward in a sub_block.
  1885. # Note: while_loop is in control flow, but it makes no sense to call optimizer in while.
  1886. # Todo: report error when it is in while_loop
  1887. if is_in_control_flow:
  1888. # create grad block if in switch control flow.
  1889. target_grad_block = program._create_block(
  1890. parent_idx=current_block.parent_idx
  1891. )
  1892. target_grad_block._set_forward_block_idx(current_block_idx)
  1893. # after _create_block, program.current_block changes
  1894. else:
  1895. target_grad_block = root_block
  1896. son_parent_block_idx_dict = _get_son_parent_block_idx_dict(
  1897. program, current_block_idx
  1898. )
  1899. block_fwd_op_num_dict = {} # block_id: fwd_op_num
  1900. for idx in son_parent_block_idx_dict:
  1901. block_fwd_op_num_dict[idx] = program.block(idx).desc.op_size()
  1902. grad_to_var = {}
  1903. # pass the cuda_graph_attr to the fill_constant which generates the loss_grad
  1904. op_desc = _create_loss_op_desc_(loss)
  1905. grad_op_id_to_fwd_op[op_desc.original_id()] = loss.op
  1906. target_grad_block.desc.append_op().copy_from(op_desc)
  1907. for block_idx in son_parent_block_idx_dict:
  1908. block = program.block(block_idx)
  1909. block_no_grad_set = set(
  1910. map(_strip_grad_suffix_, no_grad_dict[block_idx])
  1911. )
  1912. op_path_dict = {}
  1913. op_path = _find_op_path_(
  1914. block, [loss], [], block_no_grad_set, op_path_dict
  1915. )
  1916. no_grad_set = _find_no_grad_vars(
  1917. block, op_path, [loss], block_no_grad_set
  1918. )
  1919. block_no_grad_set.update(no_grad_set)
  1920. no_grad_dict[block_idx].update(
  1921. list(map(_append_grad_suffix_, block_no_grad_set))
  1922. )
  1923. input_grad_names_set = None
  1924. # For double backward, input_grad_names is used for filtering
  1925. # some non-used gradients op(s).
  1926. # TODO(liym27): need a better design.
  1927. # not support double grad in control flow sub-block now.
  1928. if not is_in_control_flow:
  1929. if program._appending_grad_times > 1:
  1930. input_grad_names_set = {_append_grad_suffix_(loss.name)}
  1931. # TODO: support _append_backward_ops_with_checkpoints_ in
  1932. # sub-block (control flow)
  1933. is_recompute = False
  1934. if (
  1935. checkpoints is not None
  1936. and isinstance(checkpoints, list)
  1937. and len(checkpoints) > 0
  1938. ):
  1939. is_recompute = True
  1940. (
  1941. program_stat,
  1942. checkpoint_names,
  1943. vars_should_be_hold,
  1944. recompute_segments,
  1945. ) = _append_backward_ops_with_checkpoints_(
  1946. root_block,
  1947. op_path,
  1948. [loss],
  1949. root_block,
  1950. no_grad_dict,
  1951. grad_to_var,
  1952. checkpoints,
  1953. grad_op_id_to_fwd_op,
  1954. )
  1955. else:
  1956. _append_backward_ops_(
  1957. block, # the block where forward ops are in
  1958. op_path,
  1959. [loss],
  1960. target_grad_block,
  1961. no_grad_dict,
  1962. grad_to_var,
  1963. callbacks,
  1964. input_grad_names_set=input_grad_names_set,
  1965. op_path_dict=op_path_dict,
  1966. distop_context=distop_context,
  1967. grad_op_id_to_fwd_op=grad_op_id_to_fwd_op,
  1968. )
  1969. grad_info_map = {}
  1970. # if in control flow, target_grad_block is a created new block which only contains grad ops,
  1971. # so fwd_op_num is set to 0.
  1972. fwd_op_num = (
  1973. block_fwd_op_num_dict[current_block_idx]
  1974. if not is_in_control_flow
  1975. else 0
  1976. )
  1977. # Because append_backward may be called multiple times,
  1978. # we need rename the internal gradient variables so that they have
  1979. # different names.
  1980. _rename_grad_(target_grad_block, fwd_op_num, grad_to_var, {}, [])
  1981. _append_backward_vars_(
  1982. target_grad_block, fwd_op_num, grad_to_var, grad_info_map
  1983. )
  1984. program.current_block_idx = current_block_idx
  1985. program._sync_with_cpp()
  1986. # for cuda graph, copy the cuda graph attr from forward op to backward op
  1987. for op in target_grad_block.ops:
  1988. if grad_op_id_to_fwd_op.get(op.desc.original_id(), None) is not None:
  1989. fwd_op = grad_op_id_to_fwd_op[op.desc.original_id()]
  1990. op._cuda_graph_attr = fwd_op._cuda_graph_attr
  1991. if parameter_list is not None:
  1992. check_type(
  1993. parameter_list,
  1994. 'parameter_list',
  1995. (list, tuple, set),
  1996. 'base.backward.append_backward',
  1997. )
  1998. parameters = []
  1999. for i, param in enumerate(parameter_list):
  2000. check_type(
  2001. param,
  2002. 'parameter_list[%s]' % i,
  2003. (framework.Variable, str),
  2004. 'base.backward.append_backward',
  2005. )
  2006. if isinstance(param, framework.Variable):
  2007. parameters.append(param.name)
  2008. elif isinstance(param, str):
  2009. parameters.append(param)
  2010. else:
  2011. params = program.global_block().all_parameters()
  2012. parameters = [param.name for param in params if param.trainable]
  2013. params_and_grads = []
  2014. op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
  2015. for param in parameters:
  2016. if param not in grad_info_map:
  2017. continue
  2018. grad_info = grad_info_map[param]
  2019. grad_block = grad_info[1]
  2020. if not grad_block.has_var(grad_info[0]):
  2021. raise ValueError(
  2022. f"grad block[{grad_info[1]}] did not have grad var {grad_info[0]}"
  2023. )
  2024. # Get the param var from the global block
  2025. param_var = program.global_block().var(param)
  2026. grad_var = grad_block.var(grad_info[0])
  2027. if not is_in_control_flow:
  2028. if loss.block.has_var(grad_info[0]):
  2029. params_and_grads.append((param_var, grad_var))
  2030. else:
  2031. params_and_grads.append((param_var, None))
  2032. else:
  2033. params_and_grads.append((param_var, grad_var))
  2034. for p, g in params_and_grads:
  2035. if g is None:
  2036. continue
  2037. ops = (
  2038. grad_block.ops if is_in_control_flow else program.global_block().ops
  2039. )
  2040. for op in reversed(ops):
  2041. assert isinstance(op, framework.Operator)
  2042. if g.name in op.output_arg_names:
  2043. g.op = op
  2044. break
  2045. if g.op is None:
  2046. raise ValueError("Unexpected branch")
  2047. attr_val = [p.name, g.name]
  2048. if g.op.has_attr(op_role_var_attr_name):
  2049. attr_val.extend(g.op.attr(op_role_var_attr_name))
  2050. g.op._set_attr(op_role_var_attr_name, attr_val)
  2051. if is_recompute:
  2052. return params_and_grads, checkpoint_names
  2053. else:
  2054. return params_and_grads
  2055. def _as_list(x):
  2056. if x is None:
  2057. return []
  2058. return list(x) if isinstance(x, Sequence) else [x]
  2059. def _is_ancestor_block(ancestor_block, block):
  2060. prog = block.program
  2061. ancestor_idx = ancestor_block.idx
  2062. parent_idx = block.parent_idx
  2063. while parent_idx != -1:
  2064. if parent_idx == ancestor_idx:
  2065. return True
  2066. parent_idx = prog.block(parent_idx).parent_idx
  2067. return False
  2068. def _get_output_names(cur_block, targets):
  2069. """
  2070. In `cur_block`, get output names those linked to targets.
  2071. NOTE:
  2072. 1. `targets` can be in `cur_block`;
  2073. Usually, `targets` is in `cur_block`. However, considering control flow,
  2074. 2. `targets` may be in sub-block but `cur_block` is an ancestor of `targets[0].block`;
  2075. 3. `targets` may be in the block which is ancestor of `cur_block`.
  2076. """
  2077. block = targets[0].block if targets else cur_block
  2078. current_output_names = {out.name for out in targets}
  2079. # 1. If `targets` in cur_block or the ancestral block of `cur_block`
  2080. if block.idx == cur_block.idx or _is_ancestor_block(block, cur_block):
  2081. return current_output_names
  2082. # 2. If `cur_block` is an ancestor of `targets[0].block`, run while loop
  2083. prog = cur_block.program
  2084. while block.idx != cur_block.idx:
  2085. assert block.parent_idx != -1
  2086. parent_block = prog.block(block.parent_idx)
  2087. parent_block_output_names = set()
  2088. for op in reversed(block.ops):
  2089. if _some_in_set_(op.desc.output_arg_names(), current_output_names):
  2090. for name in op.desc.input_arg_names():
  2091. current_output_names.add(name)
  2092. if not block.desc.find_var(
  2093. name.encode()
  2094. ) and parent_block.desc.find_var(name.encode()):
  2095. parent_block_output_names.add(name)
  2096. block = parent_block
  2097. current_output_names = parent_block_output_names
  2098. return current_output_names
  2099. def _find_no_grad_vars(block, op_path, targets, no_grad_set):
  2100. """
  2101. Find the vars which is not used in the program, and
  2102. those vars belong to no_grad_var.
  2103. """
  2104. output_names = _get_output_names(block, targets)
  2105. no_grad_var = []
  2106. for i, op in reversed(list(enumerate(op_path))):
  2107. # If the op has sub_block, it is too complicated to find the correct no_grad_var.
  2108. if not op.has_attr("sub_block"):
  2109. for out_var in op.desc.output_arg_names():
  2110. if (
  2111. out_var not in output_names
  2112. and out_var not in op.desc.input_arg_names()
  2113. and not block.vars[out_var].stop_gradient
  2114. ):
  2115. no_grad_var.append(out_var)
  2116. for name in op.desc.input_arg_names():
  2117. if name not in no_grad_set:
  2118. output_names.add(name)
  2119. return set(no_grad_var)
  2120. def _find_op_path_(
  2121. block, targets, inputs, no_grad_set, op_path_dict=None, is_while=False
  2122. ):
  2123. """
  2124. It is used to find the grad path in `block`.
  2125. Args:
  2126. block(Block): The block in which to get op path.
  2127. targets(list[Variable]): The target variables.
  2128. inputs(list[Variable]): The input variables.
  2129. no_grad_set(set): The set of no grad var name. no_grad_set will be changed.
  2130. op_path_dict(dict): op_path_dict will be changed. op_path_dict will be changed.
  2131. key(int) block index
  2132. val(list) the op path of block(index)
  2133. is_while(bool): Whether or not `block` is while block
  2134. Return:
  2135. The forward op path of block corresponding to backward op.
  2136. """
  2137. input_names = {inp.name for inp in inputs}
  2138. output_names = _get_output_names(block, targets)
  2139. if op_path_dict is None:
  2140. op_path_dict = {}
  2141. relevant_op_flags = [True] * len(block.ops)
  2142. # All the inputs of the block are used if inputs is empty,
  2143. if inputs:
  2144. for i, op in enumerate(block.ops):
  2145. if _some_in_set_(
  2146. op.desc.input_arg_names(), input_names
  2147. ) and not core.has_empty_grad_op_maker(op.type):
  2148. for name in op.desc.output_arg_names():
  2149. if name not in no_grad_set:
  2150. input_names.add(name)
  2151. else:
  2152. relevant_op_flags[i] = False
  2153. for i, op in reversed(list(enumerate(block.ops))):
  2154. if op.has_attr("sub_block"):
  2155. sub_block_id = op._block_attr_id("sub_block")
  2156. sub_block = block.program.block(sub_block_id)
  2157. sub_block_target_names = output_names & set(op.output_arg_names)
  2158. sub_block_path = _get_sub_block_path(
  2159. sub_block, op, set(), op_path_dict, sub_block_target_names
  2160. )
  2161. op_path_dict[sub_block_id] = sub_block_path
  2162. if _some_in_set_(
  2163. op.desc.output_arg_names(), output_names
  2164. ) and not core.has_empty_grad_op_maker(op.type):
  2165. for name in op.desc.input_arg_names():
  2166. if name not in no_grad_set:
  2167. output_names.add(name)
  2168. else:
  2169. relevant_op_flags[i] = False
  2170. if is_while:
  2171. # If block is while block, dealing with op specifically again.
  2172. # TODO(liym27): Consider special types of ops.
  2173. for i, op in reversed(list(enumerate(block.ops))):
  2174. if relevant_op_flags[i] is False and _some_in_set_(
  2175. op.desc.output_arg_names(), output_names
  2176. ):
  2177. relevant_op_flags[i] = True
  2178. if not core.has_empty_grad_op_maker(op.type):
  2179. for name in op.desc.input_arg_names():
  2180. if name not in no_grad_set:
  2181. output_names.add(name)
  2182. op_path = [
  2183. block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i]
  2184. ]
  2185. if inputs:
  2186. for op in op_path:
  2187. for name in op.desc.input_arg_names():
  2188. if name not in input_names and block.vars[name].stop_gradient:
  2189. no_grad_set.add(name)
  2190. return op_path
  2191. def calc_gradient_helper(
  2192. targets, inputs, target_gradients=None, no_grad_set=None
  2193. ):
  2194. '''
  2195. Calculate gradient and return grad_info_map
  2196. '''
  2197. targets = _as_list(targets)
  2198. inputs = _as_list(inputs)
  2199. target_gradients = _as_list(target_gradients)
  2200. block = targets[0].block
  2201. prog = block.program
  2202. # increase appending gradients times
  2203. prog._appending_grad_times += 1
  2204. block_idx = block.idx
  2205. if not target_gradients:
  2206. target_gradients = [None] * len(targets)
  2207. if len(targets) != len(target_gradients):
  2208. raise ValueError(
  2209. "Should have the same number of target_gradients as targets"
  2210. )
  2211. if no_grad_set is None:
  2212. no_grad_set = set()
  2213. else:
  2214. no_grad_set = _get_no_grad_set_name(copy.copy(no_grad_set))
  2215. no_grad_dict = _get_stop_gradients_(prog)
  2216. no_grad_dict[0].update(list(map(_append_grad_suffix_, no_grad_set)))
  2217. fwd_op_num = block.desc.op_size()
  2218. input_grad_names_set = set()
  2219. target_grad_map = {}
  2220. rename_var_map = {}
  2221. skip_rename_var_list = []
  2222. grad_name_set = set()
  2223. for i, grad in enumerate(target_gradients):
  2224. target = targets[i]
  2225. grad_name = _append_grad_suffix_(target.name)
  2226. if grad is None:
  2227. op_desc = _create_op_desc_(
  2228. "fill_any_like",
  2229. {"X": [target.name]},
  2230. {"Out": [grad_name]},
  2231. {
  2232. "value": 1.0,
  2233. "dtype": target.dtype,
  2234. },
  2235. )
  2236. block.desc.append_op().copy_from(op_desc)
  2237. block.program._sync_with_cpp()
  2238. input_grad_names_set.add(grad_name)
  2239. skip_rename_var_list.append(grad_name)
  2240. else:
  2241. if target.block.idx != block_idx or target.block.program != prog:
  2242. raise ValueError("all targets must be in the same block")
  2243. if target.shape != grad.shape:
  2244. raise ValueError(
  2245. f"The shapes of target and grad are different: {target.name} {grad.name}"
  2246. )
  2247. target_grad_map[_append_grad_suffix_(target.name)] = grad.name
  2248. input_grad_names_set.add(grad.name)
  2249. rename_var_map[grad_name] = grad.name
  2250. grad_name_set.add(grad_name)
  2251. if core._is_bwd_prim_enabled():
  2252. core._set_prim_target_grad_name(target_grad_map)
  2253. # For double backward, input_grad_names is used for filter
  2254. # some non-used gradients op. rename_var_map is used to
  2255. # associate target_grad var name with first grad_op input name.
  2256. if prog._appending_grad_times == 1:
  2257. input_grad_names_set = None
  2258. rename_var_map = {}
  2259. for input in inputs:
  2260. if input.block.program != prog:
  2261. raise ValueError("input must be in the same program as targets")
  2262. block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
  2263. op_path_dict = {}
  2264. op_path = _find_op_path_(
  2265. block, targets, inputs, block_no_grad_set, op_path_dict
  2266. )
  2267. # only for composite to add grad_var of the last forward op
  2268. # who has more than one output, but targets only has one,
  2269. # so targets_gradients only add one grad_var,
  2270. # eg: op1 -> op2 -> var1 / var2 targets = var1,
  2271. # targets_gradients = var1_grad, need to add var2_grad here.
  2272. tmp_targets = targets
  2273. if core._is_bwd_prim_enabled():
  2274. for op in reversed(block.ops):
  2275. if op.type == "fill_any_like":
  2276. continue
  2277. # Some outputs of composite op are not needed and will be removed.
  2278. # Thus, those vars should not be added with another op.
  2279. keep_var_list = []
  2280. if op.type in core.ops_contain_none.keys():
  2281. values = core.ops_contain_none[op.type]
  2282. if isinstance(values, list):
  2283. none_vars = values
  2284. else:
  2285. none_vars = values(op)
  2286. for none_var_name in none_vars:
  2287. keep_var_list.append(op.output(none_var_name)[0])
  2288. for var_name in op.desc.output_arg_names():
  2289. if keep_var_list and (var_name in keep_var_list):
  2290. continue
  2291. grad_var_name = _append_grad_suffix_(var_name)
  2292. if grad_var_name not in grad_name_set:
  2293. op_desc = _create_op_desc_(
  2294. "fill_any_like",
  2295. {"X": [var_name]},
  2296. {"Out": [grad_var_name]},
  2297. {'value': 0, 'dtype': targets[0].dtype},
  2298. )
  2299. block.desc.append_op().copy_from(op_desc)
  2300. tmp_targets.append(block.var(var_name))
  2301. break
  2302. block.program._sync_with_cpp()
  2303. # find no grad var by op_path
  2304. no_grad_set = _find_no_grad_vars(
  2305. block, op_path, tmp_targets, block_no_grad_set
  2306. )
  2307. block_no_grad_set.update(no_grad_set)
  2308. no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set)))
  2309. grad_to_var = {}
  2310. grad_info_map = {}
  2311. _append_backward_ops_(
  2312. block,
  2313. op_path,
  2314. targets,
  2315. block,
  2316. no_grad_dict,
  2317. grad_to_var,
  2318. input_grad_names_set=input_grad_names_set,
  2319. op_path_dict=op_path_dict,
  2320. rename_var_map=rename_var_map,
  2321. )
  2322. # Because calc_gradient may be called multiple times,
  2323. # we need rename the internal gradient variables so that they have
  2324. # different names.
  2325. _rename_grad_(
  2326. block, fwd_op_num, grad_to_var, target_grad_map, skip_rename_var_list
  2327. )
  2328. _append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map)
  2329. prog._sync_with_cpp()
  2330. return grad_info_map
  2331. def _get_grad_vars(grad_info_map, inputs):
  2332. inputs = _as_list(inputs)
  2333. grad_vars = []
  2334. for input_var in inputs:
  2335. if input_var.name not in grad_info_map:
  2336. grad_vars.append(None)
  2337. else:
  2338. grad_info = grad_info_map[input_var.name]
  2339. grad_block = grad_info[1]
  2340. grad_var = grad_block.var(grad_info[0])
  2341. grad_vars.append(grad_var)
  2342. return grad_vars
  2343. def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
  2344. """
  2345. Backpropagate the gradients of targets to inputs.
  2346. Args:
  2347. targets(Tensor|list[Tensor]|tuple[Tensor]): The target Tensors
  2348. inputs(Tensor|list[Tensor]|tuple[Tensor]): The input Tensors
  2349. target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensors
  2350. of targets which has the same shape with targets, If None, ones will
  2351. be created for them.
  2352. no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients
  2353. should be ignored. All Tensors with
  2354. `stop_gradient=True` from all blocks will
  2355. be automatically added into this set.
  2356. If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set.
  2357. Default: None.
  2358. Return:
  2359. (list[Tensor]): A list of gradients for inputs
  2360. If an input does not affect targets, the corresponding gradient Tensor
  2361. will be None
  2362. """
  2363. # NOTE: If you want to modify the logic of calc_gradient, please modify
  2364. # it inside the calc_gradient_helper and _get_grad_vars functions
  2365. # to ensure the correctness of dy2st mode.
  2366. grad_info_map = calc_gradient_helper(
  2367. targets,
  2368. inputs,
  2369. target_gradients=target_gradients,
  2370. no_grad_set=no_grad_set,
  2371. )
  2372. grad_vars = _get_grad_vars(grad_info_map, inputs)
  2373. if len(grad_vars) == 1:
  2374. return grad_vars[0]
  2375. else:
  2376. return grad_vars
  2377. @framework.static_only
  2378. def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
  2379. """
  2380. Backpropagate the gradients of targets to inputs.
  2381. Args:
  2382. targets (Tensor|list[Tensor]|tuple[Tensor]): The target Tensors.
  2383. inputs (Tensor|list[Tensor]|tuple[Tensor]): The input Tensors.
  2384. target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensor
  2385. of targets which has the same shape with targets, If None, ones will
  2386. be created for them.
  2387. no_grad_set (set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients
  2388. should be ignored. All Tensors with ``stop_gradient=True`` from all blocks will
  2389. be automatically added into this set. If this parameter is not None, the Tensors or Tensor.names
  2390. in this set will be added to the default set. Default: None.
  2391. Return:
  2392. (list[Tensor]): A list of gradients for inputs
  2393. If an input does not affect targets, the corresponding gradient Tensor
  2394. will be None.
  2395. Examples:
  2396. .. code-block:: python
  2397. >>> import paddle
  2398. >>> import paddle.nn.functional as F
  2399. >>> paddle.enable_static()
  2400. >>> x = paddle.static.data(name='x', shape=[None, 2, 8, 8], dtype='float32')
  2401. >>> x.stop_gradient=False
  2402. >>> y = paddle.static.nn.conv2d(x, 4, 1, bias_attr=False)
  2403. >>> y = F.relu(y)
  2404. >>> z = paddle.static.gradients([y], x)
  2405. >>> print(z)
  2406. [var x@GRAD : LOD_TENSOR.shape(-1, 2, 8, 8).dtype(float32).stop_gradient(False)]
  2407. """
  2408. if framework.in_pir_mode():
  2409. check_type(
  2410. targets,
  2411. 'targets',
  2412. (paddle.pir.Value, list, tuple),
  2413. 'paddle.autograd.ir_backward.grad',
  2414. )
  2415. check_type(
  2416. inputs,
  2417. 'inputs',
  2418. (paddle.pir.Value, list, tuple),
  2419. 'paddle.autograd.ir_backward.grad',
  2420. )
  2421. check_type(
  2422. target_gradients,
  2423. 'target_gradients',
  2424. (paddle.pir.Value, list, tuple, type(None)),
  2425. 'paddle.autograd.ir_backward.grad',
  2426. )
  2427. check_type(
  2428. no_grad_set,
  2429. 'no_grad_set',
  2430. (
  2431. paddle.pir.Value,
  2432. list,
  2433. tuple,
  2434. set,
  2435. type(None),
  2436. ),
  2437. 'paddle.autograd.ir_backward.grad',
  2438. )
  2439. targets = _as_list(targets)
  2440. inputs = _as_list(inputs)
  2441. target_gradients = _as_list(target_gradients)
  2442. from paddle.autograd.backward_utils import ValueSet
  2443. from paddle.autograd.ir_backward import (
  2444. calc_gradient as pir_calc_gradient,
  2445. )
  2446. if no_grad_set is None:
  2447. no_grad_set = ValueSet()
  2448. else:
  2449. no_grad_set = ValueSet(no_grad_set)
  2450. input_grad = pir_calc_gradient(
  2451. targets, inputs, target_gradients, no_grad_set
  2452. )
  2453. return input_grad
  2454. check_type(
  2455. targets,
  2456. 'targets',
  2457. (framework.Variable, list, tuple),
  2458. 'paddle.static.gradients',
  2459. )
  2460. check_type(
  2461. inputs,
  2462. 'inputs',
  2463. (framework.Variable, list, tuple),
  2464. 'paddle.static.gradients',
  2465. )
  2466. check_type(
  2467. target_gradients,
  2468. 'target_gradients',
  2469. (framework.Variable, list, tuple, type(None)),
  2470. 'paddle.static.gradients',
  2471. )
  2472. outs = calc_gradient(targets, inputs, target_gradients, no_grad_set)
  2473. return _as_list(outs)
  2474. @framework.static_only
  2475. def gradients_with_optimizer(program, optimizer, inputs=None, outputs=None):
  2476. """
  2477. :api_attr: Static Graph
  2478. Backpropagate the gradients of the program and apply the gradients with the given optimizer.
  2479. Args:
  2480. program (Program): The input program.
  2481. optimizer (Optimizer): The optimizer to apply the gradients.
  2482. inputs (Tensor|list[Tensor]|tuple[Tensor], optional): The input Tensors.
  2483. If None, the inputs will be created from the input variables in the given program. Default:None.
  2484. outputs (Tensor|list[Tensor]|tuple[Tensor], optional): The output Tensors.
  2485. If None, the outputs will be created from the output variables in the given program. Default: None.
  2486. Return:
  2487. tuple: tuple (optimize_ops, params_grads), A list of operators appended
  2488. by gradients_with_optimizer and a list of (param, grad) variable pairs, param is
  2489. ``Parameter``, grad is the gradient value corresponding to the parameter.
  2490. The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
  2491. indicate program pruning. If so, the program will be pruned by ``feed`` and
  2492. ``fetch_list`` before run, see details in ``Executor``.
  2493. Examples:
  2494. .. code-block:: python
  2495. >>> import paddle
  2496. >>> import paddle.static as static
  2497. >>> paddle.enable_static()
  2498. >>> img = static.data(name='image', shape=[None, 784])
  2499. >>> pred = static.nn.fc(x=img, size=10, activation='relu')
  2500. >>> loss = paddle.mean(pred)
  2501. >>> opt = paddle.optimizer.SGD(learning_rate=1e-3)
  2502. >>> opt_ops, pram_grads = paddle.base.backward.gradients_with_optimizer(static.default_main_program(), opt)
  2503. >>> print(opt_ops)
  2504. [{ParamOut=['fc_0.b_0']} = sgd(inputs={Grad=['fc_0.b_0@GRAD'],
  2505. LearningRate=['learning_rate_0'],
  2506. MasterParam=[],
  2507. ...
  2508. with_quant_attr = False)]
  2509. """
  2510. check_type(
  2511. program,
  2512. 'program',
  2513. paddle.base.Program,
  2514. 'paddle.static.gradients_with_optimizer',
  2515. )
  2516. check_type(
  2517. optimizer,
  2518. 'optimizer',
  2519. paddle.optimizer.Optimizer,
  2520. 'paddle.static.gradients_with_optimizer',
  2521. )
  2522. if inputs is None or outputs is None:
  2523. in_set = set()
  2524. out_set = set()
  2525. for block in program.blocks:
  2526. for op in block.ops:
  2527. for name in op.input_arg_names:
  2528. in_set.add(block.vars[name])
  2529. for name in op.output_arg_names:
  2530. out_set.add(block.vars[name])
  2531. if inputs is None:
  2532. inputs = list(in_set.difference(out_set))
  2533. if outputs is None:
  2534. outputs = list(out_set.difference(in_set))
  2535. grads = gradients(outputs, inputs)
  2536. with program_guard(program, None):
  2537. pram_grads = [
  2538. (pram, grad)
  2539. for pram, grad in zip(inputs, grads)
  2540. if isinstance(pram, paddle.base.framework.Parameter)
  2541. and grad is not None
  2542. ]
  2543. optimize_ops = optimizer.apply_gradients(pram_grads)
  2544. return optimize_ops, pram_grads