fsdp_utils.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import functools
  16. import os
  17. import re
  18. import shutil
  19. import warnings
  20. from collections import defaultdict
  21. from collections.abc import Iterable
  22. from contextlib import nullcontext
  23. from pathlib import Path
  24. from typing import Callable, Union
  25. import torch
  26. from ..logging import get_logger
  27. from .constants import FSDP_MODEL_NAME, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
  28. from .dataclasses import get_module_class_from_name
  29. from .modeling import get_non_persistent_buffers, is_peft_model
  30. from .other import get_module_children_bottom_up, is_compiled_module, save
  31. from .versions import is_torch_version
  32. logger = get_logger(__name__)
  33. def enable_fsdp_ram_efficient_loading():
  34. """
  35. Enables RAM efficient loading of Hugging Face models for FSDP in the environment.
  36. """
  37. # Sets values for `transformers.modeling_utils.is_fsdp_enabled`
  38. if "ACCELERATE_USE_FSDP" not in os.environ:
  39. os.environ["ACCELERATE_USE_FSDP"] = "True"
  40. os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "True"
  41. def disable_fsdp_ram_efficient_loading():
  42. """
  43. Disables RAM efficient loading of Hugging Face models for FSDP in the environment.
  44. """
  45. os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "False"
  46. def _get_model_state_dict(model, adapter_only=False, sd_options=None):
  47. if adapter_only and is_peft_model(model):
  48. from peft import get_peft_model_state_dict
  49. return get_peft_model_state_dict(model, adapter_name=model.active_adapter)
  50. # Invariant: `sd_options` is not None only for FSDP2
  51. if sd_options is not None:
  52. from torch.distributed.checkpoint.state_dict import get_model_state_dict
  53. return get_model_state_dict(model, options=sd_options)
  54. else:
  55. return model.state_dict()
  56. def _set_model_state_dict(model, state_dict, adapter_only=False, sd_options=None):
  57. if adapter_only and is_peft_model(model):
  58. from peft import set_peft_model_state_dict
  59. return set_peft_model_state_dict(model, state_dict, adapter_name=model.active_adapter)
  60. # Invariant: `sd_options` is not None only for FSDP2
  61. if sd_options is not None:
  62. from torch.distributed.checkpoint.state_dict import set_model_state_dict
  63. return set_model_state_dict(model, state_dict, options=sd_options)
  64. else:
  65. return model.load_state_dict(state_dict)
  66. def _prepare_sd_options(fsdp_plugin):
  67. sd_options = None
  68. # we use this only for FSDP2, as it requires torch >= 2.6.0 and this api requires torch >= 2.2.0
  69. if fsdp_plugin.fsdp_version == 2:
  70. from torch.distributed.checkpoint.state_dict import StateDictOptions
  71. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  72. sd_options = StateDictOptions(
  73. full_state_dict=fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT,
  74. cpu_offload=getattr(fsdp_plugin.state_dict_config, "offload_to_cpu", False),
  75. broadcast_from_rank0=getattr(fsdp_plugin.state_dict_config, "rank0_only", False),
  76. )
  77. return sd_options
  78. def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False):
  79. # Note: We import here to reduce import time from general modules, and isolate outside dependencies
  80. import torch.distributed.checkpoint as dist_cp
  81. from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
  82. from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
  83. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  84. os.makedirs(output_dir, exist_ok=True)
  85. if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
  86. # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
  87. # so, only enable it when num_processes>1
  88. is_multi_process = accelerator.num_processes > 1
  89. fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
  90. fsdp_plugin.state_dict_config.rank0_only = is_multi_process
  91. ctx = (
  92. FSDP.state_dict_type(
  93. model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
  94. )
  95. if fsdp_plugin.fsdp_version == 1
  96. else nullcontext()
  97. )
  98. sd_options = _prepare_sd_options(fsdp_plugin)
  99. with ctx:
  100. state_dict = _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)
  101. if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
  102. weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
  103. output_model_file = os.path.join(output_dir, weights_name)
  104. if accelerator.process_index == 0:
  105. logger.info(f"Saving model to {output_model_file}")
  106. torch.save(state_dict, output_model_file)
  107. logger.info(f"Model saved to {output_model_file}")
  108. # Invariant: `LOCAL_STATE_DICT` is never possible with `FSDP2`
  109. elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
  110. weights_name = (
  111. f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
  112. if model_index == 0
  113. else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
  114. )
  115. output_model_file = os.path.join(output_dir, weights_name)
  116. logger.info(f"Saving model to {output_model_file}")
  117. torch.save(state_dict, output_model_file)
  118. logger.info(f"Model saved to {output_model_file}")
  119. elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT:
  120. ckpt_dir = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{model_index}")
  121. os.makedirs(ckpt_dir, exist_ok=True)
  122. logger.info(f"Saving model to {ckpt_dir}")
  123. state_dict = {"model": state_dict}
  124. dist_cp.save(
  125. state_dict=state_dict,
  126. storage_writer=dist_cp.FileSystemWriter(ckpt_dir),
  127. planner=DefaultSavePlanner(),
  128. )
  129. logger.info(f"Model saved to {ckpt_dir}")
  130. def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False):
  131. # Note: We import here to reduce import time from general modules, and isolate outside dependencies
  132. import torch.distributed.checkpoint as dist_cp
  133. from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
  134. from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
  135. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  136. accelerator.wait_for_everyone()
  137. if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
  138. # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
  139. # so, only enable it when num_processes>1
  140. is_multi_process = accelerator.num_processes > 1
  141. fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
  142. fsdp_plugin.state_dict_config.rank0_only = is_multi_process
  143. ctx = (
  144. FSDP.state_dict_type(
  145. model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
  146. )
  147. if fsdp_plugin.fsdp_version == 1
  148. else nullcontext()
  149. )
  150. sd_options = _prepare_sd_options(fsdp_plugin)
  151. with ctx:
  152. if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
  153. if type(model) is not FSDP and accelerator.process_index != 0 and not accelerator.is_fsdp2:
  154. if not fsdp_plugin.sync_module_states and fsdp_plugin.fsdp_version == 1:
  155. raise ValueError(
  156. "Set the `sync_module_states` flag to `True` so that model states are synced across processes when "
  157. "initializing FSDP object"
  158. )
  159. return
  160. weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
  161. input_model_file = os.path.join(input_dir, weights_name)
  162. logger.info(f"Loading model from {input_model_file}")
  163. # we want an empty state dict for FSDP2 as we use `broadcast_from_rank0`
  164. load_model = not accelerator.is_fsdp2 or accelerator.is_main_process
  165. if load_model:
  166. state_dict = torch.load(input_model_file, weights_only=True)
  167. else:
  168. state_dict = {}
  169. logger.info(f"Model loaded from {input_model_file}")
  170. elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
  171. weights_name = (
  172. f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
  173. if model_index == 0
  174. else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
  175. )
  176. input_model_file = os.path.join(input_dir, weights_name)
  177. logger.info(f"Loading model from {input_model_file}")
  178. state_dict = torch.load(input_model_file, weights_only=True)
  179. logger.info(f"Model loaded from {input_model_file}")
  180. elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT:
  181. ckpt_dir = (
  182. os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{model_index}")
  183. if f"{FSDP_MODEL_NAME}" not in input_dir
  184. else input_dir
  185. )
  186. logger.info(f"Loading model from {ckpt_dir}")
  187. state_dict = {"model": _get_model_state_dict(model, adapter_only=adapter_only, sd_options=sd_options)}
  188. dist_cp.load(
  189. state_dict=state_dict,
  190. storage_reader=dist_cp.FileSystemReader(ckpt_dir),
  191. planner=DefaultLoadPlanner(),
  192. )
  193. state_dict = state_dict["model"]
  194. logger.info(f"Model loaded from {ckpt_dir}")
  195. load_result = _set_model_state_dict(model, state_dict, adapter_only=adapter_only, sd_options=sd_options)
  196. return load_result
  197. def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0):
  198. # Note: We import here to reduce import time from general modules, and isolate outside dependencies
  199. import torch.distributed.checkpoint as dist_cp
  200. from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
  201. from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
  202. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  203. os.makedirs(output_dir, exist_ok=True)
  204. ctx = (
  205. FSDP.state_dict_type(
  206. model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
  207. )
  208. if fsdp_plugin.fsdp_version == 1
  209. else nullcontext()
  210. )
  211. sd_options = _prepare_sd_options(fsdp_plugin)
  212. with ctx:
  213. if fsdp_plugin.fsdp_version == 2:
  214. from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
  215. optim_state = get_optimizer_state_dict(model, optimizer, options=sd_options)
  216. else:
  217. optim_state = FSDP.optim_state_dict(model, optimizer)
  218. if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
  219. if accelerator.process_index == 0:
  220. optim_state_name = (
  221. f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
  222. )
  223. output_optimizer_file = os.path.join(output_dir, optim_state_name)
  224. logger.info(f"Saving Optimizer state to {output_optimizer_file}")
  225. torch.save(optim_state, output_optimizer_file)
  226. logger.info(f"Optimizer state saved in {output_optimizer_file}")
  227. else:
  228. ckpt_dir = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
  229. os.makedirs(ckpt_dir, exist_ok=True)
  230. logger.info(f"Saving Optimizer state to {ckpt_dir}")
  231. dist_cp.save(
  232. state_dict={"optimizer": optim_state},
  233. storage_writer=dist_cp.FileSystemWriter(ckpt_dir),
  234. planner=DefaultSavePlanner(),
  235. )
  236. logger.info(f"Optimizer state saved in {ckpt_dir}")
  237. def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):
  238. # Note: We import here to reduce import time from general modules, and isolate outside dependencies
  239. import torch.distributed.checkpoint as dist_cp
  240. from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
  241. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  242. accelerator.wait_for_everyone()
  243. ctx = (
  244. FSDP.state_dict_type(
  245. model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
  246. )
  247. if fsdp_plugin.fsdp_version == 1
  248. else nullcontext()
  249. )
  250. sd_options = _prepare_sd_options(fsdp_plugin)
  251. with ctx:
  252. if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
  253. optim_state = None
  254. if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:
  255. optimizer_name = (
  256. f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
  257. )
  258. input_optimizer_file = os.path.join(input_dir, optimizer_name)
  259. logger.info(f"Loading Optimizer state from {input_optimizer_file}")
  260. optim_state = torch.load(input_optimizer_file, weights_only=True)
  261. logger.info(f"Optimizer state loaded from {input_optimizer_file}")
  262. else:
  263. ckpt_dir = (
  264. os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
  265. if f"{OPTIMIZER_NAME}" not in input_dir
  266. else input_dir
  267. )
  268. logger.info(f"Loading Optimizer from {ckpt_dir}")
  269. optim_state = {"optimizer": optimizer.state_dict()}
  270. dist_cp.load(
  271. optim_state,
  272. checkpoint_id=ckpt_dir,
  273. storage_reader=dist_cp.FileSystemReader(ckpt_dir),
  274. )
  275. optim_state = optim_state["optimizer"]
  276. logger.info(f"Optimizer loaded from {ckpt_dir}")
  277. if fsdp_plugin.fsdp_version == 1:
  278. flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state)
  279. optimizer.load_state_dict(flattened_osd)
  280. else:
  281. from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict
  282. set_optimizer_state_dict(model, optimizer, optim_state, options=sd_options)
  283. def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_path: str, safe_serialization: bool = True):
  284. """
  285. Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
  286. Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
  287. """
  288. # Note: We import here to reduce import time from general modules, and isolate outside dependencies
  289. import torch.distributed.checkpoint as dist_cp
  290. import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
  291. state_dict = {}
  292. save_path = Path(save_path)
  293. save_path.mkdir(exist_ok=True)
  294. dist_cp_format_utils._load_state_dict(
  295. state_dict,
  296. storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
  297. planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),
  298. no_dist=True,
  299. )
  300. save_path = save_path / SAFE_WEIGHTS_NAME if safe_serialization else save_path / WEIGHTS_NAME
  301. # To handle if state is a dict like {model: {...}}
  302. if len(state_dict.keys()) == 1:
  303. state_dict = state_dict[list(state_dict)[0]]
  304. save(state_dict, save_path, safe_serialization=safe_serialization)
  305. return save_path
  306. def merge_fsdp_weights(
  307. checkpoint_dir: str, output_path: str, safe_serialization: bool = True, remove_checkpoint_dir: bool = False
  308. ):
  309. """
  310. Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
  311. `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
  312. `safe_serialization` else `pytorch_model.bin`.
  313. Note: this is a CPU-bound process.
  314. Args:
  315. checkpoint_dir (`str`):
  316. The directory containing the FSDP checkpoints (can be either the model or optimizer).
  317. output_path (`str`):
  318. The path to save the merged checkpoint.
  319. safe_serialization (`bool`, *optional*, defaults to `True`):
  320. Whether to save the merged weights with safetensors (recommended).
  321. remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
  322. Whether to remove the checkpoint directory after merging.
  323. """
  324. checkpoint_dir = Path(checkpoint_dir)
  325. from accelerate.state import PartialState
  326. if not is_torch_version(">=", "2.3.0"):
  327. raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
  328. # Verify that the checkpoint directory exists
  329. if not checkpoint_dir.exists():
  330. model_path_exists = (checkpoint_dir / "pytorch_model_fsdp_0").exists()
  331. optimizer_path_exists = (checkpoint_dir / "optimizer_0").exists()
  332. err = f"Tried to load from {checkpoint_dir} but couldn't find a valid metadata file."
  333. if model_path_exists and optimizer_path_exists:
  334. err += " However, potential model and optimizer checkpoint directories exist."
  335. err += f"Please pass in either {checkpoint_dir}/pytorch_model_fsdp_0 or {checkpoint_dir}/optimizer_0"
  336. err += "instead."
  337. elif model_path_exists:
  338. err += " However, a potential model checkpoint directory exists."
  339. err += f"Please try passing in {checkpoint_dir}/pytorch_model_fsdp_0 instead."
  340. elif optimizer_path_exists:
  341. err += " However, a potential optimizer checkpoint directory exists."
  342. err += f"Please try passing in {checkpoint_dir}/optimizer_0 instead."
  343. raise ValueError(err)
  344. # To setup `save` to work
  345. state = PartialState()
  346. if state.is_main_process:
  347. logger.info(f"Merging FSDP weights from {checkpoint_dir}")
  348. save_path = _distributed_checkpoint_to_merged_weights(checkpoint_dir, output_path, safe_serialization)
  349. logger.info(f"Successfully merged FSDP weights and saved to {save_path}")
  350. if remove_checkpoint_dir:
  351. logger.info(f"Removing old checkpoint directory {checkpoint_dir}")
  352. shutil.rmtree(checkpoint_dir)
  353. state.wait_for_everyone()
  354. def ensure_weights_retied(param_init_fn, model: torch.nn.Module, device: torch.device):
  355. _tied_names = getattr(model, "_tied_weights_keys", None)
  356. if not _tied_names:
  357. # if no tied names just passthrough
  358. return param_init_fn
  359. # get map of parameter instances to params.
  360. # - needed for replacement later
  361. _tied_params = {}
  362. for name in _tied_names:
  363. name = name.split(".")
  364. name, param_name = ".".join(name[:-1]), name[-1]
  365. mod = model.get_submodule(name)
  366. param = getattr(mod, param_name)
  367. _tied_params[id(param)] = None # placeholder for the param first
  368. # build param_init_fn for the case with tied params
  369. def param_init_fn_tied_param(module: torch.nn.Module):
  370. # track which params to tie
  371. # - usually only 1, but for completeness consider > 1
  372. params_to_tie = defaultdict(list)
  373. for n, param in module.named_parameters(recurse=False):
  374. if id(param) in _tied_params:
  375. params_to_tie[id(param)].append(n)
  376. # call the param init fn, which potentially re-allocates the
  377. # parameters
  378. module = param_init_fn(module)
  379. # search the parameters again and tie them up again
  380. for id_key, _param_names in params_to_tie.items():
  381. for param_name in _param_names:
  382. param = _tied_params[id_key]
  383. if param is None:
  384. # everything will be tied to the first time the
  385. # param is observed
  386. _tied_params[id_key] = getattr(module, param_name)
  387. else:
  388. setattr(module, param_name, param) # tie
  389. return module
  390. return param_init_fn_tied_param
  391. def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
  392. """
  393. Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
  394. parameters from rank 0 to all other ranks. This function modifies the model in-place.
  395. Args:
  396. accelerator (`Accelerator`): The accelerator instance
  397. model (`torch.nn.Module`):
  398. The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
  399. full_sd (`dict`): The full state dict to load, can only be on rank 0
  400. """
  401. import torch.distributed as dist
  402. from torch.distributed.tensor import DTensor, distribute_tensor
  403. # Model was previously copied to meta device
  404. meta_sharded_sd = model.state_dict()
  405. sharded_sd = {}
  406. # Rank 0 distributes the full state dict to other ranks
  407. def _infer_parameter_dtype(model, param_name, empty_param):
  408. try:
  409. old_param = model.get_parameter_or_buffer(param_name)
  410. except AttributeError:
  411. # Need this for LORA, as there some params are not *parameters* of sorts
  412. base_param_name, local_param_name = param_name.rsplit(".", 1)
  413. submodule = model.get_submodule(base_param_name)
  414. old_param = getattr(submodule, local_param_name)
  415. is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
  416. casting_dtype = None
  417. is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
  418. if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
  419. casting_dtype = old_param.dtype
  420. return old_param is not None and old_param.is_contiguous(), casting_dtype
  421. def _cast_and_contiguous(tensor, to_contiguous, dtype):
  422. if dtype is not None:
  423. tensor = tensor.to(dtype=dtype)
  424. if to_contiguous:
  425. tensor = tensor.contiguous()
  426. return tensor
  427. if accelerator.is_main_process:
  428. for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
  429. device_mesh = sharded_param.device_mesh
  430. full_param = full_param.detach().to(device_mesh.device_type)
  431. if isinstance(full_param, DTensor):
  432. # dist.broadcast() only supports torch.Tensor.
  433. # After prepare_tp(), model parameters may become DTensor.
  434. # To broadcast such a parameter, convert it to a local tensor first.
  435. full_param = full_param.to_local()
  436. dist.broadcast(full_param, src=0, group=dist.group.WORLD)
  437. sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
  438. to_contiguous, casting_dtype = _infer_parameter_dtype(
  439. model,
  440. param_name,
  441. full_param,
  442. )
  443. sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
  444. sharded_sd[param_name] = sharded_tensor
  445. # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
  446. else:
  447. for param_name, sharded_param in meta_sharded_sd.items():
  448. device_mesh = sharded_param.device_mesh
  449. full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)
  450. dist.broadcast(full_tensor, src=0, group=dist.group.WORLD)
  451. sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements)
  452. to_contiguous, casting_dtype = _infer_parameter_dtype(
  453. model,
  454. param_name,
  455. full_tensor,
  456. )
  457. sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
  458. sharded_sd[param_name] = sharded_tensor
  459. # we set `assign=True` because our params are on meta device
  460. model.load_state_dict(sharded_sd, assign=True)
  461. return model
  462. def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, mapping: dict):
  463. """
  464. Switches the parameters of the optimizer to new ones (sharded parameters in usual case). This function modifies the
  465. optimizer in-place.
  466. Args:
  467. optimizer (`torch.optim.Optimizer`): Optimizer instance which contains the original model parameters
  468. mapping (`dict`): Mapping from the original parameter (specified by `data_ptr`) to the sharded parameter
  469. Raises:
  470. KeyError:
  471. If a parameter in the optimizer couldn't be switched to its sharded version. This should never happen and
  472. indicates a bug. If we kept the original params instead of raising, the training wouldn't be numerically
  473. correct and weights wouldn't get updated.
  474. """
  475. from torch.distributed.tensor import DTensor
  476. accessor_mapping = {}
  477. accessor_mapping[DTensor] = "_local_tensor"
  478. try:
  479. for param_group in optimizer.param_groups:
  480. param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
  481. except KeyError:
  482. # This shouldn't ever happen, but we want to fail here else training wouldn't be numerically correct
  483. # This basically means that we're missing a mapping from the original parameter to the sharded parameter
  484. raise KeyError(
  485. "A parameter in the optimizer couldn't be switched to its sharded version. This breaks the training. Please raise an issue on GitHub."
  486. )
  487. def fsdp2_apply_ac(accelerator, model: torch.nn.Module):
  488. """
  489. Applies the activation checkpointing to the model.
  490. Args:
  491. accelerator (`Accelerator`): The accelerator instance
  492. model (`torch.nn.Module`): The model to apply the activation checkpointing to
  493. Returns:
  494. `torch.nn.Module`: The model with the activation checkpointing applied
  495. """
  496. from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
  497. checkpoint_wrapper,
  498. )
  499. auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(accelerator.state.fsdp_plugin, model)
  500. for layer_name, layer in get_module_children_bottom_up(model, return_fqns=True)[:-1]:
  501. if len(layer_name.split(".")) > 1:
  502. parent_name, child_name = layer_name.rsplit(".", 1)
  503. else:
  504. parent_name = None
  505. child_name = layer_name
  506. parent_module = model.get_submodule(parent_name) if parent_name else model
  507. if auto_wrap_policy_func(parent_module):
  508. layer = checkpoint_wrapper(layer, preserve_rng_state=False)
  509. parent_module.register_module(child_name, layer)
  510. return model
  511. def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
  512. """Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
  513. Args:
  514. accelerator (`Accelerator`): The accelerator instance
  515. model (`torch.nn.Module`): The model to prepare
  516. Returns:
  517. `torch.nn.Module`: Prepared model
  518. """
  519. from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard
  520. is_type_fsdp = isinstance(model, FSDPModule) or (
  521. is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule)
  522. )
  523. if is_type_fsdp:
  524. return model
  525. fsdp2_plugin = accelerator.state.fsdp_plugin
  526. fsdp2_plugin.set_auto_wrap_policy(model)
  527. original_sd = model.state_dict()
  528. mesh = getattr(accelerator, "torch_device_mesh", None)
  529. fsdp2_kwargs = {
  530. "reshard_after_forward": fsdp2_plugin.reshard_after_forward,
  531. "offload_policy": fsdp2_plugin.cpu_offload,
  532. # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
  533. "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
  534. "mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
  535. "ignored_params": get_parameters_from_modules(fsdp2_plugin.ignored_modules, model, accelerator.device),
  536. }
  537. model_has_params4bit = False
  538. for name, param in model.named_parameters():
  539. # this is a temporary fix whereby loading models with bnb params cannot be moved from
  540. # GPU to a meta device due with FSDP2 because torch operations don't return the original class type
  541. # bypassing the move to meta will still cause the VRAM spike, but at least it still will load
  542. if param.__class__.__name__ == "Params4bit":
  543. model_has_params4bit = True
  544. break
  545. if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
  546. # Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
  547. # For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
  548. # If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU
  549. # Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
  550. # We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
  551. # Also, these buffers aren't getting sharded by default
  552. # We get the FQNs of all non-persistent buffers, to re-register them after
  553. non_persistent_buffer_fqns = get_non_persistent_buffers(model, recurse=True, fqns=True)
  554. original_non_persistent_buffers = copy.deepcopy(
  555. {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
  556. )
  557. # We move the model to meta device, as then sharding happens on meta device
  558. model = model.to(torch.device("meta"))
  559. # We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
  560. # We assume `transformers` models have a `tie_weights` method if they support it
  561. if hasattr(model, "tie_weights"):
  562. model.tie_weights()
  563. auto_wrap_policy_func = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
  564. if auto_wrap_policy_func is not None:
  565. # We skip the model itself, as that one is always wrapped
  566. for module in get_module_children_bottom_up(model)[:-1]:
  567. if auto_wrap_policy_func(module) and not isinstance(module, FSDPModule):
  568. fully_shard(module, **fsdp2_kwargs)
  569. if not isinstance(model, FSDPModule):
  570. fully_shard(model, **fsdp2_kwargs)
  571. if fsdp2_plugin.cpu_ram_efficient_loading:
  572. # If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights
  573. # Other ranks have an empty model on `meta` device, so we need to distribute the weights properly
  574. fsdp2_load_full_state_dict(accelerator, model, original_sd)
  575. if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
  576. # We re-register the buffers, as they may not be in the state_dict
  577. for fqn, buffer_tensor in original_non_persistent_buffers.items():
  578. buffer_tensor = buffer_tensor.to(accelerator.device)
  579. if "." in fqn:
  580. parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
  581. parent_module = model.get_submodule(parent_fqn)
  582. else:
  583. local_buffer_name = fqn
  584. parent_module = model
  585. parent_module.register_buffer(local_buffer_name, buffer_tensor, persistent=False)
  586. # We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
  587. # Needs to be called both here and above
  588. # removing this call makes the have slightly different loss
  589. # removing the call above leads to extra memory usage as explained in the comment above
  590. if hasattr(model, "tie_weights"):
  591. model.tie_weights()
  592. # There is no `dtype` attribution for nn.Module
  593. # Set it to None if it doesn't exist and do the upcast always
  594. model_dtype = getattr(model, "dtype", None)
  595. if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32):
  596. # We upcast the model according to `deepspeed`'s implementation
  597. # More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section
  598. model = model.to(torch.float32)
  599. if accelerator.is_main_process:
  600. # TODO(siro1): Add a warning for each parameter that was upcasted
  601. warnings.warn(
  602. "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints."
  603. )
  604. return model
  605. def fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: torch.nn.Module) -> Callable[[torch.nn.Module], bool]:
  606. """Prepares the auto wrap policy based on its type, done to mimic the behaviour of FSDP1 auto wrap policy.
  607. Args:
  608. fsdp2_plugin (`FullyShardedDataParallelPlugin`):
  609. Instance of `FullyShardedDataParallelPlugin` containing the configuration options
  610. auto_wrap_policy_type (`str`):
  611. Either `transformer` or `size`
  612. model (`torch.nn.Module`):
  613. The model to wrap
  614. Returns:
  615. `Callable[[torch.nn.Module], bool]`:
  616. The auto wrap policy function to be applied to the model
  617. """
  618. from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
  619. fn = fsdp2_plugin.auto_wrap_policy
  620. if isinstance(fn, functools.partial):
  621. fn = fn.func
  622. if fn is transformer_auto_wrap_policy:
  623. no_split_modules = getattr(model, "_no_split_modules", None)
  624. if no_split_modules is None:
  625. no_split_modules = []
  626. transformer_cls_names_to_wrap = list(no_split_modules)
  627. if fsdp2_plugin.transformer_cls_names_to_wrap is not None:
  628. transformer_cls_names_to_wrap = fsdp2_plugin.transformer_cls_names_to_wrap
  629. transformer_cls_to_wrap = set()
  630. for layer_class in transformer_cls_names_to_wrap:
  631. transformer_cls = get_module_class_from_name(model, layer_class)
  632. if transformer_cls is None:
  633. raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.")
  634. transformer_cls_to_wrap.add(transformer_cls)
  635. def policy(module: torch.nn.Module) -> bool:
  636. if fsdp2_plugin.transformer_cls_names_to_wrap is None:
  637. return False
  638. return isinstance(module, tuple(transformer_cls_to_wrap))
  639. elif fn is size_based_auto_wrap_policy:
  640. def policy(module: torch.nn.Module) -> bool:
  641. module_num_params = sum(p.numel() for p in module.parameters())
  642. return module_num_params > fsdp2_plugin.min_num_params
  643. else:
  644. return None
  645. return policy
  646. def get_fsdp2_grad_scaler(**kwargs):
  647. """
  648. Returns a `GradScaler` for FSDP2, as the current implementation of `get_grad_scaler` doesn't accept other args. We
  649. need this as current `get_grad_scaler` accepts only `distributed_type` as arg, which doesn't differentiate between
  650. FSDP1 and FSDP2
  651. """
  652. from torch.amp.grad_scaler import GradScaler
  653. return GradScaler(**kwargs)
  654. def fsdp2_canonicalize_names(named_params: dict) -> dict:
  655. """Removes parameter name modifiers in order to map them back to their original names.
  656. See huggingface/accelerate#3554 for more context.
  657. Args:
  658. named_params (`dict`): The named parameters dictionary to canonicalize.
  659. Returns:
  660. `dict`: The canonicalized named parameters dictionary
  661. """
  662. named_params = {k.replace("._checkpoint_wrapped_module", ""): v for k, v in named_params.items()}
  663. named_params = {
  664. k.replace("_orig_mod.", "") if k.startswith("_orig_mod.") else k: v for k, v in named_params.items()
  665. }
  666. named_params = {k.replace("._orig_mod", ""): v for k, v in named_params.items()}
  667. return named_params
  668. def get_parameters_from_modules(
  669. modules: Union[Iterable[torch.nn.Module], str], model, device
  670. ) -> set[torch.nn.Parameter]:
  671. """Converts modules to parameters where modules can be a string or list of torch.nn.Module
  672. Args:
  673. modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
  674. Returns:
  675. `set[torch.nn.Parameter]`: List of parameters
  676. """
  677. if modules is None:
  678. return set()
  679. parameters = []
  680. # code taken from accelerate while preparing kwargs for FSDP
  681. if isinstance(modules, str):
  682. reg = re.compile(modules)
  683. mapped_modules = []
  684. for name, module in model.named_modules():
  685. if reg.fullmatch(name):
  686. module.to(device)
  687. mapped_modules.append(module)
  688. modules = mapped_modules
  689. for module in modules:
  690. parameters.extend(list(module.parameters()))
  691. return set(parameters)