optimization.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import logging
  4. import operator
  5. import time
  6. from collections import defaultdict
  7. from collections.abc import Iterable
  8. from enum import Enum
  9. from typing import Any, cast, Optional
  10. import torch
  11. import torch.fx as fx
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.utils.mkldnn as th_mkldnn
  15. from torch.fx.node import Argument, Target
  16. from torch.fx.passes.shape_prop import ShapeProp
  17. from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval
  18. __all__ = [
  19. "matches_module_pattern",
  20. "replace_node_module",
  21. "fuse",
  22. "remove_dropout",
  23. "extract_subgraph",
  24. "modules_to_mkldnn",
  25. "reset_modules",
  26. "MklSubgraph",
  27. "gen_mkl_autotuner",
  28. "use_mkl_length",
  29. "UnionFind",
  30. "optimize_for_inference",
  31. ]
  32. def _parent_name(target: str) -> tuple[str, str]:
  33. """
  34. Splits a qualname into parent path and last atom.
  35. For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
  36. """
  37. *parent, name = target.rsplit(".", 1)
  38. return parent[0] if parent else "", name
  39. # Works for length 2 patterns with 2 modules
  40. def matches_module_pattern(
  41. pattern: Iterable[type], node: fx.Node, modules: dict[str, Any]
  42. ):
  43. if len(node.args) == 0:
  44. return False
  45. nodes: tuple[Any, fx.Node] = (node.args[0], node)
  46. for expected_type, current_node in zip(pattern, nodes):
  47. if not isinstance(current_node, fx.Node):
  48. return False
  49. if current_node.op != "call_module":
  50. return False
  51. if not isinstance(current_node.target, str):
  52. return False
  53. if current_node.target not in modules:
  54. return False
  55. if type(modules[current_node.target]) is not expected_type:
  56. return False
  57. return True
  58. def replace_node_module(
  59. node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module
  60. ):
  61. assert isinstance(node.target, str)
  62. parent_name, name = _parent_name(node.target)
  63. modules[node.target] = new_module
  64. setattr(modules[parent_name], name, new_module)
  65. def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
  66. """
  67. Fuses convolution/BN and linear/BN layers for inference purposes.
  68. Will deepcopy your model by default, but can modify the model inplace as well.
  69. """
  70. patterns = [
  71. (nn.Conv1d, nn.BatchNorm1d),
  72. (nn.Conv2d, nn.BatchNorm2d),
  73. (nn.Conv3d, nn.BatchNorm3d),
  74. (nn.Linear, nn.BatchNorm1d),
  75. ]
  76. if not inplace:
  77. model = copy.deepcopy(model)
  78. if not no_trace or not isinstance(model, torch.fx.GraphModule):
  79. fx_model = fx.symbolic_trace(model)
  80. else:
  81. fx_model = model
  82. modules = dict(fx_model.named_modules())
  83. new_graph = copy.deepcopy(fx_model.graph)
  84. for pattern in patterns:
  85. for node in new_graph.nodes:
  86. if matches_module_pattern(pattern, node, modules):
  87. if len(node.args[0].users) > 1:
  88. # Output of conv/linear is used by other nodes
  89. continue
  90. first_layer = modules[node.args[0].target]
  91. bn = modules[node.target]
  92. if not bn.track_running_stats:
  93. continue
  94. if pattern[0] in [nn.Conv1d, nn.Conv2d, nn.Conv3d]:
  95. fused_layer = fuse_conv_bn_eval(first_layer, bn)
  96. else: # nn.Linear
  97. fused_layer = fuse_linear_bn_eval(first_layer, bn)
  98. replace_node_module(node.args[0], modules, fused_layer)
  99. node.replace_all_uses_with(node.args[0])
  100. new_graph.erase_node(node)
  101. return fx.GraphModule(fx_model, new_graph)
  102. def remove_dropout(model: nn.Module) -> nn.Module:
  103. """
  104. Removes all dropout layers from the module.
  105. """
  106. fx_model = fx.symbolic_trace(model)
  107. class DropoutRemover(torch.fx.Transformer):
  108. def call_module(
  109. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  110. ) -> Any:
  111. if isinstance(self.submodules[target], nn.Dropout):
  112. assert len(args) == 1
  113. return args[0]
  114. else:
  115. return super().call_module(target, args, kwargs)
  116. return DropoutRemover(fx_model).transform()
  117. def extract_subgraph(
  118. orig_module: nn.Module,
  119. nodes: list[fx.Node],
  120. inputs: list[fx.Node],
  121. outputs: list[fx.Node],
  122. ):
  123. """
  124. Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
  125. """
  126. new_graph = fx.Graph()
  127. env: dict[fx.Node, fx.Node] = {}
  128. for input in inputs:
  129. new_node = new_graph.placeholder(input.name)
  130. env[input] = new_node
  131. for node in nodes:
  132. new_node = new_graph.node_copy(node, lambda x: env[x])
  133. env[node] = new_node
  134. new_graph.output([env[output] for output in outputs])
  135. new_graph.lint()
  136. return fx.GraphModule(orig_module, new_graph)
  137. mkldnn_supported = [
  138. nn.Conv2d,
  139. nn.Linear,
  140. nn.BatchNorm2d,
  141. nn.ReLU,
  142. nn.MaxPool2d,
  143. nn.AvgPool2d,
  144. nn.AdaptiveAvgPool2d,
  145. torch.relu,
  146. torch.transpose,
  147. torch.sigmoid,
  148. F.relu,
  149. F.avg_pool2d,
  150. F.adaptive_avg_pool2d,
  151. ]
  152. # These are operators that may not be convertible into MKLDNN ops (e.g. the
  153. # args are scalar values). Thus, we only include them in the subgraph if their
  154. # arguments are already in MKLDNN.
  155. # TODO: Determine whether this can be removed after type inference.
  156. mkldnn_supported_unknown = [operator.add, operator.mul]
  157. mkldnn_map = {
  158. nn.Conv2d: th_mkldnn.MkldnnConv2d,
  159. nn.Linear: th_mkldnn.MkldnnLinear,
  160. nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a),
  161. }
  162. def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
  163. """
  164. For each node, if it's a module that can be preconverted into MKLDNN,
  165. then we do so and create a mapping to allow us to convert from the MKLDNN
  166. version of the module to the original.
  167. """
  168. old_modules: dict[nn.Module, nn.Module] = {}
  169. for node in nodes:
  170. if node.op == "call_module":
  171. assert isinstance(node.target, str)
  172. cur_module = modules[node.target]
  173. if type(cur_module) in mkldnn_map:
  174. new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
  175. assert isinstance(new_module, nn.Module)
  176. old_modules[new_module] = copy.deepcopy(cur_module)
  177. replace_node_module(node, modules, new_module)
  178. return old_modules
  179. def reset_modules(
  180. nodes: list[fx.Node],
  181. modules: dict[str, nn.Module],
  182. old_modules: dict[nn.Module, nn.Module],
  183. ):
  184. """
  185. Maps each module that's been changed with `modules_to_mkldnn` back to its
  186. original.
  187. """
  188. for node in nodes:
  189. if node.op == "call_module":
  190. assert isinstance(node.target, str)
  191. cur_module = modules[node.target]
  192. if cur_module in old_modules:
  193. replace_node_module(node, modules, old_modules[cur_module])
  194. class MklSubgraph:
  195. def __init__(self, fx_graph: fx.Graph):
  196. self.fx_graph = fx_graph
  197. self.nodes: list[fx.Node] = []
  198. self.start_nodes: list[fx.Node] = []
  199. self.end_nodes: list[fx.Node] = []
  200. def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
  201. """
  202. This generates a heuristic that can be passed into `optimize_for_inference` that
  203. determines whether a subgraph should be run in MKL by running it with the example_inputs.
  204. Example usage:
  205. heuristic = gen_mkl_autotuner(example_inputs, iters=10)
  206. fast_model = optimization.optimize_for_inference(model, heuristic)
  207. """
  208. fx_model = None
  209. old_modules = None
  210. def use_mkl_heuristic(graph: MklSubgraph) -> bool:
  211. nonlocal fx_model, old_modules
  212. input_nodes = graph.start_nodes
  213. if fx_model is None:
  214. fx_model = graph.fx_graph.owning_module
  215. old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
  216. ShapeProp(fx_model).propagate(example_inputs)
  217. sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
  218. output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes])
  219. submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
  220. def benchmark(f):
  221. for _ in range(warmup):
  222. f()
  223. begin = time.time()
  224. for _ in range(iters):
  225. f()
  226. return time.time() - begin
  227. mkl_time = benchmark(
  228. lambda: [
  229. i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])
  230. ]
  231. )
  232. reset_modules(
  233. submodule.graph.nodes, dict(submodule.named_modules()), old_modules
  234. )
  235. no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
  236. return mkl_time < no_mkl_time
  237. return use_mkl_heuristic
  238. def use_mkl_length(graph: MklSubgraph) -> bool:
  239. """
  240. This is a heuristic that can be passed into `optimize_for_inference` that
  241. determines whether a subgraph should be run in MKL by checking if there
  242. are more than 2 nodes in it
  243. """
  244. return len(graph.nodes) > 2
  245. class UnionFind:
  246. def __init__(self, n):
  247. self.parent: list[Optional[int]] = [None] * n
  248. self.size: list[int] = [0] * n
  249. def make_set(self, v: int):
  250. self.parent[v] = v
  251. self.size[v] = 1
  252. def find(self, v: int) -> int:
  253. par = self.parent[v]
  254. if v == par:
  255. return v
  256. assert par is not None
  257. self.parent[v] = self.find(par)
  258. return cast(int, self.parent[v])
  259. def join(self, a: int, b: int):
  260. a, b = self.find(a), self.find(b)
  261. if a == b:
  262. return a
  263. if self.size[a] < self.size[b]:
  264. a, b = b, a
  265. self.parent[b] = a
  266. self.size[a] += self.size[b]
  267. def optimize_for_inference(
  268. model: torch.nn.Module,
  269. pass_config: Optional[dict[str, Any]] = None,
  270. tracer: type[fx.Tracer] = fx.Tracer,
  271. ) -> torch.nn.Module:
  272. """
  273. Performs a set of optimization passes to optimize a model for the
  274. purposes of inference. Specifically, the passes that are run are:
  275. 1. Conv/BN fusion
  276. 2. Dropout removal
  277. 3. MKL layout optimizations
  278. The third optimization takes a function `use_mkl_heuristic` that's used
  279. to determine whether a subgraph should be explicitly run in MKL layout.
  280. Note: As FX does not currently handle aliasing, this pass currently
  281. assumes nothing aliases. If that isn't true, use at your own risk.
  282. """
  283. default_pass_config = {
  284. "conv_bn_fuse": True,
  285. "remove_dropout": True,
  286. "mkldnn_layout_optimize": {"heuristic": use_mkl_length},
  287. }
  288. if pass_config is None:
  289. pass_config = {}
  290. default_pass_config.update(pass_config)
  291. if default_pass_config["conv_bn_fuse"]:
  292. model = fuse(model)
  293. if default_pass_config["remove_dropout"]:
  294. model = remove_dropout(model)
  295. if default_pass_config["mkldnn_layout_optimize"] is False:
  296. return model
  297. if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
  298. raise RuntimeError("mkldnn_layout_optimize config is not a dict")
  299. if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
  300. raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
  301. use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
  302. cur_tracer = tracer()
  303. fx_graph = cur_tracer.trace(copy.deepcopy(model))
  304. fx.GraphModule(cur_tracer.root, fx_graph)
  305. modules: dict[str, nn.Module] = dict(model.named_modules())
  306. class MklSupport(Enum):
  307. NO = 1
  308. YES = 2
  309. UNKNOWN = 3
  310. # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
  311. # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
  312. # However, if it's in `mkldnn_supported_unknown`, then we only treat it as
  313. # a MKLDNN node if its inputs are MKLDNN nodes.
  314. for node in list(fx_graph.nodes):
  315. supports_mkldnn = MklSupport.NO
  316. if node.op == "call_module":
  317. cur_module = modules[node.target]
  318. if type(cur_module) in mkldnn_supported:
  319. supports_mkldnn = MklSupport.YES
  320. sample_parameter = next(cur_module.parameters(), None)
  321. if sample_parameter is not None:
  322. assert sample_parameter.dtype == torch.float, (
  323. "this pass is only for torch.float modules"
  324. )
  325. assert sample_parameter.device == torch.device("cpu"), (
  326. "this pass is only for CPU modules"
  327. )
  328. elif node.op == "call_function":
  329. if node.target in mkldnn_supported:
  330. supports_mkldnn = MklSupport.YES
  331. elif node.target in mkldnn_supported_unknown:
  332. supports_mkldnn = MklSupport.UNKNOWN
  333. if supports_mkldnn != MklSupport.NO:
  334. if supports_mkldnn == MklSupport.UNKNOWN:
  335. if not any(arg.target == "to_dense" for arg in node.args):
  336. continue
  337. with fx_graph.inserting_before(node):
  338. mkldnn_args = fx.map_arg(
  339. node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
  340. )
  341. node.args = cast(tuple[fx.node.Argument], mkldnn_args)
  342. with fx_graph.inserting_after(node):
  343. dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
  344. node.replace_all_uses_with(dense_x)
  345. dense_x.args = (node,)
  346. # Does pre-conversion of all modules into MKLDNN (when possible)
  347. old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
  348. fx_graph.old_modules = old_modules # type: ignore[attr-defined]
  349. # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
  350. for node in fx_graph.nodes:
  351. if node.op == "call_method" and node.target == "to_dense":
  352. prv_node = node.args[0]
  353. users = list(node.users)
  354. for user in users:
  355. if user.op == "call_method" and user.target == "to_mkldnn":
  356. user.replace_all_uses_with(prv_node)
  357. fx_graph.erase_node(user)
  358. if len(node.users) == 0:
  359. fx_graph.erase_node(node)
  360. num_nodes = len(fx_graph.nodes)
  361. uf = UnionFind(num_nodes)
  362. def get_color(n):
  363. if hasattr(n, "color"): # Current node is part of a MKL subgraph
  364. return uf.find(n.color)
  365. if hasattr(n, "start_color"): # Current node is input to MKL subgraph
  366. return uf.find(n.start_color)
  367. return None
  368. # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
  369. # of input nodes (which are only `to_mkldnn` calls), output nodes
  370. # (`to_dense` calls), and intermediate nodes, which are run entirely on
  371. # MKLDNN layout tensors.
  372. #
  373. # Specifically, this code does a flood fill on a directed acyclic graph
  374. # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
  375. # If every node only had one input, this would be sufficient. However, in
  376. # the case that a node has multiple inputs coming from different start
  377. # nodes (i.e. colors), we need to join these 2 colors into 1. That's done
  378. # using a Disjoint Set Union.
  379. for cur_idx, node in enumerate(fx_graph.nodes):
  380. if node.op == "call_method" and node.target == "to_mkldnn":
  381. node.start_color = cur_idx
  382. uf.make_set(cur_idx)
  383. elif node.op == "call_method" and node.target == "to_dense":
  384. assert get_color(node.args[0]) is not None
  385. node.end_color = get_color(node.args[0])
  386. else:
  387. cur_colors = [
  388. get_color(i)
  389. for i in node.all_input_nodes
  390. if isinstance(i, fx.Node)
  391. if get_color(i) is not None
  392. ]
  393. if len(cur_colors) == 0:
  394. continue
  395. assert not any(i is None for i in cur_colors)
  396. cur_colors = sorted(cur_colors)
  397. node.color = cur_colors[0]
  398. for other_color in cur_colors[1:]:
  399. uf.join(cur_colors[0], other_color)
  400. mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
  401. for node in fx_graph.nodes:
  402. if hasattr(node, "color"):
  403. mkldnn_graphs[uf.find(node.color)].nodes.append(node)
  404. if hasattr(node, "start_color"):
  405. mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
  406. if hasattr(node, "end_color"):
  407. mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
  408. # Now that we have all the subgraphs, we need to decide which MKLDNN
  409. # subgraphs we actually want to keep in MKLDNN.
  410. for graph in mkldnn_graphs.values():
  411. if not use_mkl_heuristic(graph):
  412. for node in graph.start_nodes + graph.end_nodes:
  413. prv = node.args[0]
  414. node.replace_all_uses_with(prv) # type: ignore[arg-type]
  415. fx_graph.erase_node(node)
  416. reset_modules(graph.nodes, modules, old_modules)
  417. mkldnn_conversions = 0
  418. for node in fx_graph.nodes:
  419. if node.target == "to_mkldnn" or node.target == "to_dense":
  420. mkldnn_conversions += 1
  421. logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions)
  422. fx_graph.lint()
  423. result = fx.GraphModule(model, fx_graph)
  424. return result