parallel.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except jin compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import itertools
  15. import os
  16. import sys
  17. import time
  18. import warnings
  19. from collections import OrderedDict, namedtuple
  20. from contextlib import contextmanager
  21. from multiprocessing import Manager, Process
  22. import numpy as np
  23. import paddle
  24. from paddle import _legacy_C_ops, framework
  25. from paddle.distributed.collective import (
  26. Group,
  27. _default_group_name,
  28. _get_group_map_by_name,
  29. _new_process_group_impl,
  30. _set_default_backend,
  31. _set_default_store,
  32. _set_group_map,
  33. _set_group_map_backend,
  34. _set_group_map_by_name,
  35. _valid_backend_list,
  36. )
  37. from paddle.distributed.communication.group import (
  38. _add_new_group,
  39. _get_global_group,
  40. is_initialized,
  41. )
  42. from paddle.distributed.fleet.base.private_helper_function import (
  43. wait_server_ready,
  44. )
  45. from paddle.distributed.fleet.launch_utils import check_backend
  46. # (TODO: GhostScreaming) It will be removed later.
  47. from paddle.framework import (
  48. _set_expected_place,
  49. base as imperative_base,
  50. core,
  51. in_dynamic_mode,
  52. )
  53. from paddle.nn.layer import layers
  54. from paddle.utils import deprecated
  55. from . import parallel_helper
  56. from .backup_env import getenv_or_backup
  57. __all__ = []
  58. ParallelStrategy = core.ParallelStrategy
  59. def _build_default_parallel_strategy():
  60. strategy = ParallelStrategy()
  61. strategy.nranks = paddle.distributed.ParallelEnv().nranks
  62. strategy.local_rank = paddle.distributed.ParallelEnv().local_rank
  63. strategy.trainer_endpoints = (
  64. paddle.distributed.ParallelEnv().trainer_endpoints
  65. )
  66. strategy.current_endpoint = (
  67. paddle.distributed.ParallelEnv().current_endpoint
  68. )
  69. return strategy
  70. def _coalesce_tensors(var_groups):
  71. coalesced_grads_and_grad_vars = []
  72. for group_id, grad_vars in var_groups.items():
  73. flattened_vars = []
  74. g_var_shapes = []
  75. for g_var in grad_vars:
  76. g_var_shapes.append(g_var.shape)
  77. flattened_vars.append(
  78. paddle.reshape(x=g_var, shape=[np.prod(g_var.shape)])
  79. )
  80. coalesced_grad = paddle.concat(flattened_vars)
  81. coalesced_grads_and_grad_vars.append(
  82. [coalesced_grad, grad_vars, g_var_shapes]
  83. )
  84. return coalesced_grads_and_grad_vars
  85. @framework.dygraph_only
  86. def _reshape_inplace(x, shape):
  87. x_shape = framework._create_tensor(dtype=x.dtype)
  88. framework._dygraph_tracer().trace_op(
  89. type="reshape2",
  90. inputs={'X': x},
  91. outputs={'Out': x, 'XShape': x_shape},
  92. attrs={'shape': shape},
  93. )
  94. @framework.dygraph_only
  95. def _split_tensors(coalesced_grads_and_grad_vars):
  96. if in_dynamic_mode():
  97. for (
  98. coalesced_grad,
  99. origin_grad_vars,
  100. grad_shapes,
  101. ) in coalesced_grads_and_grad_vars:
  102. grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
  103. attrs = ()
  104. attrs += ('sections', grad_var_len)
  105. attrs += ('axis', 0)
  106. _legacy_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
  107. for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
  108. g_var.reshape_(shape=g_shape)
  109. assert g_var.shape == g_shape
  110. @imperative_base.no_grad
  111. @framework.dygraph_only
  112. def build_groups(vars, group_size):
  113. group_idx = 0
  114. memory_counter = 0
  115. var_groups = OrderedDict()
  116. dtype = vars[0].dtype
  117. for var in vars:
  118. bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
  119. if memory_counter < group_size and dtype == var.dtype:
  120. memory_counter += bytes
  121. else:
  122. memory_counter = bytes
  123. dtype = var.dtype
  124. group_idx += 1
  125. var_groups.setdefault(group_idx, []).append(var)
  126. return _coalesce_tensors(var_groups)
  127. @imperative_base.no_grad
  128. @framework.dygraph_only
  129. def sync_params_buffers(
  130. model,
  131. comm_group=None,
  132. src_rank=0,
  133. is_model_parallel=False,
  134. fuse_params=True,
  135. ):
  136. model_vars = []
  137. for _, param in model._obtain_parameters_buffers().items():
  138. if not isinstance(param, core.eager.Tensor):
  139. raise TypeError(
  140. "The data type of '%s' must be core.eager.Tensor" % param.name
  141. )
  142. if is_model_parallel:
  143. if hasattr(param, "is_distributed") and param.is_distributed:
  144. continue
  145. # NOTE(shenliang03): Support situations that do not require synchronization parameters,
  146. # such as moe's expert parameters
  147. if getattr(param, "no_sync", False):
  148. continue
  149. if param.type == core.VarDesc.VarType.VOCAB:
  150. continue
  151. model_vars.append(param.detach())
  152. if len(model_vars) == 0:
  153. return
  154. if fuse_params:
  155. # group size is 128M
  156. coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
  157. for coalesced_var, _, _ in coalesced_vars:
  158. paddle.distributed.broadcast(
  159. coalesced_var, src=src_rank, group=comm_group, sync_op=True
  160. )
  161. for coalesced_var, origin_vars, var_shapes in coalesced_vars:
  162. var_len = [np.prod(v_shape) for v_shape in var_shapes]
  163. paddle.base.framework._dygraph_tracer().trace_op(
  164. type='split',
  165. inputs={'X': coalesced_var},
  166. outputs={'Out': origin_vars},
  167. attrs={'sections': var_len, 'axis': 0},
  168. )
  169. else:
  170. for var in model_vars:
  171. paddle.distributed.broadcast(
  172. var, src=src_rank, group=comm_group, sync_op=True
  173. )
  174. class DataParallel(layers.Layer):
  175. """
  176. Run the dygraph module with data parallelism.
  177. Currently, DataParallel class only supports to run the dynamic graph
  178. with multi-process.
  179. Now supports two ways to start training:
  180. 1. start by ``paddle.distributed.spawn`` method, for example:
  181. ``python demo.py`` (spawn need to be called in ``__main__`` method)
  182. 2. start by ``paddle.distributed.launch`` module, for example:
  183. ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
  184. And the content of `demo.py` is the code of examples.
  185. Args:
  186. layers(Layer): The module that should be executed by data parallel.
  187. strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism,
  188. contains environment configuration related to parallel execution. Default: None.
  189. comm_buffer_size(int, optional): It limits the memory size(MB) of one buffer
  190. parameters' gradient which is the input of communication
  191. calling(e.g NCCLAllReduce). Default: 25.
  192. last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
  193. calling. Making the last communication buffer size small is useful to
  194. improve performance. Default: 1.
  195. find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
  196. all tensors in the return value of the wrapped model's
  197. forward function. For parameters not involved in loss
  198. calculation, their gradients will be marked as ready in
  199. advance to prepare reduce. Please note that all forward
  200. outputs derived from the wrapped model parameters must
  201. participate in the calculation of loss and subsequent
  202. gradient calculations. If not, serious error will occur.
  203. Note that setting the find_unused_parameters to True
  204. will affect computing performance. Therefore, if all parameters
  205. are sure to participate in the loss calculation and the
  206. autograd graph construction, please set it False. Default: False.
  207. Returns:
  208. Layer: The data paralleled module.
  209. Examples:
  210. .. code-block:: python
  211. :name: dp-example
  212. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  213. >>> import paddle
  214. >>> import paddle.nn as nn
  215. >>> import paddle.optimizer as opt
  216. >>> import paddle.distributed as dist
  217. >>> class LinearNet(nn.Layer):
  218. ... def __init__(self):
  219. ... super().__init__()
  220. ... self._linear1 = nn.Linear(10, 10)
  221. ... self._linear2 = nn.Linear(10, 1)
  222. ... def forward(self, x):
  223. ... return self._linear2(self._linear1(x))
  224. >>> def train():
  225. ... # 1. initialize parallel environment
  226. ... dist.init_parallel_env()
  227. ... # 2. create data parallel layer & optimizer
  228. ... layer = LinearNet()
  229. ... dp_layer = paddle.DataParallel(layer)
  230. ... loss_fn = nn.MSELoss()
  231. ... adam = opt.Adam(
  232. ... learning_rate=0.001, parameters=dp_layer.parameters())
  233. ... # 3. run layer
  234. ... inputs = paddle.randn([10, 10], 'float32')
  235. ... outputs = dp_layer(inputs)
  236. ... labels = paddle.randn([10, 1], 'float32')
  237. ... loss = loss_fn(outputs, labels)
  238. ... loss.backward()
  239. ... adam.step()
  240. ... adam.clear_grad()
  241. >>> if __name__ == '__main__':
  242. ... # 1. start by ``paddle.distributed.spawn`` (default)
  243. ... dist.spawn(train, nprocs=2)
  244. ... # 2. start by ``paddle.distributed.launch``
  245. ... # train()
  246. .. note::
  247. ``PyLayer`` is not supported in DataParallel. To solve problems of this kind,
  248. it's recommended to skip gradient synchronization among multiple cards by 'no_sync',
  249. and manually implement 'all_reduce' before model optimization. There is an example
  250. showing specific implementation processing.
  251. Examples:
  252. .. code-block:: python
  253. :name: dp-pylayer-example
  254. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  255. >>> import numpy
  256. >>> import paddle
  257. >>> import paddle.distributed as dist
  258. >>> from paddle.autograd import PyLayer
  259. >>> from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
  260. >>> class cus_tanh(PyLayer):
  261. ... @staticmethod
  262. ... def forward(ctx, x):
  263. ... y = paddle.tanh(x)
  264. ... ctx.save_for_backward(y)
  265. ... return y
  266. ... @staticmethod
  267. ... def backward(ctx, dy):
  268. ... y, = ctx.saved_tensor()
  269. ... grad = dy * (1 - paddle.square(y))
  270. ... return grad
  271. >>> class SimpleNet(paddle.nn.Layer):
  272. ... def __init__(self):
  273. ... super().__init__()
  274. ... self.linear = paddle.nn.Linear(2, 2)
  275. ... def forward(self, inputs):
  276. ... inputs = cus_tanh.apply(inputs)
  277. ... return self.linear(inputs)
  278. >>> if __name__ == '__main__':
  279. ... dist.init_parallel_env()
  280. ... model = SimpleNet()
  281. ... model = paddle.DataParallel(model)
  282. ... opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
  283. ... for step in range(10):
  284. ... x_data = numpy.random.randn(2, 2).astype(numpy.float32)
  285. ... x = paddle.to_tensor(x_data)
  286. ... x.stop_gradient = False
  287. ... # step 1 : skip gradient synchronization by 'no_sync'
  288. ... with model.no_sync():
  289. ... y_pred = model(x)
  290. ... loss = y_pred.mean()
  291. ... loss.backward()
  292. ... # step 2 : fuse + allreduce manually before optimization
  293. ... fused_allreduce_gradients(list(model.parameters()), None)
  294. ... opt.step()
  295. ... opt.clear_grad()
  296. """
  297. def __init__(
  298. self,
  299. layers,
  300. strategy=None,
  301. comm_buffer_size=25,
  302. last_comm_buffer_size=1,
  303. find_unused_parameters=False,
  304. group=None,
  305. ):
  306. super().__init__(layers.full_name() + "_data_parallel")
  307. assert (
  308. in_dynamic_mode()
  309. ), "It's not supported to construct DataParallel in static graph mode."
  310. self._layers = layers
  311. self.find_unused_parameters = find_unused_parameters
  312. self.grad_need_sync = True
  313. self.group = group
  314. self.var_dtype = core.eager.Tensor
  315. # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
  316. # It just stores some environment variables, which can be constructed by
  317. # ParallelEnv. Here it is set as an optional argument.
  318. # This parameter is not removed because of compatibility with 1.x writing.
  319. if strategy is not None:
  320. self._strategy = strategy
  321. else:
  322. self._strategy = _build_default_parallel_strategy()
  323. if self._strategy.nranks > 1:
  324. # check the environment
  325. assert parallel_helper.__parallel_ctx__clz__ is not None, (
  326. "ParallelContext must be initialized before. You should use init_parallel_env() before"
  327. "constructing the DataParallel."
  328. )
  329. if in_dynamic_mode():
  330. self.group = (
  331. paddle.distributed.collective._get_default_group()
  332. if self.group is None
  333. else self.group
  334. )
  335. assert isinstance(
  336. self.group, paddle.distributed.collective.Group
  337. ), "ProcessGroup must be an instance of Group in DataParallel."
  338. # sync buffer and params
  339. sync_params_buffers(self._layers, fuse_params=False)
  340. self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
  341. # NOTE(shenliang03): We can set environment variables to control
  342. # the size of the group, Default: 1MB. The role of this small group is:
  343. # when the last group allreduce, the overlap cannot work. Making the
  344. # the last group small is useful to improve performance.
  345. self.last_comm_buffer_size = int(
  346. last_comm_buffer_size * 1024 * 1024
  347. )
  348. self.init_reducer()
  349. else:
  350. warnings.warn(
  351. "The program will return to single-card operation. "
  352. "Please check 1, whether you use spawn or fleetrun "
  353. "to start the program. 2, Whether it is a multi-card "
  354. "program. 3, Is the current environment multi-card."
  355. )
  356. def init_reducer(self):
  357. layers_param = []
  358. params_set = set()
  359. for sublayer in self.sublayers():
  360. for _, param in sublayer.named_parameters(include_sublayers=False):
  361. if param is None or param in params_set:
  362. continue
  363. params_set.add(param)
  364. if not isinstance(param, self.var_dtype):
  365. raise TypeError(
  366. f"The data type of '{param.name}' must be '{self.var_dtype}'"
  367. )
  368. if param.trainable:
  369. layers_param.append((sublayer, param))
  370. trainable_parameters = list(
  371. filter(
  372. lambda x: not getattr(x, "no_sync", False),
  373. [param for _, param in layers_param],
  374. )
  375. )
  376. assert len(trainable_parameters) > 0, (
  377. "This model does not have any parameters to train, and "
  378. "does not need to use DataParallel"
  379. )
  380. # NOTE(shenliang03): Here we can only use the attributes to judge whether
  381. # parameter is sparse(or SelectedRows). The reason is that the sparse message
  382. # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
  383. # we should add the layer here like "paddle.nn.layer.common.Embedding".
  384. def check_layer_sparse(sublayer):
  385. if isinstance(sublayer, paddle.nn.layer.common.Embedding):
  386. return sublayer._sparse
  387. return False
  388. is_sparse_gradient = [
  389. check_layer_sparse(sublayer)
  390. for sublayer, param in layers_param
  391. if not getattr(param, "no_sync", False)
  392. ]
  393. if in_dynamic_mode():
  394. self.group_indices = core.eager_assign_group_by_size(
  395. trainable_parameters,
  396. is_sparse_gradient,
  397. [self.last_comm_buffer_size, self.comm_buffer_size],
  398. )
  399. self._reducer = core.EagerReducer(
  400. trainable_parameters,
  401. list(reversed(self.group_indices)),
  402. is_sparse_gradient,
  403. self.group.process_group,
  404. [self.last_comm_buffer_size, self.comm_buffer_size],
  405. self.find_unused_parameters,
  406. )
  407. def _find_tensor(self, obj):
  408. var_type = core.eager.Tensor
  409. if isinstance(obj, var_type):
  410. return [obj]
  411. if isinstance(obj, (list, tuple)):
  412. return itertools.chain(*map(self._find_tensor, obj))
  413. if isinstance(obj, dict):
  414. return itertools.chain(*map(self._find_tensor, obj.values()))
  415. return []
  416. @contextmanager
  417. def no_sync(self):
  418. """
  419. A context manager to stop gradient synchronization. Within no_sync(),
  420. gradients of parameters will only be accumulated on model and not
  421. synchronized util the first forward-backward out of this context.
  422. Examples:
  423. .. code-block:: python
  424. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  425. >>> import paddle
  426. >>> import paddle.nn as nn
  427. >>> import paddle.distributed as dist
  428. >>> class SimpleNet(nn.Layer):
  429. ... def __init__(self):
  430. ... super().__init__()
  431. ... self._linear = nn.Linear(10, 1)
  432. ... def forward(self, x):
  433. ... return self._linear(x)
  434. >>> dist.init_parallel_env()
  435. >>> model = SimpleNet()
  436. >>> dp_model = paddle.DataParallel(model)
  437. >>> inputs_1 = paddle.randn([10, 10], 'float32')
  438. >>> inputs_2 = paddle.ones([10, 10], 'float32')
  439. >>> with dp_model.no_sync():
  440. ... # gradients will not be synchronized
  441. ... dp_model(inputs_1).backward()
  442. >>> # synchronization happens here
  443. >>> dp_model(inputs_2).backward()
  444. """
  445. tmp_grad_need_sync = self.grad_need_sync
  446. self.grad_need_sync = False
  447. try:
  448. yield
  449. finally:
  450. self.grad_need_sync = tmp_grad_need_sync
  451. def forward(self, *inputs, **kwargs):
  452. outputs = self._layers(*inputs, **kwargs)
  453. if (
  454. self._strategy.nranks > 1
  455. and framework._dygraph_tracer()._has_grad
  456. and self.grad_need_sync
  457. ):
  458. self._reducer.prepare_for_backward(list(self._find_tensor(outputs)))
  459. return outputs
  460. @deprecated(
  461. since="2.0.0", reason="This method does not need to be called anymore."
  462. )
  463. def scale_loss(self, loss):
  464. """
  465. Deprecated method, now ``scale_loss`` is an empty method,
  466. keep this method just for compatibility.
  467. """
  468. return loss
  469. @deprecated(
  470. since="2.0.0", reason="This method does not need to be called anymore."
  471. )
  472. def apply_collective_grads(self):
  473. """
  474. Deprecated method, now ``apply_collective_grads`` is an empty method,
  475. keep this method just for compatibility.
  476. """
  477. return
  478. def state_dict(
  479. self,
  480. destination=None,
  481. include_sublayers=True,
  482. structured_name_prefix="",
  483. ):
  484. '''
  485. Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
  486. Parameters:
  487. destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
  488. include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
  489. Returns:
  490. dict: a dict contains all the parameters and persistable buffers.
  491. Examples:
  492. .. code-block:: python
  493. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  494. >>> import paddle
  495. >>> import paddle.distributed as dist
  496. >>> dist.init_parallel_env()
  497. >>> emb = paddle.nn.Embedding(10, 10)
  498. >>> emb = paddle.DataParallel(emb)
  499. >>> state_dict = emb.state_dict()
  500. >>> paddle.save(state_dict, "paddle_dy.pdparams")
  501. '''
  502. return self._layers.state_dict(
  503. destination=destination,
  504. include_sublayers=include_sublayers,
  505. structured_name_prefix=structured_name_prefix,
  506. )
  507. @framework.deprecate_stat_dict
  508. def set_state_dict(self, state_dict, use_structured_name=True):
  509. '''
  510. Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
  511. Parameters:
  512. state_dict(dict) : Dict contains all the parameters and persistable buffers.
  513. use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
  514. Default: True
  515. Returns:
  516. None
  517. Examples:
  518. .. code-block:: python
  519. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  520. >>> import paddle
  521. >>> import paddle.distributed as dist
  522. >>> dist.init_parallel_env()
  523. >>> emb = paddle.nn.Embedding(10, 10)
  524. >>> emb = paddle.DataParallel(emb)
  525. >>> state_dict = emb.state_dict()
  526. >>> paddle.save(state_dict, "paddle_dy.pdparams")
  527. >>> para_state_dict = paddle.load("paddle_dy.pdparams")
  528. >>> emb.set_state_dict(para_state_dict)
  529. '''
  530. self._layers.set_state_dict(
  531. state_dict, use_structured_name=use_structured_name
  532. )
  533. # [aliases] Compatible with old method names
  534. set_dict = set_state_dict
  535. load_dict = set_state_dict
  536. # NOTE(chenweihang): Maintain a global parallel env to avoid
  537. # initializing ParallelEnv every time and improve performance
  538. _global_parallel_env = None
  539. class ParallelEnv:
  540. """
  541. .. note::
  542. This API is not recommended, if you need to get rank and world_size,
  543. it is recommended to use ``paddle.distributed.get_rank()`` and
  544. ``paddle.distributed.get_world_size()`` .
  545. This class is used to obtain the environment variables required for
  546. the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
  547. The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
  548. or ``paddle.distributed.spawn`` .
  549. Examples:
  550. .. code-block:: python
  551. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  552. >>> import paddle
  553. >>> import paddle.distributed as dist
  554. >>> def train():
  555. ... # 1. initialize parallel environment
  556. ... dist.init_parallel_env()
  557. ... # 2. get current ParallelEnv
  558. ... parallel_env = dist.ParallelEnv()
  559. ... print("rank: ", parallel_env.rank)
  560. ... print("world_size: ", parallel_env.world_size)
  561. >>> if __name__ == '__main__':
  562. ... # 1. start by ``paddle.distributed.spawn`` (default)
  563. ... dist.spawn(train, nprocs=2)
  564. ... # 2. start by ``paddle.distributed.launch``
  565. ... train()
  566. # Print result in process 1:
  567. rank: 1
  568. world_size: 2
  569. # Print result in process 2:
  570. rank: 2
  571. world_size: 2
  572. """
  573. def __init__(self):
  574. self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
  575. self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
  576. self._device_type = str(os.getenv("PADDLE_XCCL_BACKEND", ""))
  577. self._pg_timeout = int(os.getenv("PADDLE_PG_TIMEOUT", "1800000"))
  578. # imperative only support one gpu or xpu
  579. if self._device_type != "":
  580. FLAGS_selected_custom_devices = (
  581. f'FLAGS_selected_{self._device_type}s'
  582. )
  583. selected_custom_devices = os.getenv(
  584. FLAGS_selected_custom_devices, "0"
  585. ).split(",")
  586. self._device_id = int(selected_custom_devices[0])
  587. else:
  588. if core.is_compiled_with_cuda():
  589. selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
  590. self._device_id = int(selected_gpus[0])
  591. elif core.is_compiled_with_xpu():
  592. selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
  593. self._device_id = int(selected_xpus[0])
  594. self._trainer_endpoints = getenv_or_backup(
  595. "PADDLE_TRAINER_ENDPOINTS", ""
  596. ).split(",")
  597. self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
  598. self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
  599. assert (
  600. self._nrings > 0
  601. ), "nccl_nrings must be an integer greater than 0."
  602. assert (
  603. self._nrings < 9
  604. ), "nccl_nrings should be less than 9, which is enough in most scenarios."
  605. @property
  606. def rank(self):
  607. """
  608. Rank of current trainer.
  609. Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
  610. Examples:
  611. .. code-block:: python
  612. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  613. >>> # execute this command in terminal: export PADDLE_TRAINER_ID=0
  614. >>> import paddle.distributed as dist
  615. >>> env = dist.ParallelEnv()
  616. >>> print("The rank is %d" % env.rank)
  617. The rank is 0
  618. """
  619. return self._rank
  620. @property
  621. def world_size(self):
  622. """
  623. The number of trainers (number of processes participating in current job).
  624. Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
  625. Examples:
  626. .. code-block:: python
  627. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  628. >>> # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
  629. >>> import paddle.distributed as dist
  630. >>> env = dist.ParallelEnv()
  631. >>> print("The world_size is %d" % env.world_size)
  632. The world_size is 4
  633. """
  634. return self._world_size
  635. @property
  636. def device_id(self):
  637. """
  638. The ID of selected GPU card for parallel training.
  639. Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
  640. Examples:
  641. .. code-block:: python
  642. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  643. >>> # execute this command in terminal: export FLAGS_selected_gpus=1
  644. >>> import paddle.distributed as dist
  645. >>> env = dist.ParallelEnv()
  646. >>> print("The device id are %d" % env.device_id)
  647. The device id are 1
  648. """
  649. return self._device_id
  650. @property
  651. def device_type(self):
  652. """
  653. The type of custom device for parallel training.
  654. Its value is equal to the value of the environment variable ``PADDLE_XCCL_BACKEND`` . The default value is None.
  655. """
  656. return self._device_type
  657. @property
  658. def current_endpoint(self):
  659. """
  660. The endpoint of current trainer, it is in the form of (node IP + port).
  661. Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
  662. Examples:
  663. .. code-block:: python
  664. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  665. >>> # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
  666. >>> import paddle.distributed as dist
  667. >>> env = dist.ParallelEnv()
  668. >>> print("The current endpoint are %s" % env.current_endpoint)
  669. The current endpoint are 127.0.0.1:6170
  670. """
  671. return self._current_endpoint
  672. @property
  673. def trainer_endpoints(self):
  674. """
  675. The endpoints of all trainer nodes in the task,
  676. which are used to broadcast the NCCL ID when NCCL2 is initialized.
  677. Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
  678. Examples:
  679. .. code-block:: python
  680. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  681. >>> # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
  682. >>> import paddle.distributed as dist
  683. >>> env = dist.ParallelEnv()
  684. >>> print("The trainer endpoints are %s" % env.trainer_endpoints)
  685. The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
  686. """
  687. return self._trainer_endpoints
  688. @property
  689. def nrings(self):
  690. """
  691. Nrings of current trainer.
  692. Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.
  693. Examples:
  694. .. code-block:: python
  695. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  696. >>> # execute this command in terminal: export FLAGS_nccl_nrings=1
  697. >>> import paddle.distributed as dist
  698. >>> env = dist.ParallelEnv()
  699. >>> print("The nrings is %d" % env.nrings)
  700. The nrings is 1
  701. """
  702. return self._nrings
  703. @property
  704. def pg_timeout(self):
  705. """
  706. timeout of process group.
  707. Its value is equal to the value of the environment variable ``PADDLE_PG_TIMEOUT`` . The default value is 30 minutes.
  708. Examples:
  709. .. code-block:: python
  710. >>> # execute this command in terminal: export PADDLE_PG_TIMEOUT=1800000
  711. >>> import paddle.distributed as dist
  712. >>> env = dist.ParallelEnv()
  713. >>> # the pg_timeout of process group 1800000
  714. """
  715. return self._pg_timeout
  716. # [aliases] Compatible with old method names
  717. local_rank = rank
  718. nranks = world_size
  719. dev_id = device_id
  720. def _get_global_parallel_env():
  721. global _global_parallel_env
  722. if _global_parallel_env is None:
  723. _global_parallel_env = ParallelEnv()
  724. return _global_parallel_env
  725. def _start_kv_server(port, http_server_d, size):
  726. from paddle.distributed.fleet.utils.http_server import KVServer
  727. http_server = KVServer(int(port), size=size)
  728. http_server.start()
  729. wait_seconds = 3
  730. while http_server_d.get("running", False) or not http_server.should_stop():
  731. time.sleep(wait_seconds)
  732. http_server.stop()
  733. def _is_cpuonly(backend):
  734. check_backend(backend)
  735. if (
  736. backend in ['auto', 'nccl', 'bkcl', 'heter']
  737. and (core.is_compiled_with_cuda() or core.is_compiled_with_xpu())
  738. ) or backend == 'xccl':
  739. # passes 'auto' and can use cuda or xpu, use the default logics. so return False
  740. return False
  741. else:
  742. return True
  743. def _check_var_exists(var_name):
  744. var = getenv_or_backup(var_name, None)
  745. if var is None:
  746. raise ValueError(
  747. "paddle.distributed initialize error, "
  748. "environment variable %s is needed, but not set." % var_name
  749. )
  750. def _get_modified_flags():
  751. ret = []
  752. FLAGS = namedtuple('FLAGS', ['name', 'current_value', 'default_value'])
  753. global_flags = core.globals()
  754. for key in global_flags.keys():
  755. value = global_flags.get(key)
  756. default_value = global_flags.get_default(key)
  757. if not value == default_value:
  758. ret.append(FLAGS(key, value, default_value))
  759. return ret
  760. def _print_modified_flags(modified_flags):
  761. if len(modified_flags) > 0:
  762. sys.stderr.write(
  763. "======================= Modified FLAGS detected =======================\n"
  764. )
  765. for flag in modified_flags:
  766. sys.stderr.write(str(flag))
  767. sys.stderr.write("\n")
  768. sys.stderr.write(
  769. "=======================================================================\n"
  770. )
  771. def init_parallel_env():
  772. """
  773. Initialize parallel training environment in dynamic graph mode.
  774. Note:
  775. Now initialize both `NCCL` and `GLOO` contexts for communication.
  776. Args:
  777. backend (string): A string represents the backend used by DataParallel,
  778. should be one of 'gloo'(for cpu), 'nccl'(for cuda), 'bkcl'(for xpu), 'auto'(auto detect).
  779. The auto detection prefer 'nccl', 'bkcl' than 'gloo'.
  780. Returns:
  781. None
  782. Examples:
  783. .. code-block:: python
  784. >>> # doctest: +REQUIRES(env:GPU, env:DISTRIBUTED)
  785. >>> import paddle
  786. >>> import paddle.nn as nn
  787. >>> import paddle.optimizer as opt
  788. >>> import paddle.distributed as dist
  789. >>> class LinearNet(nn.Layer):
  790. ... def __init__(self):
  791. ... super().__init__()
  792. ... self._linear1 = nn.Linear(10, 10)
  793. ... self._linear2 = nn.Linear(10, 1)
  794. ... def forward(self, x):
  795. ... return self._linear2(self._linear1(x))
  796. >>> def train():
  797. ... # 1. initialize parallel environment
  798. ... dist.init_parallel_env()
  799. ... # 2. create data parallel layer & optimizer
  800. ... layer = LinearNet()
  801. ... dp_layer = paddle.DataParallel(layer)
  802. ... loss_fn = nn.MSELoss()
  803. ... adam = opt.Adam(
  804. ... learning_rate=0.001, parameters=dp_layer.parameters())
  805. ... # 3. run layer
  806. ... inputs = paddle.randn([10, 10], 'float32')
  807. ... outputs = dp_layer(inputs)
  808. ... labels = paddle.randn([10, 1], 'float32')
  809. ... loss = loss_fn(outputs, labels)
  810. ... loss.backward()
  811. ... adam.step()
  812. ... adam.clear_grad()
  813. >>> if __name__ == '__main__':
  814. ... dist.spawn(train)
  815. """
  816. modified_flags = _get_modified_flags()
  817. _print_modified_flags(modified_flags)
  818. # 0. get env & check world size
  819. global _global_parallel_env
  820. # when call init_parallel_env, need update `_global_parallel_env`
  821. _global_parallel_env = ParallelEnv()
  822. parallel_env = _global_parallel_env
  823. # if not parallel, `init_parallel_env` do nothing
  824. if parallel_env.world_size < 2:
  825. warnings.warn(
  826. "Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything."
  827. )
  828. return
  829. # NOTE(xiongkun): support cpu gloo only, add this environment variable to
  830. # enable cpu only gloo parallel training)
  831. backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto')
  832. is_cpu_only = _is_cpuonly(backend)
  833. # 1. gpu xpu check, must be gpu or xpu,
  834. if not (
  835. is_cpu_only
  836. or core.is_compiled_with_cuda()
  837. or core.is_compiled_with_xpu()
  838. or backend == "xccl"
  839. ):
  840. raise NotImplementedError(
  841. "If you want to use CPU-only version, please use 'gloo' as backend"
  842. )
  843. if backend == "xccl":
  844. FLAGS_selected_custom_devices = (
  845. f'FLAGS_selected_{parallel_env.device_type}s'
  846. )
  847. _check_var_exists(FLAGS_selected_custom_devices)
  848. else:
  849. if not is_cpu_only and core.is_compiled_with_cuda():
  850. _check_var_exists("FLAGS_selected_gpus")
  851. backend = "nccl" if backend == "auto" else backend
  852. elif not is_cpu_only and core.is_compiled_with_xpu():
  853. _check_var_exists('FLAGS_selected_xpus')
  854. backend = "bkcl" if backend == "auto" else backend
  855. _check_var_exists("PADDLE_TRAINER_ID")
  856. _check_var_exists("PADDLE_CURRENT_ENDPOINT")
  857. _check_var_exists("PADDLE_TRAINERS_NUM")
  858. # NOTE(chenweihang): [ why config global place here? ]
  859. # the dygraph mode will be set to default mode,
  860. # users will not call `dygraph.guard` or `enable_dygraph`
  861. # directly, if they want to switch default place,
  862. # they need to call a function to change default place,
  863. # here just set correctly place to users
  864. if backend == "xccl":
  865. place = core.CustomPlace(
  866. parallel_env.device_type, parallel_env.device_id
  867. )
  868. elif is_cpu_only:
  869. place = core.CPUPlace()
  870. elif core.is_compiled_with_cuda():
  871. place = core.CUDAPlace(parallel_env.device_id)
  872. elif core.is_compiled_with_xpu():
  873. place = core.XPUPlace(parallel_env.device_id)
  874. _set_expected_place(place)
  875. group = None
  876. if backend in _valid_backend_list and in_dynamic_mode():
  877. if _default_group_name in _get_group_map_by_name():
  878. return _get_group_map_by_name()[_default_group_name]
  879. _set_default_backend(backend)
  880. rank = int(os.getenv("PADDLE_TRAINER_ID"))
  881. world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))
  882. assert rank >= 0 and world_size > rank and world_size > 1, (
  883. "rank must be non-negative and world_size must be the "
  884. "maximum rank plus one. Moreover, at least two processes are "
  885. "required to create a process group."
  886. )
  887. master_addr = os.getenv("MASTER_ADDR", None)
  888. master_port = os.getenv("MASTER_PORT", None)
  889. endpoints = (
  890. ":".join([master_addr, master_port])
  891. if master_addr and master_port
  892. else None
  893. )
  894. if endpoints is None:
  895. endpoints = os.getenv("PADDLE_MASTER", None)
  896. if endpoints is None:
  897. endpoints = getenv_or_backup("PADDLE_TRAINER_ENDPOINTS").split(',')[
  898. 0
  899. ]
  900. assert endpoints, (
  901. "The environment variable 'MASTER_ADDR' and 'MASTER_PORT' "
  902. "must be specified, for example 'export MASTER_ADDR=127.0.0.1' "
  903. "and 'export MASTER_ADDR=54612'. Or you can start your training"
  904. "with paddle.distributed.run module."
  905. )
  906. master_addr, master_port = endpoints.split(":")
  907. master_port = int(master_port)
  908. is_master = rank == 0
  909. stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900"))
  910. default_store = core.create_or_get_global_tcp_store()
  911. _set_default_store(default_store)
  912. pg = _new_process_group_impl(
  913. backend,
  914. default_store,
  915. rank,
  916. world_size,
  917. _default_group_name,
  918. pg_options=None,
  919. )
  920. ranks = list(range(world_size))
  921. group = Group(rank, 0, ranks, pg=pg, name=_default_group_name)
  922. _set_group_map_by_name(_default_group_name, group)
  923. _set_group_map(0, group)
  924. _set_group_map_backend(group, backend)
  925. _add_new_group(group)
  926. parallel_helper._set_parallel_ctx(True)
  927. # barrier will call CreateNCCLEnvCache which will call CreateNCCLCommContext.
  928. # Set device_id to prevent creating null dev_ctx.
  929. # TODO(mine): support XPU and other backends.
  930. if backend in ["nccl", 'xccl', 'bkcl']:
  931. core.CommContextManager.set_device_id(parallel_env.device_id)
  932. if int(os.getenv("FLAGS_eager_communication_connection", 0)) == 1:
  933. paddle.distributed.all_reduce(
  934. paddle.zeros([1], dtype=paddle.float32),
  935. group=group,
  936. sync_op=True,
  937. )
  938. return group
  939. node_num = {i.split(":")[0] for i in parallel_env.trainer_endpoints}
  940. # 3: init gloo context (step 1: httpserver start)
  941. init_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0"))
  942. if is_cpu_only or init_gloo or backend == "heter":
  943. ep_rank_0 = parallel_env.trainer_endpoints[0].split(":")
  944. manager = Manager()
  945. # global dict to store status
  946. http_server_d = manager.dict()
  947. http_server_d["running"] = False
  948. if parallel_env.rank == 0:
  949. # The scope for worker used by http server is '_worker'
  950. size = {'_worker': parallel_env.world_size}
  951. if backend == "heter":
  952. size = {'_worker': len(node_num)}
  953. http_server = Process(
  954. target=_start_kv_server,
  955. args=(int(ep_rank_0[1]), http_server_d, size),
  956. )
  957. http_server.daemon = True
  958. http_server_d["running"] = True
  959. http_server.start()
  960. # 4. init NCCL ParallelStrategy
  961. strategy = ParallelStrategy()
  962. if parallel_helper._is_parallel_ctx_initialized():
  963. warnings.warn("The parallel environment has been initialized.")
  964. strategy.nranks = parallel_env.world_size
  965. strategy.local_rank = parallel_env.rank
  966. strategy.trainer_endpoints = parallel_env.trainer_endpoints
  967. strategy.current_endpoint = parallel_env.current_endpoint
  968. strategy.nrings = parallel_env.nrings
  969. # init nccl or bkcl or heter context
  970. if is_cpu_only:
  971. parallel_helper._set_parallel_ctx(
  972. core.GLOOParallelContext(strategy, place)
  973. )
  974. elif backend == "heter":
  975. parallel_helper._set_parallel_ctx(
  976. core.HeterParallelContext(strategy, parallel_env.device_id)
  977. )
  978. elif core.is_compiled_with_cuda():
  979. parallel_helper._set_parallel_ctx(
  980. core.NCCLParallelContext(strategy, place)
  981. )
  982. elif core.is_compiled_with_xpu():
  983. parallel_helper._set_parallel_ctx(
  984. core.BKCLParallelContext(strategy, place)
  985. )
  986. if backend != "heter":
  987. other_endpoints = strategy.trainer_endpoints[:]
  988. other_endpoints.remove(strategy.current_endpoint)
  989. if not is_cpu_only and strategy.local_rank == 0:
  990. wait_server_ready(other_endpoints)
  991. parallel_helper._init_parallel_ctx()
  992. # 5: init gloo context (step 2: gloo init)
  993. # dividing init_gloo into two part because nccl and gloo
  994. # are separately looking for free ports which sometimes
  995. # leads to port-conflict.
  996. if (is_cpu_only or backend == "heter") and parallel_env.rank == 0:
  997. # compare to init_gloo, we don't need to
  998. # init gloo, because we do this in _init_parallel_ctx;
  999. http_server_d["running"] = False
  1000. http_server.join()
  1001. elif init_gloo:
  1002. wait_server_ready([parallel_env.trainer_endpoints[0]])
  1003. gloo_strategy = core.GlooParallelStrategy()
  1004. gloo_strategy.rank = parallel_env.rank
  1005. gloo_strategy.rank_num = parallel_env.world_size
  1006. gloo_strategy.ip_address = ep_rank_0[0]
  1007. gloo_strategy.ip_port = int(ep_rank_0[1])
  1008. default_init_timeout_seconds = 3600
  1009. default_run_timeout_seconds = 9999999
  1010. gloo_strategy.init_seconds = default_init_timeout_seconds
  1011. gloo_strategy.run_seconds = default_run_timeout_seconds
  1012. gloo = core.GlooParallelContext(gloo_strategy)
  1013. gloo.init()
  1014. if parallel_env.rank == 0:
  1015. http_server_d["running"] = False
  1016. http_server.join()
  1017. return group
  1018. def get_rank(group=None):
  1019. """
  1020. Returns the rank of current trainer in the given group, ranks are consecutive integers in [0, ``world_size``).
  1021. If none of the group is given, the global group will be used as default.
  1022. Args:
  1023. group (Group, optional): The communication group you want to get rank of current trainer, use global group as default if group is None.
  1024. Returns:
  1025. (int) The rank of current trainer in the given group. Return -1 if the process is not part of the given group.
  1026. Warning:
  1027. Argument ``group`` only supports in dygraph mode.
  1028. Examples:
  1029. .. code-block:: python
  1030. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  1031. >>> # Execute this script using distributed launch with one card configs.
  1032. >>> import paddle
  1033. >>> import paddle.distributed as dist
  1034. >>> dist.init_parallel_env()
  1035. >>> print("The rank is %d" % dist.get_rank())
  1036. The rank is 0
  1037. """
  1038. if in_dynamic_mode() and group:
  1039. return group.rank
  1040. assert group is None, "Only support group argument in eager mode."
  1041. return _get_global_parallel_env().rank
  1042. def get_world_size(group=None):
  1043. """
  1044. Returns the number of trainers (number of processes participating in current job) in the given group.
  1045. If none of the group is given, the global group will be used as default.
  1046. Args:
  1047. group (Group, optional): The communication group you want to check world size, use global group as default if group is None.
  1048. Returns:
  1049. (int) The number of trainers in the given group. Return -1 if the process if not part of the given group.
  1050. Warning:
  1051. Argument ``group`` only supports in dygraph mode.
  1052. Examples:
  1053. .. code-block:: python
  1054. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  1055. >>> # Execute this script using distributed launch with one card configs.
  1056. >>> import paddle
  1057. >>> import paddle.distributed as dist
  1058. >>> dist.init_parallel_env()
  1059. >>> print("The world_size is %d" % dist.get_world_size())
  1060. The world_size is 1
  1061. """
  1062. if in_dynamic_mode() and (group is None):
  1063. if is_initialized():
  1064. group = _get_global_group()
  1065. if in_dynamic_mode() and group:
  1066. return group.world_size
  1067. assert group is None, "Only support group argument in eager mode."
  1068. return _get_global_parallel_env().world_size