quantizer_bnb_4bit.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. from collections import defaultdict
  16. from functools import cached_property
  17. from typing import TYPE_CHECKING, Optional, Union
  18. from packaging import version
  19. from .base import HfQuantizer
  20. from .quantizers_utils import get_module_from_name
  21. if TYPE_CHECKING:
  22. from ..modeling_utils import PreTrainedModel
  23. from ..utils import (
  24. ACCELERATE_MIN_VERSION,
  25. is_accelerate_available,
  26. is_bitsandbytes_available,
  27. is_torch_available,
  28. is_torch_hpu_available,
  29. is_torch_npu_available,
  30. is_torch_xpu_available,
  31. logging,
  32. )
  33. if is_torch_available():
  34. import torch
  35. from ..pytorch_utils import Conv1D
  36. logger = logging.get_logger(__name__)
  37. class Bnb4BitHfQuantizer(HfQuantizer):
  38. """
  39. 4-bit quantization from bitsandbytes.py quantization method:
  40. before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the
  41. layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call
  42. saving:
  43. from state dict, as usual; saves weights and `quant_state` components
  44. loading:
  45. need to locate `quant_state` components and pass to Param4bit constructor
  46. """
  47. use_keep_in_fp32_modules = True
  48. requires_parameters_quantization = True
  49. requires_calibration = False
  50. required_packages = ["bitsandbytes", "accelerate"]
  51. def __init__(self, quantization_config, **kwargs):
  52. super().__init__(quantization_config, **kwargs)
  53. if self.quantization_config.llm_int8_skip_modules is not None:
  54. self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
  55. # This describes the additional items that are saved on the state dict (on the params themselves)
  56. self.bnb_keys = [
  57. f"quant_state.bitsandbytes__{self.quantization_config.bnb_4bit_quant_type}",
  58. "absmax",
  59. "quant_map",
  60. ]
  61. if self.quantization_config.bnb_4bit_use_double_quant:
  62. self.bnb_keys.extend(["nested_absmax", "nested_quant_map"])
  63. def validate_environment(self, *args, **kwargs):
  64. if not is_accelerate_available():
  65. raise ImportError(
  66. f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
  67. )
  68. if not is_bitsandbytes_available(check_library_only=True):
  69. raise ImportError(
  70. "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
  71. )
  72. if not is_torch_available():
  73. raise ImportError(
  74. "The bitsandbytes library requires PyTorch but it was not found in your environment. "
  75. "You can install it with `pip install torch`."
  76. )
  77. # `bitsandbytes` versions older than 0.43.1 eagerly require CUDA at import time,
  78. # so those versions of the library are practically only available when CUDA is too.
  79. if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.1"):
  80. if not torch.cuda.is_available():
  81. raise ImportError(
  82. "The installed version of bitsandbytes (<0.43.1) requires CUDA, but CUDA is not available. "
  83. "You may need to install PyTorch with CUDA support or upgrade bitsandbytes to >=0.43.1."
  84. )
  85. from ..integrations import validate_bnb_backend_availability
  86. from ..utils import is_bitsandbytes_multi_backend_available
  87. bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available()
  88. validate_bnb_backend_availability(raise_exception=True)
  89. if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
  90. raise ValueError(
  91. "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
  92. " sure the weights are in PyTorch format."
  93. )
  94. device_map = kwargs.get("device_map")
  95. if (
  96. device_map is not None
  97. and isinstance(device_map, dict)
  98. and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
  99. ):
  100. device_map_without_lm_head = {
  101. key: device_map[key] for key in device_map if key not in self.modules_to_not_convert
  102. }
  103. if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled:
  104. pass
  105. elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
  106. raise ValueError(
  107. "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
  108. "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
  109. "in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to "
  110. "`from_pretrained`. Check "
  111. "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
  112. "for more details. "
  113. )
  114. def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
  115. if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
  116. from accelerate.utils import CustomDtype
  117. if target_dtype != torch.int8:
  118. logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
  119. return CustomDtype.INT4
  120. else:
  121. raise ValueError(
  122. "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute"
  123. " the appropriate device map, you should upgrade your `accelerate` library,"
  124. "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map"
  125. "calculation. You may encounter unexpected behavior, or pass your own device map"
  126. )
  127. def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
  128. return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.bnb_keys)]
  129. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  130. import bitsandbytes as bnb
  131. # They are on the params themselves, so we cannot easily extract the module from the name
  132. if any(param_name.endswith(x) for x in self.bnb_keys):
  133. return True
  134. module, name = get_module_from_name(model, param_name)
  135. return isinstance(module, bnb.nn.Linear4bit) and name != "bias"
  136. def get_param_name(self, param_name: str) -> str:
  137. """
  138. Get the right param_name in order to get the module associated with the param.
  139. This is useful for quantized stats lile absmax or quant_map as we need to update the param_name to get the module as they are stored in ...weight.absmax.
  140. """
  141. if self.pre_quantized:
  142. # We need to get the param name of quantized weights and not its components. Otherwise, we won't be able to get the nn.Module associated.
  143. if any(param_name.endswith(x) for x in self.bnb_keys):
  144. param_name = (
  145. param_name.rsplit(".", 1)[0] if "quant_state." not in param_name else param_name.rsplit(".", 2)[0]
  146. )
  147. return param_name
  148. def create_quantized_param(
  149. self,
  150. model: "PreTrainedModel",
  151. param_value: "torch.Tensor",
  152. param_name: str,
  153. target_device: "torch.device",
  154. **kwargs,
  155. ):
  156. import bitsandbytes as bnb
  157. full_name = param_name
  158. # update param name to get the weights instead of the quantized stats
  159. param_name = self.get_param_name(param_name)
  160. module, tensor_name = get_module_from_name(model, param_name)
  161. # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
  162. if isinstance(target_device, int) and is_torch_npu_available():
  163. target_device = f"npu:{target_device}"
  164. # construct `new_value` for the module._parameters[tensor_name]
  165. if self.pre_quantized:
  166. module_name = param_name.rsplit(".", 1)[0]
  167. # Save the states for later quantization when they are all gathered
  168. if not hasattr(self, "param_quant_stats"):
  169. self.param_quant_stats = defaultdict(dict)
  170. self.param_quant_stats[module_name].update({full_name: param_value})
  171. # We are ready for quantization in this case (note, the +1 is for the weight itself)
  172. if len(self.param_quant_stats[module_name]) == len(self.bnb_keys) + 1:
  173. param_kwargs = {}
  174. if self.is_bnb_supports_quant_storage_module:
  175. param_kwargs["module"] = module
  176. weight = self.param_quant_stats[module_name].pop(f"{module_name}.weight")
  177. new_value = bnb.nn.Params4bit.from_prequantized(
  178. data=weight,
  179. quantized_stats=self.param_quant_stats[module_name],
  180. requires_grad=False,
  181. device=target_device,
  182. **param_kwargs,
  183. )
  184. # Set it
  185. module._parameters[tensor_name] = new_value
  186. # Delete the states
  187. del self.param_quant_stats[module_name]
  188. else:
  189. new_value = param_value.to("cpu")
  190. old_value = getattr(module, tensor_name)
  191. # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
  192. # Since weights are saved in the correct "orientation", we skip transposing when loading.
  193. if issubclass(module.source_cls, Conv1D):
  194. new_value = new_value.T
  195. kwargs = old_value.__dict__
  196. kwargs.pop("_is_hf_initialized", None)
  197. new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
  198. module._parameters[tensor_name] = new_value
  199. # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory
  200. def adjust_max_memory(self, max_memory: dict[str, Union[int, str]]) -> dict[str, Union[int, str]]:
  201. # need more space for buffers that are created during quantization
  202. max_memory = {key: val * 0.90 for key, val in max_memory.items()}
  203. return max_memory
  204. # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_dtype
  205. def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  206. if dtype is None:
  207. # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
  208. logger.info(
  209. "Overriding dtype=%s with `dtype=torch.float16` due to "
  210. "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
  211. "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
  212. " dtype=torch.float16 to remove this warning.",
  213. dtype,
  214. )
  215. dtype = torch.float16
  216. return dtype
  217. def update_device_map(self, device_map):
  218. if device_map is None:
  219. if torch.cuda.is_available():
  220. device_map = {"": torch.cuda.current_device()}
  221. elif is_torch_npu_available():
  222. device_map = {"": f"npu:{torch.npu.current_device()}"}
  223. elif is_torch_hpu_available():
  224. device_map = {"": f"hpu:{torch.hpu.current_device()}"}
  225. elif is_torch_xpu_available():
  226. device_map = {"": torch.xpu.current_device()}
  227. else:
  228. device_map = {"": "cpu"}
  229. logger.info(
  230. "The device_map was not initialized. "
  231. f"Setting device_map to {device_map}. "
  232. "If you want to use the model for inference, please set device_map ='auto' "
  233. )
  234. return device_map
  235. # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_before_weight_loading
  236. def _process_model_before_weight_loading(
  237. self,
  238. model: "PreTrainedModel",
  239. device_map,
  240. keep_in_fp32_modules: Optional[list[str]] = None,
  241. **kwargs,
  242. ):
  243. from ..integrations import replace_with_bnb_linear
  244. llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
  245. self.modules_to_not_convert = self.get_modules_to_not_convert(
  246. model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
  247. )
  248. # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
  249. if isinstance(device_map, dict) and len(device_map.keys()) > 1:
  250. keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
  251. if len(keys_on_cpu) > 0 and not llm_int8_enable_fp32_cpu_offload:
  252. raise ValueError(
  253. "If you want to offload some keys to `cpu` or `disk`, you need to set "
  254. "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
  255. " converted to 8-bit but kept in 32-bit."
  256. )
  257. self.modules_to_not_convert.extend(keys_on_cpu)
  258. model = replace_with_bnb_linear(
  259. model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
  260. )
  261. model.config.quantization_config = self.quantization_config
  262. # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
  263. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  264. model.is_loaded_in_4bit = True
  265. model.is_4bit_serializable = self.is_serializable()
  266. return model
  267. def is_serializable(self, safe_serialization=None):
  268. _is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
  269. if not _is_4bit_serializable:
  270. logger.warning(
  271. "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
  272. "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed."
  273. )
  274. return False
  275. return True
  276. @cached_property
  277. def is_bnb_supports_quant_storage_module(self) -> bool:
  278. """
  279. determines if the current version of bitsandbytes supports
  280. the `module` parameter in `Params4bit.from_prequantized`
  281. :return:
  282. """
  283. return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.43.3")
  284. @property
  285. def is_trainable(self) -> bool:
  286. return True
  287. def _dequantize(self, model):
  288. from ..integrations import dequantize_and_replace
  289. model = dequantize_and_replace(
  290. model, self.modules_to_not_convert, quantization_config=self.quantization_config
  291. )
  292. return model