| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- from paddle.distributed import fleet
- from paddle.framework import in_dynamic_mode
- from .meta_optimizers import HeterParallelOptimizer, HybridParallelOptimizer
- from .utils.log_util import logger
- def _dygraph_distributed_optimizer(optimizer, strategy=None):
- """
- Optimizer for distributed training.
- For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
- Which has basic Optimizer function and special features for distributed training.
- Args:
- optimizer(Optimizer): The executor to run for init server.
- strategy(DistributedStrategy): Extra properties for distributed optimizer.
- It is recommended to use DistributedStrategy in fleet.init(). The strategy
- here is for compatibility. If the strategy in fleet.distributed_optimizer()
- is not None, then it will overwrite the DistributedStrategy in fleet.init(),
- which will take effect in distributed training.
- Returns:
- Fleet: instance of fleet.
- Examples:
- .. code-block:: python
- >>> import paddle
- >>> import paddle.distributed.fleet as fleet
- >>> fleet.init(is_collective=True)
- >>> strategy = fleet.DistributedStrategy()
- >>> linear = paddle.nn.Linear(10, 10)
- >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())
- >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
- """
- fleet_env = fleet.fleet
- fleet_env.user_defined_optimizer = optimizer
- if strategy is not None:
- if fleet_env._is_collective:
- logger.warning(
- "It is recommended to use DistributedStrategy "
- "in fleet_env.init(). The strategy here is only for compatibility. "
- "If the strategy in fleet_env.distributed_optimizer() is "
- "not None, then it will overwrite the DistributedStrategy in fleet_env.init(), "
- "which will take effect in distributed training."
- )
- fleet_env._user_defined_strategy = copy.deepcopy(strategy)
- fleet_env._context = {}
- if fleet_env.worker_num() > 1:
- if not fleet_env._user_defined_strategy.heter_ccl_mode:
- hp_optim = HybridParallelOptimizer(
- optimizer, fleet_env._hcg, fleet_env._user_defined_strategy
- )
- if fleet_env._user_defined_strategy.hybrid_configs[
- "pp_configs"
- ].dp_comm_overlap:
- # grad all-reduce of dp and sep with be fused
- hp_optim._dp_enable = False
- hp_optim._sep_enable = False
- if fleet_env._user_defined_strategy.hybrid_configs[
- "pp_configs"
- ].sharding_comm_overlap:
- hp_optim._sharding_enable = False
- assert (
- not hp_optim._sep_enable
- ), "sep parallel can not coexist with sharding_comm_overlap"
- return hp_optim
- else:
- return HeterParallelOptimizer(
- optimizer, fleet_env._user_defined_strategy
- )
- else:
- return optimizer
- def distributed_optimizer(*args, **kwargs):
- if in_dynamic_mode():
- return _dygraph_distributed_optimizer(*args, **kwargs)
- else:
- return fleet.fleet.distributed_optimizer(*args, **kwargs)
|