| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import random
- import shutil
- from typing import Optional
- import json
- import numpy as np
- import torch
- from modelscope.hub.check_model import check_model_is_id
- from modelscope.hub.push_to_hub import (UploadStrategy, push_to_hub_in_queue,
- wait_for_done)
- from modelscope.metainfo import Hooks
- from modelscope.trainers.hooks.builder import HOOKS
- from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
- CheckpointProcessor
- from modelscope.trainers.hooks.hook import Hook
- from modelscope.trainers.hooks.priority import Priority
- from modelscope.utils.constant import (DEFAULT_REPOSITORY_REVISION, LogKeys,
- ModelFile)
- from modelscope.utils.logger import get_logger
- from modelscope.utils.torch_utils import is_master
- class CheckpointStrategy:
- by_epoch = 'by_epoch'
- by_step = 'by_step'
- no = 'no'
- @HOOKS.register_module(module_name=Hooks.CheckpointHook)
- class CheckpointHook(Hook):
- """Save checkpoints periodically.
- Args:
- save_strategy(str): The strategy to save checkpoint, can be `by_epoch`, `by_step` or `no`
- interval (int): The frequency to save model. If `by_epoch=True`,
- it means the number of epochs, else means the number of iterations
- save_dir (str): The directory to save checkpoints. If is None, use `trainer.work_dir`
- output_dir (str): The absolute path to save the output files for inference. If it's not specified,
- the default dir is `{sub_dir}/output`.
- save_last (bool): Whether to save the last checkpoint. Default: True.
- max_checkpoint_num (int): The max number of checkpoint files, default None which means never delete anything.
- If the number exceeding the limit, earlier checkpoints will be deleted first.
- push_to_hub (bool): Whether push the checkpoint to modelhub.
- hub_repo_id (str): The hub repo id.
- hub_token (str): The token of the modelhub. You can also set the environment variable `MODELSCOPE_API_TOKEN`.
- private_hub (bool): Whether push to a private hub, default True.
- hub_revision (str): Which branch to push the model to, default is `master`.
- upload_strategy (str): The action adopted when the previous uploading is not done
- and the next one is coming, can be `cancel` or `wait`.
- save_trainer_state (bool): Save the trainer state for continue training, default True.
- kwargs:
- by_epoch (bool): Same with `save_strategy`, but has a higher priority, legacy argument.
- output_sub_dir (str): The folder under the `save_dir` to save the output checkpoint for inference.
- This argument is kept to fit the existing configs.
- """
- PRIORITY = Priority.LOW
- EVAL_RESULT_FILE = 'eval_result.txt'
- PUSH_TO_HUB_QUEUE_NAME = 'train.checkpoint'
- def __init__(self,
- save_strategy: Optional[str] = CheckpointStrategy.by_epoch,
- interval: Optional[int] = 0,
- save_dir: Optional[str] = None,
- output_dir: Optional[str] = None,
- save_last: Optional[bool] = True,
- max_checkpoint_num: Optional[int] = None,
- push_to_hub: Optional[bool] = False,
- hub_repo_id: Optional[str] = None,
- hub_token: Optional[str] = None,
- private_hub: Optional[bool] = True,
- hub_revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
- upload_strategy: Optional[str] = UploadStrategy.cancel,
- save_trainer_state: bool = True,
- **kwargs):
- self.interval = interval
- self.save_dir = save_dir
- if 'by_epoch' in kwargs:
- self.save_strategy = CheckpointStrategy.by_epoch if kwargs[
- 'by_epoch'] else CheckpointStrategy.by_step
- else:
- self.save_strategy = save_strategy
- if 'output_sub_dir' in kwargs:
- self.output_sub_dir = kwargs['output_sub_dir']
- self.output_dir = None
- else:
- self.output_sub_dir = None
- self.output_dir = output_dir
- self.save_last = save_last
- self.rng_state = None
- self.push_to_hub = push_to_hub
- self.hub_repo_id = hub_repo_id
- self.hub_token = hub_token
- self.private_hub = private_hub
- self.hub_revision = hub_revision
- self.upload_strategy = upload_strategy
- self.save_trainer_state = save_trainer_state
- self.tag = -1
- self.is_model_id = None
- self.max_checkpoint_num = None
- if max_checkpoint_num is not None:
- self.max_checkpoint_num = max(int(max_checkpoint_num), 1)
- self.history_checkpoints = []
- self.processor = CheckpointProcessor()
- def set_processor(self, processor):
- """
- The checkpoint hook accepts a processor to finish the actual saving/deleting action.
- """
- self.processor = processor
- def before_run(self, trainer):
- self.tag = -1
- if not self.save_dir:
- self.save_dir = trainer.work_dir
- if not self.output_dir:
- if self.output_sub_dir:
- self.output_dir = os.path.join(self.save_dir,
- self.output_sub_dir)
- else:
- self.output_dir = os.path.join(self.save_dir,
- ModelFile.TRAIN_OUTPUT_DIR)
- if not os.path.exists(self.save_dir):
- os.makedirs(self.save_dir, exist_ok=True)
- if not hasattr(trainer, 'logger'):
- self.logger = get_logger()
- else:
- self.logger = trainer.logger
- if is_master():
- output_dir = self.output_dir
- # only global master prepares the output folder
- self.processor.prepare_output(trainer, output_dir)
- self.logger.info(f'Checkpoints will be saved to {self.save_dir}')
- def generate_prefix(self, trainer, save_strategy):
- if save_strategy == CheckpointStrategy.by_epoch:
- return f'{LogKeys.EPOCH}_{trainer.epoch + 1}'
- else:
- return f'{LogKeys.ITER}_{trainer.iter + 1}'
- def _do_save(self, trainer, save_strategy):
- # prefix like 'epoch-1' or 'iter-1'
- prefix = self.generate_prefix(trainer, save_strategy)
- if self.processor.should_save_on_rank(trainer):
- if is_master():
- if save_strategy == CheckpointStrategy.by_epoch:
- self.logger.info(
- f'Saving checkpoint at {trainer.epoch + 1} epoch')
- else:
- self.logger.info(
- f'Saving checkpoint at {trainer.iter + 1} iter')
- self._save_checkpoint(trainer, prefix)
- if is_master() and self.push_to_hub:
- if self.upload_strategy == UploadStrategy.cancel:
- output_dir = self.output_dir
- delete_dir = False
- else:
- output_dir = self.output_dir + '_upload_' + prefix
- shutil.copytree(
- self.output_dir, output_dir, dirs_exist_ok=True)
- delete_dir = True
- self._push_to_hub(trainer, prefix, output_dir, delete_dir)
- def after_train_epoch(self, trainer):
- if self.save_strategy != CheckpointStrategy.by_epoch:
- return
- if self._should_save(trainer):
- self._do_save(trainer, CheckpointStrategy.by_epoch)
- def after_train_iter(self, trainer):
- if self.save_strategy != CheckpointStrategy.by_step:
- return
- if self._should_save(trainer):
- self._do_save(trainer, CheckpointStrategy.by_step)
- def after_run(self, trainer):
- self.logger.info('Train finished. Uploading models, waiting...')
- push_to_hub_in_queue(
- self.PUSH_TO_HUB_QUEUE_NAME,
- strategy=self.upload_strategy,
- done=True)
- wait_for_done(self.PUSH_TO_HUB_QUEUE_NAME)
- if self.push_to_hub:
- self.logger.info('Uploading models done.')
- def _push_to_hub(self, trainer, prefix, output_dir, delete_dir=False):
- if self.is_model_id is None:
- self.is_model_id = check_model_is_id(trainer.input_model_id,
- self.hub_token)
- self.tag += 1
- return push_to_hub_in_queue(
- self.PUSH_TO_HUB_QUEUE_NAME,
- strategy=self.upload_strategy,
- repo_name=self.hub_repo_id,
- output_dir=output_dir,
- token=self.hub_token,
- private=self.private_hub,
- commit_message=prefix,
- tag=f'v1.{self.tag}',
- revision=self.hub_revision,
- source_repo=trainer.input_model_id if self.is_model_id else '',
- delete_dir=delete_dir)
- def save_evaluate_results(self, trainer):
- with open(os.path.join(self.output_dir, self.EVAL_RESULT_FILE),
- 'w') as f:
- f.write(json.dumps(trainer.metric_values))
- def _save_checkpoint(self, trainer, prefix):
- """Save checkpoint files and remove obsolete ones
- """
- checkpoint_path_prefix = os.path.join(self.save_dir, prefix)
- meta = self._create_training_state(trainer)
- self.processor.save_checkpoints(trainer, checkpoint_path_prefix,
- self.output_dir, meta,
- self.save_trainer_state)
- self.save_evaluate_results(trainer)
- self.history_checkpoints.append(checkpoint_path_prefix)
- self._remove_obsolete_checkpoints(trainer)
- return prefix
- def _remove_obsolete_checkpoints(self, trainer):
- if self.max_checkpoint_num is not None and \
- len(self.history_checkpoints) > self.max_checkpoint_num:
- history_checkpoints = [ckpt for ckpt in self.history_checkpoints]
- self.history_checkpoints.clear()
- for i, checkpoint_path_prefix in enumerate(history_checkpoints):
- if i < len(history_checkpoints) - self.max_checkpoint_num:
- self.logger.info(
- f'deleting checkpoint: {checkpoint_path_prefix}')
- self.processor.remove_checkpoints(
- trainer, checkpoint_path_prefix=checkpoint_path_prefix)
- else:
- self.history_checkpoints.append(checkpoint_path_prefix)
- def _should_save(self, trainer):
- if self.save_strategy == CheckpointStrategy.by_epoch:
- check_last = self.is_last_epoch
- check_frequency = self.every_n_epochs
- elif self.save_strategy == CheckpointStrategy.by_step:
- check_last = self.is_last_iter
- check_frequency = self.every_n_iters
- else:
- return False
- if check_frequency(trainer,
- self.interval) or (self.save_last
- and check_last(trainer)):
- return True
- return False
- def _create_training_state(self, trainer):
- self.rng_state = {
- 'random': random.getstate(),
- 'numpy': np.random.get_state(),
- 'cpu': torch.random.get_rng_state(),
- 'cuda': torch.cuda.get_rng_state_all(),
- }
- # keep epoch/iter/inner_iter/random_state
- meta = {
- 'epoch': trainer.epoch,
- 'iter': trainer.iter + 1,
- 'inner_iter': trainer.inner_iter + 1,
- 'rng_state': self.rng_state,
- }
- # keep hooks state
- i = 0
- for hook in trainer.hooks:
- if hasattr(hook, 'state_dict') and getattr(hook, '_should_save',
- True):
- meta[f'{hook.__class__}-{i}'] = hook.state_dict()
- i += 1
- return meta
- @HOOKS.register_module(module_name=Hooks.BestCkptSaverHook)
- class BestCkptSaverHook(CheckpointHook):
- """
- Save best checkpoints hook.
- Args:
- metric_key (str): Metric key to compare rule for best score.
- save_best(bool): Save the best checkpoint, if set to False, this hook will have no effect.
- rule (str): Comparison rule for best score. Support "max" and "min". If rule is "max", the checkpoint
- at the maximum `metric_key` will be saved, If rule is "min", the checkpoint at the minimum `metric_key`
- will be saved.
- save_file_name: The manual specified saving file name.
- restore_best (bool): Whether to restore the best checkpoint after training.
- max_checkpoint_num (int): The max number of checkpoint files, default None which means never delete anything.
- If the number exceeding the limit, checkpoints with worse metric will be deleted, which is judged by the
- `rule` and `metric_key` arguments.
- save_trainer_state (bool): Save the trainer state for continue training, default True.
- The `BestCkptSaverHook` class accepts `output_sub_dir` and `output_dir` argument as its super class do.
- If neither of them are passed, the default value is `{save_dir}/output_best`.
- This class will not accept the `interval` or `save_strategy` or `by_epoch` argument, because the saving interval
- will follow the `EvaluationHook`.
- """
- PRIORITY = Priority.LOW
- rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y}
- def __init__(self,
- metric_key: str,
- save_best: Optional[bool] = True,
- rule: Optional[str] = 'max',
- save_file_name: Optional[str] = None,
- restore_best: Optional[bool] = False,
- max_checkpoint_num: Optional[int] = 1,
- save_trainer_state: bool = True,
- **kwargs):
- assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.'
- output_kwargs = {}
- if 'output_sub_dir' not in kwargs and 'output_dir' not in kwargs:
- output_kwargs['output_sub_dir'] = ModelFile.TRAIN_BEST_OUTPUT_DIR
- kwargs.pop('interval', None)
- kwargs.pop('save_strategy', None)
- super().__init__(
- max_checkpoint_num=max_checkpoint_num,
- save_trainer_state=save_trainer_state,
- **kwargs,
- **output_kwargs,
- )
- self.save_best = save_best
- self.metric_key = metric_key
- self.rule = rule
- self._best_metric = None
- self._best_ckpt_file = None
- self.save_file_name = save_file_name
- self.restore_best = restore_best
- self.history_checkpoints = set()
- def after_train_epoch(self, trainer):
- from modelscope.trainers.hooks import EvaluationHook
- eval_hook = trainer.get_hook(EvaluationHook)
- if len(eval_hook) == 0:
- self.logger.error(
- 'Trying to save the best checkpoint, but there is no evaluation, skipping.'
- )
- if eval_hook[0].last_eval_tag == (
- 'epoch', trainer.epoch) and self._should_save(trainer):
- self._do_save(trainer, 'by_epoch')
- def after_train_iter(self, trainer):
- from modelscope.trainers.hooks import EvaluationHook
- eval_hook = trainer.get_hook(EvaluationHook)
- if len(eval_hook) == 0:
- self.logger.error(
- 'Trying to save the best checkpoint, but there is no evaluation, skipping.'
- )
- if eval_hook[0].last_eval_tag == (
- 'iter', trainer.iter) and self._should_save(trainer):
- self._do_save(trainer, 'by_step')
- def _should_save(self, trainer):
- return self.save_best and self._is_best_metric(trainer.metric_values)
- def _is_best_metric(self, metric_values):
- if metric_values is None:
- return False
- if self.metric_key not in metric_values:
- raise ValueError(
- f'Not find metric_key: {self.metric_key} in {metric_values}')
- if self._best_metric is None:
- self._best_metric = metric_values[self.metric_key]
- return True
- else:
- compare_fn = self.rule_map[self.rule]
- if compare_fn(metric_values[self.metric_key], self._best_metric):
- self._best_metric = metric_values[self.metric_key]
- return True
- return False
- def generate_prefix(self, trainer, save_strategy):
- if save_strategy == CheckpointStrategy.by_epoch:
- return f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}'
- else:
- return f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}'
- def _save_checkpoint(self, trainer, prefix):
- checkpoint_path_prefix = self.save_file_name
- if checkpoint_path_prefix is None:
- checkpoint_path_prefix = os.path.join(self.save_dir, prefix)
- else:
- checkpoint_path_prefix = os.path.join(self.save_dir,
- checkpoint_path_prefix)
- self._best_ckpt_file = checkpoint_path_prefix
- meta = self._create_training_state(trainer)
- self.processor.save_checkpoints(trainer, checkpoint_path_prefix,
- self.output_dir, meta,
- self.save_trainer_state)
- self.save_evaluate_results(trainer)
- self.history_checkpoints.add(checkpoint_path_prefix)
- self._remove_obsolete_checkpoints(trainer)
- return prefix
- def _remove_obsolete_checkpoints(self, trainer):
- def extract_metric_from_filename(name1):
- metric1 = float(name1.split(self.metric_key)[1])
- if self.rule == 'max':
- return -metric1
- else:
- return metric1
- if self.max_checkpoint_num is not None and \
- len(self.history_checkpoints) > self.max_checkpoint_num:
- history_checkpoints = sorted(
- self.history_checkpoints, key=extract_metric_from_filename)
- self.history_checkpoints.clear()
- for i, checkpoint_path_prefix in enumerate(history_checkpoints):
- if i < self.max_checkpoint_num:
- self.history_checkpoints.add(checkpoint_path_prefix)
- else:
- self.logger.info(
- f'deleting checkpoint: {checkpoint_path_prefix}')
- self.processor.remove_checkpoints(
- trainer, checkpoint_path_prefix=checkpoint_path_prefix)
- def state_dict(self):
- return {
- 'best_metric': self._best_metric,
- }
- def load_state_dict(self, state_dict):
- if state_dict is not None and len(state_dict) > 0:
- self._best_metric = state_dict.get('best_metric')
- else:
- self.logger.warning(
- 'The state_dict is not available, the best metric value will be affected.'
- )
- def after_run(self, trainer):
- if self.restore_best:
- # If restore_best is True, will call the LoadCheckpointHook to load the best checkpoint
- # for later evaluation or prediction.
- from modelscope.trainers.hooks.checkpoint.load_checkpoint_hook import LoadCheckpointHook
- LoadCheckpointHook.load_checkpoint(self._best_ckpt_file, trainer)
|