state.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373
  1. # Copyright 2021 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import logging
  16. import os
  17. import threading
  18. import warnings
  19. import weakref
  20. from contextlib import contextmanager
  21. from functools import partial
  22. from typing import Any, Callable
  23. import torch
  24. from .utils import (
  25. DistributedType,
  26. DynamoBackend,
  27. GradientAccumulationPlugin,
  28. check_cuda_fp8_capability,
  29. check_cuda_p2p_ib_support,
  30. deepspeed_required,
  31. get_cpu_distributed_information,
  32. get_int_from_env,
  33. is_ccl_available,
  34. is_datasets_available,
  35. is_deepspeed_available,
  36. is_fp8_available,
  37. is_habana_gaudi1,
  38. is_hpu_available,
  39. is_ipex_available,
  40. is_mlu_available,
  41. is_mps_available,
  42. is_musa_available,
  43. is_npu_available,
  44. is_sdaa_available,
  45. is_torch_xla_available,
  46. is_xccl_available,
  47. is_xpu_available,
  48. parse_choice_from_env,
  49. parse_flag_from_env,
  50. set_numa_affinity,
  51. )
  52. from .utils.dataclasses import SageMakerDistributedType
  53. if is_torch_xla_available():
  54. import torch_xla.core.xla_model as xm
  55. import torch_xla.runtime as xr
  56. if is_mlu_available(check_device=False):
  57. import torch_mlu # noqa: F401
  58. if is_sdaa_available(check_device=False):
  59. import torch_sdaa # noqa: F401
  60. if is_musa_available(check_device=False):
  61. import torch_musa # noqa: F401
  62. if is_npu_available(check_device=False):
  63. import torch_npu # noqa: F401
  64. logger = logging.getLogger(__name__)
  65. def is_initialized() -> bool:
  66. """
  67. Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`,
  68. but works as a module method.
  69. """
  70. return AcceleratorState._shared_state != {}
  71. # Lambda function that does nothing
  72. def do_nothing(*args, **kwargs):
  73. return None
  74. class ThreadLocalSharedDict(threading.local):
  75. """
  76. Descriptor that holds a dict shared between instances of a class in the same thread.
  77. Note: Descriptors have slightly different semantics than just a dict field on its own.
  78. `PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the
  79. underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside
  80. the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor
  81. object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`).
  82. See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html
  83. This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3).
  84. See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3
  85. """
  86. def __init__(self, thread_local: bool = False):
  87. self._storage = {}
  88. def __get__(self, obj, objtype=None):
  89. return self._storage
  90. def __set__(self, obj, value):
  91. self._storage = value
  92. # Prefer global shared dictionary, except when using TPU.
  93. SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
  94. # Inspired by Alex Martelli's 'Borg'.
  95. class PartialState:
  96. """
  97. Singleton class that has information about the current training environment and functions to help with process
  98. control. Designed to be used when only process control and device execution states are needed. Does *not* need to
  99. be initialized from `Accelerator`.
  100. Args:
  101. cpu (`bool`, *optional*):
  102. Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to
  103. `True` and force the execution on the CPU.
  104. kwargs (additional keyword arguments, *optional*):
  105. Additional keyword arguments to pass to the relevant `init_process_group` function. Valid `kwargs` can be
  106. found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.
  107. **Available attributes:**
  108. - **device** (`torch.device`) -- The device to use.
  109. - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
  110. in use.
  111. - **local_process_index** (`int`) -- The index of the current process on the current server.
  112. - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
  113. of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
  114. - **num_processes** (`int`) -- The number of processes currently launched in parallel.
  115. - **process_index** (`int`) -- The index of the current process.
  116. - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
  117. - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
  118. - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
  119. - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
  120. Example:
  121. ```python
  122. from accelerate.utils import InitProcessGroupKwargs
  123. # To include `InitProcessGroupKwargs`, init then call `.to_kwargs()`
  124. kwargs = InitProcessGroupKwargs(...).to_kwargs()
  125. state = PartialState(**kwargs)
  126. ```
  127. """
  128. _shared_state = SharedDict()
  129. _known_attrs = [
  130. "_cpu",
  131. "_mixed_precision",
  132. "_shared_state",
  133. "backend",
  134. "debug",
  135. "device",
  136. "distributed_type",
  137. "fork_launched",
  138. "local_process_index",
  139. "num_processes",
  140. "process_index",
  141. ]
  142. def __init__(self, cpu: bool = False, **kwargs):
  143. self.__dict__ = self._shared_state
  144. if not self.initialized:
  145. self._cpu = cpu
  146. self.backend = None
  147. env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
  148. self.device = torch.device(env_device) if env_device is not None else None
  149. self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
  150. use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None)
  151. dist_information = None
  152. if use_sagemaker_dp is None:
  153. use_sagemaker_dp = (
  154. os.environ.get("ACCELERATE_USE_SAGEMAKER", "false").lower() == "true"
  155. and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
  156. )
  157. # Sets up self.backend + imports
  158. original_backend = kwargs.pop("backend", None)
  159. backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
  160. if original_backend is not None and backend != original_backend:
  161. raise ValueError(f"Your assigned backend {original_backend} is not available, please use {backend}")
  162. self.backend = backend
  163. self.distributed_type = distributed_type
  164. use_deepspeed = False
  165. if not cpu and self.backend != "xla":
  166. if int(os.environ.get("LOCAL_RANK", -1)) != -1:
  167. # Deal with spawning deepspeed
  168. if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true":
  169. if not is_deepspeed_available():
  170. raise ImportError(
  171. "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
  172. )
  173. from deepspeed import comm as dist
  174. if not dist.is_initialized():
  175. if self.backend == "tccl":
  176. local_rank = os.environ.get("LOCAL_RANK", -1)
  177. torch.sdaa.set_device(f"sdaa:{local_rank}")
  178. dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
  179. # We need to flag to `use_deepspeed` to be True to override `distributed_type` later
  180. use_deepspeed = True
  181. # Deal with all other backends but XPU and CPU, that gets handled special later
  182. elif (
  183. self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU)
  184. and not torch.distributed.is_initialized()
  185. ):
  186. if self.backend == "tccl":
  187. local_rank = os.environ.get("LOCAL_RANK", -1)
  188. torch.sdaa.set_device(f"sdaa:{local_rank}")
  189. if (
  190. self.backend == "nccl"
  191. and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
  192. and (
  193. os.environ.get("FSDP_OFFLOAD_PARAMS", "false").lower() == "true"
  194. or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT"
  195. )
  196. ):
  197. self.backend = "cuda:nccl,cpu:gloo"
  198. torch.distributed.init_process_group(backend=self.backend, **kwargs)
  199. # XPU and CPU require special env configs to be set
  200. if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU):
  201. dist_information = get_cpu_distributed_information()
  202. os.environ["RANK"] = str(dist_information.rank)
  203. os.environ["WORLD_SIZE"] = str(dist_information.world_size)
  204. os.environ["LOCAL_RANK"] = str(dist_information.local_rank)
  205. os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size)
  206. if not os.environ.get("MASTER_PORT", None):
  207. os.environ["MASTER_PORT"] = "29500"
  208. if (
  209. not os.environ.get("MASTER_ADDR", None)
  210. and dist_information.local_world_size != dist_information.world_size
  211. and self.backend != "mpi"
  212. ):
  213. raise ValueError(
  214. "Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, "
  215. "please try exporting rank 0's hostname as `MASTER_ADDR`"
  216. )
  217. kwargs["rank"] = dist_information.rank
  218. kwargs["world_size"] = dist_information.world_size
  219. if (
  220. self.distributed_type == DistributedType.MULTI_CPU
  221. and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0
  222. ):
  223. import psutil
  224. num_cpu_threads_per_process = int(
  225. psutil.cpu_count(logical=False) / dist_information.local_world_size
  226. )
  227. if num_cpu_threads_per_process == 0:
  228. num_cpu_threads_per_process = 1
  229. torch.set_num_threads(num_cpu_threads_per_process)
  230. warnings.warn(
  231. f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob"
  232. " performance."
  233. )
  234. if not torch.distributed.is_initialized():
  235. torch.distributed.init_process_group(backend=self.backend, **kwargs)
  236. # No backend == no distributed training
  237. if self.backend is None:
  238. self.distributed_type = DistributedType.NO
  239. self.num_processes = 1
  240. self.process_index = 0
  241. self.local_process_index = 0
  242. elif self.backend == "xla":
  243. # XLA needs device setting first for `set_replication`
  244. self.set_device()
  245. xm.set_replication(self.device, xm.get_xla_supported_devices())
  246. self.num_processes = xr.world_size()
  247. self.process_index = xr.global_ordinal()
  248. if is_torch_xla_available(check_is_tpu=True):
  249. self.local_process_index = xm.get_local_ordinal()
  250. else:
  251. self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
  252. else:
  253. self.num_processes = torch.distributed.get_world_size()
  254. self.process_index = torch.distributed.get_rank()
  255. self.local_process_index = (
  256. int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
  257. )
  258. self.set_device()
  259. # Now we can change to deepseed
  260. if use_deepspeed:
  261. self.distributed_type = DistributedType.DEEPSPEED
  262. # Set CPU affinity if enabled
  263. if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False):
  264. set_numa_affinity(self.local_process_index)
  265. # Check for old RTX 4000's that can't use P2P or IB and are on old drivers
  266. if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
  267. if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
  268. raise NotImplementedError(
  269. "Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. "
  270. 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
  271. "will do this automatically."
  272. )
  273. # Important: This should be the *only* code outside of `self.initialized!`
  274. self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
  275. def __repr__(self) -> str:
  276. return (
  277. f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n"
  278. f"Num processes: {self.num_processes}\n"
  279. f"Process index: {self.process_index}\n"
  280. f"Local process index: {self.local_process_index}\n"
  281. f"Device: {self.device}\n"
  282. )
  283. @staticmethod
  284. def _reset_state():
  285. "Resets `_shared_state`, is used internally and should not be called"
  286. PartialState._shared_state.clear()
  287. @property
  288. def initialized(self) -> bool:
  289. "Returns whether the `PartialState` has been initialized"
  290. return self._shared_state != {}
  291. @property
  292. def use_distributed(self):
  293. """
  294. Whether the Accelerator is configured for distributed training
  295. """
  296. return self.distributed_type != DistributedType.NO and self.num_processes > 1
  297. @property
  298. def is_last_process(self) -> bool:
  299. "Returns whether the current process is the last one"
  300. return self.process_index == self.num_processes - 1
  301. @property
  302. def is_main_process(self) -> bool:
  303. "Returns whether the current process is the main process"
  304. return (
  305. self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process
  306. )
  307. @property
  308. def is_local_main_process(self) -> bool:
  309. "Returns whether the current process is the main process on the local node"
  310. return (
  311. self.local_process_index == 0
  312. if self.distributed_type != DistributedType.MEGATRON_LM
  313. else self.is_last_process
  314. )
  315. def wait_for_everyone(self):
  316. """
  317. Will stop the execution of the current process until every other process has reached that point (so this does
  318. nothing when the script is only run in one process). Useful to do before saving a model.
  319. Example:
  320. ```python
  321. >>> # Assuming two GPU processes
  322. >>> import time
  323. >>> from accelerate.state import PartialState
  324. >>> state = PartialState()
  325. >>> if state.is_main_process:
  326. ... time.sleep(2)
  327. >>> else:
  328. ... print("I'm waiting for the main process to finish its sleep...")
  329. >>> state.wait_for_everyone()
  330. >>> # Should print on every process at the same time
  331. >>> print("Everyone is here")
  332. ```
  333. """
  334. if self.distributed_type in (
  335. DistributedType.MULTI_GPU,
  336. DistributedType.MULTI_MLU,
  337. DistributedType.MULTI_SDAA,
  338. DistributedType.MULTI_MUSA,
  339. DistributedType.MULTI_NPU,
  340. DistributedType.MULTI_XPU,
  341. DistributedType.MULTI_CPU,
  342. DistributedType.MULTI_HPU,
  343. DistributedType.DEEPSPEED,
  344. DistributedType.FSDP,
  345. ):
  346. torch.distributed.barrier(device_ids=[self.local_process_index])
  347. elif self.distributed_type == DistributedType.XLA:
  348. xm.rendezvous("accelerate.utils.wait_for_everyone")
  349. def _goes_first(self, is_main: bool):
  350. if not is_main:
  351. self.wait_for_everyone()
  352. yield
  353. if is_main:
  354. self.wait_for_everyone()
  355. @contextmanager
  356. def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
  357. """
  358. Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
  359. distributed inference, such as with different prompts.
  360. Note that when using a `dict`, all keys need to have the same number of elements.
  361. Args:
  362. inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):
  363. The input to split between processes.
  364. apply_padding (`bool`, `optional`, defaults to `False`):
  365. Whether to apply padding by repeating the last element of the input so that all processes have the same
  366. number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
  367. in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
  368. Example:
  369. ```python
  370. # Assume there are two processes
  371. from accelerate import PartialState
  372. state = PartialState()
  373. with state.split_between_processes(["A", "B", "C"]) as inputs:
  374. print(inputs)
  375. # Process 0
  376. ["A", "B"]
  377. # Process 1
  378. ["C"]
  379. with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
  380. print(inputs)
  381. # Process 0
  382. ["A", "B"]
  383. # Process 1
  384. ["C", "C"]
  385. ```
  386. """
  387. if self.num_processes == 1:
  388. yield inputs
  389. return
  390. length = len(inputs)
  391. # Nested dictionary of any types
  392. if isinstance(inputs, dict):
  393. length = len(inputs[list(inputs.keys())[0]])
  394. if not all(len(v) == length for v in inputs.values()):
  395. raise ValueError("All values in the dictionary must have the same length")
  396. num_samples_per_process, num_extras = divmod(length, self.num_processes)
  397. start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
  398. end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)
  399. def _split_values(inputs, start_index, end_index):
  400. if isinstance(inputs, (list, tuple, torch.Tensor)):
  401. if start_index >= len(inputs):
  402. result = inputs[-1:]
  403. else:
  404. result = inputs[start_index:end_index]
  405. if apply_padding:
  406. if isinstance(result, torch.Tensor):
  407. from accelerate.utils import pad_across_processes, send_to_device
  408. # The tensor needs to be on the device before we can pad it
  409. tensorized_result = send_to_device(result, self.device)
  410. result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
  411. else:
  412. result += [result[-1]] * (num_samples_per_process + (1 if num_extras > 0 else 0) - len(result))
  413. return result
  414. elif isinstance(inputs, dict):
  415. for key in inputs.keys():
  416. inputs[key] = _split_values(inputs[key], start_index, end_index)
  417. return inputs
  418. else:
  419. if is_datasets_available():
  420. from datasets import Dataset
  421. if isinstance(inputs, Dataset):
  422. if start_index >= len(inputs):
  423. start_index = len(inputs) - 1
  424. if end_index > len(inputs):
  425. end_index = len(inputs)
  426. result_idcs = list(range(start_index, end_index))
  427. if apply_padding:
  428. result_idcs += [end_index - 1] * (
  429. num_samples_per_process + (1 if num_extras > 0 else 0) - len(result_idcs)
  430. )
  431. return inputs.select(result_idcs)
  432. return inputs
  433. yield _split_values(inputs, start_index, end_index)
  434. @contextmanager
  435. def main_process_first(self):
  436. """
  437. Lets the main process go first inside a with block.
  438. The other processes will enter the with block after the main process exits.
  439. Example:
  440. ```python
  441. >>> from accelerate import Accelerator
  442. >>> accelerator = Accelerator()
  443. >>> with accelerator.main_process_first():
  444. ... # This will be printed first by process 0 then in a seemingly
  445. ... # random order by the other processes.
  446. ... print(f"This will be printed by process {accelerator.process_index}")
  447. ```
  448. """
  449. yield from self._goes_first(self.is_main_process)
  450. @contextmanager
  451. def local_main_process_first(self):
  452. """
  453. Lets the local main process go inside a with block.
  454. The other processes will enter the with block after the main process exits.
  455. Example:
  456. ```python
  457. >>> from accelerate.state import PartialState
  458. >>> state = PartialState()
  459. >>> with state.local_main_process_first():
  460. ... # This will be printed first by local process 0 then in a seemingly
  461. ... # random order by the other processes.
  462. ... print(f"This will be printed by process {state.local_process_index}")
  463. ```
  464. """
  465. yield from self._goes_first(self.is_local_main_process)
  466. def on_main_process(self, function: Callable[..., Any] | None = None):
  467. """
  468. Decorator that only runs the decorated function on the main process.
  469. Args:
  470. function (`Callable`): The function to decorate.
  471. Example:
  472. ```python
  473. >>> from accelerate.state import PartialState
  474. >>> state = PartialState()
  475. >>> @state.on_main_process
  476. ... def print_something():
  477. ... print("This will be printed by process 0 only.")
  478. >>> print_something()
  479. "This will be printed by process 0 only"
  480. ```
  481. """
  482. if not self.initialized:
  483. raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.")
  484. if self.is_main_process or not self.use_distributed:
  485. return function
  486. return do_nothing
  487. def on_local_main_process(self, function: Callable[..., Any] | None = None):
  488. """
  489. Decorator that only runs the decorated function on the local main process.
  490. Args:
  491. function (`Callable`): The function to decorate.
  492. Example:
  493. ```python
  494. # Assume we have 2 servers with 4 processes each.
  495. from accelerate.state import PartialState
  496. state = PartialState()
  497. @state.on_local_main_process
  498. def print_something():
  499. print("This will be printed by process 0 only on each server.")
  500. print_something()
  501. # On server 1:
  502. "This will be printed by process 0 only"
  503. # On server 2:
  504. "This will be printed by process 0 only"
  505. ```
  506. """
  507. if self.is_local_main_process or not self.use_distributed:
  508. return function
  509. return do_nothing
  510. def on_last_process(self, function: Callable[..., Any]):
  511. """
  512. Decorator that only runs the decorated function on the last process.
  513. Args:
  514. function (`Callable`): The function to decorate.
  515. Example:
  516. ```python
  517. # Assume we have 4 processes.
  518. from accelerate.state import PartialState
  519. state = PartialState()
  520. @state.on_last_process
  521. def print_something():
  522. print(f"Printed on process {state.process_index}")
  523. print_something()
  524. "Printed on process 3"
  525. ```
  526. """
  527. if self.is_last_process or not self.use_distributed:
  528. return function
  529. return do_nothing
  530. def on_process(self, function: Callable[..., Any] | None = None, process_index: int | None = None):
  531. """
  532. Decorator that only runs the decorated function on the process with the given index.
  533. Args:
  534. function (`Callable`, `optional`):
  535. The function to decorate.
  536. process_index (`int`, `optional`):
  537. The index of the process on which to run the function.
  538. Example:
  539. ```python
  540. # Assume we have 4 processes.
  541. from accelerate.state import PartialState
  542. state = PartialState()
  543. @state.on_process(process_index=2)
  544. def print_something():
  545. print(f"Printed on process {state.process_index}")
  546. print_something()
  547. "Printed on process 2"
  548. ```
  549. """
  550. if function is None:
  551. return partial(self.on_process, process_index=process_index)
  552. if (self.process_index == process_index) or (not self.use_distributed):
  553. return function
  554. return do_nothing
  555. def on_local_process(self, function: Callable[..., Any] | None = None, local_process_index: int | None = None):
  556. """
  557. Decorator that only runs the decorated function on the process with the given index on the current node.
  558. Args:
  559. function (`Callable`, *optional*):
  560. The function to decorate.
  561. local_process_index (`int`, *optional*):
  562. The index of the local process on which to run the function.
  563. Example:
  564. ```python
  565. # Assume we have 2 servers with 4 processes each.
  566. from accelerate import Accelerator
  567. accelerator = Accelerator()
  568. @accelerator.on_local_process(local_process_index=2)
  569. def print_something():
  570. print(f"Printed on process {accelerator.local_process_index}")
  571. print_something()
  572. # On server 1:
  573. "Printed on process 2"
  574. # On server 2:
  575. "Printed on process 2"
  576. ```
  577. """
  578. if function is None:
  579. return partial(self.on_local_process, local_process_index=local_process_index)
  580. if (self.local_process_index == local_process_index) or (not self.use_distributed):
  581. return function
  582. return do_nothing
  583. def print(self, *args, **kwargs):
  584. if self.is_local_main_process:
  585. print(*args, **kwargs)
  586. @property
  587. def default_device(self) -> torch.device:
  588. """
  589. Returns the default device which is:
  590. - MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.
  591. - CUDA if `torch.cuda.is_available()`
  592. - MLU if `is_mlu_available()`
  593. - SDAA if `is_sdaa_available()`
  594. - MUSA if `is_musa_available()`
  595. - NPU if `is_npu_available()`
  596. - HPU if `is_hpu_available()`
  597. - CPU otherwise
  598. """
  599. if is_mps_available():
  600. os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
  601. return torch.device("mps")
  602. elif is_mlu_available():
  603. return torch.device("mlu")
  604. elif is_sdaa_available():
  605. return torch.device("sdaa")
  606. elif is_musa_available():
  607. return torch.device("musa")
  608. # NPU should be checked before CUDA when using `transfer_to_npu`
  609. # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
  610. elif is_npu_available():
  611. return torch.device("npu")
  612. elif is_hpu_available():
  613. return torch.device("hpu")
  614. elif torch.cuda.is_available():
  615. return torch.device("cuda")
  616. elif is_xpu_available():
  617. return torch.device("xpu")
  618. else:
  619. return torch.device("cpu")
  620. def _prepare_backend(
  621. self, cpu: bool = False, sagemaker_dp=False, backend: str | None = None
  622. ) -> tuple[str, DistributedType]:
  623. "Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
  624. distributed_type = None
  625. if sagemaker_dp:
  626. import smdistributed.dataparallel.torch.torch_smddp # noqa
  627. backend = "smddp"
  628. distributed_type = DistributedType.MULTI_GPU
  629. elif is_torch_xla_available():
  630. backend = "xla"
  631. distributed_type = DistributedType.XLA
  632. elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
  633. if is_mlu_available():
  634. backend = "cncl"
  635. distributed_type = DistributedType.MULTI_MLU
  636. if is_sdaa_available():
  637. backend = "tccl"
  638. distributed_type = DistributedType.MULTI_SDAA
  639. elif is_musa_available():
  640. backend = "mccl"
  641. distributed_type = DistributedType.MULTI_MUSA
  642. # NPU should be checked before CUDA when using `transfer_to_npu`
  643. # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
  644. elif is_npu_available():
  645. backend = "hccl"
  646. distributed_type = DistributedType.MULTI_NPU
  647. elif is_hpu_available(init_hccl=True):
  648. if backend is None:
  649. backend = "hccl"
  650. distributed_type = DistributedType.MULTI_HPU
  651. elif torch.cuda.is_available():
  652. if backend is None:
  653. backend = "nccl"
  654. distributed_type = DistributedType.MULTI_GPU
  655. elif is_xpu_available() and is_xccl_available():
  656. if backend is None:
  657. backend = "xccl"
  658. distributed_type = DistributedType.MULTI_XPU
  659. if distributed_type is None and (
  660. int(os.environ.get("LOCAL_RANK", -1)) != -1
  661. or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
  662. ):
  663. if not cpu and is_xpu_available():
  664. distributed_type = DistributedType.MULTI_XPU
  665. else:
  666. distributed_type = DistributedType.MULTI_CPU
  667. if (
  668. backend in (None, "ccl")
  669. and is_ccl_available()
  670. and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU)
  671. ):
  672. import oneccl_bindings_for_pytorch # noqa: F401
  673. backend = "ccl"
  674. elif backend in (None, "mpi") and torch.distributed.is_mpi_available():
  675. backend = "mpi"
  676. else:
  677. backend = "gloo"
  678. if distributed_type is None:
  679. distributed_type = DistributedType.NO
  680. return backend, distributed_type
  681. def set_device(self):
  682. """
  683. Sets the device in `self.device` to the current distributed environment.
  684. """
  685. if self.device is not None:
  686. return
  687. if self.distributed_type == DistributedType.NO:
  688. self.device = torch.device("cpu") if self._cpu else self.default_device
  689. return
  690. device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
  691. if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla", "hpu", "sdaa"):
  692. raise ValueError(
  693. f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
  694. )
  695. if device == "xla":
  696. self.device = xm.xla_device()
  697. elif device == "hpu":
  698. self.device = torch.device("hpu", torch.hpu.current_device())
  699. else:
  700. if device == "gpu":
  701. device = "cuda"
  702. device_module = getattr(torch, device)
  703. device_index = self.local_process_index % device_module.device_count()
  704. self.device = torch.device(device, device_index)
  705. device_module.set_device(self.device)
  706. def destroy_process_group(self, group=None):
  707. """
  708. Destroys the process group. If one is not specified, the default process group is destroyed.
  709. """
  710. if self.fork_launched and group is None:
  711. return
  712. # needed when using torch.distributed.init_process_group
  713. if torch.distributed.is_initialized():
  714. torch.distributed.destroy_process_group(group)
  715. def __getattr__(self, name: str):
  716. # By this point we know that no attributes of `self` contain `name`,
  717. # so we just modify the error message
  718. if name in self._known_attrs:
  719. raise AttributeError(
  720. f"`PartialState` object has no attribute `{name}`. "
  721. "This happens if `PartialState._reset_state()` was called and "
  722. "an `Accelerator` or `PartialState` was not reinitialized."
  723. )
  724. # Raise a typical AttributeError
  725. raise AttributeError(f"'PartialState' object has no attribute '{name}'")
  726. class AcceleratorState:
  727. """
  728. Singleton class that has information about the current training environment.
  729. **Available attributes:**
  730. - **device** (`torch.device`) -- The device to use.
  731. - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
  732. in use.
  733. - **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration for the
  734. current training environment. This is used to configure the distributed training environment.
  735. - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
  736. - **local_process_index** (`int`) -- The index of the current process on the current server.
  737. - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
  738. of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
  739. - **num_processes** (`int`) -- The number of processes currently launched in parallel.
  740. - **process_index** (`int`) -- The index of the current process.
  741. - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
  742. - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
  743. - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
  744. - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
  745. """
  746. _shared_state = SharedDict()
  747. _known_attrs = PartialState._known_attrs + [
  748. "deepspeed_plugin",
  749. "use_ipex",
  750. "fsdp_plugin",
  751. "megatron_lm_plugin",
  752. "dynamo_plugin",
  753. ]
  754. def __init__(
  755. self,
  756. mixed_precision: str | None = None,
  757. cpu: bool = False,
  758. dynamo_plugin=None,
  759. deepspeed_plugin=None,
  760. fsdp_plugin=None,
  761. torch_tp_plugin=None,
  762. megatron_lm_plugin=None,
  763. parallelism_config=None,
  764. _from_accelerator: bool = False,
  765. **kwargs,
  766. ):
  767. self.__dict__ = self._shared_state
  768. if parse_flag_from_env("ACCELERATE_USE_CPU"):
  769. cpu = True
  770. if PartialState._shared_state == {}:
  771. PartialState(cpu, **kwargs)
  772. self.__dict__.update(PartialState._shared_state)
  773. self._check_initialized(mixed_precision, cpu)
  774. if not self.initialized:
  775. self.deepspeed_plugins = None
  776. self.use_ipex = None
  777. self.torch_tp_plugin = torch_tp_plugin
  778. self.parallelism_config = parallelism_config
  779. self.device_mesh = None
  780. mixed_precision = (
  781. parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
  782. if mixed_precision is None
  783. else mixed_precision.lower()
  784. )
  785. if mixed_precision == "fp8":
  786. # this is confusing, why is is_fp8_available only checks for library availability ?
  787. if not is_fp8_available():
  788. raise ValueError(
  789. "Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed."
  790. )
  791. elif torch.cuda.is_available() and not check_cuda_fp8_capability():
  792. logger.warning(
  793. f"The current device has compute capability of {torch.cuda.get_device_capability()} which is "
  794. "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
  795. "or higher, compute capability of 8.9 or higher). Will use FP16 instead."
  796. )
  797. mixed_precision = "fp16"
  798. elif is_habana_gaudi1():
  799. logger.warning(
  800. "The current HPU device is Gaudi1 which does not support FP8 mixed precision training (requires "
  801. "Gaudi2 or higher). Will use BF16 instead."
  802. )
  803. mixed_precision = "bf16"
  804. self.dynamo_plugin = dynamo_plugin
  805. if not _from_accelerator:
  806. raise ValueError(
  807. "Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
  808. "before using any functionality from the `accelerate` library."
  809. )
  810. # deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
  811. # if we're using fp8.
  812. if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
  813. self._mixed_precision = "no"
  814. else:
  815. self._mixed_precision = mixed_precision
  816. if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
  817. if mixed_precision == "bf16":
  818. if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
  819. os.environ["XLA_USE_BF16"] = str(0)
  820. os.environ["XLA_DOWNCAST_BF16"] = str(1)
  821. self.downcast_bfloat = True
  822. else:
  823. os.environ["XLA_USE_BF16"] = str(1)
  824. os.environ["XLA_DOWNCAST_BF16"] = str(0)
  825. self.downcast_bfloat = False
  826. elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true" and not cpu:
  827. self.distributed_type = DistributedType.DEEPSPEED
  828. if not isinstance(deepspeed_plugin, dict):
  829. deepspeed_plugin.set_mixed_precision(mixed_precision)
  830. deepspeed_plugin.select(_from_accelerator_state=True)
  831. else:
  832. for plugin in deepspeed_plugin.values():
  833. plugin.set_mixed_precision(mixed_precision)
  834. # The first plugin passed in is always the active one
  835. first_plugin = next(iter(deepspeed_plugin.values()))
  836. first_plugin.select(_from_accelerator_state=True)
  837. self.deepspeed_plugins = deepspeed_plugin
  838. elif self.distributed_type in [
  839. DistributedType.MULTI_GPU,
  840. DistributedType.MULTI_MLU,
  841. DistributedType.MULTI_SDAA,
  842. DistributedType.MULTI_MUSA,
  843. DistributedType.MULTI_NPU,
  844. DistributedType.MULTI_XPU,
  845. DistributedType.MULTI_HPU,
  846. ]:
  847. # TODO: Siro - remove when axolotl fixes their side
  848. if not os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true":
  849. if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None:
  850. raise ValueError(
  851. "`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism with `cp_backend=torch`, as we also shard the model across the device mesh to save more memory"
  852. )
  853. if (
  854. self.parallelism_config is not None
  855. and self.parallelism_config.cp_enabled
  856. and fsdp_plugin.fsdp_version == 1
  857. ):
  858. raise ValueError(
  859. "Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
  860. )
  861. if (os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None) or (
  862. self.parallelism_config is not None and self.parallelism_config.cp_enabled
  863. ):
  864. self.distributed_type = DistributedType.FSDP
  865. if self._mixed_precision != "no" and fsdp_plugin is not None:
  866. fsdp_plugin.set_mixed_precision(self._mixed_precision)
  867. self.fsdp_plugin = fsdp_plugin
  868. if os.environ.get(
  869. "ACCELERATE_USE_MEGATRON_LM", "false"
  870. ).lower() == "true" and self.distributed_type not in [
  871. DistributedType.MULTI_XPU,
  872. ]:
  873. self.distributed_type = DistributedType.MEGATRON_LM
  874. megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
  875. self.megatron_lm_plugin = megatron_lm_plugin
  876. elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
  877. if is_ipex_available():
  878. # check if user disables it explicitly
  879. self.use_ipex = parse_flag_from_env("ACCELERATE_USE_IPEX", default=True)
  880. else:
  881. self.use_ipex = False
  882. if (
  883. self.dynamo_plugin.backend != DynamoBackend.NO
  884. and self._mixed_precision == "no"
  885. and self.device.type == "cuda"
  886. ):
  887. torch.backends.cuda.matmul.allow_tf32 = True
  888. if (
  889. self.dynamo_plugin.backend != DynamoBackend.NO
  890. and self._mixed_precision == "no"
  891. and self.device.type == "musa"
  892. ):
  893. torch.backends.musa.matmul.allow_tf32 = True
  894. PartialState._shared_state["distributed_type"] = self.distributed_type
  895. @property
  896. def initialized(self) -> bool:
  897. return self._shared_state != PartialState._shared_state
  898. def __repr__(self):
  899. repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n"
  900. if self.distributed_type == DistributedType.DEEPSPEED:
  901. repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
  902. return repr
  903. def _check_initialized(self, mixed_precision=None, cpu=None):
  904. "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
  905. if self.initialized:
  906. err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`."
  907. if cpu and self.device.type != "cpu":
  908. raise ValueError(err.format(flag="cpu=True"))
  909. if (
  910. mixed_precision is not None
  911. and mixed_precision != self._mixed_precision
  912. and self.distributed_type != DistributedType.DEEPSPEED
  913. ):
  914. raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'"))
  915. @property
  916. def mixed_precision(self):
  917. if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8":
  918. config = self.deepspeed_plugin.deepspeed_config
  919. if config.get("fp16", {}).get("enabled", False):
  920. mixed_precision = "fp16"
  921. elif config.get("bf16", {}).get("enabled", False):
  922. mixed_precision = "bf16"
  923. else:
  924. mixed_precision = "no"
  925. else:
  926. mixed_precision = self._mixed_precision
  927. return mixed_precision
  928. @staticmethod
  929. def _reset_state(reset_partial_state: bool = False):
  930. "Resets `_shared_state`, is used internally and should not be called"
  931. AcceleratorState._shared_state.clear()
  932. if reset_partial_state:
  933. PartialState._reset_state()
  934. def destroy_process_group(self, group=None):
  935. """
  936. Destroys the process group. If one is not specified, the default process group is destroyed.
  937. If `self.fork_launched` is `True` and `group` is `None`, nothing happens.
  938. """
  939. PartialState().destroy_process_group(group)
  940. @property
  941. def fork_launched(self):
  942. return PartialState().fork_launched
  943. @property
  944. def use_distributed(self):
  945. """
  946. Whether the Accelerator is configured for distributed training
  947. """
  948. return PartialState().use_distributed
  949. @property
  950. def is_fsdp2(self) -> bool:
  951. return self.distributed_type == DistributedType.FSDP and self.fsdp_plugin.fsdp_version == 2
  952. @property
  953. def is_last_process(self) -> bool:
  954. "Returns whether the current process is the last one"
  955. return PartialState().is_last_process
  956. @property
  957. def is_main_process(self) -> bool:
  958. "Returns whether the current process is the main process"
  959. return PartialState().is_main_process
  960. @property
  961. def is_local_main_process(self) -> bool:
  962. "Returns whether the current process is the main process on the local node"
  963. return PartialState().is_local_main_process
  964. def wait_for_everyone(self):
  965. PartialState().wait_for_everyone()
  966. @contextmanager
  967. def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
  968. """
  969. Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
  970. distributed inference, such as with different prompts.
  971. Note that when using a `dict`, all keys need to have the same number of elements.
  972. Args:
  973. inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):
  974. The input to split between processes.
  975. apply_padding (`bool`, `optional`, defaults to `False`):
  976. Whether to apply padding by repeating the last element of the input so that all processes have the same
  977. number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
  978. in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
  979. Example:
  980. ```python
  981. # Assume there are two processes
  982. from accelerate.state import AcceleratorState
  983. state = AcceleratorState()
  984. with state.split_between_processes(["A", "B", "C"]) as inputs:
  985. print(inputs)
  986. # Process 0
  987. ["A", "B"]
  988. # Process 1
  989. ["C"]
  990. with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
  991. print(inputs)
  992. # Process 0
  993. ["A", "B"]
  994. # Process 1
  995. ["C", "C"]
  996. ```
  997. """
  998. with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:
  999. yield inputs
  1000. @contextmanager
  1001. def main_process_first(self):
  1002. """
  1003. Lets the main process go first inside a with block.
  1004. The other processes will enter the with block after the main process exits.
  1005. """
  1006. with PartialState().main_process_first():
  1007. yield
  1008. @contextmanager
  1009. def local_main_process_first(self):
  1010. """
  1011. Lets the local main process go inside a with block.
  1012. The other processes will enter the with block after the main process exits.
  1013. """
  1014. with PartialState().local_main_process_first():
  1015. yield
  1016. @property
  1017. def deepspeed_plugin(self):
  1018. """
  1019. Returns the currently active DeepSpeedPlugin.
  1020. If not using deepspeed, returns `None`.
  1021. """
  1022. # To maintain original behavior, return None if not using deepspeed.
  1023. if self.distributed_type != DistributedType.DEEPSPEED:
  1024. return None
  1025. from accelerate.utils.deepspeed import get_active_deepspeed_plugin
  1026. return get_active_deepspeed_plugin(self)
  1027. @deepspeed_required
  1028. def get_deepspeed_plugin(self, name: str):
  1029. """
  1030. Returns the DeepSpeedPlugin with the given plugin_key.
  1031. """
  1032. return self.deepspeed_plugins[name]
  1033. @deepspeed_required
  1034. def select_deepspeed_plugin(self, name: str | None = None):
  1035. """
  1036. Activates the DeepSpeedPlugin with the given `name`, and will disable all other plugins.
  1037. """
  1038. for key, plugin in self.deepspeed_plugins.items():
  1039. if key != name:
  1040. plugin._unselect()
  1041. self.deepspeed_plugins[name].select(_from_accelerator_state=True)
  1042. def print(self, *args, **kwargs):
  1043. PartialState().print(*args, **kwargs)
  1044. def __getattr__(self, name: str):
  1045. # By this point we know that no attributes of `self` contain `name`,
  1046. # so we just modify the error message
  1047. if name in self._known_attrs:
  1048. raise AttributeError(
  1049. f"`AcceleratorState` object has no attribute `{name}`. "
  1050. "This happens if `AcceleratorState._reset_state()` was called and "
  1051. "an `Accelerator` or `PartialState` was not reinitialized."
  1052. )
  1053. # Raise a typical AttributeError
  1054. raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
  1055. class GradientState:
  1056. """
  1057. Singleton class that has information related to gradient synchronization for gradient accumulation
  1058. **Available attributes:**
  1059. - **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader
  1060. - **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader
  1061. - **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices
  1062. - **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over
  1063. - **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are
  1064. being iterated over
  1065. - **num_steps** (`int`) -- The number of steps to accumulate over
  1066. - **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient
  1067. accumulation
  1068. - **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader
  1069. iteration and the number of total steps reset
  1070. - **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized
  1071. as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently,
  1072. after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence
  1073. is_xla_gradients_synced is always true.
  1074. """
  1075. _shared_state = SharedDict()
  1076. def __init__(self, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None):
  1077. self.__dict__ = self._shared_state
  1078. if not self.initialized:
  1079. self.sync_gradients = True
  1080. self._dataloader_references_ref = [None]
  1081. self.plugin_kwargs = (
  1082. gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {}
  1083. )
  1084. self._is_xla_gradients_synced = False
  1085. # Plugin args are different and can be updated
  1086. if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs():
  1087. self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs()
  1088. @property
  1089. def num_steps(self) -> int:
  1090. "Returns the number of steps to accumulate over"
  1091. return self.plugin_kwargs.get("num_steps", 1)
  1092. @property
  1093. def adjust_scheduler(self) -> bool:
  1094. "Returns whether the scheduler should be adjusted"
  1095. return self.plugin_kwargs.get("adjust_scheduler", False)
  1096. @property
  1097. def sync_with_dataloader(self) -> bool:
  1098. "Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset"
  1099. return self.plugin_kwargs.get("sync_with_dataloader", True)
  1100. @property
  1101. def initialized(self) -> bool:
  1102. "Returns whether the `GradientState` has been initialized"
  1103. return GradientState._shared_state != {}
  1104. @property
  1105. def end_of_dataloader(self) -> bool:
  1106. "Returns whether we have reached the end of the current dataloader"
  1107. if not self.in_dataloader:
  1108. return False
  1109. return self.active_dataloader.end_of_dataloader
  1110. @property
  1111. def remainder(self) -> int:
  1112. "Returns the number of extra samples that were added from padding the dataloader"
  1113. if not self.in_dataloader:
  1114. return -1
  1115. return self.active_dataloader.remainder
  1116. def __repr__(self):
  1117. return (
  1118. f"Sync Gradients: {self.sync_gradients}\n"
  1119. f"At end of current dataloader: {self.end_of_dataloader}\n"
  1120. f"Extra samples added: {self.remainder}\n"
  1121. f"Gradient accumulation plugin: {self.plugin_kwargs}\n"
  1122. )
  1123. @property
  1124. def is_xla_gradients_synced(self):
  1125. "Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true."
  1126. if parse_flag_from_env("ACCELERATE_USE_FSDP", default=False):
  1127. return True
  1128. return self._is_xla_gradients_synced
  1129. @is_xla_gradients_synced.setter
  1130. def is_xla_gradients_synced(self, is_synced):
  1131. "Set the _is_xla_gradients_synced attribute."
  1132. self._is_xla_gradients_synced = is_synced
  1133. def _set_sync_gradients(self, sync_gradients):
  1134. "Private function that sets whether gradients should be synchronized. Users should not have to call this."
  1135. self.sync_gradients = sync_gradients
  1136. # Allow grad-sync to automatically work on TPUs
  1137. if (
  1138. self.sync_gradients
  1139. and is_torch_xla_available(check_is_tpu=True)
  1140. and PartialState().distributed_type == DistributedType.XLA
  1141. ):
  1142. xm.mark_step()
  1143. def _add_dataloader(self, dataloader):
  1144. "Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this."
  1145. # We explicitly use assignment to ensure that the property setter is triggered, which is required for garbage collection.
  1146. # Avoid using self.dataloader_references.append as it will not trigger the setter.
  1147. self.dataloader_references += [dataloader]
  1148. def _remove_dataloader(self, dataloader):
  1149. "Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this."
  1150. # We explicitly use assignment to ensure that the property setter is triggered.
  1151. self.dataloader_references = [
  1152. dataloader_ref for dataloader_ref in self.dataloader_references if dataloader_ref != dataloader
  1153. ]
  1154. @property
  1155. def active_dataloader(self):
  1156. return self.dataloader_references[-1]
  1157. @property
  1158. def dataloader_references(self):
  1159. # We use a property getter and setter with weakrefs to avoid circular references that prevent garbage collection
  1160. return [reference() if reference is not None else reference for reference in self._dataloader_references_ref]
  1161. @dataloader_references.setter
  1162. def dataloader_references(self, references):
  1163. self._dataloader_references_ref = [
  1164. weakref.ref(dataloader) if dataloader is not None else dataloader for dataloader in references
  1165. ]
  1166. @property
  1167. def in_dataloader(self) -> bool:
  1168. "Returns whether the current process is in a dataloader"
  1169. return self.active_dataloader is not None
  1170. @staticmethod
  1171. def _reset_state():
  1172. "Resets `_shared_state`, is used internally and should not be called"
  1173. GradientState._shared_state.clear()