fusion_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy
  7. from numpy import array_equal, ndarray
  8. from onnx import NodeProto, TensorProto, helper, numpy_helper
  9. from onnx import onnx_pb as onnx_proto
  10. from onnx_model import OnnxModel
  11. logger = getLogger(__name__)
  12. class FusionUtils:
  13. def __init__(self, model: OnnxModel):
  14. self.model: OnnxModel = model
  15. def cast_graph_input_to_int32(self, input_name: str) -> tuple[bool, str]:
  16. graph_input = self.model.find_graph_input(input_name)
  17. if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32:
  18. cast_output, cast_node = self.cast_input_to_int32(input_name)
  19. logger.debug(f"Casted graph input {input_name} to int32")
  20. return True, cast_output
  21. logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}")
  22. return False, input_name
  23. def cast_input(self, input_name: str, target_type="int32"):
  24. output_name = input_name + "_" + target_type
  25. if target_type == "int32":
  26. to_type = int(TensorProto.INT32)
  27. elif target_type == "float32":
  28. to_type = int(TensorProto.FLOAT)
  29. elif target_type == "float16":
  30. to_type = int(TensorProto.FLOAT16)
  31. else:
  32. raise ValueError("Invalid target_type: {target_type}")
  33. cast_node = self.add_cast_node(input_name, to_type, output_name)
  34. return output_name, cast_node
  35. def add_cast_node(
  36. self,
  37. input_name: str,
  38. to_type: int,
  39. output_name: str | None = None,
  40. output_name_to_node=None,
  41. graph_name: str | None = None,
  42. ):
  43. if output_name is None:
  44. output_name = input_name + f"_cast_to_{to_type}"
  45. # Avoid consequent Cast nodes.
  46. inputs = [input_name]
  47. if output_name_to_node is None:
  48. output_name_to_node = self.model.output_name_to_node()
  49. if input_name in output_name_to_node:
  50. parent_node = output_name_to_node[input_name]
  51. if parent_node and parent_node.op_type == "Cast":
  52. inputs = [parent_node.input[0]]
  53. cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name])
  54. cast_node.attribute.extend([helper.make_attribute("to", to_type)])
  55. self.model.add_node(cast_node, graph_name=graph_name)
  56. return cast_node
  57. def cast_input_to_int32(self, input_name: str):
  58. return self.cast_input(input_name, "int32")
  59. def remove_cast_int32(self, input_name: str):
  60. input_name_to_nodes = self.model.input_name_to_nodes()
  61. nodes = input_name_to_nodes[input_name]
  62. for node in nodes:
  63. if node.op_type == "Cast":
  64. is_int32 = False
  65. for att in node.attribute:
  66. if att.name == "to" and att.i == int(TensorProto.INT32):
  67. is_int32 = True
  68. break
  69. if is_int32:
  70. output_name = node.output[0]
  71. self.model.remove_node(node)
  72. self.model.replace_input_of_all_nodes(output_name, input_name)
  73. @staticmethod
  74. def update_node_input(node, i, new_input_name, input_name_to_nodes):
  75. old_input_reference = 0
  76. if (node.input[i] in input_name_to_nodes) and node in input_name_to_nodes[node.input[i]]:
  77. input_name_to_nodes[node.input[i]].remove(node)
  78. old_input_reference = len(input_name_to_nodes[node.input[i]])
  79. node.input[i] = new_input_name
  80. if new_input_name in input_name_to_nodes:
  81. input_name_to_nodes[new_input_name].append(node)
  82. else:
  83. input_name_to_nodes[new_input_name] = [node]
  84. return old_input_reference
  85. @staticmethod
  86. def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_input_index=0, parent_input_index=0):
  87. """
  88. Before:
  89. (input)-->parent-->node-->(output)
  90. After:
  91. (input)-->parent-->
  92. |
  93. +----->node-->(output)
  94. This function returns a flag whether the parent node can be removed.
  95. """
  96. old_input_name = node.input[node_input_index]
  97. new_input_name = parent_node.input[parent_input_index]
  98. old_input_reference = FusionUtils.update_node_input(node, node_input_index, new_input_name, input_name_to_nodes)
  99. # We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
  100. parent_can_be_removed = (old_input_reference == 0) and not model.find_graph_output(old_input_name)
  101. return parent_can_be_removed
  102. def get_squeeze_or_unsqueeze_axes(self, node: NodeProto) -> ndarray | None:
  103. assert node.op_type in ["Squeeze", "Unsqueeze"]
  104. # For opset >= 13, axes is an input instead of an attribute.
  105. if len(node.input) > 1:
  106. return self.model.get_constant_value(node.input[1])
  107. axes = None
  108. for attr in node.attribute:
  109. if attr.name == "axes":
  110. axes = helper.get_attribute_value(attr)
  111. return axes
  112. @staticmethod
  113. def check_node_attribute(node, attribute_name: str, expected_value, default_value=None):
  114. """Verify that a node has expected value for an attribute.
  115. Args:
  116. node (NodeProto): a node to check
  117. attribute_name (str): name of attribute
  118. expected_value (Any): expected value of the attribute
  119. default_value (Any, optional): default value if the attribute does not exist. Defaults to None.
  120. Returns:
  121. bool: whether the check is passed or not
  122. """
  123. value = default_value
  124. for attr in node.attribute:
  125. if attr.name == attribute_name:
  126. value = helper.get_attribute_value(attr)
  127. if isinstance(expected_value, list):
  128. return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
  129. else:
  130. return value == expected_value
  131. @staticmethod
  132. def transpose_2d_int8_tensor(tensor: onnx_proto.TensorProto):
  133. """Transpose a 2-D INT8 TensorProto
  134. Args:
  135. tensor (TensorProto): tensor to be transposed
  136. Returns:
  137. tensor (TensorProto): transposed tensor
  138. """
  139. if not isinstance(tensor, onnx_proto.TensorProto):
  140. raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
  141. if len(tensor.dims) != 2 or tensor.data_type != onnx_proto.TensorProto.INT8:
  142. raise ValueError("Only INT8 2-D tensors can be transposed")
  143. if tensor.raw_data:
  144. int32_data = numpy.reshape(numpy.frombuffer(tensor.raw_data, dtype="int8"), tensor.dims)
  145. int32_transposed_data = numpy.transpose(int32_data, [1, 0])
  146. tensor.raw_data = int32_transposed_data.tobytes()
  147. else:
  148. raise ValueError("only raw buffer supported")
  149. return tensor
  150. @staticmethod
  151. def check_qdq_node_for_fusion(node: NodeProto, model: OnnxModel, allow_per_tensor_quantization_only=True):
  152. """Verify if a provided QuantizeLinear (Q) / DequantizeLinear (DQ) node is a good candidate for fusion.
  153. It is a good candidate for fusion if:
  154. (1) The Q/DQ node is for per-tensor quantization if allow_per_tensor_quantization_only is `True`
  155. (2) The Q/DQ node should have constant scale
  156. (3) The Q/DQ node should have a zero point of 0
  157. Args:
  158. node (NodeProto): a Q/DQ node to check
  159. Returns:
  160. bool: whether the check is passed or not
  161. """
  162. if node.op_type not in {"QuantizeLinear", "DequantizeLinear"}:
  163. logger.debug(f"Provided node is not a Q/DQ node. Op Type: {node.op_type}")
  164. scale = model.get_constant_value(node.input[1])
  165. # Scale is not constant
  166. if scale is None:
  167. return False
  168. # Not per-tensor quantization
  169. scale_has_single_element = scale.ndim == 0 or (scale.ndim == 1 and scale.shape[0] == 1)
  170. if allow_per_tensor_quantization_only and not scale_has_single_element:
  171. return False
  172. # If the Q/DQ node has no zero point input, it is assumed to be 0 (per ONNX spec)
  173. if len(node.input) == 2:
  174. return True
  175. # Zero point should be constant and should have a value of 0
  176. zero_point = model.get_constant_value(node.input[2])
  177. # Zero point and scale should have same number of dims
  178. if scale.ndim != zero_point.ndim:
  179. return False
  180. # Zero point is not constant or zero point is not zero
  181. if zero_point is None:
  182. return False
  183. return numpy.all(zero_point == 0)
  184. def check_node_input_value(self, node, input_index: int, expected_value):
  185. """Verify that a node has expected input value
  186. Args:
  187. node (NodeProto): a node to check
  188. input_index (int): index of its input to be verified
  189. expected_value (Any): expected value of the input
  190. Returns:
  191. bool: whether the check is passed or not
  192. """
  193. assert len(node.input) > input_index
  194. value = self.model.get_constant_value(node.input[input_index])
  195. if isinstance(expected_value, list):
  196. return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
  197. else:
  198. return value == expected_value
  199. def remove_identity_nodes(self):
  200. """Remove Identity nodes, except those right before graph output."""
  201. nodes_to_remove = []
  202. graph_output_names = self.model.get_graphs_output_names()
  203. for node in self.model.nodes():
  204. if node.op_type == "Identity":
  205. if node.output[0] not in graph_output_names:
  206. self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
  207. nodes_to_remove.append(node)
  208. if nodes_to_remove:
  209. self.model.remove_nodes(nodes_to_remove)
  210. logger.info(f"Removed {len(nodes_to_remove)} Identity nodes")
  211. def remove_cascaded_cast_nodes(self):
  212. self.model.remove_cascaded_cast_nodes()
  213. def remove_useless_cast_nodes(self):
  214. self.model.remove_useless_cast_nodes()
  215. def remove_useless_reshape_nodes(self):
  216. """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape"""
  217. shape_infer = self.model.infer_runtime_shape(update=True)
  218. if shape_infer is None:
  219. return
  220. nodes_to_remove = []
  221. for node in self.model.nodes():
  222. if node.op_type == "Reshape":
  223. input_shape = shape_infer.get_edge_shape(node.input[0])
  224. output_shape = shape_infer.get_edge_shape(node.output[0])
  225. if input_shape and output_shape and input_shape == output_shape:
  226. logger.info(
  227. f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}"
  228. )
  229. nodes_to_remove.append(node)
  230. if nodes_to_remove:
  231. graph_input_names = set(self.model.get_graphs_input_names())
  232. graph_output_names = set(self.model.get_graphs_output_names())
  233. for node in nodes_to_remove:
  234. if bool(set(node.output) & graph_output_names):
  235. if (
  236. not bool(set(node.input) & graph_input_names)
  237. and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child
  238. ):
  239. self.model.replace_output_of_all_nodes(node.input[0], node.output[0])
  240. else:
  241. continue
  242. else:
  243. self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
  244. self.model.remove_node(node)
  245. class NumpyHelper:
  246. @staticmethod
  247. def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray:
  248. # When weights are in external data format but not presented, we can still test the optimizer with two changes:
  249. # (1) set fill_zeros = True (2) change load_external_data=False in optimizer.py
  250. if fill_zeros:
  251. return ndarray(
  252. shape=tensor.dims,
  253. dtype=helper.tensor_dtype_to_np_dtype(tensor.data_type),
  254. )
  255. return numpy_helper.to_array(tensor)