quantizer_hqq.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import defaultdict
  15. from typing import TYPE_CHECKING
  16. from ..integrations import prepare_for_hqq_linear
  17. from ..utils import is_hqq_available, is_torch_available, logging
  18. from .base import HfQuantizer
  19. from .quantizers_utils import get_module_from_name
  20. if TYPE_CHECKING:
  21. from ..modeling_utils import PreTrainedModel
  22. if is_torch_available():
  23. import torch
  24. if is_hqq_available():
  25. from hqq.core.quantize import HQQLinear
  26. # This is a compatibility hack. HQQ-quantized linear layers do not have a `weight` attribute,
  27. # but some models attempt to access `weight.dtype` during the forward pass. To prevent runtime errors,
  28. # we patch HQQLinear with a dummy `weight` property that returns an empty tensor with the correct dtype and device.
  29. @property
  30. def weight(self):
  31. return torch.empty(0, dtype=self.compute_dtype, device=self.device)
  32. HQQLinear.weight = weight
  33. logger = logging.get_logger(__name__)
  34. class HqqHfQuantizer(HfQuantizer):
  35. """
  36. HQQ quantizer base HF class.
  37. nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading().
  38. """
  39. use_keep_in_fp32_modules = False
  40. requires_parameters_quantization = True
  41. requires_calibration = False
  42. required_packages = ["hqq"]
  43. def __init__(self, quantization_config, **kwargs):
  44. if not is_hqq_available():
  45. raise ImportError(
  46. "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
  47. )
  48. super().__init__(quantization_config, **kwargs)
  49. self.dtype = None
  50. self.using_multi_gpu = False
  51. # Keys that are serialized specifically by hqq
  52. self.hqq_keys = HQQLinear(None, None).state_dict_keys() - {"bias"}
  53. if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
  54. raise ValueError(
  55. "Converting weights from tf/flax weights is currently not supported, please make"
  56. " sure the weights are in PyTorch format."
  57. )
  58. if self.dtype is None:
  59. if "dtype" in kwargs:
  60. self.dtype = kwargs["dtype"]
  61. else:
  62. self.dtype = torch.float32
  63. logger.info("Setting dtype to torch.float32 as the default value since it was not specified.")
  64. device_map = kwargs.get("device_map")
  65. if isinstance(device_map, dict):
  66. if "cpu" in device_map.values() or "disk" in device_map.values():
  67. raise ValueError(
  68. "You are attempting to use an HQQ model with a device_map that contains a CPU or disk device."
  69. " This is not supported. Please remove the CPU or disk device from the device_map."
  70. )
  71. else:
  72. self.using_multi_gpu = len(set(device_map.values())) > 1
  73. def update_missing_keys(
  74. self, model: "PreTrainedModel", missing_keys: list[str], prefix: str, **kwargs
  75. ) -> list[str]:
  76. if self.pre_quantized:
  77. return [key for key in missing_keys if ("weight" not in key)]
  78. else:
  79. return missing_keys
  80. # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
  81. def update_expected_keys(
  82. self, model: "PreTrainedModel", expected_keys: list[str], loaded_keys: list[str]
  83. ) -> list[str]:
  84. if not self.pre_quantized:
  85. return expected_keys
  86. # Collects all quantizable (linear) layers
  87. def _find_hqq_quantizable_layers(model, layers):
  88. for name, module in model.named_children():
  89. if isinstance(module, (torch.nn.Linear)):
  90. layers.add(module.name)
  91. _find_hqq_quantizable_layers(module, layers)
  92. new_keys = set(expected_keys)
  93. # Name modules
  94. for name, module in model.named_modules():
  95. module.name = name
  96. # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
  97. _valid_modules = set()
  98. _find_hqq_quantizable_layers(model, _valid_modules)
  99. # Remove skipped modules
  100. _skipped_modules = set()
  101. for _module in _valid_modules:
  102. for _skip_module in model.config.quantization_config["skip_modules"]:
  103. if _skip_module in _module:
  104. _skipped_modules.add(_module)
  105. _valid_modules -= _skipped_modules
  106. # Append new expected layers based on _ref_keys
  107. _ref_keys = HQQLinear(
  108. linear_layer=None,
  109. quant_config=None,
  110. compute_dtype=torch.float16,
  111. device="cpu",
  112. del_orig=False,
  113. ).state_dict_keys() - {"bias"}
  114. # Clean-up
  115. _rm_keys = set()
  116. for key in new_keys:
  117. if any(_module in key for _module in _valid_modules):
  118. _rm_keys.add(key)
  119. new_keys -= _rm_keys
  120. # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
  121. # Re-populate Linear/HQQLinear
  122. for _module in _valid_modules:
  123. if _module + ".weight" in loaded_keys:
  124. new_keys.add(_module + ".weight")
  125. else:
  126. new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
  127. if _module + ".bias" in loaded_keys:
  128. new_keys.add(_module + ".bias")
  129. return list(new_keys)
  130. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  131. module, _ = get_module_from_name(model, param_name)
  132. # Since we do not prepare the modules in advance, we need every param of the Linear layer to go through
  133. # `create_quantized_param`, even when `self.is_quantized == True`
  134. return isinstance(module, torch.nn.Linear)
  135. def create_quantized_param(
  136. self,
  137. model: "PreTrainedModel",
  138. param_value: "torch.Tensor",
  139. param_name: str,
  140. target_device: "torch.device",
  141. **kwargs,
  142. ):
  143. module, tensor_name = get_module_from_name(model, param_name)
  144. module_name = param_name.rsplit(".", 1)[0]
  145. parent_module, node = get_module_from_name(model, module_name)
  146. quant_config = model.config.quantization_config["quant_config"]
  147. skip_modules = model.config.quantization_config["skip_modules"]
  148. # In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
  149. if any(skip_module in module.name for skip_module in skip_modules):
  150. module.load_state_dict(
  151. {tensor_name: param_value.to(device=target_device, dtype=self.dtype)}, strict=False, assign=True
  152. )
  153. return
  154. # We need this hack as the model is not pre-prepared as an empty skeleton on meta device
  155. if self.pre_quantized:
  156. # Save them for later
  157. if not hasattr(self, "hqq_params"):
  158. self.hqq_params = defaultdict(dict)
  159. self.hqq_params[module_name].update({tensor_name: param_value})
  160. hqq_params = self.hqq_params[module_name]
  161. # If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
  162. # hqq does not support it...)
  163. if all(k in hqq_params for k in self.hqq_keys) and ("bias" in hqq_params or module.bias is None):
  164. hqq_layer = HQQLinear(
  165. linear_layer=None,
  166. quant_config=None,
  167. compute_dtype=self.dtype,
  168. device=target_device,
  169. del_orig=False,
  170. )
  171. hqq_layer.load_state_dict(hqq_params)
  172. if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
  173. hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
  174. if self.using_multi_gpu:
  175. hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
  176. setattr(parent_module, node, hqq_layer)
  177. del self.hqq_params[module_name], module
  178. return
  179. # Load param in the module (without caring about device or dtype, it will be changed later)
  180. module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)
  181. # If both the weight and bias have already been loaded, time to quantize!
  182. module_is_ready = module.weight.device.type != "meta" and (
  183. module.bias is None or module.bias.device.type != "meta"
  184. )
  185. if module_is_ready:
  186. module_tag = ".".join(module.name.split(".")[-2:])
  187. if "weight_quant_params" in quant_config:
  188. module_quant_config = quant_config
  189. elif module_tag in quant_config:
  190. module_quant_config = quant_config[module_tag]
  191. hqq_layer = HQQLinear(
  192. module,
  193. quant_config=module_quant_config,
  194. compute_dtype=self.dtype,
  195. device=target_device,
  196. del_orig=True,
  197. )
  198. if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
  199. hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
  200. if self.using_multi_gpu:
  201. hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
  202. setattr(parent_module, node, hqq_layer)
  203. def _patch_layer_for_multigpu(self, hqq_layer):
  204. def forward_with_device(self, x):
  205. out = torch.matmul(x.to(self.device), self.dequantize().t())
  206. if self.bias is not None:
  207. out += self.bias
  208. return out
  209. hqq_layer.forward = lambda x: forward_with_device(hqq_layer, x)
  210. return hqq_layer
  211. def _process_model_before_weight_loading(
  212. self,
  213. model: "PreTrainedModel",
  214. **kwargs,
  215. ):
  216. # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param().
  217. # prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config)
  218. model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config)
  219. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  220. model.is_hqq_quantized = True
  221. model.is_hqq_serializable = self.is_serializable()
  222. return model
  223. def is_serializable(self, safe_serialization=None):
  224. return True
  225. @property
  226. def is_trainable(self) -> bool:
  227. return True