evaluation_hook.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from collections import OrderedDict
  3. from typing import Optional
  4. from modelscope.metainfo import Hooks
  5. from .builder import HOOKS
  6. from .hook import Hook
  7. class EvaluationStrategy:
  8. by_epoch = 'by_epoch'
  9. by_step = 'by_step'
  10. no = 'no'
  11. @HOOKS.register_module(module_name=Hooks.EvaluationHook)
  12. class EvaluationHook(Hook):
  13. """
  14. Evaluation hook.
  15. Args:
  16. interval (int): Evaluation interval.
  17. by_epoch (bool): Evaluate by epoch or by iteration.
  18. start_idx (int or None, optional): The epoch or iterations validation begins.
  19. Default: None, validate every interval epochs/iterations from scratch.
  20. """
  21. def __init__(self,
  22. interval: Optional[int] = 1,
  23. eval_strategy: Optional[str] = EvaluationStrategy.by_epoch,
  24. start_idx: Optional[int] = None,
  25. **kwargs):
  26. assert interval > 0, 'interval must be a positive number'
  27. self.interval = interval
  28. self.start_idx = start_idx
  29. self.last_eval_tag = (None, None)
  30. if 'by_epoch' in kwargs:
  31. self.eval_strategy = EvaluationStrategy.by_epoch if kwargs[
  32. 'by_epoch'] else EvaluationStrategy.by_step
  33. else:
  34. self.eval_strategy = eval_strategy
  35. def after_train_iter(self, trainer):
  36. """Called after every training iter to evaluate the results."""
  37. if self.eval_strategy == EvaluationStrategy.by_step and self._should_evaluate(
  38. trainer):
  39. self.do_evaluate(trainer)
  40. self.last_eval_tag = ('iter', trainer.iter)
  41. def after_train_epoch(self, trainer):
  42. """Called after every training epoch to evaluate the results."""
  43. if self.eval_strategy == EvaluationStrategy.by_epoch and self._should_evaluate(
  44. trainer):
  45. self.do_evaluate(trainer)
  46. self.last_eval_tag = ('epoch', trainer.epoch)
  47. def add_visualization_info(self, trainer, results):
  48. if trainer.visualization_buffer.output.get('eval_results',
  49. None) is None:
  50. trainer.visualization_buffer.output['eval_results'] = OrderedDict()
  51. trainer.visualization_buffer.output['eval_results'].update(
  52. trainer.visualize(results))
  53. def do_evaluate(self, trainer):
  54. """Evaluate the results."""
  55. eval_res = trainer.evaluate()
  56. for name, val in eval_res.items():
  57. trainer.log_buffer.output['evaluation/' + name] = val
  58. trainer.log_buffer.ready = True
  59. def _should_evaluate(self, trainer):
  60. """Judge whether to perform evaluation.
  61. Here is the rule to judge whether to perform evaluation:
  62. 1. It will not perform evaluation during the epoch/iteration interval,
  63. which is determined by ``self.interval``.
  64. 2. It will not perform evaluation if the ``start_idx`` is larger than
  65. current epochs/iters.
  66. 3. It will not perform evaluation when current epochs/iters is larger than
  67. the ``start_idx`` but during epoch/iteration interval.
  68. Returns:
  69. bool: The flag indicating whether to perform evaluation.
  70. """
  71. if self.eval_strategy == EvaluationStrategy.by_epoch:
  72. current = trainer.epoch
  73. check_time = self.every_n_epochs
  74. else:
  75. current = trainer.iter
  76. check_time = self.every_n_iters
  77. if self.start_idx is None:
  78. if not check_time(trainer, self.interval):
  79. return False
  80. elif (current + 1) < self.start_idx:
  81. return False
  82. else:
  83. if (current + 1 - self.start_idx) % self.interval:
  84. return False
  85. return True