| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from types import MethodType
- import numpy as np
- import paddle
- from paddle import _C_ops, _legacy_C_ops
- from paddle.distributed import fleet
- from paddle.framework import core
- from .base.topology import ParallelMode
- def distributed_scaler(scaler):
- def unscale_method(self, optimizer):
- if not self._enable:
- return
- param_grads = []
- param_grads_bf16 = []
- param_grads_fp16 = []
- param_grads_fp32 = []
- if getattr(optimizer, '_param_groups', None) and isinstance(
- optimizer._param_groups[0], dict
- ):
- for group in optimizer._param_groups:
- for param in group['params']:
- tgt_grad = None
- if (
- hasattr(param, "main_grad")
- and param.main_grad is not None
- ):
- tgt_grad = param.main_grad
- elif param.grad is not None:
- tgt_grad = param.grad
- if tgt_grad is not None:
- param_grads.append(tgt_grad)
- if tgt_grad.dtype in [
- core.VarDesc.VarType.FP16,
- paddle.float16,
- ]:
- param_grads_fp16.append(tgt_grad)
- elif tgt_grad.dtype in [
- paddle.bfloat16,
- ]:
- param_grads_bf16.append(tgt_grad)
- else:
- param_grads_fp32.append(tgt_grad)
- else:
- strategy = fleet.fleet._user_defined_strategy
- sharding_stage_1_overlap = strategy.hybrid_configs[
- 'sharding_configs'
- ].comm_overlap
- if sharding_stage_1_overlap:
- # If sharding stage 1 enable comm overlap and need do loss scale. Here we have to wait all comm tasks.
- # If no need do loss scale, the wait for all comm tasks will do in the optimizer step.
- assert hasattr(optimizer, "_comm_buffers")
- assert hasattr(optimizer, "_sharding_enable")
- if optimizer._sharding_enable:
- # disable origin grad reduce in hybrid optimizer step
- optimizer._sharding_enable = False
- for buffer in optimizer._comm_buffers:
- buffer.scale_grads()
- # For sharding stage 1 under comm overlap, each rank only have to check finite for the response params.
- # For now, only sharding stage 1 contains this attr, this can be promoted to stage 2 and stage 3.
- assert hasattr(optimizer, "_local_parameter_list")
- parameters = optimizer._local_parameter_list
- else:
- parameters = optimizer._parameter_list
- for param in parameters:
- tgt_grad = None
- if hasattr(param, "main_grad") and param.main_grad is not None:
- tgt_grad = param.main_grad
- elif param.grad is not None:
- tgt_grad = param.grad
- if tgt_grad is not None:
- param_grads.append(tgt_grad)
- if tgt_grad.dtype in [
- core.VarDesc.VarType.FP16,
- paddle.float16,
- ]:
- param_grads_fp16.append(tgt_grad)
- elif tgt_grad.dtype in [
- paddle.bfloat16,
- ]:
- param_grads_bf16.append(tgt_grad)
- else:
- param_grads_fp32.append(tgt_grad)
- temp_found_inf_fp16 = paddle.to_tensor(np.array([0]).astype(np.bool_))
- temp_found_inf_bf16 = paddle.to_tensor(np.array([0]).astype(np.bool_))
- temp_found_inf_fp32 = paddle.to_tensor(np.array([0]).astype(np.bool_))
- self._found_inf = self._temp_found_inf_value_false
- if len(param_grads_fp16):
- _legacy_C_ops.check_finite_and_unscale(
- param_grads_fp16,
- self._scale,
- param_grads_fp16,
- temp_found_inf_fp16,
- )
- self._found_inf = _C_ops.bitwise_or(
- self._found_inf, temp_found_inf_fp16
- )
- if len(param_grads_bf16):
- _legacy_C_ops.check_finite_and_unscale(
- param_grads_bf16,
- self._scale,
- param_grads_bf16,
- temp_found_inf_bf16,
- )
- self._found_inf = _C_ops.bitwise_or(
- self._found_inf, temp_found_inf_bf16
- )
- if len(param_grads_fp32):
- _legacy_C_ops.check_finite_and_unscale(
- param_grads_fp32,
- self._scale,
- param_grads_fp32,
- temp_found_inf_fp32,
- )
- self._found_inf = _C_ops.bitwise_or(
- self._found_inf, temp_found_inf_fp32
- )
- self._found_inf = self._found_inf.cast("int32")
- # TODO(shenliang03) Since dp allreduce in the optimizer is
- # after the grad scaler, check_finite needs to synchronize global
- # information. In the future, we should use check_group to speed.
- paddle.distributed.all_reduce(
- self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
- )
- self._found_inf = self._found_inf.cast("bool")
- # Only data_parallel doesn't need to modify scaler
- fleet_env = fleet.fleet
- if fleet_env._hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL:
- scaler._unscale = MethodType(unscale_method, scaler)
- return scaler
|