scatter.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.distributed as dist
  17. from paddle import framework
  18. from paddle.distributed.communication import stream
  19. from .serialization_utils import (
  20. convert_object_to_tensor,
  21. convert_tensor_to_object,
  22. )
  23. def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
  24. """
  25. Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter
  26. is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.
  27. .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
  28. :width: 800
  29. :alt: scatter
  30. :align: center
  31. Args:
  32. tensor (Tensor): The output Tensor. Its data type
  33. should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
  34. tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
  35. should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
  36. src (int): The source rank id. Default value is 0.
  37. group (Group, optional): The group instance return by new_group or None for global default group.
  38. sync_op (bool, optional): Whether this op is a sync op. The default value is True.
  39. Returns:
  40. None.
  41. Examples:
  42. .. code-block:: python
  43. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  44. >>> import paddle
  45. >>> import paddle.distributed as dist
  46. >>> dist.init_parallel_env()
  47. >>> if dist.get_rank() == 0:
  48. ... data1 = paddle.to_tensor([7, 8, 9])
  49. ... data2 = paddle.to_tensor([10, 11, 12])
  50. ... dist.scatter(data1, src=1)
  51. >>> else:
  52. ... data1 = paddle.to_tensor([1, 2, 3])
  53. ... data2 = paddle.to_tensor([4, 5, 6])
  54. ... dist.scatter(data1, tensor_list=[data1, data2], src=1)
  55. >>> print(data1, data2)
  56. >>> # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
  57. >>> # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
  58. """
  59. return stream.scatter(tensor, tensor_list, src, group, sync_op)
  60. def scatter_object_list(
  61. out_object_list, in_object_list=None, src=0, group=None
  62. ):
  63. """
  64. Scatter picklable objects from the source to all others. Similiar to scatter(), but python object can be passed in.
  65. Args:
  66. out_object_list (list): The list of objects to store the scattered objects.
  67. in_object_list (list): The list of objects to scatter. Only objects on the src rank will be scattered.
  68. src (int): The source rank in global view.
  69. group (Group): The group instance return by new_group or None for global default group.
  70. Returns:
  71. None.
  72. Warning:
  73. This API only supports the dygraph mode.
  74. Examples:
  75. .. code-block:: python
  76. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  77. >>> import paddle.distributed as dist
  78. >>> dist.init_parallel_env()
  79. >>> out_object_list = []
  80. >>> if dist.get_rank() == 0:
  81. ... in_object_list = [{'foo': [1, 2, 3]}, {'foo': [4, 5, 6]}]
  82. >>> else:
  83. ... in_object_list = [{'bar': [1, 2, 3]}, {'bar': [4, 5, 6]}]
  84. >>> dist.scatter_object_list(out_object_list, in_object_list, src=1)
  85. >>> print(out_object_list)
  86. >>> # [{'bar': [1, 2, 3]}] (2 GPUs, out for rank 0)
  87. >>> # [{'bar': [4, 5, 6]}] (2 GPUs, out for rank 1)
  88. """
  89. assert (
  90. framework.in_dynamic_mode()
  91. ), "scatter_object_list doesn't support static graph mode."
  92. rank = dist.get_rank()
  93. in_obj_tensors = []
  94. in_obj_sizes = []
  95. if rank == src:
  96. for obj in in_object_list:
  97. obj_tensor, obj_size = convert_object_to_tensor(obj)
  98. in_obj_tensors.append(obj_tensor)
  99. in_obj_sizes.append(obj_size)
  100. max_obj_size_tensor = max(in_obj_sizes)
  101. else:
  102. max_obj_size_tensor = paddle.empty([], dtype="int64")
  103. stream.broadcast(max_obj_size_tensor, src)
  104. max_obj_size = int(max_obj_size_tensor.item())
  105. # resize to the same size
  106. in_tensor_list = []
  107. for tensor in in_obj_tensors:
  108. numpy_data = tensor.numpy()
  109. numpy_data = np.resize(numpy_data, [max_obj_size])
  110. in_tensor = paddle.to_tensor(numpy_data)
  111. in_tensor_list.append(in_tensor)
  112. out_tensor = paddle.empty([max_obj_size], dtype="uint8")
  113. scatter(out_tensor, in_tensor_list if rank == src else None, src, group)
  114. out_tensor_size = paddle.empty([], dtype="int64")
  115. scatter(out_tensor_size, in_obj_sizes if rank == src else None, src, group)
  116. out_object_list.clear()
  117. out_object_list.append(
  118. convert_tensor_to_object(out_tensor, out_tensor_size.item())
  119. )