swift_hook.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import os
  2. import shutil
  3. from modelscope.metainfo import Hooks
  4. from modelscope.trainers import EpochBasedTrainer
  5. from modelscope.trainers.hooks.builder import HOOKS
  6. from modelscope.trainers.hooks.checkpoint.checkpoint_hook import (
  7. BestCkptSaverHook, CheckpointHook, CheckpointProcessor)
  8. from modelscope.trainers.hooks.checkpoint.load_checkpoint_hook import \
  9. LoadCheckpointHook
  10. from modelscope.trainers.hooks.hook import Hook
  11. from modelscope.utils.checkpoint import save_configuration
  12. from modelscope.utils.import_utils import is_swift_available
  13. class SwiftCheckpointProcessor(CheckpointProcessor):
  14. _BIN_FILE_DIR = 'model'
  15. SWIFT_SAVE_SUFFIX = '_swift'
  16. @staticmethod
  17. def copy_files_and_dump_config(trainer, output_dir, config, bin_file):
  18. """Copy useful files to target output folder and dumps the target configuration.json.
  19. """
  20. model = trainer.unwrap_module(trainer.model)
  21. class SaveConfig:
  22. def __init__(self, output_dir, config):
  23. self.output_dir = output_dir
  24. self.config = config
  25. def __call__(self, _output_dir, _config):
  26. self.config = _config
  27. def save_config(self):
  28. save_configuration(self.output_dir, self.config)
  29. for pop_key in [
  30. 'push_to_hub', 'hub_repo_id', 'hub_token', 'private_hub'
  31. ]:
  32. if config.safe_get('train.checkpoint.period.'
  33. + pop_key) is not None:
  34. config.safe_get('train.checkpoint.period').pop(pop_key)
  35. if config.safe_get('train.checkpoint.best.' + pop_key) is not None:
  36. config.safe_get('train.checkpoint.best').pop(pop_key)
  37. save_config_fn = SaveConfig(output_dir, config)
  38. if hasattr(model, 'save_pretrained'):
  39. if not is_swift_available():
  40. raise ValueError(
  41. 'Please install swift by `pip install ms-swift` to use SwiftHook.'
  42. )
  43. from swift import SwiftModel
  44. if isinstance(model, SwiftModel):
  45. _swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
  46. model.save_pretrained(
  47. save_directory=_swift_output_dir,
  48. safe_serialization=config.safe_get(
  49. 'train.checkpoint.safe_serialization', False),
  50. adapter_name=config.safe_get(
  51. 'train.checkpoint.adapter_name', 'default'))
  52. else:
  53. model.save_pretrained(
  54. output_dir,
  55. bin_file,
  56. save_function=lambda *args, **kwargs: None,
  57. config=save_config_fn.config,
  58. save_config_function=save_config_fn)
  59. if trainer.train_preprocessor is not None:
  60. trainer.train_preprocessor.save_pretrained(
  61. output_dir,
  62. save_config_fn.config,
  63. save_config_function=save_config_fn)
  64. if trainer.eval_preprocessor is not None:
  65. trainer.eval_preprocessor.save_pretrained(
  66. output_dir,
  67. save_config_fn.config,
  68. save_config_function=save_config_fn)
  69. save_config_fn.save_config()
  70. def link_dir(self, source_dir, output_dir):
  71. if os.path.exists(output_dir):
  72. shutil.rmtree(output_dir)
  73. shutil.copytree(source_dir, output_dir)
  74. def save_swift_model_state(self, model, filename):
  75. model.save_pretrained(filename)
  76. def save_checkpoints(self,
  77. trainer,
  78. checkpoint_path_prefix,
  79. output_dir,
  80. meta=None,
  81. save_optimizers=True):
  82. model = trainer.unwrap_module(trainer.model)
  83. _model_file, _train_state_file = self._get_state_file_name(
  84. checkpoint_path_prefix)
  85. _swift_save_dir = checkpoint_path_prefix + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
  86. _swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
  87. self.save_trainer_state(trainer, model, _train_state_file, meta,
  88. save_optimizers)
  89. self.save_model_state(model, _model_file)
  90. self.link(model, _model_file, output_dir)
  91. self.save_swift_model_state(model, _swift_save_dir)
  92. self.link_dir(_swift_save_dir, _swift_output_dir)
  93. @HOOKS.register_module(module_name=Hooks.SwiftHook)
  94. class SwiftHook(Hook):
  95. _BIN_FILE_DIR = 'model'
  96. def __init__(self):
  97. pass
  98. def register_processor(self, trainer: EpochBasedTrainer):
  99. processor = SwiftCheckpointProcessor()
  100. ckpt_hook = trainer.get_hook(CheckpointHook)
  101. if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor,
  102. SwiftCheckpointProcessor):
  103. ckpt_hook[0].set_processor(processor)
  104. best_ckpt_hook = trainer.get_hook(BestCkptSaverHook)
  105. if len(best_ckpt_hook) > 0 and not isinstance(
  106. best_ckpt_hook[0].processor, SwiftCheckpointProcessor):
  107. best_ckpt_hook[0].set_processor(processor)
  108. load_ckpt_hook = trainer.get_hook(LoadCheckpointHook)
  109. if len(load_ckpt_hook) > 0 and not isinstance(
  110. load_ckpt_hook[0].processor, SwiftCheckpointProcessor):
  111. load_ckpt_hook[0].set_processor(processor)