launch.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. r"""
  15. fleetrun is a module that spawns multiple distributed
  16. process on each training node for gpu training and cpu training.
  17. Usage:
  18. In both of single node training or multiple node training, this module
  19. launch a process on each of the given gpu card or cpu machine.
  20. GPU training:
  21. 1. for single node training with all visible gpu cards:
  22. fleetrun your_training_py (arg1 arg2 and all others)
  23. 2. for single node training with [0,4) cards
  24. fleetrun --gpus="0,1,2,3" your_training_py (arg1 arg2 and all others)
  25. 3. for multiple node training such as two node:192.168.0.16, 192.168.0.17
  26. on 192.168.0.16:
  27. fleetrun --ips="192.168.0.16,192.168.0.17" \
  28. your_training_py (arg1 arg2 and all others)
  29. on 192.168.0.17:
  30. fleetrun --ips="192.168.0.16,192.168.0.17" \
  31. your_training_py (arg1 arg2 and all others)
  32. CPU training:
  33. 1. for single node training with multi servers and workers:
  34. fleetrun --server_num=2 --worker_num=2 your_training_py (arg1 arg2 and all others)
  35. 2. for multiple node training such as two node:192.168.0.16, 192.168.0.17 \
  36. with 2 servers and 4 workers.
  37. on 192.168.0.16:
  38. fleetrun --servers="192.168.0.16:6170,192.168.0.17:6170" \
  39. --workers="192.168.0.16,192.168.0.17,192.168.0.16,192.168.0.17" \
  40. your_training_py (arg1 arg2 and all others)
  41. on 192.168.0.17:
  42. fleetrun --servers="192.168.0.16:6170,192.168.0.17:6171" \
  43. --workers="192.168.0.16,192.168.0.17,192.168.0.16,192.168.0.17" \
  44. your_training_py (arg1 arg2 and all others)
  45. 3. use gloo backend for multiple node training such as two node:192.168.0.16, 192.168.0.17 \
  46. with 2 servers and 4 workers. (workers should set port)
  47. on 192.168.0.16:
  48. fleetrun --servers="192.168.0.16:6170,192.168.0.17:6170" \
  49. --workers="192.168.0.16:6171,192.168.0.17:6171,192.168.0.16:6172,192.168.0.17:6172" \
  50. your_training_py (arg1 arg2 and all others)
  51. on 192.168.0.17:
  52. fleetrun --servers="192.168.0.16:6170,192.168.0.17:6170" \
  53. --workers="192.168.0.16:6171,192.168.0.17:6171,192.168.0.16:6172,192.168.0.17:6172" \
  54. your_training_py (arg1 arg2 and all others)
  55. """
  56. import copy
  57. import os
  58. import pathlib
  59. import shutil
  60. import sys
  61. import tempfile
  62. import time
  63. from argparse import REMAINDER, ArgumentParser
  64. from paddle import framework
  65. from paddle.distributed.fleet import cloud_utils, launch_utils
  66. from paddle.distributed.fleet.elastic import enable_elastic, launch_elastic
  67. from paddle.distributed.fleet.launch_utils import (
  68. DeviceMode,
  69. DistributeMode,
  70. ParameterServerLauncher,
  71. block_windows_and_macos,
  72. check_backend,
  73. direct_start,
  74. find_free_ports,
  75. get_cluster,
  76. get_host_name_ip,
  77. get_logger,
  78. logger,
  79. start_local_trainers,
  80. terminate_local_procs,
  81. watch_local_trainers,
  82. )
  83. __all__ = []
  84. def _print_arguments(args):
  85. print("----------- Configuration Arguments -----------")
  86. for arg, value in sorted(vars(args).items()):
  87. print(f"{arg}: {value}")
  88. print("------------------------------------------------")
  89. def _parse_args():
  90. """
  91. Helper function parsing the command line options
  92. @retval ArgumentParser
  93. """
  94. parser = ArgumentParser(
  95. description='''start paddle training using multi-process mode.
  96. see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/training/cluster_howto.html#permalink-8--nccl2-
  97. '''
  98. )
  99. base_group = parser.add_argument_group("Base Parameters")
  100. base_group.add_argument(
  101. "--log_dir",
  102. type=str,
  103. default="log",
  104. help="The path for each process's log. Default --log_dir=log/",
  105. )
  106. base_group.add_argument(
  107. "--backend",
  108. type=str,
  109. default=os.environ.get('PADDLE_DISTRI_BACKEND', 'auto'),
  110. help="Specify the backend, can be gloo|nccl|bkcl|auto|heter. "
  111. "Default value is auto which prefers nccl or bkcl.",
  112. )
  113. base_group.add_argument(
  114. "--nproc_per_node",
  115. type=int,
  116. default=None,
  117. help="The number of processes to launch on a node."
  118. "In gpu training, it should be less or equal to the gpus number of you system(or you set by --gpus). And so each process can"
  119. " bound to one or average number of gpus.",
  120. )
  121. base_group.add_argument(
  122. "--run_mode",
  123. type=str,
  124. default=None,
  125. help="run mode of job, can be:collective/ps/ps-heter",
  126. )
  127. if framework.core.is_compiled_with_cuda():
  128. base_group.add_argument(
  129. "--gpus",
  130. type=str,
  131. default=None,
  132. help="It's for gpu training."
  133. "For example:"
  134. "--gpus=\"0,1,2,3\" will launch four training processes each bound to one gpu.",
  135. )
  136. base_group.add_argument("--selected_gpus", dest="gpus")
  137. if framework.core.is_compiled_with_xpu():
  138. base_group.add_argument(
  139. "--xpus",
  140. type=str,
  141. default=None,
  142. help="It's for xpu training. For example: "
  143. "--xpus=\"0,1,2,3\" will launch four training processes each bound to one xpu.",
  144. )
  145. base_group.add_argument("--selected_xpus", dest="xpus")
  146. base_group.add_argument(
  147. "training_script",
  148. type=str,
  149. help="The full path to the single GPU training "
  150. "program/script to be launched in parallel, "
  151. "followed by all the arguments for the "
  152. "training script",
  153. )
  154. base_group.add_argument('training_script_args', nargs=REMAINDER)
  155. # Optional arguments for the launch helper
  156. # for collective
  157. collective_group = parser.add_argument_group("Collective Parameters")
  158. collective_group.add_argument(
  159. "--ips",
  160. type=str,
  161. default="127.0.0.1",
  162. help="Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17..",
  163. )
  164. collective_group.add_argument(
  165. "--cluster_topo_path",
  166. type=str,
  167. default=None,
  168. help="A json format file will be stored in this path which is used"
  169. "to represent the cluster topology information for auto parallel.",
  170. )
  171. collective_group.add_argument(
  172. "--rank_mapping_path",
  173. type=str,
  174. default=None,
  175. help="A json format file will be stored in this path which is used"
  176. "to map processes to machines for auto parallel.",
  177. )
  178. collective_group.add_argument(
  179. "--enable_auto_mapping",
  180. type=bool,
  181. default=False,
  182. help="Set true to enable the lazy launch for auto-parallel scenario.",
  183. )
  184. ps_group = parser.add_argument_group("Parameter-Server Parameters")
  185. # for parameter server
  186. ps_group.add_argument(
  187. "--servers", type=str, default="", help="User defined servers ip:port"
  188. )
  189. ps_group.add_argument(
  190. "--workers", type=str, default="", help="User defined workers ip:port"
  191. )
  192. ps_group.add_argument(
  193. "--coordinators",
  194. type=str,
  195. default="",
  196. help="User defined coordinators ip:port",
  197. )
  198. ps_group.add_argument(
  199. "--heter_workers",
  200. type=str,
  201. default="",
  202. help="User defined heter workers in each stage ip1:port1;ip2:port2",
  203. )
  204. ps_group.add_argument(
  205. "--heter_devices",
  206. type=str,
  207. default="",
  208. help="User defined heter devices in each stage cpu;gpu;cpu",
  209. )
  210. ps_group.add_argument("--worker_num", type=int, help="number of workers")
  211. ps_group.add_argument(
  212. "--coordinator_num", type=int, help="number of coordinators"
  213. )
  214. ps_group.add_argument("--server_num", type=int, help="number of servers")
  215. ps_group.add_argument(
  216. "--heter_worker_num",
  217. type=str,
  218. help="number of heter_workers in each stage 1;2;3",
  219. )
  220. ps_group.add_argument("--http_port", type=int, help="Gloo http Port")
  221. # parameter elastic mode
  222. elastic_group = parser.add_argument_group("Elastic Parameters")
  223. elastic_group.add_argument(
  224. "--elastic_server", type=str, help="etcd server host:port"
  225. )
  226. elastic_group.add_argument(
  227. "--elastic_pre_hook", type=str, help="elastic pre_hook shell cmd"
  228. )
  229. elastic_group.add_argument("--job_id", type=str, help="job unique id")
  230. elastic_group.add_argument("--np", type=int, help="job pod/node number")
  231. elastic_group.add_argument("--scale", type=int, default=0, help="scale np")
  232. elastic_group.add_argument(
  233. "--host", type=str, help="bind host, default to POD_IP env"
  234. )
  235. elastic_group.add_argument(
  236. "--force", type=bool, default=False, help="update np force"
  237. )
  238. known_args, _ = parser.parse_known_args()
  239. return known_args
  240. def get_cluster_from_args(args, device_mode, devices_per_proc):
  241. node_ips = [x.strip() for x in args.ips.split(',')]
  242. if len(node_ips) == 1:
  243. node_ip = node_ips[0]
  244. else:
  245. if args.host:
  246. node_ip = args.host
  247. else:
  248. _, node_ip = get_host_name_ip()
  249. assert (
  250. node_ip in node_ips
  251. ), f"Can't find your local ip {{{node_ip}}} in node_ips: {{{node_ips}}}"
  252. node_rank = node_ips.index(node_ip)
  253. logger.debug(
  254. f"parsed from args: node_ips:{node_ips} node_ip:{node_ip} node_rank:{node_rank}"
  255. )
  256. free_ports = None
  257. if (
  258. not cloud_utils.use_paddlecloud()
  259. and len(node_ips) <= 1
  260. and os.environ.get('FLAGS_START_PORT') is None
  261. ):
  262. free_ports = find_free_ports(len(devices_per_proc))
  263. if free_ports is not None:
  264. free_ports = list(free_ports)
  265. logger.info(f"find free ports:{free_ports}")
  266. else:
  267. start_port = 6070
  268. if os.environ.get('FLAGS_START_PORT') is not None:
  269. start_port = int(os.environ.get('FLAGS_START_PORT'))
  270. free_ports = list(range(start_port, start_port + len(devices_per_proc)))
  271. trainer_endpoints = []
  272. for ip in node_ips:
  273. trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports])
  274. return get_cluster(
  275. node_ips, node_ip, trainer_endpoints, device_mode, devices_per_proc
  276. )
  277. def cpuonly_check(args):
  278. if args.ips and len(args.ips.split(',')) > 1:
  279. raise RuntimeError(
  280. "CPUONLY launch only support single trainer, that is len(ips)=1, but got %s."
  281. % args.ips
  282. )
  283. if args.run_mode:
  284. assert (
  285. args.run_mode == 'cpuonly'
  286. ), "CPUONLY launch only support run mode is CPUONLY"
  287. if args.servers:
  288. raise RuntimeError("CPUONLY launch can't have --servers as arguments.")
  289. return True
  290. def get_cluster_info(args):
  291. # parse arguments, used for cloud-single-machine and local
  292. if args.backend == 'gloo':
  293. cpuonly_check(args)
  294. if args.enable_auto_mapping:
  295. (device_mode, devices_per_proc) = (DeviceMode.GPU, [])
  296. else:
  297. (device_mode, devices_per_proc) = launch_utils.get_device_proc_info(
  298. args
  299. )
  300. trainers_num = cloud_utils.get_trainers_num()
  301. logger.debug(
  302. f"parsed from args trainers_num:{trainers_num} mode:{device_mode} devices:{devices_per_proc}"
  303. )
  304. cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
  305. cluster = None
  306. pod = None
  307. start_port = 6170
  308. if os.environ.get('FLAGS_START_PORT') is not None:
  309. start_port = os.environ.get('FLAGS_START_PORT')
  310. # auto mapping between processes and devices for auto-parallel
  311. if args.enable_auto_mapping:
  312. assert (
  313. args.cluster_topo_path is not None
  314. ), "The cluster topology must be provided when enabling auto mapping."
  315. rank_mapping_path = args.rank_mapping_path or os.getenv(
  316. "PADDLE_RANK_MAPPING_PATH"
  317. )
  318. if not rank_mapping_path:
  319. os.environ["PADDLE_NEED_RANK_MAPPING"] = str(True)
  320. os.environ["PADDLE_ENABLE_ELASTIC"] = str(
  321. enable_elastic(args, device_mode)
  322. )
  323. cwd = pathlib.Path().resolve()
  324. rank_mapping_path = os.path.join(
  325. cwd, "auto_parallel_rank_mapping.json"
  326. )
  327. os.environ["PADDLE_RANK_MAPPING_PATH"] = str(rank_mapping_path)
  328. original_args = sys.argv[1:]
  329. os.environ["PADDLE_ORIGINAL_CMD_ARGS"] = " ".join(original_args)
  330. os.environ["PADDLE_CLUSTER_TOPO_PATH"] = str(args.cluster_topo_path)
  331. os.environ["PADDLE_ENABLE_AUTO_MAPPING"] = str(
  332. args.enable_auto_mapping
  333. )
  334. (
  335. cluster,
  336. pod,
  337. ) = launch_utils.get_mapped_cluster_from_args_without_rank_mapping(
  338. args, device_mode
  339. )
  340. else:
  341. os.environ["PADDLE_NEED_RANK_MAPPING"] = str(False)
  342. os.environ["PADDLE_ENABLE_ELASTIC"] = str(
  343. enable_elastic(args, device_mode)
  344. )
  345. os.environ["PADDLE_CLUSTER_TOPO_PATH"] = str(args.cluster_topo_path)
  346. os.environ["PADDLE_RANK_MAPPING_PATH"] = str(rank_mapping_path)
  347. os.environ["PADDLE_ENABLE_AUTO_MAPPING"] = str(
  348. args.enable_auto_mapping
  349. )
  350. (
  351. cluster,
  352. pod,
  353. ) = launch_utils.get_mapped_cluster_from_args_with_rank_mapping(
  354. args, device_mode
  355. )
  356. elif cloud_utils.use_paddlecloud() and trainers_num != 1:
  357. cluster, pod = cloud_utils.get_cloud_cluster(
  358. args.ips, device_mode, devices_per_proc, start_port
  359. )
  360. logger.debug(f"get cluster from cloud:{cluster}")
  361. else:
  362. # trainers_num = 1 or not use paddlecloud ips="a,b"
  363. cluster, pod = get_cluster_from_args(
  364. args, device_mode, devices_per_proc
  365. )
  366. logger.debug(f"get cluster from args:{cluster}")
  367. return cluster, pod
  368. def get_global_envs(args, tmp_dir):
  369. global_envs = copy.copy(os.environ.copy())
  370. # add gloo env
  371. global_envs["PADDLE_WITH_GLOO"] = str(os.getenv("PADDLE_WITH_GLOO", "0"))
  372. global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3"
  373. global_envs["PADDLE_GLOO_FS_PATH"] = tmp_dir
  374. global_envs["PADDLE_DISTRI_BACKEND"] = args.backend
  375. return global_envs
  376. def launch_collective(args):
  377. tmp_dir = tempfile.mkdtemp()
  378. cluster, pod = get_cluster_info(args)
  379. global_envs = get_global_envs(args, tmp_dir)
  380. procs = start_local_trainers(
  381. cluster,
  382. pod,
  383. training_script=args.training_script,
  384. training_script_args=args.training_script_args,
  385. log_dir=args.log_dir,
  386. envs=global_envs,
  387. )
  388. for idx, proc in enumerate(procs):
  389. print(f"launch proc_id:{proc.proc.pid} idx:{idx}")
  390. while True:
  391. try:
  392. alive = watch_local_trainers(procs, cluster.trainers_nranks())
  393. if not alive:
  394. logger.info("Local processes completed.")
  395. logger.debug(f"POD info:{pod}")
  396. break
  397. time.sleep(3)
  398. except:
  399. logger.warning("Terminating... exit")
  400. terminate_local_procs(procs)
  401. sys.exit(1)
  402. if os.path.exists(tmp_dir):
  403. shutil.rmtree(tmp_dir)
  404. def launch_ps(args, distribute_mode):
  405. cloud_flag = cloud_utils.use_paddlecloud()
  406. # for ps-cpu on paddlecloud
  407. if cloud_flag and distribute_mode == DistributeMode.PS:
  408. direct_start(args)
  409. return
  410. # elif cloud_flag and distribute_mode == DistributeMode.PS_HETER:
  411. # cloud_ps_heter_env_set(args)
  412. # args.workers = os.getenv("PADDLE_TRAINER_ENDPOINTS")
  413. # args.servers = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")
  414. # args.heter_workers = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST")
  415. ps_launcher = ParameterServerLauncher(args, distribute_mode)
  416. ps_launcher.start_ps()
  417. return
  418. def infer_backend(args):
  419. if args.backend != "auto":
  420. return
  421. if framework.core.is_compiled_with_cuda():
  422. args.backend = 'nccl'
  423. elif framework.core.is_compiled_with_xpu():
  424. args.backend = 'bkcl'
  425. else:
  426. args.backend = 'gloo'
  427. def which_distributed_mode(args):
  428. infer_backend(args) # modify the args.backend
  429. if args.run_mode is not None:
  430. assert args.run_mode in ["collective", "ps", "ps-heter"]
  431. if args.run_mode == "collective":
  432. return DistributeMode.COLLECTIVE
  433. elif args.run_mode == "ps":
  434. return DistributeMode.PS
  435. elif args.run_mode == "ps-heter":
  436. return DistributeMode.PS_HETER
  437. ps_args = [
  438. '--worker_num',
  439. '--server_num',
  440. '--heter_worker_num',
  441. '--servers',
  442. '--workers',
  443. '--heter_workers',
  444. '--heter_devices',
  445. '--http_port',
  446. ]
  447. collective_args = ['--ips']
  448. ps_heter_args = ["--heter_worker_num", "--heter_workers", "--heter_devices"]
  449. coordinator_args = ["--coordinator_num", "--coordinators"]
  450. has_ps_args = [
  451. ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1])
  452. ]
  453. has_collective_args = [
  454. co_arg
  455. for co_arg in collective_args
  456. if co_arg in " ".join(sys.argv[1:-1])
  457. ]
  458. if len(has_ps_args) > 1 and len(has_collective_args) > 1:
  459. raise ValueError(
  460. "Only one mode(Collective or Parameter-Server) can be selected at the same time, but more than one configuration was received."
  461. )
  462. if framework.core.is_compiled_with_cuda():
  463. accelerators = framework.core.get_cuda_device_count()
  464. elif framework.core.is_compiled_with_xpu():
  465. accelerators = framework.core.get_xpu_device_count()
  466. else:
  467. accelerators = 0
  468. if len(has_ps_args) > 0:
  469. logger.info(
  470. f"Run parameter-sever mode. pserver arguments:{has_ps_args}, accelerators count:{accelerators}"
  471. )
  472. has_ps_heter_args = list(set(has_ps_args) & set(ps_heter_args))
  473. has_coordinator_args = list(set(has_ps_args) & set(coordinator_args))
  474. if len(has_ps_heter_args) > 0:
  475. return DistributeMode.PS_HETER
  476. else:
  477. return DistributeMode.PS
  478. elif len(has_collective_args) > 0:
  479. logger.info(
  480. f"Run collective mode. gpu arguments:{has_collective_args}, cuda count:{accelerators}"
  481. )
  482. return DistributeMode.COLLECTIVE
  483. else:
  484. if (
  485. not framework.core.is_compiled_with_cuda()
  486. and not framework.core.is_compiled_with_xpu()
  487. ):
  488. if args.servers:
  489. logger.warning(
  490. "Not found distinct arguments and not compiled with cuda or xpu. "
  491. "But found args.servers not empty, default use ps mode"
  492. )
  493. return DistributeMode.PS
  494. else:
  495. return DistributeMode.COLLECTIVE
  496. else:
  497. logger.warning(
  498. "Not found distinct arguments and compiled with cuda or xpu. "
  499. "Default use collective mode"
  500. )
  501. return DistributeMode.COLLECTIVE
  502. def launch():
  503. """
  504. Paddle distribution training entry ``python -m paddle.distributed.launch``.
  505. Usage:
  506. .. code-block:: bash
  507. :name: code-block-bash1
  508. python -m paddle.distributed.launch [-h] [--log_dir LOG_DIR] [--nproc_per_node NPROC_PER_NODE] [--run_mode RUN_MODE] [--gpus GPUS]
  509. [--selected_gpus GPUS] [--ips IPS] [--servers SERVERS] [--workers WORKERS] [--heter_workers HETER_WORKERS]
  510. [--worker_num WORKER_NUM] [--server_num SERVER_NUM] [--heter_worker_num HETER_WORKER_NUM]
  511. [--http_port HTTP_PORT] [--elastic_server ELASTIC_SERVER] [--job_id JOB_ID] [--np NP] [--scale SCALE]
  512. [--host HOST] [--force FORCE]
  513. training_script ...
  514. Base Parameters:
  515. - ``--log_dir``: The path for each process's log. e.g., ``--log_dir=output_dir``. Default ``--log_dir=log``.
  516. - ``--nproc_per_node``: The number of processes to launch on a node. In gpu training, it should be less or equal to the gpus number of you system(or you set by --gpus). e.g., ``--nproc_per_node=8``
  517. - ``--run_mode``: run mode of job, can be:collective/ps/ps-heter. e.g., ``--run_mode=ps``. Default ``--run_mode=collective``.
  518. - ``--gpus``: It's for gpu training. e.g., ``--gpus=0,1,2,3`` will launch four training processes each bound to one gpu.
  519. - ``--selected_gpus``: gpus aliases, recommend to use ``--gpus``.
  520. - ``--xpus``: It's for xpu training if xpu is available. e.g., ``--xpus=0,1,2,3``.
  521. - ``--selected_xpus``: xpus aliases, recommend to use ``--xpus``.
  522. - ``training_script``: The full path to the single GPU training program/script to be launched in parallel, followed by all the arguments for the training script. e.g., ``training.py``
  523. - ``training_script_args``: The args of training_script. e.g., ``--lr=0.1``
  524. Collective Parameters:
  525. - ``--ips``: Paddle cluster nodes ips, e.g., ``--ips=192.168.0.16,192.168.0.17``. Default ``--ips=127.0.0.1``.
  526. Parameter-Server Parameters:
  527. - ``--servers``: User defined servers ip:port, e.g., ``--servers="192.168.0.16:6170,192.168.0.17:6170"``
  528. - ``--workers``: User defined workers ip:port, e.g., ``--workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172"``
  529. - ``--heter_workers``: User defined heter workers ip1:port1;ip2:port2, e.g., ``--heter_workers="192.168.0.16:6172;192.168.0.17:6172"``
  530. - ``--worker_num``: Number of workers (It recommend to set when in the emulated distributed environment using single node)
  531. - ``--server_num``: Number of servers (It recommend to set when in the emulated distributed environment using single node)
  532. - ``--heter_worker_num``: Number of heter_workers in each stage (It recommend to set when in the emulated distributed environment using single node)
  533. - ``--heter_devices``: Type of heter_device in each stage
  534. - ``--http_port``: Gloo http Port
  535. Elastic Parameters:
  536. - ``--elastic_server``: etcd server host:port, e.g., ``--elastic_server=127.0.0.1:2379``
  537. - ``--job_id``: job unique id, e.g., ``--job_id=job1``
  538. - ``--np``: job pod/node number, e.g., ``--np=2``
  539. - ``--host``: bind host, default to POD_IP env.
  540. Returns:
  541. ``None``
  542. Examples 1 (collective, single node):
  543. .. code-block:: bash
  544. :name: code-block-example-bash1
  545. # For training on single node using 4 gpus.
  546. python -m paddle.distributed.launch --gpus=0,1,2,3 train.py --lr=0.01
  547. Examples 2 (collective, multi node):
  548. .. code-block:: bash
  549. :name: code-block-example-bash2
  550. # The parameters of --gpus and --ips must be consistent in each node.
  551. # For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17
  552. # On 192.168.0.16:
  553. python -m paddle.distributed.launch --gpus=0,1,2,3 --ips=192.168.0.16,192.168.0.17 train.py --lr=0.01
  554. # On 192.168.0.17:
  555. python -m paddle.distributed.launch --gpus=0,1,2,3 --ips=192.168.0.16,192.168.0.17 train.py --lr=0.01
  556. Examples 3 (ps, cpu, single node):
  557. .. code-block:: bash
  558. :name: code-block-example-bash3
  559. # To simulate distributed environment using single node, e.g., 2 servers and 4 workers.
  560. python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01
  561. Examples 4 (ps, cpu, multi node):
  562. .. code-block:: bash
  563. :name: code-block-example-bash4
  564. # For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17 where each node with 1 server and 2 workers.
  565. # On 192.168.0.16:
  566. python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01
  567. # On 192.168.0.17:
  568. python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01
  569. Examples 5 (ps, gpu, single node):
  570. .. code-block:: bash
  571. :name: code-block-example-bash5
  572. # To simulate distributed environment using single node, e.g., 2 servers and 4 workers, each worker use single gpu.
  573. export CUDA_VISIBLE_DEVICES=0,1,2,3
  574. python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01
  575. Examples 6 (ps, gpu, multi node):
  576. .. code-block:: bash
  577. :name: code-block-example-bash6
  578. # For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17 where each node with 1 server and 2 workers.
  579. # On 192.168.0.16:
  580. export CUDA_VISIBLE_DEVICES=0,1
  581. python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01
  582. # On 192.168.0.17:
  583. export CUDA_VISIBLE_DEVICES=0,1
  584. python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01
  585. Examples 7 (ps-heter, cpu + gpu, single node):
  586. .. code-block:: bash
  587. :name: code-block-example-bash7
  588. # To simulate distributed environment using single node, e.g., 2 servers and 4 workers, two workers use gpu, two workers use cpu.
  589. export CUDA_VISIBLE_DEVICES=0,1
  590. python -m paddle.distributed.launch --server_num=2 --worker_num=2 --heter_worker_num=2 train.py --lr=0.01
  591. Examples 8 (ps-heter, cpu + gpu, multi node):
  592. .. code-block:: bash
  593. :name: code-block-example-bash8
  594. # For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17 where each node with 1 server, 1 gpu worker, 1 cpu worker.
  595. # On 192.168.0.16:
  596. export CUDA_VISIBLE_DEVICES=0
  597. python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.17:6171" --heter_workers="192.168.0.16:6172,192.168.0.17:6172" train.py --lr=0.01
  598. # On 192.168.0.17:
  599. export CUDA_VISIBLE_DEVICES=0
  600. python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.17:6171" --heter_workers="192.168.0.16:6172,192.168.0.17:6172" train.py --lr=0.01
  601. Examples 9 (elastic):
  602. .. code-block:: bash
  603. :name: code-block-example-bash9
  604. python -m paddle.distributed.launch --elastic_server=127.0.0.1:2379 --np=2 --job_id=job1 --gpus=0,1,2,3 train.py
  605. """
  606. args = _parse_args()
  607. logger = get_logger()
  608. _print_arguments(args)
  609. if args.backend == 'auto':
  610. distribute_mode = which_distributed_mode(
  611. args
  612. ) # which_distributed_mode must modify args.backend
  613. else:
  614. assert (
  615. args.run_mode == 'collective' or args.run_mode is None
  616. ), "When backend is not 'auto', run mode must be collective"
  617. check_backend(args.backend)
  618. distribute_mode = DistributeMode.COLLECTIVE
  619. # assert args.backend in ['gloo', 'nccl', 'bkcl', 'heter', 'unknown']
  620. if args.backend == 'gloo':
  621. logger.warning("launch start with CPUONLY mode")
  622. block_windows_and_macos(
  623. args.backend
  624. ) # raise error when using gloo on windows or macos
  625. if enable_elastic(args, distribute_mode):
  626. launch_elastic(args, distribute_mode)
  627. return
  628. if distribute_mode == DistributeMode.COLLECTIVE:
  629. launch_collective(args)
  630. else:
  631. launch_ps(args, distribute_mode)
  632. if __name__ == "__main__":
  633. launch()