tools_common.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import operator
  4. from collections.abc import Mapping
  5. from dataclasses import dataclass
  6. from typing import Any, Optional, Union
  7. import torch
  8. import torch.fx
  9. from torch.fx._compatibility import compatibility
  10. from torch.fx.node import _get_qualified_name
  11. __all__ = [
  12. "get_acc_ops_name",
  13. "get_node_target",
  14. "is_node_output_tensor",
  15. "FxNetAccFusionsFinder",
  16. "legalize_graph",
  17. ]
  18. Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]]
  19. TensorOrTensors = Union[torch.Tensor, Tensors]
  20. NodeList = list[torch.fx.Node]
  21. NodeSet = set[torch.fx.Node]
  22. Names = list[str]
  23. CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
  24. @compatibility(is_backward_compatible=False)
  25. def get_acc_ops_name(k):
  26. if isinstance(k, str):
  27. return k
  28. elif k.__module__ and "acc_ops" in k.__module__:
  29. return f"acc_ops.{k.__name__}"
  30. else:
  31. module = k.__module__.replace(
  32. "torch._ops", "torch.ops"
  33. ) # WAR for bug in how torch.ops assigns module
  34. return f"{module if module else ''}.{k.__name__}"
  35. @compatibility(is_backward_compatible=False)
  36. def get_node_target(
  37. submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
  38. ) -> str:
  39. """
  40. Given a `node` returns its target typename.
  41. For "call_method" node, return node.target which is the name of that method being called.
  42. This could potential lead to conflict but should be okay because normally it's on a tensor.
  43. For "call_function" node, return typename of node.target.
  44. For "call_module" node, return typename of the module that node.target point to.
  45. If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
  46. "torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
  47. """
  48. assert node.op in CALLABLE_NODE_OPS, (
  49. "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}"
  50. )
  51. if node.op == "call_module":
  52. assert isinstance(node.target, str)
  53. submod = submodules[node.target]
  54. submod_type = getattr(submod, "_base_class_origin", type(submod))
  55. return get_acc_ops_name(submod_type)
  56. elif node.op == "call_function":
  57. target: Any = node.target
  58. return (
  59. f"acc_ops.{target.__name__}"
  60. if target.__module__ is not None and "acc_ops" in target.__module__
  61. else _get_qualified_name(target)
  62. )
  63. else:
  64. assert isinstance(node.target, str)
  65. return node.target
  66. @compatibility(is_backward_compatible=False)
  67. def is_node_output_tensor(node: torch.fx.Node) -> bool:
  68. """Checks if the node output produces a Tensor or not.
  69. NOTE: This requires to run `ShapeProp` on the containing fx graph before
  70. calling this function. This is because it works by checking the `type`
  71. metadata on the node. This metadata is produced by the `ShapeProp`.
  72. """
  73. type_ = node.meta.get("type", None)
  74. return type_ is not None and issubclass(type_, torch.Tensor)
  75. @compatibility(is_backward_compatible=False)
  76. class FxNetAccFusionsFinder:
  77. """
  78. Finds groups of connected ACC nodes that pass non-tensor data between each other.
  79. Such groups are called fusion groups.
  80. """
  81. def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
  82. self.module = module
  83. self.nodes = list(module.graph.nodes)
  84. self.acc_nodes = acc_nodes
  85. @dataclass
  86. class FusionGroup:
  87. # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
  88. top_node_idx: int
  89. # Nodes in this fusion group.
  90. nodes: NodeSet
  91. # Inputs to this fusion group.
  92. inputs: NodeSet
  93. # Nodes that in the fusion group that haven't been processed yet.
  94. nodes_need_process: NodeSet
  95. def add_node(self, node):
  96. """
  97. Add a node to fusion group.
  98. """
  99. if node in self.nodes:
  100. return
  101. self.nodes_need_process.add(node)
  102. self.nodes.add(node)
  103. self.inputs.discard(node)
  104. self.inputs.update(
  105. {
  106. n
  107. for n in node.all_input_nodes
  108. if n.op in CALLABLE_NODE_OPS and n not in self.nodes
  109. }
  110. )
  111. def recursive_add_node(
  112. self,
  113. fusion_group: "FxNetAccFusionsFinder.FusionGroup",
  114. inputs: Union[NodeSet, NodeList],
  115. visited: Optional[NodeSet] = None,
  116. ):
  117. """
  118. Start from inputs and going reverse topological order. If any upstream node
  119. is in the fusion group, add all the nodes in this path to fusion group.
  120. """
  121. for arg in inputs:
  122. # skip the node if already seen
  123. if visited is not None:
  124. if arg in visited:
  125. continue
  126. visited.add(arg)
  127. # Skip placeholder and get_attr because they won't be in the fusion group.
  128. if arg.op not in CALLABLE_NODE_OPS:
  129. continue
  130. # If the node has smaller idx, it's already an upstream node of the fusion
  131. # group. We don't need to check it anymore.
  132. if self.nodes.index(arg) < fusion_group.top_node_idx:
  133. continue
  134. # If the node is in the fusion group, return True.
  135. if arg in fusion_group.nodes:
  136. return True
  137. # Check the upstream nodes of the node, if any of them is in the fusion group
  138. # we'll add this node to fusion group and return True.
  139. if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
  140. fusion_group.add_node(arg)
  141. return True
  142. return False
  143. def __call__(self) -> dict[torch.fx.Node, NodeSet]:
  144. result: dict[torch.fx.Node, NodeSet] = {}
  145. acc_nodes = list(self.acc_nodes)
  146. for node in acc_nodes:
  147. if node in result:
  148. continue
  149. if node.op not in CALLABLE_NODE_OPS:
  150. continue
  151. if "tensor_meta" in node.meta:
  152. continue
  153. if node not in self.acc_nodes:
  154. continue
  155. fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
  156. top_node_idx=self.nodes.index(node),
  157. nodes={node},
  158. inputs=set(node.all_input_nodes),
  159. nodes_need_process={node},
  160. )
  161. while fusion_group.nodes_need_process:
  162. node = fusion_group.nodes_need_process.pop()
  163. self.recursive_add_node(
  164. fusion_group,
  165. fusion_group.inputs,
  166. visited=set(),
  167. )
  168. # Optionally add downstream nodes
  169. if "tensor_meta" not in node.meta:
  170. for user in node.users:
  171. if user.op not in CALLABLE_NODE_OPS:
  172. continue
  173. if user in fusion_group.nodes:
  174. continue
  175. fusion_group.add_node(user)
  176. self.recursive_add_node(
  177. fusion_group,
  178. fusion_group.inputs,
  179. visited=set(),
  180. )
  181. # Add some upstream nodes
  182. for arg in node.all_input_nodes:
  183. if arg.op not in CALLABLE_NODE_OPS:
  184. continue
  185. if "tensor_meta" in arg.meta:
  186. continue
  187. if arg in fusion_group.nodes:
  188. continue
  189. fusion_group.add_node(arg)
  190. fusion_group.top_node_idx = min(
  191. fusion_group.top_node_idx, self.nodes.index(arg)
  192. )
  193. self.recursive_add_node(
  194. fusion_group,
  195. fusion_group.inputs,
  196. visited=set(),
  197. )
  198. if not (set(fusion_group.nodes) <= self.acc_nodes):
  199. self.acc_nodes -= fusion_group.nodes
  200. else:
  201. for n in fusion_group.nodes:
  202. result[n] = fusion_group.nodes
  203. return result
  204. @compatibility(is_backward_compatible=False)
  205. def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
  206. """
  207. Replace the graph of the given GraphModule with one that contains the same nodes as the
  208. original, but in topologically sorted order.
  209. This is used by the merge_matmul transformation below, which disturbs the topologically sorted
  210. order of its input GraphModule, so that this order is restored before further transformation.
  211. Arguments:
  212. gm: The graph module to topologically sort. It is modified in-place.
  213. Returns:
  214. The graph module in-place sorted
  215. """
  216. # These operators are used for making runtime assertions before any
  217. # data-dependent operators occur. We want to prioritize sorting these to
  218. # ensure that these assertions appear before any data-dependent operations
  219. # in the graph.
  220. PRIORITIZED_OPS = [
  221. operator.add,
  222. operator.mul,
  223. operator.sub,
  224. operator.floordiv,
  225. operator.truediv,
  226. operator.mod,
  227. operator.le,
  228. operator.lt,
  229. operator.ge,
  230. operator.gt,
  231. operator.eq,
  232. operator.ne,
  233. torch.ops.aten.sym_constrain_range.default,
  234. torch.ops.aten.sym_constrain_range_for_size.default,
  235. torch.ops.aten._assert_async.msg,
  236. torch.ops.aten.scalar_tensor.default,
  237. torch.ops.aten._assert_scalar.default,
  238. ]
  239. indeg = dict.fromkeys(gm.graph.nodes, 0)
  240. new_graph = torch.fx.Graph()
  241. # Track how many unfulfilled dependencies each node has
  242. for node in gm.graph.nodes:
  243. for user in node.users:
  244. indeg[user] += 1
  245. queue: collections.deque = collections.deque()
  246. # Add all nodes with no dependencies to the queue
  247. for node in gm.graph.nodes:
  248. if indeg[node] == 0:
  249. queue.append(node)
  250. env: dict[torch.fx.Node, torch.fx.Node] = {}
  251. # Pop nodes from the queue, and add nodes that have had all their
  252. # dependencies fulfilled
  253. while len(queue) > 0:
  254. cur = queue.popleft()
  255. env[cur] = new_graph.node_copy(cur, lambda x: env[x])
  256. for user in cur.users:
  257. indeg[user] -= 1
  258. if indeg[user] == 0:
  259. if user.op == "call_function" and user.target in PRIORITIZED_OPS:
  260. queue.appendleft(user)
  261. else:
  262. queue.append(user)
  263. # If the new graph's size is not as large as the old one, then there must be
  264. # a cycle (i.e. some node's dependencies were not satisfied.)
  265. if len(new_graph.nodes) < len(gm.graph.nodes):
  266. raise RuntimeError(
  267. f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
  268. )
  269. new_graph._codegen = gm.graph._codegen
  270. gm.graph = new_graph
  271. return gm