state_dict.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. import gc
  5. import warnings
  6. from collections.abc import Generator, Iterable
  7. from dataclasses import asdict, dataclass, field
  8. from itertools import chain
  9. from typing import Any, Callable, cast, no_type_check, Optional, Union
  10. import torch
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. from torch.distributed._shard.sharded_tensor import ShardedTensor
  14. from torch.distributed._state_dict_utils import (
  15. _broadcast_state_dict,
  16. _distribute_state_dict,
  17. _flatten_state_dict,
  18. _gather_state_dict,
  19. _offload_state_dict_to_cpu,
  20. _unflatten_state_dict,
  21. )
  22. from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
  23. _CHECKPOINT_PREFIX,
  24. )
  25. from torch.distributed.fsdp import (
  26. FullOptimStateDictConfig,
  27. FullStateDictConfig,
  28. FullyShardedDataParallel as FSDP,
  29. OptimStateDictConfig,
  30. ShardedOptimStateDictConfig,
  31. ShardedStateDictConfig,
  32. StateDictConfig,
  33. StateDictType,
  34. )
  35. from torch.distributed.fsdp._common_utils import (
  36. _get_module_fsdp_state_if_fully_sharded_module,
  37. FSDP_WRAPPED_MODULE,
  38. )
  39. from torch.distributed.tensor import DTensor
  40. from torch.nn.modules.module import _IncompatibleKeys
  41. from torch.nn.parallel import DistributedDataParallel as DDP
  42. from torch.utils._pytree import tree_map_only
  43. __all__ = [
  44. "FQNS_T",
  45. "PrimitiveType",
  46. "ValueType",
  47. "DictValueType",
  48. "ListDictValueType",
  49. "OptimizerStateType",
  50. "StateDictOptions",
  51. "get_model_state_dict",
  52. "get_optimizer_state_dict",
  53. "get_state_dict",
  54. "set_model_state_dict",
  55. "set_optimizer_state_dict",
  56. "set_state_dict",
  57. ]
  58. _FLAT_PARAM = "_flat_param"
  59. _PG = "param_groups"
  60. _PARAMS = "params"
  61. _STATE = "state"
  62. FQNS_T = set[str]
  63. PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]
  64. ValueType = Union[
  65. PrimitiveType, list[PrimitiveType], tuple[PrimitiveType], dict[str, "ValueType"]
  66. ]
  67. DictValueType = dict[str, ValueType]
  68. ListDictValueType = list[DictValueType]
  69. OptimizerStateType = dict[str, Union[DictValueType, ListDictValueType]]
  70. _patched_state_dict: set[Callable] = set()
  71. @contextlib.contextmanager
  72. def _gc_context():
  73. is_enabled = gc.isenabled()
  74. gc.disable()
  75. try:
  76. yield
  77. finally:
  78. if is_enabled:
  79. gc.enable()
  80. @dataclass
  81. class StateDictOptions:
  82. """
  83. This dataclass specifies how get_state_dict/set_state_dict will work.
  84. - ``full_state_dict``: if this is set to True, all the tensors in the
  85. returned state_dict will be gathered. No ShardedTensor and DTensor
  86. will be in the returned state_dict.
  87. - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if
  88. ``full_state_dict`` is also true, then only the rank0 will get the
  89. state_dict and all other ranks will get empty state_dict.
  90. - ``ignore_frozen_params``: if the value is True, the returned state_dict
  91. won't contain any frozen parameters -- the ``requires_grad`` is False.
  92. The default value is False.
  93. - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option
  94. indicates whether to keep the submodule prefixes from the state_dict keys.
  95. or example, if the submodule is ``module.pretrain`` and the full FQN of
  96. the parameter is ``pretrain.layer1.weight`` of the param. When this option
  97. is True, the parameter's key in the returned state_dict will be
  98. ``pretrain.layer1.weight``. If the options is False, the key will be
  99. ``layer1.weight``.
  100. Note that if ``keep_submodule_prefixes`` is False, there may be conflicted
  101. FQNs, hence there should be only one submodule in ``submodules``.
  102. - ``strict``: the ``strict`` option when ``set_state_dict`` calls
  103. model.load_state_dict().
  104. - ``broadcast_from_rank0``: when the option is True, rank0 should receive a
  105. full state_dict and will broadcast the tensors in the state_dict/
  106. optim_state_dict one by one to other ranks. Other ranks will receive
  107. the tensors and shard according to the local shards in the model and
  108. optimizer. ``full_state_dict`` must be set to True when using this option.
  109. This option currently only supports DTensor, not the legacy ShardedTensor.
  110. """
  111. full_state_dict: bool = False
  112. cpu_offload: bool = False
  113. ignore_frozen_params: bool = False
  114. keep_submodule_prefixes: bool = True
  115. strict: bool = True
  116. broadcast_from_rank0: bool = False
  117. flatten_optimizer_state_dict: bool = False
  118. dsd_fqn_modifiers: str = "_fqn_modifiers"
  119. @dataclass
  120. class _StateDictInfo(StateDictOptions):
  121. fqn_param_mapping: dict[
  122. Union[str, torch.Tensor],
  123. Union[FQNS_T, torch.Tensor],
  124. ] = field(default_factory=dict)
  125. shared_params_mapping: dict[
  126. Union[str, torch.Tensor],
  127. Union[FQNS_T, torch.Tensor],
  128. ] = field(default_factory=dict)
  129. submodule_prefixes: set[str] = field(default_factory=set)
  130. handle_model: bool = True
  131. handle_optim: bool = True
  132. fsdp_context: Callable = contextlib.nullcontext
  133. fsdp_modules: list[nn.Module] = field(default_factory=list)
  134. def _get_fqns(
  135. model: nn.Module,
  136. name: str,
  137. dsd_fqn_modifiers: str = "_fqn_modifiers",
  138. skip_ddp_prefix: bool = True,
  139. skip_compiler_prefix: bool = True,
  140. ) -> FQNS_T:
  141. """
  142. This API is used to convert the name of a parameter to the FQNs. For FSDP
  143. without `use_orig_params`, the name of FlatParameter can be mapped to
  144. multiple original parameters. As a result, the return type of this function
  145. is `set[str]`.
  146. Args:
  147. module (nn.Module): the root model.
  148. name (str): the name
  149. skip_ddp_prefix (bool): whether to skip DDP's `module` prefix
  150. Returns:
  151. The canonical FQNs based on the model traversal.
  152. """
  153. # Remove the checkpoint prefix, if it exists.
  154. name = name.replace(_CHECKPOINT_PREFIX, "")
  155. if "." not in name:
  156. return {name}
  157. obj_names = name.split(".")
  158. fqn_obj_names = []
  159. curr_obj = model
  160. for i, curr_obj_name in enumerate(obj_names):
  161. if isinstance(curr_obj, DDP):
  162. assert curr_obj_name == "module"
  163. curr_obj = curr_obj.module
  164. if not skip_ddp_prefix:
  165. fqn_obj_names.append(curr_obj_name)
  166. elif isinstance(curr_obj, FSDP):
  167. if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM:
  168. prefix = ".".join(fqn_obj_names)
  169. flat_param = getattr(curr_obj, _FLAT_PARAM)
  170. if prefix:
  171. prefix = f"{prefix}."
  172. return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
  173. curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
  174. if curr_obj_name != FSDP_WRAPPED_MODULE:
  175. fqn_obj_names.append(curr_obj_name)
  176. curr_obj = getattr(curr_obj, curr_obj_name)
  177. elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
  178. assert curr_obj_name == "_orig_mod"
  179. curr_obj = curr_obj._orig_mod
  180. if not skip_compiler_prefix:
  181. fqn_obj_names.append(curr_obj_name)
  182. else:
  183. # In some modules, _fqn_modifiers would not shown in the state_dict keys,
  184. # skip them in the fqn to ensure load stat dict successfully for them.
  185. if hasattr(curr_obj, dsd_fqn_modifiers):
  186. if removed_fqn := getattr(curr_obj, dsd_fqn_modifiers)().get(
  187. curr_obj_name
  188. ):
  189. if hasattr(curr_obj, removed_fqn):
  190. curr_obj = getattr(curr_obj, removed_fqn)
  191. fqn_obj_names.append(curr_obj_name)
  192. if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
  193. if i != len(obj_names) - 1:
  194. raise RuntimeError("Expect `_extra_state` to be the last obj name")
  195. else:
  196. curr_obj = getattr(curr_obj, curr_obj_name)
  197. return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")}
  198. class _EXTRA_STATE:
  199. pass
  200. def _iterate_valid_model_state(model, dsd_fqn_modifiers="_fqn_modifiers"):
  201. visited_modules: set[nn.Module] = set()
  202. def recurse(module: nn.Module, curr_fqn: str) -> Generator:
  203. visited_modules.add(module)
  204. curr_fqn = f"{curr_fqn}." if curr_fqn else ""
  205. for name, submodule in module.named_children():
  206. if submodule in visited_modules:
  207. continue
  208. # if user have state_dict_hooks in their model, they can add the state_dict key changes
  209. # at dsd_fqn_modifiers in input to align with the function of state_dict_hook
  210. if (
  211. hasattr(module, dsd_fqn_modifiers)
  212. and name in getattr(module, dsd_fqn_modifiers)().values()
  213. ):
  214. # skip _fqn_modifiers here thus remove the last `.` added
  215. new_fqn = curr_fqn[:-1]
  216. else:
  217. new_fqn = f"{curr_fqn}{name}"
  218. yield from recurse(submodule, new_fqn)
  219. for name, obj in chain(
  220. module.named_buffers(recurse=False), module.named_parameters(recurse=False)
  221. ):
  222. if name in module._non_persistent_buffers_set:
  223. continue
  224. new_fqn = f"{curr_fqn}{name}"
  225. yield new_fqn, obj
  226. if (
  227. getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state)
  228. != nn.Module.get_extra_state
  229. ):
  230. new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}"
  231. yield new_fqn, _EXTRA_STATE()
  232. yield from recurse(model, "")
  233. def _verify_options(
  234. model: nn.Module,
  235. optims: tuple[torch.optim.Optimizer, ...],
  236. optim_only: bool,
  237. *,
  238. submodules: Optional[set[nn.Module]] = None,
  239. options: Optional[StateDictOptions] = None,
  240. ) -> _StateDictInfo:
  241. """
  242. Verify the model and options passed by the user and generates _StateDictInfo.
  243. """
  244. if submodules:
  245. warnings.warn(
  246. "Getting submodules only model/optim state_dict is deprecated and "
  247. "will be removed in 2.5. This feature can be achieved by manually "
  248. "filtering out the state_dict returned from get_state_dict.",
  249. FutureWarning,
  250. )
  251. if optim_only and not optims:
  252. raise RuntimeError(
  253. "Optimizers are not passed in but optim_only is set to True."
  254. )
  255. options = options or StateDictOptions()
  256. fqn_param_mapping: dict[
  257. Union[str, torch.Tensor], Union[set[str], torch.Tensor]
  258. ] = {}
  259. shared_params_mapping: dict[
  260. Union[str, torch.Tensor], Union[set[str], torch.Tensor]
  261. ] = {}
  262. for name, param in _iterate_valid_model_state(model):
  263. if isinstance(param, _EXTRA_STATE):
  264. continue
  265. fqns = _get_fqns(model, name)
  266. fqn = fqn_param_mapping.get(param, None)
  267. if fqn is not None:
  268. cast(set[str], fqn_param_mapping[param]).update(fqns)
  269. shared_params_mapping[param] = fqn_param_mapping[param]
  270. else:
  271. # We need to do copy as _get_fqns is lru_cached
  272. fqn_param_mapping[param] = fqns.copy()
  273. for fqn in fqns:
  274. if not isinstance(param, _EXTRA_STATE):
  275. fqn_param_mapping[fqn] = param
  276. for param_, fqns_ in list(shared_params_mapping.items()):
  277. for fqn in fqns_:
  278. shared_params_mapping[fqn] = cast(torch.Tensor, param_)
  279. submodule_prefixes: set[str] = set()
  280. if submodules:
  281. submodules = set(submodules)
  282. for name, module in model.named_modules():
  283. if module not in submodules:
  284. continue
  285. fqns = _get_fqns(model, name)
  286. assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
  287. submodule_prefixes.update(f"{fqn}." for fqn in fqns)
  288. if options.broadcast_from_rank0 and not options.full_state_dict:
  289. raise ValueError(
  290. "full_state_dict must be True when broadcast_from_rank0 is True."
  291. )
  292. fsdp_modules = FSDP.fsdp_modules(model)
  293. state_dict_config: StateDictConfig
  294. optim_state_dict_config: OptimStateDictConfig
  295. fsdp_context: Callable
  296. if fsdp_modules:
  297. # FSDP API only work if at least one FSDP instance exists.
  298. if options.full_state_dict:
  299. state_dict_config = FullStateDictConfig(
  300. offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload
  301. )
  302. optim_state_dict_config = FullOptimStateDictConfig(
  303. offload_to_cpu=options.cpu_offload,
  304. rank0_only=(options.cpu_offload or options.broadcast_from_rank0),
  305. )
  306. state_dict_type = StateDictType.FULL_STATE_DICT
  307. else:
  308. state_dict_config = ShardedStateDictConfig(
  309. offload_to_cpu=options.cpu_offload,
  310. )
  311. optim_state_dict_config = ShardedOptimStateDictConfig(
  312. offload_to_cpu=options.cpu_offload,
  313. )
  314. state_dict_type = StateDictType.SHARDED_STATE_DICT
  315. @contextlib.contextmanager
  316. def fsdp_state_dict_type_without_warning(
  317. module,
  318. state_dict_type,
  319. state_dict_config,
  320. optim_state_dict_config,
  321. ):
  322. with warnings.catch_warnings():
  323. warnings.filterwarnings(
  324. "ignore", message="FSDP.state_dict_type", category=FutureWarning
  325. )
  326. with FSDP.state_dict_type(
  327. module=module,
  328. state_dict_type=state_dict_type,
  329. state_dict_config=state_dict_config,
  330. optim_state_dict_config=optim_state_dict_config,
  331. ):
  332. yield
  333. fsdp_context = functools.partial(
  334. fsdp_state_dict_type_without_warning,
  335. module=model,
  336. state_dict_type=state_dict_type,
  337. state_dict_config=state_dict_config,
  338. optim_state_dict_config=optim_state_dict_config,
  339. )
  340. else:
  341. fsdp_context = contextlib.nullcontext
  342. return _StateDictInfo(
  343. **asdict(options),
  344. fqn_param_mapping=fqn_param_mapping,
  345. shared_params_mapping=shared_params_mapping,
  346. submodule_prefixes=submodule_prefixes,
  347. fsdp_context=fsdp_context,
  348. fsdp_modules=cast(list[nn.Module], fsdp_modules),
  349. handle_model=not optim_only,
  350. handle_optim=(len(optims) > 0),
  351. )
  352. def _verify_state_dict(
  353. model_state_dict: dict[str, ValueType],
  354. optim_state_dict: OptimizerStateType,
  355. info: _StateDictInfo,
  356. ) -> None:
  357. for module in info.fsdp_modules:
  358. fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
  359. assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module."
  360. # Verify if the model_state_dict and optim_state_dict are valid. This API
  361. # should give the users an explicit error message to debug or report.
  362. if (
  363. info.handle_model
  364. and not model_state_dict
  365. and not info.submodule_prefixes
  366. and not info.ignore_frozen_params
  367. and not (info.cpu_offload and info.full_state_dict)
  368. and info.strict
  369. and not info.broadcast_from_rank0
  370. ):
  371. raise RuntimeError(
  372. "The option indicates that model state_dict is required to save "
  373. "or load, but model state_dict is empty."
  374. f"rank = {dist.get_rank()=}."
  375. )
  376. if info.handle_optim:
  377. if (
  378. not optim_state_dict
  379. and not (info.cpu_offload and info.full_state_dict)
  380. and (not info.broadcast_from_rank0)
  381. ):
  382. raise RuntimeError(
  383. "The option indicates that model state_dict is required to save, "
  384. f"or load but optim state_dict is empty. {optim_state_dict}"
  385. )
  386. for key in model_state_dict.keys():
  387. if _FLAT_PARAM in key:
  388. raise RuntimeError(
  389. f"{key} contains {_FLAT_PARAM}. This can happen if the model "
  390. "is not the root module."
  391. )
  392. def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Callable:
  393. call = getattr(obj, api)
  394. if call in _patched_state_dict:
  395. call = functools.partial(getattr(obj.__class__, api), self=obj)
  396. return call
  397. def _maybe_full_or_cpu_state_dict(
  398. state_dict: dict[str, Any], info: _StateDictInfo
  399. ) -> dict[str, Any]:
  400. if info.full_state_dict:
  401. ranks_only = (
  402. ()
  403. if (not info.cpu_offload or not torch.distributed.is_initialized())
  404. else (0,)
  405. )
  406. return _gather_state_dict(
  407. state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only
  408. )
  409. elif info.cpu_offload:
  410. return _offload_state_dict_to_cpu(state_dict)
  411. else:
  412. return state_dict
  413. @torch.no_grad()
  414. def _get_model_state_dict(
  415. model: nn.Module, info: _StateDictInfo
  416. ) -> dict[str, ValueType]:
  417. if not info.handle_model:
  418. return {}
  419. with info.fsdp_context():
  420. state_dict = _state_dict_fn(model, "state_dict")()
  421. for key in list(state_dict.keys()):
  422. fqns = _get_fqns(model, key)
  423. assert len(fqns) == 1, (key, fqns)
  424. fqn = next(iter(fqns))
  425. if fqn != key:
  426. # As we only support FSDP, DDP, and TP, the only cases are
  427. # wrapper-based DDP and compiler. Verify if the assumption
  428. # is correct.
  429. def verify(key, fqn) -> bool:
  430. if len(fqn) >= len(key):
  431. return False
  432. fqn_split = fqn.split(".")
  433. key_split = key.split(".")
  434. fqn_idx = 0
  435. for key_idx, key_name in enumerate(key_split):
  436. if key_name == fqn_split[fqn_idx]:
  437. fqn_idx += 1
  438. if fqn_idx == len(fqn_split):
  439. return key_idx == len(key_split) - 1
  440. elif key_name in ("module", "_orig_mod"):
  441. continue
  442. else:
  443. return False
  444. return True
  445. if not verify(key, fqn):
  446. raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}")
  447. state_dict[fqn] = state_dict.pop(key)
  448. if info.submodule_prefixes:
  449. new_state_dict: dict[str, ValueType] = {}
  450. # TODO: make this faster.
  451. for fqn in state_dict.keys():
  452. for prefix in info.submodule_prefixes:
  453. if not fqn.startswith(prefix):
  454. continue
  455. if info.keep_submodule_prefixes:
  456. new_state_dict[fqn] = state_dict[fqn]
  457. else:
  458. new_fqn = fqn[len(prefix) :]
  459. new_state_dict[new_fqn] = state_dict[fqn]
  460. state_dict = new_state_dict
  461. if info.ignore_frozen_params:
  462. for key, param in model.named_parameters():
  463. if param.requires_grad:
  464. continue
  465. fqns = _get_fqns(model, key)
  466. for fqn in fqns:
  467. state_dict.pop(fqn)
  468. return _maybe_full_or_cpu_state_dict(state_dict, info)
  469. @torch.no_grad()
  470. def _load_model_state_dict(
  471. model: nn.Module,
  472. state_dict: dict[str, ValueType],
  473. info: _StateDictInfo,
  474. ) -> _IncompatibleKeys:
  475. if not info.handle_model or (not state_dict and not info.broadcast_from_rank0):
  476. return _IncompatibleKeys({}, {})
  477. local_state_dict = {}
  478. for key, value in _iterate_valid_model_state(model, info.dsd_fqn_modifiers):
  479. fqns = _get_fqns(model, key, info.dsd_fqn_modifiers)
  480. fqns_with_prefix = _get_fqns(
  481. model,
  482. key,
  483. info.dsd_fqn_modifiers,
  484. skip_ddp_prefix=False,
  485. skip_compiler_prefix=False,
  486. )
  487. for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix):
  488. if (
  489. not info.broadcast_from_rank0 or dist.get_rank() == 0
  490. ) and fqn != fqn_with_prefix:
  491. load_value = state_dict.pop(fqn, None)
  492. if load_value is None:
  493. if info.strict:
  494. raise RuntimeError(f"Missing key: {fqn}.")
  495. else:
  496. state_dict[fqn_with_prefix] = load_value
  497. local_state_dict[fqn_with_prefix] = value
  498. assign = False
  499. if info.broadcast_from_rank0 or info.full_state_dict:
  500. devices = set()
  501. for key, value in local_state_dict.items():
  502. if torch.is_tensor(value) and value.dim() > 0:
  503. devices.add(value.device)
  504. # In lora state_dict, there could be multiple devices, with meta device inside.
  505. # Take the other device in the broadcast/distribtue, and set assign to True
  506. if torch.device("meta") in devices:
  507. devices.remove(torch.device("meta"))
  508. assign = True
  509. if len(devices) == 0:
  510. devices.add(dist.distributed_c10d._get_pg_default_device())
  511. elif len(devices) > 1:
  512. raise ValueError("Multiple devices found")
  513. if info.broadcast_from_rank0:
  514. _broadcast_state_dict(
  515. state_dict,
  516. local_state_dict,
  517. device=devices.pop(),
  518. strict=info.strict,
  519. cpu_offload=info.cpu_offload,
  520. )
  521. elif info.full_state_dict:
  522. _distribute_state_dict(state_dict, local_state_dict, device=devices.pop())
  523. state_dict.update(local_state_dict)
  524. with info.fsdp_context():
  525. return cast(
  526. _IncompatibleKeys,
  527. _state_dict_fn(model, "load_state_dict")(
  528. state_dict=state_dict, strict=info.strict, assign=assign
  529. ),
  530. )
  531. def _init_optim_state(optim: torch.optim.Optimizer) -> None:
  532. """
  533. Initialize optim states by calling the step() with zero grads.
  534. """
  535. if optim.state:
  536. # The optimizer state is initialized.
  537. return
  538. # There are some stateless optimizers like SGD. These optimizer will
  539. # not return in the above condition. So if gradients exist, we should also
  540. # return. If gradients do not exist, the following initialization should
  541. # not disturb SGD because the gradients and lr are both zero.
  542. for param_group in optim.param_groups:
  543. for param in param_group[_PARAMS]:
  544. if param.grad is not None:
  545. return
  546. for param_group in optim.param_groups:
  547. for param in param_group[_PARAMS]:
  548. if param.requires_grad:
  549. param.grad = torch.zeros_like(param)
  550. # Some optimizers will update parameters regardless of grads due to lr, so
  551. # make lr to zero when calling `step()`.
  552. lrs = []
  553. for param_group in optim.param_groups:
  554. if "lr" in param_group:
  555. lrs.append(param_group["lr"])
  556. param_group["lr"] = (
  557. torch.tensor(0.0)
  558. if isinstance(param_group["lr"], torch.Tensor)
  559. else 0.0
  560. )
  561. optim.step(closure=None)
  562. # Whether to recover the "lr" should not matter too much as we will
  563. # restore checkpointing later.
  564. for param_group in optim.param_groups:
  565. if "lr" in param_group:
  566. param_group["lr"] = lrs.pop(0)
  567. optim.zero_grad(set_to_none=True)
  568. def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> dict[str, ValueType]:
  569. """
  570. This API flattens the optimizer state_dict to support optimizer resharding for
  571. MPMD, e.g., pipeline parallelism.
  572. Without the API, the original optimizer state_dict looks like:
  573. {
  574. "state": {
  575. "layer1.weight": {
  576. "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor
  577. },
  578. "layer2.weight": {
  579. "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor
  580. },
  581. },
  582. "param_group": [
  583. {
  584. "lr": 0.0,
  585. "betas": (0.9, 0.95), ...,
  586. "params": ["layer1.weight", "layer2.weight"]
  587. }
  588. ]
  589. }
  590. With this API, the optimizer state_dict looks like:
  591. {
  592. "state.layer1.weight.step": 10,
  593. "state.layer2.weight.step": 10,
  594. "state.layer1.weight.exp_avg": SomeTensor,
  595. "state.layer2.weight.exp_avg": SomeTensor,
  596. "state.layer1.weight.exp_avg_sq": SomeTensor,
  597. "state.layer2.weight.exp_avg_sq": SomeTensor,
  598. "param_group.layer1.weight.lr" : 0.1,
  599. "param_group.layer2.weight.lr" : 0.1,
  600. "param_group.layer1.weight.betas" : (0.9, 0.95),
  601. "param_group.layer2.weight.betas" : (0.9, 0.95),
  602. }
  603. Note that if any of the value is a container, like the betas in the example,
  604. this API won't flattent it.
  605. """
  606. def _raise_if_type_not_supported(v):
  607. if not isinstance(v, (torch.Tensor, int, float)):
  608. raise NotImplementedError(
  609. "Flattening optimizer state_dict only supports "
  610. "tensor, int, float states now. "
  611. f"Type is {type(v)}."
  612. )
  613. ret: dict[str, ValueType] = {}
  614. for fqn, state in cast(DictValueType, state_dict[_STATE]).items():
  615. for k, v in cast(DictValueType, state).items():
  616. _raise_if_type_not_supported(v)
  617. ret[f"{_STATE}.{fqn}.{k}"] = v
  618. for param_group in cast(ListDictValueType, state_dict[_PG]):
  619. fqns = param_group.pop(_PARAMS)
  620. for fqn in cast(list[str], fqns):
  621. for k, v in param_group.items():
  622. ret[f"{_PG}.{fqn}.{k}"] = v
  623. return ret
  624. def _unflatten_optim_state_dict(
  625. optim: torch.optim.Optimizer,
  626. state_dict: dict[str, ValueType],
  627. info: _StateDictInfo,
  628. ) -> OptimizerStateType:
  629. """
  630. This API unflattens the state_dict generated by _flatten_optim_state_dict().
  631. See the docstring of _flatten_optim_state_dict() for more detail.
  632. """
  633. state: DictValueType = {}
  634. pg_state: ListDictValueType = []
  635. return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}
  636. for param_group in optim.param_groups:
  637. pg_state.append({_PARAMS: []})
  638. for param in param_group[_PARAMS]:
  639. for fqn in info.fqn_param_mapping[param]:
  640. # If a parameter is shared, only one of the FQN will be used.
  641. # So we need to verify which if this fqn is actually used in
  642. # the state_dict.
  643. if fqn in info.shared_params_mapping:
  644. in_params = False
  645. for k in param_group.keys():
  646. if k == _PARAMS:
  647. continue
  648. flatten_key = f"{_PG}.{fqn}.{k}"
  649. if flatten_key in state_dict:
  650. in_params = True
  651. break
  652. else:
  653. in_params = True
  654. if not in_params:
  655. continue
  656. params = pg_state[-1][_PARAMS]
  657. assert isinstance(params, list) # typing
  658. params.append(fqn)
  659. if not param.requires_grad:
  660. continue
  661. state[fqn] = {}
  662. for state_name in optim.state[param].keys():
  663. cast(DictValueType, state[fqn])[state_name] = state_dict[
  664. f"{_STATE}.{fqn}.{state_name}"
  665. ]
  666. first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0]
  667. for k in param_group.keys():
  668. if k == _PARAMS:
  669. continue
  670. value = state_dict[f"{_PG}.{first_param_fqn}.{k}"]
  671. if k not in pg_state[-1]:
  672. pg_state[-1][k] = value
  673. elif pg_state[-1][k] != value:
  674. raise RuntimeError(
  675. "All the parameters in the same parameter group should have "
  676. f"the same saved param_group value. But {first_param_fqn}.{k} "
  677. f"is {value} while other(s) is {pg_state[-1][k]}."
  678. )
  679. return return_osd
  680. @torch.no_grad()
  681. def _get_optim_state_dict(
  682. model: nn.Module,
  683. optimizers: tuple[torch.optim.Optimizer, ...],
  684. info: _StateDictInfo,
  685. ) -> OptimizerStateType:
  686. if not info.handle_optim:
  687. return {}
  688. optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []}
  689. for optim in optimizers:
  690. _init_optim_state(optim)
  691. osd = _state_dict_fn(optim, "state_dict")()
  692. if info.fsdp_modules:
  693. with info.fsdp_context():
  694. osd = FSDP.optim_state_dict(model, optim, osd)
  695. # We need to specially handle FlatParameter FSDP as
  696. # FlatParameter FSDP converts the FQNs.
  697. # There are no easy ways to do this conversion systematically.
  698. # We can only use a string replacement without correctness check.
  699. if not osd:
  700. continue
  701. for k in list(osd[_STATE].keys()):
  702. if "_orig_mod" in k:
  703. osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k)
  704. for g in osd[_PG]:
  705. params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]]
  706. g[_PARAMS] = params
  707. else:
  708. params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups))
  709. param_pid_mapping = dict(zip(params, range(len(params))))
  710. fqn_pid_mapping = {}
  711. for key, param in model.named_parameters():
  712. fqns = _get_fqns(model, key)
  713. assert len(fqns) == 1
  714. fqn = next(iter(fqns))
  715. if param not in param_pid_mapping:
  716. continue
  717. pid = param_pid_mapping[param]
  718. fqn_pid_mapping[fqn] = pid
  719. fqn_pid_mapping[pid] = fqn
  720. for key in list(osd[_STATE].keys()):
  721. fqn = fqn_pid_mapping[key]
  722. osd[_STATE][fqn] = osd[_STATE].pop(key)
  723. for group in osd[_PG]:
  724. group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]]
  725. if not osd:
  726. continue
  727. cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE])
  728. cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG])
  729. if info.flatten_optimizer_state_dict:
  730. optim_state_dict = cast(
  731. OptimizerStateType, _flatten_optim_state_dict(optim_state_dict)
  732. )
  733. return _maybe_full_or_cpu_state_dict(optim_state_dict, info)
  734. def _split_optim_state_dict(
  735. model: nn.Module,
  736. optim: torch.optim.Optimizer,
  737. optim_state_dict: OptimizerStateType,
  738. info: _StateDictInfo,
  739. ) -> OptimizerStateType:
  740. """
  741. Extract the corresponding optim state_dict from ``optim_state_dict`` for
  742. ``optim`` and return the result optim state_dict.
  743. Args:
  744. model (nn.Module): the root model.
  745. optim (torch.optim.Optimizer): the optimizer.
  746. optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that
  747. contains the optim state_dict of ``optim``.
  748. info (_StateDictInfo): state dict information.
  749. Returns:
  750. The optim state_dict of ``optim``.
  751. """
  752. state: DictValueType = {}
  753. pg_state: ListDictValueType = []
  754. return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}
  755. pg_mapping: dict[int, int] = {}
  756. if all(
  757. isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()
  758. ):
  759. return optim_state_dict
  760. for param_group in optim.param_groups:
  761. pg_state.append({_PARAMS: []})
  762. for param in param_group[_PARAMS]:
  763. for fqn in info.fqn_param_mapping[param]:
  764. if fqn in info.shared_params_mapping:
  765. in_params = False
  766. for loaded_param_group in cast(
  767. ListDictValueType, optim_state_dict[_PG]
  768. ):
  769. if fqn in cast(list[str], loaded_param_group[_PARAMS]):
  770. in_params = True
  771. break
  772. else:
  773. in_params = True
  774. if not in_params:
  775. continue
  776. params = pg_state[-1][_PARAMS]
  777. assert isinstance(params, list)
  778. params.append(fqn)
  779. if param.requires_grad:
  780. state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn]
  781. for loaded_param_group in cast(
  782. ListDictValueType, optim_state_dict[_PG]
  783. ):
  784. if fqn in cast(list[str], loaded_param_group[_PARAMS]):
  785. pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
  786. if len(param_group[_PARAMS]) == 0:
  787. # Param_group with empty params.
  788. ret = []
  789. for loaded_param_group in cast(ListDictValueType, optim_state_dict[_PG]):
  790. if len(cast(list[str], loaded_param_group[_PARAMS])) == 0:
  791. ret.append(loaded_param_group)
  792. if len(ret) != 1:
  793. raise ValueError(
  794. "There are param groups that have zero parameters. "
  795. "In such a case, DSD only support exactly one param group "
  796. "with zero parameters."
  797. "But the loaded state_dict has zero or more than one param groups "
  798. "that have zero parameters."
  799. )
  800. if len(optim_state_dict[_PG]) != len(optim.param_groups):
  801. raise ValueError(
  802. "When there is a parameter group that has zero parameters, "
  803. "multiple optimizers are not supported."
  804. )
  805. pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
  806. for param_group in cast(ListDictValueType, optim_state_dict[_PG]):
  807. pg_idx = pg_mapping.get(id(param_group), -1)
  808. if pg_idx == -1:
  809. continue
  810. for key, value in param_group.items():
  811. if key == _PARAMS:
  812. continue
  813. # TODO: check if value is the same if exists.
  814. pg_state[pg_idx][key] = value
  815. return return_osd
  816. @torch.no_grad()
  817. def _load_optim_state_dict(
  818. model: nn.Module,
  819. optimizers: tuple[torch.optim.Optimizer, ...],
  820. state_dict: OptimizerStateType,
  821. info: _StateDictInfo,
  822. ) -> None:
  823. if not info.handle_optim:
  824. return
  825. for optim in optimizers:
  826. _init_optim_state(optim)
  827. if state_dict:
  828. if _STATE in state_dict:
  829. optim_state_dict = _split_optim_state_dict(
  830. model, optim, state_dict, info
  831. )
  832. else:
  833. optim_state_dict = _unflatten_optim_state_dict(
  834. optim, cast(dict[str, ValueType], state_dict), info
  835. )
  836. else:
  837. optim_state_dict = {}
  838. if info.fsdp_modules:
  839. # We need to specially handle FlatParameter FSDP as
  840. # FlatParameter FSDP converts the FQNs.
  841. for original_fqn, _ in model.named_parameters():
  842. fqns = _get_fqns(model, original_fqn)
  843. fqns_with_compiler = _get_fqns(
  844. model, original_fqn, skip_compiler_prefix=False
  845. )
  846. if fqns == fqns_with_compiler:
  847. continue
  848. assert len(fqns) == 1
  849. fqn = fqns.pop()
  850. fqn_with_compiler = fqns_with_compiler.pop()
  851. for g in optim_state_dict[_PG]:
  852. val = cast(dict[str, Any], g)
  853. params = [
  854. key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]
  855. ]
  856. val[_PARAMS] = params
  857. osd_state = cast(DictValueType, optim_state_dict[_STATE])
  858. for k in list(osd_state.keys()):
  859. if fqn in k:
  860. osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k)
  861. with info.fsdp_context():
  862. optim_state_dict = FSDP.optim_state_dict_to_load(
  863. model, optim, optim_state_dict
  864. )
  865. elif info.full_state_dict:
  866. info.full_state_dict = False
  867. local_state_dict = _get_optim_state_dict(model, (optim,), info)
  868. info.full_state_dict = True
  869. device = None
  870. def _device(t):
  871. if t.dim() > 0:
  872. nonlocal device
  873. if device is None:
  874. device = t.device
  875. elif device != t.device:
  876. raise ValueError("Device mismatch")
  877. return t
  878. _ = tree_map_only(torch.Tensor, _device, local_state_dict)
  879. assert device is not None
  880. flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict)
  881. flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict)
  882. if info.broadcast_from_rank0:
  883. _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device)
  884. else:
  885. _distribute_state_dict(flatten_osd, flatten_local_osd, device=device)
  886. # The modifications listed seek to address the problem where optim might possess
  887. # dissimilar parameters in comparison to optim_state_dict. This is achieved by
  888. # incorporating differential parameters within local, which may result in optim
  889. # having additional parameters ultimately.
  890. for optim_key in flatten_osd.keys():
  891. if optim_key not in flatten_local_osd:
  892. assert optim_key in osd_mapping
  893. flatten_local_osd[optim_key] = flatten_osd[optim_key]
  894. local_osd_mapping[optim_key] = osd_mapping[optim_key]
  895. optim_state_dict = _unflatten_state_dict(
  896. flatten_local_osd, local_osd_mapping
  897. )
  898. for pg in optim_state_dict[_PG]:
  899. if _PARAMS not in pg:
  900. cast(dict[str, ValueType], pg)[_PARAMS] = []
  901. # Note that we do not have to convert the FQN back to param id here if
  902. # order in optim.param_groups[idx][_PARAMS] is the same as the one in
  903. # optim_state_dict[_PG][idx][_PARAMS].
  904. _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict)
  905. def get_model_state_dict(
  906. model: nn.Module,
  907. *,
  908. submodules: Optional[set[nn.Module]] = None,
  909. options: Optional[StateDictOptions] = None,
  910. ) -> dict[str, ValueType]:
  911. """
  912. Return the model state_dict of ``model``.
  913. See ``get_state_dict`` for the detail usage.
  914. Args:
  915. model (nn.Module): the nn.Module to the model.
  916. submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters
  917. that belong to the submodules.
  918. options (StateDictOptions): the options to control how
  919. model state_dict and optimizer state_dict should be returned. See
  920. `StateDictOptions` for the details.
  921. Returns:
  922. The state_dict for ``model``.
  923. :rtype: typing.Dict[str, ValueType]
  924. """
  925. with _gc_context():
  926. info = _verify_options(
  927. model,
  928. (),
  929. optim_only=False,
  930. submodules=submodules,
  931. options=options,
  932. )
  933. model_state_dict = _get_model_state_dict(model, info)
  934. _verify_state_dict(model_state_dict, {}, info)
  935. return model_state_dict
  936. def get_optimizer_state_dict(
  937. model: nn.Module,
  938. optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
  939. *,
  940. submodules: Optional[set[nn.Module]] = None,
  941. options: Optional[StateDictOptions] = None,
  942. ) -> OptimizerStateType:
  943. """
  944. Return the combined state_dict for optimizers.
  945. See ``get_state_dict`` for the detail usage.
  946. Args:
  947. model (nn.Module): the nn.Module to the model.
  948. optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
  949. The optimizers that are used to optimize ``model``.
  950. submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters
  951. that belong to the submodules.
  952. options (StateDictOptions): the options to control how
  953. model state_dict and optimizer state_dict should be returned. See
  954. `StateDictOptions` for the details.
  955. Returns:
  956. The state_dict for ``optimizers``.
  957. :rtype: OptimizerStateType
  958. """
  959. with _gc_context():
  960. optimizers = (
  961. (optimizers,)
  962. if isinstance(optimizers, torch.optim.Optimizer)
  963. else tuple(optimizers)
  964. )
  965. info = _verify_options(
  966. model,
  967. optimizers,
  968. optim_only=True,
  969. submodules=submodules,
  970. options=options,
  971. )
  972. optim_state_dict = _get_optim_state_dict(model, optimizers, info)
  973. _verify_state_dict({}, optim_state_dict, info)
  974. return optim_state_dict
  975. def get_state_dict(
  976. model: nn.Module,
  977. optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
  978. *,
  979. submodules: Optional[set[nn.Module]] = None,
  980. options: Optional[StateDictOptions] = None,
  981. ) -> tuple[dict[str, ValueType], OptimizerStateType]:
  982. """
  983. Return the model state_dict and optimizers state_dict.
  984. ``get_state_dict`` can process any module that is parallelized by PyTorch
  985. FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any
  986. combination of these parallelisms. The main functions of ``get_state_dict``
  987. are: 1.) returning a model and optimizer state_dict that can be resharded
  988. with a different number of trainers and/or different parallelisms.
  989. 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call
  990. these APIs.
  991. 3.) sanity checking the result state_dict.
  992. The keys of the result state dictionary are the canonical FQNs (Fully
  993. Qualified Names). A canonical FQN refers to the FQN based on a parameter's
  994. position in an nn.Module hierarchy. More specifically, a canonical FQN to a
  995. parameter is the FQN returned by ``module.named_parameters()`` or
  996. ``module.named_buffers()`` when the module is not distributed by any
  997. parallelisms. Since the optimizer internally uses parameter IDs to represent
  998. a parameter, there will be a conversion from the parameter IDs to the
  999. canonical FQNs when calling this API.
  1000. ``get_state_dict`` can also process a module that is not parallelized. In
  1001. such a case, ``get_state_dict`` only performs one function -- converting the
  1002. optimizer parameter IDs to the canonical FQNs.
  1003. Example:
  1004. >>> # xdoctest: +SKIP
  1005. >>> import torch
  1006. >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1007. >>> from torch.nn.parallel import DistributedDataParallel as DDP
  1008. >>> from torch.distributed.checkpoint.state_dict import get_state_dict
  1009. >>> fsdp_model = FSDP(copy.deepcopy(model))
  1010. >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
  1011. >>> ddp_model = DDP(copy.deepcopy(model))
  1012. >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
  1013. >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
  1014. >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(
  1015. ... fsdp_model, fsdp_optim
  1016. ... )
  1017. >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
  1018. >>> # the asserts will fail.
  1019. >>> assert ddp_state_dict == fsdp_state_dict
  1020. >>> assert ddp_optim_state == fsdp_optim_state_dict
  1021. Args:
  1022. model (nn.Module): the nn.Module to the model.
  1023. optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
  1024. The optimizers that are used to optimize ``model``.
  1025. submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters
  1026. that belong to the submodules.
  1027. options (StateDictOptions): the options to control how
  1028. model state_dict and optimizer state_dict should be returned. See
  1029. `StateDictOptions` for the details.
  1030. Returns:
  1031. ``Tuple`` that contain model state_dict and optimizer state_dict.
  1032. :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType]
  1033. """
  1034. with _gc_context():
  1035. optimizers = (
  1036. (optimizers,)
  1037. if isinstance(optimizers, torch.optim.Optimizer)
  1038. else tuple(optimizers)
  1039. )
  1040. info = _verify_options(
  1041. model,
  1042. optimizers,
  1043. optim_only=False,
  1044. submodules=submodules,
  1045. options=options,
  1046. )
  1047. model_state_dict = _get_model_state_dict(model, info)
  1048. optim_state_dict = _get_optim_state_dict(model, optimizers, info)
  1049. _verify_state_dict(model_state_dict, optim_state_dict, info)
  1050. return model_state_dict, optim_state_dict
  1051. def _unflatten_model_state_dict(
  1052. model: nn.Module,
  1053. state_dict: Union[dict[nn.Module, dict[str, ValueType]], dict[str, ValueType]],
  1054. ) -> dict[str, ValueType]:
  1055. if not state_dict:
  1056. return {}
  1057. if isinstance(next(iter(state_dict.keys())), nn.Module):
  1058. warnings.warn(
  1059. "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``"
  1060. "is deprecated and will be removed in 2.5. If you need this "
  1061. "feature, please preprocessing the model_state_dict to achieve the "
  1062. "same functionality.",
  1063. FutureWarning,
  1064. )
  1065. cast_state_dict = cast(dict[nn.Module, dict[str, ValueType]], state_dict)
  1066. new_state_dict: dict[str, ValueType] = {}
  1067. for submodule, sub_state_dict in cast_state_dict.items():
  1068. for name, m in model.named_modules():
  1069. if m != submodule:
  1070. continue
  1071. fqns = _get_fqns(model, name)
  1072. assert len(fqns) == 1, "FQNs for a submodule should only have 1 element"
  1073. prefix = f"{next(iter(fqns))}."
  1074. new_state_dict.update(
  1075. {prefix + subfqn: value for subfqn, value in sub_state_dict.items()}
  1076. )
  1077. return new_state_dict
  1078. else:
  1079. return cast(dict[str, ValueType], state_dict)
  1080. def set_model_state_dict(
  1081. model: nn.Module,
  1082. model_state_dict: dict[str, ValueType],
  1083. *,
  1084. options: Optional[StateDictOptions] = None,
  1085. ) -> _IncompatibleKeys:
  1086. """Load the model state_dict.
  1087. The counterpart of ``get_model_state_dict`` to set the state_dict to the
  1088. model. See ``set_state_dict`` for the detail usage.
  1089. Args:
  1090. model (nn.Module): the nn.Module to the model.
  1091. model_state_dict: (Dict[str, ValueType]):
  1092. the model state_dict to load. If the key of the ``model_state_dict``
  1093. is nn.Module, the key is a submodule of ``model`` and the value should
  1094. be the state_dict of the submodule. When loading the state_dict,
  1095. the prefix of the submodule will be append to the state_dict.
  1096. options (StateDictOptions): the options to control how
  1097. model state_dict and optimizer state_dict should be loaded. See
  1098. `StateDictOptions` for the details.
  1099. Returns:
  1100. ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
  1101. * **missing_keys** is a list of str containing the missing keys
  1102. * **unexpected_keys** is a list of str containing the unexpected keys
  1103. :type model_state_dict: typing.Dict[str, ValueType]
  1104. """
  1105. model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(
  1106. model, model_state_dict
  1107. )
  1108. with _gc_context():
  1109. info = _verify_options(model, (), optim_only=False, options=options)
  1110. _verify_state_dict(model_state_dict, {}, info)
  1111. return _load_model_state_dict(model, model_state_dict, info)
  1112. def set_optimizer_state_dict(
  1113. model: nn.Module,
  1114. optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
  1115. optim_state_dict: OptimizerStateType,
  1116. *,
  1117. options: Optional[StateDictOptions] = None,
  1118. ) -> None:
  1119. """Load the optimizers state_dict.
  1120. The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the
  1121. optimizers. See ``set_state_dict`` for the detail usage.
  1122. WARN: ``set_optimizer_state_dict`` can only be called before ``backward()`` or after
  1123. ``step()`` is called on the optimizers. Otherwise, the optimizer states won't be
  1124. initialized correctly.
  1125. Args:
  1126. model (nn.Module): the nn.Module to the model.
  1127. optimizers (Union[Optimizer, Iterable[Optimizer]]):
  1128. The optimizers that are used to optimize ``model``.
  1129. optim_state_dict: OptimizerStateType:
  1130. the optimizer state_dict to load.
  1131. options (StateDictOptions): the options to control how
  1132. model state_dict and optimizer state_dict should be loaded. See
  1133. `StateDictOptions` for the details.
  1134. Returns:
  1135. None
  1136. :type optim_state_dict: typing.OptimizerStateType
  1137. """
  1138. with _gc_context():
  1139. optimizers = (
  1140. (optimizers,)
  1141. if isinstance(optimizers, torch.optim.Optimizer)
  1142. else tuple(optimizers)
  1143. )
  1144. info = _verify_options(model, optimizers, optim_only=True, options=options)
  1145. _verify_state_dict({}, optim_state_dict, info)
  1146. _load_optim_state_dict(model, optimizers, optim_state_dict, info)
  1147. def set_state_dict(
  1148. model: nn.Module,
  1149. optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
  1150. *,
  1151. model_state_dict: dict[str, ValueType],
  1152. optim_state_dict: OptimizerStateType,
  1153. options: Optional[StateDictOptions] = None,
  1154. ) -> _IncompatibleKeys:
  1155. """Load the model state_dict and optimizers state_dict.
  1156. The counterpart of ``get_state_dict`` to set the state_dict to the model and
  1157. optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not
  1158. have to be returned by ``get_state_dict`` but must meet the following
  1159. requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``,
  1160. 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor,
  1161. 3) optimizer state_dict cannot contain the parameter IDs; the keys should be
  1162. the canonical FQNs.
  1163. WARN: ``set_state_dict`` can only be called before ``backward()`` or after ``step()``
  1164. is called on the optimizers. Otherwise, the optimizer states won't be initialized
  1165. correctly.
  1166. Args:
  1167. model (nn.Module): the nn.Module to the model.
  1168. optimizers (Union[Optimizer, Iterable[Optimizer]]):
  1169. The optimizers that are used to optimize ``model``.
  1170. model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):
  1171. the model state_dict to load. If the key of the ``model_state_dict``
  1172. is nn.Module, the key is a submodule of ``model`` and the value should
  1173. be the state_dict of the submodule. When loading the state_dict,
  1174. the prefix of the submodule will be append to the state_dict.
  1175. optim_state_dict: OptimizerStateType:
  1176. the optimizer state_dict to load.
  1177. options (StateDictOptions): the options to control how
  1178. model state_dict and optimizer state_dict should be loaded. See
  1179. `StateDictOptions` for the details.
  1180. Returns:
  1181. ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
  1182. * **missing_keys** is a list of str containing the missing keys of the model state_dict.
  1183. * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict.
  1184. :type model_state_dict: typing.Dict[str, ValueType]
  1185. :type optim_state_dict: typing.OptimizerStateType
  1186. """
  1187. model_state_dict: dict[str, ValueType] = _unflatten_model_state_dict(
  1188. model, model_state_dict
  1189. )
  1190. with _gc_context():
  1191. optimizers = (
  1192. (optimizers,)
  1193. if isinstance(optimizers, torch.optim.Optimizer)
  1194. else tuple(optimizers)
  1195. )
  1196. info = _verify_options(
  1197. model, optimizers, optim_only=not model_state_dict, options=options
  1198. )
  1199. _verify_state_dict(model_state_dict, optim_state_dict, info)
  1200. _load_optim_state_dict(model, optimizers, optim_state_dict, info)
  1201. return _load_model_state_dict(model, model_state_dict, info)
  1202. # TODO: correct the state_dict function signature.
  1203. # TODO: this API is not yet fully tested. Make it private
  1204. @no_type_check
  1205. def _patch_model_state_dict(
  1206. model: nn.Module,
  1207. *,
  1208. options: Optional[StateDictOptions] = None,
  1209. ) -> None:
  1210. """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``.
  1211. Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to
  1212. be a partial function to call ``get_state_dict`` and ``set_state_dict``.
  1213. Example:
  1214. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1215. from torch.distributed.checkpoint.state_dict import patch_model_state_dict
  1216. model = fsdp(model)
  1217. patch_model_state_dict(model)
  1218. Args:
  1219. model (nn.Module): the nn.Module to the model.
  1220. options (StateDictOptions): the options to control how
  1221. model state_dict and optimizer state_dict should be loaded. See
  1222. `StateDictOptions` for the details.
  1223. Returns:
  1224. None
  1225. """
  1226. _state_dict_call = functools.partial(
  1227. get_model_state_dict,
  1228. model=model,
  1229. options=options,
  1230. )
  1231. def state_dict_call():
  1232. return _state_dict_call()
  1233. model.state_dict = state_dict_call
  1234. _load_state_dict_call = functools.partial(
  1235. set_model_state_dict,
  1236. model=model,
  1237. options=options,
  1238. )
  1239. def load_state_dict_call(state_dict: dict[str, Any]):
  1240. _load_state_dict_call(model_state_dict=state_dict)
  1241. model.load_state_dict = load_state_dict_call
  1242. _patched_state_dict.add(state_dict_call)
  1243. _patched_state_dict.add(load_state_dict_call)
  1244. # TODO: correct the load_state_dict function signature.
  1245. # TODO: this API is not yet fully tested. Make it private
  1246. @no_type_check
  1247. def _patch_optimizer_state_dict(
  1248. model: nn.Module,
  1249. *,
  1250. optimizers: tuple[torch.optim.Optimizer, ...],
  1251. options: Optional[StateDictOptions] = None,
  1252. ) -> None:
  1253. """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``.
  1254. Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to
  1255. be a partial function to call ``get_state_dict`` and ``set_state_dict``.
  1256. Note that if there are multiple optimizers, all of the optimizers will be patched.
  1257. So users only need to call one of the state_dict() to get the full result.
  1258. Example:
  1259. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1260. from torch.distributed.checkpoint.state_dict import patch_model_state_dict
  1261. model = fsdp(model)
  1262. patch_model_state_dict(model)
  1263. Args:
  1264. model (nn.Module): the nn.Module to the model.
  1265. options (StateDictOptions): the options to control how
  1266. model state_dict and optimizer state_dict should be loaded. See
  1267. `StateDictOptions` for the details.
  1268. Returns:
  1269. None
  1270. """
  1271. _state_dict_call = functools.partial(
  1272. get_optimizer_state_dict,
  1273. model=model,
  1274. optimizers=optimizers,
  1275. options=options,
  1276. )
  1277. def state_dict_call():
  1278. return _state_dict_call()
  1279. _load_state_dict_call = functools.partial(
  1280. set_optimizer_state_dict,
  1281. model=model,
  1282. optimizers=optimizers,
  1283. options=options,
  1284. )
  1285. def load_state_dict_call(state_dict: dict[str, Any]):
  1286. _load_state_dict_call(optim_state_dict=state_dict)
  1287. _patched_state_dict.add(state_dict_call)
  1288. _patched_state_dict.add(load_state_dict_call)
  1289. optimizers = (
  1290. (optimizers,)
  1291. if isinstance(optimizers, torch.optim.Optimizer)
  1292. else tuple(optimizers)
  1293. )
  1294. for optim in optimizers:
  1295. optim.state_dict = state_dict_call
  1296. optim.load_state_dict = load_state_dict_call