| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import numpy as np
- from modelscope.metainfo import Hooks
- from modelscope.utils.logger import get_logger
- from .builder import HOOKS
- from .hook import Hook
- from .priority import Priority
- class EarlyStopStrategy:
- by_epoch = 'by_epoch'
- by_step = 'by_step'
- no = 'no'
- @HOOKS.register_module(module_name=Hooks.EarlyStopHook)
- class EarlyStopHook(Hook):
- """Early stop when a specific metric stops improving.
- Args:
- metric_key (str): Metric key to be monitored.
- rule (str): Comparison rule for best score. Support "max" and "min".
- If rule is "max", the training will stop when `metric_key` has stopped increasing.
- If rule is "min", the training will stop when `metric_key` has stopped decreasing.
- patience (int): Trainer will stop if the monitored metric did not improve for the last `patience` times.
- min_delta (float): Minimum change in the monitored metric to qualify as an improvement.
- check_finite (bool): If true, stops training when the metric becomes NaN or infinite.
- early_stop_strategy (str): The strategy to early stop, can be by_epoch/by_step/none
- interval (int): The frequency to trigger early stop check, by epoch or step.
- """
- PRIORITY = Priority.VERY_LOW
- rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y}
- def __init__(self,
- metric_key: str,
- rule: str = 'max',
- patience: int = 3,
- min_delta: float = 0.0,
- check_finite: bool = True,
- early_stop_strategy: str = EarlyStopStrategy.by_epoch,
- interval: int = 1,
- **kwargs):
- self.metric_key = metric_key
- self.rule = rule
- self.patience = patience
- self.min_delta = min_delta
- self.check_finite = check_finite
- if 'by_epoch' in kwargs:
- self.early_stop_strategy = EarlyStopStrategy.by_epoch if kwargs[
- 'by_epoch'] else EarlyStopStrategy.by_step
- else:
- self.early_stop_strategy = early_stop_strategy
- self.interval = interval
- self.wait_count = 0
- self.best_score = float('inf') if rule == 'min' else -float('inf')
- def before_run(self, trainer):
- if not hasattr(trainer, 'logger'):
- self.logger = get_logger()
- else:
- self.logger = trainer.logger
- def _should_stop(self, trainer):
- metric_values = trainer.metric_values
- if metric_values is None:
- return False
- if self.metric_key not in metric_values:
- raise ValueError(
- f'Metric not found: {self.metric_key} not in {metric_values}')
- should_stop = False
- current_score = metric_values[self.metric_key]
- if self.check_finite and not np.isfinite(current_score):
- should_stop = True
- self.logger.warning(
- f'Metric {self.metric_key} = {current_score} is not finite. '
- f'Previous best metric: {self.best_score:.4f}.')
- elif self.rule_map[self.rule](current_score - self.min_delta,
- self.best_score):
- self.best_score = current_score
- self.wait_count = 0
- else:
- self.wait_count += 1
- if self.wait_count >= self.patience:
- should_stop = True
- self.logger.info(
- f'Metric {self.metric_key} did not improve in the last {self.wait_count} epochs or iterations. '
- f'Best score: {self.best_score:.4f}.')
- return should_stop
- def _stop_training(self, trainer):
- self.logger.info('Early Stopping!')
- trainer._stop_training = True
- def after_train_epoch(self, trainer):
- if self.early_stop_strategy != EarlyStopStrategy.by_epoch:
- return
- if not self.every_n_epochs(trainer, self.interval):
- return
- if self._should_stop(trainer):
- self._stop_training(trainer)
- def after_train_iter(self, trainer):
- if self.early_stop_strategy != EarlyStopStrategy.by_step:
- return
- if not self.every_n_iters(trainer, self.interval):
- return
- if self._should_stop(trainer):
- self._stop_training(trainer)
|