contract.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # mypy: allow-untyped-defs
  2. import uuid
  3. from collections import OrderedDict
  4. from functools import wraps
  5. from typing import Callable, Generic, Optional, Protocol
  6. from typing_extensions import Concatenate, ParamSpec, TypeVar
  7. import torch
  8. import torch.nn as nn
  9. from torch.distributed._composable_state import _State
  10. from torch.distributed.utils import _get_root_modules
  11. _T = TypeVar("_T", covariant=True)
  12. _P = ParamSpec("_P")
  13. def generate_state_key(string="__composable_api_state_key"):
  14. return f"{string}_{str(uuid.uuid4())}"
  15. STATE_KEY = generate_state_key()
  16. REGISTRY_KEY = generate_state_key()
  17. # TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
  18. # we can add args and kwargs here, and then we can detect whether fully_shard
  19. # is combined with reentrant activation checkpointing and error out with a clear
  20. # message.
  21. class RegistryItem:
  22. pass
  23. _TState = TypeVar("_TState", bound="_State", covariant=True)
  24. _M = TypeVar("_M", nn.Module, list[nn.Module])
  25. class _ContractFn(Protocol, Generic[_P, _T, _TState]):
  26. def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
  27. def state(self, module: nn.Module) -> _TState: ...
  28. def contract(
  29. state_cls: type[_TState] = _State, # type: ignore[assignment]
  30. ) -> Callable[
  31. [Callable[Concatenate[_M, _P], _M]],
  32. _ContractFn[Concatenate[_M, _P], _M, _TState],
  33. ]:
  34. r"""
  35. Decorate a function as a composable distributed API, where the first
  36. argument of the function must be an :class:`nn.Module` instance or sequence
  37. of :class:`nn.Module` instances.
  38. The decorator verifies that the decorated function does not modify
  39. fully-qualified names (FQNs) for parameters, buffers, or modules. The
  40. decorated function can return different module instances than the input
  41. modules; the FQN invariant will be enforced following the input order.
  42. When a function ``func`` is decorated by ``@contract()``, a
  43. ``.state(module: nn.Module)`` method will be installed to the decorated
  44. function. Then you can retrieve and modify the state on a module by calling
  45. ``func.state(module)``.
  46. Example::
  47. >>> # xdoctest: +SKIP
  48. >>> import torch.nn as nn
  49. >>>
  50. >>> class MyModel(nn.Module):
  51. >>> def __init__(self) -> None:
  52. >>> super().__init__()
  53. >>> self.l1 = nn.Linear(10, 10)
  54. >>> self.l2 = nn.Linear(10, 10)
  55. >>>
  56. >>> def forward(self, x):
  57. >>> return self.l2(self.l1(x))
  58. >>>
  59. >>> @contract()
  60. >>> def my_feature(module: nn.Module) -> nn.Module:
  61. >>> my_feature.state(module).some_state = "any value"
  62. >>> return module
  63. >>>
  64. >>> model = MyModel()
  65. >>> my_feature(model.l1)
  66. >>> assert my_feature.state(model.l1).some_state == "any value"
  67. >>> my_feature(model.l2)
  68. >>> model(torch.randn(2, 10)).sum().backward()
  69. """
  70. # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
  71. @wraps(state_cls) # type: ignore[arg-type]
  72. def inner(
  73. func: Callable[Concatenate[_M, _P], _M],
  74. ) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
  75. @wraps(func)
  76. def wrapper(
  77. module: _M,
  78. *args: _P.args,
  79. **kwargs: _P.kwargs,
  80. ) -> _M:
  81. inp_module = module
  82. modules: list[nn.Module]
  83. if isinstance(module, nn.Module):
  84. modules = [module]
  85. else:
  86. # If the user passes a sequence of modules, then we assume that
  87. # we only need to insert the state object on the root modules
  88. # (i.e. those without a parent) among the passed-in modules.
  89. modules = _get_root_modules(list(module))
  90. state = state_cls() # shared across all modules
  91. registry_item = RegistryItem() # shared across all modules
  92. # `func` is allowed to return different module instances than the
  93. # input modules as long as FQNs are preserved following the input
  94. # module order
  95. all_orig_named_params: list[dict[str, nn.Parameter]] = []
  96. all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
  97. all_orig_named_modules: list[dict[str, nn.Module]] = []
  98. for module in modules:
  99. default_all_state: dict[Callable, _State] = OrderedDict()
  100. default_registry: dict[str, RegistryItem] = OrderedDict()
  101. all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
  102. STATE_KEY, default_all_state
  103. )
  104. if not isinstance(all_state, dict):
  105. raise AssertionError(
  106. f"Distributed composable API states corrupted: {all_state}"
  107. )
  108. registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
  109. REGISTRY_KEY, default_registry
  110. )
  111. if not isinstance(registry, dict):
  112. raise AssertionError(
  113. f"Distributed composable API registry corrupted: {registry}"
  114. )
  115. if func in all_state or func.__name__ in registry:
  116. raise AssertionError(
  117. "Each distinct composable distributed API can only be applied to a "
  118. f"module once. {func.__name__} has already been applied to the "
  119. f"following module:\n{module}"
  120. )
  121. all_state.setdefault(func, state)
  122. registry.setdefault(func.__name__, registry_item)
  123. all_orig_named_params.append(OrderedDict(module.named_parameters()))
  124. all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
  125. all_orig_named_modules.append(OrderedDict(module.named_modules()))
  126. updated = func(inp_module, *args, **kwargs)
  127. if updated is None:
  128. updated = inp_module # type: ignore[assignment]
  129. updated_modules: list[nn.Module]
  130. if isinstance(updated, nn.Module):
  131. updated_modules = [updated]
  132. else:
  133. updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type, call-overload]
  134. all_new_named_params: list[dict[str, nn.Parameter]] = []
  135. all_new_named_buffers: list[dict[str, torch.Tensor]] = []
  136. all_new_named_modules: list[dict[str, nn.Module]] = []
  137. for module in updated_modules:
  138. all_new_named_params.append(OrderedDict(module.named_parameters()))
  139. all_new_named_buffers.append(OrderedDict(module.named_buffers()))
  140. all_new_named_modules.append(OrderedDict(module.named_modules()))
  141. num_orig_modules = len(all_orig_named_modules)
  142. num_new_modules = len(all_new_named_modules)
  143. if num_orig_modules != num_new_modules:
  144. raise AssertionError(
  145. f"{func.__name__} should return the same number of modules as input modules"
  146. f"Inputs: {num_orig_modules} modules\n"
  147. f"Outputs: {num_new_modules} modules"
  148. )
  149. def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str):
  150. if orig_fqns == new_fqns:
  151. return
  152. orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
  153. orig_only = orig_fqn_set - new_fqn_set
  154. new_only = new_fqn_set - orig_fqn_set
  155. if len(orig_only) or len(new_only):
  156. raise RuntimeError(
  157. f"{check_key}"
  158. "Composable distributed API implementations cannot modify FQNs.\n"
  159. f"FQNs only in original: {orig_only}\n"
  160. f"FQNs only in new: {new_only}"
  161. )
  162. else:
  163. raise RuntimeError(
  164. f"{check_key}"
  165. "Composable distributed API implementations cannot modify "
  166. "the order of FQNs.\n"
  167. f"Original FQNs: {orig_only}\n"
  168. f"New FQNs: {new_only}"
  169. )
  170. for orig_named_params, new_named_params in zip(
  171. all_orig_named_params, all_new_named_params
  172. ):
  173. check_fqn(
  174. list(orig_named_params.keys()),
  175. list(new_named_params.keys()),
  176. "Checking parameters: ",
  177. )
  178. for orig_named_buffers, new_named_buffers in zip(
  179. all_orig_named_buffers, all_new_named_buffers
  180. ):
  181. check_fqn(
  182. list(orig_named_buffers.keys()),
  183. list(new_named_buffers.keys()),
  184. "Checking buffers: ",
  185. )
  186. for orig_named_modules, new_named_modules in zip(
  187. all_orig_named_modules, all_new_named_modules
  188. ):
  189. check_fqn(
  190. list(orig_named_modules.keys()),
  191. list(new_named_modules.keys()),
  192. "Checking modules: ",
  193. )
  194. # TODO: verify that installed distributed paradigms are compatible with
  195. # each other.
  196. return updated
  197. def get_state(module: nn.Module) -> _State:
  198. return module.__dict__.setdefault( # type: ignore[call-overload]
  199. STATE_KEY,
  200. {}, # TODO(@yhcharles): this is a temporary fix, need a better way
  201. ).get(func) # type: ignore[call-overload]
  202. wrapper.state = get_state # type: ignore[attr-defined]
  203. return wrapper # type: ignore[return-value]
  204. return inner # type: ignore[return-value]
  205. def _get_registry(module: nn.Module) -> Optional[dict[str, RegistryItem]]:
  206. r"""
  207. Get an ``OrderedDict`` of composable APIs that have been applied to the
  208. ``module``, indexed by the API name. If no API has been applied, then this
  209. returns ``None``.
  210. """
  211. return getattr(module, REGISTRY_KEY, None)