megatron_hook.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. import shutil
  3. import torch
  4. from megatron_util import mpu
  5. from modelscope.metainfo import Hooks
  6. from modelscope.trainers import EpochBasedTrainer
  7. from modelscope.trainers.hooks.builder import HOOKS
  8. from modelscope.trainers.hooks.checkpoint.checkpoint_hook import (
  9. BestCkptSaverHook, CheckpointHook, CheckpointProcessor)
  10. from modelscope.trainers.hooks.checkpoint.load_checkpoint_hook import \
  11. LoadCheckpointHook
  12. from modelscope.trainers.hooks.hook import Hook
  13. from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint
  14. from modelscope.utils.constant import DistributedParallelType
  15. from modelscope.utils.device import create_device
  16. from modelscope.utils.logger import get_logger
  17. from modelscope.utils.megatron_utils import is_megatron_initialized
  18. from modelscope.utils.torch_utils import get_local_rank
  19. class MpuProcessor(CheckpointProcessor):
  20. _BIN_FILE_DIR = 'model'
  21. def rank_name(self):
  22. # TODO
  23. try:
  24. tp_world_size = mpu.get_tensor_model_parallel_world_size()
  25. if tp_world_size == 1:
  26. return ''
  27. mp_rank = mpu.get_tensor_model_parallel_rank()
  28. return '_mp_rank_{:02d}'.format(mp_rank)
  29. except (ImportError, AssertionError):
  30. return ''
  31. def get_bin_filename(self):
  32. mp_rank = mpu.get_tensor_model_parallel_rank()
  33. rank = '{:02d}'.format(mp_rank)
  34. return f'mp_rank_{rank}_model_states.pt'
  35. def should_save_on_rank(self, trainer):
  36. # TODO
  37. return (not torch.distributed.is_initialized()
  38. ) or mpu.get_data_parallel_rank() == 0
  39. def prepare_output(self, trainer, output_dir):
  40. config = trainer.cfg
  41. CheckpointProcessor.copy_files_and_dump_config(trainer, output_dir,
  42. config,
  43. self._BIN_FILE_DIR)
  44. os.makedirs(
  45. os.path.join(output_dir, self._BIN_FILE_DIR), exist_ok=True)
  46. def save_checkpoints(self,
  47. trainer,
  48. checkpoint_path_prefix,
  49. output_dir,
  50. meta=None,
  51. save_optimizers=True):
  52. model = trainer.unwrap_module(trainer.model)
  53. _train_state_file = checkpoint_path_prefix + self.rank_name(
  54. ) + CheckpointProcessor.TRAINER_STATE_SUFFIX
  55. # Save pth file without model state_dict
  56. save_checkpoint(
  57. model,
  58. _train_state_file,
  59. trainer.optimizer if save_optimizers else None,
  60. trainer.lr_scheduler if save_optimizers else None,
  61. meta=meta,
  62. with_model=False)
  63. save_dir = os.path.dirname(checkpoint_path_prefix)
  64. prefix = os.path.basename(checkpoint_path_prefix)
  65. bin_file = self.get_bin_filename()
  66. prefix_bin_file = os.path.join(save_dir, prefix + '_' + bin_file)
  67. save_checkpoint(model, prefix_bin_file, with_meta=False)
  68. src_file = prefix_bin_file
  69. dest_file = os.path.join(output_dir, self._BIN_FILE_DIR, bin_file)
  70. if os.path.isfile(dest_file):
  71. os.unlink(dest_file)
  72. try:
  73. os.link(src_file, dest_file)
  74. except OSError as e:
  75. get_logger().error(
  76. f'Link {src_file} to {dest_file} error: {e}, '
  77. 'changing to copy the bin file, this may case more space usage.'
  78. )
  79. shutil.copyfile(src_file, dest_file)
  80. def remove_checkpoints(self, trainer, checkpoint_path_prefix):
  81. _train_state_file = checkpoint_path_prefix + self.rank_name(
  82. ) + CheckpointProcessor.TRAINER_STATE_SUFFIX
  83. if os.path.isfile(_train_state_file):
  84. os.remove(_train_state_file)
  85. save_dir = os.path.dirname(checkpoint_path_prefix)
  86. prefix = os.path.basename(checkpoint_path_prefix)
  87. bin_file = self.get_bin_filename()
  88. absolute_file = os.path.join(save_dir, prefix + '_' + bin_file)
  89. if os.path.isfile(absolute_file):
  90. os.remove(absolute_file)
  91. def load_checkpoints(self, checkpoint_path_prefix, trainer, load_all_state,
  92. strict):
  93. model = trainer.unwrap_module(trainer.model)
  94. if os.path.isdir(checkpoint_path_prefix):
  95. save_dir = checkpoint_path_prefix
  96. bin_file = self.get_bin_filename()
  97. model_file = os.path.join(save_dir, bin_file)
  98. load_checkpoint(model_file, model, None, None)
  99. else:
  100. _train_state_file = checkpoint_path_prefix + self.rank_name(
  101. ) + CheckpointProcessor.TRAINER_STATE_SUFFIX
  102. meta = LoadCheckpointHook.load_trainer_state(
  103. trainer, _train_state_file, load_all_state)
  104. save_dir = os.path.dirname(checkpoint_path_prefix)
  105. prefix = os.path.basename(checkpoint_path_prefix)
  106. bin_file = self.get_bin_filename()
  107. model_file = os.path.join(save_dir, prefix + '_' + bin_file)
  108. load_checkpoint(model_file, model, None, None)
  109. return meta
  110. @HOOKS.register_module(module_name=Hooks.MegatronHook)
  111. class MegatronHook(Hook):
  112. _BIN_FILE_DIR = 'model'
  113. def __init__(self):
  114. self.wrapped = False
  115. def register_processor(self, trainer: EpochBasedTrainer):
  116. processor = MpuProcessor()
  117. ckpt_hook = trainer.get_hook(CheckpointHook)
  118. if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor,
  119. MpuProcessor):
  120. ckpt_hook[0].set_processor(processor)
  121. best_ckpt_hook = trainer.get_hook(BestCkptSaverHook)
  122. if len(best_ckpt_hook) > 0 and not isinstance(
  123. best_ckpt_hook[0].processor, MpuProcessor):
  124. best_ckpt_hook[0].set_processor(processor)
  125. load_ckpt_hook = trainer.get_hook(LoadCheckpointHook)
  126. if len(load_ckpt_hook) > 0 and not isinstance(
  127. load_ckpt_hook[0].processor, MpuProcessor):
  128. load_ckpt_hook[0].set_processor(processor)
  129. def after_init(self, trainer):
  130. assert is_megatron_initialized()
  131. local_rank = get_local_rank()
  132. trainer.device = create_device(f'cuda:{local_rank}')
  133. trainer.model.to(trainer.device)
  134. trainer.parallel_groups[
  135. DistributedParallelType.DP] = mpu.get_data_parallel_group()
  136. trainer.parallel_groups[DistributedParallelType.
  137. TP] = mpu.get_tensor_model_parallel_group()
  138. trainer.parallel_groups[DistributedParallelType.
  139. PP] = mpu.get_pipeline_model_parallel_group()
  140. def before_run(self, trainer):
  141. self.wrap_module(trainer)
  142. def before_val(self, trainer):
  143. self.wrap_module(trainer)
  144. def wrap_module(self, trainer):
  145. if trainer._dist:
  146. if not self.wrapped:
  147. trainer.model = trainer.to_parallel(trainer.model)
  148. self.wrapped = True