fusion_spacetodepth.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. """Define SpaceToDepth fusion."""
  7. import onnx
  8. from ... import fusions, onnx_model
  9. class FusionSpaceToDepth(fusions.Fusion):
  10. """Fusion for SpaceToDepth."""
  11. def __init__(self, model: onnx_model.ONNXModel):
  12. """Initialize.
  13. Args:
  14. model: An onnx_model.ONNXModel instance.
  15. """
  16. super().__init__(model, "SpaceToDepth", "Reshape")
  17. def _fuse_yolo(
  18. self,
  19. node: onnx.NodeProto,
  20. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  21. output_name_to_node: dict[str, onnx.NodeProto],
  22. ):
  23. """Fuse for early version of YOLO.
  24. Pattern:
  25. | [N, C, H, W]
  26. Reshape
  27. | [N, C, H/blk, blk, W/blk, blk]
  28. Transpose
  29. | [N, C, H/blk, W/blk, blk, blk]
  30. Reshape
  31. | [N, C, H/blk * W/blk, blk * blk]
  32. Transpose
  33. | [N, C, blk * blk, H/blk * W/blk]
  34. Reshape
  35. | [N, C, blk * blk, H/blk, W/blk]
  36. Transpose
  37. | [N, blk * blk, C, H/blk, W/blk]
  38. Reshape
  39. | [N, blk * blk * C, H/blk, W/blk]
  40. This sequence can be fused into a single SpaceToDepth with blocksize `blk`. Note that unlike DepthToSpace
  41. supporting DCR or CRD mode, SpaceToDepth only supports DCR mode in its latest opset version (13), which matches
  42. the pattern here.
  43. """
  44. reshape_node1 = node
  45. def get_target_child(parent_node, target_op_type):
  46. """Get target child of given node."""
  47. if parent_node.output[0] not in input_name_to_nodes:
  48. return None
  49. children = input_name_to_nodes[parent_node.output[0]]
  50. if len(children) > 1 or children[0].op_type != target_op_type:
  51. return None
  52. return children[0]
  53. if (
  54. (transpose_node1 := get_target_child(reshape_node1, "Transpose")) is None
  55. or (reshape_node2 := get_target_child(transpose_node1, "Reshape")) is None
  56. or (transpose_node2 := get_target_child(reshape_node2, "Transpose")) is None
  57. or (reshape_node3 := get_target_child(transpose_node2, "Reshape")) is None
  58. or (transpose_node3 := get_target_child(reshape_node3, "Transpose")) is None
  59. or (reshape_node4 := get_target_child(transpose_node3, "Reshape")) is None
  60. ):
  61. return False
  62. def get_tensor_shape(tensor_name):
  63. """Get shape for given tensor name."""
  64. tensor_type = self.model.get_tensor_type(tensor_name)
  65. if not tensor_type:
  66. return None
  67. tensor_shape = self.tensor_shape_to_list(tensor_type)
  68. if not tensor_shape:
  69. return None
  70. return tensor_shape
  71. if (
  72. (input_shape := get_tensor_shape(reshape_node1.input[0])) is None
  73. or (reshape_shape1 := get_tensor_shape(reshape_node1.output[0])) is None
  74. or (reshape_shape2 := get_tensor_shape(reshape_node2.output[0])) is None
  75. or (reshape_shape3 := get_tensor_shape(reshape_node3.output[0])) is None
  76. or (reshape_shape4 := get_tensor_shape(reshape_node4.output[0])) is None
  77. ):
  78. return False
  79. transpose_perm1 = self.get_node_attribute(transpose_node1, "perm")
  80. transpose_perm2 = self.get_node_attribute(transpose_node2, "perm")
  81. transpose_perm3 = self.get_node_attribute(transpose_node3, "perm")
  82. # Check rank.
  83. if (
  84. len(input_shape) != 4
  85. or len(reshape_shape1) != 6
  86. or len(reshape_shape2) != 4
  87. or len(reshape_shape3) != 5
  88. or len(reshape_shape4) != 4
  89. ):
  90. return False
  91. # Check shape and perm.
  92. batch, channel, height, width = input_shape
  93. blocksize = reshape_shape1[3]
  94. if (
  95. reshape_shape1 != [batch, channel, height // blocksize, blocksize, width // blocksize, blocksize]
  96. or transpose_perm1 != [0, 1, 2, 4, 3, 5]
  97. or reshape_shape2 != [batch, channel, (height // blocksize) * (width // blocksize), blocksize**2]
  98. or transpose_perm2 != [0, 1, 3, 2]
  99. or reshape_shape3 != [batch, channel, blocksize**2, height // blocksize, width // blocksize]
  100. or transpose_perm3 != [0, 2, 1, 3, 4]
  101. or reshape_shape4 != [batch, blocksize**2 * channel, height // blocksize, width // blocksize]
  102. ):
  103. return False
  104. self.nodes_to_remove.extend(
  105. [
  106. reshape_node1,
  107. transpose_node1,
  108. reshape_node2,
  109. transpose_node2,
  110. reshape_node3,
  111. transpose_node3,
  112. reshape_node4,
  113. ]
  114. )
  115. s2d_node = onnx.helper.make_node(
  116. self.fused_op_type,
  117. name=self.create_unique_node_name(),
  118. inputs=[reshape_node1.input[0]],
  119. outputs=[reshape_node4.output[0]],
  120. blocksize=blocksize,
  121. )
  122. self.nodes_to_add.append(s2d_node)
  123. return True
  124. def fuse(
  125. self,
  126. node: onnx.NodeProto,
  127. input_name_to_nodes: dict[str, list[onnx.NodeProto]],
  128. output_name_to_node: dict[str, onnx.NodeProto],
  129. ):
  130. """Fuse a sequence of Reshape and Transpose nodes into a single SpaceToDepth node.
  131. Args:
  132. node: An onnx.NodeProto matching the specified search type (i.e., Reshape).
  133. input_name_to_nodes: A dict mapping tensor name to consumed nodes.
  134. output_name_to_node: A dict mapping tensor name to produced node.
  135. """
  136. self._fuse_yolo(node, input_name_to_nodes, output_name_to_node)