| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776 |
- # Copyright 2022 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import functools
- from collections.abc import Mapping
- from typing import Optional, Union
- import torch
- import torch.nn as nn
- from .state import PartialState
- from .utils import (
- PrefixedDataset,
- find_device,
- named_module_tensors,
- send_to_device,
- set_module_tensor_to_device,
- )
- from .utils.imports import (
- is_mlu_available,
- is_musa_available,
- is_npu_available,
- )
- from .utils.memory import clear_device_cache
- from .utils.modeling import get_non_persistent_buffers
- from .utils.other import recursive_getattr
- _accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "sdaa", "musa"]
- class ModelHook:
- """
- A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
- with PyTorch existing hooks is that they get passed along the kwargs.
- Class attribute:
- - **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
- the `torch.no_grad()` context manager.
- """
- no_grad = False
- def init_hook(self, module):
- """
- To be executed when the hook is attached to the module.
- Args:
- module (`torch.nn.Module`): The module attached to this hook.
- """
- return module
- def pre_forward(self, module, *args, **kwargs):
- """
- To be executed just before the forward method of the model.
- Args:
- module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
- args (`Tuple[Any]`): The positional arguments passed to the module.
- kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
- Returns:
- `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
- """
- return args, kwargs
- def post_forward(self, module, output):
- """
- To be executed just after the forward method of the model.
- Args:
- module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
- output (`Any`): The output of the module.
- Returns:
- `Any`: The processed `output`.
- """
- return output
- def detach_hook(self, module):
- """
- To be executed when the hook is detached from a module.
- Args:
- module (`torch.nn.Module`): The module detached from this hook.
- """
- return module
- class SequentialHook(ModelHook):
- """
- A hook that can contain several hooks and iterates through them at each event.
- """
- def __init__(self, *hooks):
- self.hooks = hooks
- def init_hook(self, module):
- for hook in self.hooks:
- module = hook.init_hook(module)
- return module
- def pre_forward(self, module, *args, **kwargs):
- for hook in self.hooks:
- args, kwargs = hook.pre_forward(module, *args, **kwargs)
- return args, kwargs
- def post_forward(self, module, output):
- for hook in self.hooks:
- output = hook.post_forward(module, output)
- return output
- def detach_hook(self, module):
- for hook in self.hooks:
- module = hook.detach_hook(module)
- return module
- def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
- """
- Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
- this behavior and restore the original `forward` method, use `remove_hook_from_module`.
- <Tip warning={true}>
- If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
- together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
- </Tip>
- Args:
- module (`torch.nn.Module`):
- The module to attach a hook to.
- hook (`ModelHook`):
- The hook to attach.
- append (`bool`, *optional*, defaults to `False`):
- Whether the hook should be chained with an existing one (if module already contains a hook) or not.
- Returns:
- `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
- be discarded).
- """
- if append and (getattr(module, "_hf_hook", None) is not None):
- old_hook = module._hf_hook
- remove_hook_from_module(module)
- hook = SequentialHook(old_hook, hook)
- if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
- # If we already put some hook on this module, we replace it with the new one.
- old_forward = module._old_forward
- else:
- old_forward = module.forward
- module._old_forward = old_forward
- module = hook.init_hook(module)
- module._hf_hook = hook
- def new_forward(module, *args, **kwargs):
- args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
- if module._hf_hook.no_grad:
- with torch.no_grad():
- output = module._old_forward(*args, **kwargs)
- else:
- output = module._old_forward(*args, **kwargs)
- return module._hf_hook.post_forward(module, output)
- # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
- # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
- if "GraphModuleImpl" in str(type(module)):
- module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
- else:
- module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
- return module
- def remove_hook_from_module(module: nn.Module, recurse=False):
- """
- Removes any hook attached to a module via `add_hook_to_module`.
- Args:
- module (`torch.nn.Module`): The module to attach a hook to.
- recurse (`bool`, **optional**): Whether to remove the hooks recursively
- Returns:
- `torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
- be discarded).
- """
- if hasattr(module, "_hf_hook"):
- module._hf_hook.detach_hook(module)
- delattr(module, "_hf_hook")
- if hasattr(module, "_old_forward"):
- # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
- # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
- if "GraphModuleImpl" in str(type(module)):
- module.__class__.forward = module._old_forward
- else:
- module.forward = module._old_forward
- delattr(module, "_old_forward")
- # Remove accelerate added warning hooks from dispatch_model
- for attr in _accelerate_added_attributes:
- module.__dict__.pop(attr, None)
- if recurse:
- for child in module.children():
- remove_hook_from_module(child, recurse)
- return module
- class AlignDevicesHook(ModelHook):
- """
- A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
- associated module, potentially offloading the weights after the forward pass.
- Args:
- execution_device (`torch.device`, *optional*):
- The device on which inputs and model weights should be placed before the forward pass.
- offload (`bool`, *optional*, defaults to `False`):
- Whether or not the weights should be offloaded after the forward pass.
- io_same_device (`bool`, *optional*, defaults to `False`):
- Whether or not the output should be placed on the same device as the input was.
- weights_map (`Mapping[str, torch.Tensor]`, *optional*):
- When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
- offload_buffers (`bool`, *optional*, defaults to `False`):
- Whether or not to include the associated module's buffers when offloading.
- place_submodules (`bool`, *optional*, defaults to `False`):
- Whether to place the submodules on `execution_device` during the `init_hook` event.
- """
- def __init__(
- self,
- execution_device: Optional[Union[int, str, torch.device]] = None,
- offload: bool = False,
- io_same_device: bool = False,
- weights_map: Optional[Mapping] = None,
- offload_buffers: bool = False,
- place_submodules: bool = False,
- skip_keys: Optional[Union[str, list[str]]] = None,
- tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
- ):
- self.execution_device = execution_device
- self.offload = offload
- self.io_same_device = io_same_device
- self.weights_map = weights_map
- self.offload_buffers = offload_buffers
- self.place_submodules = place_submodules
- self.skip_keys = skip_keys
- # Will contain the input device when `io_same_device=True`.
- self.input_device = None
- self.param_original_devices = {}
- self.buffer_original_devices = {}
- self.tied_params_names = set()
- # The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
- # for tied weights already loaded on the target execution device.
- self.tied_params_map = tied_params_map
- def __repr__(self):
- return (
- f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
- f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
- f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
- )
- def init_hook(self, module):
- # In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
- if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
- self.tied_params_map = None
- if not self.offload and self.execution_device is not None:
- for name, _ in named_module_tensors(module, recurse=self.place_submodules):
- set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
- elif self.offload:
- self.original_devices = {
- name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
- }
- if self.weights_map is None:
- self.weights_map = {
- name: param.to("cpu")
- for name, param in named_module_tensors(
- module, include_buffers=self.offload_buffers, recurse=self.place_submodules
- )
- }
- for name, _ in named_module_tensors(
- module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
- ):
- # When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
- # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
- # As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
- # to add on the fly pointers to `tied_params_map` in the pre_forward call.
- if (
- self.tied_params_map is not None
- and recursive_getattr(module, name).data_ptr() in self.tied_params_map
- ):
- self.tied_params_names.add(name)
- set_module_tensor_to_device(module, name, "meta")
- if not self.offload_buffers and self.execution_device is not None:
- for name, _ in module.named_buffers(recurse=self.place_submodules):
- set_module_tensor_to_device(
- module, name, self.execution_device, tied_params_map=self.tied_params_map
- )
- elif self.offload_buffers and self.execution_device is not None:
- for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
- set_module_tensor_to_device(
- module, name, self.execution_device, tied_params_map=self.tied_params_map
- )
- return module
- def pre_forward(self, module, *args, **kwargs):
- if self.io_same_device:
- self.input_device = find_device([args, kwargs])
- if self.offload:
- self.tied_pointers_to_remove = set()
- for name, _ in named_module_tensors(
- module,
- include_buffers=self.offload_buffers,
- recurse=self.place_submodules,
- remove_non_persistent=True,
- ):
- fp16_statistics = None
- value = self.weights_map[name]
- if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
- if value.dtype == torch.int8:
- fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
- # In case we are using offloading with tied weights, we need to keep track of the offloaded weights
- # that are loaded on device at this point, as we will need to remove them as well from the dictionary
- # self.tied_params_map in order to allow to free memory.
- if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
- self.tied_params_map[value.data_ptr()] = {}
- if (
- value is not None
- and self.tied_params_map is not None
- and value.data_ptr() in self.tied_params_map
- and self.execution_device not in self.tied_params_map[value.data_ptr()]
- ):
- self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
- set_module_tensor_to_device(
- module,
- name,
- self.execution_device,
- value=value,
- fp16_statistics=fp16_statistics,
- tied_params_map=self.tied_params_map,
- )
- return send_to_device(args, self.execution_device), send_to_device(
- kwargs, self.execution_device, skip_keys=self.skip_keys
- )
- def post_forward(self, module, output):
- if self.offload:
- for name, _ in named_module_tensors(
- module,
- include_buffers=self.offload_buffers,
- recurse=self.place_submodules,
- remove_non_persistent=True,
- ):
- set_module_tensor_to_device(module, name, "meta")
- if type(module).__name__ == "Linear8bitLt":
- module.state.SCB = None
- module.state.CxB = None
- # We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
- # this dictionary to allow the garbage collector to do its job.
- for value_pointer, device in self.tied_pointers_to_remove:
- if isinstance(device, int):
- if is_npu_available():
- device = f"npu:{device}"
- elif is_mlu_available():
- device = f"mlu:{device}"
- elif is_musa_available():
- device = f"musa:{device}"
- if device in self.tied_params_map[value_pointer]:
- del self.tied_params_map[value_pointer][device]
- self.tied_pointers_to_remove = set()
- if self.io_same_device and self.input_device is not None:
- output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
- return output
- def detach_hook(self, module):
- if self.offload:
- for name, device in self.original_devices.items():
- if device != torch.device("meta"):
- set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
- return module
- def attach_execution_device_hook(
- module: torch.nn.Module,
- execution_device: Union[int, str, torch.device],
- skip_keys: Optional[Union[str, list[str]]] = None,
- preload_module_classes: Optional[list[str]] = None,
- tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
- ):
- """
- Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
- execution device
- Args:
- module (`torch.nn.Module`):
- The module where we want to attach the hooks.
- execution_device (`int`, `str` or `torch.device`):
- The device on which inputs and model weights should be placed before the forward pass.
- skip_keys (`str` or `List[str]`, *optional*):
- A list of keys to ignore when moving inputs or outputs between devices.
- preload_module_classes (`List[str]`, *optional*):
- A list of classes whose instances should load all their weights (even in the submodules) at the beginning
- of the forward. This should only be used for classes that have submodules which are registered but not
- called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
- `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
- tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
- A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
- device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
- instead of duplicating memory.
- """
- if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
- add_hook_to_module(
- module,
- AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
- )
- # Break the recursion if we get to a preload module.
- if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
- return
- for child in module.children():
- attach_execution_device_hook(
- child,
- execution_device,
- skip_keys=skip_keys,
- preload_module_classes=preload_module_classes,
- tied_params_map=tied_params_map,
- )
- def attach_align_device_hook(
- module: torch.nn.Module,
- execution_device: Optional[torch.device] = None,
- offload: bool = False,
- weights_map: Optional[Mapping] = None,
- offload_buffers: bool = False,
- module_name: str = "",
- skip_keys: Optional[Union[str, list[str]]] = None,
- preload_module_classes: Optional[list[str]] = None,
- tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
- ):
- """
- Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
- buffers.
- Args:
- module (`torch.nn.Module`):
- The module where we want to attach the hooks.
- execution_device (`torch.device`, *optional*):
- The device on which inputs and model weights should be placed before the forward pass.
- offload (`bool`, *optional*, defaults to `False`):
- Whether or not the weights should be offloaded after the forward pass.
- weights_map (`Mapping[str, torch.Tensor]`, *optional*):
- When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
- offload_buffers (`bool`, *optional*, defaults to `False`):
- Whether or not to include the associated module's buffers when offloading.
- module_name (`str`, *optional*, defaults to `""`):
- The name of the module.
- skip_keys (`str` or `List[str]`, *optional*):
- A list of keys to ignore when moving inputs or outputs between devices.
- preload_module_classes (`List[str]`, *optional*):
- A list of classes whose instances should load all their weights (even in the submodules) at the beginning
- of the forward. This should only be used for classes that have submodules which are registered but not
- called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
- `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
- tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
- A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
- device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
- instead of duplicating memory.
- """
- # Attach the hook on this module if it has any direct tensor.
- directs = named_module_tensors(module)
- full_offload = (
- offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes
- )
- if len(list(directs)) > 0 or full_offload:
- if weights_map is not None:
- prefix = f"{module_name}." if len(module_name) > 0 else ""
- prefixed_weights_map = PrefixedDataset(weights_map, prefix)
- else:
- prefixed_weights_map = None
- hook = AlignDevicesHook(
- execution_device=execution_device,
- offload=offload,
- weights_map=prefixed_weights_map,
- offload_buffers=offload_buffers,
- place_submodules=full_offload,
- skip_keys=skip_keys,
- tied_params_map=tied_params_map,
- )
- add_hook_to_module(module, hook, append=True)
- # We stop the recursion in case we hit the full offload.
- if full_offload:
- return
- # Recurse on all children of the module.
- for child_name, child in module.named_children():
- child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
- attach_align_device_hook(
- child,
- execution_device=execution_device,
- offload=offload,
- weights_map=weights_map,
- offload_buffers=offload_buffers,
- module_name=child_name,
- preload_module_classes=preload_module_classes,
- skip_keys=skip_keys,
- tied_params_map=tied_params_map,
- )
- def remove_hook_from_submodules(module: nn.Module):
- """
- Recursively removes all hooks attached on the submodules of a given model.
- Args:
- module (`torch.nn.Module`): The module on which to remove all hooks.
- """
- remove_hook_from_module(module)
- for child in module.children():
- remove_hook_from_submodules(child)
- def attach_align_device_hook_on_blocks(
- module: nn.Module,
- execution_device: Optional[Union[torch.device, dict[str, torch.device]]] = None,
- offload: Union[bool, dict[str, bool]] = False,
- weights_map: Optional[Mapping] = None,
- offload_buffers: bool = False,
- module_name: str = "",
- skip_keys: Optional[Union[str, list[str]]] = None,
- preload_module_classes: Optional[list[str]] = None,
- tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
- ):
- """
- Attaches `AlignDevicesHook` to all blocks of a given model as needed.
- Args:
- module (`torch.nn.Module`):
- The module where we want to attach the hooks.
- execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
- The device on which inputs and model weights should be placed before the forward pass. It can be one device
- for the whole module, or a dictionary mapping module name to device.
- offload (`bool`, *optional*, defaults to `False`):
- Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
- module, or a dictionary mapping module name to boolean.
- weights_map (`Mapping[str, torch.Tensor]`, *optional*):
- When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
- offload_buffers (`bool`, *optional*, defaults to `False`):
- Whether or not to include the associated module's buffers when offloading.
- module_name (`str`, *optional*, defaults to `""`):
- The name of the module.
- skip_keys (`str` or `List[str]`, *optional*):
- A list of keys to ignore when moving inputs or outputs between devices.
- preload_module_classes (`List[str]`, *optional*):
- A list of classes whose instances should load all their weights (even in the submodules) at the beginning
- of the forward. This should only be used for classes that have submodules which are registered but not
- called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
- `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
- tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
- A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
- device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
- instead of duplicating memory.
- """
- # If one device and one offload, we've got one hook.
- if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
- if not offload:
- hook = AlignDevicesHook(
- execution_device=execution_device,
- io_same_device=True,
- skip_keys=skip_keys,
- place_submodules=True,
- tied_params_map=tied_params_map,
- )
- add_hook_to_module(module, hook)
- else:
- attach_align_device_hook(
- module,
- execution_device=execution_device,
- offload=True,
- weights_map=weights_map,
- offload_buffers=offload_buffers,
- module_name=module_name,
- skip_keys=skip_keys,
- tied_params_map=tied_params_map,
- )
- return
- if not isinstance(execution_device, Mapping):
- execution_device = {key: execution_device for key in offload.keys()}
- if not isinstance(offload, Mapping):
- offload = {key: offload for key in execution_device.keys()}
- if module_name in execution_device and module_name in offload and not offload[module_name]:
- hook = AlignDevicesHook(
- execution_device=execution_device[module_name],
- offload_buffers=offload_buffers,
- io_same_device=(module_name == ""),
- place_submodules=True,
- skip_keys=skip_keys,
- tied_params_map=tied_params_map,
- )
- add_hook_to_module(module, hook)
- attach_execution_device_hook(
- module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map
- )
- elif module_name in execution_device and module_name in offload:
- attach_align_device_hook(
- module,
- execution_device=execution_device[module_name],
- offload=True,
- weights_map=weights_map,
- offload_buffers=offload_buffers,
- module_name=module_name,
- skip_keys=skip_keys,
- preload_module_classes=preload_module_classes,
- tied_params_map=tied_params_map,
- )
- if not hasattr(module, "_hf_hook"):
- hook = AlignDevicesHook(
- execution_device=execution_device[module_name],
- io_same_device=(module_name == ""),
- skip_keys=skip_keys,
- tied_params_map=tied_params_map,
- )
- add_hook_to_module(module, hook)
- attach_execution_device_hook(
- module,
- execution_device[module_name],
- preload_module_classes=preload_module_classes,
- skip_keys=skip_keys,
- tied_params_map=tied_params_map,
- )
- elif module_name == "":
- hook = AlignDevicesHook(
- execution_device=execution_device.get(""),
- io_same_device=True,
- skip_keys=skip_keys,
- tied_params_map=tied_params_map,
- )
- add_hook_to_module(module, hook)
- for child_name, child in module.named_children():
- child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
- attach_align_device_hook_on_blocks(
- child,
- execution_device=execution_device,
- offload=offload,
- weights_map=weights_map,
- offload_buffers=offload_buffers,
- module_name=child_name,
- preload_module_classes=preload_module_classes,
- skip_keys=skip_keys,
- tied_params_map=tied_params_map,
- )
- class CpuOffload(ModelHook):
- """
- Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
- the forward, the user needs to call the `init_hook` method again for this.
- Args:
- execution_device(`str`, `int` or `torch.device`, *optional*):
- The device on which the model should be executed. Will default to the MPS device if it's available, then
- GPU 0 if there is a GPU, and finally to the CPU.
- prev_module_hook (`UserCpuOffloadHook`, *optional*):
- The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
- passed, its offload method will be called just before the forward of the model to which this hook is
- attached.
- """
- def __init__(
- self,
- execution_device: Optional[Union[str, int, torch.device]] = None,
- prev_module_hook: Optional["UserCpuOffloadHook"] = None,
- ):
- self.prev_module_hook = prev_module_hook
- self.execution_device = execution_device if execution_device is not None else PartialState().default_device
- def init_hook(self, module):
- return module.to("cpu")
- def pre_forward(self, module, *args, **kwargs):
- if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
- prev_module = self.prev_module_hook.model
- prev_device = next(prev_module.parameters()).device
- # Only offload the previous module if it is not already on CPU.
- if prev_device != torch.device("cpu"):
- self.prev_module_hook.offload()
- clear_device_cache()
- # If the current device is already the self.execution_device, we can skip the transfer.
- current_device = next(module.parameters()).device
- if current_device == self.execution_device:
- return args, kwargs
- module.to(self.execution_device)
- return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
- class UserCpuOffloadHook:
- """
- A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
- or remove it entirely.
- """
- def __init__(self, model, hook):
- self.model = model
- self.hook = hook
- def offload(self):
- self.hook.init_hook(self.model)
- def remove(self):
- remove_hook_from_module(self.model)
- class LayerwiseCastingHook(ModelHook):
- r"""
- A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
- for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
- footprint.
- """
- _is_stateful = False
- def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
- self.storage_dtype = storage_dtype
- self.compute_dtype = compute_dtype
- self.non_blocking = non_blocking
- def init_hook(self, module: torch.nn.Module):
- module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
- return module
- def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
- module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
- return args, kwargs
- def post_forward(self, module: torch.nn.Module, output):
- module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
- return output
|