imports.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import importlib.metadata
  16. import os
  17. import sys
  18. import warnings
  19. from functools import lru_cache, wraps
  20. import torch
  21. from packaging import version
  22. from packaging.version import parse
  23. from .environment import parse_flag_from_env, patch_environment, str_to_bool
  24. from .versions import compare_versions, is_torch_version
  25. # Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
  26. USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", default=True)
  27. _torch_xla_available = False
  28. if USE_TORCH_XLA:
  29. try:
  30. import torch_xla.core.xla_model as xm # noqa: F401
  31. import torch_xla.runtime
  32. _torch_xla_available = True
  33. except ImportError:
  34. pass
  35. # Keep it for is_tpu_available. It will be removed along with is_tpu_available.
  36. _tpu_available = _torch_xla_available
  37. # Cache this result has it's a C FFI call which can be pretty time-consuming
  38. _torch_distributed_available = torch.distributed.is_available()
  39. def _is_package_available(pkg_name, metadata_name=None):
  40. # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
  41. package_exists = importlib.util.find_spec(pkg_name) is not None
  42. if package_exists:
  43. try:
  44. # Some libraries have different names in the metadata
  45. _ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name)
  46. return True
  47. except importlib.metadata.PackageNotFoundError:
  48. return False
  49. def is_torch_distributed_available() -> bool:
  50. return _torch_distributed_available
  51. def is_xccl_available():
  52. if is_torch_version(">=", "2.7.0"):
  53. return torch.distributed.distributed_c10d.is_xccl_available()
  54. if is_ipex_available():
  55. return False
  56. return False
  57. def is_ccl_available():
  58. try:
  59. pass
  60. except ImportError:
  61. print(
  62. "Intel(R) oneCCL Bindings for PyTorch* is required to run DDP on Intel(R) XPUs, but it is not"
  63. " detected. If you see \"ValueError: Invalid backend: 'ccl'\" error, please install Intel(R) oneCCL"
  64. " Bindings for PyTorch*."
  65. )
  66. return importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
  67. def get_ccl_version():
  68. return importlib.metadata.version("oneccl_bind_pt")
  69. def is_import_timer_available():
  70. return _is_package_available("import_timer")
  71. def is_pynvml_available():
  72. return _is_package_available("pynvml") or _is_package_available("pynvml", "nvidia-ml-py")
  73. def is_pytest_available():
  74. return _is_package_available("pytest")
  75. def is_msamp_available():
  76. return _is_package_available("msamp", "ms-amp")
  77. def is_schedulefree_available():
  78. return _is_package_available("schedulefree")
  79. def is_transformer_engine_available():
  80. if is_hpu_available():
  81. return _is_package_available("intel_transformer_engine", "intel-transformer-engine")
  82. else:
  83. return _is_package_available("transformer_engine", "transformer-engine")
  84. def is_transformer_engine_mxfp8_available():
  85. if _is_package_available("transformer_engine", "transformer-engine"):
  86. import transformer_engine.pytorch as te
  87. return te.fp8.check_mxfp8_support()[0]
  88. return False
  89. def is_lomo_available():
  90. return _is_package_available("lomo_optim")
  91. def is_cuda_available():
  92. """
  93. Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda
  94. uninitialized.
  95. """
  96. with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK="1"):
  97. available = torch.cuda.is_available()
  98. return available
  99. @lru_cache
  100. def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
  101. """
  102. Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
  103. the USE_TORCH_XLA to false.
  104. """
  105. assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
  106. if not _torch_xla_available:
  107. return False
  108. elif check_is_gpu:
  109. return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
  110. elif check_is_tpu:
  111. return torch_xla.runtime.device_type() == "TPU"
  112. return True
  113. def is_torchao_available():
  114. package_exists = _is_package_available("torchao")
  115. if package_exists:
  116. torchao_version = version.parse(importlib.metadata.version("torchao"))
  117. return compare_versions(torchao_version, ">=", "0.6.1")
  118. return False
  119. def is_deepspeed_available():
  120. return _is_package_available("deepspeed")
  121. def is_pippy_available():
  122. return is_torch_version(">=", "2.4.0")
  123. def is_bf16_available(ignore_tpu=False):
  124. "Checks if bf16 is supported, optionally ignoring the TPU"
  125. if is_torch_xla_available(check_is_tpu=True):
  126. return not ignore_tpu
  127. if is_cuda_available():
  128. return torch.cuda.is_bf16_supported()
  129. if is_mlu_available():
  130. return torch.mlu.is_bf16_supported()
  131. if is_xpu_available():
  132. return torch.xpu.is_bf16_supported()
  133. if is_mps_available():
  134. return torch.backends.mps.is_macos_or_newer(14, 0)
  135. return True
  136. def is_fp16_available():
  137. "Checks if fp16 is supported"
  138. if is_habana_gaudi1():
  139. return False
  140. return True
  141. def is_fp8_available():
  142. "Checks if fp8 is supported"
  143. return is_msamp_available() or is_transformer_engine_available() or is_torchao_available()
  144. def is_4bit_bnb_available():
  145. package_exists = _is_package_available("bitsandbytes")
  146. if package_exists:
  147. bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
  148. return compare_versions(bnb_version, ">=", "0.39.0")
  149. return False
  150. def is_8bit_bnb_available():
  151. package_exists = _is_package_available("bitsandbytes")
  152. if package_exists:
  153. bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
  154. return compare_versions(bnb_version, ">=", "0.37.2")
  155. return False
  156. def is_bnb_available(min_version=None):
  157. package_exists = _is_package_available("bitsandbytes")
  158. if package_exists and min_version is not None:
  159. bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
  160. return compare_versions(bnb_version, ">=", min_version)
  161. else:
  162. return package_exists
  163. def is_bitsandbytes_multi_backend_available():
  164. if not is_bnb_available():
  165. return False
  166. import bitsandbytes as bnb
  167. return "multi_backend" in getattr(bnb, "features", set())
  168. def is_torchvision_available():
  169. return _is_package_available("torchvision")
  170. def is_megatron_lm_available():
  171. if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:
  172. if importlib.util.find_spec("megatron") is not None:
  173. try:
  174. megatron_version = parse(importlib.metadata.version("megatron-core"))
  175. if compare_versions(megatron_version, ">=", "0.8.0"):
  176. return importlib.util.find_spec(".training", "megatron")
  177. except Exception as e:
  178. warnings.warn(f"Parse Megatron version failed. Exception:{e}")
  179. return False
  180. def is_transformers_available():
  181. return _is_package_available("transformers")
  182. def is_datasets_available():
  183. return _is_package_available("datasets")
  184. def is_peft_available():
  185. return _is_package_available("peft")
  186. def is_timm_available():
  187. return _is_package_available("timm")
  188. def is_triton_available():
  189. if is_xpu_available():
  190. return _is_package_available("triton", "pytorch-triton-xpu")
  191. return _is_package_available("triton")
  192. def is_aim_available():
  193. package_exists = _is_package_available("aim")
  194. if package_exists:
  195. aim_version = version.parse(importlib.metadata.version("aim"))
  196. return compare_versions(aim_version, "<", "4.0.0")
  197. return False
  198. def is_tensorboard_available():
  199. return _is_package_available("tensorboard") or _is_package_available("tensorboardX")
  200. def is_wandb_available():
  201. return _is_package_available("wandb")
  202. def is_comet_ml_available():
  203. return _is_package_available("comet_ml")
  204. def is_swanlab_available():
  205. return _is_package_available("swanlab")
  206. def is_trackio_available():
  207. return sys.version_info >= (3, 10) and _is_package_available("trackio")
  208. def is_boto3_available():
  209. return _is_package_available("boto3")
  210. def is_rich_available():
  211. if _is_package_available("rich"):
  212. return parse_flag_from_env("ACCELERATE_ENABLE_RICH", False)
  213. return False
  214. def is_sagemaker_available():
  215. return _is_package_available("sagemaker")
  216. def is_tqdm_available():
  217. return _is_package_available("tqdm")
  218. def is_clearml_available():
  219. return _is_package_available("clearml")
  220. def is_pandas_available():
  221. return _is_package_available("pandas")
  222. def is_matplotlib_available():
  223. return _is_package_available("matplotlib")
  224. def is_mlflow_available():
  225. if _is_package_available("mlflow"):
  226. return True
  227. if importlib.util.find_spec("mlflow") is not None:
  228. try:
  229. _ = importlib.metadata.metadata("mlflow-skinny")
  230. return True
  231. except importlib.metadata.PackageNotFoundError:
  232. return False
  233. return False
  234. def is_mps_available(min_version="1.12"):
  235. "Checks if MPS device is available. The minimum version required is 1.12."
  236. # With torch 1.12, you can use torch.backends.mps
  237. # With torch 2.0.0, you can use torch.mps
  238. return is_torch_version(">=", min_version) and torch.backends.mps.is_available() and torch.backends.mps.is_built()
  239. def is_ipex_available():
  240. "Checks if ipex is installed."
  241. def get_major_and_minor_from_version(full_version):
  242. return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
  243. _torch_version = importlib.metadata.version("torch")
  244. if importlib.util.find_spec("intel_extension_for_pytorch") is None:
  245. return False
  246. _ipex_version = "N/A"
  247. try:
  248. _ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
  249. except importlib.metadata.PackageNotFoundError:
  250. return False
  251. torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
  252. ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
  253. if torch_major_and_minor != ipex_major_and_minor:
  254. warnings.warn(
  255. f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
  256. f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
  257. )
  258. return False
  259. return True
  260. @lru_cache
  261. def is_mlu_available(check_device=False):
  262. """
  263. Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
  264. uninitialized.
  265. """
  266. if importlib.util.find_spec("torch_mlu") is None:
  267. return False
  268. import torch_mlu # noqa: F401
  269. with patch_environment(PYTORCH_CNDEV_BASED_MLU_CHECK="1"):
  270. available = torch.mlu.is_available()
  271. return available
  272. @lru_cache
  273. def is_musa_available(check_device=False):
  274. "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment"
  275. if importlib.util.find_spec("torch_musa") is None:
  276. return False
  277. import torch_musa # noqa: F401
  278. if check_device:
  279. try:
  280. # Will raise a RuntimeError if no MUSA is found
  281. _ = torch.musa.device_count()
  282. return torch.musa.is_available()
  283. except RuntimeError:
  284. return False
  285. return hasattr(torch, "musa") and torch.musa.is_available()
  286. @lru_cache
  287. def is_npu_available(check_device=False):
  288. "Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
  289. if importlib.util.find_spec("torch_npu") is None:
  290. return False
  291. # NOTE: importing torch_npu may raise error in some envs
  292. # e.g. inside cpu-only container with torch_npu installed
  293. try:
  294. import torch_npu # noqa: F401
  295. except Exception:
  296. return False
  297. if check_device:
  298. try:
  299. # Will raise a RuntimeError if no NPU is found
  300. _ = torch.npu.device_count()
  301. return torch.npu.is_available()
  302. except RuntimeError:
  303. return False
  304. return hasattr(torch, "npu") and torch.npu.is_available()
  305. @lru_cache
  306. def is_sdaa_available(check_device=False):
  307. "Checks if `torch_sdaa` is installed and potentially if a SDAA is in the environment"
  308. if importlib.util.find_spec("torch_sdaa") is None:
  309. return False
  310. import torch_sdaa # noqa: F401
  311. if check_device:
  312. try:
  313. # Will raise a RuntimeError if no NPU is found
  314. _ = torch.sdaa.device_count()
  315. return torch.sdaa.is_available()
  316. except RuntimeError:
  317. return False
  318. return hasattr(torch, "sdaa") and torch.sdaa.is_available()
  319. @lru_cache
  320. def is_hpu_available(init_hccl=False):
  321. "Checks if `torch.hpu` is installed and potentially if a HPU is in the environment"
  322. if (
  323. importlib.util.find_spec("habana_frameworks") is None
  324. or importlib.util.find_spec("habana_frameworks.torch") is None
  325. ):
  326. return False
  327. import habana_frameworks.torch # noqa: F401
  328. if init_hccl:
  329. import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401
  330. return hasattr(torch, "hpu") and torch.hpu.is_available()
  331. def is_habana_gaudi1():
  332. if is_hpu_available():
  333. import habana_frameworks.torch.utils.experimental as htexp # noqa: F401
  334. if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi:
  335. return True
  336. return False
  337. @lru_cache
  338. def is_xpu_available(check_device=False):
  339. """
  340. Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or via stock PyTorch (>=2.4) and
  341. potentially if a XPU is in the environment
  342. """
  343. if is_ipex_available():
  344. import intel_extension_for_pytorch # noqa: F401
  345. else:
  346. if is_torch_version("<=", "2.3"):
  347. return False
  348. if check_device:
  349. try:
  350. # Will raise a RuntimeError if no XPU is found
  351. _ = torch.xpu.device_count()
  352. return torch.xpu.is_available()
  353. except RuntimeError:
  354. return False
  355. return hasattr(torch, "xpu") and torch.xpu.is_available()
  356. def is_dvclive_available():
  357. return _is_package_available("dvclive")
  358. def is_torchdata_available():
  359. return _is_package_available("torchdata")
  360. # TODO: Remove this function once stateful_dataloader is a stable feature in torchdata.
  361. def is_torchdata_stateful_dataloader_available():
  362. package_exists = _is_package_available("torchdata")
  363. if package_exists:
  364. torchdata_version = version.parse(importlib.metadata.version("torchdata"))
  365. return compare_versions(torchdata_version, ">=", "0.8.0")
  366. return False
  367. def torchao_required(func):
  368. """
  369. A decorator that ensures the decorated function is only called when torchao is available.
  370. """
  371. @wraps(func)
  372. def wrapper(*args, **kwargs):
  373. if not is_torchao_available():
  374. raise ImportError(
  375. "`torchao` is not available, please install it before calling this function via `pip install torchao`."
  376. )
  377. return func(*args, **kwargs)
  378. return wrapper
  379. # TODO: Rework this into `utils.deepspeed` and migrate the "core" chunks into `accelerate.deepspeed`
  380. def deepspeed_required(func):
  381. """
  382. A decorator that ensures the decorated function is only called when deepspeed is enabled.
  383. """
  384. @wraps(func)
  385. def wrapper(*args, **kwargs):
  386. from accelerate.state import AcceleratorState
  387. from accelerate.utils.dataclasses import DistributedType
  388. if AcceleratorState._shared_state != {} and AcceleratorState().distributed_type != DistributedType.DEEPSPEED:
  389. raise ValueError(
  390. "DeepSpeed is not enabled, please make sure that an `Accelerator` is configured for `deepspeed` "
  391. "before calling this function."
  392. )
  393. return func(*args, **kwargs)
  394. return wrapper
  395. def is_weights_only_available():
  396. # Weights only with allowlist was added in 2.4.0
  397. # ref: https://github.com/pytorch/pytorch/pull/124331
  398. return is_torch_version(">=", "2.4.0")
  399. def is_numpy_available(min_version="1.25.0"):
  400. numpy_version = parse(importlib.metadata.version("numpy"))
  401. return compare_versions(numpy_version, ">=", min_version)