fusion_conformer_attention.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_attention import AttentionMask, FusionAttention
  7. from onnx_model import OnnxModel
  8. logger = logging.getLogger(__name__)
  9. class FusionConformerAttention(FusionAttention):
  10. """
  11. Fuse Conformer Attention subgraph into one MultiHeadAttention node.
  12. """
  13. def __init__(
  14. self,
  15. model: OnnxModel,
  16. hidden_size: int,
  17. num_heads: int,
  18. attention_mask: AttentionMask,
  19. ):
  20. super().__init__(model, hidden_size, num_heads, attention_mask)
  21. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  22. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  23. qkv_nodes = self.model.match_parent_path(
  24. normalize_node,
  25. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  26. [1, None, 0, 0, 0],
  27. )
  28. if qkv_nodes is None:
  29. logger.debug("fuse_conformer_attention: failed to match qkv path")
  30. return
  31. reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes[-3], qkv_nodes[-2], qkv_nodes[-1]
  32. past_v, present_v = "", ""
  33. v_nodes = self.model.match_parent_path(
  34. matmul_qkv,
  35. ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
  36. [1, 1, 0, 0, 1],
  37. )
  38. if v_nodes is None:
  39. v_nodes = self.model.match_parent_path(
  40. matmul_qkv,
  41. ["Transpose", "Reshape", "Add", "MatMul"],
  42. [1, 0, 0, 0],
  43. )
  44. if v_nodes is None:
  45. logger.debug("fuse_conformer_attention: failed to match v path")
  46. return
  47. else:
  48. concat_v = v_nodes[0]
  49. concat_parent = self.model.get_parent(concat_v, 0, None)
  50. present_v = concat_v.output[0]
  51. past_v = concat_parent.output[0]
  52. add_v, matmul_v = v_nodes[-2], v_nodes[-1]
  53. attn_mask = ""
  54. qk_nodes = self.model.match_parent_path(
  55. matmul_qkv,
  56. ["Softmax", "Add", "MatMul"],
  57. [0, 0, 0],
  58. )
  59. if qk_nodes is None:
  60. qk_nodes = self.model.match_parent_path(
  61. matmul_qkv,
  62. ["Where", "Softmax", "Where", "Add", "MatMul"],
  63. [0, 2, 0, 2, 0],
  64. )
  65. if qk_nodes is None:
  66. logger.debug("fuse_conformer_attention: failed to match qk path")
  67. return
  68. where_qk = qk_nodes[2]
  69. mask_nodes = self.model.match_parent_path(
  70. where_qk,
  71. ["Equal", "Unsqueeze", "Cast"],
  72. [0, 0, 0],
  73. )
  74. if mask_nodes is not None:
  75. attn_mask = mask_nodes[-1].output[0]
  76. add_qk, matmul_qk = qk_nodes[-2], qk_nodes[-1]
  77. q_nodes = self.model.match_parent_path(
  78. matmul_qk,
  79. ["Div", "Transpose", "Reshape", "Add", "MatMul"],
  80. [0, 0, 0, 0, 1],
  81. )
  82. if q_nodes is None:
  83. q_nodes = self.model.match_parent_path(
  84. matmul_qk,
  85. ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
  86. [0, 0, 0, 0, 0],
  87. )
  88. if q_nodes is None:
  89. logger.debug("fuse_conformer_attention: failed to match q path")
  90. return
  91. reshape_q, add_q, matmul_q = q_nodes[-3], q_nodes[-2], q_nodes[-1]
  92. extra_q_nodes = self.model.match_parent_path(
  93. add_qk,
  94. ["Reshape", "Transpose", "MatMul", "Transpose", "Reshape", "Div"],
  95. [1, 0, 0, 0, 0, 0],
  96. )
  97. if extra_q_nodes is not None and q_nodes[0] != extra_q_nodes[-1]:
  98. logger.debug("fuse_conformer_attention: failed to match extra q path")
  99. return
  100. past_k, present_k = "", ""
  101. k_nodes = self.model.match_parent_path(
  102. matmul_qk,
  103. ["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
  104. [1, 0, 1, 0, 0, 1],
  105. )
  106. if k_nodes is None:
  107. k_nodes = self.model.match_parent_path(
  108. matmul_qk,
  109. ["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
  110. [1, 0, 0, 0, 0],
  111. )
  112. if k_nodes is None:
  113. k_nodes = self.model.match_parent_path(
  114. matmul_qk,
  115. ["Transpose", "Reshape", "Add", "MatMul"],
  116. [1, 0, 0, 0],
  117. )
  118. if k_nodes is None:
  119. logger.debug("fuse_conformer_attention: failed to match k path")
  120. return
  121. else:
  122. concat_k = k_nodes[1]
  123. concat_parent = self.model.get_parent(concat_k, 0, None)
  124. past_k = concat_parent.output[0]
  125. present_k = concat_k.output[0]
  126. add_k, matmul_k = k_nodes[-2], k_nodes[-1]
  127. num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
  128. if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
  129. logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
  130. return
  131. new_node = None
  132. use_packed_attention_op = (
  133. matmul_q.input[0] == matmul_k.input[0] and matmul_k.input[0] == matmul_v.input[0] and extra_q_nodes is None
  134. )
  135. if use_packed_attention_op:
  136. # Self-attention, use Attention op
  137. new_node = self.create_attention_node(
  138. mask_index=attn_mask,
  139. q_matmul=matmul_q,
  140. k_matmul=matmul_k,
  141. v_matmul=matmul_v,
  142. q_add=add_q,
  143. k_add=add_k,
  144. v_add=add_v,
  145. num_heads=num_heads,
  146. hidden_size=hidden_size,
  147. first_input=matmul_q.input[0],
  148. output=reshape_qkv.output[0],
  149. add_qk_str=add_qk.input[1],
  150. past_k=past_k,
  151. past_v=past_v,
  152. present_k=present_k,
  153. present_v=present_v,
  154. )
  155. else:
  156. new_node = self.create_multihead_attention_node(
  157. q_matmul=matmul_q,
  158. k_matmul=matmul_k,
  159. v_matmul=matmul_v,
  160. q_add=add_q,
  161. k_add=add_k,
  162. v_add=add_v,
  163. num_heads=num_heads,
  164. hidden_size=hidden_size,
  165. output=reshape_qkv.output[0],
  166. key_padding_mask=attn_mask,
  167. add_qk=add_qk.input[1],
  168. past_k=past_k,
  169. past_v=past_v,
  170. present_k=present_k,
  171. present_v=present_v,
  172. )
  173. if new_node is None:
  174. logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
  175. return
  176. self.nodes_to_add.append(new_node)
  177. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  178. self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
  179. self.nodes_to_remove.extend(qk_nodes)
  180. # When using MultiHeadAttention, keep MatMul nodes unfused in original graph
  181. if not use_packed_attention_op:
  182. if q_nodes[-1].op_type == "MatMul":
  183. q_nodes.pop()
  184. if k_nodes[-1].op_type == "MatMul":
  185. k_nodes.pop()
  186. if v_nodes[-1].op_type == "MatMul":
  187. v_nodes.pop()
  188. if extra_q_nodes is None:
  189. # Don't remove Q nodes for conformer-transducer (CT) model since it has
  190. # an extra set of nodes attached to the output of the Q path that are not
  191. # part of the attention computation
  192. self.nodes_to_remove.extend(q_nodes)
  193. self.nodes_to_remove.extend(k_nodes)
  194. self.nodes_to_remove.extend(v_nodes)
  195. # Use prune graph to remove mask nodes since they are shared by all attention nodes.
  196. self.prune_graph = True