trainer_desc.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Definition of trainers."""
  15. import os
  16. import sys
  17. __all__ = []
  18. class TrainerDesc:
  19. '''
  20. Set proto from python to c++.
  21. Can be initialized from train_desc.
  22. '''
  23. def __init__(self):
  24. '''
  25. self.proto_desc = data_feed_pb2.DataFeedDesc()
  26. with open(proto_file, 'r') as f:
  27. text_format.Parse(f.read(), self.proto_desc)
  28. '''
  29. # Workaround for relative import in protobuf under python3
  30. # TODO: should be fixed
  31. cur_path = os.path.dirname(__file__)
  32. if cur_path not in sys.path:
  33. sys.path.append(cur_path)
  34. if cur_path + "/proto" not in sys.path:
  35. sys.path.append(cur_path + "/proto")
  36. from proto import trainer_desc_pb2
  37. self.proto_desc = trainer_desc_pb2.TrainerDesc()
  38. import multiprocessing as mp
  39. # set default thread num == cpu count
  40. self.proto_desc.thread_num = mp.cpu_count()
  41. self._fleet_desc = None
  42. self._device_worker = None
  43. self._program = None
  44. self._infer = False
  45. def _set_heter_info(self, ret):
  46. # ret = = fu.split_program_by_device(program)
  47. # start_list, end_list, send_list, recv_list, program_list = fu.split_program_by_device(program)
  48. # if len(start_list) != 3:
  49. # print("start_list len=", len(start_list), " will not set heter info")
  50. # return
  51. # for i in start_list[0]:
  52. # self.proto_desc.op_run_start_idx.append(i)
  53. # for i in end_list[0]:
  54. # self.proto_desc.op_run_end_idx.append(i)
  55. # for i in send_list[0]:
  56. # self.proto_desc.op_run_send_list.append(i)
  57. # for i in recv_list[0]:
  58. # self.proto_desc.op_run_recv_list.append(i)
  59. if ret is None:
  60. return
  61. # for i in ret[0]: # start_list[1]:
  62. # self.proto_desc.xpu_start_idx.append(i)
  63. self.proto_desc.xpu_start_idx = ret[0]
  64. # for i in ret[1]: #end_list[1]:
  65. # self.proto_desc.o_end_idx.append(i)
  66. self.proto_desc.xpu_end_idx = ret[1]
  67. for i in ret[2]: # send_list[1]:
  68. self.proto_desc.xpu_send_list.append(i)
  69. for i in ret[3]: # recv_list[1]:
  70. self.proto_desc.xpu_recv_list.append(i)
  71. # for i in start_list[2]:
  72. # self.proto_desc.op_run_end_start_idx.append(i)
  73. # for i in end_list[2]:
  74. # self.proto_desc.op_run_end_idx.append(i)
  75. # for i in send_list[2]:
  76. # self.proto_desc.op_run_end_send_list.append(i)
  77. # for i in recv_list[2]:
  78. # self.proto_desc.op_run_end_recv_list.append(i)
  79. def _set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
  80. # convert fetch_info to list
  81. fetch_info = list(fetch_info)
  82. for i, v in enumerate(fetch_vars):
  83. self.proto_desc.fetch_config.fetch_var_names.extend([v.name])
  84. self.proto_desc.fetch_config.fetch_var_str_format.extend(
  85. [fetch_info[i]]
  86. )
  87. self.proto_desc.fetch_config.print_period = print_period
  88. def _set_debug(self, debug):
  89. self.proto_desc.debug = debug
  90. def _set_thread(self, thread_num):
  91. self.proto_desc.thread_num = thread_num
  92. def _set_device_worker(self, device_worker):
  93. self._device_worker = device_worker
  94. def _set_infer(self, infer):
  95. self._infer = infer
  96. def _set_fleet_desc(self, fleet_desc):
  97. self._fleet_desc = fleet_desc
  98. # serialize fleet_desc
  99. from google.protobuf import text_format
  100. fleet_desc_str = text_format.MessageToString(fleet_desc)
  101. self.proto_desc.fleet_desc = fleet_desc_str
  102. def _gen_trainer_desc(self):
  103. pass
  104. def _set_program(self, program):
  105. self._program = program
  106. def _set_trainer_id(self, trainer_id):
  107. self.proto_desc.trainer_id = trainer_id
  108. def _set_trainers(self, trainers):
  109. for trainer_num in trainers:
  110. self.proto_desc.trainers.append(trainer_num)
  111. def _set_use_cvm(self, use_cvm=False):
  112. self.proto_desc.use_cvm = use_cvm
  113. def _set_no_cvm(self, no_cvm=False):
  114. self.proto_desc.no_cvm = no_cvm
  115. def _set_scale_sparse_grad_with_batch_size(
  116. self, scale_sparse_gradient_with_batch_size=True
  117. ):
  118. self.proto_desc.scale_sparse_gradient_with_batch_size = (
  119. scale_sparse_gradient_with_batch_size
  120. )
  121. def _set_scale_datanorm(self, scale_datanorm=-1):
  122. self.proto_desc.scale_datanorm = scale_datanorm
  123. def _set_dump_slot(self, dump_slot):
  124. self.proto_desc.dump_slot = dump_slot
  125. def _set_mpi_rank(self, mpi_rank):
  126. self.proto_desc.mpi_rank = mpi_rank
  127. def _set_mpi_size(self, mpi_size):
  128. self.proto_desc.mpi_size = mpi_size
  129. def _set_dump_fields(self, dump_fields):
  130. for field in dump_fields:
  131. self.proto_desc.dump_fields.append(field)
  132. def _set_is_dump_in_simple_mode(self, is_dump_in_simple_mode):
  133. self.proto_desc.is_dump_in_simple_mode = is_dump_in_simple_mode
  134. def _set_dump_num_decimals(self, dump_num_decimals):
  135. self.proto_desc.dump_num_decimals = dump_num_decimals
  136. def _set_dump_fields_path(self, path):
  137. self.proto_desc.dump_fields_path = path
  138. def _set_dump_file_num(self, dump_file_num):
  139. self.proto_desc.dump_file_num = dump_file_num
  140. def _set_user_define_dump_filename(self, user_define_dump_filename):
  141. self.proto_desc.user_define_dump_filename = user_define_dump_filename
  142. def _set_dump_converter(self, converter):
  143. self.proto_desc.dump_converter = converter
  144. def _set_enable_random_dump(self, enable_random_dump):
  145. self.proto_desc.enable_random_dump = enable_random_dump
  146. def _set_dump_interval(self, dump_interval):
  147. self.proto_desc.dump_interval = dump_interval
  148. def _set_random_with_lineid(self, random_with_lineid):
  149. self.proto_desc.random_with_lineid = random_with_lineid
  150. def _set_dump_param(self, dump_param):
  151. for param in dump_param:
  152. self.proto_desc.dump_param.append(param)
  153. def _set_dump_fields_mode(self, mode):
  154. self.proto_desc.dump_fields_mode = mode
  155. def _set_worker_places(self, worker_places):
  156. for place in worker_places:
  157. self.proto_desc.worker_places.append(place)
  158. def _set_use_ps_gpu(self, use_ps_gpu=False):
  159. self.proto_desc.use_ps_gpu = use_ps_gpu
  160. def _set_thread_barrier(self, thread_barrier):
  161. self.proto_desc.thread_barrier = thread_barrier
  162. def _set_check_nan_var_names(self, check_nan_var_names):
  163. for var in check_nan_var_names:
  164. self.proto_desc.check_nan_var_names.append(var)
  165. def _set_loss_names(self, loss_names):
  166. for loss in loss_names:
  167. self.proto_desc.loss_names.append(loss)
  168. def _set_adjust_ins_weight(self, config_dict):
  169. self.proto_desc.adjust_ins_weight_config.need_adjust = config_dict.get(
  170. "need_adjust", False
  171. )
  172. self.proto_desc.adjust_ins_weight_config.nid_slot = config_dict.get(
  173. "nid_slot", ""
  174. )
  175. self.proto_desc.adjust_ins_weight_config.nid_adjw_threshold = (
  176. config_dict.get("nid_adjw_threshold", 0.0)
  177. )
  178. self.proto_desc.adjust_ins_weight_config.nid_adjw_ratio = (
  179. config_dict.get("nid_adjw_ratio", 0.0)
  180. )
  181. self.proto_desc.adjust_ins_weight_config.ins_weight_slot = (
  182. config_dict.get("ins_weight_slot", "")
  183. )
  184. def _set_copy_table_config(self, config_dict):
  185. config = self.proto_desc.copy_table_config
  186. config.need_copy = config_dict.get("need_copy", False)
  187. config.batch_num = config_dict.get("batch_num", 100)
  188. src_sparse_tables = config_dict.get("src_sparse_tables", [])
  189. if not isinstance(src_sparse_tables, list):
  190. src_sparse_tables = [src_sparse_tables]
  191. dest_sparse_tables = config_dict.get("dest_sparse_tables", [])
  192. if not isinstance(dest_sparse_tables, list):
  193. dest_sparse_tables = [dest_sparse_tables]
  194. if len(src_sparse_tables) != len(dest_sparse_tables):
  195. raise ValueError(
  196. "len(src_sparse_tables) != len(dest_sparse_tables),"
  197. f" {len(src_sparse_tables)} vs {len(dest_sparse_tables)}"
  198. )
  199. for i in src_sparse_tables:
  200. config.src_sparse_tables.append(i)
  201. for i in dest_sparse_tables:
  202. config.dest_sparse_tables.append(i)
  203. src_dense_tables = config_dict.get("src_dense_tables", [])
  204. if not isinstance(src_dense_tables, list):
  205. src_dense_tables = [src_dense_tables]
  206. dest_dense_tables = config_dict.get("dest_dense_tables", [])
  207. if not isinstance(dest_dense_tables, list):
  208. dest_dense_tables = [dest_dense_tables]
  209. if len(src_dense_tables) != len(dest_dense_tables):
  210. raise ValueError(
  211. "len(src_dense_tables) != len(dest_dense_tables),"
  212. f" {len(src_dense_tables)} vs {len(dest_dense_tables)}"
  213. )
  214. for i in src_dense_tables:
  215. config.src_dense_tables.append(i)
  216. for i in dest_dense_tables:
  217. config.dest_dense_tables.append(i)
  218. # user can also specify dense variables to copy,
  219. # instead of copy dense table
  220. src_var_list = config_dict.get("src_var_list", [])
  221. if not isinstance(src_var_list, list):
  222. src_var_list = [src_var_list]
  223. dest_var_list = config_dict.get("dest_var_list", [])
  224. if not isinstance(dest_var_list, list):
  225. dest_var_list = [dest_var_list]
  226. if len(src_var_list) != len(dest_var_list):
  227. raise ValueError(
  228. f"len(src_var_list) != len(dest_var_list), {len(src_var_list)} vs"
  229. f" {len(dest_var_list)}"
  230. )
  231. for i in src_var_list:
  232. config.src_var_list.append(i)
  233. for i in dest_var_list:
  234. config.dest_var_list.append(i)
  235. dependency_map = config_dict.get("dependency_map", {})
  236. for key in dependency_map:
  237. m = config.table_dependency_map.add()
  238. m.key = key
  239. values = dependency_map[key]
  240. if not isinstance(values, list):
  241. values = [values]
  242. if len(values) != 1:
  243. raise ValueError("dependency len %s != 1" % len(values))
  244. for value in values:
  245. m.values.append(value)
  246. config.dense_pull_after_copy = config_dict.get(
  247. "dense_pull_after_copy", True
  248. )
  249. config.enable_dependency = config_dict.get("enable_dependency", False)
  250. config.sparse_copy_by_feasign = config_dict.get(
  251. "sparse_copy_by_feasign", True
  252. )
  253. def _desc(self):
  254. return self.proto_desc.SerializeToString()
  255. def __str__(self):
  256. from google.protobuf import text_format
  257. return text_format.MessageToString(self.proto_desc)
  258. class MultiTrainer(TrainerDesc):
  259. '''
  260. Implement of MultiTrainer.
  261. Can be init from TrainerDesc.
  262. '''
  263. def __init__(self):
  264. super().__init__()
  265. pass
  266. def _set_program(self, program):
  267. super()._set_program(program)
  268. self._program = program
  269. def _gen_trainer_desc(self):
  270. super()._gen_trainer_desc()
  271. self.proto_desc.class_name = "MultiTrainer"
  272. self._device_worker._set_infer(self._infer)
  273. self._device_worker._set_program(self._program)
  274. self._device_worker._gen_worker_desc(self.proto_desc)
  275. class DistMultiTrainer(TrainerDesc):
  276. """
  277. Implement of DistMultiTrainer.
  278. It's for Distributed training.
  279. """
  280. def __init__(self):
  281. super().__init__()
  282. pass
  283. def _set_program(self, program):
  284. super()._set_program(program)
  285. self._program = program
  286. def _gen_trainer_desc(self):
  287. super()._gen_trainer_desc()
  288. self.proto_desc.class_name = "DistMultiTrainer"
  289. if self._program is None:
  290. raise RuntimeError("None Program")
  291. self._device_worker._set_infer(self._infer)
  292. self._device_worker._set_program(self._program)
  293. self._device_worker._gen_worker_desc(self.proto_desc)
  294. class HeterXpuTrainer(TrainerDesc):
  295. """
  296. Implement of HeterXpuTrainer.
  297. It's for Distributed training.
  298. """
  299. def __init__(self):
  300. super().__init__()
  301. pass
  302. def _set_program(self, program):
  303. super()._set_program(program)
  304. self._program = program
  305. def _gen_trainer_desc(self):
  306. super()._gen_trainer_desc()
  307. self.proto_desc.class_name = "HeterXpuTrainer"
  308. if self._program is None:
  309. raise RuntimeError("None Program")
  310. self._device_worker._set_infer(self._infer)
  311. self._device_worker._set_program(self._program)
  312. self._device_worker._gen_worker_desc(self.proto_desc)
  313. class PSGPUTrainer(TrainerDesc):
  314. """
  315. Implement of PSGPUTrainer.
  316. It's for Distributed training.
  317. """
  318. def __init__(self):
  319. super().__init__()
  320. pass
  321. def _set_program(self, program):
  322. super()._set_program(program)
  323. self._program = program
  324. def _gen_trainer_desc(self):
  325. super()._gen_trainer_desc()
  326. self.proto_desc.class_name = "PSGPUTrainer"
  327. if self._program is None:
  328. raise RuntimeError("None Program")
  329. self._device_worker._set_infer(self._infer)
  330. self._device_worker._set_program(self._program)
  331. self._device_worker._gen_worker_desc(self.proto_desc)
  332. class HeterPipelineTrainer(TrainerDesc):
  333. """
  334. Implement of HeterPipelineTrainer.
  335. It's for HeterPS Pipeline training.
  336. """
  337. def __init__(self):
  338. super().__init__()
  339. pass
  340. def _set_program(self, program):
  341. super()._set_program(program)
  342. self._program = program
  343. def _gen_trainer_desc(self):
  344. super()._gen_trainer_desc()
  345. self.proto_desc.class_name = "HeterPipelineTrainer"
  346. if self._program is None:
  347. raise RuntimeError("None Program")
  348. self._device_worker._set_infer(self._infer)
  349. self._device_worker._set_program(self._program)
  350. self._device_worker._gen_worker_desc(self.proto_desc)
  351. class PipelineTrainer(TrainerDesc):
  352. """
  353. Implement of PipelineTrainer.
  354. It's for Pipeline.
  355. """
  356. def __init__(self):
  357. super().__init__()
  358. pass
  359. def _set_program(self, program):
  360. super()._set_program(program)
  361. self._program = program
  362. def _gen_trainer_desc(self):
  363. super()._gen_trainer_desc()
  364. self.proto_desc.class_name = "PipelineTrainer"
  365. if self._program is None:
  366. raise RuntimeError("None Program")
  367. self._device_worker._set_infer(self._infer)
  368. self._device_worker._set_program(self._program)
  369. self._device_worker._gen_worker_desc(self.proto_desc)