wrap.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the BSD license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import contextlib
  7. import copy
  8. from abc import ABC, abstractmethod
  9. from collections.abc import Generator, Iterable, Sequence
  10. from typing import Any, Callable, cast, Optional, Union
  11. import torch.nn as nn
  12. __all__ = [
  13. "always_wrap_policy",
  14. "lambda_auto_wrap_policy",
  15. "transformer_auto_wrap_policy",
  16. "size_based_auto_wrap_policy",
  17. "enable_wrap",
  18. "wrap",
  19. "CustomPolicy",
  20. "ModuleWrapPolicy",
  21. ]
  22. # NOTE: We intentionally keep this function simple and isolate the complexity
  23. # to `fn` to enable using this function generically. We may move this to a
  24. # non-FSDP-specific folder and/or make it public in the future.
  25. def _post_order_apply(
  26. root_module: nn.Module,
  27. fn: Callable[[nn.Module], Optional[nn.Module]],
  28. ):
  29. """
  30. This applies ``fn`` to every module in the module tree of ``root_module``
  31. following a post-order traversal. If ``fn`` returns an :class:`nn.Module`,
  32. then this replaces the original module with the newly returned one in the
  33. tree. Otherwise, ``fn`` should return ``None``, in which case the module is
  34. not changed.
  35. """
  36. # Track visited modules to avoid visiting shared modules multiple times
  37. visited_modules: set[nn.Module] = {root_module}
  38. def _post_order_apply_inner(
  39. module: nn.Module,
  40. module_name: str,
  41. parent_module: Optional[nn.Module],
  42. ):
  43. for child_module_name, child_module in module.named_children():
  44. if child_module not in visited_modules:
  45. visited_modules.add(child_module)
  46. _post_order_apply_inner(child_module, child_module_name, module)
  47. optional_module = fn(module)
  48. if optional_module is not None:
  49. assert isinstance(parent_module, nn.Module), (
  50. "Non-root modules should have their parent module set but got "
  51. f"{parent_module} for {module}"
  52. )
  53. assert module_name, (
  54. "Non-root modules should have their module name set but got "
  55. f"an empty module name for {module}"
  56. )
  57. assert isinstance(optional_module, nn.Module), (
  58. f"fn should return None or an nn.Module but got {optional_module}"
  59. )
  60. setattr(parent_module, module_name, optional_module)
  61. _post_order_apply_inner(root_module, "", None)
  62. def _construct_wrap_fn(
  63. root_module: nn.Module,
  64. target_module_to_kwargs: dict[nn.Module, dict[str, Any]],
  65. fsdp_fn: Callable,
  66. ) -> Callable[[nn.Module], Optional[nn.Module]]:
  67. """
  68. This constructs the "wrap" function to pass to :func:`_post_order_apply`
  69. based on ``target_module_to_kwargs``, which should be constructed from the
  70. wrapping policy.
  71. """
  72. def fn(module: nn.Module) -> Optional[nn.Module]:
  73. # Explicitly avoid wrapping the root module since for FSDP, it is
  74. # handled by the caller
  75. if module in target_module_to_kwargs and module is not root_module:
  76. kwargs = target_module_to_kwargs[module]
  77. return fsdp_fn(module, **kwargs)
  78. return None
  79. return fn
  80. def _run_mixed_precision_override_policy(
  81. root_module: nn.Module,
  82. module_classes: Iterable[type[nn.Module]],
  83. ignored_modules: set[nn.Module],
  84. root_kwargs: dict[str, Any],
  85. target_module_to_kwargs: dict[nn.Module, dict[str, Any]],
  86. ):
  87. module_classes_tuple = tuple(set(module_classes))
  88. for module in root_module.modules():
  89. if module in ignored_modules:
  90. continue
  91. elif isinstance(module, module_classes_tuple):
  92. # This policy overrides any existing policy
  93. if module not in target_module_to_kwargs:
  94. # Only inherit from the root kwargs if not already specified
  95. target_module_to_kwargs[module] = root_kwargs
  96. target_module_to_kwargs[module]["mixed_precision"] = None
  97. return target_module_to_kwargs
  98. def always_wrap_policy(*args, **kwargs) -> bool:
  99. """
  100. A simple recursive wrap policy that always returns ``True``. This means
  101. that every submodule is wrapped by the wrapper class in
  102. :func:`_recursive_wrap`.
  103. """
  104. return True
  105. class _Policy(ABC):
  106. """
  107. This defines an abstract base class that represents a policy for applying
  108. a module-level API.
  109. """
  110. @abstractmethod
  111. def _run_policy(
  112. self,
  113. root_module: nn.Module,
  114. ignored_modules: set[nn.Module],
  115. root_kwargs: dict[str, Any],
  116. ) -> dict[nn.Module, dict[str, Any]]:
  117. """
  118. This should return a dict ``target_module_to_kwargs`` that maps from
  119. each target module to wrap to its kwargs.
  120. """
  121. ...
  122. def _module_wrap_policy(
  123. module: nn.Module,
  124. recurse: bool,
  125. nonwrapped_numel: int,
  126. module_classes: set[type[nn.Module]],
  127. ) -> bool:
  128. """
  129. This auto wrap policy wraps every module that is an instance of any type in
  130. ``module_classes`` as its own FSDP instance. The root module given by
  131. ``module`` is always wrapped as an FSDP instance regardless. Since the
  132. wrapping proceeds bottom up, each FSDP instance manages the parameters in
  133. its subtree excluding any already managed by a child FSDP instance.
  134. Args:
  135. module (nn.Module): Current module being considered.
  136. recurse (bool): If ``False``, then this function must decide whether
  137. ``module`` should be wrapped as an FSDP instance or not. If
  138. ``True``, then the function is still recursing down the module
  139. tree as a part of the DFS.
  140. nonwrapped_numel (int): Parameter numel not yet wrapped.
  141. module_classes (Set[Type[nn.Module]]): Set of module classes that are
  142. wrapped as FSDP instances.
  143. Returns:
  144. ``True`` if ``recurse=True``, and whether ``module`` should be wrapped
  145. if ``recurse=False``.
  146. """
  147. if recurse:
  148. return True # always recurse
  149. return isinstance(module, tuple(module_classes))
  150. class ModuleWrapPolicy(_Policy):
  151. """
  152. This policy applies to every module of the specified module classes,
  153. passing in the kwargs given to the root.
  154. """
  155. def __init__(self, module_classes: Iterable[type[nn.Module]]):
  156. module_classes_set = set(module_classes)
  157. self._module_classes = module_classes_set
  158. self._module_classes_str = str(module_classes_set)
  159. def _run_policy(
  160. self,
  161. root_module: nn.Module,
  162. ignored_modules: set[nn.Module],
  163. root_kwargs: dict[str, Any],
  164. ) -> dict[nn.Module, dict[str, Any]]:
  165. module_classes = tuple(self._module_classes)
  166. target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {}
  167. for module in root_module.modules():
  168. if module in ignored_modules:
  169. continue
  170. elif isinstance(module, module_classes):
  171. # Shallow copy to avoid coupling changes across modules
  172. target_module_to_kwargs[module] = copy.copy(root_kwargs)
  173. return target_module_to_kwargs
  174. def __call__(self, module, recurse, *args, **kwargs):
  175. # nonwrapped_numel is not used.
  176. return _module_wrap_policy(
  177. module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes
  178. )
  179. def __repr__(self) -> str:
  180. return super().__repr__() + f"({self._module_classes_str})"
  181. class CustomPolicy(_Policy):
  182. """
  183. This policy takes in a lambda function that maps a given ``nn.Module`` to
  184. either ``False``, ``True``, or a kwarg dictionary.
  185. - If the function returns ``False`` or an empty dictionary, then the module
  186. does not have the API applied.
  187. - If the function returns ``True``, then the module has the API applied
  188. with the root's kwargs.
  189. - If the function returns a non-empty dictionary, then the module has the
  190. API applied, and the dictionary overrides the root's kwargs.
  191. Example::
  192. >>> # xdoctest: +SKIP("undefined variables")
  193. >>> model = init_transformer_model(...)
  194. >>> def lambda_fn(module: nn.Module):
  195. >>> if module is model.lm_head:
  196. >>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
  197. >>> elif isinstance(module, TransformerBlock):
  198. >>> return True
  199. >>> return False
  200. >>> policy = CustomPolicy(lambda_fn)
  201. >>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
  202. """
  203. def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, dict[str, Any]]]):
  204. self._lambda_fn = lambda_fn
  205. def _run_policy(
  206. self,
  207. root_module: nn.Module,
  208. ignored_modules: set[nn.Module],
  209. root_kwargs: dict[str, Any],
  210. ) -> dict[nn.Module, dict[str, Any]]:
  211. target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {}
  212. for module in root_module.modules():
  213. if module in ignored_modules:
  214. continue
  215. res = self._lambda_fn(module)
  216. if not isinstance(res, (dict, bool)):
  217. raise ValueError(
  218. "The lambda_fn passed to CustomPolicy should return "
  219. f"False/True or a kwarg dict, but it returned {res}"
  220. )
  221. if not res:
  222. continue
  223. kwargs = copy.copy(root_kwargs)
  224. if isinstance(res, dict):
  225. # Override the root kwargs with the ones specified by the
  226. # lambda function
  227. kwargs.update(res)
  228. target_module_to_kwargs[module] = kwargs
  229. return target_module_to_kwargs
  230. def lambda_auto_wrap_policy(
  231. module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
  232. ) -> bool:
  233. """
  234. A convenient auto wrap policy to wrap submodules based on an arbitrary user
  235. function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
  236. a `wrapper_cls` unit.
  237. Return if a module should be wrapped during auto wrapping.
  238. The first three parameters are required by :func:`_recursive_wrap`.
  239. Args:
  240. module (nn.Module): Current module being considered.
  241. recurse (bool): If ``False``, then this function must decide whether
  242. ``module`` should be wrapped as an FSDP instance or not. If
  243. ``True``, then the function is still recursing down the module
  244. tree as a part of the DFS.
  245. nonwrapped_numel (int): Parameter numel not yet wrapped.
  246. lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
  247. this module will be wrapped.
  248. """
  249. if recurse:
  250. return True # always recurse
  251. return lambda_fn(module)
  252. def transformer_auto_wrap_policy(
  253. module: nn.Module,
  254. recurse: bool,
  255. nonwrapped_numel: int,
  256. transformer_layer_cls: set[type[nn.Module]],
  257. ) -> bool:
  258. """
  259. See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
  260. same as ``module_classes``. Note that shared parameters must be wrapped in
  261. the same FSDP instance, so this auto wrap policy can help wrap shared
  262. embeddings into the same FSDP instance for transformer models.
  263. """
  264. return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
  265. def _wrap_module_cls_individually(
  266. module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs
  267. ):
  268. if recurse:
  269. # always recurse
  270. return True
  271. else:
  272. # if not recursing, decide whether we should wrap based on whether the type of module
  273. # is in `module_classes`.
  274. return isinstance(module, tuple(module_classes))
  275. def _or_policy(
  276. module: nn.Module,
  277. recurse: bool,
  278. nonwrapped_numel: int,
  279. policies,
  280. ) -> bool:
  281. """
  282. A policy that wraps ``module`` if any policy in the passed in iterable of
  283. ``policies`` returns ``True``.
  284. """
  285. return any(
  286. policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel)
  287. for policy in policies
  288. )
  289. def size_based_auto_wrap_policy(
  290. module: nn.Module,
  291. recurse: bool,
  292. nonwrapped_numel: int,
  293. # Additional custom arguments
  294. min_num_params: int = int(1e8),
  295. force_leaf_modules: Optional[set[type[nn.Module]]] = None,
  296. exclude_wrap_modules: Optional[set[type[nn.Module]]] = None,
  297. ) -> bool:
  298. """
  299. A size-based auto wrap policy.
  300. Args:
  301. module (nn.Module): Current module being considered.
  302. recurse (bool): If ``False``, then this function must decide whether
  303. ``module`` should be wrapped as an FSDP instance or not. If
  304. ``True``, then the function is still recursing down the module
  305. tree as a part of the DFS.
  306. nonwrapped_numel (int): Parameter numel not yet wrapped.
  307. min_num_params (int): Customizable policy input that controls the size
  308. threshold over which a module is ready to be wrapped. This is in
  309. units of numel.
  310. force_leaf_modules (Optional[set[type[nn.Module]]]): Set of module types to keep
  311. as leaves, i.e. their children will never be wrapped.
  312. exclude_wrap_modules (Optional[set[type[nn.Module]]]): Set of module types to be
  313. excluded in wrapping.
  314. Returns:
  315. Whether ``module`` should be wrapped.
  316. """
  317. force_leaf_modules = (
  318. size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
  319. if force_leaf_modules is None
  320. else force_leaf_modules
  321. )
  322. exclude_wrap_modules = (
  323. size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
  324. if exclude_wrap_modules is None
  325. else exclude_wrap_modules
  326. )
  327. # Keep the argument `min_num_params` for BC for now, but it represents the
  328. # minimum non-wrapped *numel* before triggering a wrapping
  329. min_nonwrapped_numel = min_num_params
  330. is_large = nonwrapped_numel >= min_nonwrapped_numel
  331. if recurse:
  332. # We should recurse if the module is big enough but not in force_leaf_modules list.
  333. return is_large and not isinstance(module, tuple(force_leaf_modules))
  334. else:
  335. # If we are not recursing, determine if we should wrap.
  336. return is_large and not isinstance(module, tuple(exclude_wrap_modules))
  337. # Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
  338. size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined]
  339. size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined]
  340. @contextlib.contextmanager
  341. def enable_wrap(
  342. *, wrapper_cls: Any, **wrapper_kwargs: Any
  343. ) -> Generator[None, None, None]:
  344. """
  345. Context manager to wrap modules using a wrapper.
  346. Useful for when you'd like to apply the same configuration arguments to all
  347. child modules that you wrap. A particularly important use case is wrapping
  348. large layers so that they get sharded (in-place) during initialization, to
  349. avoid running out of system memory. Large layers can indicate that they
  350. should be sharded via the ``wrap`` annotation and this context manager can
  351. provide the exact configuration for these nested instances.
  352. Usage::
  353. with enable_wrap(wrapper_cls, **params):
  354. # Wraps layer in FSDP by default if within context
  355. self.l1 = wrap(torch.nn.Linear(5, 5))
  356. Args:
  357. wrapper_cls:
  358. Class that `wrap` annotation will `wrap` modules with, such as
  359. `FullyShardedDataParallel`.
  360. **wrapper_kwargs:
  361. Configuration settings that will be passed to all ``wrap``
  362. instances inside the context
  363. """
  364. kwargs = {
  365. "wrapper_cls": wrapper_cls,
  366. **wrapper_kwargs,
  367. }
  368. with _ConfigAutoWrap(**kwargs):
  369. yield
  370. def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
  371. """
  372. Annotate that a module should be wrapped. Annotated modules will only be
  373. wrapped if inside of an :func:`enable_wrap` context manager. This allows
  374. a module to be initialized both with and without a wrapper without code
  375. change.
  376. The class that this function wraps the passed in ``nn.Module`` with is the
  377. passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
  378. ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
  379. the ``wrapper_cls`` instance. In the case of duplicate kwargs in
  380. ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
  381. respected.
  382. Usage::
  383. with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
  384. # Wraps layer in FSDP by default if within context
  385. self.l1 = wrap(torch.nn.Linear(5, 5))
  386. Args:
  387. module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
  388. **wrap_overrides: configuration overrides that will take priority over
  389. the values provided by the :func:`enable_wrap` context
  390. """
  391. if _ConfigAutoWrap.in_autowrap_context:
  392. assert _ConfigAutoWrap.wrapper_cls is not None
  393. wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
  394. return _wrap(
  395. module,
  396. _ConfigAutoWrap.wrapper_cls,
  397. **wrap_overrides,
  398. )
  399. return module
  400. def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
  401. assert wrapper_cls is not None
  402. if hasattr(module, "_wrap_overrides"):
  403. # If module has a _wrap_overrides attribute, we force overriding the
  404. # FSDP config with these attributes for this module. Currently this
  405. # is only used to disable mixed precision for BatchNorm when
  406. # auto_wrapping.
  407. overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type, dict-item]
  408. return wrapper_cls(module, **overrides)
  409. return wrapper_cls(module, **kwargs)
  410. def _recursive_wrap(
  411. module: nn.Module,
  412. auto_wrap_policy: Callable,
  413. wrapper_cls: Callable,
  414. ignored_modules: set[nn.Module],
  415. ignored_params: set[nn.Parameter],
  416. only_wrap_children: bool = False,
  417. **kwargs: Any,
  418. ) -> tuple[nn.Module, int]:
  419. """
  420. Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
  421. ``True`` with ``wrapper_cls``.
  422. Args:
  423. module (nn.Module): Module to recursively wrap.
  424. auto_wrap_policy (Callable): A callable representing a policy that
  425. determines which modules to recursively wrap with ``wrapper_cls``.
  426. ignored_modules (set[torch.nn.Module]): Modules to ignore when
  427. wrapping.
  428. ignored_params (set[torch.nn.Parameter]): Parameters to ignore when
  429. wrapping; these should be the parameters contained in the modules
  430. in ``ignored_modules``.
  431. Returns:
  432. (nn.Module, int):
  433. ``module`` after wrapping and the numel recursively wrapped.
  434. """
  435. assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
  436. assert wrapper_cls is not None, "Must specify wrapper_cls"
  437. # Make sure no child is already wrapped.
  438. for _, child in module.named_modules():
  439. if child in ignored_modules:
  440. continue
  441. try:
  442. assert not isinstance(child, cast(type, wrapper_cls))
  443. except TypeError:
  444. # wrapper_cls is a function as opposed to a class type, just bypass above check.
  445. pass
  446. # We count all params, assuming none of them are already wrapped.
  447. nonwrapped_numel = sum(
  448. p.numel() for p in module.parameters() if p not in ignored_params
  449. )
  450. assert auto_wrap_policy is not None
  451. if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
  452. total_wrapped_numel = 0
  453. # Iterate through the children, recursively wrap if necessary
  454. for name, child in module.named_children():
  455. if child in ignored_modules:
  456. continue
  457. wrapped_child, num_wrapped_params = _recursive_wrap(
  458. module=child,
  459. auto_wrap_policy=auto_wrap_policy,
  460. wrapper_cls=wrapper_cls,
  461. ignored_modules=ignored_modules,
  462. ignored_params=ignored_params,
  463. **kwargs,
  464. )
  465. setattr(module, name, wrapped_child)
  466. # Keep track of how many parameters have been wrapped
  467. total_wrapped_numel += num_wrapped_params
  468. # decide if we need to wrap the current module,
  469. # since the left over parameters exceed the number of params to wrap
  470. remainder = nonwrapped_numel - total_wrapped_numel
  471. if not only_wrap_children and auto_wrap_policy(
  472. module=module, recurse=False, nonwrapped_numel=remainder
  473. ):
  474. # Leaf node or final wrapping of the remainder both happen here.
  475. return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  476. else:
  477. return module, total_wrapped_numel
  478. return module, 0
  479. class _ConfigAutoWrap:
  480. """
  481. Helper class to wrap modules based on default config args via a context manager.
  482. See :func:`enable_wrap` for more information.
  483. """
  484. in_autowrap_context: bool = False # Context flag
  485. wrapper_cls: Optional[Callable] = None # The wrapper class
  486. kwargs: dict[str, Any] = {} # Wrapper's args
  487. def __init__(self, **kwargs: dict[str, Any]):
  488. self.kwargs = kwargs
  489. @staticmethod
  490. def enable_autowrap_context(kwargs: Any) -> None:
  491. if _ConfigAutoWrap.in_autowrap_context:
  492. raise NotImplementedError(
  493. "You are already within an autowrap context and we currently do not supported nested autowrap."
  494. )
  495. _ConfigAutoWrap.in_autowrap_context = True
  496. # Get and save the wrapper cls for the context.
  497. assert "wrapper_cls" in kwargs.keys(), (
  498. "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
  499. )
  500. _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
  501. del kwargs["wrapper_cls"]
  502. # Save the rest.
  503. _ConfigAutoWrap.kwargs = kwargs
  504. @staticmethod
  505. def disable_autowrap_context() -> None:
  506. _ConfigAutoWrap.in_autowrap_context = False
  507. _ConfigAutoWrap.wrapper_cls = None
  508. _ConfigAutoWrap.kwargs = {}
  509. def __enter__(self) -> None:
  510. self.enable_autowrap_context(self.kwargs)
  511. def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
  512. self.disable_autowrap_context()