dist_settings.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import os
  7. import torch.distributed as dist
  8. def init_dist():
  9. if "LOCAL_RANK" in os.environ:
  10. int(os.environ["LOCAL_RANK"])
  11. rank = int(os.environ["RANK"])
  12. world_size = int(os.environ["WORLD_SIZE"])
  13. dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
  14. elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
  15. int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
  16. rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
  17. world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1"))
  18. dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
  19. else:
  20. # don't need to do init for single process
  21. pass
  22. def _get_comm():
  23. try:
  24. from mpi4py import MPI # noqa: PLC0415
  25. comm = MPI.COMM_WORLD
  26. return comm
  27. except ImportError:
  28. return None
  29. def get_rank():
  30. comm = _get_comm()
  31. return comm.Get_rank() if comm is not None else 0
  32. def get_size():
  33. comm = _get_comm()
  34. return comm.Get_size() if comm is not None else 1
  35. def barrier():
  36. comm = _get_comm()
  37. if comm is not None:
  38. comm.Barrier()
  39. def print_out(*args):
  40. if get_rank() == 0:
  41. print(*args)