bitsandbytes.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. import importlib.metadata
  2. import inspect
  3. import warnings
  4. from copy import deepcopy
  5. from inspect import signature
  6. from packaging import version
  7. from ..utils import (
  8. get_available_devices,
  9. is_accelerate_available,
  10. is_bitsandbytes_available,
  11. is_bitsandbytes_multi_backend_available,
  12. is_torch_available,
  13. logging,
  14. )
  15. if is_bitsandbytes_available():
  16. import bitsandbytes as bnb
  17. import torch
  18. import torch.nn as nn
  19. from ..pytorch_utils import Conv1D
  20. if is_accelerate_available():
  21. import accelerate
  22. from accelerate import init_empty_weights
  23. from accelerate.hooks import add_hook_to_module, remove_hook_from_module
  24. from accelerate.utils import find_tied_parameters
  25. logger = logging.get_logger(__name__)
  26. def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None):
  27. """
  28. A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
  29. `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
  30. function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
  31. class `Int8Params` from `bitsandbytes`.
  32. Args:
  33. module (`torch.nn.Module`):
  34. The module in which the tensor we want to move lives.
  35. tensor_name (`str`):
  36. The full name of the parameter/buffer.
  37. device (`int`, `str` or `torch.device`):
  38. The device on which to set the tensor.
  39. value (`torch.Tensor`, *optional*):
  40. The value of the tensor (useful when going from the meta device to any other device).
  41. quantized_stats (`dict[str, Any]`, *optional*):
  42. Dict with items for either 4-bit or 8-bit serialization
  43. """
  44. # Recurse if needed
  45. if "." in tensor_name:
  46. splits = tensor_name.split(".")
  47. for split in splits[:-1]:
  48. new_module = getattr(module, split)
  49. if new_module is None:
  50. raise ValueError(f"{module} has no attribute {split}.")
  51. module = new_module
  52. tensor_name = splits[-1]
  53. if tensor_name not in module._parameters and tensor_name not in module._buffers:
  54. raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
  55. is_buffer = tensor_name in module._buffers
  56. old_value = getattr(module, tensor_name)
  57. if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
  58. raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
  59. prequantized_loading = quantized_stats is not None
  60. if is_buffer or not is_bitsandbytes_available():
  61. is_8bit = False
  62. is_4bit = False
  63. else:
  64. is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit)
  65. is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params)
  66. if is_8bit or is_4bit:
  67. param = module._parameters[tensor_name]
  68. if param.device.type != "cuda":
  69. if value is None:
  70. new_value = old_value.to(device)
  71. elif isinstance(value, torch.Tensor):
  72. new_value = value.to("cpu")
  73. else:
  74. new_value = torch.tensor(value, device="cpu")
  75. # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
  76. # Since weights are saved in the correct "orientation", we skip transposing when loading.
  77. if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
  78. new_value = new_value.T
  79. kwargs = old_value.__dict__
  80. if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
  81. raise ValueError(
  82. f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
  83. )
  84. if is_8bit:
  85. is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
  86. "0.37.2"
  87. )
  88. if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
  89. raise ValueError(
  90. "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
  91. "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
  92. )
  93. new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
  94. if prequantized_loading:
  95. setattr(new_value, "SCB", quantized_stats["SCB"].to(device))
  96. elif is_4bit:
  97. if prequantized_loading:
  98. is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
  99. "0.41.3"
  100. )
  101. if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable:
  102. raise ValueError(
  103. "Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
  104. "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
  105. )
  106. new_value = bnb.nn.Params4bit.from_prequantized(
  107. data=new_value,
  108. quantized_stats=quantized_stats,
  109. requires_grad=False,
  110. device=device,
  111. **kwargs,
  112. )
  113. else:
  114. new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
  115. module._parameters[tensor_name] = new_value
  116. else:
  117. if value is None:
  118. new_value = old_value.to(device)
  119. elif isinstance(value, torch.Tensor):
  120. new_value = value.to(device)
  121. else:
  122. new_value = torch.tensor(value, device=device)
  123. if is_buffer:
  124. module._buffers[tensor_name] = new_value
  125. else:
  126. new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
  127. module._parameters[tensor_name] = new_value
  128. def _replace_with_bnb_linear(
  129. model,
  130. modules_to_not_convert=None,
  131. current_key_name=None,
  132. quantization_config=None,
  133. has_been_replaced=False,
  134. ):
  135. """
  136. Private method that wraps the recursion for module replacement.
  137. Returns the converted model and a boolean that indicates if the conversion has been successful or not.
  138. """
  139. for name, module in model.named_children():
  140. if current_key_name is None:
  141. current_key_name = []
  142. current_key_name.append(name)
  143. if (isinstance(module, (nn.Linear, Conv1D))) and name not in modules_to_not_convert:
  144. # Check if the current key is not in the `modules_to_not_convert`
  145. current_key_name_str = ".".join(current_key_name)
  146. if not any(
  147. (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
  148. ):
  149. with init_empty_weights():
  150. if isinstance(module, Conv1D):
  151. in_features, out_features = module.weight.shape
  152. else:
  153. in_features = module.in_features
  154. out_features = module.out_features
  155. if quantization_config.quantization_method() == "llm_int8":
  156. model._modules[name] = bnb.nn.Linear8bitLt(
  157. in_features,
  158. out_features,
  159. module.bias is not None,
  160. has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
  161. threshold=quantization_config.llm_int8_threshold,
  162. )
  163. has_been_replaced = True
  164. else:
  165. if (
  166. quantization_config.llm_int8_skip_modules is not None
  167. and name in quantization_config.llm_int8_skip_modules
  168. ):
  169. pass
  170. else:
  171. extra_kwargs = (
  172. {"quant_storage": quantization_config.bnb_4bit_quant_storage}
  173. if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
  174. else {}
  175. )
  176. model._modules[name] = bnb.nn.Linear4bit(
  177. in_features,
  178. out_features,
  179. module.bias is not None,
  180. quantization_config.bnb_4bit_compute_dtype,
  181. compress_statistics=quantization_config.bnb_4bit_use_double_quant,
  182. quant_type=quantization_config.bnb_4bit_quant_type,
  183. **extra_kwargs,
  184. )
  185. has_been_replaced = True
  186. # Store the module class in case we need to transpose the weight later
  187. model._modules[name].source_cls = type(module)
  188. # Force requires grad to False to avoid unexpected errors
  189. model._modules[name].requires_grad_(False)
  190. if len(list(module.children())) > 0:
  191. _, has_been_replaced = _replace_with_bnb_linear(
  192. module,
  193. modules_to_not_convert,
  194. current_key_name,
  195. quantization_config,
  196. has_been_replaced=has_been_replaced,
  197. )
  198. # Remove the last key for recursion
  199. current_key_name.pop(-1)
  200. return model, has_been_replaced
  201. def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
  202. """
  203. A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
  204. library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
  205. 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
  206. version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
  207. bitsandbytes`
  208. The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
  209. be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
  210. CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
  211. matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
  212. (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
  213. predictive degradation is possible for very large models (>=176B parameters).
  214. Parameters:
  215. model (`torch.nn.Module`):
  216. Input model or `torch.nn.Module` as the function is run recursively.
  217. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
  218. Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
  219. for numerical stability reasons.
  220. current_key_name (`list[`str`]`, *optional*):
  221. An array to track the current key of the recursion. This is used to check whether the current key (part of
  222. it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
  223. `disk`).
  224. quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'):
  225. To configure and manage settings related to quantization, a technique used to compress neural network models
  226. by reducing the precision of the weights and activations, thus making models more efficient in terms of both
  227. storage and computation.
  228. """
  229. modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
  230. model, has_been_replaced = _replace_with_bnb_linear(
  231. model, modules_to_not_convert, current_key_name, quantization_config
  232. )
  233. if not has_been_replaced:
  234. logger.warning(
  235. "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
  236. " Please double check your model architecture, or submit an issue on github if you think this is"
  237. " a bug."
  238. )
  239. return model
  240. # For backward compatibility
  241. def replace_8bit_linear(*args, **kwargs):
  242. warnings.warn(
  243. "`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead",
  244. FutureWarning,
  245. )
  246. return replace_with_bnb_linear(*args, **kwargs)
  247. # For backward compatibility
  248. def set_module_8bit_tensor_to_device(*args, **kwargs):
  249. warnings.warn(
  250. "`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead",
  251. FutureWarning,
  252. )
  253. return set_module_quantized_tensor_to_device(*args, **kwargs)
  254. def get_keys_to_not_convert(model):
  255. r"""
  256. An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
  257. we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
  258. to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
  259. int8.
  260. Parameters:
  261. model (`torch.nn.Module`):
  262. Input model
  263. """
  264. # Create a copy of the model and tie the weights, then
  265. # check if it contains tied weights
  266. tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
  267. tied_model.tie_weights()
  268. tied_params = find_tied_parameters(tied_model)
  269. # For compatibility with Accelerate < 0.18
  270. if isinstance(tied_params, dict):
  271. tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
  272. else:
  273. tied_keys = sum(tied_params, [])
  274. has_tied_params = len(tied_keys) > 0
  275. # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision
  276. if not has_tied_params:
  277. output_emb = model.get_output_embeddings()
  278. if output_emb is not None:
  279. list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
  280. return list_last_module
  281. # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
  282. list_modules = list(model.named_parameters())
  283. list_last_module = [list_modules[-1][0]]
  284. # add last module together with tied weights
  285. intersection = set(list_last_module) - set(tied_keys)
  286. list_untouched = list(set(tied_keys)) + list(intersection)
  287. # remove ".weight" from the keys
  288. names_to_remove = [".weight", ".bias"]
  289. filtered_module_names = []
  290. for name in list_untouched:
  291. for name_to_remove in names_to_remove:
  292. if name_to_remove in name:
  293. name = name.replace(name_to_remove, "")
  294. filtered_module_names.append(name)
  295. return filtered_module_names
  296. # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
  297. def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
  298. """
  299. Helper function to dequantize 4bit or 8bit bnb weights.
  300. If the weight is not a bnb quantized weight, it will be returned as is.
  301. """
  302. if not isinstance(weight, torch.nn.Parameter):
  303. raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
  304. cls_name = weight.__class__.__name__
  305. if cls_name not in ("Params4bit", "Int8Params"):
  306. return weight
  307. if cls_name == "Params4bit":
  308. output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
  309. logger.warning_once(
  310. f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
  311. )
  312. return output_tensor.to(dtype)
  313. if state.SCB is None:
  314. state.SCB = weight.SCB
  315. if hasattr(bnb.functional, "int8_vectorwise_dequant"):
  316. # Use bitsandbytes API if available (requires v0.45.0+)
  317. dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
  318. else:
  319. # Multiply by (scale/127) to dequantize.
  320. dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
  321. return dequantized.to(dtype)
  322. def _create_accelerate_new_hook(old_hook):
  323. r"""
  324. Creates a new hook based on the old hook. Use it only if you know what you are doing !
  325. This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245
  326. with some changes
  327. """
  328. old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
  329. old_hook_attr = old_hook.__dict__
  330. filtered_old_hook_attr = {}
  331. old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
  332. for k in old_hook_attr:
  333. if k in old_hook_init_signature.parameters:
  334. filtered_old_hook_attr[k] = old_hook_attr[k]
  335. new_hook = old_hook_cls(**filtered_old_hook_attr)
  336. return new_hook
  337. def _dequantize_and_replace(
  338. model,
  339. dtype,
  340. modules_to_not_convert=None,
  341. current_key_name=None,
  342. quantization_config=None,
  343. has_been_replaced=False,
  344. ):
  345. """
  346. Converts a quantized model into its dequantized original version. The newly converted model will have
  347. some performance drop compared to the original model before quantization - use it only for specific usecases
  348. such as QLoRA adapters merging.
  349. Returns the converted model and a boolean that indicates if the conversion has been successful or not.
  350. """
  351. quant_method = quantization_config.quantization_method()
  352. target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
  353. for name, module in model.named_children():
  354. if current_key_name is None:
  355. current_key_name = []
  356. current_key_name.append(name)
  357. if isinstance(module, target_cls) and name not in modules_to_not_convert:
  358. # Check if the current key is not in the `modules_to_not_convert`
  359. current_key_name_str = ".".join(current_key_name)
  360. if not any(
  361. (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
  362. ):
  363. bias = getattr(module, "bias", None)
  364. device = module.weight.device
  365. with init_empty_weights():
  366. new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
  367. if quant_method == "llm_int8":
  368. state = module.state
  369. else:
  370. state = None
  371. new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, dtype, state))
  372. if bias is not None:
  373. new_module.bias = bias
  374. # Create a new hook and attach it in case we use accelerate
  375. if hasattr(module, "_hf_hook"):
  376. old_hook = module._hf_hook
  377. new_hook = _create_accelerate_new_hook(old_hook)
  378. remove_hook_from_module(module)
  379. add_hook_to_module(new_module, new_hook)
  380. new_module.to(device)
  381. model._modules[name] = new_module
  382. has_been_replaced = True
  383. if len(list(module.children())) > 0:
  384. _, has_been_replaced = _dequantize_and_replace(
  385. module,
  386. dtype,
  387. modules_to_not_convert,
  388. current_key_name,
  389. quantization_config,
  390. has_been_replaced=has_been_replaced,
  391. )
  392. # Remove the last key for recursion
  393. current_key_name.pop(-1)
  394. return model, has_been_replaced
  395. def dequantize_and_replace(
  396. model,
  397. modules_to_not_convert=None,
  398. quantization_config=None,
  399. ):
  400. model, has_been_replaced = _dequantize_and_replace(
  401. model,
  402. model.dtype,
  403. modules_to_not_convert=modules_to_not_convert,
  404. quantization_config=quantization_config,
  405. )
  406. if not has_been_replaced:
  407. logger.warning(
  408. "For some reason the model has not been properly dequantized. You might see unexpected behavior."
  409. )
  410. return model
  411. def _validate_bnb_multi_backend_availability(raise_exception):
  412. import bitsandbytes as bnb
  413. bnb_supported_devices = getattr(bnb, "supported_torch_devices", set())
  414. available_devices = set(get_available_devices())
  415. if not available_devices.intersection(bnb_supported_devices):
  416. if raise_exception:
  417. err_msg = (
  418. f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices}`. "
  419. "Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation"
  420. )
  421. logger.error(err_msg)
  422. raise RuntimeError(err_msg)
  423. logger.warning("No supported devices found for bitsandbytes multi-backend.")
  424. return False
  425. logger.debug("Multi-backend validation successful.")
  426. return True
  427. def _validate_bnb_cuda_backend_availability(raise_exception):
  428. if not is_torch_available():
  429. return False
  430. import torch
  431. if not torch.cuda.is_available():
  432. log_msg = (
  433. "CUDA is required but not available for bitsandbytes. Please consider installing the multi-platform enabled version of bitsandbytes, which is currently a work in progress. "
  434. "Please check currently supported platforms and installation instructions at https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
  435. )
  436. if raise_exception:
  437. logger.error(log_msg)
  438. raise RuntimeError(log_msg)
  439. logger.warning(log_msg)
  440. return False
  441. logger.debug("CUDA backend validation successful.")
  442. return True
  443. def validate_bnb_backend_availability(raise_exception=False):
  444. """
  445. Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not.
  446. """
  447. if not is_bitsandbytes_available():
  448. if importlib.util.find_spec("bitsandbytes") and version.parse(
  449. importlib.metadata.version("bitsandbytes")
  450. ) < version.parse("0.43.1"):
  451. return _validate_bnb_cuda_backend_availability(raise_exception)
  452. return False
  453. if is_bitsandbytes_multi_backend_available():
  454. return _validate_bnb_multi_backend_availability(raise_exception)
  455. return _validate_bnb_cuda_backend_availability(raise_exception)