bnb.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import os
  16. from copy import deepcopy
  17. from typing import Optional, Union
  18. import torch
  19. import torch.nn as nn
  20. from accelerate.utils.imports import (
  21. is_4bit_bnb_available,
  22. is_8bit_bnb_available,
  23. )
  24. from ..big_modeling import dispatch_model, init_empty_weights
  25. from .dataclasses import BnbQuantizationConfig
  26. from .modeling import (
  27. find_tied_parameters,
  28. get_balanced_memory,
  29. infer_auto_device_map,
  30. load_checkpoint_in_model,
  31. offload_weight,
  32. set_module_tensor_to_device,
  33. )
  34. logger = logging.getLogger(__name__)
  35. def load_and_quantize_model(
  36. model: torch.nn.Module,
  37. bnb_quantization_config: BnbQuantizationConfig,
  38. weights_location: Optional[Union[str, os.PathLike]] = None,
  39. device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,
  40. no_split_module_classes: Optional[list[str]] = None,
  41. max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
  42. offload_folder: Optional[Union[str, os.PathLike]] = None,
  43. offload_state_dict: bool = False,
  44. ):
  45. """
  46. This function will quantize the input model with the associated config passed in `bnb_quantization_config`. If the
  47. model is in the meta device, we will load and dispatch the weights according to the `device_map` passed. If the
  48. model is already loaded, we will quantize the model and put the model on the GPU,
  49. Args:
  50. model (`torch.nn.Module`):
  51. Input model. The model can be already loaded or on the meta device
  52. bnb_quantization_config (`BnbQuantizationConfig`):
  53. The bitsandbytes quantization parameters
  54. weights_location (`str` or `os.PathLike`):
  55. The folder weights_location to load. It can be:
  56. - a path to a file containing a whole model state dict
  57. - a path to a `.json` file containing the index to a sharded checkpoint
  58. - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
  59. - a path to a folder containing a unique pytorch_model.bin file.
  60. device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
  61. A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
  62. name, once a given module name is inside, every submodule of it will be sent to the same device.
  63. no_split_module_classes (`List[str]`, *optional*):
  64. A list of layer class names that should never be split across device (for instance any layer that has a
  65. residual connection).
  66. max_memory (`Dict`, *optional*):
  67. A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
  68. offload_folder (`str` or `os.PathLike`, *optional*):
  69. If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
  70. offload_state_dict (`bool`, *optional*, defaults to `False`):
  71. If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
  72. the weight of the CPU state dict + the biggest shard does not fit.
  73. Returns:
  74. `torch.nn.Module`: The quantized model
  75. """
  76. load_in_4bit = bnb_quantization_config.load_in_4bit
  77. load_in_8bit = bnb_quantization_config.load_in_8bit
  78. if load_in_8bit and not is_8bit_bnb_available():
  79. raise ImportError(
  80. "You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
  81. " make sure you have the latest version of `bitsandbytes` installed."
  82. )
  83. if load_in_4bit and not is_4bit_bnb_available():
  84. raise ValueError(
  85. "You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
  86. "make sure you have the latest version of `bitsandbytes` installed."
  87. )
  88. modules_on_cpu = []
  89. # custom device map
  90. if isinstance(device_map, dict) and len(device_map.keys()) > 1:
  91. modules_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
  92. # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
  93. if bnb_quantization_config.skip_modules is None:
  94. bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
  95. # add cpu modules to skip modules only for 4-bit modules
  96. if load_in_4bit:
  97. bnb_quantization_config.skip_modules.extend(modules_on_cpu)
  98. modules_to_not_convert = bnb_quantization_config.skip_modules
  99. # We add the modules we want to keep in full precision
  100. if bnb_quantization_config.keep_in_fp32_modules is None:
  101. bnb_quantization_config.keep_in_fp32_modules = []
  102. keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
  103. modules_to_not_convert.extend(keep_in_fp32_modules)
  104. # compatibility with peft
  105. model.is_loaded_in_4bit = load_in_4bit
  106. model.is_loaded_in_8bit = load_in_8bit
  107. model_device = get_parameter_device(model)
  108. if model_device.type != "meta":
  109. # quantization of an already loaded model
  110. logger.warning(
  111. "It is not recommended to quantize a loaded model. "
  112. "The model should be instantiated under the `init_empty_weights` context manager."
  113. )
  114. model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
  115. # convert param to the right dtype
  116. dtype = bnb_quantization_config.torch_dtype
  117. for name, param in model.state_dict().items():
  118. if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
  119. param.to(torch.float32)
  120. if param.dtype != torch.float32:
  121. name = name.replace(".weight", "").replace(".bias", "")
  122. param = getattr(model, name, None)
  123. if param is not None:
  124. param.to(torch.float32)
  125. elif torch.is_floating_point(param):
  126. param.to(dtype)
  127. if model_device.type == "cuda":
  128. model.cuda(torch.cuda.current_device())
  129. torch.cuda.empty_cache()
  130. elif torch.cuda.is_available():
  131. model.to(torch.cuda.current_device())
  132. elif torch.xpu.is_available():
  133. model.to(torch.xpu.current_device())
  134. else:
  135. raise RuntimeError("No GPU or Intel XPU found. A GPU or Intel XPU is needed for quantization.")
  136. logger.info(
  137. f"The model device type is {model_device.type}. However, gpu or intel xpu is needed for quantization."
  138. "We move the model to it."
  139. )
  140. return model
  141. elif weights_location is None:
  142. raise RuntimeError(
  143. f"`weights_location` needs to be the folder path containing the weights of the model, but we found {weights_location} "
  144. )
  145. else:
  146. with init_empty_weights():
  147. model = replace_with_bnb_layers(
  148. model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert
  149. )
  150. device_map = get_quantized_model_device_map(
  151. model,
  152. bnb_quantization_config,
  153. device_map,
  154. max_memory=max_memory,
  155. no_split_module_classes=no_split_module_classes,
  156. )
  157. if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
  158. offload_state_dict = True
  159. offload = any(x in list(device_map.values()) for x in ["cpu", "disk"])
  160. load_checkpoint_in_model(
  161. model,
  162. weights_location,
  163. device_map,
  164. dtype=bnb_quantization_config.torch_dtype,
  165. offload_folder=offload_folder,
  166. offload_state_dict=offload_state_dict,
  167. keep_in_fp32_modules=bnb_quantization_config.keep_in_fp32_modules,
  168. offload_8bit_bnb=load_in_8bit and offload,
  169. )
  170. return dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
  171. def get_quantized_model_device_map(
  172. model, bnb_quantization_config, device_map=None, max_memory=None, no_split_module_classes=None
  173. ):
  174. if device_map is None:
  175. if torch.cuda.is_available():
  176. device_map = {"": torch.cuda.current_device()}
  177. elif torch.xpu.is_available():
  178. device_map = {"": torch.xpu.current_device()}
  179. else:
  180. raise RuntimeError("No GPU found. A GPU is needed for quantization.")
  181. logger.info("The device_map was not initialized.Setting device_map to `{'':torch.cuda.current_device()}`.")
  182. if isinstance(device_map, str):
  183. if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
  184. raise ValueError(
  185. "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
  186. "'sequential'."
  187. )
  188. special_dtypes = {}
  189. special_dtypes.update(
  190. {
  191. name: bnb_quantization_config.torch_dtype
  192. for name, _ in model.named_parameters()
  193. if any(m in name for m in bnb_quantization_config.skip_modules)
  194. }
  195. )
  196. special_dtypes.update(
  197. {
  198. name: torch.float32
  199. for name, _ in model.named_parameters()
  200. if any(m in name for m in bnb_quantization_config.keep_in_fp32_modules)
  201. }
  202. )
  203. kwargs = {}
  204. kwargs["special_dtypes"] = special_dtypes
  205. kwargs["no_split_module_classes"] = no_split_module_classes
  206. kwargs["dtype"] = bnb_quantization_config.target_dtype
  207. # get max_memory for each device.
  208. if device_map != "sequential":
  209. max_memory = get_balanced_memory(
  210. model,
  211. low_zero=(device_map == "balanced_low_0"),
  212. max_memory=max_memory,
  213. **kwargs,
  214. )
  215. kwargs["max_memory"] = max_memory
  216. device_map = infer_auto_device_map(model, **kwargs)
  217. if isinstance(device_map, dict):
  218. # check if don't have any quantized module on the cpu
  219. modules_not_to_convert = bnb_quantization_config.skip_modules + bnb_quantization_config.keep_in_fp32_modules
  220. device_map_without_some_modules = {
  221. key: device_map[key] for key in device_map.keys() if key not in modules_not_to_convert
  222. }
  223. for device in ["cpu", "disk"]:
  224. if device in device_map_without_some_modules.values():
  225. if bnb_quantization_config.load_in_4bit:
  226. raise ValueError(
  227. """
  228. Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
  229. the quantized model. If you want to dispatch the model on the CPU or the disk while keeping
  230. these modules in `torch_dtype`, you need to pass a custom `device_map` to
  231. `load_and_quantize_model`. Check
  232. https://huggingface.co/docs/accelerate/main/en/usage_guides/quantization#offload-modules-to-cpu-and-disk
  233. for more details.
  234. """
  235. )
  236. else:
  237. logger.info(
  238. "Some modules are are offloaded to the CPU or the disk. Note that these modules will be converted to 8-bit"
  239. )
  240. del device_map_without_some_modules
  241. return device_map
  242. def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None):
  243. """
  244. A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit`
  245. modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules.
  246. Parameters:
  247. model (`torch.nn.Module`):
  248. Input model or `torch.nn.Module` as the function is run recursively.
  249. modules_to_not_convert (`List[str]`):
  250. Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for
  251. numerical stability reasons.
  252. current_key_name (`List[str]`, *optional*):
  253. An array to track the current key of the recursion. This is used to check whether the current key (part of
  254. it) is not in the list of modules to not convert.
  255. """
  256. if modules_to_not_convert is None:
  257. modules_to_not_convert = []
  258. model, has_been_replaced = _replace_with_bnb_layers(
  259. model, bnb_quantization_config, modules_to_not_convert, current_key_name
  260. )
  261. if not has_been_replaced:
  262. logger.warning(
  263. "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
  264. " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
  265. " Please double check your model architecture, or submit an issue on github if you think this is"
  266. " a bug."
  267. )
  268. return model
  269. def _replace_with_bnb_layers(
  270. model,
  271. bnb_quantization_config,
  272. modules_to_not_convert=None,
  273. current_key_name=None,
  274. ):
  275. """
  276. Private method that wraps the recursion for module replacement.
  277. Returns the converted model and a boolean that indicates if the conversion has been successful or not.
  278. """
  279. # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
  280. import bitsandbytes as bnb
  281. has_been_replaced = False
  282. for name, module in model.named_children():
  283. if current_key_name is None:
  284. current_key_name = []
  285. current_key_name.append(name)
  286. if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
  287. # Check if the current key is not in the `modules_to_not_convert`
  288. current_key_name_str = ".".join(current_key_name)
  289. proceed = True
  290. for key in modules_to_not_convert:
  291. if (
  292. (key in current_key_name_str) and (key + "." in current_key_name_str)
  293. ) or key == current_key_name_str:
  294. proceed = False
  295. break
  296. if proceed:
  297. # Load bnb module with empty weight and replace ``nn.Linear` module
  298. if bnb_quantization_config.load_in_8bit:
  299. bnb_module = bnb.nn.Linear8bitLt(
  300. module.in_features,
  301. module.out_features,
  302. module.bias is not None,
  303. has_fp16_weights=False,
  304. threshold=bnb_quantization_config.llm_int8_threshold,
  305. )
  306. elif bnb_quantization_config.load_in_4bit:
  307. bnb_module = bnb.nn.Linear4bit(
  308. module.in_features,
  309. module.out_features,
  310. module.bias is not None,
  311. bnb_quantization_config.bnb_4bit_compute_dtype,
  312. compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant,
  313. quant_type=bnb_quantization_config.bnb_4bit_quant_type,
  314. )
  315. else:
  316. raise ValueError("load_in_8bit and load_in_4bit can't be both False")
  317. bnb_module.weight.data = module.weight.data
  318. if module.bias is not None:
  319. bnb_module.bias.data = module.bias.data
  320. bnb_module.requires_grad_(False)
  321. setattr(model, name, bnb_module)
  322. has_been_replaced = True
  323. if len(list(module.children())) > 0:
  324. _, _has_been_replaced = _replace_with_bnb_layers(
  325. module, bnb_quantization_config, modules_to_not_convert, current_key_name
  326. )
  327. has_been_replaced = has_been_replaced | _has_been_replaced
  328. # Remove the last key for recursion
  329. current_key_name.pop(-1)
  330. return model, has_been_replaced
  331. def get_keys_to_not_convert(model):
  332. r"""
  333. An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
  334. we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
  335. to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
  336. int8.
  337. Parameters:
  338. model (`torch.nn.Module`):
  339. Input model
  340. """
  341. # Create a copy of the model
  342. with init_empty_weights():
  343. tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
  344. tied_params = find_tied_parameters(tied_model)
  345. # For compatibility with Accelerate < 0.18
  346. if isinstance(tied_params, dict):
  347. tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
  348. else:
  349. tied_keys = sum(tied_params, [])
  350. has_tied_params = len(tied_keys) > 0
  351. # Check if it is a base model
  352. is_base_model = False
  353. if hasattr(model, "base_model_prefix"):
  354. is_base_model = not hasattr(model, model.base_model_prefix)
  355. # Ignore this for base models (BertModel, GPT2Model, etc.)
  356. if (not has_tied_params) and is_base_model:
  357. return []
  358. # otherwise they have an attached head
  359. list_modules = list(model.named_children())
  360. list_last_module = [list_modules[-1][0]]
  361. # add last module together with tied weights
  362. intersection = set(list_last_module) - set(tied_keys)
  363. list_untouched = list(set(tied_keys)) + list(intersection)
  364. # remove ".weight" from the keys
  365. names_to_remove = [".weight", ".bias"]
  366. filtered_module_names = []
  367. for name in list_untouched:
  368. for name_to_remove in names_to_remove:
  369. if name_to_remove in name:
  370. name = name.replace(name_to_remove, "")
  371. filtered_module_names.append(name)
  372. return filtered_module_names
  373. def has_4bit_bnb_layers(model):
  374. """Check if we have `bnb.nn.Linear4bit` or `bnb.nn.Linear8bitLt` layers inside our model"""
  375. # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
  376. import bitsandbytes as bnb
  377. for m in model.modules():
  378. if isinstance(m, bnb.nn.Linear4bit):
  379. return True
  380. return False
  381. def get_parameter_device(parameter: nn.Module):
  382. return next(parameter.parameters()).device
  383. def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics):
  384. # if it is not quantized, we quantize and offload the quantized weights and the SCB stats
  385. if fp16_statistics is None:
  386. set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param)
  387. tensor_name = param_name
  388. module = model
  389. if "." in tensor_name:
  390. splits = tensor_name.split(".")
  391. for split in splits[:-1]:
  392. new_module = getattr(module, split)
  393. if new_module is None:
  394. raise ValueError(f"{module} has no attribute {split}.")
  395. module = new_module
  396. tensor_name = splits[-1]
  397. # offload weights
  398. module._parameters[tensor_name].requires_grad = False
  399. offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index)
  400. if hasattr(module._parameters[tensor_name], "SCB"):
  401. offload_weight(
  402. module._parameters[tensor_name].SCB,
  403. param_name.replace("weight", "SCB"),
  404. offload_folder,
  405. index=offload_index,
  406. )
  407. else:
  408. offload_weight(param, param_name, offload_folder, index=offload_index)
  409. offload_weight(fp16_statistics, param_name.replace("weight", "SCB"), offload_folder, index=offload_index)
  410. set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype, value=torch.empty(*param.size()))