| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 |
- # Copyright (c) OpenMMLab. All rights reserved.
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from functools import wraps
- from modelscope.utils.constant import TrainerStages
- from modelscope.utils.import_utils import is_method_overridden
- from .priority import Priority
- class Hook:
- """
- The Hook base class of any modelscope trainer. You can build your own hook inherited from this class.
- """
- stages = (TrainerStages.after_init, TrainerStages.before_run,
- TrainerStages.before_val, TrainerStages.before_train_epoch,
- TrainerStages.before_train_iter, TrainerStages.after_train_iter,
- TrainerStages.after_train_epoch, TrainerStages.before_val_epoch,
- TrainerStages.before_val_iter, TrainerStages.after_val_iter,
- TrainerStages.after_val_epoch, TrainerStages.after_run,
- TrainerStages.after_val)
- PRIORITY = Priority.NORMAL
- def after_init(self, trainer):
- """
- Will be called at the end of the trainer's `__init__` method
- """
- pass
- def before_run(self, trainer):
- """
- Will be called before trainer loop begins.
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- pass
- def after_run(self, trainer):
- """
- Will be called after trainer loop end.
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- pass
- def before_val(self, trainer):
- """
- Will be called before eval loop begins.
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- pass
- def after_val(self, trainer):
- """
- Will be called after eval loop end.
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- pass
- def before_epoch(self, trainer):
- """
- Will be called before every epoch begins.
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- pass
- def after_epoch(self, trainer):
- """
- Will be called after every epoch ends.
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- pass
- def before_iter(self, trainer):
- """
- Will be called before every loop begins.
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- pass
- def after_iter(self, trainer):
- """
- Will be called after every loop ends.
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- pass
- def before_train_epoch(self, trainer):
- """
- Will be called before every train epoch begins. Default call ``self.before_epoch``
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- self.before_epoch(trainer)
- def before_val_epoch(self, trainer):
- """
- Will be called before every validation epoch begins. Default call ``self.before_epoch``
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- self.before_epoch(trainer)
- def after_train_epoch(self, trainer):
- """
- Will be called after every train epoch ends. Default call ``self.after_epoch``
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- self.after_epoch(trainer)
- def after_val_epoch(self, trainer):
- """
- Will be called after every validation epoch ends. Default call ``self.after_epoch``
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- self.after_epoch(trainer)
- def before_train_iter(self, trainer):
- """
- Will be called before every train loop begins. Default call ``self.before_iter``
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- self.before_iter(trainer)
- def before_val_iter(self, trainer):
- """
- Will be called before every validation loop begins. Default call ``self.before_iter``
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- self.before_iter(trainer)
- def after_train_iter(self, trainer):
- """
- Will be called after every train loop ends. Default call ``self.after_iter``
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- self.after_iter(trainer)
- def after_val_iter(self, trainer):
- """
- Will be called after every validation loop ends. Default call ``self.after_iter``
- Args:
- trainer: The trainer instance.
- Returns: None
- """
- self.after_iter(trainer)
- @staticmethod
- def every_n_epochs(trainer, n):
- """
- Whether to reach every ``n`` epochs
- Returns: bool
- """
- return (trainer.epoch + 1) % n == 0 if n > 0 else False
- @staticmethod
- def every_n_inner_iters(runner, n):
- """
- Whether to reach every ``n`` iterations at every epoch
- Returns: bool
- """
- return (runner.inner_iter + 1) % n == 0 if n > 0 else False
- @staticmethod
- def every_n_iters(trainer, n):
- """
- Whether to reach every ``n`` iterations
- Returns: bool
- """
- return (trainer.iter + 1) % n == 0 if n > 0 else False
- @staticmethod
- def end_of_epoch(trainer):
- """
- Whether to reach the end of every epoch
- Returns: bool
- """
- return trainer.inner_iter + 1 == trainer.iters_per_epoch
- @staticmethod
- def is_last_epoch(trainer):
- """
- Whether to reach the last epoch
- Returns: bool
- """
- return trainer.epoch + 1 == trainer.max_epochs
- @staticmethod
- def is_last_iter(trainer):
- """
- Whether to reach the last iteration in the entire training process
- Returns: bool
- """
- return trainer.iter + 1 == trainer.max_iters
- def get_triggered_stages(self):
- trigger_stages = set()
- for stage in Hook.stages:
- if is_method_overridden(stage, Hook, self):
- trigger_stages.add(stage)
- return [stage for stage in Hook.stages if stage in trigger_stages]
- def state_dict(self):
- return {}
- def load_state_dict(self, state_dict):
- pass
|