gather.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle import framework
  15. from paddle.distributed.communication import stream
  16. def gather(tensor, gather_list=None, dst=0, group=None, sync_op=True):
  17. """
  18. Gather tensors from all participators.
  19. Args:
  20. tensor (Tensor): The input Tensor. Its data type
  21. should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
  22. gather_list (list): A list of Tensors to hold the gathered tensors. Every element in the list must be a Tensor whose data type
  23. should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
  24. dst (int): The dst rank id. Default value is 0.
  25. group (Group, optional): The group instance return by new_group or None for global default group.
  26. sync_op (bool, optional): Whether this op is a sync op. The default value is True.
  27. Returns:
  28. Async work handle,which can be wait on, if async_op is set to True.
  29. None, if not async_op
  30. Examples:
  31. .. code-block:: python
  32. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  33. >>> import paddle
  34. >>> import paddle.distributed as dist
  35. >>> dist.init_parallel_env()
  36. >>> gather_list = []
  37. >>> if dist.get_rank() == 0:
  38. ... data = paddle.to_tensor([1, 2, 3])
  39. ... dist.gather(data, gather_list, dst=0)
  40. >>> else:
  41. ... data = paddle.to_tensor([4, 5, 6])
  42. ... dist.gather(data1, gather_list, dst=0)
  43. >>> print(gather_list)
  44. >>> # [[1, 2, 3], [4, 5, 6]] (2 GPUs, out for rank 0)
  45. >>> # [] (2 GPUs, out for rank 1)
  46. """
  47. assert (
  48. framework.in_dynamic_mode()
  49. ), "gather doesn't support static graph mode yet."
  50. return stream.gather(tensor, gather_list, dst, group, sync_op)