tensor_parallel.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import math
  16. import operator
  17. import os
  18. import re
  19. from functools import partial, reduce
  20. import torch
  21. import torch.distributed as dist
  22. from torch import nn
  23. from ..distributed import DistributedConfig
  24. from ..utils import is_torch_greater_or_equal, logging
  25. from ..utils.generic import GeneralInterface
  26. logger = logging.get_logger(__name__)
  27. # Cache this result has it's a C FFI call which can be pretty time-consuming
  28. _torch_distributed_available = torch.distributed.is_available()
  29. if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
  30. from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
  31. def initialize_tensor_parallelism(tp_plan, tp_size=None):
  32. r"""
  33. Sets up the device mesh and initialized the backend for tensor parallelism.
  34. This function is called when the model is loaded and the TP plan is set to 'auto'.
  35. """
  36. if tp_plan is None:
  37. return None, None, None
  38. if not is_torch_greater_or_equal("2.5"):
  39. raise OSError("Tensor parallel is only supported for `torch>=2.5`.")
  40. # Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
  41. device_type = torch._C._get_accelerator().type
  42. current_device = getattr(torch, device_type)
  43. if not torch.distributed.is_initialized():
  44. try:
  45. rank = int(os.environ["RANK"])
  46. local_rank = int(os.environ["LOCAL_RANK"])
  47. world_size = int(os.environ["WORLD_SIZE"])
  48. backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
  49. backend = backend_map.get(device_type)
  50. if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", "0")):
  51. backend = "ccl"
  52. if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
  53. backend = "ccl"
  54. torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
  55. current_device = getattr(torch, device_type)
  56. if device_type != "cpu":
  57. current_device.set_device(local_rank)
  58. except Exception as e:
  59. raise OSError(
  60. "We tried to initialize torch.distributed for you, but it failed. Make "
  61. "sure you init torch distributed in your script to use `tp_plan='auto'`."
  62. ) from e
  63. if device_type != "cpu":
  64. current_device.set_device(int(os.environ["LOCAL_RANK"]))
  65. index = current_device.current_device() if device_type != "cpu" else None
  66. tp_device = torch.device(device_type, index)
  67. # Silence output for non-primary ranks
  68. if index is not None and index > 0:
  69. import sys
  70. sys.stdout = open(os.devnull, "w")
  71. sys.stderr = open(os.devnull, "w")
  72. device_map = tp_device
  73. tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
  74. device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
  75. return tp_device, device_map, device_mesh, tp_size
  76. def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
  77. """
  78. Convert block count or proportions to block sizes.
  79. This function accepts
  80. - The number of blocks (int), in which case the block size is
  81. total_size//blocks; or
  82. - A list of block sizes (list[int]).
  83. In the second case, if sum(blocks) < total_size, the ratios between
  84. the block sizes will be preserved. For instance, if blocks is
  85. [2, 1, 1] and total_size is 1024, the returned block sizes are
  86. [512, 256, 256].
  87. """
  88. if isinstance(blocks, list):
  89. total_blocks = sum(blocks)
  90. assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
  91. part_size = total_size // total_blocks
  92. return [part_size * block for block in blocks]
  93. else:
  94. assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
  95. single_size = total_size // blocks
  96. return [single_size] * blocks
  97. def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
  98. """
  99. Get the TP style for a parameter from the TP plan.
  100. The TP plan is a dictionary that maps parameter names to TP styles.
  101. The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
  102. The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
  103. not parent classes for `post_init` calls
  104. """
  105. generic_param_name = re.sub(r"\d+", "*", parameter_name)
  106. if generic_param_name in tp_plan:
  107. return tp_plan[generic_param_name]
  108. elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
  109. return tp_plan[generic_param_name.rsplit(".", 1)[0]]
  110. return None
  111. str_to_dtype = {
  112. "BOOL": torch.bool,
  113. "U8": torch.uint8,
  114. "I8": torch.int8,
  115. "I16": torch.int16,
  116. "F16": torch.float16,
  117. "BF16": torch.bfloat16,
  118. "I32": torch.int32,
  119. "F32": torch.float32,
  120. "F64": torch.float64,
  121. "I64": torch.int64,
  122. "F8_E4M3": torch.float8_e4m3fn,
  123. }
  124. def get_packed_weights(param, empty_param, device_mesh, rank, dim):
  125. """
  126. When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
  127. So if you have: gate_proj ( 16, 5120, 8190)
  128. and up_proj ( 16, 5120, 8190)
  129. packed as gate_up_proj ( 16, 5120, 2 * 8190)
  130. And you shard along the last dimension, you need to interleave the gate and up values:
  131. Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly.
  132. Let's take TP_size = 4 for an example:
  133. Packed tensor `gate_up_proj`
  134. ---------------------------------------------------------------
  135. [ G0 G1 G2 G3 | G4 G5 G6 G7 | ... | U0 U1 U2 U3 | U4 U5 U6 U7 | ... ]
  136. ↑─────────────↑ ↑─────────────↑ ↑─────────────↑ ↑─────────────↑
  137. Gate Slice 0 Gate Slice 1 Up Slice 0 Up Slice 1
  138. Explanation:
  139. - The first half of the tensor (left of the center) holds the gate_proj values.
  140. - The second half (right of the center) holds the up_proj values.
  141. - For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity.
  142. - Each shard receives one slice from the gate part and the corresponding slice from the up part.
  143. For instance:
  144. • Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ]
  145. • Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ]
  146. • … and so on.
  147. This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism.
  148. """
  149. slice_ = param
  150. total_size = empty_param.shape[dim]
  151. world_size = device_mesh.size()
  152. block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2)
  153. tensors_slices = []
  154. block_offset = 0
  155. for block_size in block_sizes:
  156. shard_block_size = block_size // world_size
  157. start = rank * shard_block_size
  158. stop = (rank + 1) * shard_block_size
  159. tensors_slices += range(block_offset + start, block_offset + stop)
  160. block_offset += block_size
  161. slice_dtype = slice_.get_dtype()
  162. # Handle F8_E4M3 dtype by converting to float16 before slicing
  163. # Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
  164. casted = False
  165. if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2":
  166. slice_ = slice_[...].to(torch.float16)
  167. casted = True
  168. if dim == 0:
  169. tensor = slice_[tensors_slices, ...]
  170. elif dim == 1 or dim == -2:
  171. tensor = slice_[:, tensors_slices, ...]
  172. elif dim == 2 or dim == -1:
  173. tensor = slice_[..., tensors_slices]
  174. else:
  175. raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
  176. if casted:
  177. return tensor
  178. else:
  179. return tensor.to(str_to_dtype[slice_dtype])
  180. def repack_weights(
  181. packed_parameter: torch.Tensor,
  182. sharded_dim: int, # The dimension index in the global tensor that was sharded
  183. world_size: int,
  184. num_blocks: int = 2,
  185. ) -> torch.Tensor:
  186. """
  187. Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
  188. For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
  189. DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
  190. along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
  191. This is an inverse operation to get_packed_weights.
  192. Args:
  193. reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
  194. sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
  195. world_size: The tensor parallel world size.
  196. num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
  197. Returns:
  198. The reordered tensor in canonical packed format.
  199. """
  200. if num_blocks != 2:
  201. raise ValueError(
  202. "Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
  203. )
  204. actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
  205. total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
  206. original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
  207. shard_chunk_size = original_block_size_on_dim // world_size
  208. prefix_shape = packed_parameter.shape[:actual_sharded_dim]
  209. suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
  210. tensor_view = packed_parameter.view(
  211. *prefix_shape,
  212. world_size,
  213. num_blocks,
  214. shard_chunk_size,
  215. *suffix_shape,
  216. )
  217. # Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
  218. # This groups all chunks of G together, then all chunks of U together.
  219. # Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
  220. # Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
  221. # Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
  222. axis_ws_abs = len(prefix_shape)
  223. axis_npp_abs = len(prefix_shape) + 1
  224. permute_order = list(range(tensor_view.ndim))
  225. permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
  226. tensor_permuted = tensor_view.permute(*permute_order)
  227. # Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
  228. # The final shape should be the same as reconstructed_tensor.
  229. final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
  230. return final_ordered_tensor
  231. def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
  232. """
  233. Generalized tensor sharding across a multi-dimensional device mesh.
  234. Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
  235. Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
  236. `Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
  237. such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
  238. Case (1)
  239. empty_param (16, 5120, 8190)
  240. dim 0
  241. device_mesh.size() 4
  242. rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
  243. rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
  244. rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
  245. rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
  246. Case (2)
  247. empty_param (16, 5120, 8190)
  248. dim 0
  249. device_mesh.size() 14
  250. rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
  251. rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
  252. rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
  253. rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
  254. rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
  255. rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
  256. rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
  257. rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
  258. rank 8 gets (0, 5120, 8190)
  259. rank 9 gets (0, 5120, 8190)
  260. rank 10 gets (0, 5120, 8190)
  261. rank 11 gets (0, 5120, 8190)
  262. rank 12 gets (0, 5120, 8190)
  263. rank 13 gets (0, 5120, 8190)
  264. Case (3)
  265. empty_param (16, 5120, 8190)
  266. dim 0
  267. device_mesh.size() 3
  268. rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
  269. rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
  270. rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
  271. In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
  272. Args:
  273. param (torch.Tensor): The tensor to shard.
  274. empty_param (torch.Tensor): A tensor used for shape reference.
  275. device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh.
  276. rank (int): Global rank of the current process/device.
  277. dim (int): Dimension along which to shard the tensor.
  278. """
  279. param_dim = empty_param.dim()
  280. if dim < 0:
  281. dim = param_dim + dim
  282. if dim >= param_dim:
  283. raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
  284. # Flatten the mesh to get the total number of devices
  285. mesh_shape = device_mesh.shape
  286. world_size = reduce(operator.mul, mesh_shape)
  287. if rank >= world_size:
  288. raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
  289. shard_size = math.ceil(empty_param.shape[dim] / world_size)
  290. start = rank * shard_size
  291. # Construct slicing index dynamically
  292. end = min(start + shard_size, empty_param.shape[dim])
  293. slice_indices = [slice(None)] * param_dim
  294. if start < empty_param.shape[dim]:
  295. slice_indices[dim] = slice(start, end)
  296. return param[tuple(slice_indices)]
  297. dimensions = list(param.shape)
  298. dimensions[dim] = 0
  299. return torch.empty(tuple(dimensions), dtype=torch.int64)
  300. def distribute_module(
  301. module: nn.Module,
  302. device_mesh=None,
  303. input_fn=None,
  304. output_fn=None,
  305. ) -> nn.Module:
  306. """
  307. Copy pasted from torch's function but we remove the communications (partitioning)
  308. as well as buffer registering that is similarly not efficient.
  309. """
  310. if len(module._forward_pre_hooks) == 0:
  311. if input_fn is not None:
  312. module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
  313. if output_fn is not None:
  314. module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
  315. return module
  316. class TensorParallelLayer:
  317. """
  318. General tensor parallel layer for transformers.
  319. """
  320. use_dtensor = True
  321. @staticmethod
  322. def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
  323. @staticmethod
  324. def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): ...
  325. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  326. raise NotImplementedError
  327. def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
  328. if self.use_dtensor:
  329. distribute_module(
  330. module,
  331. device_mesh,
  332. partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
  333. partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
  334. )
  335. # use_dtensor needs to be set to false for nn.Parameter when you want to view, chunk, slice
  336. # you name it. Whatever you want to do that is a bit unconventional, you need local tensors
  337. class GatherParallel(TensorParallelLayer):
  338. """
  339. Simple class used to define the hooks to add to a layer when we just want to gather the outputs
  340. """
  341. def __init__(
  342. self,
  343. *,
  344. input_layouts: Placement | None = None,
  345. output_layouts: Placement | None = None,
  346. use_local_output: bool = True,
  347. ):
  348. super().__init__()
  349. self.input_layouts = (input_layouts or Replicate(),)
  350. self.output_layouts = output_layouts
  351. self.desired_input_layouts = (Replicate(),)
  352. self.use_local_output = use_local_output
  353. @staticmethod
  354. def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
  355. mod.expert_parallel_group = device_mesh.get_group()
  356. if inputs and isinstance(inputs[0], DTensor):
  357. inputs = inputs[0].to_local()
  358. return inputs
  359. @staticmethod
  360. def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
  361. if isinstance(outputs, torch.Tensor):
  362. dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False)
  363. else:
  364. dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
  365. return outputs
  366. def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
  367. distribute_module(
  368. module,
  369. device_mesh,
  370. partial(self._prepare_input_fn, None, None),
  371. partial(self._prepare_output_fn, None, None),
  372. )
  373. class IsolatedParallel(TensorParallelLayer):
  374. """
  375. This class is used to isolate computation in a TP layer from the rest of the world.
  376. Parameters need to be LOCAL, so not dtensors
  377. """
  378. @staticmethod
  379. def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh=None):
  380. # annotate module input placements/sharding with input_layouts
  381. input_tensor = inputs[0]
  382. if isinstance(input_tensor, DTensor):
  383. input_tensor = input_tensor.to_local()
  384. return input_tensor
  385. @staticmethod
  386. def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh=None):
  387. # TODO: figure out dynamo support for instance method and switch this to instance method
  388. return outputs
  389. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  390. param = param[...].to(param_casting_dtype)
  391. if to_contiguous:
  392. param = param.contiguous()
  393. param = param / device_mesh.size() # TODO should be optionable
  394. # TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel)
  395. return param
  396. def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
  397. distribute_module(
  398. module,
  399. device_mesh,
  400. partial(self._prepare_input_fn, None, None),
  401. partial(self._prepare_output_fn, None, None),
  402. )
  403. class ReplicateParallel(TensorParallelLayer):
  404. """
  405. This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
  406. """
  407. def __init__(self, *, use_dtensor=True, use_local_output=True):
  408. super().__init__()
  409. self.input_layouts = (Replicate(),)
  410. self.output_layouts = (Replicate(),)
  411. self.desired_input_layouts = (Replicate(),)
  412. self.use_local_output = use_local_output
  413. self.use_dtensor = use_dtensor
  414. @staticmethod
  415. def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
  416. # TODO: figure out dynamo support for instance method and switch this to instance method
  417. # annotate module input placements/sharding with input_layouts
  418. input_tensor = inputs[0]
  419. if not isinstance(input_tensor, DTensor):
  420. input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
  421. return input_tensor
  422. @staticmethod
  423. def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
  424. return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
  425. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  426. param = param[...].to(param_casting_dtype)
  427. if to_contiguous:
  428. param = param.contiguous()
  429. param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
  430. return param
  431. class ColwiseParallel(TensorParallelLayer):
  432. """
  433. General tensor parallel layer for transformers.
  434. """
  435. def __init__(
  436. self,
  437. *,
  438. input_layouts: Placement | None = None,
  439. output_layouts: Placement | None = None,
  440. use_local_output: bool = True,
  441. use_dtensor=True,
  442. ):
  443. super().__init__()
  444. self.input_layouts = (input_layouts or Replicate(),)
  445. self.output_layouts = (output_layouts or Shard(-1),)
  446. self.desired_input_layouts = (Replicate(),)
  447. self.use_local_output = use_local_output
  448. self.use_dtensor = use_dtensor
  449. @staticmethod
  450. def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
  451. # TODO: figure out dynamo support for instance method and switch this to instance method
  452. # annotate module input placements/sharding with input_layouts
  453. input_tensor = inputs[0]
  454. if not isinstance(input_tensor, DTensor):
  455. input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
  456. # transform the input layouts to the desired layouts of ColwiseParallel
  457. if input_layouts != desired_input_layouts:
  458. input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
  459. return input_tensor
  460. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  461. # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
  462. # means Colwise as Linear is input * weight^T + bias, where
  463. # weight would become Shard(1)
  464. if param_type == "bias":
  465. parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
  466. shard = [Shard(-1)]
  467. else:
  468. shard = [Shard(-2)]
  469. parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
  470. parameter = parameter.to(param_casting_dtype)
  471. if to_contiguous:
  472. parameter = parameter.contiguous()
  473. if self.use_dtensor:
  474. parameter = DTensor.from_local(
  475. parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
  476. )
  477. return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
  478. @staticmethod
  479. def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
  480. # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
  481. if outputs.placements != output_layouts:
  482. outputs = outputs.redistribute(placements=output_layouts, async_op=False)
  483. # back to local tensor
  484. return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
  485. class PackedColwiseParallel(ColwiseParallel):
  486. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  487. # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
  488. # means Colwise as Linear is input * weight^T + bias, where
  489. # weight would become Shard(1)
  490. parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2)
  491. parameter = parameter.to(param_casting_dtype)
  492. if to_contiguous:
  493. parameter = parameter.contiguous()
  494. if self.use_dtensor:
  495. parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
  496. return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
  497. class RowwiseParallel(TensorParallelLayer):
  498. """
  499. Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
  500. Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
  501. (i.e. MLP, Attention)
  502. Keyword Args:
  503. input_layouts (Placement, optional):
  504. The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
  505. become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
  506. output_layouts (Placement, optional):
  507. The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
  508. with the user desired layout. If not specified, the output tensor is replicated.
  509. use_local_output (bool, optional):
  510. Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
  511. Returns:
  512. A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
  513. """
  514. def __init__(
  515. self,
  516. *,
  517. input_layouts: Placement | None = None,
  518. output_layouts: Placement | None = None,
  519. use_local_output: bool = True,
  520. use_dtensor=True,
  521. ):
  522. super().__init__()
  523. self.input_layouts = (input_layouts or Shard(-1),)
  524. self.output_layouts = (output_layouts or Replicate(),)
  525. self.use_local_output = use_local_output
  526. self.use_dtensor = use_dtensor
  527. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  528. # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
  529. # means Rowwise as nn.Linear is input * weight^T + bias, where
  530. # weight would become Shard(0)
  531. if param_type != "bias":
  532. parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
  533. shard = [Shard(-1)]
  534. else:
  535. shard = [Replicate()]
  536. parameter = param[:]
  537. parameter = parameter.to(param_casting_dtype)
  538. if to_contiguous:
  539. parameter = parameter.contiguous()
  540. if self.use_dtensor:
  541. parameter = DTensor.from_local(
  542. parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
  543. )
  544. return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
  545. @staticmethod
  546. def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
  547. if hasattr(mod, "bias") and mod.bias is not None:
  548. mod._bias = mod.bias.to_local()
  549. mod.bias = None
  550. input_tensor = inputs[0]
  551. if not isinstance(input_tensor, DTensor):
  552. input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
  553. if input_layouts != desired_input_layouts:
  554. input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
  555. return input_tensor
  556. @staticmethod
  557. def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
  558. # Rowwise sharding produces partial output, depending on output layouts:
  559. # 1. to replicate -> allreduce
  560. # 2. to shard -> reduce_scatter
  561. if outputs.placements != output_layouts:
  562. outputs = outputs.redistribute(placements=output_layouts, async_op=True)
  563. outputs = outputs.to_local() # otherwise the `+=` op will gather
  564. if hasattr(mod, "_bias"):
  565. outputs = outputs + mod._bias
  566. # back to local tensor if use_local_output is True
  567. return outputs
  568. def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
  569. module._distribute_module_applied = True
  570. if self.use_dtensor:
  571. if isinstance(module, nn.Linear):
  572. # rowwise linear runtime sharding requires input tensor shard on last dim
  573. self.desired_input_layouts: tuple[Placement, ...] = (Shard(-1),)
  574. elif isinstance(module, nn.Embedding):
  575. # rowwise embedding runtime sharding requires input tensor replicated
  576. self.desired_input_layouts = (Replicate(),)
  577. elif isinstance(module, nn.Parameter):
  578. # rowwise embedding runtime sharding requires input tensor replicated
  579. self.desired_input_layouts = (Shard(-1),)
  580. else:
  581. raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!")
  582. distribute_module(
  583. module,
  584. device_mesh,
  585. partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
  586. partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
  587. )
  588. class PackedRowwiseParallel(RowwiseParallel):
  589. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  590. # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
  591. # means Colwise as Linear is input * weight^T + bias, where
  592. # weight would become Shard(1)
  593. parameter = get_packed_weights(param, empty_param, device_mesh, rank, -1)
  594. parameter = parameter.to(param_casting_dtype)
  595. if to_contiguous:
  596. parameter = parameter.contiguous()
  597. if self.use_dtensor:
  598. parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False)
  599. return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
  600. class SequenceParallel(TensorParallelLayer):
  601. """
  602. SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
  603. input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
  604. `RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
  605. This style implements the operation that is described in the paper
  606. `Reducing Activation Recomputation in Large Transformer Models <https://huggingface.co/papers/2205.05198>`__
  607. If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
  608. on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
  609. passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
  610. redistribute the input to be sharded on the sequence dimension.
  611. The output of the ``nn.Module`` will be sharded on the sequence dimension.
  612. Keyword Args:
  613. sequence_dim (int, optional):
  614. The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
  615. become a DTensor that is sharded on the sequence dimension, default: 1.
  616. use_local_output (bool, optional):
  617. Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
  618. Returns:
  619. A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
  620. Example::
  621. >>> # xdoctest: +SKIP(failing)
  622. >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
  623. >>> from torch.distributed.device_mesh import init_device_mesh
  624. >>> ...
  625. >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
  626. >>> tp_mesh = init_device_mesh("cuda", (8,))
  627. >>>
  628. >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
  629. >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
  630. >>>
  631. >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
  632. >>> ...
  633. .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
  634. ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
  635. inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
  636. to ensure that they are replicated.
  637. """
  638. def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
  639. super().__init__()
  640. self.input_layouts = (Replicate(),)
  641. self.desired_input_layouts = (Shard(1),)
  642. self.output_layouts = (Replicate(),)
  643. self.use_local_output = use_local_output
  644. self.use_dtensor = True
  645. self.sequence_sharding = (Shard(sequence_dim),)
  646. self.use_local_output = use_local_output
  647. @staticmethod
  648. def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
  649. input_tensor = inputs[0]
  650. if not isinstance(input_tensor, DTensor):
  651. input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
  652. if input_layouts != desired_input_layouts:
  653. input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
  654. return input_tensor
  655. @staticmethod
  656. def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
  657. outputs = outputs.redistribute(
  658. placements=(Replicate(),), async_op=True
  659. ) # maybe we have to replicate ? because next layer is not sharded
  660. return outputs.to_local() # if use_local_output else outputs
  661. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  662. # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
  663. # means Colwise as Linear is input * weight^T + bias, where
  664. # weight would become Shard(1)
  665. parameter = param[...]
  666. parameter = parameter.to(param_casting_dtype)
  667. if to_contiguous:
  668. parameter = parameter.contiguous()
  669. if self.use_dtensor:
  670. parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
  671. return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
  672. class GroupedGemmParallel(TensorParallelLayer):
  673. """
  674. Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
  675. """
  676. def __init__(self):
  677. super().__init__()
  678. self.use_dtensor = False
  679. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  680. ep_rank = rank
  681. global_num_experts = empty_param.shape[0]
  682. if global_num_experts % device_mesh.size() != 0:
  683. raise ValueError(
  684. f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
  685. )
  686. local_num_experts = global_num_experts // device_mesh.size()
  687. param = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
  688. if to_contiguous:
  689. param = param.contiguous()
  690. return param
  691. class RouterParallel(TensorParallelLayer):
  692. """
  693. Allows to reshape the router scores to support running expert parallel.
  694. """
  695. def __init__(self, *args, **kwargs):
  696. self.args = args
  697. self.kwargs = kwargs
  698. self.use_dtensor = False
  699. @staticmethod
  700. def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
  701. input_tensor = inputs[0]
  702. if isinstance(input_tensor, DTensor):
  703. raise NotImplementedError("RouterParallel does not support DTensor input for now")
  704. return input_tensor
  705. @staticmethod
  706. def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
  707. """
  708. Imagine if you had 4 tokens, top_k = 4, and 128experts.
  709. With EP = 8. The num_local_expert should be 128/8 = 16
  710. Imagine router_indices being:
  711. [ 52, 42, 119, 67],
  712. [102, 89, 61, 40],
  713. [ 82, 103, 4, 34],
  714. [ 93, 23, 109, 11],
  715. then you can map which rank should be getting which values
  716. [3, 2, 7, 4],
  717. [6, 5, 3, 2],
  718. [5, 6, 0, 2],
  719. [5, 1, 6, 0],
  720. Thus for say rank 0, you fill with 16 (num_local_expert) the index tensor
  721. [ 16, 16, 16, 16],
  722. [ 16, 16, 16, 16],
  723. [ 16, 16, 4, 16],
  724. [ 16, 16, 16, 11],
  725. This works well. For another rank you need to make sure you round to num_local_expert
  726. because the next operation will one hot encode the router index vector.
  727. This allows us to know directly which local expert is hit.
  728. Similarly the scores are indexed with something created form
  729. router_indices.
  730. The kinda naive training loop that we use for device_map "auto" uses a similar logic.
  731. Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates.
  732. Mask invalid indices with num_local_expert for one-hot encoding, so the computes will skip the masking index.
  733. """
  734. ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
  735. if mod.num_experts % ep_size != 0:
  736. raise ValueError(
  737. f"The number of experts must be divisible by number of ep_size: {mod.num_experts} % {ep_size} != 0"
  738. )
  739. num_local_experts = mod.num_experts // ep_size
  740. router_scores, router_indices = outputs
  741. router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
  742. router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
  743. # As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
  744. if num_local_experts > 1:
  745. router_indices = torch.fmod(router_indices, num_local_experts)
  746. else:
  747. router_indices = router_indices.masked_fill(router_indices > 0, 0).masked_fill(router_indices < 0, -1)
  748. router_indices = router_indices.masked_fill(
  749. router_indices == -1, num_local_experts
  750. ) # masking class for one hot
  751. return router_scores, router_indices
  752. def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
  753. # TODO: i'd like for this to be the default
  754. param = param[...].to(param_casting_dtype)
  755. if to_contiguous:
  756. param = param.contiguous()
  757. return param
  758. def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
  759. # TODO: need an abstract Parallel class that is different from TensorParallelLayer
  760. distribute_module(
  761. module,
  762. device_mesh,
  763. partial(self._prepare_input_fn, None, None),
  764. partial(self._prepare_output_fn, None, None),
  765. )
  766. class ParallelInterface(GeneralInterface):
  767. # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
  768. # a new instance is created (in order to locally override a given entry)
  769. _global_mapping = (
  770. {
  771. "colwise": ColwiseParallel(),
  772. "rowwise": RowwiseParallel(),
  773. "colwise_rep": ColwiseParallel(output_layouts=Replicate()),
  774. "rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
  775. "local_colwise": ColwiseParallel(use_dtensor=False),
  776. "local_rowwise": RowwiseParallel(use_dtensor=False),
  777. "local": IsolatedParallel(),
  778. "gather": GatherParallel(),
  779. "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
  780. "sequence_parallel": SequenceParallel(),
  781. "replicate": ReplicateParallel(),
  782. "grouped_gemm": GroupedGemmParallel(),
  783. "ep_router": RouterParallel(),
  784. }
  785. if is_torch_greater_or_equal("2.5") and _torch_distributed_available
  786. else {}
  787. )
  788. ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
  789. def convert_local_tensor_to_dtensor(
  790. parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
  791. ) -> DTensor:
  792. """
  793. Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
  794. """
  795. _, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
  796. tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
  797. if not tp_style:
  798. return parameter
  799. if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
  800. return parameter
  801. # TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
  802. if tp_style == "local_packed_rowwise":
  803. placements = [Shard(-1)]
  804. elif tp_style == "local_rowwise":
  805. if param_type == "bias":
  806. placements = [Replicate()]
  807. else:
  808. placements = [Shard(-1)]
  809. elif tp_style == "local_colwise":
  810. if param_type == "bias":
  811. placements = [Shard(-1)]
  812. else:
  813. placements = [Shard(-2)]
  814. return DTensor.from_local(parameter, device_mesh, placements, run_check=False)
  815. def replace_state_dict_local_with_dtensor(
  816. state_dict: dict[str, torch.Tensor],
  817. tp_plan: dict[str, str],
  818. device_mesh,
  819. ) -> dict[str, torch.Tensor]:
  820. """
  821. Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible.
  822. """
  823. for key, value in state_dict.items():
  824. if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
  825. state_dict[key] = convert_local_tensor_to_dtensor(value, key, device_mesh, tp_plan)
  826. return state_dict
  827. def add_tensor_parallel_hooks_to_module(
  828. model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None
  829. ):
  830. r"""
  831. This function is called in `PretrainedModel.post_init()`. It is responsible of adding hooks
  832. to the modules of the `model`, based on the `PretrainedModel._tp_plan`.
  833. This is the place where we add the `pre_forward` and `post_forwards` hooks. These are defined
  834. for each `TensorParallelLayer` as `_prepare_input_fn` and `_prepare_output_fn`.
  835. """
  836. if current_module_plan is not None:
  837. tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
  838. try:
  839. tp_layer.prepare_module_tp(module, device_mesh)
  840. except NotImplementedError as e:
  841. print(
  842. f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
  843. )
  844. module._hf_tp_plan = current_module_plan
  845. module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
  846. def shard_and_distribute_module(
  847. model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
  848. ): # TODO: rename to shard_and_distribute_param
  849. r"""
  850. This function is called in `from_pretrained` when loading a model's checkpoints.
  851. It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
  852. All process run this function, so they just load the partition of the tensor that they require.
  853. Main uses cases:
  854. - column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
  855. - packed layers: you slice the weights, then shard like above
  856. - custom operation:
  857. - you want to add an all-gather at the end of a local layer.
  858. - you want to have a layer that is isolated from the rest of the world (because torch.DTensor does not work well with `.view` for instance)
  859. """
  860. param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
  861. tp_plan = model.tp_plan or {}
  862. module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules?
  863. rank = int(rank)
  864. current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
  865. if dist.get_rank() == 0:
  866. if current_shard_plan is None:
  867. logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.")
  868. else:
  869. logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}")
  870. if current_shard_plan is not None:
  871. try:
  872. tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
  873. param = tp_layer.partition_tensor(
  874. param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
  875. )
  876. except NotImplementedError as e:
  877. print(
  878. f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
  879. )
  880. else:
  881. param = param[:].to(param_casting_dtype)
  882. # SUPER IMPORTANT we have to use setattr
  883. # otherwise loading is crazy slow
  884. if not isinstance(param, torch.nn.Parameter):
  885. param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
  886. setattr(module_to_tp, param_type, param)
  887. # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
  888. return param
  889. def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
  890. """
  891. Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied.
  892. """
  893. if tp_plan is None:
  894. return
  895. generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
  896. unsharded_layers = set(generic_keys)
  897. unused_rules = tp_plan
  898. for key in generic_keys:
  899. param_name = key.rsplit(".", 1)[0] if "." in key else key
  900. generic_param_name = re.sub(r"\d+", "*", param_name)
  901. if generic_param_name in tp_plan:
  902. unused_rules.pop(generic_param_name)
  903. unsharded_layers.discard(key)
  904. elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
  905. unused_rules.pop(parent_param_name)
  906. unsharded_layers.discard(key)
  907. else:
  908. pass # we couldn't find the rule for this parameter, so it's not sharded
  909. if len(unused_rules) > 0:
  910. logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
  911. if len(unsharded_layers) > 0:
  912. logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
  913. def distribute_model(model, distributed_config, device_mesh, tp_size):
  914. model._tp_size = tp_size
  915. model._device_mesh = device_mesh
  916. if distributed_config is not None:
  917. if isinstance(distributed_config, dict):
  918. distributed_config = DistributedConfig.from_dict(distributed_config)
  919. model.config.distributed_config = distributed_config
  920. model_plan = model.tp_plan
  921. if model_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
  922. for v in model_plan.values():
  923. if v not in ALL_PARALLEL_STYLES:
  924. raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
  925. for name, module in model.named_modules():
  926. if not getattr(module, "_is_hooked", False):
  927. plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model_plan, is_weight=False)
  928. add_tensor_parallel_hooks_to_module(
  929. model=model,
  930. module=module,
  931. tp_plan=model_plan,
  932. layer_name="",
  933. current_module_plan=plan,
  934. device_mesh=device_mesh,
  935. )
  936. module._is_hooked = True
  937. return model