step_lr.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. """ Step Scheduler
  2. Basic step LR schedule with warmup, noise.
  3. Hacked together by / Copyright 2020 Ross Wightman
  4. """
  5. import math
  6. import torch
  7. from typing import List
  8. from .scheduler import Scheduler
  9. class StepLRScheduler(Scheduler):
  10. """
  11. """
  12. def __init__(
  13. self,
  14. optimizer: torch.optim.Optimizer,
  15. decay_t: float,
  16. decay_rate: float = 1.,
  17. warmup_t=0,
  18. warmup_lr_init=0,
  19. warmup_prefix=True,
  20. t_in_epochs=True,
  21. noise_range_t=None,
  22. noise_pct=0.67,
  23. noise_std=1.0,
  24. noise_seed=42,
  25. initialize=True,
  26. ) -> None:
  27. super().__init__(
  28. optimizer,
  29. param_group_field="lr",
  30. t_in_epochs=t_in_epochs,
  31. noise_range_t=noise_range_t,
  32. noise_pct=noise_pct,
  33. noise_std=noise_std,
  34. noise_seed=noise_seed,
  35. initialize=initialize,
  36. )
  37. self.decay_t = decay_t
  38. self.decay_rate = decay_rate
  39. self.warmup_t = warmup_t
  40. self.warmup_lr_init = warmup_lr_init
  41. self.warmup_prefix = warmup_prefix
  42. if self.warmup_t:
  43. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  44. super().update_groups(self.warmup_lr_init)
  45. else:
  46. self.warmup_steps = [1 for _ in self.base_values]
  47. def _get_lr(self, t: int) -> List[float]:
  48. if t < self.warmup_t:
  49. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  50. else:
  51. if self.warmup_prefix:
  52. t = t - self.warmup_t
  53. lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
  54. return lrs