distributed.py 107 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import functools
  4. import inspect
  5. import itertools
  6. import logging
  7. import os
  8. import sys
  9. import warnings
  10. import weakref
  11. from collections import defaultdict, deque
  12. from contextlib import contextmanager
  13. from dataclasses import dataclass, fields, is_dataclass
  14. from enum import auto, Enum
  15. from typing import Any, Callable, Optional, TYPE_CHECKING
  16. import torch
  17. import torch.distributed as dist
  18. from torch._utils import _get_device_index
  19. from torch.autograd import Function, Variable
  20. from torch.distributed.algorithms.join import Join, Joinable, JoinHook
  21. from torch.nn.modules import Module
  22. from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
  23. from torch.utils._pytree import tree_flatten, tree_unflatten
  24. RPC_AVAILABLE = False
  25. if dist.is_available():
  26. from torch.distributed.distributed_c10d import (
  27. _get_default_group,
  28. _rank_not_in_group,
  29. ReduceOp,
  30. )
  31. from torch.distributed.utils import (
  32. _alloc_storage,
  33. _cast_forward_inputs,
  34. _free_storage,
  35. _sync_module_states,
  36. _to_kwargs,
  37. _verify_param_shape_across_processes,
  38. )
  39. if dist.rpc.is_available():
  40. RPC_AVAILABLE = True
  41. from torch.distributed.rpc import RRef
  42. if TYPE_CHECKING:
  43. from torch.utils.hooks import RemovableHandle
  44. __all__ = ["DistributedDataParallel"]
  45. logger = logging.getLogger(__name__)
  46. @dataclass
  47. class _MixedPrecision:
  48. """
  49. This configures DDP-native mixed precision training.
  50. Attributes:
  51. param_dtype (torch.dtype): This specifies the dtype for model
  52. parameters, inputs (when ``cast_forward_inputs`` is set to
  53. ``True``), and therefore the dtype for computation.
  54. However, outside the forward and backward passes, parameters are in
  55. full precision. Model checkpointing always happens in full
  56. precision.
  57. reduce_dtype (torch.dtype): This specifies the dtype for gradient
  58. reduction, which is permitted to differ from ``param_dtype``.
  59. buffer_dtype (torch.dtype): This specifies the dtype for buffers.
  60. .. note:: This API is experimental and subject to change.
  61. .. note:: Only floating point tensors are cast to their specified dtypes.
  62. .. note:: ``state_dict`` checkpoints parameters and buffers in full
  63. precision.
  64. .. note:: Each low precision dtype must be specified explicitly. For
  65. example, ``_MixedPrecision(reduce_dtype=torch.float16)`` only specifies
  66. the reduction dtype to be low precision, and DDP will not cast
  67. parameters or buffers.
  68. .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
  69. happens in ``param_dtype`` if specified or the original parameter dtype
  70. otherwise. For example, ``_MixedPrecision(param_dtype=torch.float16)``
  71. would result in communication occurring in fp16.
  72. """
  73. param_dtype: Optional[torch.dtype] = None
  74. reduce_dtype: Optional[torch.dtype] = None
  75. buffer_dtype: Optional[torch.dtype] = None
  76. # TODO (rohan-varma): keep_low_precision_grads: bool = False
  77. # TODO (rohan-varma): APIs to allow users to run batchnorm and layernorm
  78. # in full precision. For DDP, this can be implemented by not performing the
  79. # parameter cast for BN and LN units.
  80. def _cast_buffers(mixed_precision_config, root_module):
  81. """Casts buffers to the given ``buffer_dtype``."""
  82. for buf in root_module.buffers():
  83. if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored:
  84. continue
  85. buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)
  86. def _setup_mixed_precision_params(mixed_precision_config, root_module):
  87. """Create and free storage for the mixed precision parameters."""
  88. for param in root_module.parameters():
  89. # Do not setup mixed precision for DDP ignored parameters.
  90. if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
  91. continue
  92. if not hasattr(param, "_mp_param"):
  93. param._mp_param = torch.zeros_like(
  94. param,
  95. device=param.device,
  96. dtype=mixed_precision_config.param_dtype,
  97. requires_grad=param.requires_grad,
  98. )
  99. _free_storage(param._mp_param)
  100. # _fp_param will point to the full precision param so it can be switched
  101. # back to at the end of forward / backward.
  102. param._fp_param = param.data
  103. def _tree_flatten_with_rref(output):
  104. output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
  105. if output_is_rref:
  106. output_tensor_list, treespec = tree_flatten(output.local_value())
  107. else:
  108. output_tensor_list, treespec = tree_flatten(output)
  109. # Need to return flattened tensors, spec to re-pack them, as well
  110. # as if the return type was actually an RRef to reconstruct.
  111. return output_tensor_list, treespec, output_is_rref
  112. def _tree_unflatten_with_rref(output, treespec, output_is_rref):
  113. output = tree_unflatten(output, treespec)
  114. if output_is_rref:
  115. output = RRef(output)
  116. return output
  117. def _find_tensors(obj):
  118. r"""Recursively find all tensors contained in the specified object."""
  119. if RPC_AVAILABLE and isinstance(obj, RRef):
  120. # If the current node is the owner of the RRef, unwrap it and try to
  121. # find Tensors.
  122. # TODO: Expand to remote RRefs.
  123. if obj.is_owner():
  124. return _find_tensors(obj.local_value())
  125. if isinstance(obj, torch.Tensor):
  126. return [obj]
  127. if isinstance(obj, (list, tuple)):
  128. return itertools.chain.from_iterable(map(_find_tensors, obj))
  129. if isinstance(obj, dict):
  130. return itertools.chain.from_iterable(map(_find_tensors, obj.values()))
  131. if is_dataclass(obj):
  132. return itertools.chain.from_iterable(
  133. map(_find_tensors, (getattr(obj, f.name) for f in fields(obj)))
  134. )
  135. return []
  136. def _dump_DDP_relevant_env_vars():
  137. relevant_env_vars = [
  138. "RANK",
  139. "LOCAL_RANK",
  140. "WORLD_SIZE",
  141. "MASTER_PORT",
  142. "MASTER_ADDR",
  143. "CUDA_VISIBLE_DEVICES",
  144. "GLOO_SOCKET_IFNAME",
  145. "GLOO_DEVICE_TRANSPORT",
  146. "NCCL_SOCKET_IFNAME",
  147. "TORCH_NCCL_BLOCKING_WAIT",
  148. "NCCL_DEBUG",
  149. "NCCL_DEBUG_SUBSYS",
  150. "NCCL_IB_DISABLE",
  151. # More NCCL env vars:
  152. "NCCL_P2P_DISABLE",
  153. "NCCL_P2P_LEVEL",
  154. "NCCL_SHM_DISABLE",
  155. "NCCL_SOCKET_NTHREADS",
  156. "NCCL_NSOCKS_PERTHREAD",
  157. "NCCL_BUFFSIZE",
  158. "NCCL_NTHREADS",
  159. "NCCL_RINGS",
  160. "NCCL_MAX_NCHANNELS",
  161. "NCCL_MIN_NCHANNELS",
  162. "NCCL_CHECKS_DISABLE",
  163. "NCCL_CHECK_POINTERS",
  164. "NCCL_LAUNCH_MODE",
  165. "NCCL_IB_HCA",
  166. "NCCL_IB_TIMEOUT",
  167. "NCCL_IB_RETRY_CNT",
  168. "NCCL_IB_GID_INDEX",
  169. "NCCL_IB_SL",
  170. "NCCL_IB_TC",
  171. "NCCL_IB_AR_THRESHOLD",
  172. "NCCL_IB_CUDA_SUPPORT",
  173. "NCCL_NET_GDR_LEVEL",
  174. "NCCL_NET_GDR_READ",
  175. "NCCL_SINGLE_RING_THRESHOLD",
  176. "NCCL_LL_THRESHOLD",
  177. "NCCL_TREE_THRESHOLD",
  178. "NCCL_ALGO",
  179. "NCCL_PROTO",
  180. "NCCL_IGNORE_CPU_AFFINITY",
  181. "NCCL_DEBUG_FILE",
  182. "NCCL_COLLNET_ENABLE",
  183. "NCCL_TOPO_FILE",
  184. "NCCL_TOPO_DUMP_FILE",
  185. "TORCH_NCCL_ASYNC_ERROR_HANDLING",
  186. ]
  187. formatted_output = ""
  188. for var in relevant_env_vars:
  189. value = os.environ[var] if var in os.environ else "N/A"
  190. formatted_output += f"env:{var}={value}\n"
  191. print(formatted_output)
  192. class _BufferCommHookLocation(Enum):
  193. PRE_FORWARD = auto()
  194. POST_FORWARD = auto()
  195. @dataclass
  196. class _BufferCommHook:
  197. buffer_comm_hook: Callable
  198. buffer_comm_hook_state: Any
  199. buffer_comm_hook_location: _BufferCommHookLocation
  200. # Add a DDPSink to run various functions when backwards starts, such as
  201. # queueing call back of out-most backward/graph task,
  202. # this helps call back is fired after all gradients' calculation
  203. # is completed.
  204. class _DDPSink(Function):
  205. @staticmethod
  206. def forward(ctx, ddp_weakref, *inputs):
  207. # set_materialize_grads(False) will ensure that None gradients stay as
  208. # None and are not filled with zeros.
  209. ctx.set_materialize_grads(False)
  210. ctx.ddp_weakref = ddp_weakref
  211. ret = inputs
  212. if ddp_weakref()._ddp_sink_clone:
  213. ret = tuple(
  214. inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
  215. )
  216. return ret
  217. @staticmethod
  218. def backward(ctx, *grad_outputs):
  219. # Enqueue delay allreduce for static graph training on the first
  220. # iteration.
  221. ddp_weakref = ctx.ddp_weakref()
  222. reducer = ddp_weakref.reducer
  223. static_graph = ddp_weakref.static_graph
  224. delay_ar_enqueued = (
  225. static_graph and ddp_weakref._static_graph_delay_allreduce_enqueued
  226. )
  227. if static_graph and not delay_ar_enqueued:
  228. Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
  229. reducer._delay_all_reduce
  230. )
  231. ddp_weakref._static_graph_delay_allreduce_enqueued = True
  232. return (None, *grad_outputs)
  233. class _DDPJoinHook(JoinHook):
  234. def __init__(self, ddp, divide_by_initial_world_size):
  235. """Set config variables for internal usage."""
  236. assert isinstance(ddp, DistributedDataParallel), (
  237. "DDP join hook requires passing in a DistributedDataParallel "
  238. "instance as the state"
  239. )
  240. assert ddp.logger is not None
  241. ddp.logger._set_uneven_input_join()
  242. self.ddp = ddp
  243. self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
  244. super().__init__()
  245. def main_hook(self):
  246. """Shadow the DDP collective communication operations in the forward and backward passes."""
  247. ddp = self.ddp
  248. # Buckets are rebuilt only once during a training period
  249. ddp.reducer._rebuild_buckets()
  250. # Schedule a broadcast if we are syncing module buffers in the
  251. # forward pass
  252. # TODO: make DDP uneven inputs context manager support buffer
  253. # comm hook (https://github.com/pytorch/pytorch/issues/65436)
  254. ddp._check_and_sync_module_buffers()
  255. # Check if need to sync in the backward pass
  256. should_sync_backwards = ddp._check_global_requires_backward_grad_sync(
  257. is_joined_rank=True
  258. )
  259. # Forward parameter sync is disabled in the next iteration if we
  260. # are skipping gradient sync this iteration, so set
  261. # `require_forward_param_sync` accordingly
  262. ddp.require_forward_param_sync = should_sync_backwards
  263. if not should_sync_backwards:
  264. return
  265. # Schedule one allreduce per gradient bucket to match the backward
  266. # pass allreduce
  267. ddp._match_all_reduce_for_bwd_pass()
  268. # Check if we need to allreduce locally unused parameters
  269. if ddp.find_unused_parameters:
  270. ddp._match_unused_params_allreduce()
  271. # Rebuilt parameters are pushed only once during a training period
  272. ddp.reducer._push_all_rebuilt_params()
  273. def post_hook(self, is_last_joiner: bool):
  274. """Sync the final model to ensure that the model is the same across all processes."""
  275. self.ddp._sync_final_model(is_last_joiner)
  276. class DistributedDataParallel(Module, Joinable):
  277. r"""Implement distributed data parallelism based on ``torch.distributed`` at module level.
  278. This container provides data parallelism by synchronizing gradients
  279. across each model replica. The devices to synchronize across are
  280. specified by the input ``process_group``, which is the entire world
  281. by default. Note that ``DistributedDataParallel`` does not chunk or
  282. otherwise shard the input across participating GPUs; the user is
  283. responsible for defining how to do so, for example through the use
  284. of a :class:`DistributedSampler`.
  285. See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
  286. The same constraints on input as in :class:`torch.nn.DataParallel` apply.
  287. Creation of this class requires that ``torch.distributed`` to be already
  288. initialized, by calling :func:`torch.distributed.init_process_group`.
  289. ``DistributedDataParallel`` is proven to be significantly faster than
  290. :class:`torch.nn.DataParallel` for single-node multi-GPU data
  291. parallel training.
  292. To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
  293. up ``N`` processes, ensuring that each process exclusively works on a single
  294. GPU from 0 to N-1. This can be done by either setting
  295. ``CUDA_VISIBLE_DEVICES`` for every process or by calling the following API for GPUs,
  296. >>> # xdoctest: +SKIP("undefined variables")
  297. >>> torch.cuda.set_device(i)
  298. or calling the unified API for :ref:`accelerator<accelerators>`,
  299. >>> # xdoctest: +SKIP("undefined variables")
  300. >>> torch.accelerator.set_device_index(i)
  301. where i is from 0 to N-1. In each process, you should refer the following
  302. to construct this module:
  303. >>> # xdoctest: +SKIP("undefined variables")
  304. >>> if torch.accelerator.is_available():
  305. >>> device_type = torch.accelerator.current_accelerator().type
  306. >>> vendor_backend = torch.distributed.get_default_backend_for_device(device_type)
  307. >>>
  308. >>> torch.distributed.init_process_group(
  309. >>> backend=vendor_backend, world_size=N, init_method='...'
  310. >>> )
  311. >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
  312. Or you can use the latest API for initialization:
  313. >>> torch.distributed.init_process_group(device_id=i)
  314. In order to spawn up multiple processes per node, you can use either
  315. ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.
  316. .. note::
  317. Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
  318. for a brief introduction to all features related to distributed training.
  319. .. note::
  320. ``DistributedDataParallel`` can be used in conjunction with
  321. :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
  322. per-rank optimizer states memory footprint. Please refer to
  323. `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
  324. for more details.
  325. .. note:: ``nccl`` backend is currently the fastest and highly recommended
  326. backend when using GPUs. This applies to both single-node and
  327. multi-node distributed training.
  328. .. note:: This module also supports mixed-precision distributed training.
  329. This means that your model can have different types of parameters such
  330. as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
  331. mixed types of parameters will just work fine.
  332. .. note:: If you use ``torch.save`` on one process to checkpoint the module,
  333. and ``torch.load`` on some other processes to recover it, make sure that
  334. ``map_location`` is configured properly for every process. Without
  335. ``map_location``, ``torch.load`` would recover the module to devices
  336. where the module was saved from.
  337. .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
  338. gradient will be ``M`` times smaller when compared to the same model
  339. trained on a single node with ``batch=M*N`` if the loss is summed (NOT
  340. averaged as usual) across instances in a batch (because the gradients
  341. between different nodes are averaged). You should take this into
  342. consideration when you want to obtain a mathematically equivalent
  343. training process compared to the local training counterpart. But in most
  344. cases, you can just treat a DistributedDataParallel wrapped model, a
  345. DataParallel wrapped model and an ordinary model on a single GPU as the
  346. same (E.g. using the same learning rate for equivalent batch size).
  347. .. note::
  348. Parameters are never broadcast between processes. The module performs
  349. an all-reduce step on gradients and assumes that they will be modified
  350. by the optimizer in all processes in the same way. Buffers
  351. (e.g. BatchNorm stats) are broadcast from the module in process of rank
  352. 0, to all other replicas in the system in every iteration.
  353. .. note::
  354. If you are using DistributedDataParallel in conjunction with the
  355. :ref:`distributed-rpc-framework`, you should always use
  356. :meth:`torch.distributed.autograd.backward` to compute gradients and
  357. :class:`torch.distributed.optim.DistributedOptimizer` for optimizing
  358. parameters.
  359. Example::
  360. >>> # xdoctest: +SKIP("undefined variables")
  361. >>> import torch.distributed.autograd as dist_autograd
  362. >>> from torch.nn.parallel import DistributedDataParallel as DDP
  363. >>> import torch
  364. >>> from torch import optim
  365. >>> from torch.distributed.optim import DistributedOptimizer
  366. >>> import torch.distributed.rpc as rpc
  367. >>> from torch.distributed.rpc import RRef
  368. >>>
  369. >>> t1 = torch.rand((3, 3), requires_grad=True)
  370. >>> t2 = torch.rand((3, 3), requires_grad=True)
  371. >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
  372. >>> ddp_model = DDP(my_model)
  373. >>>
  374. >>> # Setup optimizer
  375. >>> optimizer_params = [rref]
  376. >>> for param in ddp_model.parameters():
  377. >>> optimizer_params.append(RRef(param))
  378. >>>
  379. >>> dist_optim = DistributedOptimizer(
  380. >>> optim.SGD,
  381. >>> optimizer_params,
  382. >>> lr=0.05,
  383. >>> )
  384. >>>
  385. >>> with dist_autograd.context() as context_id:
  386. >>> pred = ddp_model(rref.to_here())
  387. >>> loss = loss_func(pred, target)
  388. >>> dist_autograd.backward(context_id, [loss])
  389. >>> dist_optim.step(context_id)
  390. .. note::
  391. DistributedDataParallel currently offers limited support for gradient
  392. checkpointing with :meth:`torch.utils.checkpoint`.
  393. If the checkpoint is done with use_reentrant=False (recommended), DDP
  394. will work as expected without any limitations.
  395. If, however, the checkpoint is done with use_reentrant=True (the default),
  396. DDP will work as expected when there are no unused parameters in the model
  397. and each layer is checkpointed at most once (make sure you are not passing
  398. `find_unused_parameters=True` to DDP). We currently do not support the
  399. case where a layer is checkpointed multiple times, or when there unused
  400. parameters in the checkpointed model.
  401. .. note::
  402. To let a non-DDP model load a state dict from a DDP model,
  403. :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present`
  404. needs to be applied to strip the prefix "module." in the DDP state dict before loading.
  405. .. warning::
  406. Constructor, forward method, and differentiation of the output (or a
  407. function of the output of this module) are distributed synchronization
  408. points. Take that into account in case different processes might be
  409. executing different code.
  410. .. warning::
  411. This module assumes all parameters are registered in the model by the
  412. time it is created. No parameters should be added nor removed later.
  413. Same applies to buffers.
  414. .. warning::
  415. This module assumes all parameters are registered in the model of each
  416. distributed processes are in the same order. The module itself will
  417. conduct gradient ``allreduce`` following the reverse order of the
  418. registered parameters of the model. In other words, it is users'
  419. responsibility to ensure that each distributed process has the exact
  420. same model and thus the exact same parameter registration order.
  421. .. warning::
  422. This module allows parameters with non-rowmajor-contiguous strides.
  423. For example, your model may contain some parameters whose
  424. :class:`torch.memory_format` is ``torch.contiguous_format``
  425. and others whose format is ``torch.channels_last``. However,
  426. corresponding parameters in different processes must have the
  427. same strides.
  428. .. warning::
  429. This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
  430. only work if gradients are to be accumulated in ``.grad`` attributes of
  431. parameters).
  432. .. warning::
  433. If you plan on using this module with a ``nccl`` backend or a ``gloo``
  434. backend (that uses Infiniband), together with a DataLoader that uses
  435. multiple workers, please change the multiprocessing start method to
  436. ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
  437. Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
  438. likely experience deadlocks if you don't change this setting.
  439. .. warning::
  440. You should never try to change your model's parameters after wrapping
  441. up your model with ``DistributedDataParallel``. Because, when
  442. wrapping up your model with ``DistributedDataParallel``, the constructor
  443. of ``DistributedDataParallel`` will register the additional gradient
  444. reduction functions on all the parameters of the model itself at the
  445. time of construction. If you change the model's parameters afterwards,
  446. gradient reduction functions no longer match the correct set of
  447. parameters.
  448. .. warning::
  449. Using ``DistributedDataParallel`` in conjunction with the
  450. :ref:`distributed-rpc-framework` is experimental and subject to change.
  451. Args:
  452. module (Module): module to be parallelized
  453. device_ids (list of int or torch.device): CUDA devices.
  454. 1) For single-device modules, ``device_ids`` can
  455. contain exactly one device id, which represents the only
  456. CUDA device where the input module corresponding to this process resides.
  457. Alternatively, ``device_ids`` can also be ``None``.
  458. 2) For multi-device modules and CPU modules,
  459. ``device_ids`` must be ``None``.
  460. When ``device_ids`` is ``None`` for both cases,
  461. both the input data for the forward pass and the actual module
  462. must be placed on the correct device.
  463. (default: ``None``)
  464. output_device (int or torch.device): Device location of output for
  465. single-device CUDA modules. For multi-device modules and
  466. CPU modules, it must be ``None``, and the module itself
  467. dictates the output location. (default: ``device_ids[0]``
  468. for single-device modules)
  469. broadcast_buffers (bool): Flag that enables syncing (broadcasting)
  470. buffers of the module at beginning of the ``forward``
  471. function. (default: ``True``)
  472. init_sync (bool): Whether to sync during initialization to verify param
  473. shapes and broadcast parameters and buffers.
  474. WARNING: if this is set to False the user is required
  475. to ensure themselves that the weights are the same on
  476. all ranks.
  477. (default: ``True``)
  478. process_group: The process group to be used for distributed data
  479. all-reduction. If ``None``, the default process group, which
  480. is created by :func:`torch.distributed.init_process_group`,
  481. will be used. (default: ``None``)
  482. bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into
  483. multiple buckets so that gradient reduction of each
  484. bucket can potentially overlap with backward computation.
  485. :attr:`bucket_cap_mb` controls the bucket size in
  486. MebiBytes (MiB). If ``None``, a default size of 25 MiB
  487. will be used. (default: ``None``)
  488. find_unused_parameters (bool): Traverse the autograd graph from all
  489. tensors contained in the return value of the
  490. wrapped module's ``forward`` function. Parameters
  491. that don't receive gradients as part of this
  492. graph are preemptively marked as being ready to
  493. be reduced. In addition, parameters that may have
  494. been used in the wrapped module's ``forward``
  495. function but were not part of loss computation and
  496. thus would also not receive gradients are
  497. preemptively marked as ready to be reduced.
  498. (default: ``False``)
  499. check_reduction: This argument is deprecated.
  500. gradient_as_bucket_view (bool): When set to ``True``, gradients will be views
  501. pointing to different offsets of ``allreduce`` communication
  502. buckets. This can reduce peak memory usage, where the
  503. saved memory size will be equal to the total gradients
  504. size. Moreover, it avoids the overhead of copying between
  505. gradients and ``allreduce`` communication buckets. When
  506. gradients are views, ``detach_()`` cannot be called on the
  507. gradients. If hitting such errors, please fix it by
  508. referring to the :meth:`~torch.optim.Optimizer.zero_grad`
  509. function in ``torch/optim/optimizer.py`` as a solution.
  510. Note that gradients will be views after first iteration, so
  511. the peak memory saving should be checked after first iteration.
  512. static_graph (bool): When set to ``True``, DDP knows the trained graph is
  513. static. Static graph means 1) The set of used and unused
  514. parameters will not change during the whole training loop; in
  515. this case, it does not matter whether users set
  516. ``find_unused_parameters = True`` or not. 2) How the graph is trained
  517. will not change during the whole training loop (meaning there is
  518. no control flow depending on iterations).
  519. When static_graph is set to be ``True``, DDP will support cases that
  520. can not be supported in the past:
  521. 1) Reentrant backwards.
  522. 2) Activation checkpointing multiple times.
  523. 3) Activation checkpointing when model has unused parameters.
  524. 4) There are model parameters that are outside of forward function.
  525. 5) Potentially improve performance when there are unused parameters,
  526. as DDP will not search graph in each iteration to detect unused
  527. parameters when static_graph is set to be ``True``.
  528. To check whether you can set static_graph to be ``True``, one way is to
  529. check ddp logging data at the end of your previous model training,
  530. if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
  531. can set ``static_graph = True`` as well.
  532. Example::
  533. >>> # xdoctest: +SKIP("undefined variables")
  534. >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
  535. >>> # Training loop
  536. >>> ...
  537. >>> ddp_logging_data = model_DDP._get_ddp_logging_data()
  538. >>> static_graph = ddp_logging_data.get("can_set_static_graph")
  539. delay_all_reduce_named_params (list of tuple of str and torch.nn.Parameter): a list
  540. of named parameters whose all reduce will be delayed when the gradient of
  541. the parameter specified in ``param_to_hook_all_reduce`` is ready. Other
  542. arguments of DDP do not apply to named params specified in this argument
  543. as these named params will be ignored by DDP reducer.
  544. param_to_hook_all_reduce (torch.nn.Parameter): a parameter to hook delayed all reduce
  545. of parameters specified in ``delay_all_reduce_named_params``.
  546. skip_all_reduce_unused_params: When set to True, DDP will skip reducing unused parameters.
  547. This requires that unused parameters remain the same across all ranks throughout
  548. the entire training process. If this condition is not met, it may cause
  549. desynchronization and result in training hang.
  550. Attributes:
  551. module (Module): the module to be parallelized.
  552. Example::
  553. >>> # xdoctest: +SKIP("undefined variables")
  554. >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
  555. >>> net = torch.nn.parallel.DistributedDataParallel(model)
  556. """
  557. # used to track whether the given thread is inside ddp forward for torchdynamo purposes
  558. _active_ddp_module: Optional["DistributedDataParallel"] = None
  559. def __init__(
  560. self,
  561. module,
  562. device_ids=None,
  563. output_device=None,
  564. dim=0,
  565. broadcast_buffers=True,
  566. init_sync=True,
  567. process_group=None,
  568. bucket_cap_mb=None,
  569. find_unused_parameters=False,
  570. check_reduction=False,
  571. gradient_as_bucket_view=False,
  572. static_graph=False,
  573. delay_all_reduce_named_params=None,
  574. param_to_hook_all_reduce=None,
  575. mixed_precision: Optional[_MixedPrecision] = None,
  576. device_mesh=None,
  577. skip_all_reduce_unused_params=False,
  578. ):
  579. super().__init__()
  580. Joinable.__init__(self)
  581. self._use_python_reducer = (
  582. torch._dynamo.utils.get_optimize_ddp_mode() == "python_reducer"
  583. )
  584. self.logger: Optional[dist.Logger] = None
  585. if bool(delay_all_reduce_named_params is not None) != bool(
  586. param_to_hook_all_reduce is not None
  587. ):
  588. self._log_and_throw(
  589. ValueError,
  590. "delay_all_reduce_named_params and param_to_hook_all_reduce "
  591. "need to be set at the same time.",
  592. )
  593. if process_group and device_mesh is not None:
  594. raise RuntimeError(
  595. "Cannot specify both process_group and device_mesh arguments."
  596. )
  597. elif process_group is None and device_mesh is None:
  598. self.process_group = _get_default_group()
  599. elif device_mesh is None:
  600. self.process_group = process_group
  601. else:
  602. if device_mesh.ndim != 1:
  603. raise RuntimeError(
  604. f"Only 1D device mesh is supported, but got {device_mesh}."
  605. )
  606. self.device_mesh = device_mesh
  607. self.process_group = device_mesh.get_group(mesh_dim=0)
  608. from torch.distributed.device_mesh import _mesh_resources
  609. root_mesh = _mesh_resources.get_root_mesh(device_mesh)
  610. # if a root mesh is not the same as device_mesh,
  611. # meaning the device_mesh is sliced out from the root mesh.
  612. if root_mesh != device_mesh:
  613. # TODO: This is a temporary work around to enable DDP + TP.
  614. # We should do the logic in DDP so that the 2D implementation is
  615. # sound and the state_dict works out of the box.
  616. # This has to be done before check UninitializedParameter.
  617. from torch.distributed.tensor.parallel.ddp import (
  618. _pre_dp_module_transform,
  619. )
  620. _pre_dp_module_transform(module)
  621. self._delay_all_reduce_params = []
  622. if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
  623. self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
  624. else:
  625. self.parameters_to_ignore = set()
  626. if delay_all_reduce_named_params is not None:
  627. for name, param in delay_all_reduce_named_params:
  628. self.parameters_to_ignore.add(name)
  629. self._delay_all_reduce_params.append(param)
  630. self._module_parameters = [
  631. p
  632. for n, p in module.named_parameters()
  633. if n not in self.parameters_to_ignore
  634. ]
  635. if not any(p.requires_grad for p in self._module_parameters):
  636. if len(self._delay_all_reduce_params):
  637. logger.info("Delay the AllReduce of all parameters.")
  638. else:
  639. self._log_and_throw(
  640. RuntimeError,
  641. "DistributedDataParallel is not needed when a module "
  642. "doesn't have any parameter that requires a gradient.",
  643. )
  644. if device_ids is not None and len(device_ids) > 1:
  645. self._log_and_throw(
  646. ValueError,
  647. "device_ids can only be None or contain a single element.",
  648. )
  649. self.is_multi_device_module = (
  650. len({p.device for p in self._module_parameters}) > 1
  651. )
  652. distinct_device_types = {
  653. p.device.type for p in self._module_parameters if p.device is not None
  654. }
  655. if len(distinct_device_types) != 1:
  656. self._log_and_throw(
  657. ValueError,
  658. "DistributedDataParallel's input module must be on "
  659. f"the same type of devices, but input module parameters locate in {distinct_device_types}.",
  660. )
  661. self.device_type = next(iter(distinct_device_types))
  662. if (
  663. device_ids is None
  664. or len(device_ids) == 0 # For backward compatibility.
  665. or self.device_type == "cpu"
  666. or self.is_multi_device_module
  667. ):
  668. if device_ids or output_device:
  669. self._log_and_throw(
  670. ValueError,
  671. "DistributedDataParallel device_ids and output_device arguments "
  672. "only work with single-device/multiple-device GPU modules or CPU modules, "
  673. f"but got device_ids {device_ids}, output_device {output_device}, "
  674. f"and module parameters { ({p.device for p in self._module_parameters}) }.", # noqa: E201,E202
  675. )
  676. self.device_ids = None
  677. self.output_device = None
  678. else:
  679. self.device_ids = [_get_device_index(x, True) for x in device_ids]
  680. if output_device is None:
  681. output_device = device_ids[0]
  682. self.output_device = _get_device_index(output_device, True)
  683. self.static_graph = False
  684. self.dim = dim
  685. self.module = module
  686. self.device = next(iter(self._module_parameters)).device
  687. self.broadcast_buffers = broadcast_buffers
  688. self.find_unused_parameters = find_unused_parameters
  689. self.require_backward_grad_sync = True
  690. self.require_forward_param_sync = True
  691. self.gradient_as_bucket_view = gradient_as_bucket_view
  692. self.mixed_precision = mixed_precision
  693. if self.mixed_precision is not None:
  694. logger.warning("Received mixed precision config %s", self.mixed_precision)
  695. if check_reduction:
  696. # This argument is no longer used since the reducer
  697. # will ensure reduction completes even if some parameters
  698. # do not receive gradients.
  699. warnings.warn(
  700. "The `check_reduction` argument in `DistributedDataParallel` "
  701. "module is deprecated. Please avoid using it.",
  702. FutureWarning,
  703. stacklevel=2,
  704. )
  705. # Check that a module does not have Uninitialized parameters
  706. for param in self._module_parameters:
  707. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  708. self._log_and_throw(
  709. RuntimeError,
  710. "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. "
  711. "Run a dummy forward pass to correctly initialize the modules",
  712. )
  713. # used for intra-node param sync and inter-node sync as well
  714. self.broadcast_bucket_size = int(250 * 1024 * 1024)
  715. # reduction bucket size
  716. if bucket_cap_mb is None:
  717. # default case (bucket cap is 25 MiB)
  718. bucket_cap_mb = 25
  719. self.bucket_bytes_cap_default = True
  720. else:
  721. self.bucket_bytes_cap_default = False
  722. self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
  723. # Whether to perform input tensor CPU to GPU copies on a side-stream
  724. self.use_side_stream_for_tensor_copies = (
  725. os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
  726. )
  727. # Initialize gradient buffers and register all reduce hook
  728. self._delay_grad_buffer: Optional[torch.Tensor] = None
  729. self._delay_grad_views: list[torch.Tensor] = []
  730. self._delay_all_reduce_all_params = False
  731. if len(self._delay_all_reduce_params) != 0:
  732. self._register_delay_all_reduce_hook(
  733. bucket_cap_mb=bucket_cap_mb,
  734. param_to_hook_all_reduce=param_to_hook_all_reduce,
  735. device_ids=device_ids,
  736. )
  737. if self._delay_all_reduce_all_params:
  738. return
  739. self.skip_all_reduce_unused_params = skip_all_reduce_unused_params
  740. # Build parameters for reducer.
  741. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  742. # All collectives during initialization are gated by this flag.
  743. if init_sync:
  744. # Verify model equivalence.
  745. _verify_param_shape_across_processes(self.process_group, parameters)
  746. # Sync params and buffers. Ensures all DDP models start off at the same value.
  747. _sync_module_states(
  748. module=self.module,
  749. process_group=self.process_group,
  750. broadcast_bucket_size=self.broadcast_bucket_size,
  751. src=0,
  752. params_and_buffers_to_ignore=self.parameters_to_ignore,
  753. broadcast_buffers=self.broadcast_buffers,
  754. )
  755. # In debug mode, build a mapping of parameter index -> parameter.
  756. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
  757. # Builds reducer.
  758. self._ddp_init_helper(
  759. parameters,
  760. expect_sparse_gradient,
  761. param_to_name_mapping,
  762. static_graph,
  763. )
  764. self._comm_hooks: list[tuple[Callable, object]] = []
  765. if self.mixed_precision is not None:
  766. _setup_mixed_precision_params(self.mixed_precision, self.module)
  767. _cast_buffers(self.mixed_precision, self.module)
  768. # Stream used for async low precision copies.
  769. self._mp_stream = torch.Stream()
  770. self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated]
  771. # Add forward pre-hook to root module to kick off copies to lower
  772. # precision.
  773. self.module.register_forward_pre_hook(
  774. self._root_copy_hook, prepend=False, with_kwargs=True
  775. )
  776. # Add forward pre hook to all submodules to wait for copy events
  777. # before running computation.
  778. for module in self.module.modules():
  779. module.register_forward_pre_hook(
  780. self._module_wait_for_copy_hook,
  781. prepend=False,
  782. with_kwargs=True,
  783. )
  784. # Set up callbacks in backward to upcast and use full precision
  785. # params. TODO (rohan-varma): Make this compose with general
  786. # comm hooks and apply_optimizer_in_backward. Importing inline to
  787. # avoid circular import issue.
  788. from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import (
  789. _AllreduceUpcastHookState,
  790. _reducer_allreduce_and_upcast_hook,
  791. )
  792. upcast_hook_state = _AllreduceUpcastHookState(
  793. ddp_weakref=weakref.ref(self),
  794. upcast_stream=torch.Stream(),
  795. )
  796. self.register_comm_hook(
  797. upcast_hook_state,
  798. _reducer_allreduce_and_upcast_hook,
  799. )
  800. # Inform reducer of reduced precision param dtype for correctness
  801. # of type checks between gradient and bucket.
  802. self.reducer._set_mixed_precision_param_dtype( # type: ignore[attr-defined]
  803. self.mixed_precision.param_dtype
  804. )
  805. self._has_rebuilt_buckets = False
  806. if static_graph:
  807. self._set_static_graph()
  808. self._lazy_init_ran = False
  809. # Register the AccumulateGrad post hooks if optimize_ddp is
  810. # True. The hooks will be deregistered if compiled_autograd is not
  811. # enabled.
  812. self._accum_grad_hooks: list[RemovableHandle] = []
  813. if self._use_python_reducer:
  814. torch._inductor.config._fuse_ddp_communication = True
  815. torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
  816. # Directly adding this to the trace rule will disturb the users
  817. # who are using DDPOptimizer.
  818. torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add(
  819. "torch.nn.parallel.distributed"
  820. )
  821. torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear()
  822. # NOTE: we should init these lazily
  823. self._register_accum_grad_hook()
  824. # Whether or not DDPSink performs a clone.
  825. self._ddp_sink_clone = True
  826. def _register_accum_grad_hook(self):
  827. import torch.distributed._functional_collectives as fcol
  828. def compiled_accum_grad_hook(
  829. param,
  830. *,
  831. param_index: int,
  832. ):
  833. if not self.require_backward_grad_sync:
  834. return
  835. if param.grad is None:
  836. return
  837. if self._comm_hooks:
  838. for hook, state in self._comm_hooks:
  839. hook(state, (param.grad, param))
  840. else:
  841. gradient = param.grad / self.process_group.size()
  842. gradient = fcol.all_reduce(gradient, "sum", self.process_group)
  843. param.grad.copy_(gradient)
  844. for index, param in enumerate(self._module_parameters):
  845. if not param.requires_grad:
  846. continue
  847. self._accum_grad_hooks.append(
  848. param.register_post_accumulate_grad_hook(
  849. functools.partial(
  850. compiled_accum_grad_hook,
  851. param_index=index,
  852. )
  853. )
  854. )
  855. def _delayed_all_reduce_hook(self, grad):
  856. world_size = dist.get_world_size(self.process_group)
  857. self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr]
  858. _ = dist.all_reduce(
  859. self._delay_grad_buffer, group=self.process_group, async_op=True
  860. )
  861. return grad
  862. def _register_delay_all_reduce_hook(
  863. self,
  864. bucket_cap_mb,
  865. param_to_hook_all_reduce,
  866. device_ids,
  867. ):
  868. # 1. Create gradient buffer
  869. device = torch.device("cpu") if device_ids is None else device_ids[0]
  870. self._delay_grad_buffer = torch.zeros(
  871. sum(p.numel() for p in self._delay_all_reduce_params),
  872. device=device,
  873. )
  874. # 2. Broadcast the parameters
  875. detached_params = [p.detach() for p in self._delay_all_reduce_params]
  876. dist._broadcast_coalesced(self.process_group, detached_params, bucket_cap_mb, 0)
  877. # 3. Hook all reduce to the specified parameter
  878. param_to_hook_all_reduce.register_hook(self._delayed_all_reduce_hook)
  879. # 4. Build tensor views for gradients
  880. offset = 0
  881. for param in self._delay_all_reduce_params:
  882. grad_view = self._delay_grad_buffer[offset : (offset + param.numel())].view(
  883. param.shape
  884. )
  885. self._delay_grad_views.append(grad_view)
  886. offset = offset + param.numel()
  887. # 5. Check whether the all reduce of all params requiring grad is delayed.
  888. for module_name, module in self.module.named_modules():
  889. for param_name, param in module.named_parameters(recurse=False):
  890. if param.requires_grad:
  891. full_name = f"{module_name}.{param_name}"
  892. if full_name not in self.parameters_to_ignore:
  893. # There is at least a param whose all reduce will not be delayed.
  894. # In this case, we should not set self._delay_all_reduce_all_params
  895. # to True.
  896. return
  897. self._delay_all_reduce_all_params = True
  898. def _setup_in_backward_optimizers(self):
  899. # Check if user has used apply_optim_in_backward to overlap optimizer
  900. # step + DDP backward. Current constraints:
  901. # 1. Only allreduce is supported at the moment, no custom communication.
  902. # 2. For DDP-managed parameters that have their optimizer run in
  903. # backward, their gradients are set to ``None``. If your use case
  904. # requires DDP parameters grad not to be set to ``None`` after their
  905. # in-backward optimizer runs, please ping
  906. # https://github.com/pytorch/pytorch/issues/90052.
  907. # NOTE: we use self._module_parameters instead of .parameters() since
  908. # the former excludes ignored (non-DDP managed) parameters.
  909. if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
  910. torch._C._log_api_usage_once("ddp.optimizer_in_backward")
  911. # Remove hooks that apply_optim_in_backward had registered because
  912. # DDP customizes how optimizer is overlapped with backward due to
  913. # the allreduce.
  914. param_to_handle_map = (
  915. dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
  916. )
  917. for p in self._module_parameters:
  918. for handle in param_to_handle_map.get(p, []):
  919. handle.remove()
  920. # Need a weakref to DDP instance to run all_reduce (from reducer)
  921. # and get managed DDP parameters.
  922. ddp_weakref = weakref.ref(self)
  923. # Note: importing in function, otherwise this will cause a circular
  924. # import.
  925. from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
  926. _apply_optim_in_backward_hook,
  927. )
  928. self.register_comm_hook(
  929. ddp_weakref,
  930. _apply_optim_in_backward_hook(
  931. gradient_is_bucket_view=self.gradient_as_bucket_view
  932. ),
  933. )
  934. self.reducer._set_optimizer_in_backward() # type: ignore[attr-defined]
  935. def _fire_reducer_autograd_hook(self, idx, *unused):
  936. """
  937. Fire the reducer's autograd hook to allreduce params in a Reducer bucket.
  938. Note that this is only used during mixed precision training as the
  939. Reducer's hooks installed during construction time would not be called
  940. as we're working in the low precision parameter setting.
  941. """
  942. self.reducer._autograd_hook(idx) # type: ignore[attr-defined]
  943. def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None:
  944. """
  945. For DDP mixed precision, put low precision copies on separate stream and create events to wait for them.
  946. When training with DDP mixed precision, this root pre-forward hook kicks
  947. off low precision copies on a separate stream and creates respective
  948. events to wait for them.
  949. """
  950. # Clear out previous iteration submodule to event. This is because we
  951. # may have populated some events for modules that didn't end up being
  952. # used.
  953. self._submodule_to_event = defaultdict(deque) # type: ignore[var-annotated]
  954. with self._mp_stream:
  955. for submodule in self.module.modules():
  956. for param in submodule.parameters(recurse=False):
  957. # Do not cast DDP ignored parameters.
  958. if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
  959. continue
  960. _alloc_storage(param._mp_param, param.size())
  961. # copy() implicitly casts to low precision
  962. with torch.no_grad():
  963. param._mp_param.copy_(param.data)
  964. # TODO: when zero_grad(set_to_none=False) or in grad
  965. # accumulation case, accumulated grads can be in fp32
  966. # which can cause errors when running DDP backwards due
  967. # to mismatched incoming and accumulated gradient types.
  968. # So we manually cast the accumulated grad down for now,
  969. # in the future we may shift to FSDP style gradient
  970. # accumulation management where the accumulated gradient
  971. # is saved and .grad field is set to None, bypassing
  972. # this issue.
  973. if param.grad is not None:
  974. param.grad.data = param.grad.to(
  975. self.mixed_precision.param_dtype # type: ignore[union-attr]
  976. )
  977. param.data = param._mp_param
  978. copy_event = torch.Event()
  979. copy_event.record()
  980. self._submodule_to_event[submodule].append(copy_event)
  981. def _module_wait_for_copy_hook(
  982. self,
  983. module,
  984. *args: Any,
  985. **kwargs: Any,
  986. ) -> None:
  987. """Before carrying out computation, wait on the appropriate event to ensure low precision copies have finished."""
  988. try:
  989. event = self._submodule_to_event[module].popleft()
  990. except IndexError:
  991. # copy event has already been waited on
  992. return
  993. event.wait(stream=torch.accelerator.current_stream())
  994. for p in module.parameters(recurse=False):
  995. # Don't register hooks if param does not require grad
  996. if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored):
  997. continue
  998. # We need to register autograd hook here instead of DDP's ctor
  999. # since we're working with the low precision param. Register them
  1000. # via obtaining the gradient accumulator.
  1001. tmp = p.expand_as(p)
  1002. grad_acc = tmp.grad_fn.next_functions[0][0]
  1003. hook = grad_acc.register_hook(
  1004. functools.partial(self._fire_reducer_autograd_hook, p._idx)
  1005. )
  1006. p._ddp_mp_hook_state = (grad_acc, hook)
  1007. def _log_and_throw(self, err_type, err_msg):
  1008. if self.logger is not None:
  1009. self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}")
  1010. raise err_type(err_msg)
  1011. def _ddp_init_helper(
  1012. self,
  1013. parameters,
  1014. expect_sparse_gradient,
  1015. param_to_name_mapping,
  1016. static_graph,
  1017. ):
  1018. """
  1019. DDP init helper function to manage parameters, grad hooks, logging, and SyncBatchNorm.
  1020. Initialization helper function that does the following:
  1021. (1) bucketing the parameters for reductions
  1022. (2) resetting the bucketing states
  1023. (3) registering the grad hooks
  1024. (4) Logging construction-time DDP logging data
  1025. (5) passing a handle of DDP to SyncBatchNorm Layer
  1026. """
  1027. # Notice, the parameters order is not in the order in which they are used,
  1028. # especially in models with control flow.
  1029. #
  1030. # Alongside parameters are not presented in the real execution order,
  1031. # if a certain model happens to also
  1032. # 1) have other collectives comm ops in its backward graph.
  1033. # 2) have unused parameter in subset ranks of the whole world.
  1034. # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter,
  1035. # matching up with other collectives comm ops on other ranks unexpectedly.
  1036. #
  1037. # In order to handle this corner case, when the parameters are not in the real execution order,
  1038. # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients
  1039. # of the whole graph are computed.
  1040. #
  1041. # Notice, here we only disable bucketing for the first iteration.
  1042. # After the first iteration, it's OK to rebuild buckets,
  1043. # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph.
  1044. # Can remove this branching once #73732 is landed.
  1045. if static_graph is True or self.find_unused_parameters is False:
  1046. bucket_size_limits = [sys.maxsize]
  1047. else:
  1048. if self.bucket_bytes_cap_default:
  1049. bucket_size_limits = [
  1050. dist._DEFAULT_FIRST_BUCKET_BYTES,
  1051. self.bucket_bytes_cap,
  1052. ]
  1053. else:
  1054. bucket_size_limits = [self.bucket_bytes_cap]
  1055. (
  1056. bucket_indices,
  1057. per_bucket_size_limits,
  1058. ) = dist._compute_bucket_assignment_by_size(
  1059. parameters,
  1060. bucket_size_limits,
  1061. expect_sparse_gradient,
  1062. )
  1063. # Remember index for parameters if we are in mixed precision, as we
  1064. # need to pass in index to Reducer's autograd hook via python.
  1065. if self.mixed_precision is not None:
  1066. for i, p in enumerate(parameters):
  1067. p._idx = i
  1068. # Note: reverse list of buckets because we want to approximate the
  1069. # order in which their gradients are produced, and assume they
  1070. # are used in the forward pass in the order they are defined.
  1071. self.reducer = dist.Reducer(
  1072. parameters,
  1073. list(reversed(bucket_indices)),
  1074. list(reversed(per_bucket_size_limits)),
  1075. self.process_group,
  1076. expect_sparse_gradient,
  1077. # The bucket size limit is specified in the constructor.
  1078. # Additionally, we allow for a single small bucket for parameters
  1079. # that are defined first, such that their gradients don't spill into
  1080. # a much larger bucket, adding unnecessary latency after gradient
  1081. # computation finishes. Experiments showed 1MB is a reasonable value.
  1082. self.bucket_bytes_cap,
  1083. self.find_unused_parameters,
  1084. self.gradient_as_bucket_view,
  1085. param_to_name_mapping,
  1086. # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
  1087. # bucket.
  1088. (
  1089. dist._DEFAULT_FIRST_BUCKET_BYTES
  1090. if self.bucket_bytes_cap_default
  1091. else self.bucket_bytes_cap
  1092. ),
  1093. self.skip_all_reduce_unused_params,
  1094. self._use_python_reducer,
  1095. )
  1096. self.logger = dist.Logger(self.reducer)
  1097. # Set as a weak reference to avoid reference cycle between
  1098. # logger and reducer.
  1099. self.reducer.set_logger(self.logger)
  1100. has_sync_bn = False
  1101. for submodule in self.module.modules():
  1102. if isinstance(submodule, torch.nn.SyncBatchNorm):
  1103. has_sync_bn = True
  1104. break
  1105. # Set logging data that can be got during construction time.
  1106. self.logger.set_construction_data_and_log(
  1107. self.module.__class__.__name__,
  1108. [] if self.device_ids is None else self.device_ids,
  1109. -1 if self.output_device is None else self.output_device,
  1110. self.broadcast_buffers,
  1111. has_sync_bn,
  1112. static_graph,
  1113. )
  1114. # passing a handle to torch.nn.SyncBatchNorm layer
  1115. self._passing_sync_batchnorm_handle(self.module)
  1116. def __getstate__(self):
  1117. self._check_default_group()
  1118. attrs = copy.copy(self.__dict__)
  1119. del attrs["process_group"]
  1120. del attrs["reducer"]
  1121. del attrs["logger"]
  1122. return attrs
  1123. def __setstate__(self, state):
  1124. # If serializable, then the process group should be the default one
  1125. self.process_group = _get_default_group()
  1126. super().__setstate__(state)
  1127. self.__dict__.setdefault("require_forward_param_sync", True)
  1128. self.__dict__.setdefault("require_backward_grad_sync", True)
  1129. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  1130. # In debug mode, build a mapping of parameter index -> parameter.
  1131. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
  1132. # Builds reducer.
  1133. self._ddp_init_helper(
  1134. parameters,
  1135. expect_sparse_gradient,
  1136. param_to_name_mapping,
  1137. self.static_graph,
  1138. )
  1139. if self.static_graph:
  1140. self.reducer._set_static_graph()
  1141. assert self.logger is not None
  1142. self.logger._set_static_graph()
  1143. def _build_params_for_reducer(self):
  1144. # Build tuple of (module, parameter) for all parameters that require grads.
  1145. modules_and_parameters = [
  1146. (module, parameter)
  1147. for module_name, module in self.module.named_modules()
  1148. for parameter in [
  1149. param
  1150. # Note that we access module.named_parameters instead of
  1151. # parameters(module). parameters(module) is only needed in the
  1152. # single-process multi device case, where it accesses replicated
  1153. # parameters through _former_parameters.
  1154. for param_name, param in module.named_parameters(recurse=False)
  1155. if param.requires_grad
  1156. and f"{module_name}.{param_name}" not in self.parameters_to_ignore
  1157. ]
  1158. ]
  1159. # Deduplicate any parameters that might be shared across child modules.
  1160. memo = set()
  1161. modules_and_parameters = [
  1162. # "p not in memo" is the deduplication check.
  1163. # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
  1164. (m, p)
  1165. for m, p in modules_and_parameters
  1166. if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
  1167. ]
  1168. # Build list of parameters.
  1169. parameters = [parameter for _, parameter in modules_and_parameters]
  1170. # Checks if a module will produce a sparse gradient.
  1171. def produces_sparse_gradient(module):
  1172. if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)):
  1173. return module.sparse
  1174. return False
  1175. # Build list of booleans indicating whether or not to expect sparse
  1176. # gradients for the corresponding parameters.
  1177. expect_sparse_gradient = [
  1178. produces_sparse_gradient(module) for module, _ in modules_and_parameters
  1179. ]
  1180. self._assign_modules_buffers()
  1181. return parameters, expect_sparse_gradient
  1182. def _assign_modules_buffers(self):
  1183. """
  1184. Assign self.module.named_buffers to self.modules_buffers.
  1185. Assigns module buffers to self.modules_buffers which are then used to
  1186. broadcast across ranks when broadcast_buffers=True. Note that this
  1187. must be called every time buffers need to be synced because buffers can
  1188. be reassigned by user module,
  1189. see https://github.com/pytorch/pytorch/issues/63916.
  1190. """
  1191. # Collect buffers for modules, filtering out buffers that should be ignored.
  1192. named_module_buffers = [
  1193. (buffer, buffer_name)
  1194. for buffer_name, buffer in self.module.named_buffers()
  1195. if buffer_name not in self.parameters_to_ignore
  1196. ]
  1197. self.modules_buffers = [
  1198. buffer for (buffer, buffer_name) in named_module_buffers
  1199. ]
  1200. # Dict[str, tensor] representing module buffers not ignored by DDP.
  1201. self.named_module_buffers = {
  1202. buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
  1203. }
  1204. def _build_debug_param_to_name_mapping(self, parameters):
  1205. param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
  1206. param_set = set(parameters)
  1207. param_index_to_param_fqn = {}
  1208. for module_name, module in self.module.named_modules():
  1209. for param_name, param in module.named_parameters(recurse=False):
  1210. fqn = f"{module_name}.{param_name}"
  1211. # Bypass ignored parameters since those are not reduced by DDP
  1212. # to begin with.
  1213. if fqn not in self.parameters_to_ignore and param.requires_grad:
  1214. if param not in param_set:
  1215. self._log_and_throw(
  1216. ValueError,
  1217. f"Param with name {fqn} found in module parameters, but not DDP parameters."
  1218. " This indicates a bug in DDP, please report an issue to PyTorch.",
  1219. )
  1220. param_index = param_to_param_index[param]
  1221. param_index_to_param_fqn[param_index] = fqn
  1222. # Ensure we covered all parameters
  1223. if len(param_set) != len(param_index_to_param_fqn):
  1224. self._log_and_throw(
  1225. ValueError,
  1226. (
  1227. "Expected param to name mapping to cover all parameters, but"
  1228. f" got conflicting lengths: {len(param_set)} vs "
  1229. f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP"
  1230. ", please report an issue to PyTorch."
  1231. ),
  1232. )
  1233. return param_index_to_param_fqn
  1234. def _get_parameters(self, m, recurse=True):
  1235. """Return a generator of module parameters."""
  1236. def model_parameters(m):
  1237. ps = (
  1238. m._former_parameters.values()
  1239. if hasattr(m, "_former_parameters")
  1240. else m.parameters(recurse=False)
  1241. )
  1242. yield from ps
  1243. for mod in m.modules() if recurse else [m]:
  1244. yield from model_parameters(mod)
  1245. def _check_default_group(self):
  1246. pickle_not_supported = False
  1247. try:
  1248. if self.process_group != _get_default_group():
  1249. pickle_not_supported = True
  1250. except RuntimeError:
  1251. pickle_not_supported = True
  1252. if pickle_not_supported:
  1253. self._log_and_throw(
  1254. RuntimeError,
  1255. "DDP Pickling/Unpickling are only supported "
  1256. "when using DDP with the default process "
  1257. "group. That is, when you have called "
  1258. "init_process_group and have not passed "
  1259. "process_group argument to DDP constructor",
  1260. )
  1261. @contextmanager
  1262. def no_sync(self):
  1263. r"""
  1264. Context manager to disable gradient synchronizations across DDP processes.
  1265. Within this context, gradients will be accumulated on module
  1266. variables, which will later be synchronized in the first
  1267. forward-backward pass exiting the context.
  1268. Example::
  1269. >>> # xdoctest: +SKIP("undefined variables")
  1270. >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
  1271. >>> with ddp.no_sync():
  1272. >>> for input in inputs:
  1273. >>> ddp(input).backward() # no synchronization, accumulate grads
  1274. >>> ddp(another_input).backward() # synchronize grads
  1275. .. warning::
  1276. The forward pass should be included inside the context manager, or
  1277. else gradients will still be synchronized.
  1278. """
  1279. old_require_backward_grad_sync = self.require_backward_grad_sync
  1280. self.require_backward_grad_sync = False
  1281. try:
  1282. yield
  1283. finally:
  1284. self.require_backward_grad_sync = old_require_backward_grad_sync
  1285. @classmethod
  1286. def _get_active_ddp_module(cls):
  1287. """`TorchDynamo` requires DDP's status and module for cooperative optimization."""
  1288. return cls._active_ddp_module
  1289. # note, this ctxmgr function is marked 'skip' in torchdynamo, so dynamo only kicks in
  1290. # for the 'module_to_run' underneath
  1291. # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
  1292. @contextmanager
  1293. @torch._disable_dynamo(recursive=False)
  1294. def _inside_ddp_forward(self):
  1295. DistributedDataParallel._active_ddp_module = self
  1296. try:
  1297. yield
  1298. finally:
  1299. DistributedDataParallel._active_ddp_module = None
  1300. def _run_ddp_forward(self, *inputs, **kwargs):
  1301. if self._use_python_reducer:
  1302. return self.module(*inputs, **kwargs) # type: ignore[index]
  1303. else:
  1304. with self._inside_ddp_forward():
  1305. return self.module(*inputs, **kwargs) # type: ignore[index]
  1306. def _clear_grad_buffer(self):
  1307. # Making param.grad points to the grad buffers before backward is based on the
  1308. # assumption that the grad accumulation is done in place in autograd engine,
  1309. # for some edge cases, if the grad accumulation in autograd engine is not in
  1310. # place, then the param.grad and grad buffers are detached.
  1311. if self._delay_grad_buffer is not None:
  1312. # We batch zero_grad for all params by resetting the whole grad
  1313. # buffer when the grad of all params is set to None.
  1314. all_param_grad_none = all(
  1315. param.grad is None for param in self._delay_all_reduce_params
  1316. )
  1317. for index, param in enumerate(self._delay_all_reduce_params):
  1318. if param.grad is None:
  1319. param.grad = self._delay_grad_views[index]
  1320. if not all_param_grad_none:
  1321. param.grad.zero_()
  1322. if all_param_grad_none:
  1323. self._delay_grad_buffer.zero_()
  1324. def _lazy_init(self):
  1325. # Initialization for DDP that occurs after construction, but lazily
  1326. # before the first forward pass.
  1327. self._setup_in_backward_optimizers()
  1328. self._lazy_init_ran = True
  1329. def _pre_forward(self, *inputs, **kwargs):
  1330. if self._use_python_reducer:
  1331. return inputs, kwargs
  1332. if not self._lazy_init_ran and not torch.compiler.is_compiling():
  1333. self._lazy_init()
  1334. if self._delay_all_reduce_all_params:
  1335. return inputs, kwargs
  1336. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  1337. assert self.logger is not None
  1338. self.logger.set_runtime_stats_and_log()
  1339. self.reducer.prepare_for_forward()
  1340. # Notify the join context that this process has not joined, if
  1341. # needed
  1342. work = Join.notify_join_context(self)
  1343. if work:
  1344. self.reducer._set_forward_pass_work_handle(
  1345. work,
  1346. self._divide_by_initial_world_size, # type: ignore[arg-type]
  1347. )
  1348. # Calling _rebuild_buckets before forward computation,
  1349. # It may allocate new buckets before deallocating old buckets
  1350. # inside _rebuild_buckets. To save peak memory usage,
  1351. # call _rebuild_buckets before the peak memory usage increases
  1352. # during forward computation.
  1353. # This should be called only once during whole training period.
  1354. if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
  1355. logger.info("Reducer buckets have been rebuilt in this iteration.")
  1356. self._has_rebuilt_buckets = True
  1357. # sync params according to location (before/after forward) user
  1358. # specified as part of hook, if hook was specified.
  1359. if self._check_sync_bufs_pre_fwd():
  1360. self._sync_buffers()
  1361. if self._join_config.enable:
  1362. # Notify joined ranks whether they should sync in backwards pass or not.
  1363. self._check_global_requires_backward_grad_sync(is_joined_rank=False)
  1364. if self.device_ids:
  1365. moved_inputs, moved_kwargs = _to_kwargs(
  1366. inputs,
  1367. kwargs,
  1368. torch.device(self.device_type, self.device_ids[0]),
  1369. self.use_side_stream_for_tensor_copies,
  1370. )
  1371. args, kwargs = moved_inputs[0], moved_kwargs[0]
  1372. # Cast inputs to reduced precision if needed.
  1373. if self.mixed_precision is not None:
  1374. args, kwargs = _cast_forward_inputs(
  1375. self.mixed_precision.param_dtype,
  1376. *args,
  1377. **kwargs,
  1378. )
  1379. return args, kwargs
  1380. else:
  1381. # Cast inputs to reduced precision if needed.
  1382. # TODO (rohan-varma) test this codepath.
  1383. if self.mixed_precision is not None:
  1384. inputs, kwargs = _cast_forward_inputs(
  1385. self.mixed_precision.param_dtype,
  1386. *inputs,
  1387. **kwargs,
  1388. )
  1389. return inputs, kwargs
  1390. def _post_forward(self, output):
  1391. if self._use_python_reducer:
  1392. return output
  1393. if self._delay_all_reduce_all_params:
  1394. self._clear_grad_buffer()
  1395. return output
  1396. # sync params according to location (before/after forward) user
  1397. # specified as part of hook, if hook was specified.
  1398. if self._check_sync_bufs_post_fwd():
  1399. self._sync_buffers()
  1400. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  1401. self.require_forward_param_sync = True
  1402. # We'll return the output object verbatim since it is a freeform
  1403. # object. We need to find any tensors in this object, though,
  1404. # because we need to figure out which parameters were used during
  1405. # this forward pass, to ensure we short circuit reduction for any
  1406. # unused parameters. Only if `find_unused_parameters` is set.
  1407. if self.find_unused_parameters and not self.static_graph:
  1408. # Do not need to populate this for static graph.
  1409. self.reducer.prepare_for_backward(list(_find_tensors(output)))
  1410. else:
  1411. self.reducer.prepare_for_backward([])
  1412. else:
  1413. self.require_forward_param_sync = False
  1414. # TODO: DDPSink is currently enabled for unused parameter detection and
  1415. # static graph training for first iteration.
  1416. if (self.find_unused_parameters and not self.static_graph) or (
  1417. self.static_graph and not self._static_graph_delay_allreduce_enqueued
  1418. ):
  1419. (
  1420. output_tensor_list,
  1421. treespec,
  1422. output_is_rref,
  1423. ) = _tree_flatten_with_rref(output)
  1424. output_placeholders: list[Optional[torch.Tensor]] = [
  1425. None for _ in range(len(output_tensor_list))
  1426. ]
  1427. # Do not touch tensors that have no grad_fn, which can cause issues
  1428. # such as https://github.com/pytorch/pytorch/issues/60733
  1429. for i, output in enumerate(output_tensor_list):
  1430. if torch.is_tensor(output) and output.grad_fn is None:
  1431. output_placeholders[i] = output
  1432. # When find_unused_parameters=True, makes tensors which require grad
  1433. # run through the DDPSink backward pass. When not all outputs are
  1434. # used in loss, this makes those corresponding tensors receive
  1435. # undefined gradient which the reducer then handles to ensure
  1436. # param.grad field is not touched and we don't error out.
  1437. passthrough_tensor_list = _DDPSink.apply(
  1438. weakref.ref(self),
  1439. *output_tensor_list,
  1440. )
  1441. for i in range(len(output_placeholders)):
  1442. if output_placeholders[i] is None:
  1443. output_placeholders[i] = passthrough_tensor_list[i]
  1444. # Reconstruct output data structure.
  1445. output = _tree_unflatten_with_rref(
  1446. output_placeholders, treespec, output_is_rref
  1447. )
  1448. # At the end of the forward pass, reset the grad buffer and grad views
  1449. self._clear_grad_buffer()
  1450. return output
  1451. def forward(self, *inputs, **kwargs):
  1452. with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
  1453. inputs, kwargs = self._pre_forward(*inputs, **kwargs)
  1454. output = (
  1455. self.module.forward(*inputs, **kwargs)
  1456. if self._delay_all_reduce_all_params
  1457. else self._run_ddp_forward(*inputs, **kwargs)
  1458. )
  1459. return self._post_forward(output)
  1460. def scatter(self, inputs, kwargs, device_ids):
  1461. return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  1462. def to_kwargs(self, inputs, kwargs, device_id):
  1463. # Kept for BC
  1464. return _to_kwargs(
  1465. inputs,
  1466. kwargs,
  1467. torch.device(self.device_type, device_id),
  1468. self.use_side_stream_for_tensor_copies,
  1469. )
  1470. def gather(self, outputs, output_device):
  1471. return gather(outputs, output_device, dim=self.dim)
  1472. def train(self, mode=True):
  1473. super().train(mode)
  1474. return self
  1475. # When running in join mode, schedules an allreduce to notify joined ranks
  1476. # of whether backwards pass synchronization will run this iteration or not.
  1477. def _check_global_requires_backward_grad_sync(self, is_joined_rank):
  1478. if not is_joined_rank and self.require_backward_grad_sync:
  1479. requires_sync_tensor = torch.ones(1, device=self.device)
  1480. else:
  1481. requires_sync_tensor = torch.zeros(1, device=self.device)
  1482. work = dist.all_reduce(
  1483. requires_sync_tensor, group=self.process_group, async_op=True
  1484. )
  1485. # (kwen2501) This if condition is a plain translation of previous
  1486. # behavior, i.e. in the `is_joined_rank=False` case, `work.wait()`
  1487. # is not called and it doesn't care about the result. I am guessing
  1488. # that it just wants to fire a matching all-reduce and does not want
  1489. # the main stream to wait.
  1490. if is_joined_rank:
  1491. work.wait()
  1492. should_sync_backwards = requires_sync_tensor.item() != 0
  1493. return should_sync_backwards
  1494. else:
  1495. return None # Return value is not/should not be used.
  1496. # When running in join mode, checks and performs sync of module buffers if
  1497. # the models have buffers that should be synchronized in the forward pass.
  1498. def _check_and_sync_module_buffers(self):
  1499. if self._check_sync_bufs_pre_fwd():
  1500. authoritative_rank = self._find_common_rank(self._distributed_rank, False)
  1501. self._sync_module_buffers(authoritative_rank)
  1502. # When running in join model, agrees upon a common rank and broadcast model
  1503. # parameters to all other ranks.
  1504. def _sync_final_model(self, is_last_joiner):
  1505. # Agree upon the process that will be the authoritative model copy.
  1506. # The current rank is a candidate for being the authoritative copy if
  1507. # is_last_joiner=True. We break ties via picking the larger rank.
  1508. self._authoritative_rank = self._find_common_rank(
  1509. self._distributed_rank, is_last_joiner
  1510. )
  1511. _sync_module_states(
  1512. module=self.module,
  1513. process_group=self.process_group,
  1514. broadcast_bucket_size=self.broadcast_bucket_size,
  1515. src=self._authoritative_rank,
  1516. params_and_buffers_to_ignore=self.parameters_to_ignore,
  1517. broadcast_buffers=self.broadcast_buffers,
  1518. )
  1519. # Schedule comm ops to match those scheduled in the reducer's backward
  1520. # pass.
  1521. def _match_all_reduce_for_bwd_pass(self):
  1522. comm_work = []
  1523. # Schedule comm in the same order as Reducer schedules them, i.e.
  1524. # the order of the buckets. Retrieving the bucket order from the reducer
  1525. # ensures that we keep the same order in join mode, such as when bucket
  1526. # order is rebuilt dynamically.
  1527. # Returns grad_buckets in order, but real tensors are substituted with
  1528. # zero tensors of the same shape.
  1529. grad_buckets = self.reducer._get_zeros_like_grad_buckets()
  1530. for grad_bucket in grad_buckets:
  1531. # Joined processes contribute zero gradient. In the case that
  1532. # divide_by_initial_world_size=True, we divide grads by the static
  1533. # world size, if not, the dividing factor is reduced by the number
  1534. # of joined processes.
  1535. work = self.reducer._run_comm_hook(grad_bucket)
  1536. comm_work.append(work)
  1537. for work in comm_work:
  1538. work.wait()
  1539. # Allreduces the used parameter mapping across ranks.
  1540. def _match_unused_params_allreduce(self):
  1541. locally_used_param_map = self.reducer._get_local_used_map()
  1542. self.process_group.allreduce(locally_used_param_map)
  1543. def join(
  1544. self,
  1545. divide_by_initial_world_size: bool = True,
  1546. enable: bool = True,
  1547. throw_on_early_termination: bool = False,
  1548. ):
  1549. r"""
  1550. Context manager for training with uneven inputs across processes in DDP.
  1551. This context manager will keep track of already-joined DDP processes,
  1552. and "shadow" the forward and backward passes by inserting collective
  1553. communication operations to match with the ones created by non-joined
  1554. DDP processes. This will ensure each collective call has a corresponding
  1555. call by already-joined DDP processes, preventing hangs or errors that
  1556. would otherwise happen when training with uneven inputs across
  1557. processes. Alternatively, if the flag ``throw_on_early_termination`` is
  1558. specified to be ``True``, all trainers will throw an error once one rank
  1559. runs out of inputs, allowing these errors to be caught and handled
  1560. according to application logic.
  1561. Once all DDP processes have joined, the context manager will broadcast
  1562. the model corresponding to the last joined process to all processes to
  1563. ensure the model is the same across all processes
  1564. (which is guaranteed by DDP).
  1565. To use this to enable training with uneven inputs across processes,
  1566. simply wrap this context manager around your training loop. No further
  1567. modifications to the model or data loading is required.
  1568. .. warning::
  1569. If the model or training loop this context manager is wrapped around
  1570. has additional distributed collective operations, such as
  1571. ``SyncBatchNorm`` in the model's forward pass, then the flag
  1572. ``throw_on_early_termination`` must be enabled. This is because this
  1573. context manager is not aware of non-DDP collective communication.
  1574. This flag will cause all ranks to throw when any one rank
  1575. exhausts inputs, allowing these errors to be caught and recovered
  1576. from across all ranks.
  1577. Args:
  1578. divide_by_initial_world_size (bool): If ``True``, will divide
  1579. gradients by the initial ``world_size`` DDP training was launched
  1580. with. If ``False``, will compute the effective world size
  1581. (number of ranks that have not depleted their inputs yet) and
  1582. divide gradients by that during allreduce. Set
  1583. ``divide_by_initial_world_size=True`` to ensure every input
  1584. sample including the uneven inputs have equal weight in terms of
  1585. how much they contribute to the global gradient. This is
  1586. achieved by always dividing the gradient by the initial
  1587. ``world_size`` even when we encounter uneven inputs. If you set
  1588. this to ``False``, we divide the gradient by the remaining
  1589. number of nodes. This ensures parity with training on a smaller
  1590. ``world_size`` although it also means the uneven inputs would
  1591. contribute more towards the global gradient. Typically, you
  1592. would want to set this to ``True`` for cases where the last few
  1593. inputs of your training job are uneven. In extreme cases, where
  1594. there is a large discrepancy in the number of inputs, setting
  1595. this to ``False`` might provide better results.
  1596. enable (bool): Whether to enable uneven input detection or not. Pass
  1597. in ``enable=False`` to disable in cases where you know that
  1598. inputs are even across participating processes. Default is
  1599. ``True``.
  1600. throw_on_early_termination (bool): Whether to throw an error
  1601. or continue training when at least one rank has exhausted
  1602. inputs. If ``True``, will throw upon the first rank reaching end
  1603. of data. If ``False``, will continue training with a smaller
  1604. effective world size until all ranks are joined. Note that if
  1605. this flag is specified, then the flag
  1606. ``divide_by_initial_world_size`` would be ignored. Default
  1607. is ``False``.
  1608. Example::
  1609. >>> # xdoctest: +SKIP("Distributed")
  1610. >>> import torch
  1611. >>> import torch.distributed as dist
  1612. >>> import os
  1613. >>> import torch.multiprocessing as mp
  1614. >>> import torch.nn as nn
  1615. >>> # On each spawned worker
  1616. >>> def worker(rank):
  1617. >>> dist.init_process_group("nccl", rank=rank, world_size=2)
  1618. >>> torch.cuda.set_device(rank)
  1619. >>> model = nn.Linear(1, 1, bias=False).to(rank)
  1620. >>> model = torch.nn.parallel.DistributedDataParallel(
  1621. >>> model, device_ids=[rank], output_device=rank
  1622. >>> )
  1623. >>> # Rank 1 gets one more input than rank 0.
  1624. >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
  1625. >>> with model.join():
  1626. >>> for _ in range(5):
  1627. >>> for inp in inputs:
  1628. >>> loss = model(inp).sum()
  1629. >>> loss.backward()
  1630. >>> # Without the join() API, the below synchronization will hang
  1631. >>> # blocking for rank 1's allreduce to complete.
  1632. >>> torch.cuda.synchronize(device=rank)
  1633. """
  1634. return Join(
  1635. [self],
  1636. enable,
  1637. throw_on_early_termination,
  1638. divide_by_initial_world_size=divide_by_initial_world_size,
  1639. )
  1640. def join_hook(
  1641. self,
  1642. **kwargs,
  1643. ):
  1644. r"""
  1645. DDP join hook enables training on uneven inputs by mirroring communications in forward and backward passes.
  1646. Arguments:
  1647. kwargs (dict): a :class:`dict` containing any keyword arguments
  1648. to modify the behavior of the join hook at run time; all
  1649. :class:`Joinable` instances sharing the same join context
  1650. manager are forwarded the same value for ``kwargs``.
  1651. The hook supports the following keyword arguments:
  1652. divide_by_initial_world_size (bool, optional):
  1653. If ``True``, then gradients are divided by the initial world
  1654. size that DDP was launched with.
  1655. If ``False``, then gradients are divided by the effective world
  1656. size (i.e. the number of non-joined processes), meaning that
  1657. the uneven inputs contribute more toward the global gradient.
  1658. Typically, this should be set to ``True`` if the degree of
  1659. unevenness is small but can be set to ``False`` in extreme
  1660. cases for possibly better results.
  1661. Default is ``True``.
  1662. """
  1663. divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
  1664. return _DDPJoinHook(
  1665. self, divide_by_initial_world_size=divide_by_initial_world_size
  1666. )
  1667. @property
  1668. def join_device(self):
  1669. return self.device
  1670. @property
  1671. def join_process_group(self):
  1672. return self.process_group
  1673. def _register_buffer_comm_hook(
  1674. self,
  1675. state,
  1676. hook: Callable,
  1677. comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
  1678. ):
  1679. r"""
  1680. Allow custom registration of hooks that define how buffer are synchronized across ranks.
  1681. The hook takes in an optional state and is passed in a Dict[str, Tensor]
  1682. corresponding to buffer names and the buffers, and can run arbitrary reductions
  1683. on buffers as opposed to DDP's default broadcast from rank 0. This is useful for
  1684. example if a counter needs to be summed or averaged across ranks every iteration.
  1685. Args:
  1686. state (Any): Optional state that is passed to the hook.
  1687. hook (Callable): Callable with the following signature:
  1688. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``
  1689. comm_hook_location (_BufferCommHookLocation): Enum value indicating
  1690. where to run the hook.
  1691. _BufferCommHookLocation.PRE_FORWARD means that the
  1692. hook will run _before_ the forward pass, and
  1693. _BufferCommHookLocation.POST_FORWARD means that the
  1694. hook will run _after_ the forward pass.
  1695. NOTE: To maximize performance, users can return a
  1696. List[torch.futures.Future] from their hook, and DDP will
  1697. install and await these hooks appropriately at the end of
  1698. the backward pass. This will ensure all buffers are
  1699. synchronized by the end of the backward pass. If this
  1700. setting is used, it is recommended to pass
  1701. comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
  1702. which will trigger the hook after the forward pass.
  1703. If _BufferCommHookLocation.PRE_FORWARD is used, users must
  1704. ensure appropriate synchronization when manipulating GPU
  1705. buffers in the forward pass.
  1706. """
  1707. assert callable(hook)
  1708. self.buffer_hook = _BufferCommHook(
  1709. buffer_comm_hook=hook,
  1710. buffer_comm_hook_state=state,
  1711. buffer_comm_hook_location=comm_hook_location,
  1712. )
  1713. def register_comm_hook(self, state: object, hook: Callable):
  1714. r"""
  1715. Register communication hook for user-defined DDP aggregation of gradients across multiple workers.
  1716. This hook would be very useful for researchers to try out new ideas. For
  1717. example, this hook can be used to implement several algorithms like GossipGrad
  1718. and gradient compression which involve different communication strategies for
  1719. parameter syncs while running Distributed DataParallel training.
  1720. Args:
  1721. state (object): Passed to the hook to maintain any state information during the training process.
  1722. Examples include error feedback in gradient compression,
  1723. peers to communicate with next in GossipGrad, etc.
  1724. It is locally stored by each worker
  1725. and shared by all the gradient tensors on the worker.
  1726. hook (Callable): Callable with the following signature:
  1727. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
  1728. This function is called once the bucket is ready. The
  1729. hook can perform whatever processing is needed and return
  1730. a Future indicating completion of any async work (ex: allreduce).
  1731. If the hook doesn't perform any communication, it still
  1732. must return a completed Future. The Future should hold the
  1733. new value of grad bucket's tensors. Once a bucket is ready,
  1734. c10d reducer would call this hook and use the tensors returned
  1735. by the Future and copy grads to individual parameters.
  1736. Note that the future's return type must be a single tensor.
  1737. We also provide an API called ``get_future`` to retrieve a
  1738. Future associated with the completion of ``c10d.ProcessGroup.Work``.
  1739. ``get_future`` is currently supported for NCCL and also supported for most
  1740. operations on GLOO and MPI, except for peer to peer operations (send/recv).
  1741. .. warning ::
  1742. Grad bucket's tensors will not be predivided by world_size. User is responsible
  1743. to divide by the world_size in case of operations like allreduce.
  1744. .. warning ::
  1745. DDP communication hook can only be registered once and should be registered
  1746. before calling backward.
  1747. .. warning ::
  1748. The Future object that hook returns should contain a single tensor
  1749. that has the same shape with the tensors inside grad bucket.
  1750. .. warning ::
  1751. ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support
  1752. for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
  1753. Example::
  1754. Below is an example of a noop hook that returns the same tensor.
  1755. >>> # xdoctest: +SKIP('undefined name')
  1756. >>> def noop(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
  1757. >>> fut = torch.futures.Future()
  1758. >>> fut.set_result(bucket.buffer())
  1759. >>> return fut
  1760. >>> ddp.register_comm_hook(state=None, hook=noop)
  1761. Example::
  1762. Below is an example of a Parallel SGD algorithm where gradients are encoded before
  1763. allreduce, and then decoded after allreduce.
  1764. >>> # xdoctest: +SKIP('undefined name')
  1765. >>> def encode_and_decode(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
  1766. >>> encoded_tensor = encode(bucket.buffer()) # encode gradients
  1767. >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future()
  1768. >>> # Define the then callback to decode.
  1769. >>> def decode(fut):
  1770. >>> decoded_tensor = decode(fut.value()[0]) # decode gradients
  1771. >>> return decoded_tensor
  1772. >>> return fut.then(decode)
  1773. >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
  1774. """
  1775. self._check_comm_hook(hook)
  1776. assert self.logger is not None
  1777. self.logger._set_comm_hook_name(hook.__qualname__)
  1778. self._comm_hooks.append((hook, state))
  1779. dist._register_comm_hook(self.reducer, state, hook)
  1780. def _register_builtin_comm_hook(self, comm_hook_type):
  1781. r"""
  1782. Register a built-in communication hook that specifies how DDP aggregates gradients across multiple workers.
  1783. The built-in hooks aim to provide efficient C++ implementations for certain hooks,
  1784. which might not be as efficient if implemented in Python using a Python communication hook.
  1785. Args:
  1786. comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc.
  1787. .. warning ::
  1788. DDP communication hook can only be registered once and should be registered
  1789. before calling backward.
  1790. Example::
  1791. Below is an example of a FP16 compression where gradients are
  1792. compressed into 16-bit floating-point numbers before allreduce, and
  1793. then decompressed after allreduce.
  1794. >>> # xdoctest: +SKIP('undefined name')
  1795. >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
  1796. """
  1797. assert self.logger is not None
  1798. self.logger._set_comm_hook_name(str(comm_hook_type))
  1799. dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
  1800. def _register_fused_optim(self, optim: type, *args, optim_params=None, **kwargs):
  1801. r"""
  1802. Register an optimizer in DDP to optimize parameter immediately after its gradient reduction.
  1803. Registers an optimizer with DDP such that the optimization for a
  1804. parameter will run immediately when that parameter's gradient is
  1805. finished with reduction, instead of waiting for all parameters'
  1806. gradients to finish reduction. This can result in a training speedup
  1807. depending on your workload since the optimizer can run while gradient
  1808. reduction for other parameters are still ongoing. In addition, this has
  1809. the potential to reduce peak memory consumption during training, as it
  1810. only needs to load the per-parameter optimizer states of a single
  1811. parameter at a time, instead of loading all per-parameter optimizer
  1812. states at once.
  1813. Args:
  1814. optim (Type): a ``torch.optim.Optimizer`` class to be registered
  1815. as a fused optimizer.
  1816. *args (Sequence[Any]): Arguments to forward to `optim`.
  1817. optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
  1818. to optimize, similar to `params` argument of traditional `torch.optim`
  1819. Optimizers. If this is omitted, all DDP model parameters will be
  1820. optimized.
  1821. **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim`.
  1822. .. warning ::
  1823. _register_fused_optim should only be called once on a DDP instance,
  1824. and registering multiple fused optimizers for the same DDP model
  1825. is not currently supported. Please ping
  1826. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1827. for your use case.
  1828. .. warning ::
  1829. _register_fused_optim and register_comm_hook currently do not
  1830. compose together, meaning that custom DDP communication hooks are
  1831. not supported with overlapped optimizers. Please ping
  1832. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1833. for your use case.
  1834. .. warning ::
  1835. Gradient accumulation and DDP `no_sync` are currently not supported
  1836. with overlapped optimizer. Please ping
  1837. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1838. for your use case.
  1839. Example::
  1840. >>> # xdoctest: +SKIP("No rendezvous handler")
  1841. >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
  1842. >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
  1843. >>> lr = 1e-2
  1844. >>> betas = (0.9, 0.99)
  1845. >>> eps = 1e-6
  1846. >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps)
  1847. >>> # Example with subset of parameters
  1848. >>> params_to_opt = [list(net.parameters())[0]]
  1849. >>> net._register_fused_optim(
  1850. ... torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps
  1851. ... )
  1852. """
  1853. # Note: importing in function, otherwise this will cause a circular
  1854. # import as optimizer_overlap module needs to import DistributedDataParallel.
  1855. from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
  1856. overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
  1857. try:
  1858. overlapped_optim.register_ddp(self)
  1859. except NotImplementedError as e:
  1860. raise RuntimeError(
  1861. f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
  1862. ) from e
  1863. def _distributed_broadcast_coalesced(
  1864. self, tensors, buffer_size, authoritative_rank=0
  1865. ):
  1866. dist._broadcast_coalesced(
  1867. self.process_group, tensors, buffer_size, authoritative_rank
  1868. )
  1869. def _check_sync_bufs_post_fwd(self):
  1870. return (
  1871. self.will_sync_module_buffers()
  1872. and hasattr(self, "buffer_hook")
  1873. and self.buffer_hook.buffer_comm_hook_location
  1874. == _BufferCommHookLocation.POST_FORWARD
  1875. )
  1876. def _check_sync_bufs_pre_fwd(self):
  1877. return self.will_sync_module_buffers() and (
  1878. not hasattr(self, "buffer_hook")
  1879. or self.buffer_hook.buffer_comm_hook_location
  1880. == _BufferCommHookLocation.PRE_FORWARD
  1881. )
  1882. def will_sync_module_buffers(self):
  1883. return (
  1884. self.require_forward_param_sync
  1885. and self.broadcast_buffers
  1886. and len(self.modules_buffers) > 0
  1887. )
  1888. def _find_common_rank(self, input_rank, rank_cond):
  1889. # -1 indicates that this rank is not under consideration to be the
  1890. # common_rank
  1891. rank_to_use = torch.tensor(
  1892. [input_rank if rank_cond else -1],
  1893. device=self.device,
  1894. )
  1895. dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group)
  1896. if rank_to_use.item() == -1:
  1897. self._log_and_throw(
  1898. ValueError,
  1899. "BUG! Expected rank_cond to be true for at least one process."
  1900. " This indicates a bug in PyTorch, please report an issue.",
  1901. )
  1902. return rank_to_use.item()
  1903. def _sync_buffers(self):
  1904. with torch.no_grad():
  1905. # module buffer sync
  1906. # Synchronize buffers across processes.
  1907. # If we are running DDP with the join manager, we have to agree
  1908. # upon a rank to sync module buffers from, since rank 0 may
  1909. # already have been joined and have stale module buffers.
  1910. if self._join_config.enable:
  1911. authoritative_rank = self._find_common_rank(
  1912. self._distributed_rank, True
  1913. )
  1914. else:
  1915. # The process with rank 0 is considered the authoritative copy.
  1916. authoritative_rank = 0
  1917. # Update self.modules_buffers in case any buffers were
  1918. # reassigned.
  1919. self._assign_modules_buffers()
  1920. self._sync_module_buffers(authoritative_rank)
  1921. def _sync_module_buffers(self, authoritative_rank):
  1922. if not hasattr(self, "buffer_hook"):
  1923. self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
  1924. else:
  1925. hook = self.buffer_hook.buffer_comm_hook
  1926. state = self.buffer_hook.buffer_comm_hook_state
  1927. futs = hook(state, self.named_module_buffers)
  1928. if futs is not None:
  1929. self.reducer._install_post_backward_futures(futs)
  1930. def _default_broadcast_coalesced(
  1931. self, bufs=None, bucket_size=None, authoritative_rank=0
  1932. ):
  1933. """
  1934. Broadcasts buffers from rank 0 to rest of workers.
  1935. If bufs, bucket_size are None, default values self.modules_buffers
  1936. and self.broadcast_bucket_size are used instead.
  1937. """
  1938. if bufs is None:
  1939. bufs = self.modules_buffers
  1940. if bucket_size is None:
  1941. bucket_size = self.broadcast_bucket_size
  1942. self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
  1943. def _passing_sync_batchnorm_handle(self, module):
  1944. for layer in module.modules():
  1945. if isinstance(layer, torch.nn.modules.SyncBatchNorm):
  1946. if self.device_type == "cpu":
  1947. self._log_and_throw(
  1948. ValueError,
  1949. "SyncBatchNorm layers only work with GPU modules",
  1950. )
  1951. def _check_comm_hook(self, hook):
  1952. if not callable(hook):
  1953. self._log_and_throw(TypeError, "Communication hook must be callable.")
  1954. sig = inspect.signature(hook)
  1955. if (
  1956. sig.parameters["bucket"].annotation != inspect._empty
  1957. and sig.parameters["bucket"].annotation != dist.GradBucket
  1958. ):
  1959. self._log_and_throw(
  1960. ValueError,
  1961. "Communication hook: bucket annotation should be dist.GradBucket.",
  1962. )
  1963. if (
  1964. sig.return_annotation != inspect._empty
  1965. and sig.return_annotation != torch.futures.Future[torch.Tensor]
  1966. ):
  1967. self._log_and_throw(
  1968. ValueError,
  1969. "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].",
  1970. )
  1971. if hook.__name__ in ["bf16_compress_hook", "bf16_compress_wrapper_hook"]:
  1972. cuda_supported = (
  1973. torch.version.cuda is not None
  1974. ) or torch.version.hip is not None
  1975. nccl_supported = (
  1976. dist.is_available()
  1977. and dist.is_nccl_available()
  1978. and torch.cuda.nccl.version() >= (2, 10)
  1979. )
  1980. xpu_xccl_supported = (
  1981. dist.is_available()
  1982. and dist.is_xccl_available()
  1983. and torch.xpu.is_available()
  1984. )
  1985. if not ((cuda_supported and nccl_supported) or xpu_xccl_supported):
  1986. self._log_and_throw(
  1987. TypeError,
  1988. "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+ or XPU and XCCL",
  1989. )
  1990. @property
  1991. def _distributed_rank(self):
  1992. return dist.get_rank(self.process_group)
  1993. @staticmethod
  1994. def _get_data_parallel_params(module, named_params=False):
  1995. """Return a generator of parameters managed by a given DDP unit."""
  1996. for param in (
  1997. module.parameters() if not named_params else module.named_parameters()
  1998. ):
  1999. if not hasattr(param, "_ddp_ignored"):
  2000. yield param
  2001. @staticmethod
  2002. def _set_params_and_buffers_to_ignore_for_model(
  2003. module, params_and_buffers_to_ignore
  2004. ):
  2005. """
  2006. Set parameters and buffers to be ignored by DDP.
  2007. Expected format for parameters is the fully qualified name: {module_name}.{param_name}, and
  2008. similarly, {module_name}.{buffer_name} for buffers. For example:
  2009. params_to_ignore = []
  2010. # NB: model here is vanilla PyTorch module, not yet wrapped with DDP.
  2011. for module_name, module in model.named_modules():
  2012. for param_name, param in module.named_parameters(recurse=False):
  2013. if should_ignore(param):
  2014. # Create expected format
  2015. fqn = f"{module_name}.{param_name}"
  2016. params_to_ignore.append(fqn)
  2017. torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
  2018. model,
  2019. params_to_ignore
  2020. )
  2021. """
  2022. # This is a workaround to set parameters and buffers DDP should ignore
  2023. # during synchronization. It will be removed when the API is finalized
  2024. # as part of addressing https://github.com/pytorch/pytorch/issues/43690.
  2025. module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
  2026. for name, param in module.named_parameters():
  2027. if name in params_and_buffers_to_ignore:
  2028. param._ddp_ignored = True
  2029. for name, buffer in module.named_buffers():
  2030. if name in params_and_buffers_to_ignore:
  2031. buffer._ddp_ignored = True
  2032. def _get_ddp_logging_data(self):
  2033. r"""
  2034. Return a dictionary of logging data for debugging and analysis.
  2035. This interface can be called after DistributedDataParallel() is
  2036. constructed. It returns a dictionary of logging data. It could help
  2037. for debugging and analysis. The logging data includes DistributedDataParallel
  2038. constructor input parameters, some internal states of DistributedDataParallel
  2039. and performance metrics. Simply print the dictionary and see what
  2040. these metrics are.
  2041. This is a prototype interface and subject to change in the future.
  2042. """
  2043. assert self.logger is not None
  2044. ddp_logging_data = self.logger._get_ddp_logging_data()
  2045. return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
  2046. def _set_ddp_runtime_logging_sample_rate(self, sample_rate):
  2047. r"""
  2048. Set sample_rate of collecting runtime stats.
  2049. This interface allows users to set sample_rate of collecting
  2050. runtime stats. The runtime stats will be recorded for the
  2051. first 10 iterations, after 10 iterations runtime stats will be
  2052. recorded once every "sample_rate" training iterations. In
  2053. default, runtime stats are recorded for the first 10 iterations,
  2054. after 10 iterations runtime stats are recorded once every
  2055. "kDDPRuntimeLoggingSampleRate=100" training iterations.
  2056. This is a prototype interface and subject to change in the future.
  2057. """
  2058. if sample_rate < 1:
  2059. self._log_and_throw(
  2060. ValueError,
  2061. "DDP runtime logging sample rate should be equal or greater than 1",
  2062. )
  2063. self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate)
  2064. def _set_static_graph(self):
  2065. """
  2066. Set static graph for DDP.
  2067. It is recommended to set static graph in the DDP constructor, which will
  2068. call this private API internally.
  2069. """
  2070. # If self.static_graph has been set, no need to set it again
  2071. if self.static_graph:
  2072. warnings.warn(
  2073. "You've set static_graph to be True, no need to set it again."
  2074. )
  2075. return
  2076. self.static_graph = True
  2077. self._static_graph_delay_allreduce_enqueued = False
  2078. self.reducer._set_static_graph()
  2079. assert self.logger is not None
  2080. self.logger._set_static_graph()
  2081. if self.find_unused_parameters:
  2082. warnings.warn(
  2083. "You passed find_unused_parameters=true to DistributedDataParallel, "
  2084. "`_set_static_graph` will detect unused parameters automatically, so "
  2085. "you do not need to set find_unused_parameters=true, just be sure these "
  2086. "unused parameters will not change during training loop while calling "
  2087. "`_set_static_graph`."
  2088. )
  2089. def _remove_autograd_hooks(self):
  2090. """Remove autograd hooks registered by the reducer on the model parameters."""
  2091. self.reducer._remove_autograd_hooks()
  2092. def _check_reducer_finalized(self):
  2093. """
  2094. Check if the reducer has processed all buckets and finalized the backward appropriately.
  2095. It is useful to call this method after calling .backward() in your training loop
  2096. in order to avoid subsequent hard to debug errors down the road due to the
  2097. reducer not finalizing backward.
  2098. """
  2099. self.reducer._check_reducer_finalized()
  2100. def _set_sparse_metadata(self, global_unique_ids):
  2101. self.reducer._set_sparse_metadata(global_unique_ids)
  2102. def _update_process_group(self, new_process_group):
  2103. """
  2104. Dynamically updates the process group for DDP so that we can shrink/expand DDP
  2105. world size without having to reinitialize DDP.
  2106. NOTE: If you are using custom communications hooks via, register_comm_hook,
  2107. you need to update the process groups for those hooks separately.
  2108. """
  2109. # Force a rebuild of buckets for a new process group. This ensures all ranks
  2110. # are synchronized in terms of when they will rebuild buckets and also
  2111. # re-evaluates previous assumptions of buckets given the world size might have
  2112. # changed.
  2113. self._has_rebuilt_buckets = False
  2114. self.reducer._reset_state()
  2115. if not _rank_not_in_group(new_process_group):
  2116. self.process_group = new_process_group
  2117. self.reducer._update_process_group(new_process_group)
  2118. def _set_ddp_sink_clone(self, val: bool):
  2119. """
  2120. Sets whether or not DDPSink should clone the output tensors or not.
  2121. The default is True since if the loss is modified in place we run
  2122. into the view is modified in-place error.
  2123. Although, cloning the tensors can add significant memory and
  2124. performance hit if the number and size of tensors are large. As
  2125. a result, this can be set to False if you are not modifying the
  2126. loss in place.
  2127. """
  2128. self._ddp_sink_clone = val