| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import datetime
- import os
- import pickle
- import time
- from collections import namedtuple
- from paddle.base import core
- from paddle.distributed.launch.context import Node
- from paddle.distributed.rpc.internal import PythonFunc, _serialize
- from paddle.distributed.utils.launch_utils import logger
- WorkerInfo = namedtuple("WorkerInfo", ["name", "rank", "ip", "port"])
- _DEFAULT_RPC_TIMEOUT = -1
- _MAX_RPC_TIMEOUT_MS = 0x7FFFFFFF
- _BARRIER_TIMEOUT_MAX_DAYS = 99999999
- # tcp store for `_barrier_never_timeout`
- _barrier_store = None
- # count the number of `_barrier_never_timeout` is called and
- # ensure that the barrier key is unique
- _barrier_count = 0
- def _set_barrier_store(store):
- global _barrier_store
- _barrier_store = store
- def _del_barrier_store():
- global _barrier_store
- del _barrier_store
- def _set_self_info(name, rank, ip, port):
- self_info = pickle.dumps(WorkerInfo(name, rank, ip, port))
- _barrier_store.set(str(rank), self_info)
- def _exchange_all_service_infos(world_size):
- all_infos = []
- s = set()
- for rank in range(world_size):
- info = pickle.loads(_barrier_store.get(str(rank)))
- assert (
- info.name not in s
- ), "The Worker name must be unique, but name `{}` is repeated."
- s.add(info.name)
- all_infos.append(info)
- return all_infos
- def _gen_endpoint():
- node = Node()
- ip = node.get_host_ip()
- free_port = node.get_free_port()
- return f"{ip}:{free_port}"
- def init_rpc(name, rank=None, world_size=None, master_endpoint=None):
- """
- init rpc.
- Args:
- name (str): worker name.
- rank (int, optional): worker id, default is None.
- world_size (int, optional): number of workers, default is None.
- master_endpoint (str, optional): id address of master, other nodes communicate with the master to
- get the information of all worker nodes, default is None.
- Returns:
- None.
- Examples:
- .. code-block:: python
- >>> # doctest: +REQUIRES(env:DISTRIBUTED)
- >>> import paddle.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=1,
- ... master_endpoint="127.0.0.1:8001")
- >>> rpc.shutdown()
- """
- rank = int(os.environ["PADDLE_TRAINER_ID"]) if rank is None else rank
- world_size = (
- int(os.environ["PADDLE_TRAINERS_NUM"])
- if world_size is None
- else world_size
- )
- worker_endpoint = os.getenv("PADDLE_WORKER_ENDPOINT", None)
- if worker_endpoint is None:
- worker_endpoint = _gen_endpoint()
- logger.info(f"Trainer {rank}: worker endpoint: {worker_endpoint}")
- master_endpoint = (
- master_endpoint
- if master_endpoint is not None
- else os.environ["PADDLE_MASTER_ENDPOINT"]
- )
- master_addr, master_port = master_endpoint.split(":")
- master_port = int(master_port)
- stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900"))
- store = core.TCPStore(
- master_addr,
- master_port,
- rank == 0,
- world_size,
- timeout=stop_check_timeout,
- )
- _set_barrier_store(store)
- ip, port = worker_endpoint.split(":")
- port = int(port)
- _set_self_info(name, rank, ip, port)
- all_infos = _exchange_all_service_infos(world_size)
- c_infos = []
- for node_info in all_infos:
- info = core.WorkerInfo(
- node_info.name, node_info.rank, node_info.ip, node_info.port
- )
- c_infos.append(info)
- core.init_and_set_agent_instance(name, c_infos)
- core.rpc_start_worker()
- # ensure that all the workers are started
- _barrier_never_timeout(rank, world_size)
- core.rpc_start_client()
- logger.info(f"Trainer {rank}: Init RPC done!")
- def rpc_sync(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT):
- """
- Make a blocking RPC call to run function ``fn`` on worker ``to``. Attention: Users must use this API in a secure network environment.
- Args:
- to (str): name of the destination worker.
- fn (fn): a callable function, such as Python callables.
- args (tuple, optional): the argument tuple for the ``fn`` invocation, default is None.
- kwargs (dict, optional): is a dictionary of keyword arguments for the ``fn``
- invocation, default is None.
- timeout (int, optional): timeout in seconds to use for this RPC. If
- the RPC does not complete in this amount of
- time, an exception indicating it has
- timed out will be raised. A value less than or equal to 0
- indicates an infinite timeout, i.e. a timeout
- error will never be raised. The default value is -1.
- Returns:
- Returns the result of running ``fn`` with ``args`` and ``kwargs``.
- Examples:
- .. code-block:: python
- >>> # doctest: +REQUIRES(env:DISTRIBUTED)
- >>> import paddle.distributed.rpc as rpc
- >>> def add(a, b):
- ... return a + b
- >>> rpc.init_rpc("worker0", rank=0, world_size=1,
- ... master_endpoint="127.0.0.1:8002")
- >>> ret = rpc.rpc_sync("worker0", add, args=(2, 3))
- >>> rpc.shutdown()
- """
- fut = _invoke_rpc(to, fn, args, kwargs, timeout)
- return fut.wait()
- def rpc_async(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT):
- """
- Make a non-blocking RPC call to run function ``fn`` on worker ``to``. Attention: Users must use this API in a secure network environment.
- Args:
- to (str): name of the destination worker.
- fn (fn): a callable function, such as Python callables.
- args (tuple, optional): the argument tuple for the ``fn`` invocation, default is None.
- kwargs (dict, optional): is a dictionary of keyword arguments for the ``fn``
- invocation, default is None.
- timeout (int, optional): timeout in seconds to use for this RPC. If
- the RPC does not complete in this amount of
- time, an exception indicating it has
- timed out will be raised. A value less than or equal to 0
- indicates an infinite timeout, i.e. a timeout
- error will never be raised. The default value is -1.
- Returns:
- Returns a :class:`FutureWrapper` object that can be waited
- on. When completed, the return value of ``fn`` on ``args`` and
- ``kwargs`` can be got by `fut.wait()`.
- Examples:
- .. code-block:: python
- >>> # doctest: +REQUIRES(env:DISTRIBUTED)
- >>> import paddle.distributed.rpc as rpc
- >>> def add(a, b):
- ... return a + b
- >>> rpc.init_rpc("worker0", rank=0, world_size=1,
- ... master_endpoint="127.0.0.1:8003")
- >>> fut = rpc.rpc_async("worker0", add, args=(2, 3))
- >>> print(fut.wait())
- 5
- >>> rpc.shutdown()
- """
- return _invoke_rpc(to, fn, args, kwargs, timeout)
- def _invoke_rpc(to, fn, args, kwargs, timeout):
- args = args if args else ()
- kwargs = kwargs if kwargs else {}
- serial_obj = _serialize(PythonFunc(fn, args, kwargs))
- timeout_ms = timeout * 1000
- timeout_ms = _MAX_RPC_TIMEOUT_MS if timeout_ms <= 0 else timeout_ms
- future = core.invoke_rpc(to, serial_obj, timeout_ms)
- return future
- def _barrier_never_timeout(global_rank, global_world_size):
- # max timeout
- timeout = datetime.timedelta(days=_BARRIER_TIMEOUT_MAX_DAYS)
- if global_world_size < 2:
- return
- global _barrier_count
- barrier_prefix = "Barrier/" + str(_barrier_count) + "/"
- _barrier_count += 1
- is_master = global_rank == 0
- def _check_keys_ready(wait_keys):
- start_time = time.time()
- while len(wait_keys) > 0:
- time.sleep(0.1)
- elapse_time = time.time() - start_time
- if datetime.timedelta(seconds=elapse_time) > timeout:
- raise RuntimeError(
- f"Keys {wait_keys} are not ready sinck rank {global_rank} is waiting them."
- )
- wait_keys = list(
- filter(lambda key: int(_barrier_store.get(key)) != 1, wait_keys)
- )
- if is_master:
- # the master will add key, wait for all workers'exiting key and exit in the end.
- # Note: the master must exit in the end to ensure that the TcpServer is destroyed in the end.
- wait_keys = [
- barrier_prefix + str(rank) for rank in range(1, global_world_size)
- ]
- _barrier_store.add(barrier_prefix + str(0), 1)
- _check_keys_ready(wait_keys)
- else:
- wait_keys = [barrier_prefix + str(0)]
- _check_keys_ready(wait_keys)
- _barrier_store.add(barrier_prefix + str(global_rank), 1)
- def shutdown():
- """
- Perform a shutdown of the RPC agent, stop the worker and destroy the agent.
- This will block until all local and remote RPC processes reach this method
- and wait for all outstanding work to complete.
- Returns:
- None.
- Examples:
- .. code-block:: python
- >>> # doctest: +REQUIRES(env:DISTRIBUTED)
- >>> import paddle.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=1,
- ... master_endpoint="127.0.0.1:8004")
- >>> rpc.shutdown()
- """
- info = get_current_worker_info()
- rank = info.rank
- world_size = len(get_all_worker_infos())
- # master will exit in the end
- _barrier_never_timeout(rank, world_size)
- core.rpc_stop_worker()
- _del_barrier_store()
- logger.info(f"Trainer {rank}: rpc shutdown!")
- def get_worker_info(name):
- """
- Get worker information by worker name.
- Args:
- name (str): name of the worker.
- Returns:
- class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`.
- Examples:
- .. code-block:: python
- >>> # doctest: +REQUIRES(env:DISTRIBUTED)
- >>> import paddle.distributed.rpc as rpc
- >>> import os
- >>> os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9002"
- >>> rpc.init_rpc("worker0", rank=0, world_size=1,
- ... master_endpoint="127.0.0.1:8005")
- >>> print(rpc.get_worker_info("worker0"))
- {name: worker0, rank: 0, ip: 127.0.0.1, port: 9002}
- >>> rpc.shutdown()
- """
- return core.rpc_get_worker_info(name)
- def get_all_worker_infos():
- """
- Get all worker informations.
- Returns:
- List[WorkerInfo].
- Examples:
- .. code-block:: python
- >>> # doctest: +REQUIRES(env:DISTRIBUTED)
- >>> import paddle.distributed.rpc as rpc
- >>> import os
- >>> os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9003"
- >>> rpc.init_rpc("worker0", rank=0, world_size=1,
- ... master_endpoint="127.0.0.1:8006")
- >>> print(rpc.get_all_worker_infos())
- [{name: worker0, rank: 0, ip: 127.0.0.1, port: 9003}]
- >>> rpc.shutdown()
- """
- return core.rpc_get_all_worker_infos()
- def get_current_worker_info():
- """
- Get current worker information.
- Returns:
- class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`.
- Examples:
- .. code-block:: python
- >>> # doctest: +REQUIRES(env:DISTRIBUTED)
- >>> import paddle.distributed.rpc as rpc
- >>> import os
- >>> os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9004"
- >>> rpc.init_rpc("worker0", rank=0, world_size=1,
- ... master_endpoint="127.0.0.1:8007")
- >>> print(rpc.get_current_worker_info())
- {name: worker0, rank: 0, ip: 127.0.0.1, port: 9004}
- >>> rpc.shutdown()
- """
- return core.rpc_get_current_worker_info()
|