# Copyright (c) Alibaba, Inc. and its affiliates. import os import shutil from typing import Dict, List, Union import torch from torch import nn from modelscope.utils.logger import get_logger from modelscope.utils.torch_utils import is_master logger = get_logger() _DEFAULT_CFG_WITH_MODEL_TYPE = { 'gpt-moe': { 'version': 'moe', 'world_size': 8 }, 'plug': { 'version': 'v1', 'world_size': 8, 'tensor_model_parallel_size': 8, 'seed': 1234 }, 'mglm-text-summarization': { 'version': 'v1', 'seed': 1234 }, } _CHECKPOINT_FORMAT = 'mp_rank_XX_model_states.pt' _IS_MEGATRON_INITIALIZED = False def init_megatron_util(megatron_cfg=None, model_dir=None, **kwargs): """Initialize megatron_util environment for megatron_based model. If argument `megatron_cfg` is not specified, then the megatorn_cfg will be load from configuration.json file in the model_dir. Args: megatron_cfg (Dict, optional): Megatron Config will be send to megatron_util. model_dir (str, optional): The model path for configuration. Defaults to None. """ from modelscope.utils.hub import read_config from megatron_util import initialize_megatron assert not (megatron_cfg is None and model_dir is None), \ 'cfg and model_dir cannot both be None when initializing megatron_util' if megatron_cfg is None: cfg = read_config(model_dir) try: megatron_cfg = cfg.megatron except AttributeError: try: model_type = cfg.model.type except AttributeError: # Fit models without model type, such as mglm model_type = cfg.pipeline.type megatron_cfg = _DEFAULT_CFG_WITH_MODEL_TYPE[model_type] \ if model_type in _DEFAULT_CFG_WITH_MODEL_TYPE else {} megatron_cfg.update(kwargs) initialize_megatron(megatron_cfg) global _IS_MEGATRON_INITIALIZED _IS_MEGATRON_INITIALIZED = True def is_megatron_initialized() -> bool: return _IS_MEGATRON_INITIALIZED def convert_megatron_checkpoint( model: nn.Module, checkpoint_dir: Union[str, bytes, os.PathLike], target_dir: Union[str, bytes, os.PathLike]) -> None: """Split or Merge checkpoint for megatron_based model. Args: model (nn.Module): Any megatron_based model. checkpoint_dir (Union[str, bytes, os.PathLike]): The save path of origin checkpoint. target_dir (Union[str, bytes, os.PathLike]): The target path of new checkpoint. """ def log_master(information: str): if is_master(): logger.info(information) if os.path.exists(os.path.join(checkpoint_dir, 'model')): checkpoint_dir = os.path.join(checkpoint_dir, 'model') origin_num_partitions = len(os.listdir(checkpoint_dir)) target_num_partitions = int(os.getenv('WORLD_SIZE')) _check_origin_dir(checkpoint_dir) _check_target_num_partitions(target_num_partitions) log_master( f'origin_num_partitions: {origin_num_partitions}, target_num_partitions: {target_num_partitions}' ) if origin_num_partitions < target_num_partitions: os.makedirs(target_dir, exist_ok=True) state_dict = _split_checkpoint( model, checkpoint_dir, target_num_partitions // origin_num_partitions) _save_converted_checkpoint(state_dict, target_dir) log_master('Split checkpoints succeeded.') elif origin_num_partitions > target_num_partitions: os.makedirs(target_dir, exist_ok=True) state_dict = _merge_checkpoint( model, checkpoint_dir, origin_num_partitions // target_num_partitions) _save_converted_checkpoint(state_dict, target_dir) log_master('Merge checkpoints succeeded.') else: shutil.copytree(checkpoint_dir, target_dir) log_master('Copy checkpoints succeeded.') def _check_origin_dir(origin_dir: Union[str, bytes, os.PathLike]) -> None: filenames = os.listdir(origin_dir) assert len(filenames) & ( len(filenames) - 1) == 0, 'The number of files must be a power of 2!' for i in range(len(filenames)): checkpoint_name = _CHECKPOINT_FORMAT.replace('XX', f'{i:02d}') assert checkpoint_name in filenames, \ f'Can not find {checkpoint_name} file!' def _check_target_num_partitions(num_partitions: int) -> None: assert num_partitions & (num_partitions - 1) == 0, \ 'The number of target partitions must be a power of 2!' def _split_checkpoint(model: nn.Module, checkpoint_dir: Union[str, bytes, os.PathLike], num_partitions: int) -> Dict[str, torch.Tensor]: target_rank = int(os.getenv('RANK')) origin_rank = target_rank // num_partitions state_dict = _load_by_rank(checkpoint_dir, origin_rank) target_state_dict = {} for name, parameter in model.named_parameters(): dim = _get_diff_dim(parameter, state_dict[name]) if dim == -1: target_state_dict[name] = state_dict[name] continue partitions_list = _split_tensor(state_dict[name], num_partitions, dim) target_state_dict[name] = partitions_list[target_rank % num_partitions].clone() return target_state_dict def _merge_checkpoint(model: nn.Module, checkpoint_dir: Union[str, bytes, os.PathLike], num_partitions: int) -> Dict[str, torch.Tensor]: target_rank = int(os.getenv('RANK')) origin_rank_list = [ target_rank * num_partitions + i for i in range(num_partitions) ] state_dict_list = [ _load_by_rank(checkpoint_dir, i) for i in origin_rank_list ] target_state_dict = {} for name, parameter in model.named_parameters(): dim = _get_diff_dim(parameter, state_dict_list[0][name]) if dim == -1: target_state_dict[name] = state_dict_list[0][name] continue target_state_dict[name] = torch.cat( [state_dict[name] for state_dict in state_dict_list], dim=dim).clone() return target_state_dict def _save_converted_checkpoint( state_dict: Dict[str, torch.Tensor], target_dir: Union[str, bytes, os.PathLike]) -> None: target_rank = int(os.getenv('RANK')) target_name = _CHECKPOINT_FORMAT.replace('XX', f'{target_rank:02d}') torch.save(state_dict, os.path.join(target_dir, target_name)) def _get_diff_dim(tensor1: torch.Tensor, tensor2: torch.Tensor) -> int: for i, (s1, s2) in enumerate(zip(tensor1.shape, tensor2.shape)): if s1 != s2: return i return -1 def _load_by_rank(checkpoint_dir: Union[str, bytes, os.PathLike], rank: int) -> Dict[str, torch.Tensor]: checkpoint_name = _CHECKPOINT_FORMAT.replace('XX', f'{rank:02d}') state_dict = torch.load( os.path.join(checkpoint_dir, checkpoint_name), map_location=lambda storage, loc: storage, weights_only=True) return state_dict['module'] if 'module' in state_dict else state_dict def _split_tensor(tensor: torch.Tensor, num_partitions: int, partition_dim: int) -> List[torch.Tensor]: from megatron_util import mpu per_partition_size = mpu.utils.divide( tensor.size(partition_dim), num_partitions) partitions_list = torch.split( tensor, per_partition_size, dim=partition_dim) return partitions_list