interpreter.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. from contextlib import contextmanager
  4. from typing import Any, Optional, TYPE_CHECKING, Union
  5. import torch
  6. import torch.fx.traceback as fx_traceback
  7. from torch._logging import trace_structured
  8. from torch.hub import tqdm
  9. from . import config
  10. from ._compatibility import compatibility
  11. from ._lazy_graph_module import _make_graph_module
  12. from ._symbolic_trace import Tracer
  13. from .graph import Graph
  14. from .graph_module import GraphModule
  15. from .node import Argument, map_aggregate, map_arg, Node, Target
  16. from .proxy import Proxy
  17. if TYPE_CHECKING:
  18. from collections.abc import Iterator
  19. __all__ = ["Interpreter", "Transformer"]
  20. @compatibility(is_backward_compatible=True)
  21. class Interpreter:
  22. """
  23. An Interpreter executes an FX graph Node-by-Node. This pattern
  24. can be useful for many things, including writing code
  25. transformations as well as analysis passes.
  26. Methods in the Interpreter class can be overridden to customize
  27. the behavior of execution. The map of overridable methods
  28. in terms of call hierarchy::
  29. run()
  30. +-- run_node
  31. +-- placeholder()
  32. +-- get_attr()
  33. +-- call_function()
  34. +-- call_method()
  35. +-- call_module()
  36. +-- output()
  37. Example:
  38. Suppose we want to swap all instances of ``torch.neg`` with
  39. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  40. method equivalents). We could subclass Interpreter like so::
  41. class NegSigmSwapInterpreter(Interpreter):
  42. def call_function(
  43. self, target: Target, args: Tuple, kwargs: Dict
  44. ) -> Any:
  45. if target == torch.sigmoid:
  46. return torch.neg(*args, **kwargs)
  47. return super().call_function(target, args, kwargs)
  48. def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
  49. if target == "neg":
  50. call_self, *args_tail = args
  51. return call_self.sigmoid(*args_tail, **kwargs)
  52. return super().call_method(target, args, kwargs)
  53. def fn(x):
  54. return torch.sigmoid(x).neg()
  55. gm = torch.fx.symbolic_trace(fn)
  56. input = torch.randn(3, 4)
  57. result = NegSigmSwapInterpreter(gm).run(input)
  58. torch.testing.assert_close(result, torch.neg(input).sigmoid())
  59. Args:
  60. module (torch.nn.Module): The module to be executed
  61. garbage_collect_values (bool): Whether to delete values after their last
  62. use within the Module's execution. This ensures optimal memory usage during
  63. execution. This can be disabled to, for example, examine all of the intermediate
  64. values in the execution by looking at the ``Interpreter.env`` attribute.
  65. graph (Optional[Graph]): If passed, the interpreter will execute this
  66. graph instead of `module.graph`, using the provided `module`
  67. argument to satisfy any requests for state.
  68. """
  69. @compatibility(is_backward_compatible=True)
  70. def __init__(
  71. self,
  72. module: torch.nn.Module,
  73. garbage_collect_values: bool = True,
  74. graph: Optional[Graph] = None,
  75. ):
  76. self.module = module
  77. self.submodules = dict(self.module.named_modules())
  78. if graph is not None:
  79. self.graph = graph
  80. else:
  81. self.graph = self.module.graph # type: ignore[assignment]
  82. self.env: dict[Node, Any] = {}
  83. self.name = "Interpreter"
  84. self.garbage_collect_values = garbage_collect_values
  85. self.extra_traceback = True
  86. if self.garbage_collect_values:
  87. # Run through reverse nodes and record the first instance of a use
  88. # of a given node. This represents the *last* use of the node in the
  89. # execution order of the program, which we will use to free unused
  90. # values
  91. node_to_last_use: dict[Node, Node] = {}
  92. self.user_to_last_uses: dict[Node, list[Node]] = {}
  93. def register_last_uses(n: Node, user: Node):
  94. if n not in node_to_last_use:
  95. node_to_last_use[n] = user
  96. self.user_to_last_uses.setdefault(user, []).append(n)
  97. for node in reversed(self.graph.nodes):
  98. for n in node._input_nodes:
  99. register_last_uses(n, node)
  100. @compatibility(is_backward_compatible=True)
  101. def run(
  102. self,
  103. *args,
  104. initial_env: Optional[dict[Node, Any]] = None,
  105. enable_io_processing: bool = True,
  106. ) -> Any:
  107. """
  108. Run `module` via interpretation and return the result.
  109. Args:
  110. *args: The arguments to the Module to run, in positional order
  111. initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
  112. This is a dict mapping `Node` to any value. This can be used, for example, to
  113. pre-populate results for certain `Nodes` so as to do only partial evaluation within
  114. the interpreter.
  115. enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
  116. process_outputs function first before using them.
  117. Returns:
  118. Any: The value returned from executing the Module
  119. """
  120. self.env = initial_env if initial_env is not None else {}
  121. # Positional function args are consumed left-to-right by
  122. # `placeholder` nodes. Use an iterator to keep track of
  123. # position and extract those values.
  124. if enable_io_processing:
  125. args = self.graph.process_inputs(*args)
  126. self.args_iter: Iterator[Any] = iter(args)
  127. pbar = tqdm(
  128. total=len(self.graph.nodes),
  129. desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
  130. initial=0,
  131. position=0,
  132. leave=True,
  133. disable=config.disable_progress,
  134. delay=0,
  135. )
  136. for node in self.graph.nodes:
  137. pbar.update(1)
  138. if node in self.env:
  139. # Short circuit if we have this value. This could
  140. # be used, for example, for partial evaluation
  141. # where the caller has pre-populated `env` with
  142. # values for a subset of the program.
  143. continue
  144. try:
  145. self.env[node] = self.run_node(node)
  146. except Exception as e:
  147. if self.extra_traceback:
  148. msg = f"While executing {node.format_node()}"
  149. msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg)
  150. msg += f"\nOriginal traceback:\n{node.stack_trace}"
  151. if (
  152. isinstance(self.module, GraphModule)
  153. and self.module.graph is not None
  154. and isinstance(self.module.graph, torch.fx.Graph)
  155. ):
  156. trace_structured(
  157. "artifact",
  158. metadata_fn=lambda: {
  159. "name": "fx_interpreter_error",
  160. "encoding": "string",
  161. },
  162. payload_fn=lambda: (
  163. f"{msg}\nGraphModule: "
  164. f"{self.module.print_readable(print_output=False, include_stride=True)}" # type: ignore[operator]
  165. ),
  166. )
  167. msg += "\nUse tlparse to see full graph. "
  168. msg += "(https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)"
  169. e.args = (msg,) + e.args[1:]
  170. if isinstance(e, KeyError):
  171. raise RuntimeError(*e.args) from e
  172. raise
  173. if self.garbage_collect_values:
  174. for to_delete in self.user_to_last_uses.get(node, []):
  175. del self.env[to_delete]
  176. if node.op == "output":
  177. output_val = self.env[node]
  178. return (
  179. self.graph.process_outputs(output_val)
  180. if enable_io_processing
  181. else output_val
  182. )
  183. @compatibility(is_backward_compatible=True)
  184. def boxed_run(self, args_list):
  185. """
  186. Run `module` via interpretation and return the result. This uses the "boxed"
  187. calling convention, where you pass a list of arguments, which will be cleared
  188. by the interpreter. This ensures that input tensors are promptly deallocated.
  189. """
  190. args_iter = iter(args_list)
  191. env = {}
  192. for n in self.graph.nodes:
  193. if n.op == "placeholder":
  194. env[n] = next(args_iter)
  195. args_list.clear()
  196. return self.run(initial_env=env)
  197. @contextmanager
  198. def _set_current_node(self, node):
  199. with fx_traceback.set_current_meta(
  200. node, f"Interpreter_{self.__class__.__name__}"
  201. ):
  202. yield
  203. @compatibility(is_backward_compatible=True)
  204. def run_node(self, n: Node) -> Any:
  205. """
  206. Run a specific node ``n`` and return the result.
  207. Calls into placeholder, get_attr, call_function,
  208. call_method, call_module, or output depending
  209. on ``node.op``
  210. Args:
  211. n (Node): The Node to execute
  212. Returns:
  213. Any: The result of executing ``n``
  214. """
  215. with self._set_current_node(n):
  216. args, kwargs = self.fetch_args_kwargs_from_env(n)
  217. assert isinstance(args, tuple)
  218. assert isinstance(kwargs, dict)
  219. return getattr(self, n.op)(n.target, args, kwargs)
  220. # Main Node running APIs
  221. @compatibility(is_backward_compatible=True)
  222. def placeholder(
  223. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  224. ) -> Any:
  225. """
  226. Execute a ``placeholder`` node. Note that this is stateful:
  227. ``Interpreter`` maintains an internal iterator over
  228. arguments passed to ``run`` and this method returns
  229. next() on that iterator.
  230. Args:
  231. target (Target): The call target for this node. See
  232. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  233. details on semantics
  234. args (Tuple): Tuple of positional args for this invocation
  235. kwargs (Dict): Dict of keyword arguments for this invocation
  236. Returns:
  237. Any: The argument value that was retrieved.
  238. """
  239. assert isinstance(target, str)
  240. if target.startswith("*"):
  241. # For a starred parameter e.g. `*args`, retrieve all
  242. # remaining values from the args list.
  243. return list(self.args_iter)
  244. else:
  245. try:
  246. return next(self.args_iter)
  247. except StopIteration as si:
  248. if len(args) > 0:
  249. return args[0]
  250. else:
  251. raise RuntimeError(
  252. f"Expected positional argument for parameter {target}, but one was not passed in!"
  253. ) from si
  254. @compatibility(is_backward_compatible=True)
  255. def get_attr(
  256. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  257. ) -> Any:
  258. """
  259. Execute a ``get_attr`` node. Will retrieve an attribute
  260. value from the ``Module`` hierarchy of ``self.module``.
  261. Args:
  262. target (Target): The call target for this node. See
  263. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  264. details on semantics
  265. args (Tuple): Tuple of positional args for this invocation
  266. kwargs (Dict): Dict of keyword arguments for this invocation
  267. Return:
  268. Any: The value of the attribute that was retrieved
  269. """
  270. assert isinstance(target, str)
  271. return self.fetch_attr(target)
  272. @compatibility(is_backward_compatible=True)
  273. def call_function(
  274. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  275. ) -> Any:
  276. """
  277. Execute a ``call_function`` node and return the result.
  278. Args:
  279. target (Target): The call target for this node. See
  280. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  281. details on semantics
  282. args (Tuple): Tuple of positional args for this invocation
  283. kwargs (Dict): Dict of keyword arguments for this invocation
  284. Return
  285. Any: The value returned by the function invocation
  286. """
  287. assert not isinstance(target, str)
  288. # Execute the function and return the result
  289. return target(*args, **kwargs)
  290. @compatibility(is_backward_compatible=True)
  291. def call_method(
  292. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  293. ) -> Any:
  294. """
  295. Execute a ``call_method`` node and return the result.
  296. Args:
  297. target (Target): The call target for this node. See
  298. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  299. details on semantics
  300. args (Tuple): Tuple of positional args for this invocation
  301. kwargs (Dict): Dict of keyword arguments for this invocation
  302. Return
  303. Any: The value returned by the method invocation
  304. """
  305. # args[0] is the `self` object for this method call
  306. self_obj, *args_tail = args
  307. # Execute the method and return the result
  308. assert isinstance(target, str)
  309. return getattr(self_obj, target)(*args_tail, **kwargs)
  310. @compatibility(is_backward_compatible=True)
  311. def call_module(
  312. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  313. ) -> Any:
  314. """
  315. Execute a ``call_module`` node and return the result.
  316. Args:
  317. target (Target): The call target for this node. See
  318. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  319. details on semantics
  320. args (Tuple): Tuple of positional args for this invocation
  321. kwargs (Dict): Dict of keyword arguments for this invocation
  322. Return
  323. Any: The value returned by the module invocation
  324. """
  325. # Retrieve executed args and kwargs values from the environment
  326. # Execute the method and return the result
  327. assert isinstance(target, str)
  328. submod = self.fetch_attr(target)
  329. return submod(*args, **kwargs)
  330. @compatibility(is_backward_compatible=True)
  331. def output(
  332. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  333. ) -> Any:
  334. """
  335. Execute an ``output`` node. This really just retrieves
  336. the value referenced by the ``output`` node and returns it.
  337. Args:
  338. target (Target): The call target for this node. See
  339. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  340. details on semantics
  341. args (Tuple): Tuple of positional args for this invocation
  342. kwargs (Dict): Dict of keyword arguments for this invocation
  343. Return:
  344. Any: The return value referenced by the output node
  345. """
  346. return args[0]
  347. # Helper methods
  348. @compatibility(is_backward_compatible=True)
  349. def fetch_attr(self, target: str):
  350. """
  351. Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
  352. Args:
  353. target (str): The fully-qualified name of the attribute to fetch
  354. Return:
  355. Any: The value of the attribute.
  356. """
  357. target_atoms = target.split(".")
  358. attr_itr = self.module
  359. for i, atom in enumerate(target_atoms):
  360. if not hasattr(attr_itr, atom):
  361. raise RuntimeError(
  362. f"Node referenced nonexistent target {'.'.join(target_atoms[: i + 1])}"
  363. )
  364. attr_itr = getattr(attr_itr, atom)
  365. return attr_itr
  366. @compatibility(is_backward_compatible=True)
  367. def fetch_args_kwargs_from_env(self, n: Node) -> tuple[tuple, dict]:
  368. """
  369. Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
  370. from the current execution environment.
  371. Args:
  372. n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
  373. Return:
  374. Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
  375. """
  376. args = self.map_nodes_to_values(n.args, n)
  377. assert isinstance(args, tuple)
  378. kwargs = self.map_nodes_to_values(n.kwargs, n)
  379. assert isinstance(kwargs, dict)
  380. return args, kwargs
  381. @compatibility(is_backward_compatible=True)
  382. def map_nodes_to_values(self, args: Argument, n: Node) -> Argument:
  383. """
  384. Recursively descend through ``args`` and look up the concrete value
  385. for each ``Node`` in the current execution environment.
  386. Args:
  387. args (Argument): Data structure within which to look up concrete values
  388. n (Node): Node to which ``args`` belongs. This is only used for error reporting.
  389. """
  390. def load_arg(n_arg: Node) -> Any:
  391. if n_arg not in self.env:
  392. raise RuntimeError(
  393. f"Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() "
  394. f"to diagnose such issues"
  395. )
  396. return self.env[n_arg]
  397. return map_arg(args, load_arg)
  398. @compatibility(is_backward_compatible=True)
  399. class Transformer(Interpreter):
  400. """
  401. ``Transformer`` is a special type of interpreter that produces a
  402. new ``Module``. It exposes a ``transform()`` method that returns
  403. the transformed ``Module``. ``Transformer`` does not require
  404. arguments to run, as ``Interpreter`` does. ``Transformer`` works
  405. entirely symbolically.
  406. Example:
  407. Suppose we want to swap all instances of ``torch.neg`` with
  408. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  409. method equivalents). We could subclass ``Transformer`` like so::
  410. class NegSigmSwapXformer(Transformer):
  411. def call_function(
  412. self,
  413. target: "Target",
  414. args: Tuple[Argument, ...],
  415. kwargs: Dict[str, Any],
  416. ) -> Any:
  417. if target == torch.sigmoid:
  418. return torch.neg(*args, **kwargs)
  419. return super().call_function(target, args, kwargs)
  420. def call_method(
  421. self,
  422. target: "Target",
  423. args: Tuple[Argument, ...],
  424. kwargs: Dict[str, Any],
  425. ) -> Any:
  426. if target == "neg":
  427. call_self, *args_tail = args
  428. return call_self.sigmoid(*args_tail, **kwargs)
  429. return super().call_method(target, args, kwargs)
  430. def fn(x):
  431. return torch.sigmoid(x).neg()
  432. gm = torch.fx.symbolic_trace(fn)
  433. transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
  434. input = torch.randn(3, 4)
  435. torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
  436. Args:
  437. module (GraphModule): The ``Module`` to be transformed.
  438. """
  439. @compatibility(is_backward_compatible=True)
  440. def __init__(self, module):
  441. super().__init__(module)
  442. self.new_graph = Graph()
  443. self.new_graph.set_codegen(module.graph._codegen)
  444. class TransformerTracer(Tracer):
  445. def __init__(self, graph: Graph):
  446. super().__init__()
  447. self.graph = graph
  448. self.tensor_attrs: dict[torch.Tensor, str] = {} # type: ignore[assignment]
  449. def is_leaf_module(self, _, __) -> bool:
  450. return True
  451. self.tracer = TransformerTracer(self.new_graph)
  452. self.tracer.root = module
  453. @compatibility(is_backward_compatible=True)
  454. def placeholder(
  455. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  456. ) -> Proxy:
  457. """
  458. Execute a ``placeholder`` node. In ``Transformer``, this is
  459. overridden to insert a new ``placeholder`` into the output
  460. graph.
  461. Args:
  462. target (Target): The call target for this node. See
  463. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  464. details on semantics
  465. args (Tuple): Tuple of positional args for this invocation
  466. kwargs (Dict): Dict of keyword arguments for this invocation
  467. """
  468. assert isinstance(target, str)
  469. default_value = next(iter(args)) if args else inspect.Signature.empty
  470. return Proxy(
  471. self.new_graph.placeholder(target, default_value=default_value), self.tracer
  472. )
  473. @compatibility(is_backward_compatible=True)
  474. def get_attr(
  475. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  476. ) -> Proxy:
  477. """
  478. Execute a ``get_attr`` node. In ``Transformer``, this is
  479. overridden to insert a new ``get_attr`` node into the output
  480. graph.
  481. Args:
  482. target (Target): The call target for this node. See
  483. `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for
  484. details on semantics
  485. args (Tuple): Tuple of positional args for this invocation
  486. kwargs (Dict): Dict of keyword arguments for this invocation
  487. """
  488. assert isinstance(target, str)
  489. return self.tracer.create_proxy("get_attr", target, args, kwargs)
  490. @compatibility(is_backward_compatible=True)
  491. def call_module(
  492. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  493. ) -> Any:
  494. # Override so that the leaf module policy from `self.tracer` is respected.
  495. assert isinstance(target, str)
  496. submod = self.fetch_attr(target)
  497. return self.tracer.call_module(submod, submod.forward, args, kwargs)
  498. @compatibility(is_backward_compatible=True)
  499. def call_function(
  500. self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
  501. ) -> Any:
  502. # Override so that functions that were wrapped are still wrapped.
  503. return self.tracer.create_proxy("call_function", target, args, kwargs)
  504. @compatibility(is_backward_compatible=True)
  505. def transform(self) -> GraphModule:
  506. """
  507. Transform ``self.module`` and return the transformed
  508. ``GraphModule``.
  509. """
  510. with fx_traceback.preserve_node_meta():
  511. result = super().run(enable_io_processing=False)
  512. if result is not None:
  513. def strip_proxy(a: Union[Argument, Proxy]) -> Any:
  514. return a.node if isinstance(a, Proxy) else a
  515. new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy))
  516. # also preserve the metadata from the old output node, if it exists
  517. old_output_node = list(self.graph.nodes)[-1]
  518. assert old_output_node.op == "output"
  519. for k, v in old_output_node.meta.items():
  520. new_output_node.meta[k] = v
  521. return _make_graph_module(self.module, self.new_graph)