deepspeed_hook.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Copyright 2020 The HuggingFace Team. All rights reserved.
  3. import math
  4. import os
  5. import shutil
  6. from functools import partialmethod
  7. import deepspeed
  8. import torch
  9. from deepspeed import DeepSpeedEngine
  10. from megatron_util import mpu, print_rank_0
  11. from transformers.deepspeed import HfTrainerDeepSpeedConfig
  12. from modelscope.metainfo import Hooks
  13. from modelscope.trainers.hooks import LoadCheckpointHook
  14. from modelscope.trainers.hooks.builder import HOOKS
  15. from modelscope.trainers.hooks.checkpoint.checkpoint_hook import (
  16. BestCkptSaverHook, CheckpointHook)
  17. from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
  18. CheckpointProcessor
  19. from modelscope.trainers.hooks.hook import Hook
  20. from modelscope.trainers.hooks.lr_scheduler_hook import (LrSchedulerHook,
  21. LrSchedulerProcessor)
  22. from modelscope.trainers.hooks.optimizer.base import (OptimizerHook,
  23. OptimizerProcessor)
  24. from modelscope.trainers.hooks.priority import Priority
  25. from modelscope.utils.checkpoint import save_checkpoint
  26. from modelscope.utils.constant import DistributedParallelType
  27. from modelscope.utils.device import create_device
  28. from modelscope.utils.logger import get_logger
  29. from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
  30. init_dist)
  31. class DeepSpeedConfig(HfTrainerDeepSpeedConfig):
  32. """
  33. The `DeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the
  34. same lifespan as the latter.
  35. """
  36. def is_auto(self, ds_key_long):
  37. val = self.get_value(ds_key_long)
  38. if val is None:
  39. return False
  40. else:
  41. return val == 'auto'
  42. def trainer_config_finalize(self, args, model, num_training_steps):
  43. """
  44. This stage runs after we have the model and know num_training_steps.
  45. Now we can complete the configuration process.
  46. """
  47. # zero
  48. # deal with config keys that use `auto` value and rely on model's hidden_size
  49. hidden_size_based_keys = [
  50. 'zero_optimization.reduce_bucket_size',
  51. 'zero_optimization.stage3_prefetch_bucket_size',
  52. 'zero_optimization.stage3_param_persistence_threshold',
  53. ]
  54. hidden_size_auto_keys = [
  55. x for x in hidden_size_based_keys if self.is_auto(x)
  56. ]
  57. if len(hidden_size_auto_keys) > 0:
  58. if hasattr(model.config, 'hidden_size'):
  59. hidden_size = model.config.hidden_size
  60. elif hasattr(model.config, 'hidden_sizes'):
  61. # if there are many hidden sizes pick the largest one
  62. hidden_size = max(model.config.hidden_sizes)
  63. else:
  64. raise ValueError(
  65. "The model's config file has neither `hidden_size` nor `hidden_sizes` entry, "
  66. "therefore it's not possible to automatically fill out the following `auto` entries "
  67. f'in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing '
  68. '`auto` values for these keys with an integer value of your choice.'
  69. )
  70. self.fill_only('zero_optimization.reduce_bucket_size',
  71. hidden_size * hidden_size)
  72. if self.is_zero3():
  73. # automatically assign the optimal config values based on model config
  74. self.fill_only('zero_optimization.stage3_prefetch_bucket_size',
  75. 0.9 * hidden_size * hidden_size)
  76. self.fill_only(
  77. 'zero_optimization.stage3_param_persistence_threshold',
  78. 10 * hidden_size)
  79. # scheduler
  80. options = args.train.optimizer.get('options', {})
  81. warmup = options.get('warmup', {})
  82. warmup_steps = warmup.get('warmup_steps', 0)
  83. warmup_ratio = warmup.get('warmup_ratio', 0.0)
  84. warmup_steps = warmup_steps if warmup_steps > 0 else math.ceil(
  85. num_training_steps * warmup_ratio)
  86. self.fill_match('scheduler.params.total_num_steps', num_training_steps)
  87. self.fill_match('scheduler.params.warmup_num_steps', warmup_steps)
  88. if len(self.mismatches) > 0:
  89. mismatches = '\n'.join(self.mismatches)
  90. raise ValueError(
  91. 'Please correct the following DeepSpeed config values that mismatch TrainingArguments'
  92. f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'."
  93. )
  94. def deepspeed_optim_sched(trainer, hf_deepspeed_config, num_training_steps):
  95. config = hf_deepspeed_config.config
  96. optimizer = None
  97. if 'optimizer' not in config:
  98. if hf_deepspeed_config.is_offload():
  99. logger.info(
  100. 'Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the'
  101. ' custom optimizer has both CPU and GPU implementation (except LAMB)'
  102. )
  103. # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
  104. # But trainer uses AdamW by default.
  105. optimizer = trainer.optimizer
  106. # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
  107. config['zero_allow_untested_optimizer'] = True
  108. lr_scheduler = None
  109. if 'scheduler' not in config:
  110. lr_scheduler = trainer.scheduler
  111. return optimizer, lr_scheduler
  112. class DeepspeedProcessor(CheckpointProcessor, LrSchedulerProcessor,
  113. OptimizerProcessor):
  114. _BIN_FILE_DIR = 'model'
  115. def rank_name(self):
  116. # TODO
  117. try:
  118. tp_world_size = mpu.get_tensor_model_parallel_world_size()
  119. if tp_world_size == 1:
  120. return ''
  121. mp_rank = mpu.get_tensor_model_parallel_rank()
  122. return '_mp_rank_{:02d}'.format(mp_rank)
  123. except (ImportError, AssertionError):
  124. return ''
  125. def get_bin_filename(self, with_mpu=True):
  126. if not with_mpu:
  127. return 'pytorch_model.bin'
  128. else:
  129. mp_rank = mpu.get_tensor_model_parallel_rank()
  130. rank = '{:02d}'.format(mp_rank)
  131. return f'mp_rank_{rank}_model_states.pt'
  132. def save_checkpoints(self,
  133. trainer,
  134. checkpoint_path_prefix,
  135. output_dir,
  136. meta=None,
  137. save_optimizers=True):
  138. model = trainer.unwrap_module(trainer.model)
  139. _train_state_file = checkpoint_path_prefix + self.rank_name(
  140. ) + CheckpointProcessor.TRAINER_STATE_SUFFIX
  141. # Save pth file without model state_dict
  142. save_checkpoint(
  143. model, _train_state_file, None, None, meta=meta, with_model=False)
  144. save_dir = os.path.dirname(checkpoint_path_prefix)
  145. prefix = os.path.basename(checkpoint_path_prefix)
  146. with_mpu = not mpu.is_unitialized()
  147. bin_file = self.get_bin_filename(with_mpu)
  148. src_file = os.path.join(checkpoint_path_prefix, bin_file)
  149. if self.zero_stage == 3 or with_mpu:
  150. trainer.model.save_checkpoint(save_dir, prefix)
  151. else:
  152. save_checkpoint(
  153. model, src_file, None, None, meta=None, with_meta=False)
  154. if self.zero_stage == 3:
  155. return
  156. if with_mpu:
  157. dest_file = os.path.join(output_dir, self._BIN_FILE_DIR, bin_file)
  158. else:
  159. dest_file = os.path.join(output_dir, bin_file)
  160. if os.path.isfile(dest_file):
  161. os.unlink(dest_file)
  162. try:
  163. os.link(src_file, dest_file)
  164. except OSError as e:
  165. get_logger().error(
  166. f'Link {src_file} to {dest_file} error: {e}, '
  167. 'changing to copy the bin file, this may case more space usage.'
  168. )
  169. shutil.copyfile(src_file, dest_file)
  170. def remove_checkpoints(self, trainer, checkpoint_path_prefix):
  171. _train_state_file = checkpoint_path_prefix + self.rank_name(
  172. ) + CheckpointProcessor.TRAINER_STATE_SUFFIX
  173. if os.path.isfile(_train_state_file):
  174. os.remove(_train_state_file)
  175. shutil.rmtree(checkpoint_path_prefix, ignore_errors=True)
  176. def load_checkpoints(self, checkpoint_path_prefix, trainer, load_all_state,
  177. strict):
  178. assert os.path.isdir(checkpoint_path_prefix)
  179. path = os.path.dirname(checkpoint_path_prefix)
  180. tag = os.path.basename(checkpoint_path_prefix)
  181. meta = {}
  182. _train_state_file = checkpoint_path_prefix + self.rank_name(
  183. ) + CheckpointProcessor.TRAINER_STATE_SUFFIX
  184. if os.path.isfile(_train_state_file):
  185. meta = self.load_trainer_state(trainer, _train_state_file,
  186. load_all_state)
  187. if isinstance(trainer.model, DeepSpeedEngine):
  188. # DeepSpeedEngine is initialized
  189. trainer.model.load_checkpoint(
  190. path,
  191. tag,
  192. load_module_strict=strict,
  193. load_module_only=not load_all_state,
  194. )
  195. else:
  196. # in eval or prediction
  197. save_dir = checkpoint_path_prefix
  198. bin_file = self.get_bin_filename()
  199. model_file = os.path.join(save_dir, bin_file)
  200. checkpoint = torch.load(
  201. model_file, map_location=lambda storage, loc: storage)
  202. checkpoint = checkpoint['module']
  203. model_dict = trainer.unwrap_module(trainer.model).state_dict()
  204. for key in checkpoint:
  205. if key not in model_dict.keys():
  206. print_rank_0('Skip key: ' + key)
  207. else:
  208. print_rank_0('Loading key: ' + key)
  209. trainer.unwrap_module(trainer.model).load_state_dict(
  210. checkpoint, strict=strict)
  211. return meta
  212. def backward(self, trainer, loss_keys, cumulative_iters, grad_clip):
  213. # assert cumulative_iters == 1, 'DeepSpeed only support cumulative_iters=1'
  214. # The `trainer.model` here is actually a deepspeed engine object.
  215. # backward step
  216. for k in loss_keys:
  217. loss = trainer.train_outputs[k]
  218. trainer.model.backward(loss)
  219. # update parameters
  220. # Optimizer step for deepspeed must be called on every step regardless of
  221. # the value of gradient accumulation iters
  222. trainer.model.step()
  223. def initialize_optimizer(self, trainer):
  224. pass
  225. def step(self, trainer):
  226. pass
  227. def should_save_on_rank(self, trainer):
  228. return True
  229. def get_current_lr(self, trainer):
  230. if isinstance(trainer.optimizer, torch.optim.Optimizer) or isinstance(
  231. trainer.optimizer, deepspeed.DeepSpeedOptimizer):
  232. lr = [group['lr'] for group in trainer.optimizer.param_groups]
  233. elif isinstance(trainer.optimizer, dict):
  234. lr = dict()
  235. for name, optim in trainer.optimizer.items():
  236. lr[name] = [group['lr'] for group in optim.param_groups]
  237. else:
  238. raise RuntimeError(
  239. 'lr is not applicable because optimizer does not exist.')
  240. return lr
  241. @HOOKS.register_module(module_name=Hooks.DeepspeedHook)
  242. class DeepspeedHook(Hook):
  243. PRIORITY = Priority.VERY_HIGH
  244. def __init__(self,
  245. config=None,
  246. deepspeed_activation_checkpointing=True,
  247. save_zero_checkpoint=False,
  248. with_mpu=True,
  249. zero_stage=None):
  250. self.save_zero_checkpoint = save_zero_checkpoint
  251. self.deepspeed_activation_checkpointing = deepspeed_activation_checkpointing
  252. self.with_mpu = with_mpu
  253. self.deepspeed_config = config
  254. if zero_stage is not None:
  255. assert zero_stage in (0, 1, 2,
  256. 3), 'zero_stage must in (0, 1, 2, 3)!'
  257. self.zero_stage = zero_stage
  258. def register_processor(self, trainer):
  259. processor = DeepspeedProcessor()
  260. optimizer_hook = trainer.get_hook(OptimizerHook)
  261. if len(optimizer_hook) > 0 and not isinstance(
  262. optimizer_hook[0].processor, DeepspeedProcessor):
  263. optimizer_hook[0].set_processor(processor)
  264. ckpt_hook = trainer.get_hook(CheckpointHook)
  265. if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor,
  266. DeepspeedProcessor):
  267. ckpt_hook[0].set_processor(processor)
  268. best_ckpt_hook = trainer.get_hook(BestCkptSaverHook)
  269. if len(best_ckpt_hook) > 0 and not isinstance(
  270. best_ckpt_hook[0].processor, DeepspeedProcessor):
  271. best_ckpt_hook[0].set_processor(processor)
  272. load_ckpt_hook = trainer.get_hook(LoadCheckpointHook)
  273. if len(load_ckpt_hook) > 0 and not isinstance(
  274. load_ckpt_hook[0].processor, DeepspeedProcessor):
  275. load_ckpt_hook[0].set_processor(processor)
  276. lr_scheduler_hook = trainer.get_hook(LrSchedulerHook)
  277. if len(lr_scheduler_hook) > 0 and not isinstance(
  278. lr_scheduler_hook[0].processor, DeepspeedProcessor):
  279. lr_scheduler_hook[0].set_processor(processor)
  280. self.processor = processor
  281. def prepare_args(self, args):
  282. args.per_device_train_batch_size = args.train.dataloader.get(
  283. 'batch_size_per_gpu', 4)
  284. args.max_grad_norm = args.train.get('clip_grad', 1.0)
  285. args.learning_rate = args.train.optimizer.get('lr', 2e-5)
  286. args.adam_beta1 = args.train.optimizer.get('adam_beta1', 0.9)
  287. args.adam_beta2 = args.train.optimizer.get('adam_beta2', 0.999)
  288. args.adam_epsilon = args.train.optimizer.get('adam_epsilon', 1e-8)
  289. args.weight_decay = args.train.optimizer.get('weight_decay', 0.0)
  290. args.fp16 = args.train.get('use_fp16', False)
  291. args.fp16_full_eval = args.train.get('use_fp16', False)
  292. args.fp16_backend = args.train.get('fp16_backend', 'amp')
  293. args.save_on_each_node = args.train.get('save_on_each_node', False)
  294. args.fp16_opt_level = args.train.get('fp16_opt_level', None)
  295. args.fp16_opt_level = next((item.get('opt_level', args.fp16_opt_level)
  296. for item in args.train.hooks
  297. if item['type'] == 'ApexAMPOptimizerHook'),
  298. args.fp16_opt_level)
  299. if not args.fp16_opt_level:
  300. args.fp16_opt_level = 'O1'
  301. args.bf16 = args.train.get('bf16', False)
  302. def get_deepspeed_config(self, trainer, args, max_steps):
  303. _, args.world_size = get_dist_info()
  304. self.prepare_args(args)
  305. if os.path.exists(self.deepspeed_config):
  306. deepspeed_config = self.deepspeed_config
  307. else:
  308. deepspeed_config = os.path.join(trainer.model_dir,
  309. self.deepspeed_config)
  310. if not os.path.exists(deepspeed_config):
  311. raise RuntimeError(
  312. f'No such DeepSpeed json config file: {self.deepspeed_config}.'
  313. )
  314. self.logger.info(f'Loading deepspeed config from {deepspeed_config}')
  315. ds_config = DeepSpeedConfig(deepspeed_config)
  316. ds_config.trainer_config_process(args)
  317. ds_config.trainer_config_finalize(args, trainer.model, max_steps)
  318. return ds_config
  319. def after_init(self, trainer):
  320. init_dist('pytorch')
  321. local_rank = get_local_rank()
  322. trainer.device = create_device(f'cuda:{local_rank}')
  323. trainer.model.to(trainer.device)
  324. trainer.parallel_groups[DistributedParallelType.DP] = None
  325. def before_val(self, trainer):
  326. pass
  327. def before_run(self, trainer):
  328. if not hasattr(trainer, 'logger'):
  329. self.logger = get_logger()
  330. else:
  331. self.logger = trainer.logger
  332. # deepspeed init
  333. args = trainer.cfg
  334. args.gradient_accumulation_steps = args.train.optimizer.get(
  335. 'options', {}).get('cumulative_iters', 1)
  336. num_update_steps_per_epoch = trainer.iters_per_epoch // args.gradient_accumulation_steps
  337. max_steps = math.ceil(trainer._max_epochs * num_update_steps_per_epoch)
  338. ds_config = self.get_deepspeed_config(trainer, args, max_steps)
  339. optimizer, lr_scheduler = deepspeed_optim_sched(
  340. trainer, ds_config, max_steps)
  341. config = ds_config.config
  342. if self.zero_stage is not None:
  343. config['zero_optimization']['stage'] = self.zero_stage
  344. self.processor.zero_stage = config['zero_optimization'].get('stage', 0)
  345. trainer.model, trainer.optimizer, _, trainer.lr_scheduler = deepspeed.initialize(
  346. model=trainer.model,
  347. optimizer=optimizer,
  348. config=config,
  349. lr_scheduler=lr_scheduler)