early_stop_hook.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import numpy as np
  3. from modelscope.metainfo import Hooks
  4. from modelscope.utils.logger import get_logger
  5. from .builder import HOOKS
  6. from .hook import Hook
  7. from .priority import Priority
  8. class EarlyStopStrategy:
  9. by_epoch = 'by_epoch'
  10. by_step = 'by_step'
  11. no = 'no'
  12. @HOOKS.register_module(module_name=Hooks.EarlyStopHook)
  13. class EarlyStopHook(Hook):
  14. """Early stop when a specific metric stops improving.
  15. Args:
  16. metric_key (str): Metric key to be monitored.
  17. rule (str): Comparison rule for best score. Support "max" and "min".
  18. If rule is "max", the training will stop when `metric_key` has stopped increasing.
  19. If rule is "min", the training will stop when `metric_key` has stopped decreasing.
  20. patience (int): Trainer will stop if the monitored metric did not improve for the last `patience` times.
  21. min_delta (float): Minimum change in the monitored metric to qualify as an improvement.
  22. check_finite (bool): If true, stops training when the metric becomes NaN or infinite.
  23. early_stop_strategy (str): The strategy to early stop, can be by_epoch/by_step/none
  24. interval (int): The frequency to trigger early stop check, by epoch or step.
  25. """
  26. PRIORITY = Priority.VERY_LOW
  27. rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y}
  28. def __init__(self,
  29. metric_key: str,
  30. rule: str = 'max',
  31. patience: int = 3,
  32. min_delta: float = 0.0,
  33. check_finite: bool = True,
  34. early_stop_strategy: str = EarlyStopStrategy.by_epoch,
  35. interval: int = 1,
  36. **kwargs):
  37. self.metric_key = metric_key
  38. self.rule = rule
  39. self.patience = patience
  40. self.min_delta = min_delta
  41. self.check_finite = check_finite
  42. if 'by_epoch' in kwargs:
  43. self.early_stop_strategy = EarlyStopStrategy.by_epoch if kwargs[
  44. 'by_epoch'] else EarlyStopStrategy.by_step
  45. else:
  46. self.early_stop_strategy = early_stop_strategy
  47. self.interval = interval
  48. self.wait_count = 0
  49. self.best_score = float('inf') if rule == 'min' else -float('inf')
  50. def before_run(self, trainer):
  51. if not hasattr(trainer, 'logger'):
  52. self.logger = get_logger()
  53. else:
  54. self.logger = trainer.logger
  55. def _should_stop(self, trainer):
  56. metric_values = trainer.metric_values
  57. if metric_values is None:
  58. return False
  59. if self.metric_key not in metric_values:
  60. raise ValueError(
  61. f'Metric not found: {self.metric_key} not in {metric_values}')
  62. should_stop = False
  63. current_score = metric_values[self.metric_key]
  64. if self.check_finite and not np.isfinite(current_score):
  65. should_stop = True
  66. self.logger.warning(
  67. f'Metric {self.metric_key} = {current_score} is not finite. '
  68. f'Previous best metric: {self.best_score:.4f}.')
  69. elif self.rule_map[self.rule](current_score - self.min_delta,
  70. self.best_score):
  71. self.best_score = current_score
  72. self.wait_count = 0
  73. else:
  74. self.wait_count += 1
  75. if self.wait_count >= self.patience:
  76. should_stop = True
  77. self.logger.info(
  78. f'Metric {self.metric_key} did not improve in the last {self.wait_count} epochs or iterations. '
  79. f'Best score: {self.best_score:.4f}.')
  80. return should_stop
  81. def _stop_training(self, trainer):
  82. self.logger.info('Early Stopping!')
  83. trainer._stop_training = True
  84. def after_train_epoch(self, trainer):
  85. if self.early_stop_strategy != EarlyStopStrategy.by_epoch:
  86. return
  87. if not self.every_n_epochs(trainer, self.interval):
  88. return
  89. if self._should_stop(trainer):
  90. self._stop_training(trainer)
  91. def after_train_iter(self, trainer):
  92. if self.early_stop_strategy != EarlyStopStrategy.by_step:
  93. return
  94. if not self.every_n_iters(trainer, self.interval):
  95. return
  96. if self._should_stop(trainer):
  97. self._stop_training(trainer)