| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- import logging
- from fusion_attention import AttentionMask, FusionAttention
- from onnx_model import OnnxModel
- logger = logging.getLogger(__name__)
- class FusionConformerAttention(FusionAttention):
- """
- Fuse Conformer Attention subgraph into one MultiHeadAttention node.
- """
- def __init__(
- self,
- model: OnnxModel,
- hidden_size: int,
- num_heads: int,
- attention_mask: AttentionMask,
- ):
- super().__init__(model, hidden_size, num_heads, attention_mask)
- def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
- # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
- qkv_nodes = self.model.match_parent_path(
- normalize_node,
- ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
- [1, None, 0, 0, 0],
- )
- if qkv_nodes is None:
- logger.debug("fuse_conformer_attention: failed to match qkv path")
- return
- reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes[-3], qkv_nodes[-2], qkv_nodes[-1]
- past_v, present_v = "", ""
- v_nodes = self.model.match_parent_path(
- matmul_qkv,
- ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
- [1, 1, 0, 0, 1],
- )
- if v_nodes is None:
- v_nodes = self.model.match_parent_path(
- matmul_qkv,
- ["Transpose", "Reshape", "Add", "MatMul"],
- [1, 0, 0, 0],
- )
- if v_nodes is None:
- logger.debug("fuse_conformer_attention: failed to match v path")
- return
- else:
- concat_v = v_nodes[0]
- concat_parent = self.model.get_parent(concat_v, 0, None)
- present_v = concat_v.output[0]
- past_v = concat_parent.output[0]
- add_v, matmul_v = v_nodes[-2], v_nodes[-1]
- attn_mask = ""
- qk_nodes = self.model.match_parent_path(
- matmul_qkv,
- ["Softmax", "Add", "MatMul"],
- [0, 0, 0],
- )
- if qk_nodes is None:
- qk_nodes = self.model.match_parent_path(
- matmul_qkv,
- ["Where", "Softmax", "Where", "Add", "MatMul"],
- [0, 2, 0, 2, 0],
- )
- if qk_nodes is None:
- logger.debug("fuse_conformer_attention: failed to match qk path")
- return
- where_qk = qk_nodes[2]
- mask_nodes = self.model.match_parent_path(
- where_qk,
- ["Equal", "Unsqueeze", "Cast"],
- [0, 0, 0],
- )
- if mask_nodes is not None:
- attn_mask = mask_nodes[-1].output[0]
- add_qk, matmul_qk = qk_nodes[-2], qk_nodes[-1]
- q_nodes = self.model.match_parent_path(
- matmul_qk,
- ["Div", "Transpose", "Reshape", "Add", "MatMul"],
- [0, 0, 0, 0, 1],
- )
- if q_nodes is None:
- q_nodes = self.model.match_parent_path(
- matmul_qk,
- ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
- [0, 0, 0, 0, 0],
- )
- if q_nodes is None:
- logger.debug("fuse_conformer_attention: failed to match q path")
- return
- reshape_q, add_q, matmul_q = q_nodes[-3], q_nodes[-2], q_nodes[-1]
- extra_q_nodes = self.model.match_parent_path(
- add_qk,
- ["Reshape", "Transpose", "MatMul", "Transpose", "Reshape", "Div"],
- [1, 0, 0, 0, 0, 0],
- )
- if extra_q_nodes is not None and q_nodes[0] != extra_q_nodes[-1]:
- logger.debug("fuse_conformer_attention: failed to match extra q path")
- return
- past_k, present_k = "", ""
- k_nodes = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
- [1, 0, 1, 0, 0, 1],
- )
- if k_nodes is None:
- k_nodes = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
- [1, 0, 0, 0, 0],
- )
- if k_nodes is None:
- k_nodes = self.model.match_parent_path(
- matmul_qk,
- ["Transpose", "Reshape", "Add", "MatMul"],
- [1, 0, 0, 0],
- )
- if k_nodes is None:
- logger.debug("fuse_conformer_attention: failed to match k path")
- return
- else:
- concat_k = k_nodes[1]
- concat_parent = self.model.get_parent(concat_k, 0, None)
- past_k = concat_parent.output[0]
- present_k = concat_k.output[0]
- add_k, matmul_k = k_nodes[-2], k_nodes[-1]
- num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
- if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
- logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
- return
- new_node = None
- use_packed_attention_op = (
- matmul_q.input[0] == matmul_k.input[0] and matmul_k.input[0] == matmul_v.input[0] and extra_q_nodes is None
- )
- if use_packed_attention_op:
- # Self-attention, use Attention op
- new_node = self.create_attention_node(
- mask_index=attn_mask,
- q_matmul=matmul_q,
- k_matmul=matmul_k,
- v_matmul=matmul_v,
- q_add=add_q,
- k_add=add_k,
- v_add=add_v,
- num_heads=num_heads,
- hidden_size=hidden_size,
- first_input=matmul_q.input[0],
- output=reshape_qkv.output[0],
- add_qk_str=add_qk.input[1],
- past_k=past_k,
- past_v=past_v,
- present_k=present_k,
- present_v=present_v,
- )
- else:
- new_node = self.create_multihead_attention_node(
- q_matmul=matmul_q,
- k_matmul=matmul_k,
- v_matmul=matmul_v,
- q_add=add_q,
- k_add=add_k,
- v_add=add_v,
- num_heads=num_heads,
- hidden_size=hidden_size,
- output=reshape_qkv.output[0],
- key_padding_mask=attn_mask,
- add_qk=add_qk.input[1],
- past_k=past_k,
- past_v=past_v,
- present_k=present_k,
- present_v=present_v,
- )
- if new_node is None:
- logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
- return
- self.nodes_to_add.append(new_node)
- self.node_name_to_graph_name[new_node.name] = self.this_graph_name
- self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
- self.nodes_to_remove.extend(qk_nodes)
- # When using MultiHeadAttention, keep MatMul nodes unfused in original graph
- if not use_packed_attention_op:
- if q_nodes[-1].op_type == "MatMul":
- q_nodes.pop()
- if k_nodes[-1].op_type == "MatMul":
- k_nodes.pop()
- if v_nodes[-1].op_type == "MatMul":
- v_nodes.pop()
- if extra_q_nodes is None:
- # Don't remove Q nodes for conformer-transducer (CT) model since it has
- # an extra set of nodes attached to the output of the Q path that are not
- # part of the attention computation
- self.nodes_to_remove.extend(q_nodes)
- self.nodes_to_remove.extend(k_nodes)
- self.nodes_to_remove.extend(v_nodes)
- # Use prune graph to remove mask nodes since they are shared by all attention nodes.
- self.prune_graph = True
|