parallelism_config.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. #
  2. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import warnings
  16. from dataclasses import dataclass
  17. from typing import TYPE_CHECKING, Literal, Optional, Union
  18. from accelerate.utils.dataclasses import (
  19. DeepSpeedSequenceParallelConfig,
  20. DistributedType,
  21. TorchContextParallelConfig,
  22. TorchTensorParallelConfig,
  23. )
  24. from accelerate.utils.versions import is_torch_version
  25. if TYPE_CHECKING:
  26. from accelerate import Accelerator
  27. @dataclass
  28. class ParallelismConfig:
  29. """
  30. A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
  31. https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
  32. Args:
  33. dp_replicate_size (`int`, defaults to `1`):
  34. The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
  35. group will not be used.
  36. dp_shard_size (`int`, defaults to `1`):
  37. The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
  38. be greater than 1, as composing DDP + TP is currently not supported.
  39. tp_size (`int`, defaults to `1`):
  40. The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
  41. used.
  42. tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
  43. The handler for the tensor parallel group.
  44. cp_size (`int`, defaults to `1`):
  45. The size of the context parallel group. Currently not supported, but reserved for future use and enabled
  46. for downstream libraries.
  47. cp_backend (`str`, defaults to `torch`):
  48. Which CP backend to use: `torch` (FSDP2)
  49. sp_size (`int`, defaults to `1`):
  50. The size of the sequence parallel group.
  51. sp_backend (`str`, defaults to `deepspeed`):
  52. Which SP backend to use:`deepspeed` (ALST/Ulysses)
  53. You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
  54. together:
  55. - `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
  56. - `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
  57. - `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
  58. `DistributedDataParallelKwargs` instead.
  59. """
  60. dp_replicate_size: Optional[int] = None
  61. dp_shard_size: Optional[int] = None
  62. tp_size: Optional[int] = None
  63. cp_size: Optional[int] = None
  64. cp_backend: Literal["torch"] = None
  65. sp_size: Optional[int] = None
  66. sp_backend: Literal["deepspeed"] = None
  67. # we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
  68. tp_handler: Union[None, TorchTensorParallelConfig] = None
  69. cp_handler: Union[None, TorchContextParallelConfig] = None
  70. sp_handler: Union[None, DeepSpeedSequenceParallelConfig] = None
  71. device_mesh = None
  72. def __repr__(self):
  73. return (
  74. "ParallelismConfig(\n "
  75. f"\tdp_replicate_size={self.dp_replicate_size},\n"
  76. f"\tdp_shard_size={self.dp_shard_size},\n"
  77. f"\ttp_size={self.tp_size},\n"
  78. f"\tcp_size={self.cp_size},\n"
  79. f"\tcp_backend={self.cp_backend},\n"
  80. f"\tsp_size={self.sp_size},\n"
  81. f"\tsp_backend={self.sp_backend},\n"
  82. f"\ttotal_size={self.total_size}\n"
  83. f"\ttp_handler={self.tp_handler},\n"
  84. f"\tcp_handler={self.cp_handler})\n"
  85. )
  86. def to_json(self):
  87. import copy
  88. _non_serializable_fields = ["device_mesh"]
  89. copy.deepcopy(
  90. {
  91. k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
  92. for k, v in self.__dict__.items()
  93. if k not in _non_serializable_fields
  94. }
  95. )
  96. @property
  97. def dp_dim_names(self):
  98. """Names of enabled dimensions across which data parallelism is applied."""
  99. dims = []
  100. if self.dp_replicate_enabled:
  101. dims += ["dp_replicate"]
  102. if self.dp_shard_enabled:
  103. dims += ["dp_shard"]
  104. return dims
  105. @property
  106. def non_dp_dim_names(self):
  107. """Names of enabled dimensions which will receive the same batch (non-data parallel dimensions)."""
  108. dims = []
  109. if self.tp_enabled:
  110. dims += ["tp"]
  111. if self.cp_enabled:
  112. dims += ["cp"]
  113. if self.sp_enabled:
  114. dims += ["sp"]
  115. return dims
  116. @property
  117. def dp_shard_cp_dim_names(self):
  118. """Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
  119. dims = []
  120. if self.dp_shard_enabled:
  121. dims += ["dp_shard"]
  122. if self.cp_enabled:
  123. dims += ["cp"]
  124. return dims
  125. @property
  126. def dp_cp_dim_names(self):
  127. """Names of enabled dimensions across which loss should be averaged"""
  128. dims = []
  129. if self.dp_replicate_enabled:
  130. dims += ["dp_replicate"]
  131. if self.dp_shard_enabled:
  132. dims += ["dp_shard"]
  133. if self.cp_enabled:
  134. dims += ["cp"]
  135. return dims
  136. @property
  137. def fsdp_dim_names(self):
  138. """Names of enabled dimensions across which FSDP is applied, including data parallel replication."""
  139. dims = []
  140. if self.dp_replicate_enabled:
  141. dims += ["dp_replicate"]
  142. dims += ["dp_shard_cp"]
  143. return dims
  144. @property
  145. def total_size(self):
  146. """The total size of the parallelism configuration, which is the product of all sizes."""
  147. return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size * self.sp_size
  148. @property
  149. def non_data_parallel_size(self):
  150. """The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
  151. return self.tp_size * self.cp_size * self.sp_size
  152. @property
  153. def data_parallel_size(self):
  154. """The size of the data parallel dimensions, which is the product of data parallel replication and"""
  155. return self.dp_replicate_size * self.dp_shard_size
  156. @property
  157. def dp_replicate_enabled(self):
  158. """True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`."""
  159. return self.dp_replicate_size > 1
  160. @property
  161. def dp_shard_enabled(self):
  162. """True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`."""
  163. return self.dp_shard_size > 1
  164. @property
  165. def tp_enabled(self):
  166. """True if tensor parallelism is enabled, i.e. `tp_size > 1`."""
  167. return self.tp_size > 1
  168. @property
  169. def cp_enabled(self):
  170. """True if context parallelism is enabled, i.e. `cp_size > 1`."""
  171. return self.cp_size > 1
  172. @property
  173. def sp_enabled(self):
  174. """True if context parallelism is enabled, i.e. `sp_size > 1`."""
  175. return self.sp_size > 1
  176. @property
  177. def active_mesh_dims(self):
  178. """Names of all active mesh dimensions."""
  179. return self.dp_dim_names + self.non_dp_dim_names
  180. def build_device_mesh(self, device_type: str):
  181. """Builds a device mesh for the given device type based on the parallelism configuration.
  182. This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).
  183. Args:
  184. device_type (`str`): The type of device for which to build the mesh, e
  185. """
  186. if is_torch_version(">=", "2.2.0"):
  187. from torch.distributed.device_mesh import init_device_mesh
  188. else:
  189. raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
  190. mesh = self._get_mesh()
  191. if len(mesh) == 0:
  192. return None
  193. mesh_dim_names, mesh_shape = mesh
  194. device_mesh = init_device_mesh(
  195. device_type,
  196. mesh_shape,
  197. mesh_dim_names=mesh_dim_names,
  198. )
  199. if self.dp_dim_names:
  200. device_mesh[self.dp_dim_names]._flatten("dp")
  201. if self.dp_shard_cp_dim_names:
  202. device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
  203. if self.dp_cp_dim_names:
  204. device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")
  205. return device_mesh
  206. def get_device_mesh(self, device_type: Optional[str] = None):
  207. if self.device_mesh is None:
  208. if device_type is not None:
  209. self.device_mesh = self.build_device_mesh(device_type)
  210. else:
  211. raise ("You need to pass a device_type e.g cuda to build the device mesh")
  212. else:
  213. if device_type is not None:
  214. if self.device_mesh.device_type != device_type:
  215. raise ValueError(
  216. f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh"
  217. )
  218. return self.device_mesh
  219. def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
  220. """Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
  221. # Build mesh dimensions dictionary
  222. mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}
  223. # Apply canonical ordering
  224. mesh_order = ["dp_replicate", "dp_shard", "cp", "sp", "tp"]
  225. sorted_items = sorted(
  226. mesh_dims.items(),
  227. key=lambda x: (mesh_order.index(x[0])),
  228. )
  229. return tuple(zip(*sorted_items))
  230. def __post_init__(self):
  231. # Basic size validation
  232. if self.dp_replicate_size is None:
  233. self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
  234. if self.dp_shard_size is None:
  235. self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
  236. if self.tp_size is None:
  237. self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
  238. if self.cp_size is None:
  239. self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
  240. if self.cp_backend is None:
  241. self.cp_backend = os.environ.get("PARALLELISM_CONFIG_CP_BACKEND", "torch")
  242. if self.sp_size is None:
  243. self.sp_size = int(os.environ.get("PARALLELISM_CONFIG_SP_SIZE", "1"))
  244. if self.sp_backend is None:
  245. self.sp_backend = os.environ.get("PARALLELISM_CONFIG_SP_BACKEND", "deepspeed")
  246. if self.tp_size > 1:
  247. if self.tp_handler is None:
  248. self.tp_handler = TorchTensorParallelConfig()
  249. if self.cp_size > 1:
  250. if self.cp_handler is None:
  251. self.cp_handler = TorchContextParallelConfig()
  252. else:
  253. cp_backends_config_map = dict(
  254. torch=TorchContextParallelConfig,
  255. )
  256. if not isinstance(self.cp_handler, cp_backends_config_map[self.cp_backend]):
  257. raise ValueError(
  258. f"ParallelismConfig's cp_backend={self.cp_backend} requires {cp_backends_config_map[self.cp_backend]}, but cp_handler was set to {type(self.cp_handler)}"
  259. )
  260. if self.sp_size > 1:
  261. if self.sp_handler is None:
  262. self.sp_handler = DeepSpeedSequenceParallelConfig()
  263. if self.dp_replicate_size < 1:
  264. raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
  265. if self.dp_shard_size < 1:
  266. raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}")
  267. if self.tp_size < 1:
  268. raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}")
  269. if self.cp_size < 1:
  270. raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}")
  271. valid_cp_backends = ["torch"]
  272. if self.cp_backend not in valid_cp_backends:
  273. raise ValueError(f"cp_backend must be one of {valid_cp_backends}, but got {self.cp_backend}")
  274. if self.sp_size < 1:
  275. raise ValueError(f"sp_size must be at least 1, but got {self.sp_size}")
  276. valid_sp_backends = ["deepspeed"]
  277. if self.sp_backend not in valid_sp_backends:
  278. raise ValueError(f"sp_backend must be one of {valid_sp_backends}, but got {self.sp_backend}")
  279. if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
  280. raise ValueError(
  281. "Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
  282. "Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, "
  283. "or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel."
  284. )
  285. self._sizes = {
  286. "dp_replicate": self.dp_replicate_size,
  287. "dp_shard": self.dp_shard_size,
  288. "tp": self.tp_size,
  289. "cp": self.cp_size,
  290. "sp": self.sp_size,
  291. }
  292. def _set_size(self, parallelism: str, size: int):
  293. assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}"
  294. self._sizes[parallelism] = size
  295. setattr(self, f"{parallelism}_size", size)
  296. def _validate_accelerator(self, accelerator: "Accelerator"):
  297. _warnings = set()
  298. if not accelerator.multi_device and self.total_size == 1:
  299. # No distributed setup, valid parallelism config
  300. return
  301. # We need this to ensure DDP works
  302. if self.total_size == 1:
  303. self._set_size("dp_replicate", accelerator.num_processes)
  304. if self.total_size != accelerator.num_processes:
  305. raise ValueError(
  306. f"ParallelismConfig total_size ({self.total_size}) does not match "
  307. f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
  308. f"dp_shard_size/tp_size/cp_size/sp_size."
  309. )
  310. if self.total_size > 1 and not (
  311. accelerator.is_fsdp2
  312. or accelerator.multi_device
  313. or accelerator.distributed_type == DistributedType.DEEPSPEED
  314. ):
  315. raise ValueError(
  316. f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}} or DistributedType.DEEPSPEED, but got {accelerator.distributed_type}."
  317. )
  318. for parallelism, size in self._sizes.items():
  319. if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
  320. _warnings.add(
  321. f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
  322. )
  323. if _warnings and accelerator.is_main_process:
  324. warnings.warn(
  325. "ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
  326. UserWarning,
  327. )