load_checkpoint.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import torch
  17. def load_checkpoint(model,
  18. load_dir,
  19. tag,
  20. load_module_strict=True,
  21. load_optimizer_states=True,
  22. load_lr_scheduler_states=True):
  23. r"""Load training checkpoint
  24. Arguments:
  25. load_dir: Required. Directory to load the checkpoint from
  26. tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
  27. load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and
  28. checkpoint match.
  29. load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint.
  30. Ex. ADAM's momentum and variance
  31. load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
  32. Return:
  33. load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
  34. client_state: State dictionary used for loading required training states in the client code.
  35. """
  36. load_path, client_states = _load_checkpoint(
  37. model,
  38. load_dir,
  39. tag,
  40. load_module_strict=load_module_strict,
  41. load_optimizer_states=load_optimizer_states,
  42. load_lr_scheduler_states=load_lr_scheduler_states)
  43. if load_optimizer_states:
  44. if model.zero_optimization() and load_path is not None:
  45. model._load_zero_checkpoint(
  46. load_dir, tag, load_optimizer_states=load_optimizer_states)
  47. return load_path, client_states
  48. def _get_ckpt_name(mp_rank, checkpoints_path, tag):
  49. ckpt_name = os.path.join(
  50. checkpoints_path, str(tag),
  51. 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
  52. return ckpt_name
  53. def pre_load(mp_rank, load_dir, tag=''):
  54. load_path = _get_ckpt_name(mp_rank, load_dir, tag)
  55. checkpoint = torch.load(
  56. load_path,
  57. map_location=lambda storage, loc: storage,
  58. weights_only=True)
  59. return checkpoint['module'] if 'module' in checkpoint else checkpoint
  60. def _load_checkpoint(model,
  61. load_dir,
  62. tag,
  63. load_module_strict=True,
  64. load_optimizer_states=True,
  65. load_lr_scheduler_states=True):
  66. load_path = model._get_ckpt_name(load_dir, tag)
  67. if not os.path.exists(load_path):
  68. return None, None
  69. checkpoint = torch.load(
  70. load_path,
  71. map_location=lambda storage, loc: storage,
  72. weights_only=True)
  73. model.load_module_state_dict(
  74. state_dict=checkpoint['module'], strict=load_module_strict)
  75. if not model.zero_optimization() and load_optimizer_states:
  76. if model.fp16_enabled():
  77. model.optimizer.load_state_dict(
  78. checkpoint['optimizer'],
  79. load_optimizer_states=load_optimizer_states)
  80. elif load_optimizer_states:
  81. model.optimizer.load_state_dict(checkpoint['optimizer'])
  82. if load_lr_scheduler_states and model.lr_scheduler is not None:
  83. model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  84. model.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
  85. model.global_steps = checkpoint['global_steps']
  86. model.global_samples = checkpoint.get(
  87. 'global_samples', model.global_steps * model.train_batch_size())
  88. model.skipped_steps = checkpoint['skipped_steps']
  89. model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
  90. model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
  91. deepspeed_states = [
  92. 'module', 'optimizer', 'lr_scheduler', 'csr_tensor_module_names',
  93. 'skipped_steps', 'global_steps', 'dp_world_size', 'mp_world_size'
  94. ]
  95. client_state = {
  96. key: value
  97. for key, value in checkpoint.items() if key not in deepspeed_states
  98. }
  99. return load_path, client_state