subgraph_rewriter.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. import copy
  2. from dataclasses import dataclass
  3. from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
  4. import torch
  5. from ._compatibility import compatibility
  6. from ._symbolic_trace import symbolic_trace
  7. from .graph import Graph
  8. from .graph_module import GraphModule
  9. from .node import Node
  10. if TYPE_CHECKING:
  11. from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
  12. __all__ = [
  13. "Match",
  14. "replace_pattern",
  15. "replace_pattern_with_filters",
  16. "ReplacedPatterns",
  17. ]
  18. @compatibility(is_backward_compatible=True)
  19. class Match(NamedTuple):
  20. # Node from which the match was found
  21. anchor: Node
  22. # Maps nodes in the pattern subgraph to nodes in the larger graph
  23. nodes_map: dict[Node, Node]
  24. @compatibility(is_backward_compatible=False)
  25. @dataclass
  26. class ReplacedPatterns:
  27. # Node from which the match was found
  28. anchor: Node
  29. # Maps nodes in the pattern subgraph to nodes in the larger graph
  30. nodes_map: dict[Node, Node]
  31. # List of nodes that were added into the graph
  32. replacements: list[Node]
  33. def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
  34. gm.delete_all_unused_submodules()
  35. if isinstance(replacement, GraphModule):
  36. replacement.graph.lint()
  37. def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]:
  38. module_path, _, attr_name = target.rpartition(".")
  39. try:
  40. mod: torch.nn.Module = gm.get_submodule(module_path)
  41. except AttributeError:
  42. return None
  43. attr = getattr(mod, attr_name, None)
  44. return attr
  45. for node in gm.graph.nodes:
  46. if node.op == "call_module" or node.op == "get_attr":
  47. gm_attr = try_get_attr(gm, node.target)
  48. replacement_attr = try_get_attr(replacement, node.target)
  49. # CASE 1: This target already exists as an attribute in our
  50. # result GraphModule. Whether or not it exists in
  51. # `replacement`, the existing submodule takes precedence.
  52. if gm_attr is not None:
  53. continue
  54. # CASE 2: The target exists as an attribute in `replacement`
  55. # only, so we need to copy it over.
  56. elif replacement_attr is not None:
  57. new_attr = copy.deepcopy(replacement_attr)
  58. if isinstance(replacement_attr, torch.nn.Module):
  59. gm.add_submodule(node.target, new_attr)
  60. else:
  61. setattr(gm, node.target, new_attr)
  62. # CASE 3: The target doesn't exist as an attribute in `gm`
  63. # or `replacement`
  64. else:
  65. raise RuntimeError(
  66. 'Attempted to create a "',
  67. node.op,
  68. '" node during subgraph rewriting '
  69. f"with target {node.target}, but "
  70. "the referenced attribute does not "
  71. "exist in the replacement GraphModule",
  72. )
  73. gm.graph.lint()
  74. @compatibility(is_backward_compatible=True)
  75. def replace_pattern(
  76. gm: GraphModule,
  77. pattern: Union[Callable, GraphModule],
  78. replacement: Union[Callable, GraphModule],
  79. ) -> list[Match]:
  80. """
  81. Matches all possible non-overlapping sets of operators and their
  82. data dependencies (``pattern``) in the Graph of a GraphModule
  83. (``gm``), then replaces each of these matched subgraphs with another
  84. subgraph (``replacement``).
  85. Args:
  86. ``gm``: The GraphModule that wraps the Graph to operate on
  87. ``pattern``: The subgraph to match in ``gm`` for replacement
  88. ``replacement``: The subgraph to replace ``pattern`` with
  89. Returns:
  90. List[Match]: A list of ``Match`` objects representing the places
  91. in the original graph that ``pattern`` was matched to. The list
  92. is empty if there are no matches. ``Match`` is defined as:
  93. .. code-block:: python
  94. class Match(NamedTuple):
  95. # Node from which the match was found
  96. anchor: Node
  97. # Maps nodes in the pattern subgraph to nodes in the larger graph
  98. nodes_map: Dict[Node, Node]
  99. Examples:
  100. .. code-block:: python
  101. import torch
  102. from torch.fx import symbolic_trace, subgraph_rewriter
  103. class M(torch.nn.Module):
  104. def __init__(self) -> None:
  105. super().__init__()
  106. def forward(self, x, w1, w2):
  107. m1 = torch.cat([w1, w2]).sum()
  108. m2 = torch.cat([w1, w2]).sum()
  109. return x + torch.max(m1) + torch.max(m2)
  110. def pattern(w1, w2):
  111. return torch.cat([w1, w2])
  112. def replacement(w1, w2):
  113. return torch.stack([w1, w2])
  114. traced_module = symbolic_trace(M())
  115. subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
  116. The above code will first match ``pattern`` in the ``forward``
  117. method of ``traced_module``. Pattern-matching is done based on
  118. use-def relationships, not node names. For example, if you had
  119. ``p = torch.cat([a, b])`` in ``pattern``, you could match
  120. ``m = torch.cat([a, b])`` in the original ``forward`` function,
  121. despite the variable names being different (``p`` vs ``m``).
  122. The ``return`` statement in ``pattern`` is matched based on its
  123. value only; it may or may not match to the ``return`` statement in
  124. the larger graph. In other words, the pattern doesn't have to extend
  125. to the end of the larger graph.
  126. When the pattern is matched, it will be removed from the larger
  127. function and replaced by ``replacement``. If there are multiple
  128. matches for ``pattern`` in the larger function, each non-overlapping
  129. match will be replaced. In the case of a match overlap, the first
  130. found match in the set of overlapping matches will be replaced.
  131. ("First" here being defined as the first in a topological ordering
  132. of the Nodes' use-def relationships. In most cases, the first Node
  133. is the parameter that appears directly after ``self``, while the
  134. last Node is whatever the function returns.)
  135. One important thing to note is that the parameters of the
  136. ``pattern`` Callable must be used in the Callable itself,
  137. and the parameters of the ``replacement`` Callable must match
  138. the pattern. The first rule is why, in the above code block, the
  139. ``forward`` function has parameters ``x, w1, w2``, but the
  140. ``pattern`` function only has parameters ``w1, w2``. ``pattern``
  141. doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
  142. As an example of the second rule, consider replacing
  143. .. code-block:: python
  144. def pattern(x, y):
  145. return torch.neg(x) + torch.relu(y)
  146. with
  147. .. code-block:: python
  148. def replacement(x, y):
  149. return torch.relu(x)
  150. In this case, ``replacement`` needs the same number of parameters
  151. as ``pattern`` (both ``x`` and ``y``), even though the parameter
  152. ``y`` isn't used in ``replacement``.
  153. After calling ``subgraph_rewriter.replace_pattern``, the generated
  154. Python code looks like this:
  155. .. code-block:: python
  156. def forward(self, x, w1, w2):
  157. stack_1 = torch.stack([w1, w2])
  158. sum_1 = stack_1.sum()
  159. stack_2 = torch.stack([w1, w2])
  160. sum_2 = stack_2.sum()
  161. max_1 = torch.max(sum_1)
  162. add_1 = x + max_1
  163. max_2 = torch.max(sum_2)
  164. add_2 = add_1 + max_2
  165. return add_2
  166. """
  167. match_and_replacements = _replace_pattern(gm, pattern, replacement)
  168. return [
  169. Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements
  170. ]
  171. # Experimental API, not backward compatible
  172. @compatibility(is_backward_compatible=False)
  173. def replace_pattern_with_filters(
  174. gm: GraphModule,
  175. pattern: Union[Callable, Graph, GraphModule],
  176. replacement: Union[Callable, Graph, GraphModule, None] = None,
  177. match_filters: Optional[
  178. list[Callable[["InternalMatch", Graph, Graph], bool]]
  179. ] = None,
  180. ignore_literals: bool = False,
  181. # Placed at the end to avoid breaking backward compatibility
  182. replacement_callback: Optional[
  183. Callable[["InternalMatch", Graph, Graph], Graph]
  184. ] = None,
  185. node_name_match: str = "",
  186. ) -> list[ReplacedPatterns]:
  187. """
  188. See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
  189. Args:
  190. ``match_filters``: A list of functions that take in
  191. (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
  192. whether the match satisfies the condition.
  193. See matcher_utils.py for definition of InternalMatch.
  194. ``replacement_callback``: A function that takes in a match and returns a
  195. Graph to be used as the replacement. This allows you to construct a
  196. replacement graph based on the match.
  197. ``replacement_callback``: Node name to match. If not empty, it will try to match the node name.
  198. """
  199. return _replace_pattern(
  200. gm,
  201. pattern,
  202. replacement,
  203. match_filters,
  204. ignore_literals,
  205. replacement_callback,
  206. node_name_match,
  207. )
  208. def _replace_pattern(
  209. gm: GraphModule,
  210. pattern: Union[Callable, Graph, GraphModule],
  211. replacement: Union[Callable, Graph, GraphModule, None] = None,
  212. match_filters: Optional[
  213. list[Callable[["InternalMatch", Graph, Graph], bool]]
  214. ] = None,
  215. ignore_literals: bool = False,
  216. # Placed at the end to avoid breaking backward compatibility
  217. replacement_callback: Optional[
  218. Callable[["InternalMatch", Graph, Graph], Graph]
  219. ] = None,
  220. node_name_match: str = "",
  221. ) -> list[ReplacedPatterns]:
  222. from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
  223. if match_filters is None:
  224. match_filters = []
  225. # Get the graphs for `gm`, `pattern`, `replacement`
  226. original_graph: Graph = gm.graph
  227. if isinstance(pattern, GraphModule):
  228. pattern_graph = pattern.graph
  229. elif isinstance(pattern, Graph):
  230. pattern_graph = pattern
  231. else:
  232. pattern_graph = symbolic_trace(pattern).graph
  233. matcher = SubgraphMatcher(
  234. pattern_graph,
  235. match_output=False,
  236. match_placeholder=False,
  237. remove_overlapping_matches=True,
  238. ignore_literals=ignore_literals,
  239. )
  240. _matches: list[InternalMatch] = matcher.match(
  241. original_graph, node_name_match=node_name_match
  242. )
  243. # Filter out matches that don't match the filter
  244. _matches = [
  245. m
  246. for m in _matches
  247. if all(
  248. match_filter(m, original_graph, pattern_graph)
  249. for match_filter in match_filters
  250. )
  251. ]
  252. if isinstance(replacement, GraphModule):
  253. common_replacement_graph = replacement.graph
  254. elif isinstance(replacement, Graph):
  255. common_replacement_graph = replacement
  256. elif callable(replacement):
  257. common_replacement_graph = symbolic_trace(replacement).graph
  258. else:
  259. assert replacement_callback is not None, (
  260. "Must provide either a replacement GraphModule or a replacement callback"
  261. )
  262. common_replacement_graph = None
  263. # As we progressively replace nodes, we'll need to keep track of how the match results should change
  264. match_changed_node: dict[Node, Node] = {}
  265. match_and_replacements = []
  266. for match in _matches:
  267. if replacement_callback is not None:
  268. replacement_graph = replacement_callback(
  269. match, original_graph, pattern_graph
  270. )
  271. else:
  272. assert common_replacement_graph is not None, (
  273. "Must provide either a replacement GraphModule or a replacement callback"
  274. )
  275. replacement_graph = common_replacement_graph
  276. replacement_placeholders = [
  277. n for n in replacement_graph.nodes if n.op == "placeholder"
  278. ]
  279. # Build connecting between replacement graph's input and original graph input producer node
  280. # Initialize `val_map` with mappings from placeholder nodes in
  281. # `replacement` to their corresponding node in `original_graph`
  282. assert len(match.placeholder_nodes) == len(replacement_placeholders)
  283. val_map: dict[Node, Node] = {}
  284. for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
  285. if isinstance(gn, Node):
  286. val_map[rn] = match_changed_node.get(gn, gn)
  287. if gn != val_map[rn]:
  288. # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn
  289. gn_ind = match.placeholder_nodes.index(gn)
  290. match.placeholder_nodes[gn_ind] = match_changed_node[gn]
  291. map_key = list(match.nodes_map.keys())[
  292. list(match.nodes_map.values()).index(gn)
  293. ]
  294. match.nodes_map[map_key] = match_changed_node[gn]
  295. else:
  296. val_map[rn] = gn
  297. # Copy the replacement graph over
  298. user_nodes: set[Node] = set()
  299. for n in match.returning_nodes:
  300. user_nodes.update(n.users)
  301. first_user_node = None
  302. if len(user_nodes) == 0:
  303. first_user_node = None
  304. elif len(user_nodes) == 1:
  305. first_user_node = next(iter(user_nodes))
  306. else:
  307. # If there are multiple user nodes, we need to find the first user node
  308. # in the current execution order of the `original_graph`
  309. for n in original_graph.nodes:
  310. if n in user_nodes:
  311. first_user_node = n
  312. break
  313. first_next_node = None
  314. if first_user_node is None:
  315. # no users, so we insert the replacement graph before the first next
  316. # node of returning nodes
  317. next_node = None
  318. for n in reversed(original_graph.nodes):
  319. if n in match.returning_nodes:
  320. first_next_node = next_node
  321. break
  322. else:
  323. next_node = n
  324. insert_point = (
  325. first_user_node if first_user_node is not None else first_next_node
  326. )
  327. assert insert_point is not None, "The insert point can't be None"
  328. with original_graph.inserting_before(insert_point):
  329. copied_returning_nodes = original_graph.graph_copy(
  330. replacement_graph, val_map
  331. )
  332. if isinstance(copied_returning_nodes, Node):
  333. copied_returning_nodes = (copied_returning_nodes,)
  334. # Get a list of nodes that have been replaced into the graph
  335. replacement_nodes: list[Node] = [
  336. v for v in val_map.values() if v not in match.placeholder_nodes
  337. ]
  338. # Hook the output Node of the replacement subgraph in to the
  339. # original Graph at the correct location
  340. assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type]
  341. for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type]
  342. gn.replace_all_uses_with(copied_node)
  343. match_changed_node[gn] = copied_node
  344. # Remove the original nodes
  345. for node in reversed(pattern_graph.nodes):
  346. if node.op != "placeholder" and node.op != "output":
  347. gn = match.nodes_map[node]
  348. gm.graph.erase_node(gn)
  349. match_and_replacements.append(
  350. ReplacedPatterns(
  351. anchor=match.anchors[0],
  352. nodes_map=match.nodes_map,
  353. replacements=replacement_nodes,
  354. )
  355. )
  356. # Update the passed-in GraphModule to reflect the new state of
  357. # `original_graph`
  358. gm.recompile()
  359. # If `replacement` was an nn.Module, we'll need to make sure that
  360. # all the submodules have been copied over correctly
  361. if isinstance(replacement, torch.nn.Module):
  362. _replace_attributes(gm, replacement)
  363. return match_and_replacements