| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- import os
- import torch.distributed as dist
- def init_dist():
- if "LOCAL_RANK" in os.environ:
- int(os.environ["LOCAL_RANK"])
- rank = int(os.environ["RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
- dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
- elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
- int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
- rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
- world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1"))
- dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
- else:
- # don't need to do init for single process
- pass
- def _get_comm():
- try:
- from mpi4py import MPI # noqa: PLC0415
- comm = MPI.COMM_WORLD
- return comm
- except ImportError:
- return None
- def get_rank():
- comm = _get_comm()
- return comm.Get_rank() if comm is not None else 0
- def get_size():
- comm = _get_comm()
- return comm.Get_size() if comm is not None else 1
- def barrier():
- comm = _get_comm()
- if comm is not None:
- comm.Barrier()
- def print_out(*args):
- if get_rank() == 0:
- print(*args)
|