megatron_utils.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. from typing import Dict, List, Union
  5. import torch
  6. from torch import nn
  7. from modelscope.utils.logger import get_logger
  8. from modelscope.utils.torch_utils import is_master
  9. logger = get_logger()
  10. _DEFAULT_CFG_WITH_MODEL_TYPE = {
  11. 'gpt-moe': {
  12. 'version': 'moe',
  13. 'world_size': 8
  14. },
  15. 'plug': {
  16. 'version': 'v1',
  17. 'world_size': 8,
  18. 'tensor_model_parallel_size': 8,
  19. 'seed': 1234
  20. },
  21. 'mglm-text-summarization': {
  22. 'version': 'v1',
  23. 'seed': 1234
  24. },
  25. }
  26. _CHECKPOINT_FORMAT = 'mp_rank_XX_model_states.pt'
  27. _IS_MEGATRON_INITIALIZED = False
  28. def init_megatron_util(megatron_cfg=None, model_dir=None, **kwargs):
  29. """Initialize megatron_util environment for megatron_based model.
  30. If argument `megatron_cfg` is not specified, then the megatorn_cfg will be load
  31. from configuration.json file in the model_dir.
  32. Args:
  33. megatron_cfg (Dict, optional): Megatron Config will be send to megatron_util.
  34. model_dir (str, optional): The model path for configuration. Defaults to None.
  35. """
  36. from modelscope.utils.hub import read_config
  37. from megatron_util import initialize_megatron
  38. assert not (megatron_cfg is None and model_dir is None), \
  39. 'cfg and model_dir cannot both be None when initializing megatron_util'
  40. if megatron_cfg is None:
  41. cfg = read_config(model_dir)
  42. try:
  43. megatron_cfg = cfg.megatron
  44. except AttributeError:
  45. try:
  46. model_type = cfg.model.type
  47. except AttributeError:
  48. # Fit models without model type, such as mglm
  49. model_type = cfg.pipeline.type
  50. megatron_cfg = _DEFAULT_CFG_WITH_MODEL_TYPE[model_type] \
  51. if model_type in _DEFAULT_CFG_WITH_MODEL_TYPE else {}
  52. megatron_cfg.update(kwargs)
  53. initialize_megatron(megatron_cfg)
  54. global _IS_MEGATRON_INITIALIZED
  55. _IS_MEGATRON_INITIALIZED = True
  56. def is_megatron_initialized() -> bool:
  57. return _IS_MEGATRON_INITIALIZED
  58. def convert_megatron_checkpoint(
  59. model: nn.Module, checkpoint_dir: Union[str, bytes, os.PathLike],
  60. target_dir: Union[str, bytes, os.PathLike]) -> None:
  61. """Split or Merge checkpoint for megatron_based model.
  62. Args:
  63. model (nn.Module): Any megatron_based model.
  64. checkpoint_dir (Union[str, bytes, os.PathLike]): The save path of origin checkpoint.
  65. target_dir (Union[str, bytes, os.PathLike]): The target path of new checkpoint.
  66. """
  67. def log_master(information: str):
  68. if is_master():
  69. logger.info(information)
  70. if os.path.exists(os.path.join(checkpoint_dir, 'model')):
  71. checkpoint_dir = os.path.join(checkpoint_dir, 'model')
  72. origin_num_partitions = len(os.listdir(checkpoint_dir))
  73. target_num_partitions = int(os.getenv('WORLD_SIZE'))
  74. _check_origin_dir(checkpoint_dir)
  75. _check_target_num_partitions(target_num_partitions)
  76. log_master(
  77. f'origin_num_partitions: {origin_num_partitions}, target_num_partitions: {target_num_partitions}'
  78. )
  79. if origin_num_partitions < target_num_partitions:
  80. os.makedirs(target_dir, exist_ok=True)
  81. state_dict = _split_checkpoint(
  82. model, checkpoint_dir,
  83. target_num_partitions // origin_num_partitions)
  84. _save_converted_checkpoint(state_dict, target_dir)
  85. log_master('Split checkpoints succeeded.')
  86. elif origin_num_partitions > target_num_partitions:
  87. os.makedirs(target_dir, exist_ok=True)
  88. state_dict = _merge_checkpoint(
  89. model, checkpoint_dir,
  90. origin_num_partitions // target_num_partitions)
  91. _save_converted_checkpoint(state_dict, target_dir)
  92. log_master('Merge checkpoints succeeded.')
  93. else:
  94. shutil.copytree(checkpoint_dir, target_dir)
  95. log_master('Copy checkpoints succeeded.')
  96. def _check_origin_dir(origin_dir: Union[str, bytes, os.PathLike]) -> None:
  97. filenames = os.listdir(origin_dir)
  98. assert len(filenames) & (
  99. len(filenames) - 1) == 0, 'The number of files must be a power of 2!'
  100. for i in range(len(filenames)):
  101. checkpoint_name = _CHECKPOINT_FORMAT.replace('XX', f'{i:02d}')
  102. assert checkpoint_name in filenames, \
  103. f'Can not find {checkpoint_name} file!'
  104. def _check_target_num_partitions(num_partitions: int) -> None:
  105. assert num_partitions & (num_partitions - 1) == 0, \
  106. 'The number of target partitions must be a power of 2!'
  107. def _split_checkpoint(model: nn.Module, checkpoint_dir: Union[str, bytes,
  108. os.PathLike],
  109. num_partitions: int) -> Dict[str, torch.Tensor]:
  110. target_rank = int(os.getenv('RANK'))
  111. origin_rank = target_rank // num_partitions
  112. state_dict = _load_by_rank(checkpoint_dir, origin_rank)
  113. target_state_dict = {}
  114. for name, parameter in model.named_parameters():
  115. dim = _get_diff_dim(parameter, state_dict[name])
  116. if dim == -1:
  117. target_state_dict[name] = state_dict[name]
  118. continue
  119. partitions_list = _split_tensor(state_dict[name], num_partitions, dim)
  120. target_state_dict[name] = partitions_list[target_rank
  121. % num_partitions].clone()
  122. return target_state_dict
  123. def _merge_checkpoint(model: nn.Module, checkpoint_dir: Union[str, bytes,
  124. os.PathLike],
  125. num_partitions: int) -> Dict[str, torch.Tensor]:
  126. target_rank = int(os.getenv('RANK'))
  127. origin_rank_list = [
  128. target_rank * num_partitions + i for i in range(num_partitions)
  129. ]
  130. state_dict_list = [
  131. _load_by_rank(checkpoint_dir, i) for i in origin_rank_list
  132. ]
  133. target_state_dict = {}
  134. for name, parameter in model.named_parameters():
  135. dim = _get_diff_dim(parameter, state_dict_list[0][name])
  136. if dim == -1:
  137. target_state_dict[name] = state_dict_list[0][name]
  138. continue
  139. target_state_dict[name] = torch.cat(
  140. [state_dict[name] for state_dict in state_dict_list],
  141. dim=dim).clone()
  142. return target_state_dict
  143. def _save_converted_checkpoint(
  144. state_dict: Dict[str, torch.Tensor],
  145. target_dir: Union[str, bytes, os.PathLike]) -> None:
  146. target_rank = int(os.getenv('RANK'))
  147. target_name = _CHECKPOINT_FORMAT.replace('XX', f'{target_rank:02d}')
  148. torch.save(state_dict, os.path.join(target_dir, target_name))
  149. def _get_diff_dim(tensor1: torch.Tensor, tensor2: torch.Tensor) -> int:
  150. for i, (s1, s2) in enumerate(zip(tensor1.shape, tensor2.shape)):
  151. if s1 != s2:
  152. return i
  153. return -1
  154. def _load_by_rank(checkpoint_dir: Union[str, bytes, os.PathLike],
  155. rank: int) -> Dict[str, torch.Tensor]:
  156. checkpoint_name = _CHECKPOINT_FORMAT.replace('XX', f'{rank:02d}')
  157. state_dict = torch.load(
  158. os.path.join(checkpoint_dir, checkpoint_name),
  159. map_location=lambda storage, loc: storage,
  160. weights_only=True)
  161. return state_dict['module'] if 'module' in state_dict else state_dict
  162. def _split_tensor(tensor: torch.Tensor, num_partitions: int,
  163. partition_dim: int) -> List[torch.Tensor]:
  164. from megatron_util import mpu
  165. per_partition_size = mpu.utils.divide(
  166. tensor.size(partition_dim), num_partitions)
  167. partitions_list = torch.split(
  168. tensor, per_partition_size, dim=partition_dim)
  169. return partitions_list