quantizer_torchao.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import re
  16. import types
  17. from collections import defaultdict
  18. from typing import TYPE_CHECKING, Optional, Union
  19. from packaging import version
  20. from .base import HfQuantizer
  21. from .quantizers_utils import get_module_from_name
  22. if TYPE_CHECKING:
  23. from ..modeling_utils import PreTrainedModel
  24. from safetensors import safe_open
  25. from ..utils import is_torch_available, is_torchao_available, logging
  26. if is_torch_available():
  27. import torch
  28. import torch.nn as nn
  29. if is_torchao_available():
  30. import torchao
  31. if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
  32. from torchao.prototype.safetensors.safetensors_support import (
  33. flatten_tensor_state_dict,
  34. unflatten_tensor_state_dict,
  35. )
  36. from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
  37. logger = logging.get_logger(__name__)
  38. def fuzzy_match_size(config_name: str) -> Optional[str]:
  39. """
  40. Extract the size digit from strings like "4weight", "8weight".
  41. Returns the digit as an integer if found, otherwise None.
  42. """
  43. config_name = config_name.lower()
  44. str_match = re.search(r"(\d)weight", config_name)
  45. if str_match:
  46. return str_match.group(1)
  47. return None
  48. def _quantization_type(weight):
  49. from torchao.dtypes import AffineQuantizedTensor
  50. from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
  51. if isinstance(weight, AffineQuantizedTensor):
  52. return f"{weight.__class__.__name__}({weight._quantization_type()})"
  53. if isinstance(weight, LinearActivationQuantizedTensor):
  54. return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
  55. def _linear_extra_repr(self):
  56. weight = _quantization_type(self.weight)
  57. if weight is None:
  58. return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
  59. else:
  60. return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
  61. if is_torchao_available():
  62. SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
  63. torchao.quantization.Float8WeightOnlyConfig,
  64. torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
  65. ]
  66. TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
  67. class TorchAoHfQuantizer(HfQuantizer):
  68. """
  69. Quantizer for torchao: https://github.com/pytorch/ao/
  70. """
  71. requires_parameters_quantization = True
  72. requires_calibration = False
  73. required_packages = ["torchao"]
  74. def __init__(self, quantization_config, **kwargs):
  75. super().__init__(quantization_config, **kwargs)
  76. if isinstance(self.quantization_config.quant_type, str):
  77. is_int_4 = "int4" in self.quantization_config.quant_type
  78. else:
  79. config_name = self.quantization_config.quant_type.__class__.__name__
  80. is_int_4 = fuzzy_match_size(config_name) == "4"
  81. # TODO: better way to get the serialized key names? Hard to read from torchao codebase
  82. if is_int_4:
  83. self.weight_ao_keys = ["qdata", "scale", "zero_point"]
  84. else:
  85. self.weight_ao_keys = ["qdata", "scale"]
  86. # Instead of serializing the simple torch.Tensor like usual, torchao adds a `:_data` suffix so we need this
  87. self.full_ao_keys = self.weight_ao_keys + ["_data"]
  88. def validate_environment(self, *args, **kwargs):
  89. if not is_torchao_available():
  90. raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
  91. self.offload = False
  92. device_map = kwargs.get("device_map")
  93. if isinstance(device_map, dict):
  94. if ("disk" in device_map.values() or "cpu" in device_map.values()) and len(device_map) > 1:
  95. self.offload = True
  96. if self.pre_quantized and "disk" in device_map.values():
  97. raise ValueError(
  98. "You are attempting to perform disk offload with a pre-quantized torchao model "
  99. "This is not supported yet . Please remove the disk device from the device_map."
  100. )
  101. if self.pre_quantized:
  102. weights_only = kwargs.get("weights_only")
  103. if weights_only:
  104. torch_version = version.parse(importlib.metadata.version("torch"))
  105. if torch_version < version.parse("2.5.0"):
  106. raise RuntimeError(
  107. f"In order to use torchao pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
  108. f" You can also set with `weights_only=False` in `from_pretrained` if you don't want to update torch"
  109. )
  110. def update_dtype(self, dtype):
  111. if self.quantization_config.quant_type == "int4_weight_only":
  112. if dtype is not None and dtype != torch.bfloat16:
  113. logger.warning_once(
  114. f"Setting dtype to {dtype} for int4_weight_only quantization, but only bfloat16 is supported right now. Please set the dtype to bfloat16."
  115. )
  116. if dtype is None:
  117. logger.warning_once(
  118. "Setting dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set dtype=torch.bfloat16 to remove this warning."
  119. )
  120. dtype = torch.bfloat16
  121. if self.quantization_config.quant_type == "int8_dynamic_activation_int8_weight":
  122. if dtype is None:
  123. logger.info(
  124. "Setting dtype to torch.float32 for int8_dynamic_activation_int8_weight quantization as no dtype was specified in from_pretrained"
  125. )
  126. # we need to set the dtype, otherwise we have dtype mismatch when performing the quantized linear op
  127. dtype = torch.float32
  128. return dtype
  129. def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool] = False):
  130. """
  131. If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
  132. the safetensors format.
  133. """
  134. if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
  135. if TORCHAO_VERSION >= version.parse("0.14.0"):
  136. return flatten_tensor_state_dict(model.state_dict())
  137. else:
  138. raise RuntimeError(
  139. f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
  140. )
  141. else:
  142. return None, {}
  143. def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  144. if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
  145. from accelerate.utils import CustomDtype
  146. # Import AOBaseConfig directly since we know we have the right version
  147. if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
  148. from torchao.core.config import AOBaseConfig
  149. quant_type = self.quantization_config.quant_type
  150. if isinstance(quant_type, AOBaseConfig):
  151. # Extract size digit using fuzzy match on the class name
  152. config_name = quant_type.__class__.__name__
  153. size_digit = fuzzy_match_size(config_name)
  154. # Map the extracted digit to appropriate dtype
  155. if size_digit == "4":
  156. return CustomDtype.INT4
  157. else:
  158. # Default to int8
  159. return torch.int8
  160. # Original mapping for non-AOBaseConfig types
  161. map_to_target_dtype = {
  162. "int4_weight_only": CustomDtype.INT4,
  163. "int8_weight_only": torch.int8,
  164. "int8_dynamic_activation_int8_weight": torch.int8,
  165. "autoquant": None,
  166. }
  167. return map_to_target_dtype[self.quantization_config.quant_type]
  168. else:
  169. raise ValueError(
  170. "You are using `device_map='auto'` on a torchao quantized model. To automatically compute"
  171. " the appropriate device map, you should upgrade your `accelerate` library with "
  172. "`pip install --upgrade accelerate`"
  173. )
  174. def adjust_max_memory(self, max_memory: dict[str, Union[int, str]]) -> dict[str, Union[int, str]]:
  175. # need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128
  176. max_memory = {key: val * 0.9 for key, val in max_memory.items()}
  177. return max_memory
  178. def _process_model_before_weight_loading(
  179. self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[list[str]] = None, **kwargs
  180. ):
  181. self.modules_to_not_convert = self.get_modules_to_not_convert(
  182. model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
  183. )
  184. if self.quantization_config.include_input_output_embeddings:
  185. input_emb = model.get_input_embeddings()
  186. input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
  187. output_emb = model.get_output_embeddings()
  188. output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
  189. self.modules_to_not_convert = [
  190. x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
  191. ]
  192. return
  193. def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
  194. return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
  195. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  196. if self.quantization_config.quant_type == "autoquant":
  197. return False
  198. # check if the param_name is not in self.modules_to_not_convert
  199. if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
  200. return False
  201. elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
  202. return True
  203. else:
  204. # we only quantize the weight of nn.Linear and nn.Embedding
  205. module, tensor_name = get_module_from_name(model, param_name)
  206. _QUANTIZABLE = [torch.nn.Linear]
  207. if self.quantization_config.include_input_output_embeddings:
  208. _QUANTIZABLE.append(torch.nn.Embedding)
  209. return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
  210. def create_quantized_param(
  211. self,
  212. model: "PreTrainedModel",
  213. param_value: "torch.Tensor",
  214. param_name: str,
  215. target_device: "torch.device",
  216. **kwargs,
  217. ):
  218. """
  219. Each nn.Linear layer that needs to be quantized is processed here.
  220. First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
  221. """
  222. from torchao.quantization import quantize_
  223. full_name = param_name
  224. # Those are the pre quantized weights
  225. if ":" in param_name:
  226. param_name = param_name.rsplit(":", 1)[0]
  227. module, tensor_name = get_module_from_name(model, param_name)
  228. if self.pre_quantized:
  229. # If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
  230. # already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
  231. is_unsafe_serialization = ":" not in full_name
  232. if tensor_name == "bias" or is_unsafe_serialization:
  233. module._parameters[tensor_name] = torch.nn.Parameter(
  234. param_value.to(target_device), requires_grad=param_value.requires_grad
  235. )
  236. return
  237. # Sanity check for the new serialization format
  238. elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
  239. raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
  240. # Save the states for later quantization when they are all gathered
  241. if not hasattr(self, "ao_params"):
  242. self.ao_params = defaultdict(dict)
  243. self.ao_params[param_name].update({full_name: param_value})
  244. # We are ready for quantization in this case (we retrieved all the needed keys)
  245. if len(self.ao_params[param_name]) == len(self.weight_ao_keys):
  246. new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name]
  247. # Set it
  248. module._parameters[tensor_name] = torch.nn.Parameter(
  249. new_param.to(target_device), requires_grad=new_param.requires_grad
  250. )
  251. # Free memory
  252. del self.ao_params[param_name]
  253. # Add repr to the module
  254. if isinstance(module, nn.Linear):
  255. module.extra_repr = types.MethodType(_linear_extra_repr, module)
  256. else:
  257. module._parameters[tensor_name] = torch.nn.Parameter(
  258. param_value, requires_grad=param_value.requires_grad
  259. ).to(target_device)
  260. # if we are quantizing tied parameters, to avoid tying the quantized weights
  261. # the correct order to do it is
  262. # 1. load the weight to model
  263. # 2. run tie_weights to populate the weights
  264. # 3. quantize
  265. input_embed = model.get_input_embeddings()
  266. if self.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
  267. model.tie_weights()
  268. setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
  269. # handle ModuleFqnToConfig, introduced in torchao 0.12.0+
  270. if self.quantization_config._get_ao_version() >= version.Version("0.12.0"):
  271. from torchao.quantization import ModuleFqnToConfig
  272. config = self.quantization_config.get_apply_tensor_subclass()
  273. if isinstance(config, ModuleFqnToConfig):
  274. module_fqn, _ = param_name.rsplit(".", 1)
  275. c = None
  276. if module_fqn in config.module_fqn_to_config:
  277. c = config.module_fqn_to_config[module_fqn]
  278. else:
  279. c = config.module_fqn_to_config.get("_default", None)
  280. if c is not None:
  281. # filter_fn: not filtering out any modules
  282. quantize_(module, c, filter_fn=lambda x, fqn: True)
  283. return
  284. quantize_(module, self.quantization_config.get_apply_tensor_subclass())
  285. def _process_model_after_weight_loading(self, model, **kwargs):
  286. """No process required for torchao quantized model"""
  287. if self.quantization_config.quant_type == "autoquant":
  288. from torchao import autoquant
  289. from torchao.quantization import ALL_AUTOQUANT_CLASS_LIST
  290. model = torch.compile(model, mode="max-autotune")
  291. model = autoquant(
  292. model,
  293. qtensor_class_list=ALL_AUTOQUANT_CLASS_LIST,
  294. set_inductor_config=False,
  295. **self.quantization_config.quant_type_kwargs,
  296. )
  297. return model
  298. return
  299. def is_serializable(self, safe_serialization=None) -> bool:
  300. if safe_serialization:
  301. _is_torchao_serializable = type(
  302. self.quantization_config.quant_type
  303. ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
  304. if not _is_torchao_serializable:
  305. logger.warning(
  306. f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
  307. and torchao version >= 0.14.0, please set `safe_serialization` to False for \
  308. {type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
  309. )
  310. return _is_torchao_serializable
  311. _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
  312. "0.25.0"
  313. )
  314. if not _is_torchao_serializable:
  315. logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
  316. if self.offload and self.quantization_config.modules_to_not_convert is None:
  317. logger.warning(
  318. "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
  319. "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
  320. )
  321. return False
  322. return _is_torchao_serializable
  323. def get_accelerator_warm_up_factor(self):
  324. """
  325. This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for accelerator warmup.
  326. - A factor of 2 means we pre-allocate the full memory footprint of the model.
  327. - A factor of 4 means we pre-allocate half of that, and so on
  328. However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give the correct size for quantized weights (like int4 or int8)
  329. That's because TorchAO internally represents quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the dtype
  330. not the actual bit-width of the quantized data.
  331. To correct for this:
  332. - Use a division factor of 8 for int4 weights
  333. - Use a division factor of 4 for int8 weights
  334. """
  335. if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
  336. from torchao.core.config import AOBaseConfig
  337. quant_type = self.quantization_config.quant_type
  338. # For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
  339. if isinstance(quant_type, AOBaseConfig):
  340. # Extract size digit using fuzzy match on the class name
  341. config_name = quant_type.__class__.__name__
  342. size_digit = fuzzy_match_size(config_name)
  343. if size_digit == "4":
  344. return 8
  345. else:
  346. return 4
  347. # Original mapping for non-AOBaseConfig types
  348. map_to_target_dtype = {
  349. "int4_weight_only": 8,
  350. "int8_weight_only": 4,
  351. "int8_dynamic_activation_int8_weight": 4,
  352. "autoquant": 4,
  353. }
  354. return map_to_target_dtype[self.quantization_config.quant_type]
  355. @property
  356. def is_trainable(self) -> bool:
  357. supported_quant_types_for_training = [
  358. "int8_weight_only",
  359. "int8_dynamic_activation_int8_weight",
  360. ]
  361. return self.quantization_config.quant_type in supported_quant_types_for_training
  362. @property
  363. def is_compileable(self) -> bool:
  364. return True
  365. def set_metadata(self, checkpoint_files: list[str]):
  366. if checkpoint_files[0].endswith(".safetensors"):
  367. metadata = {}
  368. for checkpoint in checkpoint_files:
  369. with safe_open(checkpoint, framework="pt") as f:
  370. metadata_ = f.metadata() or {}
  371. metadata.update(metadata_)
  372. # Save it
  373. self.metadata = metadata