| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from abc import ABC, abstractmethod
- from typing import TYPE_CHECKING, Any, Optional, Union
- from ..utils import is_torch_available, logging
- from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod
- from .quantizers_utils import get_module_from_name
- if TYPE_CHECKING:
- from ..modeling_utils import PreTrainedModel
- if is_torch_available():
- import torch
- from torch.nn import ModuleList
- else:
- ModuleList = str
- logger = logging.get_logger(__file__)
- class HfQuantizer(ABC):
- """
- Abstract class of the HuggingFace quantizer. Supports for now quantizing HF transformers models for inference and/or quantization.
- This class is used only for transformers.PreTrainedModel.from_pretrained and cannot be easily used outside the scope of that method
- yet.
- Attributes
- quantization_config (`transformers.utils.quantization_config.QuantizationConfigMixin`):
- The quantization config that defines the quantization parameters of your model that you want to quantize.
- modules_to_not_convert (`list[str]`, *optional*):
- The list of module names to not convert when quantizing the model.
- required_packages (`list[str]`, *optional*):
- The list of required pip packages to install prior to using the quantizer
- requires_calibration (`bool`):
- Whether the quantization method requires to calibrate the model before using it.
- requires_parameters_quantization (`bool`):
- Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is
- required to create a new xxxParameter in order to properly quantize the model.
- """
- requires_calibration = False
- required_packages = None
- requires_parameters_quantization = False
- def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
- self.quantization_config = quantization_config
- # -- Handle extra kwargs below --
- self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
- self.pre_quantized = kwargs.pop("pre_quantized", True)
- if not self.pre_quantized and self.requires_calibration:
- raise ValueError(
- f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized."
- f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to "
- f"pass `pre_quantized=True` while knowing what you are doing."
- )
- def update_torch_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
- """
- Deprecared in favor of `update_dtype`!
- Args:
- dtype (`torch.dtype`):
- The input dtype that is passed in `from_pretrained`
- """
- logger.warning_once(
- "`update_torch_dtype` is deprecated in favor of `update_dtype`! It will be removed in version v4.57"
- )
- return self.update_dtype(dtype)
- def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
- """
- Some quantization methods require to explicitly set the dtype of the model to a
- target dtype. You need to override this method in case you want to make sure that behavior is
- preserved
- Args:
- dtype (`torch.dtype`):
- The input dtype that is passed in `from_pretrained`
- """
- return dtype
- def update_device_map(self, device_map: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
- """
- Override this method if you want to pass a override the existing device map with a new
- one. E.g. for bitsandbytes, since `accelerate` is a hard requirement, if no device_map is
- passed, the device_map is set to `"auto"``
- Args:
- device_map (`Union[dict, str]`, *optional*):
- The device_map that is passed through the `from_pretrained` method.
- """
- return device_map
- def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
- """
- Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained`
- to compute the device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype`
- to `torch.int8` and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`.
- Args:
- dtype (`torch.dtype`, *optional*):
- The dtype that is used to compute the device_map.
- """
- return dtype
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
- """
- Override this method if you want to adjust the `missing_keys`.
- Args:
- missing_keys (`list[str]`, *optional*):
- The list of missing keys in the checkpoint compared to the state dict of the model
- """
- return missing_keys
- def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]:
- """
- Override this method if you want to adjust the `update_expected_keys`.
- Args:
- expected_keys (`list[str]`, *optional*):
- The list of the expected keys in the initialized model.
- loaded_keys (`list[str]`, *optional*):
- The list of the loaded keys in the checkpoint.
- """
- return expected_keys
- def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
- return unexpected_keys
- def get_special_dtypes_update(self, model, dtype: "torch.dtype") -> dict[str, "torch.dtype"]:
- """
- returns dtypes for modules that are not quantized - used for the computation of the device_map in case
- one passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified
- in `_process_model_before_weight_loading`.
- Args:
- model (`~transformers.PreTrainedModel`):
- The model to quantize
- dtype (`torch.dtype`):
- The dtype passed in `from_pretrained` method.
- """
- return {
- name: dtype for name, _ in model.named_parameters() if any(m in name for m in self.modules_to_not_convert)
- }
- def adjust_max_memory(self, max_memory: dict[str, Union[int, str]]) -> dict[str, Union[int, str]]:
- """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
- return max_memory
- def check_quantized_param(self, *args, **kwargs) -> bool:
- """DEPRECATED -> remove in v5"""
- logger.warning_once(
- "`check_quantized_param` is deprecated in favor of `param_needs_quantization`, which is a much "
- "more self.explanatory name for what the method achieves. It will be removed in v5"
- )
- return self.param_needs_quantization(*args, **kwargs)
- def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
- """
- Check whether a given param needs quantization as defined by `create_quantized_param`.
- """
- return False
- def create_quantized_param(self, *args, **kwargs):
- """
- Take needed components from state_dict (those from which `param_needs_quantization` is True) and create
- quantized param.
- It usually also load the new param directly in the `model`.
- Note: only applicable if requires_parameters_quantization == True.
- """
- if not self.requires_parameters_quantization:
- raise AttributeError(
- f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
- )
- def validate_environment(self, *args, **kwargs):
- """
- This method is used to potentially check for potential conflicts with arguments that are
- passed in `from_pretrained`. You need to define it for all future quantizers that are integrated with transformers.
- If no explicit check are needed, simply return nothing.
- """
- return
- def update_tp_plan(self, config):
- "updates the tp plan for the scales"
- return config
- def update_ep_plan(self, config):
- "updates the tp plan for the scales"
- return config
- def preprocess_model(self, model: "PreTrainedModel", **kwargs):
- """
- Setting model attributes and/or converting model before weights loading. At this point
- the model should be initialized on the meta device so you can freely manipulate the skeleton
- of the model in order to replace modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`.
- Args:
- model (`~transformers.PreTrainedModel`):
- The model to quantize
- kwargs (`dict`, *optional*):
- The keyword arguments that are passed along `_process_model_before_weight_loading`.
- """
- model.is_quantized = True
- model.quantization_method = self.quantization_config.quant_method
- if self.pre_quantized:
- self._convert_model_for_quantization(model)
- return self._process_model_before_weight_loading(model, **kwargs)
- def postprocess_model(self, model: "PreTrainedModel", **kwargs):
- """
- Post-process the model post weights loading.
- Make sure to override the abstract method `_process_model_after_weight_loading`.
- Args:
- model (`~transformers.PreTrainedModel`):
- The model to quantize
- kwargs (`dict`, *optional*):
- The keyword arguments that are passed along `_process_model_after_weight_loading`.
- """
- return self._process_model_after_weight_loading(model, **kwargs)
- def remove_quantization_config(self, model):
- """
- Remove the quantization config from the model.
- """
- if hasattr(model, "hf_quantizer"):
- del model.hf_quantizer
- if hasattr(model.config, "quantization_config"):
- del model.config.quantization_config
- if hasattr(model.config, "_pre_quantization_dtype"):
- del model.config._pre_quantization_dtype
- if hasattr(model, "quantization_method"):
- del model.quantization_method
- model.is_quantized = False
- def dequantize(self, model):
- """
- Potentially dequantize the model to retrieve the original model, with some loss in accuracy / performance.
- Note not all quantization schemes support this.
- """
- model = self._dequantize(model)
- # Delete quantizer and quantization config
- del model.hf_quantizer
- del model.config.quantization_config
- del model.config._pre_quantization_dtype
- del model.quantization_method
- model.is_quantized = False
- return model
- def get_accelerator_warm_up_factor(self):
- """
- The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up accelerator.
- A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
- we allocate half the memory of the weights residing in the empty model, etc...
- """
- # By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
- # really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
- # weight loading)
- return 4
- def _dequantize(self, model):
- raise NotImplementedError(
- f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
- )
- def get_param_name(self, param_name: str) -> str:
- """
- Override this method if you want to adjust the `param_name`.
- """
- return param_name
- @staticmethod
- def get_modules_to_not_convert(
- model: "PreTrainedModel",
- skip_modules: Optional[list[str]] = None,
- keep_in_fp32_modules: Optional[list[str]] = None,
- add_default_skips: bool = False,
- ):
- from ..integrations import get_keys_to_not_convert
- if skip_modules is None or add_default_skips:
- modules_to_not_convert = get_keys_to_not_convert(model)
- else:
- modules_to_not_convert = []
- if skip_modules is not None:
- modules_to_not_convert.extend(skip_modules)
- if keep_in_fp32_modules is not None:
- modules_to_not_convert.extend(keep_in_fp32_modules)
- return modules_to_not_convert
- @property
- def is_qat_trainable(self) -> bool:
- """Flag indicating whether the quantized model can carry out quantization aware training"""
- return False
- @property
- def is_compileable(self) -> bool:
- """Flag indicating whether the quantized model can be compiled"""
- return False
- def get_state_dict_and_metadata(self, model, safe_serialization=False):
- """Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
- return None, {}
- def update_state_dict_with_metadata(self, state_dict, metadata):
- """Update state dict with metadata. Default behaviour returns state_dict"""
- return state_dict
- @abstractmethod
- def _process_model_before_weight_loading(self, model, **kwargs): ...
- @abstractmethod
- def _process_model_after_weight_loading(self, model, **kwargs): ...
- @abstractmethod
- def is_serializable(self, safe_serialization=None): ...
- @property
- @abstractmethod
- def is_trainable(self): ...
- def _convert_model_for_quantization(self, model):
- from accelerate import init_empty_weights
- for name, module in model.named_modules():
- module_class_name = module.__class__.__name__
- if module_class_name in MODULES_TO_PATCH_FOR_QUANTIZATION and (
- self.quantization_config.quant_method
- in MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name]["quantization_methods"]
- ):
- with init_empty_weights():
- parent_module, name = get_module_from_name(model, name)
- parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name]["module_name"](
- model.config.get_text_config()
- )
- class SequentialLlama4TextExperts(ModuleList):
- """
- A module that implements a compressed version of a list of expert modules.
- This is specifically designed to work with Llama4TextExperts in MoE layers.
- """
- def __init__(self, config):
- from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
- super().__init__([Llama4TextMLP(config) for _ in range(config.num_local_experts)])
- self.num_experts = config.num_local_experts
- def forward(
- self,
- hidden_states: "torch.Tensor",
- ) -> "torch.Tensor":
- hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
- routed_out = torch.zeros_like(hidden_states)
- for expert_idx in range(self.num_experts):
- routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx])
- return routed_out
- MODULES_TO_PATCH_FOR_QUANTIZATION = {
- "Llama4TextExperts": {
- "module_name": SequentialLlama4TextExperts,
- "quantization_methods": [
- QuantizationMethod.COMPRESSED_TENSORS,
- QuantizationMethod.BITS_AND_BYTES,
- ],
- }
- }
|