fusion_qordered_matmul.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import FusionUtils
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionQOrderedMatMul(Fusion):
  12. def __init__(self, model: OnnxModel):
  13. super().__init__(model, "QOrderedMatMul", "MatMul")
  14. def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
  15. matmul_children = self.model.get_children(node, input_name_to_nodes)
  16. # Should only have 1 child - Bias Add
  17. if len(matmul_children) != 1 or matmul_children[0].op_type != "Add":
  18. return
  19. bias_add_node = matmul_children[0]
  20. # Atleast one of the inputs to Bias Add node must be a constant
  21. bias_add_node_index = 0
  22. if (
  23. self.model.get_constant_value(bias_add_node.input[0]) is None
  24. and self.model.get_constant_value(bias_add_node.input[1]) is None
  25. ):
  26. return
  27. if self.model.get_constant_value(bias_add_node.input[0]) is None:
  28. bias_add_node_index = 1
  29. bias_add_children = self.model.get_children(bias_add_node, input_name_to_nodes)
  30. if len(bias_add_children) != 1:
  31. return
  32. bias_add_child = bias_add_children[0]
  33. # Bias Add can have another Add downstream (Residual Add layer)
  34. residual_add_node = None
  35. downstream_quantize_node = None
  36. if bias_add_child.op_type == "Add":
  37. residual_add_node = bias_add_child
  38. residual_add_children = self.model.get_children(residual_add_node, input_name_to_nodes)
  39. if len(residual_add_children) != 1 or residual_add_children[0].op_type != "QuantizeLinear":
  40. return
  41. downstream_quantize_node = residual_add_children[0]
  42. elif bias_add_child.op_type == "QuantizeLinear":
  43. downstream_quantize_node = bias_add_child
  44. else:
  45. return
  46. # Make sure the downstream QuantizeLinear has the proper zero points and scales
  47. if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
  48. return
  49. # The first input to MatMul should flow through a DequantizeLinear node
  50. first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
  51. node,
  52. [(["DequantizeLinear"], [0])],
  53. output_name_to_node,
  54. )
  55. # If Attention is not fused, this is the pattern to look for
  56. # leading upto the MatMul
  57. reshape_node_0 = None
  58. transpose_node_0 = None
  59. if first_path_id < 0:
  60. first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
  61. node,
  62. [(["Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear"], [0, 0, 0, 0])],
  63. output_name_to_node,
  64. )
  65. if first_path_id < 0:
  66. return
  67. reshape_node_0 = first_input_parent_nodes[0]
  68. transpose_node_0 = first_input_parent_nodes[1]
  69. dequantize_node_0 = first_input_parent_nodes[2]
  70. else:
  71. dequantize_node_0 = first_input_parent_nodes[0]
  72. # Make sure the upstream DequantizeLinear-0 has the proper zero points and scales
  73. if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_0, self.model):
  74. return
  75. # The second input to MatMul should flow through a DequantizeLinear node
  76. dequantize_node_1 = None
  77. is_weight_transpose_required = True
  78. weight_path_id, weight_nodes, _ = self.model.match_parent_paths(
  79. node,
  80. [(["DequantizeLinear", "QuantizeLinear", "Transpose", "DequantizeLinear"], [1, 0, 0, 0])],
  81. output_name_to_node,
  82. )
  83. if weight_path_id < 0:
  84. weight_path_id, weight_nodes, _ = self.model.match_parent_paths(
  85. node,
  86. [(["DequantizeLinear"], [1])],
  87. output_name_to_node,
  88. )
  89. if weight_path_id < 0:
  90. return
  91. dequantize_node_1 = weight_nodes[0]
  92. else:
  93. is_weight_transpose_required = False
  94. dequantize_node_1 = weight_nodes[3]
  95. # Check if weight 'B' is a constant
  96. if self.model.get_constant_value(dequantize_node_1.input[0]) is None:
  97. return
  98. # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
  99. # Per-channel scales are supported for weights alone
  100. if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_1, self.model, False):
  101. return
  102. # Make sure the upstream flow into the Residual Add node flows through a DQ node
  103. residual_add_dequantize_node = None
  104. if residual_add_node is not None:
  105. residual_path_id, residual_input_parent_nodes, _ = self.model.match_parent_paths(
  106. residual_add_node,
  107. [
  108. (["DequantizeLinear"], [1]),
  109. ],
  110. output_name_to_node,
  111. )
  112. if residual_path_id < 0:
  113. return
  114. residual_add_dequantize_node = residual_input_parent_nodes[0]
  115. # Make sure the upstream DequantizeLinear to the Residual Add has the proper zero points and scales
  116. if residual_add_dequantize_node is not None and not FusionUtils.check_qdq_node_for_fusion(
  117. residual_add_dequantize_node, self.model
  118. ):
  119. return
  120. # Subgraph nodes to be fused
  121. subgraph_nodes = [node, bias_add_node] # MatMul + Bias Add
  122. if residual_add_node is not None:
  123. subgraph_nodes.extend([residual_add_node]) # Residual Add
  124. subgraph_nodes.extend(weight_nodes)
  125. subgraph_nodes.extend([downstream_quantize_node]) # Downstream Q node
  126. if not self.model.is_safe_to_fuse_nodes(
  127. subgraph_nodes, downstream_quantize_node.output, input_name_to_nodes, output_name_to_node
  128. ):
  129. logger.debug("It is not safe to fuse QOrderedMatMul node. Skip")
  130. return
  131. # Deal with the case where-in the Attention subgraph is not fused
  132. if transpose_node_0 is not None:
  133. self.model.replace_node_input(transpose_node_0, transpose_node_0.input[0], dequantize_node_0.input[0])
  134. # Make inputs
  135. fused_node_inputs = [
  136. reshape_node_0.output[0] if reshape_node_0 is not None else dequantize_node_0.input[0],
  137. dequantize_node_0.input[1],
  138. dequantize_node_1.input[0],
  139. dequantize_node_1.input[1],
  140. downstream_quantize_node.input[1],
  141. bias_add_node.input[bias_add_node_index],
  142. ]
  143. if residual_add_node is not None:
  144. fused_node_inputs.append(residual_add_dequantize_node.input[0])
  145. fused_node_inputs.append(residual_add_dequantize_node.input[1])
  146. # The MatMul weight 'B' and 'bias' need some post-processing
  147. # Transpose weight 'B' from order ROW to order COL
  148. # This offline transpose is needed only while using the CUDA EP
  149. # TODO: Make this fusion logic EP-agnostic ?
  150. if is_weight_transpose_required:
  151. weight_tensor = self.model.get_initializer(dequantize_node_1.input[0])
  152. FusionUtils.transpose_2d_int8_tensor(weight_tensor)
  153. fused_node = helper.make_node(
  154. "QOrderedMatMul",
  155. inputs=fused_node_inputs,
  156. outputs=[downstream_quantize_node.output[0]],
  157. name=self.model.create_node_name("QOrderedMatMul", name_prefix="QOrderedMatMul"),
  158. )
  159. fused_node.attribute.extend([helper.make_attribute("order_A", 1)])
  160. fused_node.attribute.extend([helper.make_attribute("order_B", 0)])
  161. fused_node.attribute.extend([helper.make_attribute("order_Y", 1)])
  162. fused_node.domain = "com.microsoft"
  163. self.nodes_to_remove.extend(subgraph_nodes)
  164. self.nodes_to_add.append(fused_node)
  165. self.node_name_to_graph_name[fused_node.name] = self.this_graph_name