| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378 |
- #
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import warnings
- from dataclasses import dataclass
- from typing import TYPE_CHECKING, Literal, Optional, Union
- from accelerate.utils.dataclasses import (
- DeepSpeedSequenceParallelConfig,
- DistributedType,
- TorchContextParallelConfig,
- TorchTensorParallelConfig,
- )
- from accelerate.utils.versions import is_torch_version
- if TYPE_CHECKING:
- from accelerate import Accelerator
- @dataclass
- class ParallelismConfig:
- """
- A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
- https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
- Args:
- dp_replicate_size (`int`, defaults to `1`):
- The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
- group will not be used.
- dp_shard_size (`int`, defaults to `1`):
- The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
- be greater than 1, as composing DDP + TP is currently not supported.
- tp_size (`int`, defaults to `1`):
- The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
- used.
- tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
- The handler for the tensor parallel group.
- cp_size (`int`, defaults to `1`):
- The size of the context parallel group. Currently not supported, but reserved for future use and enabled
- for downstream libraries.
- cp_backend (`str`, defaults to `torch`):
- Which CP backend to use: `torch` (FSDP2)
- sp_size (`int`, defaults to `1`):
- The size of the sequence parallel group.
- sp_backend (`str`, defaults to `deepspeed`):
- Which SP backend to use:`deepspeed` (ALST/Ulysses)
- You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
- together:
- - `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
- - `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
- - `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
- `DistributedDataParallelKwargs` instead.
- """
- dp_replicate_size: Optional[int] = None
- dp_shard_size: Optional[int] = None
- tp_size: Optional[int] = None
- cp_size: Optional[int] = None
- cp_backend: Literal["torch"] = None
- sp_size: Optional[int] = None
- sp_backend: Literal["deepspeed"] = None
- # we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
- tp_handler: Union[None, TorchTensorParallelConfig] = None
- cp_handler: Union[None, TorchContextParallelConfig] = None
- sp_handler: Union[None, DeepSpeedSequenceParallelConfig] = None
- device_mesh = None
- def __repr__(self):
- return (
- "ParallelismConfig(\n "
- f"\tdp_replicate_size={self.dp_replicate_size},\n"
- f"\tdp_shard_size={self.dp_shard_size},\n"
- f"\ttp_size={self.tp_size},\n"
- f"\tcp_size={self.cp_size},\n"
- f"\tcp_backend={self.cp_backend},\n"
- f"\tsp_size={self.sp_size},\n"
- f"\tsp_backend={self.sp_backend},\n"
- f"\ttotal_size={self.total_size}\n"
- f"\ttp_handler={self.tp_handler},\n"
- f"\tcp_handler={self.cp_handler})\n"
- )
- def to_json(self):
- import copy
- _non_serializable_fields = ["device_mesh"]
- copy.deepcopy(
- {
- k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
- for k, v in self.__dict__.items()
- if k not in _non_serializable_fields
- }
- )
- @property
- def dp_dim_names(self):
- """Names of enabled dimensions across which data parallelism is applied."""
- dims = []
- if self.dp_replicate_enabled:
- dims += ["dp_replicate"]
- if self.dp_shard_enabled:
- dims += ["dp_shard"]
- return dims
- @property
- def non_dp_dim_names(self):
- """Names of enabled dimensions which will receive the same batch (non-data parallel dimensions)."""
- dims = []
- if self.tp_enabled:
- dims += ["tp"]
- if self.cp_enabled:
- dims += ["cp"]
- if self.sp_enabled:
- dims += ["sp"]
- return dims
- @property
- def dp_shard_cp_dim_names(self):
- """Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
- dims = []
- if self.dp_shard_enabled:
- dims += ["dp_shard"]
- if self.cp_enabled:
- dims += ["cp"]
- return dims
- @property
- def dp_cp_dim_names(self):
- """Names of enabled dimensions across which loss should be averaged"""
- dims = []
- if self.dp_replicate_enabled:
- dims += ["dp_replicate"]
- if self.dp_shard_enabled:
- dims += ["dp_shard"]
- if self.cp_enabled:
- dims += ["cp"]
- return dims
- @property
- def fsdp_dim_names(self):
- """Names of enabled dimensions across which FSDP is applied, including data parallel replication."""
- dims = []
- if self.dp_replicate_enabled:
- dims += ["dp_replicate"]
- dims += ["dp_shard_cp"]
- return dims
- @property
- def total_size(self):
- """The total size of the parallelism configuration, which is the product of all sizes."""
- return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size * self.sp_size
- @property
- def non_data_parallel_size(self):
- """The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
- return self.tp_size * self.cp_size * self.sp_size
- @property
- def data_parallel_size(self):
- """The size of the data parallel dimensions, which is the product of data parallel replication and"""
- return self.dp_replicate_size * self.dp_shard_size
- @property
- def dp_replicate_enabled(self):
- """True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`."""
- return self.dp_replicate_size > 1
- @property
- def dp_shard_enabled(self):
- """True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`."""
- return self.dp_shard_size > 1
- @property
- def tp_enabled(self):
- """True if tensor parallelism is enabled, i.e. `tp_size > 1`."""
- return self.tp_size > 1
- @property
- def cp_enabled(self):
- """True if context parallelism is enabled, i.e. `cp_size > 1`."""
- return self.cp_size > 1
- @property
- def sp_enabled(self):
- """True if context parallelism is enabled, i.e. `sp_size > 1`."""
- return self.sp_size > 1
- @property
- def active_mesh_dims(self):
- """Names of all active mesh dimensions."""
- return self.dp_dim_names + self.non_dp_dim_names
- def build_device_mesh(self, device_type: str):
- """Builds a device mesh for the given device type based on the parallelism configuration.
- This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).
- Args:
- device_type (`str`): The type of device for which to build the mesh, e
- """
- if is_torch_version(">=", "2.2.0"):
- from torch.distributed.device_mesh import init_device_mesh
- else:
- raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
- mesh = self._get_mesh()
- if len(mesh) == 0:
- return None
- mesh_dim_names, mesh_shape = mesh
- device_mesh = init_device_mesh(
- device_type,
- mesh_shape,
- mesh_dim_names=mesh_dim_names,
- )
- if self.dp_dim_names:
- device_mesh[self.dp_dim_names]._flatten("dp")
- if self.dp_shard_cp_dim_names:
- device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
- if self.dp_cp_dim_names:
- device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")
- return device_mesh
- def get_device_mesh(self, device_type: Optional[str] = None):
- if self.device_mesh is None:
- if device_type is not None:
- self.device_mesh = self.build_device_mesh(device_type)
- else:
- raise ("You need to pass a device_type e.g cuda to build the device mesh")
- else:
- if device_type is not None:
- if self.device_mesh.device_type != device_type:
- raise ValueError(
- f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh"
- )
- return self.device_mesh
- def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
- """Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
- # Build mesh dimensions dictionary
- mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}
- # Apply canonical ordering
- mesh_order = ["dp_replicate", "dp_shard", "cp", "sp", "tp"]
- sorted_items = sorted(
- mesh_dims.items(),
- key=lambda x: (mesh_order.index(x[0])),
- )
- return tuple(zip(*sorted_items))
- def __post_init__(self):
- # Basic size validation
- if self.dp_replicate_size is None:
- self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
- if self.dp_shard_size is None:
- self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
- if self.tp_size is None:
- self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
- if self.cp_size is None:
- self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
- if self.cp_backend is None:
- self.cp_backend = os.environ.get("PARALLELISM_CONFIG_CP_BACKEND", "torch")
- if self.sp_size is None:
- self.sp_size = int(os.environ.get("PARALLELISM_CONFIG_SP_SIZE", "1"))
- if self.sp_backend is None:
- self.sp_backend = os.environ.get("PARALLELISM_CONFIG_SP_BACKEND", "deepspeed")
- if self.tp_size > 1:
- if self.tp_handler is None:
- self.tp_handler = TorchTensorParallelConfig()
- if self.cp_size > 1:
- if self.cp_handler is None:
- self.cp_handler = TorchContextParallelConfig()
- else:
- cp_backends_config_map = dict(
- torch=TorchContextParallelConfig,
- )
- if not isinstance(self.cp_handler, cp_backends_config_map[self.cp_backend]):
- raise ValueError(
- f"ParallelismConfig's cp_backend={self.cp_backend} requires {cp_backends_config_map[self.cp_backend]}, but cp_handler was set to {type(self.cp_handler)}"
- )
- if self.sp_size > 1:
- if self.sp_handler is None:
- self.sp_handler = DeepSpeedSequenceParallelConfig()
- if self.dp_replicate_size < 1:
- raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
- if self.dp_shard_size < 1:
- raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}")
- if self.tp_size < 1:
- raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}")
- if self.cp_size < 1:
- raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}")
- valid_cp_backends = ["torch"]
- if self.cp_backend not in valid_cp_backends:
- raise ValueError(f"cp_backend must be one of {valid_cp_backends}, but got {self.cp_backend}")
- if self.sp_size < 1:
- raise ValueError(f"sp_size must be at least 1, but got {self.sp_size}")
- valid_sp_backends = ["deepspeed"]
- if self.sp_backend not in valid_sp_backends:
- raise ValueError(f"sp_backend must be one of {valid_sp_backends}, but got {self.sp_backend}")
- if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
- raise ValueError(
- "Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
- "Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, "
- "or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel."
- )
- self._sizes = {
- "dp_replicate": self.dp_replicate_size,
- "dp_shard": self.dp_shard_size,
- "tp": self.tp_size,
- "cp": self.cp_size,
- "sp": self.sp_size,
- }
- def _set_size(self, parallelism: str, size: int):
- assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}"
- self._sizes[parallelism] = size
- setattr(self, f"{parallelism}_size", size)
- def _validate_accelerator(self, accelerator: "Accelerator"):
- _warnings = set()
- if not accelerator.multi_device and self.total_size == 1:
- # No distributed setup, valid parallelism config
- return
- # We need this to ensure DDP works
- if self.total_size == 1:
- self._set_size("dp_replicate", accelerator.num_processes)
- if self.total_size != accelerator.num_processes:
- raise ValueError(
- f"ParallelismConfig total_size ({self.total_size}) does not match "
- f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
- f"dp_shard_size/tp_size/cp_size/sp_size."
- )
- if self.total_size > 1 and not (
- accelerator.is_fsdp2
- or accelerator.multi_device
- or accelerator.distributed_type == DistributedType.DEEPSPEED
- ):
- raise ValueError(
- f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}} or DistributedType.DEEPSPEED, but got {accelerator.distributed_type}."
- )
- for parallelism, size in self._sizes.items():
- if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
- _warnings.add(
- f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
- )
- if _warnings and accelerator.is_main_process:
- warnings.warn(
- "ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
- UserWarning,
- )
|