| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015 |
- # Copyright 2024 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Contains pytorch-specific helpers."""
- import importlib
- import json
- import os
- import re
- from collections import defaultdict, namedtuple
- from functools import lru_cache
- from pathlib import Path
- from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
- from packaging import version
- from .. import constants, logging
- from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
- logger = logging.get_logger(__file__)
- if TYPE_CHECKING:
- import torch
- # SAVING
- def save_torch_model(
- model: "torch.nn.Module",
- save_directory: Union[str, Path],
- *,
- filename_pattern: Optional[str] = None,
- force_contiguous: bool = True,
- max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
- metadata: Optional[Dict[str, str]] = None,
- safe_serialization: bool = True,
- is_main_process: bool = True,
- shared_tensors_to_discard: Optional[List[str]] = None,
- ):
- """
- Saves a given torch model to disk, handling sharding and shared tensors issues.
- See also [`save_torch_state_dict`] to save a state dict with more flexibility.
- For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
- The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
- saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
- an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
- [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
- safetensors (the default). Otherwise, the shards are saved as pickle.
- Before saving the model, the `save_directory` is cleaned from any previous shard files.
- > [!WARNING]
- > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
- > size greater than `max_shard_size`.
- > [!WARNING]
- > If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
- Args:
- model (`torch.nn.Module`):
- The model to save on disk.
- save_directory (`str` or `Path`):
- The directory in which the model will be saved.
- filename_pattern (`str`, *optional*):
- The pattern to generate the files names in which the model will be saved. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
- parameter.
- force_contiguous (`boolean`, *optional*):
- Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
- model, but it could potentially change performance if the layout of the tensor was chosen specifically for
- that reason. Defaults to `True`.
- max_shard_size (`int` or `str`, *optional*):
- The maximum size of each shard, in bytes. Defaults to 5GB.
- metadata (`Dict[str, str]`, *optional*):
- Extra information to save along with the model. Some metadata will be added for each dropped tensors.
- This information will not be enough to recover the entire shared structure but might help understanding
- things.
- safe_serialization (`bool`, *optional*):
- Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
- Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
- in a future version.
- is_main_process (`bool`, *optional*):
- Whether the process calling this is the main process or not. Useful when in distributed training like
- TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
- the main process to avoid race conditions. Defaults to True.
- shared_tensors_to_discard (`List[str]`, *optional*):
- List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
- detected, it will drop the first name alphabetically.
- Example:
- ```py
- >>> from huggingface_hub import save_torch_model
- >>> model = ... # A PyTorch model
- # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
- >>> save_torch_model(model, "path/to/folder")
- # Load model back
- >>> from huggingface_hub import load_torch_model # TODO
- >>> load_torch_model(model, "path/to/folder")
- >>>
- ```
- """
- save_torch_state_dict(
- state_dict=model.state_dict(),
- filename_pattern=filename_pattern,
- force_contiguous=force_contiguous,
- max_shard_size=max_shard_size,
- metadata=metadata,
- safe_serialization=safe_serialization,
- save_directory=save_directory,
- is_main_process=is_main_process,
- shared_tensors_to_discard=shared_tensors_to_discard,
- )
- def save_torch_state_dict(
- state_dict: Dict[str, "torch.Tensor"],
- save_directory: Union[str, Path],
- *,
- filename_pattern: Optional[str] = None,
- force_contiguous: bool = True,
- max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
- metadata: Optional[Dict[str, str]] = None,
- safe_serialization: bool = True,
- is_main_process: bool = True,
- shared_tensors_to_discard: Optional[List[str]] = None,
- ) -> None:
- """
- Save a model state dictionary to the disk, handling sharding and shared tensors issues.
- See also [`save_torch_model`] to directly save a PyTorch model.
- For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
- The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
- saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
- an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
- [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
- safetensors (the default). Otherwise, the shards are saved as pickle.
- Before saving the model, the `save_directory` is cleaned from any previous shard files.
- > [!WARNING]
- > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
- > size greater than `max_shard_size`.
- > [!WARNING]
- > If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
- Args:
- state_dict (`Dict[str, torch.Tensor]`):
- The state dictionary to save.
- save_directory (`str` or `Path`):
- The directory in which the model will be saved.
- filename_pattern (`str`, *optional*):
- The pattern to generate the files names in which the model will be saved. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
- parameter.
- force_contiguous (`boolean`, *optional*):
- Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
- model, but it could potentially change performance if the layout of the tensor was chosen specifically for
- that reason. Defaults to `True`.
- max_shard_size (`int` or `str`, *optional*):
- The maximum size of each shard, in bytes. Defaults to 5GB.
- metadata (`Dict[str, str]`, *optional*):
- Extra information to save along with the model. Some metadata will be added for each dropped tensors.
- This information will not be enough to recover the entire shared structure but might help understanding
- things.
- safe_serialization (`bool`, *optional*):
- Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
- Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
- in a future version.
- is_main_process (`bool`, *optional*):
- Whether the process calling this is the main process or not. Useful when in distributed training like
- TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
- the main process to avoid race conditions. Defaults to True.
- shared_tensors_to_discard (`List[str]`, *optional*):
- List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
- detected, it will drop the first name alphabetically.
- Example:
- ```py
- >>> from huggingface_hub import save_torch_state_dict
- >>> model = ... # A PyTorch model
- # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
- >>> state_dict = model_to_save.state_dict()
- >>> save_torch_state_dict(state_dict, "path/to/folder")
- ```
- """
- save_directory = str(save_directory)
- if filename_pattern is None:
- filename_pattern = (
- constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
- if safe_serialization
- else constants.PYTORCH_WEIGHTS_FILE_PATTERN
- )
- if metadata is None:
- metadata = {}
- if safe_serialization:
- try:
- from safetensors.torch import save_file as save_file_fn
- except ImportError as e:
- raise ImportError(
- "Please install `safetensors` to use safe serialization. "
- "You can install it with `pip install safetensors`."
- ) from e
- # Clean state dict for safetensors
- state_dict = _clean_state_dict_for_safetensors(
- state_dict,
- metadata,
- force_contiguous=force_contiguous,
- shared_tensors_to_discard=shared_tensors_to_discard,
- )
- else:
- from torch import save as save_file_fn # type: ignore[assignment, no-redef]
- logger.warning(
- "You are using unsafe serialization. Due to security reasons, it is recommended not to load "
- "pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
- "using safe serialization by installing `safetensors` with `pip install safetensors`."
- )
- # Split dict
- state_dict_split = split_torch_state_dict_into_shards(
- state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
- )
- # Only main process should clean up existing files to avoid race conditions in distributed environment
- if is_main_process:
- existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
- for filename in os.listdir(save_directory):
- if existing_files_regex.match(filename):
- try:
- logger.debug(f"Removing existing file '{filename}' from folder.")
- os.remove(os.path.join(save_directory, filename))
- except Exception as e:
- logger.warning(
- f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
- )
- # Save each shard
- per_file_metadata = {"format": "pt"}
- if not state_dict_split.is_sharded:
- per_file_metadata.update(metadata)
- safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {}
- for filename, tensors in state_dict_split.filename_to_tensors.items():
- shard = {tensor: state_dict[tensor] for tensor in tensors}
- save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs) # ty: ignore[invalid-argument-type]
- logger.debug(f"Shard saved to {filename}")
- # Save the index (if any)
- if state_dict_split.is_sharded:
- index_path = filename_pattern.format(suffix="") + ".index.json"
- index = {
- "metadata": {**state_dict_split.metadata, **metadata},
- "weight_map": state_dict_split.tensor_to_filename,
- }
- with open(os.path.join(save_directory, index_path), "w") as f:
- json.dump(index, f, indent=2)
- logger.info(
- f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). "
- f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. "
- f"You can find where each parameters has been saved in the index located at {index_path}."
- )
- logger.info(f"Model weights successfully saved to {save_directory}!")
- def split_torch_state_dict_into_shards(
- state_dict: Dict[str, "torch.Tensor"],
- *,
- filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
- max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
- ) -> StateDictSplit:
- """
- Split a model state dictionary in shards so that each shard is smaller than a given size.
- The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
- made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
- have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
- [6+2+2GB], [6+2GB], [6GB].
- > [!TIP]
- > To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses
- > `split_torch_state_dict_into_shards` under the hood.
- > [!WARNING]
- > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
- > size greater than `max_shard_size`.
- Args:
- state_dict (`Dict[str, torch.Tensor]`):
- The state dictionary to save.
- filename_pattern (`str`, *optional*):
- The pattern to generate the files names in which the model will be saved. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"`.
- max_shard_size (`int` or `str`, *optional*):
- The maximum size of each shard, in bytes. Defaults to 5GB.
- Returns:
- [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
- Example:
- ```py
- >>> import json
- >>> import os
- >>> from safetensors.torch import save_file as safe_save_file
- >>> from huggingface_hub import split_torch_state_dict_into_shards
- >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
- ... state_dict_split = split_torch_state_dict_into_shards(state_dict)
- ... for filename, tensors in state_dict_split.filename_to_tensors.items():
- ... shard = {tensor: state_dict[tensor] for tensor in tensors}
- ... safe_save_file(
- ... shard,
- ... os.path.join(save_directory, filename),
- ... metadata={"format": "pt"},
- ... )
- ... if state_dict_split.is_sharded:
- ... index = {
- ... "metadata": state_dict_split.metadata,
- ... "weight_map": state_dict_split.tensor_to_filename,
- ... }
- ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f:
- ... f.write(json.dumps(index, indent=2))
- ```
- """
- return split_state_dict_into_shards_factory(
- state_dict,
- max_shard_size=max_shard_size,
- filename_pattern=filename_pattern,
- get_storage_size=get_torch_storage_size,
- get_storage_id=get_torch_storage_id,
- )
- # LOADING
- def load_torch_model(
- model: "torch.nn.Module",
- checkpoint_path: Union[str, os.PathLike],
- *,
- strict: bool = False,
- safe: bool = True,
- weights_only: bool = False,
- map_location: Optional[Union[str, "torch.device"]] = None,
- mmap: bool = False,
- filename_pattern: Optional[str] = None,
- ) -> NamedTuple:
- """
- Load a checkpoint into a model, handling both sharded and non-sharded checkpoints.
- Args:
- model (`torch.nn.Module`):
- The model in which to load the checkpoint.
- checkpoint_path (`str` or `os.PathLike`):
- Path to either the checkpoint file or directory containing the checkpoint(s).
- strict (`bool`, *optional*, defaults to `False`):
- Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint.
- safe (`bool`, *optional*, defaults to `True`):
- If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function
- will first attempt to load safetensors files if they are available, otherwise it will fall back to loading
- pickle files. `filename_pattern` parameter takes precedence over `safe` parameter.
- weights_only (`bool`, *optional*, defaults to `False`):
- If True, only loads the model weights without optimizer states and other metadata.
- Only supported in PyTorch >= 1.13.
- map_location (`str` or `torch.device`, *optional*):
- A `torch.device` object, string or a dict specifying how to remap storage locations. It
- indicates the location where all tensors should be loaded.
- mmap (`bool`, *optional*, defaults to `False`):
- Whether to use memory-mapped file loading. Memory mapping can improve loading performance
- for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints.
- filename_pattern (`str`, *optional*):
- The pattern to look for the index file. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"`.
- Returns:
- `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields.
- - `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint.
- - `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model.
- Raises:
- [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
- If the checkpoint file or directory does not exist.
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
- If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
- If the checkpoint path is invalid or if the checkpoint format cannot be determined.
- Example:
- ```python
- >>> from huggingface_hub import load_torch_model
- >>> model = ... # A PyTorch model
- >>> load_torch_model(model, "path/to/checkpoint")
- ```
- """
- checkpoint_path = Path(checkpoint_path)
- if not checkpoint_path.exists():
- raise ValueError(f"Checkpoint path {checkpoint_path} does not exist")
- # 1. Check if checkpoint is a single file
- if checkpoint_path.is_file():
- state_dict = load_state_dict_from_file(
- checkpoint_file=checkpoint_path,
- map_location=map_location,
- weights_only=weights_only,
- )
- return model.load_state_dict(state_dict, strict=strict)
- # 2. If not, checkpoint_path is a directory
- if filename_pattern is None:
- filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
- index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
- # Only fallback to pickle format if safetensors index is not found and safe is False.
- if not index_path.is_file() and not safe:
- filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN
- index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
- if index_path.is_file():
- return _load_sharded_checkpoint(
- model=model,
- save_directory=checkpoint_path,
- strict=strict,
- weights_only=weights_only,
- filename_pattern=filename_pattern,
- )
- # Look for single model file
- model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin"))
- if len(model_files) == 1:
- state_dict = load_state_dict_from_file(
- checkpoint_file=model_files[0],
- map_location=map_location,
- weights_only=weights_only,
- mmap=mmap,
- )
- return model.load_state_dict(state_dict, strict=strict)
- raise ValueError(
- f"Directory '{checkpoint_path}' does not contain a valid checkpoint. "
- "Expected either a sharded checkpoint with an index file, or a single model file."
- )
- def _load_sharded_checkpoint(
- model: "torch.nn.Module",
- save_directory: os.PathLike,
- *,
- strict: bool = False,
- weights_only: bool = False,
- filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
- ) -> NamedTuple:
- """
- Loads a sharded checkpoint into a model. This is the same as
- [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
- but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model.
- Args:
- model (`torch.nn.Module`):
- The model in which to load the checkpoint.
- save_directory (`str` or `os.PathLike`):
- A path to a folder containing the sharded checkpoint.
- strict (`bool`, *optional*, defaults to `False`):
- Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
- weights_only (`bool`, *optional*, defaults to `False`):
- If True, only loads the model weights without optimizer states and other metadata.
- Only supported in PyTorch >= 1.13.
- filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`):
- The pattern to look for the index file. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"`.
- Returns:
- `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields,
- - `missing_keys` is a list of str containing the missing keys
- - `unexpected_keys` is a list of str containing the unexpected keys
- """
- # 1. Load and validate index file
- # The index file contains mapping of parameter names to shard files
- index_path = filename_pattern.format(suffix="") + ".index.json"
- index_file = os.path.join(save_directory, index_path)
- with open(index_file, "r", encoding="utf-8") as f:
- index = json.load(f)
- # 2. Validate keys if in strict mode
- # This is done before loading any shards to fail fast
- if strict:
- _validate_keys_for_strict_loading(model, index["weight_map"].keys())
- # 3. Load each shard using `load_state_dict`
- # Get unique shard files (multiple parameters can be in same shard)
- shard_files = list(set(index["weight_map"].values()))
- for shard_file in shard_files:
- # Load shard into memory
- shard_path = os.path.join(save_directory, shard_file)
- state_dict = load_state_dict_from_file(
- shard_path,
- map_location="cpu",
- weights_only=weights_only,
- )
- # Update model with parameters from this shard
- model.load_state_dict(state_dict, strict=strict)
- # Explicitly remove the state dict from memory
- del state_dict
- # 4. Return compatibility info
- loaded_keys = set(index["weight_map"].keys())
- model_keys = set(model.state_dict().keys())
- return _IncompatibleKeys(
- missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys)
- )
- def load_state_dict_from_file(
- checkpoint_file: Union[str, os.PathLike],
- map_location: Optional[Union[str, "torch.device"]] = None,
- weights_only: bool = False,
- mmap: bool = False,
- ) -> Union[Dict[str, "torch.Tensor"], Any]:
- """
- Loads a checkpoint file, handling both safetensors and pickle checkpoint formats.
- Args:
- checkpoint_file (`str` or `os.PathLike`):
- Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint.
- map_location (`str` or `torch.device`, *optional*):
- A `torch.device` object, string or a dict specifying how to remap storage locations. It
- indicates the location where all tensors should be loaded.
- weights_only (`bool`, *optional*, defaults to `False`):
- If True, only loads the model weights without optimizer states and other metadata.
- Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when
- loading safetensors files.
- mmap (`bool`, *optional*, defaults to `False`):
- Whether to use memory-mapped file loading. Memory mapping can improve loading performance
- for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when
- loading safetensors files, as the `safetensors` library uses memory mapping by default.
- Returns:
- `Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint.
- - For safetensors files: always returns a dictionary mapping parameter names to tensors.
- - For pickle files: returns any Python object that was pickled (commonly a state dict, but could be
- an entire model, optimizer state, or any other Python object).
- Raises:
- [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
- If the checkpoint file does not exist.
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
- If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
- If the checkpoint file format is invalid or if git-lfs files are not properly downloaded.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
- If the checkpoint file path is empty or invalid.
- Example:
- ```python
- >>> from huggingface_hub import load_state_dict_from_file
- # Load a PyTorch checkpoint
- >>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu")
- >>> model.load_state_dict(state_dict)
- # Load a safetensors checkpoint
- >>> state_dict = load_state_dict_from_file("path/to/model.safetensors")
- >>> model.load_state_dict(state_dict)
- ```
- """
- checkpoint_path = Path(checkpoint_file)
- # Check if file exists and is a regular file (not a directory)
- if not checkpoint_path.is_file():
- raise FileNotFoundError(
- f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and "
- "the file has been properly downloaded."
- )
- # Load safetensors checkpoint
- if checkpoint_path.suffix == ".safetensors":
- try:
- from safetensors import safe_open
- from safetensors.torch import load_file
- except ImportError as e:
- raise ImportError(
- "Please install `safetensors` to load safetensors checkpoint. "
- "You can install it with `pip install safetensors`."
- ) from e
- # Check format of the archive
- with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
- metadata = f.metadata()
- # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
- if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
- raise OSError(
- f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
- "you save your model with the `save_torch_model` method."
- )
- device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location
- # meta device is not supported with safetensors, falling back to CPU
- if device == "meta":
- logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.")
- device = "cpu"
- return load_file(checkpoint_file, device=device) # type: ignore[arg-type]
- # Otherwise, load from pickle
- try:
- import torch
- from torch import load
- except ImportError as e:
- raise ImportError(
- "Please install `torch` to load torch tensors. You can install it with `pip install torch`."
- ) from e
- # Add additional kwargs, mmap is only supported in torch >= 2.1.0
- additional_kwargs = {}
- if version.parse(torch.__version__) >= version.parse("2.1.0"):
- additional_kwargs["mmap"] = mmap
- # weights_only is only supported in torch >= 1.13.0
- if version.parse(torch.__version__) >= version.parse("1.13.0"):
- additional_kwargs["weights_only"] = weights_only
- return load(
- checkpoint_file,
- map_location=map_location,
- **additional_kwargs,
- )
- # HELPERS
- def _validate_keys_for_strict_loading(
- model: "torch.nn.Module",
- loaded_keys: Iterable[str],
- ) -> None:
- """
- Validate that model keys match loaded keys when strict loading is enabled.
- Args:
- model: The PyTorch model being loaded
- loaded_keys: The keys present in the checkpoint
- Raises:
- RuntimeError: If there are missing or unexpected keys in strict mode
- """
- loaded_keys_set = set(loaded_keys)
- model_keys = set(model.state_dict().keys())
- missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint
- unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model
- if missing_keys or unexpected_keys:
- error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
- if missing_keys:
- str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)])
- error_message += f"\nMissing key(s): {str_missing_keys}."
- if unexpected_keys:
- str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)])
- error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
- raise RuntimeError(error_message)
- def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
- """Returns a unique id for plain tensor
- or a (potentially nested) Tuple of unique id for the flattened Tensor
- if the input is a wrapper tensor subclass Tensor
- """
- try:
- from torch.distributed.tensor import DTensor
- if isinstance(tensor, DTensor):
- local_tensor = tensor.to_local()
- return local_tensor.storage().data_ptr()
- except ImportError:
- pass
- try:
- # for torch 2.1 and above we can also handle tensor subclasses
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- if is_traceable_wrapper_subclass(tensor):
- attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
- return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs)
- except ImportError:
- # for torch version less than 2.1, we can fallback to original implementation
- pass
- if tensor.device.type == "xla" and is_torch_tpu_available():
- # NOTE: xla tensors dont have storage
- # use some other unique id to distinguish.
- # this is a XLA tensor, it must be created using torch_xla's
- # device. So the following import is safe:
- import torch_xla # type: ignore[import]
- unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
- else:
- unique_id = storage_ptr(tensor)
- return unique_id
- def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[Tuple["torch.device", Union[int, Tuple[Any, ...]], int]]:
- """
- Return unique identifier to a tensor storage.
- Multiple different tensors can share the same underlying storage. This identifier is
- guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
- non-overlapping lifetimes may have the same id.
- In the case of meta tensors, we return None since we can't tell if they share the same storage.
- Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
- """
- if tensor.device.type == "meta":
- return None
- else:
- return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
- def get_torch_storage_size(tensor: "torch.Tensor") -> int:
- """
- Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
- """
- try:
- from torch.distributed.tensor import DTensor
- if isinstance(tensor, DTensor):
- # this returns the size of the FULL tensor in bytes
- return tensor.nbytes
- except ImportError:
- pass
- try:
- # for torch 2.1 and above we can also handle tensor subclasses
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- if is_traceable_wrapper_subclass(tensor):
- attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
- return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs)
- except ImportError:
- # for torch version less than 2.1, we can fallback to original implementation
- pass
- try:
- return tensor.untyped_storage().nbytes()
- except AttributeError:
- # Fallback for torch==1.10
- try:
- return tensor.storage().size() * _get_dtype_size(tensor.dtype)
- except NotImplementedError:
- # Fallback for meta storage
- # On torch >=2.0 this is the tensor size
- return tensor.nelement() * _get_dtype_size(tensor.dtype)
- @lru_cache()
- def is_torch_tpu_available(check_device=True):
- """
- Checks if `torch_xla` is installed and potentially if a TPU is in the environment
- Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/utils/import_utils.py#L463.
- """
- if importlib.util.find_spec("torch_xla") is not None:
- if check_device:
- # We need to check if `xla_device` can be found, will raise a RuntimeError if not
- try:
- import torch_xla.core.xla_model as xm # type: ignore[import]
- _ = xm.xla_device()
- return True
- except RuntimeError:
- return False
- return True
- return False
- def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11.
- """
- try:
- # for torch 2.1 and above we can also handle tensor subclasses
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- if is_traceable_wrapper_subclass(tensor):
- return _get_unique_id(tensor) # type: ignore
- except ImportError:
- # for torch version less than 2.1, we can fallback to original implementation
- pass
- try:
- return tensor.untyped_storage().data_ptr()
- except Exception:
- # Fallback for torch==1.10
- try:
- return tensor.storage().data_ptr()
- except NotImplementedError:
- # Fallback for meta storage
- return 0
- def _clean_state_dict_for_safetensors(
- state_dict: Dict[str, "torch.Tensor"],
- metadata: Dict[str, str],
- force_contiguous: bool = True,
- shared_tensors_to_discard: Optional[List[str]] = None,
- ):
- """Remove shared tensors from state_dict and update metadata accordingly (for reloading).
- Warning: `state_dict` and `metadata` are mutated in-place!
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
- """
- to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
- for kept_name, to_remove_group in to_removes.items():
- for to_remove in to_remove_group:
- if metadata is None:
- metadata = {}
- if to_remove not in metadata:
- # Do not override user data
- metadata[to_remove] = kept_name
- del state_dict[to_remove]
- if force_contiguous:
- state_dict = {k: v.contiguous() for k, v in state_dict.items()}
- return state_dict
- def _end_ptr(tensor: "torch.Tensor") -> int:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23.
- """
- if tensor.nelement():
- stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype)
- else:
- stop = tensor.data_ptr()
- return stop
- def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44
- """
- filtered_tensors = []
- for shared in tensors:
- if len(shared) < 2:
- filtered_tensors.append(shared)
- continue
- areas = []
- for name in shared:
- tensor = state_dict[name]
- areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
- areas.sort()
- _, last_stop, last_name = areas[0]
- filtered_tensors.append({last_name})
- for start, stop, name in areas[1:]:
- if start >= last_stop:
- filtered_tensors.append({name})
- else:
- filtered_tensors[-1].add(name)
- last_stop = stop
- return filtered_tensors
- def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69.
- """
- import torch
- tensors_dict = defaultdict(set)
- for k, v in state_dict.items():
- if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0:
- # Need to add device as key because of multiple GPU.
- tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k)
- tensors = list(sorted(tensors_dict.values()))
- tensors = _filter_shared_not_shared(tensors, state_dict)
- return tensors
- def _is_complete(tensor: "torch.Tensor") -> bool:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
- """
- try:
- # for torch 2.1 and above we can also handle tensor subclasses
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- if is_traceable_wrapper_subclass(tensor):
- attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
- return all(_is_complete(getattr(tensor, attr)) for attr in attrs)
- except ImportError:
- # for torch version less than 2.1, we can fallback to original implementation
- pass
- return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size(
- tensor.dtype
- ) == get_torch_storage_size(tensor)
- def _remove_duplicate_names(
- state_dict: Dict[str, "torch.Tensor"],
- *,
- preferred_names: Optional[List[str]] = None,
- discard_names: Optional[List[str]] = None,
- ) -> Dict[str, List[str]]:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
- """
- if preferred_names is None:
- preferred_names = []
- unique_preferred_names = set(preferred_names)
- if discard_names is None:
- discard_names = []
- unique_discard_names = set(discard_names)
- shareds = _find_shared_tensors(state_dict)
- to_remove = defaultdict(list)
- for shared in shareds:
- complete_names = set([name for name in shared if _is_complete(state_dict[name])])
- if not complete_names:
- raise RuntimeError(
- "Error while trying to find names to remove to save state dict, but found no suitable name to keep"
- f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model"
- " since you could be storing much more memory than needed. Please refer to"
- " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an"
- " issue."
- )
- keep_name = sorted(list(complete_names))[0]
- # Mechanism to preferentially select keys to keep
- # coming from the on-disk file to allow
- # loading models saved with a different choice
- # of keep_name
- preferred = complete_names.difference(unique_discard_names)
- if preferred:
- keep_name = sorted(list(preferred))[0]
- if unique_preferred_names:
- preferred = unique_preferred_names.intersection(complete_names)
- if preferred:
- keep_name = sorted(list(preferred))[0]
- for name in sorted(shared):
- if name != keep_name:
- to_remove[keep_name].append(name)
- return to_remove
- @lru_cache()
- def _get_dtype_size(dtype: "torch.dtype") -> int:
- """
- Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344
- """
- import torch
- # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
- _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
- _float8_e5m2 = getattr(torch, "float8_e5m2", None)
- _SIZE = {
- torch.int64: 8,
- torch.float32: 4,
- torch.int32: 4,
- torch.bfloat16: 2,
- torch.float16: 2,
- torch.int16: 2,
- torch.uint8: 1,
- torch.int8: 1,
- torch.bool: 1,
- torch.float64: 8,
- _float8_e4m3fn: 1,
- _float8_e5m2: 1,
- }
- return _SIZE[dtype]
- class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
- """
- This is used to report missing and unexpected keys in the state dict.
- Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52.
- """
- def __repr__(self) -> str:
- if not self.missing_keys and not self.unexpected_keys:
- return "<All keys matched successfully>"
- return super().__repr__()
- __str__ = __repr__
|