group.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. import paddle
  16. import paddle.distributed as dist
  17. from paddle import framework
  18. class Group:
  19. """
  20. The abstract representation of group.
  21. """
  22. def __init__(self, rank_in_group, id, ranks, pg=None, name=None):
  23. self._rank_in_group = rank_in_group
  24. self._world_size = len(ranks) if rank_in_group >= 0 else -1
  25. self._id = id
  26. self._ranks = ranks
  27. self._pg = pg
  28. self._name = name
  29. @property
  30. def rank(self):
  31. return self._rank_in_group
  32. @property
  33. def ranks(self):
  34. return self._ranks
  35. @property
  36. def nranks(self):
  37. return len(self._ranks)
  38. @property
  39. def name(self):
  40. return self._name
  41. @property
  42. def process_group(self):
  43. return self._pg
  44. @property
  45. def world_size(self):
  46. return self._world_size
  47. @property
  48. def backend(self):
  49. return self._pg.name()
  50. @property
  51. def id(self):
  52. return self._id
  53. def is_member(self):
  54. if self.rank < 0:
  55. return False
  56. if self.nranks < 2:
  57. return False
  58. return True
  59. def get_group_rank(self, rank):
  60. if self.is_member():
  61. return self.ranks.index(rank)
  62. else:
  63. return -1
  64. def __repr__(self):
  65. debug_str = (
  66. f"rank: {self.rank}, nranks: {self.nranks}, id: {self.id}, ranks: "
  67. )
  68. debug_str += ", ".join(map(str, self.ranks))
  69. debug_str += "; name: "
  70. debug_str += self.name if self.name else "None"
  71. return debug_str
  72. class _GroupManager:
  73. global_group_id = 0
  74. group_map_by_id = {}
  75. def _get_global_group():
  76. if _GroupManager.global_group_id not in _GroupManager.group_map_by_id:
  77. raise RuntimeError("The global group is not initialized.")
  78. return _GroupManager.group_map_by_id[_GroupManager.global_group_id]
  79. def _add_new_group(group):
  80. if group.id in _GroupManager.group_map_by_id:
  81. raise RuntimeError(f"The group with id {group.id} already exist.")
  82. _GroupManager.group_map_by_id[group.id] = group
  83. def _is_global_group(group):
  84. return group.id == _GroupManager.global_group_id
  85. def _warn_cur_rank_not_in_group(group):
  86. global_rank = dist.get_rank()
  87. if group and not group.is_member():
  88. warnings.warn(
  89. f"Current global rank {global_rank} is not in group {group.name}"
  90. )
  91. return True
  92. return False
  93. def _get_or_throw_group_rank(global_rank, group):
  94. group_rank = group.get_group_rank(global_rank)
  95. assert (
  96. group_rank >= 0
  97. ), f"The input rank {global_rank} can not be found inside the group {group.name}"
  98. return group_rank
  99. def is_initialized():
  100. """
  101. Check whether the distributed environment has been initialized
  102. Returns:
  103. `True` if distributed environment has been initialized, otherwise `False`.
  104. Warning:
  105. This API only supports the dygraph mode.
  106. Examples:
  107. .. code-block:: python
  108. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  109. >>> import paddle
  110. >>> print(paddle.distributed.is_initialized())
  111. False
  112. >>> paddle.distributed.init_parallel_env()
  113. >>> print(paddle.distributed.is_initialized())
  114. True
  115. """
  116. return _GroupManager.global_group_id in _GroupManager.group_map_by_id
  117. def destroy_process_group(group=None):
  118. """
  119. Destroy a given group for communication
  120. Args:
  121. group (Group, optional): The group to be destroyed. All of process groups, including
  122. the default group, will be destroyed and the distributed
  123. environment will be deinitialized.
  124. Returns : None
  125. Warning:
  126. This API only supports the dygraph mode.
  127. Examples:
  128. .. code-block:: python
  129. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  130. >>> import paddle
  131. >>> import paddle.distributed as dist
  132. >>> dist.init_parallel_env()
  133. >>> group = dist.new_group([0, 1])
  134. >>> dist.destroy_process_group(group)
  135. >>> print(dist.is_initialized())
  136. True
  137. >>> dist.destroy_process_group()
  138. >>> print(dist.is_initialized())
  139. False
  140. """
  141. group = _get_global_group() if group is None else group
  142. assert (
  143. group.id in _GroupManager.group_map_by_id
  144. ), f"Destroy group with id {group.id} is invalid."
  145. if _is_global_group(group):
  146. _GroupManager.group_map_by_id.clear()
  147. else:
  148. del _GroupManager.group_map_by_id[group.id]
  149. def get_group(id=0):
  150. """
  151. Get group instance by group id.
  152. Args:
  153. id (int): the group id. Default value is 0.
  154. Returns:
  155. Group: the group instance.
  156. Examples:
  157. .. code-block:: python
  158. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  159. >>> import paddle
  160. >>> import paddle.distributed as dist
  161. >>> dist.init_parallel_env()
  162. >>> gid = paddle.distributed.new_group([2,4,6])
  163. >>> paddle.distributed.get_group(gid.id)
  164. """
  165. if id in _GroupManager.group_map_by_id:
  166. return _GroupManager.group_map_by_id[id]
  167. warnings.warn(f"Group {id} is not initialized.")
  168. return None
  169. def _sync_calc_stream(tensor):
  170. if framework.in_dynamic_mode():
  171. return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor)
  172. else:
  173. op_type = 'c_sync_calc_stream'
  174. helper = framework.LayerHelper(op_type, **locals())
  175. helper.append_op(
  176. type=op_type,
  177. inputs={'X': [tensor]},
  178. outputs={'Out': [tensor]},
  179. )
  180. def _sync_comm_stream(tensor, ring_id=0):
  181. if framework.in_dynamic_mode():
  182. return paddle._legacy_C_ops.c_sync_comm_stream(
  183. [tensor], [tensor], 'ring_id', ring_id
  184. )
  185. else:
  186. op_type = 'c_sync_comm_stream'
  187. helper = framework.LayerHelper(op_type, **locals())
  188. helper.append_op(
  189. type=op_type,
  190. inputs={'X': [tensor]},
  191. outputs={'Out': [tensor]},
  192. attrs={'ring_id': ring_id},
  193. )
  194. def wait(tensor, group=None, use_calc_stream=True):
  195. """
  196. wait to sync stream for group.
  197. Args:
  198. tensor (Tensor): The Tensor used before sync.
  199. group (Group): The Group instance to perform sync.
  200. use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
  201. Default to True.
  202. Returns:
  203. None.
  204. Examples:
  205. .. code-block:: python
  206. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  207. >>> import paddle
  208. >>> paddle.distributed.init_parallel_env()
  209. >>> tindata = paddle.randn(shape=[2, 3])
  210. >>> paddle.distributed.all_reduce(tindata, sync_op=True)
  211. >>> paddle.distributed.wait(tindata)
  212. """
  213. if group is not None and not group.is_member():
  214. return
  215. if use_calc_stream:
  216. _sync_calc_stream(tensor)
  217. else:
  218. ring_id = 0 if group is None else group.id
  219. _sync_comm_stream(tensor, ring_id)
  220. def barrier(group=None):
  221. """
  222. Barrier among all participators in the group.
  223. Args:
  224. group (Group): The group instance return by new_group or None for global default group.
  225. Returns:
  226. None.
  227. Examples:
  228. .. code-block:: python
  229. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  230. >>> import paddle
  231. >>> from paddle.distributed import init_parallel_env
  232. >>> paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
  233. >>> init_parallel_env()
  234. >>> paddle.distributed.barrier()
  235. """
  236. if group is not None and not group.is_member():
  237. return
  238. if framework.in_dynamic_mode():
  239. group = _get_global_group() if group is None else group
  240. place = framework._current_expected_place()
  241. if isinstance(place, framework.CPUPlace):
  242. task = group.process_group.barrier()
  243. else:
  244. device_id = place.get_device_id()
  245. task = group.process_group.barrier(device_id)
  246. task.wait()
  247. return
  248. ring_id = 0 if group is None else group.id
  249. barrier_tensor = paddle.full([1], 1, dtype="int32")
  250. if framework.in_dynamic_mode():
  251. return paddle._legacy_C_ops.barrier(
  252. barrier_tensor, barrier_tensor, 'ring_id', ring_id
  253. )
  254. else:
  255. op_type = 'barrier'
  256. if not isinstance(ring_id, int):
  257. raise ValueError("The type of 'group' for barrier must be int.")
  258. helper = framework.LayerHelper(op_type, **locals())
  259. helper.append_op(
  260. type=op_type,
  261. inputs={'X': [barrier_tensor]},
  262. outputs={'Out': [barrier_tensor]},
  263. attrs={'ring_id': ring_id},
  264. )
  265. def get_backend(group=None):
  266. """
  267. Get the backend of given group.
  268. Args:
  269. group (Group): The group to work on. Use the global group as default.
  270. Returns:
  271. Returns the name of the given group backend.
  272. Examples:
  273. .. code-block:: python
  274. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  275. >>> import paddle
  276. >>> paddle.distributed.init_parallel_env()
  277. >>> paddle.distributed.get_backend()
  278. NCCL
  279. """
  280. if _warn_cur_rank_not_in_group(group):
  281. raise RuntimeError("Invalid group specified")
  282. group = _get_global_group() if group is None else group
  283. return group.backend