split_module.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import logging
  4. from collections import OrderedDict
  5. from typing import Any, Callable, Optional
  6. import torch
  7. from torch.fx._compatibility import compatibility
  8. from torch.fx._utils import lazy_format_graph_code
  9. from torch.fx.graph_module import GraphModule
  10. from torch.fx.node import Node
  11. __all__ = ["Partition", "split_module"]
  12. log = _LOGGER = logging.getLogger(__name__)
  13. @compatibility(is_backward_compatible=True)
  14. class Partition:
  15. def __init__(self, name: str):
  16. self.name: str = name
  17. self.submod_name = f"submod_{name}"
  18. self.node_names: list[str] = []
  19. self.inputs: dict[str, None] = {}
  20. self.outputs: dict[str, None] = {}
  21. self.dependencies: dict[str, None] = {}
  22. self.dependents: dict[str, None] = {}
  23. self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
  24. self.environment: dict[Node, Node] = {}
  25. self.targets: dict[str, Any] = {}
  26. def __repr__(self) -> str:
  27. return (
  28. f"name: {self.name},\n"
  29. f" nodes: {self.node_names},\n"
  30. f" inputs: {self.inputs},\n"
  31. f" outputs: {self.outputs},\n"
  32. f" partitions depended on: {self.dependencies},\n"
  33. f" partition dependents: {self.dependents}"
  34. )
  35. def _get_attr_from_qualname(mod: torch.nn.Module, qualname: str) -> Any:
  36. attr_val = mod
  37. for atom in qualname.split("."): # type: ignore[union-attr]
  38. if not hasattr(attr_val, atom):
  39. raise AttributeError(f"Node target {qualname} not found!")
  40. attr_val = getattr(attr_val, atom)
  41. return attr_val
  42. # Creates subgraphs out of main graph
  43. @compatibility(is_backward_compatible=True)
  44. def split_module(
  45. m: GraphModule,
  46. root_m: torch.nn.Module,
  47. split_callback: Callable[[Node], int],
  48. qualname_map: Optional[dict[str, str]] = None,
  49. keep_original_order: Optional[bool] = False,
  50. keep_original_node_name: Optional[bool] = False,
  51. keep_original_input_name: bool = True,
  52. ):
  53. """
  54. Creates subgraphs out of main graph
  55. Args:
  56. m (GraphModule): Graph module to split
  57. root_m (torch.nn.Module): root nn module. Not currently used. Included
  58. because the root nn module is usually transformed via
  59. torch.fx._symbolic_trace.symbolic_trace (see example below)
  60. split_callback (Callable[[Node], int]): Callable function
  61. that maps a given Node instance to a numeric partition identifier.
  62. split_module will use this function as the policy for which operations
  63. appear in which partitions in the output Module.
  64. qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
  65. mapping from new target names in the module after split to old target
  66. names in the original module.
  67. keep_original_order: Optional[bool]: keep the original order of the GraphModule
  68. or use the Topological order of the new constructed GraphModule
  69. keep_original_node_name: Optional[bool]: If the partitioned graphs should
  70. have the same node names as the original graph.
  71. keep_original_input_name: bool: If the partitioned graphs should
  72. have the same input names as the original graph.
  73. Returns:
  74. GraphModule: the module after split.
  75. Example:
  76. This is a sample setup:
  77. import torch
  78. from torch.fx.symbolic_trace import symbolic_trace
  79. from torch.fx.graph_module import GraphModule
  80. from torch.fx.node import Node
  81. from torch.fx.passes.split_module import split_module
  82. class MyModule(torch.nn.Module):
  83. def __init__(self) -> None:
  84. super().__init__()
  85. self.param = torch.nn.Parameter(torch.rand(3, 4))
  86. self.linear = torch.nn.Linear(4, 5)
  87. def forward(self, x, y):
  88. z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
  89. w = self.linear(y).clamp(min=0.0, max=1.0)
  90. return z + w
  91. # symbolically trace model
  92. my_module = MyModule()
  93. my_module_traced = symbolic_trace(my_module)
  94. # random mod partitioning
  95. partition_counter = 0
  96. NPARTITIONS = 3
  97. def mod_partition(node: Node):
  98. global partition_counter
  99. partition = partition_counter % NPARTITIONS
  100. partition_counter = (partition_counter + 1) % NPARTITIONS
  101. return partition
  102. # split module in module with submodules
  103. module_with_submodules = split_module(
  104. my_module_traced, my_module, mod_partition
  105. )
  106. Output looks like this. Original graph is broken into partitions
  107. > print(module_with_submodules)
  108. GraphModule(
  109. (submod_0): GraphModule(
  110. (linear): Linear(in_features=4, out_features=5, bias=True)
  111. )
  112. (submod_1): GraphModule(
  113. (linear): Linear(in_features=4, out_features=5, bias=True)
  114. )
  115. (submod_2): GraphModule()
  116. )
  117. def forward(self, x, y):
  118. param = self.param
  119. submod_0 = self.submod_0(x, param, y); x = param = y = None
  120. getitem = submod_0[0]
  121. getitem_1 = submod_0[1]; submod_0 = None
  122. submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
  123. getitem_2 = submod_1[0]
  124. getitem_3 = submod_1[1]; submod_1 = None
  125. submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
  126. return submod_2
  127. Output of split module is the same as output of input traced module.
  128. This is an example within a test setting:
  129. > orig_out = my_module_traced(x, y)
  130. > submodules_out = module_with_submodules(x, y)
  131. > self.assertEqual(orig_out, submodules_out)
  132. True
  133. """
  134. log.debug(
  135. "%s",
  136. lazy_format_graph_code("pre split_module", m, colored=True),
  137. )
  138. def construct_graph(
  139. node: Node,
  140. base_mod_env: dict[str, Node],
  141. base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule],
  142. ):
  143. if node.op == "placeholder":
  144. default_value = (
  145. node.args[0] if len(node.args) > 0 else inspect.Signature.empty
  146. )
  147. if keep_original_node_name:
  148. args = (
  149. () if default_value is inspect.Signature.empty else (default_value,)
  150. )
  151. base_mod_env[node.name] = base_mod_graph.create_node(
  152. "placeholder",
  153. node.name,
  154. args=args, # type: ignore[arg-type]
  155. type_expr=node.type,
  156. )
  157. else:
  158. base_mod_env[node.name] = base_mod_graph.placeholder(
  159. node.target, # type: ignore[arg-type]
  160. type_expr=node.type,
  161. default_value=default_value,
  162. )
  163. base_mod_env[node.name].meta = node.meta.copy()
  164. elif node.op == "get_attr":
  165. base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type]
  166. base_mod_env[node.name].meta = node.meta.copy()
  167. assert isinstance(node.target, str)
  168. attr_val = _get_attr_from_qualname(m, node.target)
  169. base_mod_attrs[node.target] = attr_val # type: ignore[index]
  170. return base_mod_env, base_mod_attrs
  171. import sympy
  172. partitions: dict[str, Partition] = {}
  173. orig_nodes: dict[str, Node] = {}
  174. symbol_to_node: dict[sympy.Symbol, Node] = {}
  175. def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
  176. from torch.fx.experimental.symbolic_shapes import free_symbols
  177. defined = getattr(def_node, "_fx_partition", None)
  178. used = getattr(use_node, "_fx_partition", None)
  179. log.debug(
  180. "record_cross_partition_use %s (%s) %s (%s)",
  181. def_node.name,
  182. defined,
  183. use_node.name if use_node is not None else "-",
  184. used,
  185. )
  186. if defined != used:
  187. if defined is not None:
  188. def_partition = partitions[defined]
  189. def_partition.outputs.setdefault(def_node.name)
  190. if used is not None:
  191. def_partition.dependents.setdefault(used)
  192. if used is not None:
  193. use_partition = partitions[used]
  194. use_partition.inputs.setdefault(def_node.name)
  195. # We have made def_node an input to the use_partition. If
  196. # this input has symbolic symbols in its size, those also must
  197. # be made as inputs to the partition
  198. if (def_val := def_node.meta.get("example_value")) is not None:
  199. for s in sorted(free_symbols(def_val), key=str):
  200. s_node = symbol_to_node[s]
  201. use_partition.inputs.setdefault(s_node.name)
  202. if symbol_to_node[s].op != "placeholder":
  203. # If the node that defines the symbol is not a
  204. # placeholder, we must make it an output of the
  205. # partition. Note that this may be in a different
  206. # partition than defined! Although, this doesn't
  207. # really make a difference for correctness, since
  208. # defined is guaranteed to have the symbol in
  209. # scope and can return it; you just get less
  210. # optimal codegen in this case.
  211. s_defined = getattr(s_node, "_fx_partition", None)
  212. if s_defined is not None:
  213. s_def_partition = partitions[s_defined]
  214. s_def_partition.outputs.setdefault(s_node.name)
  215. s_def_partition.dependents.setdefault(used)
  216. use_partition.dependencies.setdefault(s_defined)
  217. if defined is not None:
  218. use_partition.dependencies.setdefault(defined)
  219. def instantiate_node_partition_mapping(node):
  220. partition_name = str(split_callback(node))
  221. log.debug(
  222. "instantiate_node_partition_mapping %s (%s)", node.name, partition_name
  223. )
  224. # add node to partitions
  225. partition = partitions.get(partition_name)
  226. if partition is None:
  227. partitions[partition_name] = partition = Partition(partition_name)
  228. partition.node_names.append(node.name)
  229. node._fx_partition = partition_name
  230. # Global State Nodes are nodes which by their global state effects,
  231. # "taint" all downstream nodes while they are active.
  232. GLOBAL_STATE_NODES = [
  233. torch.amp._enter_autocast,
  234. torch.amp._exit_autocast,
  235. torch._C._set_grad_enabled,
  236. ]
  237. # For grad regions:
  238. # ------------------------
  239. # 1. first region: we do nothing
  240. # 2. subsequent regions: we insert the set_grad at the beginning
  241. grad_regions: OrderedDict[Node, set[int]] = OrderedDict()
  242. # For autocast regions:
  243. # ------------------------
  244. # 1. first region: we will only insert the _exit at the end
  245. # 2. intermediate regions: we will insert both the
  246. # _enter at the beginning and _exit at the end
  247. # 3. last region: we will only insert _enter at the beginning
  248. # We will do so in the order in which the autocasts were instantiated.
  249. autocast_regions: OrderedDict[Node, set[int]] = OrderedDict()
  250. autocast_exits: dict[Node, Optional[Node]] = {}
  251. active_grad = None
  252. active_autocasts = set()
  253. for node in m.graph.nodes:
  254. # This will prefer placeholder bindings, because those come first.
  255. # This is a little dangerous though: it is possible that an unbacked
  256. # symbol is used without any binding site for it, in which case we
  257. # will get a KeyError not able to find it. I'd like to fix this by
  258. # having passes.runtime_assert establish some invariants that I can
  259. # rely on later, but this needs some extra work. Quick fix first.
  260. # See https://github.com/pytorch/pytorch/issues/130534
  261. if (
  262. (val := node.meta.get("example_value")) is not None
  263. and isinstance(val, (torch.SymInt, torch.SymFloat))
  264. and isinstance(s0 := val.node.expr, sympy.Symbol)
  265. and s0 not in symbol_to_node
  266. ):
  267. symbol_to_node[val.node.expr] = node
  268. if node.op in ["placeholder", "get_attr", "output"]:
  269. continue
  270. instantiate_node_partition_mapping(node)
  271. if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
  272. if node.target == torch._C._set_grad_enabled:
  273. assert len(node.args) == 1
  274. assert isinstance(node.args[0], bool)
  275. active_grad = node
  276. grad_regions[active_grad] = set({split_callback(node)})
  277. elif node.target == torch.amp._enter_autocast:
  278. # Should all be python constants
  279. assert all(not isinstance(arg, Node) for arg in node.args)
  280. active_autocasts.add(node)
  281. autocast_regions[node] = set({split_callback(node)})
  282. autocast_exits[node] = None
  283. elif node.target == torch.amp._exit_autocast:
  284. assert len(node.args) == 1
  285. autocast_regions[node.args[0]].add(split_callback(node))
  286. active_autocasts.remove(node.args[0])
  287. autocast_exits[node.args[0]] = node
  288. if active_grad is not None:
  289. grad_regions[active_grad].add(split_callback(node))
  290. for a in active_autocasts:
  291. autocast_regions[a].add(split_callback(node))
  292. assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
  293. autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
  294. grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
  295. if _LOGGER.isEnabledFor(logging.DEBUG):
  296. _LOGGER.debug("autocast_regions: %s", autocast_regions)
  297. _LOGGER.debug("grad_regions: %s", grad_regions)
  298. assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
  299. # split nodes into partitions
  300. highest_partition = -1
  301. for node in m.graph.nodes:
  302. orig_nodes[node.name] = node
  303. # TODO currently placeholders/parameters aren't put into random partitions,
  304. # rather they're added to the graphs where they are used down below
  305. if node.op in ["placeholder", "get_attr"]:
  306. continue
  307. if node.op == "output":
  308. torch.fx.graph.map_arg(
  309. node.args[0], lambda n: record_cross_partition_use(n, None)
  310. )
  311. continue
  312. if assert_monotonically_increasing:
  313. pid = split_callback(node)
  314. assert highest_partition <= pid, (
  315. "autocast or set_grad_enabled require monotonically increasing partitions:"
  316. f"highest: {highest_partition}, this node's: {pid}"
  317. )
  318. highest_partition = pid
  319. # do not capture cross-partition dependencies for global state nodes as they will be
  320. # self-contained - their setup and unwind will be isolated to each partition submodule.
  321. if node.target not in GLOBAL_STATE_NODES:
  322. torch.fx.graph.map_arg(
  323. node.args, lambda def_node: record_cross_partition_use(def_node, node)
  324. )
  325. torch.fx.graph.map_arg(
  326. node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
  327. ) # noqa: B950
  328. original_partition_order = list(partitions.keys())
  329. # find partitions with no dependencies
  330. root_partitions: list[str] = []
  331. for partition_name, partition in partitions.items():
  332. if not len(partition.dependencies):
  333. root_partitions.append(partition_name)
  334. # check partitions for circular dependencies and create topological partition ordering
  335. sorted_partitions: list[str] = []
  336. while root_partitions:
  337. root_partition = root_partitions.pop()
  338. sorted_partitions.append(root_partition)
  339. for dependent in partitions[root_partition].dependents:
  340. partitions[dependent].dependencies.pop(root_partition) # noqa: B909
  341. if not partitions[dependent].dependencies:
  342. root_partitions.append(dependent)
  343. if len(sorted_partitions) != len(partitions):
  344. raise RuntimeError("cycle exists between partitions!")
  345. # Enter prelude
  346. for regions_mapping in [autocast_regions, grad_regions]:
  347. for node, regions in regions_mapping.items():
  348. assert len(regions) > 0
  349. partitions[str(regions[0])].environment[node] = node
  350. for r in regions[1:]:
  351. partition = partitions[str(r)]
  352. new_node = partition.graph.create_node(
  353. op=node.op,
  354. target=node.target,
  355. args=tuple(arg for arg in node.args),
  356. kwargs={},
  357. type_expr=node.type,
  358. )
  359. new_node.meta = (
  360. node.meta.copy()
  361. ) # is it really a good idea to copy this?
  362. partition.environment[node] = new_node
  363. # add placeholders to partition inputs
  364. for partition_name in sorted_partitions:
  365. partition = partitions[partition_name]
  366. new_inputs: dict[str, None] = {}
  367. counter = 0
  368. for inp in partition.inputs:
  369. orig_node = orig_nodes[inp]
  370. # We don't pass in get_attr nodes as inputs to the partition, but
  371. # instead set them as targets and use getattr within the module
  372. def add_placeholder():
  373. if keep_original_input_name:
  374. name = inp
  375. else:
  376. nonlocal counter
  377. name = f"arg_{counter}"
  378. counter += 1
  379. placeholder = partition.graph.placeholder(
  380. name,
  381. type_expr=orig_nodes[inp].type,
  382. )
  383. new_inputs[inp] = None
  384. return placeholder
  385. if orig_node.op == "get_attr":
  386. assert isinstance(orig_node.target, str)
  387. orig_attr = _get_attr_from_qualname(m, orig_node.target)
  388. if isinstance(orig_attr, torch.nn.Module):
  389. placeholder = partition.graph.get_attr(orig_node.target)
  390. partition.targets[orig_node.target] = orig_attr
  391. else:
  392. placeholder = add_placeholder()
  393. else:
  394. placeholder = add_placeholder()
  395. placeholder.meta = orig_nodes[inp].meta.copy()
  396. partition.environment[orig_nodes[inp]] = placeholder
  397. partition.inputs = new_inputs
  398. # Transform nodes and collect targets for partition's submodule
  399. for node in m.graph.nodes:
  400. if hasattr(node, "_fx_partition"):
  401. partition = partitions[node._fx_partition]
  402. # swap out old graph nodes in kw/args with references to new nodes in this submodule
  403. environment = partition.environment
  404. gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
  405. gathered_kwargs = torch.fx.graph.map_arg(
  406. node.kwargs, lambda n: environment[n]
  407. )
  408. if node.op not in ["call_module", "get_attr"]:
  409. target = node.target
  410. else:
  411. target_attr = _get_attr_from_qualname(m, node.target)
  412. target = node.target.replace(".", "_")
  413. partition.targets[target] = target_attr
  414. # Fill in the passed-in mapping from new qualname to old qualname
  415. if qualname_map is not None:
  416. # When creating the split module later, the submodules will have
  417. # path prefix matching the corresponding partition's submod_name
  418. qualname = f"{partition.submod_name}.{target}"
  419. qualname_map[qualname] = node.target
  420. assert isinstance(gathered_args, tuple)
  421. assert isinstance(gathered_kwargs, dict)
  422. name = node.name if keep_original_node_name else None
  423. new_node = partition.graph.create_node(
  424. op=node.op,
  425. target=target,
  426. args=gathered_args,
  427. kwargs=gathered_kwargs,
  428. type_expr=node.type,
  429. name=name,
  430. )
  431. new_node.meta = node.meta.copy()
  432. partition.environment[node] = new_node
  433. # Exit epilogue
  434. for regions_mapping in [autocast_regions]:
  435. for node in reversed(regions_mapping):
  436. regions = regions_mapping[node]
  437. assert len(regions) > 0
  438. for r in regions[:-1]:
  439. partition = partitions[str(r)]
  440. exit_node = autocast_exits[node]
  441. assert exit_node is not None, "Missing exit node"
  442. new_node = partition.graph.create_node(
  443. op=exit_node.op,
  444. target=exit_node.target,
  445. args=(partition.environment[node],),
  446. kwargs={},
  447. type_expr=exit_node.type,
  448. )
  449. new_node.meta = (
  450. exit_node.meta.copy()
  451. ) # is it really a good idea to copy this?
  452. # original module environment dict mapping node names to nodes
  453. orig_mod_env: dict[str, Node] = {}
  454. # Set up values to construct base module
  455. base_mod_env: dict[str, Node] = {}
  456. base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
  457. base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {}
  458. if not keep_original_order:
  459. for node in m.graph.nodes:
  460. base_mod_env, base_mod_attrs = construct_graph(
  461. node, base_mod_env, base_mod_attrs
  462. )
  463. else:
  464. # Go through the graph to construct the mapping dict
  465. for node in m.graph.nodes:
  466. orig_mod_env[node.name] = node
  467. # Do some things iterating over the partitions in topological order again:
  468. # 1) Finish off submodule Graphs by setting corresponding outputs
  469. # 2) Construct GraphModules for each submodule
  470. # 3) Construct the base graph by emitting calls to those submodules in
  471. # topological order or original order specified by keep_original_order
  472. construct_order_partitions = (
  473. sorted_partitions if not keep_original_order else original_partition_order
  474. )
  475. already_constructed_attr_nodes = set()
  476. # We actually need to insert the placeholder nodes in the original order
  477. # otherwise graph signature will be wrong.
  478. original_order = [node for node in m.graph.nodes if node.op == "placeholder"]
  479. for partition_name in construct_order_partitions:
  480. partition = partitions[partition_name]
  481. # Set correct output values
  482. output_vals = tuple(
  483. partition.environment[orig_nodes[name]] for name in partition.outputs
  484. )
  485. # skip output node generation if there are no output values
  486. num_output_vals = len(output_vals)
  487. if num_output_vals == 1:
  488. partition.graph.output(output_vals[0])
  489. elif num_output_vals > 1:
  490. partition.graph.output(output_vals)
  491. else:
  492. # Invariant - Graph should always have an output node.
  493. partition.graph.output(())
  494. if keep_original_order:
  495. # first get the attr nodes required by this partition
  496. orig_mod_attr_nodes: list[Node] = [
  497. orig_mod_env[key]
  498. for key in partition.inputs
  499. if key not in original_order
  500. ]
  501. for node in original_order:
  502. if node in already_constructed_attr_nodes:
  503. continue # already added this attr to the base graph
  504. base_mod_env, _based_mod_attrs = construct_graph(
  505. node, base_mod_env, base_mod_attrs
  506. )
  507. already_constructed_attr_nodes.add(node)
  508. # Construct GraphModule for this partition
  509. for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
  510. if node in already_constructed_attr_nodes:
  511. continue
  512. base_mod_env, base_mod_attrs = construct_graph(
  513. node, base_mod_env, base_mod_attrs
  514. )
  515. already_constructed_attr_nodes.add(node)
  516. base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
  517. partition.targets, partition.graph
  518. ) # noqa: B950
  519. # Emit call in base graph to this submodule
  520. output_val = base_mod_graph.call_module(
  521. partition.submod_name,
  522. tuple(base_mod_env[name] for name in partition.inputs),
  523. )
  524. num_outputs = len(partition.outputs)
  525. if num_outputs > 1:
  526. # Unpack multiple return values from submodule
  527. output_val_proxy = torch.fx.proxy.Proxy(output_val)
  528. for i, output_name in enumerate(partition.outputs):
  529. base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
  530. elif num_outputs == 1:
  531. base_mod_env[next(iter(partition.outputs))] = output_val
  532. # When keep_original_order=True and if the graph doesn't have any
  533. # `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs`
  534. # are never populated.
  535. # For this case, we call `construct_graph` here which takes care of updating them.
  536. if keep_original_order and not base_mod_env:
  537. for node in m.graph.nodes:
  538. base_mod_env, base_mod_attrs = construct_graph(
  539. node, base_mod_env, base_mod_attrs
  540. )
  541. # Add output node to `base_mod_graph` (i.e. the split graph) which will be returned.
  542. for node in m.graph.nodes:
  543. if node.op == "output":
  544. base_mod_graph.output(
  545. torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
  546. ) # noqa: B950
  547. ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
  548. log.debug(
  549. "%s",
  550. lazy_format_graph_code("post split_module", ret, colored=True),
  551. )
  552. return ret