| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from modelscope.metainfo import Hooks
- from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
- from modelscope.utils.constant import LogKeys
- from modelscope.utils.logger import get_logger
- from modelscope.utils.torch_utils import is_master
- from .builder import HOOKS
- from .hook import Hook
- from .priority import Priority
- class LrSchedulerProcessor:
- def __init__(self):
- self.lr_strategy = None
- self.warmup_lr_scheduler = None
- def set_lr_strategy(self, lr_strategy):
- self.lr_strategy = lr_strategy
- def set_warmup_lr_scheduler(self, warmup_lr_scheduler):
- self.warmup_lr_scheduler = warmup_lr_scheduler
- def initialize_lr_scheduler(self, trainer):
- """Initialize the lr scheduler.
- This is a strategic function which can be registered by other hook's function.
- """
- pass
- def step(self, trainer):
- """Do lr scheduler's step.
- This is a strategic function which can be registered by other hook's function.
- """
- if self.warmup_lr_scheduler is not None:
- self.warmup_lr_scheduler.step()
- else:
- trainer.lr_scheduler.step()
- def get_current_lr(self, trainer):
- import torch
- if isinstance(trainer.optimizer, torch.optim.Optimizer):
- lr = [group['lr'] for group in trainer.optimizer.param_groups]
- elif isinstance(trainer.optimizer, dict):
- lr = dict()
- for name, optim in trainer.optimizer.items():
- lr[name] = [group['lr'] for group in optim.param_groups]
- else:
- raise RuntimeError(
- 'lr is not applicable because optimizer does not exist.')
- return lr
- class LrStrategy:
- by_epoch = 'by_epoch'
- by_step = 'by_step'
- no = 'no'
- @HOOKS.register_module(module_name=Hooks.LrSchedulerHook)
- class LrSchedulerHook(Hook):
- """Lr scheduler.
- Args:
- by_epoch (bool): Whether lr changes by epoch
- warmup (dict): warm up config
- """
- PRIORITY = Priority.LOW
- def __init__(self,
- lr_strategy=LrStrategy.by_epoch,
- warmup=None,
- **kwargs) -> None:
- super().__init__()
- if 'by_epoch' in kwargs:
- self.lr_strategy = LrStrategy.by_epoch if kwargs[
- 'by_epoch'] else LrStrategy.by_step
- else:
- self.lr_strategy = lr_strategy
- self.warmup = warmup
- self.warmup_lr_scheduler = None
- self.processor = LrSchedulerProcessor()
- def set_processor(self, processor):
- self.processor = processor
- def before_run(self, trainer):
- self.processor.set_lr_strategy(self.lr_strategy)
- if self.warmup is not None:
- assert isinstance(self.warmup, dict) and 'type' in self.warmup
- self.warmup_lr_scheduler = build_lr_scheduler(
- cfg=self.warmup,
- default_args={'base_scheduler': trainer.lr_scheduler})
- self.processor.set_warmup_lr_scheduler(self.warmup_lr_scheduler)
- self.processor.initialize_lr_scheduler(trainer)
- def after_train_iter(self, trainer):
- if self.lr_strategy == LrStrategy.by_step and trainer.iter >= getattr(
- trainer, 'cumulative_iters', 1) - 1:
- self.processor.step(trainer)
- trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer)
- def before_train_epoch(self, trainer):
- trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer)
- def after_train_epoch(self, trainer):
- if self.lr_strategy == LrStrategy.by_epoch:
- self.processor.step(trainer)
- def _get_log_lr(self, trainer):
- # forward compatibility with AddLrLogHook in EasyCV
- if not hasattr(self, 'processor'):
- self.processor = LrSchedulerProcessor()
- cur_lr = self.processor.get_current_lr(trainer)
- # only record lr of the first param group
- if isinstance(cur_lr, list):
- lr = cur_lr[0]
- else:
- assert isinstance(cur_lr, dict)
- lr = {}
- for k, lr_ in cur_lr.items():
- assert isinstance(lr_, list)
- lr.update({k: lr_[0]})
- return lr
- class PlateauLrSchedulerProcessor(LrSchedulerProcessor):
- def __init__(self, metric_key):
- super().__init__()
- self.metric_key = metric_key
- def step(self, trainer):
- # adapt to evaluation interval is greater than 1
- if trainer.metric_values is None:
- if is_master():
- print(
- f'Current epoch {trainer.epoch} has no evaluation metric values, skip lr_scheduler.step() !'
- )
- return
- metrics = trainer.metric_values[self.metric_key]
- if self.lr_strategy == LrStrategy.by_epoch:
- if self.warmup_lr_scheduler is not None:
- self.warmup_lr_scheduler.step(metrics=metrics)
- else:
- trainer.lr_scheduler.step(metrics=metrics)
- @HOOKS.register_module(module_name=Hooks.PlateauLrSchedulerHook)
- class PlateauLrSchedulerHook(Hook):
- """Lr scheduler hook for `ReduceLROnPlateau`.
- Args:
- metric_key (str): Metric key returned from `trainer.metric_values`,
- get the value of metric key and pass it to `ReduceLROnPlateau.step`.
- """
- PRIORITY = Priority.LOW # should be after EvaluationHook
- def __init__(self, metric_key, **kwargs):
- super().__init__()
- self.metric_key = metric_key
- def register_processor(self, trainer):
- lr_scheduler_hook = trainer.get_hook(LrSchedulerHook)
- if len(lr_scheduler_hook) > 0 and type(
- lr_scheduler_hook[0].processor) in (type(None),
- LrSchedulerProcessor):
- lr_scheduler_hook[0].set_processor(
- PlateauLrSchedulerProcessor(self.metric_key))
- def before_run(self, trainer):
- if not hasattr(trainer, 'logger'):
- self.logger = get_logger()
- else:
- self.logger = trainer.logger
- @HOOKS.register_module(module_name=Hooks.NoneLrSchedulerHook)
- class NoneLrSchedulerHook(LrSchedulerHook):
- PRIORITY = Priority.LOW # should be after EvaluationHook
- def __init__(self, by_epoch=True, warmup=None) -> None:
- super().__init__(by_epoch=by_epoch, warmup=warmup)
- def before_run(self, trainer):
- return
- def after_train_epoch(self, trainer):
- return
|