quant_utils.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import copy
  8. import logging
  9. import os
  10. import tempfile
  11. from enum import Enum
  12. from pathlib import Path
  13. import numpy
  14. import onnx
  15. from onnx import ModelProto, TensorProto, external_data_helper
  16. from onnx import onnx_pb as onnx_proto
  17. from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
  18. from onnx.reference import ReferenceEvaluator
  19. from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
  20. try:
  21. from onnx.reference.custom_element_types import float8e4m3fn
  22. except ImportError:
  23. float8e4m3fn = None
  24. # INT4 np.dtypes added in ONNX 1.16. These map to np.int8/np.uint8 because numpy
  25. # does not support sub-byte types.
  26. try:
  27. from onnx.reference.custom_element_types import int4, uint4
  28. except ImportError:
  29. int4 = None
  30. uint4 = None
  31. try:
  32. from onnx.reference.op_run import to_array_extended
  33. except ImportError:
  34. # old version of onnx.
  35. to_array_extended = None
  36. __producer__ = "onnx.quantize"
  37. __version__ = "0.1.0"
  38. onnx_domain = "ai.onnx"
  39. ms_domain = "com.microsoft"
  40. QUANT_OP_NAME = "QuantizeLinear"
  41. QUANT_INPUT_SUFFIX = "_QuantizeLinear_Input"
  42. DEQUANT_OP_NAME = "DequantizeLinear"
  43. DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output"
  44. TENSOR_NAME_QUANT_SUFFIX = "_quantized"
  45. MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB
  46. FLOAT8_DISTRIBUTIONS = {}
  47. type_to_name = {getattr(TensorProto, k): k for k in dir(TensorProto) if isinstance(getattr(TensorProto, k), int)}
  48. # Quantization mode
  49. # IntegerOps: Use IntegerOps in quantized model. Only ConvInteger and MatMulInteger ops are supported now.
  50. # QLinearOps: Use QLinearOps in quantized model. Only QLinearConv and QLinearMatMul ops are supported now.
  51. class QuantizationMode(Enum):
  52. IntegerOps = 0
  53. QLinearOps = 1
  54. def __str__(self):
  55. return self.name
  56. @staticmethod
  57. def from_string(mode):
  58. try:
  59. return QuantizationMode[mode]
  60. except KeyError:
  61. raise ValueError() # noqa: B904
  62. class QuantizedValueType(Enum):
  63. Input = 0
  64. Initializer = 1
  65. def __str__(self):
  66. return self.name
  67. @staticmethod
  68. def from_string(v):
  69. try:
  70. return QuantizedValueType[v]
  71. except KeyError:
  72. raise ValueError() # noqa: B904
  73. class QuantType(Enum):
  74. QInt8 = 0
  75. QUInt8 = 1
  76. QFLOAT8E4M3FN = 2
  77. QInt16 = 3
  78. QUInt16 = 4
  79. QInt4 = 5
  80. QUInt4 = 6
  81. def __str__(self):
  82. return self.name
  83. @staticmethod
  84. def from_string(t):
  85. try:
  86. return QuantType[t]
  87. except KeyError:
  88. raise ValueError() # noqa: B904
  89. @property
  90. def tensor_type(self):
  91. if self == QuantType.QInt8:
  92. return TensorProto.INT8
  93. if self == QuantType.QUInt8:
  94. return TensorProto.UINT8
  95. if self == QuantType.QUInt16:
  96. return TensorProto.UINT16
  97. if self == QuantType.QInt16:
  98. return TensorProto.INT16
  99. if self == QuantType.QFLOAT8E4M3FN:
  100. return TensorProto.FLOAT8E4M3FN
  101. if self == QuantType.QUInt4:
  102. return TensorProto.UINT4
  103. if self == QuantType.QInt4:
  104. return TensorProto.INT4
  105. raise ValueError(f"Unexpected value qtype={self!r}.")
  106. class QuantFormat(Enum):
  107. QOperator = 0
  108. QDQ = 1
  109. def __str__(self):
  110. return self.name
  111. @staticmethod
  112. def from_string(format):
  113. try:
  114. return QuantFormat[format]
  115. except KeyError:
  116. raise ValueError() # noqa: B904
  117. ONNX_TYPE_TO_NP_TYPE = {
  118. onnx_proto.TensorProto.INT8: numpy.dtype("int8"),
  119. onnx_proto.TensorProto.UINT8: numpy.dtype("uint8"),
  120. onnx_proto.TensorProto.INT16: numpy.dtype("int16"),
  121. onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"),
  122. onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn,
  123. onnx_proto.TensorProto.INT4: int4, # base_dtype is np.int8
  124. onnx_proto.TensorProto.UINT4: uint4, # base_dtype is np.uint8
  125. }
  126. ONNX_INT_TYPE_RANGE = {
  127. onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(255, dtype=numpy.uint8)),
  128. onnx_proto.TensorProto.INT8: (numpy.array(-128, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
  129. onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65535, dtype=numpy.uint16)),
  130. onnx_proto.TensorProto.INT16: (numpy.array(-32768, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
  131. onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=uint4), numpy.array(15, dtype=uint4)),
  132. onnx_proto.TensorProto.INT4: (numpy.array(-8, dtype=int4), numpy.array(7, dtype=int4)),
  133. }
  134. ONNX_INT_TYPE_SYMMETRIC_RANGE = {
  135. onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(254, dtype=numpy.uint8)),
  136. onnx_proto.TensorProto.INT8: (numpy.array(-127, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
  137. onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65534, dtype=numpy.uint16)),
  138. onnx_proto.TensorProto.INT16: (numpy.array(-32767, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
  139. }
  140. ONNX_INT_TYPE_REDUCED_RANGE = {
  141. onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(127, dtype=numpy.uint8)),
  142. onnx_proto.TensorProto.INT8: (numpy.array(-64, dtype=numpy.int8), numpy.array(64, dtype=numpy.int8)),
  143. onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(32767, dtype=numpy.uint16)),
  144. onnx_proto.TensorProto.INT16: (numpy.array(-16384, dtype=numpy.int16), numpy.array(16384, dtype=numpy.int16)),
  145. onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=int4), numpy.array(7, dtype=int4)),
  146. onnx_proto.TensorProto.INT4: (numpy.array(-4, dtype=int4), numpy.array(3, dtype=int4)),
  147. }
  148. def _check_type(*args, zero_point_index=-1):
  149. new_args = []
  150. for i, a in enumerate(args):
  151. if numpy.issubdtype(type(a), numpy.number):
  152. new_args.append(numpy.array(a))
  153. elif isinstance(a, numpy.ndarray):
  154. new_args.append(a)
  155. else:
  156. raise TypeError(f"arg {i} is not an array: {a}")
  157. if i == zero_point_index:
  158. v = new_args[-1]
  159. if v.dtype == numpy.float32 or v.dtype == numpy.float16:
  160. raise TypeError(f"zero_point cannot be {v.dtype}")
  161. return tuple(new_args) if len(new_args) > 1 else new_args[0]
  162. def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
  163. assert qType in ONNX_TYPE_TO_NP_TYPE, (
  164. f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported."
  165. )
  166. if qType in (
  167. onnx_proto.TensorProto.FLOAT8E4M3FN,
  168. onnx_proto.TensorProto.FLOAT8E4M3FNUZ,
  169. onnx_proto.TensorProto.FLOAT8E5M2,
  170. onnx_proto.TensorProto.FLOAT8E5M2FNUZ,
  171. ):
  172. if zero_point != 0:
  173. raise NotImplementedError(f"zero_point is expected to be null for float 8 not {zero_point!r}.")
  174. if arr.dtype == numpy.float32:
  175. onnx_type = TensorProto.FLOAT
  176. elif arr.dtype == numpy.float16:
  177. onnx_type = TensorProto.FLOAT16
  178. else:
  179. raise ValueError(f"Unexpected dtype {arr.dtype}.")
  180. onnx_model = make_model(
  181. make_graph(
  182. [
  183. make_node(
  184. "Constant", [], ["zero_point"], value=onnx.helper.make_tensor("zero_point", qType, [], [0])
  185. ),
  186. make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]),
  187. ],
  188. "qu",
  189. [
  190. make_tensor_value_info("X", onnx_type, None),
  191. make_tensor_value_info("scale", onnx_type, None),
  192. ],
  193. [make_tensor_value_info("Y", qType, None)],
  194. )
  195. )
  196. ref = ReferenceEvaluator(onnx_model)
  197. return _check_type(ref.run(None, {"X": arr, "scale": scale})[0])
  198. else:
  199. # Quantizes data for all integer types.
  200. #
  201. # For int4 types, the quantized data is returned as either np.int8 or np.uint8,
  202. # which matches the python reference ONNX implementation of QuantizeLinear.
  203. # This data can be packed into 4-bit elements by using pack_bytes_to_4bit().
  204. dtype = ONNX_TYPE_TO_NP_TYPE[qType]
  205. qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False)
  206. cliplow = max(qmin, low) if low is not None else qmin
  207. cliphigh = min(qmax, high) if high is not None else qmax
  208. arr_fp32 = numpy.asarray((arr.astype(numpy.float32) / scale).round() + zero_point)
  209. numpy.clip(arr_fp32, cliplow, cliphigh, out=arr_fp32)
  210. return _check_type(arr_fp32.astype(dtype))
  211. def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=None):
  212. """Calculate the scale s and zero point z for the quantization relation
  213. r = s(q-z), where r are the original values and q are the corresponding
  214. quantized values.
  215. r and z are calculated such that every value within [rmin,rmax] has an
  216. approximate representation within [qmin,qmax]. In addition, qmin <= z <=
  217. qmax is enforced. If the symmetric flag is set to True, the interval
  218. [rmin,rmax] is symmetrized to [-absmax, +absmax], where
  219. absmax = max(abs(rmin), abs(rmax)).
  220. :parameter rmin: minimum value of r
  221. :parameter rmax: maximum value of r
  222. :parameter qmin: minimum value representable by the target quantization data type
  223. :parameter qmax: maximum value representable by the target quantization data type
  224. :parameter symmetric: True if the floating-point range should be made symmetric. Defaults to False.
  225. :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
  226. :return: zero and scale [z, s]
  227. """
  228. if qmin > 0 or qmax < 0:
  229. raise ValueError(f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while qmin:{qmin}, qmmax:{qmax}")
  230. # Adjust rmin and rmax such that 0 is included in the range. This is
  231. # required to make sure zero can be represented by the quantization data
  232. # type (i.e. to make sure qmin <= zero_point <= qmax)
  233. rmin = numpy.minimum(rmin, numpy.array(0, dtype=rmin.dtype))
  234. rmax = numpy.maximum(rmax, numpy.array(0, dtype=rmax.dtype))
  235. # Ensure a minimum float-point range if specified.
  236. if min_real_range is not None:
  237. rmax = max(rmax, rmin + numpy.asarray(min_real_range, dtype=rmin.dtype))
  238. if symmetric:
  239. absmax = numpy.maximum(numpy.abs(rmin), numpy.abs(rmax))
  240. rmin = -absmax
  241. rmax = +absmax
  242. assert qmin <= qmax, f"qmin={rmin} > qmax={rmax}"
  243. dr = numpy.array(rmax - rmin, dtype=numpy.float64)
  244. dq = numpy.array(qmax, dtype=numpy.float64) - numpy.array(qmin, dtype=numpy.float64)
  245. scale = numpy.array(dr / dq)
  246. assert scale >= 0, "scale issue"
  247. if scale < numpy.finfo(rmax.dtype).tiny:
  248. scale = numpy.array(1.0, dtype=rmax.dtype)
  249. zero_point = numpy.array(0, dtype=qmin.dtype)
  250. else:
  251. if symmetric:
  252. # When symmetric (i.e., rmax == -rmin), the zero_point formula reduces to round((qmax + qmin) / 2.0).
  253. # This simpler formula doesn't depend on scale and guarantees that the zero point values
  254. # for int8, uint8, int16, and uint16 are always 0, 128, 0, and 32768, respectively.
  255. # This is important for per-channel/symmetric QLinearConv on CPU EP, which requires all channels to have
  256. # the exact same zero_point values.
  257. zero_point = numpy.array(
  258. numpy.round((qmin + qmax) / numpy.array(2.0, dtype=numpy.float64)), dtype=qmin.dtype
  259. )
  260. else:
  261. zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=qmin.dtype)
  262. scale = scale.astype(rmax.dtype)
  263. return [zero_point, scale]
  264. def compute_scale_zp_float8(element_type, std):
  265. """Calculate the scale s for a float8 type (E4M3FN).
  266. The function assumes the coefficient distribution and the float 8
  267. distribution are similar to two gaussian laws.
  268. :return: zero and scale [z, s]
  269. More details in notebook `quantization_fp8.ipynb
  270. <https://github.com/microsoft/onnxruntime/blob/main/docs/python/notebooks/quantization_fp8.ipynb>`_.
  271. """
  272. zp_dtype = None
  273. if element_type not in FLOAT8_DISTRIBUTIONS:
  274. if element_type == TensorProto.FLOAT8E4M3FN:
  275. from onnx.numpy_helper import float8e4m3_to_float32 # noqa: PLC0415
  276. from onnx.reference.custom_element_types import float8e4m3fn # noqa: PLC0415
  277. zp_dtype = float8e4m3fn
  278. all_values = [float8e4m3_to_float32(i) for i in range(256)]
  279. values = numpy.array(
  280. [f for f in all_values if not numpy.isnan(f) and not numpy.isinf(f)], dtype=numpy.float32
  281. )
  282. else:
  283. raise ValueError(f"Quantization to element_type={element_type} not implemented.")
  284. FLOAT8_DISTRIBUTIONS[element_type] = values
  285. elif element_type == TensorProto.FLOAT8E4M3FN:
  286. from onnx.reference.custom_element_types import float8e4m3fn # noqa: PLC0415
  287. zp_dtype = float8e4m3fn
  288. if zp_dtype is None:
  289. raise TypeError(f"Unexpected element_type {element_type}.")
  290. std_f8 = numpy.std(FLOAT8_DISTRIBUTIONS[element_type])
  291. zero = numpy.array(0, dtype=zp_dtype)
  292. scale = numpy.array(std / std_f8, dtype=std.dtype)
  293. return [zero, scale]
  294. def compute_data_quant_params(
  295. data: numpy.ndarray,
  296. quant_type: onnx.TensorProto.DataType,
  297. symmetric: bool,
  298. reduce_range: bool = False,
  299. min_real_range: float | None = None,
  300. rmin_override: float | None = None,
  301. rmax_override: float | None = None,
  302. ) -> tuple[numpy.ndarray, numpy.ndarray]:
  303. """
  304. Returns the zero_point and scale for the given data.
  305. :param data: The data for which to compute quantization parameters.
  306. :param quant_type: The quantization data type.
  307. :param symmetric: whether symmetric quantization is used or not.
  308. :parameter reduce_range: True if the quantization range should be reduced. Defaults to False.
  309. :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
  310. :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data).
  311. :parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data).
  312. :return: zero point and scale
  313. """
  314. if not isinstance(data, numpy.ndarray):
  315. raise TypeError(f"Weight must be given as an array not {type(data)}.")
  316. if rmin_override is not None:
  317. rmin = rmin_override
  318. else:
  319. rmin = data.min() if len(data) else 0.0
  320. if rmax_override is not None:
  321. rmax = rmax_override
  322. else:
  323. rmax = data.max() if len(data) else 0.0
  324. rmin = numpy.array(rmin, dtype=data.dtype)
  325. rmax = numpy.array(rmax, dtype=data.dtype)
  326. scale = numpy.array(1.0, dtype=data.dtype)
  327. if quant_type == TensorProto.FLOAT8E4M3FN:
  328. if reduce_range:
  329. raise RuntimeError("Unsupported option reduce_range=True for float 8.")
  330. std = numpy.std(data)
  331. zero_point, scale = compute_scale_zp_float8(quant_type, std)
  332. return _check_type(zero_point, scale, zero_point_index=0)
  333. if quant_type in (
  334. TensorProto.INT8,
  335. TensorProto.UINT8,
  336. TensorProto.INT16,
  337. TensorProto.UINT16,
  338. TensorProto.INT4,
  339. TensorProto.UINT4,
  340. ):
  341. qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range, symmetric=symmetric)
  342. if len(data):
  343. zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range)
  344. else:
  345. zero_point = numpy.array(0, dtype=qmin.dtype)
  346. return _check_type(zero_point, scale, zero_point_index=0)
  347. raise ValueError(f"Unexpected value for quant_type={quant_type}.")
  348. def quantize_data(
  349. data, qType, symmetric, reduce_range=False, min_real_range=None, rmin_override=None, rmax_override=None
  350. ) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
  351. """
  352. :param data: data to quantize
  353. :param qType: data type to quantize to.
  354. :param symmetric: whether symmetric quantization is used or not.
  355. :parameter reduce_range: True if the quantization range should be reduced. Defaults to False.
  356. :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
  357. :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data).
  358. :parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data).
  359. :return: minimum, maximum, zero point, scale, and quantized weights
  360. To pack weights, we compute a linear transformation
  361. - when data `type == uint8` mode, from `[rmin, rmax]` -> :math:`[0, 2^{b-1}]` and
  362. - when data `type == int8`, from `[-m , m]` -> :math:`[-(2^{b-1}-1), 2^{b-1}-1]` where
  363. `m = max(abs(rmin), abs(rmax))`
  364. and add necessary intermediate nodes to transform quantized weight to full weight using the equation
  365. :math:`r = S(q-z)`, where
  366. - *r*: real original value
  367. - *q*: quantized value
  368. - *S*: scale
  369. - *z*: zero point
  370. """
  371. zero_point, scale = compute_data_quant_params(
  372. data,
  373. qType,
  374. symmetric,
  375. reduce_range,
  376. min_real_range,
  377. rmin_override,
  378. rmax_override,
  379. )
  380. if qType == TensorProto.FLOAT8E4M3FN:
  381. quantized_data = quantize_nparray(qType, data, scale, zero_point)
  382. if any((quantized_data.astype(numpy.uint8).ravel() & 127) == 127):
  383. np_data = numpy.asarray(data)
  384. raise RuntimeError(
  385. f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], "
  386. f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]."
  387. )
  388. return zero_point, scale, quantized_data
  389. if qType in (
  390. TensorProto.INT8,
  391. TensorProto.UINT8,
  392. TensorProto.INT16,
  393. TensorProto.UINT16,
  394. TensorProto.INT4,
  395. TensorProto.UINT4,
  396. ):
  397. quantized_data = quantize_nparray(qType, data, scale, zero_point)
  398. return zero_point, scale, quantized_data
  399. raise ValueError(f"Unexpected value for qType={qType}.")
  400. def quantize_onnx_initializer(
  401. weight: onnx.TensorProto,
  402. quant_type: onnx.TensorProto.DataType,
  403. zero_point: numpy.ndarray,
  404. scale: numpy.ndarray,
  405. axis: int | None = None,
  406. quant_weight_name: str | None = None,
  407. ) -> onnx.TensorProto:
  408. """
  409. Returns a quantized version of the given ONNX initializer.
  410. :param weight: The ONNX initializer to quantize.
  411. :param quant_type: The final quantized data type.
  412. :param zero_point: The zero-point value to use for quantization.
  413. :param scale: The scale value to use for quantization.
  414. :param axis: The quantization axis if quantizing per-channel. Defaults to None.
  415. :param quant_weight_name: The name of the quantized initializer.
  416. If not specified, the quantized name is generated.
  417. :return: The quantized ONNX initializer.
  418. """
  419. weight_data = tensor_proto_to_array(weight)
  420. q_weight_data: numpy.ndarray | None = None
  421. if axis is None: # Per-tensor quantization
  422. q_weight_data = quantize_nparray(quant_type, weight_data.ravel(), scale, zero_point)
  423. else: # Per-channel quantization
  424. channel_count = weight_data.shape[axis]
  425. channel_dims = list(weight_data.shape) # deep copy
  426. channel_dims[axis] = 1 # only one per channel for reshape
  427. quantized_channel_data_list = []
  428. for i in range(channel_count):
  429. channel_data = weight_data.take(i, axis)
  430. channel_scale = scale[i]
  431. channel_zero_point = zero_point[i]
  432. quantized_channel_data = quantize_nparray(
  433. quant_type, channel_data.ravel(), channel_scale, channel_zero_point
  434. )
  435. quantized_channel_data_list.append(numpy.asarray(quantized_channel_data).reshape(channel_dims))
  436. q_weight_data = numpy.concatenate(quantized_channel_data_list, axis)
  437. q_weight_name = quant_weight_name if quant_weight_name else f"{weight.name}{TENSOR_NAME_QUANT_SUFFIX}"
  438. if quant_type == onnx.TensorProto.FLOAT8E4M3FN:
  439. q_weight_initializer = onnx.TensorProto()
  440. q_weight_initializer.data_type = quant_type
  441. q_weight_initializer.dims.extend(weight.dims)
  442. q_weight_initializer.name = q_weight_name
  443. # Do not remove .flatten().copy() numpy is not clear about data persistence.
  444. q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
  445. if to_array_extended is not None:
  446. # This test should not be needed but it helped catch some issues
  447. # with data persistence and tobytes.
  448. check = to_array_extended(q_weight_initializer)
  449. if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
  450. raise RuntimeError(
  451. f"The initializer of shape {weight_data.shape} could not be created, expecting "
  452. f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
  453. f"\nraw={str(q_weight_initializer)[:200]}."
  454. )
  455. elif quant_type in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
  456. if q_weight_data.dtype not in (numpy.int8, numpy.uint8):
  457. raise RuntimeError(f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values.")
  458. # We do not use onnx.helper.pack_float32_to_4bit() due to performance.
  459. # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
  460. packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))
  461. # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
  462. q_weight_initializer = onnx.helper.make_tensor(q_weight_name, quant_type, weight.dims, packed_data, raw=True)
  463. else:
  464. quant_np_dtype = onnx.helper.tensor_dtype_to_np_dtype(quant_type)
  465. q_weight_data = numpy.asarray(q_weight_data, dtype=quant_np_dtype).reshape(weight.dims)
  466. q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
  467. return q_weight_initializer
  468. def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802
  469. """
  470. Return qmin and qmax, the minimum and maximum value representable by the given qType
  471. :parameter qType: onnx.onnx_pb.TensorProto.UINT8 or onnx.onnx_pb.TensorProto.UINT8
  472. :return: qmin, qmax
  473. """
  474. if qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
  475. raise NotImplementedError("This function is not implemented for float 8 as not needed.")
  476. qrange = None
  477. if reduce_range:
  478. qrange = ONNX_INT_TYPE_REDUCED_RANGE.get(qType)
  479. elif symmetric and qType in ONNX_INT_TYPE_SYMMETRIC_RANGE:
  480. qrange = ONNX_INT_TYPE_SYMMETRIC_RANGE[qType]
  481. else:
  482. qrange = ONNX_INT_TYPE_RANGE.get(qType)
  483. if not qrange:
  484. raise ValueError(f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported.")
  485. qmin, qmax = qrange
  486. if qmin > 0 or qmax < 0:
  487. raise ValueError(
  488. f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while "
  489. f"qmin:{qmin}, qmmax:{qmax}, dtype={qmin.dtype}, reduce_range={reduce_range}, "
  490. f"symmetric={symmetric}, qType={qType}"
  491. )
  492. return qrange
  493. def get_qrange_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802
  494. """
  495. Helper function to get the quantization range for a type.
  496. parameter qType: quantization type.
  497. return: quantization range.
  498. """
  499. qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
  500. return qmax - qmin
  501. def normalize_axis(axis: int, rank: int) -> tuple[bool, int]:
  502. """
  503. Helper function that tries to return a normalized axis in the range [0, rank - 1].
  504. :parameter axis: The axis to normalize.
  505. :parameter rank: The tensor rank (number of dimensions).
  506. :return (is_valid, axis_norm)
  507. """
  508. axis_norm = axis + rank if axis < 0 else axis
  509. is_valid = axis_norm >= 0 and axis_norm < rank
  510. return is_valid, axis_norm
  511. def pack_bytes_to_4bit(src_8bit: bytes) -> bytearray:
  512. """
  513. Copies a source array of 8-bit values into a destination bytearray of packed 4-bit values.
  514. Assumes that the source values are already in the appropriate int4 range.
  515. :parameter src_8bit: The 8-bit element values to pack.
  516. :return A bytearray with every two 8-bit src elements packed into a single byte.
  517. """
  518. num_elems = len(src_8bit)
  519. if num_elems == 0:
  520. return bytearray()
  521. dst_size = (num_elems + 1) // 2 # Ex: 5 8-bit elems packed into 3 bytes
  522. dst = bytearray(dst_size)
  523. src_i: int = 0
  524. dst_i: int = 0
  525. # Pack two 8-bit elements into a single byte in each iteration.
  526. while src_i < num_elems - 1:
  527. dst[dst_i] = ((src_8bit[src_i + 1] & 0xF) << 4) | (src_8bit[src_i] & 0xF)
  528. dst_i += 1
  529. src_i += 2
  530. if src_i < num_elems:
  531. # Odd number of elements.
  532. dst[dst_i] = src_8bit[src_i] & 0xF
  533. return dst
  534. class QuantizedInitializer:
  535. """
  536. Represents a linearly quantized weight input from ONNX operators
  537. """
  538. def __init__(
  539. self,
  540. name,
  541. initializer,
  542. rmins,
  543. rmaxs,
  544. zero_points,
  545. scales,
  546. data=[], # noqa: B006
  547. quantized_data=[], # noqa: B006
  548. axis=None,
  549. ):
  550. self.name = name
  551. self.initializer = initializer # TensorProto initializer in ONNX graph
  552. self.rmins = rmins # List of minimum range for each axis
  553. self.rmaxs = rmaxs # List of maximum range for each axis
  554. # 1D tensor of zero points computed for each axis. scalar if axis is empty
  555. self.zero_points = zero_points
  556. self.scales = scales # 1D tensor of scales computed for each axis. scalar if axis is empty
  557. self.data = data # original data from initializer TensorProto
  558. self.quantized_data = quantized_data # weight-packed data from data
  559. # Scalar to specify which dimension in the initializer to weight pack.
  560. self.axis = axis
  561. # If empty, single zero point and scales computed from a single rmin and rmax
  562. class QuantizedValue:
  563. """
  564. Represents a linearly quantized value (input\\output\\intializer)
  565. """
  566. def __init__(
  567. self,
  568. name,
  569. new_quantized_name,
  570. scale_name,
  571. zero_point_name,
  572. quantized_value_type,
  573. axis=None,
  574. node_type=None,
  575. node_qtype=None,
  576. scale_type=None,
  577. ):
  578. self.original_name = name
  579. self.q_name = new_quantized_name
  580. self.scale_name = scale_name
  581. self.zp_name = zero_point_name
  582. self.value_type = quantized_value_type
  583. self.axis = axis
  584. self.node_type = node_type
  585. self.node_qtype = node_qtype
  586. self.scale_type = scale_type
  587. class BiasToQuantize:
  588. """
  589. Represents a bias to be quantized
  590. """
  591. def __init__(self, bias_name, input_name, weight_name):
  592. self.bias_name = bias_name
  593. self.input_name = input_name
  594. self.weight_name = weight_name
  595. def attribute_to_kwarg(attribute):
  596. """
  597. Convert attribute to kwarg format for use with onnx.helper.make_node.
  598. :parameter attribute: attribute in AttributeProto format.
  599. :return: attribute in {key: value} format.
  600. """
  601. if attribute.type == 0:
  602. raise ValueError(f"attribute {attribute.name} does not have type specified.")
  603. # Based on attribute type definitions from AttributeProto
  604. # definition in https://github.com/onnx/onnx/blob/main/onnx/onnx.proto
  605. if attribute.type == 1:
  606. value = attribute.f
  607. elif attribute.type == 2:
  608. value = attribute.i
  609. elif attribute.type == 3:
  610. value = attribute.s
  611. elif attribute.type == 4:
  612. value = attribute.t
  613. elif attribute.type == 5:
  614. value = attribute.g
  615. elif attribute.type == 6:
  616. value = attribute.floats
  617. elif attribute.type == 7:
  618. value = attribute.ints
  619. elif attribute.type == 8:
  620. value = attribute.strings
  621. elif attribute.type == 9:
  622. value = attribute.tensors
  623. elif attribute.type == 10:
  624. value = attribute.graphs
  625. else:
  626. raise ValueError(f"attribute {attribute.name} has unsupported type {attribute.type}.")
  627. return {attribute.name: value}
  628. def find_by_name(item_name, item_list):
  629. """
  630. Helper function to find item by name in a list.
  631. parameter item_name: name of the item.
  632. parameter item_list: list of items.
  633. return: item if found. None otherwise.
  634. """
  635. items = [item for item in item_list if item.name == item_name]
  636. return items[0] if len(items) > 0 else None
  637. def get_elem_index(elem_name, elem_list):
  638. """
  639. Helper function to return index of an item in a node list
  640. """
  641. elem_idx = -1
  642. for i in range(len(elem_list)):
  643. if elem_list[i] == elem_name:
  644. elem_idx = i
  645. return elem_idx
  646. def get_mul_node(inputs, output, name):
  647. """
  648. Helper function to create a Mul node.
  649. parameter inputs: list of input names.
  650. parameter output: output name.
  651. parameter name: name of the node.
  652. return: Mul node in NodeProto format.
  653. """
  654. return onnx.helper.make_node("Mul", inputs, [output], name)
  655. def generate_identified_filename(filename: Path, identifier: str) -> Path:
  656. """
  657. Helper function to generate a identifiable filepath by concatenating the given identifier as a suffix.
  658. """
  659. return filename.parent.joinpath(filename.stem + identifier + filename.suffix)
  660. def apply_plot(hist, hist_edges):
  661. import sys # noqa: PLC0415
  662. import matplotlib.pyplot as plt # noqa: PLC0415
  663. import numpy # noqa: PLC0415
  664. numpy.set_printoptions(threshold=sys.maxsize)
  665. print("Histogram:")
  666. print(hist)
  667. print("Histogram Edges:")
  668. print(hist_edges)
  669. plt.stairs(hist, hist_edges, fill=True)
  670. plt.xlabel("Tensor value")
  671. plt.ylabel("Counts")
  672. plt.title("Tensor value V.S. Counts")
  673. plt.show()
  674. def write_calibration_table(calibration_cache, dir="."):
  675. """
  676. Helper function to write calibration table to files.
  677. """
  678. import json # noqa: PLC0415
  679. import flatbuffers # noqa: PLC0415
  680. import numpy as np # noqa: PLC0415
  681. import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue # noqa: PLC0415
  682. import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable # noqa: PLC0415
  683. from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData # noqa: PLC0415
  684. logging.info(f"calibration cache: {calibration_cache}")
  685. class MyEncoder(json.JSONEncoder):
  686. def default(self, obj):
  687. if isinstance(obj, (TensorData, TensorsData)):
  688. return obj.to_dict()
  689. if isinstance(obj, np.ndarray):
  690. return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
  691. if isinstance(obj, CalibrationMethod):
  692. return {"CLS": obj.__class__.__name__, "value": str(obj)}
  693. return json.JSONEncoder.default(self, obj)
  694. json_data = json.dumps(calibration_cache, cls=MyEncoder)
  695. with open(os.path.join(dir, "calibration.json"), "w") as file:
  696. file.write(json_data) # use `json.loads` to do the reverse
  697. # Serialize data using FlatBuffers
  698. zero = np.array(0)
  699. builder = flatbuffers.Builder(1024)
  700. key_value_list = []
  701. for key in sorted(calibration_cache.keys()):
  702. values = calibration_cache[key]
  703. d_values = values.to_dict()
  704. floats = [
  705. float(d_values.get("highest", zero).item()),
  706. float(d_values.get("lowest", zero).item()),
  707. ]
  708. value = str(max(floats))
  709. flat_key = builder.CreateString(key)
  710. flat_value = builder.CreateString(value)
  711. KeyValue.KeyValueStart(builder)
  712. KeyValue.KeyValueAddKey(builder, flat_key)
  713. KeyValue.KeyValueAddValue(builder, flat_value)
  714. key_value = KeyValue.KeyValueEnd(builder)
  715. key_value_list.append(key_value)
  716. TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
  717. for key_value in key_value_list:
  718. builder.PrependUOffsetTRelative(key_value)
  719. main_dict = builder.EndVector()
  720. TrtTable.TrtTableStart(builder)
  721. TrtTable.TrtTableAddDict(builder, main_dict)
  722. cal_table = TrtTable.TrtTableEnd(builder)
  723. builder.Finish(cal_table)
  724. buf = builder.Output()
  725. with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file:
  726. file.write(buf)
  727. # Deserialize data (for validation)
  728. if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"):
  729. cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
  730. dict_len = cal_table.DictLength()
  731. for i in range(dict_len):
  732. key_value = cal_table.Dict(i)
  733. logging.info(key_value.Key())
  734. logging.info(key_value.Value())
  735. # write plain text
  736. with open(os.path.join(dir, "calibration.cache"), "w") as file:
  737. for key in sorted(calibration_cache.keys()):
  738. values = calibration_cache[key]
  739. d_values = values.to_dict()
  740. floats = [
  741. float(d_values.get("highest", zero).item()),
  742. float(d_values.get("lowest", zero).item()),
  743. ]
  744. value = key + " " + str(max(floats))
  745. file.write(value)
  746. file.write("\n")
  747. def smooth_distribution(p, eps=0.0001):
  748. """Given a discrete distribution (may have not been normalized to 1),
  749. smooth it by replacing zeros with eps multiplied by a scaling factor
  750. and taking the corresponding amount off the non-zero values.
  751. Ref: http://web.engr.illinois.edu/~hanj/cs412/bk3/KL-divergence.pdf
  752. https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
  753. """
  754. is_zeros = (p == 0).astype(numpy.float32)
  755. is_nonzeros = (p != 0).astype(numpy.float32)
  756. n_zeros = is_zeros.sum()
  757. n_nonzeros = p.size - n_zeros
  758. if not n_nonzeros:
  759. # raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
  760. return None
  761. eps1 = eps * float(n_zeros) / float(n_nonzeros)
  762. assert eps1 < 1.0, f"n_zeros={n_zeros}, n_nonzeros={n_nonzeros}, eps1={eps1}"
  763. hist = p.astype(numpy.float32)
  764. hist += eps * is_zeros + (-eps1) * is_nonzeros
  765. assert (hist <= 0).sum() == 0
  766. return hist
  767. def model_has_external_data(model_path: Path):
  768. model = onnx.load(model_path.as_posix(), load_external_data=False)
  769. return any(external_data_helper.uses_external_data(intializer) for intializer in model.graph.initializer)
  770. def optimize_model(model_path: Path, opt_model_path: Path):
  771. """
  772. Generate model that applies graph optimization (constant folding, etc.)
  773. parameter model_path: path to the original onnx model
  774. parameter opt_model_path: path to the optimized onnx model
  775. :return: optimized onnx model
  776. """
  777. sess_option = SessionOptions()
  778. sess_option.optimized_model_filepath = opt_model_path.as_posix()
  779. sess_option.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
  780. kwargs = {}
  781. # This will rename constant initializer names, disable it to make test pass.
  782. kwargs["disabled_optimizers"] = ["ConstantSharing"]
  783. _ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"], **kwargs)
  784. def add_pre_process_metadata(model: ModelProto):
  785. """Tag the model that it went through quantization pre-processing"""
  786. metadata_props = {"onnx.quant.pre_process": "onnxruntime.quant"}
  787. if model.metadata_props:
  788. for prop in model.metadata_props:
  789. metadata_props.update({prop.key: prop.value})
  790. onnx.helper.set_model_props(model, metadata_props)
  791. def model_has_pre_process_metadata(model: ModelProto) -> bool:
  792. """Check the model whether it went through quantization pre-processing"""
  793. if model.metadata_props:
  794. for prop in model.metadata_props:
  795. if prop.key == "onnx.quant.pre_process" and prop.value == "onnxruntime.quant":
  796. return True
  797. return False
  798. def add_infer_metadata(model: ModelProto):
  799. metadata_props = {"onnx.infer": "onnxruntime.quant"}
  800. if model.metadata_props:
  801. for p in model.metadata_props:
  802. metadata_props.update({p.key: p.value})
  803. onnx.helper.set_model_props(model, metadata_props)
  804. def model_has_infer_metadata(model: ModelProto) -> bool:
  805. if model.metadata_props:
  806. for p in model.metadata_props:
  807. if p.key == "onnx.infer" and p.value == "onnxruntime.quant":
  808. return True
  809. return False
  810. def get_opset_version(model: ModelProto) -> int:
  811. ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
  812. if len(ai_onnx_domain) != 1:
  813. raise ValueError("Failed to find proper ai.onnx domain")
  814. opset_version = ai_onnx_domain[0].version
  815. return opset_version
  816. def update_opset_version(model: ModelProto, weight_type: QuantType) -> ModelProto:
  817. opset_version = get_opset_version(model)
  818. target_opset_version = opset_version
  819. weight_quant_type = getattr(weight_type, "tensor_type", weight_type)
  820. if opset_version < 19 and weight_quant_type == onnx.TensorProto.FLOAT8E4M3FN:
  821. logging.warning(
  822. f"The original model opset version is {opset_version}, which does not support quantization to float 8. "
  823. "Please update the model to opset >= 19. Automatically update the model to opset 19. "
  824. "Please verify the quantized model."
  825. )
  826. target_opset_version = 19
  827. elif opset_version == 10:
  828. logging.warning(
  829. f"The original model opset version is {opset_version}, which does not support node fusions. "
  830. "Please update the model to opset >= 11 for better performance."
  831. )
  832. elif opset_version < 10:
  833. logging.warning(
  834. f"The original model opset version is {opset_version}, which does not support quantization. "
  835. "Please update the model to opset >= 11. Automatically update the model to opset 11. "
  836. "Please verify the quantized model."
  837. )
  838. target_opset_version = 11
  839. if target_opset_version != opset_version:
  840. model = onnx.version_converter.convert_version(model, target_opset_version)
  841. # Additional nodes may be added to the model during the opset version conversion. Run shape inference
  842. # to ensure all nodes are included in model.graph.value_info.
  843. model = save_and_reload_model_with_shape_infer(model)
  844. return model
  845. def load_model_with_shape_infer(model_path: Path) -> ModelProto:
  846. inferred_model_path = generate_identified_filename(model_path, "-inferred")
  847. onnx.shape_inference.infer_shapes_path(str(model_path), str(inferred_model_path))
  848. model = onnx.load(inferred_model_path.as_posix())
  849. add_infer_metadata(model)
  850. inferred_model_path.unlink()
  851. return model
  852. def save_and_reload_model_with_shape_infer(model: ModelProto) -> ModelProto:
  853. with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
  854. model_copy = copy.deepcopy(model)
  855. model_path = Path(quant_tmp_dir).joinpath("model.onnx")
  856. onnx.save_model(model_copy, model_path.as_posix(), save_as_external_data=True)
  857. return load_model_with_shape_infer(model_path)
  858. def tensor_proto_to_array(initializer: TensorProto) -> numpy.ndarray:
  859. if initializer.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
  860. return onnx.numpy_helper.to_array(initializer)
  861. raise ValueError(
  862. f"Only float type is supported. Weights {initializer.name} is {type_to_name[initializer.data_type]}"
  863. )
  864. def add_quant_suffix(tensor_name: str) -> str:
  865. return tensor_name + "_QuantizeLinear"
  866. def add_quant_input_suffix(tensor_name: str) -> str:
  867. return tensor_name + QUANT_INPUT_SUFFIX
  868. def add_quant_output_suffix(tensor_name) -> str:
  869. return tensor_name + "_QuantizeLinear_Output"
  870. def add_dequant_suffix(tensor_name) -> str:
  871. return tensor_name + "_DequantizeLinear"
  872. def add_dequant_input_suffix(tensor_name) -> str:
  873. return tensor_name + "_DequantizeLinear_Input"
  874. def add_dequant_output_suffix(tensor_name) -> str:
  875. return tensor_name + DEQUANT_OUTPUT_SUFFIX