adamp.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. """
  2. AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
  3. Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
  4. Code: https://github.com/clovaai/AdamP
  5. Copyright (c) 2020-present NAVER Corp.
  6. MIT license
  7. """
  8. import torch
  9. import torch.nn.functional as F
  10. from torch.optim.optimizer import Optimizer
  11. import math
  12. def _channel_view(x) -> torch.Tensor:
  13. return x.reshape(x.size(0), -1)
  14. def _layer_view(x) -> torch.Tensor:
  15. return x.reshape(1, -1)
  16. def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
  17. wd = 1.
  18. expand_size = (-1,) + (1,) * (len(p.shape) - 1)
  19. for view_func in [_channel_view, _layer_view]:
  20. param_view = view_func(p)
  21. grad_view = view_func(grad)
  22. cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
  23. # FIXME this is a problem for PyTorch XLA
  24. if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
  25. p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
  26. perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
  27. wd = wd_ratio
  28. return perturb, wd
  29. return perturb, wd
  30. class AdamP(Optimizer):
  31. def __init__(
  32. self,
  33. params,
  34. lr=1e-3,
  35. betas=(0.9, 0.999),
  36. eps=1e-8,
  37. weight_decay=0,
  38. delta=0.1,
  39. wd_ratio=0.1,
  40. nesterov=False,
  41. ):
  42. defaults = dict(
  43. lr=lr,
  44. betas=betas,
  45. eps=eps,
  46. weight_decay=weight_decay,
  47. delta=delta,
  48. wd_ratio=wd_ratio,
  49. nesterov=nesterov,
  50. )
  51. super(AdamP, self).__init__(params, defaults)
  52. @torch.no_grad()
  53. def step(self, closure=None):
  54. loss = None
  55. if closure is not None:
  56. with torch.enable_grad():
  57. loss = closure()
  58. for group in self.param_groups:
  59. for p in group['params']:
  60. if p.grad is None:
  61. continue
  62. grad = p.grad
  63. beta1, beta2 = group['betas']
  64. nesterov = group['nesterov']
  65. state = self.state[p]
  66. # State initialization
  67. if len(state) == 0:
  68. state['step'] = 0
  69. state['exp_avg'] = torch.zeros_like(p)
  70. state['exp_avg_sq'] = torch.zeros_like(p)
  71. # Adam
  72. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  73. state['step'] += 1
  74. bias_correction1 = 1 - beta1 ** state['step']
  75. bias_correction2 = 1 - beta2 ** state['step']
  76. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  77. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  78. denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
  79. step_size = group['lr'] / bias_correction1
  80. if nesterov:
  81. perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
  82. else:
  83. perturb = exp_avg / denom
  84. # Projection
  85. wd_ratio = 1.
  86. if len(p.shape) > 1:
  87. perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
  88. # Weight decay
  89. if group['weight_decay'] > 0:
  90. p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
  91. # Step
  92. p.add_(perturb, alpha=-step_size)
  93. return loss