qdq_quantizer.py 68 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import logging
  8. from dataclasses import dataclass
  9. from enum import Enum
  10. from typing import Any
  11. import numpy as np
  12. import onnx
  13. import onnx.numpy_helper
  14. from onnx import TensorProto
  15. from onnx import onnx_pb as onnx_proto
  16. from .base_quantizer import BaseQuantizer, QuantizationParams
  17. from .calibrate import TensorData
  18. from .quant_utils import (
  19. DEQUANT_OP_NAME,
  20. ONNX_TYPE_TO_NP_TYPE,
  21. QUANT_OP_NAME,
  22. QuantizedValue,
  23. QuantizedValueType,
  24. __producer__,
  25. __version__,
  26. add_dequant_output_suffix,
  27. add_dequant_suffix,
  28. add_quant_input_suffix,
  29. add_quant_output_suffix,
  30. add_quant_suffix,
  31. compute_data_quant_params,
  32. compute_scale_zp,
  33. compute_scale_zp_float8,
  34. find_by_name,
  35. get_qmin_qmax_for_qType,
  36. ms_domain,
  37. normalize_axis,
  38. quantize_onnx_initializer,
  39. tensor_proto_to_array,
  40. )
  41. from .registry import CreateQDQQuantizer
  42. class QDQQuantTensorType(Enum):
  43. ACTIVATION = 0
  44. WEIGHT = 1
  45. BIAS = 2
  46. # Holds the name of the node input from which a node output will share the
  47. # same quantization param initializers (zero-point and scale initializers).
  48. # Ex: A Transpose node's output will use the same quant param initializers used at the input.
  49. @dataclass
  50. class QDQQuantParamProvider:
  51. input_name: str
  52. node_name: str
  53. # Holds information for tensors that have been marked for quantization by operator quantizers.
  54. # Does not hold information for bias tensors.
  55. class QDQTensorQuantInfo:
  56. def __init__(self, tensor_type=QDQQuantTensorType.ACTIVATION, quant_para_provider=None, axis=None, data_type=None):
  57. self.tensor_type = tensor_type
  58. self.quant_para_provider = quant_para_provider
  59. self.axis = axis
  60. self.is_shared = quant_para_provider is not None
  61. assert data_type is not None
  62. self.data_type = data_type
  63. # Holds information for bias tensors that have been marked for quantization by operator quantizers.
  64. @dataclass
  65. class QDQBiasQuantInfo:
  66. node_name: str
  67. input_name: str
  68. weight_name: str
  69. beta: float
  70. # Holds quantization parameter values (scale, zp) for a tensor.
  71. # A tensor typically has a one set of quantization parameters, unless the tensor is
  72. # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
  73. @dataclass
  74. class QDQTensorQuantParams:
  75. original: QuantizationParams # Generated by producer node.
  76. converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes.
  77. converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type.
  78. def get_for_consumer(self, consumer_node_name) -> QuantizationParams:
  79. if self.converted is None: # Quantized value is not converted, return original
  80. return self.original
  81. if self.converted_recv_nodes is None: # All consumers receive the converted value
  82. return self.converted
  83. # Check if consumer node name is in the list of nodes that
  84. # receive the converted quantization value. If not, return the original value generated
  85. # by the tensor's producer.
  86. return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original
  87. # Holds scale and zero_point initializer TensorProtos.
  88. @dataclass
  89. class QDQScaleZpInitializers:
  90. scale: TensorProto
  91. zero_point: TensorProto
  92. # Holds all scale and zero-point initializers for a tensor.
  93. # A tensor typically has a one set of quantization parameters, unless the tensor is
  94. # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
  95. @dataclass
  96. class QDQTensorScaleZpInitializers:
  97. original: QDQScaleZpInitializers
  98. converted: QDQScaleZpInitializers | None
  99. converted_recv_nodes: set[str] | None
  100. # Holds cached information of a tensor's quantized values (types, zp/scale initializer names, etc.).
  101. # A tensor typically has a one set of quantization parameters, unless the tensor is
  102. # at a "mixed-precision" boundary where the activation quantization type changes (e.g., from uint8 to uint16).
  103. @dataclass
  104. class QDQTensorQuantizedValue:
  105. original: QuantizedValue
  106. converted: QuantizedValue | None
  107. converted_recv_nodes: set[str] | None
  108. def get_for_consumer(self, consumer_node_name) -> QuantizedValue:
  109. if self.converted is None: # Quantized value is not converted, return original
  110. return self.original
  111. if self.converted_recv_nodes is None: # All consumers receive the converted value
  112. return self.converted
  113. # Check if consumer node name is in the list of nodes that
  114. # receive the converted quantization value. If not, return the original value generated
  115. # by the tensor's producer.
  116. return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original
  117. class QDQQuantizer(BaseQuantizer):
  118. def __init__(
  119. self,
  120. model,
  121. per_channel,
  122. reduce_range,
  123. weight_qType,
  124. activation_qType,
  125. tensors_range,
  126. nodes_to_quantize,
  127. nodes_to_exclude,
  128. op_types_to_quantize,
  129. extra_options=None,
  130. ):
  131. BaseQuantizer.__init__(
  132. self,
  133. model,
  134. per_channel,
  135. reduce_range,
  136. weight_qType,
  137. activation_qType,
  138. tensors_range,
  139. nodes_to_quantize,
  140. nodes_to_exclude,
  141. op_types_to_quantize,
  142. extra_options,
  143. )
  144. self.tensors_to_quantize: dict[str, QDQTensorQuantInfo] = {}
  145. self.bias_to_quantize: dict[str, QDQBiasQuantInfo] = {}
  146. self.nodes_to_remove = []
  147. # Specific op types to exclude qdq quantization for their outputs.
  148. # In TRT, it's not recommended to quantize outputs for weighted ops such as Conv, Matmul, Gemm
  149. # because those ops may be followed by nodes that require high resolution inputs.
  150. # Adding QDQ for those ops' output may end up with worse accuracy.
  151. # So, we don't recommend to add QDQ to node's output under such condition.
  152. self.op_types_to_exclude_output_quantization = extra_options.get("OpTypesToExcludeOutputQuantization", [])
  153. # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization.
  154. # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair.
  155. # Therefore, we need to disable this optimization and add qdq pair to weight.
  156. self.add_qdq_pair_to_weight = extra_options.get("AddQDQPairToWeight", False)
  157. # Some scenarios do not need the bias quantized. For example, in the case of Quantization Aware Training,
  158. # quantizing the bias is not needed. This is because in QAT, all model parameters are expected to be in
  159. # floating point format. To that end, we can use the FakeQuant operator for weights and activations that
  160. # can always have QDQ pairs (by using AddQDQPairToWeight). But for biases in a quantized model, we can't use
  161. # FakeQuant because it only ever appears before a DQ (since it is quantized as int32).
  162. self.quantize_bias = extra_options.get("QuantizeBias", True)
  163. # The default behavior is that multiple nodes can share a QDQ pair as their inputs.
  164. # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node.
  165. self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False)
  166. self.tensor_to_its_receiving_nodes: dict[str, list[onnx.NodeProto]] = {}
  167. # Maps a tensor to the DequantizeLinear node (in the original input model) that outputs the tensor.
  168. # Populated for input models with some pre-quantized weights (typically via a different tool).
  169. self.tensor_to_producing_dq: dict[str, onnx.NodeProto] = {}
  170. # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True.
  171. self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {})
  172. self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None
  173. # User can specify if removable activations, like Clip/Relu, should be kept in the graph.
  174. # Used in the QDQRemovableActivation class.
  175. self.qdq_keep_removable_activations = extra_options.get("QDQKeepRemovableActivations", False)
  176. # Let user disable adjustment of weight scales for bias inputs that are quantized to int32.
  177. self.qdq_disable_weight_adjust_for_int32_bias = extra_options.get("QDQDisableWeightAdjustForInt32Bias", False)
  178. # The ONNX spec did not support 16-bit Q/DQ ops before opset 21.
  179. # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types
  180. # are 16-bit or 4-bit integers.
  181. if self.opset_version < 21:
  182. opset21_types = (TensorProto.UINT16, TensorProto.INT16, TensorProto.UINT4, TensorProto.INT4)
  183. overrides_have_opset21_types = any(
  184. t.tensor_type in opset21_types for t in self.tensor_quant_override_qtypes
  185. )
  186. if not self.qdq_op_domain and (
  187. self.activation_qType in opset21_types
  188. or self.weight_qType in opset21_types
  189. or overrides_have_opset21_types
  190. ):
  191. logging.warning(
  192. "ONNX QuantizeLinear and DequantizeLinear operators do not support "
  193. "16-bit/4-bit integer quantization types prior to opset 21. "
  194. f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to "
  195. "enable support."
  196. )
  197. self.qdq_op_domain = ms_domain
  198. self.quantization_params = self.calc_graph_quant_params()
  199. self.initializer_quant_params: dict[str, QuantizationParams] = {}
  200. # Map of all original value names to quantized value names
  201. self.quantized_value_map = {}
  202. def _get_tensor_type(self, tensor_name):
  203. """
  204. Check if tensor can be quantized
  205. """
  206. weight = find_by_name(tensor_name, self.model.initializer())
  207. if weight is not None:
  208. return weight.data_type
  209. elif tensor_name in self.value_infos:
  210. vi = self.value_infos[tensor_name]
  211. if vi.type.HasField("tensor_type"):
  212. return vi.type.tensor_type.elem_type
  213. return None
  214. def _is_tensor_quantizable(self, tensor_name):
  215. """
  216. Check if tensor can be quantized
  217. """
  218. weight = find_by_name(tensor_name, self.model.initializer())
  219. if weight is not None:
  220. if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
  221. return True
  222. elif tensor_name in self.value_infos:
  223. vi = self.value_infos[tensor_name]
  224. if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
  225. TensorProto.FLOAT,
  226. TensorProto.FLOAT16,
  227. ):
  228. return True
  229. else:
  230. logging.warning(
  231. f"failed to infer the type of tensor: {tensor_name}. Skip to quantize it. Please check if it is expected."
  232. )
  233. return False
  234. def __quantize_tensor(self, tensor_name, quant_sharing_provider=None, tensor_type=QDQQuantTensorType.ACTIVATION):
  235. """
  236. Adds a tensor to the list (actually a dict) of tensors to quantize. Called indirectly by op quantizers that
  237. want to quantize a tensor (i.e., "mark" a tensor for quantization).
  238. If quant_sharing_provider is not None, tensor with name tensor_name will be quantized with the same
  239. quantization parameters as the node input specified in quant_sharing_provider. Ex: A Tranpose node's output
  240. will typically use the same quantization parameter initializers used at the Transpose node's input.
  241. Args:
  242. tensor_name: name of the tensor to quantize
  243. quant_sharing_provider: name of the tensor and node that provides quantization parameter
  244. tensor_type: QDQQuantTensorType default ACTIVATION
  245. """
  246. if self._is_tensor_quantizable(tensor_name):
  247. if quant_sharing_provider:
  248. if not isinstance(quant_sharing_provider, QDQQuantParamProvider):
  249. raise TypeError(
  250. f"quant_sharing_provider must be of type QDQQuantParamProvider, not {type(quant_sharing_provider)}."
  251. )
  252. data_type = self._get_tensor_type(tensor_name)
  253. self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
  254. tensor_type=tensor_type, quant_para_provider=quant_sharing_provider, data_type=data_type
  255. )
  256. elif tensor_name not in self.tensors_to_quantize:
  257. data_type = self._get_tensor_type(tensor_name)
  258. self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(tensor_type=tensor_type, data_type=data_type)
  259. def quantize_activation_tensor(self, tensor_name: str):
  260. """
  261. Adds a tensor to the list of tensors to quantize. Called by op quantizers that
  262. want to quantize a tensor (i.e., "mark" a tensor for quantization).
  263. Args:
  264. tensor_name: name of the tensor to quantize
  265. """
  266. return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.ACTIVATION)
  267. def quantize_output_same_as_input(self, output_name: str, input_name: str, node_name: str):
  268. """
  269. Adds a tensor to the list of tensors to quantize. Called by op quantizers that
  270. want to quantize an output tensor using the same quantization parameters as one of the node's inputs.
  271. Ex: A Tranpose node's output will typically use the same quantization parameter initializers used at
  272. the Transpose node's input.
  273. Args:
  274. output_name: name of the node output to quantize so that it uses the same quantization params as an input.
  275. input_name: name of the node input from which the output tensor will get its quantization params.
  276. node_name: name of the node that consumes `input_name`.
  277. """
  278. return self.__quantize_tensor(
  279. output_name, QDQQuantParamProvider(input_name, node_name), QDQQuantTensorType.ACTIVATION
  280. )
  281. def quantize_weight_tensor(self, tensor_name: str):
  282. """
  283. Adds a tensor to the list of weight tensors to quantize. Called by op quantizers that
  284. want to quantize a weight (i.e., "mark" a weight for quantization).
  285. Args:
  286. tensor_name: name of the weight to quantize
  287. """
  288. return self.__quantize_tensor(tensor_name, None, QDQQuantTensorType.WEIGHT)
  289. def quantize_weight_tensor_per_channel(self, tensor_name, axis):
  290. weight = find_by_name(tensor_name, self.model.initializer())
  291. if weight:
  292. if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
  293. self.tensors_to_quantize[tensor_name] = QDQTensorQuantInfo(
  294. tensor_type=QDQQuantTensorType.WEIGHT, axis=axis, data_type=weight.data_type
  295. )
  296. else:
  297. logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.")
  298. def _dup_initializer(self, initializer: onnx.TensorProto) -> onnx.TensorProto:
  299. """
  300. Duplicates an existing initializer and adds it to the model. Returns the new initializer.
  301. """
  302. name_suffix: int = self.model.get_largest_initializer_name_suffix(initializer.name) + 1
  303. new_initializer_name = f"{initializer.name}{name_suffix}"
  304. new_initializer = onnx.TensorProto()
  305. new_initializer.CopyFrom(initializer)
  306. new_initializer.name = new_initializer_name
  307. self.model.add_initializer(new_initializer)
  308. return new_initializer
  309. def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0):
  310. """
  311. Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that
  312. want to quantize a bias with bias_zero_point = 0 and bias_scale = input_scale * weight_scale * beta.
  313. TODO: Explain the reasoning for using this formula.
  314. Args:
  315. node_name: name of the node that consumes the bias, input, and weight tensors.
  316. bias_name: name of the bias tensor to quantize.
  317. input_name: name of the input tensor whose scale is used to compute the bias's scale.
  318. weight_name: name of the weight tensor whose scale is used to compute the bias's scale.
  319. beta: Multiplier used to compute the bias's scale.
  320. """
  321. # If the user provided quantization overrides for this tensor, treat it as a regular weight.
  322. if self.tensor_quant_overrides.get(bias_name):
  323. logging.info(
  324. f"Quantizing bias tensor '{bias_name}' as a weight due to the presence of user-specified overrides"
  325. )
  326. is_per_channel, axis = self.is_tensor_per_channel(bias_name, default_axis=0)
  327. if is_per_channel:
  328. self.quantize_weight_tensor_per_channel(bias_name, axis)
  329. else:
  330. self.quantize_weight_tensor(bias_name)
  331. return
  332. bias_initializer = find_by_name(bias_name, self.model.initializer())
  333. if bias_initializer is None:
  334. logging.warning(f"Expected bias '{bias_name}' to be an initializer")
  335. return
  336. if bias_initializer.data_type not in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
  337. logging.info(f"Expected bias '{bias_name}' to be an floating-point initializer")
  338. return
  339. actual_bias_name = bias_name
  340. if bias_name in self.bias_to_quantize:
  341. # This bias input is consumed by two different nodes. We need to duplicate the bias so that
  342. # each node has its own bias input. This is necessary because the bias's scale is computed
  343. # from the node's other input scales.
  344. new_bias_initializer = self._dup_initializer(bias_initializer)
  345. actual_bias_name = new_bias_initializer.name
  346. # Replace this node's bias input
  347. self.model.replace_input_of_nodes(bias_name, actual_bias_name, {node_name})
  348. logging.info(f"Created a copy of bias input '{bias_name}' called '{actual_bias_name}'")
  349. # Add this to our list of biases to quantize.
  350. self.bias_to_quantize[actual_bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta)
  351. def _adjust_weight_scale_for_int32_bias(
  352. self,
  353. input_scale: np.ndarray,
  354. weight_scale: np.ndarray,
  355. weight_name: str,
  356. bias_tp: onnx.TensorProto,
  357. is_per_channel: bool,
  358. ) -> tuple[bool, np.ndarray | None]:
  359. """
  360. Checks if the bias scale (input_scale * weight_scale) that we intend to use is too small.
  361. A bias scale that is too small leads to quantized bias values that fall outside the range of a int32 and have to
  362. be clipped, which decreases accuracy. If this function detects such a scenario, the weight_scale value will be
  363. increased to prevent this from happening.
  364. Although the adjustment method and amount differs, the idea to adjust the weight's scale came from the following
  365. reference:
  366. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/optimize/quantization_utils.cc#L252
  367. :param input_scale: The input's scale.
  368. :param weight_scale: The weight scale to potentially adjust.
  369. :param weight_name: The weight initializer's name. Used for logging.
  370. :param bias_tp: The bias ONNX initializer.
  371. :param is_per_channel: True if the bias and weight are quantized per-channel.
  372. :return: A tuple with a bool indicating if the weight's scale was adjusted and the new weight scale.
  373. """
  374. if not weight_scale.size:
  375. return False, None
  376. bias_float_data = tensor_proto_to_array(bias_tp)
  377. int32_info = np.iinfo(np.int32)
  378. multiplicative_epsilon = 1.0001
  379. qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64)
  380. weight_scale_dtype = weight_scale.dtype
  381. updated_an_elem = False
  382. if not is_per_channel:
  383. rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64))
  384. rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64))
  385. absmax = np.maximum(np.abs(rmin), np.abs(rmax))
  386. bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange
  387. input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
  388. weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64)
  389. bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
  390. if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
  391. # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio.
  392. ratio = bias_smallest_valid_scale / bias_candidate_scale
  393. logging.info(
  394. f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to "
  395. f"ensure bias input `{bias_tp.name}` has a valid scale."
  396. )
  397. new_scale = weight_scale_fp64 * ratio
  398. weight_scale = new_scale.astype(weight_scale_dtype)
  399. updated_an_elem = True
  400. elif weight_scale.shape and len(weight_scale.shape) == 1:
  401. # per-channel case
  402. num_elems = weight_scale.shape[0]
  403. for i in range(num_elems):
  404. bias_rmax = np.abs(bias_float_data[i])
  405. bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * bias_rmax) / qrange
  406. input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64)
  407. weight_scale_fp64 = np.array(weight_scale[i].item(), dtype=np.float64)
  408. bias_candidate_scale = input_scale_fp64 * weight_scale_fp64
  409. if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0):
  410. # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio.
  411. ratio = bias_smallest_valid_scale / bias_candidate_scale
  412. logging.info(
  413. f"Increased scale[{i}] for weight `{weight_name}` by ratio {ratio} "
  414. f"to ensure bias input `{bias_tp.name}` has a valid scale."
  415. )
  416. new_scale = weight_scale_fp64 * ratio
  417. weight_scale[i] = new_scale.astype(weight_scale_dtype)
  418. updated_an_elem = True
  419. return updated_an_elem, weight_scale
  420. def _adjust_weight_quant_params_for_bias_tensors(self):
  421. """
  422. Iterates through all bias inputs that should be quantized to int32. If the intended
  423. bias scale (equal to input_scale * weight_scale) is too small, this function will increase
  424. the associated weight's scale to ensure the bias does not overflow the int32 range when quantized.
  425. """
  426. if self.qdq_disable_weight_adjust_for_int32_bias:
  427. # User passed an extra_option to disable this adjustment.
  428. return
  429. for bias_name, bias_info in self.bias_to_quantize.items():
  430. if (
  431. bias_info.input_name not in self.quantization_params
  432. or bias_info.input_name not in self.tensors_to_quantize
  433. or bias_info.weight_name not in self.initializer_quant_params
  434. ):
  435. continue
  436. # Get the associated input's scale.
  437. input_qparams = self.quantization_params[bias_info.input_name].get_for_consumer(bias_info.node_name)
  438. input_info = self.tensors_to_quantize[bias_info.input_name]
  439. input_scale = np.asarray(
  440. input_qparams["scale"], dtype=onnx.helper.tensor_dtype_to_np_dtype(input_info.data_type)
  441. )
  442. weight_quant_params = self.initializer_quant_params[bias_info.weight_name]
  443. weight_quant_type = weight_quant_params["quant_type"]
  444. if weight_quant_type not in (onnx.TensorProto.INT8, onnx.TensorProto.INT16):
  445. continue
  446. weight_zero_point: np.ndarray = weight_quant_params["zero_point"]
  447. if weight_zero_point.any():
  448. # Skip if zero_point(s) are not all zero (i.e., symmetric quant)
  449. continue
  450. weight_scale: np.ndarray = weight_quant_params["scale"]
  451. is_per_channel = weight_quant_params.get("axis", None) is not None
  452. # Get adjusted weight scales.
  453. did_update_weight_scale, new_weight_scale = self._adjust_weight_scale_for_int32_bias(
  454. input_scale,
  455. weight_scale,
  456. bias_info.weight_name,
  457. find_by_name(bias_name, self.model.initializer()),
  458. is_per_channel,
  459. )
  460. if did_update_weight_scale:
  461. weight_quant_params["scale"] = new_weight_scale
  462. def remove_node(self, node):
  463. self.nodes_to_remove.append(node)
  464. def remove_nodes(self):
  465. self.model.remove_nodes(self.nodes_to_remove)
  466. def quantize_model(self):
  467. for node in self.model.nodes():
  468. if self.should_quantize_node(node):
  469. op_quantizer = CreateQDQQuantizer(self, node)
  470. op_quantizer.quantize()
  471. for tensor_name in node.input:
  472. if tensor_name not in self.tensor_to_its_receiving_nodes:
  473. self.tensor_to_its_receiving_nodes[tensor_name] = []
  474. self.tensor_to_its_receiving_nodes[tensor_name].append(node)
  475. if node.op_type == DEQUANT_OP_NAME:
  476. for tensor_name in node.output:
  477. self.tensor_to_producing_dq[tensor_name] = node
  478. self.initializer_quant_params = self._calc_initializer_quant_params()
  479. self._adjust_weight_quant_params_for_bias_tensors()
  480. self._quantize_normal_tensors()
  481. self._quantize_sharing_param_tensors()
  482. if self.quantize_bias:
  483. self._quantize_bias_tensors()
  484. self.remove_nodes()
  485. if not self.add_qdq_pair_to_weight:
  486. self.model.clean_initializers()
  487. self.model.model.producer_name = __producer__
  488. self.model.model.producer_version = __version__
  489. if self.qdq_op_domain == ms_domain:
  490. self.model.set_opset_import(ms_domain, 1)
  491. return self.model.model
  492. def try_replacing_upstream_output(self, upstream_output_name, output_name):
  493. if (
  494. output_name in self.quantization_params
  495. and self.quantization_params[output_name].converted is None
  496. and self.quantization_params[upstream_output_name].converted is None
  497. and len(self.model.input_name_to_nodes()[upstream_output_name]) == 1
  498. and not self.model.is_graph_output(upstream_output_name)
  499. and not self.model.is_graph_input(upstream_output_name)
  500. ):
  501. self.model.replace_output_of_all_nodes(upstream_output_name, output_name)
  502. if upstream_output_name in self.tensors_to_quantize:
  503. del self.tensors_to_quantize[upstream_output_name]
  504. return True
  505. return False
  506. def _create_q_node(
  507. self,
  508. q_input: str,
  509. q_output: str,
  510. quant_node_name: str,
  511. scale_name: str,
  512. zp_name: str,
  513. axis: int | None = None,
  514. ):
  515. """
  516. Creates a QuantizeLinear node and adds it to the model.
  517. """
  518. qlinear_node = onnx.helper.make_node(
  519. QUANT_OP_NAME,
  520. [q_input, scale_name, zp_name],
  521. [q_output],
  522. quant_node_name,
  523. axis=axis,
  524. domain=self.qdq_op_domain,
  525. )
  526. self.model.add_nodes([qlinear_node])
  527. def _create_dq_node(
  528. self,
  529. dq_input: str,
  530. dq_output: str,
  531. dequant_node_name: str,
  532. scale_name: str,
  533. zp_name: str,
  534. axis: int | None = None,
  535. ):
  536. """
  537. Creates a DequantizeLinear node and adds it to the model.
  538. """
  539. dequant_node = onnx.helper.make_node(
  540. DEQUANT_OP_NAME,
  541. [dq_input, scale_name, zp_name],
  542. [dq_output],
  543. dequant_node_name,
  544. axis=axis,
  545. domain=self.qdq_op_domain,
  546. )
  547. self.model.add_nodes([dequant_node])
  548. def _create_qdq_nodes(
  549. self, q_input, q_output, quant_node_name, dq_input, dq_output, dequant_node_name, scale_name, zp_name, axis=None
  550. ):
  551. qlinear_node = onnx.helper.make_node(
  552. QUANT_OP_NAME,
  553. [q_input, scale_name, zp_name],
  554. [q_output],
  555. quant_node_name,
  556. axis=axis,
  557. domain=self.qdq_op_domain,
  558. )
  559. dequant_node = onnx.helper.make_node(
  560. DEQUANT_OP_NAME,
  561. [dq_input, scale_name, zp_name],
  562. [dq_output],
  563. dequant_node_name,
  564. axis=axis,
  565. domain=self.qdq_op_domain,
  566. )
  567. self.model.add_nodes([qlinear_node, dequant_node])
  568. def _add_qdq_nodes_for_initializer(self, weight_proto: onnx.TensorProto):
  569. """
  570. Adds Q/DQ nodes for an initializer. If `self.add_qdq_pair_to_weight` is true, creates
  571. the sequence (weight_f32 -> Q -> DQ -> ). Otherwise, this function quantizes the initializer
  572. and adds the sequence (weight_quant -> DQ ->).
  573. """
  574. weight_name = weight_proto.name
  575. if weight_name in self.quantized_value_map:
  576. return
  577. quant_params: QuantizationParams = self.initializer_quant_params[weight_name]
  578. axis: int = quant_params.get("axis")
  579. scale_zp_initializers = self._make_scale_zp_initializers(weight_name, quant_params)
  580. q_weight_name: str | None = None
  581. weight_dequant_output = add_dequant_output_suffix(weight_name)
  582. self.model.replace_input_of_all_nodes(weight_name, weight_dequant_output)
  583. if self.add_qdq_pair_to_weight:
  584. # Don't actually quantize the weight. Instead, keep floating-point weight and create the node
  585. # sequence (weight_f32 -> Q -> DQ -> weight_dequant)
  586. weight_quant_output = add_quant_output_suffix(weight_name)
  587. self._create_qdq_nodes(
  588. weight_name,
  589. weight_quant_output,
  590. add_quant_suffix(weight_name),
  591. weight_quant_output,
  592. weight_dequant_output,
  593. add_dequant_suffix(weight_name),
  594. scale_zp_initializers.scale.name,
  595. scale_zp_initializers.zero_point.name,
  596. axis,
  597. )
  598. else:
  599. # Quantize the weight and create the node sequence:
  600. # (weight_quantized -> DQ -> weight_dequant)
  601. quant_weight = quantize_onnx_initializer(
  602. weight_proto,
  603. quant_params["quant_type"],
  604. quant_params["zero_point"],
  605. quant_params["scale"],
  606. axis,
  607. )
  608. self.model.add_initializer(quant_weight)
  609. q_weight_name = quant_weight.name
  610. dequant_node = onnx.helper.make_node(
  611. DEQUANT_OP_NAME,
  612. [quant_weight.name, scale_zp_initializers.scale.name, scale_zp_initializers.zero_point.name],
  613. [weight_dequant_output],
  614. add_dequant_suffix(weight_name),
  615. axis=axis,
  616. domain=self.qdq_op_domain,
  617. )
  618. self.model.add_node(dequant_node)
  619. # Log entry for this quantized weight
  620. quantized_value = QuantizedValue(
  621. weight_name,
  622. q_weight_name,
  623. scale_zp_initializers.scale.name,
  624. scale_zp_initializers.zero_point.name,
  625. QuantizedValueType.Initializer,
  626. axis=axis,
  627. )
  628. self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None)
  629. def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_type=None):
  630. if (
  631. self.dedicated_qdq_pair
  632. and tensor_name in self.tensor_to_its_receiving_nodes
  633. and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1
  634. ):
  635. num_dedicated_qdq_pair = len(self.tensor_to_its_receiving_nodes[tensor_name])
  636. for i in range(num_dedicated_qdq_pair):
  637. postfix = f"_{i + 1}"
  638. tensor_name_quant_output_postfix = add_quant_output_suffix(tensor_name) + postfix
  639. tensor_name_dequant_output_postfix = add_dequant_output_suffix(tensor_name) + postfix
  640. quant_node_name_postfix = add_quant_suffix(tensor_name) + postfix
  641. dequant_node_name_postfix = add_dequant_suffix(tensor_name) + postfix
  642. self._create_qdq_nodes(
  643. tensor_name,
  644. tensor_name_quant_output_postfix,
  645. quant_node_name_postfix,
  646. tensor_name_quant_output_postfix,
  647. tensor_name_dequant_output_postfix,
  648. dequant_node_name_postfix,
  649. scale_name,
  650. zp_name,
  651. )
  652. node = self.tensor_to_its_receiving_nodes[tensor_name][i]
  653. self.model.replace_node_input(node, tensor_name, tensor_name_dequant_output_postfix)
  654. if i == 0:
  655. quantized_value = QuantizedValue(
  656. tensor_name,
  657. tensor_name_dequant_output_postfix,
  658. scale_name,
  659. zp_name,
  660. QuantizedValueType.Input,
  661. scale_type=data_type,
  662. )
  663. self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None)
  664. else:
  665. q_input = tensor_name
  666. dq_output = add_dequant_output_suffix(tensor_name)
  667. if self.model.is_graph_output(tensor_name):
  668. q_input = add_quant_input_suffix(tensor_name)
  669. dq_output = tensor_name
  670. self.model.replace_output_of_all_nodes(tensor_name, q_input)
  671. else:
  672. self.model.replace_input_of_all_nodes(tensor_name, dq_output)
  673. self._create_qdq_nodes(
  674. q_input,
  675. add_quant_output_suffix(tensor_name),
  676. add_quant_suffix(tensor_name),
  677. add_quant_output_suffix(tensor_name),
  678. dq_output,
  679. add_dequant_suffix(tensor_name),
  680. scale_name,
  681. zp_name,
  682. )
  683. quantized_value = QuantizedValue(
  684. tensor_name,
  685. dq_output,
  686. scale_name,
  687. zp_name,
  688. QuantizedValueType.Input,
  689. scale_type=data_type,
  690. )
  691. self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(quantized_value, None, None)
  692. def _add_qdq_ops_for_converted_activation(
  693. self,
  694. tensor_name,
  695. first_scale_name,
  696. first_zp_name,
  697. scale_data_type,
  698. convert_scale_name,
  699. convert_zp_name,
  700. convert_recv_nodes,
  701. ):
  702. """
  703. Adds Q and DQ ops to a tensor whose quantized data type is converted. That is, some consumers may use the
  704. original data type from the producer, while other consumers use the converted data type.
  705. This is generally done by adding a sequence of ops that convert from one data type (e.g., uint8) to another (e.g., uint16).
  706. T_float ---> Quant(to u8) ---> Convert(to u16) ---> Dequant(to float) ---> T_float'
  707. where Convert(to u16) is equivalent to: ---> Dequant(to float) ---> Quant(to u16) --->
  708. This function handles the following scenarios:
  709. 1) Tensor T is not a graph output; all consumers use the converted type
  710. <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> <Consumers>
  711. 2) Tensor T is not a graph output; some consumers use the original type, others use the converted type
  712. <Producer> ---> Q1 -+-> DQ1 ---> <Consumers of original type>
  713. |
  714. +-> DQ1' ---> Q2 ---> DQ2 ---> <Consumers of converted type>
  715. 3) Tensor T is a graph output; all consumers use the converted type
  716. <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 -+-> <Consumers>
  717. |
  718. +-> <Graph output>
  719. 4) Tensor T is a graph output; some consumers use the original type, others use the converted type
  720. <Producer> ---> Q1 -+-> DQ1 -+-> <Consumers of original type>
  721. | |
  722. | +-> <Graph output>
  723. |
  724. +-> DQ1' ---> Q2 ---> DQ2 ---> <Consumers of converted type>
  725. 5) Tensor T is a graph output that is not consumed by any other nodes.
  726. <Producer> ---> Q1 ---> DQ1 ---> Q2 ---> DQ2 ---> <Graph output>
  727. """
  728. tensor_recv_nodes = {node.name for node in self.tensor_to_its_receiving_nodes.get(tensor_name, [])}
  729. if (
  730. self.dedicated_qdq_pair
  731. and tensor_name in self.tensor_to_its_receiving_nodes
  732. and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1
  733. ):
  734. # TODO: Add support for dedicated_qdq_pair if/when needed.
  735. raise ValueError(
  736. "Do not currently support converted quant_types in TensorQuantOverrides when the `dedicated_qdq_pair` extra_option is enabled"
  737. )
  738. # Determine which nodes consume the original quantized type and which nodes
  739. # consume the converted quantized type.
  740. original_recv_nodes = tensor_recv_nodes
  741. if convert_recv_nodes is None: # In this case, all consumers receive the converted type.
  742. convert_recv_nodes = tensor_recv_nodes
  743. original_recv_nodes = set()
  744. else:
  745. original_recv_nodes = original_recv_nodes - convert_recv_nodes
  746. all_use_converted = len(convert_recv_nodes) == len(tensor_recv_nodes)
  747. is_graph_output = self.model.is_graph_output(tensor_name)
  748. # Create first Q op.
  749. first_q_input = tensor_name
  750. if is_graph_output:
  751. first_q_input = add_quant_input_suffix(tensor_name)
  752. self.model.replace_output_of_all_nodes(tensor_name, first_q_input)
  753. first_q_output = add_quant_output_suffix(tensor_name)
  754. self._create_q_node(
  755. first_q_input, first_q_output, add_quant_suffix(tensor_name), first_scale_name, first_zp_name
  756. )
  757. # Create first DQ op.
  758. first_dq_output = add_dequant_output_suffix(tensor_name)
  759. if is_graph_output and not all_use_converted:
  760. first_dq_output = tensor_name
  761. if original_recv_nodes and first_dq_output != tensor_name:
  762. self.model.replace_input_of_nodes(tensor_name, first_dq_output, original_recv_nodes)
  763. self._create_dq_node(
  764. first_q_output, first_dq_output, add_dequant_suffix(tensor_name), first_scale_name, first_zp_name
  765. )
  766. # Create parallel clone of first DQ op if _not all_ consumers use the converted type.
  767. # --> DQ1' --> Q2 --> DQ2 --> <Consumers of converted type>
  768. #
  769. # This DQ clone would only have one consumer Q node (Q2) and could be potentially fused with
  770. # it by some EPs (e.g., QNN) without breaking other "node units".
  771. # Ex QNN fusion:
  772. # --> Convert (fused) --> DQ2 --> <Consumers of converted type>
  773. second_q_input = first_dq_output
  774. if not all_use_converted:
  775. second_q_input = add_quant_input_suffix(f"{tensor_name}_convert")
  776. self._create_dq_node(
  777. first_q_output,
  778. second_q_input,
  779. add_dequant_suffix(f"{tensor_name}_convert_clone"),
  780. first_scale_name,
  781. first_zp_name,
  782. )
  783. # Create second Q op.
  784. second_q_output = add_quant_output_suffix(f"{tensor_name}_convert")
  785. self._create_q_node(
  786. second_q_input,
  787. second_q_output,
  788. add_quant_suffix(f"{tensor_name}_convert"),
  789. convert_scale_name,
  790. convert_zp_name,
  791. )
  792. # Create second DQ op.
  793. second_dq_output = add_dequant_output_suffix(f"{tensor_name}_convert")
  794. if is_graph_output and all_use_converted:
  795. second_dq_output = tensor_name
  796. if convert_recv_nodes and second_dq_output != tensor_name:
  797. self.model.replace_input_of_nodes(tensor_name, second_dq_output, convert_recv_nodes)
  798. self._create_dq_node(
  799. second_q_output,
  800. second_dq_output,
  801. add_dequant_suffix(f"{tensor_name}_convert"),
  802. convert_scale_name,
  803. convert_zp_name,
  804. )
  805. # Store in quantized_value_map
  806. original_quantized_value = QuantizedValue(
  807. tensor_name,
  808. first_dq_output,
  809. first_scale_name,
  810. first_zp_name,
  811. QuantizedValueType.Input,
  812. scale_type=scale_data_type,
  813. )
  814. converted_quantized_value = QuantizedValue(
  815. tensor_name,
  816. second_dq_output,
  817. convert_scale_name,
  818. convert_zp_name,
  819. QuantizedValueType.Input,
  820. scale_type=scale_data_type,
  821. )
  822. self.quantized_value_map[tensor_name] = QDQTensorQuantizedValue(
  823. original_quantized_value, converted_quantized_value, convert_recv_nodes
  824. )
  825. def _quantize_normal_tensors(self):
  826. """
  827. Adds Q/DQ ops to tensors (activations and weights) that have been marked for quantization by op quantizers.
  828. """
  829. for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
  830. if tensor_name in self.quantized_value_map:
  831. continue
  832. if not tensor_info.is_shared:
  833. # Quantize the input
  834. initializer = find_by_name(tensor_name, self.model.initializer())
  835. if initializer:
  836. self._add_qdq_nodes_for_initializer(initializer)
  837. else:
  838. # Check if this tensor is already a dequantized value. If so, skip it.
  839. # This happens if the original input model already has some pre-quantized weights
  840. # generated by a different tool.
  841. # Ex: (quantized_weight -> DequantizeLinear -> this_tensor)
  842. if tensor_name in self.tensor_to_producing_dq:
  843. del self.tensors_to_quantize[tensor_name]
  844. continue
  845. tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name)
  846. if not tensor_qparam_initializers:
  847. raise ValueError(
  848. f"Quantization parameters are not specified for param {tensor_name}. "
  849. "In static mode quantization params for inputs and outputs of nodes to be quantized are required."
  850. )
  851. if tensor_qparam_initializers.converted is None:
  852. # Normal case: <producer> --> Q --> DQ --> <consumers>
  853. self._add_qdq_pair_for_activation(
  854. tensor_name,
  855. tensor_qparam_initializers.original.scale.name,
  856. tensor_qparam_initializers.original.zero_point.name,
  857. data_type=tensor_info.data_type,
  858. )
  859. else:
  860. # Conversion case: <producer> ---> Q1 -+-> DQ1 --> <consumers of original type>
  861. # |
  862. # +-> DQ1' --> Q2 --> DQ2 --> <consumers of converted type>
  863. assert tensor_info.data_type == tensor_qparam_initializers.original.scale.data_type
  864. self._add_qdq_ops_for_converted_activation(
  865. tensor_name,
  866. tensor_qparam_initializers.original.scale.name,
  867. tensor_qparam_initializers.original.zero_point.name,
  868. tensor_info.data_type,
  869. tensor_qparam_initializers.converted.scale.name,
  870. tensor_qparam_initializers.converted.zero_point.name,
  871. tensor_qparam_initializers.converted_recv_nodes,
  872. )
  873. del self.tensors_to_quantize[tensor_name]
  874. def _quantize_sharing_param_tensors(self):
  875. """
  876. Adds Q/DQ ops to tensors that have been marked for quantization by op quantizers.
  877. Only operates on tensors that want to use the quantization parameter initializers from an upstream tensor.
  878. For example, a Transpose node's output tensor will typically want to use the same quantization parameter
  879. initializers as the Transpose node's input.
  880. """
  881. while self.tensors_to_quantize:
  882. for tensor_name, tensor_info in self.tensors_to_quantize.copy().items():
  883. quant_provider = tensor_info.quant_para_provider
  884. if quant_provider and quant_provider.input_name in self.quantized_value_map:
  885. del self.tensors_to_quantize[tensor_name]
  886. quantized_value = self.quantized_value_map[quant_provider.input_name].get_for_consumer(
  887. quant_provider.node_name
  888. )
  889. if self.is_input_a_initializer(tensor_name):
  890. raise ValueError("Quantization parameter shared mode is not supported for weight yet")
  891. if tensor_name in self.tensor_to_producing_dq:
  892. raise ValueError(
  893. f"Quantization parameter sharing is invalid for tensor {tensor_name} "
  894. "because it has already been quantized"
  895. )
  896. # Need to check if this tensor's quant_type is converted for some consumers.
  897. # If so, create new scale/zp initializers for these consumers.
  898. converted_qparam_inits = None
  899. converted_recv_nodes = None
  900. if tensor_name in self.quantization_params:
  901. tensor_params = self.quantization_params[tensor_name]
  902. if tensor_params.converted:
  903. converted_qparam_inits = self._make_scale_zp_initializers(
  904. tensor_name, tensor_params.converted, "_convert"
  905. )
  906. converted_recv_nodes = tensor_params.converted_recv_nodes
  907. if converted_qparam_inits is None:
  908. # Normal case: <producer> --> Q_shared --> DQ_shared --> <consumers>
  909. self._add_qdq_pair_for_activation(
  910. tensor_name, quantized_value.scale_name, quantized_value.zp_name
  911. )
  912. else:
  913. # Conversion case: <producer> ---> Q_shared -+-> DQ_shared --> <consumers of original type>
  914. # |
  915. # +-> DQ_shared' --> Q2 --> DQ2 --> <consumers of converted type>
  916. self._add_qdq_ops_for_converted_activation(
  917. tensor_name,
  918. quantized_value.scale_name,
  919. quantized_value.zp_name,
  920. converted_qparam_inits.scale.data_type,
  921. converted_qparam_inits.scale.name,
  922. converted_qparam_inits.zero_point.name,
  923. converted_recv_nodes,
  924. )
  925. def _quantize_bias_tensors(self):
  926. """
  927. Adds DQ ops (or Cast) for bias tensors that have been marked for quantization by op quantizers.
  928. """
  929. for bias_name, bias_info in self.bias_to_quantize.items():
  930. if bias_name in self.quantized_value_map:
  931. continue
  932. # Quantize the input
  933. self.quantize_bias_static(bias_name, bias_info)
  934. init = find_by_name(bias_name, self.model.initializer())
  935. self.model.remove_initializer(init)
  936. quant_value = self.quantized_value_map[bias_name].original
  937. if quant_value.node_type == "Cast":
  938. # simple cast to float 16 and not DequantizeLinear
  939. # cublasLtMatmul only supports (b)float16, float bias.
  940. if not isinstance(init.data_type, int):
  941. raise TypeError(f"Unexpected type {type(init.data_type)} for input={bias_info.input_name!r}")
  942. node_name = add_dequant_suffix(bias_name)
  943. dequant_node = onnx.helper.make_node(
  944. "Cast",
  945. [quant_value.q_name],
  946. [bias_name],
  947. name=node_name,
  948. to=init.data_type,
  949. )
  950. elif quant_value.node_type in (None, "DequantizeLinear"):
  951. if quant_value.node_qtype in {
  952. onnx.TensorProto.FLOAT16,
  953. onnx.TensorProto.BFLOAT16,
  954. onnx.TensorProto.FLOAT,
  955. }:
  956. raise RuntimeError(f"Unexpected quantize type {quant_value.node_qtype} for DequantizeLinear.")
  957. inputs = [quant_value.q_name, quant_value.scale_name, quant_value.zp_name]
  958. node_name = add_dequant_suffix(bias_name)
  959. if quant_value.axis is not None:
  960. dequant_node = onnx.helper.make_node(
  961. "DequantizeLinear",
  962. inputs,
  963. [bias_name],
  964. node_name,
  965. axis=quant_value.axis,
  966. domain=self.qdq_op_domain,
  967. )
  968. else:
  969. dequant_node = onnx.helper.make_node(
  970. "DequantizeLinear",
  971. inputs,
  972. [bias_name],
  973. node_name,
  974. domain=self.qdq_op_domain,
  975. )
  976. else:
  977. raise RuntimeError(f"Unexpected operator type {quant_value.node_type!r}.")
  978. self.model.add_node(dequant_node)
  979. def is_tensor_quantized(self, tensor_name: str):
  980. return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize
  981. def is_tensor_per_channel(
  982. self,
  983. tensor_name: str,
  984. default_axis: int,
  985. op_type: str | None = None,
  986. ) -> tuple[bool, int | None]:
  987. """
  988. Checks if a given tensor is configured to be quantized per-channel. If so, also returns the channel axis.
  989. ORT only supports per-channel quantization on static weights (i.e., ONNX initializers). If the user did not provide
  990. tensor quantization overrides for this tensor, then the value of self.per_channel determines if the weight
  991. is to be quantized per-channel.
  992. Params:
  993. tensor_name: The name of the tensor to check.
  994. default_axis: The default channel axis. This method checks if the normalized axis is within bounds.
  995. Can be overridden via the extra_options 'QDQOpTypePerChannelSupportToAxis'
  996. and 'TensorQuantOverrides'.
  997. op_type: Optional, defaults to None. The operator type that is the only consumer of this weight.
  998. Used to access the extra option 'QDQOpTypePerChannelSupportToAxis'.
  999. Returns:
  1000. A tuple (is_per_channel, axis) in which the first element indicates whether the tensor is
  1001. quantized per-channel and the second element is the channel axis.
  1002. The returned axis is only None if the tensor is not per-channel or the axis is out of bounds.
  1003. """
  1004. weight_initializer = self.initializers.get(tensor_name)
  1005. if weight_initializer is None:
  1006. return False, None # Only support per-channel weights
  1007. if self.tensor_quant_overrides.has_per_tensor_overrides(tensor_name):
  1008. return False, None # User provided per-tensor overrides for this initializer
  1009. has_per_chan_overrides = self.tensor_quant_overrides.has_per_channel_overrides(tensor_name)
  1010. if not self.per_channel and not has_per_chan_overrides:
  1011. return False, None # global self.per_channel is off and user did not provide per-channel overrides.
  1012. axis = self.qdq_op_type_per_channel_support_to_axis.get(op_type, default_axis) if op_type else default_axis
  1013. if has_per_chan_overrides:
  1014. per_chan_overrides = self.tensor_quant_overrides.get_per_channel_overrides(tensor_name)
  1015. axis = per_chan_overrides[0]["axis"] # Prefer axis from user-specified tensor-level overrides if available
  1016. weight_rank = len(weight_initializer.dims)
  1017. axis_valid, axis = normalize_axis(axis, weight_rank)
  1018. if not axis_valid:
  1019. logging.warning(f"Axis {axis} is out-of-range for weight '{tensor_name}' with rank {weight_rank}")
  1020. return False, None
  1021. return True, axis
  1022. def _get_tensor_quantization_scale(self, tensor_name: str, consumer_node_name: str) -> np.ndarray | None:
  1023. """
  1024. Returns the quantization scale of a tensor that is consumed by the given node.
  1025. :parameter tensor_name: The name of the tensor.
  1026. :parameter consumer_node_name: The name of the node that consumes the tensor as input. Necessary in case
  1027. the quantization type of the tensor was converted.
  1028. Refer: QDQQuantizer::_add_qdq_ops_for_converted_activation.
  1029. :returns: The quantization scale or None.
  1030. """
  1031. initializers = self.model.initializer()
  1032. scale_initializer: onnx.TensorProto | None = None
  1033. if tensor_name in self.quantized_value_map:
  1034. # Tensor was quantized by this tool, so get scale from initializer created by this tool run.
  1035. scale_name = self.quantized_value_map[tensor_name].get_for_consumer(consumer_node_name).scale_name
  1036. scale_initializer = find_by_name(scale_name, initializers)
  1037. else:
  1038. # Tensor was already quantized in original model, so get scale from DQ node that outputs the tensor.
  1039. dq_node = self.tensor_to_producing_dq.get(tensor_name, None)
  1040. if dq_node:
  1041. scale_initializer = find_by_name(dq_node.input[1], initializers)
  1042. return tensor_proto_to_array(scale_initializer) if scale_initializer is not None else None
  1043. def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str:
  1044. """
  1045. Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
  1046. """
  1047. # Handle case where bias already in quantization map
  1048. if bias_name in self.quantized_value_map:
  1049. return self.quantized_value_map[bias_name].original.q_name
  1050. # get scale for weight.
  1051. weight_scale = self._get_tensor_quantization_scale(bias_info.weight_name, bias_info.node_name)
  1052. if weight_scale is None:
  1053. raise ValueError(
  1054. f"Unable to get valid quantization scale for weight input '{bias_info.weight_name}' "
  1055. f"when quantizing bias '{bias_name}' to int32."
  1056. )
  1057. # get scale for input.
  1058. input_scale = self._get_tensor_quantization_scale(bias_info.input_name, bias_info.node_name)
  1059. if input_scale is None:
  1060. raise ValueError(
  1061. f"Unable to get valid quantization scale for input '{bias_info.input_name}' "
  1062. f"when quantizing bias '{bias_name}' to int32."
  1063. )
  1064. (
  1065. quantized_bias_name,
  1066. quantized_bias_scale_name,
  1067. quantized_bias_zp_name,
  1068. bias_scale_data,
  1069. node_type,
  1070. node_qtype,
  1071. ) = self.quantize_bias_static_impl(bias_name, input_scale, weight_scale, bias_info.beta)
  1072. quantized_value = QuantizedValue(
  1073. bias_name,
  1074. quantized_bias_name,
  1075. quantized_bias_scale_name,
  1076. quantized_bias_zp_name,
  1077. QuantizedValueType.Initializer,
  1078. 0 if bias_scale_data.size > 1 else None,
  1079. node_type=node_type,
  1080. node_qtype=node_qtype,
  1081. )
  1082. self.quantized_value_map[bias_name] = QDQTensorQuantizedValue(quantized_value, None, None)
  1083. return quantized_bias_name
  1084. def _make_scale_zp_initializers(
  1085. self, param_name: str, quant_params: QuantizationParams, init_name_suffix: str = ""
  1086. ) -> QDQScaleZpInitializers:
  1087. """
  1088. Creates and returns scale and zero-point initializers for the given quantization params. The initializers are
  1089. named:
  1090. - {param_name}_zero_point{init_name_suffix}
  1091. - {param_name}_scale{init_name_suffix}
  1092. """
  1093. zero_point = quant_params["zero_point"]
  1094. scale = quant_params["scale"]
  1095. zero_point_type = quant_params["quant_type"]
  1096. axis: int | None = quant_params.get("axis")
  1097. assert (axis is not None and len(scale.shape) == 1) or (axis is None and len(scale.shape) == 0), (
  1098. "Wrong scale/zp shapes"
  1099. )
  1100. assert len(scale.shape) == len(zero_point.shape), "Scale and zero-point must have the same rank"
  1101. zero_point_name = param_name + "_zero_point" + init_name_suffix
  1102. scale_name = param_name + "_scale" + init_name_suffix
  1103. # Add initializers to model
  1104. init_zp = onnx.helper.make_tensor(
  1105. zero_point_name, zero_point_type, zero_point.shape, zero_point.ravel().tolist()
  1106. )
  1107. self.model.add_initializer(init_zp)
  1108. if scale.dtype == np.float32:
  1109. scale_type = onnx_proto.TensorProto.FLOAT
  1110. elif scale.dtype == np.float16:
  1111. scale_type = onnx_proto.TensorProto.FLOAT16
  1112. else:
  1113. raise ValueError(f"Unexpected dtype={scale.dtype} for param_name={param_name!r}")
  1114. init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale.shape, scale.ravel().tolist())
  1115. self.model.add_initializer(init_scale)
  1116. return QDQScaleZpInitializers(init_scale, init_zp)
  1117. def _make_tensor_scale_zp_initializers(self, tensor_name: str) -> QDQTensorScaleZpInitializers | None:
  1118. """
  1119. Create and returns all scale/zero_point initializers for a given tensor. If the tensor is converted
  1120. to a different quantization type, this function creates two pairs of zp/scale initializers. Otherwise,
  1121. only one pair of zp/scale initializers is created.
  1122. """
  1123. if self.quantization_params is None or tensor_name not in self.quantization_params:
  1124. logging.info(f'Quantization parameters for tensor:"{tensor_name}" not specified')
  1125. return None
  1126. tensor_params = self.quantization_params[tensor_name]
  1127. if not isinstance(tensor_params, QDQTensorQuantParams):
  1128. raise TypeError(f"Unexpected type {type(tensor_params)} for {tensor_name!r}.")
  1129. original_inits = self._make_scale_zp_initializers(tensor_name, tensor_params.original)
  1130. converted_inits = (
  1131. self._make_scale_zp_initializers(tensor_name, tensor_params.converted, "_convert")
  1132. if tensor_params.converted
  1133. else None
  1134. )
  1135. return QDQTensorScaleZpInitializers(original_inits, converted_inits, tensor_params.converted_recv_nodes)
  1136. def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, Any]) -> QuantizationParams:
  1137. """
  1138. Calculates quantization parameters (scale/zero-point) given a tensor's min/max range and optional
  1139. user-provided overrides.
  1140. """
  1141. quant_type = self.activation_qType
  1142. if "quant_type" in quant_overrides:
  1143. quant_type = quant_overrides["quant_type"].tensor_type
  1144. if "scale" in quant_overrides and "zero_point" in quant_overrides:
  1145. zero, scale = quant_overrides["zero_point"], quant_overrides["scale"]
  1146. elif quant_type == onnx.TensorProto.FLOAT8E4M3FN:
  1147. zero, scale = compute_scale_zp_float8(quant_type, tensor_data.avg_std[1])
  1148. else:
  1149. rmin = quant_overrides.get("rmin", tensor_data.range_value[0])
  1150. rmax = quant_overrides.get("rmax", tensor_data.range_value[1])
  1151. symmetric = quant_overrides.get("symmetric", self.is_activation_symmetric)
  1152. reduce_range = quant_overrides.get("reduce_range", False)
  1153. qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric)
  1154. zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range)
  1155. return QuantizationParams(zero_point=zero.squeeze(), scale=scale.squeeze(), quant_type=quant_type)
  1156. def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]:
  1157. """
  1158. Calculates quantization parameters (scale/zero-point) for all tensors in the graph using each tensor's min/max range
  1159. and optional user-provided overrides.
  1160. """
  1161. if self.tensors_range is None:
  1162. return {}
  1163. self.adjust_tensor_ranges()
  1164. quantization_params = {}
  1165. for tensor_name in self.tensors_range:
  1166. td = self.tensors_range[tensor_name]
  1167. if not isinstance(td, TensorData):
  1168. raise TypeError(f"Unexpected type {type(td)} for {tensor_name!r}.")
  1169. quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(tensor_name, default_val={})
  1170. original = self.calc_quant_params(td, quant_overrides)
  1171. converted = None
  1172. converted_recv_nodes = None
  1173. if "convert" in quant_overrides:
  1174. converted = self.calc_quant_params(td, quant_overrides["convert"])
  1175. converted_recv_nodes = quant_overrides["convert"].get("recv_nodes")
  1176. quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes)
  1177. return quantization_params
  1178. def _calc_initializer_quant_params(self) -> dict[str, QuantizationParams]:
  1179. """
  1180. Returns quantization parameters (scale/zero_point/quant_type) for all initializers.
  1181. """
  1182. quantization_params: dict[str, QuantizationParams] = {}
  1183. for tensor_name, tensor_info in self.tensors_to_quantize.items():
  1184. initializer = find_by_name(tensor_name, self.model.initializer())
  1185. if not initializer:
  1186. continue
  1187. initializer_data = tensor_proto_to_array(initializer)
  1188. initializer_rank = len(initializer_data.shape)
  1189. # initializers for elementwise ops use the quant_type for activations.
  1190. is_weight = tensor_info.tensor_type is QDQQuantTensorType.WEIGHT
  1191. quant_type = self.weight_qType if is_weight else self.activation_qType
  1192. # Try to get scale/zp directly from user's overrides and avoid computation.
  1193. if self.tensor_quant_overrides.overrides_scale_zp(tensor_name):
  1194. overrides = self.tensor_quant_overrides[tensor_name]
  1195. if "quant_type" in overrides[0]:
  1196. quant_type = overrides[0]["quant_type"].tensor_type
  1197. zp_dtype = ONNX_TYPE_TO_NP_TYPE[quant_type]
  1198. is_per_channel = "axis" in overrides[0]
  1199. if not is_per_channel:
  1200. quantization_params[tensor_name] = QuantizationParams(
  1201. zero_point=np.array(overrides[0]["zero_point"], dtype=zp_dtype),
  1202. scale=np.array(overrides[0]["scale"], initializer_data.dtype),
  1203. quant_type=quant_type,
  1204. )
  1205. else:
  1206. zero_points_list = []
  1207. scales_list = []
  1208. for chan_overrides in overrides:
  1209. zero_points_list.append(np.array(chan_overrides["zero_point"], zp_dtype))
  1210. scales_list.append(np.array(chan_overrides["scale"], dtype=initializer_data.dtype))
  1211. channel_axis = overrides[0]["axis"]
  1212. is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank)
  1213. if not is_axis_valid:
  1214. raise ValueError(
  1215. f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is "
  1216. f"out-of-bounds for rank {initializer_rank}"
  1217. )
  1218. quantization_params[tensor_name] = QuantizationParams(
  1219. zero_point=np.array(zero_points_list),
  1220. scale=np.array(scales_list),
  1221. quant_type=quant_type,
  1222. axis=norm_channel_axis,
  1223. )
  1224. continue
  1225. # Compute scale/zp normally. User's overrides may still override parameters
  1226. # used to compute the scale/zp (e.g., rmin, rmax, symmetric, etc.)
  1227. overrides = self.tensor_quant_overrides.get(tensor_name, [{}])
  1228. if "quant_type" in overrides[0]:
  1229. quant_type = overrides[0]["quant_type"].tensor_type
  1230. channel_axis = overrides[0].get("axis", tensor_info.axis)
  1231. is_per_channel = channel_axis is not None
  1232. # Note: always quantize per-channel initializers as symmetric because QLinear* ops require the
  1233. # same zero-point in every channel, which is necessarily the case for symmetric quantization.
  1234. is_symmetric_default = is_per_channel or (
  1235. self.is_weight_symmetric(quant_type) if is_weight else self.is_activation_symmetric
  1236. )
  1237. is_symmetric = overrides[0].get("symmetric", is_symmetric_default)
  1238. reduce_range = overrides[0].get("reduce_range", self.reduce_range)
  1239. zero_point: np.ndarray | None = None
  1240. scale: np.ndarray | None = None
  1241. if not is_per_channel:
  1242. zero_point, scale = compute_data_quant_params(
  1243. initializer_data.flatten(),
  1244. quant_type,
  1245. is_symmetric,
  1246. reduce_range=reduce_range,
  1247. min_real_range=self.min_real_range,
  1248. rmin_override=overrides[0].get("rmin"),
  1249. rmax_override=overrides[0].get("rmax"),
  1250. )
  1251. else:
  1252. is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank)
  1253. if not is_axis_valid:
  1254. raise ValueError(
  1255. f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is "
  1256. f"out-of-bounds for rank {initializer_rank}"
  1257. )
  1258. channel_axis = norm_channel_axis
  1259. channel_count = initializer_data.shape[channel_axis]
  1260. zero_points_list = []
  1261. scales_list = []
  1262. for i in range(channel_count):
  1263. per_channel_data = initializer_data.take(i, channel_axis)
  1264. channel_overrides = overrides[i] if overrides and i < len(overrides) else {}
  1265. channel_zero_point, channel_scale = compute_data_quant_params(
  1266. per_channel_data.ravel(),
  1267. quant_type,
  1268. is_symmetric,
  1269. reduce_range=reduce_range,
  1270. min_real_range=self.min_real_range,
  1271. rmin_override=channel_overrides.get("rmin"),
  1272. rmax_override=channel_overrides.get("rmax"),
  1273. )
  1274. zero_points_list.append(channel_zero_point)
  1275. scales_list.append(channel_scale)
  1276. zero_point = np.asarray(zero_points_list)
  1277. scale = np.asarray(scales_list)
  1278. quantization_params[tensor_name] = QuantizationParams(
  1279. zero_point=zero_point,
  1280. scale=scale,
  1281. quant_type=quant_type,
  1282. axis=channel_axis,
  1283. )
  1284. return quantization_params