| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from paddle import _C_ops
- from paddle.base.executor import global_scope
- from ..base import core, framework
- from ..base.framework import Variable
- from .optimizer import Optimizer
- __all__ = []
- class Lamb(Optimizer):
- r"""
- LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.
- LAMB Optimizer is designed to scale up the batch size of training without losing
- accuracy, which supports adaptive element-wise updating and accurate layer-wise
- correction. For more information, please refer to `Large Batch Optimization for
- Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_ .
- The updating of parameters follows:
- .. math::
- m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t
- v_t &= \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2
- m_t &= \frac{m_t}{\beta_1^t}
- v_t &= \frac{v_t}{\beta_2^t}
- r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon}
- w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1})
- where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the
- learning rate, :math:`\\lambda` the LAMB weight decay rate.
- Args:
- learning_rate (float|Variable, optional): the learning rate used to update parameters. \
- Can be a float value or a Variable with data type float32. Default 0.001.
- lamb_weight_decay (float, optional): The LAMB weight decay rate. Default 0.01. Remind that weight_decay should be None.
- beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
- Default 0.9.
- beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
- Default 0.999.
- epsilon (float, optional): A small float value for numerical stability. Default 1e-6.
- parameters (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \
- This parameter is required in dygraph mode. And you can specify different options for \
- different parameter groups such as the learning rate, weight decay, etc, \
- then the parameters are list of dict. Note that the learning_rate in parameter groups \
- represents the scale of base learning_rate. \
- The default value is None in static graph mode, at this time all parameters will be updated.
- grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of
- some derived class of ``GradientClipBase`` . There are three clipping strategies
- ( :ref:`api_paddle_base_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_base_clip_ClipGradByNorm` ,
- :ref:`api_paddle_base_clip_ClipGradByValue` ). If you want better convergence, it is recommended
- to use :ref:`api_paddle_base_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
- exclude_from_weight_decay_fn (function, optional): whether to skip weight decay for a parameter when this function returns True while take the parameter as input.
- always_adapt (bool, optional): whether to use Layer-wise LR adaptation. By default, skip adaptation on parameters that are
- excluded from weight decay, unless always_adapt == True, then always enable LR adaptation.
- name(str|None): For detailed information, please refer to
- :ref:`api_guide_Name` . Usually name is no need to set and None by default.
- Examples:
- .. code-block:: python
- >>> import paddle
- >>> inp = paddle.uniform(shape=[10, 10], dtype='float32', min=-0.1, max=0.1)
- >>> linear = paddle.nn.Linear(10, 10)
- >>> out = linear(inp)
- >>> loss = paddle.mean(out)
- >>> beta1 = paddle.to_tensor([0.9], dtype="float32")
- >>> beta2 = paddle.to_tensor([0.85], dtype="float32")
- >>> lamb = paddle.optimizer.Lamb(learning_rate=0.002, parameters=linear.parameters(), lamb_weight_decay=0.01)
- >>> back = out.backward()
- >>> lamb.step()
- >>> lamb.clear_grad()
- """
- _moment1_acc_str = "moment1"
- _moment2_acc_str = "moment2"
- _beta1_pow_acc_str = "beta1_pow_acc"
- _beta2_pow_acc_str = "beta2_pow_acc"
- def __init__(
- self,
- learning_rate=0.001,
- lamb_weight_decay=0.01,
- beta1=0.9,
- beta2=0.999,
- epsilon=1e-6,
- parameters=None,
- grad_clip=None,
- exclude_from_weight_decay_fn=None,
- multi_precision=False,
- always_adapt=False,
- name=None,
- ):
- assert learning_rate is not None
- assert beta1 is not None
- assert beta2 is not None
- assert epsilon is not None
- super().__init__(
- learning_rate=learning_rate,
- parameters=parameters,
- weight_decay=None,
- grad_clip=grad_clip,
- name=name,
- )
- self.type = "lamb"
- self._beta1 = beta1
- self._beta2 = beta2
- self._epsilon = epsilon
- self._lamb_weight_decay = lamb_weight_decay
- self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
- self._default_dict = {
- 'beta1': beta1,
- 'beta2': beta2,
- 'epsilon': epsilon,
- 'lamb_weight_decay': lamb_weight_decay,
- 'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn,
- }
- self._master_weights = {}
- self._used_master_weights = {}
- # TODO(zengjinle): expose API as soon as possible
- self._multi_precision = multi_precision
- self.always_adapt = always_adapt
- def _get_parameter(self, name, scope=None):
- if scope is None:
- scope = global_scope()
- p_t = scope.find_var(name).get_tensor()
- master_name = self._used_master_weights.get(name)
- if master_name is not None:
- master_p_t = scope.find_var(master_name).get_tensor()
- assert master_p_t._dtype() != p_t._dtype()
- assert master_p_t.shape() == p_t.shape()
- else:
- master_p_t = None
- return p_t, master_p_t
- def _create_accumulators(self, block, parameters):
- assert isinstance(block, framework.Block)
- if isinstance(parameters, dict):
- parameters = self._update_param_group(parameters)
- # Create accumulator tensors for first and second moments
- for p in parameters:
- if p.name in self._already_create_accumulator:
- continue
- if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
- master_p = self._create_master_weight(p)
- self._add_moments_pows(master_p)
- self._already_create_accumulator.add(p.name)
- else:
- self._add_moments_pows(p)
- self._already_create_accumulator.add(p.name)
- def _add_moments_pows(self, p):
- acc_dtype = p.dtype
- if self._is_dtype_fp16_or_bf16(acc_dtype):
- acc_dtype = core.VarDesc.VarType.FP32
- self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
- self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
- self._add_accumulator(
- name=self._beta1_pow_acc_str,
- param=p,
- dtype=acc_dtype,
- fill_value=0.9
- if isinstance(self._beta1, Variable)
- else self._beta1,
- shape=[1],
- type=core.VarDesc.VarType.LOD_TENSOR,
- device='cpu',
- )
- self._add_accumulator(
- name=self._beta2_pow_acc_str,
- param=p,
- dtype=acc_dtype,
- fill_value=0.999
- if isinstance(self._beta2, Variable)
- else self._beta2,
- shape=[1],
- type=core.VarDesc.VarType.LOD_TENSOR,
- device='cpu',
- )
- def _append_optimize_op(self, block, param_and_grad):
- assert isinstance(block, framework.Block)
- if isinstance(param_and_grad, dict):
- param_and_grad = self._update_param_group(param_and_grad)
- block.program._use_lamb = True
- moment1 = self._get_accumulator_master(
- self._moment1_acc_str, param_and_grad[0]
- )
- moment2 = self._get_accumulator_master(
- self._moment2_acc_str, param_and_grad[0]
- )
- beta1_pow_acc = self._get_accumulator_master(
- self._beta1_pow_acc_str, param_and_grad[0]
- )
- beta2_pow_acc = self._get_accumulator_master(
- self._beta2_pow_acc_str, param_and_grad[0]
- )
- if (
- self._exclude_from_weight_decay_fn is not None
- and self._exclude_from_weight_decay_fn(param_and_grad[0])
- ):
- weight_decay = 0.0
- else:
- weight_decay = self._lamb_weight_decay
- lr = self._create_param_lr(param_and_grad)
- find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
- param_and_grad[0].dtype
- )
- p_name = param_and_grad[0].name
- if find_master:
- master_weight = self._master_weights[p_name]
- self._used_master_weights[p_name] = master_weight.name
- else:
- master_weight = None
- if framework.in_dynamic_or_pir_mode():
- _C_ops.lamb_(
- param_and_grad[0],
- param_and_grad[1],
- lr,
- moment1,
- moment2,
- beta1_pow_acc,
- beta2_pow_acc,
- master_weight,
- None,
- weight_decay,
- self._beta1,
- self._beta2,
- self._epsilon,
- self.always_adapt,
- find_master,
- )
- return None
- else:
- # create the lamb optimize op
- inputs = {
- "Param": param_and_grad[0],
- "Grad": param_and_grad[1],
- "LearningRate": lr,
- "Moment1": moment1,
- "Moment2": moment2,
- "Beta1Pow": beta1_pow_acc,
- "Beta2Pow": beta2_pow_acc,
- }
- outputs = {
- "ParamOut": param_and_grad[0],
- "Moment1Out": moment1,
- "Moment2Out": moment2,
- "Beta1PowOut": beta1_pow_acc,
- "Beta2PowOut": beta2_pow_acc,
- }
- attrs = {
- "beta1": self._beta1,
- "beta2": self._beta2,
- "epsilon": self._epsilon,
- "weight_decay": weight_decay,
- "always_adapt": self.always_adapt,
- "multi_precision": find_master,
- }
- if find_master:
- inputs["MasterParam"] = master_weight
- outputs["MasterParamOut"] = master_weight
- found_inf = self._get_auxiliary_var('found_inf')
- if found_inf:
- inputs["SkipUpdate"] = found_inf
- lamb_op = block.append_op(
- type=self.type,
- inputs=inputs,
- outputs=outputs,
- attrs=attrs,
- stop_gradient=True,
- )
- return lamb_op
- def _update_param_group(self, parameters):
- self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
- self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
- self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
- self._lamb_weight_decay = parameters.get(
- 'lamb_weight_decay', self._default_dict['lamb_weight_decay']
- )
- self._exclude_from_weight_decay_fn = parameters.get(
- 'exclude_from_weight_decay_fn',
- self._default_dict['exclude_from_weight_decay_fn'],
- )
- parameters = parameters.get('params')
- return parameters
|