sgdp.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. """
  2. SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
  3. Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
  4. Code: https://github.com/clovaai/AdamP
  5. Copyright (c) 2020-present NAVER Corp.
  6. MIT license
  7. """
  8. import torch
  9. import torch.nn.functional as F
  10. from torch.optim.optimizer import Optimizer, required
  11. import math
  12. from .adamp import projection
  13. class SGDP(Optimizer):
  14. def __init__(
  15. self,
  16. params,
  17. lr=required,
  18. momentum=0,
  19. dampening=0,
  20. weight_decay=0,
  21. nesterov=False,
  22. eps=1e-8,
  23. delta=0.1,
  24. wd_ratio=0.1
  25. ):
  26. defaults = dict(
  27. lr=lr,
  28. momentum=momentum,
  29. dampening=dampening,
  30. weight_decay=weight_decay,
  31. nesterov=nesterov,
  32. eps=eps,
  33. delta=delta,
  34. wd_ratio=wd_ratio,
  35. )
  36. super(SGDP, self).__init__(params, defaults)
  37. @torch.no_grad()
  38. def step(self, closure=None):
  39. loss = None
  40. if closure is not None:
  41. with torch.enable_grad():
  42. loss = closure()
  43. for group in self.param_groups:
  44. weight_decay = group['weight_decay']
  45. momentum = group['momentum']
  46. dampening = group['dampening']
  47. nesterov = group['nesterov']
  48. for p in group['params']:
  49. if p.grad is None:
  50. continue
  51. grad = p.grad
  52. state = self.state[p]
  53. # State initialization
  54. if len(state) == 0:
  55. state['momentum'] = torch.zeros_like(p)
  56. # SGD
  57. buf = state['momentum']
  58. buf.mul_(momentum).add_(grad, alpha=1. - dampening)
  59. if nesterov:
  60. d_p = grad + momentum * buf
  61. else:
  62. d_p = buf
  63. # Projection
  64. wd_ratio = 1.
  65. if len(p.shape) > 1:
  66. d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
  67. # Weight decay
  68. if weight_decay != 0:
  69. p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
  70. # Step
  71. p.add_(d_p, alpha=-group['lr'])
  72. return loss