fusion_qordered_attention.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_attention import AttentionMask
  8. from fusion_base import Fusion
  9. from fusion_utils import FusionUtils, NumpyHelper
  10. from onnx import NodeProto, helper
  11. from onnx_model import OnnxModel
  12. logger = getLogger(__name__)
  13. class FusionQOrderedAttention(Fusion):
  14. def __init__(
  15. self,
  16. model: OnnxModel,
  17. hidden_size: int,
  18. num_heads: int,
  19. attention_mask: AttentionMask,
  20. ):
  21. self.hidden_size = hidden_size
  22. self.num_heads = num_heads
  23. self.attention_mask = attention_mask
  24. super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization")
  25. def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]:
  26. """Detect num_heads and hidden_size from a reshape node.
  27. Args:
  28. reshape_q (NodeProto): reshape node for Q
  29. Returns:
  30. Tuple[int, int]: num_heads and hidden_size
  31. """
  32. # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
  33. q_shape = self.model.get_initializer(reshape_q.input[1])
  34. if q_shape is None:
  35. logger.debug(f"{reshape_q.input[1]} is not initializer.")
  36. # Check if the second input to Reshape flows through a Constant node
  37. # TODO: Investigate why FusionAttention doesn't have such logic
  38. constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1])
  39. if constant_node is None:
  40. return self.num_heads, self.hidden_size # Fall back to user specified value
  41. else:
  42. constant_node = constant_node[0]
  43. if len(constant_node.attribute) != 1:
  44. return self.num_heads, self.hidden_size # Fall back to user specified value
  45. # This is assuming it is a Tensor attribute (this is a safe assumption)
  46. q_shape = constant_node.attribute[0].t
  47. q_shape_value = NumpyHelper.to_array(q_shape)
  48. if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
  49. logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
  50. return self.num_heads, self.hidden_size # Fall back to user specified value
  51. num_heads = q_shape_value[2]
  52. head_size = q_shape_value[3]
  53. hidden_size = num_heads * head_size
  54. if self.num_heads > 0 and num_heads != self.num_heads:
  55. if self.num_heads_warning:
  56. logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
  57. self.num_heads_warning = False # Do not show the warning more than once
  58. if self.hidden_size > 0 and hidden_size != self.hidden_size:
  59. if self.hidden_size_warning:
  60. logger.warning(
  61. f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
  62. )
  63. self.hidden_size_warning = False # Do not show the warning more than once
  64. return num_heads, hidden_size
  65. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  66. add_before_layernorm = self.model.match_parent_path(
  67. normalize_node,
  68. ["QuantizeLinear", "Add"],
  69. [0, 0],
  70. )
  71. if add_before_layernorm is not None:
  72. start_node = add_before_layernorm[-1]
  73. else:
  74. return
  75. # Input QDQ nodes
  76. dequantize_input = self.model.match_parent_path(
  77. start_node,
  78. ["DequantizeLinear"],
  79. [None],
  80. )
  81. if dequantize_input is None:
  82. logger.debug("fuse_qordered_attention: failed to match input qdq nodes path")
  83. return
  84. dequantize_input = dequantize_input[-1]
  85. # QKV nodes
  86. qkv_nodes = self.model.match_parent_path(
  87. start_node,
  88. ["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"],
  89. [None, None, 0, 0, 0, 0, 0],
  90. )
  91. if qkv_nodes is None:
  92. logger.debug("fuse_qordered_attention: failed to match qkv path")
  93. return
  94. (_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes
  95. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  96. if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model):
  97. return
  98. if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model):
  99. return
  100. # Identify the root input to the Attention node
  101. other_inputs = []
  102. for _i, input in enumerate(start_node.input):
  103. if input not in output_name_to_node:
  104. continue
  105. if input == qkv_nodes[0].output[0]:
  106. continue
  107. other_inputs.append(input)
  108. if len(other_inputs) != 1:
  109. return
  110. root_input = other_inputs[0]
  111. # V nodes
  112. v_nodes = self.model.match_parent_path(
  113. matmul_qkv,
  114. ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
  115. [1, 0, 0, 0, 0, None],
  116. )
  117. if v_nodes is None:
  118. logger.debug("fuse_qordered_attention: failed to match v path")
  119. return
  120. (_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes
  121. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  122. if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model):
  123. return
  124. if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model):
  125. return
  126. # V MatMul weight
  127. dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1])
  128. if dequantize_v_matmul_weight is None:
  129. logger.debug("fuse_qordered_attention: failed to match v path")
  130. return
  131. dequantize_v_matmul_weight = dequantize_v_matmul_weight[0]
  132. if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None:
  133. return
  134. # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
  135. # Per-channel scales are supported for weights alone
  136. if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False):
  137. return
  138. # QK nodes
  139. qk_nodes = self.model.match_parent_path(
  140. matmul_qkv,
  141. [
  142. "DequantizeLinear",
  143. "QuantizeLinear",
  144. "Softmax",
  145. "Add",
  146. "Div",
  147. "DequantizeLinear",
  148. "QuantizeLinear",
  149. "MatMul",
  150. ],
  151. [0, 0, 0, 0, None, 0, 0, 0],
  152. )
  153. if qk_nodes is None:
  154. logger.debug("fuse_qordered_attention: failed to match qk path")
  155. return
  156. (
  157. dequantize_qk_softmax,
  158. quantize_qk_softmax,
  159. softmax_qk,
  160. add_qk,
  161. div_qk,
  162. dequantize_qk,
  163. quantize_qk,
  164. matmul_qk,
  165. ) = qk_nodes
  166. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  167. if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model):
  168. return
  169. if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model):
  170. return
  171. if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model):
  172. return
  173. if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model):
  174. return
  175. # Q nodes
  176. q_nodes = self.model.match_parent_path(
  177. matmul_qk,
  178. ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
  179. [0, 0, 0, 0, 0, None],
  180. )
  181. if q_nodes is None:
  182. logger.debug("fuse_qordered_attention: failed to match q path")
  183. return
  184. (_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes
  185. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  186. if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model):
  187. return
  188. if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model):
  189. return
  190. # Q MatMul weight
  191. dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1])
  192. if dequantize_q_matmul_weight is None:
  193. logger.debug("fuse_qordered_attention: failed to match q path")
  194. return
  195. dequantize_q_matmul_weight = dequantize_q_matmul_weight[0]
  196. if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None:
  197. return
  198. # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
  199. # Per-channel scales are supported for weights alone
  200. if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False):
  201. return
  202. # K nodes
  203. k_nodes = self.model.match_parent_path(
  204. matmul_qk,
  205. ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
  206. [1, 0, 0, 0, 0, None],
  207. )
  208. if k_nodes is None:
  209. logger.debug("fuse_qordered_attention: failed to match k path")
  210. return
  211. (_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes
  212. # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
  213. if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model):
  214. return
  215. if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model):
  216. return
  217. # K MatMul weight
  218. dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1])
  219. if dequantize_k_matmul_weight is None:
  220. logger.debug("fuse_qordered_attention: failed to match k path")
  221. return
  222. dequantize_k_matmul_weight = dequantize_k_matmul_weight[0]
  223. if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None:
  224. return
  225. # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
  226. # Per-channel scales are supported for weights alone
  227. if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False):
  228. return
  229. # Mask nodes
  230. mask_nodes = self.model.match_parent_path(
  231. add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]
  232. )
  233. if mask_nodes is None:
  234. logger.debug("fuse_qordered_attention: failed to match mask_nodes path")
  235. return
  236. # Ascertain `qkv_hidden_sizes` attribute value
  237. q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
  238. k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
  239. v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
  240. qw = NumpyHelper.to_array(q_weight)
  241. kw = NumpyHelper.to_array(k_weight)
  242. vw = NumpyHelper.to_array(v_weight)
  243. qw_out_size = np.prod(qw.shape[1:])
  244. kw_out_size = np.prod(kw.shape[1:])
  245. vw_out_size = np.prod(vw.shape[1:])
  246. # Form QOrderedAttention node
  247. if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
  248. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
  249. # Ascertain `num_heads` and `hidden_size`
  250. num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
  251. # Formulate the inputs
  252. # Actual quantized input
  253. attention_inputs = [dequantize_input.input[0]]
  254. attention_inputs.append(dequantize_input.input[1])
  255. attention_inputs.append(dequantize_q.input[1])
  256. attention_inputs.append(dequantize_k.input[1])
  257. attention_inputs.append(dequantize_v.input[1])
  258. attention_inputs.append(dequantize_q_matmul_weight.input[0])
  259. attention_inputs.append(dequantize_k_matmul_weight.input[0])
  260. attention_inputs.append(dequantize_v_matmul_weight.input[0])
  261. attention_inputs.append(dequantize_q_matmul_weight.input[1])
  262. attention_inputs.append(dequantize_k_matmul_weight.input[1])
  263. attention_inputs.append(dequantize_v_matmul_weight.input[1])
  264. if self.model.get_initializer(add_q.input[0]):
  265. attention_inputs.append(add_q.input[0])
  266. else: # second input is the constant bias
  267. attention_inputs.append(add_q.input[1])
  268. if self.model.get_initializer(add_k.input[0]):
  269. attention_inputs.append(add_k.input[0])
  270. else: # second input is the constant bias
  271. attention_inputs.append(add_k.input[1])
  272. if self.model.get_initializer(add_v.input[0]):
  273. attention_inputs.append(add_v.input[0])
  274. else: # second input is the constant bias
  275. attention_inputs.append(add_v.input[1])
  276. attention_inputs.append(quantize_qk.input[1])
  277. attention_inputs.append(quantize_qk_softmax.input[1])
  278. attention_inputs.append(dequantize_qkv.input[1])
  279. # Mask input
  280. if mask_index is not None:
  281. attention_inputs.append(mask_index)
  282. else:
  283. attention_inputs.append("")
  284. # The MatMul weight 'B' and 'bias' need some post-processing
  285. # Transpose weight 'B' from order ROW to order COL
  286. # This offline transpose is needed only while using the CUDA EP
  287. # TODO: Make this fusion logic EP-agnostic ?
  288. q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
  289. FusionUtils.transpose_2d_int8_tensor(q_weight_tensor)
  290. k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
  291. FusionUtils.transpose_2d_int8_tensor(k_weight_tensor)
  292. v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
  293. FusionUtils.transpose_2d_int8_tensor(v_weight_tensor)
  294. # Name and create Attention node
  295. attention_node_name = self.model.create_node_name("QOrderedAttention")
  296. attention_node = helper.make_node(
  297. "QOrderedAttention",
  298. inputs=attention_inputs,
  299. outputs=[reshape_qkv.output[0]],
  300. name=attention_node_name,
  301. )
  302. self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0])
  303. self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0])
  304. attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
  305. attention_node.attribute.extend([helper.make_attribute("order_input", 1)])
  306. attention_node.attribute.extend([helper.make_attribute("order_weight", 0)])
  307. attention_node.attribute.extend([helper.make_attribute("order_output", 1)])
  308. attention_node.attribute.extend(
  309. [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
  310. )
  311. attention_node.domain = "com.microsoft"
  312. self.nodes_to_add.append(attention_node)
  313. self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
  314. self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv])
  315. self.nodes_to_remove.extend(qk_nodes)
  316. self.nodes_to_remove.extend(q_nodes)
  317. self.nodes_to_remove.extend(k_nodes)
  318. self.nodes_to_remove.extend(v_nodes)
  319. self.nodes_to_remove.extend(
  320. [dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight]
  321. )
  322. # Use prune graph to remove mask nodes since they are shared by all attention nodes.
  323. # self.nodes_to_remove.extend(mask_nodes)
  324. self.prune_graph = True