communicator.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Copyright(c) 2019 PaddlePaddle Authors.All Rights Reserved.
  15. #
  16. # Licensed under the Apache License, Version 2.0(the "License");
  17. # you may not use this file except in compliance with the License.
  18. # You may obtain a copy of the License at
  19. #
  20. # http: // www.apache.org/licenses/LICENSE-2.0
  21. #
  22. # Unless required by applicable law or agreed to in writing, software
  23. # distributed under the License is distributed on an "AS IS" BASIS,
  24. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. # See the License for the specific language governing permissions and
  26. # limitations under the License.
  27. """
  28. Communicator is used for async distribute training in distribute_transpiler mode.
  29. It's a wrapper of a cpp class Communicator and should be used inside fleet API.
  30. """
  31. import paddle
  32. from paddle.distributed.ps.utils.public import DistributedMode
  33. from paddle.framework import core
  34. __all__ = []
  35. class Communicator:
  36. def __init__(self, mode, kwargs=None, envs=None):
  37. """
  38. Communicator is used for async distribute training in distribute_transpiler mode.
  39. It's a wrapper of a cpp class Communicator and should be used inside fleet API.
  40. Args:
  41. program(Program): the trainers program after transpile of distribute_transpiler.
  42. It's used by communicator to extract the information to do communication.
  43. Returns:
  44. None
  45. Examples:
  46. .. code-block:: python
  47. >>> import paddle
  48. >>> prog = paddle.static.Program()
  49. >>> comm = paddle.distributed.communicator.Communicator(prog)
  50. >>> comm.start()
  51. >>> comm.stop()
  52. """
  53. # set all recv op to not_run mode
  54. if kwargs is None:
  55. if envs is None:
  56. envs = {}
  57. else:
  58. if mode == DistributedMode.SYNC:
  59. envs["pserver_endpoints"] = ','.join(
  60. kwargs["pserver_endpoints"]
  61. )
  62. envs["trainers"] = str(kwargs["trainers"])
  63. envs["trainer_id"] = str(kwargs["trainer_id"])
  64. envs["need_global_step"] = str(kwargs["need_global_step"])
  65. envs["barrier_table_id"] = str(kwargs["barrier_table_id"])
  66. mode_str = None
  67. if mode == DistributedMode.SYNC:
  68. mode_str = "SYNC"
  69. elif mode == DistributedMode.ASYNC:
  70. mode_str = "ASYNC"
  71. elif mode == DistributedMode.HALF_ASYNC:
  72. mode_str = "HALF_ASYNC"
  73. elif mode == DistributedMode.GEO:
  74. mode_str = "GEO"
  75. self.mode = mode_str
  76. self.envs = envs
  77. self.communicator_ = None
  78. self.send_ctx_ = None
  79. self.recv_ctx_ = None
  80. def init_with_ctx(
  81. self, send_ctx, recv_ctx, proto_txt, unit64_hosts, scope=None
  82. ):
  83. if scope is None:
  84. scope = paddle.static.global_scope()
  85. self.communicator_ = core.DistCommunicator(
  86. self.mode,
  87. proto_txt,
  88. unit64_hosts,
  89. send_ctx,
  90. recv_ctx,
  91. scope,
  92. self.envs,
  93. )
  94. self.send_ctx_ = send_ctx
  95. self.recv_ctx_ = recv_ctx
  96. def create_client_to_client_connection(
  97. self,
  98. pserver_timeout_ms=500000,
  99. pserver_connect_timeout_ms=10000,
  100. max_retry=3,
  101. ):
  102. self.communicator_.create_client_to_client_connection(
  103. pserver_timeout_ms, pserver_connect_timeout_ms, max_retry
  104. )
  105. def get_client_info(self):
  106. return self.communicator_.get_client_info()
  107. def set_clients(self, host_list):
  108. self.communicator_.set_clients(host_list)
  109. def start(self):
  110. """
  111. Start communicator. Should call before training process.
  112. Returns:
  113. None
  114. Examples:
  115. .. code-block:: python
  116. >>> import paddle
  117. >>> prog = paddle.static.Program()
  118. >>> comm = paddle.distributed.communicator.Communicator(prog)
  119. >>> comm.start()
  120. >>> comm.stop()
  121. """
  122. if self.communicator_ is None:
  123. print('you must call init_with_ctx first to init comm before start')
  124. return
  125. self.communicator_.start()
  126. def stop(self):
  127. """
  128. Stop communicator. Should call after training process.
  129. Returns:
  130. None
  131. Examples:
  132. .. code-block:: python
  133. >>> import paddle
  134. >>> prog = paddle.static.Program()
  135. >>> comm = paddle.distributed.communicator.Communicator(prog)
  136. >>> comm.start()
  137. >>> comm.stop()
  138. """
  139. if self.communicator_ is None:
  140. print('you must call init_with_ctx first to init comm before stop')
  141. return
  142. self.communicator_.stop()
  143. def is_running(self):
  144. """
  145. Get communicator is running or stop.
  146. Returns:
  147. bool
  148. Examples:
  149. .. code-block:: python
  150. >>> import paddle
  151. >>> prog = paddle.static.Program()
  152. >>> comm = paddle.distributed.communicator.Communicator(prog)
  153. >>> comm.is_running()
  154. """
  155. if self.communicator_ is None:
  156. print('you must call init_with_ctx first to init comm before stop')
  157. return
  158. self.communicator_.is_running()
  159. def recv(self):
  160. self.communicator_.recv()
  161. def init_params(self, context):
  162. self.communicator_.init_params(context)
  163. def pull_dense(self, context):
  164. self.communicator_.pull_dense(context)
  165. def push_sparse_param(self, var_name, table_id=-1, scope=None):
  166. if scope is None:
  167. scope = paddle.static.global_scope()
  168. if not self.is_running():
  169. raise ValueError(
  170. "Communicator should init first. Using fleet.init_worker() before push_sparse_param()"
  171. )
  172. assert isinstance(var_name, str)
  173. assert isinstance(table_id, int)
  174. if table_id == -1:
  175. table_id = self.send_ctx_[var_name].table_id()
  176. self.communicator_.push_sparse_param(var_name, table_id, scope)
  177. class FLCommunicator(Communicator): # only for coordinator
  178. def __init__(self, ps_hosts, kwargs=None):
  179. mode = None
  180. super().__init__(mode, kwargs)
  181. send_ctx = {}
  182. dense_map = {}
  183. prototxt = ""
  184. self.mode = "WITH_COORDINATOR"
  185. self.init_with_ctx(send_ctx, dense_map, prototxt, ps_hosts)
  186. def start_coordinator(self, self_endpoint, trainer_endpoints):
  187. if self.communicator_ is not None:
  188. self.communicator_.start_coordinator(
  189. self_endpoint, trainer_endpoints
  190. )
  191. def save_fl_strategy(self, mp):
  192. if self.communicator_ is not None:
  193. self.communicator_.save_fl_strategy(mp)
  194. else:
  195. raise ValueError("self.communicator_ is null")
  196. def query_fl_clients_info(self):
  197. info_mp = {}
  198. if self.communicator_ is not None:
  199. info_mp = self.communicator_.query_fl_clients_info()
  200. return info_mp
  201. class LargeScaleKV:
  202. def __init__(self):
  203. self.scale_kv = core.LargeScaleKV()
  204. def save(self, varname, dirname):
  205. self.scale_kv.save(varname, dirname)
  206. def load(self, varname, dirname):
  207. self.scale_kv.load(varname, dirname)
  208. def size(self, varname):
  209. return self.scale_kv.size(varname)
  210. class HeterClient:
  211. def __init__(self, endpoint, previous_endpoint, trainer_id):
  212. self.heter_client_ = core.HeterClient(
  213. endpoint, previous_endpoint, trainer_id
  214. )
  215. def stop(self):
  216. self.heter_client_.stop()