| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- import argparse
- import logging
- import os
- import coloredlogs
- from constants import (
- AttentionInputIDs,
- AttentionOutputIDs,
- MultiHeadAttentionInputIDs,
- MultiHeadAttentionOutputIDs,
- Operators,
- )
- from onnx import helper, load_model
- from onnx_model import NodeProto, OnnxModel
- from shape_infer_helper import SymbolicShapeInferenceHelper
- logger = logging.getLogger(__name__)
- class PackingAttentionBase:
- def __init__(self, model: OnnxModel, attention_op_type: str):
- self.model: OnnxModel = model
- self.nodes_to_remove: list = []
- self.nodes_to_add: list = []
- self.prune_graph: bool = False
- self.node_name_to_graph_name: dict = {}
- self.this_graph_name: str = self.model.model.graph.name
- self.attention_op_type = attention_op_type
- self.attention_nodes = self.model.get_nodes_by_op_type(attention_op_type)
- def _try_getting_attention_mask(self) -> str | None:
- mask_index = (
- AttentionInputIDs.MASK_INDEX
- if self.attention_op_type == Operators.ATTENTION
- else MultiHeadAttentionInputIDs.KEY_PADDING_MASK
- )
- first_attention_node = self._try_getting_first_attention()
- # check if attention has mask
- if not first_attention_node or len(first_attention_node.input) <= mask_index:
- return None
- attention_mask = first_attention_node.input[mask_index]
- # check if all attention nodes have same mask
- for node in self.attention_nodes:
- if len(node.input) <= mask_index or node.input[mask_index] != attention_mask:
- return None
- return attention_mask
- def _try_getting_first_attention(self) -> NodeProto | None:
- if len(self.attention_nodes) <= 0:
- return None
- return self.attention_nodes[0]
- def _try_getting_last_layernorm(self) -> NodeProto | None:
- last_layernorm_node = None
- for node in self.model.nodes():
- if node.op_type == Operators.LAYERNORM or node.op_type == Operators.SKIPLAYERNORM:
- last_layernorm_node = node
- return last_layernorm_node
- def _are_attentions_supported(self) -> bool:
- raise NotImplementedError()
- def _insert_removepadding_node(self, inputs: list[str], outputs: list[str]) -> None:
- new_node = helper.make_node(
- Operators.REMOVEPADDING,
- inputs=inputs,
- outputs=outputs,
- name=self.model.create_node_name(Operators.REMOVEPADDING),
- )
- new_node.domain = "com.microsoft"
- self.nodes_to_add.append(new_node)
- self.node_name_to_graph_name[new_node.name] = self.this_graph_name
- def _insert_restorepadding_node(self, inputs: list[str], outputs: list[str]) -> None:
- new_node = helper.make_node(
- Operators.RESTOREPADDING,
- inputs=inputs,
- outputs=outputs,
- name=self.model.create_node_name(Operators.RESTOREPADDING),
- )
- new_node.domain = "com.microsoft"
- self.nodes_to_add.append(new_node)
- self.node_name_to_graph_name[new_node.name] = self.this_graph_name
- def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
- raise NotImplementedError()
- def _get_input_to_remove_padding(self, first_attention_node) -> str | None:
- if self.attention_op_type == Operators.ATTENTION:
- return first_attention_node.input[AttentionInputIDs.INPUT]
- return None
- def convert(self, use_symbolic_shape_infer: bool = True) -> None:
- logger.debug("start converting to packing model...")
- if not self._are_attentions_supported():
- return
- attention_mask = self._try_getting_attention_mask()
- if not attention_mask:
- return
- first_attention_node = self._try_getting_first_attention()
- last_layernorm_node = self._try_getting_last_layernorm()
- if not last_layernorm_node:
- return
- # insert RemovePadding
- input_to_remove_padding = self._get_input_to_remove_padding(first_attention_node)
- if not input_to_remove_padding:
- return
- output_without_padding = input_to_remove_padding + "_no_padding"
- token_offset = input_to_remove_padding + "_token_offset"
- cumulated_seq_len = input_to_remove_padding + "_cumulated_seq_len"
- max_seq_len = input_to_remove_padding + "_max_seq_len"
- self._insert_removepadding_node(
- [input_to_remove_padding, attention_mask],
- [output_without_padding, token_offset, cumulated_seq_len, max_seq_len],
- )
- self.model.replace_input_of_all_nodes(input_to_remove_padding, output_without_padding)
- logger.debug("inserted RemovePadding before Attention")
- # insert RestorePadding
- restorepadding_input = last_layernorm_node.output[0] + "_restore_input"
- self._insert_restorepadding_node([restorepadding_input, token_offset], [last_layernorm_node.output[0]])
- self.model.replace_output_of_all_nodes(last_layernorm_node.output[0], restorepadding_input)
- logger.debug(f"inserted RestorePadding after last {last_layernorm_node.op_type} layer")
- # insert PackedAttention
- self._replace_attention_with_packing_attention(token_offset, cumulated_seq_len)
- logger.debug(f"replaced {self.attention_op_type} with Packed{self.attention_op_type}")
- self.model.remove_nodes(self.nodes_to_remove)
- self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
- if self.prune_graph:
- self.model.prune_graph()
- elif self.nodes_to_remove or self.nodes_to_add:
- self.model.update_graph()
- self.model.clean_shape_infer()
- if use_symbolic_shape_infer:
- # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
- # are not recognized by onnx shape inference.
- shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model, verbose=0)
- inferred_model = shape_infer_helper.infer_shapes(self.model.model, auto_merge=True, guess_output_rank=False)
- if inferred_model:
- self.model.model = inferred_model
- class PackingAttention(PackingAttentionBase):
- def __init__(self, model: OnnxModel):
- super().__init__(model, Operators.ATTENTION)
- def _are_attentions_supported(self) -> bool:
- for node in self.attention_nodes:
- if OnnxModel.get_node_attribute(node, "past_present_share_buffer") is not None:
- return False
- if OnnxModel.get_node_attribute(node, "do_rotary") is not None:
- return False
- unidirection_attr = OnnxModel.get_node_attribute(node, "unidirectional")
- if unidirection_attr is not None and unidirection_attr != 0:
- return False
- if len(node.input) > AttentionInputIDs.PAST and not node.input[AttentionInputIDs.PAST]:
- return False
- if (
- len(node.input) > AttentionInputIDs.PAST_SEQUENCE_LENGTH
- and not node.input[AttentionInputIDs.PAST_SEQUENCE_LENGTH]
- ):
- return False
- return True
- def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
- for attention in self.attention_nodes:
- attention_bias = (
- attention.input[AttentionInputIDs.ATTENTION_BIAS]
- if len(attention.input) > AttentionInputIDs.ATTENTION_BIAS
- else ""
- )
- packed_attention = helper.make_node(
- Operators.PACKEDATTENTION,
- inputs=[
- attention.input[AttentionInputIDs.INPUT],
- attention.input[AttentionInputIDs.WEIGHTS],
- attention.input[AttentionInputIDs.BIAS],
- token_offset,
- cumulative_sequence_length,
- attention_bias,
- ],
- outputs=[attention.output[AttentionOutputIDs.OUTPUT]],
- name=self.model.create_node_name(Operators.PACKEDATTENTION),
- )
- attributes = []
- for attr in attention.attribute:
- if attr.name in ["num_heads", "qkv_hidden_sizes", "scale"]:
- attributes.append(attr)
- packed_attention.attribute.extend(attributes)
- packed_attention.domain = "com.microsoft"
- self.nodes_to_add.append(packed_attention)
- self.nodes_to_remove.append(attention)
- self.node_name_to_graph_name[packed_attention.name] = self.this_graph_name
- logger.info("Converted %d Attention nodes to PackedAttention.", len(self.attention_nodes))
- class PackingMultiHeadAttention(PackingAttentionBase):
- def __init__(self, model: OnnxModel):
- super().__init__(model, Operators.MULTI_HEAD_ATTENTION)
- def _check_empty_input(self, node, index: int, name: str):
- """Check a node does not have given input."""
- if len(node.input) > index:
- if len(node.input[index]) > 0:
- logger.error(f"node input {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
- return False
- return True
- def _check_empty_output(self, node, index: int, name: str):
- """Check a node does not have given input."""
- if len(node.output) > index:
- if len(node.output[index]) > 0:
- logger.error(f"node output {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
- return False
- return True
- def _are_attentions_supported(self) -> bool:
- for node in self.attention_nodes:
- for attr in node.attribute:
- if attr.name not in ["num_heads", "mask_filter_value", "scale"]:
- logger.error(f"node attribute {attr.name} is not supported in PackedMultiHeadAttention: {node}")
- return False
- if node.input[MultiHeadAttentionInputIDs.KEY] and not node.input[MultiHeadAttentionInputIDs.VALUE]:
- logger.error("packed kv format is not supported in PackedMultiHeadAttention")
- return False
- if not (
- self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_KEY, "past_key")
- and self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_VALUE, "past_key")
- and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_KEY, "present_key")
- and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_VALUE, "present_key")
- ):
- return False
- return True
- def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
- gated_relative_pos_bias_count = 0
- for mha in self.attention_nodes:
- attention_bias = (
- mha.input[MultiHeadAttentionInputIDs.ATTENTION_BIAS]
- if len(mha.input) > MultiHeadAttentionInputIDs.ATTENTION_BIAS
- else ""
- )
- packed_mha = helper.make_node(
- Operators.PACKED_MULTI_HEAD_ATTENTION,
- inputs=[
- mha.input[MultiHeadAttentionInputIDs.QUERY],
- mha.input[MultiHeadAttentionInputIDs.KEY],
- mha.input[MultiHeadAttentionInputIDs.VALUE],
- mha.input[MultiHeadAttentionInputIDs.BIAS],
- token_offset,
- cumulative_sequence_length,
- attention_bias,
- ],
- outputs=[mha.output[MultiHeadAttentionOutputIDs.OUTPUT]],
- name=self.model.create_node_name(Operators.PACKED_MULTI_HEAD_ATTENTION),
- )
- attributes = []
- for attr in mha.attribute:
- if attr.name in ["num_heads", "mask_filter_value", "scale"]:
- attributes.append(attr)
- packed_mha.attribute.extend(attributes)
- packed_mha.domain = "com.microsoft"
- self.nodes_to_add.append(packed_mha)
- self.nodes_to_remove.append(mha)
- self.node_name_to_graph_name[packed_mha.name] = self.this_graph_name
- # Append token_offset input to GatedRelativePositionBias
- if attention_bias:
- rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.ATTENTION_BIAS)
- if (
- rel_pos_bias_node
- and rel_pos_bias_node.op_type == "GatedRelativePositionBias"
- and len(rel_pos_bias_node.input) == 6
- ):
- rel_pos_bias_node.input.append(token_offset)
- gated_relative_pos_bias_count += 1
- logger.info("Converted %d MultiHeadAttention nodes to PackedMultiHeadAttention.", len(self.attention_nodes))
- logger.info("Converted %d GatedRelativePositionBias nodes to packing mode.", gated_relative_pos_bias_count)
- def _get_input_to_remove_padding(self, first_attention_node) -> str | None:
- # When there are query, key and value inputs, we need to find the first input of the parent MatMul node.
- matmul = self.model.get_parent(first_attention_node, 0)
- if matmul and matmul.op_type == "MatMul":
- return matmul.input[0]
- return None
- class PackingMode:
- def __init__(self, model: OnnxModel):
- self.model = model
- def convert(self, use_symbolic_shape_infer: bool = True) -> None:
- if self.model.get_nodes_by_op_type(Operators.ATTENTION):
- if self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
- logger.error("Packing mode does not support both Attention and MultiHeadAttention in same graph.")
- return None
- packing = PackingAttention(self.model)
- return packing.convert(use_symbolic_shape_infer)
- elif self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
- packing = PackingMultiHeadAttention(self.model)
- return packing.convert(use_symbolic_shape_infer)
- else:
- logger.error("Packing mode requires either Attention or MultiHeadAttention node in onnx graph.")
- return None
- def _parse_arguments():
- parser = argparse.ArgumentParser(
- description="Convert to packing mode tool for ONNX Runtime. It converts BERT like model to use packing mode."
- )
- parser.add_argument("--input", required=True, type=str, help="input onnx model path")
- parser.add_argument("--output", required=True, type=str, help="optimized onnx model path")
- parser.add_argument("--verbose", required=False, action="store_true", help="show debug information.")
- parser.set_defaults(verbose=False)
- parser.add_argument(
- "--use_external_data_format",
- required=False,
- action="store_true",
- help="use external data format to store large model (>2GB)",
- )
- parser.set_defaults(use_external_data_format=False)
- args = parser.parse_args()
- return args
- def _setup_logger(verbose):
- if verbose:
- coloredlogs.install(
- level="DEBUG",
- fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
- )
- else:
- coloredlogs.install(fmt="%(funcName)20s: %(message)s")
- def main():
- args = _parse_arguments()
- _setup_logger(args.verbose)
- logger.debug(f"arguments:{args}")
- if os.path.realpath(args.input) == os.path.realpath(args.output):
- logger.warning("Specified the same input and output path. Note that this may overwrite the original model")
- model = load_model(args.input)
- packing_mode = PackingMode(OnnxModel(model))
- packing_mode.convert()
- packing_mode.model.save_model_to_file(args.output, use_external_data_format=args.use_external_data_format)
- if __name__ == "__main__":
- main()
|