cluster.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917
  1. #!/usr/bin/env python
  2. # Copyright 2021 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. from ...utils import (
  17. ComputeEnvironment,
  18. DistributedType,
  19. is_deepspeed_available,
  20. is_fp8_available,
  21. is_hpu_available,
  22. is_mlu_available,
  23. is_mps_available,
  24. is_msamp_available,
  25. is_musa_available,
  26. is_npu_available,
  27. is_sdaa_available,
  28. is_transformer_engine_available,
  29. is_transformers_available,
  30. is_xpu_available,
  31. )
  32. from ...utils.constants import (
  33. DEEPSPEED_MULTINODE_LAUNCHERS,
  34. FSDP2_STATE_DICT_TYPE,
  35. FSDP_AUTO_WRAP_POLICY,
  36. FSDP_BACKWARD_PREFETCH,
  37. FSDP_SHARDING_STRATEGY,
  38. FSDP_STATE_DICT_TYPE,
  39. TORCH_DYNAMO_MODES,
  40. )
  41. from .config_args import ClusterConfig
  42. from .config_utils import (
  43. DYNAMO_BACKENDS,
  44. _ask_field,
  45. _ask_options,
  46. _convert_distributed_mode,
  47. _convert_dynamo_backend,
  48. _convert_fp8_backend,
  49. _convert_mixed_precision,
  50. _convert_yes_no_to_bool,
  51. )
  52. def get_cluster_input():
  53. distributed_type = _ask_options(
  54. "Which type of machine are you using?",
  55. [
  56. "No distributed training",
  57. "multi-CPU",
  58. "multi-XPU",
  59. "multi-HPU",
  60. "multi-GPU",
  61. "multi-NPU",
  62. "multi-MLU",
  63. "multi-SDAA",
  64. "multi-MUSA",
  65. "TPU",
  66. ],
  67. _convert_distributed_mode,
  68. )
  69. machine_rank = 0
  70. num_machines = 1
  71. num_processes = 1
  72. gpu_ids = None
  73. main_process_ip = None
  74. main_process_port = None
  75. rdzv_backend = "static"
  76. same_network = True
  77. debug = False
  78. if distributed_type in [
  79. DistributedType.MULTI_GPU,
  80. DistributedType.MULTI_MLU,
  81. DistributedType.MULTI_SDAA,
  82. DistributedType.MULTI_MUSA,
  83. DistributedType.MULTI_NPU,
  84. DistributedType.MULTI_XPU,
  85. DistributedType.MULTI_CPU,
  86. DistributedType.MULTI_HPU,
  87. ]:
  88. num_machines = _ask_field(
  89. "How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
  90. int,
  91. default=1,
  92. )
  93. if num_machines > 1:
  94. machine_rank = _ask_options(
  95. "What is the rank of this machine?",
  96. list(range(num_machines)),
  97. int,
  98. )
  99. main_process_ip = _ask_field(
  100. "What is the IP address of the machine that will host the main process? ",
  101. )
  102. main_process_port = _ask_field(
  103. "What is the port you will use to communicate with the main process? ",
  104. int,
  105. )
  106. same_network = _ask_field(
  107. "Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: ",
  108. _convert_yes_no_to_bool,
  109. default=True,
  110. error_message="Please enter yes or no.",
  111. )
  112. if not same_network:
  113. rdzv_backend = _ask_field(
  114. "What rendezvous backend will you use? ('static', 'c10d', ...): ", default="static"
  115. )
  116. debug = _ask_field(
  117. "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
  118. _convert_yes_no_to_bool,
  119. default=False,
  120. error_message="Please enter yes or no.",
  121. )
  122. if distributed_type == DistributedType.NO:
  123. use_cpu = _ask_field(
  124. "Do you want to run your training on CPU only (even if a GPU / Apple Silicon / Ascend NPU device is available)? [yes/NO]:",
  125. _convert_yes_no_to_bool,
  126. default=False,
  127. error_message="Please enter yes or no.",
  128. )
  129. elif distributed_type == DistributedType.MULTI_CPU:
  130. use_cpu = True
  131. else:
  132. use_cpu = False
  133. ipex_config = {}
  134. mpirun_config = {}
  135. if use_cpu or is_xpu_available():
  136. ipex_config["ipex"] = _ask_field(
  137. "Do you want to use Intel PyTorch Extension (IPEX) to speed up training on CPU/XPU? [yes/NO]:",
  138. _convert_yes_no_to_bool,
  139. default=False,
  140. error_message="Please enter yes or no.",
  141. )
  142. if use_cpu:
  143. if distributed_type == DistributedType.MULTI_CPU:
  144. use_mpirun = _ask_field(
  145. "Do you want accelerate to launch mpirun? [yes/NO]: ",
  146. _convert_yes_no_to_bool,
  147. default=False,
  148. error_message="Please enter yes or no.",
  149. )
  150. if use_mpirun:
  151. mpirun_hostfile = _ask_field(
  152. "Please enter the path to the hostfile to use with mpirun [~/hostfile]: ",
  153. str,
  154. default="~/hostfile",
  155. )
  156. mpirun_config["mpirun_hostfile"] = os.path.expanduser(mpirun_hostfile.strip())
  157. mpirun_config["mpirun_ccl"] = _ask_field("Enter the number of oneCCL worker threads [1]: ", default=1)
  158. dynamo_config = {}
  159. use_dynamo = _ask_field(
  160. "Do you wish to optimize your script with torch dynamo?[yes/NO]:",
  161. _convert_yes_no_to_bool,
  162. default=False,
  163. error_message="Please enter yes or no.",
  164. )
  165. if use_dynamo:
  166. prefix = "dynamo_"
  167. dynamo_config[prefix + "backend"] = _ask_options(
  168. "Which dynamo backend would you like to use?",
  169. [x.lower() for x in DYNAMO_BACKENDS],
  170. _convert_dynamo_backend,
  171. default=2,
  172. )
  173. use_custom_options = _ask_field(
  174. "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
  175. _convert_yes_no_to_bool,
  176. default=False,
  177. error_message="Please enter yes or no.",
  178. )
  179. if use_custom_options:
  180. dynamo_config[prefix + "mode"] = _ask_options(
  181. "Which mode do you want to use?",
  182. TORCH_DYNAMO_MODES,
  183. lambda x: TORCH_DYNAMO_MODES[int(x)],
  184. default=0,
  185. )
  186. dynamo_config[prefix + "use_fullgraph"] = _ask_field(
  187. "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
  188. _convert_yes_no_to_bool,
  189. default=False,
  190. error_message="Please enter yes or no.",
  191. )
  192. dynamo_config[prefix + "use_dynamic"] = _ask_field(
  193. "Do you want to enable dynamic shape tracing? [yes/NO]: ",
  194. _convert_yes_no_to_bool,
  195. default=False,
  196. error_message="Please enter yes or no.",
  197. )
  198. dynamo_config[prefix + "use_regional_compilation"] = _ask_field(
  199. "Do you want to enable regional compilation? [yes/NO]: ",
  200. _convert_yes_no_to_bool,
  201. default=False,
  202. error_message="Please enter yes or no.",
  203. )
  204. use_mps = not use_cpu and is_mps_available()
  205. deepspeed_config = {}
  206. if (
  207. distributed_type
  208. in [
  209. DistributedType.MULTI_GPU,
  210. DistributedType.MULTI_XPU,
  211. DistributedType.MULTI_HPU,
  212. DistributedType.MULTI_NPU,
  213. DistributedType.MULTI_MLU,
  214. DistributedType.MULTI_SDAA,
  215. DistributedType.MULTI_MUSA,
  216. DistributedType.NO,
  217. ]
  218. and not use_mps
  219. ):
  220. use_deepspeed = _ask_field(
  221. "Do you want to use DeepSpeed? [yes/NO]: ",
  222. _convert_yes_no_to_bool,
  223. default=False,
  224. error_message="Please enter yes or no.",
  225. )
  226. if use_deepspeed:
  227. distributed_type = DistributedType.DEEPSPEED
  228. assert is_deepspeed_available(), (
  229. "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
  230. )
  231. if distributed_type == DistributedType.DEEPSPEED:
  232. use_deepspeed_config = _ask_field(
  233. "Do you want to specify a json file to a DeepSpeed config? [yes/NO]: ",
  234. _convert_yes_no_to_bool,
  235. default=False,
  236. error_message="Please enter yes or no.",
  237. )
  238. if use_deepspeed_config:
  239. deepspeed_config["deepspeed_config_file"] = _ask_field(
  240. "Please enter the path to the json DeepSpeed config file: ",
  241. str,
  242. default="none",
  243. )
  244. else:
  245. deepspeed_config["zero_stage"] = _ask_options(
  246. "What should be your DeepSpeed's ZeRO optimization stage?",
  247. [0, 1, 2, 3],
  248. int,
  249. default=2,
  250. )
  251. deepspeed_devices = ["none", "cpu", "nvme"]
  252. if deepspeed_config["zero_stage"] >= 2:
  253. deepspeed_config["offload_optimizer_device"] = _ask_options(
  254. "Where to offload optimizer states?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
  255. )
  256. deepspeed_config["offload_param_device"] = _ask_options(
  257. "Where to offload parameters?", deepspeed_devices, lambda x: deepspeed_devices[int(x)]
  258. )
  259. if deepspeed_config["offload_param_device"] == "nvme":
  260. deepspeed_config["offload_param_nvme_path"] = _ask_field(
  261. "Nvme Path to offload parameters?",
  262. str,
  263. default="/nvme",
  264. )
  265. if deepspeed_config["offload_optimizer_device"] == "nvme":
  266. deepspeed_config["offload_optimizer_nvme_path"] = _ask_field(
  267. "Nvme Path to offload optimizer states?",
  268. str,
  269. default="/nvme",
  270. )
  271. deepspeed_config["gradient_accumulation_steps"] = _ask_field(
  272. "How many gradient accumulation steps you're passing in your script? [1]: ",
  273. int,
  274. default=1,
  275. )
  276. use_gradient_clipping = _ask_field(
  277. "Do you want to use gradient clipping? [yes/NO]: ",
  278. _convert_yes_no_to_bool,
  279. default=False,
  280. error_message="Please enter yes or no.",
  281. )
  282. if use_gradient_clipping:
  283. deepspeed_config["gradient_clipping"] = _ask_field(
  284. "What is the gradient clipping value? [1.0]: ",
  285. float,
  286. default=1.0,
  287. )
  288. if deepspeed_config["zero_stage"] == 3:
  289. deepspeed_config["zero3_save_16bit_model"] = _ask_field(
  290. "Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: ",
  291. _convert_yes_no_to_bool,
  292. default=False,
  293. error_message="Please enter yes or no.",
  294. )
  295. deepspeed_config["zero3_init_flag"] = _ask_field(
  296. "Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: ",
  297. _convert_yes_no_to_bool,
  298. default=False,
  299. error_message="Please enter yes or no.",
  300. )
  301. if deepspeed_config["zero3_init_flag"]:
  302. if not is_transformers_available():
  303. raise Exception(
  304. "When `zero3_init_flag` is set, it requires Transformers to be installed. "
  305. "Please run `pip3 install transformers`."
  306. )
  307. use_moe = _ask_field(
  308. "Do you want to enable Mixture-of-Experts training (MoE)? [yes/NO]: ",
  309. _convert_yes_no_to_bool,
  310. default=False,
  311. error_message="Please enter yes or no.",
  312. )
  313. if use_moe:
  314. deepspeed_config["deepspeed_moe_layer_cls_names"] = _ask_field(
  315. "Specify the comma-separated list of transformers MoE layer class names (case-sensitive), e.g : "
  316. " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ... : ",
  317. str,
  318. )
  319. if num_machines > 1:
  320. launcher_query = "Which Type of launcher do you want to use?"
  321. deepspeed_config["deepspeed_multinode_launcher"] = _ask_options(
  322. launcher_query,
  323. DEEPSPEED_MULTINODE_LAUNCHERS,
  324. lambda x: DEEPSPEED_MULTINODE_LAUNCHERS[int(x)],
  325. )
  326. if deepspeed_config["deepspeed_multinode_launcher"] != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
  327. deepspeed_config["deepspeed_hostfile"] = _ask_field(
  328. "DeepSpeed configures multi-node compute resources with hostfile. "
  329. "Each row is of the format `hostname slots=[num_gpus]`, e.g., `localhost slots=2`; "
  330. "for more information please refer official [documentation]"
  331. "(https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). "
  332. "Please specify the location of hostfile: ",
  333. str,
  334. )
  335. is_exclusion_filter = _ask_field(
  336. "Do you want to specify exclusion filter string? [yes/NO]: ",
  337. _convert_yes_no_to_bool,
  338. default=False,
  339. error_message="Please enter yes or no.",
  340. )
  341. if is_exclusion_filter:
  342. deepspeed_config["deepspeed_exclusion_filter"] = _ask_field(
  343. "DeepSpeed exclusion filter string: ",
  344. str,
  345. )
  346. is_inclusion_filter = _ask_field(
  347. "Do you want to specify inclusion filter string? [yes/NO]: ",
  348. _convert_yes_no_to_bool,
  349. default=False,
  350. error_message="Please enter yes or no.",
  351. )
  352. if is_inclusion_filter:
  353. deepspeed_config["deepspeed_inclusion_filter"] = _ask_field(
  354. "DeepSpeed inclusion filter string: ",
  355. str,
  356. )
  357. fsdp_config = {}
  358. if distributed_type in [
  359. DistributedType.MULTI_GPU,
  360. DistributedType.MULTI_NPU,
  361. DistributedType.MULTI_MLU,
  362. DistributedType.MULTI_SDAA,
  363. DistributedType.MULTI_MUSA,
  364. DistributedType.MULTI_XPU,
  365. DistributedType.MULTI_HPU,
  366. ]:
  367. use_fsdp = _ask_field(
  368. "Do you want to use FullyShardedDataParallel? [yes/NO]: ",
  369. _convert_yes_no_to_bool,
  370. default=False,
  371. error_message="Please enter yes or no.",
  372. )
  373. if use_fsdp:
  374. distributed_type = DistributedType.FSDP
  375. if distributed_type == DistributedType.FSDP:
  376. fsdp_config["fsdp_version"] = _ask_options(
  377. "What should be your FSDP version? [2]: ",
  378. [1, 2],
  379. lambda x: int(x) + 1,
  380. default=1,
  381. )
  382. fsdp_version = fsdp_config["fsdp_version"] # extract to a variable to simplify usage later
  383. if fsdp_version == 1:
  384. sharding_strategy_query = "What should be your sharding strategy?"
  385. fsdp_config["fsdp_reshard_after_forward"] = _ask_options(
  386. sharding_strategy_query,
  387. FSDP_SHARDING_STRATEGY,
  388. lambda x: FSDP_SHARDING_STRATEGY[int(x)],
  389. )
  390. else:
  391. fsdp_config["fsdp_reshard_after_forward"] = _ask_field(
  392. "Do you want to enable resharding after forward? [YES/no]: ",
  393. _convert_yes_no_to_bool,
  394. default=True,
  395. error_message="Please enter yes or no.",
  396. )
  397. fsdp_config["fsdp_offload_params"] = _ask_field(
  398. "Do you want to offload parameters and gradients to CPU? [yes/NO]: ",
  399. _convert_yes_no_to_bool,
  400. default=False,
  401. error_message="Please enter yes or no.",
  402. )
  403. fsdp_wrap_query = "What should be your auto wrap policy?"
  404. fsdp_config["fsdp_auto_wrap_policy"] = _ask_options(
  405. fsdp_wrap_query,
  406. FSDP_AUTO_WRAP_POLICY,
  407. lambda x: FSDP_AUTO_WRAP_POLICY[int(x)],
  408. )
  409. if fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[0]:
  410. use_no_split_modules = _ask_field(
  411. "Do you want to use the model's `_no_split_modules` to wrap. Only applicable for 🤗 Transformers [yes/NO]: ",
  412. _convert_yes_no_to_bool,
  413. default=False,
  414. error_message="Please enter yes or no.",
  415. )
  416. if not use_no_split_modules:
  417. fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = _ask_field(
  418. "Specify the comma-separated list of transformer layer class names (case-sensitive) to wrap ,e.g, :"
  419. "`BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput` ...? : ",
  420. str,
  421. )
  422. elif fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[1]:
  423. fsdp_config["fsdp_min_num_params"] = _ask_field(
  424. "What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: ",
  425. int,
  426. default=100000000,
  427. )
  428. # Removed in FSDP2, ask for user input for FSDP1
  429. if fsdp_version == 1:
  430. fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?"
  431. fsdp_config["fsdp_backward_prefetch"] = _ask_options(
  432. fsdp_backward_prefetch_query,
  433. FSDP_BACKWARD_PREFETCH,
  434. lambda x: FSDP_BACKWARD_PREFETCH[int(x)],
  435. )
  436. fsdp_state_dict_type_query = "What should be your FSDP's state dict type?"
  437. fsdp_config["fsdp_state_dict_type"] = _ask_options(
  438. fsdp_state_dict_type_query,
  439. FSDP_STATE_DICT_TYPE if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE,
  440. lambda x: FSDP_STATE_DICT_TYPE[int(x)] if fsdp_version == 1 else FSDP2_STATE_DICT_TYPE[int(x)],
  441. default=0,
  442. )
  443. # Not implemented in FSDP2, ask for user input for FSDP1
  444. if fsdp_version == 1:
  445. fsdp_config["fsdp_forward_prefetch"] = _ask_field(
  446. "Do you want to enable FSDP's forward prefetch policy? [yes/NO]: ",
  447. _convert_yes_no_to_bool,
  448. default=False,
  449. error_message="Please enter yes or no.",
  450. )
  451. # Obsolete in FSDP2, ask for user input for FSDP1
  452. if fsdp_version == 1:
  453. fsdp_config["fsdp_use_orig_params"] = _ask_field(
  454. "Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ",
  455. _convert_yes_no_to_bool,
  456. default=True,
  457. error_message="Please enter yes or no.",
  458. )
  459. fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field(
  460. "Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: ",
  461. _convert_yes_no_to_bool,
  462. default=True,
  463. error_message="Please enter yes or no.",
  464. )
  465. # Obsolete in FSDP2, ask for user input for FSDP1
  466. if fsdp_version == 1:
  467. if fsdp_config["fsdp_cpu_ram_efficient_loading"]:
  468. fsdp_config["fsdp_sync_module_states"] = True
  469. else:
  470. fsdp_config["fsdp_sync_module_states"] = _ask_field(
  471. "Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ",
  472. _convert_yes_no_to_bool,
  473. default=True,
  474. error_message="Please enter yes or no.",
  475. )
  476. fsdp_config["fsdp_activation_checkpointing"] = _ask_field(
  477. "Do you want to enable FSDP activation checkpointing? [yes/NO]: ",
  478. _convert_yes_no_to_bool,
  479. default=False,
  480. error_message="Please enter yes or no.",
  481. )
  482. parallelism_config = {}
  483. if fsdp_config.get("fsdp_version", 1) == 2:
  484. use_parallelism_config = _ask_field(
  485. "Do you want to use the parallelism config? [yes/NO]: ",
  486. _convert_yes_no_to_bool,
  487. default=False,
  488. error_message="Please enter yes or no.",
  489. )
  490. if use_parallelism_config:
  491. prefix = "parallelism_config_"
  492. parallelism_config[prefix + "dp_replicate_size"] = _ask_field(
  493. "What is the data parallelism replicate size? [1]: ",
  494. int,
  495. default=1,
  496. error_message="Please enter an integer.",
  497. )
  498. parallelism_config[prefix + "dp_shard_size"] = _ask_field(
  499. "What is the FSDP shard size? [1]: ",
  500. int,
  501. default=1,
  502. error_message="Please enter an integer.",
  503. )
  504. parallelism_config[prefix + "tp_size"] = _ask_field(
  505. "What is the tensor parallelism size? [1]: ",
  506. int,
  507. default=1,
  508. error_message="Please enter an integer.",
  509. )
  510. parallelism_config[prefix + "cp_size"] = _ask_field(
  511. "What is the context parallelism size? [1]: ",
  512. int,
  513. default=1,
  514. error_message="Please enter an integer.",
  515. )
  516. if parallelism_config[prefix + "cp_size"] > 1:
  517. parallelism_config[prefix + "cp_comm_strategy"] = _ask_options(
  518. "What is the compute parallelism communication strategy?",
  519. ["allgather", "alltoall"],
  520. lambda x: ["allgather", "alltoall"][int(x)],
  521. default=0,
  522. )
  523. megatron_lm_config = {}
  524. if distributed_type in [DistributedType.MULTI_GPU]:
  525. use_megatron_lm = _ask_field(
  526. "Do you want to use Megatron-LM ? [yes/NO]: ",
  527. _convert_yes_no_to_bool,
  528. default=False,
  529. error_message="Please enter yes or no.",
  530. )
  531. if use_megatron_lm:
  532. distributed_type = DistributedType.MEGATRON_LM
  533. if distributed_type == DistributedType.MEGATRON_LM:
  534. prefix = "megatron_lm_"
  535. megatron_lm_config[prefix + "tp_degree"] = _ask_field(
  536. "What is the Tensor Parallelism degree/size? [1]:",
  537. int,
  538. default=1,
  539. error_message="Please enter an integer.",
  540. )
  541. if megatron_lm_config[prefix + "tp_degree"] > 1:
  542. megatron_lm_config[prefix + "sequence_parallelism"] = _ask_field(
  543. "Do you want to enable Sequence Parallelism? [YES/no]: ",
  544. _convert_yes_no_to_bool,
  545. default=True,
  546. error_message="Please enter yes or no.",
  547. )
  548. megatron_lm_config[prefix + "pp_degree"] = _ask_field(
  549. "What is the Pipeline Parallelism degree/size? [1]:",
  550. int,
  551. default=1,
  552. error_message="Please enter an integer.",
  553. )
  554. if megatron_lm_config[prefix + "pp_degree"] > 1:
  555. megatron_lm_config[prefix + "num_micro_batches"] = _ask_field(
  556. "What is the number of micro-batches? [1]:",
  557. int,
  558. default=1,
  559. error_message="Please enter an integer.",
  560. )
  561. megatron_lm_config[prefix + "recompute_activations"] = _ask_field(
  562. "Do you want to enable selective activation recomputation? [YES/no]: ",
  563. _convert_yes_no_to_bool,
  564. default=True,
  565. error_message="Please enter yes or no.",
  566. )
  567. megatron_lm_config[prefix + "use_distributed_optimizer"] = _ask_field(
  568. "Do you want to use distributed optimizer "
  569. "which shards optimizer state and gradients across data parallel ranks? [YES/no]: ",
  570. _convert_yes_no_to_bool,
  571. default=True,
  572. error_message="Please enter yes or no.",
  573. )
  574. megatron_lm_config[prefix + "gradient_clipping"] = _ask_field(
  575. "What is the gradient clipping value based on global L2 Norm (0 to disable)? [1.0]: ",
  576. float,
  577. default=1.0,
  578. )
  579. # TPU specific defaults
  580. tpu_commands = None
  581. tpu_command_file = None
  582. tpu_downcast_bf16 = "no"
  583. tpu_env = []
  584. tpu_name = None
  585. tpu_vm = None
  586. tpu_zone = None
  587. tpu_use_sudo = False
  588. tpu_use_cluster = False
  589. if distributed_type in [
  590. DistributedType.MULTI_CPU,
  591. DistributedType.MULTI_XPU,
  592. DistributedType.MULTI_HPU,
  593. DistributedType.MULTI_GPU,
  594. DistributedType.MULTI_MLU,
  595. DistributedType.MULTI_SDAA,
  596. DistributedType.MULTI_MUSA,
  597. DistributedType.MULTI_NPU,
  598. DistributedType.XLA,
  599. ]:
  600. machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
  601. if machine_type == "TPU":
  602. machine_type += " cores"
  603. elif machine_type == "CPU":
  604. machine_type = "processes"
  605. else:
  606. machine_type += "(s)"
  607. num_processes = _ask_field(
  608. f"How many {machine_type} should be used for distributed training? [1]:",
  609. int,
  610. default=1,
  611. error_message="Please enter an integer.",
  612. )
  613. elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
  614. num_processes = _ask_field(
  615. "How many GPU(s) should be used for distributed training? [1]:",
  616. int,
  617. default=1,
  618. error_message="Please enter an integer.",
  619. )
  620. else:
  621. num_processes = 1
  622. if (distributed_type == DistributedType.MULTI_GPU) and (num_machines == 1) and (num_processes == 1):
  623. raise ValueError(
  624. f"Specified distributed type {distributed_type} but only using 1 GPU on a single machine. Please select `No distributed training` for the type of machine you are using."
  625. )
  626. if (
  627. distributed_type
  628. in [
  629. DistributedType.MULTI_GPU,
  630. DistributedType.MULTI_MLU,
  631. DistributedType.MULTI_SDAA,
  632. DistributedType.MULTI_MUSA,
  633. DistributedType.MULTI_NPU,
  634. DistributedType.MULTI_XPU,
  635. DistributedType.MULTI_HPU,
  636. DistributedType.NO,
  637. ]
  638. and not use_cpu
  639. and not use_mps
  640. ):
  641. if is_npu_available():
  642. machine_type = "NPU(s)"
  643. elif is_mlu_available():
  644. machine_type = "MLU(s)"
  645. elif is_sdaa_available():
  646. machine_type = "SDAA(s)"
  647. elif is_musa_available():
  648. machine_type = "MUSA(s)"
  649. elif is_xpu_available():
  650. machine_type = "XPU(s)"
  651. elif is_hpu_available():
  652. machine_type = "HPU(s)"
  653. else:
  654. machine_type = "GPU(s)"
  655. gpu_ids = _ask_field(
  656. f"What {machine_type} (by id) should be used for training on this machine as a comma-separated list? [all]:",
  657. default="all",
  658. )
  659. # CPU affinity is only supported on NVIDIA hardware for now
  660. enable_cpu_affinity = False
  661. if distributed_type in (DistributedType.NO, DistributedType.MULTI_GPU) and not use_cpu and not use_mps:
  662. enable_cpu_affinity = _ask_field(
  663. "Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: ",
  664. _convert_yes_no_to_bool,
  665. default=False,
  666. error_message="Please enter yes or no.",
  667. )
  668. fp8_config = None
  669. if distributed_type == DistributedType.XLA:
  670. mixed_precision = "no"
  671. main_training_function = _ask_field(
  672. "What is the name of the function in your script that should be launched in all parallel scripts? [main]: ",
  673. default="main",
  674. )
  675. tpu_use_cluster = _ask_field(
  676. "Are you using a TPU cluster? [yes/NO]: ",
  677. _convert_yes_no_to_bool,
  678. default=False,
  679. error_message="Please enter yes or no.",
  680. )
  681. if tpu_use_cluster:
  682. tpu_name = _ask_field(
  683. "What is the name of your TPU cluster? ",
  684. default=None,
  685. error_message="Please enter the name of your TPU cluster.",
  686. )
  687. tpu_zone = _ask_field(
  688. "What is the zone of your TPU cluster? ",
  689. default=None,
  690. error_message="Please enter the zone of your TPU cluster.",
  691. )
  692. tpu_use_sudo = _ask_field(
  693. "To run a python script in a TPU pod, should `sudo` be used? [yes/NO]: ",
  694. default=False,
  695. error_message="Please enter yes or no.",
  696. )
  697. run_commands = _ask_field(
  698. "Do you have code you wish to run on startup in each pod? [yes/NO]: ",
  699. _convert_yes_no_to_bool,
  700. default=False,
  701. error_message="Please enter yes or no.",
  702. )
  703. if run_commands:
  704. use_command_file = _ask_field(
  705. "Is this code located in a bash script? [yes/NO]: ",
  706. _convert_yes_no_to_bool,
  707. default=False,
  708. error_message="Please enter yes or no.",
  709. )
  710. if use_command_file:
  711. tpu_command_file = _ask_field(
  712. "What is the path to your bash script? ",
  713. default=None,
  714. error_message="Please enter the path to your bash script.",
  715. )
  716. tpu_command_file = os.path.abspath(tpu_command_file)
  717. else:
  718. print("Please enter each command separately you wish to run on startup in each pod.")
  719. tpu_commands = []
  720. another_command = True
  721. while another_command:
  722. tpu_commands.append(
  723. _ask_field(
  724. "Please enter a single command to be ran ",
  725. default=None,
  726. error_message="Please enter the commands you wish to run on startup in each pod as a single string.",
  727. )
  728. )
  729. another_command = _ask_field(
  730. "Do you wish to add another command? [yes/NO]: ",
  731. _convert_yes_no_to_bool,
  732. default=False,
  733. error_message="Please enter yes or no.",
  734. )
  735. tpu_vm = _ask_field(
  736. "If not using an instance group, what are the names of the Compute VM instances to be used, separated by a comma: ",
  737. default="",
  738. ).split(",")
  739. tpu_env = _ask_field(
  740. "What environment variables do you wish to set in each pod, separated by a comma: ",
  741. default="",
  742. ).split(",")
  743. else:
  744. main_training_function = "main"
  745. if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config:
  746. mixed_precision = None
  747. else:
  748. mixed_precision = _ask_options(
  749. "Do you wish to use mixed precision?",
  750. ["no", "fp16", "bf16", "fp8"],
  751. _convert_mixed_precision,
  752. )
  753. if mixed_precision == "fp8":
  754. if not is_fp8_available():
  755. raise ValueError("FP8 (either Transformer Engine or MSAMP) is not installed on this machine.")
  756. fp8_config = {}
  757. fp8_config["backend"] = _ask_options(
  758. "Which FP8 backend do you want to use?",
  759. ["te", "msamp"],
  760. _convert_fp8_backend,
  761. )
  762. if fp8_config["backend"] == "TE":
  763. if not is_transformer_engine_available():
  764. raise ValueError("TransformersEngine was selected, but it is not installed on this machine.")
  765. fp8_config["use_autocast_during_eval"] = _ask_field(
  766. "Do you want to use FP8 autocast during eval mode? Generally better metrics are found when this is disabled [yes/NO]: ",
  767. _convert_yes_no_to_bool,
  768. default=False,
  769. )
  770. fp8_config["margin"] = _ask_field(
  771. "What margin should be used for gradient scaling? [0]: ",
  772. int,
  773. default=0,
  774. )
  775. fp8_config["interval"] = _ask_field(
  776. "What interval should be used for for how often the scaling factor is recomputed? [1]: ",
  777. int,
  778. default=1,
  779. )
  780. fp8_config["fp8_format"] = _ask_options(
  781. "Which weight format should be used?",
  782. ["HYBRID", "E4M3", "E5M2"],
  783. lambda i: ["HYBRID", "E4M3", "E5M2"][i],
  784. default=0,
  785. )
  786. fp8_config["amax_history_length"] = _ask_field(
  787. "What length of history should be used for the amax scaling factor computation? [1024]: ",
  788. int,
  789. default=1024,
  790. )
  791. fp8_config["amax_compute_algorithm"] = _ask_options(
  792. "Which algorithm should be used for the amax scaling factor computation?",
  793. ["max", "most_recent"],
  794. lambda x: "max" if x == 0 else "most_recent",
  795. default=0,
  796. )
  797. fp8_config["override_linear_precision"] = _ask_field(
  798. "Do you want to to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision? [yes/NO]: ",
  799. _convert_yes_no_to_bool,
  800. default=False,
  801. )
  802. if fp8_config["override_linear_precision"]:
  803. fprop = _ask_field(
  804. "Should `fprop` be executed in higher precision? [yes/NO]: ",
  805. _convert_yes_no_to_bool,
  806. default=False,
  807. )
  808. dgrad = _ask_field(
  809. "Should `dgrad` be executed in higher precision? [yes/NO]: ",
  810. _convert_yes_no_to_bool,
  811. default=False,
  812. )
  813. wgrad = _ask_field(
  814. "Should `wgrad` be executed in higher precision? [yes/NO]: ",
  815. _convert_yes_no_to_bool,
  816. default=False,
  817. )
  818. fp8_config["override_linear_precision"] = (fprop, dgrad, wgrad)
  819. else:
  820. fp8_config["override_linear_precision"] = (False, False, False)
  821. elif fp8_config["backend"] == "MSAMP":
  822. if not is_msamp_available():
  823. raise ValueError("MSAMP was selected, but it is not installed on this machine.")
  824. fp8_config["optimization_level"] = _ask_options(
  825. "Which optimization level should be used?",
  826. ["O1", "O2"],
  827. lambda x: "O1" if x == 0 else "O2",
  828. default=1,
  829. )
  830. if use_dynamo and mixed_precision == "no" and not use_cpu:
  831. print(
  832. "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
  833. )
  834. if distributed_type == DistributedType.XLA and mixed_precision == "bf16":
  835. tpu_downcast_bf16 = _ask_field(
  836. "Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no"
  837. )
  838. return ClusterConfig(
  839. compute_environment=ComputeEnvironment.LOCAL_MACHINE,
  840. distributed_type=distributed_type,
  841. num_processes=num_processes,
  842. gpu_ids=gpu_ids,
  843. mixed_precision=mixed_precision,
  844. downcast_bf16=tpu_downcast_bf16,
  845. machine_rank=machine_rank,
  846. num_machines=num_machines,
  847. main_process_ip=main_process_ip,
  848. main_process_port=main_process_port,
  849. main_training_function=main_training_function,
  850. fp8_config=fp8_config,
  851. deepspeed_config=deepspeed_config,
  852. fsdp_config=fsdp_config,
  853. parallelism_config=parallelism_config,
  854. megatron_lm_config=megatron_lm_config,
  855. ipex_config=ipex_config,
  856. mpirun_config=mpirun_config,
  857. use_cpu=use_cpu,
  858. rdzv_backend=rdzv_backend,
  859. same_network=same_network,
  860. commands=tpu_commands,
  861. command_file=tpu_command_file,
  862. tpu_env=tpu_env,
  863. tpu_name=tpu_name,
  864. tpu_vm=tpu_vm,
  865. tpu_zone=tpu_zone,
  866. tpu_use_sudo=tpu_use_sudo,
  867. tpu_use_cluster=tpu_use_cluster,
  868. dynamo_config=dynamo_config,
  869. debug=debug,
  870. enable_cpu_affinity=enable_cpu_affinity,
  871. )