scaler.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from types import MethodType
  15. import numpy as np
  16. import paddle
  17. from paddle import _C_ops, _legacy_C_ops
  18. from paddle.distributed import fleet
  19. from paddle.framework import core
  20. from .base.topology import ParallelMode
  21. def distributed_scaler(scaler):
  22. def unscale_method(self, optimizer):
  23. if not self._enable:
  24. return
  25. param_grads = []
  26. param_grads_bf16 = []
  27. param_grads_fp16 = []
  28. param_grads_fp32 = []
  29. if getattr(optimizer, '_param_groups', None) and isinstance(
  30. optimizer._param_groups[0], dict
  31. ):
  32. for group in optimizer._param_groups:
  33. for param in group['params']:
  34. tgt_grad = None
  35. if (
  36. hasattr(param, "main_grad")
  37. and param.main_grad is not None
  38. ):
  39. tgt_grad = param.main_grad
  40. elif param.grad is not None:
  41. tgt_grad = param.grad
  42. if tgt_grad is not None:
  43. param_grads.append(tgt_grad)
  44. if tgt_grad.dtype in [
  45. core.VarDesc.VarType.FP16,
  46. paddle.float16,
  47. ]:
  48. param_grads_fp16.append(tgt_grad)
  49. elif tgt_grad.dtype in [
  50. paddle.bfloat16,
  51. ]:
  52. param_grads_bf16.append(tgt_grad)
  53. else:
  54. param_grads_fp32.append(tgt_grad)
  55. else:
  56. strategy = fleet.fleet._user_defined_strategy
  57. sharding_stage_1_overlap = strategy.hybrid_configs[
  58. 'sharding_configs'
  59. ].comm_overlap
  60. if sharding_stage_1_overlap:
  61. # If sharding stage 1 enable comm overlap and need do loss scale. Here we have to wait all comm tasks.
  62. # If no need do loss scale, the wait for all comm tasks will do in the optimizer step.
  63. assert hasattr(optimizer, "_comm_buffers")
  64. assert hasattr(optimizer, "_sharding_enable")
  65. if optimizer._sharding_enable:
  66. # disable origin grad reduce in hybrid optimizer step
  67. optimizer._sharding_enable = False
  68. for buffer in optimizer._comm_buffers:
  69. buffer.scale_grads()
  70. # For sharding stage 1 under comm overlap, each rank only have to check finite for the response params.
  71. # For now, only sharding stage 1 contains this attr, this can be promoted to stage 2 and stage 3.
  72. assert hasattr(optimizer, "_local_parameter_list")
  73. parameters = optimizer._local_parameter_list
  74. else:
  75. parameters = optimizer._parameter_list
  76. for param in parameters:
  77. tgt_grad = None
  78. if hasattr(param, "main_grad") and param.main_grad is not None:
  79. tgt_grad = param.main_grad
  80. elif param.grad is not None:
  81. tgt_grad = param.grad
  82. if tgt_grad is not None:
  83. param_grads.append(tgt_grad)
  84. if tgt_grad.dtype in [
  85. core.VarDesc.VarType.FP16,
  86. paddle.float16,
  87. ]:
  88. param_grads_fp16.append(tgt_grad)
  89. elif tgt_grad.dtype in [
  90. paddle.bfloat16,
  91. ]:
  92. param_grads_bf16.append(tgt_grad)
  93. else:
  94. param_grads_fp32.append(tgt_grad)
  95. temp_found_inf_fp16 = paddle.to_tensor(np.array([0]).astype(np.bool_))
  96. temp_found_inf_bf16 = paddle.to_tensor(np.array([0]).astype(np.bool_))
  97. temp_found_inf_fp32 = paddle.to_tensor(np.array([0]).astype(np.bool_))
  98. self._found_inf = self._temp_found_inf_value_false
  99. if len(param_grads_fp16):
  100. _legacy_C_ops.check_finite_and_unscale(
  101. param_grads_fp16,
  102. self._scale,
  103. param_grads_fp16,
  104. temp_found_inf_fp16,
  105. )
  106. self._found_inf = _C_ops.bitwise_or(
  107. self._found_inf, temp_found_inf_fp16
  108. )
  109. if len(param_grads_bf16):
  110. _legacy_C_ops.check_finite_and_unscale(
  111. param_grads_bf16,
  112. self._scale,
  113. param_grads_bf16,
  114. temp_found_inf_bf16,
  115. )
  116. self._found_inf = _C_ops.bitwise_or(
  117. self._found_inf, temp_found_inf_bf16
  118. )
  119. if len(param_grads_fp32):
  120. _legacy_C_ops.check_finite_and_unscale(
  121. param_grads_fp32,
  122. self._scale,
  123. param_grads_fp32,
  124. temp_found_inf_fp32,
  125. )
  126. self._found_inf = _C_ops.bitwise_or(
  127. self._found_inf, temp_found_inf_fp32
  128. )
  129. self._found_inf = self._found_inf.cast("int32")
  130. # TODO(shenliang03) Since dp allreduce in the optimizer is
  131. # after the grad scaler, check_finite needs to synchronize global
  132. # information. In the future, we should use check_group to speed.
  133. paddle.distributed.all_reduce(
  134. self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
  135. )
  136. self._found_inf = self._found_inf.cast("bool")
  137. # Only data_parallel doesn't need to modify scaler
  138. fleet_env = fleet.fleet
  139. if fleet_env._hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL:
  140. scaler._unscale = MethodType(unscale_method, scaler)
  141. return scaler