weight_only.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932
  1. #
  2. # The implementation of this file is based on:
  3. # https://github.com/intel/neural-compressor/tree/master/neural_compressor
  4. #
  5. # Copyright (c) 2023 Intel Corporation
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. #
  19. # Modifications:
  20. # Add k-quant quantization method.
  21. # Copyright (c) Microsoft Corporation. All rights reserved.
  22. # Licensed under the MIT License.
  23. """WeightOnly for onnxrt adaptor."""
  24. import copy
  25. import logging
  26. import os
  27. import sys
  28. import numpy as np
  29. import onnx
  30. from onnx import numpy_helper
  31. from onnx.helper import np_dtype_to_tensor_dtype
  32. import onnxruntime as ort
  33. from .onnx_model import ONNXModel
  34. from .util import simple_progress_bar
  35. logger = logging.getLogger("neural_compressor")
  36. def make_matmul_weight_only_node(
  37. node,
  38. weight_shape,
  39. num_bits,
  40. group_size,
  41. k_blocks,
  42. q_weight,
  43. scale,
  44. zero_point,
  45. accuracy_level=0,
  46. ): # pragma: no cover
  47. """Build MatMulNBits node.
  48. Args:
  49. node: original matmul node
  50. weight_shape: original weight shape
  51. num_bits (int): num_bits
  52. group_size (int): how many elements share one scale/zp
  53. k_blocks (int): block number
  54. q_weight (array): quantized weight
  55. scale (array): scale
  56. zero_point (array): zero point
  57. accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).
  58. Returns:
  59. matmul_weight_only_node: MatMulNBits node
  60. new_inits: initializers of the new node
  61. """
  62. blob_size = group_size * num_bits // 8
  63. packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8")
  64. q_weight_name = node.input[1] + f"_Q{num_bits!s}G{group_size!s}"
  65. input_names = [node.input[0], q_weight_name]
  66. new_inits = []
  67. kwargs = {}
  68. op_type = "MatMulNBits"
  69. # pack quantized weight
  70. if num_bits == 4:
  71. q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
  72. packed[:, :] = q_weight_pairs[:, :blob_size]
  73. elif num_bits == 8:
  74. packed = q_weight
  75. else:
  76. logger.error(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")
  77. packed = np.reshape(packed, (-1, k_blocks, blob_size))
  78. # build scale tensor
  79. scale = np.reshape(scale, (-1, k_blocks))
  80. assert scale.dtype == np.float32 or scale.dtype == np.float16
  81. scale_tensor = onnx.helper.make_tensor(
  82. name=node.input[1] + "_scale",
  83. data_type=np_dtype_to_tensor_dtype(scale.dtype),
  84. dims=scale.shape,
  85. vals=scale.tobytes(),
  86. raw=True,
  87. )
  88. input_names.append(scale_tensor.name)
  89. new_inits.append(scale_tensor)
  90. # build zero_point tensor
  91. if zero_point is not None:
  92. if num_bits == 8:
  93. packed_zp = zero_point.astype("uint8")
  94. elif num_bits == 4:
  95. # For 4-bit case, the default zeros is 0x8. So it is 0x88 = 136 if we fill lower/higher 4 bits with 0x8.
  96. packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
  97. # create an index array
  98. idx = np.arange(zero_point.shape[0] // k_blocks * k_blocks).reshape(-1)
  99. # separate odd and even indices
  100. even_idx = idx[::2]
  101. odd_idx = idx[1::2]
  102. # vectorized operation for even and odd indices
  103. packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel()
  104. packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4)
  105. else:
  106. raise ValueError(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.")
  107. packed_zp = np.reshape(packed_zp, (weight_shape[1], -1))
  108. zp_tensor = onnx.helper.make_tensor(
  109. name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
  110. )
  111. input_names.append(zp_tensor.name)
  112. new_inits.append(zp_tensor)
  113. # set kwargs
  114. kwargs["K"] = weight_shape[0]
  115. kwargs["N"] = weight_shape[1]
  116. kwargs["bits"] = num_bits
  117. kwargs["block_size"] = group_size
  118. if accuracy_level > 0:
  119. # require onnxruntime > 1.16.3
  120. kwargs["accuracy_level"] = accuracy_level
  121. q_weight_tensor = onnx.helper.make_tensor(
  122. name=q_weight_name,
  123. data_type=2,
  124. dims=packed.shape,
  125. vals=packed.tobytes(),
  126. raw=True,
  127. )
  128. new_inits.append(q_weight_tensor)
  129. matmul_weight_only_node = onnx.helper.make_node(
  130. op_type,
  131. inputs=input_names,
  132. outputs=node.output,
  133. name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits),
  134. domain="com.microsoft",
  135. **kwargs,
  136. )
  137. return matmul_weight_only_node, new_inits
  138. def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
  139. """Quantize tensor per group.
  140. Args:
  141. data : input weight
  142. num_bits (int, optional): num_bits. Defaults to 4.
  143. group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
  144. scheme (str, optional): quantization scheme. Defaults to "asym".
  145. dtype (str, optional): data type. Defaults to "int".
  146. ratio (float, optional): percentile of clip. Defaults to 1.0.
  147. Returns:
  148. output: quantized weight
  149. scale: scale
  150. zero_point: zero point
  151. """
  152. data = np.reshape(data, (-1, group_size))
  153. if scheme == "asym" or dtype == "uint":
  154. maxq = 2**num_bits - 1
  155. minq = 0
  156. elif scheme == "sym":
  157. maxq = 2 ** (num_bits - 1) - 1 if num_bits != 1 else 0
  158. minq = -(2 ** (num_bits - 1)) if num_bits != 1 else -1
  159. rmin = np.min(data, axis=1, keepdims=True) * ratio
  160. rmax = np.max(data, axis=1, keepdims=True) * ratio
  161. if scheme == "sym":
  162. max_range = np.maximum(np.abs(rmin), np.abs(rmax))
  163. scale = np.ones(rmax.shape)
  164. mask = max_range > 0
  165. scale[mask] = (max_range[mask] * 2.0).astype(np.float64) / (maxq - minq)
  166. zero_point = (
  167. np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1))
  168. )
  169. else:
  170. scale = np.ones(rmax.shape)
  171. scale[rmin != rmax] = np.array(
  172. [float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()]
  173. )
  174. zero_point = (
  175. ((np.zeros(scale.shape) - rmin) / scale).round()
  176. if dtype == "int"
  177. else np.maximum(0, np.minimum(maxq, ((np.zeros(scale.shape) - rmin) / scale).round())).astype("uint8")
  178. )
  179. q_weight = np.empty_like(data, dtype=scale.dtype)
  180. np.divide(data, scale, out=q_weight)
  181. np.add(q_weight, zero_point, out=q_weight)
  182. np.round(q_weight, out=q_weight)
  183. np.clip(q_weight, minq, maxq, out=q_weight)
  184. return q_weight, scale, zero_point
  185. def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32):
  186. """Quantize tensor per group based on k quant.
  187. Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
  188. Args:
  189. data : input weight
  190. num_bits (int, optional): num_bits. Defaults to 4.
  191. group_size (int, optional): how many elements share one scale/zp. Defaults to 32.
  192. Returns:
  193. output: quantized weight
  194. scale: scale
  195. zero_point: zero point
  196. """
  197. data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size)
  198. maxq = 2**num_bits - 1
  199. minq = 0
  200. sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1)
  201. av_x = np.sqrt(sum_x2 / group_size) # (nb, 1)
  202. weights = np.add(av_x, np.abs(data)) # (nb, group_size)
  203. rmin = np.min(data, axis=1, keepdims=True) # (nb, 1)
  204. rmax = np.max(data, axis=1, keepdims=True) # (nb, 1)
  205. sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1)
  206. sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
  207. iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
  208. mask = rmin != rmax
  209. iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
  210. scale = 1 / iscale
  211. quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
  212. diff = scale * quant_data + rmin - data # (nb, group_size)
  213. best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
  214. nstep = 20
  215. rdelta = 0.1
  216. # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1
  217. rrmin = -1
  218. for is_ in range(nstep):
  219. iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
  220. factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
  221. mask = rmin != rmax
  222. iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
  223. quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
  224. mul_weights_quant_data_new = weights * quant_data_new
  225. sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
  226. sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
  227. sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
  228. D = np.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806
  229. this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
  230. this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
  231. diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
  232. mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
  233. mad_1 = np.array(mad)
  234. best_mad_1 = np.array(best_mad)
  235. idx_to_replace = np.where(mad_1 < best_mad_1)[0]
  236. quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
  237. best_mad[idx_to_replace] = mad[idx_to_replace]
  238. scale[idx_to_replace] = this_scale[idx_to_replace]
  239. rmin[idx_to_replace] = this_min[idx_to_replace]
  240. zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
  241. scale = scale.astype(np.float64)
  242. q_weight = np.empty_like(data, dtype=scale.dtype)
  243. np.divide(data, scale, out=q_weight)
  244. np.add(q_weight, zero_point, out=q_weight)
  245. np.round(q_weight, out=q_weight)
  246. np.clip(q_weight, minq, maxq, out=q_weight)
  247. return q_weight, scale, zero_point
  248. def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32):
  249. """Quantize tensor per group based on k quant.
  250. Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c
  251. Args:
  252. data : input weight
  253. num_bits (int, optional): num_bits. Defaults to 4.
  254. group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
  255. Returns:
  256. output: quantized weight
  257. scale: scale
  258. zero_point: zero point
  259. """
  260. try:
  261. import cupy as cp # noqa: PLC0415
  262. import torch # noqa: PLC0415
  263. if torch.cuda.is_available():
  264. data = cp.asarray(data)
  265. data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size)
  266. maxq = 2**num_bits - 1
  267. minq = 0
  268. sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1)
  269. av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1)
  270. weights = cp.add(av_x, cp.abs(data)) # (nb, group_size)
  271. rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1)
  272. rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1)
  273. sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1)
  274. sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size)
  275. iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
  276. mask = rmin != rmax
  277. iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask])
  278. scale = 1 / iscale
  279. quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size)
  280. diff = scale * quant_data + rmin - data # (nb, group_size)
  281. best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
  282. nstep = 20
  283. rdelta = 0.1
  284. rrmin = -1
  285. for is_ in range(nstep):
  286. iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1)
  287. factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0]
  288. mask = rmin != rmax
  289. iscale_new[mask] = factor / (rmax[mask] - rmin[mask])
  290. quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size)
  291. mul_weights_quant_data_new = weights * quant_data_new
  292. sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1)
  293. sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1)
  294. sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1)
  295. D = cp.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806
  296. this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1)
  297. this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1)
  298. diff = this_scale * quant_data_new + this_min - data # (nb, group_size)
  299. mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1)
  300. mad_1 = cp.array(mad)
  301. best_mad_1 = cp.array(best_mad)
  302. idx_to_replace = cp.where(mad_1 < best_mad_1)[0]
  303. quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :]
  304. best_mad[idx_to_replace] = mad[idx_to_replace]
  305. scale[idx_to_replace] = this_scale[idx_to_replace]
  306. rmin[idx_to_replace] = this_min[idx_to_replace]
  307. zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8")
  308. scale = scale.astype(cp.float64)
  309. q_weight = cp.empty_like(data, dtype=scale.dtype)
  310. cp.divide(data, scale, out=q_weight)
  311. cp.add(q_weight, zero_point, out=q_weight)
  312. cp.round(q_weight, out=q_weight)
  313. cp.clip(q_weight, minq, maxq, out=q_weight)
  314. return q_weight.get(), scale.get(), zero_point.get()
  315. else:
  316. logger.warning(
  317. "Try to use k-quant quantization on CUDA. However, CUDA is not available."
  318. "Fall back to k-quant quantization on CPU."
  319. )
  320. return quant_tensor_k_quant_cpu(data, num_bits, group_size)
  321. except ImportError:
  322. logger.info(
  323. "Now we are using k-quant quantization on cpu, which is time consuming."
  324. "Please consider install cupy to speed up on CUDA. See https://cupy.dev/"
  325. "Please also install torch to check CUDA availability."
  326. )
  327. return quant_tensor_k_quant_cpu(data, num_bits, group_size)
  328. def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
  329. """Quant dequant tensor per group.
  330. Args:
  331. data : input weight
  332. num_bits (int, optional): num_bits. Defaults to 4.
  333. group_size (int, optional): how many elements share one scale/zp. Defaults to 4.
  334. scheme (str, optional): quantization scheme. Defaults to "asym".
  335. dtype (str, optional): data type. Defaults to "int".
  336. ratio (float, optional): percentile of clip. Defaults to 1.0.
  337. Returns:
  338. output: quant-dequant weight
  339. """
  340. org_shape = data.shape
  341. weight, scale, zp = quant_tensor(data, num_bits, group_size, scheme, dtype, ratio)
  342. return np.reshape(scale * (weight - zp), org_shape)
  343. def pad_tensor(weight, group_size, k_blocks):
  344. """Pad tensor rowi so that it can be is divisible by group_size.
  345. Args:
  346. weight (array): weight
  347. group_size (int): how many elements share one scale/zp
  348. k_blocks (int): the number of block
  349. Returns:
  350. weight: paded weight
  351. """
  352. if group_size == -1:
  353. return weight
  354. org_w_shape = weight.shape
  355. padded_rows = k_blocks * group_size
  356. pad_len = padded_rows - org_w_shape[0]
  357. if pad_len > 0:
  358. weight = np.pad(weight, ((0, pad_len), (0, 0)), "constant")
  359. return weight
  360. def rtn_quantize(
  361. model,
  362. weight_config={}, # noqa: B006
  363. num_bits=4,
  364. group_size=32,
  365. scheme="asym",
  366. ratios={}, # noqa: B006
  367. accuracy_level=0,
  368. providers=["CPUExecutionProvider"], # noqa: B006
  369. algorithm="k_quant",
  370. ):
  371. """Quant the model with round to nearst method.
  372. Args:
  373. model (ModelProto or ONNXModel): onnx model
  374. weight_config (dict): quantization config
  375. For example,
  376. weight_config = {
  377. 'fc2':
  378. {
  379. 'bits': 4,
  380. 'group_size': 32,
  381. 'scheme': 'sym',
  382. 'algorithm': 'RTN'
  383. }
  384. }
  385. num_bits (int, optional): num_bits. Default is 4.
  386. group_size (int, optional): how many elements share one scale/zp. Default is 32.
  387. scheme (str, optional): sym or asym. Defaults to "asym".
  388. ratios (dict, optional): percentile of clip. Defaults to {}.
  389. accuracy_level (int): accuracy level. Support 0 (unset),1(fp32), 2(fp16), 3(bf16), or 4(int8).
  390. providers (list): providers to use
  391. Returns:
  392. model: fake quantized ONNXModel
  393. """
  394. model = ONNXModel(model)
  395. base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
  396. new_nodes = []
  397. remove_nodes = []
  398. total_num = len([i for i in model.nodes() if i.op_type in ["MatMul"]])
  399. curr_id = 0
  400. for node in model.nodes():
  401. if node.op_type in ["MatMul"]:
  402. curr_id += 1
  403. simple_progress_bar(total_num, curr_id)
  404. if (
  405. node.op_type in ["MatMul"]
  406. and model.get_initializer(node.input[1]) is not None
  407. and weight_config.get(node.name, {}) != "fp32"
  408. ):
  409. weight_tensor = model.get_initializer(node.input[1])
  410. weight = numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy()
  411. if len(weight.shape) != 2:
  412. continue
  413. dtype = weight.dtype
  414. if node.name in weight_config:
  415. num_bits = weight_config[node.name]["bits"]
  416. group_size = weight_config[node.name]["group_size"]
  417. scheme = weight_config[node.name]["scheme"]
  418. org_w_shape = weight.shape # ic, oc
  419. group_size = group_size if group_size != -1 else org_w_shape[0]
  420. k_blocks = (org_w_shape[0] - 1) // group_size + 1
  421. init_share_num = model.get_initializer_share_num(node.input[1])
  422. weight = pad_tensor(weight, group_size, k_blocks)
  423. satisfy_MatMulNBits_condition = num_bits == 4 or num_bits == 8 # noqa: N806
  424. if satisfy_MatMulNBits_condition: # pragma: no cover
  425. if algorithm == "k_quant":
  426. q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size)
  427. else:
  428. q_weight, scale, zp = quant_tensor(
  429. weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1)
  430. )
  431. q_matmul_node, new_inits = make_matmul_weight_only_node(
  432. node=node,
  433. weight_shape=org_w_shape,
  434. num_bits=num_bits,
  435. group_size=group_size,
  436. k_blocks=k_blocks,
  437. q_weight=q_weight.astype("uint8"),
  438. scale=scale.astype(dtype),
  439. zero_point=zp if scheme == "asym" or algorithm == "k_quant" else None,
  440. accuracy_level=accuracy_level,
  441. )
  442. model.add_initializers(new_inits)
  443. remove_nodes.append(node)
  444. new_nodes.append(q_matmul_node)
  445. else:
  446. q_weight = qdq_tensor(weight.T, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1))
  447. q_weight = np.reshape(q_weight, (org_w_shape[1], -1))
  448. q_weight = np.transpose(q_weight)
  449. q_weight = q_weight[: org_w_shape[0], :].astype(dtype)
  450. q_weight_tensor = onnx.helper.make_tensor(
  451. name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
  452. data_type=np_dtype_to_tensor_dtype(dtype),
  453. dims=weight.shape,
  454. vals=q_weight.tobytes(),
  455. raw=True,
  456. )
  457. model.add_initializer(q_weight_tensor)
  458. node.input[1] = q_weight_tensor.name
  459. if init_share_num == 1:
  460. model.remove_initializer(weight_tensor)
  461. model.add_nodes(new_nodes)
  462. model.remove_nodes(remove_nodes)
  463. model.topological_sort()
  464. return model
  465. def get_weight_scale(weight, group_size):
  466. """Get the scale of weight."""
  467. org_shape = weight.shape
  468. weight = np.reshape(weight, (-1, group_size)) if group_size != -1 else weight
  469. scale = np.mean(np.reshape(np.abs(weight) / np.max(np.abs(weight), axis=1, keepdims=True), org_shape), axis=0)
  470. return scale
  471. def prepare_inputs(model, n_samples, dataloader, providers):
  472. """Prepare inputs for weight only quantization.
  473. Args:
  474. model (ModelProto or ONNXModel): onnx model
  475. n_samples (int, optional): calibration sample number. -1 means all samples.
  476. dataloader (object): dataloader for calibration.
  477. providers (list): providers to use
  478. Returns:
  479. inputs: prepared inputs.
  480. so: session options
  481. """
  482. from importlib.util import find_spec # noqa: PLC0415
  483. from .util import to_numpy # noqa: PLC0415
  484. so = ort.SessionOptions()
  485. if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover
  486. from onnxruntime_extensions import get_library_path # noqa: PLC0415
  487. so.register_custom_ops_library(get_library_path())
  488. if model.is_large_model:
  489. onnx.save_model(
  490. model.model,
  491. model.model_path + "_augment.onnx",
  492. save_as_external_data=True,
  493. all_tensors_to_one_file=True,
  494. convert_attribute=False,
  495. )
  496. session = (
  497. ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
  498. if not model.is_large_model
  499. else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
  500. )
  501. inputs_names = [i.name for i in session.get_inputs()]
  502. del session
  503. inputs = []
  504. for i, data in enumerate(dataloader):
  505. if n_samples != -1 and ((i + 1) * dataloader.batch_size) > n_samples:
  506. break
  507. if len(inputs_names) != 1 or isinstance(data[0], dict):
  508. assert len(data[0]) == len(inputs_names), (
  509. f"Input number mismatch, require {len(inputs_names)} but get {len(data[0])}"
  510. )
  511. if isinstance(data[0], dict):
  512. inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()])) # noqa: C404
  513. elif isinstance(data[0], np.ndarray): # pragma: no cover
  514. inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]], strict=False)])) # noqa: C404
  515. else: # pragma: no cover
  516. inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0], strict=False)])) # noqa: C404
  517. return inputs, so
  518. def gptq(
  519. W,
  520. H,
  521. num_bits=4,
  522. group_size=32,
  523. scheme="asym",
  524. blocksize=128,
  525. percdamp=0.01,
  526. actorder=False,
  527. mse=False,
  528. perchannel=True,
  529. ):
  530. """Quant the weight with GPTQ method.
  531. Args:
  532. W (array): weight.
  533. H (array): Hessian matrix.
  534. num_bits (int, optional): num_bits. Default is 4.
  535. group_size (int, optional): how many elements share one scale/zp. Default is 32.
  536. scheme (str, optional): sym or asym. Defaults to "asym".
  537. blocksize (int, optional): blocksize to quantize weight.
  538. percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
  539. actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
  540. mse (bool, optional): whether get scale and zero point with mse error.
  541. perchannel (bool, optional): whether quantize weight per-channel.
  542. Returns:
  543. Q: fake quantized weight
  544. """
  545. maxq = 2**num_bits - 1
  546. grid = 100
  547. maxshrink = 0.8
  548. norm = 2.4
  549. def find_params(weight):
  550. org_shape = weight.shape
  551. # find zp, scale
  552. if not perchannel:
  553. weight = np.expand_dims(weight.flatten(), axis=1)
  554. tmp = np.zeros(weight.shape[1])
  555. xmin = np.minimum(np.min(weight, axis=0), tmp)
  556. xmax = np.maximum(np.max(weight, axis=0), tmp)
  557. if scheme == "sym":
  558. xmax = np.maximum(np.abs(xmin), xmax)
  559. tmp = xmin < 0
  560. if np.any(tmp):
  561. xmin[tmp] = -xmax[tmp]
  562. tmp = (xmin == 0) & (xmax == 0)
  563. xmin[tmp] = -1
  564. xmax[tmp] = +1
  565. scale = (xmax - xmin) / maxq
  566. if scheme == "sym":
  567. zero = np.ones(scale.shape) * (maxq + 1) / 2
  568. else:
  569. zero = np.round(-xmin / scale)
  570. if mse:
  571. best = np.ones([weight.shape[1]]) * float("inf")
  572. for i in range(int(maxshrink * grid)):
  573. p = 1 - i / grid
  574. xmin1 = p * xmin
  575. xmax1 = p * xmax
  576. scale1 = (xmax1 - xmin1) / maxq
  577. zero1 = np.round(-xmin1 / scale1) if scheme != "sym" else zero
  578. q = np.clip(np.round(weight / scale1) + zero1, 0, maxq)
  579. q -= weight
  580. q = np.power(np.abs(q), norm)
  581. err = np.sum(q, 0)
  582. tmp = err < best
  583. if np.any(tmp):
  584. best[tmp] = err[tmp]
  585. scale[tmp] = scale1[tmp]
  586. zero[tmp] = zero1[tmp]
  587. if not perchannel:
  588. tmp = org_shape[1]
  589. scale = np.repeat(scale, tmp)
  590. zero = np.repeat(zero, tmp)
  591. shape = [-1] + [1] * (len(org_shape) - 1)
  592. scale = np.reshape(scale, shape)
  593. zero = np.reshape(zero, shape)
  594. return scale, zero
  595. shape = W.shape
  596. scale, zp = find_params(W)
  597. dead = np.diag(H) == 0
  598. H[dead, dead] = 1
  599. W[dead, :] = 0 # such channel makes no contribution to quantization computation
  600. # rearrange considering the diag's value
  601. if actorder:
  602. perm = np.argsort(np.diag(H))[::-1]
  603. W = W[perm, :] # noqa: N806
  604. H = H[perm, :][:, perm] # noqa: N806
  605. Losses = np.zeros_like(W) # noqa: N806
  606. Q = np.zeros_like(W) # noqa: N806
  607. damp = percdamp * np.mean(np.diag(H))
  608. diag = np.arange(shape[0])
  609. H[diag, diag] += damp # add a average value of
  610. H = np.linalg.cholesky(np.linalg.inv(H)).T # noqa: N806
  611. Hinv = H # noqa: N806
  612. for i1 in range(0, shape[0], blocksize):
  613. i2 = min(i1 + blocksize, shape[0])
  614. count = i2 - i1
  615. W1 = copy.deepcopy(W[i1:i2, :]) # noqa: N806
  616. Q1 = np.zeros_like(W1) # noqa: N806
  617. Err1 = np.zeros_like(W1) # noqa: N806
  618. Losses1 = np.zeros_like(W1) # noqa: N806
  619. Hinv1 = Hinv[i1:i2, i1:i2] # noqa: N806
  620. for i in range(count): # within a block, channel wise
  621. w = W1[i, :]
  622. d = Hinv1[i, i]
  623. if group_size != -1:
  624. if (i1 + i) % group_size == 0:
  625. scale, zp = find_params(W[(i1 + i) : (i1 + i + group_size), :])
  626. q = (scale * (np.clip(np.round(w[:, np.newaxis] / scale) + zp, 0, maxq) - zp)).flatten()
  627. Q1[i, :] = q
  628. Losses1[i, :] = (w - q) ** 2 / d**2
  629. err1 = (w - q) / d
  630. W1[i:, :] -= np.matmul(np.expand_dims(Hinv1[i:, i], axis=1), np.expand_dims(err1, axis=0))
  631. Err1[i, :] = err1
  632. Q[i1:i2, :] = Q1
  633. Losses[i1:i2, :] = Losses1 / 2
  634. W[i2:, :] -= np.matmul(Hinv[i2:, i1:i2], Err1)
  635. if actorder:
  636. invperm = np.argsort(perm)
  637. Q = Q[invperm, :] # noqa: N806
  638. Q = np.reshape(Q, W.shape) # noqa: N806
  639. del W
  640. return Q
  641. def gptq_quantize(
  642. model,
  643. dataloader,
  644. weight_config={}, # noqa: B006
  645. num_bits=4,
  646. group_size=32,
  647. scheme="asym",
  648. n_samples=128,
  649. percdamp=0.01,
  650. blocksize=128,
  651. actorder=False,
  652. mse=False,
  653. perchannel=True,
  654. accuracy_level=0,
  655. providers=["CPUExecutionProvider"], # noqa: B006
  656. ):
  657. """Quant the model with GPTQ method.
  658. Args:
  659. model (ModelProto or ONNXModel): onnx model
  660. dataloader (object): dataloader for calibration.
  661. weight_config (dict): quantization config
  662. For example,
  663. weight_config = {
  664. 'fc2':
  665. {
  666. 'bits': 4,
  667. 'group_size': 32,
  668. 'scheme': 'sym',
  669. 'algorithm': 'GPTQ'
  670. }
  671. }
  672. num_bits (int, optional): num_bits. Default is 4.
  673. group_size (int, optional): how many elements share one scale/zp. Default is 32.
  674. scheme (str, optional): sym or asym. Defaults to "asym".
  675. n_samples (int, optional): calibration sample number.
  676. percdamp (float, optional): percent of the average Hessian diagonal to use for dampening.
  677. blocksize (int, optional): blocksize to quantize weight.
  678. actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value.
  679. mse (bool, optional): whether get scale and zero point with mse error.
  680. perchannel (bool, optional): whether quantize weight per-channel.
  681. accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8).
  682. providers (list): providers to use
  683. Returns:
  684. model: fake quantized ONNXModel
  685. """
  686. model = ONNXModel(model)
  687. base_dir = os.path.dirname(model.model_path) if model.model_path is not None else ""
  688. inputs, so = prepare_inputs(model, n_samples, dataloader, providers)
  689. del dataloader
  690. org_output = copy.deepcopy(model.model.graph.output)
  691. model.remove_tensors_from_outputs([i.name for i in org_output])
  692. output_names = []
  693. for node in model.nodes():
  694. if (
  695. node.op_type in ["MatMul"]
  696. and weight_config.get(node.name, {}) != "fp32"
  697. and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
  698. ):
  699. output_names.append(node.input[0])
  700. output_names = list(set(output_names))
  701. model.add_tensors_to_outputs(output_names)
  702. if model.is_large_model:
  703. onnx.save_model(
  704. model.model,
  705. model.model_path + "_augment.onnx",
  706. save_as_external_data=True,
  707. all_tensors_to_one_file=True,
  708. convert_attribute=False,
  709. )
  710. session = (
  711. ort.InferenceSession(model.model.SerializeToString(), so, providers=providers)
  712. if not model.is_large_model
  713. else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers)
  714. )
  715. for idx, input_name in enumerate(output_names):
  716. simple_progress_bar(len(output_names), idx + 1)
  717. node_list = []
  718. weights = []
  719. for node in model.input_name_to_nodes[input_name]:
  720. if (
  721. node.op_type in ["MatMul"]
  722. and weight_config.get(node.name, {}) != "fp32"
  723. and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ"
  724. and model.get_initializer(node.input[1]) is not None
  725. ):
  726. weight = numpy_helper.to_array(
  727. model.get_initializer(model.get_node(node.name).input[1]), base_dir
  728. ).copy()
  729. if len(weight.shape) != 2:
  730. continue
  731. weights.append(weight)
  732. node_list.append(model.get_node(node.name))
  733. if len(weights) == 0:
  734. continue
  735. Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] # noqa: N806
  736. nsamples = 0
  737. for data in inputs:
  738. inp = session.run([input_name], data)[0]
  739. tmp = inp.shape[0]
  740. inp = np.reshape(inp, (-1, inp.shape[-1]))
  741. Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] # noqa: N806
  742. nsamples += tmp
  743. inp = np.sqrt(2 / nsamples) * inp
  744. Hs = [i + np.matmul(inp.T, inp) for i in Hs] # noqa: N806
  745. for (
  746. node,
  747. weight,
  748. H, # noqa: N806
  749. ) in zip(node_list, weights, Hs, strict=False):
  750. if node.name in weight_config:
  751. num_bits = weight_config[node.name]["bits"]
  752. group_size = weight_config[node.name]["group_size"]
  753. scheme = weight_config[node.name]["scheme"]
  754. group_size = group_size if group_size != -1 else weight.shape[0]
  755. dtype = weight.dtype
  756. q_weight = gptq(
  757. weight,
  758. H,
  759. num_bits=num_bits,
  760. group_size=group_size,
  761. scheme=scheme,
  762. blocksize=blocksize,
  763. percdamp=percdamp,
  764. actorder=actorder,
  765. mse=mse,
  766. perchannel=perchannel,
  767. )
  768. weight_tensor = model.get_initializer(node.input[1])
  769. init_share_num = model.get_initializer_share_num(node.input[1])
  770. satisfy_MatMulNBits_condition = num_bits == 4 # noqa: N806
  771. if satisfy_MatMulNBits_condition: # pragma: no cover
  772. org_shape = weight.shape
  773. k_blocks = (org_shape[0] + group_size - 1) // group_size
  774. q_weight = pad_tensor(q_weight, group_size, k_blocks)
  775. q_weight, scale, zp = quant_tensor(q_weight.T, num_bits, group_size, scheme, "uint")
  776. q_matmul_node, new_inits = make_matmul_weight_only_node(
  777. node=node,
  778. weight_shape=org_shape,
  779. num_bits=num_bits,
  780. group_size=group_size,
  781. k_blocks=k_blocks,
  782. q_weight=q_weight.astype("uint8"),
  783. scale=scale.astype(dtype),
  784. zero_point=zp if scheme == "asym" else None,
  785. accuracy_level=accuracy_level,
  786. )
  787. model.add_initializers(new_inits)
  788. model.remove_node(node)
  789. model.add_node(q_matmul_node)
  790. else:
  791. q_weight_tensor = onnx.helper.make_tensor(
  792. name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}",
  793. data_type=np_dtype_to_tensor_dtype(dtype),
  794. dims=q_weight.shape,
  795. vals=q_weight.astype(dtype).tobytes(),
  796. raw=True,
  797. )
  798. model.add_initializer(q_weight_tensor)
  799. node.input[1] = q_weight_tensor.name
  800. if init_share_num == 1:
  801. model.remove_initializer(weight_tensor)
  802. model.remove_tensors_from_outputs(output_names)
  803. model.model.graph.output.MergeFrom(org_output)
  804. model.topological_sort()
  805. # reload external data to prevent external data file path errors
  806. if model.is_large_model:
  807. from onnx.external_data_helper import load_external_data_for_model # noqa: PLC0415
  808. load_external_data_for_model(model.model, os.path.split(model.model_path)[0])
  809. return model