checkpointing.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright 2021-2022 The Alibaba PAI Team Authors.
  2. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import torch
  17. from megatron_util import mpu
  18. from megatron_util.model import Float16Module
  19. from megatron_util.utils import unwrap_model
  20. from torch.nn.parallel import DistributedDataParallel as torchDDP
  21. from .configuration import logger
  22. from .moe.layer import MoE
  23. def get_checkpoint_names(checkpoints_path,
  24. path_load_tag,
  25. num_experts,
  26. tensor_rank=None,
  27. expp_rank=None):
  28. """Determine the directory name for this rank's checkpoint."""
  29. if tensor_rank is None:
  30. tensor_rank = mpu.get_tensor_model_parallel_rank()
  31. common_path = os.path.join(checkpoints_path, path_load_tag,
  32. f'mp_rank_{tensor_rank:02d}')
  33. if num_experts[0] > 0:
  34. model_name = os.path.join(common_path, 'model_rng.pt')
  35. optim_name = os.path.join(
  36. checkpoints_path, path_load_tag,
  37. f'expp_rank_{expp_rank}_mp_rank_{tensor_rank:02d}_optim_states.pt')
  38. else:
  39. model_name = optim_name = os.path.join(common_path,
  40. 'model_optim_rng.pt')
  41. return model_name, optim_name
  42. def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id):
  43. mp_rank = mpu.get_tensor_model_parallel_rank()
  44. ckpt_name = os.path.join(
  45. os.path.join(checkpoints_path, 'model'),
  46. f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt'
  47. )
  48. return ckpt_name
  49. def _load_base_checkpoint(load_dir, path_load_tag=None, num_experts=None):
  50. """ Load the base state_dict from the given directory
  51. If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
  52. """
  53. largest_group_name = mpu.get_max_expert_size_name()
  54. expp_rank = mpu.get_expert_parallel_rank(largest_group_name)
  55. checkpoint_names = get_checkpoint_names(
  56. load_dir,
  57. path_load_tag=path_load_tag,
  58. num_experts=num_experts,
  59. expp_rank=expp_rank)
  60. model_checkpoint_name, optim_checkpoint_name = checkpoint_names
  61. logger.info(f'Loading model checkpoint from {model_checkpoint_name}')
  62. model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
  63. return model_state_dict
  64. def load_checkpoint(model,
  65. load_dir,
  66. num_experts=None,
  67. strict=True,
  68. path_load_tag='model',
  69. load_ds_ckpts=True):
  70. model = unwrap_model(model, (torchDDP, Float16Module))
  71. model_state_dict = _load_base_checkpoint(
  72. load_dir, path_load_tag=path_load_tag, num_experts=num_experts)
  73. assert model_state_dict is not None
  74. if load_ds_ckpts:
  75. load_moe_checkpoint(model, model_state_dict['module'], load_dir)
  76. else:
  77. load_moe_checkpoint(model, model_state_dict['model'], load_dir)
  78. if load_ds_ckpts:
  79. model.load_state_dict(model_state_dict['module'], strict=strict)
  80. else:
  81. model.load_state_dict(model_state_dict['model'], strict=strict)
  82. if torch.distributed.is_initialized():
  83. torch.distributed.barrier()
  84. def load_moe_checkpoint(model, state_dict, load_dir):
  85. moe_layer_id = 0
  86. for n_module, module in model.named_modules():
  87. if isinstance(module, MoE): # and torch.distributed.get_rank() == 0:
  88. group_name = module.expert_group_name
  89. num_local_experts = module.num_local_experts
  90. expp_rank = mpu.get_expert_parallel_rank(group_name)
  91. # loop all local_experts
  92. for local_expert_id in range(num_local_experts):
  93. global_expert_id = expp_rank * num_local_experts + local_expert_id
  94. moe_load_path = _get_expert_ckpt_name(load_dir, moe_layer_id,
  95. global_expert_id)
  96. logger.info(f'Loading expert states from {moe_load_path}')
  97. expert_state_dict = torch.load(
  98. moe_load_path, map_location=torch.device('cpu'))
  99. # Updating global -> local expert ids
  100. moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
  101. for key in list(expert_state_dict.keys()):
  102. local_key = key.replace(
  103. f'{moe_str_prefix}{global_expert_id}',
  104. f'{moe_str_prefix}{local_expert_id}')
  105. expert_state_dict[local_key] = expert_state_dict.pop(key)
  106. state_dict.update(expert_state_dict)
  107. moe_layer_id += 1