| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058 |
- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the 'License');
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an 'AS IS' BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import paddle
- from paddle.base import unique_name
- from paddle.distributed.fleet.base.private_helper_function import (
- wait_server_ready,
- )
- from paddle.framework import core
- from paddle.static import default_main_program, default_startup_program
- OpRole = core.op_proto_and_checker_maker.OpRole
- class Collective:
- ''' '''
- def __init__(self, nrings):
- self.nrings = nrings
- self.endpoints = None
- self.current_endpoint = None
- self.other_endpoints = None
- self.nranks = None
- self.rank = None
- self.startup_program = None
- self.main_program = None
- op_maker = core.op_proto_and_checker_maker
- self.op_role_key = op_maker.kOpRoleAttrName()
- self.op_role_var_key = op_maker.kOpRoleVarAttrName()
- def transpile(
- self,
- startup_program,
- main_program,
- rank,
- endpoints,
- current_endpoint,
- wait_port,
- ):
- # in case of '127.0.0.1:6700,127.0.0.1:6701,...'
- if isinstance(endpoints, str):
- endpoints = endpoints.split(',')
- self.startup_program = startup_program
- if startup_program is None:
- self.startup_program = default_startup_program()
- self.main_program = main_program
- if main_program is None:
- self.main_program = default_main_program()
- self.nranks = len(endpoints)
- if (
- self.nranks == 1
- and self.mode != "single_process_multi_thread"
- and self.mode != "box"
- ):
- raise ValueError('the number of endpoints must > 1')
- if rank < 0:
- raise ValueError('rank must >= 0')
- self.rank = rank
- if current_endpoint not in endpoints:
- raise ValueError(
- 'current endpoint %s is not in %s',
- current_endpoint,
- str(endpoints),
- )
- self.endpoints = endpoints
- self.current_endpoint = current_endpoint
- if current_endpoint:
- nranks = len(endpoints)
- other_endpoints = endpoints[:]
- other_endpoints.remove(current_endpoint)
- self.other_endpoints = other_endpoints
- self.wait_port = wait_port
- self.startup_program._origin_program = self.startup_program.clone()
- self._transpile_startup_program()
- self.main_program._origin_program = self.main_program.clone()
- self._transpile_main_program()
- def _transpile_main_program(self):
- raise NotImplementedError('call the inherited method of subclasses')
- def _transpile_startup_program(self):
- for ring_id in range(self.nrings):
- self._init_communicator(
- self.startup_program,
- self.current_endpoint,
- self.endpoints,
- self.rank,
- ring_id,
- self.wait_port,
- )
- self._broadcast_params()
- def _init_communicator(
- self,
- program,
- current_endpoint,
- endpoints,
- rank,
- ring_id,
- wait_port,
- has_multitrainer=False,
- ):
- endpoints_str = ",".join(endpoints)
- nranks = len(endpoints)
- other_endpoints = endpoints[:]
- other_endpoints.remove(current_endpoint)
- block = program.global_block()
- if rank == 0 and wait_port:
- wait_server_ready(other_endpoints)
- block = program.global_block()
- if core.is_compiled_with_cuda():
- nccl_id_var = block.create_var(
- name=unique_name.generate('nccl_id'),
- persistable=True,
- type=core.VarDesc.VarType.RAW,
- )
- block.append_op(
- type='c_gen_nccl_id',
- inputs={},
- outputs={'Out': nccl_id_var},
- attrs={
- 'rank': rank,
- 'endpoint': current_endpoint,
- 'other_endpoints': other_endpoints,
- self.op_role_key: OpRole.Forward,
- },
- )
- if not has_multitrainer:
- # 'endpoints': endpoints_str,
- block.append_op(
- type='c_comm_init',
- inputs={'X': nccl_id_var},
- outputs={},
- attrs={
- 'nranks': nranks,
- 'rank': rank,
- 'ring_id': ring_id,
- self.op_role_key: OpRole.Forward,
- },
- )
- else:
- block.append_op(
- type='c_comm_init_multitrainer',
- inputs={'X': nccl_id_var},
- outputs={},
- attrs={
- 'ntrainers': nranks,
- 'trainer_id': rank,
- 'ring_id': ring_id,
- self.op_role_key: OpRole.Forward,
- },
- )
- elif core.is_compiled_with_xpu():
- bkcl_id_var = block.create_var(
- name=unique_name.generate('bkcl_id'),
- persistable=True,
- type=core.VarDesc.VarType.RAW,
- )
- block.append_op(
- type='c_gen_bkcl_id',
- inputs={},
- outputs={'Out': bkcl_id_var},
- attrs={
- 'rank': rank,
- 'endpoint': current_endpoint,
- 'other_endpoints': other_endpoints,
- self.op_role_key: OpRole.Forward,
- },
- )
- block.append_op(
- type='c_comm_init',
- inputs={'X': bkcl_id_var},
- outputs={},
- attrs={
- 'nranks': nranks,
- 'rank': rank,
- 'ring_id': ring_id,
- 'endpoints': endpoints_str,
- self.op_role_key: OpRole.Forward,
- },
- )
- elif (
- paddle.distributed.ParallelEnv().device_type
- in paddle.device.get_all_custom_device_type()
- ):
- xccl_id_var = block.create_var(
- name=unique_name.generate('xccl_id'),
- persistable=True,
- type=core.VarDesc.VarType.RAW,
- )
- block.append_op(
- type='c_gen_xccl_id',
- inputs={},
- outputs={'Out': xccl_id_var},
- attrs={
- 'rank': rank,
- 'endpoint': current_endpoint,
- 'other_endpoints': other_endpoints,
- self.op_role_key: OpRole.Forward,
- },
- )
- block.append_op(
- type='c_comm_init',
- inputs={'X': xccl_id_var},
- outputs={},
- attrs={
- 'nranks': nranks,
- 'rank': rank,
- 'ring_id': ring_id,
- 'endpoints': endpoints_str,
- self.op_role_key: OpRole.Forward,
- },
- )
- def _broadcast_params(self):
- block = self.startup_program.global_block()
- ring_id = -1
- for param in block.iter_parameters():
- if param.is_distributed:
- continue
- ring_id = (ring_id + 1) % self.nrings
- block.append_op(
- type='c_broadcast',
- inputs={'X': param},
- outputs={'Out': param},
- attrs={
- 'ring_id': ring_id,
- 'root': 0,
- self.op_role_key: OpRole.Forward,
- },
- )
- for ring_id in range(self.nrings):
- block.append_op(
- type='c_sync_comm_stream',
- inputs={'X': param},
- outputs={'Out': param},
- attrs={'ring_id': ring_id, self.op_role_key: OpRole.Forward},
- )
- def _is_loss_grad_op(self, op):
- if self.op_role_key not in op.attr_names:
- return False
- op_role = int(op.all_attrs()[self.op_role_key])
- return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)
- def _is_backward_op(self, op):
- return self.op_role_key in op.attr_names and int(
- op.all_attrs()[self.op_role_key]
- ) & int(OpRole.Backward)
- def _is_update_op(self, op):
- return (
- 'Param' in op.input_names
- and 'Grad' in op.input_names
- and "LearningRate" in op.input_names
- )
- def _is_optimizer_op(self, op):
- return self.op_role_key in op.attr_names and int(
- op.all_attrs()[self.op_role_key]
- ) & int(OpRole.Optimize)
- class GradAllReduce(Collective):
- ''' '''
- def __init__(self, nrings=2):
- Collective.__init__(self, nrings)
- self.mode = "grad_allreduce"
- def _transpile_main_program(self):
- self._insert_scale_loss_grad_ops()
- self._insert_allreduce_ops()
- def _insert_scale_loss_grad_ops(self):
- '''
- In order to keep the learning rate consistent in different numbers of
- training workers, we scale the loss grad by the number of workers
- '''
- block = self.main_program.global_block()
- for idx, op in reversed(list(enumerate(block.ops))):
- if self._is_loss_grad_op(op):
- loss_grad_var = block.vars[op.output_arg_names[0]]
- block._insert_op(
- idx + 1,
- type='scale',
- inputs={'X': loss_grad_var},
- outputs={'Out': loss_grad_var},
- attrs={
- 'scale': 1.0 / self.nranks,
- self.op_role_key: OpRole.Backward,
- },
- )
- def _insert_allreduce_ops(self):
- block = self.main_program.global_block()
- ring_id = -1
- grad = None
- for idx, op in reversed(list(enumerate(block.ops))):
- if (
- self._is_backward_op(op)
- and self.op_role_var_key in op.attr_names
- ):
- op_role_var = op.all_attrs()[self.op_role_var_key]
- if len(op_role_var) == 0:
- continue
- assert len(op_role_var) % 2 == 0
- offset = idx
- for i in range(0, len(op_role_var), 2):
- param = block.vars[op_role_var[i]]
- grad = block.vars[op_role_var[i + 1]]
- if param.is_distributed:
- continue
- if offset == idx:
- offset += 1
- block._insert_op(
- offset,
- type='c_sync_calc_stream',
- inputs={'X': grad},
- outputs={'Out': grad},
- attrs={self.op_role_key: OpRole.Backward},
- )
- offset += 1
- # As we search ops reversely, we should insert c_allreduce_sum
- # op in the same way to keep the ring_id alternate
- ring_id = (ring_id + 1) % self.nrings
- block._insert_op(
- offset,
- type='c_allreduce_sum',
- inputs={'X': grad},
- outputs={'Out': grad},
- attrs={
- 'ring_id': ring_id,
- self.op_role_key: OpRole.Backward,
- },
- )
- if grad is None:
- return
- for idx, op in enumerate(block.ops):
- if self._is_optimizer_op(op):
- for ring_id in range(self.nrings):
- block._insert_op(
- idx + ring_id,
- type='c_sync_comm_stream',
- inputs={'X': grad},
- outputs={'Out': grad},
- attrs={
- 'ring_id': ring_id,
- self.op_role_key: OpRole.Backward,
- },
- )
- break
- class LocalSGD(Collective):
- ''' '''
- def __init__(self, nrings=2):
- Collective.__init__(self, nrings)
- self.snapshot_key = '@SNAPSHOT'
- self.mode = "local_sgd"
- def _transpile_startup_program(self):
- Collective._transpile_startup_program(self)
- block = self.startup_program.global_block()
- non_dist_params = []
- for param in block.iter_parameters():
- if not param.is_distributed:
- non_dist_params.append(param)
- for param in non_dist_params:
- snapshot = block.create_var(
- name=self.snapshot_name(param.name),
- shape=param.shape,
- persistable=True,
- stop_gradient=True,
- )
- block.append_op(
- type='assign',
- inputs={'X': [param]},
- outputs={'Out': [snapshot]},
- attrs={self.op_role_key: OpRole.Forward},
- )
- def snapshot_name(self, param_name):
- return param_name + self.snapshot_key
- def _transpile_main_program(self):
- block = self.main_program.global_block()
- ordered_param_snapshot = []
- ring_id = -1
- for idx, op in reversed(list(enumerate(block.ops))):
- if self._is_update_op(op):
- param = block.vars[op.input('Param')[0]]
- if param.is_distributed:
- continue
- snapshot = block.create_var(
- name=self.snapshot_name(param.name),
- shape=param.shape,
- persistable=True,
- stop_gradient=True,
- dtype=param.dtype,
- )
- block._insert_op(
- idx + 1,
- type='elementwise_sub',
- inputs={'X': [snapshot], 'Y': [param]},
- outputs={'Out': [param]},
- attrs={self.op_role_key: OpRole.Optimize},
- )
- block._insert_op(
- idx + 2,
- type='c_sync_calc_stream',
- inputs={'X': param},
- outputs={'Out': param},
- attrs={self.op_role_key: OpRole.Optimize},
- )
- ring_id = (ring_id + 1) % self.nrings
- block._insert_op(
- idx + 3,
- type='c_allreduce_sum',
- inputs={'X': [param]},
- outputs={'Out': [param]},
- attrs={
- 'ring_id': ring_id,
- self.op_role_key: OpRole.Optimize,
- },
- )
- ordered_param_snapshot.append((param, snapshot))
- for ring_id in range(self.nrings):
- block.append_op(
- type='c_sync_comm_stream',
- inputs={'X': param},
- outputs={'Out': param},
- attrs={'ring_id': ring_id, self.op_role_key: OpRole.Optimize},
- )
- for param_snapshot in reversed(ordered_param_snapshot):
- param = param_snapshot[0]
- snapshot = param_snapshot[1]
- block.append_op(
- type='scale',
- inputs={'X': [param]},
- outputs={'Out': [param]},
- attrs={
- 'scale': 1.0 / self.nranks,
- self.op_role_key: OpRole.Optimize,
- },
- )
- block.append_op(
- type='elementwise_sub',
- inputs={'X': [snapshot], 'Y': [param]},
- outputs={'Out': [param]},
- attrs={self.op_role_key: OpRole.Optimize},
- )
- block.append_op(
- type='assign',
- inputs={'X': [param]},
- outputs={'Out': [snapshot]},
- attrs={self.op_role_key: OpRole.Optimize},
- )
- class SingleProcessMultiThread(GradAllReduce):
- """
- single process multi thread mode
- """
- def __init__(self):
- GradAllReduce.__init__(self, 1)
- self.mode = "single_process_multi_thread"
- self.fuse_allreduce = int(os.getenv("PADDLE_FUSE_ALLREDUCE", "1"))
- self.loss_scale = int(os.getenv("PADDLE_LOSS_SCALE", "1"))
- self.gpu_nums = len(
- os.getenv("FLAGS_selected_gpus", "0,1,2,3,4,5,6,7").split(",")
- )
- def _transpile_startup_program(self):
- nodes_num = 0
- if len(self.endpoints) > 1:
- nodes_num = len({x.split(':')[0] for x in self.endpoints})
- # different ip num is multi node
- if nodes_num > 1:
- self.nranks = nodes_num
- print("begin to _transpile_startup_program for multi-node")
- print("current_endpoint: ", self.current_endpoint)
- print("total endpoints: ", self.endpoints)
- print("rank: %d, ring_id: %d" % (self.rank, self.nrings))
- for ring_id in range(self.nrings):
- self._init_communicator(
- self.startup_program,
- self.current_endpoint,
- self.endpoints,
- self.rank,
- ring_id,
- self.wait_port,
- True,
- )
- else:
- self.nranks = 1
- print("begin to _transpile_startup_program for single-node")
- block = self.startup_program.global_block()
- block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
- def _transpile_main_program(self):
- # not need loss scale and no dense param
- param_cnt = self._get_update_param_count()
- if self.loss_scale == 0 and param_cnt == 0:
- return
- # scale loss
- if self.loss_scale:
- self._insert_scale_loss_grad_ops(param_cnt)
- # no param
- if param_cnt == 0:
- return
- # fuse allreduce
- if self.fuse_allreduce > 0:
- print("begin used fuse_allreduce param count = %s" % (param_cnt))
- # use fuse allreduce
- self._insert_fuse_allreduce_ops()
- else:
- self._insert_allreduce_ops()
- def _get_update_param_count(self):
- """
- get need update param count
- """
- param_count = 0
- block = self.main_program.global_block()
- for idx, op in reversed(list(enumerate(block.ops))):
- if not self._is_backward_op(op):
- continue
- if self.op_role_var_key not in op.attr_names:
- continue
- op_role_var = op.all_attrs()[self.op_role_var_key]
- if len(op_role_var) == 0:
- continue
- assert len(op_role_var) % 2 == 0
- for i in range(0, len(op_role_var), 2):
- param = block.vars[op_role_var[i]]
- if param.is_distributed:
- continue
- param_count = param_count + 1
- return param_count
- def _insert_scale_loss_grad_ops(self, param_cnt):
- '''
- In order to keep the learning rate consistent in different numbers of
- training workers, we scale the loss grad by the number of workers
- '''
- if param_cnt > 0:
- scale = 1.0 / self.nranks / self.gpu_nums
- else:
- scale = 1.0 / self.gpu_nums
- print("begin _insert_scale_loss_grad_ops scale = %s" % (scale))
- block = self.main_program.global_block()
- for idx, op in reversed(list(enumerate(block.ops))):
- if not self._is_loss_grad_op(op):
- continue
- loss_grad_var = block.vars[op.output_arg_names[0]]
- block._insert_op(
- idx + 1,
- type='scale',
- inputs={'X': loss_grad_var},
- outputs={'Out': loss_grad_var},
- attrs={'scale': scale, self.op_role_key: OpRole.Backward},
- )
- def _insert_fuse_allreduce_ops(self):
- """
- insert coalesce_tensor and all reduce ops
- """
- block = self.main_program.global_block()
- ring_id = -1
- grad = None
- input_grads = []
- global_offset = 0 # find insert offset of fuse tensor, after the max dense grad offset
- for idx, op in reversed(list(enumerate(block.ops))):
- if (
- self._is_backward_op(op)
- and self.op_role_var_key in op.attr_names
- ):
- op_role_var = op.all_attrs()[self.op_role_var_key]
- if len(op_role_var) == 0:
- continue
- assert len(op_role_var) % 2 == 0
- offset = idx
- for i in range(0, len(op_role_var), 2):
- param = block.vars[op_role_var[i]]
- grad = block.vars[op_role_var[i + 1]]
- if param.is_distributed:
- continue
- if offset == idx:
- input_grads.append(grad)
- global_offset = max(global_offset, offset + 1)
- if grad is None:
- return
- if self.fuse_allreduce == 2:
- # grads aggregation of multi-gpus
- block._insert_op(
- global_offset,
- type='c_sync_calc_stream',
- inputs={'X': input_grads[0]},
- outputs={'Out': input_grads[0]},
- attrs={self.op_role_key: OpRole.Backward},
- )
- global_offset += 1
- ring_id = (ring_id + 1) % self.nrings
- block._insert_op(
- global_offset,
- type='c_allreduce_xsum',
- inputs={'X': input_grads},
- outputs={'Out': input_grads},
- attrs={'ring_id': ring_id, self.op_role_key: OpRole.Backward},
- )
- global_offset += 1
- # sync before adam
- block._insert_op(
- global_offset,
- type='c_sync_comm_stream',
- inputs={'X': input_grads[0]},
- outputs={'Out': input_grads[0]},
- attrs={'ring_id': ring_id, self.op_role_key: OpRole.Backward},
- )
- global_offset += 1
- else:
- # init output_grads
- output_grads = input_grads
- # init fused_output with temp shape, it will calculate real shape depend on inputs
- fused_output = block.create_var(
- name="fused_output",
- shape=[1],
- persistable=False,
- dtype=core.VarDesc.VarType.FP32,
- stop_gradient=True,
- )
- # fuse all grad tensors
- coalesce_tensor_attrs = {
- "copy_data": True,
- "set_constant": False,
- "dtype": core.VarDesc.VarType.FP32,
- }
- block._insert_op(
- global_offset,
- type='coalesce_tensor',
- inputs={'Input': input_grads},
- outputs={'Output': output_grads, 'FusedOutput': fused_output},
- attrs=coalesce_tensor_attrs,
- )
- global_offset += 1
- # grads aggregation of multi-gpus
- block._insert_op(
- global_offset,
- type='c_sync_calc_stream',
- inputs={'X': fused_output},
- outputs={'Out': fused_output},
- attrs={self.op_role_key: OpRole.Backward},
- )
- global_offset += 1
- ring_id = (ring_id + 1) % self.nrings
- block._insert_op(
- global_offset,
- type='c_allreduce_sum',
- inputs={'X': fused_output},
- outputs={'Out': fused_output},
- attrs={'ring_id': ring_id, self.op_role_key: OpRole.Backward},
- )
- global_offset += 1
- # sync before adam
- block._insert_op(
- global_offset,
- type='c_sync_comm_stream',
- inputs={'X': fused_output},
- outputs={'Out': fused_output},
- attrs={'ring_id': ring_id, self.op_role_key: OpRole.Backward},
- )
- global_offset += 1
- class MultiThread(GradAllReduce):
- ''' '''
- def __init__(self, nrings=1, trans_mode="fuse_all_reduce"):
- GradAllReduce.__init__(self, nrings)
- self.mode = "box"
- self.trans_mode = trans_mode
- self.fuse_grad_size_in_num = 128
- gpu_nums = os.getenv("FLAGS_selected_gpus", "0,1,2,3,4,5,6,7,8").split(
- ","
- )
- self.gpu_num = len(gpu_nums)
- def _transpile_startup_program(self):
- if len(self.endpoints) > 1:
- print("begin to _transpile_startup_program for multi-node")
- print("current_endpoint: ", self.current_endpoint)
- print("total endpoints: ", self.endpoints)
- print("rank: %d, ring_id: %d" % (self.rank, self.nrings))
- for ring_id in range(self.nrings):
- self._init_communicator(
- self.startup_program,
- self.current_endpoint,
- self.endpoints,
- self.rank,
- ring_id,
- self.wait_port,
- True,
- )
- else:
- if "xpu" in self.trans_mode:
- print(
- "begin to _transpile_startup_program for single-node in XPU"
- )
- block = self.startup_program.global_block()
- block.append_op(
- type='c_comm_init_all',
- attrs={
- 'devices': list(
- map(
- int, os.getenv("FLAGS_selected_gpus").split(",")
- )
- ),
- 'ring_id': 0,
- },
- )
- else:
- print("begin to _transpile_startup_program for single-node")
- block = self.startup_program.global_block()
- block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
- def _transpile_main_program(self):
- self._insert_scale_loss_grad_ops()
- if self.trans_mode == "all_gather":
- print("begin to transpile in all-gather mode")
- self.allgather_ranks = self.nranks * self.gpu_num
- self._insert_allgather_ops()
- self._update_adam_ops()
- elif self.trans_mode == "fuse_all_reduce":
- print("begin to transpile in fuse all-reduce mode")
- self._insert_fuse_allreduce_ops()
- elif (
- self.trans_mode == "all_reduce_xpu"
- and len(os.getenv("FLAGS_selected_gpus").split(",")) == 1
- ):
- print(
- "skip transpile in all-reduce-xpu mode when number of devices is only one"
- )
- else:
- print("begin to transpile in all-reduce mode")
- self._insert_allreduce_ops()
- def _insert_allgather_ops(self):
- """
- insert allgather op to the main_program
- """
- block = self.main_program.global_block()
- ring_id = -1
- grad = None
- for idx, op in reversed(list(enumerate(block.ops))):
- if (
- self._is_backward_op(op)
- and self.op_role_var_key in op.attr_names
- ):
- op_role_var = op.all_attrs()[self.op_role_var_key]
- if len(op_role_var) == 0:
- continue
- assert len(op_role_var) % 2 == 0
- offset = idx
- for i in range(0, len(op_role_var), 2):
- param = block.vars[op_role_var[i]]
- new_grad_var = block.create_var(
- name=op_role_var[i] + "_allgather",
- shape=[self.allgather_ranks] + list(param.shape),
- persistable=False,
- dtype=core.VarDesc.VarType.FP32,
- stop_gradient=True,
- )
- grad = block.vars[op_role_var[i + 1]]
- if param.is_distributed: # no need to care: used in PLSC
- continue
- if offset == idx:
- offset += 1
- block._insert_op(
- offset,
- type='c_sync_calc_stream',
- inputs={'X': grad},
- outputs={'Out': grad},
- attrs={self.op_role_key: OpRole.Backward},
- )
- offset += 1
- # As we search ops reversely, we should insert c_allgather
- # op in the same way to keep the ring_id alternate
- ring_id = (ring_id + 1) % self.nrings
- block._insert_op(
- offset,
- type='c_allgather',
- inputs={'X': grad},
- outputs={'Out': new_grad_var},
- attrs={
- 'nranks': self.allgather_ranks,
- 'ring_id': ring_id,
- self.op_role_key: OpRole.Backward,
- },
- )
- if grad is None:
- return
- for idx, op in enumerate(block.ops):
- if self._is_optimizer_op(op):
- for ring_id in range(self.nrings):
- block._insert_op(
- idx + ring_id,
- type='c_sync_comm_stream',
- inputs={'X': grad},
- outputs={'Out': grad},
- attrs={
- 'ring_id': ring_id,
- self.op_role_key: OpRole.Backward,
- },
- )
- break
- def _update_adam_ops(self):
- """
- remove the original adam op, and add new adam ops
- """
- block = self.main_program.global_block()
- for idx, op in reversed(list(enumerate(block.ops))):
- if self._is_optimizer_op(op):
- offset = idx
- if (
- op.type != 'adam' and op.type != 'lamb'
- ): # filter out scale op
- continue
- param_name = op.input("Param")[0]
- inputs = {
- "Param": block.vars[op.input("Param")[0]],
- "LearningRate": block.vars[op.input("LearningRate")[0]],
- "Moment1": block.vars[op.input("Moment1")[0]],
- "Moment2": block.vars[op.input("Moment2")[0]],
- "Beta1Pow": block.vars[op.input("Beta1Pow")[0]],
- "Beta2Pow": block.vars[op.input("Beta2Pow")[0]],
- }
- outputs = {
- "ParamOut": block.vars[op.output("ParamOut")[0]],
- "Moment1Out": block.vars[op.output("Moment1Out")[0]],
- "Moment2Out": block.vars[op.output("Moment2Out")[0]],
- "Beta1PowOut": block.vars[op.output("Beta1PowOut")[0]],
- "Beta2PowOut": block.vars[op.output("Beta2PowOut")[0]],
- }
- attrs = {
- "epsilon": op.attr('epsilon'),
- "beta1": op.attr('beta1'),
- "beta2": op.attr('beta2'),
- "lazy_mode": op.attr('lazy_mode'),
- "min_row_size_to_use_multithread": op.attr(
- 'min_row_size_to_use_multithread'
- ),
- }
- split_vars = [
- block.create_var(
- name=param_name + "_" + str(i),
- shape=block.vars[op.input("Param")[0]].shape,
- persistable=False,
- dtype=core.VarDesc.VarType.FP32,
- stop_gradient=True,
- )
- for i in range(self.allgather_ranks)
- ]
- block._insert_op(
- offset,
- type="split",
- inputs={
- 'X': block.vars[op.input("Param")[0] + "_allgather"]
- },
- outputs={'Out': split_vars},
- attrs={'num': self.allgather_ranks, 'axis': 0},
- )
- offset += 1
- for i in range(self.allgather_ranks):
- inputs["Grad"] = split_vars[i]
- block._insert_op(
- offset,
- type=op.type,
- inputs=inputs,
- outputs=outputs,
- attrs=attrs,
- )
- offset += 1
- # remove the original adam op
- block._remove_op(offset)
- def _insert_fuse_allreduce_ops(self):
- """
- insert coalesce_tensor and all reduce ops
- """
- block = self.main_program.global_block()
- ring_id = 0 % self.nrings
- grad = None
- param_grads = []
- # find all grad params
- for op in reversed(block.ops):
- if (
- self._is_backward_op(op)
- and self.op_role_var_key in op.attr_names
- ):
- op_role_var = op.all_attrs()[self.op_role_var_key]
- if len(op_role_var) == 0:
- continue
- assert len(op_role_var) % 2 == 0, (
- "vars need to be one param var followed by one grad var, "
- "but got odd number of vars"
- )
- for i in range(0, len(op_role_var), 2):
- param_name = op_role_var[i]
- param = block.var(param_name)
- grad_name = op_role_var[i + 1]
- grad = block.var(grad_name)
- if param.is_distributed:
- continue
- param_grads.append(grad)
- if grad is None:
- return
- segments = []
- last_dtype = None
- # split the grad based on dtype and fused size
- for var in param_grads:
- if (
- len(segments) == 0
- or len(segments[-1]) == self.fuse_grad_size_in_num
- or var.dtype != last_dtype
- ):
- segments.append([var])
- last_dtype = var.dtype
- else:
- segments[-1].append(var)
- fused_vars = []
- for idx, op in enumerate(block.ops):
- if self._is_optimizer_op(op):
- for segment in segments:
- # insert coalesce tensor
- tmp_var = block.create_var(
- name=unique_name.generate(
- f'FusedOutput_{segment[0].name}'
- ),
- dtype=segment[0].dtype,
- persistable=False,
- stop_gradient=True,
- )
- fused_vars.append(tmp_var)
- block._insert_op(
- idx,
- type="coalesce_tensor",
- inputs={"Input": segment},
- outputs={"Output": segment, "FusedOutput": tmp_var},
- attrs={
- "copy_data": True,
- "use_align": True,
- "dtype": segment[0].dtype,
- self.op_role_key: OpRole.Backward,
- },
- )
- break
- # insert the allreduce_sum op
- for idx, op in enumerate(block.ops):
- if self._is_optimizer_op(op):
- for fused_var in fused_vars:
- block._insert_op(
- idx,
- type='c_allreduce_sum',
- inputs={'X': fused_var},
- outputs={'Out': fused_var},
- attrs={
- 'ring_id': ring_id,
- 'use_calc_stream': False,
- self.op_role_key: OpRole.Backward,
- },
- )
- block._insert_op(
- idx,
- type='c_sync_calc_stream',
- inputs={'X': fused_var},
- outputs={'Out': fused_var},
- attrs={self.op_role_key: OpRole.Backward},
- )
- break
- if len(fused_vars) == 0:
- block._sync_with_cpp()
- return
- # insert the sync comm op
- for idx, op in enumerate(block.ops):
- if self._is_optimizer_op(op):
- block._insert_op(
- idx,
- type='c_sync_comm_stream',
- inputs={'X': fused_vars[0]},
- outputs={'Out': fused_vars[0]},
- attrs={
- 'ring_id': ring_id,
- self.op_role_key: OpRole.Backward,
- },
- )
- break
- block._sync_with_cpp()
|