fuse.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from typing import Any, Callable, Union
  4. from torch.ao.quantization.backend_config import (
  5. BackendConfig,
  6. get_native_backend_config,
  7. )
  8. from torch.ao.quantization.backend_config.utils import (
  9. get_fuser_method_mapping,
  10. get_fusion_pattern_to_extra_inputs_getter,
  11. get_fusion_pattern_to_root_node_getter,
  12. )
  13. from torch.ao.quantization.utils import NodePattern, Pattern
  14. from torch.fx import GraphModule, map_arg, Node
  15. from torch.fx.graph import Graph
  16. from .custom_config import FuseCustomConfig
  17. from .fuse_handler import _get_fusion_pattern_to_fuse_handler_cls, FuseHandler
  18. from .match_utils import _is_match, MatchAllNode
  19. from .pattern_utils import _sorted_patterns_dict
  20. __all__ = [
  21. "fuse",
  22. # TODO: We should make this private in the future
  23. # This is currently needed for test_public_bindings for some reason
  24. "FuseHandler",
  25. ]
  26. def fuse(
  27. model: GraphModule,
  28. is_qat: bool,
  29. fuse_custom_config: Union[FuseCustomConfig, dict[str, Any], None] = None,
  30. backend_config: Union[BackendConfig, dict[str, Any], None] = None,
  31. ) -> GraphModule:
  32. if fuse_custom_config is None:
  33. fuse_custom_config = FuseCustomConfig()
  34. if isinstance(fuse_custom_config, dict):
  35. warnings.warn(
  36. "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
  37. "in a future version. Please pass in a FuseCustomConfig instead.",
  38. FutureWarning,
  39. stacklevel=2,
  40. )
  41. fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
  42. if isinstance(backend_config, dict):
  43. warnings.warn(
  44. "Passing a backend_config_dict to prepare is deprecated and will not be supported "
  45. "in a future version. Please pass in a BackendConfig instead.",
  46. FutureWarning,
  47. stacklevel=2,
  48. )
  49. backend_config = BackendConfig.from_dict(backend_config)
  50. named_modules = dict(model.named_modules())
  51. if backend_config is None:
  52. backend_config = get_native_backend_config()
  53. fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(
  54. _get_fusion_pattern_to_fuse_handler_cls(backend_config)
  55. )
  56. fuser_method_mapping = get_fuser_method_mapping(backend_config)
  57. fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(
  58. backend_config
  59. )
  60. fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(
  61. backend_config
  62. )
  63. # find fusion
  64. fusion_pairs = _find_matches(model, model.graph, fusion_pattern_to_fuse_handler_cls)
  65. # TODO: change this to inplace changes to graph, since we no longer construct
  66. # new GraphModule anymore
  67. fused_graph = Graph()
  68. env: dict[Any, Any] = {}
  69. def load_arg(a):
  70. return map_arg(a, lambda node: env[node.name])
  71. def default_root_node_getter(node_pattern):
  72. while not isinstance(node_pattern[-1], Node):
  73. node_pattern = node_pattern[-1]
  74. return node_pattern[-1]
  75. for node in model.graph.nodes:
  76. (
  77. maybe_last_node,
  78. pattern,
  79. matched_node_pattern,
  80. obj,
  81. node_to_subpattern,
  82. ) = fusion_pairs.get(node.name, (None, None, None, None, None))
  83. # get the corresponding subpattern for the current node
  84. if node_to_subpattern is not None:
  85. node_subpattern = node_to_subpattern.get(node, None)
  86. else:
  87. node_subpattern = None
  88. if maybe_last_node is node:
  89. assert obj is not None
  90. root_node_getter = fusion_pattern_to_root_node_getter.get(
  91. pattern, default_root_node_getter
  92. )
  93. root_node = root_node_getter(matched_node_pattern) # type: ignore[index]
  94. extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(
  95. pattern, None
  96. )
  97. extra_inputs = []
  98. if extra_inputs_getter is not None:
  99. extra_inputs = extra_inputs_getter(matched_node_pattern)
  100. # TODO: add validation that root_node is a module and has the same type
  101. # as the root_module in the configuration
  102. env[node.name] = obj.fuse(
  103. load_arg,
  104. named_modules,
  105. fused_graph,
  106. root_node,
  107. extra_inputs,
  108. matched_node_pattern, # type: ignore[arg-type]
  109. fuse_custom_config,
  110. fuser_method_mapping,
  111. is_qat,
  112. )
  113. elif maybe_last_node is None or node_subpattern is MatchAllNode:
  114. env[node.name] = fused_graph.node_copy(node, load_arg)
  115. # node matched in patterns and is not root is removed here
  116. model = GraphModule(model, fused_graph)
  117. return model
  118. def _find_matches(
  119. root: GraphModule,
  120. graph: Graph,
  121. pattern_to_fuse_handler_cls: dict[Pattern, Callable],
  122. ) -> dict[str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]]]:
  123. modules = dict(root.named_modules())
  124. # node name -> (root_node, match_value)
  125. match_map: dict[
  126. str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]]
  127. ] = {}
  128. # a map from node to the matched subpattern
  129. node_to_subpattern: dict[Node, Any] = {}
  130. # TODO: dedup with quantization matching function in match_utils.py
  131. def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
  132. if isinstance(pattern, tuple):
  133. s, *args = pattern
  134. current_node_pattern: list[Node] = []
  135. apply_match(s, node, match, current_node_pattern, node_to_subpattern)
  136. for subpattern, arg in zip(args, node.args):
  137. apply_match(
  138. subpattern, arg, match, current_node_pattern, node_to_subpattern
  139. )
  140. matched_node_pattern.append(tuple(current_node_pattern))
  141. else:
  142. # the first pattern matches will take precedence
  143. if node.name not in match_map:
  144. matched_node_pattern.append(node)
  145. # MatchAllNode here is actually MatchAllInputNode which should not
  146. # be added to match_map
  147. if pattern is not MatchAllNode:
  148. node_to_subpattern[node] = pattern
  149. root_node, pattern, handler = match
  150. match_map[node.name] = (
  151. root_node,
  152. pattern,
  153. matched_node_pattern,
  154. handler,
  155. node_to_subpattern,
  156. )
  157. for node in reversed(graph.nodes):
  158. if node.name not in match_map:
  159. for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
  160. matched_node_pattern: list[Node] = []
  161. if _is_match(modules, node, pattern):
  162. apply_match(
  163. pattern,
  164. node,
  165. (node, pattern, fuse_handler_cls(node)),
  166. matched_node_pattern,
  167. node_to_subpattern,
  168. )
  169. break
  170. return match_map