convert_to_packing_mode.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import argparse
  6. import logging
  7. import os
  8. import coloredlogs
  9. from constants import (
  10. AttentionInputIDs,
  11. AttentionOutputIDs,
  12. MultiHeadAttentionInputIDs,
  13. MultiHeadAttentionOutputIDs,
  14. Operators,
  15. )
  16. from onnx import helper, load_model
  17. from onnx_model import NodeProto, OnnxModel
  18. from shape_infer_helper import SymbolicShapeInferenceHelper
  19. logger = logging.getLogger(__name__)
  20. class PackingAttentionBase:
  21. def __init__(self, model: OnnxModel, attention_op_type: str):
  22. self.model: OnnxModel = model
  23. self.nodes_to_remove: list = []
  24. self.nodes_to_add: list = []
  25. self.prune_graph: bool = False
  26. self.node_name_to_graph_name: dict = {}
  27. self.this_graph_name: str = self.model.model.graph.name
  28. self.attention_op_type = attention_op_type
  29. self.attention_nodes = self.model.get_nodes_by_op_type(attention_op_type)
  30. def _try_getting_attention_mask(self) -> str | None:
  31. mask_index = (
  32. AttentionInputIDs.MASK_INDEX
  33. if self.attention_op_type == Operators.ATTENTION
  34. else MultiHeadAttentionInputIDs.KEY_PADDING_MASK
  35. )
  36. first_attention_node = self._try_getting_first_attention()
  37. # check if attention has mask
  38. if not first_attention_node or len(first_attention_node.input) <= mask_index:
  39. return None
  40. attention_mask = first_attention_node.input[mask_index]
  41. # check if all attention nodes have same mask
  42. for node in self.attention_nodes:
  43. if len(node.input) <= mask_index or node.input[mask_index] != attention_mask:
  44. return None
  45. return attention_mask
  46. def _try_getting_first_attention(self) -> NodeProto | None:
  47. if len(self.attention_nodes) <= 0:
  48. return None
  49. return self.attention_nodes[0]
  50. def _try_getting_last_layernorm(self) -> NodeProto | None:
  51. last_layernorm_node = None
  52. for node in self.model.nodes():
  53. if node.op_type == Operators.LAYERNORM or node.op_type == Operators.SKIPLAYERNORM:
  54. last_layernorm_node = node
  55. return last_layernorm_node
  56. def _are_attentions_supported(self) -> bool:
  57. raise NotImplementedError()
  58. def _insert_removepadding_node(self, inputs: list[str], outputs: list[str]) -> None:
  59. new_node = helper.make_node(
  60. Operators.REMOVEPADDING,
  61. inputs=inputs,
  62. outputs=outputs,
  63. name=self.model.create_node_name(Operators.REMOVEPADDING),
  64. )
  65. new_node.domain = "com.microsoft"
  66. self.nodes_to_add.append(new_node)
  67. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  68. def _insert_restorepadding_node(self, inputs: list[str], outputs: list[str]) -> None:
  69. new_node = helper.make_node(
  70. Operators.RESTOREPADDING,
  71. inputs=inputs,
  72. outputs=outputs,
  73. name=self.model.create_node_name(Operators.RESTOREPADDING),
  74. )
  75. new_node.domain = "com.microsoft"
  76. self.nodes_to_add.append(new_node)
  77. self.node_name_to_graph_name[new_node.name] = self.this_graph_name
  78. def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
  79. raise NotImplementedError()
  80. def _get_input_to_remove_padding(self, first_attention_node) -> str | None:
  81. if self.attention_op_type == Operators.ATTENTION:
  82. return first_attention_node.input[AttentionInputIDs.INPUT]
  83. return None
  84. def convert(self, use_symbolic_shape_infer: bool = True) -> None:
  85. logger.debug("start converting to packing model...")
  86. if not self._are_attentions_supported():
  87. return
  88. attention_mask = self._try_getting_attention_mask()
  89. if not attention_mask:
  90. return
  91. first_attention_node = self._try_getting_first_attention()
  92. last_layernorm_node = self._try_getting_last_layernorm()
  93. if not last_layernorm_node:
  94. return
  95. # insert RemovePadding
  96. input_to_remove_padding = self._get_input_to_remove_padding(first_attention_node)
  97. if not input_to_remove_padding:
  98. return
  99. output_without_padding = input_to_remove_padding + "_no_padding"
  100. token_offset = input_to_remove_padding + "_token_offset"
  101. cumulated_seq_len = input_to_remove_padding + "_cumulated_seq_len"
  102. max_seq_len = input_to_remove_padding + "_max_seq_len"
  103. self._insert_removepadding_node(
  104. [input_to_remove_padding, attention_mask],
  105. [output_without_padding, token_offset, cumulated_seq_len, max_seq_len],
  106. )
  107. self.model.replace_input_of_all_nodes(input_to_remove_padding, output_without_padding)
  108. logger.debug("inserted RemovePadding before Attention")
  109. # insert RestorePadding
  110. restorepadding_input = last_layernorm_node.output[0] + "_restore_input"
  111. self._insert_restorepadding_node([restorepadding_input, token_offset], [last_layernorm_node.output[0]])
  112. self.model.replace_output_of_all_nodes(last_layernorm_node.output[0], restorepadding_input)
  113. logger.debug(f"inserted RestorePadding after last {last_layernorm_node.op_type} layer")
  114. # insert PackedAttention
  115. self._replace_attention_with_packing_attention(token_offset, cumulated_seq_len)
  116. logger.debug(f"replaced {self.attention_op_type} with Packed{self.attention_op_type}")
  117. self.model.remove_nodes(self.nodes_to_remove)
  118. self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
  119. if self.prune_graph:
  120. self.model.prune_graph()
  121. elif self.nodes_to_remove or self.nodes_to_add:
  122. self.model.update_graph()
  123. self.model.clean_shape_infer()
  124. if use_symbolic_shape_infer:
  125. # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
  126. # are not recognized by onnx shape inference.
  127. shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model, verbose=0)
  128. inferred_model = shape_infer_helper.infer_shapes(self.model.model, auto_merge=True, guess_output_rank=False)
  129. if inferred_model:
  130. self.model.model = inferred_model
  131. class PackingAttention(PackingAttentionBase):
  132. def __init__(self, model: OnnxModel):
  133. super().__init__(model, Operators.ATTENTION)
  134. def _are_attentions_supported(self) -> bool:
  135. for node in self.attention_nodes:
  136. if OnnxModel.get_node_attribute(node, "past_present_share_buffer") is not None:
  137. return False
  138. if OnnxModel.get_node_attribute(node, "do_rotary") is not None:
  139. return False
  140. unidirection_attr = OnnxModel.get_node_attribute(node, "unidirectional")
  141. if unidirection_attr is not None and unidirection_attr != 0:
  142. return False
  143. if len(node.input) > AttentionInputIDs.PAST and not node.input[AttentionInputIDs.PAST]:
  144. return False
  145. if (
  146. len(node.input) > AttentionInputIDs.PAST_SEQUENCE_LENGTH
  147. and not node.input[AttentionInputIDs.PAST_SEQUENCE_LENGTH]
  148. ):
  149. return False
  150. return True
  151. def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
  152. for attention in self.attention_nodes:
  153. attention_bias = (
  154. attention.input[AttentionInputIDs.ATTENTION_BIAS]
  155. if len(attention.input) > AttentionInputIDs.ATTENTION_BIAS
  156. else ""
  157. )
  158. packed_attention = helper.make_node(
  159. Operators.PACKEDATTENTION,
  160. inputs=[
  161. attention.input[AttentionInputIDs.INPUT],
  162. attention.input[AttentionInputIDs.WEIGHTS],
  163. attention.input[AttentionInputIDs.BIAS],
  164. token_offset,
  165. cumulative_sequence_length,
  166. attention_bias,
  167. ],
  168. outputs=[attention.output[AttentionOutputIDs.OUTPUT]],
  169. name=self.model.create_node_name(Operators.PACKEDATTENTION),
  170. )
  171. attributes = []
  172. for attr in attention.attribute:
  173. if attr.name in ["num_heads", "qkv_hidden_sizes", "scale"]:
  174. attributes.append(attr)
  175. packed_attention.attribute.extend(attributes)
  176. packed_attention.domain = "com.microsoft"
  177. self.nodes_to_add.append(packed_attention)
  178. self.nodes_to_remove.append(attention)
  179. self.node_name_to_graph_name[packed_attention.name] = self.this_graph_name
  180. logger.info("Converted %d Attention nodes to PackedAttention.", len(self.attention_nodes))
  181. class PackingMultiHeadAttention(PackingAttentionBase):
  182. def __init__(self, model: OnnxModel):
  183. super().__init__(model, Operators.MULTI_HEAD_ATTENTION)
  184. def _check_empty_input(self, node, index: int, name: str):
  185. """Check a node does not have given input."""
  186. if len(node.input) > index:
  187. if len(node.input[index]) > 0:
  188. logger.error(f"node input {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
  189. return False
  190. return True
  191. def _check_empty_output(self, node, index: int, name: str):
  192. """Check a node does not have given input."""
  193. if len(node.output) > index:
  194. if len(node.output[index]) > 0:
  195. logger.error(f"node output {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
  196. return False
  197. return True
  198. def _are_attentions_supported(self) -> bool:
  199. for node in self.attention_nodes:
  200. for attr in node.attribute:
  201. if attr.name not in ["num_heads", "mask_filter_value", "scale"]:
  202. logger.error(f"node attribute {attr.name} is not supported in PackedMultiHeadAttention: {node}")
  203. return False
  204. if node.input[MultiHeadAttentionInputIDs.KEY] and not node.input[MultiHeadAttentionInputIDs.VALUE]:
  205. logger.error("packed kv format is not supported in PackedMultiHeadAttention")
  206. return False
  207. if not (
  208. self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_KEY, "past_key")
  209. and self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_VALUE, "past_key")
  210. and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_KEY, "present_key")
  211. and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_VALUE, "present_key")
  212. ):
  213. return False
  214. return True
  215. def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
  216. gated_relative_pos_bias_count = 0
  217. for mha in self.attention_nodes:
  218. attention_bias = (
  219. mha.input[MultiHeadAttentionInputIDs.ATTENTION_BIAS]
  220. if len(mha.input) > MultiHeadAttentionInputIDs.ATTENTION_BIAS
  221. else ""
  222. )
  223. packed_mha = helper.make_node(
  224. Operators.PACKED_MULTI_HEAD_ATTENTION,
  225. inputs=[
  226. mha.input[MultiHeadAttentionInputIDs.QUERY],
  227. mha.input[MultiHeadAttentionInputIDs.KEY],
  228. mha.input[MultiHeadAttentionInputIDs.VALUE],
  229. mha.input[MultiHeadAttentionInputIDs.BIAS],
  230. token_offset,
  231. cumulative_sequence_length,
  232. attention_bias,
  233. ],
  234. outputs=[mha.output[MultiHeadAttentionOutputIDs.OUTPUT]],
  235. name=self.model.create_node_name(Operators.PACKED_MULTI_HEAD_ATTENTION),
  236. )
  237. attributes = []
  238. for attr in mha.attribute:
  239. if attr.name in ["num_heads", "mask_filter_value", "scale"]:
  240. attributes.append(attr)
  241. packed_mha.attribute.extend(attributes)
  242. packed_mha.domain = "com.microsoft"
  243. self.nodes_to_add.append(packed_mha)
  244. self.nodes_to_remove.append(mha)
  245. self.node_name_to_graph_name[packed_mha.name] = self.this_graph_name
  246. # Append token_offset input to GatedRelativePositionBias
  247. if attention_bias:
  248. rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.ATTENTION_BIAS)
  249. if (
  250. rel_pos_bias_node
  251. and rel_pos_bias_node.op_type == "GatedRelativePositionBias"
  252. and len(rel_pos_bias_node.input) == 6
  253. ):
  254. rel_pos_bias_node.input.append(token_offset)
  255. gated_relative_pos_bias_count += 1
  256. logger.info("Converted %d MultiHeadAttention nodes to PackedMultiHeadAttention.", len(self.attention_nodes))
  257. logger.info("Converted %d GatedRelativePositionBias nodes to packing mode.", gated_relative_pos_bias_count)
  258. def _get_input_to_remove_padding(self, first_attention_node) -> str | None:
  259. # When there are query, key and value inputs, we need to find the first input of the parent MatMul node.
  260. matmul = self.model.get_parent(first_attention_node, 0)
  261. if matmul and matmul.op_type == "MatMul":
  262. return matmul.input[0]
  263. return None
  264. class PackingMode:
  265. def __init__(self, model: OnnxModel):
  266. self.model = model
  267. def convert(self, use_symbolic_shape_infer: bool = True) -> None:
  268. if self.model.get_nodes_by_op_type(Operators.ATTENTION):
  269. if self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
  270. logger.error("Packing mode does not support both Attention and MultiHeadAttention in same graph.")
  271. return None
  272. packing = PackingAttention(self.model)
  273. return packing.convert(use_symbolic_shape_infer)
  274. elif self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
  275. packing = PackingMultiHeadAttention(self.model)
  276. return packing.convert(use_symbolic_shape_infer)
  277. else:
  278. logger.error("Packing mode requires either Attention or MultiHeadAttention node in onnx graph.")
  279. return None
  280. def _parse_arguments():
  281. parser = argparse.ArgumentParser(
  282. description="Convert to packing mode tool for ONNX Runtime. It converts BERT like model to use packing mode."
  283. )
  284. parser.add_argument("--input", required=True, type=str, help="input onnx model path")
  285. parser.add_argument("--output", required=True, type=str, help="optimized onnx model path")
  286. parser.add_argument("--verbose", required=False, action="store_true", help="show debug information.")
  287. parser.set_defaults(verbose=False)
  288. parser.add_argument(
  289. "--use_external_data_format",
  290. required=False,
  291. action="store_true",
  292. help="use external data format to store large model (>2GB)",
  293. )
  294. parser.set_defaults(use_external_data_format=False)
  295. args = parser.parse_args()
  296. return args
  297. def _setup_logger(verbose):
  298. if verbose:
  299. coloredlogs.install(
  300. level="DEBUG",
  301. fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
  302. )
  303. else:
  304. coloredlogs.install(fmt="%(funcName)20s: %(message)s")
  305. def main():
  306. args = _parse_arguments()
  307. _setup_logger(args.verbose)
  308. logger.debug(f"arguments:{args}")
  309. if os.path.realpath(args.input) == os.path.realpath(args.output):
  310. logger.warning("Specified the same input and output path. Note that this may overwrite the original model")
  311. model = load_model(args.input)
  312. packing_mode = PackingMode(OnnxModel(model))
  313. packing_mode.convert()
  314. packing_mode.model.save_model_to_file(args.output, use_external_data_format=args.use_external_data_format)
  315. if __name__ == "__main__":
  316. main()