quantizer_fp_quant.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Optional
  15. from .base import HfQuantizer
  16. from .quantizers_utils import get_module_from_name
  17. if TYPE_CHECKING:
  18. from ..modeling_utils import PreTrainedModel
  19. from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, logging
  20. from ..utils.quantization_config import QuantizationConfigMixin
  21. if is_torch_available():
  22. import torch
  23. logger = logging.get_logger(__name__)
  24. class FPQuantHfQuantizer(HfQuantizer):
  25. """
  26. Quantizer for the FP-Quant method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
  27. """
  28. requires_calibration = False
  29. requires_parameters_quantization = True
  30. is_qat_trainable = True
  31. required_packages = ["fp_quant"]
  32. def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
  33. super().__init__(quantization_config, **kwargs)
  34. self.quantization_config = quantization_config
  35. def validate_environment(self, device_map, **kwargs):
  36. if not torch.cuda.is_available():
  37. raise NotImplementedError(
  38. "FPQuant quantization is only supported on GPU. Please use a different quantizer."
  39. )
  40. if not is_qutlass_available() and not self.quantization_config.pseudoquantization:
  41. raise ImportError(
  42. "Using `fp_quant` with real quantization requires a **Blackwell GPU** and qutlass: `git clone https://github.com/IST-DASLab/qutlass.git && cd qutlass && pip install --no-build-isolation .`. You can use `FPQuantConfig(pseudoquantization=True, ...)` to use Triton-based pseudo-quantization. It doesn't provide any speedups but emulates the quantization behavior of the real quantization."
  43. )
  44. if self.quantization_config.pseudoquantization:
  45. logger.warning(
  46. "Using pseudo-quantization for FP-Quant. This doesn't provide any speedups but emulates the quantization behavior of the real quantization."
  47. )
  48. if not is_fp_quant_available():
  49. raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`")
  50. if device_map is None and not self.quantization_config.pseudoquantization:
  51. raise ValueError(
  52. "You are attempting to load a FPQuant model without setting device_map."
  53. " Please set device_map comprised of 'cuda' devices."
  54. )
  55. elif (
  56. isinstance(device_map, dict)
  57. and ("cpu" in device_map.values() or "disk" in device_map.values())
  58. and not self.quantization_config.pseudoquantization
  59. ):
  60. raise ValueError(
  61. "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
  62. " This is not supported. Please remove the CPU or disk device from the device_map."
  63. )
  64. def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  65. if dtype is None:
  66. logger.info("`dtype` is None. Setting `dtype=torch.bfloat16` for qutlass compatibility.")
  67. dtype = torch.bfloat16
  68. elif dtype != torch.bfloat16:
  69. raise ValueError(f"Invalid `dtype` {dtype}. fp_quant quantization only supports `dtype=torch.bfloat16`.")
  70. return dtype
  71. def create_quantized_param(
  72. self,
  73. model: "PreTrainedModel",
  74. param_value: "torch.Tensor",
  75. param_name: str,
  76. target_device: "torch.device",
  77. **kwargs,
  78. ):
  79. module, _ = get_module_from_name(model, param_name)
  80. # The module holds either:
  81. # * `weight` when `store_master_weights=True`
  82. # * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
  83. # * `dqweight` when `store_master_weights=False` and `pseudoquantization=True`
  84. if param_name.endswith(".qweight"):
  85. # Loading a real quantized checkpoint without master weights
  86. module.qweight = torch.nn.Parameter(
  87. param_value.to(target_device),
  88. requires_grad=False,
  89. )
  90. module.weight = None
  91. module.dqweight = None
  92. return
  93. if param_name.endswith(".dqweight"):
  94. # Loading a pseudo-quantized checkpoint without master weights
  95. module.dqweight = torch.nn.Parameter(param_value.to(target_device))
  96. module.weight = None
  97. module.qweight = None
  98. module.scales = None
  99. return
  100. # Loading master weights or an unquantized checkpoint
  101. module.weight = torch.nn.Parameter(param_value.to(target_device))
  102. # Let pre-forward handle the quantization and set None where necessary
  103. module.pre_forward()
  104. def _process_model_before_weight_loading(
  105. self,
  106. model: "PreTrainedModel",
  107. **kwargs,
  108. ):
  109. from fp_quant import replace_with_fp_quant_linear
  110. from ..integrations.fp_quant import adapt_fp_quant_config
  111. replace_with_fp_quant_linear(
  112. model,
  113. fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
  114. )
  115. model.config.quantization_config = self.quantization_config
  116. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  117. return model
  118. def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
  119. from fp_quant import FPQuantLinear
  120. fp_quant_names = {name for name, module in model.named_modules() if isinstance(module, FPQuantLinear)}
  121. def should_exclude(key: str) -> bool:
  122. if key.endswith(".weight") or key.endswith(".bias"):
  123. return False
  124. full_key = f"{prefix}.{key}"
  125. return any(name in key or name in full_key for name in fp_quant_names)
  126. return [key for key in missing_keys if not should_exclude(key)]
  127. @property
  128. def is_trainable(self, model: Optional["PreTrainedModel"] = None):
  129. trainable = self.quantization_config.store_master_weights
  130. if not trainable:
  131. logger.warning(
  132. "You are attempting to train a model with FPQuant quantization. This is only supported when `store_master_weights=True`. Please set `store_master_weights=True` to train the model."
  133. )
  134. return trainable
  135. def is_serializable(self, safe_serialization=None):
  136. return True
  137. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  138. from fp_quant import FPQuantLinear
  139. module, tensor_name = get_module_from_name(model, param_name)
  140. if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]:
  141. # Only quantize weights of FPQuantLinear modules that are not already quantized
  142. return True
  143. else:
  144. return False