model.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from paddle.distributed import fleet
  16. from .base.topology import ParallelMode
  17. from .meta_parallel import (
  18. PipelineLayer,
  19. PipelineParallel,
  20. PipelineParallelWithInterleave,
  21. PipelineParallelWithInterleaveFthenB,
  22. SegmentParallel,
  23. ShardingParallel,
  24. TensorParallel,
  25. )
  26. _grad_scalar = None
  27. def distributed_model(model):
  28. """
  29. Return distributed data parallel model (Only work in dygraph mode)
  30. Args:
  31. model (Layer): the user-defined model which inherits Layer.
  32. Returns:
  33. distributed data parallel model which inherits Layer.
  34. Examples:
  35. .. code-block:: python
  36. >>> import paddle
  37. >>> import paddle.nn as nn
  38. >>> from paddle.distributed import fleet
  39. >>> class LinearNet(nn.Layer):
  40. ... def __init__(self):
  41. ... super().__init__()
  42. ... self._linear1 = nn.Linear(10, 10)
  43. ... self._linear2 = nn.Linear(10, 1)
  44. ... def forward(self, x):
  45. ... return self._linear2(self._linear1(x))
  46. >>> # 1. initialize fleet environment
  47. >>> fleet.init(is_collective=True)
  48. >>> # 2. create layer & optimizer
  49. >>> layer = LinearNet()
  50. >>> loss_fn = nn.MSELoss()
  51. >>> adam = paddle.optimizer.Adam(
  52. ... learning_rate=0.001, parameters=layer.parameters())
  53. >>> # 3. get data_parallel model using fleet
  54. >>> adam = fleet.distributed_optimizer(adam)
  55. >>> dp_layer = fleet.distributed_model(layer)
  56. >>> # 4. run layer
  57. >>> inputs = paddle.randn([10, 10], 'float32')
  58. >>> outputs = dp_layer(inputs)
  59. >>> labels = paddle.randn([10, 1], 'float32')
  60. >>> loss = loss_fn(outputs, labels)
  61. >>> print("loss:", loss.numpy())
  62. >>> loss.backward()
  63. >>> adam.step()
  64. >>> adam.clear_grad()
  65. """
  66. fleet_env = fleet.fleet
  67. assert model is not None, "model should not be None"
  68. if paddle.distributed.get_world_size() <= 1:
  69. return model
  70. strategy = fleet_env._user_defined_strategy
  71. if strategy.amp:
  72. level = (
  73. "O2"
  74. if strategy.amp_configs['use_pure_fp16']
  75. or strategy.amp_configs['use_pure_bf16']
  76. else "O1"
  77. )
  78. if level == "O2":
  79. model = paddle.amp.decorate(
  80. models=model,
  81. optimizers=None,
  82. level="O2",
  83. master_weight=None,
  84. save_dtype=None,
  85. dtype="float16"
  86. if strategy.amp_configs['use_pure_fp16']
  87. else "bfloat16",
  88. )
  89. init_loss_scaling = strategy.amp_configs['init_loss_scaling']
  90. incr_ratio = strategy.amp_configs['incr_ratio']
  91. decr_ratio = strategy.amp_configs['decr_ratio']
  92. incr_every_n_steps = strategy.amp_configs['incr_every_n_steps']
  93. decr_every_n_nan_or_inf = strategy.amp_configs[
  94. 'decr_every_n_nan_or_inf'
  95. ]
  96. use_dynamic_loss_scaling = strategy.amp_configs[
  97. 'use_dynamic_loss_scaling'
  98. ]
  99. global _grad_scalar
  100. _grad_scalar = paddle.amp.GradScaler(
  101. init_loss_scaling=init_loss_scaling,
  102. incr_ratio=incr_ratio,
  103. decr_ratio=decr_ratio,
  104. incr_every_n_steps=incr_every_n_steps,
  105. decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
  106. use_dynamic_loss_scaling=use_dynamic_loss_scaling,
  107. )
  108. if strategy.heter_ccl_mode:
  109. distributed_model = paddle.DataParallel(
  110. model,
  111. comm_buffer_size=strategy.fuse_grad_size_in_MB,
  112. last_comm_buffer_size=strategy.last_comm_group_size_MB,
  113. find_unused_parameters=strategy.find_unused_parameters,
  114. )
  115. return distributed_model
  116. if fleet_env._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
  117. model = ShardingParallel(model, fleet_env._hcg, strategy=strategy)
  118. elif fleet_env._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
  119. model = paddle.DataParallel(
  120. model,
  121. comm_buffer_size=strategy.fuse_grad_size_in_MB,
  122. last_comm_buffer_size=strategy.last_comm_group_size_MB,
  123. find_unused_parameters=strategy.find_unused_parameters,
  124. group=fleet_env._hcg.get_data_parallel_group(),
  125. )
  126. elif fleet_env._hcg.get_parallel_mode() == ParallelMode.SEGMENT_PARALLEL:
  127. model = SegmentParallel(model, fleet_env._hcg, strategy=strategy)
  128. elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
  129. model = TensorParallel(model, fleet_env._hcg, strategy=strategy)
  130. elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
  131. assert isinstance(
  132. model, PipelineLayer
  133. ), "For pipeline parallel, the model should an instance of PipelineLayer"
  134. if model.get_num_virtual_stages() == 1:
  135. # 1f1b pipeline
  136. model = PipelineParallel(model, fleet_env._hcg, strategy=strategy)
  137. else:
  138. accumulate_steps = strategy.pipeline_configs['accumulate_steps']
  139. pp_degree = fleet_env._hcg.get_pipe_parallel_world_size()
  140. if accumulate_steps >= 2 * pp_degree:
  141. # interleave pipeline
  142. model = PipelineParallelWithInterleave(
  143. model, fleet_env._hcg, strategy=strategy
  144. )
  145. elif pp_degree <= accumulate_steps < 2 * pp_degree:
  146. model = PipelineParallelWithInterleaveFthenB(
  147. model, fleet_env._hcg, strategy=strategy
  148. )
  149. else:
  150. raise ValueError(
  151. f"The accumulate_steps({accumulate_steps}) should be greater than or equal to pp_degree({pp_degree})"
  152. )
  153. return model