onnx_model_bert_keras.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import onnx
  7. from onnx import numpy_helper
  8. from onnx_model_bert_tf import BertOnnxModelTF
  9. logger = logging.getLogger(__name__)
  10. class BertOnnxModelKeras(BertOnnxModelTF):
  11. def __init__(self, model, num_heads, hidden_size):
  12. super().__init__(model, num_heads, hidden_size)
  13. def match_mask_path(self, add_or_sub_before_softmax):
  14. mask_nodes = self.match_parent_path(
  15. add_or_sub_before_softmax,
  16. ["Mul", "Sub", "Reshape", "Cast"],
  17. [1, None, 1, 0],
  18. )
  19. if mask_nodes is not None:
  20. return mask_nodes
  21. mask_nodes = self.match_parent_path(
  22. add_or_sub_before_softmax,
  23. ["Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
  24. [1, 1, 1, 0, 0],
  25. )
  26. if mask_nodes is not None:
  27. return mask_nodes
  28. mask_nodes = self.match_parent_path(
  29. add_or_sub_before_softmax,
  30. ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
  31. [1, None, 1, 0, 0],
  32. )
  33. return mask_nodes
  34. def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node):
  35. reshape_nodes = []
  36. for x in [matmul_q, matmul_k, matmul_v]:
  37. root_input = x.input[0]
  38. root_node = output_name_to_node[root_input]
  39. if root_node == parent:
  40. continue
  41. if root_node.op_type == "Reshape" and root_node.input[0] == parent.output[0]:
  42. reshape_nodes.append(root_node)
  43. continue
  44. logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}")
  45. return False, []
  46. return True, reshape_nodes
  47. def fuse_attention(self):
  48. self.input_name_to_nodes()
  49. output_name_to_node = self.output_name_to_node()
  50. nodes_to_remove = []
  51. attention_count = 0
  52. skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
  53. for normalize_node in skip_layer_norm_nodes:
  54. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  55. parent = self.get_parent(normalize_node, 0)
  56. if parent is None or parent.op_type not in [
  57. "SkipLayerNormalization",
  58. "EmbedLayerNormalization",
  59. ]:
  60. if parent.op_type == "Add":
  61. parent = self.get_parent(normalize_node, 1)
  62. if parent is None or parent.op_type not in [
  63. "SkipLayerNormalization",
  64. "EmbedLayerNormalization",
  65. ]:
  66. logger.debug(f"First input for skiplayernorm: {parent.op_type if parent is not None else None}")
  67. continue
  68. else:
  69. logger.debug(f"First input for skiplayernorm: {parent.op_type if parent is not None else None}")
  70. continue
  71. else:
  72. # TODO: shall we add back the checking of children op types.
  73. pass
  74. qkv_nodes = self.match_parent_path(
  75. normalize_node,
  76. ["Add", "Reshape", "MatMul", "Reshape", "Transpose", "MatMul"],
  77. [None, 0, 0, 0, 0, 0],
  78. )
  79. if qkv_nodes is None:
  80. logger.debug("Failed to match qkv nodes")
  81. continue
  82. (
  83. add,
  84. extra_reshape_0,
  85. matmul,
  86. reshape_qkv,
  87. transpose_qkv,
  88. matmul_qkv,
  89. ) = qkv_nodes
  90. logger.debug("Matched qkv nodes")
  91. v_nodes = self.match_parent_path(
  92. matmul_qkv,
  93. ["Transpose", "Reshape", "Add", "Reshape", "MatMul"],
  94. [1, 0, 0, 0, 0],
  95. )
  96. if v_nodes is None:
  97. logger.debug("Failed to match v path")
  98. continue
  99. (transpose_v, reshape_v, add_v, extra_reshape_1, matmul_v) = v_nodes
  100. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Sub", "MatMul"], [0, 0, 0])
  101. if qk_nodes is not None:
  102. (softmax_qk, sub_qk, matmul_qk) = qk_nodes
  103. q_nodes = self.match_parent_path(
  104. matmul_qk,
  105. ["Mul", "Transpose", "Reshape", "Add", "Reshape", "MatMul"],
  106. [0, None, 0, 0, 0, 0],
  107. )
  108. if q_nodes is not None:
  109. (
  110. mul_q,
  111. transpose_q,
  112. reshape_q,
  113. add_q,
  114. extra_reshape_2,
  115. matmul_q,
  116. ) = q_nodes
  117. else:
  118. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, None])
  119. if qk_nodes is None:
  120. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Div", "MatMul"], [0, 0, 0, None])
  121. if qk_nodes is None:
  122. logger.debug("Failed to match qk path")
  123. continue
  124. (softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes
  125. q_nodes = self.match_parent_path(
  126. matmul_qk,
  127. ["Transpose", "Reshape", "Add", "Reshape", "MatMul"],
  128. [0, 0, 0, 0, 0],
  129. )
  130. if q_nodes is not None:
  131. (transpose_q, reshape_q, add_q, extra_reshape_2, matmul_q) = q_nodes
  132. if q_nodes is None:
  133. logger.debug("Failed to match q path")
  134. continue
  135. k_nodes = self.match_parent_path(
  136. matmul_qk,
  137. ["Transpose", "Reshape", "Add", "Reshape", "MatMul"],
  138. [1, 0, 0, 0, 0],
  139. )
  140. if k_nodes is None:
  141. logger.debug("Failed to match k path")
  142. continue
  143. (transpose_k, reshape_k, add_k, extra_reshape_3, matmul_k) = k_nodes
  144. mask_nodes = self.match_mask_path(qk_nodes[1])
  145. if mask_nodes is None:
  146. logger.debug("Failed to match mask path")
  147. continue
  148. if not self.has_constant_input(mask_nodes[1], 1):
  149. logger.debug("Sub node expected to have an input with constant value 1.0.")
  150. continue
  151. is_same_root, reshape_nodes = self.check_attention_input(
  152. matmul_q, matmul_k, matmul_v, parent, output_name_to_node
  153. )
  154. if is_same_root:
  155. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
  156. logger.debug("Create an Attention node.")
  157. attention_node = self.attention_fusion.create_attention_node(
  158. mask_index=mask_index,
  159. q_matmul=matmul_q,
  160. k_matmul=matmul_k,
  161. v_matmul=matmul_v,
  162. q_add=add_q,
  163. k_add=add_k,
  164. v_add=add_v,
  165. num_heads=self.num_heads,
  166. hidden_size=self.hidden_size,
  167. first_input=parent.output[0],
  168. output=reshape_qkv.output[0],
  169. )
  170. if attention_node is None:
  171. continue
  172. self.add_node(attention_node)
  173. attention_count += 1
  174. nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
  175. nodes_to_remove.extend(qk_nodes)
  176. nodes_to_remove.extend(q_nodes)
  177. nodes_to_remove.extend(k_nodes)
  178. nodes_to_remove.extend(v_nodes)
  179. nodes_to_remove.extend(mask_nodes)
  180. nodes_to_remove.extend(reshape_nodes)
  181. nodes_to_remove.append(extra_reshape_0)
  182. self.replace_node_input(add, extra_reshape_0.output[0], matmul.output[0])
  183. else:
  184. logger.debug("Root node not matched.")
  185. continue
  186. self.remove_nodes(nodes_to_remove)
  187. self.update_graph()
  188. logger.info(f"Fused Attention count:{attention_count}")
  189. def preprocess(self):
  190. self.process_embedding()
  191. self.fuse_mask()
  192. self.skip_reshape()
  193. def skip_reshape(self):
  194. self.input_name_to_nodes()
  195. self.output_name_to_node()
  196. count = 0
  197. reshape_nodes = self.get_nodes_by_op_type("Reshape")
  198. for reshape_node in reshape_nodes:
  199. parent = self.get_parent(reshape_node, 0)
  200. if parent is not None and parent.op_type == "Reshape":
  201. reshape_node.input[0] = parent.input[0]
  202. count += 1
  203. if count > 0:
  204. logger.info(f"Skip consequent Reshape count: {count}")
  205. def fuse_embedding(self, node, output_name_to_node):
  206. assert node.op_type == "LayerNormalization"
  207. logger.debug(f"start fusing embedding from node with output={node.output[0]}...")
  208. word_embed_path = self.match_parent_path(node, ["Add", "Add", "Gather"], [0, 0, 0], output_name_to_node)
  209. if word_embed_path is None:
  210. logger.debug("failed to match word_embed_path")
  211. return False
  212. skip_node, add_node, gather_node = word_embed_path
  213. word_initializer = self.get_initializer(gather_node.input[0])
  214. if word_initializer is None:
  215. logger.debug("failed to get word initializer")
  216. return False
  217. temp = numpy_helper.to_array(word_initializer)
  218. if len(temp.shape) == 2:
  219. logger.info(f"Found word embedding. name:{word_initializer.name}, shape:{temp.shape}")
  220. word_embedding = word_initializer.name
  221. else:
  222. logger.info(f"Failed to find word embedding. name:{word_initializer.name}, shape:{temp.shape}")
  223. return False
  224. pos_initializer = self.get_initializer(add_node.input[1])
  225. if pos_initializer is not None:
  226. temp = numpy_helper.to_array(pos_initializer)
  227. if len(temp.shape) == 3 and temp.shape[0] == 1:
  228. tensor = numpy_helper.from_array(temp.reshape((temp.shape[1], temp.shape[2])), "position_embedding")
  229. self.add_initializer(tensor)
  230. logger.info(f"Found position embedding. name:{pos_initializer.name}, shape:{temp.shape[1:]}")
  231. position_embedding = "position_embedding"
  232. else:
  233. logger.info(f"Failed to find position embedding. name:{pos_initializer.name}, shape:{temp.shape}")
  234. return False
  235. else:
  236. pos_embed_path = self.match_parent_path(add_node, ["Gather", "Slice"], [1, 1], output_name_to_node)
  237. if pos_embed_path is None:
  238. logger.debug("failed to match pos_embed_path")
  239. return False
  240. pos_gather, pos_slice = pos_embed_path
  241. pos_initializer = self.get_initializer(pos_gather.input[0])
  242. if pos_initializer is None:
  243. logger.debug("failed to get pos initializer")
  244. return False
  245. temp = numpy_helper.to_array(pos_initializer)
  246. if len(temp.shape) == 2:
  247. logger.info(f"Found word embedding. name:{pos_initializer.name}, shape:{temp.shape}")
  248. position_embedding = pos_initializer.name
  249. else:
  250. logger.info(f"Failed to find position embedding. name:{pos_initializer.name}, shape:{temp.shape}")
  251. return False
  252. gather = self.get_parent(skip_node, 1, output_name_to_node)
  253. if gather is None or gather.op_type != "Gather":
  254. logger.debug("failed to get gather")
  255. return False
  256. segment_initializer = self.get_initializer(gather.input[0])
  257. if segment_initializer is None:
  258. logger.debug("failed to get segment initializer")
  259. return False
  260. temp = numpy_helper.to_array(segment_initializer)
  261. if len(temp.shape) == 2:
  262. logger.info(f"Found segment embedding. name:{segment_initializer.name}, shape:{temp.shape}")
  263. segment_embedding = segment_initializer.name
  264. else:
  265. logger.info(f"Failed to find segment embedding. name:{segment_initializer.name}, shape:{temp.shape}")
  266. return False
  267. logger.info("Create Embedding node")
  268. self.create_embedding_subgraph(node, word_embedding, segment_embedding, position_embedding)
  269. return True
  270. def process_embedding(self):
  271. """
  272. Automatically detect word, segment and position embeddings.
  273. """
  274. logger.info("start processing embedding layer...")
  275. output_name_to_node = self.output_name_to_node()
  276. for node in self.nodes():
  277. if node.op_type == "LayerNormalization":
  278. if self.fuse_embedding(node, output_name_to_node):
  279. return
  280. break
  281. def fuse_mask(self):
  282. nodes_to_remove = []
  283. for node in self.nodes():
  284. if node.op_type == "Mul" and self.has_constant_input(node, -10000):
  285. mask_path = self.match_parent_path(node, ["Sub", "Cast", "Slice", "Unsqueeze"], [0, 1, 0, 0])
  286. if mask_path is None:
  287. continue
  288. sub_node, cast_node, slice_node, unsqueeze_node = mask_path
  289. mask_input_name = self.attention_mask.get_first_mask()
  290. if unsqueeze_node.input[0] != mask_input_name:
  291. print(f"Cast input {unsqueeze_node.input[0]} is not mask input {mask_input_name}")
  292. continue
  293. unsqueeze_added_1 = onnx.helper.make_node(
  294. "Unsqueeze",
  295. inputs=[mask_input_name],
  296. outputs=["mask_fuse_unsqueeze1_output"],
  297. name="Mask_UnSqueeze_1",
  298. axes=[1],
  299. )
  300. unsqueeze_added_2 = onnx.helper.make_node(
  301. "Unsqueeze",
  302. inputs=["mask_fuse_unsqueeze1_output"],
  303. outputs=["mask_fuse_unsqueeze2_output"],
  304. name="Mask_UnSqueeze_2",
  305. axes=[2],
  306. )
  307. # self.replace_node_input(cast_node, cast_node.input[0], 'mask_fuse_unsqueeze2_output')
  308. cast_node_2 = onnx.helper.make_node(
  309. "Cast",
  310. inputs=["mask_fuse_unsqueeze2_output"],
  311. outputs=["mask_fuse_cast_output"],
  312. )
  313. cast_node_2.attribute.extend([onnx.helper.make_attribute("to", 1)])
  314. self.replace_node_input(sub_node, sub_node.input[1], "mask_fuse_cast_output")
  315. nodes_to_remove.extend([slice_node, unsqueeze_node, cast_node])
  316. self.add_node(unsqueeze_added_1)
  317. self.add_node(unsqueeze_added_2)
  318. self.add_node(cast_node_2)
  319. self.remove_nodes(nodes_to_remove)
  320. # Prune graph is done after removing nodes to remove island nodes.
  321. if len(nodes_to_remove) > 0:
  322. self.prune_graph()
  323. logger.info("Fused mask" if len(nodes_to_remove) > 0 else "Failed to fuse mask")
  324. def remove_extra_reshape(self):
  325. skiplayernorm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
  326. reshape_removed = 0
  327. for skiplayernorm_node in skiplayernorm_nodes:
  328. path = self.match_parent_path(
  329. skiplayernorm_node,
  330. [
  331. "Add",
  332. "Reshape",
  333. "MatMul",
  334. "Reshape",
  335. "Gelu",
  336. "Add",
  337. "Reshape",
  338. "MatMul",
  339. "SkipLayerNormalization",
  340. ],
  341. [0, 0, 0, 0, 0, 0, 0, 0, 0],
  342. )
  343. if path is None:
  344. continue
  345. (
  346. add_1,
  347. reshape_1,
  348. matmul_1,
  349. reshape_2,
  350. gelu,
  351. add_2,
  352. reshape_3,
  353. matmul_2,
  354. skiplayernorm,
  355. ) = path
  356. add_2.input[0] = matmul_2.output[0]
  357. self.remove_node(reshape_3)
  358. matmul_1.input[0] = gelu.output[0]
  359. self.remove_node(reshape_2)
  360. add_1.input[0] = matmul_1.output[0]
  361. self.remove_node(reshape_1)
  362. reshape_removed += 3
  363. return reshape_removed
  364. def remove_extra_reshape_2(self):
  365. skiplayernorm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
  366. reshape_removed = 0
  367. for skiplayernorm_node in skiplayernorm_nodes:
  368. path = self.match_parent_path(
  369. skiplayernorm_node,
  370. [
  371. "Add",
  372. "Reshape",
  373. "MatMul",
  374. "Reshape",
  375. "Gelu",
  376. "Add",
  377. "Reshape",
  378. "MatMul",
  379. "Reshape",
  380. "SkipLayerNormalization",
  381. ],
  382. [None, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  383. )
  384. if path is None:
  385. continue
  386. (
  387. add_1,
  388. reshape_1,
  389. matmul_1,
  390. reshape_2,
  391. gelu,
  392. add_2,
  393. reshape_3,
  394. matmul_2,
  395. reshape_4,
  396. skiplayernorm,
  397. ) = path
  398. matmul_2.input[0] = skiplayernorm.output[0]
  399. self.remove_node(reshape_4)
  400. add_2.input[0] = matmul_2.output[0]
  401. self.remove_node(reshape_3)
  402. matmul_1.input[0] = gelu.output[0]
  403. self.remove_node(reshape_2)
  404. add_1.input[0] = matmul_1.output[0]
  405. self.remove_node(reshape_1)
  406. reshape_removed += 4
  407. return reshape_removed
  408. def postprocess(self):
  409. reshape_removed = self.remove_extra_reshape() + self.remove_extra_reshape_2()
  410. logger.info(f"Remove {reshape_removed} Reshape nodes.")
  411. self.prune_graph()