| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169 |
- # mypy: allow-untyped-defs
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import logging
- import math
- import os
- import threading
- import warnings
- from collections.abc import Iterator
- from functools import reduce
- from itertools import chain, zip_longest
- from typing import Optional, TYPE_CHECKING, Union
- import torch
- from torch.distributed import is_available
- from torch.utils._typing_utils import not_none
- __all__ = ["init_device_mesh", "DeviceMesh"]
- if not is_available():
- import sys
- # We need to create the stubs when distributed is not available.
- # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
- # since it would try to import ``torch.distributed.device_mesh`` or
- # ``torch.distributed.init_device_mesh`` but cannot find them.
- class _DeviceMeshStub:
- pass
- def _init_device_mesh_stub():
- pass
- sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
- sys.modules[
- "torch.distributed.device_mesh"
- ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
- else:
- from torch._C._distributed_c10d import Backend as C10dBackend
- from torch.distributed.distributed_c10d import (
- _get_default_group,
- _resolve_process_group,
- get_backend,
- get_process_group_ranks,
- get_rank,
- get_world_size,
- init_process_group,
- is_initialized,
- new_group,
- ProcessGroup,
- split_group,
- )
- logger = logging.getLogger(__name__)
- # only import numpy typing when type checking
- if TYPE_CHECKING:
- try:
- from numpy.typing import ArrayLike
- except ImportError:
- logger.warning(
- "DeviceMesh requires numpy >= 1.21 to be installed for type checking"
- )
- class _MeshEnv(threading.local):
- def __init__(self) -> None:
- self.mesh_stack: list[DeviceMesh] = []
- self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {}
- self.mesh_dim_group_options: dict[
- int, tuple[Optional[str], Optional[C10dBackend.Options]]
- ] = {}
- self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {}
- # Record flatten mesh name to its mesh dim index in root mesh.
- self.flatten_name_to_root_dims: dict[
- DeviceMesh, dict[str, tuple[int, ...]]
- ] = {}
- def get_current_mesh(self) -> "DeviceMesh":
- if len(self.mesh_stack) == 0:
- raise RuntimeError("No device mesh is currently active!")
- return self.mesh_stack[-1]
- def create_sub_mesh(
- self,
- device_mesh: "DeviceMesh",
- submesh_dim_names: tuple[str, ...],
- submesh_dims: list[tuple[int, ...]],
- ) -> "DeviceMesh":
- # Get the submesh dim size from the submesh_dims.
- # For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want
- # to slice out mesh["dp_cp"], then submesh_dims = [(0, 1), (2,)] and submesh_dim_size = [2 * 2, 2] = [4, 2].
- # If we want to slice out mesh["dp", "cp"], then submesh_dims = [(0,), (1,)] and submesh_dim_size = [2, 2].
- slice_dim_size = [
- reduce(
- lambda x, y: x * device_mesh.mesh.size(y),
- mesh_dim,
- 1,
- )
- for mesh_dim in submesh_dims
- ]
- mesh_tensor = device_mesh.mesh
- # slice_dim_idx could be different from submesh_dims, as we may need to flatten out some dims.
- slice_dim_idx = []
- slice_dim_group_name = []
- # keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
- # flattened mesh tensor.
- num_dims_flatten = 0
- for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names):
- # Currently, this only allows slicing out a contiguous flattened dim.
- # TODO: we need to handle reconstructing a non-contiguous flattened dim.
- if len(mesh_dim_indices) > 1:
- # We need to move the start_dim and end_dim to the left if some dims are already flattened.
- mesh_tensor = mesh_tensor.flatten(
- start_dim=mesh_dim_indices[0] - num_dims_flatten,
- end_dim=mesh_dim_indices[-1] - num_dims_flatten,
- )
- # If some dims are already flattened, we need to adjust the slice_dim_idx accordingly.
- # For example, if the submesh_dims = [(0, 1), (2,), (3, 4)] with 0-1 flattened and 3-4 flattened,
- # then the final slice_dim_idx should be [0, 1, 2].
- slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
- num_dims_flatten += len(mesh_dim_indices) - 1
- slice_dim_group_name.append(
- self.root_to_flatten_mapping[device_mesh][
- mesh_dim_name
- ]._dim_group_names[0] # type: ignore[has-type]
- )
- else:
- slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
- slice_dim_group_name.append(
- device_mesh._dim_group_names[mesh_dim_indices[0]] # type: ignore[has-type]
- )
- # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
- mesh_dims_remained_idx = list(range(mesh_tensor.ndim))
- for idx in slice_dim_idx:
- if idx not in mesh_dims_remained_idx:
- raise NotImplementedError(
- "Currently, this only allows slicing out a contiguous flattened dim."
- )
- mesh_dims_remained_idx.remove(idx)
- # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx]
- # This means on each local rank of the outermost slice mesh dim, we have a tensor of submesh size with
- # the pg ranks of the submesh. From this, we can extract the submesh mesh tensor contains the current rank.
- pg_ranks_by_dim = mesh_tensor.permute(
- *mesh_dims_remained_idx, *slice_dim_idx
- ).reshape(-1, *slice_dim_size)
- cur_rank = device_mesh.get_rank()
- for mesh_nd in pg_ranks_by_dim:
- submesh = DeviceMesh(
- device_mesh.device_type,
- mesh_nd,
- mesh_dim_names=submesh_dim_names,
- _init_backend=False,
- )
- if cur_rank in mesh_nd:
- res_submesh = submesh
- res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined, has-type]
- self.child_to_root_mapping[res_submesh] = device_mesh
- return res_submesh
- def create_flatten_mesh(
- self,
- device_mesh: "DeviceMesh",
- mesh_dim_name: Optional[str] = None,
- backend_override: tuple[Optional[str], Optional[C10dBackend.Options]] = (
- None,
- None,
- ),
- ) -> "DeviceMesh":
- root_mesh = _mesh_resources.get_root_mesh(device_mesh)
- flatten_dims_in_root = [
- not_none(root_mesh.mesh_dim_names).index(flatten_mesh_dim_name)
- for flatten_mesh_dim_name in not_none(device_mesh.mesh_dim_names)
- ]
- if not mesh_dim_name:
- mesh_dim_name = "_".join(not_none(device_mesh.mesh_dim_names))
- # Check whether the mesh_dim_name for flattened mesh is valid.
- self.flatten_name_to_root_dims.setdefault(root_mesh, {})
- invalid_dim_names = chain(
- list(not_none(root_mesh.mesh_dim_names)),
- *self.flatten_name_to_root_dims[root_mesh].keys(),
- )
- if mesh_dim_name in invalid_dim_names:
- raise RuntimeError(
- f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ",
- f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. "
- f"Please specify another valid mesh_dim_name.",
- )
- # Quick return if the flatten mesh has been created before.
- # TODO: If we decide to restrict flatten initialization once, we should remove
- # this check and throw an error if the flatten mesh is already created before.
- if (
- root_mesh in self.root_to_flatten_mapping
- and mesh_dim_name in self.root_to_flatten_mapping[root_mesh]
- ):
- return self.root_to_flatten_mapping[root_mesh][mesh_dim_name]
- flattened_mesh_dim_size = math.prod(device_mesh.mesh.size())
- remained_dims_in_root = list(range(root_mesh.mesh.ndim))
- for flatten_dim_in_root in flatten_dims_in_root:
- remained_dims_in_root.remove(flatten_dim_in_root)
- pg_ranks_by_dim = root_mesh.mesh.permute(
- *remained_dims_in_root, *flatten_dims_in_root
- ).reshape(-1, flattened_mesh_dim_size)
- cur_rank = root_mesh.get_rank()
- for mesh_nd in pg_ranks_by_dim:
- # need to init backend here since the flattened pg doesn't exist in root mesh.
- flattened_mesh = DeviceMesh(
- root_mesh.device_type,
- mesh_nd,
- mesh_dim_names=(mesh_dim_name,),
- backend_override=(backend_override,),
- )
- if cur_rank in mesh_nd:
- res_flattened_mesh = flattened_mesh
- self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined]
- self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = (
- res_flattened_mesh # type: ignore[possibly-undefined]
- )
- self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(
- flatten_dims_in_root
- ) # type: ignore[possibly-undefined]
- return res_flattened_mesh
- def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh":
- # If a mesh could not be found in the child_to_root_mapping, it is a root mesh itself.
- # A root mesh is not created through slicing.
- # We considers the root mesh of a root mesh is itself.
- root_mesh = self.child_to_root_mapping.get(device_mesh, None)
- return device_mesh if not root_mesh else root_mesh
- def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
- """
- Returns the index of the mesh dim in the root mesh.
- The device_mesh passed in needs to be sliced out from the root mesh
- or submesh of the root mesh.
- """
- root_mesh = self.get_root_mesh(device_mesh)
- child_mesh_dim_names = device_mesh.mesh_dim_names
- if root_mesh and child_mesh_dim_names:
- assert len(child_mesh_dim_names) == 1, (
- "The submesh can only be a 1D mesh."
- )
- child_mesh_dim_name = child_mesh_dim_names[0]
- return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name)
- return None
- @staticmethod
- def num_devices_per_host(device_type: str) -> int:
- return _get_device_handle(device_type).device_count()
- @staticmethod
- def num_hosts(device_type: str) -> int:
- # ProcessGroup can't tell us this info so we have to infer it, assume
- # homogeneous hardware for now
- return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
- def get_mesh_dim_by_name(
- self, device_mesh: "DeviceMesh", mesh_dim_name: str
- ) -> int:
- if (
- device_mesh.mesh_dim_names is None
- or len(device_mesh.mesh_dim_names) == 0
- ):
- raise KeyError(
- "No `mesh_dim_names` found.",
- )
- if mesh_dim_name not in device_mesh.mesh_dim_names:
- raise KeyError(
- f"Mesh dimension '{mesh_dim_name}' does not exist.",
- f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}",
- )
- return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name))
- def _set_mesh_dim_group_options(
- self,
- dim: int,
- backend: Optional[str],
- pg_options: Optional[C10dBackend.Options] = None,
- ) -> None:
- self.mesh_dim_group_options[dim] = (backend, pg_options)
- def _get_slice_mesh_dims(
- self, device_mesh, mesh_dim_names
- ) -> list[tuple[int, ...]]:
- """
- Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
- If valid, return dim indexes of the slice mesh in the device mesh.
- """
- if device_mesh != self.get_root_mesh(device_mesh):
- warnings.warn(
- "You are attempting to slice a submesh from another submesh. While we support this operation, "
- "it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. "
- "If not, this may result in some ranks receiving the submesh while others encounter errors."
- )
- # The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names
- # or its flattened mesh's mesh_dim_names.
- self.flatten_name_to_root_dims.setdefault(device_mesh, {})
- flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh]
- valid_mesh_dim_names = [
- *device_mesh.mesh_dim_names,
- *flatten_name_to_root_dims,
- ]
- if not all(
- mesh_dim_name in valid_mesh_dim_names
- for mesh_dim_name in mesh_dim_names
- ):
- raise KeyError(
- f"Invalid mesh_dim_names {mesh_dim_names} specified. "
- f"Valid mesh_dim_names are {valid_mesh_dim_names}."
- )
- # Validate the order of the slice mesh dim indices.
- # This needs to be in ascending order.
- curr_idx = -1
- slice_mesh_dims = []
- for mesh_dim_name in mesh_dim_names:
- if mesh_dim_name in flatten_name_to_root_dims:
- mesh_indices = flatten_name_to_root_dims[mesh_dim_name]
- # TODO: this doesn't allow non-contiguous slicing with flatten dim yet. next_idx
- # should be mesh_indices[0] once we support non-contiguous slicing with flatten dim.
- next_idx = mesh_indices[-1]
- slice_mesh_dims.append(mesh_indices)
- else:
- next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name)
- slice_mesh_dims.append((next_idx,))
- if next_idx <= curr_idx:
- raise KeyError(
- f"Invalid mesh_dim_names {mesh_dim_names} specified. "
- f"Found mesh dim indices to slice: {slice_mesh_dims}. "
- "Mesh dim indices should be in ascending order."
- )
- curr_idx = next_idx
- return slice_mesh_dims
- def _get_all_submeshes(
- self, device_mesh: "DeviceMesh", mesh_dim_name: str
- ) -> list["DeviceMesh"]:
- """
- Return all the submeshes of a given mesh dimension of the device mesh.
- """
- mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name)
- pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
- -1, device_mesh.mesh.size(mesh_dim)
- )
- cur_rank = device_mesh.get_rank()
- res_submeshes = []
- for mesh_1d in pg_ranks_by_dim:
- submesh = DeviceMesh(
- device_mesh.device_type,
- mesh_1d,
- mesh_dim_names=(mesh_dim_name,),
- _init_backend=False,
- )
- submesh._dim_group_names = (
- [device_mesh._dim_group_names[mesh_dim]] # type: ignore[has-type]
- if cur_rank in mesh_1d
- else []
- )
- res_submeshes.append(submesh)
- return res_submeshes
- _mesh_resources: _MeshEnv = _MeshEnv()
- def _get_device_handle(device_type: str = "cuda"):
- """
- Get the module corresponding to the device_type which is cuda or cuda-like device.
- For example, when the device_type is cuda, the module `torch.cuda` is returned.
- Return None when there is no corresponding module for device_type, otherwise
- return the corresponding module.
- """
- return getattr(torch, device_type, None)
- class DeviceMesh:
- """
- DeviceMesh represents a mesh of devices, where layout of devices could be
- represented as a n-d dimension array, and each value of the n-d dimensional
- array is the global id of the default process group ranks.
- DeviceMesh could be used to setup the N dimensional device connections across the cluster,
- and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on
- each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects
- already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization),
- and will select/set the device for the current process if user does not set the device
- beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization.
- DeviceMesh can also be used as a context manager when using together with DTensor APIs.
- .. note::
- DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
- is running on all processes/ranks in the cluster. Therefore, users need to make sure the
- `mesh` array (which describes the layout of devices) should be identical across all ranks.
- Inconsistent `mesh` will lead to silent hang.
- Args:
- device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
- mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
- of devices, where the IDs are global IDs of the default process group.
- Returns:
- DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
- The following program runs on each process/rank in an SPMD manner. In this example, we have 2
- hosts with 4 GPUs each.
- A reduction over the first dimension of mesh will reduce across
- columns (0, 4), .. and (3, 7), a reduction over the second dimension
- of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
- Example::
- >>> # xdoctest: +SKIP("no rank")
- >>> from torch.distributed.device_mesh import DeviceMesh
- >>>
- >>> # Initialize device mesh as (2, 4) to represent the topology
- >>> # of cross-host(dim 0), and within-host (dim 1).
- >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
- """
- device_type: str
- mesh: torch.Tensor
- mesh_dim_names: Optional[tuple[str, ...]]
- def __init__(
- self,
- device_type: str,
- mesh: Union[torch.Tensor, "ArrayLike"],
- *,
- mesh_dim_names: Optional[tuple[str, ...]] = None,
- backend_override: Optional[
- tuple[tuple[Optional[str], Optional[C10dBackend.Options]], ...]
- ] = None,
- _init_backend: bool = True,
- ) -> None:
- self.device_type = device_type
- if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
- raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
- self.mesh = (
- mesh.detach().to(dtype=torch.int)
- if isinstance(mesh, torch.Tensor)
- else torch.tensor(mesh, device="cpu", dtype=torch.int)
- )
- self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
- if backend_override is None:
- backend_override = ((None, None),) * self.mesh.ndim
- # private field to pre-generate DeviceMesh's hash
- self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
- self._thread_id = None
- # Skip process group initialization if xla device or init backend is False
- # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
- if device_type != "xla":
- # always try to create default (world) pg, even if it is not initialized
- # already. The world pg is used for device mesh identity (rank) on each
- # process (we need to know if the current global rank is in the mesh or not).
- if _init_backend:
- self._setup_world_group_and_device()
- self._init_process_groups(backend_override)
- if is_initialized() and get_backend() == "threaded":
- self._thread_id = threading.get_ident()
- # calculate the coordinates of the current global rank on the mesh
- rank_coords = (self.mesh == get_rank()).nonzero()
- assert rank_coords.size(0) in (0, 1)
- self._coordinate_on_dim: Optional[list[int]] = (
- rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
- )
- def _setup_world_group_and_device(self):
- default_initialized = is_initialized()
- # TODO: think about how to allow pg options to be passed to world group
- # or mesh dimension groups
- if not default_initialized:
- init_process_group()
- world_size = get_world_size()
- if self.mesh.numel() > world_size:
- raise RuntimeError(
- f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!"
- )
- # ONLY set the device if the current device is not initialized, if user already
- # set the device before DeviceMesh init, we respect the user's choice.
- device_handle = _get_device_handle(self.device_type)
- if device_handle and not device_handle.is_initialized():
- # auto set the cuda/cuda-like device only if user has not set it, if there's LOCAL_RANK
- # env variable from launchers, we use it to set the device.
- if "LOCAL_RANK" in os.environ:
- local_rank = int(os.environ["LOCAL_RANK"])
- logger.info(
- "Setting default device for the current process based on LOCAL_RANK=%s",
- local_rank,
- )
- device_handle.set_device(local_rank)
- else:
- warnings.warn(
- "It seems like you did not set/select the default device for the current process before the DeviceMesh "
- "initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. "
- "It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that "
- "the underlying communicator (i.e. NCCL) can be initialized properly. "
- "Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the "
- "device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. "
- )
- # heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
- # NOTE: This device selection would only work for homogeneous hardware.
- num_devices_per_host = device_handle.device_count()
- if (
- world_size > num_devices_per_host
- and world_size % num_devices_per_host != 0
- ):
- raise RuntimeError(
- f"DeviceMesh only support homogeneous hardware, but found "
- f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
- )
- device_handle.set_device(get_rank() % num_devices_per_host)
- return _get_default_group()
- def _init_process_groups(
- self,
- backend_override: tuple[
- tuple[Optional[str], Optional[C10dBackend.Options]], ...
- ],
- ):
- # group_name associated with each mesh dimension, each
- # mesh dimension should have one sub-group per rank
- #
- dim_group_names: list[str] = []
- default_group = _get_default_group()
- if (
- self.mesh.ndim == 1
- and self.mesh.numel() == get_world_size()
- and _mesh_resources.mesh_dim_group_options.get(0, (None, None))
- == (None, None)
- and backend_override[0] == (None, None)
- ):
- # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`.
- # Otherwise, create new pg.
- ranks = list(range(get_world_size()))
- dim_group = (
- new_group(
- backend="cpu:gloo,cuda:nccl",
- ranks=ranks,
- group_desc="mesh_default",
- )
- if torch.cuda.is_available()
- and get_backend(default_group) == "gloo"
- else default_group
- )
- dim_group_names.append(dim_group.group_name)
- else:
- # create sub pgs base on the mesh argument specified
- for dim in range(self.mesh.ndim):
- # swap the current dim to the last dim
- # then reshape to flatten out other dims
- pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
- -1, self.mesh.size(dim)
- )
- # Respect dim group options specified via _MeshEnv.set_dim_group_options().
- # Inherit from the parent group if no options are specified for the group.
- if dim in _mesh_resources.mesh_dim_group_options:
- if backend_override[dim] != (None, None):
- raise RuntimeError(
- f"Dimension {dim} present both in the backend_override argument "
- "and via _mesh_resources._set_mesh_dim_group_options"
- )
- (
- backend,
- pg_options,
- ) = _mesh_resources.mesh_dim_group_options[dim]
- else:
- backend, pg_options = backend_override[dim]
- # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
- # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
- # If the mesh doesn't not have a mesh_dim_names, then the group description of the
- # subgroup would be `mesh_dim_0` and `mesh_dim_1`.
- group_desc = (
- f"mesh_{self.mesh_dim_names[dim]}"
- if self.mesh_dim_names
- else f"mesh_dim_{dim}"
- )
- # If bound_device_id exists, it means the nccl communicator has been eagerly initialized
- # so that we can use `split_group` to create subgroups through `ncclCommSplit`.
- # In this case, we only need to make one API call (`split_group``) for the subgroup creation
- # for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create
- # all the subgroups.
- # Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
- # numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
- # mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups.
- dim_group = None
- has_split_group = False
- if (
- (
- bound_device_id := getattr(
- default_group, "bound_device_id", None
- )
- )
- is not None
- and torch.cuda.is_available()
- and (
- backend is None
- or default_group._get_backend(torch.device("cuda")).name()
- == backend
- )
- ):
- dim_group = split_group(
- parent_pg=default_group,
- pg_options=pg_options,
- split_ranks=pg_ranks_by_dim.tolist(),
- group_desc=group_desc,
- )
- has_split_group = True
- # If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
- # and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
- # Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
- # along with appending information to the `dim_group_names` list whenever necessary.
- for dim_mesh in pg_ranks_by_dim:
- subgroup_ranks = dim_mesh.tolist()
- # We temporarily revert the reuse subgroup, since it breaks two internal tests.
- # Temporarily reverting to resolve test timeout while root-causing.
- # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
- if bound_device_id is None or not has_split_group:
- dim_group = new_group(
- ranks=subgroup_ranks,
- backend=backend,
- pg_options=pg_options,
- group_desc=group_desc,
- )
- # only add to dim_groups if the current rank in the subgroup
- if self.get_rank() in subgroup_ranks:
- if len(dim_group_names) > dim:
- raise RuntimeError(
- f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
- f"in {subgroup_ranks}!"
- )
- dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
- self._dim_group_names = dim_group_names
- def __enter__(self) -> "DeviceMesh":
- # set this mesh as the current mesh in mesh env
- _mesh_resources.mesh_stack.append(self)
- return self
- # pyre-fixme[2]: Parameter must be annotated.
- def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
- # pop this mesh from mesh env
- _mesh_resources.mesh_stack.pop()
- def __repr__(self) -> str:
- device_mesh_repr = (
- f"({', '.join(f'{k}={v}' for k, v in zip(self.mesh_dim_names, self.mesh.shape))})"
- if self.mesh_dim_names
- else f"{tuple(self.mesh.shape)}"
- )
- device_mesh_repr = f"DeviceMesh({device_mesh_repr}, device: '{self.device_type}', stride: {self.mesh.stride()}"
- # We only print the mesh tensor if the debug mode is turned on.
- if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
- device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
- return f"{device_mesh_repr})"
- def __hash__(self):
- # lazily compute hash
- self._hash = getattr(self, "_hash", None)
- if not self._hash:
- self._hash = hash(
- (
- self._flatten_mesh_list,
- self.mesh.shape,
- self.device_type,
- self.mesh_dim_names,
- self._thread_id,
- )
- )
- return self._hash
- def __eq__(self, other: object) -> bool:
- if self is other:
- return True
- if not isinstance(other, DeviceMesh):
- return False
- return (
- self._flatten_mesh_list == other._flatten_mesh_list
- and self.mesh.shape == other.mesh.shape
- and self.device_type == other.device_type
- and self.mesh_dim_names == other.mesh_dim_names
- and self._thread_id == other._thread_id
- )
- def __getitem__(
- self, mesh_dim_names: Union[str, tuple[str, ...]]
- ) -> "DeviceMesh":
- """
- Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh.
- The submesh created consists of the dimensions and the communicators indicated by
- ``mesh_dim_names``
- Args:
- mesh_dim_names (Union[str, Tuple[str]]): the name or the tuple of names of the
- mesh dimension of the DeviceMesh to create the submesh for.
- Returns:
- A :class:`DeviceMesh` object
- The following program runs on each process/rank in an SPMD manner in a world size of 8.
- In the first example:
- Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]).
- Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]).
- Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]).
- Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]).
- Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]).
- Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]).
- In the second example:
- Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]).
- Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]).
- Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]).
- Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]).
- Example::
- >>> # xdoctest: +SKIP("no rank")
- >>> from torch.distributed.device_mesh import DeviceMesh
- >>>
- >>> # Initialize a 2D device mesh as (2, 4) to represent the topology
- >>> # of cross-host(dim 0), and within-host (dim 1).
- >>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp"))
- >>> tp_mesh = mesh_2d["tp"]
- >>> dp_mesh = mesh_2d["dp"]
- >>>
- >>> # Initialize a 3D mesh.
- >>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp"))
- >>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh.
- >>> dp_cp_mesh = mesh_3d["dp", "cp"]
- >>> cp_dp_mesh = mesh_3d["cp", "dp"]
- """
- if not self.mesh_dim_names:
- raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!")
- mesh_dim_names = (
- (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names
- )
- if mesh_dim_names == self.mesh_dim_names:
- return self
- else:
- slice_mesh_dims = _mesh_resources._get_slice_mesh_dims(
- self, mesh_dim_names
- )
- # When using FakeTensorMode to trace the model, `create_sub_mesh()` will
- # fail as it will require a real tensor to manipulate.
- # `unset_fake_temporarily()` will allow us to materialize the tensors
- # within `_mesh_resources`, which should not affect modling.
- #
- # Note that this should be orthogonal to torch.compile(). But whether
- # we can compile device_mesh `slicing` (no graph break) is not verified
- # yet and need a follow-up,
- # TODO: compiler + device_mesh slicing.
- with torch._subclasses.fake_tensor.unset_fake_temporarily():
- submesh = _mesh_resources.create_sub_mesh(
- self, mesh_dim_names, slice_mesh_dims
- )
- return submesh
- def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
- """
- Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the
- DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.
- Args:
- mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
- of the mesh dimension. Default is None.
- Returns:
- A :class:`ProcessGroup` object.
- """
- if not hasattr(self, "_dim_group_names"):
- raise RuntimeError("DeviceMesh process groups not initialized!")
- if self.mesh.ndim > 1 and mesh_dim is None:
- raise RuntimeError(
- f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
- "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
- "If you want to get the list of all the ProcessGroups in the DeviceMesh,"
- "please use `get_all_groups()` instead.",
- )
- # Quick return if the current device_mesh is a 1D mesh.
- if self.mesh.ndim == 1 and mesh_dim is None:
- return not_none(_resolve_process_group(self._dim_group_names[0]))
- root_mesh = _mesh_resources.get_root_mesh(self)
- root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get(
- root_mesh, None
- )
- if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
- dim_group_name = root_to_flatten_mapping[
- mesh_dim # type: ignore[index]
- ]._dim_group_names[0]
- return not_none(_resolve_process_group(dim_group_name))
- else:
- mesh_dim = (
- _mesh_resources.get_mesh_dim_by_name(self, mesh_dim)
- if isinstance(mesh_dim, str)
- else mesh_dim
- )
- assert isinstance(mesh_dim, int)
- return not_none(_resolve_process_group(self._dim_group_names[mesh_dim]))
- def get_all_groups(self) -> list[ProcessGroup]:
- """
- Returns a list of ProcessGroups for all mesh dimensions.
- Returns:
- A list of :class:`ProcessGroup` object.
- """
- return [self.get_group(i) for i in range(self.mesh.ndim)]
- @staticmethod
- def from_group(
- group: Union[ProcessGroup, list[ProcessGroup]],
- device_type: str,
- mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
- *,
- mesh_dim_names: Optional[tuple[str, ...]] = None,
- ) -> "DeviceMesh":
- """
- Constructs a :class:`DeviceMesh` with ``device_type`` from an
- existing :class:`ProcessGroup` or a list of existing :class:`ProcessGroup`.
- The constructed device mesh has number of dimensions equal to the
- number of groups passed. For example, if a single process group is passed in,
- the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in,
- the resulted DeviceMesh is a 2D mesh.
- If more than one group is passed, then the ``mesh`` and ``mesh_dim_names`` arguments
- are required. The order of the process groups passed in determines the topology of
- the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh.
- The `mesh` tensor passed in must have the same number of dimensions as the number of process
- groups passed in, and the order of the dimensions in the `mesh` tensor must match the order
- in the process groups passed in.
- Args:
- group (ProcessGroup or list[ProcessGroup]): the existing ProcessGroup
- or a list of existing ProcessGroups.
- device_type (str): The device type of the mesh. Currently supports: "cpu",
- "cuda/cuda-like". Passing in a device type with a GPU index, such as "cuda:0",
- is not allowed.
- mesh (torch.Tensor or ArrayLike, optional): A multi-dimensional array or an
- integer tensor describing the layout of devices, where the IDs are global IDs
- of the default process group. Default is None.
- mesh_dim_names (tuple[str], optional): A tuple of mesh dimension names to assign
- to each dimension of the multi-dimensional array describing the layout of devices.
- Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names`
- must be unique. Default is None.
- Returns:
- DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
- """
- # 1D scenario
- if isinstance(group, ProcessGroup):
- group_ranks = get_process_group_ranks(group)
- if (
- isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks
- ) or (
- mesh is not None
- and not isinstance(mesh, torch.Tensor)
- and mesh != group_ranks
- ):
- raise ValueError(
- f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}"
- )
- mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int)
- device_mesh = DeviceMesh(
- device_type,
- mesh,
- mesh_dim_names=mesh_dim_names,
- _init_backend=False,
- )
- device_mesh._dim_group_names = [group.group_name]
- return device_mesh
- # nD scenario
- groups = list(group)
- if len(groups) == 0:
- raise ValueError("Expects at least one ProcessGroup to be passed")
- if mesh is None:
- raise ValueError("Must pass mesh if passing multiple ProcessGroups")
- if mesh_dim_names is None:
- raise ValueError(
- "Must pass mesh_dim_names if passing multiple ProcessGroups"
- )
- mesh = (
- mesh.detach().to(dtype=torch.int, device="cpu")
- if isinstance(mesh, torch.Tensor)
- else torch.tensor(mesh, device="cpu", dtype=torch.int)
- )
- if mesh.ndim != len(groups):
- raise ValueError(
- "Expects mesh with ndim equal to number of ProcessGroups but got "
- f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups"
- )
- device_mesh = DeviceMesh(
- device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False
- )
- device_mesh._dim_group_names = [group.group_name for group in groups]
- return device_mesh
- def size(self, mesh_dim: Optional[int] = None) -> int:
- return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
- @property
- def ndim(self) -> int:
- return self.mesh.ndim
- @property
- def shape(self) -> tuple[int, ...]:
- return tuple(self.mesh.shape)
- def get_rank(self) -> int:
- """
- Returns the current global rank.
- """
- return get_rank()
- def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
- """
- Returns the local rank of the given mesh_dim of the DeviceMesh.
- Args:
- mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
- of the mesh dimension. Default is None.
- Returns:
- An integer denotes the local rank.
- The following program runs on each process/rank in an SPMD manner. In this example, we have 2
- hosts with 4 GPUs each.
- Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.
- Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.
- Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.
- Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.
- Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.
- Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
- Example::
- >>> # xdoctest: +SKIP("no rank")
- >>> from torch.distributed.device_mesh import DeviceMesh
- >>>
- >>> # Initialize device mesh as (2, 4) to represent the topology
- >>> # of cross-host(dim 0), and within-host (dim 1).
- >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
- """
- if self.ndim > 1 and mesh_dim is None:
- raise RuntimeError(
- f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
- "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
- )
- elif mesh_dim is None:
- mesh_dim = 0
- mesh_dim_group = not_none(self.get_group(mesh_dim))
- assert isinstance(mesh_dim_group, ProcessGroup), (
- "We expect ProcessGroup before calling `get_rank`!"
- )
- return not_none(get_rank(mesh_dim_group))
- def get_coordinate(self) -> Optional[list[int]]:
- """
- Return the relative indices of this rank relative to all
- dimensions of the mesh. If this rank is not part of the mesh, return None.
- """
- return self._coordinate_on_dim if self._coordinate_on_dim else None
- def _flatten(
- self,
- mesh_dim_name: Optional[str] = None,
- backend_override: Union[
- None, str, C10dBackend.Options, tuple[str, C10dBackend.Options]
- ] = None,
- ) -> "DeviceMesh":
- """
- Returns a 1D DeviceMesh by flattening the current DeviceMesh.
- If no mesh_dim_name is provided, the default is a string concatenating the mesh_dim_names of the
- given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh
- DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling
- mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 2, 4, 6], mesh_dim_names=("dp_cp",))
- on rank 0, 2, 4, 6 and a 1D submesh DeviceMesh([1, 3, 5, 7], mesh_dim_names=("dp_cp",)) on rank 1, 3, 5, 7.
- After the flattened dimension is created, to access the flattened dimension in mesh_3d, one can use the
- existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"].
- """
- if not self.mesh_dim_names:
- raise RuntimeError(
- "Cannot flatten a DeviceMesh without mesh_dim_names!"
- )
- if backend_override is not None:
- (backend_override_tuple,) = _normalize_backend_override(
- {0: backend_override}, 1
- )
- else:
- backend_override_tuple = (None, None)
- return _mesh_resources.create_flatten_mesh(
- self, mesh_dim_name, backend_override_tuple
- )
- def _normalize_backend_override(
- backend_override: dict[
- Union[int, str],
- Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
- ],
- ndim: int,
- mesh_dim_names: Optional[tuple[str, ...]] = None,
- ) -> Iterator[tuple[Optional[str], Optional[C10dBackend.Options]]]:
- if mesh_dim_names is None:
- mesh_dim_names = ()
- for dim_idx, dim_name in zip_longest(range(ndim), mesh_dim_names):
- if dim_name is not None and dim_name in backend_override:
- if dim_idx in backend_override:
- raise RuntimeError(
- f"Found redundant dim index {dim_idx} and "
- f"name {dim_name} in backend_override"
- )
- val = backend_override.pop(dim_name)
- elif dim_idx in backend_override:
- val = backend_override.pop(dim_idx)
- else:
- yield (None, None)
- continue
- if isinstance(val, str):
- yield (val, None)
- elif isinstance(val, C10dBackend.Options):
- yield (None, val)
- else:
- yield val
- if backend_override:
- raise RuntimeError(
- f"Found invalid keys in backend_override: got {list(backend_override.keys())}, "
- f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}"
- )
- def init_device_mesh(
- device_type: str,
- mesh_shape: tuple[int, ...],
- *,
- mesh_dim_names: Optional[tuple[str, ...]] = None,
- backend_override: Optional[
- dict[
- Union[int, str],
- Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]],
- ]
- ] = None,
- ) -> DeviceMesh:
- """
- Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
- This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`.
- If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`.
- .. note::
- `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program
- runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array
- describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging.
- .. note::
- If no process group is found, init_device_mesh will initialize distributed process group/groups
- required for distributed communications behind the scene.
- Args:
- device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like", "xpu".
- Passing in a device type with a GPU index, such as "cuda:0", is not allowed.
- mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array
- describing the layout of devices.
- mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension
- of the multi-dimensional array describing the layout of devices. Its length must match the length
- of `mesh_shape`. Each string in `mesh_dim_names` must be unique.
- backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional): Overrides for some or all of
- the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a
- dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name
- of the backend and its options, or just one of these two components (in which case the other will be
- set to its default value).
- Returns:
- DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
- Example::
- >>> # xdoctest: +SKIP("no rank")
- >>> from torch.distributed.device_mesh import init_device_mesh
- >>>
- >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
- >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
- """
- if mesh_dim_names is not None:
- if len(set(mesh_dim_names)) != len(mesh_dim_names):
- raise RuntimeError(
- "Each mesh_dim_name must be unique.",
- f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
- )
- if len(mesh_shape) != len(mesh_dim_names):
- raise RuntimeError(
- "mesh_shape and mesh_dim_names should have same length!",
- f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
- )
- if backend_override is not None:
- backend_override_tuple = tuple(
- _normalize_backend_override(
- backend_override, len(mesh_shape), mesh_dim_names
- )
- )
- else:
- backend_override_tuple = None
- # assume valid device types are all letters
- if device_type and not device_type.isalpha():
- raise RuntimeError(
- f"Device type with index is not supported but got {device_type}. ",
- "If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
- )
- # Always initialize the mesh's tensor on CPU, regardless of what the
- # external device type has been set to be (e.g. meta)
- with torch.device("cpu"):
- mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape)
- device_mesh = DeviceMesh(
- device_type=device_type,
- mesh=mesh,
- mesh_dim_names=mesh_dim_names,
- backend_override=backend_override_tuple,
- )
- return device_mesh
|