graph.py 79 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import contextlib
  4. import copy
  5. import enum
  6. import functools
  7. import inspect
  8. import keyword
  9. import math
  10. import os
  11. import re
  12. import typing
  13. import warnings
  14. from collections import defaultdict
  15. from collections.abc import Iterable, Iterator
  16. from contextlib import contextmanager
  17. from dataclasses import dataclass
  18. from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING
  19. import torch
  20. import torch.utils._pytree as pytree
  21. from torch._C import _fx_map_arg as map_arg, _NodeIter
  22. from torch.utils._dtype_abbrs import dtype_abbrs
  23. from . import _pytree as fx_pytree
  24. from ._compatibility import compatibility
  25. from .immutable_collections import immutable_dict
  26. from .node import _get_qualified_name, _type_repr, Argument, Node, Target
  27. __all__ = ["PythonCode", "CodeGen", "Graph"]
  28. if TYPE_CHECKING:
  29. from ._symbolic_trace import Tracer # noqa: F401
  30. from .graph_module import GraphModule # noqa: F401
  31. # Mapping of builtins to their `typing` equivalent.
  32. # (PEP585: See D68459095 test plan)
  33. _origin_type_map = {
  34. list: typing.List, # noqa: UP006
  35. dict: typing.Dict, # noqa: UP006
  36. set: typing.Set, # noqa: UP006
  37. frozenset: typing.FrozenSet, # noqa: UP006
  38. tuple: typing.Tuple, # noqa: UP006
  39. }
  40. _legal_ops = dict.fromkeys(
  41. ["call_function", "call_method", "get_attr", "call_module", "placeholder", "output"]
  42. )
  43. # Signature for functions thattransforms the body (`list[str]`) of the
  44. # generated code
  45. TransformCodeFunc = Callable[[list[str]], list[str]]
  46. class _CustomBuiltin(NamedTuple):
  47. """Additional objs that we add to every graph's globals.
  48. The repr() for some standard library objects is not valid Python code without
  49. an import. For common objects of this sort, we bundle them in the globals of
  50. every FX graph.
  51. """
  52. # How to import this object from the standard library.
  53. import_str: str
  54. # The actual object, produced from that import string.
  55. obj: Any
  56. # Combined dict of disallowed variable names so we can check with one lookup
  57. _illegal_names = {k: object() for k in keyword.kwlist}
  58. _illegal_names.update(builtins.__dict__) # can't shadow a builtin name
  59. _custom_builtins: dict[str, _CustomBuiltin] = {}
  60. def _register_custom_builtin(name: str, import_str: str, obj: Any):
  61. _custom_builtins[name] = _CustomBuiltin(import_str, obj)
  62. _illegal_names[name] = obj
  63. _register_custom_builtin("inf", "from math import inf", math.inf)
  64. _register_custom_builtin("nan", "from math import nan", math.nan)
  65. _register_custom_builtin("NoneType", "NoneType = type(None)", type(None))
  66. _register_custom_builtin("torch", "import torch", torch)
  67. _register_custom_builtin("device", "from torch import device", torch.device)
  68. _register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree)
  69. _register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree)
  70. def _is_magic(x: str) -> bool:
  71. return x.startswith("__") and x.endswith("__")
  72. def _snake_case(s: str) -> str:
  73. """
  74. Transforms the given string ``s`` to a Python-style variable name
  75. Examples:
  76. ``mod.snake_case`` -> ``mod.snake_case``
  77. ``mod.pascalCase``-> ``mod.pascal_case``
  78. ``mod.ALL_CAPS`` -> ``mod.all_caps``
  79. """
  80. return _snake_case_sub(s).lower()
  81. # Replace occurrences where a lowercase letter is followed by an uppercase letter
  82. _snake_case_sub = functools.partial(re.compile(r"(?<=[a-z])([A-Z])").sub, r"_\1")
  83. # Find chars that can't be in a Python identifier
  84. _illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
  85. # Combined check for variable names:
  86. # 1) Checks name is not empty
  87. # 2) Checks first character is not a digit
  88. # 3) Checks name has no illegal characters (_illegal_char_regex)
  89. # 3) Splits off the number suffix (if present)
  90. _name_regex = re.compile(r"^([a-zA-Z_][0-9a-zA-Z_]*?)(?:_(\d+))?$")
  91. # starts with torch but does not start with torch._dynamo. or torch._inductor.
  92. _torch_but_not_dynamo = re.compile(
  93. r"^torch(?:\.(?!_dynamo\.|_inductor\.)[^.]+)*$"
  94. ).fullmatch
  95. def _is_from_torch(obj: Any) -> bool:
  96. module_name = getattr(obj, "__module__", None)
  97. if module_name is not None:
  98. return _torch_but_not_dynamo(module_name) is not None
  99. name = getattr(obj, "__name__", None)
  100. # exclude torch because torch.torch.torch.torch works. idk mang
  101. if name is not None and name != "torch":
  102. for guess in [torch, torch.nn.functional]:
  103. if getattr(guess, name, None) is obj:
  104. return True
  105. return False
  106. class _Namespace:
  107. """A context for associating names uniquely with objects.
  108. The following invariants are enforced:
  109. - Each object gets a single name.
  110. - Each name is unique within a given namespace.
  111. - Names generated do not shadow builtins, unless the object is indeed that builtin.
  112. """
  113. def __init__(self):
  114. self._obj_to_name: dict[Any, str] = {}
  115. self._used_names: set[str] = set()
  116. self._base_count: dict[str, int] = {}
  117. def create_name(self, candidate: str, obj: Optional[Any]) -> str:
  118. """Create a unique name.
  119. Arguments:
  120. candidate: used as the basis for the unique name, relevant to the user.
  121. obj: If not None, an object that will be associated with the unique name.
  122. """
  123. if obj is not None and obj in self._obj_to_name:
  124. return self._obj_to_name[obj]
  125. # optimistically check if candidate is already a valid name
  126. match = _name_regex.match(candidate)
  127. if match is None:
  128. # delete all characters that are illegal in a Python identifier
  129. candidate = _illegal_char_regex.sub("_", candidate)
  130. if not candidate:
  131. candidate = "_unnamed"
  132. if candidate[0].isdigit():
  133. candidate = f"_{candidate}"
  134. match = _name_regex.match(candidate)
  135. assert match is not None
  136. base, num = match.group(1, 2)
  137. if num is None or candidate in self._used_names:
  138. num = self._base_count.get(candidate, 0)
  139. if _illegal_names.get(candidate, obj) is not obj:
  140. num += 1
  141. candidate = f"{base}_{num}"
  142. # assume illegal names don't end in _\d so no need to check again
  143. else:
  144. num = int(num)
  145. while candidate in self._used_names:
  146. num += 1
  147. candidate = f"{base}_{num}"
  148. self._used_names.add(candidate)
  149. self._base_count[base] = num
  150. if obj is not None:
  151. self._obj_to_name[obj] = candidate
  152. return candidate
  153. def associate_name_with_obj(self, name: str, obj: Any):
  154. """Associate a unique name with an object.
  155. Neither `name` nor `obj` should be associated already.
  156. """
  157. maybe_existing = self._obj_to_name.setdefault(obj, name)
  158. assert maybe_existing is name, "obj is already associated"
  159. def _rename_object(self, obj: Any, name: str):
  160. assert obj in self._obj_to_name
  161. self._obj_to_name[obj] = name
  162. self._used_names.add(name)
  163. @compatibility(is_backward_compatible=True)
  164. @dataclass
  165. class PythonCode:
  166. """
  167. Represents all the information necessary to exec or save a graph as Python code.
  168. """
  169. # Python source code for the forward function definition.
  170. src: str
  171. # Values in global scope during execution of `src_def`.
  172. globals: dict[str, Any]
  173. # Optional mapping from the forward function's line number to
  174. # node index.
  175. _lineno_map: Optional[dict[int, Optional[int]]]
  176. def _format_target(base: str, target: str) -> str:
  177. elems = target.split(".")
  178. r = base
  179. for e in elems:
  180. if not e.isidentifier():
  181. r = f'getattr({r}, "{e}")'
  182. else:
  183. r = f"{r}.{e}"
  184. return r
  185. class _InsertPoint:
  186. def __init__(self, graph, new_insert):
  187. self.graph = graph
  188. self.orig_insert, graph._insert = graph._insert, new_insert
  189. def __enter__(self):
  190. pass
  191. def __exit__(self, type, value, tb):
  192. self.graph._insert = self.orig_insert
  193. class _node_list:
  194. def __init__(self, graph: "Graph", direction: Literal["_prev", "_next"] = "_next"):
  195. assert direction in ("_next", "_prev")
  196. self.graph = graph
  197. self.direction = direction
  198. def __len__(self):
  199. return self.graph._len
  200. def __iter__(self):
  201. return _NodeIter(self.graph._root, self.direction == "_prev")
  202. def __reversed__(self):
  203. return _node_list(self.graph, "_next" if self.direction == "_prev" else "_prev")
  204. class _PyTreeInfo(NamedTuple):
  205. """
  206. Contains extra info stored when we're using Pytrees
  207. """
  208. orig_args: list[str]
  209. in_spec: pytree.TreeSpec
  210. out_spec: Optional[pytree.TreeSpec]
  211. @dataclass(frozen=True)
  212. class _ParsedStackTrace:
  213. """
  214. Represents the top-most frame of a parsed stack trace
  215. """
  216. file: str
  217. lineno: str
  218. name: str
  219. code: str
  220. def get_summary_str(self):
  221. return f"File: {self.file}:{self.lineno} in {self.name}, code: {self.code}"
  222. # get File:lineno code from stack_trace
  223. def _parse_stack_trace(stack_trace: str):
  224. if stack_trace is None:
  225. return None
  226. pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
  227. lines = stack_trace.strip().split("\n")
  228. # stacktrace should have innermost frame last, so we
  229. # iterate backwards to find the first line that starts
  230. # with 'File '
  231. for idx in range(len(lines) - 2, -1, -1):
  232. line = lines[idx].strip()
  233. matches = pattern.match(line)
  234. if matches:
  235. file = matches.group(1)
  236. lineno = matches.group(2)
  237. name = matches.group(3)
  238. # next line should be the code
  239. code = lines[idx + 1].strip()
  240. return _ParsedStackTrace(file, lineno, name, code)
  241. return None
  242. @compatibility(is_backward_compatible=False)
  243. class CodeGen:
  244. # This is an override hook so we can customize the SymNode printer.
  245. _sym_repr: Callable[["torch.types.PySymType"], str] = lambda x: repr(x)
  246. def __init__(self):
  247. self._body_transformer: Optional[TransformCodeFunc] = None
  248. self._func_name: str = "forward"
  249. def _format_multiline_args(self, args: list[str]) -> str:
  250. """Helper to format function arguments in expanded multiline format."""
  251. return "".join(self._format_single_arg(arg) for arg in args)
  252. def _format_single_arg(self, arg: str) -> str:
  253. """Helper to format a single argument with optional comment."""
  254. if "#" in arg:
  255. arg_part, comment_part = arg.split("#", 1)
  256. return f" {arg_part.rstrip()}, # {comment_part.lstrip()}\n"
  257. else:
  258. return f" {arg},\n"
  259. def _get_delimiters(self, container) -> tuple[str, str]:
  260. """Helper to get opening and closing delimiters for containers."""
  261. return ("(", ")") if isinstance(container, tuple) else ("[", "]")
  262. def _format_multiline_container(self, items, descs=None, prefix="") -> str:
  263. """Helper to format containers (lists/tuples) in multiline format."""
  264. ldelim, rdelim = self._get_delimiters(items)
  265. desc_trailers = self._get_desc_trailers(items, descs)
  266. return (
  267. f"{prefix}{ldelim}\n"
  268. + "".join(
  269. f" {item},{trailer}\n" for item, trailer in zip(items, desc_trailers)
  270. )
  271. + f"{rdelim}"
  272. )
  273. def _get_desc_trailers(self, items, descs):
  274. """Helper to generate description trailers for items."""
  275. if descs is None:
  276. return [""] * len(items)
  277. return [f" # {desc}" for desc in descs]
  278. def _call_method_with_signature_check(self, method, *args, **kwargs):
  279. """Helper to call a method with optional parameters based on signature."""
  280. sig = inspect.signature(method)
  281. # Filter kwargs to only include parameters that exist in the method signature
  282. filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
  283. return method(*args, **filtered_kwargs)
  284. def gen_fn_def(
  285. self,
  286. free_vars: list[str],
  287. maybe_return_annotation: str,
  288. *,
  289. expanded_def: bool = False,
  290. ) -> str:
  291. """
  292. Given the free variables and a return annotation, generates the beginning of the FX function.
  293. By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
  294. """
  295. # If the original function didn't have self as its first argument, we
  296. # would have added it.
  297. if len(free_vars) == 0 or free_vars[0] != "self":
  298. free_vars.insert(0, "self")
  299. if expanded_def:
  300. args_formatted = self._format_multiline_args(free_vars)
  301. return (
  302. f"def {self._func_name}(\n{args_formatted}){maybe_return_annotation}:"
  303. )
  304. else:
  305. return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"
  306. def generate_output(
  307. self, output_args: Argument, *, descs: Optional[Any] = None
  308. ) -> str:
  309. """
  310. Given the output arguments, generates the return statement of the FX function.
  311. Note: The returned statement should not be indented.
  312. """
  313. if descs is not None and isinstance(output_args, (list, tuple)):
  314. return self._format_multiline_container(output_args, descs, "return ")
  315. else:
  316. return f"return {repr(output_args)}"
  317. def process_inputs(self, *args: Any) -> Any:
  318. """
  319. Transforms the inputs so that the graph can take them as arguments, as
  320. non-default codegen may result in the inputs to the function being
  321. different from the inputs to the graph.
  322. If the graph was directly runnable, this invariant should hold true
  323. `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)`
  324. """
  325. return args
  326. def process_outputs(self, outputs: Any) -> Any:
  327. """
  328. Transforms the outputs of the graph to be identical to the codegen.
  329. See ``process_inputs`` for more details.
  330. """
  331. return outputs
  332. def additional_globals(self) -> list[tuple[str, Any]]:
  333. """
  334. If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
  335. For example, return ['List', typing.List] if you need ``List`` in the global context.
  336. """
  337. return []
  338. def _gen_python_code(
  339. self,
  340. nodes,
  341. root_module: str,
  342. namespace: _Namespace,
  343. *,
  344. verbose: bool = False,
  345. include_stride: bool = False,
  346. include_device: bool = False,
  347. colored: bool = False,
  348. # Render each argument on its own line
  349. expanded_def: bool = False,
  350. ) -> PythonCode:
  351. free_vars: list[str] = []
  352. body: list[str] = []
  353. globals_: dict[str, Any] = {}
  354. wrapped_fns: dict[str, None] = {}
  355. # Wrap string in list to pass by reference
  356. maybe_return_annotation: list[str] = [""]
  357. include_stride = include_stride or (
  358. os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1"
  359. )
  360. include_device = include_device or (
  361. os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1"
  362. )
  363. def add_global(name_hint: str, obj: Any):
  364. """Add an obj to be tracked as a global.
  365. We call this for names that reference objects external to the
  366. Graph, like functions or types.
  367. Returns: the global name that should be used to reference 'obj' in generated source.
  368. """
  369. if (
  370. _is_from_torch(obj) and obj != torch.device
  371. ): # to support registering torch.device
  372. # HACK: workaround for how torch custom ops are registered. We
  373. # can't import them like normal modules so they must retain their
  374. # fully qualified name.
  375. return _get_qualified_name(obj)
  376. # normalize the name hint to get a proper identifier
  377. global_name = namespace.create_name(name_hint, obj)
  378. if global_name in globals_:
  379. assert globals_[global_name] == obj
  380. return global_name
  381. globals_[global_name] = obj
  382. return global_name
  383. # Pre-fill the globals table with registered builtins.
  384. for name, (_, obj) in _custom_builtins.items():
  385. add_global(name, obj)
  386. def type_repr(o: Any):
  387. if o == ():
  388. # Empty tuple is used for empty tuple type annotation Tuple[()]
  389. return "()"
  390. typename = _type_repr(o)
  391. if origin_type := getattr(o, "__origin__", None):
  392. # list[...], typing.List[...], TensorType[...]
  393. if isinstance(o, typing._GenericAlias): # type: ignore[attr-defined]
  394. # This is a generic pre-PEP585 type, e.g. typing.List[torch.Tensor]
  395. origin_type = _origin_type_map.get(origin_type, origin_type)
  396. origin_typename = add_global(_type_repr(origin_type), origin_type)
  397. if hasattr(o, "__args__"):
  398. # Assign global names for each of the inner type variables.
  399. args = [type_repr(arg) for arg in o.__args__]
  400. if len(args) == 0:
  401. # Bare type, such as `typing.Tuple` with no subscript
  402. # This code-path used in Python < 3.9
  403. return origin_typename
  404. return f"{origin_typename}[{','.join(args)}]"
  405. else:
  406. # Bare type, such as `typing.Tuple` with no subscript
  407. # This code-path used in Python 3.9+
  408. return origin_typename
  409. # Common case: this is a regular module name like 'foo.bar.baz'
  410. return add_global(typename, o)
  411. if colored:
  412. red = _color_fns["red"]
  413. dim_green = _color_fns["dim_green"]
  414. dim = _color_fns["dim"]
  415. dim_blue = _color_fns["dim_blue"]
  416. blue = _color_fns["blue"]
  417. else:
  418. red = _identity
  419. dim_green = _identity
  420. dim = _identity
  421. dim_blue = _identity
  422. blue = _identity
  423. def _get_repr(arg: Any) -> str:
  424. if isinstance(arg, Node): # first because common
  425. return repr(arg)
  426. elif isinstance(arg, tuple) and hasattr(arg, "_fields"):
  427. # Handle NamedTuples (if it has `_fields`) via add_global.
  428. qualified_name = _get_qualified_name(type(arg))
  429. global_name = add_global(qualified_name, type(arg))
  430. return f"{global_name}{repr(tuple(arg))}"
  431. elif isinstance(
  432. arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
  433. ):
  434. qualified_name = _get_qualified_name(arg)
  435. global_name = add_global(qualified_name, arg)
  436. return f"{global_name}"
  437. elif isinstance(arg, enum.Enum):
  438. cls = arg.__class__
  439. clsname = add_global(cls.__name__, cls)
  440. return f"{clsname}.{arg.name}"
  441. elif isinstance(arg, torch.Tensor):
  442. size = list(arg.size())
  443. dtype = str(arg.dtype).split(".")[-1]
  444. return f"torch.Tensor(size={size}, dtype={dtype})"
  445. elif isinstance(arg, tuple):
  446. if len(arg) == 1:
  447. return f"({_get_repr(arg[0])},)"
  448. else:
  449. return "(" + ", ".join(_get_repr(a) for a in arg) + ")"
  450. elif isinstance(arg, list):
  451. return "[" + ", ".join(_get_repr(a) for a in arg) + "]"
  452. elif isinstance(arg, slice):
  453. return f"slice({_get_repr(arg.start)}, {_get_repr(arg.stop)}, {_get_repr(arg.step)})"
  454. else:
  455. return blue(repr(arg))
  456. def _format_args(
  457. args: tuple[Argument, ...], kwargs: dict[str, Argument]
  458. ) -> str:
  459. res = [_get_repr(a) for a in args]
  460. res.extend([f"{k} = {_get_repr(v)}" for k, v in kwargs.items()])
  461. return ", ".join(res)
  462. # Run through reverse nodes and record the first instance of a use
  463. # of a given node. This represents the *last* use of the node in the
  464. # execution order of the program, which we will use to free unused
  465. # values
  466. node_to_last_use: dict[Node, Node] = {}
  467. user_to_last_uses: dict[Node, list[Node]] = {}
  468. def register_last_uses(n: Node, user: Node):
  469. if n not in node_to_last_use:
  470. node_to_last_use[n] = user
  471. user_to_last_uses.setdefault(user, []).append(n)
  472. for node in reversed(nodes):
  473. for input_node in node._input_nodes:
  474. register_last_uses(input_node, node)
  475. def delete_unused_values(user: Node):
  476. """
  477. Delete values after their last use. This ensures that values that are
  478. not used in the remainder of the code are freed and the memory usage
  479. of the code is optimal.
  480. """
  481. if user.op == "placeholder":
  482. return
  483. if user.op == "output":
  484. body.append("\n")
  485. return
  486. nodes_to_delete = user_to_last_uses.get(user, [])
  487. if len(user.users.keys()) == 0:
  488. # This node is not used by any others. however it's also not
  489. # removed by DCE since side-effect. We want to free it's outputs
  490. # right after its execution done to save memory.
  491. nodes_to_delete.append(user)
  492. if len(nodes_to_delete):
  493. to_delete_str = " = ".join(
  494. [repr(n) for n in nodes_to_delete] + ["None"]
  495. )
  496. body.append(f"; {dim(to_delete_str)}\n")
  497. else:
  498. body.append("\n")
  499. prev_stacktrace = None
  500. def append_stacktrace_summary(node: Node):
  501. """
  502. Append a summary of the stacktrace to the generated code. This is
  503. useful for debugging.
  504. """
  505. nonlocal prev_stacktrace
  506. if node.op not in {"placeholder", "output"}:
  507. stack_trace = node.stack_trace
  508. if stack_trace:
  509. if stack_trace != prev_stacktrace:
  510. prev_stacktrace = stack_trace
  511. if parsed_stack_trace := _parse_stack_trace(stack_trace):
  512. summary_str = parsed_stack_trace.get_summary_str()
  513. else:
  514. summary_str = ""
  515. body.append(f"\n {dim(f'# {summary_str}')}\n")
  516. elif prev_stacktrace != "":
  517. prev_stacktrace = ""
  518. no_stacktrace_msg = "# No stacktrace found for following nodes"
  519. body.append(f"\n{dim(no_stacktrace_msg)}\n")
  520. def stringify_shape(shape: Iterable) -> str:
  521. return f"[{', '.join([str(x) for x in shape])}]"
  522. def emit_node(node: Node):
  523. maybe_type_annotation = (
  524. "" if node.type is None else f" : {type_repr(node.type)}"
  525. )
  526. maybe_comment = ""
  527. if verbose:
  528. # override annotation with more detailed information
  529. from torch.fx.experimental.proxy_tensor import py_sym_types
  530. from torch.fx.passes.shape_prop import TensorMetadata
  531. meta_val = node.meta.get(
  532. "val",
  533. node.meta.get("tensor_meta", node.meta.get("example_value", None)),
  534. )
  535. # use string as annotation, to make it valid python code
  536. if isinstance(meta_val, torch.Tensor) and meta_val.layout not in (
  537. torch.sparse_csc,
  538. torch.sparse_csr,
  539. ):
  540. stride_annotation = (
  541. f"{stringify_shape(meta_val.stride())}"
  542. if include_stride
  543. else ""
  544. )
  545. device_annotation = f"{meta_val.device}" if include_device else ""
  546. maybe_type_annotation = (
  547. f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}'
  548. f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"'
  549. )
  550. elif isinstance(meta_val, py_sym_types):
  551. val_str = CodeGen._sym_repr(meta_val)
  552. maybe_type_annotation = f': "Sym({val_str})"'
  553. elif isinstance(meta_val, TensorMetadata):
  554. maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
  555. desc = None
  556. if expanded_def:
  557. desc = node.meta.get("desc", None)
  558. if desc is not None and node.op == "placeholder":
  559. maybe_comment += f" # {desc}"
  560. # output is handled specially
  561. if node.op == "placeholder":
  562. assert isinstance(node.target, str)
  563. maybe_default_arg = (
  564. "" if not node.args else f" = {_get_repr(node.args[0])}"
  565. )
  566. free_vars.append(
  567. f"{node.target}{maybe_type_annotation}{maybe_default_arg}{maybe_comment}"
  568. )
  569. raw_name = node.target.replace("*", "")
  570. if raw_name != repr(node):
  571. body.append(f"{repr(node)} = {raw_name}\n")
  572. return
  573. elif node.op == "call_method":
  574. assert isinstance(node.target, str)
  575. body.append(
  576. f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}"
  577. f"({_format_args(node.args[1:], node.kwargs)})"
  578. )
  579. return
  580. elif node.op == "call_function":
  581. assert callable(node.target)
  582. # pretty print operators
  583. if (
  584. getattr(node.target, "__module__", "") == "_operator"
  585. and node.target.__name__ in magic_methods
  586. ):
  587. assert isinstance(node.args, tuple)
  588. body.append(
  589. f"{repr(node)}{maybe_type_annotation} = "
  590. f"{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}"
  591. )
  592. return
  593. # pretty print inplace operators; required for jit.script to work properly
  594. # not currently supported in normal FX graphs, but generated by torchdynamo
  595. if (
  596. getattr(node.target, "__module__", "") == "_operator"
  597. and node.target.__name__ in inplace_methods
  598. ):
  599. body.append(
  600. f"{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; "
  601. f"{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}"
  602. )
  603. return
  604. qualified_name = _get_qualified_name(node.target)
  605. global_name = add_global(qualified_name, node.target)
  606. # special case for getattr: node.args could be 2-argument or 3-argument
  607. # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
  608. if (
  609. global_name == "getattr"
  610. and isinstance(node.args, tuple)
  611. and isinstance(node.args[1], str)
  612. and node.args[1].isidentifier()
  613. and len(node.args) == 2
  614. ):
  615. body.append(
  616. f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}"
  617. )
  618. return
  619. body.append(
  620. f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
  621. )
  622. if node.meta.get("is_wrapped", False):
  623. wrapped_fns.setdefault(global_name)
  624. return
  625. elif node.op == "call_module":
  626. assert isinstance(node.target, str)
  627. body.append(
  628. f"{repr(node)}{maybe_type_annotation} = "
  629. f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
  630. )
  631. return
  632. elif node.op == "get_attr":
  633. assert isinstance(node.target, str)
  634. body.append(
  635. f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}"
  636. )
  637. return
  638. elif node.op == "output":
  639. if node.type is not None:
  640. maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
  641. body.append(
  642. self._call_method_with_signature_check(
  643. self.generate_output,
  644. node.args[0],
  645. descs=desc if expanded_def else None,
  646. )
  647. )
  648. return
  649. raise NotImplementedError(f"node: {node.op} {node.target}")
  650. for i, node in enumerate(nodes):
  651. # NOTE: emit_node does not emit a string with newline. It depends
  652. # on delete_unused_values to append one
  653. if verbose:
  654. append_stacktrace_summary(node)
  655. # emit a counter comment to keep track of
  656. # node index, which will be deleted later
  657. # after going through _body_transformer
  658. body.append(f"# COUNTER: {i}\n")
  659. emit_node(node)
  660. delete_unused_values(node)
  661. if len(body) == 0:
  662. # If the Graph has no non-placeholder nodes, no lines for the body
  663. # have been emitted. To continue to have valid Python code, emit a
  664. # single pass statement
  665. body.append("pass\n")
  666. if len(wrapped_fns) > 0:
  667. wrap_name = add_global("wrap", torch.fx.wrap)
  668. wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
  669. else:
  670. wrap_stmts = ""
  671. if self._body_transformer:
  672. body = self._body_transformer(body)
  673. for name, value in self.additional_globals():
  674. add_global(name, value)
  675. prologue = self._call_method_with_signature_check(
  676. self.gen_fn_def,
  677. free_vars,
  678. maybe_return_annotation[0],
  679. expanded_def=expanded_def,
  680. )
  681. # remove counter and generate lineno to node index mapping
  682. lineno_map: dict[int, Optional[int]] = {}
  683. prologue_len = prologue.count("\n") + 1
  684. new_lines: list[str] = []
  685. cur_idx = None
  686. for line in "".join(body).split("\n"):
  687. counter = _counter_regexp.search(line)
  688. if counter is not None:
  689. cur_idx = int(counter.group(1))
  690. else:
  691. lineno_map[len(new_lines) + prologue_len] = cur_idx
  692. new_lines.append(line)
  693. code = "\n".join(new_lines).lstrip("\n")
  694. code = "\n".join(" " + line for line in code.split("\n"))
  695. fn_code = f"""
  696. {wrap_stmts}
  697. {prologue}
  698. {code}"""
  699. return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
  700. # Ideally, we'd like to refactor all of the pytree logic into this codegen
  701. # class. Unfortunately, there are 3 areas we currently need extra logic in FX.
  702. # 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`.
  703. # 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec.
  704. # Since we can't access .graph within the FX forward, we need to copy the attribute to the module.
  705. # 3. We currently can't register the pytree imports with `add_global` - not sure why.
  706. class _PyTreeCodeGen(CodeGen):
  707. def __init__(self, pytree_info: _PyTreeInfo):
  708. super().__init__()
  709. self.pytree_info: _PyTreeInfo = pytree_info
  710. def process_inputs(self, *inputs: Any) -> Any:
  711. flat_args = pytree.arg_tree_leaves(*inputs)
  712. return flat_args
  713. def process_outputs(self, out: Any) -> Any:
  714. if self.pytree_info is None or self.pytree_info.out_spec is None:
  715. return out
  716. if not isinstance(out, (list, tuple)):
  717. out = [out]
  718. assert self.pytree_info.out_spec is not None
  719. return pytree.tree_unflatten(out, self.pytree_info.out_spec)
  720. def _format_annotations(self, free_vars: list[str], expanded_def: bool) -> str:
  721. """Helper to format annotations for variables in pytree codegen."""
  722. if not free_vars:
  723. return ""
  724. has_annotation = [x for x in free_vars if ":" in x]
  725. if not has_annotation:
  726. return ""
  727. if expanded_def:
  728. return "\n " + "\n ".join(has_annotation)
  729. else:
  730. return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
  731. def gen_fn_def(
  732. self, free_vars, maybe_return_annotation, *, expanded_def: bool = False
  733. ):
  734. # Given a user function/model:
  735. # myargs = [myargs0, myargs1]
  736. # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
  737. # def forward(self, mypos, *myargs, mykey=None, **mykwargs):
  738. #
  739. # The generated code flattens all keywords into positional arguments for `forward()`
  740. # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1):
  741. #
  742. # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately
  743. # e.g. tree_flatten_spec(([mypos, myargs0, myargs1],
  744. # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}),
  745. # self._in_spec)
  746. #
  747. # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
  748. # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
  749. if self.pytree_info is None:
  750. return super().gen_fn_def(
  751. free_vars, maybe_return_annotation, expanded_def=expanded_def
  752. )
  753. fn_args = self.pytree_info.orig_args
  754. has_orig_self = (fn_args[0] == "self") if len(fn_args) > 0 else False
  755. if has_orig_self:
  756. free_vars.insert(0, "self")
  757. fn_definition = super().gen_fn_def(
  758. fn_args[:], maybe_return_annotation, expanded_def=expanded_def
  759. )
  760. if len(free_vars) > 0: # pytree has placeholders in it
  761. # when kwargs is present, in_spec is tuple(args, kwargs)
  762. has_args_kwargs_tuple = (
  763. self.pytree_info.in_spec.type == tuple
  764. and self.pytree_info.in_spec.num_children == 2
  765. and self.pytree_info.in_spec.children_specs[0].type == tuple
  766. and self.pytree_info.in_spec.children_specs[1].type == dict
  767. )
  768. fn_kwargs = "{}"
  769. fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
  770. if has_args_kwargs_tuple:
  771. count_args = self.pytree_info.in_spec.children_specs[0].num_children
  772. fn_args = self.pytree_info.orig_args[:count_args]
  773. fn_kwargs = (
  774. "{"
  775. + ", ".join(
  776. f"'{k}':{v}"
  777. for k, v in zip(
  778. self.pytree_info.in_spec.children_specs[1].context,
  779. self.pytree_info.orig_args[count_args:],
  780. )
  781. )
  782. + "}"
  783. )
  784. fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
  785. # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
  786. # we need to split it to two lines:
  787. # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
  788. # one for code: `var1, var2, = function_call()`
  789. without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
  790. fn_definition += self._format_annotations(free_vars, expanded_def)
  791. fn_definition += f"""
  792. {", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
  793. return fn_definition
  794. def generate_output(self, output_args, *, descs: Optional[Any] = None):
  795. if self.pytree_info and self.pytree_info.out_spec:
  796. if descs is not None and isinstance(output_args, (list, tuple)):
  797. return (
  798. self._format_multiline_container(
  799. output_args, descs, "return pytree.tree_unflatten("
  800. )
  801. + ", self._out_spec)"
  802. )
  803. else:
  804. return (
  805. f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)"
  806. )
  807. else:
  808. return super().generate_output(output_args, descs=descs)
  809. class _FindNodesLookupTable:
  810. """
  811. Side table for the graph for the purpose of doing fast queries
  812. """
  813. def __init__(self):
  814. self.table: dict[tuple[str, Optional[Target]], dict[Node, None]] = defaultdict(
  815. dict
  816. )
  817. def _key(self, node) -> tuple[str, Optional[Target]]:
  818. return (node.op, node.target if node.op == "call_function" else None)
  819. def __contains__(self, node) -> bool:
  820. return node in self.table[self._key(node)]
  821. def insert(self, node: Node) -> None:
  822. self.table[self._key(node)][node] = None
  823. def remove(self, node: Node) -> None:
  824. self.table[self._key(node)].pop(node)
  825. def find_nodes(self, *, op: str, target: Optional["Target"] = None):
  826. if op == "call_function":
  827. assert target is not None
  828. return [*self.table[(op, target)].keys()]
  829. if target is None:
  830. return [*self.table[(op, None)].keys()]
  831. # op is call_method, get_attr, call_module
  832. return [node for node in self.table[(op, None)].keys() if node.target == target]
  833. @compatibility(is_backward_compatible=True)
  834. class Graph:
  835. """
  836. ``Graph`` is the main data structure used in the FX Intermediate Representation.
  837. It consists of a series of ``Node`` s, each representing callsites (or other
  838. syntactic constructs). The list of ``Node`` s, taken together, constitute a
  839. valid Python function.
  840. For example, the following code
  841. .. code-block:: python
  842. import torch
  843. import torch.fx
  844. class MyModule(torch.nn.Module):
  845. def __init__(self):
  846. super().__init__()
  847. self.param = torch.nn.Parameter(torch.rand(3, 4))
  848. self.linear = torch.nn.Linear(4, 5)
  849. def forward(self, x):
  850. return torch.topk(
  851. torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3
  852. )
  853. m = MyModule()
  854. gm = torch.fx.symbolic_trace(m)
  855. Will produce the following Graph::
  856. print(gm.graph)
  857. .. code-block:: text
  858. graph(x):
  859. %linear_weight : [num_users=1] = self.linear.weight
  860. %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
  861. %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
  862. %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
  863. %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
  864. %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
  865. return topk_1
  866. For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
  867. """
  868. @compatibility(is_backward_compatible=True)
  869. def __init__(
  870. self,
  871. owning_module: Optional["GraphModule"] = None,
  872. tracer_cls: Optional[type["Tracer"]] = None,
  873. tracer_extras: Optional[dict[str, Any]] = None,
  874. ):
  875. """
  876. Construct an empty Graph.
  877. """
  878. self._root: Node = Node(self, "", "root", "", (), {})
  879. self._used_names: dict[str, int] = {} # base name -> number
  880. self._insert = self._root.prepend
  881. self._len = 0
  882. self._graph_namespace = _Namespace()
  883. self._owning_module = owning_module
  884. self._tracer_cls = tracer_cls
  885. self._tracer_extras = tracer_extras
  886. self._codegen = CodeGen()
  887. self._co_fields: dict[str, Any] = {}
  888. self._find_nodes_lookup_table = _FindNodesLookupTable()
  889. @property
  890. def owning_module(self):
  891. return self._owning_module
  892. @owning_module.setter
  893. def owning_module(self, mod: Optional["GraphModule"]):
  894. self._owning_module = mod
  895. @property
  896. def nodes(self) -> _node_list:
  897. """
  898. Get the list of Nodes that constitute this Graph.
  899. Note that this ``Node`` list representation is a doubly-linked list. Mutations
  900. during iteration (e.g. delete a Node, add a Node) are safe.
  901. Returns:
  902. A doubly-linked list of Nodes. Note that ``reversed`` can be called on
  903. this list to switch iteration order.
  904. """
  905. return _node_list(self)
  906. @compatibility(is_backward_compatible=False)
  907. def output_node(self) -> Node:
  908. output_node = next(iter(reversed(self.nodes)))
  909. assert output_node.op == "output"
  910. return output_node
  911. @compatibility(is_backward_compatible=False)
  912. def find_nodes(
  913. self, *, op: str, target: Optional["Target"] = None, sort: bool = True
  914. ):
  915. """
  916. Allows for fast query of nodes
  917. Args:
  918. op (str): the name of the operation
  919. target (Optional[Target]): the target of the node. For call_function,
  920. the target is required. For other ops, the target is optional.
  921. sort (bool): whether to return nodes in the order they appear on
  922. on the graph.
  923. Returns:
  924. Iterable of nodes with the requested op and target.
  925. """
  926. node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target)
  927. if sort:
  928. return sorted(node_list)
  929. return node_list
  930. @compatibility(is_backward_compatible=True)
  931. def graph_copy(
  932. self, g: "Graph", val_map: dict[Node, Node], return_output_node=False
  933. ) -> "Optional[Argument]":
  934. """
  935. Copy all nodes from a given graph into ``self``.
  936. Args:
  937. g (Graph): The source graph from which to copy Nodes.
  938. val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping
  939. from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed
  940. in with values in it already to override copying of certain values.
  941. Returns:
  942. The value in ``self`` that is now equivalent to the output value in ``g``,
  943. if ``g`` had an ``output`` node. ``None`` otherwise.
  944. """
  945. for node in g.nodes:
  946. if node in val_map:
  947. continue
  948. if node.op == "output":
  949. rv = map_arg(node.args[0], lambda n: val_map[n])
  950. return rv if not return_output_node else (rv, node)
  951. val_map[node] = self.node_copy(node, lambda n: val_map[n])
  952. return None
  953. def __deepcopy__(self, memo=None) -> "Graph":
  954. """
  955. Explicitly implement __deepcopy__ to prevent excessive recursion depth
  956. from the default implementation. This uses graph_copy to copy the nodes
  957. in an iterative way, rather than recursive. It also populates the
  958. memoization table to prevent unnecessary copies (e.g. references to
  959. nodes or other parts of the Graph from a custom GraphModule implementation.
  960. """
  961. memo = memo if memo else {}
  962. g = Graph(tracer_cls=self._tracer_cls)
  963. output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
  964. g._codegen = copy.deepcopy(self._codegen)
  965. if output_vals is not None:
  966. assert isinstance(output_vals, tuple)
  967. output_val, old_output_node = output_vals
  968. new_output_node = g.output(
  969. output_val, type_expr=getattr(old_output_node, "type", None)
  970. )
  971. new_output_node.meta = copy.copy(old_output_node.meta)
  972. return g
  973. @compatibility(is_backward_compatible=True)
  974. def create_node(
  975. self,
  976. op: str,
  977. target: "Target",
  978. args: Optional[tuple["Argument", ...]] = None,
  979. kwargs: Optional[dict[str, "Argument"]] = None,
  980. name: Optional[str] = None,
  981. type_expr: Optional[Any] = None,
  982. ) -> Node:
  983. """
  984. Create a ``Node`` and add it to the ``Graph`` at the current insert-point.
  985. Note that the current insert-point can be set via :meth:`Graph.inserting_before`
  986. and :meth:`Graph.inserting_after`.
  987. Args:
  988. op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr',
  989. 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are
  990. described in the ``Graph`` docstring.
  991. args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node.
  992. kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node
  993. name (Optional[str]): an optional string name for the ``Node``.
  994. This will influence the name of the value assigned to in the
  995. Python generated code.
  996. type_expr (Optional[Any]): an optional type annotation representing the
  997. Python type the output of this node will have.
  998. Returns:
  999. The newly-created and inserted node.
  1000. """
  1001. # `target in _legal_ops` is checked in Node.__init__
  1002. if not args:
  1003. args = ()
  1004. else:
  1005. assert isinstance(args, tuple), "args must be a tuple"
  1006. if not kwargs:
  1007. kwargs = immutable_dict()
  1008. else:
  1009. assert isinstance(kwargs, dict), "kwargs must be a dict"
  1010. candidate = name if name is not None else self._target_to_str(target)
  1011. name = self._graph_namespace.create_name(candidate, None)
  1012. n = Node(self, name, op, target, args, kwargs, type_expr)
  1013. if (
  1014. self.owning_module is not None
  1015. and getattr(self.owning_module, "_create_node_hooks", None) is not None
  1016. ):
  1017. for f in self.owning_module._create_node_hooks:
  1018. f(n)
  1019. self._graph_namespace.associate_name_with_obj(name, n)
  1020. self._insert(n)
  1021. self._find_nodes_lookup_table.insert(n)
  1022. self._len += 1
  1023. return n
  1024. @compatibility(is_backward_compatible=False)
  1025. def process_inputs(self, *args):
  1026. """
  1027. Processes args so that they can be passed to the FX graph.
  1028. """
  1029. return self._codegen.process_inputs(*args)
  1030. @compatibility(is_backward_compatible=False)
  1031. def process_outputs(self, out):
  1032. return self._codegen.process_outputs(out)
  1033. @compatibility(is_backward_compatible=True)
  1034. def erase_node(self, to_erase: Node) -> None:
  1035. """
  1036. Erases a ``Node`` from the ``Graph``. Throws an exception if
  1037. there are still users of that node in the ``Graph``.
  1038. Args:
  1039. to_erase (Node): The ``Node`` to erase from the ``Graph``.
  1040. """
  1041. if len(to_erase.users) > 0:
  1042. raise RuntimeError(
  1043. f"Tried to erase Node {to_erase} but it still had {len(to_erase.users)} "
  1044. f"users in the graph: {to_erase.users}!"
  1045. )
  1046. if to_erase.graph != self:
  1047. raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!")
  1048. if to_erase._erased:
  1049. warnings.warn(f"erase_node({to_erase}) on an already erased node")
  1050. return
  1051. if (
  1052. self.owning_module is not None
  1053. and getattr(self.owning_module, "_erase_node_hooks", None) is not None
  1054. ):
  1055. for f in self.owning_module._erase_node_hooks:
  1056. f(to_erase)
  1057. self._find_nodes_lookup_table.remove(to_erase)
  1058. to_erase._remove_from_list()
  1059. to_erase._erased = True # iterators may retain handles to erased nodes
  1060. self._len -= 1
  1061. # Null out this Node's argument nodes so that the Nodes referred to
  1062. # can update their ``users`` accordingly
  1063. to_erase._update_args_kwargs(
  1064. map_arg(to_erase._args, lambda n: None),
  1065. map_arg(to_erase._kwargs, lambda n: None),
  1066. )
  1067. @compatibility(is_backward_compatible=True)
  1068. def inserting_before(self, n: Optional[Node] = None):
  1069. """Set the point at which create_node and companion methods will insert into the graph.
  1070. When used within a 'with' statement, this will temporary set the insert point and
  1071. then restore it when the with statement exits::
  1072. with g.inserting_before(n):
  1073. ... # inserting before node n
  1074. ... # insert point restored to what it was previously
  1075. g.inserting_before(n) # set the insert point permanently
  1076. Args:
  1077. n (Optional[Node]): The node before which to insert. If None this will insert before
  1078. the beginning of the entire graph.
  1079. Returns:
  1080. A resource manager that will restore the insert point on ``__exit__``.
  1081. """
  1082. if n is None:
  1083. return self.inserting_after(self._root)
  1084. assert n.graph == self, "Node to insert before is not in graph."
  1085. return _InsertPoint(self, n.prepend)
  1086. @compatibility(is_backward_compatible=True)
  1087. def inserting_after(self, n: Optional[Node] = None):
  1088. """Set the point at which create_node and companion methods will insert into the graph.
  1089. When used within a 'with' statement, this will temporary set the insert point and
  1090. then restore it when the with statement exits::
  1091. with g.inserting_after(n):
  1092. ... # inserting after node n
  1093. ... # insert point restored to what it was previously
  1094. g.inserting_after(n) # set the insert point permanently
  1095. Args:
  1096. n (Optional[Node]): The node before which to insert. If None this will insert after
  1097. the beginning of the entire graph.
  1098. Returns:
  1099. A resource manager that will restore the insert point on ``__exit__``.
  1100. """
  1101. if n is None:
  1102. return self.inserting_before(self._root)
  1103. assert n.graph == self, "Node to insert after is not in graph."
  1104. return _InsertPoint(self, n.append)
  1105. @compatibility(is_backward_compatible=True)
  1106. def placeholder(
  1107. self,
  1108. name: str,
  1109. type_expr: Optional[Any] = None,
  1110. default_value: Any = inspect.Signature.empty,
  1111. ) -> Node:
  1112. """
  1113. Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
  1114. a function input.
  1115. Args:
  1116. name (str): A name for the input value. This corresponds to the name
  1117. of the positional argument to the function this ``Graph`` represents.
  1118. type_expr (Optional[Any]): an optional type annotation representing the
  1119. Python type the output of this node will have. This is needed in some
  1120. cases for proper code generation (e.g. when the function is used
  1121. subsequently in TorchScript compilation).
  1122. default_value (Any): The default value this function argument should take
  1123. on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty`
  1124. should be passed as this argument to specify that the parameter does _not_
  1125. have a default value.
  1126. .. note::
  1127. The same insertion point and type expression rules apply for this method
  1128. as ``Graph.create_node``.
  1129. """
  1130. args = () if default_value is inspect.Signature.empty else (default_value,)
  1131. return self.create_node("placeholder", name, args=args, type_expr=type_expr)
  1132. @compatibility(is_backward_compatible=True)
  1133. def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
  1134. """
  1135. Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the
  1136. fetch of an attribute from the ``Module`` hierarchy.
  1137. Args:
  1138. qualified_name (str): the fully-qualified name of the attribute to be retrieved.
  1139. For example, if the traced Module has a submodule named ``foo``, which has a
  1140. submodule named ``bar``, which has an attribute named ``baz``, the qualified
  1141. name ``foo.bar.baz`` should be passed as ``qualified_name``.
  1142. type_expr (Optional[Any]): an optional type annotation representing the
  1143. Python type the output of this node will have.
  1144. Returns:
  1145. The newly-created and inserted ``get_attr`` node.
  1146. .. note::
  1147. The same insertion point and type expression rules apply for this method
  1148. as ``Graph.create_node``.
  1149. """
  1150. def _get_attr_reference_exists(
  1151. mod: torch.nn.Module, qualified_name: str
  1152. ) -> bool:
  1153. module_path, _, name = qualified_name.rpartition(".")
  1154. try:
  1155. submod: torch.nn.Module = mod.get_submodule(module_path)
  1156. except AttributeError:
  1157. warnings.warn(f"Failed to fetch module {module_path}!")
  1158. return False
  1159. if not hasattr(submod, name):
  1160. return False
  1161. res = getattr(submod, name)
  1162. if (
  1163. not isinstance(res, torch.nn.Module)
  1164. and not isinstance(res, torch.nn.Parameter)
  1165. and name not in submod._buffers
  1166. ):
  1167. return False
  1168. return True
  1169. if self.owning_module and not _get_attr_reference_exists(
  1170. self.owning_module, qualified_name
  1171. ):
  1172. warnings.warn(
  1173. "Attempted to insert a get_attr Node with no "
  1174. "underlying reference in the owning "
  1175. "GraphModule! Call "
  1176. "GraphModule.add_submodule to add the "
  1177. "necessary submodule, "
  1178. "GraphModule.add_parameter to add the "
  1179. "necessary Parameter, or "
  1180. "nn.Module.register_buffer to add the "
  1181. "necessary buffer",
  1182. stacklevel=2,
  1183. )
  1184. return self.create_node("get_attr", qualified_name, type_expr=type_expr)
  1185. @compatibility(is_backward_compatible=True)
  1186. def call_module(
  1187. self,
  1188. module_name: str,
  1189. args: Optional[tuple["Argument", ...]] = None,
  1190. kwargs: Optional[dict[str, "Argument"]] = None,
  1191. type_expr: Optional[Any] = None,
  1192. ) -> Node:
  1193. """
  1194. Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node
  1195. represents a call to the forward() function of a ``Module`` in the ``Module``
  1196. hierarchy.
  1197. Args:
  1198. module_name (str): The qualified name of the ``Module`` in the ``Module``
  1199. hierarchy to be called. For example, if the traced ``Module`` has a
  1200. submodule named ``foo``, which has a submodule named ``bar``, the
  1201. qualified name ``foo.bar`` should be passed as ``module_name`` to
  1202. call that module.
  1203. args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
  1204. to the called method. Note that this should *not* include a ``self`` argument.
  1205. kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
  1206. to the called method
  1207. type_expr (Optional[Any]): an optional type annotation representing the
  1208. Python type the output of this node will have.
  1209. Returns:
  1210. The newly-created and inserted ``call_module`` node.
  1211. .. note::
  1212. The same insertion point and type expression rules apply for this method
  1213. as :meth:`Graph.create_node`.
  1214. """
  1215. if self.owning_module and self.owning_module.get_submodule(module_name) is None:
  1216. warnings.warn(
  1217. "Attempted to insert a call_module Node with "
  1218. "no underlying reference in the owning "
  1219. "GraphModule! Call "
  1220. "GraphModule.add_submodule to add the "
  1221. "necessary submodule"
  1222. )
  1223. return self.create_node(
  1224. "call_module", module_name, args, kwargs, type_expr=type_expr
  1225. )
  1226. @compatibility(is_backward_compatible=True)
  1227. def call_method(
  1228. self,
  1229. method_name: str,
  1230. args: Optional[tuple["Argument", ...]] = None,
  1231. kwargs: Optional[dict[str, "Argument"]] = None,
  1232. type_expr: Optional[Any] = None,
  1233. ) -> Node:
  1234. """
  1235. Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node
  1236. represents a call to a given method on the 0th element of ``args``.
  1237. Args:
  1238. method_name (str): The name of the method to apply to the self argument.
  1239. For example, if args[0] is a ``Node`` representing a ``Tensor``,
  1240. then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``.
  1241. args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
  1242. to the called method. Note that this *should* include a ``self`` argument.
  1243. kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
  1244. to the called method
  1245. type_expr (Optional[Any]): an optional type annotation representing the
  1246. Python type the output of this node will have.
  1247. Returns:
  1248. The newly created and inserted ``call_method`` node.
  1249. .. note::
  1250. The same insertion point and type expression rules apply for this method
  1251. as :meth:`Graph.create_node`.
  1252. """
  1253. return self.create_node(
  1254. "call_method", method_name, args, kwargs, type_expr=type_expr
  1255. )
  1256. @compatibility(is_backward_compatible=True)
  1257. def call_function(
  1258. self,
  1259. the_function: Callable[..., Any],
  1260. args: Optional[tuple["Argument", ...]] = None,
  1261. kwargs: Optional[dict[str, "Argument"]] = None,
  1262. type_expr: Optional[Any] = None,
  1263. name: Optional[str] = None,
  1264. ) -> Node:
  1265. """
  1266. Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node
  1267. represents a call to a Python callable, specified by ``the_function``.
  1268. Args:
  1269. the_function (Callable[..., Any]): The function to be called. Can be any PyTorch
  1270. operator, Python function, or member of the ``builtins`` or ``operator``
  1271. namespaces.
  1272. args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
  1273. to the called function.
  1274. kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
  1275. to the called function
  1276. type_expr (Optional[Any]): an optional type annotation representing the
  1277. Python type the output of this node will have.
  1278. name (Optional[str]): The name of the node. If not specified, set to None
  1279. Returns:
  1280. The newly created and inserted ``call_function`` node.
  1281. .. note::
  1282. The same insertion point and type expression rules apply for this method
  1283. as :meth:`Graph.create_node`.
  1284. """
  1285. return self.create_node(
  1286. "call_function", the_function, args, kwargs, name=name, type_expr=type_expr
  1287. )
  1288. @compatibility(is_backward_compatible=True)
  1289. def node_copy(
  1290. self, node: Node, arg_transform: Callable[[Node], "Argument"] = lambda x: x
  1291. ) -> Node:
  1292. """
  1293. Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from
  1294. the graph of node to the graph of self. Example::
  1295. # Copying all the nodes in `g` into `new_graph`
  1296. g: torch.fx.Graph = ...
  1297. new_graph = torch.fx.graph()
  1298. value_remap = {}
  1299. for node in g.nodes:
  1300. value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
  1301. Args:
  1302. node (Node): The node to copy into ``self``.
  1303. arg_transform (Callable[[Node], Argument]): A function that transforms
  1304. ``Node`` arguments in node's ``args`` and ``kwargs`` into the
  1305. equivalent argument in ``self``. In the simplest case, this should
  1306. retrieve a value out of a table mapping Nodes in the original
  1307. graph to ``self``.
  1308. """
  1309. args = map_arg(node.args, arg_transform)
  1310. kwargs = map_arg(node.kwargs, arg_transform)
  1311. assert isinstance(args, tuple)
  1312. assert isinstance(kwargs, dict)
  1313. result_node = self.create_node(
  1314. node.op, node.target, args, kwargs, node.name, node.type
  1315. )
  1316. result_node.meta = copy.copy(node.meta)
  1317. return result_node
  1318. @compatibility(is_backward_compatible=True)
  1319. def output(self, result: "Argument", type_expr: Optional[Any] = None):
  1320. """
  1321. Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents
  1322. a ``return`` statement in Python code. ``result`` is the value that should
  1323. be returned.
  1324. Args:
  1325. result (Argument): The value to be returned.
  1326. type_expr (Optional[Any]): an optional type annotation representing the
  1327. Python type the output of this node will have.
  1328. .. note::
  1329. The same insertion point and type expression rules apply for this method
  1330. as ``Graph.create_node``.
  1331. """
  1332. return self.create_node(
  1333. op="output", target="output", args=(result,), type_expr=type_expr
  1334. )
  1335. def _target_to_str(self, target: Optional[Target]) -> str:
  1336. if callable(target):
  1337. op = target.__name__
  1338. else:
  1339. assert isinstance(target, str)
  1340. op = target
  1341. if _is_magic(op):
  1342. op = op[2:-2]
  1343. op = _snake_case(op)
  1344. return op
  1345. @compatibility(is_backward_compatible=True)
  1346. def python_code(
  1347. self,
  1348. root_module: str,
  1349. *,
  1350. verbose: bool = False,
  1351. include_stride: bool = False,
  1352. include_device: bool = False,
  1353. colored: bool = False,
  1354. expanded_def: bool = False,
  1355. ) -> PythonCode:
  1356. """
  1357. Turn this ``Graph`` into valid Python code.
  1358. Args:
  1359. root_module (str): The name of the root module on which to look-up
  1360. qualified name targets. This is usually 'self'.
  1361. Returns:
  1362. A PythonCode object, consisting of two fields:
  1363. src: the Python source code representing the object
  1364. globals: a dictionary of global names in `src` -> the objects that they reference.
  1365. """
  1366. # NOTE: [Graph Namespaces]
  1367. #
  1368. # There are two types of symbols in generated Python source code:
  1369. # locals and globals.
  1370. # Locals are locally defined by the output of a node in the Graph.
  1371. # Globals are references to external objects, like functions or types.
  1372. #
  1373. # When generating Python code, we need to make sure to name things
  1374. # appropriately. In particular:
  1375. # - All names should be unique, to avoid weird shadowing bugs.
  1376. # - These names need to be consistent, e.g. a object should always be
  1377. # referenced by the same name.
  1378. #
  1379. # To do this, we create a new namespace just for this source. All names
  1380. # that get printed must come from this namespace.
  1381. #
  1382. # Why can't we reuse node.name? Because it was generated within the
  1383. # namespace `self._graph_namespace`. In order to provide uniqueness
  1384. # over both locals (node.name) *and* globals, we create a completely
  1385. # new namespace to put all identifiers in.
  1386. namespace = _Namespace()
  1387. # Override Node's repr to generate a valid name within our namespace.
  1388. # Since repr() is designed to produce a valid Python expression, it
  1389. # makes sense to reuse it. This way, it's easy to print something like
  1390. # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is
  1391. # implemented cooperatively to allow this.
  1392. def node_repr(n: Node):
  1393. return namespace.create_name(n.name, n)
  1394. @contextmanager
  1395. def override_node_repr(graph: Graph):
  1396. orig_repr_fns = {}
  1397. for node in graph.nodes:
  1398. orig_repr_fns[node] = node._repr_fn
  1399. node._repr_fn = node_repr
  1400. try:
  1401. yield None
  1402. finally:
  1403. # restore the original repr functions
  1404. for node in graph.nodes:
  1405. node._repr_fn = orig_repr_fns[node]
  1406. with override_node_repr(self):
  1407. return self._python_code(
  1408. root_module,
  1409. namespace,
  1410. verbose=verbose,
  1411. include_stride=include_stride,
  1412. include_device=include_device,
  1413. colored=colored,
  1414. expanded_def=expanded_def,
  1415. )
  1416. def _python_code(
  1417. self,
  1418. root_module: str,
  1419. namespace: _Namespace,
  1420. *,
  1421. verbose: bool = False,
  1422. include_stride: bool = False,
  1423. include_device: bool = False,
  1424. colored: bool = False,
  1425. expanded_def: bool = False,
  1426. ) -> PythonCode:
  1427. return self._codegen._gen_python_code(
  1428. self.nodes,
  1429. root_module,
  1430. namespace,
  1431. verbose=verbose,
  1432. include_stride=include_stride,
  1433. include_device=include_device,
  1434. colored=colored,
  1435. expanded_def=expanded_def,
  1436. )
  1437. def __str__(self) -> str:
  1438. """
  1439. Return a human-readable (not machine-readable) string representation
  1440. of this Graph
  1441. """
  1442. placeholder_names: list[str] = []
  1443. # This is a one-element array just so ``format_node`` can modify the closed
  1444. # over value
  1445. maybe_return_typename: list[str] = [""]
  1446. node_strs = [node.format_node(placeholder_names) for node in self.nodes]
  1447. param_str = ", ".join(placeholder_names)
  1448. s = f"graph({param_str}){maybe_return_typename[0]}:"
  1449. for node_str in node_strs:
  1450. if node_str:
  1451. s += "\n " + node_str
  1452. return s
  1453. @compatibility(is_backward_compatible=True)
  1454. def print_tabular(self):
  1455. """
  1456. Prints the intermediate representation of the graph in tabular
  1457. format. Note that this API requires the ``tabulate`` module to be
  1458. installed.
  1459. """
  1460. try:
  1461. from tabulate import tabulate
  1462. except ImportError:
  1463. print(
  1464. "`print_tabular` relies on the library `tabulate`, "
  1465. "which could not be found on this machine. Run `pip "
  1466. "install tabulate` to install the library."
  1467. )
  1468. raise
  1469. node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes]
  1470. print(
  1471. tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"])
  1472. )
  1473. @compatibility(is_backward_compatible=True)
  1474. def lint(self):
  1475. """
  1476. Runs various checks on this Graph to make sure it is well-formed. In
  1477. particular:
  1478. - Checks Nodes have correct ownership (owned by this graph)
  1479. - Checks Nodes appear in topological order
  1480. - If this Graph has an owning GraphModule, checks that targets
  1481. exist in that GraphModule
  1482. """
  1483. # Check topo order
  1484. def check_arg(arg: Node, n: Optional[Node] = None) -> None:
  1485. context_str = f" of Node '{n}' " if n else " "
  1486. if arg.graph is not self:
  1487. raise RuntimeError(
  1488. f"Argument '{arg}'{context_str}does not belong to this Graph, "
  1489. f"but was used as an argument! If you are copying nodes from another graph, make "
  1490. f"sure to use ``arg_transform`` on node_copy() to remap values\n{self}"
  1491. )
  1492. if arg not in seen_values:
  1493. raise RuntimeError(
  1494. f"Argument '{arg}'{context_str}was used before it has been "
  1495. f"defined! Please check that Nodes in the graph are topologically ordered\n{self}"
  1496. )
  1497. seen_names: set[str] = set()
  1498. seen_values: set[Node] = set()
  1499. for node in self.nodes:
  1500. if node.op not in _legal_ops:
  1501. raise RuntimeError(f"Node {node} had unknown opcode {node.op}!")
  1502. if node.graph is not self:
  1503. raise RuntimeError(f"Node '{node}' does not belong to this Graph!")
  1504. if node not in self._find_nodes_lookup_table:
  1505. raise RuntimeError(f"Node '{node}' is not added to the side table")
  1506. for arg in node._input_nodes:
  1507. check_arg(arg, node)
  1508. seen_values.add(node)
  1509. if node.name in seen_names:
  1510. raise RuntimeError(f"Node redefined name {node.name}!")
  1511. seen_names.add(node.name)
  1512. # Check targets are legit
  1513. if self.owning_module:
  1514. for node in self.nodes:
  1515. if node.op == "call_function":
  1516. if not callable(node.target):
  1517. raise ValueError(
  1518. f"Node {node} target {node.target} has type {torch.typename(node.target)} but "
  1519. "a Callable is expected"
  1520. )
  1521. else:
  1522. if not isinstance(node.target, str):
  1523. raise ValueError(
  1524. f"Node {node} target {node.target} has type {torch.typename(node.target)} but "
  1525. "a str is expected"
  1526. )
  1527. if node.op in ["get_attr", "call_module"]:
  1528. target_atoms = node.target.split(".")
  1529. m_itr = self.owning_module
  1530. for i, atom in enumerate(target_atoms):
  1531. new_m_itr = getattr(m_itr, atom, None)
  1532. seen_qualname = ".".join(target_atoms[:i])
  1533. if new_m_itr is None:
  1534. raise RuntimeError(
  1535. f"Node {node} target {node.target} references nonexistent attribute "
  1536. f"{atom} of {seen_qualname}"
  1537. )
  1538. if node.op == "call_module" and not isinstance(
  1539. new_m_itr, torch.nn.Module
  1540. ):
  1541. raise RuntimeError(
  1542. f"Node {node} target {node.target} {atom} of {seen_qualname} does "
  1543. "not reference an nn.Module"
  1544. )
  1545. m_itr = new_m_itr
  1546. @compatibility(is_backward_compatible=True)
  1547. def eliminate_dead_code(
  1548. self, is_impure_node: Optional[Callable[[Node], bool]] = None
  1549. ) -> bool:
  1550. """
  1551. Remove all dead code from the graph, based on each node's number of
  1552. users, and whether the nodes have any side effects. The graph must be
  1553. topologically sorted before calling.
  1554. Args:
  1555. is_impure_node (Optional[Callable[[Node], bool]]): A function that returns
  1556. whether a node is impure. If this is None, then the default behavior is to
  1557. use Node.is_impure.
  1558. Returns:
  1559. bool: Whether the graph was changed as a result of the pass.
  1560. Example:
  1561. Before dead code is eliminated, `a` from `a = x + 1` below has no users
  1562. and thus can be eliminated from the graph without having an effect.
  1563. .. code-block:: python
  1564. def forward(self, x):
  1565. a = x + 1
  1566. return x + self.attr_1
  1567. After dead code is eliminated, `a = x + 1` has been removed, and the rest
  1568. of `forward` remains.
  1569. .. code-block:: python
  1570. def forward(self, x):
  1571. return x + self.attr_1
  1572. .. warning::
  1573. Dead code elimination has some heuristics to avoid removing
  1574. side-effectful nodes (see Node.is_impure) but in general coverage
  1575. is very bad, so you should assume that this method is not sound
  1576. to call unless you know that your FX graph consists entirely
  1577. of functional operations or you supply your own custom
  1578. function for detecting side-effectful nodes.
  1579. """
  1580. from torch.utils._ordered_set import OrderedSet
  1581. # Lint the graph first to make sure its topologically sorted, otherwise
  1582. # DCE below will not behave as expected.
  1583. self.lint()
  1584. impure_random = True
  1585. if torch._guards.TracingContext.try_get():
  1586. impure_random = torch._inductor.config.fallback_random
  1587. def has_side_effect(node):
  1588. if is_impure_node is not None:
  1589. return is_impure_node(node)
  1590. return node.is_impure(impure_random)
  1591. # Reverse iterate so that when we remove a node, any nodes used as an
  1592. # input to that node have an updated user count that no longer reflects
  1593. # the removed node.
  1594. changed = False
  1595. for node in reversed(self.nodes):
  1596. if not has_side_effect(node) and len(node.users) == 0:
  1597. self.erase_node(node)
  1598. changed = True
  1599. # Call DCE on the subgraphs
  1600. if self.owning_module is not None:
  1601. subgraph_names = OrderedSet(
  1602. x.target for x in self.find_nodes(op="get_attr")
  1603. )
  1604. for child_name, child_module in self.owning_module.named_children():
  1605. # Sometimes an owning_module can have unused children. Skip them
  1606. # by checking them from get_attr node targets.
  1607. if child_name in subgraph_names and isinstance(
  1608. child_module, torch.fx.GraphModule
  1609. ):
  1610. changed |= child_module.graph.eliminate_dead_code()
  1611. child_module.recompile()
  1612. return changed
  1613. @compatibility(is_backward_compatible=False)
  1614. def set_codegen(self, codegen: CodeGen):
  1615. self._codegen = codegen
  1616. @compatibility(is_backward_compatible=False)
  1617. def on_generate_code(
  1618. self,
  1619. make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc],
  1620. ):
  1621. """Register a transformer function when python code is generated
  1622. Args:
  1623. make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):
  1624. a function that returns a code transformer to be registered.
  1625. This function is called by `on_generate_code` to obtain the
  1626. code transformer.
  1627. This function is also given as its input the currently
  1628. registered code transformer (or None if nothing is registered),
  1629. in case it is not desirable to overwrite it. This is useful to
  1630. chain code transformers together.
  1631. Returns:
  1632. a context manager that when used in a `with` statement, to automatically
  1633. restore the previously registered code transformer.
  1634. Example:
  1635. .. code-block:: python
  1636. gm: fx.GraphModule = ...
  1637. # This is a code transformer we want to register. This code
  1638. # transformer prepends a pdb import and trace statement at the very
  1639. # beginning of the generated torch.fx code to allow for manual
  1640. # debugging with the PDB library.
  1641. def insert_pdb(body):
  1642. return ["import pdb; pdb.set_trace()\\n", *body]
  1643. # Registers `insert_pdb`, and overwrites the current registered
  1644. # code transformer (given by `_` to the lambda):
  1645. gm.graph.on_generate_code(lambda _: insert_pdb)
  1646. # Or alternatively, registers a code transformer which first
  1647. # runs `body` through existing registered transformer, then
  1648. # through `insert_pdb`:
  1649. gm.graph.on_generate_code(
  1650. lambda current_trans: (
  1651. lambda body: insert_pdb(
  1652. current_trans(body) if current_trans else body
  1653. )
  1654. )
  1655. )
  1656. gm.recompile()
  1657. gm(*inputs) # drops into pdb
  1658. This function can also be used as a context manager, with the benefit to
  1659. automatically restores the previously registered code transformer:
  1660. .. code-block:: python
  1661. # ... continue from previous example
  1662. with gm.graph.on_generate_code(lambda _: insert_pdb):
  1663. # do more stuff with `gm`...
  1664. gm.recompile()
  1665. gm(*inputs) # drops into pdb
  1666. # now previous code transformer is restored (but `gm`'s code with pdb
  1667. # remains - that means you can run `gm` with pdb here too, until you
  1668. # run next `recompile()`).
  1669. """
  1670. on_gen_code_old = self._codegen._body_transformer
  1671. self._codegen._body_transformer = make_transformer(on_gen_code_old)
  1672. @contextlib.contextmanager
  1673. def on_generate_code_context_manager():
  1674. try:
  1675. yield
  1676. finally:
  1677. self._codegen._body_transformer = on_gen_code_old
  1678. return on_generate_code_context_manager()
  1679. @contextmanager
  1680. def _override_sym_repr(
  1681. override: Callable[["torch.types.PySymType"], str],
  1682. ) -> Iterator[None]:
  1683. tmp = CodeGen._sym_repr
  1684. try:
  1685. CodeGen._sym_repr = override
  1686. yield
  1687. finally:
  1688. CodeGen._sym_repr = tmp
  1689. def _identity(x):
  1690. return x
  1691. def _make_color_fn(code):
  1692. def f(s):
  1693. reset = "\033[0m"
  1694. return f"{code}{s}{reset}"
  1695. return f
  1696. _color_codes = {
  1697. "yellow": "\033[33m",
  1698. "cyan": "\033[36m",
  1699. "green": "\033[32m",
  1700. "blue": "\033[34m",
  1701. "red": "\033[31m",
  1702. "dim": "\033[2m",
  1703. "dim_blue": "\033[2m\033[34m",
  1704. "dim_green": "\033[2m\033[32m",
  1705. }
  1706. _color_fns = {k: _make_color_fn(v) for k, v in _color_codes.items()}
  1707. _counter_regexp = re.compile(r"# COUNTER: (\d+)")
  1708. reflectable_magic_methods = {
  1709. "add": "{} + {}",
  1710. "sub": "{} - {}",
  1711. "mul": "{} * {}",
  1712. "floordiv": "{} // {}",
  1713. "truediv": "{} / {}",
  1714. "div": "{} / {}",
  1715. "mod": "{} % {}",
  1716. "pow": "{} ** {}",
  1717. "lshift": "{} << {}",
  1718. "rshift": "{} >> {}",
  1719. "and_": "{} & {}",
  1720. "or_": "{} | {}",
  1721. "xor": "{} ^ {}",
  1722. "getitem": "{}[{}]",
  1723. "matmul": "{} @ {}",
  1724. }
  1725. magic_methods = {
  1726. "eq": "{} == {}",
  1727. "ne": "{} != {}",
  1728. "lt": "{} < {}",
  1729. "gt": "{} > {}",
  1730. "le": "{} <= {}",
  1731. "ge": "{} >= {}",
  1732. "pos": "+{}",
  1733. "neg": "-{}",
  1734. "invert": "~{}",
  1735. **reflectable_magic_methods,
  1736. }
  1737. inplace_methods = {
  1738. "iadd": "{} += {}",
  1739. "iand": "{} &= {}",
  1740. "ifloordiv": "{} //= {}",
  1741. "ilshift": "{} <<= {}",
  1742. "imod": "{} %= {}",
  1743. "imul": "{} *= {}",
  1744. "imatmul": "{} @= {}",
  1745. "ior": "{} |= {}",
  1746. "ipow": "{} **= {}",
  1747. "irshift": "{} >>= {}",
  1748. "isub": "{} -= {}",
  1749. "itruediv": "{} /= {}",
  1750. "ixor": "{} ^= {}",
  1751. "setitem": "{}[{}] = {}",
  1752. }