| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564 |
- # Copyright 2022 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import importlib
- import importlib.metadata
- import os
- import sys
- import warnings
- from functools import lru_cache, wraps
- import torch
- from packaging import version
- from packaging.version import parse
- from .environment import parse_flag_from_env, patch_environment, str_to_bool
- from .versions import compare_versions, is_torch_version
- # Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
- USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", default=True)
- _torch_xla_available = False
- if USE_TORCH_XLA:
- try:
- import torch_xla.core.xla_model as xm # noqa: F401
- import torch_xla.runtime
- _torch_xla_available = True
- except ImportError:
- pass
- # Keep it for is_tpu_available. It will be removed along with is_tpu_available.
- _tpu_available = _torch_xla_available
- # Cache this result has it's a C FFI call which can be pretty time-consuming
- _torch_distributed_available = torch.distributed.is_available()
- def _is_package_available(pkg_name, metadata_name=None):
- # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
- package_exists = importlib.util.find_spec(pkg_name) is not None
- if package_exists:
- try:
- # Some libraries have different names in the metadata
- _ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name)
- return True
- except importlib.metadata.PackageNotFoundError:
- return False
- def is_torch_distributed_available() -> bool:
- return _torch_distributed_available
- def is_xccl_available():
- if is_torch_version(">=", "2.7.0"):
- return torch.distributed.distributed_c10d.is_xccl_available()
- if is_ipex_available():
- return False
- return False
- def is_ccl_available():
- try:
- pass
- except ImportError:
- print(
- "Intel(R) oneCCL Bindings for PyTorch* is required to run DDP on Intel(R) XPUs, but it is not"
- " detected. If you see \"ValueError: Invalid backend: 'ccl'\" error, please install Intel(R) oneCCL"
- " Bindings for PyTorch*."
- )
- return importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
- def get_ccl_version():
- return importlib.metadata.version("oneccl_bind_pt")
- def is_import_timer_available():
- return _is_package_available("import_timer")
- def is_pynvml_available():
- return _is_package_available("pynvml") or _is_package_available("pynvml", "nvidia-ml-py")
- def is_pytest_available():
- return _is_package_available("pytest")
- def is_msamp_available():
- return _is_package_available("msamp", "ms-amp")
- def is_schedulefree_available():
- return _is_package_available("schedulefree")
- def is_transformer_engine_available():
- if is_hpu_available():
- return _is_package_available("intel_transformer_engine", "intel-transformer-engine")
- else:
- return _is_package_available("transformer_engine", "transformer-engine")
- def is_transformer_engine_mxfp8_available():
- if _is_package_available("transformer_engine", "transformer-engine"):
- import transformer_engine.pytorch as te
- return te.fp8.check_mxfp8_support()[0]
- return False
- def is_lomo_available():
- return _is_package_available("lomo_optim")
- def is_cuda_available():
- """
- Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda
- uninitialized.
- """
- with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK="1"):
- available = torch.cuda.is_available()
- return available
- @lru_cache
- def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
- """
- Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
- the USE_TORCH_XLA to false.
- """
- assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
- if not _torch_xla_available:
- return False
- elif check_is_gpu:
- return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
- elif check_is_tpu:
- return torch_xla.runtime.device_type() == "TPU"
- return True
- def is_torchao_available():
- package_exists = _is_package_available("torchao")
- if package_exists:
- torchao_version = version.parse(importlib.metadata.version("torchao"))
- return compare_versions(torchao_version, ">=", "0.6.1")
- return False
- def is_deepspeed_available():
- return _is_package_available("deepspeed")
- def is_pippy_available():
- return is_torch_version(">=", "2.4.0")
- def is_bf16_available(ignore_tpu=False):
- "Checks if bf16 is supported, optionally ignoring the TPU"
- if is_torch_xla_available(check_is_tpu=True):
- return not ignore_tpu
- if is_cuda_available():
- return torch.cuda.is_bf16_supported()
- if is_mlu_available():
- return torch.mlu.is_bf16_supported()
- if is_xpu_available():
- return torch.xpu.is_bf16_supported()
- if is_mps_available():
- return torch.backends.mps.is_macos_or_newer(14, 0)
- return True
- def is_fp16_available():
- "Checks if fp16 is supported"
- if is_habana_gaudi1():
- return False
- return True
- def is_fp8_available():
- "Checks if fp8 is supported"
- return is_msamp_available() or is_transformer_engine_available() or is_torchao_available()
- def is_4bit_bnb_available():
- package_exists = _is_package_available("bitsandbytes")
- if package_exists:
- bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
- return compare_versions(bnb_version, ">=", "0.39.0")
- return False
- def is_8bit_bnb_available():
- package_exists = _is_package_available("bitsandbytes")
- if package_exists:
- bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
- return compare_versions(bnb_version, ">=", "0.37.2")
- return False
- def is_bnb_available(min_version=None):
- package_exists = _is_package_available("bitsandbytes")
- if package_exists and min_version is not None:
- bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
- return compare_versions(bnb_version, ">=", min_version)
- else:
- return package_exists
- def is_bitsandbytes_multi_backend_available():
- if not is_bnb_available():
- return False
- import bitsandbytes as bnb
- return "multi_backend" in getattr(bnb, "features", set())
- def is_torchvision_available():
- return _is_package_available("torchvision")
- def is_megatron_lm_available():
- if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:
- if importlib.util.find_spec("megatron") is not None:
- try:
- megatron_version = parse(importlib.metadata.version("megatron-core"))
- if compare_versions(megatron_version, ">=", "0.8.0"):
- return importlib.util.find_spec(".training", "megatron")
- except Exception as e:
- warnings.warn(f"Parse Megatron version failed. Exception:{e}")
- return False
- def is_transformers_available():
- return _is_package_available("transformers")
- def is_datasets_available():
- return _is_package_available("datasets")
- def is_peft_available():
- return _is_package_available("peft")
- def is_timm_available():
- return _is_package_available("timm")
- def is_triton_available():
- if is_xpu_available():
- return _is_package_available("triton", "pytorch-triton-xpu")
- return _is_package_available("triton")
- def is_aim_available():
- package_exists = _is_package_available("aim")
- if package_exists:
- aim_version = version.parse(importlib.metadata.version("aim"))
- return compare_versions(aim_version, "<", "4.0.0")
- return False
- def is_tensorboard_available():
- return _is_package_available("tensorboard") or _is_package_available("tensorboardX")
- def is_wandb_available():
- return _is_package_available("wandb")
- def is_comet_ml_available():
- return _is_package_available("comet_ml")
- def is_swanlab_available():
- return _is_package_available("swanlab")
- def is_trackio_available():
- return sys.version_info >= (3, 10) and _is_package_available("trackio")
- def is_boto3_available():
- return _is_package_available("boto3")
- def is_rich_available():
- if _is_package_available("rich"):
- return parse_flag_from_env("ACCELERATE_ENABLE_RICH", False)
- return False
- def is_sagemaker_available():
- return _is_package_available("sagemaker")
- def is_tqdm_available():
- return _is_package_available("tqdm")
- def is_clearml_available():
- return _is_package_available("clearml")
- def is_pandas_available():
- return _is_package_available("pandas")
- def is_matplotlib_available():
- return _is_package_available("matplotlib")
- def is_mlflow_available():
- if _is_package_available("mlflow"):
- return True
- if importlib.util.find_spec("mlflow") is not None:
- try:
- _ = importlib.metadata.metadata("mlflow-skinny")
- return True
- except importlib.metadata.PackageNotFoundError:
- return False
- return False
- def is_mps_available(min_version="1.12"):
- "Checks if MPS device is available. The minimum version required is 1.12."
- # With torch 1.12, you can use torch.backends.mps
- # With torch 2.0.0, you can use torch.mps
- return is_torch_version(">=", min_version) and torch.backends.mps.is_available() and torch.backends.mps.is_built()
- def is_ipex_available():
- "Checks if ipex is installed."
- def get_major_and_minor_from_version(full_version):
- return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
- _torch_version = importlib.metadata.version("torch")
- if importlib.util.find_spec("intel_extension_for_pytorch") is None:
- return False
- _ipex_version = "N/A"
- try:
- _ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
- except importlib.metadata.PackageNotFoundError:
- return False
- torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
- ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
- if torch_major_and_minor != ipex_major_and_minor:
- warnings.warn(
- f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
- f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
- )
- return False
- return True
- @lru_cache
- def is_mlu_available(check_device=False):
- """
- Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
- uninitialized.
- """
- if importlib.util.find_spec("torch_mlu") is None:
- return False
- import torch_mlu # noqa: F401
- with patch_environment(PYTORCH_CNDEV_BASED_MLU_CHECK="1"):
- available = torch.mlu.is_available()
- return available
- @lru_cache
- def is_musa_available(check_device=False):
- "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment"
- if importlib.util.find_spec("torch_musa") is None:
- return False
- import torch_musa # noqa: F401
- if check_device:
- try:
- # Will raise a RuntimeError if no MUSA is found
- _ = torch.musa.device_count()
- return torch.musa.is_available()
- except RuntimeError:
- return False
- return hasattr(torch, "musa") and torch.musa.is_available()
- @lru_cache
- def is_npu_available(check_device=False):
- "Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
- if importlib.util.find_spec("torch_npu") is None:
- return False
- # NOTE: importing torch_npu may raise error in some envs
- # e.g. inside cpu-only container with torch_npu installed
- try:
- import torch_npu # noqa: F401
- except Exception:
- return False
- if check_device:
- try:
- # Will raise a RuntimeError if no NPU is found
- _ = torch.npu.device_count()
- return torch.npu.is_available()
- except RuntimeError:
- return False
- return hasattr(torch, "npu") and torch.npu.is_available()
- @lru_cache
- def is_sdaa_available(check_device=False):
- "Checks if `torch_sdaa` is installed and potentially if a SDAA is in the environment"
- if importlib.util.find_spec("torch_sdaa") is None:
- return False
- import torch_sdaa # noqa: F401
- if check_device:
- try:
- # Will raise a RuntimeError if no NPU is found
- _ = torch.sdaa.device_count()
- return torch.sdaa.is_available()
- except RuntimeError:
- return False
- return hasattr(torch, "sdaa") and torch.sdaa.is_available()
- @lru_cache
- def is_hpu_available(init_hccl=False):
- "Checks if `torch.hpu` is installed and potentially if a HPU is in the environment"
- if (
- importlib.util.find_spec("habana_frameworks") is None
- or importlib.util.find_spec("habana_frameworks.torch") is None
- ):
- return False
- import habana_frameworks.torch # noqa: F401
- if init_hccl:
- import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401
- return hasattr(torch, "hpu") and torch.hpu.is_available()
- def is_habana_gaudi1():
- if is_hpu_available():
- import habana_frameworks.torch.utils.experimental as htexp # noqa: F401
- if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi:
- return True
- return False
- @lru_cache
- def is_xpu_available(check_device=False):
- """
- Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or via stock PyTorch (>=2.4) and
- potentially if a XPU is in the environment
- """
- if is_ipex_available():
- import intel_extension_for_pytorch # noqa: F401
- else:
- if is_torch_version("<=", "2.3"):
- return False
- if check_device:
- try:
- # Will raise a RuntimeError if no XPU is found
- _ = torch.xpu.device_count()
- return torch.xpu.is_available()
- except RuntimeError:
- return False
- return hasattr(torch, "xpu") and torch.xpu.is_available()
- def is_dvclive_available():
- return _is_package_available("dvclive")
- def is_torchdata_available():
- return _is_package_available("torchdata")
- # TODO: Remove this function once stateful_dataloader is a stable feature in torchdata.
- def is_torchdata_stateful_dataloader_available():
- package_exists = _is_package_available("torchdata")
- if package_exists:
- torchdata_version = version.parse(importlib.metadata.version("torchdata"))
- return compare_versions(torchdata_version, ">=", "0.8.0")
- return False
- def torchao_required(func):
- """
- A decorator that ensures the decorated function is only called when torchao is available.
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
- if not is_torchao_available():
- raise ImportError(
- "`torchao` is not available, please install it before calling this function via `pip install torchao`."
- )
- return func(*args, **kwargs)
- return wrapper
- # TODO: Rework this into `utils.deepspeed` and migrate the "core" chunks into `accelerate.deepspeed`
- def deepspeed_required(func):
- """
- A decorator that ensures the decorated function is only called when deepspeed is enabled.
- """
- @wraps(func)
- def wrapper(*args, **kwargs):
- from accelerate.state import AcceleratorState
- from accelerate.utils.dataclasses import DistributedType
- if AcceleratorState._shared_state != {} and AcceleratorState().distributed_type != DistributedType.DEEPSPEED:
- raise ValueError(
- "DeepSpeed is not enabled, please make sure that an `Accelerator` is configured for `deepspeed` "
- "before calling this function."
- )
- return func(*args, **kwargs)
- return wrapper
- def is_weights_only_available():
- # Weights only with allowlist was added in 2.4.0
- # ref: https://github.com/pytorch/pytorch/pull/124331
- return is_torch_version(">=", "2.4.0")
- def is_numpy_available(min_version="1.25.0"):
- numpy_version = parse(importlib.metadata.version("numpy"))
- return compare_versions(numpy_version, ">=", min_version)
|