collective.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import datetime
  15. import hashlib
  16. import os
  17. import paddle
  18. # (TODO: GhostScreaming) It will be removed later.
  19. from paddle.base import core
  20. from paddle.framework import in_dynamic_mode
  21. from .communication.group import Group, _add_new_group, is_initialized
  22. from .fleet.layers.mpu.mp_ops import ( # noqa: F401
  23. _c_concat,
  24. _c_identity,
  25. _c_lookup_table,
  26. _c_softmax_with_cross_entropy,
  27. _c_split,
  28. _Linear,
  29. _linear,
  30. _mp_allreduce,
  31. _parallel_embedding,
  32. _parallel_linear,
  33. _set_var_distributed,
  34. split,
  35. )
  36. __all__ = []
  37. _global_env = None
  38. def _get_global_env():
  39. global _global_env
  40. if not _global_env:
  41. _global_env = paddle.distributed.ParallelEnv()
  42. return _global_env
  43. # group map : the map of all group, 0 for GlobalGroup
  44. # Dict[int, Group]
  45. _group_map = {}
  46. _global_env_gid = 0
  47. # group map by name : the map of all groups from their names
  48. # Dict[name, Group]
  49. _group_map_by_name = {}
  50. # backend map by group : the map of all backend from their groups
  51. # Dict[group, backend]
  52. _group_map_backend = {}
  53. # Name of the default group for init_parallel_env
  54. _default_group_name = "_default_pg"
  55. _valid_backend_list = ['nccl', 'gloo', 'heter', 'xccl', 'bkcl']
  56. _default_store = None # the default tcp store
  57. _default_backend = None
  58. _default_timeout = datetime.timedelta(seconds=1800)
  59. _start_ring_id = 0
  60. def _set_default_backend(backend):
  61. global _default_backend
  62. _default_backend = backend
  63. def _set_default_store(store):
  64. global _default_store
  65. _default_store = store
  66. def _get_group_map():
  67. global _group_map
  68. if _global_env_gid not in _group_map:
  69. genv = _get_global_env()
  70. _group_map[_global_env_gid] = Group(
  71. genv.rank, 0, list(range(genv.world_size))
  72. )
  73. return _group_map
  74. def _get_global_group():
  75. return _get_group_map()[_global_env_gid]
  76. def _get_group_map_by_name():
  77. global _group_map_by_name
  78. return _group_map_by_name
  79. def _get_default_group():
  80. global _group_map_by_name
  81. assert is_initialized(), (
  82. "Call paddle.distributed.init_parallel_env first "
  83. "to initialize the distributed environment."
  84. )
  85. return _get_group_map_by_name()[_default_group_name]
  86. def _set_group_map(gid, group):
  87. global _group_map
  88. assert gid not in _group_map
  89. _group_map[gid] = group
  90. def _set_group_map_by_name(name, group):
  91. global _group_map_by_name
  92. assert name not in _group_map_by_name
  93. _group_map_by_name[name] = group
  94. def _set_group_map_backend(group, backend):
  95. global _group_map_backend
  96. assert group not in _group_map_backend
  97. _group_map_backend[group] = backend
  98. def _new_ring_id():
  99. # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
  100. if in_dynamic_mode():
  101. global _start_ring_id
  102. _start_ring_id += 1
  103. return _start_ring_id + max(_get_global_env().nrings, 9)
  104. else:
  105. return len(_get_group_map()) + max(_get_global_env().nrings, 9)
  106. def _new_process_group_impl(
  107. backend,
  108. store,
  109. rank,
  110. world_size,
  111. group_name,
  112. pg_options,
  113. group_id=0,
  114. nccl_comm_init_option=0,
  115. ):
  116. pg = None
  117. genv = _get_global_env()
  118. assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
  119. if backend == "gloo":
  120. pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
  121. elif backend == "nccl":
  122. pg = core.ProcessGroupNCCL.create(
  123. store,
  124. rank,
  125. world_size,
  126. group_id,
  127. genv.pg_timeout,
  128. nccl_comm_init_option,
  129. )
  130. elif backend == "xccl":
  131. pg = core.ProcessGroupCustom.create(
  132. store, genv.device_type, rank, world_size, group_id
  133. )
  134. elif backend == "bkcl":
  135. pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id)
  136. return pg
  137. # _custom_gid provides a way for users to
  138. # set the group id, which is usually useful
  139. # to be compatible with the static graph mode.
  140. _custom_gid = None
  141. def _set_custom_gid(gid):
  142. global _custom_gid
  143. _custom_gid = gid
  144. def new_group(
  145. ranks=None,
  146. backend=None,
  147. timeout=_default_timeout,
  148. nccl_comm_init_option=0,
  149. ):
  150. """
  151. Creates a new distributed communication group.
  152. Args:
  153. ranks (list): The global ranks of group members.
  154. backend (str): The backend used to create group, only nccl is supported now.
  155. timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
  156. Returns:
  157. Group: The group instance.
  158. Examples:
  159. .. code-block:: python
  160. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  161. >>> import paddle
  162. >>> paddle.distributed.init_parallel_env()
  163. >>> tindata = paddle.randn(shape=[2, 3])
  164. >>> gp = paddle.distributed.new_group([2, 4, 6])
  165. >>> paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
  166. """
  167. global _custom_gid
  168. global _group_map
  169. if in_dynamic_mode():
  170. global _default_group_name
  171. gid = _custom_gid if _custom_gid else _new_ring_id()
  172. group_name = _default_group_name + str(gid)
  173. if backend != 'heter' and (ranks is None or len(ranks) > 1):
  174. global_group = _get_default_group()
  175. global_rank = global_group.rank
  176. global_ranks = global_group.ranks
  177. backend = _default_backend if backend is None else backend
  178. if ranks is None:
  179. ranks = global_ranks
  180. assert len(ranks) <= len(global_ranks), (
  181. "Size of new group must be less than or "
  182. "equal to that of the default global group."
  183. )
  184. size = len(ranks)
  185. ranks = sorted(ranks)
  186. if size > 1 and global_rank in ranks:
  187. rank = 0 if backend == 'heter' else ranks.index(global_rank)
  188. pg = _new_process_group_impl(
  189. backend,
  190. _default_store,
  191. rank,
  192. size,
  193. group_name,
  194. pg_options=None,
  195. group_id=gid,
  196. nccl_comm_init_option=nccl_comm_init_option,
  197. )
  198. else:
  199. rank = -1
  200. pg = None
  201. group = Group(rank, gid, ranks, pg=pg, name=group_name)
  202. _group_map_by_name[group_name] = group
  203. _group_map[gid] = group
  204. _group_map_backend[group] = backend
  205. # TODO: The method below is a new method for group management, will replace the previous
  206. # three in the future.
  207. _add_new_group(group)
  208. if int(os.getenv("FLAGS_eager_communication_connection", 0)) == 1:
  209. paddle.distributed.all_reduce(
  210. paddle.zeros([1], dtype=paddle.float32),
  211. group=group,
  212. sync_op=True,
  213. )
  214. return group
  215. if not backend:
  216. backend = 'nccl'
  217. assert backend == 'nccl', "backend other than nccl is not supported yet"
  218. genv = _get_global_env()
  219. global_rank = genv.rank
  220. ring_id = _new_ring_id()
  221. if global_rank not in ranks:
  222. gp = Group(-1, ring_id, ranks)
  223. _group_map[ring_id] = gp
  224. else:
  225. ranks = sorted(ranks)
  226. group_rank = ranks.index(global_rank)
  227. group_size = len(ranks)
  228. gp = Group(group_rank, ring_id, ranks)
  229. _group_map[ring_id] = gp
  230. if group_size >= 2:
  231. strategy = core.ParallelStrategy()
  232. strategy.nranks = group_size
  233. strategy.local_rank = group_rank
  234. strategy.trainer_endpoints = [
  235. genv.trainer_endpoints[i] for i in ranks
  236. ]
  237. strategy.current_endpoint = genv.current_endpoint
  238. strategy.nrings = 1
  239. if core.is_compiled_with_cuda():
  240. place = core.CUDAPlace(genv.device_id)
  241. core.NCCLParallelContext(strategy, place).init_with_ring_id(
  242. ring_id
  243. )
  244. elif core.is_compiled_with_xpu():
  245. place = core.XPUPlace(genv.device_id)
  246. core.BKCLParallelContext(strategy, place).init_with_ring_id(
  247. ring_id
  248. )
  249. else:
  250. raise AssertionError("no cuda device found")
  251. else:
  252. return gp
  253. # TODO(shenliang03): This is a temporary solution to solve the problem of
  254. # hang caused by cross-creation of new_group
  255. tmp = (
  256. paddle.to_tensor([1], dtype="int32")
  257. if in_dynamic_mode()
  258. else paddle.full([0], 1, dtype="int32")
  259. )
  260. paddle.distributed.all_reduce(tmp, sync_op=True)
  261. paddle.distributed.wait(tmp)
  262. return gp
  263. def is_available():
  264. """
  265. Check whether the distributed package is available.
  266. Returns:
  267. Returns True if the distributed package is available, otherwise False.
  268. Examples:
  269. .. code-block:: python
  270. >>> import paddle
  271. >>> print(paddle.distributed.is_available())
  272. """
  273. return core.is_compiled_with_dist()
  274. def _init_parallel_env(backend):
  275. store = core.create_or_get_global_tcp_store()
  276. global_env = _get_global_env()
  277. rank = global_env.rank
  278. world_size = global_env.world_size
  279. dev_id = global_env.device_id
  280. if backend == "gloo":
  281. core.CommContextManager.create_gloo_comm_context(
  282. store, "0", rank, world_size
  283. )
  284. elif backend == "nccl":
  285. endpoints_str = ""
  286. for endpoint in global_env.trainer_endpoints:
  287. endpoints_str += endpoint
  288. endpoints_str += "ring_id:{}".format("0")
  289. endpoints_str_hash = hashlib.md5(
  290. endpoints_str.encode(encoding='UTF-8')
  291. ).hexdigest()
  292. core.CommContextManager.set_device_id(dev_id)
  293. core.CommContextManager.create_nccl_comm_context(
  294. store, "0", rank, world_size, endpoints_str_hash
  295. )
  296. elif backend == "xccl":
  297. dev_type = global_env.device_type
  298. paddle.device.set_device(f"{dev_type}:{dev_id}")
  299. core.CommContextManager.create_xccl_comm_context(
  300. store, "0", rank, world_size, dev_type
  301. )
  302. elif backend == "bkcl":
  303. endpoints_str = ""
  304. for endpoint in global_env.trainer_endpoints:
  305. endpoints_str += endpoint
  306. endpoints_str += "ring_id:{}".format("0")
  307. endpoints_str_hash = hashlib.md5(
  308. endpoints_str.encode(encoding='UTF-8')
  309. ).hexdigest()
  310. core.CommContextManager.set_device_id(dev_id)
  311. core.CommContextManager.create_bkcl_comm_context(
  312. store, "0", rank, world_size, endpoints_str_hash
  313. )