optimizer.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. from paddle.distributed import fleet
  16. from paddle.framework import in_dynamic_mode
  17. from .meta_optimizers import HeterParallelOptimizer, HybridParallelOptimizer
  18. from .utils.log_util import logger
  19. def _dygraph_distributed_optimizer(optimizer, strategy=None):
  20. """
  21. Optimizer for distributed training.
  22. For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
  23. Which has basic Optimizer function and special features for distributed training.
  24. Args:
  25. optimizer(Optimizer): The executor to run for init server.
  26. strategy(DistributedStrategy): Extra properties for distributed optimizer.
  27. It is recommended to use DistributedStrategy in fleet.init(). The strategy
  28. here is for compatibility. If the strategy in fleet.distributed_optimizer()
  29. is not None, then it will overwrite the DistributedStrategy in fleet.init(),
  30. which will take effect in distributed training.
  31. Returns:
  32. Fleet: instance of fleet.
  33. Examples:
  34. .. code-block:: python
  35. >>> import paddle
  36. >>> import paddle.distributed.fleet as fleet
  37. >>> fleet.init(is_collective=True)
  38. >>> strategy = fleet.DistributedStrategy()
  39. >>> linear = paddle.nn.Linear(10, 10)
  40. >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())
  41. >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
  42. """
  43. fleet_env = fleet.fleet
  44. fleet_env.user_defined_optimizer = optimizer
  45. if strategy is not None:
  46. if fleet_env._is_collective:
  47. logger.warning(
  48. "It is recommended to use DistributedStrategy "
  49. "in fleet_env.init(). The strategy here is only for compatibility. "
  50. "If the strategy in fleet_env.distributed_optimizer() is "
  51. "not None, then it will overwrite the DistributedStrategy in fleet_env.init(), "
  52. "which will take effect in distributed training."
  53. )
  54. fleet_env._user_defined_strategy = copy.deepcopy(strategy)
  55. fleet_env._context = {}
  56. if fleet_env.worker_num() > 1:
  57. if not fleet_env._user_defined_strategy.heter_ccl_mode:
  58. hp_optim = HybridParallelOptimizer(
  59. optimizer, fleet_env._hcg, fleet_env._user_defined_strategy
  60. )
  61. if fleet_env._user_defined_strategy.hybrid_configs[
  62. "pp_configs"
  63. ].dp_comm_overlap:
  64. # grad all-reduce of dp and sep with be fused
  65. hp_optim._dp_enable = False
  66. hp_optim._sep_enable = False
  67. if fleet_env._user_defined_strategy.hybrid_configs[
  68. "pp_configs"
  69. ].sharding_comm_overlap:
  70. hp_optim._sharding_enable = False
  71. assert (
  72. not hp_optim._sep_enable
  73. ), "sep parallel can not coexist with sharding_comm_overlap"
  74. return hp_optim
  75. else:
  76. return HeterParallelOptimizer(
  77. optimizer, fleet_env._user_defined_strategy
  78. )
  79. else:
  80. return optimizer
  81. def distributed_optimizer(*args, **kwargs):
  82. if in_dynamic_mode():
  83. return _dygraph_distributed_optimizer(*args, **kwargs)
  84. else:
  85. return fleet.fleet.distributed_optimizer(*args, **kwargs)