| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373 |
- # Copyright 2021 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import annotations
- import logging
- import os
- import threading
- import warnings
- import weakref
- from contextlib import contextmanager
- from functools import partial
- from typing import Any, Callable
- import torch
- from .utils import (
- DistributedType,
- DynamoBackend,
- GradientAccumulationPlugin,
- check_cuda_fp8_capability,
- check_cuda_p2p_ib_support,
- deepspeed_required,
- get_cpu_distributed_information,
- get_int_from_env,
- is_ccl_available,
- is_datasets_available,
- is_deepspeed_available,
- is_fp8_available,
- is_habana_gaudi1,
- is_hpu_available,
- is_ipex_available,
- is_mlu_available,
- is_mps_available,
- is_musa_available,
- is_npu_available,
- is_sdaa_available,
- is_torch_xla_available,
- is_xccl_available,
- is_xpu_available,
- parse_choice_from_env,
- parse_flag_from_env,
- set_numa_affinity,
- )
- from .utils.dataclasses import SageMakerDistributedType
- if is_torch_xla_available():
- import torch_xla.core.xla_model as xm
- import torch_xla.runtime as xr
- if is_mlu_available(check_device=False):
- import torch_mlu # noqa: F401
- if is_sdaa_available(check_device=False):
- import torch_sdaa # noqa: F401
- if is_musa_available(check_device=False):
- import torch_musa # noqa: F401
- if is_npu_available(check_device=False):
- import torch_npu # noqa: F401
- logger = logging.getLogger(__name__)
- def is_initialized() -> bool:
- """
- Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`,
- but works as a module method.
- """
- return AcceleratorState._shared_state != {}
- # Lambda function that does nothing
- def do_nothing(*args, **kwargs):
- return None
- class ThreadLocalSharedDict(threading.local):
- """
- Descriptor that holds a dict shared between instances of a class in the same thread.
- Note: Descriptors have slightly different semantics than just a dict field on its own.
- `PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the
- underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside
- the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor
- object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`).
- See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html
- This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3).
- See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3
- """
- def __init__(self, thread_local: bool = False):
- self._storage = {}
- def __get__(self, obj, objtype=None):
- return self._storage
- def __set__(self, obj, value):
- self._storage = value
- # Prefer global shared dictionary, except when using TPU.
- SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
- # Inspired by Alex Martelli's 'Borg'.
- class PartialState:
- """
- Singleton class that has information about the current training environment and functions to help with process
- control. Designed to be used when only process control and device execution states are needed. Does *not* need to
- be initialized from `Accelerator`.
- Args:
- cpu (`bool`, *optional*):
- Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to
- `True` and force the execution on the CPU.
- kwargs (additional keyword arguments, *optional*):
- Additional keyword arguments to pass to the relevant `init_process_group` function. Valid `kwargs` can be
- found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.
- **Available attributes:**
- - **device** (`torch.device`) -- The device to use.
- - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
- in use.
- - **local_process_index** (`int`) -- The index of the current process on the current server.
- - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
- of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
- - **num_processes** (`int`) -- The number of processes currently launched in parallel.
- - **process_index** (`int`) -- The index of the current process.
- - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
- - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
- - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
- - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
- Example:
- ```python
- from accelerate.utils import InitProcessGroupKwargs
- # To include `InitProcessGroupKwargs`, init then call `.to_kwargs()`
- kwargs = InitProcessGroupKwargs(...).to_kwargs()
- state = PartialState(**kwargs)
- ```
- """
- _shared_state = SharedDict()
- _known_attrs = [
- "_cpu",
- "_mixed_precision",
- "_shared_state",
- "backend",
- "debug",
- "device",
- "distributed_type",
- "fork_launched",
- "local_process_index",
- "num_processes",
- "process_index",
- ]
- def __init__(self, cpu: bool = False, **kwargs):
- self.__dict__ = self._shared_state
- if not self.initialized:
- self._cpu = cpu
- self.backend = None
- env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
- self.device = torch.device(env_device) if env_device is not None else None
- self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
- use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None)
- dist_information = None
- if use_sagemaker_dp is None:
- use_sagemaker_dp = (
- os.environ.get("ACCELERATE_USE_SAGEMAKER", "false").lower() == "true"
- and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
- )
- # Sets up self.backend + imports
- original_backend = kwargs.pop("backend", None)
- backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
- if original_backend is not None and backend != original_backend:
- raise ValueError(f"Your assigned backend {original_backend} is not available, please use {backend}")
- self.backend = backend
- self.distributed_type = distributed_type
- use_deepspeed = False
- if not cpu and self.backend != "xla":
- if int(os.environ.get("LOCAL_RANK", -1)) != -1:
- # Deal with spawning deepspeed
- if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true":
- if not is_deepspeed_available():
- raise ImportError(
- "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
- )
- from deepspeed import comm as dist
- if not dist.is_initialized():
- if self.backend == "tccl":
- local_rank = os.environ.get("LOCAL_RANK", -1)
- torch.sdaa.set_device(f"sdaa:{local_rank}")
- dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
- # We need to flag to `use_deepspeed` to be True to override `distributed_type` later
- use_deepspeed = True
- # Deal with all other backends but XPU and CPU, that gets handled special later
- elif (
- self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU)
- and not torch.distributed.is_initialized()
- ):
- if self.backend == "tccl":
- local_rank = os.environ.get("LOCAL_RANK", -1)
- torch.sdaa.set_device(f"sdaa:{local_rank}")
- if (
- self.backend == "nccl"
- and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
- and (
- os.environ.get("FSDP_OFFLOAD_PARAMS", "false").lower() == "true"
- or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT"
- )
- ):
- self.backend = "cuda:nccl,cpu:gloo"
- torch.distributed.init_process_group(backend=self.backend, **kwargs)
- # XPU and CPU require special env configs to be set
- if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU):
- dist_information = get_cpu_distributed_information()
- os.environ["RANK"] = str(dist_information.rank)
- os.environ["WORLD_SIZE"] = str(dist_information.world_size)
- os.environ["LOCAL_RANK"] = str(dist_information.local_rank)
- os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size)
- if not os.environ.get("MASTER_PORT", None):
- os.environ["MASTER_PORT"] = "29500"
- if (
- not os.environ.get("MASTER_ADDR", None)
- and dist_information.local_world_size != dist_information.world_size
- and self.backend != "mpi"
- ):
- raise ValueError(
- "Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, "
- "please try exporting rank 0's hostname as `MASTER_ADDR`"
- )
- kwargs["rank"] = dist_information.rank
- kwargs["world_size"] = dist_information.world_size
- if (
- self.distributed_type == DistributedType.MULTI_CPU
- and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0
- ):
- import psutil
- num_cpu_threads_per_process = int(
- psutil.cpu_count(logical=False) / dist_information.local_world_size
- )
- if num_cpu_threads_per_process == 0:
- num_cpu_threads_per_process = 1
- torch.set_num_threads(num_cpu_threads_per_process)
- warnings.warn(
- f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob"
- " performance."
- )
- if not torch.distributed.is_initialized():
- torch.distributed.init_process_group(backend=self.backend, **kwargs)
- # No backend == no distributed training
- if self.backend is None:
- self.distributed_type = DistributedType.NO
- self.num_processes = 1
- self.process_index = 0
- self.local_process_index = 0
- elif self.backend == "xla":
- # XLA needs device setting first for `set_replication`
- self.set_device()
- xm.set_replication(self.device, xm.get_xla_supported_devices())
- self.num_processes = xr.world_size()
- self.process_index = xr.global_ordinal()
- if is_torch_xla_available(check_is_tpu=True):
- self.local_process_index = xm.get_local_ordinal()
- else:
- self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
- else:
- self.num_processes = torch.distributed.get_world_size()
- self.process_index = torch.distributed.get_rank()
- self.local_process_index = (
- int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
- )
- self.set_device()
- # Now we can change to deepseed
- if use_deepspeed:
- self.distributed_type = DistributedType.DEEPSPEED
- # Set CPU affinity if enabled
- if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False):
- set_numa_affinity(self.local_process_index)
- # Check for old RTX 4000's that can't use P2P or IB and are on old drivers
- if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
- if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
- raise NotImplementedError(
- "Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. "
- 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
- "will do this automatically."
- )
- # Important: This should be the *only* code outside of `self.initialized!`
- self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
- def __repr__(self) -> str:
- return (
- f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n"
- f"Num processes: {self.num_processes}\n"
- f"Process index: {self.process_index}\n"
- f"Local process index: {self.local_process_index}\n"
- f"Device: {self.device}\n"
- )
- @staticmethod
- def _reset_state():
- "Resets `_shared_state`, is used internally and should not be called"
- PartialState._shared_state.clear()
- @property
- def initialized(self) -> bool:
- "Returns whether the `PartialState` has been initialized"
- return self._shared_state != {}
- @property
- def use_distributed(self):
- """
- Whether the Accelerator is configured for distributed training
- """
- return self.distributed_type != DistributedType.NO and self.num_processes > 1
- @property
- def is_last_process(self) -> bool:
- "Returns whether the current process is the last one"
- return self.process_index == self.num_processes - 1
- @property
- def is_main_process(self) -> bool:
- "Returns whether the current process is the main process"
- return (
- self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process
- )
- @property
- def is_local_main_process(self) -> bool:
- "Returns whether the current process is the main process on the local node"
- return (
- self.local_process_index == 0
- if self.distributed_type != DistributedType.MEGATRON_LM
- else self.is_last_process
- )
- def wait_for_everyone(self):
- """
- Will stop the execution of the current process until every other process has reached that point (so this does
- nothing when the script is only run in one process). Useful to do before saving a model.
- Example:
- ```python
- >>> # Assuming two GPU processes
- >>> import time
- >>> from accelerate.state import PartialState
- >>> state = PartialState()
- >>> if state.is_main_process:
- ... time.sleep(2)
- >>> else:
- ... print("I'm waiting for the main process to finish its sleep...")
- >>> state.wait_for_everyone()
- >>> # Should print on every process at the same time
- >>> print("Everyone is here")
- ```
- """
- if self.distributed_type in (
- DistributedType.MULTI_GPU,
- DistributedType.MULTI_MLU,
- DistributedType.MULTI_SDAA,
- DistributedType.MULTI_MUSA,
- DistributedType.MULTI_NPU,
- DistributedType.MULTI_XPU,
- DistributedType.MULTI_CPU,
- DistributedType.MULTI_HPU,
- DistributedType.DEEPSPEED,
- DistributedType.FSDP,
- ):
- torch.distributed.barrier(device_ids=[self.local_process_index])
- elif self.distributed_type == DistributedType.XLA:
- xm.rendezvous("accelerate.utils.wait_for_everyone")
- def _goes_first(self, is_main: bool):
- if not is_main:
- self.wait_for_everyone()
- yield
- if is_main:
- self.wait_for_everyone()
- @contextmanager
- def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
- """
- Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
- distributed inference, such as with different prompts.
- Note that when using a `dict`, all keys need to have the same number of elements.
- Args:
- inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):
- The input to split between processes.
- apply_padding (`bool`, `optional`, defaults to `False`):
- Whether to apply padding by repeating the last element of the input so that all processes have the same
- number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
- in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
- Example:
- ```python
- # Assume there are two processes
- from accelerate import PartialState
- state = PartialState()
- with state.split_between_processes(["A", "B", "C"]) as inputs:
- print(inputs)
- # Process 0
- ["A", "B"]
- # Process 1
- ["C"]
- with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
- print(inputs)
- # Process 0
- ["A", "B"]
- # Process 1
- ["C", "C"]
- ```
- """
- if self.num_processes == 1:
- yield inputs
- return
- length = len(inputs)
- # Nested dictionary of any types
- if isinstance(inputs, dict):
- length = len(inputs[list(inputs.keys())[0]])
- if not all(len(v) == length for v in inputs.values()):
- raise ValueError("All values in the dictionary must have the same length")
- num_samples_per_process, num_extras = divmod(length, self.num_processes)
- start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
- end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)
- def _split_values(inputs, start_index, end_index):
- if isinstance(inputs, (list, tuple, torch.Tensor)):
- if start_index >= len(inputs):
- result = inputs[-1:]
- else:
- result = inputs[start_index:end_index]
- if apply_padding:
- if isinstance(result, torch.Tensor):
- from accelerate.utils import pad_across_processes, send_to_device
- # The tensor needs to be on the device before we can pad it
- tensorized_result = send_to_device(result, self.device)
- result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
- else:
- result += [result[-1]] * (num_samples_per_process + (1 if num_extras > 0 else 0) - len(result))
- return result
- elif isinstance(inputs, dict):
- for key in inputs.keys():
- inputs[key] = _split_values(inputs[key], start_index, end_index)
- return inputs
- else:
- if is_datasets_available():
- from datasets import Dataset
- if isinstance(inputs, Dataset):
- if start_index >= len(inputs):
- start_index = len(inputs) - 1
- if end_index > len(inputs):
- end_index = len(inputs)
- result_idcs = list(range(start_index, end_index))
- if apply_padding:
- result_idcs += [end_index - 1] * (
- num_samples_per_process + (1 if num_extras > 0 else 0) - len(result_idcs)
- )
- return inputs.select(result_idcs)
- return inputs
- yield _split_values(inputs, start_index, end_index)
- @contextmanager
- def main_process_first(self):
- """
- Lets the main process go first inside a with block.
- The other processes will enter the with block after the main process exits.
- Example:
- ```python
- >>> from accelerate import Accelerator
- >>> accelerator = Accelerator()
- >>> with accelerator.main_process_first():
- ... # This will be printed first by process 0 then in a seemingly
- ... # random order by the other processes.
- ... print(f"This will be printed by process {accelerator.process_index}")
- ```
- """
- yield from self._goes_first(self.is_main_process)
- @contextmanager
- def local_main_process_first(self):
- """
- Lets the local main process go inside a with block.
- The other processes will enter the with block after the main process exits.
- Example:
- ```python
- >>> from accelerate.state import PartialState
- >>> state = PartialState()
- >>> with state.local_main_process_first():
- ... # This will be printed first by local process 0 then in a seemingly
- ... # random order by the other processes.
- ... print(f"This will be printed by process {state.local_process_index}")
- ```
- """
- yield from self._goes_first(self.is_local_main_process)
- def on_main_process(self, function: Callable[..., Any] | None = None):
- """
- Decorator that only runs the decorated function on the main process.
- Args:
- function (`Callable`): The function to decorate.
- Example:
- ```python
- >>> from accelerate.state import PartialState
- >>> state = PartialState()
- >>> @state.on_main_process
- ... def print_something():
- ... print("This will be printed by process 0 only.")
- >>> print_something()
- "This will be printed by process 0 only"
- ```
- """
- if not self.initialized:
- raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.")
- if self.is_main_process or not self.use_distributed:
- return function
- return do_nothing
- def on_local_main_process(self, function: Callable[..., Any] | None = None):
- """
- Decorator that only runs the decorated function on the local main process.
- Args:
- function (`Callable`): The function to decorate.
- Example:
- ```python
- # Assume we have 2 servers with 4 processes each.
- from accelerate.state import PartialState
- state = PartialState()
- @state.on_local_main_process
- def print_something():
- print("This will be printed by process 0 only on each server.")
- print_something()
- # On server 1:
- "This will be printed by process 0 only"
- # On server 2:
- "This will be printed by process 0 only"
- ```
- """
- if self.is_local_main_process or not self.use_distributed:
- return function
- return do_nothing
- def on_last_process(self, function: Callable[..., Any]):
- """
- Decorator that only runs the decorated function on the last process.
- Args:
- function (`Callable`): The function to decorate.
- Example:
- ```python
- # Assume we have 4 processes.
- from accelerate.state import PartialState
- state = PartialState()
- @state.on_last_process
- def print_something():
- print(f"Printed on process {state.process_index}")
- print_something()
- "Printed on process 3"
- ```
- """
- if self.is_last_process or not self.use_distributed:
- return function
- return do_nothing
- def on_process(self, function: Callable[..., Any] | None = None, process_index: int | None = None):
- """
- Decorator that only runs the decorated function on the process with the given index.
- Args:
- function (`Callable`, `optional`):
- The function to decorate.
- process_index (`int`, `optional`):
- The index of the process on which to run the function.
- Example:
- ```python
- # Assume we have 4 processes.
- from accelerate.state import PartialState
- state = PartialState()
- @state.on_process(process_index=2)
- def print_something():
- print(f"Printed on process {state.process_index}")
- print_something()
- "Printed on process 2"
- ```
- """
- if function is None:
- return partial(self.on_process, process_index=process_index)
- if (self.process_index == process_index) or (not self.use_distributed):
- return function
- return do_nothing
- def on_local_process(self, function: Callable[..., Any] | None = None, local_process_index: int | None = None):
- """
- Decorator that only runs the decorated function on the process with the given index on the current node.
- Args:
- function (`Callable`, *optional*):
- The function to decorate.
- local_process_index (`int`, *optional*):
- The index of the local process on which to run the function.
- Example:
- ```python
- # Assume we have 2 servers with 4 processes each.
- from accelerate import Accelerator
- accelerator = Accelerator()
- @accelerator.on_local_process(local_process_index=2)
- def print_something():
- print(f"Printed on process {accelerator.local_process_index}")
- print_something()
- # On server 1:
- "Printed on process 2"
- # On server 2:
- "Printed on process 2"
- ```
- """
- if function is None:
- return partial(self.on_local_process, local_process_index=local_process_index)
- if (self.local_process_index == local_process_index) or (not self.use_distributed):
- return function
- return do_nothing
- def print(self, *args, **kwargs):
- if self.is_local_main_process:
- print(*args, **kwargs)
- @property
- def default_device(self) -> torch.device:
- """
- Returns the default device which is:
- - MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.
- - CUDA if `torch.cuda.is_available()`
- - MLU if `is_mlu_available()`
- - SDAA if `is_sdaa_available()`
- - MUSA if `is_musa_available()`
- - NPU if `is_npu_available()`
- - HPU if `is_hpu_available()`
- - CPU otherwise
- """
- if is_mps_available():
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
- return torch.device("mps")
- elif is_mlu_available():
- return torch.device("mlu")
- elif is_sdaa_available():
- return torch.device("sdaa")
- elif is_musa_available():
- return torch.device("musa")
- # NPU should be checked before CUDA when using `transfer_to_npu`
- # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
- elif is_npu_available():
- return torch.device("npu")
- elif is_hpu_available():
- return torch.device("hpu")
- elif torch.cuda.is_available():
- return torch.device("cuda")
- elif is_xpu_available():
- return torch.device("xpu")
- else:
- return torch.device("cpu")
- def _prepare_backend(
- self, cpu: bool = False, sagemaker_dp=False, backend: str | None = None
- ) -> tuple[str, DistributedType]:
- "Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
- distributed_type = None
- if sagemaker_dp:
- import smdistributed.dataparallel.torch.torch_smddp # noqa
- backend = "smddp"
- distributed_type = DistributedType.MULTI_GPU
- elif is_torch_xla_available():
- backend = "xla"
- distributed_type = DistributedType.XLA
- elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
- if is_mlu_available():
- backend = "cncl"
- distributed_type = DistributedType.MULTI_MLU
- if is_sdaa_available():
- backend = "tccl"
- distributed_type = DistributedType.MULTI_SDAA
- elif is_musa_available():
- backend = "mccl"
- distributed_type = DistributedType.MULTI_MUSA
- # NPU should be checked before CUDA when using `transfer_to_npu`
- # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
- elif is_npu_available():
- backend = "hccl"
- distributed_type = DistributedType.MULTI_NPU
- elif is_hpu_available(init_hccl=True):
- if backend is None:
- backend = "hccl"
- distributed_type = DistributedType.MULTI_HPU
- elif torch.cuda.is_available():
- if backend is None:
- backend = "nccl"
- distributed_type = DistributedType.MULTI_GPU
- elif is_xpu_available() and is_xccl_available():
- if backend is None:
- backend = "xccl"
- distributed_type = DistributedType.MULTI_XPU
- if distributed_type is None and (
- int(os.environ.get("LOCAL_RANK", -1)) != -1
- or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
- ):
- if not cpu and is_xpu_available():
- distributed_type = DistributedType.MULTI_XPU
- else:
- distributed_type = DistributedType.MULTI_CPU
- if (
- backend in (None, "ccl")
- and is_ccl_available()
- and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU)
- ):
- import oneccl_bindings_for_pytorch # noqa: F401
- backend = "ccl"
- elif backend in (None, "mpi") and torch.distributed.is_mpi_available():
- backend = "mpi"
- else:
- backend = "gloo"
- if distributed_type is None:
- distributed_type = DistributedType.NO
- return backend, distributed_type
- def set_device(self):
- """
- Sets the device in `self.device` to the current distributed environment.
- """
- if self.device is not None:
- return
- if self.distributed_type == DistributedType.NO:
- self.device = torch.device("cpu") if self._cpu else self.default_device
- return
- device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
- if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla", "hpu", "sdaa"):
- raise ValueError(
- f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
- )
- if device == "xla":
- self.device = xm.xla_device()
- elif device == "hpu":
- self.device = torch.device("hpu", torch.hpu.current_device())
- else:
- if device == "gpu":
- device = "cuda"
- device_module = getattr(torch, device)
- device_index = self.local_process_index % device_module.device_count()
- self.device = torch.device(device, device_index)
- device_module.set_device(self.device)
- def destroy_process_group(self, group=None):
- """
- Destroys the process group. If one is not specified, the default process group is destroyed.
- """
- if self.fork_launched and group is None:
- return
- # needed when using torch.distributed.init_process_group
- if torch.distributed.is_initialized():
- torch.distributed.destroy_process_group(group)
- def __getattr__(self, name: str):
- # By this point we know that no attributes of `self` contain `name`,
- # so we just modify the error message
- if name in self._known_attrs:
- raise AttributeError(
- f"`PartialState` object has no attribute `{name}`. "
- "This happens if `PartialState._reset_state()` was called and "
- "an `Accelerator` or `PartialState` was not reinitialized."
- )
- # Raise a typical AttributeError
- raise AttributeError(f"'PartialState' object has no attribute '{name}'")
- class AcceleratorState:
- """
- Singleton class that has information about the current training environment.
- **Available attributes:**
- - **device** (`torch.device`) -- The device to use.
- - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
- in use.
- - **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration for the
- current training environment. This is used to configure the distributed training environment.
- - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
- - **local_process_index** (`int`) -- The index of the current process on the current server.
- - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
- of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
- - **num_processes** (`int`) -- The number of processes currently launched in parallel.
- - **process_index** (`int`) -- The index of the current process.
- - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
- - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
- - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
- - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
- """
- _shared_state = SharedDict()
- _known_attrs = PartialState._known_attrs + [
- "deepspeed_plugin",
- "use_ipex",
- "fsdp_plugin",
- "megatron_lm_plugin",
- "dynamo_plugin",
- ]
- def __init__(
- self,
- mixed_precision: str | None = None,
- cpu: bool = False,
- dynamo_plugin=None,
- deepspeed_plugin=None,
- fsdp_plugin=None,
- torch_tp_plugin=None,
- megatron_lm_plugin=None,
- parallelism_config=None,
- _from_accelerator: bool = False,
- **kwargs,
- ):
- self.__dict__ = self._shared_state
- if parse_flag_from_env("ACCELERATE_USE_CPU"):
- cpu = True
- if PartialState._shared_state == {}:
- PartialState(cpu, **kwargs)
- self.__dict__.update(PartialState._shared_state)
- self._check_initialized(mixed_precision, cpu)
- if not self.initialized:
- self.deepspeed_plugins = None
- self.use_ipex = None
- self.torch_tp_plugin = torch_tp_plugin
- self.parallelism_config = parallelism_config
- self.device_mesh = None
- mixed_precision = (
- parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
- if mixed_precision is None
- else mixed_precision.lower()
- )
- if mixed_precision == "fp8":
- # this is confusing, why is is_fp8_available only checks for library availability ?
- if not is_fp8_available():
- raise ValueError(
- "Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed."
- )
- elif torch.cuda.is_available() and not check_cuda_fp8_capability():
- logger.warning(
- f"The current device has compute capability of {torch.cuda.get_device_capability()} which is "
- "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
- "or higher, compute capability of 8.9 or higher). Will use FP16 instead."
- )
- mixed_precision = "fp16"
- elif is_habana_gaudi1():
- logger.warning(
- "The current HPU device is Gaudi1 which does not support FP8 mixed precision training (requires "
- "Gaudi2 or higher). Will use BF16 instead."
- )
- mixed_precision = "bf16"
- self.dynamo_plugin = dynamo_plugin
- if not _from_accelerator:
- raise ValueError(
- "Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
- "before using any functionality from the `accelerate` library."
- )
- # deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
- # if we're using fp8.
- if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
- self._mixed_precision = "no"
- else:
- self._mixed_precision = mixed_precision
- if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
- if mixed_precision == "bf16":
- if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
- os.environ["XLA_USE_BF16"] = str(0)
- os.environ["XLA_DOWNCAST_BF16"] = str(1)
- self.downcast_bfloat = True
- else:
- os.environ["XLA_USE_BF16"] = str(1)
- os.environ["XLA_DOWNCAST_BF16"] = str(0)
- self.downcast_bfloat = False
- elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true" and not cpu:
- self.distributed_type = DistributedType.DEEPSPEED
- if not isinstance(deepspeed_plugin, dict):
- deepspeed_plugin.set_mixed_precision(mixed_precision)
- deepspeed_plugin.select(_from_accelerator_state=True)
- else:
- for plugin in deepspeed_plugin.values():
- plugin.set_mixed_precision(mixed_precision)
- # The first plugin passed in is always the active one
- first_plugin = next(iter(deepspeed_plugin.values()))
- first_plugin.select(_from_accelerator_state=True)
- self.deepspeed_plugins = deepspeed_plugin
- elif self.distributed_type in [
- DistributedType.MULTI_GPU,
- DistributedType.MULTI_MLU,
- DistributedType.MULTI_SDAA,
- DistributedType.MULTI_MUSA,
- DistributedType.MULTI_NPU,
- DistributedType.MULTI_XPU,
- DistributedType.MULTI_HPU,
- ]:
- # TODO: Siro - remove when axolotl fixes their side
- if not os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true":
- if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None:
- raise ValueError(
- "`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism with `cp_backend=torch`, as we also shard the model across the device mesh to save more memory"
- )
- if (
- self.parallelism_config is not None
- and self.parallelism_config.cp_enabled
- and fsdp_plugin.fsdp_version == 1
- ):
- raise ValueError(
- "Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
- )
- if (os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None) or (
- self.parallelism_config is not None and self.parallelism_config.cp_enabled
- ):
- self.distributed_type = DistributedType.FSDP
- if self._mixed_precision != "no" and fsdp_plugin is not None:
- fsdp_plugin.set_mixed_precision(self._mixed_precision)
- self.fsdp_plugin = fsdp_plugin
- if os.environ.get(
- "ACCELERATE_USE_MEGATRON_LM", "false"
- ).lower() == "true" and self.distributed_type not in [
- DistributedType.MULTI_XPU,
- ]:
- self.distributed_type = DistributedType.MEGATRON_LM
- megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
- self.megatron_lm_plugin = megatron_lm_plugin
- elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
- if is_ipex_available():
- # check if user disables it explicitly
- self.use_ipex = parse_flag_from_env("ACCELERATE_USE_IPEX", default=True)
- else:
- self.use_ipex = False
- if (
- self.dynamo_plugin.backend != DynamoBackend.NO
- and self._mixed_precision == "no"
- and self.device.type == "cuda"
- ):
- torch.backends.cuda.matmul.allow_tf32 = True
- if (
- self.dynamo_plugin.backend != DynamoBackend.NO
- and self._mixed_precision == "no"
- and self.device.type == "musa"
- ):
- torch.backends.musa.matmul.allow_tf32 = True
- PartialState._shared_state["distributed_type"] = self.distributed_type
- @property
- def initialized(self) -> bool:
- return self._shared_state != PartialState._shared_state
- def __repr__(self):
- repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n"
- if self.distributed_type == DistributedType.DEEPSPEED:
- repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
- return repr
- def _check_initialized(self, mixed_precision=None, cpu=None):
- "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
- if self.initialized:
- err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`."
- if cpu and self.device.type != "cpu":
- raise ValueError(err.format(flag="cpu=True"))
- if (
- mixed_precision is not None
- and mixed_precision != self._mixed_precision
- and self.distributed_type != DistributedType.DEEPSPEED
- ):
- raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'"))
- @property
- def mixed_precision(self):
- if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8":
- config = self.deepspeed_plugin.deepspeed_config
- if config.get("fp16", {}).get("enabled", False):
- mixed_precision = "fp16"
- elif config.get("bf16", {}).get("enabled", False):
- mixed_precision = "bf16"
- else:
- mixed_precision = "no"
- else:
- mixed_precision = self._mixed_precision
- return mixed_precision
- @staticmethod
- def _reset_state(reset_partial_state: bool = False):
- "Resets `_shared_state`, is used internally and should not be called"
- AcceleratorState._shared_state.clear()
- if reset_partial_state:
- PartialState._reset_state()
- def destroy_process_group(self, group=None):
- """
- Destroys the process group. If one is not specified, the default process group is destroyed.
- If `self.fork_launched` is `True` and `group` is `None`, nothing happens.
- """
- PartialState().destroy_process_group(group)
- @property
- def fork_launched(self):
- return PartialState().fork_launched
- @property
- def use_distributed(self):
- """
- Whether the Accelerator is configured for distributed training
- """
- return PartialState().use_distributed
- @property
- def is_fsdp2(self) -> bool:
- return self.distributed_type == DistributedType.FSDP and self.fsdp_plugin.fsdp_version == 2
- @property
- def is_last_process(self) -> bool:
- "Returns whether the current process is the last one"
- return PartialState().is_last_process
- @property
- def is_main_process(self) -> bool:
- "Returns whether the current process is the main process"
- return PartialState().is_main_process
- @property
- def is_local_main_process(self) -> bool:
- "Returns whether the current process is the main process on the local node"
- return PartialState().is_local_main_process
- def wait_for_everyone(self):
- PartialState().wait_for_everyone()
- @contextmanager
- def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
- """
- Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
- distributed inference, such as with different prompts.
- Note that when using a `dict`, all keys need to have the same number of elements.
- Args:
- inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):
- The input to split between processes.
- apply_padding (`bool`, `optional`, defaults to `False`):
- Whether to apply padding by repeating the last element of the input so that all processes have the same
- number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
- in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
- Example:
- ```python
- # Assume there are two processes
- from accelerate.state import AcceleratorState
- state = AcceleratorState()
- with state.split_between_processes(["A", "B", "C"]) as inputs:
- print(inputs)
- # Process 0
- ["A", "B"]
- # Process 1
- ["C"]
- with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
- print(inputs)
- # Process 0
- ["A", "B"]
- # Process 1
- ["C", "C"]
- ```
- """
- with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:
- yield inputs
- @contextmanager
- def main_process_first(self):
- """
- Lets the main process go first inside a with block.
- The other processes will enter the with block after the main process exits.
- """
- with PartialState().main_process_first():
- yield
- @contextmanager
- def local_main_process_first(self):
- """
- Lets the local main process go inside a with block.
- The other processes will enter the with block after the main process exits.
- """
- with PartialState().local_main_process_first():
- yield
- @property
- def deepspeed_plugin(self):
- """
- Returns the currently active DeepSpeedPlugin.
- If not using deepspeed, returns `None`.
- """
- # To maintain original behavior, return None if not using deepspeed.
- if self.distributed_type != DistributedType.DEEPSPEED:
- return None
- from accelerate.utils.deepspeed import get_active_deepspeed_plugin
- return get_active_deepspeed_plugin(self)
- @deepspeed_required
- def get_deepspeed_plugin(self, name: str):
- """
- Returns the DeepSpeedPlugin with the given plugin_key.
- """
- return self.deepspeed_plugins[name]
- @deepspeed_required
- def select_deepspeed_plugin(self, name: str | None = None):
- """
- Activates the DeepSpeedPlugin with the given `name`, and will disable all other plugins.
- """
- for key, plugin in self.deepspeed_plugins.items():
- if key != name:
- plugin._unselect()
- self.deepspeed_plugins[name].select(_from_accelerator_state=True)
- def print(self, *args, **kwargs):
- PartialState().print(*args, **kwargs)
- def __getattr__(self, name: str):
- # By this point we know that no attributes of `self` contain `name`,
- # so we just modify the error message
- if name in self._known_attrs:
- raise AttributeError(
- f"`AcceleratorState` object has no attribute `{name}`. "
- "This happens if `AcceleratorState._reset_state()` was called and "
- "an `Accelerator` or `PartialState` was not reinitialized."
- )
- # Raise a typical AttributeError
- raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
- class GradientState:
- """
- Singleton class that has information related to gradient synchronization for gradient accumulation
- **Available attributes:**
- - **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader
- - **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader
- - **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices
- - **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over
- - **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are
- being iterated over
- - **num_steps** (`int`) -- The number of steps to accumulate over
- - **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient
- accumulation
- - **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader
- iteration and the number of total steps reset
- - **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized
- as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently,
- after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence
- is_xla_gradients_synced is always true.
- """
- _shared_state = SharedDict()
- def __init__(self, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None):
- self.__dict__ = self._shared_state
- if not self.initialized:
- self.sync_gradients = True
- self._dataloader_references_ref = [None]
- self.plugin_kwargs = (
- gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {}
- )
- self._is_xla_gradients_synced = False
- # Plugin args are different and can be updated
- if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs():
- self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs()
- @property
- def num_steps(self) -> int:
- "Returns the number of steps to accumulate over"
- return self.plugin_kwargs.get("num_steps", 1)
- @property
- def adjust_scheduler(self) -> bool:
- "Returns whether the scheduler should be adjusted"
- return self.plugin_kwargs.get("adjust_scheduler", False)
- @property
- def sync_with_dataloader(self) -> bool:
- "Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset"
- return self.plugin_kwargs.get("sync_with_dataloader", True)
- @property
- def initialized(self) -> bool:
- "Returns whether the `GradientState` has been initialized"
- return GradientState._shared_state != {}
- @property
- def end_of_dataloader(self) -> bool:
- "Returns whether we have reached the end of the current dataloader"
- if not self.in_dataloader:
- return False
- return self.active_dataloader.end_of_dataloader
- @property
- def remainder(self) -> int:
- "Returns the number of extra samples that were added from padding the dataloader"
- if not self.in_dataloader:
- return -1
- return self.active_dataloader.remainder
- def __repr__(self):
- return (
- f"Sync Gradients: {self.sync_gradients}\n"
- f"At end of current dataloader: {self.end_of_dataloader}\n"
- f"Extra samples added: {self.remainder}\n"
- f"Gradient accumulation plugin: {self.plugin_kwargs}\n"
- )
- @property
- def is_xla_gradients_synced(self):
- "Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true."
- if parse_flag_from_env("ACCELERATE_USE_FSDP", default=False):
- return True
- return self._is_xla_gradients_synced
- @is_xla_gradients_synced.setter
- def is_xla_gradients_synced(self, is_synced):
- "Set the _is_xla_gradients_synced attribute."
- self._is_xla_gradients_synced = is_synced
- def _set_sync_gradients(self, sync_gradients):
- "Private function that sets whether gradients should be synchronized. Users should not have to call this."
- self.sync_gradients = sync_gradients
- # Allow grad-sync to automatically work on TPUs
- if (
- self.sync_gradients
- and is_torch_xla_available(check_is_tpu=True)
- and PartialState().distributed_type == DistributedType.XLA
- ):
- xm.mark_step()
- def _add_dataloader(self, dataloader):
- "Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this."
- # We explicitly use assignment to ensure that the property setter is triggered, which is required for garbage collection.
- # Avoid using self.dataloader_references.append as it will not trigger the setter.
- self.dataloader_references += [dataloader]
- def _remove_dataloader(self, dataloader):
- "Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this."
- # We explicitly use assignment to ensure that the property setter is triggered.
- self.dataloader_references = [
- dataloader_ref for dataloader_ref in self.dataloader_references if dataloader_ref != dataloader
- ]
- @property
- def active_dataloader(self):
- return self.dataloader_references[-1]
- @property
- def dataloader_references(self):
- # We use a property getter and setter with weakrefs to avoid circular references that prevent garbage collection
- return [reference() if reference is not None else reference for reference in self._dataloader_references_ref]
- @dataloader_references.setter
- def dataloader_references(self, references):
- self._dataloader_references_ref = [
- weakref.ref(dataloader) if dataloader is not None else dataloader for dataloader in references
- ]
- @property
- def in_dataloader(self) -> bool:
- "Returns whether the current process is in a dataloader"
- return self.active_dataloader is not None
- @staticmethod
- def _reset_state():
- "Resets `_shared_state`, is used internally and should not be called"
- GradientState._shared_state.clear()
|