| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import os
- import shutil
- import torch
- from megatron_util import mpu
- from modelscope.metainfo import Hooks
- from modelscope.trainers import EpochBasedTrainer
- from modelscope.trainers.hooks.builder import HOOKS
- from modelscope.trainers.hooks.checkpoint.checkpoint_hook import (
- BestCkptSaverHook, CheckpointHook, CheckpointProcessor)
- from modelscope.trainers.hooks.checkpoint.load_checkpoint_hook import \
- LoadCheckpointHook
- from modelscope.trainers.hooks.hook import Hook
- from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint
- from modelscope.utils.constant import DistributedParallelType
- from modelscope.utils.device import create_device
- from modelscope.utils.logger import get_logger
- from modelscope.utils.megatron_utils import is_megatron_initialized
- from modelscope.utils.torch_utils import get_local_rank
- class MpuProcessor(CheckpointProcessor):
- _BIN_FILE_DIR = 'model'
- def rank_name(self):
- # TODO
- try:
- tp_world_size = mpu.get_tensor_model_parallel_world_size()
- if tp_world_size == 1:
- return ''
- mp_rank = mpu.get_tensor_model_parallel_rank()
- return '_mp_rank_{:02d}'.format(mp_rank)
- except (ImportError, AssertionError):
- return ''
- def get_bin_filename(self):
- mp_rank = mpu.get_tensor_model_parallel_rank()
- rank = '{:02d}'.format(mp_rank)
- return f'mp_rank_{rank}_model_states.pt'
- def should_save_on_rank(self, trainer):
- # TODO
- return (not torch.distributed.is_initialized()
- ) or mpu.get_data_parallel_rank() == 0
- def prepare_output(self, trainer, output_dir):
- config = trainer.cfg
- CheckpointProcessor.copy_files_and_dump_config(trainer, output_dir,
- config,
- self._BIN_FILE_DIR)
- os.makedirs(
- os.path.join(output_dir, self._BIN_FILE_DIR), exist_ok=True)
- def save_checkpoints(self,
- trainer,
- checkpoint_path_prefix,
- output_dir,
- meta=None,
- save_optimizers=True):
- model = trainer.unwrap_module(trainer.model)
- _train_state_file = checkpoint_path_prefix + self.rank_name(
- ) + CheckpointProcessor.TRAINER_STATE_SUFFIX
- # Save pth file without model state_dict
- save_checkpoint(
- model,
- _train_state_file,
- trainer.optimizer if save_optimizers else None,
- trainer.lr_scheduler if save_optimizers else None,
- meta=meta,
- with_model=False)
- save_dir = os.path.dirname(checkpoint_path_prefix)
- prefix = os.path.basename(checkpoint_path_prefix)
- bin_file = self.get_bin_filename()
- prefix_bin_file = os.path.join(save_dir, prefix + '_' + bin_file)
- save_checkpoint(model, prefix_bin_file, with_meta=False)
- src_file = prefix_bin_file
- dest_file = os.path.join(output_dir, self._BIN_FILE_DIR, bin_file)
- if os.path.isfile(dest_file):
- os.unlink(dest_file)
- try:
- os.link(src_file, dest_file)
- except OSError as e:
- get_logger().error(
- f'Link {src_file} to {dest_file} error: {e}, '
- 'changing to copy the bin file, this may case more space usage.'
- )
- shutil.copyfile(src_file, dest_file)
- def remove_checkpoints(self, trainer, checkpoint_path_prefix):
- _train_state_file = checkpoint_path_prefix + self.rank_name(
- ) + CheckpointProcessor.TRAINER_STATE_SUFFIX
- if os.path.isfile(_train_state_file):
- os.remove(_train_state_file)
- save_dir = os.path.dirname(checkpoint_path_prefix)
- prefix = os.path.basename(checkpoint_path_prefix)
- bin_file = self.get_bin_filename()
- absolute_file = os.path.join(save_dir, prefix + '_' + bin_file)
- if os.path.isfile(absolute_file):
- os.remove(absolute_file)
- def load_checkpoints(self, checkpoint_path_prefix, trainer, load_all_state,
- strict):
- model = trainer.unwrap_module(trainer.model)
- if os.path.isdir(checkpoint_path_prefix):
- save_dir = checkpoint_path_prefix
- bin_file = self.get_bin_filename()
- model_file = os.path.join(save_dir, bin_file)
- load_checkpoint(model_file, model, None, None)
- else:
- _train_state_file = checkpoint_path_prefix + self.rank_name(
- ) + CheckpointProcessor.TRAINER_STATE_SUFFIX
- meta = LoadCheckpointHook.load_trainer_state(
- trainer, _train_state_file, load_all_state)
- save_dir = os.path.dirname(checkpoint_path_prefix)
- prefix = os.path.basename(checkpoint_path_prefix)
- bin_file = self.get_bin_filename()
- model_file = os.path.join(save_dir, prefix + '_' + bin_file)
- load_checkpoint(model_file, model, None, None)
- return meta
- @HOOKS.register_module(module_name=Hooks.MegatronHook)
- class MegatronHook(Hook):
- _BIN_FILE_DIR = 'model'
- def __init__(self):
- self.wrapped = False
- def register_processor(self, trainer: EpochBasedTrainer):
- processor = MpuProcessor()
- ckpt_hook = trainer.get_hook(CheckpointHook)
- if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor,
- MpuProcessor):
- ckpt_hook[0].set_processor(processor)
- best_ckpt_hook = trainer.get_hook(BestCkptSaverHook)
- if len(best_ckpt_hook) > 0 and not isinstance(
- best_ckpt_hook[0].processor, MpuProcessor):
- best_ckpt_hook[0].set_processor(processor)
- load_ckpt_hook = trainer.get_hook(LoadCheckpointHook)
- if len(load_ckpt_hook) > 0 and not isinstance(
- load_ckpt_hook[0].processor, MpuProcessor):
- load_ckpt_hook[0].set_processor(processor)
- def after_init(self, trainer):
- assert is_megatron_initialized()
- local_rank = get_local_rank()
- trainer.device = create_device(f'cuda:{local_rank}')
- trainer.model.to(trainer.device)
- trainer.parallel_groups[
- DistributedParallelType.DP] = mpu.get_data_parallel_group()
- trainer.parallel_groups[DistributedParallelType.
- TP] = mpu.get_tensor_model_parallel_group()
- trainer.parallel_groups[DistributedParallelType.
- PP] = mpu.get_pipeline_model_parallel_group()
- def before_run(self, trainer):
- self.wrap_module(trainer)
- def before_val(self, trainer):
- self.wrap_module(trainer)
- def wrap_module(self, trainer):
- if trainer._dist:
- if not self.wrapped:
- trainer.model = trainer.to_parallel(trainer.model)
- self.wrapped = True
|