recv.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle.distributed.communication import stream
  15. def recv(tensor, src=0, group=None, sync_op=True):
  16. """
  17. Receive a tensor to the sender.
  18. Args:
  19. tensor (Tensor): The tensor to receive. Its data type
  20. should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
  21. src (int): The source rank id.
  22. group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
  23. sync_op (bool, optional): Whether this op is a sync op. The default value is True.
  24. Returns:
  25. Return a task object.
  26. Examples:
  27. .. code-block:: python
  28. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  29. >>> import paddle
  30. >>> import paddle.distributed as dist
  31. >>> dist.init_parallel_env()
  32. >>> if dist.get_rank() == 0:
  33. ... data = paddle.to_tensor([7, 8, 9])
  34. ... dist.send(data, dst=1)
  35. >>> else:
  36. ... data = paddle.to_tensor([1, 2, 3])
  37. ... dist.recv(data, src=0)
  38. >>> print(data)
  39. >>> # [7, 8, 9] (2 GPUs)
  40. """
  41. return stream.recv(
  42. tensor, src=src, group=group, sync_op=sync_op, use_calc_stream=False
  43. )
  44. def irecv(tensor, src=None, group=None):
  45. """
  46. Receive a tensor to the sender.
  47. Args:
  48. tensor (Tensor): The Tensor to receive. Its data type
  49. should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
  50. src (int): The source rank id.
  51. group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
  52. Returns:
  53. Return a task object.
  54. Warning:
  55. This API only supports the dygraph mode.
  56. Examples:
  57. .. code-block:: python
  58. >>> # doctest: +REQUIRES(env: DISTRIBUTED)
  59. >>> import paddle
  60. >>> import paddle.distributed as dist
  61. >>> dist.init_parallel_env()
  62. >>> if dist.get_rank() == 0:
  63. ... data = paddle.to_tensor([7, 8, 9])
  64. ... task = dist.isend(data, dst=1)
  65. >>> else:
  66. ... data = paddle.to_tensor([1, 2, 3])
  67. ... task = dist.irecv(data, src=0)
  68. >>> task.wait()
  69. >>> print(data)
  70. >>> # [7, 8, 9] (2 GPUs)
  71. """
  72. return recv(tensor, src, group, sync_op=False)