warmup.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from modelscope.metainfo import LR_Schedulers
  3. from modelscope.trainers.lrscheduler.builder import LR_SCHEDULER
  4. from .base import BaseWarmup
  5. @LR_SCHEDULER.register_module(module_name=LR_Schedulers.ConstantWarmup)
  6. class ConstantWarmup(BaseWarmup):
  7. """Linear warmup scheduler.
  8. Args:
  9. base_scheduler (torch.optim._LRScheduler): an instance of torch.optim._LRScheduler type
  10. warmup_ratio (float): Lr used at warmup stage equals to warmup_ratio * initial_lr
  11. warmup_iters (int | list): Warmup iterations
  12. last_epoch (int): The index of last epoch.
  13. """
  14. def __init__(self,
  15. base_scheduler,
  16. warmup_iters,
  17. warmup_ratio=0.1,
  18. last_epoch=-1):
  19. self.warmup_ratio = warmup_ratio
  20. super(ConstantWarmup, self).__init__(
  21. base_scheduler, warmup_iters=warmup_iters, last_epoch=last_epoch)
  22. def get_warmup_scale(self, cur_iter):
  23. if cur_iter >= self.warmup_iters:
  24. return 1.0
  25. return self.warmup_ratio
  26. @LR_SCHEDULER.register_module(module_name=LR_Schedulers.LinearWarmup)
  27. class LinearWarmup(BaseWarmup):
  28. """Linear warmup scheduler.
  29. Args:
  30. base_scheduler (torch.optim._LRScheduler): an instance of torch.optim._LRScheduler type
  31. warmup_iters (int | list): Warmup iterations
  32. warmup_ratio (float): Lr used at the beginning of warmup equals to warmup_ratio * initial_lr
  33. last_epoch (int): The index of last epoch.
  34. """
  35. def __init__(self,
  36. base_scheduler,
  37. warmup_iters,
  38. warmup_ratio=0.1,
  39. last_epoch=-1):
  40. self.warmup_ratio = warmup_ratio
  41. super(LinearWarmup, self).__init__(
  42. base_scheduler, warmup_iters=warmup_iters, last_epoch=last_epoch)
  43. def get_warmup_scale(self, cur_iter):
  44. k = (1 - cur_iter / self.warmup_iters) * (1 - self.warmup_ratio)
  45. return 1 - k
  46. @LR_SCHEDULER.register_module(module_name=LR_Schedulers.ExponentialWarmup)
  47. class ExponentialWarmup(BaseWarmup):
  48. """Exponential warmup scheduler.
  49. Args:
  50. base_scheduler (torch.optim._LRScheduler): an instance of torch.optim._LRScheduler type
  51. warmup_iters (int | list): Warmup iterations
  52. warmup_ratio (float): Lr used at the beginning of warmup equals to warmup_ratio * initial_lr
  53. last_epoch (int): The index of last epoch.
  54. """
  55. def __init__(self,
  56. base_scheduler,
  57. warmup_iters,
  58. warmup_ratio=0.1,
  59. last_epoch=-1):
  60. self.warmup_ratio = warmup_ratio
  61. super(ExponentialWarmup, self).__init__(
  62. base_scheduler, warmup_iters=warmup_iters, last_epoch=last_epoch)
  63. def get_warmup_scale(self, cur_iter):
  64. k = self.warmup_ratio**(1 - cur_iter / self.warmup_iters)
  65. return k