collective.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the 'License');
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an 'AS IS' BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import paddle
  16. from paddle.base import unique_name
  17. from paddle.distributed.fleet.base.private_helper_function import (
  18. wait_server_ready,
  19. )
  20. from paddle.framework import core
  21. from paddle.static import default_main_program, default_startup_program
  22. OpRole = core.op_proto_and_checker_maker.OpRole
  23. class Collective:
  24. ''' '''
  25. def __init__(self, nrings):
  26. self.nrings = nrings
  27. self.endpoints = None
  28. self.current_endpoint = None
  29. self.other_endpoints = None
  30. self.nranks = None
  31. self.rank = None
  32. self.startup_program = None
  33. self.main_program = None
  34. op_maker = core.op_proto_and_checker_maker
  35. self.op_role_key = op_maker.kOpRoleAttrName()
  36. self.op_role_var_key = op_maker.kOpRoleVarAttrName()
  37. def transpile(
  38. self,
  39. startup_program,
  40. main_program,
  41. rank,
  42. endpoints,
  43. current_endpoint,
  44. wait_port,
  45. ):
  46. # in case of '127.0.0.1:6700,127.0.0.1:6701,...'
  47. if isinstance(endpoints, str):
  48. endpoints = endpoints.split(',')
  49. self.startup_program = startup_program
  50. if startup_program is None:
  51. self.startup_program = default_startup_program()
  52. self.main_program = main_program
  53. if main_program is None:
  54. self.main_program = default_main_program()
  55. self.nranks = len(endpoints)
  56. if (
  57. self.nranks == 1
  58. and self.mode != "single_process_multi_thread"
  59. and self.mode != "box"
  60. ):
  61. raise ValueError('the number of endpoints must > 1')
  62. if rank < 0:
  63. raise ValueError('rank must >= 0')
  64. self.rank = rank
  65. if current_endpoint not in endpoints:
  66. raise ValueError(
  67. 'current endpoint %s is not in %s',
  68. current_endpoint,
  69. str(endpoints),
  70. )
  71. self.endpoints = endpoints
  72. self.current_endpoint = current_endpoint
  73. if current_endpoint:
  74. nranks = len(endpoints)
  75. other_endpoints = endpoints[:]
  76. other_endpoints.remove(current_endpoint)
  77. self.other_endpoints = other_endpoints
  78. self.wait_port = wait_port
  79. self.startup_program._origin_program = self.startup_program.clone()
  80. self._transpile_startup_program()
  81. self.main_program._origin_program = self.main_program.clone()
  82. self._transpile_main_program()
  83. def _transpile_main_program(self):
  84. raise NotImplementedError('call the inherited method of subclasses')
  85. def _transpile_startup_program(self):
  86. for ring_id in range(self.nrings):
  87. self._init_communicator(
  88. self.startup_program,
  89. self.current_endpoint,
  90. self.endpoints,
  91. self.rank,
  92. ring_id,
  93. self.wait_port,
  94. )
  95. self._broadcast_params()
  96. def _init_communicator(
  97. self,
  98. program,
  99. current_endpoint,
  100. endpoints,
  101. rank,
  102. ring_id,
  103. wait_port,
  104. has_multitrainer=False,
  105. ):
  106. endpoints_str = ",".join(endpoints)
  107. nranks = len(endpoints)
  108. other_endpoints = endpoints[:]
  109. other_endpoints.remove(current_endpoint)
  110. block = program.global_block()
  111. if rank == 0 and wait_port:
  112. wait_server_ready(other_endpoints)
  113. block = program.global_block()
  114. if core.is_compiled_with_cuda():
  115. nccl_id_var = block.create_var(
  116. name=unique_name.generate('nccl_id'),
  117. persistable=True,
  118. type=core.VarDesc.VarType.RAW,
  119. )
  120. block.append_op(
  121. type='c_gen_nccl_id',
  122. inputs={},
  123. outputs={'Out': nccl_id_var},
  124. attrs={
  125. 'rank': rank,
  126. 'endpoint': current_endpoint,
  127. 'other_endpoints': other_endpoints,
  128. self.op_role_key: OpRole.Forward,
  129. },
  130. )
  131. if not has_multitrainer:
  132. # 'endpoints': endpoints_str,
  133. block.append_op(
  134. type='c_comm_init',
  135. inputs={'X': nccl_id_var},
  136. outputs={},
  137. attrs={
  138. 'nranks': nranks,
  139. 'rank': rank,
  140. 'ring_id': ring_id,
  141. self.op_role_key: OpRole.Forward,
  142. },
  143. )
  144. else:
  145. block.append_op(
  146. type='c_comm_init_multitrainer',
  147. inputs={'X': nccl_id_var},
  148. outputs={},
  149. attrs={
  150. 'ntrainers': nranks,
  151. 'trainer_id': rank,
  152. 'ring_id': ring_id,
  153. self.op_role_key: OpRole.Forward,
  154. },
  155. )
  156. elif core.is_compiled_with_xpu():
  157. bkcl_id_var = block.create_var(
  158. name=unique_name.generate('bkcl_id'),
  159. persistable=True,
  160. type=core.VarDesc.VarType.RAW,
  161. )
  162. block.append_op(
  163. type='c_gen_bkcl_id',
  164. inputs={},
  165. outputs={'Out': bkcl_id_var},
  166. attrs={
  167. 'rank': rank,
  168. 'endpoint': current_endpoint,
  169. 'other_endpoints': other_endpoints,
  170. self.op_role_key: OpRole.Forward,
  171. },
  172. )
  173. block.append_op(
  174. type='c_comm_init',
  175. inputs={'X': bkcl_id_var},
  176. outputs={},
  177. attrs={
  178. 'nranks': nranks,
  179. 'rank': rank,
  180. 'ring_id': ring_id,
  181. 'endpoints': endpoints_str,
  182. self.op_role_key: OpRole.Forward,
  183. },
  184. )
  185. elif (
  186. paddle.distributed.ParallelEnv().device_type
  187. in paddle.device.get_all_custom_device_type()
  188. ):
  189. xccl_id_var = block.create_var(
  190. name=unique_name.generate('xccl_id'),
  191. persistable=True,
  192. type=core.VarDesc.VarType.RAW,
  193. )
  194. block.append_op(
  195. type='c_gen_xccl_id',
  196. inputs={},
  197. outputs={'Out': xccl_id_var},
  198. attrs={
  199. 'rank': rank,
  200. 'endpoint': current_endpoint,
  201. 'other_endpoints': other_endpoints,
  202. self.op_role_key: OpRole.Forward,
  203. },
  204. )
  205. block.append_op(
  206. type='c_comm_init',
  207. inputs={'X': xccl_id_var},
  208. outputs={},
  209. attrs={
  210. 'nranks': nranks,
  211. 'rank': rank,
  212. 'ring_id': ring_id,
  213. 'endpoints': endpoints_str,
  214. self.op_role_key: OpRole.Forward,
  215. },
  216. )
  217. def _broadcast_params(self):
  218. block = self.startup_program.global_block()
  219. ring_id = -1
  220. for param in block.iter_parameters():
  221. if param.is_distributed:
  222. continue
  223. ring_id = (ring_id + 1) % self.nrings
  224. block.append_op(
  225. type='c_broadcast',
  226. inputs={'X': param},
  227. outputs={'Out': param},
  228. attrs={
  229. 'ring_id': ring_id,
  230. 'root': 0,
  231. self.op_role_key: OpRole.Forward,
  232. },
  233. )
  234. for ring_id in range(self.nrings):
  235. block.append_op(
  236. type='c_sync_comm_stream',
  237. inputs={'X': param},
  238. outputs={'Out': param},
  239. attrs={'ring_id': ring_id, self.op_role_key: OpRole.Forward},
  240. )
  241. def _is_loss_grad_op(self, op):
  242. if self.op_role_key not in op.attr_names:
  243. return False
  244. op_role = int(op.all_attrs()[self.op_role_key])
  245. return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)
  246. def _is_backward_op(self, op):
  247. return self.op_role_key in op.attr_names and int(
  248. op.all_attrs()[self.op_role_key]
  249. ) & int(OpRole.Backward)
  250. def _is_update_op(self, op):
  251. return (
  252. 'Param' in op.input_names
  253. and 'Grad' in op.input_names
  254. and "LearningRate" in op.input_names
  255. )
  256. def _is_optimizer_op(self, op):
  257. return self.op_role_key in op.attr_names and int(
  258. op.all_attrs()[self.op_role_key]
  259. ) & int(OpRole.Optimize)
  260. class GradAllReduce(Collective):
  261. ''' '''
  262. def __init__(self, nrings=2):
  263. Collective.__init__(self, nrings)
  264. self.mode = "grad_allreduce"
  265. def _transpile_main_program(self):
  266. self._insert_scale_loss_grad_ops()
  267. self._insert_allreduce_ops()
  268. def _insert_scale_loss_grad_ops(self):
  269. '''
  270. In order to keep the learning rate consistent in different numbers of
  271. training workers, we scale the loss grad by the number of workers
  272. '''
  273. block = self.main_program.global_block()
  274. for idx, op in reversed(list(enumerate(block.ops))):
  275. if self._is_loss_grad_op(op):
  276. loss_grad_var = block.vars[op.output_arg_names[0]]
  277. block._insert_op(
  278. idx + 1,
  279. type='scale',
  280. inputs={'X': loss_grad_var},
  281. outputs={'Out': loss_grad_var},
  282. attrs={
  283. 'scale': 1.0 / self.nranks,
  284. self.op_role_key: OpRole.Backward,
  285. },
  286. )
  287. def _insert_allreduce_ops(self):
  288. block = self.main_program.global_block()
  289. ring_id = -1
  290. grad = None
  291. for idx, op in reversed(list(enumerate(block.ops))):
  292. if (
  293. self._is_backward_op(op)
  294. and self.op_role_var_key in op.attr_names
  295. ):
  296. op_role_var = op.all_attrs()[self.op_role_var_key]
  297. if len(op_role_var) == 0:
  298. continue
  299. assert len(op_role_var) % 2 == 0
  300. offset = idx
  301. for i in range(0, len(op_role_var), 2):
  302. param = block.vars[op_role_var[i]]
  303. grad = block.vars[op_role_var[i + 1]]
  304. if param.is_distributed:
  305. continue
  306. if offset == idx:
  307. offset += 1
  308. block._insert_op(
  309. offset,
  310. type='c_sync_calc_stream',
  311. inputs={'X': grad},
  312. outputs={'Out': grad},
  313. attrs={self.op_role_key: OpRole.Backward},
  314. )
  315. offset += 1
  316. # As we search ops reversely, we should insert c_allreduce_sum
  317. # op in the same way to keep the ring_id alternate
  318. ring_id = (ring_id + 1) % self.nrings
  319. block._insert_op(
  320. offset,
  321. type='c_allreduce_sum',
  322. inputs={'X': grad},
  323. outputs={'Out': grad},
  324. attrs={
  325. 'ring_id': ring_id,
  326. self.op_role_key: OpRole.Backward,
  327. },
  328. )
  329. if grad is None:
  330. return
  331. for idx, op in enumerate(block.ops):
  332. if self._is_optimizer_op(op):
  333. for ring_id in range(self.nrings):
  334. block._insert_op(
  335. idx + ring_id,
  336. type='c_sync_comm_stream',
  337. inputs={'X': grad},
  338. outputs={'Out': grad},
  339. attrs={
  340. 'ring_id': ring_id,
  341. self.op_role_key: OpRole.Backward,
  342. },
  343. )
  344. break
  345. class LocalSGD(Collective):
  346. ''' '''
  347. def __init__(self, nrings=2):
  348. Collective.__init__(self, nrings)
  349. self.snapshot_key = '@SNAPSHOT'
  350. self.mode = "local_sgd"
  351. def _transpile_startup_program(self):
  352. Collective._transpile_startup_program(self)
  353. block = self.startup_program.global_block()
  354. non_dist_params = []
  355. for param in block.iter_parameters():
  356. if not param.is_distributed:
  357. non_dist_params.append(param)
  358. for param in non_dist_params:
  359. snapshot = block.create_var(
  360. name=self.snapshot_name(param.name),
  361. shape=param.shape,
  362. persistable=True,
  363. stop_gradient=True,
  364. )
  365. block.append_op(
  366. type='assign',
  367. inputs={'X': [param]},
  368. outputs={'Out': [snapshot]},
  369. attrs={self.op_role_key: OpRole.Forward},
  370. )
  371. def snapshot_name(self, param_name):
  372. return param_name + self.snapshot_key
  373. def _transpile_main_program(self):
  374. block = self.main_program.global_block()
  375. ordered_param_snapshot = []
  376. ring_id = -1
  377. for idx, op in reversed(list(enumerate(block.ops))):
  378. if self._is_update_op(op):
  379. param = block.vars[op.input('Param')[0]]
  380. if param.is_distributed:
  381. continue
  382. snapshot = block.create_var(
  383. name=self.snapshot_name(param.name),
  384. shape=param.shape,
  385. persistable=True,
  386. stop_gradient=True,
  387. dtype=param.dtype,
  388. )
  389. block._insert_op(
  390. idx + 1,
  391. type='elementwise_sub',
  392. inputs={'X': [snapshot], 'Y': [param]},
  393. outputs={'Out': [param]},
  394. attrs={self.op_role_key: OpRole.Optimize},
  395. )
  396. block._insert_op(
  397. idx + 2,
  398. type='c_sync_calc_stream',
  399. inputs={'X': param},
  400. outputs={'Out': param},
  401. attrs={self.op_role_key: OpRole.Optimize},
  402. )
  403. ring_id = (ring_id + 1) % self.nrings
  404. block._insert_op(
  405. idx + 3,
  406. type='c_allreduce_sum',
  407. inputs={'X': [param]},
  408. outputs={'Out': [param]},
  409. attrs={
  410. 'ring_id': ring_id,
  411. self.op_role_key: OpRole.Optimize,
  412. },
  413. )
  414. ordered_param_snapshot.append((param, snapshot))
  415. for ring_id in range(self.nrings):
  416. block.append_op(
  417. type='c_sync_comm_stream',
  418. inputs={'X': param},
  419. outputs={'Out': param},
  420. attrs={'ring_id': ring_id, self.op_role_key: OpRole.Optimize},
  421. )
  422. for param_snapshot in reversed(ordered_param_snapshot):
  423. param = param_snapshot[0]
  424. snapshot = param_snapshot[1]
  425. block.append_op(
  426. type='scale',
  427. inputs={'X': [param]},
  428. outputs={'Out': [param]},
  429. attrs={
  430. 'scale': 1.0 / self.nranks,
  431. self.op_role_key: OpRole.Optimize,
  432. },
  433. )
  434. block.append_op(
  435. type='elementwise_sub',
  436. inputs={'X': [snapshot], 'Y': [param]},
  437. outputs={'Out': [param]},
  438. attrs={self.op_role_key: OpRole.Optimize},
  439. )
  440. block.append_op(
  441. type='assign',
  442. inputs={'X': [param]},
  443. outputs={'Out': [snapshot]},
  444. attrs={self.op_role_key: OpRole.Optimize},
  445. )
  446. class SingleProcessMultiThread(GradAllReduce):
  447. """
  448. single process multi thread mode
  449. """
  450. def __init__(self):
  451. GradAllReduce.__init__(self, 1)
  452. self.mode = "single_process_multi_thread"
  453. self.fuse_allreduce = int(os.getenv("PADDLE_FUSE_ALLREDUCE", "1"))
  454. self.loss_scale = int(os.getenv("PADDLE_LOSS_SCALE", "1"))
  455. self.gpu_nums = len(
  456. os.getenv("FLAGS_selected_gpus", "0,1,2,3,4,5,6,7").split(",")
  457. )
  458. def _transpile_startup_program(self):
  459. nodes_num = 0
  460. if len(self.endpoints) > 1:
  461. nodes_num = len({x.split(':')[0] for x in self.endpoints})
  462. # different ip num is multi node
  463. if nodes_num > 1:
  464. self.nranks = nodes_num
  465. print("begin to _transpile_startup_program for multi-node")
  466. print("current_endpoint: ", self.current_endpoint)
  467. print("total endpoints: ", self.endpoints)
  468. print("rank: %d, ring_id: %d" % (self.rank, self.nrings))
  469. for ring_id in range(self.nrings):
  470. self._init_communicator(
  471. self.startup_program,
  472. self.current_endpoint,
  473. self.endpoints,
  474. self.rank,
  475. ring_id,
  476. self.wait_port,
  477. True,
  478. )
  479. else:
  480. self.nranks = 1
  481. print("begin to _transpile_startup_program for single-node")
  482. block = self.startup_program.global_block()
  483. block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
  484. def _transpile_main_program(self):
  485. # not need loss scale and no dense param
  486. param_cnt = self._get_update_param_count()
  487. if self.loss_scale == 0 and param_cnt == 0:
  488. return
  489. # scale loss
  490. if self.loss_scale:
  491. self._insert_scale_loss_grad_ops(param_cnt)
  492. # no param
  493. if param_cnt == 0:
  494. return
  495. # fuse allreduce
  496. if self.fuse_allreduce > 0:
  497. print("begin used fuse_allreduce param count = %s" % (param_cnt))
  498. # use fuse allreduce
  499. self._insert_fuse_allreduce_ops()
  500. else:
  501. self._insert_allreduce_ops()
  502. def _get_update_param_count(self):
  503. """
  504. get need update param count
  505. """
  506. param_count = 0
  507. block = self.main_program.global_block()
  508. for idx, op in reversed(list(enumerate(block.ops))):
  509. if not self._is_backward_op(op):
  510. continue
  511. if self.op_role_var_key not in op.attr_names:
  512. continue
  513. op_role_var = op.all_attrs()[self.op_role_var_key]
  514. if len(op_role_var) == 0:
  515. continue
  516. assert len(op_role_var) % 2 == 0
  517. for i in range(0, len(op_role_var), 2):
  518. param = block.vars[op_role_var[i]]
  519. if param.is_distributed:
  520. continue
  521. param_count = param_count + 1
  522. return param_count
  523. def _insert_scale_loss_grad_ops(self, param_cnt):
  524. '''
  525. In order to keep the learning rate consistent in different numbers of
  526. training workers, we scale the loss grad by the number of workers
  527. '''
  528. if param_cnt > 0:
  529. scale = 1.0 / self.nranks / self.gpu_nums
  530. else:
  531. scale = 1.0 / self.gpu_nums
  532. print("begin _insert_scale_loss_grad_ops scale = %s" % (scale))
  533. block = self.main_program.global_block()
  534. for idx, op in reversed(list(enumerate(block.ops))):
  535. if not self._is_loss_grad_op(op):
  536. continue
  537. loss_grad_var = block.vars[op.output_arg_names[0]]
  538. block._insert_op(
  539. idx + 1,
  540. type='scale',
  541. inputs={'X': loss_grad_var},
  542. outputs={'Out': loss_grad_var},
  543. attrs={'scale': scale, self.op_role_key: OpRole.Backward},
  544. )
  545. def _insert_fuse_allreduce_ops(self):
  546. """
  547. insert coalesce_tensor and all reduce ops
  548. """
  549. block = self.main_program.global_block()
  550. ring_id = -1
  551. grad = None
  552. input_grads = []
  553. global_offset = 0 # find insert offset of fuse tensor, after the max dense grad offset
  554. for idx, op in reversed(list(enumerate(block.ops))):
  555. if (
  556. self._is_backward_op(op)
  557. and self.op_role_var_key in op.attr_names
  558. ):
  559. op_role_var = op.all_attrs()[self.op_role_var_key]
  560. if len(op_role_var) == 0:
  561. continue
  562. assert len(op_role_var) % 2 == 0
  563. offset = idx
  564. for i in range(0, len(op_role_var), 2):
  565. param = block.vars[op_role_var[i]]
  566. grad = block.vars[op_role_var[i + 1]]
  567. if param.is_distributed:
  568. continue
  569. if offset == idx:
  570. input_grads.append(grad)
  571. global_offset = max(global_offset, offset + 1)
  572. if grad is None:
  573. return
  574. if self.fuse_allreduce == 2:
  575. # grads aggregation of multi-gpus
  576. block._insert_op(
  577. global_offset,
  578. type='c_sync_calc_stream',
  579. inputs={'X': input_grads[0]},
  580. outputs={'Out': input_grads[0]},
  581. attrs={self.op_role_key: OpRole.Backward},
  582. )
  583. global_offset += 1
  584. ring_id = (ring_id + 1) % self.nrings
  585. block._insert_op(
  586. global_offset,
  587. type='c_allreduce_xsum',
  588. inputs={'X': input_grads},
  589. outputs={'Out': input_grads},
  590. attrs={'ring_id': ring_id, self.op_role_key: OpRole.Backward},
  591. )
  592. global_offset += 1
  593. # sync before adam
  594. block._insert_op(
  595. global_offset,
  596. type='c_sync_comm_stream',
  597. inputs={'X': input_grads[0]},
  598. outputs={'Out': input_grads[0]},
  599. attrs={'ring_id': ring_id, self.op_role_key: OpRole.Backward},
  600. )
  601. global_offset += 1
  602. else:
  603. # init output_grads
  604. output_grads = input_grads
  605. # init fused_output with temp shape, it will calculate real shape depend on inputs
  606. fused_output = block.create_var(
  607. name="fused_output",
  608. shape=[1],
  609. persistable=False,
  610. dtype=core.VarDesc.VarType.FP32,
  611. stop_gradient=True,
  612. )
  613. # fuse all grad tensors
  614. coalesce_tensor_attrs = {
  615. "copy_data": True,
  616. "set_constant": False,
  617. "dtype": core.VarDesc.VarType.FP32,
  618. }
  619. block._insert_op(
  620. global_offset,
  621. type='coalesce_tensor',
  622. inputs={'Input': input_grads},
  623. outputs={'Output': output_grads, 'FusedOutput': fused_output},
  624. attrs=coalesce_tensor_attrs,
  625. )
  626. global_offset += 1
  627. # grads aggregation of multi-gpus
  628. block._insert_op(
  629. global_offset,
  630. type='c_sync_calc_stream',
  631. inputs={'X': fused_output},
  632. outputs={'Out': fused_output},
  633. attrs={self.op_role_key: OpRole.Backward},
  634. )
  635. global_offset += 1
  636. ring_id = (ring_id + 1) % self.nrings
  637. block._insert_op(
  638. global_offset,
  639. type='c_allreduce_sum',
  640. inputs={'X': fused_output},
  641. outputs={'Out': fused_output},
  642. attrs={'ring_id': ring_id, self.op_role_key: OpRole.Backward},
  643. )
  644. global_offset += 1
  645. # sync before adam
  646. block._insert_op(
  647. global_offset,
  648. type='c_sync_comm_stream',
  649. inputs={'X': fused_output},
  650. outputs={'Out': fused_output},
  651. attrs={'ring_id': ring_id, self.op_role_key: OpRole.Backward},
  652. )
  653. global_offset += 1
  654. class MultiThread(GradAllReduce):
  655. ''' '''
  656. def __init__(self, nrings=1, trans_mode="fuse_all_reduce"):
  657. GradAllReduce.__init__(self, nrings)
  658. self.mode = "box"
  659. self.trans_mode = trans_mode
  660. self.fuse_grad_size_in_num = 128
  661. gpu_nums = os.getenv("FLAGS_selected_gpus", "0,1,2,3,4,5,6,7,8").split(
  662. ","
  663. )
  664. self.gpu_num = len(gpu_nums)
  665. def _transpile_startup_program(self):
  666. if len(self.endpoints) > 1:
  667. print("begin to _transpile_startup_program for multi-node")
  668. print("current_endpoint: ", self.current_endpoint)
  669. print("total endpoints: ", self.endpoints)
  670. print("rank: %d, ring_id: %d" % (self.rank, self.nrings))
  671. for ring_id in range(self.nrings):
  672. self._init_communicator(
  673. self.startup_program,
  674. self.current_endpoint,
  675. self.endpoints,
  676. self.rank,
  677. ring_id,
  678. self.wait_port,
  679. True,
  680. )
  681. else:
  682. if "xpu" in self.trans_mode:
  683. print(
  684. "begin to _transpile_startup_program for single-node in XPU"
  685. )
  686. block = self.startup_program.global_block()
  687. block.append_op(
  688. type='c_comm_init_all',
  689. attrs={
  690. 'devices': list(
  691. map(
  692. int, os.getenv("FLAGS_selected_gpus").split(",")
  693. )
  694. ),
  695. 'ring_id': 0,
  696. },
  697. )
  698. else:
  699. print("begin to _transpile_startup_program for single-node")
  700. block = self.startup_program.global_block()
  701. block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
  702. def _transpile_main_program(self):
  703. self._insert_scale_loss_grad_ops()
  704. if self.trans_mode == "all_gather":
  705. print("begin to transpile in all-gather mode")
  706. self.allgather_ranks = self.nranks * self.gpu_num
  707. self._insert_allgather_ops()
  708. self._update_adam_ops()
  709. elif self.trans_mode == "fuse_all_reduce":
  710. print("begin to transpile in fuse all-reduce mode")
  711. self._insert_fuse_allreduce_ops()
  712. elif (
  713. self.trans_mode == "all_reduce_xpu"
  714. and len(os.getenv("FLAGS_selected_gpus").split(",")) == 1
  715. ):
  716. print(
  717. "skip transpile in all-reduce-xpu mode when number of devices is only one"
  718. )
  719. else:
  720. print("begin to transpile in all-reduce mode")
  721. self._insert_allreduce_ops()
  722. def _insert_allgather_ops(self):
  723. """
  724. insert allgather op to the main_program
  725. """
  726. block = self.main_program.global_block()
  727. ring_id = -1
  728. grad = None
  729. for idx, op in reversed(list(enumerate(block.ops))):
  730. if (
  731. self._is_backward_op(op)
  732. and self.op_role_var_key in op.attr_names
  733. ):
  734. op_role_var = op.all_attrs()[self.op_role_var_key]
  735. if len(op_role_var) == 0:
  736. continue
  737. assert len(op_role_var) % 2 == 0
  738. offset = idx
  739. for i in range(0, len(op_role_var), 2):
  740. param = block.vars[op_role_var[i]]
  741. new_grad_var = block.create_var(
  742. name=op_role_var[i] + "_allgather",
  743. shape=[self.allgather_ranks] + list(param.shape),
  744. persistable=False,
  745. dtype=core.VarDesc.VarType.FP32,
  746. stop_gradient=True,
  747. )
  748. grad = block.vars[op_role_var[i + 1]]
  749. if param.is_distributed: # no need to care: used in PLSC
  750. continue
  751. if offset == idx:
  752. offset += 1
  753. block._insert_op(
  754. offset,
  755. type='c_sync_calc_stream',
  756. inputs={'X': grad},
  757. outputs={'Out': grad},
  758. attrs={self.op_role_key: OpRole.Backward},
  759. )
  760. offset += 1
  761. # As we search ops reversely, we should insert c_allgather
  762. # op in the same way to keep the ring_id alternate
  763. ring_id = (ring_id + 1) % self.nrings
  764. block._insert_op(
  765. offset,
  766. type='c_allgather',
  767. inputs={'X': grad},
  768. outputs={'Out': new_grad_var},
  769. attrs={
  770. 'nranks': self.allgather_ranks,
  771. 'ring_id': ring_id,
  772. self.op_role_key: OpRole.Backward,
  773. },
  774. )
  775. if grad is None:
  776. return
  777. for idx, op in enumerate(block.ops):
  778. if self._is_optimizer_op(op):
  779. for ring_id in range(self.nrings):
  780. block._insert_op(
  781. idx + ring_id,
  782. type='c_sync_comm_stream',
  783. inputs={'X': grad},
  784. outputs={'Out': grad},
  785. attrs={
  786. 'ring_id': ring_id,
  787. self.op_role_key: OpRole.Backward,
  788. },
  789. )
  790. break
  791. def _update_adam_ops(self):
  792. """
  793. remove the original adam op, and add new adam ops
  794. """
  795. block = self.main_program.global_block()
  796. for idx, op in reversed(list(enumerate(block.ops))):
  797. if self._is_optimizer_op(op):
  798. offset = idx
  799. if (
  800. op.type != 'adam' and op.type != 'lamb'
  801. ): # filter out scale op
  802. continue
  803. param_name = op.input("Param")[0]
  804. inputs = {
  805. "Param": block.vars[op.input("Param")[0]],
  806. "LearningRate": block.vars[op.input("LearningRate")[0]],
  807. "Moment1": block.vars[op.input("Moment1")[0]],
  808. "Moment2": block.vars[op.input("Moment2")[0]],
  809. "Beta1Pow": block.vars[op.input("Beta1Pow")[0]],
  810. "Beta2Pow": block.vars[op.input("Beta2Pow")[0]],
  811. }
  812. outputs = {
  813. "ParamOut": block.vars[op.output("ParamOut")[0]],
  814. "Moment1Out": block.vars[op.output("Moment1Out")[0]],
  815. "Moment2Out": block.vars[op.output("Moment2Out")[0]],
  816. "Beta1PowOut": block.vars[op.output("Beta1PowOut")[0]],
  817. "Beta2PowOut": block.vars[op.output("Beta2PowOut")[0]],
  818. }
  819. attrs = {
  820. "epsilon": op.attr('epsilon'),
  821. "beta1": op.attr('beta1'),
  822. "beta2": op.attr('beta2'),
  823. "lazy_mode": op.attr('lazy_mode'),
  824. "min_row_size_to_use_multithread": op.attr(
  825. 'min_row_size_to_use_multithread'
  826. ),
  827. }
  828. split_vars = [
  829. block.create_var(
  830. name=param_name + "_" + str(i),
  831. shape=block.vars[op.input("Param")[0]].shape,
  832. persistable=False,
  833. dtype=core.VarDesc.VarType.FP32,
  834. stop_gradient=True,
  835. )
  836. for i in range(self.allgather_ranks)
  837. ]
  838. block._insert_op(
  839. offset,
  840. type="split",
  841. inputs={
  842. 'X': block.vars[op.input("Param")[0] + "_allgather"]
  843. },
  844. outputs={'Out': split_vars},
  845. attrs={'num': self.allgather_ranks, 'axis': 0},
  846. )
  847. offset += 1
  848. for i in range(self.allgather_ranks):
  849. inputs["Grad"] = split_vars[i]
  850. block._insert_op(
  851. offset,
  852. type=op.type,
  853. inputs=inputs,
  854. outputs=outputs,
  855. attrs=attrs,
  856. )
  857. offset += 1
  858. # remove the original adam op
  859. block._remove_op(offset)
  860. def _insert_fuse_allreduce_ops(self):
  861. """
  862. insert coalesce_tensor and all reduce ops
  863. """
  864. block = self.main_program.global_block()
  865. ring_id = 0 % self.nrings
  866. grad = None
  867. param_grads = []
  868. # find all grad params
  869. for op in reversed(block.ops):
  870. if (
  871. self._is_backward_op(op)
  872. and self.op_role_var_key in op.attr_names
  873. ):
  874. op_role_var = op.all_attrs()[self.op_role_var_key]
  875. if len(op_role_var) == 0:
  876. continue
  877. assert len(op_role_var) % 2 == 0, (
  878. "vars need to be one param var followed by one grad var, "
  879. "but got odd number of vars"
  880. )
  881. for i in range(0, len(op_role_var), 2):
  882. param_name = op_role_var[i]
  883. param = block.var(param_name)
  884. grad_name = op_role_var[i + 1]
  885. grad = block.var(grad_name)
  886. if param.is_distributed:
  887. continue
  888. param_grads.append(grad)
  889. if grad is None:
  890. return
  891. segments = []
  892. last_dtype = None
  893. # split the grad based on dtype and fused size
  894. for var in param_grads:
  895. if (
  896. len(segments) == 0
  897. or len(segments[-1]) == self.fuse_grad_size_in_num
  898. or var.dtype != last_dtype
  899. ):
  900. segments.append([var])
  901. last_dtype = var.dtype
  902. else:
  903. segments[-1].append(var)
  904. fused_vars = []
  905. for idx, op in enumerate(block.ops):
  906. if self._is_optimizer_op(op):
  907. for segment in segments:
  908. # insert coalesce tensor
  909. tmp_var = block.create_var(
  910. name=unique_name.generate(
  911. f'FusedOutput_{segment[0].name}'
  912. ),
  913. dtype=segment[0].dtype,
  914. persistable=False,
  915. stop_gradient=True,
  916. )
  917. fused_vars.append(tmp_var)
  918. block._insert_op(
  919. idx,
  920. type="coalesce_tensor",
  921. inputs={"Input": segment},
  922. outputs={"Output": segment, "FusedOutput": tmp_var},
  923. attrs={
  924. "copy_data": True,
  925. "use_align": True,
  926. "dtype": segment[0].dtype,
  927. self.op_role_key: OpRole.Backward,
  928. },
  929. )
  930. break
  931. # insert the allreduce_sum op
  932. for idx, op in enumerate(block.ops):
  933. if self._is_optimizer_op(op):
  934. for fused_var in fused_vars:
  935. block._insert_op(
  936. idx,
  937. type='c_allreduce_sum',
  938. inputs={'X': fused_var},
  939. outputs={'Out': fused_var},
  940. attrs={
  941. 'ring_id': ring_id,
  942. 'use_calc_stream': False,
  943. self.op_role_key: OpRole.Backward,
  944. },
  945. )
  946. block._insert_op(
  947. idx,
  948. type='c_sync_calc_stream',
  949. inputs={'X': fused_var},
  950. outputs={'Out': fused_var},
  951. attrs={self.op_role_key: OpRole.Backward},
  952. )
  953. break
  954. if len(fused_vars) == 0:
  955. block._sync_with_cpp()
  956. return
  957. # insert the sync comm op
  958. for idx, op in enumerate(block.ops):
  959. if self._is_optimizer_op(op):
  960. block._insert_op(
  961. idx,
  962. type='c_sync_comm_stream',
  963. inputs={'X': fused_vars[0]},
  964. outputs={'Out': fused_vars[0]},
  965. attrs={
  966. 'ring_id': ring_id,
  967. self.op_role_key: OpRole.Backward,
  968. },
  969. )
  970. break
  971. block._sync_with_cpp()