fusion_gpt_attention_megatron.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. import numpy as np
  7. from fusion_gpt_attention import FusionGptAttentionPastBase
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. def is_close(value, expected_value):
  12. return abs(value - expected_value) <= 1e-6
  13. class FusionGptAttentionMegatron(FusionGptAttentionPastBase):
  14. """
  15. Fuse GPT-2 Attention with past state subgraph from Megatron into one Attention node.
  16. """
  17. def __init__(self, model: OnnxModel, num_heads: int):
  18. super().__init__(model, num_heads)
  19. def fuse_attention_node(
  20. self,
  21. matmul_before_split,
  22. add_before_split,
  23. past,
  24. present,
  25. input,
  26. reshape_qkv,
  27. mask,
  28. ):
  29. attention_node_name = self.model.create_node_name("GptAttention")
  30. int32_mask = self.cast_attention_mask(mask)
  31. output = reshape_qkv.output[0]
  32. i = 1 if (add_before_split.input[0] == matmul_before_split.output[0]) else 0
  33. attention_node = helper.make_node(
  34. "Attention",
  35. inputs=[
  36. input,
  37. matmul_before_split.input[1],
  38. add_before_split.input[i],
  39. int32_mask,
  40. past,
  41. ],
  42. outputs=[output, present],
  43. name=attention_node_name,
  44. )
  45. attention_node.domain = "com.microsoft"
  46. attention_node.attribute.extend(
  47. [
  48. helper.make_attribute("num_heads", self.num_heads),
  49. helper.make_attribute("unidirectional", 0), # unidirectional shall not be ON for 4D attention mask
  50. ]
  51. )
  52. if self.mask_filter_value is not None:
  53. attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
  54. nodes_to_add = [attention_node]
  55. self.nodes_to_add.extend(nodes_to_add)
  56. for node in nodes_to_add:
  57. self.node_name_to_graph_name[node.name] = self.this_graph_name
  58. self.nodes_to_remove.append(reshape_qkv)
  59. # we rely on prune_graph() to clean old subgraph nodes
  60. self.prune_graph = True
  61. def match_mask(self, sub_qk, mul_qk, matmul_qk, layernorm_before_attention):
  62. mask_nodes = self.model.match_parent_path(sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0])
  63. if mask_nodes is None:
  64. logger.debug("fuse_attention: failed to match unidirectional mask path")
  65. return None
  66. (mul_mask, sub_mask, last_slice_mask, slice_mask) = mask_nodes
  67. if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
  68. _, mul_val = self.model.get_constant_input(mask_nodes[0])
  69. if mul_val != 10000:
  70. self.mask_filter_value = -mul_val
  71. if mul_qk.input[1] != last_slice_mask.output[0]:
  72. logger.debug("fuse_attention failed: mul_qk.input[1] != last_slice_mask.output[0]")
  73. return None
  74. if not self.utils.check_node_input_value(mul_mask, 1, 10000.0):
  75. logger.debug("fuse_attention failed: mul_mask input 1 is not constant 10000.0")
  76. return None
  77. if not self.utils.check_node_input_value(sub_mask, 0, 1.0):
  78. logger.debug("fuse_attention failed: sub_mask input 0 is not constant 1.0")
  79. return None
  80. if not self.model.find_graph_input(slice_mask.input[0]):
  81. logger.info("expect slick_mask input 0 to be graph input")
  82. return None
  83. if not self.utils.check_node_input_value(last_slice_mask, 1, [0]):
  84. logger.debug("fuse_attention failed: last_slice_mask input 1 (starts) is not constant [0]")
  85. return None
  86. if not self.utils.check_node_input_value(last_slice_mask, 3, [3]):
  87. logger.debug("fuse_attention failed: last_slice_mask input 3 (axes) is not constant [3]")
  88. return False
  89. if not self.utils.check_node_input_value(last_slice_mask, 4, [1]):
  90. logger.debug("fuse_attention failed: last_slice_mask input 4 (steps) is not constant [1]")
  91. return False
  92. if not self.utils.check_node_input_value(slice_mask, 3, [2]):
  93. logger.debug("fuse_attention failed: slice_mask input 3 (axes) is not constant [2]")
  94. return None
  95. if not self.utils.check_node_input_value(slice_mask, 4, [1]):
  96. logger.debug("fuse_attention failed: slice_mask input 4 (steps) is not constant [1]")
  97. return None
  98. last_slice_path = self.model.match_parent_path(
  99. last_slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
  100. )
  101. if last_slice_path is None or last_slice_path[-1] != matmul_qk:
  102. logger.debug("fuse_attention: failed to match last slice path")
  103. return None
  104. first_slice_path = self.model.match_parent_path(
  105. slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
  106. )
  107. if first_slice_path is None or first_slice_path[-1] != matmul_qk:
  108. logger.debug("fuse_attention: failed to match first slice path")
  109. return None
  110. first_slice_sub = self.model.match_parent_path(
  111. slice_mask,
  112. ["Unsqueeze", "Sub", "Gather", "Shape", "MatMul"],
  113. [1, 0, 0, 0, 0],
  114. )
  115. if first_slice_sub is None or first_slice_sub[-1] != matmul_qk:
  116. logger.debug("fuse_attention: failed to match last slice sub path")
  117. return None
  118. first_slice_sub_1 = self.model.match_parent_path(
  119. slice_mask,
  120. ["Unsqueeze", "Sub", "Gather", "Shape", "LayerNormalization"],
  121. [1, 0, 1, 0, 0],
  122. )
  123. if first_slice_sub_1 is None:
  124. first_slice_sub_1 = self.model.match_parent_path(
  125. slice_mask,
  126. ["Unsqueeze", "Sub", "Gather", "Shape", "SkipLayerNormalization"],
  127. [1, 0, 1, 0, 0],
  128. )
  129. if first_slice_sub_1 is None or first_slice_sub_1[-1] != layernorm_before_attention:
  130. logger.debug("fuse_attention: failed to match last slice sub path 1")
  131. return None
  132. return slice_mask.input[0]
  133. def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
  134. past = None
  135. present = None
  136. is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
  137. qkv_nodes = None
  138. if not is_normalize_node_skiplayernorm:
  139. qkv_nodes = self.model.match_parent_path(
  140. normalize_node,
  141. ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  142. [0, 1, None, 0, 0, 0],
  143. output_name_to_node=output_name_to_node,
  144. )
  145. else:
  146. qkv_nodes = self.model.match_parent_path(
  147. normalize_node,
  148. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  149. [1, None, 0, 0, 0],
  150. output_name_to_node=output_name_to_node,
  151. )
  152. if qkv_nodes is None:
  153. return
  154. skip_input = None
  155. if not is_normalize_node_skiplayernorm:
  156. (
  157. add_skip,
  158. add_after_attention,
  159. matmul_after_attention,
  160. reshape_qkv,
  161. transpose_qkv,
  162. matmul_qkv,
  163. ) = qkv_nodes
  164. skip_input = add_skip.input[0]
  165. else:
  166. (
  167. add_after_attention,
  168. matmul_after_attention,
  169. reshape_qkv,
  170. transpose_qkv,
  171. matmul_qkv,
  172. ) = qkv_nodes
  173. skip_input = normalize_node.input[0]
  174. v_nodes = self.model.match_parent_path(
  175. matmul_qkv,
  176. [
  177. "Concat",
  178. "Transpose",
  179. "Reshape",
  180. "Split",
  181. "Add",
  182. "MatMul",
  183. "LayerNormalization",
  184. ],
  185. [1, 1, 0, 0, 0, None, 0],
  186. )
  187. if v_nodes is None:
  188. v_nodes = self.model.match_parent_path(
  189. matmul_qkv,
  190. [
  191. "Concat",
  192. "Transpose",
  193. "Reshape",
  194. "Split",
  195. "Add",
  196. "MatMul",
  197. "SkipLayerNormalization",
  198. ],
  199. [1, 1, 0, 0, 0, None, 0],
  200. )
  201. if v_nodes is None:
  202. logger.debug("fuse_attention: failed to match v path")
  203. return
  204. (
  205. concat_v,
  206. transpose_v,
  207. reshape_v,
  208. split_v,
  209. add_before_split,
  210. matmul_before_split,
  211. layernorm_before_attention,
  212. ) = v_nodes
  213. if (
  214. layernorm_before_attention.op_type == "LayerNormalization"
  215. and skip_input != layernorm_before_attention.input[0]
  216. ):
  217. logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
  218. return
  219. if (
  220. layernorm_before_attention.op_type == "SkipLayerNormalization"
  221. and skip_input != layernorm_before_attention.output[3]
  222. ):
  223. logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
  224. return
  225. qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "MatMul"], [0, 0, 0, 0])
  226. if qk_nodes is None:
  227. logger.debug("fuse_attention: failed to match qk path")
  228. return None
  229. (softmax_qk, sub_qk, mul_qk, matmul_qk) = qk_nodes
  230. if self.model.get_node_attribute(softmax_qk, "axis") != 3:
  231. logger.debug("fuse_attention failed: softmax_qk axis != 3")
  232. return None
  233. attention_mask = self.match_mask(sub_qk, mul_qk, matmul_qk, layernorm_before_attention)
  234. q_nodes = self.model.match_parent_path(matmul_qk, ["Div", "Transpose", "Reshape", "Split"], [0, 0, 0, 0])
  235. if q_nodes is None:
  236. logger.debug("fuse_attention: failed to match q path")
  237. return
  238. (div_q, transpose_q, reshape_q, split_q) = q_nodes
  239. if split_v != split_q:
  240. logger.debug("fuse_attention: skip since split_v != split_q")
  241. return
  242. k_nodes = self.model.match_parent_path(
  243. matmul_qk,
  244. ["Div", "Transpose", "Concat", "Transpose", "Reshape", "Split"],
  245. [1, 0, 0, 1, 0, 0],
  246. )
  247. if k_nodes is None:
  248. logger.debug("fuse_attention: failed to match k path")
  249. return
  250. (div_k, _, concat_k, transpose_k, reshape_k, split_k) = k_nodes
  251. if split_v != split_k:
  252. logger.debug("fuse_attention: skip since split_v != split_k")
  253. return
  254. i, value = self.model.get_constant_input(reshape_k)
  255. if not (
  256. isinstance(value, np.ndarray)
  257. and list(value.shape) == [4]
  258. and value[0] == 0
  259. and value[1] == 0
  260. and value[2] > 0
  261. and value[3] > 0
  262. ):
  263. logger.debug("fuse_attention: reshape constant input is not [0, 0, N, H]")
  264. return
  265. num_heads = value[2]
  266. if num_heads != self.num_heads:
  267. logger.info(f"Detected num_heads={num_heads}. Ignore user specified value {self.num_heads}")
  268. self.num_heads = num_heads
  269. hidden_size_per_head = value[3]
  270. i, value = self.model.get_constant_input(div_k)
  271. expected_value = float(np.sqrt(np.sqrt(hidden_size_per_head)))
  272. if not is_close(value, expected_value):
  273. logger.debug(f"fuse_attention: div_k value={value} expected={expected_value}")
  274. return
  275. i, value = self.model.get_constant_input(div_q)
  276. if not is_close(value, expected_value):
  277. logger.debug(f"fuse_attention: div_q value={value} expected={expected_value}")
  278. return
  279. # Match past and present paths
  280. past = self.match_past_pattern_2(concat_k, concat_v, output_name_to_node)
  281. if past is None:
  282. logger.debug("fuse_attention: match past failed")
  283. return
  284. if not self.model.find_graph_input(past):
  285. logger.debug("fuse_attention: past is not graph input.")
  286. # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
  287. present = self.match_present(concat_v, input_name_to_nodes)
  288. if present is None:
  289. logger.debug("fuse_attention: match present failed")
  290. return
  291. if not self.model.find_graph_output(present):
  292. logger.info("fuse_attention: expect present to be graph output")
  293. return
  294. self.fuse_attention_node(
  295. matmul_before_split,
  296. add_before_split,
  297. past,
  298. present,
  299. layernorm_before_attention.output[0],
  300. reshape_qkv,
  301. attention_mask,
  302. )