lr_scheduler_hook.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from modelscope.metainfo import Hooks
  3. from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
  4. from modelscope.utils.constant import LogKeys
  5. from modelscope.utils.logger import get_logger
  6. from modelscope.utils.torch_utils import is_master
  7. from .builder import HOOKS
  8. from .hook import Hook
  9. from .priority import Priority
  10. class LrSchedulerProcessor:
  11. def __init__(self):
  12. self.lr_strategy = None
  13. self.warmup_lr_scheduler = None
  14. def set_lr_strategy(self, lr_strategy):
  15. self.lr_strategy = lr_strategy
  16. def set_warmup_lr_scheduler(self, warmup_lr_scheduler):
  17. self.warmup_lr_scheduler = warmup_lr_scheduler
  18. def initialize_lr_scheduler(self, trainer):
  19. """Initialize the lr scheduler.
  20. This is a strategic function which can be registered by other hook's function.
  21. """
  22. pass
  23. def step(self, trainer):
  24. """Do lr scheduler's step.
  25. This is a strategic function which can be registered by other hook's function.
  26. """
  27. if self.warmup_lr_scheduler is not None:
  28. self.warmup_lr_scheduler.step()
  29. else:
  30. trainer.lr_scheduler.step()
  31. def get_current_lr(self, trainer):
  32. import torch
  33. if isinstance(trainer.optimizer, torch.optim.Optimizer):
  34. lr = [group['lr'] for group in trainer.optimizer.param_groups]
  35. elif isinstance(trainer.optimizer, dict):
  36. lr = dict()
  37. for name, optim in trainer.optimizer.items():
  38. lr[name] = [group['lr'] for group in optim.param_groups]
  39. else:
  40. raise RuntimeError(
  41. 'lr is not applicable because optimizer does not exist.')
  42. return lr
  43. class LrStrategy:
  44. by_epoch = 'by_epoch'
  45. by_step = 'by_step'
  46. no = 'no'
  47. @HOOKS.register_module(module_name=Hooks.LrSchedulerHook)
  48. class LrSchedulerHook(Hook):
  49. """Lr scheduler.
  50. Args:
  51. by_epoch (bool): Whether lr changes by epoch
  52. warmup (dict): warm up config
  53. """
  54. PRIORITY = Priority.LOW
  55. def __init__(self,
  56. lr_strategy=LrStrategy.by_epoch,
  57. warmup=None,
  58. **kwargs) -> None:
  59. super().__init__()
  60. if 'by_epoch' in kwargs:
  61. self.lr_strategy = LrStrategy.by_epoch if kwargs[
  62. 'by_epoch'] else LrStrategy.by_step
  63. else:
  64. self.lr_strategy = lr_strategy
  65. self.warmup = warmup
  66. self.warmup_lr_scheduler = None
  67. self.processor = LrSchedulerProcessor()
  68. def set_processor(self, processor):
  69. self.processor = processor
  70. def before_run(self, trainer):
  71. self.processor.set_lr_strategy(self.lr_strategy)
  72. if self.warmup is not None:
  73. assert isinstance(self.warmup, dict) and 'type' in self.warmup
  74. self.warmup_lr_scheduler = build_lr_scheduler(
  75. cfg=self.warmup,
  76. default_args={'base_scheduler': trainer.lr_scheduler})
  77. self.processor.set_warmup_lr_scheduler(self.warmup_lr_scheduler)
  78. self.processor.initialize_lr_scheduler(trainer)
  79. def after_train_iter(self, trainer):
  80. if self.lr_strategy == LrStrategy.by_step and trainer.iter >= getattr(
  81. trainer, 'cumulative_iters', 1) - 1:
  82. self.processor.step(trainer)
  83. trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer)
  84. def before_train_epoch(self, trainer):
  85. trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer)
  86. def after_train_epoch(self, trainer):
  87. if self.lr_strategy == LrStrategy.by_epoch:
  88. self.processor.step(trainer)
  89. def _get_log_lr(self, trainer):
  90. # forward compatibility with AddLrLogHook in EasyCV
  91. if not hasattr(self, 'processor'):
  92. self.processor = LrSchedulerProcessor()
  93. cur_lr = self.processor.get_current_lr(trainer)
  94. # only record lr of the first param group
  95. if isinstance(cur_lr, list):
  96. lr = cur_lr[0]
  97. else:
  98. assert isinstance(cur_lr, dict)
  99. lr = {}
  100. for k, lr_ in cur_lr.items():
  101. assert isinstance(lr_, list)
  102. lr.update({k: lr_[0]})
  103. return lr
  104. class PlateauLrSchedulerProcessor(LrSchedulerProcessor):
  105. def __init__(self, metric_key):
  106. super().__init__()
  107. self.metric_key = metric_key
  108. def step(self, trainer):
  109. # adapt to evaluation interval is greater than 1
  110. if trainer.metric_values is None:
  111. if is_master():
  112. print(
  113. f'Current epoch {trainer.epoch} has no evaluation metric values, skip lr_scheduler.step() !'
  114. )
  115. return
  116. metrics = trainer.metric_values[self.metric_key]
  117. if self.lr_strategy == LrStrategy.by_epoch:
  118. if self.warmup_lr_scheduler is not None:
  119. self.warmup_lr_scheduler.step(metrics=metrics)
  120. else:
  121. trainer.lr_scheduler.step(metrics=metrics)
  122. @HOOKS.register_module(module_name=Hooks.PlateauLrSchedulerHook)
  123. class PlateauLrSchedulerHook(Hook):
  124. """Lr scheduler hook for `ReduceLROnPlateau`.
  125. Args:
  126. metric_key (str): Metric key returned from `trainer.metric_values`,
  127. get the value of metric key and pass it to `ReduceLROnPlateau.step`.
  128. """
  129. PRIORITY = Priority.LOW # should be after EvaluationHook
  130. def __init__(self, metric_key, **kwargs):
  131. super().__init__()
  132. self.metric_key = metric_key
  133. def register_processor(self, trainer):
  134. lr_scheduler_hook = trainer.get_hook(LrSchedulerHook)
  135. if len(lr_scheduler_hook) > 0 and type(
  136. lr_scheduler_hook[0].processor) in (type(None),
  137. LrSchedulerProcessor):
  138. lr_scheduler_hook[0].set_processor(
  139. PlateauLrSchedulerProcessor(self.metric_key))
  140. def before_run(self, trainer):
  141. if not hasattr(trainer, 'logger'):
  142. self.logger = get_logger()
  143. else:
  144. self.logger = trainer.logger
  145. @HOOKS.register_module(module_name=Hooks.NoneLrSchedulerHook)
  146. class NoneLrSchedulerHook(LrSchedulerHook):
  147. PRIORITY = Priority.LOW # should be after EvaluationHook
  148. def __init__(self, by_epoch=True, warmup=None) -> None:
  149. super().__init__(by_epoch=by_epoch, warmup=warmup)
  150. def before_run(self, trainer):
  151. return
  152. def after_train_epoch(self, trainer):
  153. return