fusion_layernorm.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. from __future__ import annotations
  7. import onnx
  8. from ..onnx_model import ONNXModel
  9. from .fusion import Fusion
  10. class FusionLayerNormalization(Fusion):
  11. def __init__(self, model: ONNXModel):
  12. super().__init__(model, "LayerNormalization", "ReduceMean")
  13. def fuse(
  14. self,
  15. reduce_mean_node: onnx.NodeProto,
  16. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  17. output_name_to_node: dict[str, onnx.NodeProto],
  18. ):
  19. """
  20. Interface function that tries to fuse a node sequence containing a ReduceMean node into a single
  21. LayerNormalization node.
  22. +----------------------+
  23. | |
  24. | v
  25. [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
  26. (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
  27. | |
  28. +-------------------------------------------------+
  29. It also handles cases of duplicated sub nodes exported from older version of PyTorch:
  30. +----------------------+
  31. | v
  32. | +-------> Sub-----------------------------------------------+
  33. | | |
  34. | | v
  35. [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
  36. | ^
  37. | |
  38. +----------------------+
  39. """
  40. children = self.model.get_children(reduce_mean_node, input_name_to_nodes)
  41. if len(children) == 0 or len(children) > 2:
  42. return
  43. root_input = reduce_mean_node.input[0]
  44. if children[0].op_type != "Sub" or children[0].input[0] != root_input:
  45. return
  46. if len(children) == 2:
  47. if children[1].op_type != "Sub" or children[1].input[0] != root_input:
  48. return
  49. div_node = None
  50. for child in children:
  51. div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
  52. if div_node is not None:
  53. break
  54. if div_node is None:
  55. return
  56. path_id, parent_nodes, _ = self.match_parent_paths(
  57. div_node,
  58. [
  59. (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
  60. (
  61. ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
  62. [1, 0, 0, 0, 0, 0],
  63. ),
  64. ],
  65. output_name_to_node,
  66. )
  67. if path_id < 0:
  68. return
  69. sub_node = parent_nodes[-1]
  70. if sub_node not in children:
  71. return
  72. second_add_node = parent_nodes[1]
  73. i, add_weight = self.get_constant_input(second_add_node)
  74. if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
  75. # Skip fusion since epsilon value is not expected.
  76. return
  77. pow_node = parent_nodes[3]
  78. if self.find_constant_input(pow_node, 2.0) != 1:
  79. return
  80. mul_node = input_name_to_nodes[div_node.output[0]][0]
  81. if mul_node.op_type != "Mul":
  82. return
  83. last_add_node = input_name_to_nodes[mul_node.output[0]][0]
  84. if last_add_node.op_type != "Add":
  85. return
  86. subgraph_nodes = [reduce_mean_node]
  87. subgraph_nodes.extend(children)
  88. subgraph_nodes.extend(parent_nodes[:-1])
  89. subgraph_nodes.extend([last_add_node, mul_node, div_node])
  90. if not self.is_safe_to_fuse_nodes(
  91. subgraph_nodes,
  92. last_add_node.output,
  93. input_name_to_nodes,
  94. output_name_to_node,
  95. ):
  96. return
  97. weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)]
  98. if not self.is_constant_with_specified_rank(weight_input, 1):
  99. return
  100. bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)]
  101. if not self.is_constant_with_specified_rank(bias_input, 1):
  102. return
  103. self.nodes_to_remove.extend(subgraph_nodes)
  104. normalize_node = onnx.helper.make_node(
  105. "LayerNormalization",
  106. name=self.create_unique_node_name(),
  107. inputs=[reduce_mean_node.input[0], weight_input, bias_input],
  108. outputs=[last_add_node.output[0]],
  109. )
  110. normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))])
  111. self.nodes_to_add.append(normalize_node)