splitter_base.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112
  1. # mypy: allow-untyped-defs
  2. import argparse
  3. import copy
  4. import json
  5. import logging
  6. import os
  7. from collections import defaultdict
  8. from collections.abc import Iterable, Sequence
  9. from dataclasses import dataclass
  10. from typing import Any, Literal, NamedTuple, Optional
  11. import torch
  12. from torch._logging import trace_structured
  13. from torch.fx._compatibility import compatibility
  14. from torch.fx.node import map_arg
  15. from torch.fx.passes.graph_manipulation import get_size_of_node
  16. from .graph_drawer import FxGraphDrawer
  17. from .operator_support import get_node_target, OperatorSupportBase
  18. from .shape_prop import ShapeProp
  19. from .split_utils import split_by_tags
  20. from .tools_common import (
  21. CALLABLE_NODE_OPS,
  22. FxNetAccFusionsFinder,
  23. is_node_output_tensor,
  24. NodeList,
  25. NodeSet,
  26. Tensors,
  27. )
  28. __all__ = [
  29. "FxNetAccNodesFinder",
  30. "FxNetSplitterInternalError",
  31. "Subgraph",
  32. "SplitResult",
  33. "generate_inputs_for_submodules",
  34. "NodeEvent",
  35. "NodeEventTracker",
  36. ]
  37. _LOGGER = logging.getLogger(__name__)
  38. DEFAULT_MIN_ACC_MODULE_SIZE = 1
  39. DEFAULT_SKIP_FUSION = False
  40. DEFAULT_ALLOW_NON_TENSOR = False
  41. # ENV var and constants for node tracker
  42. TRACKER_DUMP_PATH = "_fx_net_tracker"
  43. NODES_SUFFIX = "_nodes.txt"
  44. ALL_SUFFIX = "_all.txt"
  45. ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE = "FX_NET_ACC_SPLITTER_TRACKER_MODE"
  46. ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH = "FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH"
  47. ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES = (
  48. "FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES"
  49. )
  50. DUMP_PREFIX = os.environ.get(
  51. ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH, TRACKER_DUMP_PATH
  52. )
  53. """
  54. Different modes of the event tracker for local debugging:
  55. "0": No local dumps. Information available by setting breakpoints and visually inspect in pdb.
  56. "1": Dump all events to DUMP_PREFIX_all.txt
  57. "2": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES
  58. recursively and dump to DUMP_PREFIX_nodex.txt
  59. "3": In addition to events dump, track all nodes with more than 1 event recursively and dump to DUMP_PREFIX_nodex.txt
  60. In addition to the above local dumps, tracker is always enabled and dumps via trace_structured.
  61. """
  62. TRACKER_MODE: Literal["0", "1", "2", "3"] = os.environ.get(
  63. ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE, "0"
  64. ) # type: ignore[assignment]
  65. class _SplitterSettingBase:
  66. def __init__(
  67. self,
  68. min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
  69. skip_fusion=DEFAULT_SKIP_FUSION,
  70. allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR,
  71. max_acc_splits: int = -1,
  72. ):
  73. parser = argparse.ArgumentParser()
  74. parser.add_argument(
  75. "--min-acc-module-size",
  76. "--min_acc_module_size",
  77. required=False,
  78. type=int,
  79. help="Minimum size limit of an accelerator subgraph.",
  80. )
  81. parser.add_argument(
  82. "--max-acc-splits",
  83. "--max_acc_splits",
  84. required=False,
  85. type=int,
  86. help="Enforce a maximum number of split subgraphs.",
  87. )
  88. parser.add_argument(
  89. "--skip-fusion",
  90. "--skip_fusion",
  91. default=False,
  92. action="store_true",
  93. help="If true then no fusion groups. Fusion group is used to "
  94. "enforce no non-tensor data flow between submodules. If we don't "
  95. "have this constrain, setting this to false is recommended as it "
  96. "can reduce overhead.",
  97. )
  98. parser.add_argument(
  99. "--allow-non-tensor",
  100. "--allow_non_tensor",
  101. default=False,
  102. action="store_true",
  103. help="For some backends non-tensor data flow between cpu and them "
  104. "are not allowed. Therefore, if a node supported by accelerator but "
  105. "it has non-tensor inputs or outputs to a cpu node we would want to "
  106. "consider it as a cpu node during splitting. However, for some backends "
  107. "we might not care about non-tensor data flow and we can set this option "
  108. "to true to disable the functionality that prevent non-tensor data flow.",
  109. )
  110. args, _unknown = parser.parse_known_args()
  111. self.min_acc_module_size: int = (
  112. args.min_acc_module_size
  113. if args.min_acc_module_size
  114. else min_acc_module_size
  115. )
  116. self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
  117. self.allow_non_tensor: bool = (
  118. args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
  119. )
  120. self.max_acc_splits: int = max_acc_splits
  121. @compatibility(is_backward_compatible=False)
  122. class NodeEvent:
  123. """
  124. An event in graph split that happened on a node.
  125. source: Subject of the event
  126. desc: readable description
  127. dep: Optional dependency, usually the node that caused the event.
  128. """
  129. def __init__(
  130. self, source: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None
  131. ):
  132. self.source = source
  133. self.desc = desc
  134. self.dep = dep
  135. def to_str(self):
  136. # source: The name of the subject of the event.
  137. # desc: description of the event, in the format of <event_type>|<explanation>
  138. # dep: The name of the cause of this event, which is another node, or #
  139. # if it's caused by the subject node
  140. return f"{self.source.name}: {self.desc} {self.dep.name if self.dep else '#'}"
  141. @compatibility(is_backward_compatible=False)
  142. class NodeEventTracker:
  143. """
  144. Tracks node events during the splitter execution.
  145. """
  146. def __init__(self, tracker_mode, dump_prefix):
  147. self.tracker_mode = tracker_mode
  148. self.dump_prefix = dump_prefix
  149. # list of events
  150. self.events = []
  151. # dict from node name to event index
  152. self.node_events = {}
  153. self.writer = print
  154. def add(self, node: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None):
  155. """
  156. Add a new event to the tracker.
  157. """
  158. event = NodeEvent(node, desc, dep)
  159. self.events.append(event)
  160. if node.name not in self.node_events:
  161. self.node_events[node.name] = []
  162. self.node_events[node.name].append(len(self.events) - 1)
  163. def print_node(self, node_name, recursive=False, tab="", writer=None):
  164. """
  165. Print a node and its events.
  166. @param recursive: if True, print nodes that caused the events on this current node.
  167. @param tab: Indentation for dependencies.
  168. @param writer: function to write to file. If None, use print.
  169. """
  170. if not writer:
  171. writer = self.writer
  172. for idx in self.node_events.get(node_name, []):
  173. event = self.events[idx]
  174. writer(tab + event.to_str())
  175. if recursive and event.dep is not None:
  176. self.print_node(
  177. event.dep.name, recursive=True, tab="| " + tab, writer=writer
  178. )
  179. def to_dict(self):
  180. """
  181. Create dict dump on all events.
  182. """
  183. ret: dict[str, list[str]] = {}
  184. for name in self.node_events.keys():
  185. ret[name] = []
  186. for idx in self.node_events.get(name, []):
  187. event = self.events[idx]
  188. ret[name].append(event.to_str())
  189. return ret
  190. def print_all(self, writer=None):
  191. """
  192. Print all nodes in a list.
  193. @param writer: function to write to file. If None, use print.
  194. """
  195. if not writer:
  196. writer = self.writer
  197. for name in self.node_events.keys():
  198. writer(f"Node: {name}:")
  199. self.print_node(name, recursive=False, tab=" ", writer=writer)
  200. def dump(self):
  201. """
  202. Function to be invoked at the end of the finder execution to printout tracked events specified by the mode.
  203. """
  204. # dump via trace_structured
  205. trace_structured(
  206. "artifact",
  207. metadata_fn=lambda: {
  208. "name": "fx_net_acc_splitter_finder_events",
  209. "encoding": "json",
  210. },
  211. payload_fn=lambda: json.dumps(self.to_dict()),
  212. )
  213. def writeln(f):
  214. def fn(x):
  215. return f.write(x + "\n")
  216. return fn
  217. # Mode 0: no local dump
  218. # Mode >=1: Dump all events to file
  219. if self.tracker_mode >= 1:
  220. with open(self.dump_prefix + ALL_SUFFIX, "w") as f:
  221. self.print_all(writeln(f))
  222. def dump_selected_nodes(nodes):
  223. with open(self.dump_prefix + NODES_SUFFIX, "w") as f:
  224. for node_name in nodes:
  225. writeln(f"===== Tracking node {node_name} =====")
  226. self.print_node(
  227. node_name, recursive=True, tab="|-", writer=writeln(f)
  228. )
  229. writeln(f"===== End of tracking node {node_name} =====")
  230. # Mode 2: Dump specific nodes in recursive manner.
  231. # Mode 3: Dump all nodes with more than 1 event in recursive manner.
  232. if self.tracker_mode == 2 or self.tracker_mode == 3:
  233. nodes = (
  234. os.environ.get(ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES, "").split(
  235. ","
  236. )
  237. if self.tracker_mode == 2
  238. else [
  239. name for name, events in self.node_events.items() if len(events) > 1
  240. ]
  241. )
  242. dump_selected_nodes(nodes)
  243. @compatibility(is_backward_compatible=False)
  244. class FxNetAccNodesFinder:
  245. """
  246. Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
  247. input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
  248. I.e. if we have a chain:
  249. ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
  250. where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
  251. This behavior can be turned off by passing allow_non_tensor=True.
  252. """
  253. def __init__(
  254. self,
  255. module: torch.fx.GraphModule,
  256. operator_support: OperatorSupportBase,
  257. allow_non_tensor: bool,
  258. ):
  259. self.module = module
  260. self.operator_support = operator_support
  261. self.allow_non_tensor = allow_non_tensor
  262. self.acc_nodes: NodeSet = set()
  263. self.tracker = NodeEventTracker(int(TRACKER_MODE), DUMP_PREFIX)
  264. def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
  265. """
  266. Transitively excludes nodes from ACC supported set.
  267. For every node in the worklist:
  268. - removes its downstream ACC nodes from ACC supported set,
  269. - if any downstream ACC node produces non-tensor output,
  270. then it gets added into the worklist.
  271. """
  272. while cpu_worklist:
  273. node = cpu_worklist.pop(0)
  274. for user in node.users:
  275. if user in self.acc_nodes:
  276. self.acc_nodes.remove(user)
  277. self.tracker.add(user, "acc_del|user_of_new_cpu_node", node)
  278. if not is_node_output_tensor(user):
  279. self.tracker.add(user, "new_cpu_node|non_tensor_output")
  280. cpu_worklist.append(user)
  281. def reduce_acc_nodes_non_tensor_input(self):
  282. """
  283. Excludes nodes from ACC supported set that have direct
  284. upstream CPU nodes that produce non-tensor outputs.
  285. """
  286. non_tensor_cpu_nodes: NodeList = []
  287. for node in self.module.graph.nodes:
  288. if node.op not in CALLABLE_NODE_OPS:
  289. continue
  290. if node in self.acc_nodes:
  291. continue
  292. if is_node_output_tensor(node):
  293. continue
  294. self.tracker.add(node, "new_cpu_node|callable_non_tensor_input")
  295. non_tensor_cpu_nodes.append(node)
  296. self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
  297. def reduce_acc_nodes_non_tensor_output(self):
  298. """
  299. Excludes nodes from ACC supported set that produce non-tensor
  300. outputs and have downstream CPU nodes.
  301. """
  302. while True:
  303. new_cpu_nodes: NodeList = []
  304. for acc_node in self.acc_nodes:
  305. if is_node_output_tensor(acc_node):
  306. continue
  307. for user in acc_node.users:
  308. if user not in self.acc_nodes:
  309. new_cpu_nodes.append(acc_node)
  310. self.tracker.add(
  311. acc_node, "acc_del|non_tensor_output_with_cpu_user", user
  312. )
  313. break
  314. if not new_cpu_nodes:
  315. break
  316. for new_cpu_node in new_cpu_nodes:
  317. self.acc_nodes.remove(new_cpu_node)
  318. self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
  319. def __call__(self) -> NodeSet:
  320. submodules = dict(self.module.named_modules())
  321. self.acc_nodes = set()
  322. for n in self.module.graph.nodes:
  323. if n.op not in CALLABLE_NODE_OPS:
  324. self.tracker.add(n, "init_cpu|not_callable")
  325. continue
  326. if not self.operator_support.is_node_supported(submodules, n):
  327. self.tracker.add(n, "init_cpu|operator_support")
  328. continue
  329. self.tracker.add(n, "init_acc|callable_and_operator_supported")
  330. self.acc_nodes.add(n)
  331. if not self.allow_non_tensor:
  332. self.reduce_acc_nodes_non_tensor_input()
  333. self.reduce_acc_nodes_non_tensor_output()
  334. self.tracker.dump()
  335. return self.acc_nodes
  336. @compatibility(is_backward_compatible=False)
  337. class FxNetSplitterInternalError(Exception):
  338. pass
  339. @compatibility(is_backward_compatible=False)
  340. @dataclass
  341. class Subgraph:
  342. is_acc: bool
  343. nodes: NodeList
  344. device_ordinal: Optional[int] = None
  345. @compatibility(is_backward_compatible=False)
  346. class SplitResult(NamedTuple):
  347. """
  348. Stores the results of the splitter.
  349. Attributes:
  350. split_module: root module after splitting.
  351. submodule_inputs: a dict that maps submodule name to its inputs.
  352. non_acc_submodule_prefix: the prefix for non acc submodules. For
  353. acc submodule the prefix is always "_run_on_acc_".
  354. """
  355. split_module: torch.fx.GraphModule
  356. submodule_inputs: dict[str, Any]
  357. non_acc_submodule_prefix: str
  358. @compatibility(is_backward_compatible=False)
  359. def generate_inputs_for_submodules(
  360. model: torch.nn.Module,
  361. inputs: Sequence[Any],
  362. target_submodules: Iterable[str],
  363. deepcopy: bool = False,
  364. ) -> dict[str, Any]:
  365. """
  366. Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
  367. function doesn't work.
  368. Args:
  369. model: root model.
  370. inputs: inputs to the root model.
  371. target_submodules: submodules that we want to generate inputs for.
  372. Returns:
  373. A dict that maps from submodule name to its inputs.
  374. """
  375. handles = []
  376. results = {}
  377. submodule_to_names = {mod: name for name, mod in model.named_modules()}
  378. def pre_forward(module, module_inputs):
  379. results[submodule_to_names[module]] = (
  380. copy.deepcopy(module_inputs) if deepcopy else module_inputs
  381. )
  382. for name, mod in model.named_modules():
  383. if name in target_submodules:
  384. if not isinstance(mod, torch.jit.ScriptModule):
  385. handles.append(mod.register_forward_pre_hook(pre_forward))
  386. def clean_up_handles():
  387. for h in handles:
  388. h.remove()
  389. try:
  390. with torch.no_grad():
  391. model(*inputs)
  392. except Exception as e:
  393. clean_up_handles()
  394. raise e
  395. clean_up_handles()
  396. return results
  397. class _SplitterBase:
  398. """
  399. Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
  400. Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
  401. Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
  402. Given the following graph:
  403. ==> b ==>
  404. // \\
  405. a d
  406. \\ //
  407. ==> c ==>
  408. class SimpleModule(torch.nn.Module):
  409. def forward(self, a):
  410. b = torch.sin(a)
  411. c = torch.cos(a)
  412. d = b + c
  413. return d
  414. and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
  415. we will get the following split result:
  416. main:
  417. def forward(self, a):
  418. run_on_acc_0_0 = self._run_on_acc_0_0(a)
  419. getitem = run_on_acc_0_0[0]
  420. getitem_1 = run_on_acc_0_0[1]
  421. run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
  422. return run_on_cpu_1_1
  423. _run_on_acc_0_0:
  424. def forward(self, a):
  425. sin_1 = torch.sin(a)
  426. cos_1 = torch.cos(a)
  427. return (sin_1, cos_1)
  428. _run_on_cpu_1_1:
  429. def forward(self, sin_1, cos_1):
  430. add_1 = sin_1 + cos_1
  431. return add_1
  432. """
  433. # PCIe bandwidth for the backend, default to 100 GB/s
  434. PCIe_BW = 100 * 2**30
  435. def __init__(
  436. self,
  437. module: torch.fx.GraphModule,
  438. sample_input: Sequence[Any],
  439. operator_support: OperatorSupportBase,
  440. settings: _SplitterSettingBase,
  441. non_acc_submodule_name: str = "_run_on_cpu_",
  442. return_tuple: bool = False,
  443. nodes_finder: Optional[FxNetAccNodesFinder] = None,
  444. ):
  445. """
  446. Preprocesses graph before splitting:
  447. - finds nodes supported by ACC,
  448. - finds fusion groups for ACC nodes having non-tensor IO,
  449. - builds a graph of direct dependencies,
  450. - builds a map of fused nodes to their fusions.
  451. As a result we get self.acc_nodes, self.deps and self.fusions.
  452. """
  453. assert isinstance(module, torch.fx.GraphModule)
  454. self.module = module
  455. ShapeProp(self.module).propagate(*sample_input)
  456. self.settings = settings
  457. self.operator_support = operator_support
  458. self.sample_input = sample_input
  459. if nodes_finder is None:
  460. nodes_finder = FxNetAccNodesFinder(
  461. self.module, self.operator_support, self.settings.allow_non_tensor
  462. )
  463. self.acc_nodes = nodes_finder()
  464. if self.settings.skip_fusion:
  465. self.fusions = {}
  466. else:
  467. self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
  468. # Modify deps to add more deps for fused nodes
  469. self.deps = self.find_deps()
  470. self.update_deps_for_fusions()
  471. self.non_acc_submodule_name = non_acc_submodule_name
  472. self._node_submodule_map: dict[str, str] = {}
  473. self._return_tuple = return_tuple
  474. self.tags: list[str] = []
  475. # ===============================================================
  476. # Helpers for ctor and initial state
  477. # ===============================================================
  478. def get_node_submodule_map(self) -> dict[str, str]:
  479. """Returns a map from node name to submodule name, e.g.
  480. node: main_module_impl_impl_over_arch_unary_multiple_embedding
  481. _pooling_embedding_pooling_sparse_entity_equivalence_key
  482. _proxy_embedding_bag
  483. maps to submodule name of: _run_on_acc_1
  484. """
  485. return self._node_submodule_map
  486. def find_deps(self) -> dict[torch.fx.Node, NodeSet]:
  487. """
  488. Builds a graph of node dependencies. Leaf nodes don't have any
  489. dependencies and the "output" node doesn't have nodes depending on it.
  490. Resulting graph has only direct dependencies, i.e. there are no
  491. transitive dependencies.
  492. """
  493. deps: dict[torch.fx.Node, NodeSet] = defaultdict(set)
  494. for node in self.module.graph.nodes:
  495. if node.op not in CALLABLE_NODE_OPS:
  496. continue
  497. for user in node.users:
  498. if user.op != "output":
  499. deps[user].add(node)
  500. return deps
  501. def update_deps_for_fusions(self):
  502. """
  503. Updates graph of dependencies so that:
  504. - nodes from the same fusion depend on the same set of outer nodes,
  505. - outer nodes depending on a fusion depend on all nodes in that fusion.
  506. """
  507. for node in self.fusions:
  508. fusion = self.fusions[node]
  509. for fused_neighbor in fusion:
  510. self.deps[node].update(self.deps[fused_neighbor] - fusion)
  511. for user in fused_neighbor.users:
  512. if user not in fusion:
  513. self.deps[user].add(node)
  514. # ===============================================================
  515. # Helpers for preview
  516. # ===============================================================
  517. def _lower_model_to_backend(
  518. self, mod: torch.fx.GraphModule, inputs: Tensors
  519. ) -> torch.nn.Module:
  520. """
  521. Lower the model to a backend.
  522. """
  523. return mod
  524. def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str:
  525. """
  526. When an error occurs during lowering or running the lowered mod, we use this
  527. function to find culprits in the `mod` that causes the error.
  528. """
  529. return "Unable to find a culprit because _find_culprit() function is not implemented."
  530. def _draw_graph_based_on_node_support(
  531. self, mod: torch.fx.GraphModule, supported_nodes: NodeList
  532. ):
  533. color_map = {
  534. "default": "AliceBlue",
  535. "supported": "chartreuse1",
  536. "unsupported": "crimson",
  537. }
  538. class CustomDrawer(FxGraphDrawer):
  539. def _get_node_style(self, node):
  540. template = super()._get_node_style(node)
  541. if node in supported_nodes:
  542. template["fillcolor"] = color_map["supported"]
  543. elif node.op in CALLABLE_NODE_OPS:
  544. template["fillcolor"] = color_map["unsupported"]
  545. else:
  546. template["fillcolor"] = color_map["default"]
  547. return template
  548. drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
  549. dot_graph = drawer.get_main_dot_graph()
  550. # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
  551. dot_graph.write_raw("node_support.dot") # type: ignore[attr-defined]
  552. def node_support_preview(self, dump_graph: bool = False):
  553. submodules = dict(self.module.named_modules())
  554. supported_nodes: NodeList = []
  555. supported_node_types = defaultdict(set)
  556. unsupported_node_types = defaultdict(set)
  557. def get_dtype(arg):
  558. tensor_meta = arg.meta.get("tensor_meta")
  559. return getattr(tensor_meta, "dtype", None)
  560. for node in self.module.graph.nodes:
  561. if node.op not in CALLABLE_NODE_OPS:
  562. continue
  563. target = get_node_target(submodules, node)
  564. # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
  565. arg_dtypes = [
  566. get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
  567. for arg in node.args
  568. ]
  569. # Find last non-None element. If all elements are None, return max_len.
  570. last_index = len(arg_dtypes) - next(
  571. (
  572. i
  573. for i, dtype in enumerate(reversed(arg_dtypes))
  574. if dtype is not None
  575. ),
  576. len(arg_dtypes),
  577. )
  578. # Strip None elements at the end.
  579. arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
  580. kwarg_dtypes_tuple = tuple(
  581. (k, get_dtype(arg))
  582. for k, arg in node.kwargs.items()
  583. if isinstance(arg, torch.fx.Node)
  584. )
  585. if self.operator_support.is_node_supported(submodules, node):
  586. supported_nodes.append(node)
  587. supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
  588. else:
  589. unsupported_node_types[target].add(
  590. (arg_dtypes_tuple, kwarg_dtypes_tuple)
  591. )
  592. if dump_graph:
  593. self._draw_graph_based_on_node_support(self.module, supported_nodes)
  594. reports = "\nSupported node types in the model:\n"
  595. for t, dtypes in supported_node_types.items():
  596. for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
  597. reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
  598. reports += "\nUnsupported node types in the model:\n"
  599. for t, dtypes in unsupported_node_types.items():
  600. for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
  601. reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
  602. print(reports)
  603. # Return reports for testing purpose
  604. return reports
  605. def split_preview(self, dump_graph: bool = False):
  606. reports = ""
  607. subgraphs = self.put_nodes_into_subgraphs()
  608. acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
  609. cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
  610. reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
  611. reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
  612. subgraphs = self.remove_small_acc_subgraphs(subgraphs)
  613. acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
  614. cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
  615. reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
  616. reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
  617. for i, subgraph in enumerate(subgraphs):
  618. reports += (
  619. f"_run_on_acc_{i}: "
  620. if subgraph.is_acc
  621. else f"{self.non_acc_submodule_name}{i}: "
  622. )
  623. reports += f"{len(subgraph.nodes)} node(s)\n"
  624. self.tag(subgraphs)
  625. split_mod = self.split(remove_tag=True)
  626. split_mod.eval()
  627. if dump_graph:
  628. drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True)
  629. dot_graphs = drawer.get_all_dot_graphs()
  630. for name, dot_graph in dot_graphs.items():
  631. # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
  632. dot_graph.write_raw(f"{name}.dot") # type: ignore[attr-defined]
  633. max_qps: float = self.PCIe_BW
  634. bottleneck_module = ""
  635. for node in split_mod.graph.nodes:
  636. if node.op == "call_module" and "acc" in node.target:
  637. reports += f"\nProcessing acc submodule {node.target}\n"
  638. submod = getattr(split_mod, node.target)
  639. def get_submod_inputs(main_mod, submod, example_inputs):
  640. sub_inputs = None
  641. def get_inputs(self, inputs):
  642. nonlocal sub_inputs
  643. sub_inputs = inputs
  644. handle = submod.register_forward_pre_hook(get_inputs)
  645. main_mod(*example_inputs)
  646. handle.remove()
  647. return sub_inputs
  648. submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input)
  649. ShapeProp(submod).propagate(*submod_inputs)
  650. total_input_bytes = 0
  651. total_output_bytes = 0
  652. reports += "Checking inputs...\n"
  653. for n in submod.graph.nodes:
  654. if n.op == "placeholder":
  655. if not is_node_output_tensor(n):
  656. reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
  657. else:
  658. total_input_bytes += get_size_of_node(submod, n)[0]
  659. if n.op == "output":
  660. output_node = n
  661. reports += "Checking outputs...\n"
  662. def get_bytes(node: torch.fx.Node):
  663. nonlocal total_output_bytes
  664. nonlocal reports
  665. if not is_node_output_tensor(node):
  666. reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
  667. else:
  668. total_output_bytes += get_size_of_node(submod, node)[0]
  669. map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined]
  670. qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
  671. reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
  672. reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
  673. if qps < max_qps:
  674. max_qps = qps
  675. bottleneck_module = node.target
  676. try:
  677. lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
  678. except RuntimeError:
  679. reports += "Run into an error during lowering!\n"
  680. reports += self._find_culprit(submod, submod_inputs)
  681. continue
  682. try:
  683. lowered_submod(*submod_inputs)
  684. except RuntimeError:
  685. reports += "Run into an error during inference!\n"
  686. reports += self._find_culprit(submod, submod_inputs)
  687. else:
  688. reports += "Lowering and running succeed!\n"
  689. reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
  690. reports += f" bottleneck is submodule {bottleneck_module}."
  691. print(reports)
  692. # return the reports for testing purposes
  693. return reports
  694. # ===============================================================
  695. # Helpers for extend_acc_subgraph() method
  696. # ===============================================================
  697. def find_reverse_deps(
  698. self, tag_id: Optional[int] = None
  699. ) -> dict[torch.fx.Node, NodeSet]:
  700. """
  701. Builds reversed topological node dependencies, if tag_id is specified,
  702. we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
  703. """
  704. result: dict[torch.fx.Node, NodeSet] = defaultdict(set)
  705. for node in self.module.graph.nodes:
  706. if node.op not in CALLABLE_NODE_OPS:
  707. continue
  708. for user in node.users:
  709. if user.op not in CALLABLE_NODE_OPS:
  710. continue
  711. if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
  712. result[node].add(user)
  713. return result
  714. def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]):
  715. processed_node = set()
  716. for node, fusion in self.fusions.items():
  717. if node in processed_node:
  718. continue
  719. new_dep = set()
  720. # Create a new dependency set which include all the
  721. # dependencies of the nodes in the fusion group
  722. for n in fusion:
  723. new_dep.update(deps[n])
  724. # Exclude nodes in the fusion
  725. new_dep.difference_update(fusion)
  726. # Update dependency
  727. for n in fusion:
  728. deps[n] = new_dep
  729. for arg in n.all_input_nodes:
  730. if arg not in fusion:
  731. deps[arg].update(fusion)
  732. processed_node.add(n)
  733. def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
  734. """
  735. Finds parent nodes of the `tag` subgraph.
  736. Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
  737. and is not a placeholder, we consider it as the parent node of the subgraph.
  738. """
  739. parent_nodes = set()
  740. for node in self.module.graph.nodes:
  741. if node.op in CALLABLE_NODE_OPS and node.tag == tag:
  742. for arg in node.all_input_nodes:
  743. if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
  744. parent_nodes.add(arg)
  745. return parent_nodes
  746. def extend_acc_subgraph(self, tag: str):
  747. """
  748. Extend the acc subgraph with `tag` going the reversed topological direction.
  749. """
  750. # Dict that maps node to its users and ignore users that
  751. # are in the subgraph that has greater tag
  752. deps = self.find_reverse_deps(tag_id=int(tag.rsplit("_", maxsplit=1)[-1]))
  753. self.update_reverse_deps_for_fusions(deps)
  754. # Parent nodes of the subgraph
  755. parent_nodes = self.find_parent_nodes_of_subgraph(tag)
  756. visited_nodes: NodeSet = set()
  757. while parent_nodes:
  758. node = None
  759. # Find a acc node that depends on visited nodes only
  760. for n in parent_nodes:
  761. if deps[n] <= visited_nodes and n in self.acc_nodes:
  762. node = n
  763. break
  764. if node is None:
  765. break
  766. # Put the node into `tag` subgraph
  767. node.tag = tag # type: ignore[attr-defined]
  768. parent_nodes.remove(node)
  769. visited_nodes.add(node)
  770. # If node is in a fusion group, add all fusion buddies to parent nodes
  771. if node in self.fusions:
  772. for fusion_node in self.fusions[node]:
  773. if fusion_node not in visited_nodes:
  774. parent_nodes.add(fusion_node)
  775. # Add inputs of the node to parent nodes
  776. for arg in node.all_input_nodes:
  777. if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
  778. parent_nodes.add(arg)
  779. # ===============================================================
  780. # Helpers for split() method
  781. # ===============================================================
  782. def starter_nodes(self) -> tuple[NodeSet, NodeSet]:
  783. """
  784. Finds nodes that consume module inputs or get_attr nodes.
  785. """
  786. starter_cpu_nodes: NodeSet = set()
  787. starter_acc_nodes: NodeSet = set()
  788. for node in self.module.graph.nodes:
  789. if node.op not in {"placeholder", "get_attr"}:
  790. continue
  791. for user in node.users:
  792. if user in self.acc_nodes:
  793. starter_acc_nodes.add(user)
  794. else:
  795. starter_cpu_nodes.add(user)
  796. return starter_cpu_nodes, starter_acc_nodes
  797. def put_nodes_into_subgraphs(self) -> list[Subgraph]:
  798. # We start graph traversal from leaf nodes
  799. current_cpu_nodes, current_acc_nodes = self.starter_nodes()
  800. visited_nodes: NodeSet = set()
  801. # Determine which subgraph to start from based on which subgraph has
  802. # 0-dep node
  803. acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
  804. current_subgraph_nodes: NodeList = []
  805. # Result accumulator
  806. subgraphs: list[Subgraph] = []
  807. while current_cpu_nodes or current_acc_nodes:
  808. # Find the first node that should belong to the current subgraph and has all dependencies resolved
  809. current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
  810. node = next(
  811. (n for n in current_nodes if self.deps[n] <= visited_nodes),
  812. None,
  813. )
  814. # If nothing was found, then it's time to flip the mode and start a new subgraph
  815. if node is None:
  816. if not current_subgraph_nodes:
  817. raise FxNetSplitterInternalError("Subgraph can't be empty")
  818. subgraphs.append(
  819. Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
  820. )
  821. acc_subgraph = not acc_subgraph
  822. current_subgraph_nodes = []
  823. continue
  824. current_nodes.remove(node)
  825. visited_nodes.add(node)
  826. current_subgraph_nodes.append(node)
  827. # Add fusion buddies
  828. if node in self.fusions:
  829. if node in self.acc_nodes:
  830. current_acc_nodes.update(self.fusions[node] - visited_nodes)
  831. else:
  832. current_cpu_nodes.update(self.fusions[node] - visited_nodes)
  833. # Put depending nodes into the queue
  834. for user in node.users:
  835. if user.op not in CALLABLE_NODE_OPS:
  836. continue
  837. # Add downstream nodes
  838. if user in self.acc_nodes:
  839. current_acc_nodes.add(user)
  840. else:
  841. current_cpu_nodes.add(user)
  842. # Check if the last subgraph was not created
  843. if current_subgraph_nodes:
  844. subgraphs.append(
  845. Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
  846. )
  847. if not subgraphs:
  848. raise FxNetSplitterInternalError("Couldn't create subgraphs")
  849. return subgraphs
  850. def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]:
  851. """
  852. This pass finds ACC submodules with less than specified size and merges
  853. them with adjacent CPU submodules.
  854. """
  855. result: list[Subgraph] = []
  856. for subgraph in subgraphs:
  857. if subgraph.is_acc:
  858. if len(subgraph.nodes) >= self.settings.min_acc_module_size:
  859. result.append(subgraph)
  860. else:
  861. print(
  862. "Eliminating acc subgraph because it's smaller than the threshold: "
  863. f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
  864. )
  865. if result:
  866. result[-1].nodes.extend(subgraph.nodes)
  867. else:
  868. subgraph.is_acc = False
  869. result.append(subgraph)
  870. else:
  871. if result and not result[-1].is_acc:
  872. result[-1].nodes.extend(subgraph.nodes)
  873. else:
  874. result.append(subgraph)
  875. return result
  876. def tag(self, subgraphs: list[Subgraph]):
  877. self.tags = []
  878. for subgraph in subgraphs:
  879. tag = (
  880. f"_run_on_acc_{len(self.tags)}"
  881. if subgraph.is_acc
  882. else f"{self.non_acc_submodule_name}{len(self.tags)}"
  883. )
  884. self.tags.append(tag)
  885. for node in subgraph.nodes:
  886. if hasattr(node, "tag"):
  887. raise FxNetSplitterInternalError(f"Node {node} was already tagged")
  888. node.tag = tag # type: ignore[attr-defined]
  889. self._node_submodule_map[node.name] = tag
  890. def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
  891. split_module = split_by_tags(
  892. self.module, self.tags, return_tuple=self._return_tuple
  893. )
  894. if remove_tag:
  895. for node in self.module.graph.nodes:
  896. if hasattr(node, "tag"):
  897. del node.tag
  898. return split_module # type: ignore[return-value]
  899. def __call__(self) -> torch.fx.GraphModule:
  900. subgraphs = self.put_nodes_into_subgraphs()
  901. subgraphs = self.remove_small_acc_subgraphs(subgraphs)
  902. acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
  903. non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
  904. print(
  905. f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs"
  906. )
  907. self.tag(subgraphs)
  908. return self.split()
  909. def generate_split_results(self) -> SplitResult:
  910. split_module = self()
  911. submodule_names = []
  912. for name, _mod in split_module.named_children():
  913. submodule_names.append(name)
  914. if (
  915. self.settings.max_acc_splits > 0
  916. and len(submodule_names) > self.settings.max_acc_splits
  917. ):
  918. raise ValueError(
  919. "Cannot fulfill max_acc_splits limit. "
  920. "This may cause split fragmentation and "
  921. "result in performance issues."
  922. )
  923. submodule_inputs = generate_inputs_for_submodules(
  924. split_module, self.sample_input, submodule_names
  925. )
  926. return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)