checkpointing.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import random
  15. from pathlib import Path
  16. from typing import Optional
  17. import numpy as np
  18. import torch
  19. from safetensors.torch import load_model
  20. from .utils import (
  21. MODEL_NAME,
  22. OPTIMIZER_NAME,
  23. RNG_STATE_NAME,
  24. SAFE_MODEL_NAME,
  25. SAFE_WEIGHTS_NAME,
  26. SAMPLER_NAME,
  27. SCALER_NAME,
  28. SCHEDULER_NAME,
  29. WEIGHTS_NAME,
  30. get_pretty_name,
  31. is_cuda_available,
  32. is_hpu_available,
  33. is_mlu_available,
  34. is_musa_available,
  35. is_sdaa_available,
  36. is_torch_version,
  37. is_torch_xla_available,
  38. is_xpu_available,
  39. load,
  40. save,
  41. )
  42. if is_torch_version(">=", "2.4.0"):
  43. from torch.amp import GradScaler
  44. else:
  45. from torch.cuda.amp import GradScaler
  46. if is_torch_xla_available():
  47. import torch_xla.core.xla_model as xm
  48. from .logging import get_logger
  49. from .state import PartialState
  50. logger = get_logger(__name__)
  51. def save_accelerator_state(
  52. output_dir: str,
  53. model_states: list[dict],
  54. optimizers: list,
  55. schedulers: list,
  56. dataloaders: list,
  57. process_index: int,
  58. step: int,
  59. scaler: Optional[GradScaler] = None,
  60. save_on_each_node: bool = False,
  61. safe_serialization: bool = True,
  62. ):
  63. """
  64. Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
  65. <Tip>
  66. If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native
  67. `pickle`.
  68. </Tip>
  69. Args:
  70. output_dir (`str` or `os.PathLike`):
  71. The name of the folder to save all relevant weights and states.
  72. model_states (`List[torch.nn.Module]`):
  73. A list of model states
  74. optimizers (`List[torch.optim.Optimizer]`):
  75. A list of optimizer instances
  76. schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
  77. A list of learning rate schedulers
  78. dataloaders (`List[torch.utils.data.DataLoader]`):
  79. A list of dataloader instances to save their sampler states
  80. process_index (`int`):
  81. The current process index in the Accelerator state
  82. step (`int`):
  83. The current step in the internal step tracker
  84. scaler (`torch.amp.GradScaler`, *optional*):
  85. An optional gradient scaler instance to save;
  86. save_on_each_node (`bool`, *optional*):
  87. Whether to save on every node, or only the main node.
  88. safe_serialization (`bool`, *optional*, defaults to `True`):
  89. Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
  90. """
  91. output_dir = Path(output_dir)
  92. # Model states
  93. for i, state in enumerate(model_states):
  94. weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
  95. if i > 0:
  96. weights_name = weights_name.replace(".", f"_{i}.")
  97. output_model_file = output_dir.joinpath(weights_name)
  98. save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
  99. logger.info(f"Model weights saved in {output_model_file}")
  100. # Optimizer states
  101. for i, opt in enumerate(optimizers):
  102. state = opt.state_dict()
  103. optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
  104. output_optimizer_file = output_dir.joinpath(optimizer_name)
  105. save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
  106. logger.info(f"Optimizer state saved in {output_optimizer_file}")
  107. # Scheduler states
  108. for i, scheduler in enumerate(schedulers):
  109. state = scheduler.state_dict()
  110. scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
  111. output_scheduler_file = output_dir.joinpath(scheduler_name)
  112. save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
  113. logger.info(f"Scheduler state saved in {output_scheduler_file}")
  114. # DataLoader states
  115. for i, dataloader in enumerate(dataloaders):
  116. sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
  117. output_sampler_file = output_dir.joinpath(sampler_name)
  118. # Only save if we have our custom sampler
  119. from .data_loader import IterableDatasetShard, SeedableRandomSampler
  120. if isinstance(dataloader.dataset, IterableDatasetShard):
  121. sampler = dataloader.get_sampler()
  122. if isinstance(sampler, SeedableRandomSampler):
  123. save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
  124. if getattr(dataloader, "use_stateful_dataloader", False):
  125. dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
  126. output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
  127. state_dict = dataloader.state_dict()
  128. torch.save(state_dict, output_dataloader_state_dict_file)
  129. logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
  130. # GradScaler state
  131. if scaler is not None:
  132. state = scaler.state_dict()
  133. output_scaler_file = output_dir.joinpath(SCALER_NAME)
  134. torch.save(state, output_scaler_file)
  135. logger.info(f"Gradient scaler state saved in {output_scaler_file}")
  136. # Random number generator states
  137. states = {}
  138. states_name = f"{RNG_STATE_NAME}_{process_index}.pkl"
  139. states["step"] = step
  140. states["random_state"] = random.getstate()
  141. states["numpy_random_seed"] = np.random.get_state()
  142. states["torch_manual_seed"] = torch.get_rng_state()
  143. if is_xpu_available():
  144. states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
  145. if is_mlu_available():
  146. states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
  147. elif is_sdaa_available():
  148. states["torch_sdaa_manual_seed"] = torch.sdaa.get_rng_state_all()
  149. elif is_musa_available():
  150. states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all()
  151. if is_hpu_available():
  152. states["torch_hpu_manual_seed"] = torch.hpu.get_rng_state_all()
  153. if is_cuda_available():
  154. states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
  155. if is_torch_xla_available():
  156. states["xm_seed"] = xm.get_rng_state()
  157. output_states_file = output_dir.joinpath(states_name)
  158. torch.save(states, output_states_file)
  159. logger.info(f"Random states saved in {output_states_file}")
  160. return output_dir
  161. def load_accelerator_state(
  162. input_dir,
  163. models,
  164. optimizers,
  165. schedulers,
  166. dataloaders,
  167. process_index,
  168. scaler=None,
  169. map_location=None,
  170. load_kwargs=None,
  171. **load_model_func_kwargs,
  172. ):
  173. """
  174. Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
  175. Args:
  176. input_dir (`str` or `os.PathLike`):
  177. The name of the folder to load all relevant weights and states.
  178. models (`List[torch.nn.Module]`):
  179. A list of model instances
  180. optimizers (`List[torch.optim.Optimizer]`):
  181. A list of optimizer instances
  182. schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
  183. A list of learning rate schedulers
  184. process_index (`int`):
  185. The current process index in the Accelerator state
  186. scaler (`torch.amp.GradScaler`, *optional*):
  187. An optional *GradScaler* instance to load
  188. map_location (`str`, *optional*):
  189. What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
  190. load_kwargs (`dict`, *optional*):
  191. Additional arguments that can be passed to the `load` function.
  192. load_model_func_kwargs (`dict`, *optional*):
  193. Additional arguments that can be passed to the model's `load_state_dict` method.
  194. Returns:
  195. `dict`: Contains the `Accelerator` attributes to override while loading the state.
  196. """
  197. # stores the `Accelerator` attributes to override
  198. override_attributes = dict()
  199. if map_location not in [None, "cpu", "on_device"]:
  200. raise TypeError(
  201. "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`"
  202. )
  203. if map_location is None:
  204. map_location = "cpu"
  205. elif map_location == "on_device":
  206. map_location = PartialState().device
  207. if load_kwargs is None:
  208. load_kwargs = {}
  209. input_dir = Path(input_dir)
  210. # Model states
  211. for i, model in enumerate(models):
  212. ending = f"_{i}" if i > 0 else ""
  213. input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
  214. if input_model_file.exists():
  215. load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
  216. else:
  217. # Load with torch
  218. input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
  219. state_dict = load(input_model_file, map_location=map_location)
  220. model.load_state_dict(state_dict, **load_model_func_kwargs)
  221. logger.info("All model weights loaded successfully")
  222. # Optimizer states
  223. for i, opt in enumerate(optimizers):
  224. optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
  225. input_optimizer_file = input_dir.joinpath(optimizer_name)
  226. optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs)
  227. optimizers[i].load_state_dict(optimizer_state)
  228. logger.info("All optimizer states loaded successfully")
  229. # Scheduler states
  230. for i, scheduler in enumerate(schedulers):
  231. scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
  232. input_scheduler_file = input_dir.joinpath(scheduler_name)
  233. scheduler_state = load(input_scheduler_file, **load_kwargs)
  234. scheduler.load_state_dict(scheduler_state)
  235. logger.info("All scheduler states loaded successfully")
  236. for i, dataloader in enumerate(dataloaders):
  237. sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
  238. input_sampler_file = input_dir.joinpath(sampler_name)
  239. # Only load if we have our custom sampler
  240. from .data_loader import IterableDatasetShard, SeedableRandomSampler
  241. if isinstance(dataloader.dataset, IterableDatasetShard):
  242. sampler = dataloader.get_sampler()
  243. if isinstance(sampler, SeedableRandomSampler):
  244. sampler = dataloader.set_sampler(load(input_sampler_file))
  245. if getattr(dataloader, "use_stateful_dataloader", False):
  246. dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
  247. input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
  248. if input_dataloader_state_dict_file.exists():
  249. state_dict = load(input_dataloader_state_dict_file, **load_kwargs)
  250. dataloader.load_state_dict(state_dict)
  251. logger.info("All dataloader sampler states loaded successfully")
  252. # GradScaler state
  253. if scaler is not None:
  254. input_scaler_file = input_dir.joinpath(SCALER_NAME)
  255. scaler_state = load(input_scaler_file)
  256. scaler.load_state_dict(scaler_state)
  257. logger.info("GradScaler state loaded successfully")
  258. # Random states
  259. try:
  260. states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
  261. if "step" in states:
  262. override_attributes["step"] = states["step"]
  263. random.setstate(states["random_state"])
  264. np.random.set_state(states["numpy_random_seed"])
  265. torch.set_rng_state(states["torch_manual_seed"])
  266. if is_xpu_available():
  267. torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
  268. if is_mlu_available():
  269. torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
  270. elif is_sdaa_available():
  271. torch.sdaa.set_rng_state_all(states["torch_sdaa_manual_seed"])
  272. elif is_musa_available():
  273. torch.musa.set_rng_state_all(states["torch_musa_manual_seed"])
  274. else:
  275. torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
  276. if is_torch_xla_available():
  277. xm.set_rng_state(states["xm_seed"])
  278. logger.info("All random states loaded successfully")
  279. except Exception:
  280. logger.info("Could not load random states")
  281. return override_attributes
  282. def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
  283. """
  284. Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
  285. """
  286. # Should this be the right way to get a qual_name type value from `obj`?
  287. save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
  288. logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
  289. save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)
  290. def load_custom_state(obj, path, index: int = 0):
  291. """
  292. Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when
  293. loading the state.
  294. """
  295. load_location = f"{path}/custom_checkpoint_{index}.pkl"
  296. logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}")
  297. obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False))