fusion_qordered_gelu.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. from logging import getLogger
  6. from fusion_base import Fusion
  7. from fusion_utils import FusionUtils
  8. from onnx import helper
  9. from onnx_model import OnnxModel
  10. logger = getLogger(__name__)
  11. class FusionQOrderedGelu(Fusion):
  12. def __init__(self, model: OnnxModel):
  13. super().__init__(model, "QOrderedGelu", ["Gelu", "FastGelu"])
  14. def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
  15. """
  16. INPUT PATTERN
  17. Fuse (quantized) Gelu subgraph into one node QOrderedGelu:
  18. -> quantized input -> DQ -> Gelu -> Q ->
  19. (or)
  20. -> quantized input -> DQ -> FastGelu -> Q ->
  21. OUTPUT PATTERN
  22. -> QOrderedGelu ->
  23. """
  24. gelu_children = self.model.get_children(node, input_name_to_nodes)
  25. # Should only have 1 child - QuantizeLinear (or)
  26. # Should have 2 children - QuantizeLinear + Shape
  27. if not (
  28. (len(gelu_children) == 1 and gelu_children[0].op_type == "QuantizeLinear")
  29. or (
  30. len(gelu_children) == 2
  31. and gelu_children[0].op_type == "QuantizeLinear"
  32. and gelu_children[1].op_type == "Shape"
  33. )
  34. ):
  35. return
  36. downstream_quantize_node = gelu_children[0]
  37. downstream_shape_node = None
  38. if len(gelu_children) == 2:
  39. downstream_shape_node = gelu_children[1]
  40. if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
  41. return
  42. # The first input to Gelu should flow through a DequantizeLinear node
  43. first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
  44. node,
  45. [(["DequantizeLinear"], [0])],
  46. output_name_to_node,
  47. )
  48. if first_path_id < 0:
  49. return
  50. upstream_dequantize_node = first_input_parent_nodes[0]
  51. if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model):
  52. return
  53. # Fusion logic
  54. subgraph_nodes = [node] # Gelu/FastGelu
  55. subgraph_nodes.extend([downstream_quantize_node, upstream_dequantize_node]) # Relevant Q, DQ nodes
  56. if not self.model.is_safe_to_fuse_nodes(
  57. subgraph_nodes,
  58. (
  59. [node.output[0], downstream_quantize_node.output[0]]
  60. if downstream_shape_node is not None
  61. else downstream_quantize_node.output
  62. ),
  63. input_name_to_nodes,
  64. output_name_to_node,
  65. ):
  66. logger.debug("It is not safe to fuse QOrderedGelu node. Skip")
  67. return
  68. self.nodes_to_remove.extend(subgraph_nodes)
  69. ordered_gelu_node = helper.make_node(
  70. "QOrderedGelu",
  71. inputs=[
  72. upstream_dequantize_node.input[0],
  73. upstream_dequantize_node.input[1],
  74. downstream_quantize_node.input[1],
  75. ],
  76. outputs=[downstream_quantize_node.output[0]],
  77. name=self.model.create_node_name("QOrderedGelu", name_prefix="QOrderedGelu"),
  78. )
  79. # Arrange the downstream Shape's input to be fed from the
  80. # downstream QuantizeLinear node, so that fusion will
  81. # be deemed safe
  82. if downstream_shape_node is not None:
  83. self.model.replace_node_input(
  84. downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0]
  85. )
  86. # TODO: We only support CuBlasLt order ORDER_ROW for now.
  87. # Once we start supporting other data ordering format(s), we
  88. # will support user configuring the data ordering for the op.
  89. ordered_gelu_node.attribute.extend([helper.make_attribute("order_X", 1)])
  90. ordered_gelu_node.attribute.extend([helper.make_attribute("order_Y", 1)])
  91. ordered_gelu_node.domain = "com.microsoft"
  92. self.nodes_to_add.append(ordered_gelu_node)
  93. self.node_name_to_graph_name[ordered_gelu_node.name] = self.this_graph_name