| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561 |
- # Copyright 2022 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import collections
- import platform
- import re
- import socket
- from codecs import encode
- from collections import OrderedDict
- from functools import partial, reduce
- from types import MethodType
- from typing import Optional
- import numpy as np
- import torch
- from packaging.version import Version
- from safetensors.torch import save_file as safe_save_file
- from ..commands.config.default import write_basic_config # noqa: F401
- from ..logging import get_logger
- from ..state import PartialState
- from .constants import FSDP_PYTORCH_VERSION
- from .dataclasses import DistributedType
- from .imports import (
- is_deepspeed_available,
- is_numpy_available,
- is_torch_distributed_available,
- is_torch_xla_available,
- is_weights_only_available,
- )
- from .modeling import id_tensor_storage
- from .transformer_engine import convert_model
- from .versions import is_torch_version
- logger = get_logger(__name__)
- if is_torch_xla_available():
- import torch_xla.core.xla_model as xm
- def is_compiled_module(module: torch.nn.Module) -> bool:
- """
- Check whether the module was compiled with torch.compile()
- """
- if not hasattr(torch, "_dynamo"):
- return False
- return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
- def has_compiled_regions(module: torch.nn.Module) -> bool:
- """
- Check whether the module has submodules that were compiled with `torch.compile()`.
- """
- if not hasattr(torch, "_dynamo"):
- return False
- if module._modules:
- for submodule in module.modules():
- if isinstance(submodule, torch._dynamo.eval_frame.OptimizedModule):
- return True
- return False
- def is_repeated_blocks(module: torch.nn.Module) -> bool:
- """
- Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This
- is useful to determine whether we should apply regional compilation to the module.
- """
- return isinstance(module, torch.nn.ModuleList) and all(isinstance(m, module[0].__class__) for m in module)
- def has_repeated_blocks(module: torch.nn.Module) -> bool:
- """
- Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at
- any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the
- module.
- """
- if module._modules:
- for submodule in module.modules():
- if is_repeated_blocks(submodule):
- return True
- return False
- def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
- """
- Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to
- hit the compiler's cache. For example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be
- accessed as `model.transformer.h[0]`. The rest of the model (e.g. model.lm_head) is compiled separately.
- This allows us to speed up the compilation overhead / cold start of models like LLMs and Transformers in general.
- See https://pytorch.org/tutorials/recipes/regional_compilation.html for more details.
- Args:
- module (`torch.nn.Module`):
- The model to compile.
- **compile_kwargs:
- Additional keyword arguments to pass to `torch.compile()`.
- Returns:
- `torch.nn.Module`: A new instance of the model with some compiled regions.
- Example:
- ```python
- >>> from accelerate.utils import compile_regions
- >>> from transformers import AutoModelForCausalLM
- >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
- >>> compiled_model = compile_regions(model, mode="reduce-overhead")
- >>> compiled_model.transformer.h[0]
- OptimizedModule(
- (_orig_mod): GPT2Block(
- (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
- (attn): GPT2Attention(
- (c_attn): Conv1D(nf=2304, nx=768)
- (c_proj): Conv1D(nf=768, nx=768)
- (attn_dropout): Dropout(p=0.1, inplace=False)
- (resid_dropout): Dropout(p=0.1, inplace=False)
- )
- (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
- (mlp): GPT2MLP(
- (c_fc): Conv1D(nf=3072, nx=768)
- (c_proj): Conv1D(nf=768, nx=3072)
- (act): NewGELUActivation()
- (dropout): Dropout(p=0.1, inplace=False)
- )
- )
- )
- ```
- """
- def _compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
- if is_repeated_blocks(module):
- new_module = torch.nn.ModuleList()
- for submodule in module:
- new_module.append(torch.compile(submodule, **compile_kwargs))
- elif has_repeated_blocks(module):
- new_module = module.__class__.__new__(module.__class__)
- new_module.__dict__.update(module.__dict__)
- new_module._modules = {}
- for name, submodule in module.named_children():
- new_module.add_module(name, _compile_regions(submodule, **compile_kwargs))
- else:
- new_module = torch.compile(module, **compile_kwargs)
- return new_module
- new_module = _compile_regions(module, **compile_kwargs)
- if "_orig_mod" not in new_module.__dict__:
- # Keeps a reference to the original module to decompile/unwrap it later
- new_module.__dict__["_orig_mod"] = module
- return new_module
- def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):
- """
- Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`.
- Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that
- `torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method
- instead.
- Args:
- module (`torch.nn.Module`):
- The model to compile.
- **compile_kwargs:
- Additional keyword arguments to pass to `module.compile()`.
- """
- if is_repeated_blocks(module):
- for submodule in module:
- submodule.compile(**compile_kwargs)
- elif has_repeated_blocks(module):
- for child in module.children():
- compile_regions_deepspeed(child, **compile_kwargs)
- else: # leaf node
- module.compile(**compile_kwargs)
- def model_has_dtensor(model: torch.nn.Module) -> bool:
- """
- Check if the model has DTensor parameters.
- Args:
- model (`torch.nn.Module`):
- The model to check.
- Returns:
- `bool`: Whether the model has DTensor parameters.
- """
- if is_torch_version(">=", "2.5.0"):
- from torch.distributed.tensor import DTensor
- else:
- # from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
- from torch.distributed._tensor import DTensor
- return any(isinstance(p, DTensor) for p in model.parameters())
- def extract_model_from_parallel(
- model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
- ):
- """
- Extract a model from its distributed containers.
- Args:
- model (`torch.nn.Module`):
- The model to extract.
- keep_fp32_wrapper (`bool`, *optional*):
- Whether to remove mixed precision hooks from the model.
- keep_torch_compile (`bool`, *optional*):
- Whether to unwrap compiled model.
- recursive (`bool`, *optional*, defaults to `False`):
- Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
- recursively, not just the top-level distributed containers.
- Returns:
- `torch.nn.Module`: The extracted model.
- """
- options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)
- is_compiled = is_compiled_module(model)
- has_compiled = has_compiled_regions(model)
- if is_compiled:
- compiled_model = model
- model = model._orig_mod
- elif has_compiled:
- compiled_model = model
- model = model.__dict__["_orig_mod"]
- if is_deepspeed_available():
- from deepspeed import DeepSpeedEngine
- options += (DeepSpeedEngine,)
- if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():
- from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
- options += (FSDP,)
- while isinstance(model, options):
- model = model.module
- if recursive:
- # This is needed in cases such as using FSDPv2 on XLA
- def _recursive_unwrap(module):
- # Wrapped modules are standardly wrapped as `module`, similar to the cases earlier
- # with DDP, DataParallel, DeepSpeed, and FSDP
- if hasattr(module, "module"):
- unwrapped_module = _recursive_unwrap(module.module)
- else:
- unwrapped_module = module
- # Next unwrap child sublayers recursively
- for name, child in unwrapped_module.named_children():
- setattr(unwrapped_module, name, _recursive_unwrap(child))
- return unwrapped_module
- # Start with top-level
- model = _recursive_unwrap(model)
- if not keep_fp32_wrapper:
- forward = model.forward
- original_forward = model.__dict__.pop("_original_forward", None)
- if original_forward is not None:
- while hasattr(forward, "__wrapped__"):
- forward = forward.__wrapped__
- if forward == original_forward:
- break
- model.forward = MethodType(forward, model)
- if getattr(model, "_converted_to_transformer_engine", False):
- convert_model(model, to_transformer_engine=False)
- if keep_torch_compile:
- if is_compiled:
- compiled_model._orig_mod = model
- model = compiled_model
- elif has_compiled:
- compiled_model.__dict__["_orig_mod"] = model
- model = compiled_model
- return model
- def wait_for_everyone():
- """
- Introduces a blocking point in the script, making sure all processes have reached this point before continuing.
- <Tip warning={true}>
- Make sure all processes will reach this instruction otherwise one of your processes will hang forever.
- </Tip>
- """
- PartialState().wait_for_everyone()
- def clean_state_dict_for_safetensors(state_dict: dict):
- """
- Cleans the state dictionary from a model and removes tensor aliasing if present.
- Args:
- state_dict (`dict`):
- The state dictionary from a model
- """
- ptrs = collections.defaultdict(list)
- # When bnb serialization is used, weights in state dict can be strings
- for name, tensor in state_dict.items():
- if not isinstance(tensor, str):
- ptrs[id_tensor_storage(tensor)].append(name)
- # These are all pointers of tensors with shared memory
- shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
- warn_names = set()
- for names in shared_ptrs.values():
- # When not all duplicates have been cleaned, we still remove those keys but put a clear warning.
- # If the link between tensors was done at runtime then `from_pretrained` will not get
- # the key back leading to random tensor. A proper warning will be shown
- # during reload (if applicable), but since the file is not necessarily compatible with
- # the config, better show a proper warning.
- found_names = [name for name in names if name in state_dict]
- warn_names.update(found_names[1:])
- for name in found_names[1:]:
- del state_dict[name]
- if len(warn_names) > 0:
- logger.warning(
- f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
- )
- state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()}
- return state_dict
- def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
- """
- Save the data to disk. Use in place of `torch.save()`.
- Args:
- obj:
- The data to save
- f:
- The file (or file-like object) to use to save the data
- save_on_each_node (`bool`, *optional*, defaults to `False`):
- Whether to only save on the global main process
- safe_serialization (`bool`, *optional*, defaults to `False`):
- Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`).
- """
- # When TorchXLA is enabled, it's necessary to transfer all data to the CPU before saving.
- # Another issue arises with `id_tensor_storage`, which treats all XLA tensors as identical.
- # If tensors remain on XLA, calling `clean_state_dict_for_safetensors` will result in only
- # one XLA tensor remaining.
- if PartialState().distributed_type == DistributedType.XLA:
- obj = xm._maybe_convert_to_cpu(obj)
- # Check if it's a model and remove duplicates
- if safe_serialization:
- save_func = partial(safe_save_file, metadata={"format": "pt"})
- if isinstance(obj, OrderedDict):
- obj = clean_state_dict_for_safetensors(obj)
- else:
- save_func = torch.save
- if PartialState().is_main_process and not save_on_each_node:
- save_func(obj, f)
- elif PartialState().is_local_main_process and save_on_each_node:
- save_func(obj, f)
- # The following are considered "safe" globals to reconstruct various types of objects when using `weights_only=True`
- # These should be added and then removed after loading in the file
- np_core = np._core if is_numpy_available("2.0.0") else np.core
- TORCH_SAFE_GLOBALS = [
- # numpy arrays are just numbers, not objects, so we can reconstruct them safely
- np_core.multiarray._reconstruct,
- np.ndarray,
- # The following are needed for the RNG states
- encode,
- np.dtype,
- ]
- if is_numpy_available("1.25.0"):
- TORCH_SAFE_GLOBALS.append(np.dtypes.UInt32DType)
- def load(f, map_location=None, **kwargs):
- """
- Compatible drop-in replacement of `torch.load()` which allows for `weights_only` to be used if `torch` version is
- 2.4.0 or higher. Otherwise will ignore the kwarg.
- Will also add (and then remove) an exception for numpy arrays
- Args:
- f:
- The file (or file-like object) to use to load the data
- map_location:
- a function, `torch.device`, string or a dict specifying how to remap storage locations
- **kwargs:
- Additional keyword arguments to pass to `torch.load()`.
- """
- try:
- if is_weights_only_available():
- old_safe_globals = torch.serialization.get_safe_globals()
- if "weights_only" not in kwargs:
- kwargs["weights_only"] = True
- torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS)
- else:
- kwargs.pop("weights_only", None)
- loaded_obj = torch.load(f, map_location=map_location, **kwargs)
- finally:
- if is_weights_only_available():
- torch.serialization.clear_safe_globals()
- if old_safe_globals:
- torch.serialization.add_safe_globals(old_safe_globals)
- return loaded_obj
- def get_pretty_name(obj):
- """
- Gets a pretty name from `obj`.
- """
- if not hasattr(obj, "__qualname__") and not hasattr(obj, "__name__"):
- obj = getattr(obj, "__class__", obj)
- if hasattr(obj, "__qualname__"):
- return obj.__qualname__
- if hasattr(obj, "__name__"):
- return obj.__name__
- return str(obj)
- def merge_dicts(source, destination):
- """
- Recursively merges two dictionaries.
- Args:
- source (`dict`): The dictionary to merge into `destination`.
- destination (`dict`): The dictionary to merge `source` into.
- """
- for key, value in source.items():
- if isinstance(value, dict):
- node = destination.setdefault(key, {})
- merge_dicts(value, node)
- else:
- destination[key] = value
- return destination
- def is_port_in_use(port: Optional[int] = None) -> bool:
- """
- Checks if a port is in use on `localhost`. Useful for checking if multiple `accelerate launch` commands have been
- run and need to see if the port is already in use.
- """
- if port is None:
- port = 29500
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- return s.connect_ex(("localhost", port)) == 0
- def get_free_port() -> int:
- """
- Gets a free port on `localhost`. Useful for automatic port selection when port 0 is specified in distributed
- training scenarios.
- Returns:
- int: An available port number
- """
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(("", 0)) # bind to port 0 for OS to assign a free port
- return s.getsockname()[1]
- def convert_bytes(size):
- "Converts `size` from bytes to the largest possible unit"
- for x in ["bytes", "KB", "MB", "GB", "TB"]:
- if size < 1024.0:
- return f"{round(size, 2)} {x}"
- size /= 1024.0
- return f"{round(size, 2)} PB"
- def check_os_kernel():
- """Warns if the kernel version is below the recommended minimum on Linux."""
- # see issue #1929
- info = platform.uname()
- system = info.system
- if system != "Linux":
- return
- _, version, *_ = re.split(r"(\d+\.\d+\.\d+)", info.release)
- min_version = "5.5.0"
- if Version(version) < Version(min_version):
- msg = (
- f"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can "
- "cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher."
- )
- logger.warning(msg, main_process_only=True)
- def recursive_getattr(obj, attr: str):
- """
- Recursive `getattr`.
- Args:
- obj:
- A class instance holding the attribute.
- attr (`str`):
- The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
- """
- def _getattr(obj, attr):
- return getattr(obj, attr)
- return reduce(_getattr, [obj] + attr.split("."))
- def get_module_children_bottom_up(model: torch.nn.Module, return_fqns: bool = False) -> list[torch.nn.Module]:
- """Traverse the model in bottom-up order and return the children modules in that order.
- Args:
- model (`torch.nn.Module`): the model to get the children of
- Returns:
- `list[torch.nn.Module]`: a list of children modules of `model` in bottom-up order. The last element is the
- `model` itself.
- """
- top = model if not return_fqns else ("", model)
- stack = [top]
- ordered_modules = []
- while stack:
- current_module = stack.pop()
- if return_fqns:
- current_module_name, current_module = current_module
- for name, attr in current_module.named_children():
- if isinstance(attr, torch.nn.Module):
- if return_fqns:
- child_name = current_module_name + "." + name if current_module_name else name
- stack.append((child_name, attr))
- else:
- stack.append(attr)
- if return_fqns:
- ordered_modules.append((current_module_name, current_module))
- else:
- ordered_modules.append(current_module)
- return ordered_modules[::-1]
|