split_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # mypy: allow-untyped-defs
  2. import copy
  3. from dataclasses import dataclass, field
  4. from typing import Optional, Union
  5. import torch.fx
  6. from torch.fx._compatibility import compatibility
  7. from torch.fx.graph import map_arg
  8. from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
  9. from .tools_common import NodeList
  10. __all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
  11. @compatibility(is_backward_compatible=False)
  12. def getattr_recursive(obj, name):
  13. for layer in name.split("."):
  14. if isinstance(obj, torch.nn.ModuleList):
  15. if hasattr(obj, "_modules") and layer in obj._modules:
  16. obj = obj._modules[layer]
  17. else:
  18. return None
  19. elif hasattr(obj, layer):
  20. obj = getattr(obj, layer)
  21. else:
  22. return None
  23. return obj
  24. @compatibility(is_backward_compatible=False)
  25. def setattr_recursive(obj, attr, value):
  26. if "." not in attr:
  27. setattr(obj, attr, value)
  28. else:
  29. layer = attr.split(".")
  30. setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
  31. @compatibility(is_backward_compatible=False)
  32. @dataclass
  33. class Component:
  34. """
  35. A component serves as a container for a subgraph we want to create afterwards.
  36. """
  37. graph: torch.fx.Graph
  38. order: int
  39. name: str
  40. # Stores the placeholder nodes in `graph`.
  41. input_placeholders: list = field(default_factory=list)
  42. # Store the nodes in original graph that are placeholder in `graph`.
  43. orig_inputs: list = field(default_factory=list)
  44. # Store the nodes in original graph that are outputs in `graph`.
  45. orig_outputs: list = field(default_factory=list)
  46. # Mapping from get_attr node in original graph to get_attr node in `graph`.
  47. getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
  48. constructor_args: list[str] = field(default_factory=list)
  49. gm: Optional[torch.fx.GraphModule] = None
  50. @compatibility(is_backward_compatible=False)
  51. def split_by_tags(
  52. gm: torch.fx.GraphModule,
  53. tags: list[str],
  54. return_fqn_mapping: bool = False,
  55. return_tuple: bool = False,
  56. GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule,
  57. ) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]:
  58. """
  59. Splits a GraphModule using tags on its graph nodes. We honor the order of
  60. tags. For example, we have tags = ["a", "b", "c"], the function will create
  61. the initial submodules in the order of "a", "b", "c".
  62. To set a tag:
  63. gm.graph.nodes[idx].tag = "mytag"
  64. This will result in all nodes with the same tag being extracted and placed in their
  65. own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
  66. and output nodes are created when needed while get_attr nodes get copied to submodules
  67. where they are used.
  68. Given the following module def:
  69. class SimpleModule(torch.nn.Module):
  70. def __init__(self) -> None:
  71. super().__init__()
  72. self.linear1 = torch.nn.Linear(...)
  73. self.linear2 = torch.nn.Linear(...)
  74. self.linear3 = torch.nn.Linear(...)
  75. def forward(self, in1, in2):
  76. r1 = self.linear1(in1)
  77. r2 = self.linear2(in2)
  78. r3 = torch.cat([r1, r2])
  79. return self.linear3(r3)
  80. Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
  81. ro:
  82. def forward(self, in1):
  83. self = self.root
  84. linear1 = self.linear1(in1)
  85. return linear1
  86. main:
  87. def forward(self, in2, linear1):
  88. self = self.root
  89. linear2 = self.linear2(in2)
  90. cat_1 = torch.cat([linear1, linear2])
  91. linear3 = self.linear3(cat_1)
  92. return linear3
  93. main:
  94. def forward(self, in1, in2):
  95. self = self.root
  96. ro_0 = self.ro_0(in1)
  97. main_1 = self.main_1(in2, ro_0)
  98. return main_1
  99. Returns:
  100. split_gm: torch fx graph after split
  101. orig_to_split_fqn_mapping: a map between the original fqn and the fqn
  102. after split for call_module and get_attr.
  103. """
  104. def flatten(x: torch.fx.node.Argument) -> NodeList:
  105. """
  106. Stores nodes in x to a list and returns the list.
  107. """
  108. r: NodeList = []
  109. map_arg(x, r.append)
  110. return r
  111. # Mapping from node in original module to node in created submodule.
  112. node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
  113. # Mapping from node in original module or created submodules to
  114. # corresponding component.
  115. node_to_component: dict[torch.fx.Node, Component] = {}
  116. # Mapping from tag to the corresponding component.
  117. tag_to_component: dict[str, Component] = {}
  118. # Stores all components.
  119. all_components: list[Component] = []
  120. # Stores nodes that will be used in main graph.
  121. used_in_main: dict[torch.fx.Node, None] = {}
  122. # Main graph after split.
  123. main_g = torch.fx.Graph()
  124. # Mapping from node in original module to node in main graph after split.
  125. main_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
  126. # Output node of original module.
  127. output_node: Optional[torch.fx.Node] = None
  128. # Create a component for each tag, we don't expect to create other components afterwards.
  129. for tag in tags:
  130. comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
  131. all_components.append(comp)
  132. tag_to_component[tag] = comp
  133. # Traverse the nodes in original graph and take care of them.
  134. for node in gm.graph.nodes:
  135. if node.op == "output":
  136. if output_node is not None:
  137. raise RuntimeError("Multiple output nodes in graph!")
  138. output_node = node
  139. continue
  140. # Placeholders in the original graph get copied to main graph.
  141. if node.op == "placeholder":
  142. main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
  143. main_remapping[node].meta = copy.copy(node.meta)
  144. continue
  145. # Get_attr nodes are ignored because we are not tagging them.
  146. # Instead, we copy them directly to the submodules use them afterwards.
  147. if node.op == "get_attr":
  148. continue
  149. # Now we process callable nodes which are nodes with op of call_module,
  150. # call_function or call_method. Every callable nodes should be tagged.
  151. assert hasattr(node, "tag"), f"Node does not have tag: {node.format_node()}"
  152. upstream_components = [
  153. node_to_component[x]
  154. for x in flatten(node.args) + flatten(node.kwargs)
  155. if x.op not in {"placeholder", "get_attr"}
  156. ]
  157. comp = tag_to_component[node.tag]
  158. node_to_component[node] = comp
  159. # Max order of upperstream components.
  160. mx = max((c.order for c in upstream_components), default=0)
  161. # Expect the component for `node` has higher order then its upstream components.
  162. assert comp.order >= mx, (
  163. f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}"
  164. )
  165. # Map a input of `node` to nodes in the component's graph.
  166. def remap_func(x):
  167. # If input is a get_attr node, copy it to current component's graph.
  168. # Returns the get_attr node in current component's graph.
  169. if x.op == "get_attr":
  170. if x not in comp.getattr_maps:
  171. comp.getattr_maps[x] = comp.graph.get_attr(
  172. x.target, type_expr=x.type
  173. )
  174. comp.getattr_maps[x].meta = copy.copy(x.meta)
  175. return comp.getattr_maps[x]
  176. # If input is not a placeholder, it should have been put into a component
  177. # already. If it's the current component then we return the corresponding
  178. # node in the component.
  179. if x.op != "placeholder" and node_to_component[x] == comp:
  180. return node_remapping[x]
  181. # If input is a placeholder or it's in other components, we want to make it
  182. # as a placeholder in current component's graph.
  183. if x not in comp.orig_inputs:
  184. comp.orig_inputs.append(x)
  185. placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
  186. placeholder.meta = copy.copy(x.meta)
  187. comp.input_placeholders.append(placeholder)
  188. used_in_main[x] = None
  189. return comp.input_placeholders[comp.orig_inputs.index(x)]
  190. n = comp.graph.node_copy(node, remap_func)
  191. n.tag = node.tag # type: ignore[attr-defined]
  192. node_remapping[node] = n
  193. node_to_component[n] = comp
  194. if output_node is None:
  195. raise RuntimeError("Graph had no output node!")
  196. for x in flatten(output_node.args[0]):
  197. if x.op == "get_attr":
  198. # We don't need components mapping for nodes of type "get_attr"
  199. # that are consumed by the output. Only need to make sure we create
  200. # corresponding counterparts in the resulting graph.
  201. main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
  202. else:
  203. # All component results consumed by the output node should be
  204. # marked as "used in main".
  205. used_in_main[x] = None
  206. # If a node is used in main graph then we mark it as an output in the component
  207. # it belongs to.
  208. for n in used_in_main:
  209. if n.op != "placeholder":
  210. node_to_component[n].orig_outputs.append(n)
  211. # Now we create a graphmodule for each component.
  212. orig_to_split_fqn_mapping: dict[str, str] = {}
  213. for comp in all_components:
  214. outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
  215. if return_tuple:
  216. comp.graph.output(outs)
  217. else:
  218. # Take care of the args of FX output node. If there's a single
  219. # output then the output node args is like (output_single), else
  220. # if there're multiple outputs then the output node args is like
  221. # ((output_0, output_1, ...)).
  222. comp.graph.output(outs[0] if len(outs) == 1 else outs)
  223. comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
  224. gm, subgraph=comp.graph, comp_name=comp.name
  225. )
  226. orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
  227. # Create a call_module node in main graph.
  228. main_node = main_g.call_module(
  229. comp.name,
  230. args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
  231. kwargs=None,
  232. )
  233. if len(outs) == 1 and not return_tuple:
  234. main_remapping[comp.orig_outputs[0]] = main_node
  235. else:
  236. for i, o in enumerate(comp.orig_outputs):
  237. # Use Proxy to record getitem access.
  238. main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index]
  239. main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
  240. main_root = HolderModule({comp.name: comp.gm for comp in all_components})
  241. main_g._codegen = gm.graph._codegen
  242. # If the output nodes consumes get_attr directly in the original graph,
  243. # then we need to make sure get_attr is copied to the new graph.
  244. for x in flatten(output_node.args[0]):
  245. if x.op == "get_attr":
  246. setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
  247. result_gm = GraphModuleCls(main_root, main_g)
  248. if return_fqn_mapping:
  249. return result_gm, orig_to_split_fqn_mapping
  250. return result_gm