rpc.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import datetime
  15. import os
  16. import pickle
  17. import time
  18. from collections import namedtuple
  19. from paddle.base import core
  20. from paddle.distributed.launch.context import Node
  21. from paddle.distributed.rpc.internal import PythonFunc, _serialize
  22. from paddle.distributed.utils.launch_utils import logger
  23. WorkerInfo = namedtuple("WorkerInfo", ["name", "rank", "ip", "port"])
  24. _DEFAULT_RPC_TIMEOUT = -1
  25. _MAX_RPC_TIMEOUT_MS = 0x7FFFFFFF
  26. _BARRIER_TIMEOUT_MAX_DAYS = 99999999
  27. # tcp store for `_barrier_never_timeout`
  28. _barrier_store = None
  29. # count the number of `_barrier_never_timeout` is called and
  30. # ensure that the barrier key is unique
  31. _barrier_count = 0
  32. def _set_barrier_store(store):
  33. global _barrier_store
  34. _barrier_store = store
  35. def _del_barrier_store():
  36. global _barrier_store
  37. del _barrier_store
  38. def _set_self_info(name, rank, ip, port):
  39. self_info = pickle.dumps(WorkerInfo(name, rank, ip, port))
  40. _barrier_store.set(str(rank), self_info)
  41. def _exchange_all_service_infos(world_size):
  42. all_infos = []
  43. s = set()
  44. for rank in range(world_size):
  45. info = pickle.loads(_barrier_store.get(str(rank)))
  46. assert (
  47. info.name not in s
  48. ), "The Worker name must be unique, but name `{}` is repeated."
  49. s.add(info.name)
  50. all_infos.append(info)
  51. return all_infos
  52. def _gen_endpoint():
  53. node = Node()
  54. ip = node.get_host_ip()
  55. free_port = node.get_free_port()
  56. return f"{ip}:{free_port}"
  57. def init_rpc(name, rank=None, world_size=None, master_endpoint=None):
  58. """
  59. init rpc.
  60. Args:
  61. name (str): worker name.
  62. rank (int, optional): worker id, default is None.
  63. world_size (int, optional): number of workers, default is None.
  64. master_endpoint (str, optional): id address of master, other nodes communicate with the master to
  65. get the information of all worker nodes, default is None.
  66. Returns:
  67. None.
  68. Examples:
  69. .. code-block:: python
  70. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  71. >>> import paddle.distributed.rpc as rpc
  72. >>> rpc.init_rpc("worker0", rank=0, world_size=1,
  73. ... master_endpoint="127.0.0.1:8001")
  74. >>> rpc.shutdown()
  75. """
  76. rank = int(os.environ["PADDLE_TRAINER_ID"]) if rank is None else rank
  77. world_size = (
  78. int(os.environ["PADDLE_TRAINERS_NUM"])
  79. if world_size is None
  80. else world_size
  81. )
  82. worker_endpoint = os.getenv("PADDLE_WORKER_ENDPOINT", None)
  83. if worker_endpoint is None:
  84. worker_endpoint = _gen_endpoint()
  85. logger.info(f"Trainer {rank}: worker endpoint: {worker_endpoint}")
  86. master_endpoint = (
  87. master_endpoint
  88. if master_endpoint is not None
  89. else os.environ["PADDLE_MASTER_ENDPOINT"]
  90. )
  91. master_addr, master_port = master_endpoint.split(":")
  92. master_port = int(master_port)
  93. stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900"))
  94. store = core.TCPStore(
  95. master_addr,
  96. master_port,
  97. rank == 0,
  98. world_size,
  99. timeout=stop_check_timeout,
  100. )
  101. _set_barrier_store(store)
  102. ip, port = worker_endpoint.split(":")
  103. port = int(port)
  104. _set_self_info(name, rank, ip, port)
  105. all_infos = _exchange_all_service_infos(world_size)
  106. c_infos = []
  107. for node_info in all_infos:
  108. info = core.WorkerInfo(
  109. node_info.name, node_info.rank, node_info.ip, node_info.port
  110. )
  111. c_infos.append(info)
  112. core.init_and_set_agent_instance(name, c_infos)
  113. core.rpc_start_worker()
  114. # ensure that all the workers are started
  115. _barrier_never_timeout(rank, world_size)
  116. core.rpc_start_client()
  117. logger.info(f"Trainer {rank}: Init RPC done!")
  118. def rpc_sync(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT):
  119. """
  120. Make a blocking RPC call to run function ``fn`` on worker ``to``. Attention: Users must use this API in a secure network environment.
  121. Args:
  122. to (str): name of the destination worker.
  123. fn (fn): a callable function, such as Python callables.
  124. args (tuple, optional): the argument tuple for the ``fn`` invocation, default is None.
  125. kwargs (dict, optional): is a dictionary of keyword arguments for the ``fn``
  126. invocation, default is None.
  127. timeout (int, optional): timeout in seconds to use for this RPC. If
  128. the RPC does not complete in this amount of
  129. time, an exception indicating it has
  130. timed out will be raised. A value less than or equal to 0
  131. indicates an infinite timeout, i.e. a timeout
  132. error will never be raised. The default value is -1.
  133. Returns:
  134. Returns the result of running ``fn`` with ``args`` and ``kwargs``.
  135. Examples:
  136. .. code-block:: python
  137. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  138. >>> import paddle.distributed.rpc as rpc
  139. >>> def add(a, b):
  140. ... return a + b
  141. >>> rpc.init_rpc("worker0", rank=0, world_size=1,
  142. ... master_endpoint="127.0.0.1:8002")
  143. >>> ret = rpc.rpc_sync("worker0", add, args=(2, 3))
  144. >>> rpc.shutdown()
  145. """
  146. fut = _invoke_rpc(to, fn, args, kwargs, timeout)
  147. return fut.wait()
  148. def rpc_async(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT):
  149. """
  150. Make a non-blocking RPC call to run function ``fn`` on worker ``to``. Attention: Users must use this API in a secure network environment.
  151. Args:
  152. to (str): name of the destination worker.
  153. fn (fn): a callable function, such as Python callables.
  154. args (tuple, optional): the argument tuple for the ``fn`` invocation, default is None.
  155. kwargs (dict, optional): is a dictionary of keyword arguments for the ``fn``
  156. invocation, default is None.
  157. timeout (int, optional): timeout in seconds to use for this RPC. If
  158. the RPC does not complete in this amount of
  159. time, an exception indicating it has
  160. timed out will be raised. A value less than or equal to 0
  161. indicates an infinite timeout, i.e. a timeout
  162. error will never be raised. The default value is -1.
  163. Returns:
  164. Returns a :class:`FutureWrapper` object that can be waited
  165. on. When completed, the return value of ``fn`` on ``args`` and
  166. ``kwargs`` can be got by `fut.wait()`.
  167. Examples:
  168. .. code-block:: python
  169. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  170. >>> import paddle.distributed.rpc as rpc
  171. >>> def add(a, b):
  172. ... return a + b
  173. >>> rpc.init_rpc("worker0", rank=0, world_size=1,
  174. ... master_endpoint="127.0.0.1:8003")
  175. >>> fut = rpc.rpc_async("worker0", add, args=(2, 3))
  176. >>> print(fut.wait())
  177. 5
  178. >>> rpc.shutdown()
  179. """
  180. return _invoke_rpc(to, fn, args, kwargs, timeout)
  181. def _invoke_rpc(to, fn, args, kwargs, timeout):
  182. args = args if args else ()
  183. kwargs = kwargs if kwargs else {}
  184. serial_obj = _serialize(PythonFunc(fn, args, kwargs))
  185. timeout_ms = timeout * 1000
  186. timeout_ms = _MAX_RPC_TIMEOUT_MS if timeout_ms <= 0 else timeout_ms
  187. future = core.invoke_rpc(to, serial_obj, timeout_ms)
  188. return future
  189. def _barrier_never_timeout(global_rank, global_world_size):
  190. # max timeout
  191. timeout = datetime.timedelta(days=_BARRIER_TIMEOUT_MAX_DAYS)
  192. if global_world_size < 2:
  193. return
  194. global _barrier_count
  195. barrier_prefix = "Barrier/" + str(_barrier_count) + "/"
  196. _barrier_count += 1
  197. is_master = global_rank == 0
  198. def _check_keys_ready(wait_keys):
  199. start_time = time.time()
  200. while len(wait_keys) > 0:
  201. time.sleep(0.1)
  202. elapse_time = time.time() - start_time
  203. if datetime.timedelta(seconds=elapse_time) > timeout:
  204. raise RuntimeError(
  205. f"Keys {wait_keys} are not ready sinck rank {global_rank} is waiting them."
  206. )
  207. wait_keys = list(
  208. filter(lambda key: int(_barrier_store.get(key)) != 1, wait_keys)
  209. )
  210. if is_master:
  211. # the master will add key, wait for all workers'exiting key and exit in the end.
  212. # Note: the master must exit in the end to ensure that the TcpServer is destroyed in the end.
  213. wait_keys = [
  214. barrier_prefix + str(rank) for rank in range(1, global_world_size)
  215. ]
  216. _barrier_store.add(barrier_prefix + str(0), 1)
  217. _check_keys_ready(wait_keys)
  218. else:
  219. wait_keys = [barrier_prefix + str(0)]
  220. _check_keys_ready(wait_keys)
  221. _barrier_store.add(barrier_prefix + str(global_rank), 1)
  222. def shutdown():
  223. """
  224. Perform a shutdown of the RPC agent, stop the worker and destroy the agent.
  225. This will block until all local and remote RPC processes reach this method
  226. and wait for all outstanding work to complete.
  227. Returns:
  228. None.
  229. Examples:
  230. .. code-block:: python
  231. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  232. >>> import paddle.distributed.rpc as rpc
  233. >>> rpc.init_rpc("worker0", rank=0, world_size=1,
  234. ... master_endpoint="127.0.0.1:8004")
  235. >>> rpc.shutdown()
  236. """
  237. info = get_current_worker_info()
  238. rank = info.rank
  239. world_size = len(get_all_worker_infos())
  240. # master will exit in the end
  241. _barrier_never_timeout(rank, world_size)
  242. core.rpc_stop_worker()
  243. _del_barrier_store()
  244. logger.info(f"Trainer {rank}: rpc shutdown!")
  245. def get_worker_info(name):
  246. """
  247. Get worker information by worker name.
  248. Args:
  249. name (str): name of the worker.
  250. Returns:
  251. class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`.
  252. Examples:
  253. .. code-block:: python
  254. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  255. >>> import paddle.distributed.rpc as rpc
  256. >>> import os
  257. >>> os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9002"
  258. >>> rpc.init_rpc("worker0", rank=0, world_size=1,
  259. ... master_endpoint="127.0.0.1:8005")
  260. >>> print(rpc.get_worker_info("worker0"))
  261. {name: worker0, rank: 0, ip: 127.0.0.1, port: 9002}
  262. >>> rpc.shutdown()
  263. """
  264. return core.rpc_get_worker_info(name)
  265. def get_all_worker_infos():
  266. """
  267. Get all worker informations.
  268. Returns:
  269. List[WorkerInfo].
  270. Examples:
  271. .. code-block:: python
  272. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  273. >>> import paddle.distributed.rpc as rpc
  274. >>> import os
  275. >>> os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9003"
  276. >>> rpc.init_rpc("worker0", rank=0, world_size=1,
  277. ... master_endpoint="127.0.0.1:8006")
  278. >>> print(rpc.get_all_worker_infos())
  279. [{name: worker0, rank: 0, ip: 127.0.0.1, port: 9003}]
  280. >>> rpc.shutdown()
  281. """
  282. return core.rpc_get_all_worker_infos()
  283. def get_current_worker_info():
  284. """
  285. Get current worker information.
  286. Returns:
  287. class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`.
  288. Examples:
  289. .. code-block:: python
  290. >>> # doctest: +REQUIRES(env:DISTRIBUTED)
  291. >>> import paddle.distributed.rpc as rpc
  292. >>> import os
  293. >>> os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9004"
  294. >>> rpc.init_rpc("worker0", rank=0, world_size=1,
  295. ... master_endpoint="127.0.0.1:8007")
  296. >>> print(rpc.get_current_worker_info())
  297. {name: worker0, rank: 0, ip: 127.0.0.1, port: 9004}
  298. >>> rpc.shutdown()
  299. """
  300. return core.rpc_get_current_worker_info()