nadam.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from paddle import _C_ops
  16. from paddle.base.libpaddle import DataType
  17. from ..base import core, framework
  18. from ..base.framework import (
  19. in_dynamic_or_pir_mode,
  20. in_pir_mode,
  21. )
  22. from .optimizer import Optimizer
  23. __all__ = []
  24. class NAdam(Optimizer):
  25. r"""
  26. The NAdam optimizer is implemented based on the Adam Optimization
  27. in paper `Incorporating Nesterov Momentum into Adam <https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ>`_.
  28. The main improvement is to combine the advantages of Nesterov momentum and Adam adaptive learning rate.
  29. .. math::
  30. \begin{aligned}
  31. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  32. &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} \rho ^{t \psi} \big) \\
  33. &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96 ^{(t+1)\psi}\big)\\
  34. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  35. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  36. &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i) + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\
  37. &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  38. &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
  39. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  40. &\hspace{0mm} \text{ with: } \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
  41. &\hspace{0mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\
  42. \end{aligned}
  43. Args:
  44. learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
  45. It can be a float value or a LRScheduler. The default value is 0.002.
  46. parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``.
  47. This parameter is required in dygraph mode. And you can specify different options for
  48. different parameter groups such as the learning rate, weight decay, etc,
  49. then the parameters are list of dict. Note that the learning_rate in parameter groups
  50. represents the scale of base learning_rate.
  51. The default value is None in static graph mode, at this time all parameters will be updated.
  52. beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.
  53. It should be a float number or a 0-D Tensor with shape [] and data type as float32.
  54. The default value is 0.9.
  55. beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
  56. It should be a float number or a 0-D Tensor with shape [] and data type as float32.
  57. The default value is 0.999.
  58. epsilon (float, optional): A small float value for numerical stability.
  59. The default value is 1e-08.
  60. weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor.
  61. Default None, meaning there is no regularization.
  62. momentum_decay (float, optional): momentum momentum_decay. The default value is 0.004.
  63. grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of
  64. some derived class of ``GradientClipBase`` . There are three clipping strategies
  65. ( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` ,
  66. :ref:`api_paddle_nn_ClipGradByValue` ). Default None, meaning there is no gradient clipping.
  67. name (str, optional): Normally there is no need for user to set this property.
  68. For more information, please refer to :ref:`api_guide_Name`.
  69. The default value is None.
  70. Notes:
  71. Currently, NAdam doesn't support sparse parameter optimization.
  72. Examples:
  73. .. code-block:: python
  74. >>> import paddle
  75. >>> inp = paddle.rand([10,10], dtype="float32")
  76. >>> linear = paddle.nn.Linear(10, 10)
  77. >>> out = linear(inp)
  78. >>> loss = paddle.mean(out)
  79. >>> nadam = paddle.optimizer.NAdam(learning_rate=0.1,
  80. ... parameters=linear.parameters())
  81. >>> out.backward()
  82. >>> nadam.step()
  83. >>> nadam.clear_grad()
  84. >>> # Note that the learning_rate of linear_2 is 0.01.
  85. >>> linear_1 = paddle.nn.Linear(10, 10)
  86. >>> linear_2 = paddle.nn.Linear(10, 10)
  87. >>> inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
  88. >>> out = linear_1(inp)
  89. >>> out = linear_2(out)
  90. >>> loss = paddle.mean(out)
  91. >>> opt = paddle.optimizer.NAdam(
  92. ... learning_rate=0.1,
  93. ... parameters=[{
  94. ... 'params': linear_1.parameters()
  95. ... }, {
  96. ... 'params': linear_2.parameters(),
  97. ... 'weight_decay': 0.001,
  98. ... 'learning_rate': 0.1,
  99. ... 'beta1': 0.8
  100. ... }],
  101. ... weight_decay=0.01,
  102. ... beta1=0.9
  103. ... )
  104. >>> loss.backward()
  105. >>> opt.step()
  106. >>> opt.clear_grad()
  107. """
  108. _momentum_decay_pow_acc_str = "momentum_decay_pow"
  109. _beta2_pow_acc_str = "beta2_pow"
  110. _mu_product_acc_str = "mu_product"
  111. _moment1_acc_str = "moment1"
  112. _moment2_acc_str = "moment2"
  113. def __init__(
  114. self,
  115. learning_rate=0.002,
  116. beta1=0.9,
  117. beta2=0.999,
  118. epsilon=1.0e-8,
  119. momentum_decay=0.004,
  120. parameters=None,
  121. weight_decay=None,
  122. grad_clip=None,
  123. name=None,
  124. ):
  125. if isinstance(learning_rate, (float, int)) and not 0.0 <= learning_rate:
  126. raise ValueError(
  127. f"Invalid learning rate: {learning_rate}, expect learning_rate >= 0."
  128. )
  129. if not 0.0 <= epsilon:
  130. raise ValueError(
  131. f"Invalid epsilon value: {epsilon}, expect epsilon >= 0."
  132. )
  133. if not 0.0 <= beta1 < 1.0:
  134. raise ValueError(
  135. f"Invalid beta1: {beta1}, expect 0. <= beta1 < 1.0."
  136. )
  137. if not 0.0 <= beta2 < 1.0:
  138. raise ValueError(
  139. f"Invalid beta2: {beta2}, expect 0. <= beta2 < 1.0."
  140. )
  141. if not 0.0 <= momentum_decay:
  142. raise ValueError(
  143. f"Invalid momentum_decay value: {momentum_decay}, expect momentum_decay >= 0."
  144. )
  145. super().__init__(
  146. learning_rate=learning_rate,
  147. parameters=parameters,
  148. weight_decay=weight_decay,
  149. grad_clip=grad_clip,
  150. name=name,
  151. )
  152. self.type = "nadam"
  153. self._beta1 = beta1
  154. self._beta2 = beta2
  155. self._epsilon = epsilon
  156. self._momentum_decay = momentum_decay
  157. self._multi_precision = False
  158. self._master_weights = {}
  159. self._default_dict = {
  160. 'beta1': beta1,
  161. 'beta2': beta2,
  162. 'epsilon': epsilon,
  163. 'momentum_decay': momentum_decay,
  164. }
  165. def _add_moments_pows(self, p):
  166. acc_dtype = p.dtype
  167. if self._is_dtype_fp16_or_bf16(acc_dtype):
  168. acc_dtype = (
  169. DataType.FLOAT32 if in_pir_mode() else core.VarDesc.VarType.FP32
  170. )
  171. self._add_accumulator(
  172. name=self._momentum_decay_pow_acc_str,
  173. param=p,
  174. dtype=acc_dtype,
  175. fill_value=1.0,
  176. )
  177. self._add_accumulator(
  178. name=self._beta2_pow_acc_str,
  179. param=p,
  180. dtype=acc_dtype,
  181. fill_value=1.0,
  182. )
  183. self._add_accumulator(
  184. name=self._mu_product_acc_str,
  185. param=p,
  186. dtype=acc_dtype,
  187. fill_value=1.0,
  188. )
  189. self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
  190. self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
  191. def _create_accumulators(self, block, parameters):
  192. if not isinstance(block, framework.Block):
  193. raise TypeError("block is not instance of framework.Block.")
  194. if isinstance(parameters, dict):
  195. parameters = parameters.get('params')
  196. for p in parameters:
  197. if p.name in self._already_create_accumulator:
  198. continue
  199. if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
  200. master_p = self._create_master_weight(p)
  201. self._add_moments_pows(master_p)
  202. self._already_create_accumulator.add(p.name)
  203. continue
  204. if (
  205. self._is_dtype_fp16_or_bf16(p.dtype)
  206. and not self._multi_precision
  207. ):
  208. warnings.warn(
  209. "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
  210. "Consider using multi_precision=True option of the Lars optimizer."
  211. )
  212. self._add_moments_pows(p)
  213. self._already_create_accumulator.add(p.name)
  214. def _append_optimize_op(self, block, param_and_grad):
  215. if not isinstance(block, framework.Block):
  216. raise TypeError("block is not instance of framework.Block.")
  217. if isinstance(param_and_grad, dict):
  218. param_and_grad = self._update_param_group(param_and_grad)
  219. momentum_decay_pow_acc = self._get_accumulator_master(
  220. self._momentum_decay_pow_acc_str, param_and_grad[0]
  221. )
  222. beta2_pow_acc = self._get_accumulator_master(
  223. self._beta2_pow_acc_str, param_and_grad[0]
  224. )
  225. mu_product_acc = self._get_accumulator_master(
  226. self._mu_product_acc_str, param_and_grad[0]
  227. )
  228. moment1_acc = self._get_accumulator_master(
  229. self._moment1_acc_str, param_and_grad[0]
  230. )
  231. moment2_acc = self._get_accumulator_master(
  232. self._moment2_acc_str, param_and_grad[0]
  233. )
  234. find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
  235. param_and_grad[0].dtype
  236. )
  237. master_weight = (
  238. self._master_weights[param_and_grad[0].name]
  239. if find_master
  240. else None
  241. )
  242. if in_dynamic_or_pir_mode():
  243. _C_ops.nadam_(
  244. param_and_grad[0],
  245. param_and_grad[1],
  246. self._create_param_lr(param_and_grad),
  247. momentum_decay_pow_acc,
  248. beta2_pow_acc,
  249. mu_product_acc,
  250. moment1_acc,
  251. moment2_acc,
  252. master_weight,
  253. self._beta1,
  254. self._beta2,
  255. self._epsilon,
  256. self._momentum_decay,
  257. find_master,
  258. )
  259. return None
  260. else:
  261. inputs = {
  262. "param": param_and_grad[0],
  263. "grad": param_and_grad[1],
  264. "momentum_decay_pow": momentum_decay_pow_acc,
  265. "beta2_pow": beta2_pow_acc,
  266. "mu_product": mu_product_acc,
  267. "moment1": moment1_acc,
  268. "moment2": moment2_acc,
  269. "learning_rate": self._create_param_lr(param_and_grad),
  270. }
  271. outputs = {
  272. "param_out": param_and_grad[0],
  273. "momentum_decay_pow_out": momentum_decay_pow_acc,
  274. "beta2_pow_out": beta2_pow_acc,
  275. "mu_product_out": mu_product_acc,
  276. "moment1_out": moment1_acc,
  277. "moment2_out": moment2_acc,
  278. }
  279. if find_master:
  280. inputs["master_param"] = master_weight
  281. outputs["master_param_out"] = master_weight
  282. nadam_op = block.append_op(
  283. type=self.type,
  284. inputs=inputs,
  285. outputs=outputs,
  286. attrs={
  287. "epsilon": self._epsilon,
  288. "beta1": self._beta1,
  289. "beta2": self._beta2,
  290. "momentum_decay": self._momentum_decay,
  291. },
  292. stop_gradient=True,
  293. )
  294. return nadam_op
  295. def _update_param_group(self, parameters):
  296. self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
  297. self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
  298. self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
  299. self._momentum_decay = parameters.get(
  300. 'momentum_decay', self._default_dict['momentum_decay']
  301. )
  302. parameters = parameters.get('params')
  303. return parameters