auto.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. # Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import warnings
  16. from typing import Optional, Union
  17. from ..models.auto.configuration_auto import AutoConfig
  18. from ..utils import logging
  19. from ..utils.quantization_config import (
  20. AqlmConfig,
  21. AutoRoundConfig,
  22. AwqConfig,
  23. BitNetQuantConfig,
  24. BitsAndBytesConfig,
  25. CompressedTensorsConfig,
  26. EetqConfig,
  27. FbgemmFp8Config,
  28. FineGrainedFP8Config,
  29. FPQuantConfig,
  30. GPTQConfig,
  31. HiggsConfig,
  32. HqqConfig,
  33. Mxfp4Config,
  34. QuantizationConfigMixin,
  35. QuantizationMethod,
  36. QuantoConfig,
  37. QuarkConfig,
  38. SpQRConfig,
  39. TorchAoConfig,
  40. VptqConfig,
  41. )
  42. from .base import HfQuantizer
  43. from .quantizer_aqlm import AqlmHfQuantizer
  44. from .quantizer_auto_round import AutoRoundQuantizer
  45. from .quantizer_awq import AwqQuantizer
  46. from .quantizer_bitnet import BitNetHfQuantizer
  47. from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
  48. from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
  49. from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
  50. from .quantizer_eetq import EetqHfQuantizer
  51. from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
  52. from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
  53. from .quantizer_fp_quant import FPQuantHfQuantizer
  54. from .quantizer_gptq import GptqHfQuantizer
  55. from .quantizer_higgs import HiggsHfQuantizer
  56. from .quantizer_hqq import HqqHfQuantizer
  57. from .quantizer_mxfp4 import Mxfp4HfQuantizer
  58. from .quantizer_quanto import QuantoHfQuantizer
  59. from .quantizer_quark import QuarkHfQuantizer
  60. from .quantizer_spqr import SpQRHfQuantizer
  61. from .quantizer_torchao import TorchAoHfQuantizer
  62. from .quantizer_vptq import VptqHfQuantizer
  63. AUTO_QUANTIZER_MAPPING = {
  64. "awq": AwqQuantizer,
  65. "bitsandbytes_4bit": Bnb4BitHfQuantizer,
  66. "bitsandbytes_8bit": Bnb8BitHfQuantizer,
  67. "gptq": GptqHfQuantizer,
  68. "aqlm": AqlmHfQuantizer,
  69. "quanto": QuantoHfQuantizer,
  70. "quark": QuarkHfQuantizer,
  71. "fp_quant": FPQuantHfQuantizer,
  72. "eetq": EetqHfQuantizer,
  73. "higgs": HiggsHfQuantizer,
  74. "hqq": HqqHfQuantizer,
  75. "compressed-tensors": CompressedTensorsHfQuantizer,
  76. "fbgemm_fp8": FbgemmFp8HfQuantizer,
  77. "torchao": TorchAoHfQuantizer,
  78. "bitnet": BitNetHfQuantizer,
  79. "vptq": VptqHfQuantizer,
  80. "spqr": SpQRHfQuantizer,
  81. "fp8": FineGrainedFP8HfQuantizer,
  82. "auto-round": AutoRoundQuantizer,
  83. "mxfp4": Mxfp4HfQuantizer,
  84. }
  85. AUTO_QUANTIZATION_CONFIG_MAPPING = {
  86. "awq": AwqConfig,
  87. "bitsandbytes_4bit": BitsAndBytesConfig,
  88. "bitsandbytes_8bit": BitsAndBytesConfig,
  89. "eetq": EetqConfig,
  90. "gptq": GPTQConfig,
  91. "aqlm": AqlmConfig,
  92. "quanto": QuantoConfig,
  93. "quark": QuarkConfig,
  94. "fp_quant": FPQuantConfig,
  95. "hqq": HqqConfig,
  96. "compressed-tensors": CompressedTensorsConfig,
  97. "fbgemm_fp8": FbgemmFp8Config,
  98. "higgs": HiggsConfig,
  99. "torchao": TorchAoConfig,
  100. "bitnet": BitNetQuantConfig,
  101. "vptq": VptqConfig,
  102. "spqr": SpQRConfig,
  103. "fp8": FineGrainedFP8Config,
  104. "auto-round": AutoRoundConfig,
  105. "mxfp4": Mxfp4Config,
  106. }
  107. logger = logging.get_logger(__name__)
  108. class AutoQuantizationConfig:
  109. """
  110. The Auto-HF quantization config class that takes care of automatically dispatching to the correct
  111. quantization config given a quantization config stored in a dictionary.
  112. """
  113. @classmethod
  114. def from_dict(cls, quantization_config_dict: dict):
  115. quant_method = quantization_config_dict.get("quant_method")
  116. # We need a special care for bnb models to make sure everything is BC ..
  117. if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
  118. suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
  119. quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
  120. elif quant_method is None:
  121. raise ValueError(
  122. "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
  123. )
  124. if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING:
  125. raise ValueError(
  126. f"Unknown quantization type, got {quant_method} - supported types are:"
  127. f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
  128. )
  129. target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
  130. return target_cls.from_dict(quantization_config_dict)
  131. @classmethod
  132. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  133. model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
  134. if getattr(model_config, "quantization_config", None) is None:
  135. raise ValueError(
  136. f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
  137. )
  138. quantization_config_dict = model_config.quantization_config
  139. quantization_config = cls.from_dict(quantization_config_dict)
  140. # Update with potential kwargs that are passed through from_pretrained.
  141. quantization_config.update(**kwargs)
  142. return quantization_config
  143. class AutoHfQuantizer:
  144. """
  145. The Auto-HF quantizer class that takes care of automatically instantiating to the correct
  146. `HfQuantizer` given the `QuantizationConfig`.
  147. """
  148. @classmethod
  149. def from_config(cls, quantization_config: Union[QuantizationConfigMixin, dict], **kwargs):
  150. # Convert it to a QuantizationConfig if the q_config is a dict
  151. if isinstance(quantization_config, dict):
  152. quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
  153. quant_method = quantization_config.quant_method
  154. # Again, we need a special care for bnb as we have a single quantization config
  155. # class for both 4-bit and 8-bit quantization
  156. if quant_method == QuantizationMethod.BITS_AND_BYTES:
  157. if quantization_config.load_in_8bit:
  158. quant_method += "_8bit"
  159. else:
  160. quant_method += "_4bit"
  161. if quant_method not in AUTO_QUANTIZER_MAPPING:
  162. raise ValueError(
  163. f"Unknown quantization type, got {quant_method} - supported types are:"
  164. f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
  165. )
  166. target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
  167. return target_cls(quantization_config, **kwargs)
  168. @classmethod
  169. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  170. quantization_config = AutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
  171. return cls.from_config(quantization_config)
  172. @classmethod
  173. def merge_quantization_configs(
  174. cls,
  175. quantization_config: Union[dict, QuantizationConfigMixin],
  176. quantization_config_from_args: Optional[QuantizationConfigMixin],
  177. ):
  178. """
  179. handles situations where both quantization_config from args and quantization_config from model config are present.
  180. """
  181. if quantization_config_from_args is not None:
  182. warning_msg = (
  183. "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
  184. " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
  185. )
  186. else:
  187. warning_msg = ""
  188. if isinstance(quantization_config, dict):
  189. # Convert the config based on the type of quantization_config_from_args (e.g., AutoRoundConfig), which takes priority before automatic configuration dispatch.
  190. if isinstance(quantization_config_from_args, AutoRoundConfig):
  191. quantization_config = AutoRoundConfig.from_dict(quantization_config)
  192. else:
  193. quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
  194. if (
  195. quantization_config_from_args is not None
  196. and quantization_config.__class__.__name__ != quantization_config_from_args.__class__.__name__
  197. ):
  198. raise ValueError(
  199. f"The model is quantized with {quantization_config.__class__.__name__} but you are passing a {quantization_config_from_args.__class__.__name__} config. "
  200. "Please make sure to pass the same quantization config class to `from_pretrained` with different loading attributes."
  201. )
  202. if (
  203. isinstance(
  204. quantization_config,
  205. (GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig, Mxfp4Config),
  206. )
  207. and quantization_config_from_args is not None
  208. ):
  209. loading_attr_dict = quantization_config_from_args.get_loading_attributes()
  210. for attr, val in loading_attr_dict.items():
  211. setattr(quantization_config, attr, val)
  212. warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
  213. if warning_msg != "" and not isinstance(quantization_config, Mxfp4Config):
  214. warnings.warn(warning_msg)
  215. else:
  216. # in the case of mxfp4, we don't want to print the warning message, bit confusing for users
  217. logger.info(warning_msg)
  218. return quantization_config
  219. @staticmethod
  220. def supports_quant_method(quantization_config_dict):
  221. quant_method = quantization_config_dict.get("quant_method", None)
  222. if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
  223. suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
  224. quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
  225. elif quant_method is None:
  226. raise ValueError(
  227. "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
  228. )
  229. if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING:
  230. logger.warning(
  231. f"Unknown quantization type, got {quant_method} - supported types are:"
  232. f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. "
  233. "To remove the warning, you can delete the quantization_config attribute in config.json"
  234. )
  235. return False
  236. return True
  237. def register_quantization_config(method: str):
  238. """Register a custom quantization configuration."""
  239. def register_config_fn(cls):
  240. if method in AUTO_QUANTIZATION_CONFIG_MAPPING:
  241. raise ValueError(f"Config '{method}' already registered")
  242. if not issubclass(cls, QuantizationConfigMixin):
  243. raise TypeError("Config must extend QuantizationConfigMixin")
  244. AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls
  245. return cls
  246. return register_config_fn
  247. def register_quantizer(name: str):
  248. """Register a custom quantizer."""
  249. def register_quantizer_fn(cls):
  250. if name in AUTO_QUANTIZER_MAPPING:
  251. raise ValueError(f"Quantizer '{name}' already registered")
  252. if not issubclass(cls, HfQuantizer):
  253. raise ValueError("Quantizer must extend HfQuantizer")
  254. AUTO_QUANTIZER_MAPPING[name] = cls
  255. return cls
  256. return register_quantizer_fn
  257. def get_hf_quantizer(config, quantization_config, dtype, from_tf, from_flax, device_map, weights_only, user_agent):
  258. pre_quantized = hasattr(config, "quantization_config")
  259. if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
  260. pre_quantized = False
  261. if pre_quantized or quantization_config is not None:
  262. if pre_quantized:
  263. config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
  264. config.quantization_config, quantization_config
  265. )
  266. else:
  267. config.quantization_config = quantization_config
  268. hf_quantizer = AutoHfQuantizer.from_config(
  269. config.quantization_config,
  270. pre_quantized=pre_quantized,
  271. )
  272. else:
  273. hf_quantizer = None
  274. if hf_quantizer is not None:
  275. hf_quantizer.validate_environment(
  276. dtype=dtype,
  277. from_tf=from_tf,
  278. from_flax=from_flax,
  279. device_map=device_map,
  280. weights_only=weights_only,
  281. )
  282. dtype = hf_quantizer.update_dtype(dtype)
  283. device_map = hf_quantizer.update_device_map(device_map)
  284. config = hf_quantizer.update_tp_plan(config)
  285. config = hf_quantizer.update_ep_plan(config)
  286. # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
  287. if not getattr(hf_quantizer.quantization_config, "dequantize", False):
  288. quant_method = hf_quantizer.quantization_config.quant_method
  289. user_agent["quant"] = getattr(quant_method, "value", quant_method)
  290. return hf_quantizer, config, dtype, device_map