lamb.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle import _C_ops
  15. from paddle.base.executor import global_scope
  16. from ..base import core, framework
  17. from ..base.framework import Variable
  18. from .optimizer import Optimizer
  19. __all__ = []
  20. class Lamb(Optimizer):
  21. r"""
  22. LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.
  23. LAMB Optimizer is designed to scale up the batch size of training without losing
  24. accuracy, which supports adaptive element-wise updating and accurate layer-wise
  25. correction. For more information, please refer to `Large Batch Optimization for
  26. Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_ .
  27. The updating of parameters follows:
  28. .. math::
  29. m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t
  30. v_t &= \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2
  31. m_t &= \frac{m_t}{\beta_1^t}
  32. v_t &= \frac{v_t}{\beta_2^t}
  33. r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon}
  34. w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1})
  35. where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the
  36. learning rate, :math:`\\lambda` the LAMB weight decay rate.
  37. Args:
  38. learning_rate (float|Variable, optional): the learning rate used to update parameters. \
  39. Can be a float value or a Variable with data type float32. Default 0.001.
  40. lamb_weight_decay (float, optional): The LAMB weight decay rate. Default 0.01. Remind that weight_decay should be None.
  41. beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
  42. Default 0.9.
  43. beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
  44. Default 0.999.
  45. epsilon (float, optional): A small float value for numerical stability. Default 1e-6.
  46. parameters (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \
  47. This parameter is required in dygraph mode. And you can specify different options for \
  48. different parameter groups such as the learning rate, weight decay, etc, \
  49. then the parameters are list of dict. Note that the learning_rate in parameter groups \
  50. represents the scale of base learning_rate. \
  51. The default value is None in static graph mode, at this time all parameters will be updated.
  52. grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of
  53. some derived class of ``GradientClipBase`` . There are three clipping strategies
  54. ( :ref:`api_paddle_base_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_base_clip_ClipGradByNorm` ,
  55. :ref:`api_paddle_base_clip_ClipGradByValue` ). If you want better convergence, it is recommended
  56. to use :ref:`api_paddle_base_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
  57. exclude_from_weight_decay_fn (function, optional): whether to skip weight decay for a parameter when this function returns True while take the parameter as input.
  58. always_adapt (bool, optional): whether to use Layer-wise LR adaptation. By default, skip adaptation on parameters that are
  59. excluded from weight decay, unless always_adapt == True, then always enable LR adaptation.
  60. name(str|None): For detailed information, please refer to
  61. :ref:`api_guide_Name` . Usually name is no need to set and None by default.
  62. Examples:
  63. .. code-block:: python
  64. >>> import paddle
  65. >>> inp = paddle.uniform(shape=[10, 10], dtype='float32', min=-0.1, max=0.1)
  66. >>> linear = paddle.nn.Linear(10, 10)
  67. >>> out = linear(inp)
  68. >>> loss = paddle.mean(out)
  69. >>> beta1 = paddle.to_tensor([0.9], dtype="float32")
  70. >>> beta2 = paddle.to_tensor([0.85], dtype="float32")
  71. >>> lamb = paddle.optimizer.Lamb(learning_rate=0.002, parameters=linear.parameters(), lamb_weight_decay=0.01)
  72. >>> back = out.backward()
  73. >>> lamb.step()
  74. >>> lamb.clear_grad()
  75. """
  76. _moment1_acc_str = "moment1"
  77. _moment2_acc_str = "moment2"
  78. _beta1_pow_acc_str = "beta1_pow_acc"
  79. _beta2_pow_acc_str = "beta2_pow_acc"
  80. def __init__(
  81. self,
  82. learning_rate=0.001,
  83. lamb_weight_decay=0.01,
  84. beta1=0.9,
  85. beta2=0.999,
  86. epsilon=1e-6,
  87. parameters=None,
  88. grad_clip=None,
  89. exclude_from_weight_decay_fn=None,
  90. multi_precision=False,
  91. always_adapt=False,
  92. name=None,
  93. ):
  94. assert learning_rate is not None
  95. assert beta1 is not None
  96. assert beta2 is not None
  97. assert epsilon is not None
  98. super().__init__(
  99. learning_rate=learning_rate,
  100. parameters=parameters,
  101. weight_decay=None,
  102. grad_clip=grad_clip,
  103. name=name,
  104. )
  105. self.type = "lamb"
  106. self._beta1 = beta1
  107. self._beta2 = beta2
  108. self._epsilon = epsilon
  109. self._lamb_weight_decay = lamb_weight_decay
  110. self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
  111. self._default_dict = {
  112. 'beta1': beta1,
  113. 'beta2': beta2,
  114. 'epsilon': epsilon,
  115. 'lamb_weight_decay': lamb_weight_decay,
  116. 'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
  117. }
  118. self._master_weights = {}
  119. self._used_master_weights = {}
  120. # TODO(zengjinle): expose API as soon as possible
  121. self._multi_precision = multi_precision
  122. self.always_adapt = always_adapt
  123. def _get_parameter(self, name, scope=None):
  124. if scope is None:
  125. scope = global_scope()
  126. p_t = scope.find_var(name).get_tensor()
  127. master_name = self._used_master_weights.get(name)
  128. if master_name is not None:
  129. master_p_t = scope.find_var(master_name).get_tensor()
  130. assert master_p_t._dtype() != p_t._dtype()
  131. assert master_p_t.shape() == p_t.shape()
  132. else:
  133. master_p_t = None
  134. return p_t, master_p_t
  135. def _create_accumulators(self, block, parameters):
  136. assert isinstance(block, framework.Block)
  137. if isinstance(parameters, dict):
  138. parameters = self._update_param_group(parameters)
  139. # Create accumulator tensors for first and second moments
  140. for p in parameters:
  141. if p.name in self._already_create_accumulator:
  142. continue
  143. if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
  144. master_p = self._create_master_weight(p)
  145. self._add_moments_pows(master_p)
  146. self._already_create_accumulator.add(p.name)
  147. else:
  148. self._add_moments_pows(p)
  149. self._already_create_accumulator.add(p.name)
  150. def _add_moments_pows(self, p):
  151. acc_dtype = p.dtype
  152. if self._is_dtype_fp16_or_bf16(acc_dtype):
  153. acc_dtype = core.VarDesc.VarType.FP32
  154. self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
  155. self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
  156. self._add_accumulator(
  157. name=self._beta1_pow_acc_str,
  158. param=p,
  159. dtype=acc_dtype,
  160. fill_value=0.9
  161. if isinstance(self._beta1, Variable)
  162. else self._beta1,
  163. shape=[1],
  164. type=core.VarDesc.VarType.LOD_TENSOR,
  165. device='cpu',
  166. )
  167. self._add_accumulator(
  168. name=self._beta2_pow_acc_str,
  169. param=p,
  170. dtype=acc_dtype,
  171. fill_value=0.999
  172. if isinstance(self._beta2, Variable)
  173. else self._beta2,
  174. shape=[1],
  175. type=core.VarDesc.VarType.LOD_TENSOR,
  176. device='cpu',
  177. )
  178. def _append_optimize_op(self, block, param_and_grad):
  179. assert isinstance(block, framework.Block)
  180. if isinstance(param_and_grad, dict):
  181. param_and_grad = self._update_param_group(param_and_grad)
  182. block.program._use_lamb = True
  183. moment1 = self._get_accumulator_master(
  184. self._moment1_acc_str, param_and_grad[0]
  185. )
  186. moment2 = self._get_accumulator_master(
  187. self._moment2_acc_str, param_and_grad[0]
  188. )
  189. beta1_pow_acc = self._get_accumulator_master(
  190. self._beta1_pow_acc_str, param_and_grad[0]
  191. )
  192. beta2_pow_acc = self._get_accumulator_master(
  193. self._beta2_pow_acc_str, param_and_grad[0]
  194. )
  195. if (
  196. self._exclude_from_weight_decay_fn is not None
  197. and self._exclude_from_weight_decay_fn(param_and_grad[0])
  198. ):
  199. weight_decay = 0.0
  200. else:
  201. weight_decay = self._lamb_weight_decay
  202. lr = self._create_param_lr(param_and_grad)
  203. find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
  204. param_and_grad[0].dtype
  205. )
  206. p_name = param_and_grad[0].name
  207. if find_master:
  208. master_weight = self._master_weights[p_name]
  209. self._used_master_weights[p_name] = master_weight.name
  210. else:
  211. master_weight = None
  212. if framework.in_dynamic_or_pir_mode():
  213. _C_ops.lamb_(
  214. param_and_grad[0],
  215. param_and_grad[1],
  216. lr,
  217. moment1,
  218. moment2,
  219. beta1_pow_acc,
  220. beta2_pow_acc,
  221. master_weight,
  222. None,
  223. weight_decay,
  224. self._beta1,
  225. self._beta2,
  226. self._epsilon,
  227. self.always_adapt,
  228. find_master,
  229. )
  230. return None
  231. else:
  232. # create the lamb optimize op
  233. inputs = {
  234. "Param": param_and_grad[0],
  235. "Grad": param_and_grad[1],
  236. "LearningRate": lr,
  237. "Moment1": moment1,
  238. "Moment2": moment2,
  239. "Beta1Pow": beta1_pow_acc,
  240. "Beta2Pow": beta2_pow_acc,
  241. }
  242. outputs = {
  243. "ParamOut": param_and_grad[0],
  244. "Moment1Out": moment1,
  245. "Moment2Out": moment2,
  246. "Beta1PowOut": beta1_pow_acc,
  247. "Beta2PowOut": beta2_pow_acc,
  248. }
  249. attrs = {
  250. "beta1": self._beta1,
  251. "beta2": self._beta2,
  252. "epsilon": self._epsilon,
  253. "weight_decay": weight_decay,
  254. "always_adapt": self.always_adapt,
  255. "multi_precision": find_master,
  256. }
  257. if find_master:
  258. inputs["MasterParam"] = master_weight
  259. outputs["MasterParamOut"] = master_weight
  260. found_inf = self._get_auxiliary_var('found_inf')
  261. if found_inf:
  262. inputs["SkipUpdate"] = found_inf
  263. lamb_op = block.append_op(
  264. type=self.type,
  265. inputs=inputs,
  266. outputs=outputs,
  267. attrs=attrs,
  268. stop_gradient=True,
  269. )
  270. return lamb_op
  271. def _update_param_group(self, parameters):
  272. self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
  273. self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
  274. self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
  275. self._lamb_weight_decay = parameters.get(
  276. 'lamb_weight_decay', self._default_dict['lamb_weight_decay']
  277. )
  278. self._exclude_from_weight_decay_fn = parameters.get(
  279. 'exclude_from_weight_decay_fn',
  280. self._default_dict['exclude_from_weight_decay_fn'],
  281. )
  282. parameters = parameters.get('params')
  283. return parameters