| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- # Copyright 2021-2022 The Alibaba PAI Team Authors.
- # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import torch
- from megatron_util import mpu
- from megatron_util.model import Float16Module
- from megatron_util.utils import unwrap_model
- from torch.nn.parallel import DistributedDataParallel as torchDDP
- from .configuration import logger
- from .moe.layer import MoE
- def get_checkpoint_names(checkpoints_path,
- path_load_tag,
- num_experts,
- tensor_rank=None,
- expp_rank=None):
- """Determine the directory name for this rank's checkpoint."""
- if tensor_rank is None:
- tensor_rank = mpu.get_tensor_model_parallel_rank()
- common_path = os.path.join(checkpoints_path, path_load_tag,
- f'mp_rank_{tensor_rank:02d}')
- if num_experts[0] > 0:
- model_name = os.path.join(common_path, 'model_rng.pt')
- optim_name = os.path.join(
- checkpoints_path, path_load_tag,
- f'expp_rank_{expp_rank}_mp_rank_{tensor_rank:02d}_optim_states.pt')
- else:
- model_name = optim_name = os.path.join(common_path,
- 'model_optim_rng.pt')
- return model_name, optim_name
- def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id):
- mp_rank = mpu.get_tensor_model_parallel_rank()
- ckpt_name = os.path.join(
- os.path.join(checkpoints_path, 'model'),
- f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt'
- )
- return ckpt_name
- def _load_base_checkpoint(load_dir, path_load_tag=None, num_experts=None):
- """ Load the base state_dict from the given directory
- If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
- """
- largest_group_name = mpu.get_max_expert_size_name()
- expp_rank = mpu.get_expert_parallel_rank(largest_group_name)
- checkpoint_names = get_checkpoint_names(
- load_dir,
- path_load_tag=path_load_tag,
- num_experts=num_experts,
- expp_rank=expp_rank)
- model_checkpoint_name, optim_checkpoint_name = checkpoint_names
- logger.info(f'Loading model checkpoint from {model_checkpoint_name}')
- model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
- return model_state_dict
- def load_checkpoint(model,
- load_dir,
- num_experts=None,
- strict=True,
- path_load_tag='model',
- load_ds_ckpts=True):
- model = unwrap_model(model, (torchDDP, Float16Module))
- model_state_dict = _load_base_checkpoint(
- load_dir, path_load_tag=path_load_tag, num_experts=num_experts)
- assert model_state_dict is not None
- if load_ds_ckpts:
- load_moe_checkpoint(model, model_state_dict['module'], load_dir)
- else:
- load_moe_checkpoint(model, model_state_dict['model'], load_dir)
- if load_ds_ckpts:
- model.load_state_dict(model_state_dict['module'], strict=strict)
- else:
- model.load_state_dict(model_state_dict['model'], strict=strict)
- if torch.distributed.is_initialized():
- torch.distributed.barrier()
- def load_moe_checkpoint(model, state_dict, load_dir):
- moe_layer_id = 0
- for n_module, module in model.named_modules():
- if isinstance(module, MoE): # and torch.distributed.get_rank() == 0:
- group_name = module.expert_group_name
- num_local_experts = module.num_local_experts
- expp_rank = mpu.get_expert_parallel_rank(group_name)
- # loop all local_experts
- for local_expert_id in range(num_local_experts):
- global_expert_id = expp_rank * num_local_experts + local_expert_id
- moe_load_path = _get_expert_ckpt_name(load_dir, moe_layer_id,
- global_expert_id)
- logger.info(f'Loading expert states from {moe_load_path}')
- expert_state_dict = torch.load(
- moe_load_path, map_location=torch.device('cpu'))
- # Updating global -> local expert ids
- moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
- for key in list(expert_state_dict.keys()):
- local_key = key.replace(
- f'{moe_str_prefix}{global_expert_id}',
- f'{moe_str_prefix}{local_expert_id}')
- expert_state_dict[local_key] = expert_state_dict.pop(key)
- state_dict.update(expert_state_dict)
- moe_layer_id += 1
|