# # The implementation of this file is based on: # https://github.com/intel/neural-compressor/tree/master/neural_compressor # # Copyright (c) 2023 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Class for ONNX model.""" import copy import logging import os import sys from collections import deque from pathlib import Path import onnx import onnx.external_data_helper from .util import MAXIMUM_PROTOBUF, find_by_name logger = logging.getLogger("neural_compressor") # TODO: Check https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/onnx_model.py to see if we can integrate with it. class ONNXModel: """Build ONNX model.""" def __init__(self, model, **kwargs): """Initialize an ONNX model. Args: model (str or ModelProto): path to onnx model or loaded ModelProto model object. ignore_warning (bool): ignore large model warning. Default is False. load_external_data (bool): load external data for large model. Default is True. """ self._model = model if not isinstance(model, str) else onnx.load(model, load_external_data=False) self._model_path = None if not isinstance(model, str) else model self.check_is_large_model() if self._is_large_model and self._model_path is None and not kwargs.get("ignore_warning", False): logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize") if self._is_large_model and isinstance(model, str) and kwargs.get("load_external_data", True): onnx.external_data_helper.load_external_data_for_model(self._model, os.path.dirname(self._model_path)) self._config = None if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()): from transformers import AutoConfig # noqa: PLC0415 self._config = AutoConfig.from_pretrained(Path(model).parent.as_posix()) self.node_name_counter = {} self._output_name_to_node = {} self._input_name_to_nodes = {} self._get_input_name_to_nodes(self._model.graph.node) self._get_output_name_to_node(self._model.graph.node) self._graph_info = {} self._get_graph_info() self._q_config = None def check_is_large_model(self): """Check model > 2GB.""" init_size = 0 for init in self._model.graph.initializer: # if initializer has external data location, return True if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: self._is_large_model = True return # if raise error of initializer size > 2GB, return True try: init_bytes = init.SerializeToString() init_size += sys.getsizeof(init_bytes) except Exception as e: if "exceeds maximum protobuf size of 2GB" in str(e): self._is_large_model = True return else: # pragma: no cover raise e if init_size > MAXIMUM_PROTOBUF: self._is_large_model = True return self._is_large_model = False @property def is_large_model(self): """Check the onnx model is over 2GB.""" return self._is_large_model @property def model_path(self): """Return model path.""" return self._model_path @model_path.setter def model_path(self, path): """Set model path.""" self._model_path = path def framework(self): """Return framework.""" return "onnxruntime" @property def q_config(self): """Return q_config.""" return self._q_config @q_config.setter def q_config(self, q_config): """Set q_config.""" self._q_config = q_config @property def hf_config(self): """Return huggingface config if model is Transformer-based.""" return self._config @property def model(self): """Return model itself.""" return self._model @model.setter def model(self, model): """Set model itself.""" self._model = model self._graph_info = {} self._get_graph_info() self._output_name_to_node = {} self._input_name_to_nodes = {} self._get_input_name_to_nodes(self._model.graph.node) self._get_output_name_to_node(self._model.graph.node) def input(self): """Return input of model.""" return [i.name for i in self._model.graph.input] def output(self): """Return output of model.""" return [i.name for i in self._model.graph.output] def update(self): """Update model info.""" self._graph_info = {} self._get_graph_info() self._output_name_to_node = {} self._input_name_to_nodes = {} self._get_input_name_to_nodes(self._model.graph.node) self._get_output_name_to_node(self._model.graph.node) @property def graph_info(self): """Return ORT Graph Info object holding information about backend graph.""" return self._graph_info def _get_graph_info(self): """Update graph info.""" for node in self._model.graph.node: self.graph_info.update({node.name: node.op_type}) def save(self, root): """Save ONNX model.""" if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]): raise ValueError('"root" directory does not exists.') if self.is_large_model: onnx.external_data_helper.load_external_data_for_model(self._model, os.path.split(self._model_path)[0]) onnx.save_model( self._model, root, save_as_external_data=True, all_tensors_to_one_file=True, location=root.split("/")[-1] + "_data", size_threshold=1024, convert_attribute=False, ) else: onnx.save(self._model, root) if self._config is not None: model_type = "" if not hasattr(self._config, "model_type") else self._config.model_type self._config.__class__.model_type = model_type output_config_file = Path(root).parent.joinpath("config.json").as_posix() self._config.to_json_file(output_config_file, use_diff=False) def nodes(self): """Return model nodes.""" return self._model.graph.node def initializer(self): """Return model initializer.""" return self._model.graph.initializer def graph(self): """Return model graph.""" return self._model.graph def ir_version(self): """Return model ir_version.""" return self._model.ir_version def opset_import(self): """Return model opset_import.""" return self._model.opset_import def remove_node(self, node): """Remove a node from model.""" if node in self._model.graph.node: self._model.graph.node.remove(node) def remove_nodes(self, nodes_to_remove): """Remove nodes from model.""" for node in nodes_to_remove: self.remove_node(node) def add_node(self, node): """Add a node to model.""" self._model.graph.node.extend([node]) def add_nodes(self, nodes_to_add): """Add nodes to model.""" self._model.graph.node.extend(nodes_to_add) def add_initializer(self, tensor): """Add a initializer to model.""" if find_by_name(tensor.name, self._model.graph.initializer) is None: self._model.graph.initializer.extend([tensor]) def add_initializers(self, tensors): """Add initializers to model.""" for tensor in tensors: self.add_initializer(tensor) def get_initializer(self, name): """Get an initializer by name.""" for tensor in self._model.graph.initializer: if tensor.name == name: return tensor return None def get_initializer_share_num(self, name): """Get the number of shares of initializer.""" num = 0 if self.get_initializer(name) is None: return num for node in self.nodes(): if name in node.input: num += 1 return num def get_node(self, name): """Get a node by name.""" for node in self._model.graph.node: if node.name == name: return node return None def remove_initializer(self, tensor): """Remove an initializer from model.""" if tensor in self._model.graph.initializer: self._model.graph.initializer.remove(tensor) def remove_initializers(self, init_to_remove): """Remove initializers from model.""" for initializer in init_to_remove: self.remove_initializer(initializer) def set_initializer(self, tensor, array, raw=False): """Update initializer.""" old_tensor = self.get_initializer(tensor) self.remove_initializer(old_tensor) dims = old_tensor.dims data_type = old_tensor.data_type new_tensor = ( onnx.helper.make_tensor(tensor, data_type, dims, array.flatten().tolist()) if not raw else onnx.helper.make_tensor(tensor, data_type, dims, array.tostring(), raw=raw) ) self.add_initializer(new_tensor) @property def input_name_to_nodes(self): """Return input names of nodes.""" return self._input_name_to_nodes def _get_input_name_to_nodes(self, nodes): """Get input names of nodes.""" for node in nodes: attrs = [ attr for attr in node.attribute if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS ] if len(attrs) > 0: for attr in attrs: self._get_input_name_to_nodes(attr.g.node) for input_name in node.input: if len(input_name.strip()) != 0: if input_name not in self._input_name_to_nodes: self._input_name_to_nodes[input_name] = [node] else: self._input_name_to_nodes[input_name].append(node) @property def output_name_to_node(self): """Return output names of nodes.""" return self._output_name_to_node def _get_output_name_to_node(self, nodes): """Get output names of nodes.""" for node in nodes: attrs = [ attr for attr in node.attribute if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS ] if len(attrs) > 0: for attr in attrs: self._get_output_name_to_node(attr.g.node) for output_name in node.output: if len(output_name.strip()) != 0: self._output_name_to_node[output_name] = node def get_siblings(self, node): """Get siblings nodes.""" siblings = [] for parent in self.get_parents(node): for child in self.get_children(parent): if child.name != node.name: siblings.append(child) return siblings def get_children(self, node, input_name_to_nodes=None): """Get children nodes.""" if input_name_to_nodes is None: input_name_to_nodes = self._input_name_to_nodes children = [] for output in node.output: if output in input_name_to_nodes: for child in input_name_to_nodes[output]: children.append(child) # noqa: PERF402 return children def get_parents(self, node, output_name_to_node=None): """Get parents nodes.""" if output_name_to_node is None: output_name_to_node = self._output_name_to_node parents = [] for input in node.input: if input in output_name_to_node: parents.append(output_name_to_node[input]) return parents def get_parent(self, node, idx, output_name_to_node=None): """Get parent node by idx.""" if output_name_to_node is None: output_name_to_node = self._output_name_to_node if len(node.input) <= idx: return None input = node.input[idx] if input not in output_name_to_node: return None return output_name_to_node[input] def find_node_by_name(self, node_name, new_nodes_list, graph): """Find out node by name.""" graph_nodes_list = list(graph.node) # deep copy graph_nodes_list.extend(new_nodes_list) node = find_by_name(node_name, graph_nodes_list) return node def find_nodes_by_initializer(self, graph, initializer): """Find all nodes with given initializer as an input.""" nodes = [] for node in graph.node: for node_input in node.input: if node_input == initializer.name: nodes.append(node) return nodes def get_scale_zero(self, tensor): """Help function to get scale and zero_point.""" if not tensor.endswith("_quantized"): logger.debug(f"Find {tensor} in the quantized graph is not quantized.") return None, None def _searcher(tensor_name): """Search scale and zero point tensor recursively.""" node = self._input_name_to_nodes[tensor_name][0] parent = self._output_name_to_node.get(tensor_name, None) direct_int8 = ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "MaxPool", "Pad", "Split"] if parent is not None and parent.op_type in direct_int8: fp32_tensor_name = ( parent.input[0] .replace("_quantized", "") .replace("_QuantizeLinear", "") .replace("_QuantizeInput", "") ) elif node.op_type in ["Gather"]: # pragma: no cover fp32_tensor_name = ( node.output[0] .replace("_quantized", "") .replace("_QuantizeLinear", "") .replace("_QuantizeInput", "") ) else: fp32_tensor_name = ( tensor_name.replace("_quantized", "").replace("_QuantizeLinear", "").replace("_QuantizeInput", "") ) scale = fp32_tensor_name + "_scale" scale_tensor = self.get_initializer(scale) zo = fp32_tensor_name + "_zero_point" zo_tensor = self.get_initializer(zo) if scale_tensor is None or zo_tensor is None: if parent is not None: scale_tensor, zo_tensor = _searcher(parent.input[0]) return scale_tensor, zo_tensor node = self._input_name_to_nodes[tensor][0] # TODO check if scale_tensor and zero_point is needed # for bias of qlinearconv, scale and zero_point is not needed if (node.op_type == "QLinearConv" and tensor == node.input[-1]) or ( node.op_type == "QGemm" and tensor == node.input[-3] ): return None, None else: scale_tensor, zo_tensor = _searcher(tensor) assert scale_tensor, f"missing scale for tensor {tensor}" assert zo_tensor, f"missing zero point for tensor {tensor}" return scale_tensor, zo_tensor def save_model_to_file(self, output_path, use_external_data_format=False): """Save model to external data, which is needed for model size > 2GB.""" if use_external_data_format: onnx.external_data_helper.convert_model_to_external_data( self._model, all_tensors_to_one_file=True, location=Path(output_path).name + ".data" ) onnx.save_model(self._model, output_path) @staticmethod def replace_node_input(node, old_input_name, new_input_name): """Replace input of a node.""" assert isinstance(old_input_name, str) and isinstance(new_input_name, str) for j in range(len(node.input)): if node.input[j] == old_input_name: node.input[j] = new_input_name def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optype=None, black_optype=None): """Replace inputs of all nodes.""" if white_optype is None: white_optype = [] if black_optype is None: black_optype = [] if len(white_optype) > 0: for node in self.model.graph.node: if node.op_type in white_optype: ONNXModel.replace_node_input(node, old_input_name, new_input_name) else: for node in self.model.graph.node: if node.op_type not in black_optype: ONNXModel.replace_node_input(node, old_input_name, new_input_name) @staticmethod def replace_node_output(node, old_output_name, new_output_name): """Replace output of a node.""" assert isinstance(old_output_name, str) and isinstance(new_output_name, str) for j in range(len(node.output)): if node.output[j] == old_output_name: node.output[j] = new_output_name def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_optype=None, black_optype=None): """Replace outputs of all nodes.""" if white_optype is None: white_optype = [] if black_optype is None: black_optype = [] if len(white_optype) > 0: for node in self.model.graph.node: if node.op_type in white_optype: ONNXModel.replace_node_output(node, old_output_name, new_output_name) else: for node in self.model.graph.node: if node.op_type not in black_optype: ONNXModel.replace_node_output(node, old_output_name, new_output_name) def remove_unused_nodes(self): """Remove unused nodes.""" unused_nodes = [] nodes = self.nodes() for node in nodes: if ( node.op_type == "Constant" and node.output[0] not in self._model.graph.output and node.output[0] not in self._input_name_to_nodes ): unused_nodes.append(node) elif ( node.op_type == "QuantizeLinear" and len(self.get_children(node)) == 1 and self.get_children(node)[0].op_type == "DequantizeLinear" and node.input[0] not in self._output_name_to_node and self.get_children(node)[0].output[0] not in self._input_name_to_nodes ): unused_nodes.append(node) unused_nodes.extend(self.get_children(node)) else: # remove the node if it does not serve as the input or output of any other nodes unused = True for output in node.output: if output in self._input_name_to_nodes or output in self.output(): unused = False break for input in node.input: if self.get_initializer(input) is not None: continue elif input in self._output_name_to_node or input in self.input(): unused = False break if unused: unused_nodes.append(node) self.remove_nodes(unused_nodes) ununsed_weights = [] for w in self._model.graph.initializer: if w.name not in self._input_name_to_nodes and w.name not in self._model.graph.output: ununsed_weights.append(w) # Remove from graph.input for graph_input in self.graph().input: if graph_input.name == w.name: self.graph().input.remove(graph_input) self.remove_initializers(ununsed_weights) self.update() def topological_sort(self, enable_subgraph=False): """Topological sort the model.""" if not enable_subgraph: input_name_to_nodes = {} output_name_to_node = {} for node in self.model.graph.node: for input_name in node.input: if len(input_name.strip()) != 0: if input_name not in input_name_to_nodes: input_name_to_nodes[input_name] = [node] else: input_name_to_nodes[input_name].append(node) for output_name in node.output: if len(output_name.strip()) != 0: output_name_to_node[output_name] = node else: # pragma: no cover input_name_to_nodes = self._input_name_to_nodes output_name_to_node = self._output_name_to_node all_nodes = {} q = deque() wait = deque() for inp in self.model.graph.input: q.extend(input_name_to_nodes[inp.name]) for n in self.model.graph.node: if all(i not in output_name_to_node and i not in self.input() for i in n.input): q.append(n) while q: n = q.popleft() if not all(output_name_to_node[i].name in all_nodes for i in n.input if i in output_name_to_node): if n not in wait: wait.append(n) continue all_nodes[n.name] = n for out in n.output: if out in input_name_to_nodes: q.extend([i for i in input_name_to_nodes[out] if i.name not in all_nodes and i not in q]) if len(q) == 0 and len(wait) != 0: q = copy.deepcopy(wait) wait.clear() nodes = [i[1] for i in all_nodes.items()] assert len(list({n.name for n in nodes})) == len(list({n.name for n in self.model.graph.node})) self.model.graph.ClearField("node") self.model.graph.node.extend(nodes) def get_nodes_chain(self, start, stop, result_chain=None): """Get nodes chain with given start node and stop node.""" if result_chain is None: result_chain = [] # process start node list start_node = deque() for node in start: if isinstance(node, str): start_node.append(node) elif isinstance(node, onnx.NodeProto): start_node.append(node.name) else: assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011 # process stop node list stop_node = [] for node in stop: if isinstance(node, str): stop_node.append(node) elif isinstance(node, onnx.NodeProto): stop_node.append(node.name) else: assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011 while start_node: node_name = start_node.popleft() if node_name in stop_node: continue if node_name not in result_chain: result_chain.append(node_name) else: continue node = find_by_name(node_name, list(self.model.graph.node)) for parent in self.get_parents(node): start_node.append(parent.name) return result_chain def find_split_node_for_layer_wise_quantization(self): """Find split node for layer wise quantization.""" # find split nodes of decoder blocks # embed -> decoder.0 -(split_node)-> ... -(split_node)-> decoder.n -(split_node)-> norm -> head # after split: embed -> decoder.0, # decoder.1, # decoder.2, # ..., # decoder.n, # norm -> head start_nodes = [] for node in self._model.graph.node: start_node, qkv_nodes_list = None, None if node.op_type == "SkipLayerNormalization": start_node = node qkv_nodes_list = [ self.match_parent_path( start_node, ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [None, 0, 0, 0, 0], ), self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, 1, 0, 0, 0], ), ] if node.op_type == "Add": start_node = node qkv_nodes_list = [ # match base attention structure self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [0, None, 0, 0, 0], ), self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0] ), # match gpt attention no past structure self.match_parent_path( start_node, ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], [None, 0, 0, 0, 0, 0], output_name_to_node=self.output_name_to_node, return_indice=[], ), # match bart attention structure self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0, 0, 0], ), self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0, 0, 0], ), self.match_parent_path( start_node, ["MatMul", "Mul", "MatMul", "Mul", "Div", "Add"], [None, 0, None, 0, None, 0], ), self.match_parent_path( start_node, ["MatMul", "Mul", "MatMul", "SimplifiedLayerNormalization", "Add"], [None, 0, None, 0, 0], ), ] if not start_node: continue if not any(qkv_nodes_list): continue start_nodes.append(start_node) return start_nodes def find_qkv_in_attention(self, find_all=False): """Find qkv MatMul in Attention. Args: find_all (bool, optional): find all qkv MatMul. Defaults to False Returns: qkv (list): qkv MatMul list """ qkv = [] for node in self._model.graph.node: if node.op_type == "Attention": qkv.append([node.name]) continue start_node, qkv_nodes_list = None, None if node.op_type == "SkipLayerNormalization": start_node = node qkv_nodes_list = [ self.match_parent_path( start_node, ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [None, 0, 0, 0, 0], ), self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, 1, 0, 0, 0], ), ] if node.op_type == "Add": start_node = node qkv_nodes_list = [ # match base attention structure self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [0, None, 0, 0, 0], ), self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0] ), # match gpt attention no past structure self.match_parent_path( start_node, ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], [None, 0, 0, 0, 0, 0], output_name_to_node=self.output_name_to_node, return_indice=[], ), # match bart attention structure self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0, 0, 0], ), self.match_parent_path( start_node, ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0, 0, 0], ), ] if not start_node: continue if not any(qkv_nodes_list): continue qkv_nodes = [qkv for qkv in qkv_nodes_list if qkv is not None][-1] other_inputs = [] for input in start_node.input: if input not in self.output_name_to_node: continue if input == qkv_nodes[0].output[0]: continue other_inputs.append(input) if len(other_inputs) != 1: continue root_input = other_inputs[0] input_name_to_nodes = self.input_name_to_nodes children = input_name_to_nodes[root_input] children_types = [child.op_type for child in children] if children_types.count("MatMul") == 3: qkv.append([child.name for child in children if child.op_type == "MatMul"]) if not find_all: break return qkv def find_ffn_matmul(self, attention_index, attention_matmul_list, block_len): """Find MatMul in FFN. Args: attention_index (list): index of Attention attention_matmul_list (list): list of Attention and MatMul nodes block_len (int): block length Returns: list: list of MatMul in FFN """ ffn_matmul = [] for idx in range(len(attention_index)): if idx != len(attention_index) - 1: index = attention_index[idx + 1] if index - 2 >= 0: ffn_matmul.append([attention_matmul_list[index - 2], attention_matmul_list[index - 1]]) else: index = attention_index[idx] if index + block_len - 1 < len(attention_matmul_list): ffn_matmul.append( [attention_matmul_list[index + block_len - 2], attention_matmul_list[index + block_len - 1]] ) return ffn_matmul def export(self, save_path, conf): """Export Qlinear to QDQ model.""" from neural_compressor.config import ONNXQlinear2QDQConfig # noqa: PLC0415 from neural_compressor.utils.export import onnx_qlinear_to_qdq # noqa: PLC0415 if isinstance(conf, ONNXQlinear2QDQConfig): add_nodes, remove_nodes, inits = onnx_qlinear_to_qdq(self._model, self._input_name_to_nodes) self.add_nodes(add_nodes) self.remove_nodes(remove_nodes) self.add_initializers(inits) self.update() self.remove_unused_nodes() self.topological_sort() self.save(save_path) else: logger.warning("Unsupported config for export, only ONNXQlinear2QDQConfig is supported!") exit(0) def add_tensors_to_outputs(self, tensor_names): """Add the tensors to the model outputs to gets their values. Args: tensor_names: The names of tensors to be dumped. """ added_outputs = [] for tensor in tensor_names: if tensor not in self.output(): added_tensor = onnx.helper.ValueInfoProto() added_tensor.name = tensor added_outputs.append(added_tensor) self._model.graph.output.extend(added_outputs) # pylint: disable=no-member def remove_tensors_from_outputs(self, tensor_names): """Remove the tensors from the model outputs. Args: tensor_names: The names of tensors to be removed. """ removed_outputs = [] for tensor in tensor_names: if tensor in self.output(): removed_outputs.append(self._model.graph.output[self.output().index(tensor)]) for output in removed_outputs: self._model.graph.output.remove(output) def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=None): """Find parent node based on constraints on op_type. Args: node (str): current node name. parent_op_type (str): constraint of parent node op_type. output_name_to_node (dict): dictionary with output name as key, and node as value. exclude (list): list of nodes that are excluded (not allowed to match as parent). Returns: parent: The matched parent node. None if not found. index: The input index of matched parent node. None if not found. """ if exclude is None: exclude = [] for i, input in enumerate(node.input): if input in output_name_to_node: parent = output_name_to_node[input] if parent.op_type == parent_op_type and parent not in exclude: return parent, i return None, None def match_parent( self, node, parent_op_type, input_index=None, output_name_to_node=None, exclude=None, return_indice=None, ): """Find parent node based on constraints on op_type and index. Args: node (str): current node name. parent_op_type (str): constraint of parent node op_type. input_index (int or None): only check the parent given input index of current node. output_name_to_node (dict): dictionary with output name as key, and node as value. exclude (list): list of nodes that are excluded (not allowed to match as parent). return_indice (list): a list to append the input index when input_index is None. Returns: parent: The matched parent node. """ assert node is not None assert input_index is None or input_index >= 0 if exclude is None: exclude = [] if output_name_to_node is None: output_name_to_node = self._output_name_to_node if input_index is None: parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude) if return_indice is not None: return_indice.append(index) return parent if input_index >= len(node.input): return None parent = self.get_parent(node, input_index, output_name_to_node) if parent is not None and parent.op_type == parent_op_type and parent not in exclude: return parent return None def match_parent_path( self, node, parent_op_types, parent_input_index, output_name_to_node=None, return_indice=None, ): """Find a sequence of input edges based on constraints on parent op_type and index. Args: node (str): current node name. parent_op_types (str): constraint of parent node op_type of each input edge. parent_input_index (list): constraint of input index of each input edge. None means no constraint. output_name_to_node (dict): dictionary with output name as key, and node as value. return_indice (list): a list to append the input index when there is no constraint on input index of an edge. Returns: parents: a list of matched parent node. """ assert len(parent_input_index) == len(parent_op_types) if output_name_to_node is None: output_name_to_node = self._output_name_to_node current_node = node matched_parents = [] for i, op_type in enumerate(parent_op_types): matched_parent = self.match_parent( current_node, op_type, parent_input_index[i], output_name_to_node, exclude=[], return_indice=return_indice, ) if matched_parent is None: return None matched_parents.append(matched_parent) current_node = matched_parent return matched_parents def is_smoothquant_model(self): """Check the model is smooth quantized or not. Returns: bool: the model is smooth quantized or not. """ for init in self.model.graph.initializer: # noqa: SIM110 if "_smooth_scale" in init.name: return True return False def find_split_nodes(self): """Find split nodes for layer-wise quantization.""" split_nodes = self.find_split_node_for_layer_wise_quantization() return split_nodes def split_model_with_node( self, split_node_name, path_of_model_to_split, shape_infer=True, save_both_split_models=True ): """Split model into two parts at a given node. Args: split_node_name (str): name of the node where the model is split at> path_of_model_to_split (str): path of model to be split. shape_infer (bool): do shape inference. Default is True. save_both_split_models (bool): whether to save the two split models. False means only save the first split model. True means save both the two split models. Default id True. Returns: tuple: the first split model, the second split model """ # origin model : ... -> node_1 -> split_node -> node_2 -> ... # split model 1: ... -> node_1 -> split_node # split model 2: node_2 -> ... split_model_part_1 = onnx.ModelProto() split_model_part_1.CopyFrom(self._model) split_model_part_1.graph.ClearField("node") split_model_part_2 = onnx.ModelProto() split_model_part_2.CopyFrom(self._model) split_model_part_2.graph.ClearField("node") split_node_output = None part_idx = 1 for node in self._model.graph.node: if part_idx == 1: split_model_part_1.graph.node.append(node) elif part_idx == 2: split_model_part_2.graph.node.append(node) if node.name == split_node_name: split_node_output = node.output part_idx = 2 assert len(split_node_output) == 1, ( f"Only support split at node with 1 output tensor, while current split node {split_node_name} has {len(split_node_output)} output tensors" ) split_tensor_name = split_node_output[0] # infer shape of the model to be split if shape_infer: try: from neural_compressor.adaptor.ox_utils.util import infer_shapes # noqa: PLC0415 self._model = infer_shapes(self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path)) except Exception as e: # pragma: no cover logger.error( "Shape infer fails for layer-wise quantization. " "We would recommend checking the graph optimization level of your model " "and setting it to 'DISABLE_ALL' or 'ENABLE_BASIC', " "as this may help avoid this error." ) raise e split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name) split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape) split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True) split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True) # remove unused input & output split_model_part_1._remove_unused_input_output() split_model_part_2._remove_unused_input_output() split_model_part_1.model.graph.output.append(split_tensor) split_model_part_2.model.graph.input.append(split_tensor) insert_output_for_model_1 = [] insert_input_for_model_2 = [] for output in split_model_part_1.output_name_to_node: if output in split_model_part_2.input_name_to_nodes: output_type, output_shape = self._get_output_type_shape_by_tensor_name(output) output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape) if output_tensor not in split_model_part_1.model.graph.output: insert_output_for_model_1.append(output_tensor) if output_tensor not in split_model_part_2.model.graph.input: insert_input_for_model_2.append(output_tensor) # insert model 1 output for output in insert_output_for_model_1: split_model_part_1.model.graph.output.append(output) # insert model 2 input for input in insert_input_for_model_2: split_model_part_2.model.graph.input.append(input) # remove unused init split_model_part_1.remove_unused_init() split_model_part_2.remove_unused_init() split_model_part_1.update() split_model_part_2.update() dir_of_model_to_split = os.path.dirname(path_of_model_to_split) split_model_part_1.load_model_initializer_by_tensor(dir_of_model_to_split) split_model_part_1_path = os.path.join(dir_of_model_to_split, "split_model_part_1.onnx") split_model_part_1.model_path = split_model_part_1_path split_model_part_1._save_split_model(split_model_part_1_path) split_model_part_1.check_is_large_model() logger.debug(f"save split model part 1 to {split_model_part_1_path} for layer wise quantization") if save_both_split_models: split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split) split_model_part_2_path = os.path.join(dir_of_model_to_split, "split_model_part_2.onnx") split_model_part_2.model_path = split_model_part_2_path split_model_part_2._save_split_model(split_model_part_2_path) split_model_part_2.check_is_large_model() logger.debug(f"save split model part 2 to {split_model_part_2_path} for layer wise quantization") return split_model_part_1, split_model_part_2 else: return split_model_part_1, split_model_part_2 def _save_split_model(self, save_path): """Save split model as external data for layer wise quantization. Args: save_path (str): the path to save the split model """ if os.path.exists(save_path + "_data"): os.remove(save_path + "_data") onnx.save_model( self._model, save_path, save_as_external_data=True, all_tensors_to_one_file=True, location=save_path.split("/")[-1] + "_data", size_threshold=1024, convert_attribute=False, ) def _get_output_type_shape_by_tensor_name(self, tensor_name): """Get output type and shape with a tensor name. Args: tensor_name (str): name of a tensor Returns: tuple: output type and shape """ elem_type = onnx.TensorProto.FLOAT shape = None for output in self._model.graph.value_info: if output.name == tensor_name: elem_type = output.type.tensor_type.elem_type shape = [ dim.dim_value if dim.HasField("dim_value") else -1 for dim in output.type.tensor_type.shape.dim ] break return elem_type, shape def _remove_unused_input_output(self): """Remove unused input & output for split model.""" remove_outputs = [] remove_inputs = [] for output in self._model.graph.output: if output.name not in self.output_name_to_node: remove_outputs.append(output) for input in self._model.graph.input: if input.name not in self.input_name_to_nodes: remove_inputs.append(input) for output in remove_outputs: self._model.graph.output.remove(output) for input in remove_inputs: self._model.graph.input.remove(input) def remove_unused_init(self): """Remove unused init.""" remov_inits = [] for init in self._model.graph.initializer: if init.name not in self.input_name_to_nodes: remov_inits.append(init) self.remove_initializers(remov_inits) def load_model_initializer_by_tensor(self, data_path=None): """Load model initializer by tensor. Args: data_path (str, optional): the directory of saved initializer. Defaults to None. """ if data_path is None: data_path = os.path.dirname(self._model_path) for init in self._model.graph.initializer: if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: onnx.external_data_helper.load_external_data_for_tensor(init, data_path) def write_external_data_to_new_location(self, external_data_location="external.data", overwrite=False): """Write external data of merged quantized model to new location to save memory. Args: external_data_location (str, optional): external data location of merged quantized model. Defaults to "external.data". overwrite (bool, optional): if True, remove existed externa data. Defaults to False. """ if overwrite and os.path.exists(os.path.join(os.path.dirname(self._model_path), external_data_location)): os.remove(os.path.join(os.path.dirname(self._model_path), external_data_location)) self.load_model_initializer_by_tensor() onnx.external_data_helper.convert_model_to_external_data(self._model, location=external_data_location) # TODO : if init is already saved, skip write it onnx.external_data_helper.write_external_data_tensors(self._model, filepath=os.path.dirname(self._model_path)) def merge_split_models(self, to_merge_model): """Merge two split model into final model.""" to_merge_model.write_external_data_to_new_location() self.add_nodes(list(to_merge_model.nodes())) self.add_initializers(list(to_merge_model.initializer())) self.update() # add new output for output in to_merge_model.graph().output: if output.name not in self.output(): self._model.graph.output.append(output) # remove unused output remove_output = [] for output in self._model.graph.output: if output.name in to_merge_model.input(): remove_output.append(output) for output in remove_output: self._model.graph.output.remove(output) # add new input for input in to_merge_model.graph().input: if ( input.name not in self.input() and input.name not in self.output() and input.name not in self.output_name_to_node ): self._model.graph.input.append(input) def re_org_output(self, origin_output): """Re-org output of merged model for layer-wise quantization.""" outputs = {} tmp_remove = [] for output in self._model.graph.output: outputs[output.name] = output tmp_remove.append(output) for output in tmp_remove: self._model.graph.output.remove(output) for out_name in origin_output: self._model.graph.output.append(outputs[out_name])