bitnet.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. from ..utils import is_accelerate_available, is_torch_available, logging
  2. if is_accelerate_available():
  3. from accelerate import init_empty_weights
  4. if is_torch_available():
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. logger = logging.get_logger(__name__)
  9. # the weights are ternary so can be represented with 2 bits, and they are packed in uint8 tensors, hence the number of values per item is 4
  10. VALUES_PER_ITEM = 4
  11. def pack_weights(quantized_weights: torch.Tensor) -> torch.Tensor:
  12. """
  13. Packs a tensor of quantized weights into a compact format using 2 bits per value.
  14. Parameters:
  15. -----------
  16. quantized_weights : torch.Tensor
  17. A tensor containing ternary quantized weights with values in {-1, 0, 1}. These values are adjusted to
  18. {0, 1, 2} before being packed.
  19. Returns:
  20. --------
  21. torch.Tensor
  22. A packed tensor where each element stores 4 quantized values (each using 2 bits) in an 8-bit format.
  23. """
  24. original_shape = quantized_weights.shape
  25. row_dim = (original_shape[0] + VALUES_PER_ITEM - 1) // VALUES_PER_ITEM
  26. if len(original_shape) == 1:
  27. packed_tensor_shape = (row_dim,)
  28. else:
  29. packed_tensor_shape = (row_dim, *original_shape[1:])
  30. quantized_weights += 1
  31. packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8)
  32. unpacked = quantized_weights.to(torch.uint8)
  33. it = min(VALUES_PER_ITEM, (original_shape[0] // row_dim) + 1)
  34. for i in range(it):
  35. start = i * row_dim
  36. end = min(start + row_dim, original_shape[0])
  37. packed[: (end - start)] |= unpacked[start:end] << 2 * i
  38. return packed
  39. @torch.compile
  40. def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
  41. """
  42. Unpacks a tensor of quantized weights that were stored in a packed format using 2 bits per value.
  43. Parameters:
  44. -----------
  45. packed : torch.Tensor
  46. A tensor containing packed weights where each element represents 4 quantized values (using 2 bits per value).
  47. dtype : torch.dtype
  48. The dtype of the returned Tensor
  49. Returns:
  50. --------
  51. torch.Tensor
  52. A tensor of unpacked weights, where each value is converted from its packed 2-bit representation.
  53. Example:
  54. --------
  55. packed = torch.tensor([[0b10100001, 0b00011000],
  56. [0b10010000, 0b00001010]], dtype=torch.uint8)
  57. # Unpack the values
  58. unpacked = unpack_weights(packed)
  59. # Resulting unpacked tensor
  60. print(unpacked)
  61. # Output: tensor([[ 0, -1],
  62. [-1, 1],
  63. [-1, 1],
  64. [-1, 1],
  65. [ 1, 0],
  66. [ 0, -1],
  67. [ 1, -1],
  68. [ 1, -1]])
  69. Explanation of the example:
  70. ---------------------------
  71. Let's take the first value for example 0b10100001, we we will only focus on the first column,
  72. because every element is unpacked across the first dimension
  73. - First 2 bits: `01` → 0 at [0][0]
  74. - Second 2 bits: `00` → -1 at [0][2]
  75. - Third 2 bits: `10` → 1 at [0][4]
  76. - Fourth 2 bits: `10` → 1 at [0][6]
  77. the second value of the same row (0b10010000) will give the values for [0][1], [0][3], [0][5], [0][7]
  78. We subtract 1 because during the packing process, it's easier to work with values like 0, 1, and 2. To make this possible,
  79. we add 1 to the original ternary weights (which are typically -1, 0, and 1) when packing them. When unpacking, we reverse
  80. this by subtracting 1 to restore the original ternary values.
  81. """
  82. packed_shape = packed.shape
  83. if len(packed_shape) == 1:
  84. original_row_dim = packed_shape[0] * VALUES_PER_ITEM
  85. unpacked_shape = (original_row_dim,)
  86. else:
  87. original_row_dim = packed_shape[0] * VALUES_PER_ITEM
  88. unpacked_shape = (original_row_dim, *packed_shape[1:])
  89. unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
  90. for i in range(VALUES_PER_ITEM):
  91. start = i * packed_shape[0]
  92. end = start + packed_shape[0]
  93. mask = 3 << (2 * i)
  94. unpacked[start:end] = (packed & mask) >> (2 * i)
  95. return unpacked.to(dtype) - 1
  96. class BitLinear(nn.Module):
  97. def __init__(
  98. self,
  99. in_features: int,
  100. out_features: int,
  101. bias: bool,
  102. device=None,
  103. dtype=None,
  104. use_rms_norm: bool = False,
  105. rms_norm_eps: float = 1e-6,
  106. ):
  107. super().__init__()
  108. self.dtype = dtype
  109. self.in_features = in_features
  110. self.out_features = out_features
  111. self.register_buffer(
  112. "weight",
  113. torch.zeros(
  114. (out_features // VALUES_PER_ITEM, in_features),
  115. dtype=torch.uint8,
  116. device=device,
  117. ),
  118. )
  119. self.register_buffer(
  120. "weight_scale",
  121. torch.ones(
  122. (1),
  123. dtype=dtype,
  124. device=device,
  125. ),
  126. )
  127. if bias:
  128. self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
  129. else:
  130. self.bias = None
  131. # Optional RMSNorm (applied on the activations before quantization).
  132. self.rms_norm = None
  133. if use_rms_norm:
  134. from ..models.llama.modeling_llama import LlamaRMSNorm
  135. self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
  136. @torch.compile
  137. def activation_quant(self, input, num_bits=8):
  138. """
  139. Activation function : Performs symmetric, per-token quantization on the input activations.
  140. Parameters:
  141. -----------
  142. x : torch.Tensor
  143. Input activations to be quantized.
  144. num_bits : int, optional (default=8)
  145. Number of bits to use for quantization, determining the quantization range.
  146. Returns:
  147. --------
  148. result : torch.Tensor
  149. Quantized activation tensor, with values mapped to an `int8` range.
  150. scale : torch.Tensor
  151. The per-channel scaling factors used to quantize the tensor.
  152. """
  153. Qn = -(2 ** (num_bits - 1))
  154. Qp = 2 ** (num_bits - 1) - 1
  155. scale = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
  156. result = (input * scale).round().clamp(Qn, Qp)
  157. return result.to(torch.int8), scale
  158. @torch.compile
  159. def post_quant_process(self, input, input_scale, weight_scale):
  160. out = input / (input_scale * weight_scale)
  161. return out
  162. def forward(self, input):
  163. # Apply RMSNorm on the input if requested.
  164. if self.rms_norm is not None:
  165. input = self.rms_norm(input)
  166. w = self.weight
  167. w_quant = unpack_weights(w, dtype=self.dtype)
  168. input_quant, input_scale = self.activation_quant(input)
  169. y = F.linear(input_quant.to(self.dtype), w_quant)
  170. y = self.post_quant_process(y, self.weight_scale, input_scale)
  171. if self.bias is not None:
  172. y += self.bias.view(1, -1).expand_as(y)
  173. return y
  174. class WeightQuant(torch.autograd.Function):
  175. """
  176. Implements a custom autograd function for weight quantization.
  177. This performs ternary quantization (-1, 0, 1) based on scaling by the
  178. mean absolute value of the weights. It uses the Straight-Through Estimator
  179. (STE) for the backward pass.
  180. """
  181. @staticmethod
  182. @torch.compile
  183. def forward(ctx, weight):
  184. dtype = weight.dtype
  185. weight = weight.float()
  186. scale = 1.0 / weight.abs().mean().clamp_(min=1e-5)
  187. weight = (weight * scale).round().clamp(-1, 1) / scale
  188. return weight.to(dtype)
  189. @staticmethod
  190. def backward(ctx, grad_output):
  191. grad_input = grad_output.clone()
  192. return grad_input
  193. class ActQuant(torch.autograd.Function):
  194. """
  195. Implements a custom autograd function for activation quantization.
  196. This performs symmetric 8-bit quantization (to the range [-128, 127])
  197. based on the maximum absolute value along the last dimension (per-token/row scaling).
  198. It uses the Straight-Through Estimator (STE) for the backward pass.
  199. """
  200. @staticmethod
  201. @torch.compile
  202. def forward(ctx, activation):
  203. dtype = activation.dtype
  204. activation = activation.float()
  205. scale = 127 / activation.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
  206. activation = (activation * scale).round().clamp(-128, 127) / scale
  207. return activation.to(dtype)
  208. @staticmethod
  209. def backward(ctx, grad_output):
  210. grad_input = grad_output.clone()
  211. return grad_input
  212. class AutoBitLinear(nn.Linear):
  213. def __init__(
  214. self,
  215. in_features: int,
  216. out_features: int,
  217. bias: bool = True,
  218. device=None,
  219. dtype=None,
  220. online_quant: bool = False,
  221. use_rms_norm: bool = False,
  222. rms_norm_eps: float = 1e-6,
  223. ):
  224. super().__init__(in_features, out_features, bias)
  225. self.online_quant = online_quant
  226. # Optional RMSNorm
  227. self.rms_norm = None
  228. if use_rms_norm:
  229. from ..models.llama.modeling_llama import LlamaRMSNorm
  230. self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
  231. if not online_quant:
  232. self.register_buffer(
  233. "weight_scale",
  234. torch.ones(
  235. (1),
  236. dtype=dtype,
  237. device=device,
  238. ),
  239. )
  240. self._register_load_state_dict_pre_hook(self.load_hook)
  241. def load_hook(
  242. self,
  243. state_dict,
  244. prefix,
  245. *args,
  246. **kwargs,
  247. ):
  248. if (prefix + "weight") in state_dict and state_dict[prefix + "weight"].dtype != self.weight.dtype:
  249. state_dict[prefix + "weight"] = unpack_weights(state_dict[prefix + "weight"], dtype=self.weight.dtype)
  250. return state_dict
  251. def forward(self, input):
  252. # Optional RMSNorm on activations prior to quantization.
  253. if self.rms_norm is not None:
  254. input = self.rms_norm(input)
  255. if self.online_quant:
  256. weight = WeightQuant.apply(self.weight)
  257. else:
  258. weight = self.weight
  259. input = ActQuant.apply(input)
  260. output = F.linear(input, weight, self.bias)
  261. if not self.online_quant:
  262. output = output * self.weight_scale
  263. return output
  264. def _replace_with_bitnet_linear(
  265. model,
  266. modules_to_not_convert=None,
  267. current_key_name=None,
  268. quantization_config=None,
  269. has_been_replaced=False,
  270. pre_quantized=False,
  271. ):
  272. """
  273. Private method that wraps the recursion for module replacement.
  274. Returns the converted model and a boolean that indicates if the conversion has been successful or not.
  275. """
  276. if current_key_name is None:
  277. current_key_name = []
  278. for name, module in model.named_children():
  279. if current_key_name is None:
  280. current_key_name = []
  281. current_key_name.append(name)
  282. # Check if the current key is not in the `modules_to_not_convert`
  283. if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
  284. with init_empty_weights():
  285. if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
  286. in_features = module.in_features
  287. out_features = module.out_features
  288. if quantization_config and quantization_config.linear_class == "autobitlinear":
  289. model._modules[name] = AutoBitLinear(
  290. in_features=in_features,
  291. out_features=out_features,
  292. bias=module.bias is not None,
  293. device=module.weight.device,
  294. dtype=module.weight.dtype,
  295. online_quant=(quantization_config.quantization_mode == "online"),
  296. use_rms_norm=quantization_config.use_rms_norm,
  297. rms_norm_eps=quantization_config.rms_norm_eps,
  298. )
  299. if quantization_config.quantization_mode == "offline":
  300. model._modules[name].requires_grad_(False)
  301. else:
  302. model._modules[name] = BitLinear(
  303. in_features=in_features,
  304. out_features=out_features,
  305. bias=module.bias is not None,
  306. device=module.weight.device,
  307. dtype=module.weight.dtype,
  308. use_rms_norm=quantization_config.use_rms_norm if quantization_config else False,
  309. rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6,
  310. )
  311. model._modules[name].requires_grad_(False)
  312. has_been_replaced = True
  313. if len(list(module.children())) > 0:
  314. _, has_been_replaced = _replace_with_bitnet_linear(
  315. module,
  316. modules_to_not_convert=modules_to_not_convert,
  317. current_key_name=current_key_name,
  318. quantization_config=quantization_config,
  319. has_been_replaced=has_been_replaced,
  320. )
  321. # Remove the last key for recursion
  322. current_key_name.pop(-1)
  323. return model, has_been_replaced
  324. def replace_with_bitnet_linear(
  325. model,
  326. modules_to_not_convert=None,
  327. current_key_name=None,
  328. quantization_config=None,
  329. pre_quantized=False,
  330. ):
  331. """
  332. A helper function to replace all `torch.nn.Linear` modules by `BitLinear158` modules`.
  333. The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
  334. be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
  335. CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
  336. Parameters:
  337. model (`torch.nn.Module`):
  338. Input model or `torch.nn.Module` as the function is run recursively.
  339. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
  340. Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
  341. for numerical stability reasons.
  342. current_key_name (`list[`str`]`, *optional*):
  343. An array to track the current key of the recursion. This is used to check whether the current key (part of
  344. it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
  345. `disk`).
  346. """
  347. modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
  348. if quantization_config and quantization_config.modules_to_not_convert is not None:
  349. modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
  350. modules_to_not_convert = list(set(modules_to_not_convert))
  351. model, has_been_replaced = _replace_with_bitnet_linear(
  352. model,
  353. modules_to_not_convert,
  354. current_key_name,
  355. quantization_config,
  356. pre_quantized=pre_quantized,
  357. )
  358. if not has_been_replaced:
  359. logger.warning(
  360. "You are loading your model using bitnet but no linear modules were found in your model."
  361. " Please double check your model architecture, or submit an issue on github if you think this is"
  362. " a bug."
  363. )
  364. return model