api.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. """
  2. This file includes public APIs for FSDP such as the classes used for the
  3. constructor arguments.
  4. """
  5. from collections.abc import Sequence
  6. from dataclasses import dataclass
  7. from enum import auto, Enum
  8. from typing import Optional
  9. import torch
  10. from torch.nn.modules.batchnorm import _BatchNorm
  11. __all__ = [
  12. "ShardingStrategy",
  13. "BackwardPrefetch",
  14. "MixedPrecision",
  15. "CPUOffload",
  16. "StateDictType",
  17. "StateDictConfig",
  18. "FullStateDictConfig",
  19. "LocalStateDictConfig",
  20. "ShardedStateDictConfig",
  21. "OptimStateDictConfig",
  22. "FullOptimStateDictConfig",
  23. "LocalOptimStateDictConfig",
  24. "ShardedOptimStateDictConfig",
  25. "StateDictSettings",
  26. ]
  27. class ShardingStrategy(Enum):
  28. """
  29. This specifies the sharding strategy to be used for distributed training by
  30. :class:`FullyShardedDataParallel`.
  31. - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.
  32. For the parameters, this strategy unshards (via all-gather) before the
  33. forward, reshards after the forward, unshards before the backward
  34. computation, and reshards after the backward computation. For gradients,
  35. it synchronizes and shards them (via reduce-scatter) after the backward
  36. computation. The sharded optimizer states are updated locally per rank.
  37. - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during
  38. computation, and additionally, parameters are sharded outside
  39. computation. For the parameters, this strategy unshards before the
  40. forward, does not reshard them after the forward, and only reshards them
  41. after the backward computation. The sharded optimizer states are updated
  42. locally per rank. Inside ``no_sync()``, the parameters are not resharded
  43. after the backward computation.
  44. - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded
  45. but instead replicated across ranks similar to PyTorch's
  46. :class:`DistributedDataParallel` API. For gradients, this strategy
  47. synchronizes them (via all-reduce) after the backward computation. The
  48. unsharded optimizer states are updated locally per rank.
  49. - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
  50. nodes. This results in reduced communication volume as expensive all-gathers and
  51. reduce-scatters are only done within a node, which can be more performant for medium
  52. -sized models.
  53. - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
  54. nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
  55. since the unsharded parameters are not freed after the forward pass, saving the
  56. all-gathers in the pre-backward.
  57. """
  58. FULL_SHARD = auto()
  59. SHARD_GRAD_OP = auto()
  60. NO_SHARD = auto()
  61. HYBRID_SHARD = auto()
  62. _HYBRID_SHARD_ZERO2 = auto()
  63. class BackwardPrefetch(Enum):
  64. """
  65. This configures explicit backward prefetching, which improves throughput by
  66. enabling communication and computation overlap in the backward pass at the
  67. cost of slightly increased memory usage.
  68. - ``BACKWARD_PRE``: This enables the most overlap but increases memory
  69. usage the most. This prefetches the next set of parameters *before* the
  70. current set of parameters' gradient computation. This overlaps the *next
  71. all-gather* and the *current gradient computation*, and at the peak, it
  72. holds the current set of parameters, next set of parameters, and current
  73. set of gradients in memory.
  74. - ``BACKWARD_POST``: This enables less overlap but requires less memory
  75. usage. This prefetches the next set of parameters *after* the current
  76. set of parameters' gradient computation. This overlaps the *current
  77. reduce-scatter* and the *next gradient computation*, and it frees the
  78. current set of parameters before allocating memory for the next set of
  79. parameters, only holding the next set of parameters and current set of
  80. gradients in memory at the peak.
  81. - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables
  82. the backward prefetching altogether. This has no overlap and does not
  83. increase memory usage. In general, we do not recommend this setting since
  84. it may degrade throughput significantly.
  85. For more technical context: For a single process group using NCCL backend,
  86. any collectives, even if issued from different streams, contend for the
  87. same per-device NCCL stream, which implies that the relative order in which
  88. the collectives are issued matters for overlapping. The two backward
  89. prefetching values correspond to different issue orders.
  90. """
  91. # NOTE: For both modes, the ordering that defines "current" and "next" is
  92. # not always exact in the current implementation. A mistargeted prefetch
  93. # simply means that the parameter memory is allocated earlier than needed,
  94. # possibly increasing peak memory usage, but does not affect correctness.
  95. BACKWARD_PRE = auto()
  96. BACKWARD_POST = auto()
  97. @dataclass
  98. class MixedPrecision:
  99. """
  100. This configures FSDP-native mixed precision training.
  101. Attributes:
  102. param_dtype (Optional[torch.dtype]): This specifies the dtype for model
  103. parameters during forward and backward and thus the dtype for
  104. forward and backward computation. Outside forward and backward, the
  105. *sharded* parameters are kept in full precision (e.g. for the
  106. optimizer step), and for model checkpointing, the parameters are
  107. always saved in full precision. (Default: ``None``)
  108. reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
  109. gradient reduction (i.e. reduce-scatter or all-reduce). If this is
  110. ``None`` but ``param_dtype`` is not ``None``, then this takes on
  111. the ``param_dtype`` value, still running gradient reduction in low
  112. precision. This is permitted to differ from ``param_dtype``, e.g.
  113. to force gradient reduction to run in full precision. (Default:
  114. ``None``)
  115. buffer_dtype (Optional[torch.dtype]): This specifies the dtype for
  116. buffers. FSDP does not shard buffers. Rather, FSDP casts them to
  117. ``buffer_dtype`` in the first forward pass and keeps them in that
  118. dtype thereafter. For model checkpointing, the buffers are saved
  119. in full precision except for ``LOCAL_STATE_DICT``. (Default:
  120. ``None``)
  121. keep_low_precision_grads (bool): If ``False``, then FSDP upcasts
  122. gradients to full precision after the backward pass in preparation
  123. for the optimizer step. If ``True``, then FSDP keeps the gradients
  124. in the dtype used for gradient reduction, which can save memory if
  125. using a custom optimizer that supports running in low precision.
  126. (Default: ``False``)
  127. cast_forward_inputs (bool): If ``True``, then this FSDP module casts
  128. its forward args and kwargs to ``param_dtype``. This is to ensure
  129. that parameter and input dtypes match for forward computation, as
  130. required by many ops. This may need to be set to ``True`` when only
  131. applying mixed precision to some but not all FSDP modules, in which
  132. case a mixed-precision FSDP submodule needs to recast its inputs.
  133. (Default: ``False``)
  134. cast_root_forward_inputs (bool): If ``True``, then the root FSDP module
  135. casts its forward args and kwargs to ``param_dtype``, overriding
  136. the value of ``cast_forward_inputs``. For non-root FSDP modules,
  137. this does not do anything. (Default: ``True``)
  138. _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies
  139. module classes to ignore for mixed precision when using an
  140. ``auto_wrap_policy``: Modules of these classes will have FSDP
  141. applied to them separately with mixed precision disabled (meaning
  142. that the final FSDP construction would deviate from the specified
  143. policy). If ``auto_wrap_policy`` is not specified, then this does
  144. not do anything. This API is experimental and subject to change.
  145. (Default: ``(_BatchNorm,)``)
  146. .. note:: This API is experimental and subject to change.
  147. .. note:: Only floating point tensors are cast to their specified dtypes.
  148. .. note:: In ``summon_full_params``, parameters are forced to full
  149. precision, but buffers are not.
  150. .. note:: Layer norm and batch norm accumulate in ``float32`` even when
  151. their inputs are in a low precision like ``float16`` or ``bfloat16``.
  152. Disabling FSDP's mixed precision for those norm modules only means that
  153. the affine parameters are kept in ``float32``. However, this incurs
  154. separate all-gathers and reduce-scatters for those norm modules, which
  155. may be inefficient, so if the workload permits, the user should prefer
  156. to still apply mixed precision to those modules.
  157. .. note:: By default, if the user passes a model with any ``_BatchNorm``
  158. modules and specifies an ``auto_wrap_policy``, then the batch norm
  159. modules will have FSDP applied to them separately with mixed precision
  160. disabled. See the ``_module_classes_to_ignore`` argument.
  161. .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
  162. ``cast_forward_inputs=False`` by default. For the root FSDP instance,
  163. its ``cast_root_forward_inputs`` takes precedence over its
  164. ``cast_forward_inputs``. For non-root FSDP instances, their
  165. ``cast_root_forward_inputs`` values are ignored. The default setting is
  166. sufficient for the typical case where each FSDP instance has the same
  167. ``MixedPrecision`` configuration and only needs to cast inputs to the
  168. ``param_dtype`` at the beginning of the model's forward pass.
  169. .. note:: For nested FSDP instances with different ``MixedPrecision``
  170. configurations, we recommend setting individual ``cast_forward_inputs``
  171. values to configure casting inputs or not before each instance's
  172. forward. In such a case, since the casts happen before each FSDP
  173. instance's forward, a parent FSDP instance should have its non-FSDP
  174. submodules run before its FSDP submodules to avoid the activation dtype
  175. being changed due to a different ``MixedPrecision`` configuration.
  176. Example::
  177. >>> # xdoctest: +SKIP("undefined variables")
  178. >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
  179. >>> model[1] = FSDP(
  180. >>> model[1],
  181. >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
  182. >>> )
  183. >>> model = FSDP(
  184. >>> model,
  185. >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
  186. >>> )
  187. The above shows a working example. On the other hand, if ``model[1]``
  188. were replaced with ``model[0]``, meaning that the submodule using
  189. different ``MixedPrecision`` ran its forward first, then ``model[1]``
  190. would incorrectly see ``float16`` activations instead of ``bfloat16``
  191. ones.
  192. """
  193. param_dtype: Optional[torch.dtype] = None
  194. reduce_dtype: Optional[torch.dtype] = None
  195. buffer_dtype: Optional[torch.dtype] = None
  196. keep_low_precision_grads: bool = False
  197. cast_forward_inputs: bool = False
  198. cast_root_forward_inputs: bool = True
  199. _module_classes_to_ignore: Sequence[type[torch.nn.Module]] = (_BatchNorm,)
  200. @dataclass
  201. class CPUOffload:
  202. """
  203. This configures CPU offloading.
  204. Attributes:
  205. offload_params (bool): This specifies whether to offload parameters to
  206. CPU when not involved in computation. If ``True``, then this
  207. offloads gradients to CPU as well, meaning that the optimizer step
  208. runs on CPU.
  209. """
  210. offload_params: bool = False
  211. class StateDictType(Enum):
  212. """
  213. This enum indicates that which type of ``state_dict`` the FSDP module is
  214. currently processing (returning or loading).
  215. The default value is FULL_STATE_DICT to comply the PyTorch convention.
  216. .. note::
  217. FSDP currently supports three types of ``state_dict``:
  218. 1. ``state_dict/load_state_dict`: this pair of APIs return and load
  219. the non-sharded, unflattened parameters. The semantics is the
  220. same as using DDP.
  221. 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
  222. and load local sharded, flattened parameters. The values returned
  223. by ``_local_state_dict`` can be directly used by FSDP and is only
  224. meaningful to FSDP (because parameters are flattened). Note that
  225. these APIs are meant for use via the :func:`state_dict_type`
  226. context manager as follows:
  227. >>> # xdoctest: +SKIP("undefined variables")
  228. >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
  229. ... state = fsdp.state_dict() # loads local state dict
  230. 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
  231. return and load sharded, unflattened parameters. The ``state_dict``
  232. return by ``sharded_state_dict`` can be used by all other parallel
  233. schemes (resharding may be required).
  234. """
  235. FULL_STATE_DICT = auto()
  236. LOCAL_STATE_DICT = auto()
  237. SHARDED_STATE_DICT = auto()
  238. @dataclass
  239. class StateDictConfig:
  240. """
  241. ``StateDictConfig`` is the base class for all ``state_dict`` configuration
  242. classes. Users should instantiate a child class (e.g.
  243. ``FullStateDictConfig``) in order to configure settings for the
  244. corresponding ``state_dict`` type supported by FSDP.
  245. Attributes:
  246. offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict
  247. values to CPU, and if ``False``, then FSDP keeps them on GPU.
  248. (Default: ``False``)
  249. """
  250. offload_to_cpu: bool = False
  251. @dataclass
  252. class FullStateDictConfig(StateDictConfig):
  253. """
  254. ``FullStateDictConfig`` is a config class meant to be used with
  255. ``StateDictType.FULL_STATE_DICT``. We recommend enabling both
  256. ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state
  257. dicts to save GPU memory and CPU memory, respectively. This config class
  258. is meant to be used via the :func:`state_dict_type` context manager as
  259. follows:
  260. >>> # xdoctest: +SKIP("undefined variables")
  261. >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  262. >>> fsdp = FSDP(model, auto_wrap_policy=...)
  263. >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
  264. >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
  265. >>> state = fsdp.state_dict()
  266. >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
  267. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
  268. >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
  269. >>> if dist.get_rank() == 0:
  270. >>> # Load checkpoint only on rank 0 to avoid memory redundancy
  271. >>> state_dict = torch.load("my_checkpoint.pt")
  272. >>> model.load_state_dict(state_dict)
  273. >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
  274. >>> # communicates loaded checkpoint states from rank 0 to rest of the world.
  275. >>> fsdp = FSDP(
  276. ... model,
  277. ... device_id=torch.cuda.current_device(),
  278. ... auto_wrap_policy=...,
  279. ... sync_module_states=True,
  280. ... )
  281. >>> # After this point, all ranks have FSDP model with loaded checkpoint.
  282. Attributes:
  283. rank0_only (bool): If ``True``, then only rank 0 saves the full state
  284. dict, and nonzero ranks save an empty dict. If ``False``, then all
  285. ranks save the full state dict. (Default: ``False``)
  286. """
  287. rank0_only: bool = False
  288. @dataclass
  289. class LocalStateDictConfig(StateDictConfig):
  290. pass
  291. @dataclass
  292. class ShardedStateDictConfig(StateDictConfig):
  293. """
  294. ``ShardedStateDictConfig`` is a config class meant to be used with
  295. ``StateDictType.SHARDED_STATE_DICT``.
  296. Attributes:
  297. _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
  298. as ``DTensor``, and if ``False``, then FSDP saves them as
  299. ``ShardedTensor``. (Default: ``False``)
  300. .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig`
  301. and it is used by FSDP to determine the type of state dict values. Users should not
  302. manually modify ``_use_dtensor``.
  303. """
  304. _use_dtensor: bool = False
  305. @dataclass
  306. class OptimStateDictConfig:
  307. """
  308. ``OptimStateDictConfig`` is the base class for all ``optim_state_dict``
  309. configuration classes. Users should instantiate a child class (e.g.
  310. ``FullOptimStateDictConfig``) in order to configure settings for the
  311. corresponding ``optim_state_dict`` type supported by FSDP.
  312. Attributes:
  313. offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's
  314. tensor values to CPU, and if ``False``, then FSDP keeps them on the
  315. original device (which is GPU unless parameter CPU offloading is
  316. enabled). (Default: ``True``)
  317. """
  318. offload_to_cpu: bool = True
  319. @dataclass
  320. class FullOptimStateDictConfig(OptimStateDictConfig):
  321. """
  322. Attributes:
  323. rank0_only (bool): If ``True``, then only rank 0 saves the full state
  324. dict, and nonzero ranks save an empty dict. If ``False``, then all
  325. ranks save the full state dict. (Default: ``False``)
  326. """
  327. rank0_only: bool = False
  328. @dataclass
  329. class LocalOptimStateDictConfig(OptimStateDictConfig):
  330. offload_to_cpu: bool = False
  331. @dataclass
  332. class ShardedOptimStateDictConfig(OptimStateDictConfig):
  333. """
  334. ``ShardedOptimStateDictConfig`` is a config class meant to be used with
  335. ``StateDictType.SHARDED_STATE_DICT``.
  336. Attributes:
  337. _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
  338. as ``DTensor``, and if ``False``, then FSDP saves them as
  339. ``ShardedTensor``. (Default: ``False``)
  340. .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig`
  341. and it is used by FSDP to determine the type of state dict values. Users should not
  342. manually modify ``_use_dtensor``.
  343. """
  344. _use_dtensor: bool = False
  345. @dataclass
  346. class StateDictSettings:
  347. state_dict_type: StateDictType
  348. state_dict_config: StateDictConfig
  349. optim_state_dict_config: OptimStateDictConfig