# Copyright 2021-2022 The Alibaba PAI Team Authors. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch from megatron_util import mpu from megatron_util.model import Float16Module from megatron_util.utils import unwrap_model from torch.nn.parallel import DistributedDataParallel as torchDDP from .configuration import logger from .moe.layer import MoE def get_checkpoint_names(checkpoints_path, path_load_tag, num_experts, tensor_rank=None, expp_rank=None): """Determine the directory name for this rank's checkpoint.""" if tensor_rank is None: tensor_rank = mpu.get_tensor_model_parallel_rank() common_path = os.path.join(checkpoints_path, path_load_tag, f'mp_rank_{tensor_rank:02d}') if num_experts[0] > 0: model_name = os.path.join(common_path, 'model_rng.pt') optim_name = os.path.join( checkpoints_path, path_load_tag, f'expp_rank_{expp_rank}_mp_rank_{tensor_rank:02d}_optim_states.pt') else: model_name = optim_name = os.path.join(common_path, 'model_optim_rng.pt') return model_name, optim_name def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id): mp_rank = mpu.get_tensor_model_parallel_rank() ckpt_name = os.path.join( os.path.join(checkpoints_path, 'model'), f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt' ) return ckpt_name def _load_base_checkpoint(load_dir, path_load_tag=None, num_experts=None): """ Load the base state_dict from the given directory If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. """ largest_group_name = mpu.get_max_expert_size_name() expp_rank = mpu.get_expert_parallel_rank(largest_group_name) checkpoint_names = get_checkpoint_names( load_dir, path_load_tag=path_load_tag, num_experts=num_experts, expp_rank=expp_rank) model_checkpoint_name, optim_checkpoint_name = checkpoint_names logger.info(f'Loading model checkpoint from {model_checkpoint_name}') model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') return model_state_dict def load_checkpoint(model, load_dir, num_experts=None, strict=True, path_load_tag='model', load_ds_ckpts=True): model = unwrap_model(model, (torchDDP, Float16Module)) model_state_dict = _load_base_checkpoint( load_dir, path_load_tag=path_load_tag, num_experts=num_experts) assert model_state_dict is not None if load_ds_ckpts: load_moe_checkpoint(model, model_state_dict['module'], load_dir) else: load_moe_checkpoint(model, model_state_dict['model'], load_dir) if load_ds_ckpts: model.load_state_dict(model_state_dict['module'], strict=strict) else: model.load_state_dict(model_state_dict['model'], strict=strict) if torch.distributed.is_initialized(): torch.distributed.barrier() def load_moe_checkpoint(model, state_dict, load_dir): moe_layer_id = 0 for n_module, module in model.named_modules(): if isinstance(module, MoE): # and torch.distributed.get_rank() == 0: group_name = module.expert_group_name num_local_experts = module.num_local_experts expp_rank = mpu.get_expert_parallel_rank(group_name) # loop all local_experts for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id moe_load_path = _get_expert_ckpt_name(load_dir, moe_layer_id, global_expert_id) logger.info(f'Loading expert states from {moe_load_path}') expert_state_dict = torch.load( moe_load_path, map_location=torch.device('cpu')) # Updating global -> local expert ids moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' for key in list(expert_state_dict.keys()): local_key = key.replace( f'{moe_str_prefix}{global_expert_id}', f'{moe_str_prefix}{local_expert_id}') expert_state_dict[local_key] = expert_state_dict.pop(key) state_dict.update(expert_state_dict) moe_layer_id += 1