onnx_model.py 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251
  1. #
  2. # The implementation of this file is based on:
  3. # https://github.com/intel/neural-compressor/tree/master/neural_compressor
  4. #
  5. # Copyright (c) 2023 Intel Corporation
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Class for ONNX model."""
  19. import copy
  20. import logging
  21. import os
  22. import sys
  23. from collections import deque
  24. from pathlib import Path
  25. import onnx
  26. import onnx.external_data_helper
  27. from .util import MAXIMUM_PROTOBUF, find_by_name
  28. logger = logging.getLogger("neural_compressor")
  29. # TODO: Check https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/onnx_model.py to see if we can integrate with it.
  30. class ONNXModel:
  31. """Build ONNX model."""
  32. def __init__(self, model, **kwargs):
  33. """Initialize an ONNX model.
  34. Args:
  35. model (str or ModelProto): path to onnx model or loaded ModelProto model object.
  36. ignore_warning (bool): ignore large model warning. Default is False.
  37. load_external_data (bool): load external data for large model. Default is True.
  38. """
  39. self._model = model if not isinstance(model, str) else onnx.load(model, load_external_data=False)
  40. self._model_path = None if not isinstance(model, str) else model
  41. self.check_is_large_model()
  42. if self._is_large_model and self._model_path is None and not kwargs.get("ignore_warning", False):
  43. logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize")
  44. if self._is_large_model and isinstance(model, str) and kwargs.get("load_external_data", True):
  45. onnx.external_data_helper.load_external_data_for_model(self._model, os.path.dirname(self._model_path))
  46. self._config = None
  47. if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()):
  48. from transformers import AutoConfig # noqa: PLC0415
  49. self._config = AutoConfig.from_pretrained(Path(model).parent.as_posix())
  50. self.node_name_counter = {}
  51. self._output_name_to_node = {}
  52. self._input_name_to_nodes = {}
  53. self._get_input_name_to_nodes(self._model.graph.node)
  54. self._get_output_name_to_node(self._model.graph.node)
  55. self._graph_info = {}
  56. self._get_graph_info()
  57. self._q_config = None
  58. def check_is_large_model(self):
  59. """Check model > 2GB."""
  60. init_size = 0
  61. for init in self._model.graph.initializer:
  62. # if initializer has external data location, return True
  63. if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
  64. self._is_large_model = True
  65. return
  66. # if raise error of initializer size > 2GB, return True
  67. try:
  68. init_bytes = init.SerializeToString()
  69. init_size += sys.getsizeof(init_bytes)
  70. except Exception as e:
  71. if "exceeds maximum protobuf size of 2GB" in str(e):
  72. self._is_large_model = True
  73. return
  74. else: # pragma: no cover
  75. raise e
  76. if init_size > MAXIMUM_PROTOBUF:
  77. self._is_large_model = True
  78. return
  79. self._is_large_model = False
  80. @property
  81. def is_large_model(self):
  82. """Check the onnx model is over 2GB."""
  83. return self._is_large_model
  84. @property
  85. def model_path(self):
  86. """Return model path."""
  87. return self._model_path
  88. @model_path.setter
  89. def model_path(self, path):
  90. """Set model path."""
  91. self._model_path = path
  92. def framework(self):
  93. """Return framework."""
  94. return "onnxruntime"
  95. @property
  96. def q_config(self):
  97. """Return q_config."""
  98. return self._q_config
  99. @q_config.setter
  100. def q_config(self, q_config):
  101. """Set q_config."""
  102. self._q_config = q_config
  103. @property
  104. def hf_config(self):
  105. """Return huggingface config if model is Transformer-based."""
  106. return self._config
  107. @property
  108. def model(self):
  109. """Return model itself."""
  110. return self._model
  111. @model.setter
  112. def model(self, model):
  113. """Set model itself."""
  114. self._model = model
  115. self._graph_info = {}
  116. self._get_graph_info()
  117. self._output_name_to_node = {}
  118. self._input_name_to_nodes = {}
  119. self._get_input_name_to_nodes(self._model.graph.node)
  120. self._get_output_name_to_node(self._model.graph.node)
  121. def input(self):
  122. """Return input of model."""
  123. return [i.name for i in self._model.graph.input]
  124. def output(self):
  125. """Return output of model."""
  126. return [i.name for i in self._model.graph.output]
  127. def update(self):
  128. """Update model info."""
  129. self._graph_info = {}
  130. self._get_graph_info()
  131. self._output_name_to_node = {}
  132. self._input_name_to_nodes = {}
  133. self._get_input_name_to_nodes(self._model.graph.node)
  134. self._get_output_name_to_node(self._model.graph.node)
  135. @property
  136. def graph_info(self):
  137. """Return ORT Graph Info object holding information about backend graph."""
  138. return self._graph_info
  139. def _get_graph_info(self):
  140. """Update graph info."""
  141. for node in self._model.graph.node:
  142. self.graph_info.update({node.name: node.op_type})
  143. def save(self, root):
  144. """Save ONNX model."""
  145. if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]):
  146. raise ValueError('"root" directory does not exists.')
  147. if self.is_large_model:
  148. onnx.external_data_helper.load_external_data_for_model(self._model, os.path.split(self._model_path)[0])
  149. onnx.save_model(
  150. self._model,
  151. root,
  152. save_as_external_data=True,
  153. all_tensors_to_one_file=True,
  154. location=root.split("/")[-1] + "_data",
  155. size_threshold=1024,
  156. convert_attribute=False,
  157. )
  158. else:
  159. onnx.save(self._model, root)
  160. if self._config is not None:
  161. model_type = "" if not hasattr(self._config, "model_type") else self._config.model_type
  162. self._config.__class__.model_type = model_type
  163. output_config_file = Path(root).parent.joinpath("config.json").as_posix()
  164. self._config.to_json_file(output_config_file, use_diff=False)
  165. def nodes(self):
  166. """Return model nodes."""
  167. return self._model.graph.node
  168. def initializer(self):
  169. """Return model initializer."""
  170. return self._model.graph.initializer
  171. def graph(self):
  172. """Return model graph."""
  173. return self._model.graph
  174. def ir_version(self):
  175. """Return model ir_version."""
  176. return self._model.ir_version
  177. def opset_import(self):
  178. """Return model opset_import."""
  179. return self._model.opset_import
  180. def remove_node(self, node):
  181. """Remove a node from model."""
  182. if node in self._model.graph.node:
  183. self._model.graph.node.remove(node)
  184. def remove_nodes(self, nodes_to_remove):
  185. """Remove nodes from model."""
  186. for node in nodes_to_remove:
  187. self.remove_node(node)
  188. def add_node(self, node):
  189. """Add a node to model."""
  190. self._model.graph.node.extend([node])
  191. def add_nodes(self, nodes_to_add):
  192. """Add nodes to model."""
  193. self._model.graph.node.extend(nodes_to_add)
  194. def add_initializer(self, tensor):
  195. """Add a initializer to model."""
  196. if find_by_name(tensor.name, self._model.graph.initializer) is None:
  197. self._model.graph.initializer.extend([tensor])
  198. def add_initializers(self, tensors):
  199. """Add initializers to model."""
  200. for tensor in tensors:
  201. self.add_initializer(tensor)
  202. def get_initializer(self, name):
  203. """Get an initializer by name."""
  204. for tensor in self._model.graph.initializer:
  205. if tensor.name == name:
  206. return tensor
  207. return None
  208. def get_initializer_share_num(self, name):
  209. """Get the number of shares of initializer."""
  210. num = 0
  211. if self.get_initializer(name) is None:
  212. return num
  213. for node in self.nodes():
  214. if name in node.input:
  215. num += 1
  216. return num
  217. def get_node(self, name):
  218. """Get a node by name."""
  219. for node in self._model.graph.node:
  220. if node.name == name:
  221. return node
  222. return None
  223. def remove_initializer(self, tensor):
  224. """Remove an initializer from model."""
  225. if tensor in self._model.graph.initializer:
  226. self._model.graph.initializer.remove(tensor)
  227. def remove_initializers(self, init_to_remove):
  228. """Remove initializers from model."""
  229. for initializer in init_to_remove:
  230. self.remove_initializer(initializer)
  231. def set_initializer(self, tensor, array, raw=False):
  232. """Update initializer."""
  233. old_tensor = self.get_initializer(tensor)
  234. self.remove_initializer(old_tensor)
  235. dims = old_tensor.dims
  236. data_type = old_tensor.data_type
  237. new_tensor = (
  238. onnx.helper.make_tensor(tensor, data_type, dims, array.flatten().tolist())
  239. if not raw
  240. else onnx.helper.make_tensor(tensor, data_type, dims, array.tostring(), raw=raw)
  241. )
  242. self.add_initializer(new_tensor)
  243. @property
  244. def input_name_to_nodes(self):
  245. """Return input names of nodes."""
  246. return self._input_name_to_nodes
  247. def _get_input_name_to_nodes(self, nodes):
  248. """Get input names of nodes."""
  249. for node in nodes:
  250. attrs = [
  251. attr
  252. for attr in node.attribute
  253. if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
  254. ]
  255. if len(attrs) > 0:
  256. for attr in attrs:
  257. self._get_input_name_to_nodes(attr.g.node)
  258. for input_name in node.input:
  259. if len(input_name.strip()) != 0:
  260. if input_name not in self._input_name_to_nodes:
  261. self._input_name_to_nodes[input_name] = [node]
  262. else:
  263. self._input_name_to_nodes[input_name].append(node)
  264. @property
  265. def output_name_to_node(self):
  266. """Return output names of nodes."""
  267. return self._output_name_to_node
  268. def _get_output_name_to_node(self, nodes):
  269. """Get output names of nodes."""
  270. for node in nodes:
  271. attrs = [
  272. attr
  273. for attr in node.attribute
  274. if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
  275. ]
  276. if len(attrs) > 0:
  277. for attr in attrs:
  278. self._get_output_name_to_node(attr.g.node)
  279. for output_name in node.output:
  280. if len(output_name.strip()) != 0:
  281. self._output_name_to_node[output_name] = node
  282. def get_siblings(self, node):
  283. """Get siblings nodes."""
  284. siblings = []
  285. for parent in self.get_parents(node):
  286. for child in self.get_children(parent):
  287. if child.name != node.name:
  288. siblings.append(child)
  289. return siblings
  290. def get_children(self, node, input_name_to_nodes=None):
  291. """Get children nodes."""
  292. if input_name_to_nodes is None:
  293. input_name_to_nodes = self._input_name_to_nodes
  294. children = []
  295. for output in node.output:
  296. if output in input_name_to_nodes:
  297. for child in input_name_to_nodes[output]:
  298. children.append(child) # noqa: PERF402
  299. return children
  300. def get_parents(self, node, output_name_to_node=None):
  301. """Get parents nodes."""
  302. if output_name_to_node is None:
  303. output_name_to_node = self._output_name_to_node
  304. parents = []
  305. for input in node.input:
  306. if input in output_name_to_node:
  307. parents.append(output_name_to_node[input])
  308. return parents
  309. def get_parent(self, node, idx, output_name_to_node=None):
  310. """Get parent node by idx."""
  311. if output_name_to_node is None:
  312. output_name_to_node = self._output_name_to_node
  313. if len(node.input) <= idx:
  314. return None
  315. input = node.input[idx]
  316. if input not in output_name_to_node:
  317. return None
  318. return output_name_to_node[input]
  319. def find_node_by_name(self, node_name, new_nodes_list, graph):
  320. """Find out node by name."""
  321. graph_nodes_list = list(graph.node) # deep copy
  322. graph_nodes_list.extend(new_nodes_list)
  323. node = find_by_name(node_name, graph_nodes_list)
  324. return node
  325. def find_nodes_by_initializer(self, graph, initializer):
  326. """Find all nodes with given initializer as an input."""
  327. nodes = []
  328. for node in graph.node:
  329. for node_input in node.input:
  330. if node_input == initializer.name:
  331. nodes.append(node)
  332. return nodes
  333. def get_scale_zero(self, tensor):
  334. """Help function to get scale and zero_point."""
  335. if not tensor.endswith("_quantized"):
  336. logger.debug(f"Find {tensor} in the quantized graph is not quantized.")
  337. return None, None
  338. def _searcher(tensor_name):
  339. """Search scale and zero point tensor recursively."""
  340. node = self._input_name_to_nodes[tensor_name][0]
  341. parent = self._output_name_to_node.get(tensor_name, None)
  342. direct_int8 = ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "MaxPool", "Pad", "Split"]
  343. if parent is not None and parent.op_type in direct_int8:
  344. fp32_tensor_name = (
  345. parent.input[0]
  346. .replace("_quantized", "")
  347. .replace("_QuantizeLinear", "")
  348. .replace("_QuantizeInput", "")
  349. )
  350. elif node.op_type in ["Gather"]: # pragma: no cover
  351. fp32_tensor_name = (
  352. node.output[0]
  353. .replace("_quantized", "")
  354. .replace("_QuantizeLinear", "")
  355. .replace("_QuantizeInput", "")
  356. )
  357. else:
  358. fp32_tensor_name = (
  359. tensor_name.replace("_quantized", "").replace("_QuantizeLinear", "").replace("_QuantizeInput", "")
  360. )
  361. scale = fp32_tensor_name + "_scale"
  362. scale_tensor = self.get_initializer(scale)
  363. zo = fp32_tensor_name + "_zero_point"
  364. zo_tensor = self.get_initializer(zo)
  365. if scale_tensor is None or zo_tensor is None:
  366. if parent is not None:
  367. scale_tensor, zo_tensor = _searcher(parent.input[0])
  368. return scale_tensor, zo_tensor
  369. node = self._input_name_to_nodes[tensor][0]
  370. # TODO check if scale_tensor and zero_point is needed
  371. # for bias of qlinearconv, scale and zero_point is not needed
  372. if (node.op_type == "QLinearConv" and tensor == node.input[-1]) or (
  373. node.op_type == "QGemm" and tensor == node.input[-3]
  374. ):
  375. return None, None
  376. else:
  377. scale_tensor, zo_tensor = _searcher(tensor)
  378. assert scale_tensor, f"missing scale for tensor {tensor}"
  379. assert zo_tensor, f"missing zero point for tensor {tensor}"
  380. return scale_tensor, zo_tensor
  381. def save_model_to_file(self, output_path, use_external_data_format=False):
  382. """Save model to external data, which is needed for model size > 2GB."""
  383. if use_external_data_format:
  384. onnx.external_data_helper.convert_model_to_external_data(
  385. self._model, all_tensors_to_one_file=True, location=Path(output_path).name + ".data"
  386. )
  387. onnx.save_model(self._model, output_path)
  388. @staticmethod
  389. def replace_node_input(node, old_input_name, new_input_name):
  390. """Replace input of a node."""
  391. assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
  392. for j in range(len(node.input)):
  393. if node.input[j] == old_input_name:
  394. node.input[j] = new_input_name
  395. def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optype=None, black_optype=None):
  396. """Replace inputs of all nodes."""
  397. if white_optype is None:
  398. white_optype = []
  399. if black_optype is None:
  400. black_optype = []
  401. if len(white_optype) > 0:
  402. for node in self.model.graph.node:
  403. if node.op_type in white_optype:
  404. ONNXModel.replace_node_input(node, old_input_name, new_input_name)
  405. else:
  406. for node in self.model.graph.node:
  407. if node.op_type not in black_optype:
  408. ONNXModel.replace_node_input(node, old_input_name, new_input_name)
  409. @staticmethod
  410. def replace_node_output(node, old_output_name, new_output_name):
  411. """Replace output of a node."""
  412. assert isinstance(old_output_name, str) and isinstance(new_output_name, str)
  413. for j in range(len(node.output)):
  414. if node.output[j] == old_output_name:
  415. node.output[j] = new_output_name
  416. def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_optype=None, black_optype=None):
  417. """Replace outputs of all nodes."""
  418. if white_optype is None:
  419. white_optype = []
  420. if black_optype is None:
  421. black_optype = []
  422. if len(white_optype) > 0:
  423. for node in self.model.graph.node:
  424. if node.op_type in white_optype:
  425. ONNXModel.replace_node_output(node, old_output_name, new_output_name)
  426. else:
  427. for node in self.model.graph.node:
  428. if node.op_type not in black_optype:
  429. ONNXModel.replace_node_output(node, old_output_name, new_output_name)
  430. def remove_unused_nodes(self):
  431. """Remove unused nodes."""
  432. unused_nodes = []
  433. nodes = self.nodes()
  434. for node in nodes:
  435. if (
  436. node.op_type == "Constant"
  437. and node.output[0] not in self._model.graph.output
  438. and node.output[0] not in self._input_name_to_nodes
  439. ):
  440. unused_nodes.append(node)
  441. elif (
  442. node.op_type == "QuantizeLinear"
  443. and len(self.get_children(node)) == 1
  444. and self.get_children(node)[0].op_type == "DequantizeLinear"
  445. and node.input[0] not in self._output_name_to_node
  446. and self.get_children(node)[0].output[0] not in self._input_name_to_nodes
  447. ):
  448. unused_nodes.append(node)
  449. unused_nodes.extend(self.get_children(node))
  450. else:
  451. # remove the node if it does not serve as the input or output of any other nodes
  452. unused = True
  453. for output in node.output:
  454. if output in self._input_name_to_nodes or output in self.output():
  455. unused = False
  456. break
  457. for input in node.input:
  458. if self.get_initializer(input) is not None:
  459. continue
  460. elif input in self._output_name_to_node or input in self.input():
  461. unused = False
  462. break
  463. if unused:
  464. unused_nodes.append(node)
  465. self.remove_nodes(unused_nodes)
  466. ununsed_weights = []
  467. for w in self._model.graph.initializer:
  468. if w.name not in self._input_name_to_nodes and w.name not in self._model.graph.output:
  469. ununsed_weights.append(w)
  470. # Remove from graph.input
  471. for graph_input in self.graph().input:
  472. if graph_input.name == w.name:
  473. self.graph().input.remove(graph_input)
  474. self.remove_initializers(ununsed_weights)
  475. self.update()
  476. def topological_sort(self, enable_subgraph=False):
  477. """Topological sort the model."""
  478. if not enable_subgraph:
  479. input_name_to_nodes = {}
  480. output_name_to_node = {}
  481. for node in self.model.graph.node:
  482. for input_name in node.input:
  483. if len(input_name.strip()) != 0:
  484. if input_name not in input_name_to_nodes:
  485. input_name_to_nodes[input_name] = [node]
  486. else:
  487. input_name_to_nodes[input_name].append(node)
  488. for output_name in node.output:
  489. if len(output_name.strip()) != 0:
  490. output_name_to_node[output_name] = node
  491. else: # pragma: no cover
  492. input_name_to_nodes = self._input_name_to_nodes
  493. output_name_to_node = self._output_name_to_node
  494. all_nodes = {}
  495. q = deque()
  496. wait = deque()
  497. for inp in self.model.graph.input:
  498. q.extend(input_name_to_nodes[inp.name])
  499. for n in self.model.graph.node:
  500. if all(i not in output_name_to_node and i not in self.input() for i in n.input):
  501. q.append(n)
  502. while q:
  503. n = q.popleft()
  504. if not all(output_name_to_node[i].name in all_nodes for i in n.input if i in output_name_to_node):
  505. if n not in wait:
  506. wait.append(n)
  507. continue
  508. all_nodes[n.name] = n
  509. for out in n.output:
  510. if out in input_name_to_nodes:
  511. q.extend([i for i in input_name_to_nodes[out] if i.name not in all_nodes and i not in q])
  512. if len(q) == 0 and len(wait) != 0:
  513. q = copy.deepcopy(wait)
  514. wait.clear()
  515. nodes = [i[1] for i in all_nodes.items()]
  516. assert len(list({n.name for n in nodes})) == len(list({n.name for n in self.model.graph.node}))
  517. self.model.graph.ClearField("node")
  518. self.model.graph.node.extend(nodes)
  519. def get_nodes_chain(self, start, stop, result_chain=None):
  520. """Get nodes chain with given start node and stop node."""
  521. if result_chain is None:
  522. result_chain = []
  523. # process start node list
  524. start_node = deque()
  525. for node in start:
  526. if isinstance(node, str):
  527. start_node.append(node)
  528. elif isinstance(node, onnx.NodeProto):
  529. start_node.append(node.name)
  530. else:
  531. assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011
  532. # process stop node list
  533. stop_node = []
  534. for node in stop:
  535. if isinstance(node, str):
  536. stop_node.append(node)
  537. elif isinstance(node, onnx.NodeProto):
  538. stop_node.append(node.name)
  539. else:
  540. assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011
  541. while start_node:
  542. node_name = start_node.popleft()
  543. if node_name in stop_node:
  544. continue
  545. if node_name not in result_chain:
  546. result_chain.append(node_name)
  547. else:
  548. continue
  549. node = find_by_name(node_name, list(self.model.graph.node))
  550. for parent in self.get_parents(node):
  551. start_node.append(parent.name)
  552. return result_chain
  553. def find_split_node_for_layer_wise_quantization(self):
  554. """Find split node for layer wise quantization."""
  555. # find split nodes of decoder blocks
  556. # embed -> decoder.0 -(split_node)-> ... -(split_node)-> decoder.n -(split_node)-> norm -> head
  557. # after split: embed -> decoder.0,
  558. # decoder.1,
  559. # decoder.2,
  560. # ...,
  561. # decoder.n,
  562. # norm -> head
  563. start_nodes = []
  564. for node in self._model.graph.node:
  565. start_node, qkv_nodes_list = None, None
  566. if node.op_type == "SkipLayerNormalization":
  567. start_node = node
  568. qkv_nodes_list = [
  569. self.match_parent_path(
  570. start_node,
  571. ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  572. [None, 0, 0, 0, 0],
  573. ),
  574. self.match_parent_path(
  575. start_node,
  576. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  577. [1, 1, 0, 0, 0],
  578. ),
  579. ]
  580. if node.op_type == "Add":
  581. start_node = node
  582. qkv_nodes_list = [
  583. # match base attention structure
  584. self.match_parent_path(
  585. start_node,
  586. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  587. [0, None, 0, 0, 0],
  588. ),
  589. self.match_parent_path(
  590. start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0]
  591. ),
  592. # match gpt attention no past structure
  593. self.match_parent_path(
  594. start_node,
  595. ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
  596. [None, 0, 0, 0, 0, 0],
  597. output_name_to_node=self.output_name_to_node,
  598. return_indice=[],
  599. ),
  600. # match bart attention structure
  601. self.match_parent_path(
  602. start_node,
  603. ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  604. [0, None, 0, 0, 0, 0],
  605. ),
  606. self.match_parent_path(
  607. start_node,
  608. ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  609. [1, None, 0, 0, 0, 0],
  610. ),
  611. self.match_parent_path(
  612. start_node,
  613. ["MatMul", "Mul", "MatMul", "Mul", "Div", "Add"],
  614. [None, 0, None, 0, None, 0],
  615. ),
  616. self.match_parent_path(
  617. start_node,
  618. ["MatMul", "Mul", "MatMul", "SimplifiedLayerNormalization", "Add"],
  619. [None, 0, None, 0, 0],
  620. ),
  621. ]
  622. if not start_node:
  623. continue
  624. if not any(qkv_nodes_list):
  625. continue
  626. start_nodes.append(start_node)
  627. return start_nodes
  628. def find_qkv_in_attention(self, find_all=False):
  629. """Find qkv MatMul in Attention.
  630. Args:
  631. find_all (bool, optional): find all qkv MatMul. Defaults to False
  632. Returns:
  633. qkv (list): qkv MatMul list
  634. """
  635. qkv = []
  636. for node in self._model.graph.node:
  637. if node.op_type == "Attention":
  638. qkv.append([node.name])
  639. continue
  640. start_node, qkv_nodes_list = None, None
  641. if node.op_type == "SkipLayerNormalization":
  642. start_node = node
  643. qkv_nodes_list = [
  644. self.match_parent_path(
  645. start_node,
  646. ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  647. [None, 0, 0, 0, 0],
  648. ),
  649. self.match_parent_path(
  650. start_node,
  651. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  652. [1, 1, 0, 0, 0],
  653. ),
  654. ]
  655. if node.op_type == "Add":
  656. start_node = node
  657. qkv_nodes_list = [
  658. # match base attention structure
  659. self.match_parent_path(
  660. start_node,
  661. ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
  662. [0, None, 0, 0, 0],
  663. ),
  664. self.match_parent_path(
  665. start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0]
  666. ),
  667. # match gpt attention no past structure
  668. self.match_parent_path(
  669. start_node,
  670. ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
  671. [None, 0, 0, 0, 0, 0],
  672. output_name_to_node=self.output_name_to_node,
  673. return_indice=[],
  674. ),
  675. # match bart attention structure
  676. self.match_parent_path(
  677. start_node,
  678. ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  679. [0, None, 0, 0, 0, 0],
  680. ),
  681. self.match_parent_path(
  682. start_node,
  683. ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
  684. [1, None, 0, 0, 0, 0],
  685. ),
  686. ]
  687. if not start_node:
  688. continue
  689. if not any(qkv_nodes_list):
  690. continue
  691. qkv_nodes = [qkv for qkv in qkv_nodes_list if qkv is not None][-1]
  692. other_inputs = []
  693. for input in start_node.input:
  694. if input not in self.output_name_to_node:
  695. continue
  696. if input == qkv_nodes[0].output[0]:
  697. continue
  698. other_inputs.append(input)
  699. if len(other_inputs) != 1:
  700. continue
  701. root_input = other_inputs[0]
  702. input_name_to_nodes = self.input_name_to_nodes
  703. children = input_name_to_nodes[root_input]
  704. children_types = [child.op_type for child in children]
  705. if children_types.count("MatMul") == 3:
  706. qkv.append([child.name for child in children if child.op_type == "MatMul"])
  707. if not find_all:
  708. break
  709. return qkv
  710. def find_ffn_matmul(self, attention_index, attention_matmul_list, block_len):
  711. """Find MatMul in FFN.
  712. Args:
  713. attention_index (list): index of Attention
  714. attention_matmul_list (list): list of Attention and MatMul nodes
  715. block_len (int): block length
  716. Returns:
  717. list: list of MatMul in FFN
  718. """
  719. ffn_matmul = []
  720. for idx in range(len(attention_index)):
  721. if idx != len(attention_index) - 1:
  722. index = attention_index[idx + 1]
  723. if index - 2 >= 0:
  724. ffn_matmul.append([attention_matmul_list[index - 2], attention_matmul_list[index - 1]])
  725. else:
  726. index = attention_index[idx]
  727. if index + block_len - 1 < len(attention_matmul_list):
  728. ffn_matmul.append(
  729. [attention_matmul_list[index + block_len - 2], attention_matmul_list[index + block_len - 1]]
  730. )
  731. return ffn_matmul
  732. def export(self, save_path, conf):
  733. """Export Qlinear to QDQ model."""
  734. from neural_compressor.config import ONNXQlinear2QDQConfig # noqa: PLC0415
  735. from neural_compressor.utils.export import onnx_qlinear_to_qdq # noqa: PLC0415
  736. if isinstance(conf, ONNXQlinear2QDQConfig):
  737. add_nodes, remove_nodes, inits = onnx_qlinear_to_qdq(self._model, self._input_name_to_nodes)
  738. self.add_nodes(add_nodes)
  739. self.remove_nodes(remove_nodes)
  740. self.add_initializers(inits)
  741. self.update()
  742. self.remove_unused_nodes()
  743. self.topological_sort()
  744. self.save(save_path)
  745. else:
  746. logger.warning("Unsupported config for export, only ONNXQlinear2QDQConfig is supported!")
  747. exit(0)
  748. def add_tensors_to_outputs(self, tensor_names):
  749. """Add the tensors to the model outputs to gets their values.
  750. Args:
  751. tensor_names: The names of tensors to be dumped.
  752. """
  753. added_outputs = []
  754. for tensor in tensor_names:
  755. if tensor not in self.output():
  756. added_tensor = onnx.helper.ValueInfoProto()
  757. added_tensor.name = tensor
  758. added_outputs.append(added_tensor)
  759. self._model.graph.output.extend(added_outputs) # pylint: disable=no-member
  760. def remove_tensors_from_outputs(self, tensor_names):
  761. """Remove the tensors from the model outputs.
  762. Args:
  763. tensor_names: The names of tensors to be removed.
  764. """
  765. removed_outputs = []
  766. for tensor in tensor_names:
  767. if tensor in self.output():
  768. removed_outputs.append(self._model.graph.output[self.output().index(tensor)])
  769. for output in removed_outputs:
  770. self._model.graph.output.remove(output)
  771. def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=None):
  772. """Find parent node based on constraints on op_type.
  773. Args:
  774. node (str): current node name.
  775. parent_op_type (str): constraint of parent node op_type.
  776. output_name_to_node (dict): dictionary with output name as key, and node as value.
  777. exclude (list): list of nodes that are excluded (not allowed to match as parent).
  778. Returns:
  779. parent: The matched parent node. None if not found.
  780. index: The input index of matched parent node. None if not found.
  781. """
  782. if exclude is None:
  783. exclude = []
  784. for i, input in enumerate(node.input):
  785. if input in output_name_to_node:
  786. parent = output_name_to_node[input]
  787. if parent.op_type == parent_op_type and parent not in exclude:
  788. return parent, i
  789. return None, None
  790. def match_parent(
  791. self,
  792. node,
  793. parent_op_type,
  794. input_index=None,
  795. output_name_to_node=None,
  796. exclude=None,
  797. return_indice=None,
  798. ):
  799. """Find parent node based on constraints on op_type and index.
  800. Args:
  801. node (str): current node name.
  802. parent_op_type (str): constraint of parent node op_type.
  803. input_index (int or None): only check the parent given input index of current node.
  804. output_name_to_node (dict): dictionary with output name as key, and node as value.
  805. exclude (list): list of nodes that are excluded (not allowed to match as parent).
  806. return_indice (list): a list to append the input index when input_index is None.
  807. Returns:
  808. parent: The matched parent node.
  809. """
  810. assert node is not None
  811. assert input_index is None or input_index >= 0
  812. if exclude is None:
  813. exclude = []
  814. if output_name_to_node is None:
  815. output_name_to_node = self._output_name_to_node
  816. if input_index is None:
  817. parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
  818. if return_indice is not None:
  819. return_indice.append(index)
  820. return parent
  821. if input_index >= len(node.input):
  822. return None
  823. parent = self.get_parent(node, input_index, output_name_to_node)
  824. if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
  825. return parent
  826. return None
  827. def match_parent_path(
  828. self,
  829. node,
  830. parent_op_types,
  831. parent_input_index,
  832. output_name_to_node=None,
  833. return_indice=None,
  834. ):
  835. """Find a sequence of input edges based on constraints on parent op_type and index.
  836. Args:
  837. node (str): current node name.
  838. parent_op_types (str): constraint of parent node op_type of each input edge.
  839. parent_input_index (list): constraint of input index of each input edge.
  840. None means no constraint.
  841. output_name_to_node (dict): dictionary with output name as key, and node as value.
  842. return_indice (list): a list to append the input index when there is
  843. no constraint on input index of an edge.
  844. Returns:
  845. parents: a list of matched parent node.
  846. """
  847. assert len(parent_input_index) == len(parent_op_types)
  848. if output_name_to_node is None:
  849. output_name_to_node = self._output_name_to_node
  850. current_node = node
  851. matched_parents = []
  852. for i, op_type in enumerate(parent_op_types):
  853. matched_parent = self.match_parent(
  854. current_node,
  855. op_type,
  856. parent_input_index[i],
  857. output_name_to_node,
  858. exclude=[],
  859. return_indice=return_indice,
  860. )
  861. if matched_parent is None:
  862. return None
  863. matched_parents.append(matched_parent)
  864. current_node = matched_parent
  865. return matched_parents
  866. def is_smoothquant_model(self):
  867. """Check the model is smooth quantized or not.
  868. Returns:
  869. bool: the model is smooth quantized or not.
  870. """
  871. for init in self.model.graph.initializer: # noqa: SIM110
  872. if "_smooth_scale" in init.name:
  873. return True
  874. return False
  875. def find_split_nodes(self):
  876. """Find split nodes for layer-wise quantization."""
  877. split_nodes = self.find_split_node_for_layer_wise_quantization()
  878. return split_nodes
  879. def split_model_with_node(
  880. self, split_node_name, path_of_model_to_split, shape_infer=True, save_both_split_models=True
  881. ):
  882. """Split model into two parts at a given node.
  883. Args:
  884. split_node_name (str): name of the node where the model is split at>
  885. path_of_model_to_split (str): path of model to be split.
  886. shape_infer (bool): do shape inference. Default is True.
  887. save_both_split_models (bool): whether to save the two split models.
  888. False means only save the first split model.
  889. True means save both the two split models.
  890. Default id True.
  891. Returns:
  892. tuple: the first split model, the second split model
  893. """
  894. # origin model : ... -> node_1 -> split_node -> node_2 -> ...
  895. # split model 1: ... -> node_1 -> split_node
  896. # split model 2: node_2 -> ...
  897. split_model_part_1 = onnx.ModelProto()
  898. split_model_part_1.CopyFrom(self._model)
  899. split_model_part_1.graph.ClearField("node")
  900. split_model_part_2 = onnx.ModelProto()
  901. split_model_part_2.CopyFrom(self._model)
  902. split_model_part_2.graph.ClearField("node")
  903. split_node_output = None
  904. part_idx = 1
  905. for node in self._model.graph.node:
  906. if part_idx == 1:
  907. split_model_part_1.graph.node.append(node)
  908. elif part_idx == 2:
  909. split_model_part_2.graph.node.append(node)
  910. if node.name == split_node_name:
  911. split_node_output = node.output
  912. part_idx = 2
  913. assert len(split_node_output) == 1, (
  914. f"Only support split at node with 1 output tensor, while current split node {split_node_name} has {len(split_node_output)} output tensors"
  915. )
  916. split_tensor_name = split_node_output[0]
  917. # infer shape of the model to be split
  918. if shape_infer:
  919. try:
  920. from neural_compressor.adaptor.ox_utils.util import infer_shapes # noqa: PLC0415
  921. self._model = infer_shapes(self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path))
  922. except Exception as e: # pragma: no cover
  923. logger.error(
  924. "Shape infer fails for layer-wise quantization. "
  925. "We would recommend checking the graph optimization level of your model "
  926. "and setting it to 'DISABLE_ALL' or 'ENABLE_BASIC', "
  927. "as this may help avoid this error."
  928. )
  929. raise e
  930. split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name)
  931. split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape)
  932. split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True)
  933. split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True)
  934. # remove unused input & output
  935. split_model_part_1._remove_unused_input_output()
  936. split_model_part_2._remove_unused_input_output()
  937. split_model_part_1.model.graph.output.append(split_tensor)
  938. split_model_part_2.model.graph.input.append(split_tensor)
  939. insert_output_for_model_1 = []
  940. insert_input_for_model_2 = []
  941. for output in split_model_part_1.output_name_to_node:
  942. if output in split_model_part_2.input_name_to_nodes:
  943. output_type, output_shape = self._get_output_type_shape_by_tensor_name(output)
  944. output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape)
  945. if output_tensor not in split_model_part_1.model.graph.output:
  946. insert_output_for_model_1.append(output_tensor)
  947. if output_tensor not in split_model_part_2.model.graph.input:
  948. insert_input_for_model_2.append(output_tensor)
  949. # insert model 1 output
  950. for output in insert_output_for_model_1:
  951. split_model_part_1.model.graph.output.append(output)
  952. # insert model 2 input
  953. for input in insert_input_for_model_2:
  954. split_model_part_2.model.graph.input.append(input)
  955. # remove unused init
  956. split_model_part_1.remove_unused_init()
  957. split_model_part_2.remove_unused_init()
  958. split_model_part_1.update()
  959. split_model_part_2.update()
  960. dir_of_model_to_split = os.path.dirname(path_of_model_to_split)
  961. split_model_part_1.load_model_initializer_by_tensor(dir_of_model_to_split)
  962. split_model_part_1_path = os.path.join(dir_of_model_to_split, "split_model_part_1.onnx")
  963. split_model_part_1.model_path = split_model_part_1_path
  964. split_model_part_1._save_split_model(split_model_part_1_path)
  965. split_model_part_1.check_is_large_model()
  966. logger.debug(f"save split model part 1 to {split_model_part_1_path} for layer wise quantization")
  967. if save_both_split_models:
  968. split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split)
  969. split_model_part_2_path = os.path.join(dir_of_model_to_split, "split_model_part_2.onnx")
  970. split_model_part_2.model_path = split_model_part_2_path
  971. split_model_part_2._save_split_model(split_model_part_2_path)
  972. split_model_part_2.check_is_large_model()
  973. logger.debug(f"save split model part 2 to {split_model_part_2_path} for layer wise quantization")
  974. return split_model_part_1, split_model_part_2
  975. else:
  976. return split_model_part_1, split_model_part_2
  977. def _save_split_model(self, save_path):
  978. """Save split model as external data for layer wise quantization.
  979. Args:
  980. save_path (str): the path to save the split model
  981. """
  982. if os.path.exists(save_path + "_data"):
  983. os.remove(save_path + "_data")
  984. onnx.save_model(
  985. self._model,
  986. save_path,
  987. save_as_external_data=True,
  988. all_tensors_to_one_file=True,
  989. location=save_path.split("/")[-1] + "_data",
  990. size_threshold=1024,
  991. convert_attribute=False,
  992. )
  993. def _get_output_type_shape_by_tensor_name(self, tensor_name):
  994. """Get output type and shape with a tensor name.
  995. Args:
  996. tensor_name (str): name of a tensor
  997. Returns:
  998. tuple: output type and shape
  999. """
  1000. elem_type = onnx.TensorProto.FLOAT
  1001. shape = None
  1002. for output in self._model.graph.value_info:
  1003. if output.name == tensor_name:
  1004. elem_type = output.type.tensor_type.elem_type
  1005. shape = [
  1006. dim.dim_value if dim.HasField("dim_value") else -1 for dim in output.type.tensor_type.shape.dim
  1007. ]
  1008. break
  1009. return elem_type, shape
  1010. def _remove_unused_input_output(self):
  1011. """Remove unused input & output for split model."""
  1012. remove_outputs = []
  1013. remove_inputs = []
  1014. for output in self._model.graph.output:
  1015. if output.name not in self.output_name_to_node:
  1016. remove_outputs.append(output)
  1017. for input in self._model.graph.input:
  1018. if input.name not in self.input_name_to_nodes:
  1019. remove_inputs.append(input)
  1020. for output in remove_outputs:
  1021. self._model.graph.output.remove(output)
  1022. for input in remove_inputs:
  1023. self._model.graph.input.remove(input)
  1024. def remove_unused_init(self):
  1025. """Remove unused init."""
  1026. remov_inits = []
  1027. for init in self._model.graph.initializer:
  1028. if init.name not in self.input_name_to_nodes:
  1029. remov_inits.append(init)
  1030. self.remove_initializers(remov_inits)
  1031. def load_model_initializer_by_tensor(self, data_path=None):
  1032. """Load model initializer by tensor.
  1033. Args:
  1034. data_path (str, optional): the directory of saved initializer. Defaults to None.
  1035. """
  1036. if data_path is None:
  1037. data_path = os.path.dirname(self._model_path)
  1038. for init in self._model.graph.initializer:
  1039. if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
  1040. onnx.external_data_helper.load_external_data_for_tensor(init, data_path)
  1041. def write_external_data_to_new_location(self, external_data_location="external.data", overwrite=False):
  1042. """Write external data of merged quantized model to new location to save memory.
  1043. Args:
  1044. external_data_location (str, optional): external data location of merged quantized model.
  1045. Defaults to "external.data".
  1046. overwrite (bool, optional): if True, remove existed externa data. Defaults to False.
  1047. """
  1048. if overwrite and os.path.exists(os.path.join(os.path.dirname(self._model_path), external_data_location)):
  1049. os.remove(os.path.join(os.path.dirname(self._model_path), external_data_location))
  1050. self.load_model_initializer_by_tensor()
  1051. onnx.external_data_helper.convert_model_to_external_data(self._model, location=external_data_location)
  1052. # TODO : if init is already saved, skip write it
  1053. onnx.external_data_helper.write_external_data_tensors(self._model, filepath=os.path.dirname(self._model_path))
  1054. def merge_split_models(self, to_merge_model):
  1055. """Merge two split model into final model."""
  1056. to_merge_model.write_external_data_to_new_location()
  1057. self.add_nodes(list(to_merge_model.nodes()))
  1058. self.add_initializers(list(to_merge_model.initializer()))
  1059. self.update()
  1060. # add new output
  1061. for output in to_merge_model.graph().output:
  1062. if output.name not in self.output():
  1063. self._model.graph.output.append(output)
  1064. # remove unused output
  1065. remove_output = []
  1066. for output in self._model.graph.output:
  1067. if output.name in to_merge_model.input():
  1068. remove_output.append(output)
  1069. for output in remove_output:
  1070. self._model.graph.output.remove(output)
  1071. # add new input
  1072. for input in to_merge_model.graph().input:
  1073. if (
  1074. input.name not in self.input()
  1075. and input.name not in self.output()
  1076. and input.name not in self.output_name_to_node
  1077. ):
  1078. self._model.graph.input.append(input)
  1079. def re_org_output(self, origin_output):
  1080. """Re-org output of merged model for layer-wise quantization."""
  1081. outputs = {}
  1082. tmp_remove = []
  1083. for output in self._model.graph.output:
  1084. outputs[output.name] = output
  1085. tmp_remove.append(output)
  1086. for output in tmp_remove:
  1087. self._model.graph.output.remove(output)
  1088. for out_name in origin_output:
  1089. self._model.graph.output.append(outputs[out_name])