adamw.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. from torch import Tensor
  4. from .adam import Adam, adam
  5. from .optimizer import (
  6. _capturable_doc,
  7. _differentiable_doc,
  8. _foreach_doc,
  9. _fused_doc,
  10. _maximize_doc,
  11. _params_doc,
  12. ParamsT,
  13. )
  14. __all__ = ["AdamW", "adamw"]
  15. class AdamW(Adam):
  16. def __init__(
  17. self,
  18. params: ParamsT,
  19. lr: Union[float, Tensor] = 1e-3,
  20. betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
  21. eps: float = 1e-8,
  22. weight_decay: float = 1e-2,
  23. amsgrad: bool = False,
  24. *,
  25. maximize: bool = False,
  26. foreach: Optional[bool] = None,
  27. capturable: bool = False,
  28. differentiable: bool = False,
  29. fused: Optional[bool] = None,
  30. ):
  31. super().__init__(
  32. params,
  33. lr,
  34. betas,
  35. eps,
  36. weight_decay,
  37. amsgrad,
  38. foreach=foreach,
  39. maximize=maximize,
  40. capturable=capturable,
  41. differentiable=differentiable,
  42. fused=fused,
  43. decoupled_weight_decay=True,
  44. )
  45. # Preserve decoupled_weight_decay from AdamW for backwards compatibility. The following
  46. # guarantees that decoupled_weight_decay will always be True for loading any state into
  47. # AdamW
  48. def __setstate__(self, state):
  49. super().__setstate__(state)
  50. for group in self.param_groups:
  51. group["decoupled_weight_decay"] = True
  52. AdamW.__doc__ = (
  53. r"""Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance.
  54. .. math::
  55. \begin{aligned}
  56. &\rule{110mm}{0.4pt} \\
  57. &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
  58. \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
  59. \: \epsilon \text{ (epsilon)} \\
  60. &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
  61. \: \textit{maximize} \\
  62. &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
  63. \text{ ( second moment)}, \: v_0^{max}\leftarrow 0 \\[-1.ex]
  64. &\rule{110mm}{0.4pt} \\
  65. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  66. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  67. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  68. &\hspace{5mm}\textbf{else} \\
  69. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  70. &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
  71. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  72. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  73. &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  74. &\hspace{5mm}\textbf{if} \: amsgrad \\
  75. &\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\
  76. &\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\
  77. &\hspace{5mm}\textbf{else} \\
  78. &\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  79. &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
  80. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  81. &\rule{110mm}{0.4pt} \\[-1.ex]
  82. &\bf{return} \: \theta_t \\[-1.ex]
  83. &\rule{110mm}{0.4pt} \\[-1.ex]
  84. \end{aligned}
  85. For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
  86. """
  87. + rf"""
  88. Args:
  89. {_params_doc}
  90. lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
  91. is not yet supported for all our implementations. Please use a float
  92. LR if you are not also specifying fused=True or capturable=True.
  93. betas (Tuple[float, float], optional): coefficients used for computing
  94. running averages of gradient and its square (default: (0.9, 0.999))
  95. eps (float, optional): term added to the denominator to improve
  96. numerical stability (default: 1e-8)
  97. weight_decay (float, optional): weight decay coefficient (default: 1e-2)
  98. amsgrad (bool, optional): whether to use the AMSGrad variant of this
  99. algorithm from the paper `On the Convergence of Adam and Beyond`_
  100. (default: False)
  101. {_maximize_doc}
  102. {_foreach_doc}
  103. {_capturable_doc}
  104. {_differentiable_doc}
  105. {_fused_doc}
  106. .. Note::
  107. A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
  108. .. _Decoupled Weight Decay Regularization:
  109. https://arxiv.org/abs/1711.05101
  110. .. _On the Convergence of Adam and Beyond:
  111. https://openreview.net/forum?id=ryQu7f-RZ
  112. """
  113. )
  114. # @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adam
  115. def adamw(
  116. params: list[Tensor],
  117. grads: list[Tensor],
  118. exp_avgs: list[Tensor],
  119. exp_avg_sqs: list[Tensor],
  120. max_exp_avg_sqs: list[Tensor],
  121. state_steps: list[Tensor],
  122. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  123. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  124. foreach: Optional[bool] = None,
  125. capturable: bool = False,
  126. differentiable: bool = False,
  127. fused: Optional[bool] = None,
  128. grad_scale: Optional[Tensor] = None,
  129. found_inf: Optional[Tensor] = None,
  130. has_complex: bool = False,
  131. *,
  132. amsgrad: bool,
  133. beta1: float,
  134. beta2: float,
  135. lr: Union[float, Tensor],
  136. weight_decay: float,
  137. eps: float,
  138. maximize: bool,
  139. ):
  140. r"""Functional API that performs AdamW algorithm computation.
  141. See :class:`~torch.optim.AdamW` for details.
  142. """
  143. adam(
  144. params,
  145. grads,
  146. exp_avgs,
  147. exp_avg_sqs,
  148. max_exp_avg_sqs,
  149. state_steps,
  150. foreach=foreach,
  151. capturable=capturable,
  152. differentiable=differentiable,
  153. fused=fused,
  154. grad_scale=grad_scale,
  155. found_inf=found_inf,
  156. has_complex=has_complex,
  157. amsgrad=amsgrad,
  158. beta1=beta1,
  159. beta2=beta2,
  160. lr=lr,
  161. weight_decay=weight_decay,
  162. eps=eps,
  163. maximize=maximize,
  164. decoupled_weight_decay=True,
  165. )