fusion_simplified_layernorm.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import logging
  2. from fusion_base import Fusion
  3. from fusion_skiplayernorm import FusionSkipLayerNormalization
  4. from onnx import helper
  5. from onnx_model import OnnxModel
  6. logger = logging.getLogger(__name__)
  7. class FusionSimplifiedLayerNormalization(Fusion):
  8. def __init__(self, model: OnnxModel):
  9. super().__init__(model, "SimplifiedLayerNormalization", "Mul")
  10. def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
  11. if node.op_type != "Mul":
  12. return
  13. sim_ln_nodes = None
  14. # RMSNorm formula:
  15. # S = Pow(X, 2) or S = Mul(X, X)
  16. # MS = ReduceMean(S)
  17. # MSEps = Add(MS, epsilon)
  18. # RMS = Sqrt(MSEps)
  19. # InvRMS = Div(1, RMS) or InvRMS = Reciprocal(RMS)
  20. # Normalized = Mul(D, InvRMS)
  21. # Y = Mul(Normalized, Scale)
  22. #
  23. # (root_input) ----------------------------------------+
  24. # | |
  25. # v v
  26. # Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node)
  27. # (B=2) (A/B=eps) (A=1) (A/B=scale)
  28. #
  29. # (root_input) ----------------------------------------+
  30. # | | |
  31. # v v v
  32. # Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node)
  33. # (B=2) (A/B=eps) (A=1) (A/B=scale)
  34. #
  35. return_indice = []
  36. sim_ln_nodes = self.model.match_parent_path(
  37. node,
  38. ["Mul", "Div", "Sqrt", "Add", "ReduceMean"],
  39. [None, 1, 1, 0, None],
  40. output_name_to_node=output_name_to_node,
  41. return_indice=return_indice,
  42. )
  43. if sim_ln_nodes:
  44. mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
  45. if not self.model.has_constant_input(div_node, 1.0):
  46. return
  47. node_parent = mul_node
  48. else:
  49. # Div(1, RMS) can also be represented as Reciprocal(RMS) like
  50. #
  51. # (root_input) -----------------------------------------------+
  52. # | |
  53. # v v
  54. # Pow --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node)
  55. # (B=2) (A/B=eps) (A/B=scale)
  56. #
  57. # (root_input) -----------------------------------------------+
  58. # | | |
  59. # v v v
  60. # Mul --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node)
  61. # (B=2) (A/B=eps) (A/B=scale)
  62. #
  63. return_indice = []
  64. sim_ln_nodes = self.model.match_parent_path(
  65. node,
  66. ["Mul", "Reciprocal", "Sqrt", "Add", "ReduceMean"],
  67. [None, 1, 0, 0, None],
  68. output_name_to_node=output_name_to_node,
  69. return_indice=return_indice,
  70. )
  71. if sim_ln_nodes is not None:
  72. mul_node, _reciprocal_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
  73. node_parent = mul_node
  74. else:
  75. # (root_input) --------------------------------+
  76. # | |
  77. # v v
  78. # Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node)
  79. # (B=2) (A/B=eps) (A/B=scale)
  80. #
  81. # (root_input) --------------------------------+
  82. # | | |
  83. # v v v
  84. # Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node)
  85. # (B=2) (A/B=eps) (A/B=scale)
  86. #
  87. return_indice = []
  88. sim_ln_nodes = self.model.match_parent_path(
  89. node,
  90. ["Div", "Sqrt", "Add", "ReduceMean"],
  91. [None, 1, 0, None],
  92. output_name_to_node=output_name_to_node,
  93. return_indice=return_indice,
  94. )
  95. if sim_ln_nodes is not None:
  96. div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
  97. node_parent = div_node
  98. else:
  99. return
  100. reduce_mean_parent = self.model.get_parent(reduce_mean_node, 0, output_name_to_node)
  101. if reduce_mean_parent is None or reduce_mean_parent.op_type not in ["Pow", "Mul"]:
  102. return
  103. if reduce_mean_parent.op_type == "Pow":
  104. if self.model.find_constant_input(reduce_mean_parent, 2.0) != 1:
  105. return
  106. else:
  107. assert reduce_mean_parent.op_type == "Mul"
  108. if reduce_mean_parent[0] != reduce_mean_parent[1]:
  109. return
  110. root_input = reduce_mean_parent.input[0]
  111. if root_input not in node_parent.input:
  112. return
  113. _i, epsilon = self.model.get_constant_input(add_node)
  114. if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4:
  115. logger.warning(f"epsilon value is not expected: {epsilon}")
  116. return
  117. # ReduceMean must have keepdims == 1
  118. keepdims = self.model.get_node_attribute(reduce_mean_node, "keepdims")
  119. if not keepdims:
  120. return
  121. # ReduceMean axes must refer only to the last dimension.
  122. # Axes became an input in opset 18. Before then, axes was an attribute.
  123. axes = self.model.get_node_attribute(reduce_mean_node, "axes")
  124. if (not axes) and len(reduce_mean_node.input) > 1:
  125. axes = self.model.get_constant_value(reduce_mean_node.input[1])
  126. # Make sure only one axis as required by SimplifiedLayerNormalization spec.
  127. if not axes or len(axes) != 1:
  128. return
  129. self.nodes_to_remove.extend(sim_ln_nodes)
  130. self.nodes_to_remove.append(reduce_mean_parent)
  131. self.nodes_to_remove.append(node)
  132. normalize_node = helper.make_node(
  133. "SimplifiedLayerNormalization",
  134. inputs=[root_input, node.input[1 - return_indice[0]]],
  135. outputs=[node.output[0]],
  136. name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="RMSNorm"),
  137. )
  138. normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
  139. normalize_node.attribute.extend([helper.make_attribute("axis", axes[0])])
  140. normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)])
  141. self.nodes_to_add.append(normalize_node)
  142. self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
  143. class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
  144. def __init__(self, model: OnnxModel):
  145. super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
  146. def fuse(self, node, input_name_to_nodes, output_name_to_node):
  147. super().fuse(node, input_name_to_nodes, output_name_to_node)