optimizer.py 80 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import os
  16. from collections import defaultdict
  17. import numpy as np
  18. import paddle
  19. import paddle.autograd as imperative_base
  20. from paddle import _C_ops
  21. from paddle._pir_ops import parameter, set_parameter
  22. from paddle.autograd.backward_utils import ValueDict
  23. from paddle.base import core
  24. from paddle.base.framework import (
  25. Variable,
  26. _current_expected_place,
  27. default_main_program,
  28. device_guard,
  29. in_dygraph_mode,
  30. in_dynamic_or_pir_mode,
  31. in_pir_mode,
  32. name_scope,
  33. )
  34. from paddle.regularizer import L2Decay
  35. from ..base import framework, unique_name
  36. from ..base.backward import (
  37. _get_no_grad_set_name,
  38. _get_no_grad_set_value,
  39. append_backward,
  40. )
  41. from ..base.framework import Parameter
  42. from ..base.layer_helper import LayerHelper
  43. from .lr import LRScheduler
  44. __all__ = []
  45. g_shard_bypass_dygraph_optimizer = int(
  46. os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)
  47. )
  48. @framework.static_only
  49. def append_backward_new(
  50. loss_list,
  51. parameter_list=None,
  52. no_grad_set=None,
  53. callbacks=None,
  54. checkpoints=None,
  55. distop_context=None,
  56. ):
  57. from paddle.incubate.autograd.primx import Transform, orig2prim
  58. program = default_main_program()
  59. assert (
  60. program.num_blocks == 1
  61. ), "The append_backward_new interface is designed to process only one block."
  62. block = program.current_block()
  63. for el in loss_list:
  64. assert (
  65. el.block == block
  66. ), 'variable in loss_list should be in current block of main program'
  67. orig2prim(block)
  68. ad = Transform(block)
  69. if parameter_list is None:
  70. parameter_list = program.global_block().all_parameters()
  71. param_dot, loss_dot = ad.linearize(parameter_list, loss_list)
  72. loss_bar, param_bar = ad.transpose(loss_dot, param_dot)
  73. # remove param_dot and their constructor ops
  74. op_indexes = []
  75. for var in param_dot:
  76. if var is not None:
  77. op_index = block.ops.index(var.op)
  78. assert op_index >= 0
  79. op_indexes.append(op_index)
  80. ad.erase_ops(sorted(op_indexes))
  81. ad.erase_dots(param_dot)
  82. if len(parameter_list) == 1:
  83. params_and_grads = [(parameter_list, param_bar)]
  84. else:
  85. params_and_grads = []
  86. for i, param in enumerate(parameter_list):
  87. params_and_grads.append((param, param_bar[i]))
  88. return params_and_grads
  89. class Optimizer:
  90. r"""Optimizer Base class.
  91. Define the common interface of an optimizer.
  92. User should not use this class directly,
  93. but need to use one of it's implementation.
  94. Args:
  95. learning_rate (float|LRScheduler): The learning rate used to update ``Parameter``.
  96. It can be a float value or any subclass of ``LRScheduler`` .
  97. parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \
  98. This parameter is required in dygraph mode. And you can specify different options for \
  99. different parameter groups such as the learning rate, weight decay, etc, \
  100. then the parameters are list of dict. Note that the learning_rate in paramter groups \
  101. represents the scale of base learning_rate. \
  102. The default value is None in static graph mode, at this time all parameters will be updated.
  103. weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
  104. It canbe a float value as coeff of L2 regularization or \
  105. :ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`.
  106. If a parameter has set regularizer using :ref:`api_paddle_ParamAttr` already, \
  107. the regularization setting here in optimizer will be ignored for this parameter. \
  108. Otherwise, the regularization setting here in optimizer will take effect. \
  109. Default None, meaning there is no regularization.
  110. grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of \
  111. some derived class of ``GradientClipBase`` . There are three cliping strategies \
  112. ( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` , \
  113. :ref:`api_paddle_nn_ClipGradByValue` ). Default None, meaning there is no gradient clipping.
  114. name (str, optional): Normally there is no need for user to set this property.
  115. For more information, please refer to :ref:`api_guide_Name`.
  116. The default value is None.
  117. Returns:
  118. Base class for optimizer.
  119. Examples:
  120. .. code-block:: python
  121. >>> # Take the subclass adam as an example
  122. >>> import paddle
  123. >>> linear = paddle.nn.Linear(10, 10)
  124. >>> inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
  125. >>> out = linear(inp)
  126. >>> loss = paddle.mean(out)
  127. >>> adam = paddle.optimizer.Adam(learning_rate=0.1,
  128. ... parameters=linear.parameters())
  129. >>> loss.backward()
  130. >>> adam.step()
  131. >>> adam.clear_grad()
  132. >>> #Take the subclass sgd as an example
  133. >>> #optimize parameters in linear_1 and linear2 in different options.
  134. >>> #Note that the learning_rate of linear_2 is 0.01.
  135. >>> linear_1 = paddle.nn.Linear(10, 10)
  136. >>> linear_2 = paddle.nn.Linear(10, 10)
  137. >>> inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
  138. >>> out = linear_1(inp)
  139. >>> out = linear_2(out)
  140. >>> loss = paddle.mean(out)
  141. >>> sgd = paddle.optimizer.SGD(
  142. ... learning_rate=0.1,
  143. ... parameters=[{
  144. ... 'params': linear_1.parameters()
  145. ... }, {
  146. ... 'params': linear_2.parameters(),
  147. ... 'weight_decay': 0.001,
  148. ... 'learning_rate': 0.1
  149. ... }],
  150. ... weight_decay=0.01)
  151. >>> loss.backward()
  152. >>> sgd.step()
  153. >>> sgd.clear_grad()
  154. """
  155. @imperative_base.no_grad()
  156. def __init__(
  157. self,
  158. learning_rate,
  159. parameters=None,
  160. weight_decay=None,
  161. grad_clip=None,
  162. name=None,
  163. ):
  164. if parameters is not None:
  165. # paddle.Tensor is also iterable, so here we don't check whether
  166. # the input is iterable, if the input is paddle.Tensor, the
  167. # list(paddle.Tensor) will be a error value
  168. if isinstance(parameters, (paddle.Tensor, core.eager.Tensor)):
  169. raise TypeError(
  170. "`parameters` argument given to the optimizer should be "
  171. f"an iterable of paddle Tensors, but got argument type is `{type(parameters)}`."
  172. )
  173. if isinstance(parameters, dict):
  174. raise TypeError(
  175. "`parameters` argument should not get dict type, "
  176. "if parameter groups is needed, please set `parameters`"
  177. " as list of dict"
  178. )
  179. self._parameter_list = list(parameters)
  180. else:
  181. self._parameter_list = None
  182. self._name = name
  183. if framework.in_dygraph_mode():
  184. if self._parameter_list is None:
  185. raise AttributeError(
  186. "parameters argument given to the Optimizer should not be None in dygraph mode."
  187. )
  188. if weight_decay is not None:
  189. if not isinstance(self._parameter_list[0], dict):
  190. for param in self._parameter_list:
  191. if (
  192. hasattr(param, 'regularizer')
  193. and param.regularizer is not None
  194. ):
  195. logging.info(
  196. "If regularizer of a Parameter has been set by 'paddle.ParamAttr' or 'static.WeightNormParamAttr' already. "
  197. "The weight_decay[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
  198. % weight_decay.__str__()
  199. )
  200. break
  201. if not isinstance(learning_rate, (float, LRScheduler)):
  202. raise TypeError(
  203. "learning rate should be float or LRScheduler, got %s here"
  204. % type(learning_rate)
  205. )
  206. if grad_clip is not None:
  207. if not isinstance(grad_clip, paddle.nn.clip.GradientClipBase):
  208. raise TypeError(
  209. "'grad_clip' should be an instance of GradientClipBase's derived class"
  210. )
  211. if isinstance(weight_decay, float):
  212. self.regularization = L2Decay(weight_decay)
  213. else:
  214. self.regularization = weight_decay
  215. self._grad_clip = grad_clip
  216. self._learning_rate = learning_rate
  217. self._dtype = None
  218. # Infer the dtype form parameter
  219. if self._parameter_list:
  220. if isinstance(self._parameter_list[0], dict):
  221. for param_group in self._parameter_list:
  222. assert (
  223. 'params' in param_group
  224. ), 'params should be set in parameters if parameter groups are optimized in different options'
  225. self._dtype = self._parameter_list[0]['params'][0].dtype
  226. else:
  227. self._dtype = self._parameter_list[0].dtype
  228. # each program should have a independent learning rate
  229. # program -> tensor(learning_rate)
  230. self._learning_rate_map = {}
  231. # Dictionary of accumulators. Some optimizer subclasses need to
  232. # allocate and manage extra tensors associated with the parameters
  233. # to train. These tensors are called accumulators.
  234. # {accum_name : { paramter_name : accumulator_for_parameter, ...}, ...}
  235. self._accumulators = defaultdict(lambda: {})
  236. self.helper = None
  237. self._opti_name_list = []
  238. self._accumulators_holder = {}
  239. self._param_device_map = {}
  240. self.clear_gradients = self.clear_grad
  241. self._default_dict = {
  242. 'weight_decay': self.regularization,
  243. 'grad_clip': self._grad_clip,
  244. }
  245. self._param_groups = []
  246. if self._parameter_list and isinstance(self._parameter_list[0], dict):
  247. for param_group in self._parameter_list:
  248. self._add_param_group(param_group.copy())
  249. else:
  250. self._param_groups = self._parameter_list
  251. # NOTE: Multi Tensor: Pass in all parameters and gradients to the op kernel of the Optimizer at one time for updating for dygraph mode.
  252. # Optimizer support list: [ paddle.optimizer.Momentum, paddle.optimizer.Adam].
  253. self._use_multi_tensor = None
  254. self._param_dict = self._create_multi_tensor_dict()
  255. self._auxiliary_vars = {}
  256. self._already_create_accumulator = set()
  257. self._master_weights = {}
  258. # create master gradients' states
  259. self._create_master_grad_states()
  260. def _create_master_grad_states(self):
  261. # master gradients states
  262. if in_pir_mode():
  263. self._master_grads = ValueDict()
  264. else:
  265. self._master_grads = {}
  266. self._master_grad = False
  267. def _set_auxiliary_var(self, key, val):
  268. self._auxiliary_vars[key] = val
  269. def _create_multi_tensor_dict(self):
  270. n = len(self._param_groups) if self._param_groups is not None else 1
  271. return {
  272. 'FP32_LODTensor': [[] for _ in range(n)],
  273. 'FP16_LODTensor': [[] for _ in range(n)],
  274. }
  275. def _get_auxiliary_var(self, key):
  276. return self._auxiliary_vars.get(key, None)
  277. @framework.dygraph_only
  278. def state_dict(self):
  279. '''
  280. Get state dict information from optimizer. It contain all the tensor used by optimizer. For Adam optimizer, contains beta1, beta2, momentum etc. If LRScheduler have been used, global_step will be include in state dict.
  281. If the optimizer never be called(minimize function), the state_dict is empty.
  282. Args:
  283. None
  284. Returns:
  285. state_dict(dict) : dict contains all the Tensor used by optimizer
  286. Examples:
  287. .. code-block:: python
  288. >>> import paddle
  289. >>> emb = paddle.nn.Embedding(10, 10)
  290. >>> adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters())
  291. >>> state_dict = adam.state_dict()
  292. '''
  293. state_dict = {}
  294. if len(self._accumulators) == 0 and len(self._accumulators_holder) > 0:
  295. for name, var in self._accumulators_holder.items():
  296. state_dict[name] = var
  297. else:
  298. for k, v in self._accumulators.items():
  299. for para_name, var_tmp in v.items():
  300. state_dict[var_tmp.name] = var_tmp
  301. # save scale value for xpu
  302. if core.is_compiled_with_xpu():
  303. state_dict[
  304. var_tmp.name + ".SCALE_VALUE"
  305. ] = var_tmp.get_tensor().get_xpu_scale_value()
  306. # if has master weight and then save master weight
  307. if hasattr(self, "_master_weights"):
  308. if len(self._master_weights) != 0:
  309. state_dict["master_weights"] = self._master_weights
  310. # global step if use lr decay
  311. if isinstance(self._learning_rate, LRScheduler):
  312. state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
  313. return state_dict
  314. @framework.dygraph_only
  315. def set_state_dict(self, state_dict):
  316. '''
  317. Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If LRScheduler have been used, global_step will be changed.
  318. Args:
  319. state_dict(dict) : Dict contains all the Tensor needed by optimizer
  320. Return:
  321. None
  322. Examples:
  323. .. code-block:: python
  324. >>> import paddle
  325. >>> emb = paddle.nn.Embedding(10, 10)
  326. >>> layer_state_dict = emb.state_dict()
  327. >>> paddle.save(layer_state_dict, "emb.pdparams")
  328. >>> scheduler = paddle.optimizer.lr.NoamDecay(
  329. ... d_model=0.01, warmup_steps=100, verbose=True)
  330. >>> adam = paddle.optimizer.Adam(
  331. ... learning_rate=scheduler,
  332. ... parameters=emb.parameters())
  333. >>> opt_state_dict = adam.state_dict()
  334. >>> paddle.save(opt_state_dict, "adam.pdopt")
  335. >>> opti_state_dict = paddle.load("adam.pdopt")
  336. >>> adam.set_state_dict(opti_state_dict)
  337. '''
  338. if isinstance(self._learning_rate, LRScheduler):
  339. self._learning_rate.set_state_dict(state_dict["LR_Scheduler"])
  340. # NOTE: exclude learning rate scheduler's state from
  341. # _accumulators_holder.
  342. state_dict = state_dict.copy()
  343. if "LR_Scheduler" in state_dict:
  344. state_dict.pop("LR_Scheduler")
  345. if "master_weights" in state_dict:
  346. if hasattr(self, "_master_weights"):
  347. self._master_weights = state_dict["master_weights"]
  348. state_dict.pop("master_weights")
  349. self._accumulators_holder = state_dict
  350. for k, v in self._accumulators.items():
  351. for para_name, var_tmp in v.items():
  352. assert (
  353. var_tmp.name in state_dict
  354. ), f"optimizer Tensor {var_tmp.name} not found"
  355. var = var_tmp.value()
  356. tensor = var.get_tensor()
  357. # load scale value for xpu
  358. if core.is_compiled_with_xpu():
  359. tensor.set_xpu_scale_value(
  360. state_dict.get(var_tmp.name + ".SCALE_VALUE", -1.0)
  361. )
  362. var.set_value(state_dict[var_tmp.name])
  363. def get_opti_var_name_list(self):
  364. return self._opti_name_list
  365. def _create_global_learning_rate(self):
  366. def do_create():
  367. # lr var can't be float16 or bfloat16, for pure fp16 or bf16 training, should extra handle the dtype for lr
  368. _lr_dtype = (
  369. paddle.get_default_dtype()
  370. if self._dtype is None
  371. else self._dtype
  372. )
  373. _lr_dtype = (
  374. paddle.float32
  375. if (
  376. (
  377. paddle.get_default_dtype() != "float16"
  378. and _lr_dtype == paddle.float16
  379. )
  380. or (
  381. paddle.get_default_dtype() != "bfloat16"
  382. and _lr_dtype == paddle.bfloat16
  383. )
  384. )
  385. else _lr_dtype
  386. )
  387. if isinstance(self._learning_rate, LRScheduler):
  388. lr_var = self._global_learning_rate()
  389. # only create global lr_var once
  390. if in_pir_mode():
  391. startup_program = paddle.static.default_startup_program()
  392. main_program = paddle.static.default_main_program()
  393. lr_name = unique_name.generate('learning_rate')
  394. # startup program insert && set_parameter
  395. lr_value = float(self._learning_rate())
  396. with paddle.static.program_guard(startup_program):
  397. initializer = paddle.nn.initializer.Constant(
  398. value=lr_value
  399. )
  400. paramete_meta = paddle.pir.core.ParameterMeta(
  401. [], _lr_dtype
  402. )
  403. init_result = initializer(
  404. paramete_meta, startup_program.global_block()
  405. )
  406. init_result.persistable = True
  407. set_parameter(init_result, lr_name)
  408. main_program.set_parameters_from(startup_program)
  409. if not isinstance(lr_var, paddle.pir.Value):
  410. self._learning_rate._var_name = lr_name
  411. with paddle.static.program_guard(main_program):
  412. param = parameter(lr_name, _lr_dtype, [])
  413. param.stop_gradient = True
  414. param.persistable = True
  415. main_program.lr_scheduler = self._learning_rate
  416. main_program.lr_var = param
  417. main_program.lr_name = lr_name
  418. self._learning_rate_map[main_program] = param
  419. else:
  420. if not isinstance(lr_var, framework.Variable):
  421. lr_name = unique_name.generate('learning_rate')
  422. self._learning_rate._var_name = lr_name
  423. lr_var = self.helper.create_global_variable(
  424. name=lr_name,
  425. shape=[],
  426. persistable=True,
  427. stop_gradient=True,
  428. dtype=_lr_dtype,
  429. )
  430. main_prog = framework.default_main_program()
  431. main_prog.lr_scheduler = self._learning_rate
  432. main_prog.lr_var = lr_var
  433. self._learning_rate_map[
  434. framework.default_main_program()
  435. ] = lr_var
  436. lr_value = float(self._learning_rate())
  437. self.helper.set_variable_initializer(
  438. lr_var,
  439. initializer=paddle.nn.initializer.Constant(
  440. value=lr_value
  441. ),
  442. )
  443. elif isinstance(self._learning_rate, float):
  444. # only create global lr_var once
  445. lr = self._global_learning_rate()
  446. if in_pir_mode():
  447. if isinstance(lr, paddle.pir.Value):
  448. return
  449. else:
  450. place = _current_expected_place()
  451. if not isinstance(_lr_dtype, paddle.base.core.DataType):
  452. if isinstance(
  453. _lr_dtype, paddle.base.libpaddle.VarDesc.VarType
  454. ):
  455. _lr_dtype = paddle.pir.core.vartype_to_datatype[
  456. _lr_dtype
  457. ]
  458. else:
  459. _lr_dtype = (
  460. paddle.pir.core.convert_np_dtype_to_dtype_(
  461. _lr_dtype
  462. )
  463. )
  464. self._learning_rate_map[
  465. paddle.static.default_main_program()
  466. ] = paddle.pir.core.create_persistable_value(
  467. dtype=_lr_dtype,
  468. shape=[],
  469. name=unique_name.generate("learning_rate"),
  470. initializer=paddle.nn.initializer.ConstantInitializer(
  471. value=float(self._learning_rate)
  472. ),
  473. )
  474. else:
  475. if isinstance(lr, framework.Variable):
  476. return
  477. else:
  478. self._learning_rate_map[
  479. framework.default_main_program()
  480. ] = paddle.static.create_global_var(
  481. name=unique_name.generate("learning_rate"),
  482. shape=[],
  483. value=float(self._learning_rate),
  484. dtype=_lr_dtype,
  485. persistable=True,
  486. )
  487. with paddle.base.framework.dygraph_guard_if_declarative():
  488. do_create()
  489. @framework.dygraph_only
  490. def set_lr(self, value):
  491. """
  492. :api_attr: imperative
  493. Set the value of the learning rate manually in the optimizer. If the optimizer use LRScheduler,
  494. this API cannot be invoked, because it will lead to conflict.
  495. Args:
  496. value (float): the value of learning rate
  497. Returns:
  498. None
  499. Examples:
  500. .. code-block:: python
  501. >>> import paddle
  502. >>> linear = paddle.nn.Linear(10, 10)
  503. >>> adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters())
  504. >>> # set learning rate manually by python float value
  505. >>> lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
  506. >>> for i in range(5):
  507. ... adam.set_lr(lr_list[i])
  508. ... lr = adam.get_lr()
  509. ... print("current lr is {}".format(lr))
  510. current lr is 0.2
  511. current lr is 0.3
  512. current lr is 0.4
  513. current lr is 0.5
  514. current lr is 0.6
  515. """
  516. if not isinstance(value, (int, float)):
  517. raise TypeError(
  518. "The type of 'value' in optimizer.set_lr must be float, but received %s."
  519. % (type(value))
  520. )
  521. if isinstance(self._learning_rate, LRScheduler):
  522. raise RuntimeError(
  523. "optimizer's learning rate can't be LRScheduler when invoke this API, because this will lead to conflict."
  524. )
  525. self._learning_rate = float(value)
  526. current_lr = self._global_learning_rate()
  527. if current_lr is not None:
  528. if in_dygraph_mode():
  529. place = _current_expected_place()
  530. _C_ops.full_(
  531. current_lr,
  532. list(current_lr.shape),
  533. float(value),
  534. current_lr.dtype,
  535. place,
  536. )
  537. else:
  538. global_block = framework.default_main_program().global_block()
  539. global_block.append_op(
  540. type='fill_constant',
  541. outputs={'Out': [current_lr]},
  542. attrs={
  543. 'dtype': current_lr.dtype,
  544. 'shape': list(current_lr.shape),
  545. 'value': float(value),
  546. },
  547. stop_gradient=True,
  548. )
  549. @framework.dygraph_only
  550. def set_lr_scheduler(self, scheduler):
  551. """
  552. :api_attr: imperative
  553. Set the LRScheduler of the learning rate manually in the optimizer. If the optimizer already used LRScheduler previously,
  554. this API will set it be the new one.
  555. Args:
  556. scheduler (LRScheduler): the LRScheduler of learning rate
  557. Returns:
  558. None
  559. Examples:
  560. .. code-block:: python
  561. >>> import paddle
  562. >>> linear = paddle.nn.Linear(10, 10)
  563. >>> adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters())
  564. >>> # set learning rate manually by class LRScheduler
  565. >>> scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2,4,6], gamma=0.8)
  566. >>> adam.set_lr_scheduler(scheduler)
  567. >>> lr = adam.get_lr()
  568. >>> print("current lr is {}".format(lr))
  569. current lr is 0.5
  570. >>> # set learning rate manually by another LRScheduler
  571. >>> scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.1, step_size=5, gamma=0.6)
  572. >>> adam.set_lr_scheduler(scheduler)
  573. >>> lr = adam.get_lr()
  574. >>> print("current lr is {}".format(lr))
  575. current lr is 0.1
  576. """
  577. from paddle.optimizer.lr import LRScheduler
  578. if not isinstance(scheduler, LRScheduler):
  579. raise TypeError(
  580. "The type of 'scheduler' in optimizer.set_lr_schduler must be LRScheduler, but received %s."
  581. % (type(scheduler))
  582. )
  583. self._learning_rate = scheduler
  584. def get_lr(self):
  585. """
  586. Get current learning rate of optimizer.
  587. If 'LRScheduler' is not used, the return value is all the same.
  588. If 'LRScheduler' is used, the return value is the current scheduled learing rete.
  589. Returns:
  590. float: The current learning rate of optimizer.
  591. Examples:
  592. .. code-block:: python
  593. >>> # train on default dynamic graph mode
  594. >>> import paddle
  595. >>> import numpy as np
  596. >>> emb = paddle.nn.Embedding(10, 3)
  597. >>> ## example1: LRScheduler is not used, return the same value is all the same
  598. >>> adam = paddle.optimizer.Adam(0.01, parameters = emb.parameters())
  599. >>> for batch in range(10):
  600. ... input = paddle.randint(low=0, high=5, shape=[5])
  601. ... out = emb(input)
  602. ... out.backward()
  603. ... print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.01
  604. ... adam.step()
  605. Learning rate of step0: 0.01
  606. Learning rate of step1: 0.01
  607. Learning rate of step2: 0.01
  608. Learning rate of step3: 0.01
  609. Learning rate of step4: 0.01
  610. Learning rate of step5: 0.01
  611. Learning rate of step6: 0.01
  612. Learning rate of step7: 0.01
  613. Learning rate of step8: 0.01
  614. Learning rate of step9: 0.01
  615. >>> ## example2: StepDecay is used, return the scheduled learning rate
  616. >>> scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=2, gamma=0.1)
  617. >>> adam = paddle.optimizer.Adam(scheduler, parameters = emb.parameters())
  618. >>> for batch in range(10):
  619. ... input = paddle.randint(low=0, high=5, shape=[5])
  620. ... out = emb(input)
  621. ... out.backward()
  622. ... print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.5->0.05...
  623. ... adam.step()
  624. ... scheduler.step()
  625. Learning rate of step0: 0.5
  626. Learning rate of step1: 0.5
  627. Learning rate of step2: 0.05
  628. Learning rate of step3: 0.05
  629. Learning rate of step4: 0.005000000000000001
  630. Learning rate of step5: 0.005000000000000001
  631. Learning rate of step6: 0.0005000000000000001
  632. Learning rate of step7: 0.0005000000000000001
  633. Learning rate of step8: 5.000000000000001e-05
  634. Learning rate of step9: 5.000000000000001e-05
  635. >>> # train on static graph mode
  636. >>> paddle.enable_static()
  637. >>> main_prog = paddle.static.Program()
  638. >>> start_prog = paddle.static.Program()
  639. >>> with paddle.static.program_guard(main_prog, start_prog):
  640. ... x = paddle.static.data(name='x', shape=[None, 10])
  641. ... z = paddle.static.nn.fc(x, 100)
  642. ... loss = paddle.mean(z)
  643. ... scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=2, gamma=0.1)
  644. ... adam = paddle.optimizer.Adam(learning_rate=scheduler)
  645. ... adam.minimize(loss)
  646. >>> exe = paddle.static.Executor()
  647. >>> exe.run(start_prog)
  648. >>> for batch in range(10):
  649. ... print("Learning rate of step{}: {}".format(batch, adam.get_lr())) # 0.5->0.05->0.005...
  650. ... out = exe.run(main_prog, feed={'x': np.random.randn(3, 10).astype('float32')})
  651. ... scheduler.step()
  652. Learning rate of step0: 0.5
  653. Learning rate of step1: 0.5
  654. Learning rate of step2: 0.05
  655. Learning rate of step3: 0.05
  656. Learning rate of step4: 0.005000000000000001
  657. Learning rate of step5: 0.005000000000000001
  658. Learning rate of step6: 0.0005000000000000001
  659. Learning rate of step7: 0.0005000000000000001
  660. Learning rate of step8: 5.000000000000001e-05
  661. Learning rate of step9: 5.000000000000001e-05
  662. """
  663. if isinstance(self._learning_rate, float):
  664. return self._learning_rate
  665. else:
  666. return self._learning_rate()
  667. def _global_learning_rate(self, program=None):
  668. """
  669. get global decayed learning rate
  670. :return:
  671. """
  672. if program is None:
  673. if in_dygraph_mode():
  674. program = framework.default_main_program()
  675. else:
  676. program = paddle.static.default_main_program()
  677. return self._learning_rate_map.get(program, None)
  678. def _append_optimize_op(self, block, param_and_grad):
  679. """append optimize operator to block and return all the added optimize_op"""
  680. raise NotImplementedError(
  681. "Class \"Optimizer\" connot be used directly as an optimizer, please use its subclasses such as \"Adam\""
  682. )
  683. def _create_param_lr(self, param_and_grad):
  684. # create learning rate tensor for every parameter
  685. param = param_and_grad[0]
  686. if hasattr(param, 'optimize_attr') and param.optimize_attr is not None:
  687. param_lr = param.optimize_attr['learning_rate']
  688. if isinstance(param_lr, (Variable, paddle.pir.Value)):
  689. return param_lr
  690. else:
  691. if param_lr == 1.0:
  692. return self._global_learning_rate()
  693. else:
  694. with paddle.static.default_main_program()._lr_schedule_guard(
  695. is_with_opt=True
  696. ), framework.name_scope(
  697. 'scale_with_param_lr'
  698. ):
  699. return self._global_learning_rate() * param_lr
  700. else:
  701. return self._global_learning_rate()
  702. def _create_master_weight(self, param):
  703. if param.name in self._master_weights:
  704. var = self._master_weights[param.name]
  705. else:
  706. var_name = self._gen_master_weight_var_name(param)
  707. if in_pir_mode():
  708. startup_program = paddle.static.default_startup_program()
  709. main_program = paddle.static.default_main_program()
  710. with paddle.static.program_guard(startup_program):
  711. def get_param_from_startup(startup, name):
  712. for op in startup.global_block().ops:
  713. if (
  714. op.name() == 'builtin.set_parameter'
  715. and name == op.attrs()['parameter_name']
  716. ):
  717. return op.operand(0).source()
  718. return None
  719. startup_param = get_param_from_startup(
  720. startup_program, param.name
  721. )
  722. var = paddle.cast(startup_param, 'float32')
  723. var.persistable = True
  724. paddle._pir_ops.set_persistable_value(var, var_name)
  725. with paddle.static.program_guard(main_program):
  726. paddle.pir.reset_insertion_point_to_start()
  727. var = paddle.static.data(
  728. var_name, var.shape, var.dtype, core.Place()
  729. )
  730. var.persistable = True
  731. elif framework.in_dygraph_mode():
  732. var = paddle.cast(param, 'float32')
  733. var.name = var_name
  734. else:
  735. assert isinstance(self.helper, LayerHelper)
  736. var = paddle.static.create_global_var(
  737. name=var_name,
  738. shape=param.shape,
  739. value=0,
  740. dtype='float32',
  741. persistable=True,
  742. )
  743. block = self.helper.startup_program.global_block()
  744. block.append_op(
  745. type="cast",
  746. inputs={"X": [param]},
  747. outputs={"Out": [var]},
  748. attrs={
  749. "in_dtype": param.dtype,
  750. "out_dtype": core.VarDesc.VarType.FP32,
  751. },
  752. )
  753. self._master_weights[param.name] = var
  754. return var
  755. def _gen_master_weight_var_name(self, param):
  756. var_name = param.name + "_fp32_master"
  757. return unique_name.generate(var_name)
  758. def _create_master_grad(self, grad):
  759. assert self._is_dtype_fp16_or_bf16(grad.dtype)
  760. if in_pir_mode():
  761. if grad in self._master_grads:
  762. var = self._master_grads[grad]
  763. else:
  764. var = paddle.cast(grad, 'float32')
  765. self._master_grads[grad] = var
  766. else:
  767. if grad.name in self._master_grads:
  768. var = self._master_grads[grad.name]
  769. else:
  770. var_name = grad.name + "_fp32_master"
  771. var_name = unique_name.generate(var_name)
  772. var = grad.block.create_var(
  773. name=var_name,
  774. shape=grad.shape,
  775. value=0,
  776. dtype='float32',
  777. lod_level=grad.lod_level,
  778. persistable=grad.persistable,
  779. is_data=grad.is_data,
  780. )
  781. self._master_grads[grad.name] = var
  782. return var
  783. def _create_accumulators(self, block, parameters):
  784. """Create all accumulators needed by the parameters
  785. Args:
  786. block: the block in which the loss tensor is present
  787. parameters: list of parameter tensors for the optimizer
  788. """
  789. pass
  790. def _finish_update(self, block, parameters_and_grads):
  791. """Finish any custom updates needed
  792. before completing an optimization step
  793. Args:
  794. block: the block in which the loss tensor is present
  795. parameters: list of parameter tensors for the optimizer
  796. Returns:
  797. None
  798. """
  799. pass
  800. def _add_accumulator(
  801. self,
  802. name,
  803. param,
  804. dtype=None,
  805. fill_value=0.0,
  806. shape=None,
  807. type=None,
  808. device=None,
  809. ):
  810. """Utility function to add an accumulator for a parameter
  811. Args:
  812. block: the block in which the loss tensor is present
  813. name: name of the accumulator
  814. param: parameter tensor for which accumulator is to be added
  815. dtype: data type of the accumulator tensor
  816. fill_value: value to initialize the accumulator tensor
  817. """
  818. if self._name is not None:
  819. name = self._name + "_" + name
  820. if (
  821. name in self._accumulators
  822. and param.name in self._accumulators[name]
  823. ):
  824. if framework.in_dygraph_mode():
  825. return self._accumulators[name][param.name]
  826. raise Exception(
  827. f"Accumulator {name} already exists for parameter {param.name}"
  828. )
  829. if shape is None:
  830. shape = param.shape
  831. var_name = param.name + "_" + name
  832. var_name = unique_name.generate(var_name)
  833. self._opti_name_list.append(var_name)
  834. if device is None:
  835. device = self._get_device_for_param(param.name)
  836. if in_pir_mode():
  837. var = paddle.pir.core.create_persistable_value(
  838. dtype or param.dtype,
  839. shape,
  840. var_name,
  841. initializer=paddle.nn.initializer.Constant(
  842. value=float(fill_value)
  843. ),
  844. )
  845. else:
  846. assert isinstance(self.helper, LayerHelper)
  847. var = self.helper.create_global_variable(
  848. name=var_name,
  849. persistable=True,
  850. dtype=dtype or param.dtype,
  851. type=core.VarDesc.VarType.LOD_TENSOR,
  852. shape=shape,
  853. belong_to_optimizer=True,
  854. )
  855. if in_dygraph_mode() and (
  856. device == 'cpu' or isinstance(device, core.CPUPlace)
  857. ):
  858. _C_ops.full_(
  859. var,
  860. var.shape,
  861. str(float(fill_value)),
  862. var.dtype,
  863. core.CPUPlace(),
  864. )
  865. else:
  866. with device_guard(device):
  867. self.helper.set_variable_initializer(
  868. var,
  869. initializer=paddle.nn.initializer.Constant(
  870. value=float(fill_value)
  871. ),
  872. )
  873. if framework.in_dygraph_mode():
  874. if len(self._accumulators_holder) > 0:
  875. assert (
  876. var_name in self._accumulators_holder
  877. ), f"Optimizer set error, {var_name} should in state dict"
  878. var.set_value(self._accumulators_holder.pop(var_name))
  879. # load scale value for xpu
  880. if core.is_compiled_with_xpu():
  881. var.get_tensor().set_xpu_scale_value(
  882. self._accumulators_holder.get(
  883. var_name + ".SCALE_VALUE", -1.0
  884. )
  885. )
  886. self._accumulators[name][param.name] = var
  887. return var
  888. def _get_accumulator(self, name, param):
  889. """Utility function to fetch an accumulator for a parameter
  890. Args:
  891. name: name of the accumulator
  892. param: parameter tensor for which accumulator is to be fetched
  893. Returns:
  894. accumulator tensor for the parameter
  895. """
  896. if self._name is not None:
  897. name = self._name + "_" + name
  898. if (
  899. name not in self._accumulators
  900. or param.name not in self._accumulators[name]
  901. ):
  902. raise Exception(
  903. f"Accumulator {name} does not exist for parameter {param.name}"
  904. )
  905. return self._accumulators[name][param.name]
  906. def _get_accumulator_master(self, name, param):
  907. """Utility function to fetch an accumulator for a parameter
  908. Args:
  909. name: name of the accumulator
  910. param: parameter variable for which accumulator is to be fetched
  911. Returns:
  912. accumulator variable for the parameter
  913. """
  914. if self._name is not None:
  915. name = self._name + "_" + name
  916. find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
  917. param.dtype
  918. )
  919. target_param = (
  920. self._master_weights[param.name] if find_master else param
  921. )
  922. target_name = target_param.name
  923. if (
  924. name not in self._accumulators
  925. or target_name not in self._accumulators[name]
  926. ):
  927. raise Exception(
  928. f"Accumulator {name} does not exist for parameter {target_name}"
  929. )
  930. return self._accumulators[name][target_name]
  931. def _update_param_device_map(self, parameters_and_grads, target_block):
  932. for param_and_grad in parameters_and_grads:
  933. if param_and_grad[0].stop_gradient is False:
  934. param_name = param_and_grad[0].name
  935. ops = target_block.ops
  936. device_attr_name = (
  937. core.op_proto_and_checker_maker.kOpDeviceAttrName()
  938. )
  939. for op in ops:
  940. input_arg_names = op.input_arg_names
  941. if param_name in input_arg_names:
  942. self._param_device_map[param_name] = op.attr(
  943. device_attr_name
  944. )
  945. break
  946. def _get_device_for_param(self, param_name):
  947. device = None
  948. if param_name in self._param_device_map:
  949. device = self._param_device_map[param_name]
  950. return device
  951. def _create_optimization_pass(
  952. self, parameters_and_grads, param_group_idx=0
  953. ):
  954. """Add optimization operators to update gradients to tensors.
  955. Args:
  956. parameters_and_grads(list(tuple(Tensor, Tensor))):
  957. a list of (tensor, gradient) pair to update.
  958. Returns:
  959. return_op_list: a list of operators that will complete one step of
  960. optimization. This will include parameter update ops, global step
  961. update ops and any other custom ops required by subclasses to manage
  962. their internal state.
  963. """
  964. # This is a default implementation of create_optimization_pass that
  965. # can be shared by most optimizers. This implementation assumes that
  966. # the subclass will implement the _append_optimize_op method and the
  967. # _initialize_tensors method. The subclass can extend the
  968. # _create_accumulators method if it needs to create accumulators
  969. # for parameters and extend _finish_update method to add custom ops.
  970. # Allways called under program_guard use global block as loss block
  971. # But if current block is in control flow, append optimize op in the
  972. # grad block of current block
  973. global_block = framework.default_main_program().global_block()
  974. target_block = global_block
  975. current_block = framework.default_main_program().current_block()
  976. if current_block.idx != global_block.idx:
  977. assert (
  978. current_block.backward_block_idx != -1
  979. ), "current block is not global_block, but it doesn't have backward block."
  980. target_block = framework.default_main_program().blocks[
  981. current_block.backward_block_idx
  982. ]
  983. start = len(target_block.ops)
  984. self.helper = LayerHelper(self.__class__.__name__)
  985. self._create_global_learning_rate()
  986. # NOTE: Multi Tensor support [ Momentum, Adam ] for dygraph mode
  987. if self._use_multi_tensor and self.__class__.__name__ in [
  988. 'Momentum',
  989. 'Adam',
  990. ]:
  991. if (
  992. len(self._param_dict['FP32_LODTensor'][param_group_idx]) == 0
  993. and len(self._param_dict['FP16_LODTensor'][param_group_idx])
  994. == 0
  995. ):
  996. if isinstance(parameters_and_grads, list):
  997. assert param_group_idx == 0
  998. self._multi_tensor_init(
  999. target_block,
  1000. [
  1001. p[0]
  1002. for p in parameters_and_grads
  1003. if not p[0].stop_gradient
  1004. ],
  1005. param_group_idx,
  1006. )
  1007. else:
  1008. self._update_param_group(parameters_and_grads)
  1009. self._multi_tensor_init(
  1010. target_block,
  1011. [
  1012. p[0]
  1013. for p in parameters_and_grads['params']
  1014. if not p[0].stop_gradient
  1015. ],
  1016. param_group_idx,
  1017. )
  1018. if framework.in_dygraph_mode():
  1019. self._append_optimize_multi_tensor_op(
  1020. target_block,
  1021. parameters_and_grads,
  1022. param_group_idx=param_group_idx,
  1023. )
  1024. else:
  1025. self._update_param_device_map(
  1026. parameters_and_grads, target_block
  1027. )
  1028. # NOTE: Multi Tensor requires all parameters to be in the same device and program.
  1029. # param_grad_list = [p_0,g_0,p_1,g_1,....]
  1030. param_grad_list = []
  1031. for param_and_grad in parameters_and_grads:
  1032. if (
  1033. not param_and_grad[0].stop_gradient
  1034. and param_and_grad[1] is not None
  1035. ):
  1036. param_grad_list.append(param_and_grad[0])
  1037. param_grad_list.append(param_and_grad[1])
  1038. with param_grad_list[0].block.program._optimized_guard(
  1039. param_grad_list
  1040. ), name_scope("optimizer"):
  1041. device = self._get_device_for_param(param_grad_list[0].name)
  1042. with device_guard(device):
  1043. self._append_optimize_multi_tensor_op(
  1044. target_block,
  1045. parameters_and_grads,
  1046. param_group_idx=param_group_idx,
  1047. )
  1048. else:
  1049. if not framework.in_dygraph_mode():
  1050. params_grads_device_map = (
  1051. parameters_and_grads['params']
  1052. if isinstance(parameters_and_grads, dict)
  1053. else parameters_and_grads
  1054. )
  1055. self._update_param_device_map(
  1056. params_grads_device_map, target_block
  1057. )
  1058. if isinstance(parameters_and_grads, list):
  1059. with paddle.base.framework.dygraph_guard_if_declarative():
  1060. self._create_accumulators(
  1061. target_block,
  1062. [
  1063. p[0]
  1064. for p in parameters_and_grads
  1065. if not p[0].stop_gradient
  1066. ],
  1067. )
  1068. else:
  1069. params_acc_dict = parameters_and_grads.copy()
  1070. params_acc_dict['params'] = [
  1071. p[0]
  1072. for p in params_acc_dict['params']
  1073. if not p[0].stop_gradient
  1074. ]
  1075. with paddle.base.framework.dygraph_guard_if_declarative():
  1076. self._create_accumulators(target_block, params_acc_dict)
  1077. if framework.in_dygraph_mode():
  1078. found_inf = self._get_auxiliary_var('found_inf')
  1079. if found_inf:
  1080. if isinstance(found_inf, core.eager.Tensor):
  1081. self._set_auxiliary_var('found_inf', True)
  1082. else:
  1083. if isinstance(found_inf, core.eager.Tensor):
  1084. self._set_auxiliary_var('found_inf', False)
  1085. if isinstance(parameters_and_grads, list):
  1086. for param_and_grad in parameters_and_grads:
  1087. # Parameters can be uninitialized in pipeline parallel of semi-auto parallel.
  1088. # Since gradient clip and parameters update mixed up in one interface, so we
  1089. # need to filter again here.
  1090. if (
  1091. param_and_grad[1] is None
  1092. or not param_and_grad[0]._is_initialized()
  1093. ):
  1094. continue
  1095. if param_and_grad[0].stop_gradient is False:
  1096. self._append_optimize_op(
  1097. target_block, param_and_grad
  1098. )
  1099. else:
  1100. for param_and_grad in parameters_and_grads['params']:
  1101. if (
  1102. param_and_grad[1] is None
  1103. or not param_and_grad[0]._is_initialized()
  1104. ):
  1105. continue
  1106. if param_and_grad[0].stop_gradient is False:
  1107. param_grad_dict = {}
  1108. param_grad_dict['params'] = param_and_grad
  1109. param_grad_dict.update(
  1110. {
  1111. k: v
  1112. for k, v in parameters_and_grads.items()
  1113. if k != 'params'
  1114. }
  1115. )
  1116. self._append_optimize_op(
  1117. target_block, param_grad_dict
  1118. )
  1119. else:
  1120. for param_and_grad in parameters_and_grads:
  1121. if param_and_grad[1] is None:
  1122. continue
  1123. with param_and_grad[0].block.program._optimized_guard(
  1124. param_and_grad
  1125. ), name_scope("optimizer"):
  1126. if param_and_grad[0].stop_gradient is False:
  1127. device = self._get_device_for_param(
  1128. param_and_grad[0].name
  1129. )
  1130. with device_guard(device):
  1131. optimize_op = self._append_optimize_op(
  1132. target_block, param_and_grad
  1133. )
  1134. # Get custom finish ops for subclasses
  1135. # FIXME: Need to fix this once we figure out how to handle dependencies
  1136. self._finish_update(target_block, parameters_and_grads)
  1137. paddle.base.core._set_warmup(False)
  1138. end = len(target_block.ops)
  1139. return target_block._slice_ops(start, end)
  1140. def _pir_create_optimization_pass(
  1141. self, parameters_and_grads, param_group_idx=0
  1142. ):
  1143. """Add optimization operators to update gradients to tensors.
  1144. Args:
  1145. parameters_and_grads(list(tuple(Tensor, Tensor))):
  1146. a list of (tensor, gradient) pair to update.
  1147. Returns:
  1148. return_op_list: a list of operators that will complete one step of
  1149. optimization. This will include parameter update ops, global step
  1150. update ops and any other custom ops required by subclasses to manage
  1151. their internal state.
  1152. """
  1153. global_block = framework.default_main_program().global_block()
  1154. target_block = global_block
  1155. start = len(target_block.ops)
  1156. self._create_global_learning_rate()
  1157. params_grads_device_map = (
  1158. parameters_and_grads['params']
  1159. if isinstance(parameters_and_grads, dict)
  1160. else parameters_and_grads
  1161. )
  1162. self._update_param_device_map(params_grads_device_map, target_block)
  1163. if isinstance(parameters_and_grads, list):
  1164. self._create_accumulators(
  1165. target_block,
  1166. [p[0] for p in parameters_and_grads if not p[0].stop_gradient],
  1167. )
  1168. else:
  1169. params_acc_dict = parameters_and_grads.copy()
  1170. params_acc_dict['params'] = [
  1171. p[0]
  1172. for p in params_acc_dict['params']
  1173. if not p[0].stop_gradient
  1174. ]
  1175. self._create_accumulators(target_block, params_acc_dict)
  1176. if isinstance(parameters_and_grads, list):
  1177. for param_and_grad in parameters_and_grads:
  1178. if param_and_grad[1] is None:
  1179. continue
  1180. if param_and_grad[0].stop_gradient is False:
  1181. self._append_optimize_op(target_block, param_and_grad)
  1182. else:
  1183. for param_and_grad in parameters_and_grads['params']:
  1184. if param_and_grad[1] is None:
  1185. continue
  1186. if param_and_grad[0].stop_gradient is False:
  1187. param_grad_dict = {}
  1188. param_grad_dict['params'] = param_and_grad
  1189. param_grad_dict.update(
  1190. {
  1191. k: v
  1192. for k, v in parameters_and_grads.items()
  1193. if k != 'params'
  1194. }
  1195. )
  1196. self._append_optimize_op(target_block, param_grad_dict)
  1197. # Get custom finish ops for subclasses
  1198. # FIXME: Need to fix this once we figure out how to handle dependencies
  1199. self._finish_update(target_block, parameters_and_grads)
  1200. paddle.base.core._set_warmup(False)
  1201. end = len(target_block.ops)
  1202. return target_block._slice_ops(start, end)
  1203. def backward(
  1204. self,
  1205. loss,
  1206. startup_program=None,
  1207. parameters=None,
  1208. no_grad_set=None,
  1209. callbacks=None,
  1210. ):
  1211. """
  1212. The first part of ``minimize``, do auto-diff to append backward operations for
  1213. the current program.
  1214. Args:
  1215. loss (Tensor): ``loss`` tensor to run optimizations.
  1216. startup_program (Program, optional): :ref:`api_paddle_static_Program` for
  1217. initializing parameters in ``parameters``. The default value
  1218. is None, at this time :ref:`api_paddle_static_default_startup_program` will be used.
  1219. parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update
  1220. to minimize ``loss``. The default value is None, at this time all parameters
  1221. will be updated.
  1222. no_grad_set (set, optional): Set of ``Tensor`` or ``Tensor.name`` that don't need
  1223. to be updated. The default value is None.
  1224. callbacks (list, optional): list of callable objects to run when appending backward
  1225. operator for one parameter. The default value is None.
  1226. Return:
  1227. list: list of (param, grad) tensor pairs, param is ``Parameter``,
  1228. grad is the gradient value corresponding to the parameter.
  1229. Examples:
  1230. .. code-block:: python
  1231. >>> import paddle
  1232. >>> x = paddle.arange(26, dtype="float32").reshape([2, 13])
  1233. >>> linear = paddle.nn.Linear(13, 5)
  1234. >>> # This can be any optimizer supported by dygraph.
  1235. >>> adam = paddle.optimizer.Adam(learning_rate = 0.01,
  1236. ... parameters = linear.parameters())
  1237. >>> out = linear(x)
  1238. >>> out.backward()
  1239. >>> adam.step()
  1240. >>> adam.clear_grad()
  1241. """
  1242. act_no_grad_set = None
  1243. if framework.in_dygraph_mode():
  1244. pass
  1245. else:
  1246. act_no_grad_set = self._get_no_grad_set(loss, no_grad_set)
  1247. # Infer dtype by loss if None
  1248. if self._dtype is None:
  1249. self._dtype = loss.dtype
  1250. if framework.in_dygraph_mode():
  1251. parameter_list = parameters if parameters else self._parameter_list
  1252. # It is very time-consuming to call c++ functions in a loop on the python side.
  1253. # We put this part of the code on the c++ side to improve the speed in eager mode.
  1254. params_grads = []
  1255. grads = core.eager.get_all_grads(parameter_list)
  1256. for index, grad in enumerate(grads):
  1257. if grad is not None:
  1258. params_grads.append((parameter_list[index], grad))
  1259. else:
  1260. if callbacks is None:
  1261. callbacks = [paddle.nn.clip.error_clip_callback]
  1262. else:
  1263. assert isinstance(callbacks, list)
  1264. program = loss.block.program
  1265. assert np.prod(loss.shape) == 1, (
  1266. f"The number of elements of loss should be 1, but the current loss.shape is {loss.shape}, whose number of elements is not 1. "
  1267. "Maybe that you should call paddle.mean to process the current loss."
  1268. )
  1269. parameter_list = parameters if parameters else self._parameter_list
  1270. with paddle.static.program_guard(program, startup_program):
  1271. if in_pir_mode():
  1272. if parameter_list is None:
  1273. # all parameters will be updated.
  1274. program_all_params = (
  1275. program.global_block().all_parameters()
  1276. )
  1277. parameter_list = [
  1278. param
  1279. for param in program_all_params
  1280. if param.stop_gradient is False
  1281. ]
  1282. params_grads = []
  1283. grads = paddle.autograd.ir_backward.grad(
  1284. loss, parameter_list, no_grad_vars=act_no_grad_set
  1285. )
  1286. for index, grad in enumerate(grads):
  1287. if grad is not None:
  1288. params_grads.append((parameter_list[index], grad))
  1289. else:
  1290. from paddle.incubate.autograd.utils import prim_enabled
  1291. if prim_enabled():
  1292. params_grads = append_backward_new(
  1293. [loss], parameter_list, act_no_grad_set, callbacks
  1294. )
  1295. else:
  1296. params_grads = append_backward(
  1297. loss, parameter_list, act_no_grad_set, callbacks
  1298. )
  1299. return params_grads
  1300. def apply_gradients(self, params_grads):
  1301. """
  1302. Second part of `minimize`, appending optimization operators for
  1303. given `params_grads` pairs.
  1304. Args:
  1305. params_grads (list): list of (param, grad) pair to do optimization.
  1306. Returns:
  1307. list: A list of operators appended to the current program.
  1308. Examples:
  1309. .. code-block:: python
  1310. >>> import paddle
  1311. >>> inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
  1312. >>> linear = paddle.nn.Linear(10, 10)
  1313. >>> out = linear(inp)
  1314. >>> loss = paddle.mean(out)
  1315. >>> optimizer = paddle.optimizer.Adam(learning_rate=0.1,
  1316. ... parameters=linear.parameters())
  1317. >>> params_grads = optimizer.backward(loss)
  1318. >>> optimizer.apply_gradients(params_grads)
  1319. """
  1320. # NOTE(zhaoyinglia): AutoParallel set '_sorted' attribute to skip the 'sorted' operator.
  1321. if not hasattr(self, "_sorted"):
  1322. params_grads = sorted(params_grads, key=lambda x: x[0].name)
  1323. # 'optimizer(grad_clip)' or 'set_gradient_clip'
  1324. if self._grad_clip is not None:
  1325. params_grads = self._grad_clip(params_grads)
  1326. else:
  1327. params_grads = paddle.nn.clip.append_gradient_clip_ops(params_grads)
  1328. # Add regularization if any
  1329. params_grads = self.append_regularization_ops(
  1330. params_grads, self.regularization
  1331. )
  1332. optimize_ops = self._create_optimization_pass(params_grads)
  1333. return optimize_ops
  1334. def _apply_optimize(
  1335. self, loss, startup_program, params_grads, param_group_idx=0
  1336. ):
  1337. """
  1338. Second part of `minimize`, appending optimization operators for
  1339. given `params_grads` pairs.
  1340. Args:
  1341. loss (Tensor): loss tensor to run optimizations.
  1342. startup_program (Program): startup_program for initializing parameters
  1343. in `parameters`.
  1344. params_grads (list): list of (param, grad) pair to do optimization.
  1345. Returns:
  1346. list: A list of operators appended to the current program.
  1347. """
  1348. if framework.in_dygraph_mode() and g_shard_bypass_dygraph_optimizer:
  1349. return
  1350. if in_dynamic_or_pir_mode():
  1351. with paddle.static.program_guard(
  1352. paddle.static.default_main_program(),
  1353. paddle.static.default_startup_program(),
  1354. ):
  1355. if isinstance(params_grads, list):
  1356. if self._grad_clip is not None:
  1357. params_grads = self._grad_clip(params_grads)
  1358. params_grads = self.append_regularization_ops(
  1359. params_grads, self.regularization
  1360. )
  1361. else:
  1362. grad_clip = params_grads['grad_clip']
  1363. if grad_clip is not None:
  1364. params_grads['params'] = grad_clip(
  1365. params_grads['params']
  1366. )
  1367. params_grads['params'] = self.append_regularization_ops(
  1368. params_grads['params'], self.regularization
  1369. )
  1370. if in_pir_mode():
  1371. optimize_ops = self._pir_create_optimization_pass(
  1372. params_grads, param_group_idx=param_group_idx
  1373. )
  1374. else:
  1375. optimize_ops = self._create_optimization_pass(
  1376. params_grads, param_group_idx=param_group_idx
  1377. )
  1378. else:
  1379. assert param_group_idx == 0
  1380. program = loss.block.program
  1381. with paddle.static.program_guard(program, startup_program):
  1382. optimize_ops = self.apply_gradients(params_grads)
  1383. return optimize_ops
  1384. def _create_regularization_of_grad(self, param, grad, regularization=None):
  1385. """Create and add backward regularization Operators
  1386. Function helper of append_regularization_ops.
  1387. """
  1388. # If no gradient or no regularization is specified, then we don't need to do anything
  1389. if grad is None or (
  1390. (
  1391. not hasattr(param, 'regularizer')
  1392. or (hasattr(param, 'regularizer') and param.regularizer is None)
  1393. )
  1394. and regularization is None
  1395. ):
  1396. return grad
  1397. regularization_term = None
  1398. # when master_grad is true in amp training, grad will be fp32, but param maybe fp16.
  1399. # we get master weight when master_grad is true to avoid type mismatch error.
  1400. def get_target_param(param, grad):
  1401. target_param = param
  1402. if param.dtype != grad.dtype:
  1403. find_master = (
  1404. self._multi_precision
  1405. and self._is_dtype_fp16_or_bf16(param.dtype)
  1406. )
  1407. if find_master and len(self._master_weights) != 0:
  1408. target_param = self._master_weights[param.name]
  1409. else:
  1410. target_param = param.astype(grad.dtype)
  1411. return target_param
  1412. param = get_target_param(param, grad)
  1413. if hasattr(param, 'regularizer') and param.regularizer is not None:
  1414. # Add variable for regularization term in grad block
  1415. regularization_term = param.regularizer(param, grad, grad.block)
  1416. elif regularization is not None:
  1417. regularization_term = regularization(param, grad, grad.block)
  1418. assert regularization_term is not None
  1419. if in_dynamic_or_pir_mode():
  1420. return _C_ops.add_n([grad, regularization_term])
  1421. else:
  1422. new_grad = grad
  1423. if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
  1424. # FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
  1425. # the grad's type and name will be changed. But the gradient's name
  1426. # is used in ParallelExecutor Reduce mode, so I add a flag for
  1427. # the new_grad here.
  1428. new_grad = grad.block.create_var(
  1429. name=grad.name + core.kNewGradSuffix(),
  1430. dtype=param.dtype,
  1431. shape=param.shape,
  1432. lod_level=param.lod_level,
  1433. type=core.VarDesc.VarType.LOD_TENSOR,
  1434. )
  1435. inputs = {"X": [grad, regularization_term]}
  1436. outputs = {"Out": [new_grad]}
  1437. grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
  1438. return new_grad
  1439. def append_regularization_ops(
  1440. self, parameters_and_grads, regularization=None
  1441. ):
  1442. r"""Create and add backward regularization Operators
  1443. Creates and adds backward regularization operators in the BlockDesc.
  1444. This will add gradients of the regularizer function to the gradients
  1445. of the parameters and return these modified gradients. This is the
  1446. same as implementing weight decay in optimizers for regularization.
  1447. Args:
  1448. parameters_and_grads: A list of (parameters, gradients) pairs
  1449. that need to be regularized.
  1450. regularization: A global regularizer. If the parameter is not
  1451. set. It will be applied with regularizer.
  1452. Returns:
  1453. list[(Variable, Variable)]: list of (parameters, gradients) \
  1454. pair with the regularized gradient
  1455. Raises:
  1456. Exception: Unknown regularization type
  1457. """
  1458. params_and_grads = []
  1459. if framework.in_dygraph_mode() or in_pir_mode():
  1460. for param, grad in parameters_and_grads:
  1461. new_grad = self._create_regularization_of_grad(
  1462. param, grad, regularization
  1463. )
  1464. params_and_grads.append((param, new_grad))
  1465. else:
  1466. repeate_regularizer = False
  1467. with framework.name_scope('regularization'):
  1468. for param, grad in parameters_and_grads:
  1469. if (
  1470. not repeate_regularizer
  1471. and param.regularizer is not None
  1472. and regularization is not None
  1473. ):
  1474. repeate_regularizer = True
  1475. logging.info(
  1476. "If regularizer of a Parameter has been set by 'base.ParamAttr' or 'base.WeightNormParamAttr' already. "
  1477. "The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
  1478. % regularization.__str__()
  1479. )
  1480. with param.block.program._optimized_guard([param, grad]):
  1481. new_grad = self._create_regularization_of_grad(
  1482. param, grad, regularization
  1483. )
  1484. params_and_grads.append((param, new_grad))
  1485. return params_and_grads
  1486. def _get_no_grad_set(self, loss, no_grad_set=None):
  1487. if in_pir_mode():
  1488. no_grad_set = _get_no_grad_set_value(no_grad_set)
  1489. parameters = loss.block.program.global_block().all_parameters()
  1490. param_no_trainable = [
  1491. param for param in parameters if param.stop_gradient is True
  1492. ]
  1493. # If the parameter is no trainable, it should not have a gradient.
  1494. no_grad_set.update(param_no_trainable)
  1495. return no_grad_set
  1496. else:
  1497. no_grad_set = _get_no_grad_set_name(no_grad_set)
  1498. parameters = loss.block.program.global_block().all_parameters()
  1499. param_no_trainable = {
  1500. param.name
  1501. for param in parameters
  1502. if param.stop_gradient is True
  1503. }
  1504. # If the parameter is no trainable, it should not have a gradient.
  1505. no_grad_set.update(param_no_trainable)
  1506. return no_grad_set
  1507. @framework.non_static_only
  1508. def clear_grad(self, set_to_zero=True):
  1509. """
  1510. Clear the gradients of all optimized parameters for model.
  1511. If not, new gradient will accumulat on previous gradient.
  1512. There are two method to clear grad: set_to_zero or delete grad.
  1513. Args:
  1514. set_to_zero (bool, optional): If set grads to zero or not, default is True.
  1515. Returns:
  1516. None
  1517. Examples:
  1518. .. code-block:: python
  1519. >>> import paddle
  1520. >>> a = paddle.arange(26, dtype="float32").reshape([2, 13])
  1521. >>> linear = paddle.nn.Linear(13, 5)
  1522. >>> # This can be any optimizer supported by dygraph.
  1523. >>> adam = paddle.optimizer.Adam(learning_rate = 0.01,
  1524. ... parameters = linear.parameters())
  1525. >>> out = linear(a)
  1526. >>> out.backward()
  1527. >>> adam.step()
  1528. >>> adam.clear_grad()
  1529. """
  1530. param_list = []
  1531. if self._parameter_list is None or not isinstance(
  1532. self._parameter_list[0], dict
  1533. ):
  1534. for p in self._parameter_list:
  1535. if not p.stop_gradient:
  1536. param_list.append(p)
  1537. else:
  1538. for param_group in self._param_groups:
  1539. for p in param_group['params']:
  1540. if not p.stop_gradient:
  1541. param_list.append(p)
  1542. for p in param_list:
  1543. p.clear_gradient(set_to_zero)
  1544. @imperative_base.no_grad()
  1545. def minimize(
  1546. self, loss, startup_program=None, parameters=None, no_grad_set=None
  1547. ):
  1548. """
  1549. Add operations to minimize ``loss`` by updating ``parameters``.
  1550. Args:
  1551. loss (Tensor): A ``Tensor`` containing the value to minimize.
  1552. startup_program (Program, optional): :ref:`api_paddle_static_Program` for
  1553. initializing parameters in ``parameters``. The default value
  1554. is None, at this time :ref:`api_paddle_static_default_startup_program` will be used.
  1555. parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update
  1556. to minimize ``loss``. The default value is None, at this time all parameters
  1557. will be updated.
  1558. no_grad_set (set, optional): Set of ``Tensor`` or ``Tensor.name`` that don't need
  1559. to be updated. The default value is None.
  1560. Returns:
  1561. tuple: tuple (optimize_ops, params_grads), A list of operators appended
  1562. by minimize and a list of (param, grad) tensor pairs, param is
  1563. ``Parameter``, grad is the gradient value corresponding to the parameter.
  1564. In static graph mode, the returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
  1565. indicate program pruning. If so, the program will be pruned by ``feed`` and
  1566. ``fetch_list`` before run, see details in ``Executor``.
  1567. Examples:
  1568. .. code-block:: python
  1569. >>> import paddle
  1570. >>> linear = paddle.nn.Linear(10, 10)
  1571. >>> input = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
  1572. >>> out = linear(input)
  1573. >>> loss = paddle.mean(out)
  1574. >>> beta1 = paddle.to_tensor([0.9], dtype="float32")
  1575. >>> beta2 = paddle.to_tensor([0.99], dtype="float32")
  1576. >>> adam = paddle.optimizer.Adam(learning_rate=0.1,
  1577. ... parameters=linear.parameters(),
  1578. ... weight_decay=0.01)
  1579. >>> loss.backward()
  1580. >>> adam.minimize(loss)
  1581. >>> adam.clear_grad()
  1582. """
  1583. assert isinstance(
  1584. loss, (Variable, paddle.pir.Value)
  1585. ), "The loss should be an Tensor."
  1586. parameter_list = parameters if parameters else self._parameter_list
  1587. params_grads = self.backward(
  1588. loss,
  1589. startup_program=startup_program,
  1590. parameters=parameter_list,
  1591. no_grad_set=no_grad_set,
  1592. )
  1593. optimize_ops = self._apply_optimize(
  1594. loss, startup_program=startup_program, params_grads=params_grads
  1595. )
  1596. return optimize_ops, params_grads
  1597. def _declarative_step(self):
  1598. """
  1599. In declarative mode, we forward `call step` to `call apply_gradients`
  1600. """
  1601. params = (
  1602. paddle.static.default_main_program().global_block().all_parameters()
  1603. )
  1604. assert not isinstance(
  1605. self._parameter_list[0], dict
  1606. ), "Only list of parameters is supported while using optimizer in @paddle.jit.static."
  1607. selected_params = {param.name for param in self._parameter_list}
  1608. parameters = [param for param in params if param.trainable]
  1609. parameters = list(
  1610. filter(
  1611. lambda x: x.name in selected_params and hasattr(x, "grad"),
  1612. parameters,
  1613. )
  1614. )
  1615. params_grads = [(param, param.grad) for param in parameters]
  1616. optimize_ops = self.apply_gradients(params_grads)
  1617. @imperative_base.no_grad()
  1618. @framework.non_static_only
  1619. def step(self):
  1620. """
  1621. Execute the optimizer and update parameters once.
  1622. Returns:
  1623. None
  1624. Examples:
  1625. .. code-block:: python
  1626. >>> import paddle
  1627. >>> a = paddle.arange(26, dtype="float32").reshape([2, 13])
  1628. >>> linear = paddle.nn.Linear(13, 5)
  1629. >>> # This can be any optimizer supported by dygraph.
  1630. >>> adam = paddle.optimizer.Adam(learning_rate = 0.01,
  1631. ... parameters = linear.parameters())
  1632. >>> out = linear(a)
  1633. >>> out.backward()
  1634. >>> adam.step()
  1635. >>> adam.clear_grad()
  1636. """
  1637. if paddle.base.dygraph.base.in_to_static_mode():
  1638. self._declarative_step()
  1639. return
  1640. if not isinstance(self._param_groups[0], dict):
  1641. params_grads = []
  1642. for param in self._param_groups:
  1643. if param.stop_gradient:
  1644. continue
  1645. if param._grad_ivar() is not None:
  1646. grad_var = param._grad_ivar()
  1647. params_grads.append((param, grad_var))
  1648. self._apply_optimize(
  1649. loss=None,
  1650. startup_program=None,
  1651. params_grads=params_grads,
  1652. param_group_idx=0,
  1653. )
  1654. else:
  1655. # optimize parameters in groups
  1656. for idx, param_group in enumerate(self._param_groups):
  1657. params_grads = defaultdict(lambda: [])
  1658. for param in param_group['params']:
  1659. if param.stop_gradient:
  1660. continue
  1661. if param._grad_ivar() is not None:
  1662. grad_var = param._grad_ivar()
  1663. params_grads['params'].append((param, grad_var))
  1664. params_grads.update(
  1665. {k: v for k, v in param_group.items() if k != 'params'}
  1666. )
  1667. self._apply_optimize(
  1668. loss=None,
  1669. startup_program=None,
  1670. params_grads=params_grads,
  1671. param_group_idx=idx,
  1672. )
  1673. def _add_param_group(self, param_group):
  1674. """
  1675. Add a param group to parameter_list.
  1676. Args:
  1677. param_group (dict): The group of Tensors to be optimzed with
  1678. different optimization options.
  1679. """
  1680. params = param_group['params']
  1681. if isinstance(params, Parameter):
  1682. param_group['params'] = [params]
  1683. elif isinstance(params, set):
  1684. raise TypeError(
  1685. "optimizer parameters should be in ordered collections,"
  1686. "but received set, please use list instead."
  1687. )
  1688. else:
  1689. param_group['params'] = list(params)
  1690. # Update optimization options for each groups
  1691. for k, v in self._default_dict.items():
  1692. param_group.setdefault(k, v)
  1693. param_set = set()
  1694. for group in self._param_groups:
  1695. param_set.update(set(group['params']))
  1696. if not param_set.isdisjoint(set(param_group['params'])):
  1697. raise ValueError(
  1698. "some parameters appear in more than one parameter group"
  1699. )
  1700. for param in param_group['params']:
  1701. weight_decay = param_group['weight_decay']
  1702. if isinstance(weight_decay, float):
  1703. regularization = L2Decay(weight_decay)
  1704. else:
  1705. regularization = weight_decay
  1706. param.regularizer = regularization
  1707. param.optimize_attr['learning_rate'] = param_group.get(
  1708. 'learning_rate', 1.0
  1709. )
  1710. self._param_groups.append(param_group)
  1711. def _update_param_group(self, parameters):
  1712. """
  1713. Update the param group with new entry
  1714. Args:
  1715. parameters (dict): The extra group of Tensors to be optimzed with
  1716. different optimization options. Only used in child class.
  1717. """
  1718. pass
  1719. @framework.dygraph_only
  1720. def _multi_tensor_init(self, target_block, parameters, param_group_idx):
  1721. """
  1722. All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32).
  1723. This function will be overridden in the corresponding optimizer file.
  1724. Args:
  1725. target_block: the block in which the loss tensor is present
  1726. parameters: list of parameter tensors for the optimizer
  1727. """
  1728. pass
  1729. @framework.dygraph_only
  1730. def _append_optimize_multi_tensor_op(
  1731. self, target_block, parameters_and_grads, param_group_idx
  1732. ):
  1733. """
  1734. For Multi Tensor, append optimize merged_operator to block.
  1735. """
  1736. pass
  1737. def _is_dtype_fp16_or_bf16(self, dtype):
  1738. """
  1739. check the dtype is fp16 or the dtype is bf16
  1740. :param dtype: instance of core.VarDesc.VarType
  1741. :return: True if dtype is one of fp16 or bf16, False otherwise
  1742. """
  1743. assert isinstance(
  1744. dtype, (core.VarDesc.VarType, core.DataType)
  1745. ), "The dtype should be an instance of core.VarDesc.VarType or core.DataType."
  1746. if isinstance(dtype, core.VarDesc.VarType):
  1747. return (
  1748. dtype == core.VarDesc.VarType.FP16
  1749. or dtype == core.VarDesc.VarType.BF16
  1750. )
  1751. else:
  1752. return (
  1753. dtype == core.DataType.FLOAT16 or dtype == core.DataType.UINT16
  1754. )