nccl.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import warnings
  4. from collections.abc import Sequence
  5. from typing import Optional, Union
  6. import torch.cuda
  7. __all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
  8. SUM = 0 # ncclRedOp_t
  9. def is_available(tensors):
  10. if not hasattr(torch._C, "_nccl_all_reduce"):
  11. warnings.warn("PyTorch is not compiled with NCCL support")
  12. return False
  13. devices = set()
  14. for tensor in tensors:
  15. if tensor.is_sparse:
  16. return False
  17. if not tensor.is_contiguous():
  18. return False
  19. if not tensor.is_cuda:
  20. return False
  21. device = tensor.get_device()
  22. if device in devices:
  23. return False
  24. devices.add(device)
  25. return True
  26. def version():
  27. """
  28. Returns the version of the NCCL.
  29. This function returns a tuple containing the major, minor, and patch version numbers of the NCCL.
  30. The suffix is also included in the tuple if a version suffix exists.
  31. Returns:
  32. tuple: The version information of the NCCL.
  33. """
  34. ver = torch._C._nccl_version()
  35. major = ver >> 32
  36. minor = (ver >> 16) & 65535
  37. patch = ver & 65535
  38. suffix = torch._C._nccl_version_suffix().decode("utf-8")
  39. if suffix == "":
  40. return (major, minor, patch)
  41. else:
  42. return (major, minor, patch, suffix)
  43. def unique_id():
  44. return torch._C._nccl_unique_id()
  45. def init_rank(num_ranks, uid, rank):
  46. return torch._C._nccl_init_rank(num_ranks, uid, rank)
  47. def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
  48. if not isinstance(inputs, collections.abc.Container) or isinstance(
  49. inputs, torch.Tensor
  50. ):
  51. raise TypeError("Inputs should be a collection of tensors")
  52. def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
  53. _check_sequence_type(inputs)
  54. if outputs is None:
  55. outputs = inputs
  56. _check_sequence_type(outputs)
  57. torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
  58. # `output` used to be `outputs`, taking in a list of tensors. So we have two
  59. # arguments for BC reasons.
  60. def reduce(
  61. inputs: Sequence[torch.Tensor],
  62. output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
  63. root: int = 0,
  64. op: int = SUM,
  65. streams: Optional[Sequence[torch.cuda.Stream]] = None,
  66. comms=None,
  67. *,
  68. outputs: Optional[Sequence[torch.Tensor]] = None,
  69. ) -> None:
  70. _check_sequence_type(inputs)
  71. _output: torch.Tensor
  72. if outputs is not None:
  73. if output is not None:
  74. raise ValueError(
  75. "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
  76. "favor of 'output', taking in a single output tensor. The signature of reduce is: "
  77. "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
  78. )
  79. else:
  80. warnings.warn(
  81. "`nccl.reduce` with an output tensor list is deprecated. "
  82. "Please specify a single output tensor with argument 'output' instead instead.",
  83. FutureWarning,
  84. stacklevel=2,
  85. )
  86. _output = outputs[root]
  87. elif not isinstance(output, torch.Tensor) and isinstance(
  88. output, collections.abc.Sequence
  89. ):
  90. # User called old API with positional arguments of list of output tensors.
  91. warnings.warn(
  92. "nccl.reduce with an output tensor list is deprecated. "
  93. "Please specify a single output tensor.",
  94. FutureWarning,
  95. stacklevel=2,
  96. )
  97. _output = output[root]
  98. else:
  99. _output = inputs[root] if output is None else output
  100. torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
  101. def broadcast(
  102. inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
  103. ) -> None:
  104. _check_sequence_type(inputs)
  105. torch._C._nccl_broadcast(inputs, root, streams, comms)
  106. def all_gather(
  107. inputs: Sequence[torch.Tensor],
  108. outputs: Sequence[torch.Tensor],
  109. streams=None,
  110. comms=None,
  111. ) -> None:
  112. _check_sequence_type(inputs)
  113. _check_sequence_type(outputs)
  114. torch._C._nccl_all_gather(inputs, outputs, streams, comms)
  115. def reduce_scatter(
  116. inputs: Sequence[torch.Tensor],
  117. outputs: Sequence[torch.Tensor],
  118. op: int = SUM,
  119. streams=None,
  120. comms=None,
  121. ) -> None:
  122. _check_sequence_type(inputs)
  123. _check_sequence_type(outputs)
  124. torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)