radam.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from paddle import _C_ops
  16. from paddle.base.libpaddle import DataType
  17. from ..base import core, framework
  18. from ..base.framework import (
  19. in_dynamic_or_pir_mode,
  20. in_pir_mode,
  21. )
  22. from .optimizer import Optimizer
  23. __all__ = []
  24. class RAdam(Optimizer):
  25. r"""
  26. The RAdam optimizer is implemented based on the Adam Optimization
  27. in paper `On the Variance of the Adaptive Learning Rate and Beyond <https://arxiv.org/abs/1908.03265>`_.
  28. RAdam improved the initial stability of training by modifying Adam's momentum term.
  29. .. math::
  30. \begin{aligned}
  31. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  32. &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  33. &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  34. &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  35. &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} -
  36. 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\
  37. &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\
  38. &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\
  39. &\hspace{12mm} r_t \leftarrow
  40. \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
  41. &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t \\
  42. &\hspace{6mm}\textbf{else} \\
  43. &\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} \\
  44. &\hspace{0mm} \text{ with: } \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, \: \theta_t \text{ (params)} \\
  45. &\hspace{0mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1
  46. \end{aligned}
  47. Args:
  48. learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
  49. It can be a float value or a LRScheduler. The default value is 0.001.
  50. parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``.
  51. This parameter is required in dygraph mode. And you can specify different options for
  52. different parameter groups such as the learning rate, weight decay, etc,
  53. then the parameters are list of dict. Note that the learning_rate in parameter groups
  54. represents the scale of base learning_rate.
  55. The default value is None in static graph mode, at this time all parameters will be updated.
  56. beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.
  57. It should be a float number or a 0-D Tensor with shape [] and data type as float32.
  58. The default value is 0.9.
  59. beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
  60. It should be a float number or a 0-D Tensor with shape [] and data type as float32.
  61. The default value is 0.999.
  62. epsilon (float, optional): A small float value for numerical stability.
  63. The default value is 1e-08.
  64. weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor.
  65. Default None, meaning there is no regularization.
  66. grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of
  67. some derived class of ``GradientClipBase`` . There are three clipping strategies
  68. ( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` ,
  69. :ref:`api_paddle_nn_ClipGradByValue` ). Default None, meaning there is no gradient clipping.
  70. name (str, optional): Normally there is no need for user to set this property.
  71. For more information, please refer to :ref:`api_guide_Name`.
  72. The default value is None.
  73. Note:
  74. Currently, RAdam doesn't support sparse parameter optimization.
  75. Examples:
  76. .. code-block:: python
  77. >>> import paddle
  78. >>> inp = paddle.rand([10,10], dtype="float32")
  79. >>> linear = paddle.nn.Linear(10, 10)
  80. >>> out = linear(inp)
  81. >>> loss = paddle.mean(out)
  82. >>> radam = paddle.optimizer.RAdam(learning_rate=0.1,
  83. ... parameters=linear.parameters())
  84. >>> out.backward()
  85. >>> radam.step()
  86. >>> radam.clear_grad()
  87. >>> # Note that the learning_rate of linear_2 is 0.01.
  88. >>> linear_1 = paddle.nn.Linear(10, 10)
  89. >>> linear_2 = paddle.nn.Linear(10, 10)
  90. >>> inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
  91. >>> out = linear_1(inp)
  92. >>> out = linear_2(out)
  93. >>> loss = paddle.mean(out)
  94. >>> opt = paddle.optimizer.RAdam(
  95. ... learning_rate=0.1,
  96. ... parameters=[{
  97. ... 'params': linear_1.parameters()
  98. ... }, {
  99. ... 'params': linear_2.parameters(),
  100. ... 'weight_decay': 0.001,
  101. ... 'learning_rate': 0.1,
  102. ... 'beta1': 0.8
  103. ... }],
  104. ... weight_decay=0.01,
  105. ... beta1=0.9
  106. ... )
  107. >>> loss.backward()
  108. >>> opt.step()
  109. >>> opt.clear_grad()
  110. """
  111. _beta1_pow_acc_str = "beta1_pow"
  112. _beta2_pow_acc_str = "beta2_pow"
  113. _rho_acc_str = "rho"
  114. _moment1_acc_str = "moment1"
  115. _moment2_acc_str = "moment2"
  116. def __init__(
  117. self,
  118. learning_rate=0.001,
  119. beta1=0.9,
  120. beta2=0.999,
  121. epsilon=1.0e-8,
  122. parameters=None,
  123. weight_decay=None,
  124. grad_clip=None,
  125. name=None,
  126. ):
  127. if isinstance(learning_rate, (float, int)) and not 0.0 <= learning_rate:
  128. raise ValueError(
  129. f"Invalid learning rate: {learning_rate}, expect learning_rate >= 0."
  130. )
  131. if not 0.0 <= epsilon:
  132. raise ValueError(
  133. f"Invalid epsilon value: {epsilon}, expect epsilon >= 0."
  134. )
  135. if not 0.0 <= beta1 < 1.0:
  136. raise ValueError(
  137. f"Invalid beta1: {beta1}, expect 0. <= beta1 < 1.0."
  138. )
  139. if not 0.0 <= beta2 < 1.0:
  140. raise ValueError(
  141. f"Invalid beta2: {beta2}, expect 0. <= beta2 < 1.0."
  142. )
  143. super().__init__(
  144. learning_rate=learning_rate,
  145. parameters=parameters,
  146. weight_decay=weight_decay,
  147. grad_clip=grad_clip,
  148. name=name,
  149. )
  150. self.type = "radam"
  151. self._beta1 = beta1
  152. self._beta2 = beta2
  153. self._epsilon = epsilon
  154. self._multi_precision = False
  155. self._master_weights = {}
  156. self._default_dict = {
  157. 'beta1': beta1,
  158. 'beta2': beta2,
  159. 'epsilon': epsilon,
  160. }
  161. def _add_moments_pows(self, p):
  162. acc_dtype = p.dtype
  163. if self._is_dtype_fp16_or_bf16(acc_dtype):
  164. acc_dtype = (
  165. DataType.FLOAT32 if in_pir_mode() else core.VarDesc.VarType.FP32
  166. )
  167. self._add_accumulator(
  168. name=self._beta1_pow_acc_str,
  169. param=p,
  170. dtype=acc_dtype,
  171. fill_value=1.0,
  172. )
  173. self._add_accumulator(
  174. name=self._beta2_pow_acc_str,
  175. param=p,
  176. dtype=acc_dtype,
  177. fill_value=1.0,
  178. )
  179. self._add_accumulator(
  180. name=self._rho_acc_str,
  181. param=p,
  182. dtype=acc_dtype,
  183. fill_value=1.0,
  184. )
  185. self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
  186. self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
  187. def _create_accumulators(self, block, parameters):
  188. if not isinstance(block, framework.Block):
  189. raise TypeError("block is not instance of framework.Block.")
  190. if isinstance(parameters, dict):
  191. parameters = parameters.get('params')
  192. for p in parameters:
  193. if p.name in self._already_create_accumulator:
  194. continue
  195. if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
  196. master_p = self._create_master_weight(p)
  197. self._add_moments_pows(master_p)
  198. self._already_create_accumulator.add(p.name)
  199. continue
  200. if (
  201. self._is_dtype_fp16_or_bf16(p.dtype)
  202. and not self._multi_precision
  203. ):
  204. warnings.warn(
  205. "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
  206. "Consider using multi_precision=True option of the Lars optimizer."
  207. )
  208. self._add_moments_pows(p)
  209. self._already_create_accumulator.add(p.name)
  210. def _append_optimize_op(self, block, param_and_grad):
  211. if not isinstance(block, framework.Block):
  212. raise TypeError("block is not instance of framework.Block.")
  213. if isinstance(param_and_grad, dict):
  214. param_and_grad = self._update_param_group(param_and_grad)
  215. beta1_pow_acc = self._get_accumulator_master(
  216. self._beta1_pow_acc_str, param_and_grad[0]
  217. )
  218. beta2_pow_acc = self._get_accumulator_master(
  219. self._beta2_pow_acc_str, param_and_grad[0]
  220. )
  221. rho_acc = self._get_accumulator_master(
  222. self._rho_acc_str, param_and_grad[0]
  223. )
  224. moment1_acc = self._get_accumulator_master(
  225. self._moment1_acc_str, param_and_grad[0]
  226. )
  227. moment2_acc = self._get_accumulator_master(
  228. self._moment2_acc_str, param_and_grad[0]
  229. )
  230. find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
  231. param_and_grad[0].dtype
  232. )
  233. master_weight = (
  234. self._master_weights[param_and_grad[0].name]
  235. if find_master
  236. else None
  237. )
  238. if in_dynamic_or_pir_mode():
  239. _C_ops.radam_(
  240. param_and_grad[0],
  241. param_and_grad[1],
  242. self._create_param_lr(param_and_grad),
  243. beta1_pow_acc,
  244. beta2_pow_acc,
  245. rho_acc,
  246. moment1_acc,
  247. moment2_acc,
  248. master_weight,
  249. self._beta1,
  250. self._beta2,
  251. self._epsilon,
  252. find_master,
  253. )
  254. return None
  255. else:
  256. inputs = {
  257. "param": param_and_grad[0],
  258. "grad": param_and_grad[1],
  259. "beta1_pow": beta1_pow_acc,
  260. "beta2_pow": beta2_pow_acc,
  261. "rho": rho_acc,
  262. "moment1": moment1_acc,
  263. "moment2": moment2_acc,
  264. "learning_rate": self._create_param_lr(param_and_grad),
  265. }
  266. outputs = {
  267. "param_out": param_and_grad[0],
  268. "beta1_pow_out": beta1_pow_acc,
  269. "beta2_pow_out": beta2_pow_acc,
  270. "rho_out": rho_acc,
  271. "moment1_out": moment1_acc,
  272. "moment2_out": moment2_acc,
  273. }
  274. if find_master:
  275. inputs["master_param"] = master_weight
  276. outputs["master_param_out"] = master_weight
  277. radam_op = block.append_op(
  278. type=self.type,
  279. inputs=inputs,
  280. outputs=outputs,
  281. attrs={
  282. "epsilon": self._epsilon,
  283. "beta1": self._beta1,
  284. "beta2": self._beta2,
  285. },
  286. stop_gradient=True,
  287. )
  288. return radam_op
  289. def _update_param_group(self, parameters):
  290. self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
  291. self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
  292. self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
  293. parameters = parameters.get('params')
  294. return parameters