onnx_model_bert_tf.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import numpy as np
  7. import onnx
  8. from onnx import TensorProto, helper, numpy_helper
  9. from onnx_model_bert import BertOnnxModel
  10. logger = logging.getLogger(__name__)
  11. class BertOnnxModelTF(BertOnnxModel):
  12. def __init__(self, model, num_heads, hidden_size):
  13. super().__init__(model, num_heads, hidden_size)
  14. def remove_identity(self):
  15. nodes_to_remove = []
  16. for node in self.nodes():
  17. if node.op_type == "Identity":
  18. if not self.find_graph_output(node.output[0]):
  19. self.replace_input_of_all_nodes(node.output[0], node.input[0])
  20. nodes_to_remove.append(node)
  21. self.remove_nodes(nodes_to_remove)
  22. logger.info(f"Removed Identity count: {len(nodes_to_remove)}")
  23. def match_mask_path(self, add_or_sub_before_softmax):
  24. mask_nodes = self.match_parent_path(
  25. add_or_sub_before_softmax,
  26. ["Mul", "Sub", "Reshape", "Cast"],
  27. [1, None, 1, 0],
  28. )
  29. if mask_nodes is not None:
  30. return mask_nodes
  31. mask_nodes = self.match_parent_path(
  32. add_or_sub_before_softmax,
  33. ["Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
  34. [1, 0, 1, 0, 0],
  35. )
  36. if mask_nodes is not None:
  37. return mask_nodes
  38. mask_nodes = self.match_parent_path(
  39. add_or_sub_before_softmax,
  40. ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
  41. [1, None, 1, 0, 0],
  42. )
  43. return mask_nodes
  44. def get_2d_initializers_from_parent_subgraphs(self, current_node):
  45. """
  46. Find initializers that is 2D. Returns a dictionary with name as key and shape as value.
  47. """
  48. parent_nodes = self.get_parent_subgraph_nodes(current_node, [])
  49. initializers = {}
  50. for node in parent_nodes:
  51. for input in node.input:
  52. initializer = self.get_initializer(input)
  53. if initializer:
  54. temp = numpy_helper.to_array(initializer)
  55. if len(temp.shape) == 2:
  56. initializers[initializer.name] = temp.shape
  57. return initializers
  58. def find_segment_ids(self, segment_embedding, input_ids):
  59. input_name_to_nodes = self.input_name_to_nodes()
  60. if segment_embedding not in input_name_to_nodes:
  61. return None
  62. nodes = input_name_to_nodes[segment_embedding]
  63. if len(nodes) != 1:
  64. return None
  65. graph_inputs = self.get_graph_inputs(nodes[0], recursive=True)
  66. if len(graph_inputs) > 1:
  67. print("Found multiple candidates of segment_ids", graph_inputs)
  68. return None
  69. # Find segment ids in graph inputs. The segment id input must not be the same as input_ids.
  70. if len(graph_inputs) == 1 and graph_inputs[0] != input_ids:
  71. return graph_inputs[0]
  72. # If the segment id candidate is the same as the input_ids, try to assign alternative segment ids and simplify the graph if needed.
  73. segment_ids = nodes[0].input[1]
  74. _, segment_id_path, _ = self.match_parent_paths(
  75. nodes[0],
  76. [
  77. (
  78. ["ConstantOfShape", "Cast", "Concat", "Slice", "Cast", "Shape"],
  79. [1, 0, 0, 0, 0, 0],
  80. ),
  81. (
  82. [
  83. "ConstantOfShape",
  84. "Cast",
  85. "Concat",
  86. "Unsqueeze",
  87. "Squeeze",
  88. "Slice",
  89. "Cast",
  90. "Shape",
  91. ],
  92. [1, 0, 0, 0, 0, 0, 0, 0],
  93. ),
  94. ],
  95. None,
  96. )
  97. if segment_id_path and input_ids and input_ids == segment_id_path[-1].input[0]:
  98. logger.debug("Simplify semgent id path...")
  99. constantofshape_node = segment_id_path[0]
  100. graph_name = self.get_graph_by_node(constantofshape_node).name
  101. self.add_node(
  102. helper.make_node("Shape", inputs=[input_ids], outputs=["input_shape"]),
  103. graph_name,
  104. )
  105. constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0])
  106. self.add_node(
  107. helper.make_node(
  108. "ConstantOfShape",
  109. inputs=["input_shape"],
  110. outputs=["zeros_for_input_shape"],
  111. value=constantofshape_value,
  112. ),
  113. graph_name,
  114. )
  115. segment_ids = "zeros_for_input_shape"
  116. return segment_ids
  117. def find_input_ids(self, word_embedding):
  118. input_name_to_nodes = self.input_name_to_nodes()
  119. if word_embedding not in input_name_to_nodes:
  120. return None
  121. nodes = input_name_to_nodes[word_embedding]
  122. if len(nodes) != 1:
  123. return None
  124. graph_inputs = self.get_graph_inputs(nodes[0], recursive=True)
  125. if len(graph_inputs) == 1:
  126. return graph_inputs[0]
  127. print("Found multiple candidates of input_ids", graph_inputs)
  128. return None
  129. def find_mask_input(self, excluded_graph_inputs):
  130. for node in self.nodes():
  131. if node.op_type == "Softmax":
  132. mask_path = self.match_parent_path(
  133. node,
  134. ["Add", "Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
  135. [0, 1, None, 1, 0, 0],
  136. )
  137. if mask_path is None:
  138. continue
  139. (
  140. add_node,
  141. mul_node,
  142. sub_node,
  143. cast_node,
  144. slice_node,
  145. unsqueeze_node,
  146. ) = mask_path
  147. if self.has_constant_input(mul_node, -10000) and self.has_constant_input(sub_node, 1):
  148. graph_inputs = self.get_graph_inputs(sub_node, recursive=True)
  149. inputs = [input for input in graph_inputs if input not in excluded_graph_inputs]
  150. if len(inputs) > 1:
  151. print("Found multiple candidates of mask input", inputs)
  152. return None
  153. if len(inputs) == 1:
  154. return inputs[0]
  155. # Duplicated input found. Try to simplify the graph.
  156. path_to_be_simplified = self.match_parent_path(
  157. mask_path[-1],
  158. [
  159. "ConstantOfShape",
  160. "Cast",
  161. "Concat",
  162. "Unsqueeze",
  163. "Squeeze",
  164. "Slice",
  165. "Cast",
  166. "Shape",
  167. ],
  168. [0, 0, 0, 0, 0, 0, 0, 0],
  169. )
  170. duplicated_inputs = [input for input in graph_inputs if input in excluded_graph_inputs]
  171. # Simplify graph for dynamic axes.
  172. if (
  173. path_to_be_simplified
  174. and duplicated_inputs
  175. and len(duplicated_inputs) == 1
  176. and duplicated_inputs[0] == path_to_be_simplified[-1].input[0]
  177. ):
  178. logger.debug("Simplify semgent id path...")
  179. constantofshape_node = path_to_be_simplified[0]
  180. constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0])
  181. graph_name = self.get_graph_by_node(constantofshape_node).name
  182. self.add_node(
  183. helper.make_node(
  184. "Shape",
  185. inputs=[duplicated_inputs[0]],
  186. outputs=["input_shape_for_mask"],
  187. ),
  188. graph_name,
  189. )
  190. self.add_node(
  191. helper.make_node(
  192. "ConstantOfShape",
  193. inputs=["input_shape_for_mask"],
  194. outputs=[unsqueeze_node.input[0]],
  195. value=constantofshape_value,
  196. ),
  197. graph_name,
  198. )
  199. return unsqueeze_node.input[0]
  200. return None
  201. def create_embedding_subgraph(self, normalize_node, word_embedding, segment_embedding, position_embedding):
  202. input_ids = self.find_input_ids(word_embedding)
  203. if input_ids is None:
  204. logger.info("Failed to find input_ids. Cannot fuse embedding layer.")
  205. return False
  206. segment_ids = self.find_segment_ids(segment_embedding, input_ids)
  207. if segment_ids is None:
  208. logger.info("Failed to find segment_ids. Cannot fuse embedding layer.")
  209. return False
  210. mask_input = self.find_mask_input([segment_ids, input_ids])
  211. if mask_input is None:
  212. logger.info("Failed to find input_mask. Cannot fuse embedding layer.")
  213. return False
  214. self.bert_inputs = [input_ids, segment_ids, mask_input]
  215. mask_index = self.create_node_name("mask_index")
  216. self.attention_mask.set_mask_indice(mask_input, mask_index)
  217. if self.find_graph_input(input_ids).type.tensor_type.elem_type != TensorProto.INT32:
  218. casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids)
  219. if self.find_graph_input(segment_ids):
  220. casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids)
  221. else:
  222. segment_ids, segment_id_cast_node = self.utils.cast_input_to_int32(segment_ids)
  223. if self.find_graph_input(mask_input):
  224. casted, mask_input = self.utils.cast_graph_input_to_int32(mask_input)
  225. else:
  226. mask_input, mask_input_cast_node = self.utils.cast_input_to_int32(mask_input)
  227. embed_output = self.create_node_name("embed_output")
  228. embed_node = onnx.helper.make_node(
  229. "EmbedLayerNormalization",
  230. inputs=[
  231. input_ids,
  232. segment_ids,
  233. word_embedding,
  234. position_embedding,
  235. segment_embedding,
  236. normalize_node.input[1], # gamma
  237. normalize_node.input[2], # beta
  238. mask_input,
  239. ],
  240. outputs=[embed_output, mask_index],
  241. name="EmbedLayer",
  242. )
  243. embed_node.domain = "com.microsoft"
  244. self.replace_input_of_all_nodes(normalize_node.output[0], embed_output)
  245. self.add_node(embed_node, self.get_graph_by_node(normalize_node).name)
  246. def process_embedding(self):
  247. """
  248. Automatically detect word, segment and position embeddings.
  249. """
  250. logger.info("start processing embedding layer...")
  251. output_name_to_node = self.output_name_to_node()
  252. layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
  253. for layer_norm_node in layer_norm_nodes:
  254. pos_embed_path = self.match_parent_path(
  255. layer_norm_node,
  256. ["Add", "Reshape", "Slice"],
  257. [0, 1, 0],
  258. output_name_to_node,
  259. )
  260. if pos_embed_path is None:
  261. continue
  262. add_node, reshape_node, slice_node = pos_embed_path
  263. initializer = self.get_initializer(slice_node.input[0])
  264. if initializer is None:
  265. continue
  266. temp = numpy_helper.to_array(initializer)
  267. if len(temp.shape) == 2:
  268. logger.info(f"Found position embedding. name:{initializer.name}, shape:{temp.shape}")
  269. position_embedding = initializer.name
  270. else:
  271. logger.info(f"Failed to find position embedding. name:{initializer.name}, shape:{temp.shape}")
  272. return
  273. first_parent = self.get_parent(add_node, 0, output_name_to_node)
  274. if first_parent is not None and first_parent.op_type == "Add":
  275. embeddings = self.get_2d_initializers_from_parent_subgraphs(first_parent)
  276. if len(embeddings) != 2:
  277. logger.warning(
  278. f"Failed to find two embeddings (word and segment) from Add node. Found {embeddings}"
  279. )
  280. return
  281. word_embedding = None
  282. segment_embedding = None
  283. for name, shape in embeddings.items():
  284. if shape[0] == 2:
  285. segment_embedding = name
  286. logger.info(f"Found segment embedding. name:{name}, shape:{shape}")
  287. else:
  288. word_embedding = name
  289. logger.info(f"Found words embedding. name:{name}, shape:{shape}")
  290. if word_embedding is None or segment_embedding is None:
  291. logger.info("Failed to find both word and segment embedding")
  292. return
  293. logger.info("Create Embedding node")
  294. self.create_embedding_subgraph(
  295. layer_norm_node,
  296. word_embedding,
  297. segment_embedding,
  298. position_embedding,
  299. )
  300. # Prune graph to remove those original embedding nodes.
  301. self.prune_graph()
  302. break
  303. def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node):
  304. for x in [matmul_q, matmul_k, matmul_v]:
  305. root_input = x.input[0]
  306. root_node = output_name_to_node[root_input]
  307. if root_node == parent:
  308. continue
  309. logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}")
  310. return False
  311. return True
  312. def fuse_attention(self):
  313. output_name_to_node = self.output_name_to_node()
  314. nodes_to_remove = []
  315. attention_count = 0
  316. start_nodes = []
  317. skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
  318. layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
  319. # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
  320. # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
  321. start_nodes.extend(skip_layer_norm_nodes)
  322. start_nodes.extend(layer_norm_nodes)
  323. for normalize_node in start_nodes:
  324. graph_name = self.get_graph_by_node(normalize_node).name
  325. # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
  326. if normalize_node.op_type == "LayerNormalization":
  327. add_before_layernorm = self.match_parent(normalize_node, "Add", 0)
  328. if add_before_layernorm is not None:
  329. normalize_node = add_before_layernorm # noqa: PLW2901
  330. else:
  331. continue
  332. parent = self.get_parent(normalize_node, 1)
  333. if parent is None or parent.op_type not in [
  334. "SkipLayerNormalization",
  335. "LayerNormalization",
  336. "Reshape",
  337. ]:
  338. parent = self.get_parent(normalize_node, 0)
  339. if parent is None or parent.op_type not in [
  340. "SkipLayerNormalization",
  341. "LayerNormalization",
  342. "Reshape",
  343. ]:
  344. logger.debug("Failed to match parent of normalize_node")
  345. continue
  346. qkv_nodes = self.match_parent_path(
  347. normalize_node,
  348. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  349. [0, 0, 0, 0, 0],
  350. )
  351. if qkv_nodes is None:
  352. qkv_nodes = self.match_parent_path(
  353. normalize_node,
  354. ["MatMul", "Reshape", "Transpose", "MatMul"],
  355. [1, 0, 0, 0],
  356. )
  357. if qkv_nodes is None:
  358. qkv_nodes = self.match_parent_path(normalize_node, ["Add", "Einsum", "Einsum"], [0, 0, 0])
  359. if qkv_nodes is None:
  360. logger.debug("Failed to match qkv nodes")
  361. continue
  362. matmul_qkv = qkv_nodes[-1]
  363. v_nodes = self.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0])
  364. if v_nodes is None:
  365. v_nodes = self.match_parent_path(matmul_qkv, ["Add", "Einsum"], [1, 0])
  366. if v_nodes is None:
  367. logger.debug("Failed to match v path")
  368. continue
  369. add_v = v_nodes[-2]
  370. matmul_v = v_nodes[-1]
  371. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
  372. if qk_nodes is None:
  373. qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Einsum"], [0, 0, 0])
  374. if qk_nodes is None:
  375. logger.debug("Failed to match qk_paths")
  376. continue
  377. matmul_qk = qk_nodes[-1]
  378. q_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0])
  379. if q_nodes is None:
  380. q_nodes = self.match_parent_path(matmul_qk, ["Add", "Einsum"], [0, 0])
  381. if q_nodes is None:
  382. logger.debug("Failed to match q path")
  383. continue
  384. add_q = q_nodes[-2]
  385. matmul_q = q_nodes[-1]
  386. k_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0])
  387. if k_nodes is None:
  388. k_nodes = self.match_parent_path(matmul_qk, ["Mul", "Add", "Einsum"], [1, 0, 0])
  389. if k_nodes is None:
  390. logger.debug("Failed to match k path")
  391. continue
  392. add_k = k_nodes[-2]
  393. matmul_k = k_nodes[-1]
  394. mask_nodes = self.match_mask_path(qk_nodes[1])
  395. if mask_nodes is None:
  396. logger.debug("Cannot find mask_nodes.")
  397. continue
  398. if not self.has_constant_input(mask_nodes[1], 1):
  399. logger.debug("Sub node expected to have an input with constant value 1.0.")
  400. continue
  401. # add a squeeze node to convert a 3-d mask to 2-d
  402. squeeze_node = self.match_parent_path(mask_nodes[-1], ["Squeeze"], [0]) or self.match_parent_path(
  403. mask_nodes[-1], ["Expand"], [0]
  404. )
  405. squeeze_node_name = "Squeeze_3d_to_2d_mask"
  406. squeeze_output_name = squeeze_node_name + "_output"
  407. if squeeze_node is None and len(mask_nodes) == 5 and self.find_graph_input(mask_nodes[-1].input[0]) is None:
  408. mask_input = mask_nodes[-1].input[1]
  409. self.add_node(
  410. helper.make_node(
  411. "Squeeze",
  412. [mask_input],
  413. [squeeze_output_name],
  414. squeeze_node_name,
  415. axes=[1],
  416. ),
  417. graph_name,
  418. )
  419. mask_nodes[-1].input[0] = squeeze_output_name
  420. is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node)
  421. if is_same_root:
  422. mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
  423. logger.debug("Create an Attention node.")
  424. # For tf models, q and v are flipped.
  425. attention_node = self.attention_fusion.create_attention_node(
  426. mask_index=mask_index,
  427. q_matmul=matmul_k,
  428. k_matmul=matmul_q,
  429. v_matmul=matmul_v,
  430. q_add=add_k,
  431. k_add=add_q,
  432. v_add=add_v,
  433. num_heads=self.num_heads,
  434. hidden_size=self.hidden_size,
  435. first_input=parent.output[0],
  436. output=qkv_nodes[2].output[0],
  437. )
  438. if attention_node is None:
  439. continue
  440. if qkv_nodes[1].op_type == "Einsum":
  441. # add reshape before einsum
  442. tensor = helper.make_tensor(
  443. name=qkv_nodes[1].name + "_newshape",
  444. data_type=TensorProto.INT64,
  445. dims=[4],
  446. vals=np.int64(
  447. [
  448. [
  449. 0,
  450. 0,
  451. self.num_heads,
  452. int(self.hidden_size / self.num_heads),
  453. ]
  454. ]
  455. ).tobytes(),
  456. raw=True,
  457. )
  458. self.add_initializer(tensor, graph_name)
  459. reshape_ = helper.make_node(
  460. "Reshape",
  461. inputs=[
  462. attention_node.output[0],
  463. qkv_nodes[1].name + "_newshape",
  464. ],
  465. outputs=[qkv_nodes[1].name + "_reshape_output"],
  466. name=qkv_nodes[1].name + "_reshape",
  467. )
  468. qkv_nodes[1].input[0] = qkv_nodes[1].name + "_reshape_output"
  469. self.add_node(reshape_, graph_name)
  470. if parent.op_type == "Reshape":
  471. # Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input
  472. hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1]
  473. tensor = helper.make_tensor(
  474. name=parent.name + "_modified",
  475. data_type=TensorProto.INT64,
  476. dims=[3],
  477. vals=np.int64([[1, -1, hidden_size]]).tobytes(),
  478. raw=True,
  479. )
  480. self.add_initializer(tensor, graph_name)
  481. parent.input[1] = parent.name + "_modified"
  482. self.add_node(attention_node, graph_name)
  483. attention_count += 1
  484. nodes_to_remove.extend(qkv_nodes[2:])
  485. nodes_to_remove.extend(qk_nodes)
  486. nodes_to_remove.extend(q_nodes)
  487. nodes_to_remove.extend(k_nodes)
  488. nodes_to_remove.extend(v_nodes)
  489. nodes_to_remove.extend(mask_nodes)
  490. else:
  491. logger.debug("Root node not matched.")
  492. continue
  493. self.remove_nodes(nodes_to_remove)
  494. self.update_graph()
  495. logger.info(f"Fused Attention count:{attention_count}")
  496. def preprocess(self):
  497. self.remove_identity()
  498. self.process_embedding()
  499. self.skip_reshape()
  500. def skip_reshape(self):
  501. count = 0
  502. reshape_nodes = self.get_nodes_by_op_type("Reshape")
  503. for reshape_node in reshape_nodes:
  504. parent = self.get_parent(reshape_node, 0)
  505. if parent is not None and parent.op_type == "Reshape":
  506. reshape_node.input[0] = parent.input[0]
  507. count += 1
  508. if count > 0:
  509. logger.info(f"Skip consequent Reshape count: {count}")
  510. def remove_reshape_before_first_attention(self):
  511. attention_nodes = self.get_nodes_by_op_type("Attention")
  512. for attention_node in attention_nodes:
  513. path = self.match_parent_path(attention_node, ["Reshape", "EmbedLayerNormalization"], [0, 0])
  514. if path is None:
  515. continue
  516. logger.info("Remove Reshape before first Attention node.")
  517. reshape, _ = path
  518. self.replace_input_of_all_nodes(reshape.output[0], reshape.input[0])
  519. self.remove_node(reshape)
  520. break
  521. def postprocess(self):
  522. self.remove_reshape_before_first_attention()
  523. self.prune_graph()