_swap.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. import logging
  2. import operator
  3. import types
  4. from collections import defaultdict
  5. from typing import Optional
  6. import torch
  7. import torch.fx._pytree as fx_pytree
  8. import torch.utils._pytree as pytree
  9. from torch.export.exported_program import (
  10. ConstantArgument,
  11. ExportedProgram,
  12. ModuleCallSignature,
  13. )
  14. from torch.fx.passes.tools_common import legalize_graph, NodeList
  15. from torch.fx.passes.utils.fuser_utils import erase_nodes, fuse_as_graphmodule
  16. log = logging.getLogger(__name__)
  17. def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]:
  18. node_users = list(node.users.keys())
  19. getitem_users = set()
  20. for user in node_users:
  21. if user.op == "output":
  22. continue
  23. assert user.op == "call_function" and user.target == operator.getitem, (
  24. f"Expected getitem node as user for {node}, instead got {user}"
  25. )
  26. getitem_users.update(list(user.users.keys()))
  27. return getitem_users
  28. def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
  29. """
  30. We want to try to remove extraneous pytree flatten/unflatten calls between modules
  31. calls. Instead of having the following:
  32. graph():
  33. ...
  34. %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
  35. %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {})
  36. %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
  37. %tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
  38. %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
  39. %getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {})
  40. %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})
  41. %bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {})
  42. ...
  43. We could do the following, if we know that all the outputs of `foo` feed into `bar`:
  44. graph():
  45. ...
  46. %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
  47. %bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {})
  48. ...
  49. Currently this optimization only works for the case where all of the outputs
  50. of `foo` go directly into `bar`, and `bar` has no other inputs.
  51. """ # noqa: B950
  52. log.debug("Trying to remove pytrees for module call %s", curr_module_node)
  53. curr_module_users = list(curr_module_node.users.keys())
  54. assert len(curr_module_users) == 1, (
  55. f"Expected only one user for module node, instead got {list(curr_module_users)}"
  56. )
  57. flatten_node = curr_module_users[0]
  58. assert (
  59. flatten_node.op == "call_function"
  60. and flatten_node.target == fx_pytree.tree_flatten_spec
  61. )
  62. flatten_getitem_users = _get_getitem_users(flatten_node)
  63. if len(flatten_getitem_users) != 1:
  64. log.debug(
  65. "More than one user found for flatten node, %s: %s. "
  66. "Unable to fuse it with another unflatten call.",
  67. flatten_node,
  68. flatten_getitem_users,
  69. )
  70. return
  71. unflatten_node = next(iter(flatten_getitem_users))
  72. if not (
  73. unflatten_node.op == "call_function"
  74. and unflatten_node.target == pytree.tree_unflatten
  75. ):
  76. log.debug(
  77. "Flatten node %s's user is not a pytree.tree_unflatten. "
  78. "Instead it is: %s. Passing...",
  79. flatten_node,
  80. unflatten_node,
  81. )
  82. return
  83. for i, arg in enumerate(unflatten_node.args[0]): # type: ignore[union-attr,arg-type]
  84. if arg not in flatten_node.users:
  85. log.debug(
  86. "Module %s's outputs are not all directly used as inputs to "
  87. "the subsequent module. Unable to fuse the connecting "
  88. "flatten/unflatten. The inputs to the subsequent module are: %s. ",
  89. curr_module_node,
  90. unflatten_node.args[0],
  91. )
  92. return
  93. if not (
  94. arg.op == "call_function"
  95. and arg.target == operator.getitem
  96. and arg.args[1] == i
  97. ):
  98. log.debug(
  99. "Module %s's outputs are not all directly used in the same "
  100. "order as outputted. Unable to fuse the connecting "
  101. "flatten/unflatten. The inputs to the "
  102. "subsequent module are: %s. ",
  103. curr_module_node,
  104. unflatten_node.args[0],
  105. )
  106. return
  107. # Unflatten has two levels of getitem, because it gets the args and kwargs
  108. unflatten_getitem_getitem_users = set()
  109. unflatten_getitem_users = _get_getitem_users(unflatten_node)
  110. for unflatten_getitem_user in unflatten_getitem_users:
  111. unflatten_getitem_getitem_users.update(
  112. list(unflatten_getitem_user.users.keys())
  113. )
  114. if len(unflatten_getitem_getitem_users) != 1:
  115. log.debug(
  116. "More than one user found for unflatten node, %s: %s. "
  117. "Unable to fuse it with another flatten call.",
  118. unflatten_node,
  119. unflatten_getitem_getitem_users,
  120. )
  121. return
  122. next_module_node = next(iter(unflatten_getitem_getitem_users))
  123. if not (next_module_node.op == "call_module"):
  124. log.debug(
  125. "Unflatten node %s's user is not a call_module. "
  126. "Instead it is: %s. Passing...",
  127. unflatten_node,
  128. next_module_node,
  129. )
  130. return
  131. # Directly put the outputs of the current module into the next module
  132. next_module_node.args = (curr_module_node,)
  133. def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None:
  134. """
  135. Remove extraneous pytree flatten/unflatten calls.
  136. We try a couple of optimizations here:
  137. 1. Remove pytree flatten/unflatten calls between modules
  138. 2. TODO: Remove module's in_spec + initial unflatten call
  139. 3. TODO: Remove module's out_spec + final flatten call
  140. """
  141. for node in gm.graph.nodes:
  142. if node.op == "call_module" and node.target != "_guards_fn":
  143. _try_remove_connecting_pytrees(node)
  144. gm.graph.eliminate_dead_code()
  145. def _construct_inputs(
  146. gm: torch.fx.GraphModule,
  147. signature: ModuleCallSignature,
  148. node_name_map: dict[str, torch.fx.Node],
  149. ) -> tuple[list[torch.fx.Node], dict[str, torch.fx.Node]]:
  150. tree_unflatten_args: list[Optional[torch.fx.Node]] = []
  151. for input_ in signature.inputs:
  152. if isinstance(input_, ConstantArgument) and input_.value is None:
  153. # Constants should be directly embedded into the graph and not used
  154. # as inputs
  155. tree_unflatten_args.append(None)
  156. elif input_.name not in node_name_map:
  157. # For unused inputs
  158. tree_unflatten_args.append(None)
  159. else:
  160. tree_unflatten_args.append(node_name_map[input_.name])
  161. # Insert unflatten call
  162. from .unflatten import _generate_unflatten
  163. unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec)
  164. assert signature.in_spec.num_children == 2
  165. args_spec = signature.in_spec.children_specs[0]
  166. assert args_spec.context is None
  167. args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0))
  168. args_nodes = [
  169. gm.graph.call_function(operator.getitem, (args_node, i))
  170. for i in range(args_spec.num_children)
  171. ]
  172. kwargs_spec = signature.in_spec.children_specs[1]
  173. assert kwargs_spec.context is not None
  174. kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1))
  175. kwargs_nodes = {
  176. k: gm.graph.call_function(operator.getitem, (kwargs_node, k))
  177. for k in kwargs_spec.context
  178. }
  179. return args_nodes, kwargs_nodes
  180. def _insert_call_module(
  181. gm: torch.fx.GraphModule,
  182. args_nodes: list[torch.fx.Node],
  183. kwargs_nodes: dict[str, torch.fx.Node],
  184. module_to_swap: torch.nn.Module,
  185. name: str,
  186. ) -> torch.fx.Node:
  187. from .unflatten import _assign_attr, _AttrKind
  188. _assign_attr(module_to_swap, gm, name, _AttrKind.MODULE)
  189. module_node = gm.graph.call_module(name, tuple(args_nodes), kwargs_nodes) # type: ignore[arg-type]
  190. return module_node
  191. def _deconstruct_outputs(
  192. gm: torch.fx.GraphModule,
  193. signature: ModuleCallSignature,
  194. module_node: torch.fx.Node,
  195. node_name_map: dict[str, torch.fx.Node],
  196. orig_outputs: tuple[torch.fx.Node, ...],
  197. ) -> None:
  198. from .unflatten import _generate_flatten_spec
  199. flatten_node = _generate_flatten_spec(gm, module_node, signature.out_spec)
  200. for i, orig_output in enumerate(orig_outputs):
  201. # Use Proxy to record getitem access.
  202. proxy_out = torch.fx.Proxy(flatten_node)[i].node # type: ignore[index]
  203. orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
  204. node_name_map[orig_output.name] = proxy_out
  205. def _swap_module_helper(
  206. gm: torch.fx.GraphModule,
  207. modules_to_swap: dict[str, torch.nn.Module],
  208. module_call_graph: dict[str, ModuleCallSignature],
  209. ) -> torch.fx.GraphModule:
  210. log.debug("Starting graph:")
  211. log.debug(gm.graph)
  212. legalize_graph(gm)
  213. partitions: dict[str, NodeList] = defaultdict(list)
  214. node_name_map: dict[str, torch.fx.Node] = {
  215. node.name: node for node in gm.graph.nodes
  216. }
  217. # TODO: Handle the duplicate module case
  218. for node in gm.graph.nodes:
  219. if nn_module_stack := node.meta.get("nn_module_stack"):
  220. for path, _ in nn_module_stack.values():
  221. if path in modules_to_swap:
  222. partitions[path].append(node)
  223. break
  224. for name, nodes in partitions.items():
  225. """
  226. Given a graph like the following, and we want to swap out the submodule "foo":
  227. graph():
  228. %x : [num_users=1] = placeholder[target=x]
  229. %y : [num_users=2] = placeholder[target=y]
  230. %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {}), nn_module_stack = {"foo": ("foo", torch.nn.Module)}
  231. %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %add), kwargs = {}), nn_module_stack = {"bar": ("bar", torch.nn.Module)}
  232. return (sub,)
  233. We will first partition out foo's subgraph:
  234. graph():
  235. %x : [num_users=1] = placeholder[target=x]
  236. %y : [num_users=2] = placeholder[target=y]
  237. %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {})
  238. return add
  239. And then insert an unflatten + call_module + flatten to replace the subgraph:
  240. graph():
  241. %x : [num_users=1] = placeholder[target=x]
  242. %y : [num_users=1] = placeholder[target=y]
  243. %_spec_0 : [num_users=1] = get_attr[target=_spec_0]
  244. %tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
  245. %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {})
  246. %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
  247. %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {})
  248. %getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten, 1), kwargs = {})
  249. %foo : [num_users=0] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
  250. %_spec_1 : [num_users=1] = get_attr[target=_spec_1]
  251. %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (None, %_spec_1), kwargs = {})
  252. %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
  253. %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %getitem_4), kwargs = {})
  254. return (%sub,)
  255. The `tree_unflatten` call will construct tensor inputs into the input
  256. format needed by the swapped eager module.
  257. The `call_module` node should now reference the swapped torch.nn.Module.
  258. The `tree_flatten_spec` call will deconstruct the eager outputs of the
  259. swapped module into tensors.
  260. """ # noqa: B950
  261. submod_name = name.replace(".", "_")
  262. sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
  263. gm, nodes, f"fused_{submod_name}"
  264. )
  265. log.debug("Fused subgraph nodes:")
  266. log.debug(sub_gm.graph)
  267. signature: ModuleCallSignature = module_call_graph[name]
  268. args_nodes, kwargs_nodes = _construct_inputs(gm, signature, node_name_map)
  269. module_node = _insert_call_module(
  270. gm, args_nodes, kwargs_nodes, modules_to_swap[name], name
  271. )
  272. _deconstruct_outputs(gm, signature, module_node, node_name_map, orig_outputs)
  273. erase_nodes(gm, nodes)
  274. log.debug("Swapped graph:")
  275. log.debug(gm.graph)
  276. legalize_graph(gm)
  277. log.debug("Before removing extraneous pytrees:")
  278. log.debug(gm.graph)
  279. _remove_extraneous_pytrees(gm)
  280. log.debug("After removing extraneous pytrees:")
  281. log.debug(gm.graph)
  282. gm.recompile()
  283. return gm
  284. def _fix_input_output_signature(
  285. gm: torch.fx.GraphModule, signature: ModuleCallSignature
  286. ) -> None:
  287. """
  288. Given the unlifted module from calling ep.module(), we want to remove the
  289. pytree processing from the graph module's PyTreeCodeGen and instead make it
  290. nodes inside of the graph. This allows us to do some optimizations, like
  291. remove these pytree calls if it is unnecessary, and makes the PyTree part
  292. more obvious to graph passes.
  293. """
  294. from torch.export.unflatten import _generate_flatten, _generate_unflatten
  295. # Remove the registered pytree codegen because we will take care of it
  296. # through inserting pytree nodes into the graph
  297. gm.graph._codegen = torch.fx.graph.CodeGen()
  298. old_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
  299. new_placeholders = []
  300. forward_arg_names = signature.forward_arg_names
  301. if forward_arg_names is None:
  302. forward_arg_names = []
  303. assert signature.in_spec.num_children == 2
  304. arg_spec = signature.in_spec.children_specs[0]
  305. kwarg_spec = signature.in_spec.children_specs[1]
  306. assert arg_spec.type == tuple
  307. assert kwarg_spec.type == dict
  308. for i in range(arg_spec.num_children):
  309. forward_arg_names.append(f"arg_{i}")
  310. forward_arg_names.extend(kwarg_spec.context)
  311. for arg in forward_arg_names:
  312. with gm.graph.inserting_before(old_placeholders[0]):
  313. new_placeholders.append(gm.graph.placeholder(arg))
  314. # Insert flatten call for the inputs
  315. with gm.graph.inserting_before(old_placeholders[0]):
  316. flat_node = _generate_flatten(gm, tuple(new_placeholders))
  317. for i, old_placeholder in enumerate(old_placeholders):
  318. old_placeholder.op = "call_function"
  319. old_placeholder.target = operator.getitem
  320. old_placeholder.args = (flat_node, i)
  321. # Insert unflatten call for the outputs
  322. output_node = next(node for node in gm.graph.nodes if node.op == "output")
  323. with gm.graph.inserting_before(output_node):
  324. unflat = _generate_unflatten(gm, output_node.args[0], signature.out_spec)
  325. output_node.args = (unflat,)
  326. gm.recompile()
  327. def _swap_modules(
  328. ep: ExportedProgram, modules_to_swap: dict[str, torch.nn.Module]
  329. ) -> torch.fx.GraphModule:
  330. """
  331. Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps
  332. previously traced modules with new eager modules specified. Returns a
  333. fx.GraphModule with a custom forward function.
  334. Args:
  335. ep (ExportedProgram): Exported program to modify
  336. modules_to_swap (Dict[str, torch.nn.Module]): Mapping from module fqn to
  337. eager module to swap with. The specified module fqn should have also
  338. been specified in the `preserve_module_call_signature` argument to
  339. torch.export so that we know how to restore the calling convention
  340. to this argument.
  341. run_with_interpreter: Whether or not to run the graph using
  342. fx.Interpreter. Setting to true will help result in better error
  343. messages and easier debugging, but it has found to result in a QPS
  344. drop.
  345. """
  346. module_call_graph = {
  347. entry.fqn: entry.signature for entry in ep.module_call_graph if entry.signature
  348. }
  349. gm = ep.module()
  350. gm.validate_inputs = False # type: ignore[assignment]
  351. gm.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
  352. assert isinstance(gm, torch.fx.GraphModule)
  353. _fix_input_output_signature(gm, ep.module_call_graph[0].signature)
  354. gm.module_call_graph = ep.module_call_graph
  355. gm.train = types.MethodType(type(gm).train, gm) # type: ignore[assignment]
  356. gm.eval = types.MethodType(type(gm).eval, gm) # type: ignore[assignment]
  357. assert isinstance(gm, torch.fx.GraphModule)
  358. gm = _swap_module_helper(gm, modules_to_swap, module_call_graph)
  359. return gm