launch.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import subprocess
  17. import sys
  18. import warnings
  19. from ast import literal_eval
  20. from shutil import which
  21. from typing import Any
  22. import torch
  23. from ..commands.config.config_args import SageMakerConfig
  24. from ..utils import (
  25. DynamoBackend,
  26. PrecisionType,
  27. is_ccl_available,
  28. is_fp8_available,
  29. is_hpu_available,
  30. is_ipex_available,
  31. is_mlu_available,
  32. is_musa_available,
  33. is_npu_available,
  34. is_sdaa_available,
  35. is_torch_xla_available,
  36. is_xpu_available,
  37. )
  38. from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
  39. from ..utils.other import get_free_port, is_port_in_use, merge_dicts
  40. from ..utils.versions import compare_versions
  41. from .dataclasses import DistributedType, SageMakerDistributedType
  42. def _filter_args(args, parser, default_args=[]):
  43. """
  44. Filters out all `accelerate` specific args
  45. """
  46. new_args, _ = parser.parse_known_args(default_args)
  47. for key, value in vars(args).items():
  48. if key in vars(new_args).keys():
  49. setattr(new_args, key, value)
  50. return new_args
  51. def _get_mpirun_args():
  52. """
  53. Determines the executable and argument names for mpirun, based on the type of install. The supported MPI programs
  54. are: OpenMPI, Intel MPI, or MVAPICH.
  55. Returns: Program name and arg names for hostfile, num processes, and processes per node
  56. """
  57. # Find the MPI program name
  58. mpi_apps = [x for x in ["mpirun", "mpiexec"] if which(x)]
  59. if len(mpi_apps) == 0:
  60. raise OSError("mpirun or mpiexec were not found. Ensure that Intel MPI, Open MPI, or MVAPICH are installed.")
  61. # Call the app with the --version flag to determine which MPI app is installed
  62. mpi_app = mpi_apps[0]
  63. mpirun_version = subprocess.check_output([mpi_app, "--version"])
  64. if b"Open MPI" in mpirun_version:
  65. return mpi_app, "--hostfile", "-n", "--npernode", "--bind-to"
  66. else:
  67. # Intel MPI and MVAPICH both use the same arg names
  68. return mpi_app, "-f", "-n", "-ppn", ""
  69. def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
  70. """
  71. Setup the FP8 environment variables.
  72. """
  73. prefix = "ACCELERATE_"
  74. for arg in vars(args):
  75. if arg.startswith("fp8_"):
  76. value = getattr(args, arg)
  77. if value is not None:
  78. if arg == "fp8_override_linear_precision":
  79. current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
  80. current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
  81. current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
  82. else:
  83. current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
  84. return current_env
  85. def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]:
  86. """
  87. Prepares and returns the command list and an environment with the correct simple launcher environment variables.
  88. """
  89. cmd = []
  90. if args.no_python and args.module:
  91. raise ValueError("--module and --no_python cannot be used together")
  92. num_processes = getattr(args, "num_processes", None)
  93. num_machines = args.num_machines
  94. if args.mpirun_hostfile is not None:
  95. mpi_app_name, hostfile_arg, num_proc_arg, proc_per_node_arg, bind_to_arg = _get_mpirun_args()
  96. bind_to = getattr(args, "bind-to", "socket")
  97. nproc_per_node = str(num_processes // num_machines) if num_processes and num_machines else "1"
  98. cmd += [
  99. mpi_app_name,
  100. hostfile_arg,
  101. args.mpirun_hostfile,
  102. proc_per_node_arg,
  103. nproc_per_node,
  104. ]
  105. if num_processes:
  106. cmd += [num_proc_arg, str(num_processes)]
  107. if bind_to_arg:
  108. cmd += [bind_to_arg, bind_to]
  109. if not args.no_python:
  110. cmd.append(sys.executable)
  111. if args.module:
  112. cmd.append("-m")
  113. cmd.append(args.training_script)
  114. cmd.extend(args.training_script_args)
  115. current_env = os.environ.copy()
  116. current_env["ACCELERATE_USE_CPU"] = str(args.cpu or args.use_cpu)
  117. if args.debug:
  118. current_env["ACCELERATE_DEBUG_MODE"] = "true"
  119. if args.gpu_ids != "all" and args.gpu_ids is not None:
  120. if is_xpu_available():
  121. current_env["ZE_AFFINITY_MASK"] = args.gpu_ids
  122. elif is_mlu_available():
  123. current_env["MLU_VISIBLE_DEVICES"] = args.gpu_ids
  124. elif is_sdaa_available():
  125. current_env["SDAA_VISIBLE_DEVICES"] = args.gpu_ids
  126. elif is_musa_available():
  127. current_env["MUSA_VISIBLE_DEVICES"] = args.gpu_ids
  128. elif is_npu_available():
  129. current_env["ASCEND_RT_VISIBLE_DEVICES"] = args.gpu_ids
  130. elif is_hpu_available():
  131. current_env["HABANA_VISIBLE_MODULES"] = args.gpu_ids
  132. else:
  133. current_env["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
  134. if num_machines > 1:
  135. assert args.main_process_ip is not None, (
  136. "When using multiple machines, you need to specify the main process IP."
  137. )
  138. assert args.main_process_port is not None, (
  139. "When using multiple machines, you need to specify the main process port."
  140. )
  141. ccl_worker_count = getattr(args, "mpirun_ccl", 0) if is_ccl_available() else 0
  142. if (num_processes is not None and num_processes > 1) or num_machines > 1:
  143. current_env["MASTER_ADDR"] = args.main_process_ip if args.main_process_ip is not None else "127.0.0.1"
  144. current_env["MASTER_PORT"] = str(args.main_process_port) if args.main_process_port is not None else "29500"
  145. current_env["CCL_WORKER_COUNT"] = str(ccl_worker_count)
  146. if current_env["ACCELERATE_USE_CPU"]:
  147. current_env["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
  148. current_env["KMP_BLOCKTIME"] = str(1)
  149. try:
  150. mixed_precision = PrecisionType(args.mixed_precision.lower())
  151. except ValueError:
  152. raise ValueError(
  153. f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
  154. )
  155. current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
  156. if args.mixed_precision.lower() == "fp8":
  157. if not is_fp8_available():
  158. raise RuntimeError(
  159. "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
  160. )
  161. current_env = setup_fp8_env(args, current_env)
  162. try:
  163. dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
  164. except ValueError:
  165. raise ValueError(
  166. f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
  167. )
  168. current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value
  169. current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode
  170. current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph)
  171. current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic)
  172. current_env["ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION"] = str(args.dynamo_use_regional_compilation)
  173. current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
  174. if is_ipex_available():
  175. current_env["ACCELERATE_USE_IPEX"] = str(args.ipex).lower()
  176. if args.enable_cpu_affinity:
  177. current_env["ACCELERATE_CPU_AFFINITY"] = "1"
  178. return cmd, current_env
  179. def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
  180. """
  181. Prepares and returns an environment with the correct multi-GPU environment variables.
  182. """
  183. # get free port and update configurations
  184. if args.main_process_port == 0:
  185. args.main_process_port = get_free_port()
  186. elif args.main_process_port is None:
  187. args.main_process_port = 29500
  188. num_processes = args.num_processes
  189. num_machines = args.num_machines
  190. main_process_ip = args.main_process_ip
  191. main_process_port = args.main_process_port
  192. if num_machines > 1:
  193. args.nproc_per_node = str(num_processes // num_machines)
  194. args.nnodes = str(num_machines)
  195. args.node_rank = int(args.machine_rank)
  196. if getattr(args, "same_network", False):
  197. args.master_addr = str(main_process_ip)
  198. args.master_port = str(main_process_port)
  199. else:
  200. args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}"
  201. else:
  202. args.nproc_per_node = str(num_processes)
  203. if main_process_port is not None:
  204. args.master_port = str(main_process_port)
  205. # only need to check port availability in main process, in case we have to start multiple launchers on the same machine
  206. # for some reasons like splitting log files.
  207. need_port_check = num_machines <= 1 or int(args.machine_rank) == 0
  208. if need_port_check and is_port_in_use(main_process_port):
  209. if num_machines <= 1:
  210. args.standalone = True
  211. warnings.warn(
  212. f"Port `{main_process_port}` is already in use. "
  213. "Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. "
  214. "If this current attempt fails, or for more control in future runs, please specify a different port "
  215. "(e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection "
  216. "in your launch command or Accelerate config file."
  217. )
  218. else:
  219. raise ConnectionError(
  220. f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. "
  221. "Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)"
  222. " and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`."
  223. )
  224. if args.module and args.no_python:
  225. raise ValueError("--module and --no_python cannot be used together")
  226. elif args.module:
  227. args.module = True
  228. elif args.no_python:
  229. args.no_python = True
  230. current_env = os.environ.copy()
  231. if args.debug:
  232. current_env["ACCELERATE_DEBUG_MODE"] = "true"
  233. gpu_ids = getattr(args, "gpu_ids", "all")
  234. if gpu_ids != "all" and args.gpu_ids is not None:
  235. if is_xpu_available():
  236. current_env["ZE_AFFINITY_MASK"] = gpu_ids
  237. elif is_mlu_available():
  238. current_env["MLU_VISIBLE_DEVICES"] = gpu_ids
  239. elif is_sdaa_available():
  240. current_env["SDAA_VISIBLE_DEVICES"] = gpu_ids
  241. elif is_musa_available():
  242. current_env["MUSA_VISIBLE_DEVICES"] = gpu_ids
  243. elif is_npu_available():
  244. current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids
  245. elif is_hpu_available():
  246. current_env["HABANA_VISIBLE_MODULES"] = gpu_ids
  247. else:
  248. current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids
  249. mixed_precision = args.mixed_precision.lower()
  250. try:
  251. mixed_precision = PrecisionType(mixed_precision)
  252. except ValueError:
  253. raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.")
  254. current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
  255. if args.mixed_precision.lower() == "fp8":
  256. if not is_fp8_available():
  257. raise RuntimeError(
  258. "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
  259. )
  260. current_env = setup_fp8_env(args, current_env)
  261. try:
  262. dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
  263. except ValueError:
  264. raise ValueError(
  265. f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
  266. )
  267. current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value
  268. current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode
  269. current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph)
  270. current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic)
  271. current_env["ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION"] = str(args.dynamo_use_regional_compilation)
  272. if args.use_fsdp:
  273. current_env["ACCELERATE_USE_FSDP"] = "true"
  274. if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states:
  275. raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`")
  276. current_env["FSDP_VERSION"] = str(args.fsdp_version) if hasattr(args, "fsdp_version") else "1"
  277. # For backwards compatibility, we support this in launched scripts,
  278. # however, we do not ask users for this in `accelerate config` CLI
  279. current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy)
  280. current_env["FSDP_RESHARD_AFTER_FORWARD"] = str(args.fsdp_reshard_after_forward).lower()
  281. current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower()
  282. current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params)
  283. if args.fsdp_auto_wrap_policy is not None:
  284. current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy)
  285. if args.fsdp_transformer_layer_cls_to_wrap is not None:
  286. current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap)
  287. if args.fsdp_backward_prefetch is not None:
  288. current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch)
  289. if args.fsdp_state_dict_type is not None:
  290. current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type)
  291. current_env["FSDP_FORWARD_PREFETCH"] = str(args.fsdp_forward_prefetch).lower()
  292. current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower()
  293. current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
  294. current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
  295. current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
  296. if getattr(args, "fsdp_ignored_modules", None) is not None:
  297. current_env["FSDP_IGNORED_MODULES"] = str(args.fsdp_ignored_modules)
  298. if args.use_megatron_lm:
  299. prefix = "MEGATRON_LM_"
  300. current_env["ACCELERATE_USE_MEGATRON_LM"] = "true"
  301. current_env[prefix + "TP_DEGREE"] = str(args.megatron_lm_tp_degree)
  302. current_env[prefix + "PP_DEGREE"] = str(args.megatron_lm_pp_degree)
  303. current_env[prefix + "GRADIENT_CLIPPING"] = str(args.megatron_lm_gradient_clipping)
  304. if args.megatron_lm_num_micro_batches is not None:
  305. current_env[prefix + "NUM_MICRO_BATCHES"] = str(args.megatron_lm_num_micro_batches)
  306. if args.megatron_lm_sequence_parallelism is not None:
  307. current_env[prefix + "SEQUENCE_PARALLELISM"] = str(args.megatron_lm_sequence_parallelism)
  308. if args.megatron_lm_recompute_activations is not None:
  309. current_env[prefix + "RECOMPUTE_ACTIVATIONS"] = str(args.megatron_lm_recompute_activations)
  310. if args.megatron_lm_use_distributed_optimizer is not None:
  311. current_env[prefix + "USE_DISTRIBUTED_OPTIMIZER"] = str(args.megatron_lm_use_distributed_optimizer)
  312. current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
  313. if args.enable_cpu_affinity:
  314. current_env["ACCELERATE_CPU_AFFINITY"] = "1"
  315. if args.use_parallelism_config:
  316. current_env = prepare_extend_env_parallelism_config(args, current_env)
  317. return current_env
  318. def prepare_extend_env_parallelism_config(
  319. args: argparse.Namespace, current_env: dict
  320. ) -> tuple[list[str], dict[str, str]]:
  321. """
  322. Extends `current_env` with context parallelism env vars if any have been set
  323. """
  324. prefix = "PARALLELISM_CONFIG_"
  325. current_env["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
  326. current_env[prefix + "DP_REPLICATE_SIZE"] = str(args.parallelism_config_dp_replicate_size)
  327. current_env[prefix + "DP_SHARD_SIZE"] = str(args.parallelism_config_dp_shard_size)
  328. current_env[prefix + "TP_SIZE"] = str(args.parallelism_config_tp_size)
  329. current_env[prefix + "CP_SIZE"] = str(args.parallelism_config_cp_size)
  330. current_env[prefix + "CP_BACKEND"] = str(args.parallelism_config_cp_backend)
  331. current_env[prefix + "SP_SIZE"] = str(args.parallelism_config_sp_size)
  332. current_env[prefix + "SP_BACKEND"] = str(args.parallelism_config_sp_backend)
  333. if args.parallelism_config_cp_size > 1:
  334. current_env[prefix + "CP_COMM_STRATEGY"] = str(args.parallelism_config_cp_comm_strategy)
  335. if args.parallelism_config_sp_size > 1:
  336. current_env[prefix + "SP_SEQ_LENGTH"] = str(args.parallelism_config_sp_seq_length)
  337. current_env[prefix + "SP_SEQ_LENGTH_IS_VARIABLE"] = str(args.parallelism_config_sp_seq_length_is_variable)
  338. current_env[prefix + "SP_ATTN_IMPLEMENTATION"] = str(args.parallelism_config_sp_attn_implementation)
  339. return current_env
  340. def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]:
  341. """
  342. Prepares and returns the command list and an environment with the correct DeepSpeed environment variables.
  343. """
  344. # get free port and update configurations
  345. if args.main_process_port == 0:
  346. args.main_process_port = get_free_port()
  347. elif args.main_process_port is None:
  348. args.main_process_port = 29500
  349. num_processes = args.num_processes
  350. num_machines = args.num_machines
  351. main_process_ip = args.main_process_ip
  352. main_process_port = args.main_process_port
  353. cmd = None
  354. # make sure launcher is not None
  355. if args.deepspeed_multinode_launcher is None:
  356. # set to default pdsh
  357. args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0]
  358. if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
  359. cmd = ["deepspeed"]
  360. cmd.extend(["--hostfile", str(args.deepspeed_hostfile)])
  361. if args.deepspeed_multinode_launcher == "nossh":
  362. if compare_versions("deepspeed", "<", "0.14.5"):
  363. raise ValueError("nossh launcher requires DeepSpeed >= 0.14.5")
  364. cmd.extend(["--node_rank", str(args.machine_rank), "--no_ssh"])
  365. else:
  366. cmd.extend(["--no_local_rank", "--launcher", str(args.deepspeed_multinode_launcher)])
  367. if args.deepspeed_exclusion_filter is not None:
  368. cmd.extend(
  369. [
  370. "--exclude",
  371. str(args.deepspeed_exclusion_filter),
  372. ]
  373. )
  374. elif args.deepspeed_inclusion_filter is not None:
  375. cmd.extend(
  376. [
  377. "--include",
  378. str(args.deepspeed_inclusion_filter),
  379. ]
  380. )
  381. else:
  382. cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)])
  383. if main_process_ip:
  384. cmd.extend(["--master_addr", str(main_process_ip)])
  385. cmd.extend(["--master_port", str(main_process_port)])
  386. if args.module and args.no_python:
  387. raise ValueError("--module and --no_python cannot be used together")
  388. elif args.module:
  389. cmd.append("--module")
  390. elif args.no_python:
  391. cmd.append("--no_python")
  392. cmd.append(args.training_script)
  393. cmd.extend(args.training_script_args)
  394. elif num_machines > 1 and args.deepspeed_multinode_launcher == DEEPSPEED_MULTINODE_LAUNCHERS[1]:
  395. args.nproc_per_node = str(num_processes // num_machines)
  396. args.nnodes = str(num_machines)
  397. args.node_rank = int(args.machine_rank)
  398. if getattr(args, "same_network", False):
  399. args.master_addr = str(main_process_ip)
  400. args.master_port = str(main_process_port)
  401. else:
  402. args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}"
  403. else:
  404. args.nproc_per_node = str(num_processes)
  405. if main_process_port is not None:
  406. args.master_port = str(main_process_port)
  407. # only need to check port availability in main process, in case we have to start multiple launchers on the same machine
  408. # for some reasons like splitting log files.
  409. need_port_check = num_machines <= 1 or int(args.machine_rank) == 0
  410. if need_port_check and is_port_in_use(main_process_port):
  411. if num_machines <= 1:
  412. args.standalone = True
  413. warnings.warn(
  414. f"Port `{main_process_port}` is already in use. "
  415. "Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. "
  416. "If this current attempt fails, or for more control in future runs, please specify a different port "
  417. "(e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection "
  418. "in your launch command or Accelerate config file."
  419. )
  420. else:
  421. raise ConnectionError(
  422. f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. "
  423. "Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)"
  424. " and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`."
  425. )
  426. if args.module and args.no_python:
  427. raise ValueError("--module and --no_python cannot be used together")
  428. elif args.module:
  429. args.module = True
  430. elif args.no_python:
  431. args.no_python = True
  432. current_env = os.environ.copy()
  433. if args.debug:
  434. current_env["ACCELERATE_DEBUG_MODE"] = "true"
  435. gpu_ids = getattr(args, "gpu_ids", "all")
  436. if gpu_ids != "all" and args.gpu_ids is not None:
  437. if is_xpu_available():
  438. current_env["ZE_AFFINITY_MASK"] = gpu_ids
  439. elif is_mlu_available():
  440. current_env["MLU_VISIBLE_DEVICES"] = gpu_ids
  441. elif is_sdaa_available():
  442. current_env["SDAA_VISIBLE_DEVICES"] = gpu_ids
  443. elif is_musa_available():
  444. current_env["MUSA_VISIBLE_DEVICES"] = gpu_ids
  445. elif is_npu_available():
  446. current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids
  447. elif is_hpu_available():
  448. current_env["HABANA_VISIBLE_MODULES"] = gpu_ids
  449. else:
  450. current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids
  451. try:
  452. mixed_precision = PrecisionType(args.mixed_precision.lower())
  453. except ValueError:
  454. raise ValueError(
  455. f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
  456. )
  457. current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath("."))
  458. current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
  459. if args.mixed_precision.lower() == "fp8":
  460. if not is_fp8_available():
  461. raise RuntimeError(
  462. "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
  463. )
  464. current_env = setup_fp8_env(args, current_env)
  465. current_env["ACCELERATE_CONFIG_DS_FIELDS"] = str(args.deepspeed_fields_from_accelerate_config).lower()
  466. current_env["ACCELERATE_USE_DEEPSPEED"] = "true"
  467. if args.zero_stage is not None:
  468. current_env["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage)
  469. if args.gradient_accumulation_steps is not None:
  470. current_env["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps)
  471. if args.gradient_clipping is not None:
  472. current_env["ACCELERATE_GRADIENT_CLIPPING"] = str(args.gradient_clipping).lower()
  473. if args.offload_optimizer_device is not None:
  474. current_env["ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower()
  475. if args.offload_param_device is not None:
  476. current_env["ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower()
  477. if args.zero3_init_flag is not None:
  478. current_env["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower()
  479. if args.zero3_save_16bit_model is not None:
  480. current_env["ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower()
  481. if args.deepspeed_config_file is not None:
  482. current_env["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file)
  483. if args.enable_cpu_affinity:
  484. current_env["ACCELERATE_CPU_AFFINITY"] = "1"
  485. if args.deepspeed_moe_layer_cls_names is not None:
  486. current_env["ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES"] = str(args.deepspeed_moe_layer_cls_names)
  487. if args.use_parallelism_config:
  488. current_env = prepare_extend_env_parallelism_config(args, current_env)
  489. return cmd, current_env
  490. def prepare_tpu(
  491. args: argparse.Namespace, current_env: dict[str, str], pod: bool = False
  492. ) -> tuple[argparse.Namespace, dict[str, str]]:
  493. """
  494. Prepares and returns an environment with the correct TPU environment variables.
  495. """
  496. if args.mixed_precision == "bf16" and is_torch_xla_available(check_is_tpu=True):
  497. if args.downcast_bf16:
  498. current_env["XLA_DOWNCAST_BF16"] = "1"
  499. else:
  500. current_env["XLA_USE_BF16"] = "1"
  501. if args.debug:
  502. current_env["ACCELERATE_DEBUG_MODE"] = "true"
  503. if pod:
  504. # Take explicit args and set them up for XLA
  505. args.vm = args.tpu_vm
  506. args.tpu = args.tpu_name
  507. return args, current_env
  508. def _convert_nargs_to_dict(nargs: list[str]) -> dict[str, str]:
  509. if len(nargs) < 0:
  510. return {}
  511. # helper function to infer type for argsparser
  512. def _infer_type(s):
  513. try:
  514. s = float(s)
  515. if s // 1 == s:
  516. return int(s)
  517. return s
  518. except ValueError:
  519. return s
  520. parser = argparse.ArgumentParser()
  521. _, unknown = parser.parse_known_args(nargs)
  522. for index, argument in enumerate(unknown):
  523. if argument.startswith(("-", "--")):
  524. action = None
  525. if index + 1 < len(unknown): # checks if next index would be in list
  526. if unknown[index + 1].startswith(("-", "--")): # checks if next element is an key
  527. # raise an error if element is store_true or store_false
  528. raise ValueError(
  529. "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types"
  530. )
  531. else: # raise an error if last element is store_true or store_false
  532. raise ValueError(
  533. "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types"
  534. )
  535. # adds argument to parser based on action_store true
  536. if action is None:
  537. parser.add_argument(argument, type=_infer_type)
  538. else:
  539. parser.add_argument(argument, action=action)
  540. return {
  541. key: (literal_eval(value) if value in ("True", "False") else value)
  542. for key, value in parser.parse_args(nargs).__dict__.items()
  543. }
  544. def prepare_sagemager_args_inputs(
  545. sagemaker_config: SageMakerConfig, args: argparse.Namespace
  546. ) -> tuple[argparse.Namespace, dict[str, Any]]:
  547. # configure environment
  548. print("Configuring Amazon SageMaker environment")
  549. os.environ["AWS_DEFAULT_REGION"] = sagemaker_config.region
  550. # configure credentials
  551. if sagemaker_config.profile is not None:
  552. os.environ["AWS_PROFILE"] = sagemaker_config.profile
  553. elif args.aws_access_key_id is not None and args.aws_secret_access_key is not None:
  554. os.environ["AWS_ACCESS_KEY_ID"] = args.aws_access_key_id
  555. os.environ["AWS_SECRET_ACCESS_KEY"] = args.aws_secret_access_key
  556. else:
  557. raise OSError("You need to provide an aws_access_key_id and aws_secret_access_key when not using aws_profile")
  558. # extract needed arguments
  559. source_dir = os.path.dirname(args.training_script)
  560. if not source_dir: # checks if string is empty
  561. source_dir = "."
  562. entry_point = os.path.basename(args.training_script)
  563. if not entry_point.endswith(".py"):
  564. raise ValueError(f'Your training script should be a python script and not "{entry_point}"')
  565. print("Converting Arguments to Hyperparameters")
  566. hyperparameters = _convert_nargs_to_dict(args.training_script_args)
  567. try:
  568. mixed_precision = PrecisionType(args.mixed_precision.lower())
  569. except ValueError:
  570. raise ValueError(
  571. f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
  572. )
  573. try:
  574. dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
  575. except ValueError:
  576. raise ValueError(
  577. f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
  578. )
  579. # Environment variables to be set for use during training job
  580. environment = {
  581. "ACCELERATE_USE_SAGEMAKER": "true",
  582. "ACCELERATE_MIXED_PRECISION": str(mixed_precision),
  583. "ACCELERATE_DYNAMO_BACKEND": dynamo_backend.value,
  584. "ACCELERATE_DYNAMO_MODE": args.dynamo_mode,
  585. "ACCELERATE_DYNAMO_USE_FULLGRAPH": str(args.dynamo_use_fullgraph),
  586. "ACCELERATE_DYNAMO_USE_DYNAMIC": str(args.dynamo_use_dynamic),
  587. "ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION": str(args.dynamo_use_regional_compilation),
  588. "ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE": sagemaker_config.distributed_type.value,
  589. }
  590. if args.mixed_precision.lower() == "fp8":
  591. if not is_fp8_available():
  592. raise RuntimeError(
  593. "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
  594. )
  595. environment = setup_fp8_env(args, environment)
  596. # configure distribution set up
  597. distribution = None
  598. if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL:
  599. distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}
  600. # configure sagemaker inputs
  601. sagemaker_inputs = None
  602. if sagemaker_config.sagemaker_inputs_file is not None:
  603. print(f"Loading SageMaker Inputs from {sagemaker_config.sagemaker_inputs_file} file")
  604. sagemaker_inputs = {}
  605. with open(sagemaker_config.sagemaker_inputs_file) as file:
  606. for i, line in enumerate(file):
  607. if i == 0:
  608. continue
  609. l = line.split("\t")
  610. sagemaker_inputs[l[0]] = l[1].strip()
  611. print(f"Loaded SageMaker Inputs: {sagemaker_inputs}")
  612. # configure sagemaker metrics
  613. sagemaker_metrics = None
  614. if sagemaker_config.sagemaker_metrics_file is not None:
  615. print(f"Loading SageMaker Metrics from {sagemaker_config.sagemaker_metrics_file} file")
  616. sagemaker_metrics = []
  617. with open(sagemaker_config.sagemaker_metrics_file) as file:
  618. for i, line in enumerate(file):
  619. if i == 0:
  620. continue
  621. l = line.split("\t")
  622. metric_dict = {
  623. "Name": l[0],
  624. "Regex": l[1].strip(),
  625. }
  626. sagemaker_metrics.append(metric_dict)
  627. print(f"Loaded SageMaker Metrics: {sagemaker_metrics}")
  628. # configure session
  629. print("Creating Estimator")
  630. args = {
  631. "image_uri": sagemaker_config.image_uri,
  632. "entry_point": entry_point,
  633. "source_dir": source_dir,
  634. "role": sagemaker_config.iam_role_name,
  635. "transformers_version": sagemaker_config.transformers_version,
  636. "pytorch_version": sagemaker_config.pytorch_version,
  637. "py_version": sagemaker_config.py_version,
  638. "base_job_name": sagemaker_config.base_job_name,
  639. "instance_count": sagemaker_config.num_machines,
  640. "instance_type": sagemaker_config.ec2_instance_type,
  641. "debugger_hook_config": False,
  642. "distribution": distribution,
  643. "hyperparameters": hyperparameters,
  644. "environment": environment,
  645. "metric_definitions": sagemaker_metrics,
  646. }
  647. if sagemaker_config.additional_args is not None:
  648. args = merge_dicts(sagemaker_config.additional_args, args)
  649. return args, sagemaker_inputs
  650. def env_var_path_add(env_var_name, path_to_add):
  651. """
  652. Extends a path-based environment variable's value with a new path and returns the updated value. It's up to the
  653. caller to set it in os.environ.
  654. """
  655. paths = [p for p in os.environ.get(env_var_name, "").split(":") if len(p) > 0]
  656. paths.append(str(path_to_add))
  657. return ":".join(paths)
  658. class PrepareForLaunch:
  659. """
  660. Prepare a function that will launched in a distributed setup.
  661. Args:
  662. launcher (`Callable`):
  663. The function to launch.
  664. distributed_type ([`~state.DistributedType`]):
  665. The distributed type to prepare for.
  666. debug (`bool`, *optional*, defaults to `False`):
  667. Whether or not this is a debug launch.
  668. """
  669. def __init__(self, launcher, distributed_type="NO", debug=False):
  670. self.launcher = launcher
  671. self.distributed_type = DistributedType(distributed_type)
  672. self.debug = debug
  673. def __call__(self, index, *args):
  674. if self.debug:
  675. world_size = int(os.environ.get("WORLD_SIZE"))
  676. rdv_file = os.environ.get("ACCELERATE_DEBUG_RDV_FILE")
  677. torch.distributed.init_process_group(
  678. "gloo",
  679. rank=index,
  680. store=torch.distributed.FileStore(rdv_file, world_size),
  681. world_size=world_size,
  682. )
  683. elif self.distributed_type in (
  684. DistributedType.MULTI_GPU,
  685. DistributedType.MULTI_MLU,
  686. DistributedType.MULTI_MUSA,
  687. DistributedType.MULTI_NPU,
  688. DistributedType.MULTI_XPU,
  689. DistributedType.MULTI_CPU,
  690. ):
  691. # Prepare the environment for torch.distributed
  692. os.environ["LOCAL_RANK"] = str(index)
  693. nproc = int(os.environ.get("NPROC", 1))
  694. node_rank = int(os.environ.get("NODE_RANK", 0))
  695. os.environ["RANK"] = str(nproc * node_rank + index)
  696. os.environ["FORK_LAUNCHED"] = str(1)
  697. self.launcher(*args)