graph_drawer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # mypy: allow-untyped-defs
  2. import hashlib
  3. from itertools import chain
  4. from types import ModuleType
  5. from typing import Any, Optional, TYPE_CHECKING
  6. import torch
  7. import torch.fx
  8. from torch.fx._compatibility import compatibility
  9. from torch.fx.graph import _parse_stack_trace
  10. from torch.fx.node import _format_arg, _get_qualified_name
  11. from torch.fx.operator_schemas import normalize_function
  12. from torch.fx.passes.shape_prop import TensorMetadata
  13. if TYPE_CHECKING:
  14. import pydot
  15. HAS_PYDOT = True
  16. else:
  17. pydot: Optional[ModuleType]
  18. try:
  19. import pydot
  20. HAS_PYDOT = True
  21. except ModuleNotFoundError:
  22. HAS_PYDOT = False
  23. pydot = None
  24. __all__ = ["FxGraphDrawer"]
  25. _COLOR_MAP = {
  26. "placeholder": '"AliceBlue"',
  27. "call_module": "LemonChiffon1",
  28. "get_param": "Yellow2",
  29. "get_attr": "LightGrey",
  30. "output": "PowderBlue",
  31. }
  32. _HASH_COLOR_MAP = [
  33. "CadetBlue1",
  34. "Coral",
  35. "DarkOliveGreen1",
  36. "DarkSeaGreen1",
  37. "GhostWhite",
  38. "Khaki1",
  39. "LavenderBlush1",
  40. "LightSkyBlue",
  41. "MistyRose1",
  42. "MistyRose2",
  43. "PaleTurquoise2",
  44. "PeachPuff1",
  45. "Salmon",
  46. "Thistle1",
  47. "Thistle3",
  48. "Wheat1",
  49. ]
  50. _WEIGHT_TEMPLATE = {
  51. "fillcolor": "Salmon",
  52. "style": '"filled,rounded"',
  53. "fontcolor": "#000000",
  54. }
  55. if HAS_PYDOT:
  56. @compatibility(is_backward_compatible=False)
  57. class FxGraphDrawer:
  58. """
  59. Visualize a torch.fx.Graph with graphviz
  60. Basic usage:
  61. g = FxGraphDrawer(symbolic_traced, "resnet18")
  62. g.get_dot_graph().write_svg("a.svg")
  63. """
  64. def __init__(
  65. self,
  66. graph_module: torch.fx.GraphModule,
  67. name: str,
  68. ignore_getattr: bool = False,
  69. ignore_parameters_and_buffers: bool = False,
  70. skip_node_names_in_args: bool = True,
  71. parse_stack_trace: bool = False,
  72. dot_graph_shape: Optional[str] = None,
  73. normalize_args: bool = False,
  74. ):
  75. self._name = name
  76. self.dot_graph_shape = (
  77. dot_graph_shape if dot_graph_shape is not None else "record"
  78. )
  79. self.normalize_args = normalize_args
  80. _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
  81. self._dot_graphs = {
  82. name: self._to_dot(
  83. graph_module,
  84. name,
  85. ignore_getattr,
  86. ignore_parameters_and_buffers,
  87. skip_node_names_in_args,
  88. parse_stack_trace,
  89. )
  90. }
  91. for node in graph_module.graph.nodes:
  92. if node.op != "call_module":
  93. continue
  94. leaf_node = self._get_leaf_node(graph_module, node)
  95. if not isinstance(leaf_node, torch.fx.GraphModule):
  96. continue
  97. self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
  98. leaf_node,
  99. f"{name}_{node.target}",
  100. ignore_getattr,
  101. ignore_parameters_and_buffers,
  102. skip_node_names_in_args,
  103. parse_stack_trace,
  104. )
  105. def get_dot_graph(self, submod_name=None) -> pydot.Dot:
  106. """
  107. Visualize a torch.fx.Graph with graphviz
  108. Example:
  109. >>> # xdoctest: +REQUIRES(module:pydot)
  110. >>> # xdoctest: +REQUIRES(module:ubelt)
  111. >>> # define module
  112. >>> class MyModule(torch.nn.Module):
  113. >>> def __init__(self) -> None:
  114. >>> super().__init__()
  115. >>> self.linear = torch.nn.Linear(4, 5)
  116. >>> def forward(self, x):
  117. >>> return self.linear(x).clamp(min=0.0, max=1.0)
  118. >>> module = MyModule()
  119. >>> # trace the module
  120. >>> symbolic_traced = torch.fx.symbolic_trace(module)
  121. >>> # setup output file
  122. >>> import ubelt as ub
  123. >>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir()
  124. >>> fpath = dpath / "linear.svg"
  125. >>> # draw the graph
  126. >>> g = FxGraphDrawer(symbolic_traced, "linear")
  127. >>> g.get_dot_graph().write_svg(fpath)
  128. """
  129. if submod_name is None:
  130. return self.get_main_dot_graph()
  131. else:
  132. return self.get_submod_dot_graph(submod_name)
  133. def get_main_dot_graph(self) -> pydot.Dot:
  134. return self._dot_graphs[self._name]
  135. def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
  136. return self._dot_graphs[f"{self._name}_{submod_name}"]
  137. def get_all_dot_graphs(self) -> dict[str, pydot.Dot]:
  138. return self._dot_graphs
  139. def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]:
  140. template = {
  141. "shape": self.dot_graph_shape,
  142. "fillcolor": "#CAFFE3",
  143. "style": '"filled,rounded"',
  144. "fontcolor": "#000000",
  145. }
  146. if node.op in _COLOR_MAP:
  147. template["fillcolor"] = _COLOR_MAP[node.op]
  148. else:
  149. # Use a random color for each node; based on its name so it's stable.
  150. target_name = node._pretty_print_target(node.target)
  151. target_hash = int(
  152. hashlib.md5(
  153. target_name.encode(), usedforsecurity=False
  154. ).hexdigest()[:8],
  155. 16,
  156. )
  157. template["fillcolor"] = _HASH_COLOR_MAP[
  158. target_hash % len(_HASH_COLOR_MAP)
  159. ]
  160. return template
  161. def _get_leaf_node(
  162. self, module: torch.nn.Module, node: torch.fx.Node
  163. ) -> torch.nn.Module:
  164. py_obj = module
  165. assert isinstance(node.target, str)
  166. atoms = node.target.split(".")
  167. for atom in atoms:
  168. if not hasattr(py_obj, atom):
  169. raise RuntimeError(
  170. str(py_obj) + " does not have attribute " + atom + "!"
  171. )
  172. py_obj = getattr(py_obj, atom)
  173. return py_obj
  174. def _typename(self, target: Any) -> str:
  175. if isinstance(target, torch.nn.Module):
  176. ret = torch.typename(target)
  177. elif isinstance(target, str):
  178. ret = target
  179. else:
  180. ret = _get_qualified_name(target)
  181. # Escape "{" and "}" to prevent dot files like:
  182. # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
  183. # which triggers `Error: bad label format (...)` from dot
  184. return ret.replace("{", r"\{").replace("}", r"\}")
  185. # shorten path to avoid drawing long boxes
  186. # for full path = '/home/weif/pytorch/test.py'
  187. # return short path = 'pytorch/test.py'
  188. def _shorten_file_name(
  189. self,
  190. full_file_name: str,
  191. truncate_to_last_n: int = 2,
  192. ):
  193. splits = full_file_name.split("/")
  194. if len(splits) >= truncate_to_last_n:
  195. return "/".join(splits[-truncate_to_last_n:])
  196. return full_file_name
  197. def _get_node_label(
  198. self,
  199. module: torch.fx.GraphModule,
  200. node: torch.fx.Node,
  201. skip_node_names_in_args: bool,
  202. parse_stack_trace: bool,
  203. ) -> str:
  204. def _get_str_for_args_kwargs(arg):
  205. if isinstance(arg, tuple):
  206. prefix, suffix = r"|args=(\l", r",\n)\l"
  207. arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
  208. elif isinstance(arg, dict):
  209. prefix, suffix = r"|kwargs={\l", r",\n}\l"
  210. arg_strs_list = [
  211. f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items()
  212. ]
  213. else: # Fall back to nothing in unexpected case.
  214. return ""
  215. # Strip out node names if requested.
  216. if skip_node_names_in_args:
  217. arg_strs_list = [a for a in arg_strs_list if "%" not in a]
  218. if len(arg_strs_list) == 0:
  219. return ""
  220. arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
  221. if len(arg_strs_list) == 1:
  222. arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
  223. return arg_strs.replace("{", r"\{").replace("}", r"\}")
  224. label = "{" + f"name=%{node.name}|op_code={node.op}\n"
  225. if node.op == "call_module":
  226. leaf_module = self._get_leaf_node(module, node)
  227. label += r"\n" + self._typename(leaf_module) + r"\n|"
  228. extra = ""
  229. if hasattr(leaf_module, "__constants__"):
  230. extra = r"\n".join(
  231. [
  232. f"{c}: {getattr(leaf_module, c)}"
  233. for c in leaf_module.__constants__ # type: ignore[union-attr]
  234. ] # type: ignore[union-attr]
  235. )
  236. label += extra + r"\n"
  237. else:
  238. label += f"|target={self._typename(node.target)}" + r"\n"
  239. if self.normalize_args:
  240. try:
  241. args, kwargs = normalize_function( # type: ignore[misc]
  242. node.target, # type: ignore[arg-type]
  243. node.args, # type: ignore[arg-type]
  244. node.kwargs,
  245. normalize_to_only_use_kwargs=True,
  246. )
  247. except Exception:
  248. # Fallback to not normalizing if there's an exception.
  249. # Some functions need overloads specified to normalize.
  250. args, kwargs = node.args, node.kwargs
  251. else:
  252. args, kwargs = node.args, node.kwargs
  253. if len(args) > 0:
  254. label += _get_str_for_args_kwargs(args)
  255. if len(kwargs) > 0:
  256. label += _get_str_for_args_kwargs(kwargs)
  257. label += f"|num_users={len(node.users)}" + r"\n"
  258. tensor_meta = node.meta.get("tensor_meta")
  259. label += self._tensor_meta_to_label(tensor_meta)
  260. # for original fx graph
  261. # print buf=buf0, n_origin=6
  262. buf_meta = node.meta.get("buf_meta", None)
  263. if buf_meta is not None:
  264. label += f"|buf={buf_meta.name}" + r"\n"
  265. label += f"|n_origin={buf_meta.n_origin}" + r"\n"
  266. # for original fx graph
  267. # print file:lineno code
  268. if parse_stack_trace and node.stack_trace is not None:
  269. parsed_stack_trace = _parse_stack_trace(node.stack_trace)
  270. fname = self._shorten_file_name(parsed_stack_trace.file)
  271. label += (
  272. f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}"
  273. + r"\n"
  274. )
  275. return label + "}"
  276. def _tensor_meta_to_label(self, tm) -> str:
  277. if tm is None:
  278. return ""
  279. elif isinstance(tm, TensorMetadata):
  280. return self._stringify_tensor_meta(tm)
  281. elif isinstance(tm, list):
  282. result = ""
  283. for item in tm:
  284. result += self._tensor_meta_to_label(item)
  285. return result
  286. elif isinstance(tm, dict):
  287. result = ""
  288. for v in tm.values():
  289. result += self._tensor_meta_to_label(v)
  290. return result
  291. elif isinstance(tm, tuple):
  292. result = ""
  293. for item in tm:
  294. result += self._tensor_meta_to_label(item)
  295. return result
  296. else:
  297. raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
  298. def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
  299. result = ""
  300. if not hasattr(tm, "dtype"):
  301. print("tm", tm)
  302. result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
  303. result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
  304. result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
  305. result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
  306. if tm.is_quantized:
  307. assert tm.qparams is not None
  308. assert "qscheme" in tm.qparams
  309. qscheme = tm.qparams["qscheme"]
  310. if qscheme in {
  311. torch.per_tensor_affine,
  312. torch.per_tensor_symmetric,
  313. }:
  314. result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
  315. result += (
  316. "|"
  317. + "q_zero_point"
  318. + "="
  319. + str(tm.qparams["zero_point"])
  320. + r"\n"
  321. )
  322. elif qscheme in {
  323. torch.per_channel_affine,
  324. torch.per_channel_symmetric,
  325. torch.per_channel_affine_float_qparams,
  326. }:
  327. result += (
  328. "|"
  329. + "q_per_channel_scale"
  330. + "="
  331. + str(tm.qparams["scale"])
  332. + r"\n"
  333. )
  334. result += (
  335. "|"
  336. + "q_per_channel_zero_point"
  337. + "="
  338. + str(tm.qparams["zero_point"])
  339. + r"\n"
  340. )
  341. result += (
  342. "|"
  343. + "q_per_channel_axis"
  344. + "="
  345. + str(tm.qparams["axis"])
  346. + r"\n"
  347. )
  348. else:
  349. raise RuntimeError(f"Unsupported qscheme: {qscheme}")
  350. result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
  351. return result
  352. def _get_tensor_label(self, t: torch.Tensor) -> str:
  353. return str(t.dtype) + str(list(t.shape)) + r"\n"
  354. # when parse_stack_trace=True
  355. # print file:lineno code
  356. def _to_dot(
  357. self,
  358. graph_module: torch.fx.GraphModule,
  359. name: str,
  360. ignore_getattr: bool,
  361. ignore_parameters_and_buffers: bool,
  362. skip_node_names_in_args: bool,
  363. parse_stack_trace: bool,
  364. ) -> pydot.Dot:
  365. """
  366. Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
  367. If ignore_parameters_and_buffers is True, the parameters and buffers
  368. created with the module will not be added as nodes and edges.
  369. """
  370. # "TB" means top-to-bottom rank direction in layout
  371. dot_graph = pydot.Dot(name, rankdir="TB")
  372. buf_name_to_subgraph = {}
  373. for node in graph_module.graph.nodes:
  374. if ignore_getattr and node.op == "get_attr":
  375. continue
  376. style = self._get_node_style(node)
  377. dot_node = pydot.Node(
  378. node.name,
  379. label=self._get_node_label(
  380. graph_module, node, skip_node_names_in_args, parse_stack_trace
  381. ),
  382. **style, # type: ignore[arg-type]
  383. )
  384. current_graph = dot_graph
  385. buf_meta = node.meta.get("buf_meta", None)
  386. if buf_meta is not None and buf_meta.n_origin > 1:
  387. buf_name = buf_meta.name
  388. if buf_name not in buf_name_to_subgraph:
  389. buf_name_to_subgraph[buf_name] = pydot.Cluster(
  390. buf_name, label=buf_name
  391. )
  392. current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment]
  393. current_graph.add_node(dot_node)
  394. def get_module_params_or_buffers():
  395. for pname, ptensor in chain(
  396. leaf_module.named_parameters(), leaf_module.named_buffers()
  397. ):
  398. pname1 = node.name + "." + pname
  399. label1 = (
  400. pname1 + "|op_code=get_" + "parameter"
  401. if isinstance(ptensor, torch.nn.Parameter)
  402. else "buffer" + r"\l"
  403. )
  404. dot_w_node = pydot.Node(
  405. pname1,
  406. label="{" + label1 + self._get_tensor_label(ptensor) + "}",
  407. **_WEIGHT_TEMPLATE, # type: ignore[arg-type]
  408. )
  409. dot_graph.add_node(dot_w_node)
  410. dot_graph.add_edge(pydot.Edge(pname1, node.name))
  411. if node.op == "call_module":
  412. leaf_module = self._get_leaf_node(graph_module, node)
  413. if not ignore_parameters_and_buffers and not isinstance(
  414. leaf_module, torch.fx.GraphModule
  415. ):
  416. get_module_params_or_buffers()
  417. for subgraph in buf_name_to_subgraph.values():
  418. subgraph.set("color", "royalblue")
  419. subgraph.set("penwidth", "2")
  420. dot_graph.add_subgraph(subgraph) # type: ignore[arg-type]
  421. for node in graph_module.graph.nodes:
  422. if ignore_getattr and node.op == "get_attr":
  423. continue
  424. for user in node.users:
  425. dot_graph.add_edge(pydot.Edge(node.name, user.name))
  426. return dot_graph
  427. else:
  428. if not TYPE_CHECKING:
  429. @compatibility(is_backward_compatible=False)
  430. class FxGraphDrawer:
  431. def __init__(
  432. self,
  433. graph_module: torch.fx.GraphModule,
  434. name: str,
  435. ignore_getattr: bool = False,
  436. ignore_parameters_and_buffers: bool = False,
  437. skip_node_names_in_args: bool = True,
  438. parse_stack_trace: bool = False,
  439. dot_graph_shape: Optional[str] = None,
  440. normalize_args: bool = False,
  441. ):
  442. raise RuntimeError(
  443. "FXGraphDrawer requires the pydot package to be installed. Please install "
  444. "pydot through your favorite Python package manager."
  445. )