| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- import logging
- from fusion_base import Fusion
- from fusion_skiplayernorm import FusionSkipLayerNormalization
- from onnx import helper
- from onnx_model import OnnxModel
- logger = logging.getLogger(__name__)
- class FusionSimplifiedLayerNormalization(Fusion):
- def __init__(self, model: OnnxModel):
- super().__init__(model, "SimplifiedLayerNormalization", "Mul")
- def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
- if node.op_type != "Mul":
- return
- sim_ln_nodes = None
- # RMSNorm formula:
- # S = Pow(X, 2) or S = Mul(X, X)
- # MS = ReduceMean(S)
- # MSEps = Add(MS, epsilon)
- # RMS = Sqrt(MSEps)
- # InvRMS = Div(1, RMS) or InvRMS = Reciprocal(RMS)
- # Normalized = Mul(D, InvRMS)
- # Y = Mul(Normalized, Scale)
- #
- # (root_input) ----------------------------------------+
- # | |
- # v v
- # Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node)
- # (B=2) (A/B=eps) (A=1) (A/B=scale)
- #
- # (root_input) ----------------------------------------+
- # | | |
- # v v v
- # Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node)
- # (B=2) (A/B=eps) (A=1) (A/B=scale)
- #
- return_indice = []
- sim_ln_nodes = self.model.match_parent_path(
- node,
- ["Mul", "Div", "Sqrt", "Add", "ReduceMean"],
- [None, 1, 1, 0, None],
- output_name_to_node=output_name_to_node,
- return_indice=return_indice,
- )
- if sim_ln_nodes:
- mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
- if not self.model.has_constant_input(div_node, 1.0):
- return
- node_parent = mul_node
- else:
- # Div(1, RMS) can also be represented as Reciprocal(RMS) like
- #
- # (root_input) -----------------------------------------------+
- # | |
- # v v
- # Pow --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node)
- # (B=2) (A/B=eps) (A/B=scale)
- #
- # (root_input) -----------------------------------------------+
- # | | |
- # v v v
- # Mul --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node)
- # (B=2) (A/B=eps) (A/B=scale)
- #
- return_indice = []
- sim_ln_nodes = self.model.match_parent_path(
- node,
- ["Mul", "Reciprocal", "Sqrt", "Add", "ReduceMean"],
- [None, 1, 0, 0, None],
- output_name_to_node=output_name_to_node,
- return_indice=return_indice,
- )
- if sim_ln_nodes is not None:
- mul_node, _reciprocal_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
- node_parent = mul_node
- else:
- # (root_input) --------------------------------+
- # | |
- # v v
- # Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node)
- # (B=2) (A/B=eps) (A/B=scale)
- #
- # (root_input) --------------------------------+
- # | | |
- # v v v
- # Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node)
- # (B=2) (A/B=eps) (A/B=scale)
- #
- return_indice = []
- sim_ln_nodes = self.model.match_parent_path(
- node,
- ["Div", "Sqrt", "Add", "ReduceMean"],
- [None, 1, 0, None],
- output_name_to_node=output_name_to_node,
- return_indice=return_indice,
- )
- if sim_ln_nodes is not None:
- div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
- node_parent = div_node
- else:
- return
- reduce_mean_parent = self.model.get_parent(reduce_mean_node, 0, output_name_to_node)
- if reduce_mean_parent is None or reduce_mean_parent.op_type not in ["Pow", "Mul"]:
- return
- if reduce_mean_parent.op_type == "Pow":
- if self.model.find_constant_input(reduce_mean_parent, 2.0) != 1:
- return
- else:
- assert reduce_mean_parent.op_type == "Mul"
- if reduce_mean_parent[0] != reduce_mean_parent[1]:
- return
- root_input = reduce_mean_parent.input[0]
- if root_input not in node_parent.input:
- return
- _i, epsilon = self.model.get_constant_input(add_node)
- if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4:
- logger.warning(f"epsilon value is not expected: {epsilon}")
- return
- # ReduceMean must have keepdims == 1
- keepdims = self.model.get_node_attribute(reduce_mean_node, "keepdims")
- if not keepdims:
- return
- # ReduceMean axes must refer only to the last dimension.
- # Axes became an input in opset 18. Before then, axes was an attribute.
- axes = self.model.get_node_attribute(reduce_mean_node, "axes")
- if (not axes) and len(reduce_mean_node.input) > 1:
- axes = self.model.get_constant_value(reduce_mean_node.input[1])
- # Make sure only one axis as required by SimplifiedLayerNormalization spec.
- if not axes or len(axes) != 1:
- return
- self.nodes_to_remove.extend(sim_ln_nodes)
- self.nodes_to_remove.append(reduce_mean_parent)
- self.nodes_to_remove.append(node)
- normalize_node = helper.make_node(
- "SimplifiedLayerNormalization",
- inputs=[root_input, node.input[1 - return_indice[0]]],
- outputs=[node.output[0]],
- name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="RMSNorm"),
- )
- normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
- normalize_node.attribute.extend([helper.make_attribute("axis", axes[0])])
- normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)])
- self.nodes_to_add.append(normalize_node)
- self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
- class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
- def __init__(self, model: OnnxModel):
- super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
- def fuse(self, node, input_name_to_nodes, output_name_to_node):
- super().fuse(node, input_name_to_nodes, output_name_to_node)
|