base.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from torch.optim.lr_scheduler import _LRScheduler
  3. class BaseWarmup(_LRScheduler):
  4. """Base warmup scheduler
  5. Args:
  6. base_scheduler (torch.optim._LRScheduler): an instance of torch.optim._LRScheduler type
  7. warmup_iters (int | list): Warmup iterations
  8. last_epoch (int): The index of last epoch.
  9. """
  10. def __init__(self,
  11. base_scheduler,
  12. warmup_iters,
  13. last_epoch=-1,
  14. verbose=False):
  15. self.base_scheduler = base_scheduler
  16. self.warmup_iters = warmup_iters
  17. optimizer = self.base_scheduler.optimizer
  18. self._is_init_step = True
  19. super(BaseWarmup, self).__init__(
  20. optimizer, last_epoch=last_epoch, verbose=verbose)
  21. def get_lr(self):
  22. return self.base_scheduler.get_lr()
  23. def state_dict(self):
  24. return self.base_scheduler.state_dict()
  25. def load_state_dict(self, state_dict):
  26. return self.base_scheduler.load_state_dict(state_dict)
  27. def scale(self):
  28. """Scale the learning rates.
  29. """
  30. scale_value = self.get_warmup_scale(self.base_scheduler._step_count
  31. - 1)
  32. if isinstance(scale_value, (int, float)):
  33. scale_value = [
  34. scale_value for _ in range(len(self.optimizer.param_groups))
  35. ]
  36. else:
  37. assert isinstance(
  38. scale_value, (list, tuple)), 'Only support list or tuple type!'
  39. assert len(scale_value) == len(
  40. self.optimizer.param_groups), ('Size mismatch {} != {}'.format(
  41. len(scale_value), len(self.optimizer.param_groups)))
  42. for i, group in enumerate(self.optimizer.param_groups):
  43. group['lr'] *= scale_value[i]
  44. def step(self, *args, **kwargs):
  45. """
  46. When ``self.base_scheduler._step_count`` is less than ``self.warmup_iters``, multiply lr by scale
  47. """
  48. if self.base_scheduler._step_count > self.warmup_iters:
  49. return self.base_scheduler.step(*args, **kwargs)
  50. for group, lr in zip(self.optimizer.param_groups, self.base_lrs):
  51. group['lr'] = lr
  52. # `base_scheduler` has done step() at init when build
  53. if self._is_init_step:
  54. self._is_init_step = False
  55. else:
  56. self.base_scheduler.step(*args, **kwargs)
  57. self.scale()
  58. @classmethod
  59. def get_warmup_scale(self, cur_iter):
  60. pass