| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- import os
- import shutil
- from modelscope.metainfo import Hooks
- from modelscope.trainers import EpochBasedTrainer
- from modelscope.trainers.hooks.builder import HOOKS
- from modelscope.trainers.hooks.checkpoint.checkpoint_hook import (
- BestCkptSaverHook, CheckpointHook, CheckpointProcessor)
- from modelscope.trainers.hooks.checkpoint.load_checkpoint_hook import \
- LoadCheckpointHook
- from modelscope.trainers.hooks.hook import Hook
- from modelscope.utils.checkpoint import save_configuration
- from modelscope.utils.import_utils import is_swift_available
- class SwiftCheckpointProcessor(CheckpointProcessor):
- _BIN_FILE_DIR = 'model'
- SWIFT_SAVE_SUFFIX = '_swift'
- @staticmethod
- def copy_files_and_dump_config(trainer, output_dir, config, bin_file):
- """Copy useful files to target output folder and dumps the target configuration.json.
- """
- model = trainer.unwrap_module(trainer.model)
- class SaveConfig:
- def __init__(self, output_dir, config):
- self.output_dir = output_dir
- self.config = config
- def __call__(self, _output_dir, _config):
- self.config = _config
- def save_config(self):
- save_configuration(self.output_dir, self.config)
- for pop_key in [
- 'push_to_hub', 'hub_repo_id', 'hub_token', 'private_hub'
- ]:
- if config.safe_get('train.checkpoint.period.'
- + pop_key) is not None:
- config.safe_get('train.checkpoint.period').pop(pop_key)
- if config.safe_get('train.checkpoint.best.' + pop_key) is not None:
- config.safe_get('train.checkpoint.best').pop(pop_key)
- save_config_fn = SaveConfig(output_dir, config)
- if hasattr(model, 'save_pretrained'):
- if not is_swift_available():
- raise ValueError(
- 'Please install swift by `pip install ms-swift` to use SwiftHook.'
- )
- from swift import SwiftModel
- if isinstance(model, SwiftModel):
- _swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
- model.save_pretrained(
- save_directory=_swift_output_dir,
- safe_serialization=config.safe_get(
- 'train.checkpoint.safe_serialization', False),
- adapter_name=config.safe_get(
- 'train.checkpoint.adapter_name', 'default'))
- else:
- model.save_pretrained(
- output_dir,
- bin_file,
- save_function=lambda *args, **kwargs: None,
- config=save_config_fn.config,
- save_config_function=save_config_fn)
- if trainer.train_preprocessor is not None:
- trainer.train_preprocessor.save_pretrained(
- output_dir,
- save_config_fn.config,
- save_config_function=save_config_fn)
- if trainer.eval_preprocessor is not None:
- trainer.eval_preprocessor.save_pretrained(
- output_dir,
- save_config_fn.config,
- save_config_function=save_config_fn)
- save_config_fn.save_config()
- def link_dir(self, source_dir, output_dir):
- if os.path.exists(output_dir):
- shutil.rmtree(output_dir)
- shutil.copytree(source_dir, output_dir)
- def save_swift_model_state(self, model, filename):
- model.save_pretrained(filename)
- def save_checkpoints(self,
- trainer,
- checkpoint_path_prefix,
- output_dir,
- meta=None,
- save_optimizers=True):
- model = trainer.unwrap_module(trainer.model)
- _model_file, _train_state_file = self._get_state_file_name(
- checkpoint_path_prefix)
- _swift_save_dir = checkpoint_path_prefix + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
- _swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
- self.save_trainer_state(trainer, model, _train_state_file, meta,
- save_optimizers)
- self.save_model_state(model, _model_file)
- self.link(model, _model_file, output_dir)
- self.save_swift_model_state(model, _swift_save_dir)
- self.link_dir(_swift_save_dir, _swift_output_dir)
- @HOOKS.register_module(module_name=Hooks.SwiftHook)
- class SwiftHook(Hook):
- _BIN_FILE_DIR = 'model'
- def __init__(self):
- pass
- def register_processor(self, trainer: EpochBasedTrainer):
- processor = SwiftCheckpointProcessor()
- ckpt_hook = trainer.get_hook(CheckpointHook)
- if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor,
- SwiftCheckpointProcessor):
- ckpt_hook[0].set_processor(processor)
- best_ckpt_hook = trainer.get_hook(BestCkptSaverHook)
- if len(best_ckpt_hook) > 0 and not isinstance(
- best_ckpt_hook[0].processor, SwiftCheckpointProcessor):
- best_ckpt_hook[0].set_processor(processor)
- load_ckpt_hook = trainer.get_hook(LoadCheckpointHook)
- if len(load_ckpt_hook) > 0 and not isinstance(
- load_ckpt_hook[0].processor, SwiftCheckpointProcessor):
- load_ckpt_hook[0].set_processor(processor)
|