| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Since, https://github.com/huggingface/transformers/pull/36963, loading is always performed with models on meta
- device. But since the `init_empty_weights` and `find_tied_parameters` functions are from accelerate, and accelerate is
- somewhat still a soft dependency, we copy the functions here to be used natively in Transformers.
- The `init_empty_weights` and `init_on_device` functions were copied from `accelerate.big_modeling.py`, and the
- `find_tied_parameters` was copied from `accelerate.utils.modeling.py`
- """
- from contextlib import contextmanager
- from ..utils import is_torch_available, logging
- if is_torch_available():
- import torch
- import torch.nn as nn
- logger = logging.get_logger(__name__)
- @contextmanager
- def init_empty_weights(include_buffers: bool = False):
- """
- A context manager under which models are initialized with all parameters on the meta device, therefore creating an
- empty model. Useful when just initializing the model would blow the available RAM.
- Args:
- include_buffers (`bool`, *optional*):
- Whether or not to also put all buffers on the meta device while initializing.
- Example:
- ```python
- import torch.nn as nn
- from accelerate import init_empty_weights
- # Initialize a model with 100 billions parameters in no time and without using any RAM.
- with init_empty_weights():
- tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
- ```
- <Tip warning={true}>
- Any model created under this context manager has no weights. As such you can't do something like
- `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
- Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
- called.
- </Tip>
- """
- with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
- yield f
- @contextmanager
- def init_on_device(device: "torch.device", include_buffers: bool = False):
- """
- A context manager under which models are initialized with all parameters on the specified device.
- Args:
- device (`torch.device`):
- Device to initialize all parameters on.
- include_buffers (`bool`, *optional*):
- Whether or not to also put all buffers on the meta device while initializing.
- Example:
- ```python
- import torch.nn as nn
- from accelerate import init_on_device
- with init_on_device(device=torch.device("cuda")):
- tst = nn.Linear(100, 100) # on `cuda` device
- ```
- """
- if include_buffers:
- with device:
- yield
- return
- old_register_parameter = nn.Module.register_parameter
- if include_buffers:
- old_register_buffer = nn.Module.register_buffer
- def register_empty_parameter(module, name, param):
- old_register_parameter(module, name, param)
- if param is not None:
- param_cls = type(module._parameters[name])
- kwargs = module._parameters[name].__dict__
- kwargs["requires_grad"] = param.requires_grad
- module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
- def register_empty_buffer(module, name, buffer, persistent=True):
- old_register_buffer(module, name, buffer, persistent=persistent)
- if buffer is not None:
- module._buffers[name] = module._buffers[name].to(device)
- # Patch tensor creation
- if include_buffers:
- tensor_constructors_to_patch = {
- torch_function_name: getattr(torch, torch_function_name)
- for torch_function_name in ["empty", "zeros", "ones", "full"]
- }
- else:
- tensor_constructors_to_patch = {}
- def patch_tensor_constructor(fn):
- def wrapper(*args, **kwargs):
- kwargs["device"] = device
- return fn(*args, **kwargs)
- return wrapper
- try:
- nn.Module.register_parameter = register_empty_parameter
- if include_buffers:
- nn.Module.register_buffer = register_empty_buffer
- for torch_function_name in tensor_constructors_to_patch:
- setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
- yield
- finally:
- nn.Module.register_parameter = old_register_parameter
- if include_buffers:
- nn.Module.register_buffer = old_register_buffer
- for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
- setattr(torch, torch_function_name, old_torch_function)
- def find_tied_parameters(model: "nn.Module", **kwargs):
- """
- Find the tied parameters in a given model.
- <Tip warning={true}>
- The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
- them.
- </Tip>
- Args:
- model (`torch.nn.Module`): The model to inspect.
- Returns:
- list[list[str]]: A list of lists of parameter names being all tied together.
- Example:
- ```py
- >>> from collections import OrderedDict
- >>> import torch.nn as nn
- >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
- >>> model.linear2.weight = model.linear1.weight
- >>> find_tied_parameters(model)
- [['linear1.weight', 'linear2.weight']]
- ```
- """
- # get ALL model parameters and their names
- all_named_parameters = dict(model.named_parameters(remove_duplicate=False))
- # get ONLY unique named parameters,
- # if parameter is tied and have multiple names, it will be included only once
- no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True))
- # the difference of the two sets will give us the tied parameters
- tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
- # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
- # which names refer to the same parameter. To identify this, we need to group them together.
- tied_param_groups = {}
- for tied_param_name in tied_param_names:
- tied_param = all_named_parameters[tied_param_name]
- for param_name, param in no_duplicate_named_parameters.items():
- # compare if parameters are the same, if so, group their names together
- if param is tied_param:
- if param_name not in tied_param_groups:
- tied_param_groups[param_name] = []
- tied_param_groups[param_name].append(tied_param_name)
- return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
|