| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- from logging import getLogger
- import numpy as np
- from fusion_attention import AttentionMask
- from fusion_base import Fusion
- from fusion_utils import FusionUtils, NumpyHelper
- from onnx import NodeProto, helper
- from onnx_model import OnnxModel
- logger = getLogger(__name__)
- class FusionQOrderedAttention(Fusion):
- def __init__(
- self,
- model: OnnxModel,
- hidden_size: int,
- num_heads: int,
- attention_mask: AttentionMask,
- ):
- self.hidden_size = hidden_size
- self.num_heads = num_heads
- self.attention_mask = attention_mask
- super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization")
- def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]:
- """Detect num_heads and hidden_size from a reshape node.
- Args:
- reshape_q (NodeProto): reshape node for Q
- Returns:
- Tuple[int, int]: num_heads and hidden_size
- """
- # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
- q_shape = self.model.get_initializer(reshape_q.input[1])
- if q_shape is None:
- logger.debug(f"{reshape_q.input[1]} is not initializer.")
- # Check if the second input to Reshape flows through a Constant node
- # TODO: Investigate why FusionAttention doesn't have such logic
- constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1])
- if constant_node is None:
- return self.num_heads, self.hidden_size # Fall back to user specified value
- else:
- constant_node = constant_node[0]
- if len(constant_node.attribute) != 1:
- return self.num_heads, self.hidden_size # Fall back to user specified value
- # This is assuming it is a Tensor attribute (this is a safe assumption)
- q_shape = constant_node.attribute[0].t
- q_shape_value = NumpyHelper.to_array(q_shape)
- if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
- logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
- return self.num_heads, self.hidden_size # Fall back to user specified value
- num_heads = q_shape_value[2]
- head_size = q_shape_value[3]
- hidden_size = num_heads * head_size
- if self.num_heads > 0 and num_heads != self.num_heads:
- if self.num_heads_warning:
- logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
- self.num_heads_warning = False # Do not show the warning more than once
- if self.hidden_size > 0 and hidden_size != self.hidden_size:
- if self.hidden_size_warning:
- logger.warning(
- f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
- )
- self.hidden_size_warning = False # Do not show the warning more than once
- return num_heads, hidden_size
- def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
- add_before_layernorm = self.model.match_parent_path(
- normalize_node,
- ["QuantizeLinear", "Add"],
- [0, 0],
- )
- if add_before_layernorm is not None:
- start_node = add_before_layernorm[-1]
- else:
- return
- # Input QDQ nodes
- dequantize_input = self.model.match_parent_path(
- start_node,
- ["DequantizeLinear"],
- [None],
- )
- if dequantize_input is None:
- logger.debug("fuse_qordered_attention: failed to match input qdq nodes path")
- return
- dequantize_input = dequantize_input[-1]
- # QKV nodes
- qkv_nodes = self.model.match_parent_path(
- start_node,
- ["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"],
- [None, None, 0, 0, 0, 0, 0],
- )
- if qkv_nodes is None:
- logger.debug("fuse_qordered_attention: failed to match qkv path")
- return
- (_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes
- # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
- if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model):
- return
- if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model):
- return
- # Identify the root input to the Attention node
- other_inputs = []
- for _i, input in enumerate(start_node.input):
- if input not in output_name_to_node:
- continue
- if input == qkv_nodes[0].output[0]:
- continue
- other_inputs.append(input)
- if len(other_inputs) != 1:
- return
- root_input = other_inputs[0]
- # V nodes
- v_nodes = self.model.match_parent_path(
- matmul_qkv,
- ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
- [1, 0, 0, 0, 0, None],
- )
- if v_nodes is None:
- logger.debug("fuse_qordered_attention: failed to match v path")
- return
- (_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes
- # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
- if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model):
- return
- if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model):
- return
- # V MatMul weight
- dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1])
- if dequantize_v_matmul_weight is None:
- logger.debug("fuse_qordered_attention: failed to match v path")
- return
- dequantize_v_matmul_weight = dequantize_v_matmul_weight[0]
- if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None:
- return
- # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
- # Per-channel scales are supported for weights alone
- if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False):
- return
- # QK nodes
- qk_nodes = self.model.match_parent_path(
- matmul_qkv,
- [
- "DequantizeLinear",
- "QuantizeLinear",
- "Softmax",
- "Add",
- "Div",
- "DequantizeLinear",
- "QuantizeLinear",
- "MatMul",
- ],
- [0, 0, 0, 0, None, 0, 0, 0],
- )
- if qk_nodes is None:
- logger.debug("fuse_qordered_attention: failed to match qk path")
- return
- (
- dequantize_qk_softmax,
- quantize_qk_softmax,
- softmax_qk,
- add_qk,
- div_qk,
- dequantize_qk,
- quantize_qk,
- matmul_qk,
- ) = qk_nodes
- # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
- if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model):
- return
- if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model):
- return
- if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model):
- return
- if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model):
- return
- # Q nodes
- q_nodes = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
- [0, 0, 0, 0, 0, None],
- )
- if q_nodes is None:
- logger.debug("fuse_qordered_attention: failed to match q path")
- return
- (_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes
- # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
- if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model):
- return
- if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model):
- return
- # Q MatMul weight
- dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1])
- if dequantize_q_matmul_weight is None:
- logger.debug("fuse_qordered_attention: failed to match q path")
- return
- dequantize_q_matmul_weight = dequantize_q_matmul_weight[0]
- if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None:
- return
- # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
- # Per-channel scales are supported for weights alone
- if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False):
- return
- # K nodes
- k_nodes = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
- [1, 0, 0, 0, 0, None],
- )
- if k_nodes is None:
- logger.debug("fuse_qordered_attention: failed to match k path")
- return
- (_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes
- # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
- if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model):
- return
- if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model):
- return
- # K MatMul weight
- dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1])
- if dequantize_k_matmul_weight is None:
- logger.debug("fuse_qordered_attention: failed to match k path")
- return
- dequantize_k_matmul_weight = dequantize_k_matmul_weight[0]
- if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None:
- return
- # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
- # Per-channel scales are supported for weights alone
- if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False):
- return
- # Mask nodes
- mask_nodes = self.model.match_parent_path(
- add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]
- )
- if mask_nodes is None:
- logger.debug("fuse_qordered_attention: failed to match mask_nodes path")
- return
- # Ascertain `qkv_hidden_sizes` attribute value
- q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
- k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
- v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
- qw = NumpyHelper.to_array(q_weight)
- kw = NumpyHelper.to_array(k_weight)
- vw = NumpyHelper.to_array(v_weight)
- qw_out_size = np.prod(qw.shape[1:])
- kw_out_size = np.prod(kw.shape[1:])
- vw_out_size = np.prod(vw.shape[1:])
- # Form QOrderedAttention node
- if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
- mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
- # Ascertain `num_heads` and `hidden_size`
- num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
- # Formulate the inputs
- # Actual quantized input
- attention_inputs = [dequantize_input.input[0]]
- attention_inputs.append(dequantize_input.input[1])
- attention_inputs.append(dequantize_q.input[1])
- attention_inputs.append(dequantize_k.input[1])
- attention_inputs.append(dequantize_v.input[1])
- attention_inputs.append(dequantize_q_matmul_weight.input[0])
- attention_inputs.append(dequantize_k_matmul_weight.input[0])
- attention_inputs.append(dequantize_v_matmul_weight.input[0])
- attention_inputs.append(dequantize_q_matmul_weight.input[1])
- attention_inputs.append(dequantize_k_matmul_weight.input[1])
- attention_inputs.append(dequantize_v_matmul_weight.input[1])
- if self.model.get_initializer(add_q.input[0]):
- attention_inputs.append(add_q.input[0])
- else: # second input is the constant bias
- attention_inputs.append(add_q.input[1])
- if self.model.get_initializer(add_k.input[0]):
- attention_inputs.append(add_k.input[0])
- else: # second input is the constant bias
- attention_inputs.append(add_k.input[1])
- if self.model.get_initializer(add_v.input[0]):
- attention_inputs.append(add_v.input[0])
- else: # second input is the constant bias
- attention_inputs.append(add_v.input[1])
- attention_inputs.append(quantize_qk.input[1])
- attention_inputs.append(quantize_qk_softmax.input[1])
- attention_inputs.append(dequantize_qkv.input[1])
- # Mask input
- if mask_index is not None:
- attention_inputs.append(mask_index)
- else:
- attention_inputs.append("")
- # The MatMul weight 'B' and 'bias' need some post-processing
- # Transpose weight 'B' from order ROW to order COL
- # This offline transpose is needed only while using the CUDA EP
- # TODO: Make this fusion logic EP-agnostic ?
- q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
- FusionUtils.transpose_2d_int8_tensor(q_weight_tensor)
- k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
- FusionUtils.transpose_2d_int8_tensor(k_weight_tensor)
- v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
- FusionUtils.transpose_2d_int8_tensor(v_weight_tensor)
- # Name and create Attention node
- attention_node_name = self.model.create_node_name("QOrderedAttention")
- attention_node = helper.make_node(
- "QOrderedAttention",
- inputs=attention_inputs,
- outputs=[reshape_qkv.output[0]],
- name=attention_node_name,
- )
- self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0])
- self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0])
- attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
- attention_node.attribute.extend([helper.make_attribute("order_input", 1)])
- attention_node.attribute.extend([helper.make_attribute("order_weight", 0)])
- attention_node.attribute.extend([helper.make_attribute("order_output", 1)])
- attention_node.attribute.extend(
- [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
- )
- attention_node.domain = "com.microsoft"
- self.nodes_to_add.append(attention_node)
- self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
- self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv])
- self.nodes_to_remove.extend(qk_nodes)
- self.nodes_to_remove.extend(q_nodes)
- self.nodes_to_remove.extend(k_nodes)
- self.nodes_to_remove.extend(v_nodes)
- self.nodes_to_remove.extend(
- [dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight]
- )
- # Use prune graph to remove mask nodes since they are shared by all attention nodes.
- # self.nodes_to_remove.extend(mask_nodes)
- self.prune_graph = True
|