_symbolic_trace.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import collections
  4. import contextlib
  5. import copy
  6. import functools
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import warnings
  12. from itertools import chain
  13. from types import CodeType, FunctionType, ModuleType
  14. from typing import Any, Callable, get_args, NamedTuple, Optional, Union
  15. from typing_extensions import TypeAlias
  16. import torch
  17. import torch.utils._pytree as pytree
  18. from torch._C import ScriptObject # type: ignore[attr-defined]
  19. from torch._library.fake_class_registry import FakeScriptObject
  20. from ._compatibility import compatibility
  21. from ._lazy_graph_module import _make_graph_module
  22. from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
  23. from .graph_module import GraphModule
  24. from .node import Argument, base_types, map_aggregate
  25. from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase
  26. log = logging.getLogger(__name__)
  27. HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
  28. # These need to run in global scope to handle nested calls correctly
  29. _orig_module_call: Callable = torch.nn.Module.__call__
  30. _orig_module_getattr: Callable = torch.nn.Module.__getattr__
  31. _proxyable_classes: dict[type, None] = {}
  32. _is_fx_tracing_flag = False
  33. _ConstantAttributeType: TypeAlias = Union[
  34. torch.Tensor, torch.ScriptObject, FakeScriptObject, pytree.TreeSpec
  35. ]
  36. _constant_attribute_types = get_args(_ConstantAttributeType)
  37. # We only want to print this once to avoid flooding logs
  38. @functools.lru_cache
  39. def is_fx_tracing_warning():
  40. log.warning(
  41. "is_fx_tracing will return true for both fx.symbolic_trace and "
  42. "torch.export. Please use "
  43. "is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace "
  44. "or torch.compiler.is_compiling() for specifically torch.export/compile."
  45. )
  46. def is_fx_tracing():
  47. is_fx_tracing_warning()
  48. return _is_fx_tracing_flag
  49. def is_fx_symbolic_tracing():
  50. return _is_fx_tracing_flag and not torch.compiler.is_compiling()
  51. @compatibility(is_backward_compatible=True)
  52. class ProxyableClassMeta(type):
  53. """
  54. ProxyableClassMeta allows you to make construction of a given Python class
  55. symbolically traceable. For example::
  56. import torch
  57. import torch.fx
  58. class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
  59. def __init__(self, left, right):
  60. self.left, self.right = left, right
  61. def add(self, other):
  62. l = self.left + other.left
  63. r = self.right + other.right
  64. return TensorPair(l, r)
  65. def mul(self, other):
  66. l = self.left * other.left
  67. r = self.right * other.right
  68. return TensorPair(l, r)
  69. def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor):
  70. s = x.add(TensorPair(y, y))
  71. return s.mul(x)
  72. x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
  73. y = torch.randn(5, 3)
  74. ref_out = use_tensor_pair_ctor(x, y)
  75. traced = torch.fx.symbolic_trace(use_tensor_pair_ctor)
  76. print(traced.code)
  77. '''
  78. def forward(self, x : __main___TensorPair, y : torch.Tensor):
  79. tensor_pair = __main___TensorPair(y, y); y = None
  80. add = x.add(tensor_pair); tensor_pair = None
  81. mul = add.mul(x); add = x = None
  82. return mul
  83. '''
  84. From this example, we can see that construction of a class (``TensorPair``)
  85. defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic
  86. tracing.
  87. """
  88. def __init__(cls, name, bases, attrs):
  89. _proxyable_classes.setdefault(cls)
  90. super().__init__(name, bases, attrs)
  91. def __call__(cls, *args, **kwargs):
  92. instance = cls.__new__(cls) # type: ignore[call-overload]
  93. if not is_fx_tracing():
  94. cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
  95. return instance
  96. found_proxies = []
  97. def check_proxy(a):
  98. if isinstance(a, Proxy):
  99. found_proxies.append(a)
  100. map_aggregate(args, check_proxy)
  101. map_aggregate(kwargs, check_proxy)
  102. if len(found_proxies) != 0:
  103. tracer = found_proxies[0].tracer
  104. return tracer.create_proxy("call_function", cls, args, kwargs)
  105. else:
  106. cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
  107. return instance
  108. def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
  109. co = fn.__code__
  110. co_flags = co.co_flags & ~HAS_VARSTUFF
  111. co_args: tuple
  112. if hasattr(co, "co_qualname"):
  113. # Python-3.11+ code signature
  114. co_args = (
  115. nargs,
  116. 0,
  117. 0,
  118. co.co_nlocals,
  119. co.co_stacksize,
  120. co_flags,
  121. co.co_code,
  122. co.co_consts,
  123. co.co_names,
  124. co.co_varnames,
  125. co.co_filename,
  126. co.co_name,
  127. co.co_qualname, # type: ignore[attr-defined]
  128. co.co_firstlineno,
  129. co.co_linetable,
  130. co.co_exceptiontable, # type: ignore[attr-defined]
  131. co.co_freevars,
  132. co.co_cellvars,
  133. )
  134. elif hasattr(co, "co_posonlyargcount"):
  135. co_args = (
  136. nargs,
  137. 0,
  138. 0,
  139. co.co_nlocals,
  140. co.co_stacksize,
  141. co_flags,
  142. co.co_code,
  143. co.co_consts,
  144. co.co_names,
  145. co.co_varnames,
  146. co.co_filename,
  147. co.co_name,
  148. co.co_firstlineno,
  149. co.co_lnotab,
  150. co.co_freevars,
  151. co.co_cellvars,
  152. )
  153. else:
  154. co_args = (
  155. nargs,
  156. 0,
  157. co.co_nlocals,
  158. co.co_stacksize,
  159. co_flags,
  160. co.co_code,
  161. co.co_consts,
  162. co.co_names,
  163. co.co_varnames,
  164. co.co_filename,
  165. co.co_name,
  166. co.co_firstlineno,
  167. co.co_lnotab,
  168. co.co_freevars,
  169. co.co_cellvars,
  170. )
  171. new_code = CodeType(*co_args) # type: ignore[arg-type]
  172. return FunctionType(
  173. new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__
  174. )
  175. # we need to insert placeholder nodes for *args and **kwargs
  176. # we can't call this function normally, otherwise it would try to unpack them
  177. # instead, let's make python think that args and kwargs are normal variables
  178. @compatibility(is_backward_compatible=False)
  179. class PHBase:
  180. """
  181. Object representing an input placeholder to `concrete_args`
  182. """
  183. def __repr__(self):
  184. return "PH"
  185. PH = PHBase()
  186. @compatibility(is_backward_compatible=False)
  187. class PHWithMeta(PHBase):
  188. """
  189. Object representing an input placeholder to `concrete_args`
  190. """
  191. def __init__(self, ph_key: Optional[str] = None):
  192. super().__init__()
  193. # Provide a hey for user to identify placeholder node during analysis
  194. self.ph_key = ph_key
  195. def _transfer_attrs(fr, to):
  196. for attr_name in dir(fr):
  197. attr_val = getattr(fr, attr_name)
  198. if (
  199. not callable(attr_val)
  200. and not attr_name.startswith("__")
  201. and not hasattr(to, attr_name)
  202. ):
  203. setattr(to, attr_name, attr_val)
  204. @compatibility(is_backward_compatible=True)
  205. class Tracer(TracerBase):
  206. # Reference: https://github.com/pytorch/pytorch/issues/54354
  207. # The first line of this docstring overrides the one Sphinx generates for the
  208. # documentation. We need it so that Sphinx doesn't leak `math`s path from the
  209. # build environment (e.g. `<module 'math' from '/leaked/path').
  210. """Tracer(autowrap_modules=(math,), autowrap_functions=())
  211. ``Tracer`` is the class that implements the symbolic tracing functionality
  212. of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
  213. to ``Tracer().trace(m)``.
  214. Tracer can be subclassed to override various behaviors of the tracing
  215. process. The different behaviors that can be overridden are described
  216. in the docstrings of the methods on this class.
  217. """
  218. # Not checking BC on this API because the default value for `autowrap_modules`
  219. # includes the local filepath to the `math` module, which would jitter
  220. # across machines.
  221. @compatibility(is_backward_compatible=True)
  222. def __init__(
  223. self,
  224. autowrap_modules: tuple[ModuleType] = (math,),
  225. autowrap_functions: tuple[Callable, ...] = (),
  226. param_shapes_constant: bool = False,
  227. ) -> None:
  228. # This method's signature is overridden by the first line of this class'
  229. # docstring. If this method's signature is modified, the signature that
  230. # overrides it also should be modified accordingly.
  231. """
  232. Construct a Tracer object.
  233. Args:
  234. autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
  235. Python modules whose functions should be wrapped automatically
  236. without needing to use fx.wrap(). Backward-compatibility for
  237. this parameter is guaranteed.
  238. autowrap_functions (Tuple[Callable, ...]): defaults to `()`,
  239. Python functions that should be wrapped automatically without
  240. needing to use fx.wrap(). Backward compatibility for this
  241. parameter is guaranteed.
  242. param_shapes_constant (bool): When this flag is set, calls to shape,
  243. size and a few other shape like attributes of a module's parameter
  244. will be evaluated directly, rather than returning a new Proxy value
  245. for an attribute access. Backward compatibility for this parameter
  246. is guaranteed.
  247. """
  248. super().__init__()
  249. # Functions we will eagerly wrap when we see them while tracing
  250. # this captures both `math.sqrt()` and `from math import sqrt` automatically
  251. self._autowrap_function_ids: set[int] = {
  252. id(value)
  253. for name, value in chain.from_iterable(
  254. m.__dict__.items() for m in autowrap_modules
  255. )
  256. if not name.startswith("_") and callable(value)
  257. }
  258. self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
  259. # Python modules to apply autowrap to at the start, in addition to
  260. # modules we see while tracing
  261. self._autowrap_search: list[ModuleType] = list(autowrap_modules)
  262. self.param_shapes_constant = param_shapes_constant
  263. self.submodule_paths: Optional[dict[torch.nn.Module, str]] = None
  264. self.root_module_name: str = ""
  265. # Maps the containing module's name to the operator name
  266. self.scope = Scope("", None)
  267. # Records the module call stack
  268. self.module_stack = collections.OrderedDict()
  269. self.num_calls: dict[str, int] = {}
  270. # Mapping of node name to module scope
  271. self.node_name_to_scope: dict[str, tuple[str, type]] = {}
  272. _qualname_counter: dict[str, int] = collections.defaultdict(int)
  273. @compatibility(is_backward_compatible=True)
  274. def get_fresh_qualname(self, prefix: str) -> str:
  275. """
  276. Gets a fresh name for a prefix and returns it. This function ensures
  277. that it will not clash with an existing attribute on the graph.
  278. """
  279. # The idea here is that if the module doesn't have this prefix at all we
  280. # should reset the counter to start from the beginning
  281. # It's a ... little bit hacky (doesn't cover all cases) but the precise
  282. # naming of the prefixes isn't a correctness issue, just a niceness
  283. # issue
  284. qualname = f"{prefix}0"
  285. if not hasattr(self.root, qualname):
  286. self._qualname_counter[prefix] = 0
  287. return qualname
  288. i = self._qualname_counter[prefix]
  289. while True:
  290. qualname = f"{prefix}{i}"
  291. i += 1
  292. if not hasattr(self.root, qualname):
  293. break
  294. self._qualname_counter[prefix] = i
  295. return qualname
  296. @compatibility(is_backward_compatible=True)
  297. def create_arg(self, a: Any) -> "Argument":
  298. """
  299. A method to specify the behavior of tracing when preparing values to
  300. be used as arguments to nodes in the ``Graph``.
  301. By default, the behavior includes:
  302. #. Iterate through collection types (e.g. tuple, list, dict) and recursively
  303. call ``create_args`` on the elements.
  304. #. Given a Proxy object, return a reference to the underlying IR ``Node``
  305. #. Given a non-Proxy Tensor object, emit IR for various cases:
  306. * For a Parameter, emit a ``get_attr`` node referring to that Parameter
  307. * For a non-Parameter Tensor, store the Tensor away in a special
  308. attribute referring to that attribute.
  309. This method can be overridden to support more types.
  310. Args:
  311. a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
  312. Returns:
  313. The value ``a`` converted into the appropriate ``Argument``
  314. """
  315. # The base tracer is used to construct Graphs when there is no associated
  316. # module hierarchy, so it can never create parameter references.
  317. # The default tracer adds the ability to refer to parameters when
  318. # tracing modules.
  319. if isinstance(a, torch.nn.Parameter):
  320. for n, p in self.root.named_parameters():
  321. if a is p:
  322. return self.create_node("get_attr", n, (), {})
  323. raise NameError("parameter is not a member of this module")
  324. elif isinstance(a, torch.Tensor):
  325. for n_, p_ in self.root.named_buffers():
  326. if a is p_:
  327. return self.create_node("get_attr", n_, (), {})
  328. elif isinstance(a, torch.nn.Module):
  329. for n_, p_ in self.root.named_modules():
  330. if a is p_:
  331. return self.create_node("get_attr", n_, (), {})
  332. # For NamedTuple instances that appear literally as args, we emit
  333. # a node to construct the NamedTuple and use that Node as the argument.
  334. if isinstance(a, tuple) and hasattr(a, "_fields"):
  335. args = tuple(self.create_arg(elem) for elem in a)
  336. return self.create_node("call_function", a.__class__, args, {})
  337. # Tensors do not have a reliable string repr() from which they can be
  338. # constructed (and we probably don't want to rely on that, either), so
  339. # for any constant Tensor values we encounter, first search for if they
  340. # are an attribute of some module in the module hierarchy. If so, emit
  341. # a get_attr to retrieve that tensor. Otherwise, we'll store away the
  342. # tensor value into a special attribute on the Module s.t. we can
  343. # retrieve it with a get_attr.
  344. if isinstance(a, _constant_attribute_types):
  345. qualname: Optional[str] = self.tensor_attrs.get(a)
  346. # Tensor was not found in the Module hierarchy, stow it away in a
  347. # special attribute and set the qualname to refer to that
  348. if not qualname:
  349. if isinstance(a, torch.Tensor):
  350. base_name = "_tensor_constant"
  351. elif isinstance(a, (FakeScriptObject, ScriptObject)):
  352. base_name = "_torchbind_obj"
  353. elif isinstance(a, pytree.TreeSpec):
  354. base_name = "_tree_spec_constant"
  355. else:
  356. raise RuntimeError(
  357. f"cannot create constant arg for {a} of type {type(a)}."
  358. )
  359. qualname = self.get_fresh_qualname(base_name)
  360. assert isinstance(qualname, str)
  361. self.tensor_attrs[a] = qualname
  362. setattr(self.root, qualname, a)
  363. return self.create_node("get_attr", qualname, (), {})
  364. if type(a) in _proxyable_classes:
  365. # This is an instance of a proxyable class for which we did not
  366. # witness its construction. Intern this as a constant attribute
  367. # TODO: binary search
  368. qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_")
  369. assert isinstance(qualname, str)
  370. setattr(self.root, qualname, a)
  371. return self.create_node("get_attr", qualname, (), {})
  372. return super().create_arg(a)
  373. @compatibility(is_backward_compatible=True)
  374. def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
  375. """
  376. A method to specify whether a given ``nn.Module`` is a "leaf" module.
  377. Leaf modules are the atomic units that appear in
  378. the IR, referenced by ``call_module`` calls. By default,
  379. Modules in the PyTorch standard library namespace (torch.nn)
  380. are leaf modules. All other modules are traced through and
  381. their constituent ops are recorded, unless specified otherwise
  382. via this parameter.
  383. Args:
  384. m (Module): The module being queried about
  385. module_qualified_name (str): The path to root of this module. For example,
  386. if you have a module hierarchy where submodule ``foo`` contains
  387. submodule ``bar``, which contains submodule ``baz``, that module will
  388. appear with the qualified name ``foo.bar.baz`` here.
  389. """
  390. return (
  391. m.__module__.startswith("torch.nn")
  392. or m.__module__.startswith("torch.ao.nn")
  393. ) and not isinstance(m, torch.nn.Sequential)
  394. @compatibility(is_backward_compatible=True)
  395. def path_of_module(self, mod: torch.nn.Module) -> str:
  396. """
  397. Helper method to find the qualified name of ``mod`` in the Module hierarchy
  398. of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
  399. a submodule named ``bar``, passing ``bar`` into this function will return
  400. the string "foo.bar".
  401. Args:
  402. mod (str): The ``Module`` to retrieve the qualified name for.
  403. """
  404. # Prefer the O(1) algorithm
  405. if self.submodule_paths:
  406. path = self.submodule_paths.get(mod)
  407. if path is None:
  408. raise NameError("module is not installed as a submodule")
  409. assert isinstance(path, str)
  410. return path
  411. # O(N^2) fallback in the case that we didn't store the submodule
  412. # paths.
  413. else:
  414. for n, p in self.root.named_modules():
  415. if mod is p:
  416. return n
  417. raise NameError("module is not installed as a submodule")
  418. @compatibility(is_backward_compatible=True)
  419. def call_module(
  420. self,
  421. m: torch.nn.Module,
  422. forward: Callable[..., Any],
  423. args: tuple[Any, ...],
  424. kwargs: dict[str, Any],
  425. ) -> Any:
  426. """
  427. Method that specifies the behavior of this ``Tracer`` when it encounters
  428. a call to an ``nn.Module`` instance.
  429. By default, the behavior is to check if the called module is a leaf module
  430. via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
  431. ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
  432. the operations in its ``forward`` function.
  433. This method can be overridden to--for example--create nested traced
  434. GraphModules, or any other behavior you would want while tracing across
  435. ``Module`` boundaries.
  436. Args:
  437. m (Module): The module for which a call is being emitted
  438. forward (Callable): The forward() method of the ``Module`` to be invoked
  439. args (Tuple): args of the module callsite
  440. kwargs (Dict): kwargs of the module callsite
  441. Return:
  442. The return value from the Module call. In the case that a ``call_module``
  443. node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
  444. value was returned from the ``Module`` invocation.
  445. """
  446. module_qualified_name = self.path_of_module(m)
  447. with ScopeContextManager(
  448. self.scope, Scope(module_qualified_name, type(m))
  449. ) as _scope:
  450. # module_stack is an ordered dict so writing then deleting the
  451. # entry is equivalent to push/pop on a list
  452. num_calls = self.num_calls.get(module_qualified_name, 0)
  453. module_key = (
  454. f"{_scope.module_path}@{num_calls}"
  455. if num_calls > 0
  456. else _scope.module_path
  457. )
  458. self.module_stack[module_key] = (module_qualified_name, _scope.module_type)
  459. self.num_calls[module_qualified_name] = num_calls + 1
  460. if not self.is_leaf_module(m, module_qualified_name):
  461. ret_val = forward(*args, **kwargs)
  462. else:
  463. ret_val = self.create_proxy(
  464. "call_module", module_qualified_name, args, kwargs
  465. )
  466. key, _ = self.module_stack.popitem(last=True)
  467. assert key == module_key, f" Unexpected key {key}"
  468. return ret_val
  469. @compatibility(is_backward_compatible=False)
  470. def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
  471. """
  472. Method that specifies the behavior of this ``Tracer`` when we call getattr
  473. on a call to an ``nn.Module`` instance.
  474. By default, the behavior is to return a proxy value for the attribute. It
  475. also stores the proxy value in the ``parameter_proxy_cache``, so that future
  476. calls will reuse the proxy rather than creating a new one.
  477. This method can be overridden to --for example-- not return proxies when
  478. querying parameters.
  479. Args:
  480. attr (str): The name of the attribute being queried
  481. attr_val (Any): The value of the attribute
  482. parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies
  483. Return:
  484. The return value from the getattr call.
  485. """
  486. def maybe_get_proxy_for_attr(
  487. attr_val, collection_to_search, parameter_proxy_cache
  488. ):
  489. for n, p in collection_to_search:
  490. if attr_val is p:
  491. if n not in parameter_proxy_cache:
  492. kwargs = {}
  493. if (
  494. "proxy_factory_fn"
  495. in inspect.signature(self.create_proxy).parameters
  496. ):
  497. kwargs["proxy_factory_fn"] = (
  498. None
  499. if not self.param_shapes_constant
  500. else lambda node: ParameterProxy(
  501. self, node, n, attr_val
  502. )
  503. )
  504. val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
  505. parameter_proxy_cache[n] = val_proxy
  506. return parameter_proxy_cache[n]
  507. return None
  508. if isinstance(attr_val, torch.nn.Parameter):
  509. maybe_parameter_proxy = maybe_get_proxy_for_attr(
  510. attr_val, self.root.named_parameters(), parameter_proxy_cache
  511. )
  512. if maybe_parameter_proxy is not None:
  513. return maybe_parameter_proxy
  514. if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
  515. maybe_buffer_proxy = maybe_get_proxy_for_attr(
  516. attr_val, self.root.named_buffers(), parameter_proxy_cache
  517. )
  518. if maybe_buffer_proxy is not None:
  519. return maybe_buffer_proxy
  520. return attr_val
  521. # This method will be refactored
  522. @compatibility(is_backward_compatible=False)
  523. def create_args_for_root(self, root_fn, is_module, concrete_args=None):
  524. """
  525. Create ``placeholder`` nodes corresponding to the signature of the ``root``
  526. Module. This method introspects root's signature and emits those
  527. nodes accordingly, also supporting ``*args`` and ``**kwargs``.
  528. """
  529. # In some cases, a function or method has been decorated with a wrapper
  530. # defined via ``functools.wraps``. In this case, the outer code object
  531. # will likely not contain the actual parameters we care about, so unwrap
  532. # the function to get to the innermost callable.
  533. fn_for_analysis = inspect.unwrap(root_fn)
  534. co = fn_for_analysis.__code__
  535. total_args = co.co_argcount + co.co_kwonlyargcount
  536. orig_args = list(co.co_varnames)
  537. names_iter = iter(co.co_varnames)
  538. args: list[Any] = []
  539. skip_arg_idx = 0
  540. if is_module:
  541. if total_args == 0:
  542. raise RuntimeError(
  543. "``self`` argument cannot be part of *args expansion!"
  544. )
  545. skip_arg_idx = 1
  546. next(names_iter) # skip self
  547. args.append(self.root)
  548. sig = inspect.signature(fn_for_analysis)
  549. # This covers the very specific case where we are passing in flat
  550. # concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
  551. # In this case, just take the concrete_args and pass them through.
  552. name_idx = 0
  553. if (
  554. isinstance(concrete_args, tuple)
  555. and len(concrete_args) > 0
  556. and (co.co_flags & HAS_VARSTUFF)
  557. and total_args == 1
  558. ):
  559. for concrete_arg in concrete_args:
  560. out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
  561. if isinstance(concrete_arg, PHBase):
  562. if concrete_arg != PH:
  563. # Transfer attrs in the case where you're using a placeholder other
  564. # than the singleton PH (PH has no attributes to transfer).
  565. # Proxies were created out of the placeholders.
  566. # Transfer any metadata (put on the placeholders in the form of
  567. # attributes set by the user) from the placeholder to the
  568. # underlying nodes (the proxy is unwrapped by the user, but
  569. # the metadata should hold).
  570. _transfer_attrs(fr=concrete_arg, to=out.node)
  571. args.append(out)
  572. name_idx += 1
  573. return root_fn, args
  574. arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
  575. if isinstance(concrete_args, tuple):
  576. if len(arg_names) != len(concrete_args):
  577. raise RuntimeError(
  578. f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
  579. )
  580. concrete_args = dict(zip(arg_names, concrete_args))
  581. def proxy_placeholder(name):
  582. return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis)
  583. args.extend(proxy_placeholder(names) for names in arg_names)
  584. if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
  585. # TODO: type annotations for *args and **kwargs
  586. if co.co_flags & inspect.CO_VARARGS:
  587. args.append(proxy_placeholder("*" + next(names_iter)))
  588. if co.co_flags & inspect.CO_VARKEYWORDS:
  589. args.append(proxy_placeholder("**" + next(names_iter)))
  590. root_fn = _patch_function(root_fn, len(args))
  591. flat_args, in_spec = pytree.tree_flatten(tuple(args))
  592. if not all(child.is_leaf() for child in in_spec.children_specs):
  593. # In the case that we have pytree-flattened inputs in
  594. # `concrete_args`, generate a flattening wrapper around the
  595. # original root function and return that.
  596. self.graph._codegen = _PyTreeCodeGen( # type: ignore[has-type]
  597. _PyTreeInfo(orig_args[:total_args], in_spec, None)
  598. )
  599. def flatten_fn(*args):
  600. tree_args = pytree.tree_unflatten(list(args), in_spec)
  601. tree_out = root_fn(*tree_args)
  602. out_args, out_spec = pytree.tree_flatten(tree_out)
  603. assert isinstance(self.graph._codegen, _PyTreeCodeGen) # type: ignore[has-type]
  604. self.graph._codegen.pytree_info = (
  605. self.graph._codegen.pytree_info._replace(out_spec=out_spec)
  606. )
  607. return out_args
  608. return flatten_fn, flat_args
  609. return root_fn, args
  610. @compatibility(is_backward_compatible=True)
  611. def trace(
  612. self,
  613. root: Union[torch.nn.Module, Callable[..., Any]],
  614. concrete_args: Optional[dict[str, Any]] = None,
  615. ) -> Graph:
  616. """
  617. Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
  618. can either be an ``nn.Module`` instance or a Python callable.
  619. Note that after this call, ``self.root`` may be different from the ``root`` passed
  620. in here. For example, when a free function is passed to ``trace()``, we will
  621. create an ``nn.Module`` instance to use as the root and add embedded constants
  622. to.
  623. Args:
  624. root (Union[Module, Callable]): Either a ``Module`` or a function to be
  625. traced through. Backwards-compatibility for this parameter is
  626. guaranteed.
  627. concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
  628. not be treated as Proxies. This parameter is experimental and
  629. its backwards-compatibility is *NOT* guaranteed.
  630. Returns:
  631. A ``Graph`` representing the semantics of the passed-in ``root``.
  632. """
  633. global _is_fx_tracing_flag
  634. old_is_fx_tracing_flag = _is_fx_tracing_flag
  635. _is_fx_tracing_flag = True
  636. try:
  637. if isinstance(root, torch.nn.Module):
  638. # do real recompilation for _LazyGraphModule before retracing since the trace
  639. # method can not trace the _lazy_forward method. Got error:
  640. # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
  641. # without this.
  642. from torch.fx._lazy_graph_module import _LazyGraphModule
  643. _LazyGraphModule.force_recompile(root)
  644. self.root = root
  645. assert hasattr(type(root), self.traced_func_name), (
  646. f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
  647. )
  648. fn = getattr(type(root), self.traced_func_name)
  649. self.root_module_name = root._get_name()
  650. self.submodule_paths = {mod: name for name, mod in root.named_modules()}
  651. else:
  652. self.root = torch.nn.Module()
  653. fn = root
  654. tracer_cls: Optional[type[Tracer]] = getattr(self, "__class__", None)
  655. self.graph = Graph(tracer_cls=tracer_cls)
  656. if hasattr(fn, "__code__"):
  657. code = fn.__code__
  658. self.graph._co_fields = {
  659. "co_name": code.co_name,
  660. "co_filename": code.co_filename,
  661. "co_firstlineno": code.co_firstlineno,
  662. }
  663. # When we encounter a Tensor value that's not a parameter, we look if it
  664. # is some other attribute on the model. Construct a dict mapping Tensor
  665. # values to the qualified name here for efficiency. This is used downstream
  666. # in create_arg
  667. self.tensor_attrs: dict[
  668. _ConstantAttributeType,
  669. str,
  670. ] = {}
  671. def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]):
  672. for k, v in m.__dict__.items():
  673. if isinstance(v, _constant_attribute_types):
  674. self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
  675. for k, v in m.named_children():
  676. collect_tensor_attrs(v, prefix_atoms + [k])
  677. collect_tensor_attrs(self.root, [])
  678. assert isinstance(fn, FunctionType)
  679. fn_globals = fn.__globals__ # run before it gets patched
  680. fn, args = self.create_args_for_root(
  681. fn, isinstance(root, torch.nn.Module), concrete_args
  682. )
  683. parameter_proxy_cache: dict[
  684. str, Proxy
  685. ] = {} # Reduce number of get_attr calls
  686. # Method dispatch on parameters is not recorded unless it's directly used.
  687. # Thus, we need to insert a proxy when __getattr__ requests a parameter.
  688. @functools.wraps(_orig_module_getattr)
  689. def module_getattr_wrapper(mod, attr):
  690. attr_val = _orig_module_getattr(mod, attr)
  691. return self.getattr(attr, attr_val, parameter_proxy_cache)
  692. @functools.wraps(_orig_module_call)
  693. def module_call_wrapper(mod, *args, **kwargs):
  694. def forward(*args, **kwargs):
  695. return _orig_module_call(mod, *args, **kwargs)
  696. _autowrap_check(
  697. patcher, # type: ignore[has-type]
  698. getattr(getattr(mod, "forward", mod), "__globals__", {}),
  699. self._autowrap_function_ids,
  700. )
  701. return self.call_module(mod, forward, args, kwargs)
  702. with _new_patcher() as patcher:
  703. # allow duplicate patches to support the case of nested calls
  704. patcher.patch_method(
  705. torch.nn.Module,
  706. "__getattr__",
  707. module_getattr_wrapper,
  708. deduplicate=False,
  709. )
  710. patcher.patch_method(
  711. torch.nn.Module,
  712. "__call__",
  713. module_call_wrapper,
  714. deduplicate=False,
  715. )
  716. _patch_wrapped_functions(patcher)
  717. _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
  718. for module in self._autowrap_search:
  719. _autowrap_check(
  720. patcher, module.__dict__, self._autowrap_function_ids
  721. )
  722. self.create_node(
  723. "output",
  724. "output",
  725. (self.create_arg(fn(*args)),),
  726. {},
  727. type_expr=fn.__annotations__.get("return", None),
  728. )
  729. self.submodule_paths = None
  730. except RuntimeError as e:
  731. if isinstance(e.args[0], str) and "data-dependent" in e.args[0]:
  732. partial_fx_graph = self.graph.python_code(
  733. root_module="self",
  734. verbose=True,
  735. ).src
  736. e.partial_fx_graph = partial_fx_graph # type: ignore[attr-defined]
  737. raise
  738. raise
  739. finally:
  740. _is_fx_tracing_flag = old_is_fx_tracing_flag
  741. return self.graph
  742. def __deepcopy__(self, memo):
  743. # _autowrap_search contains modules, which cannot be deepcopied.
  744. new_tracer = Tracer.__new__(Tracer)
  745. for k, v in self.__dict__.items():
  746. if k in {"_autowrap_search"}:
  747. new_obj = copy.copy(v)
  748. else:
  749. new_obj = copy.deepcopy(v, memo)
  750. new_tracer.__dict__[k] = new_obj
  751. return new_tracer
  752. def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis):
  753. if concrete_args is not None and name in concrete_args:
  754. cnt = 0
  755. def replace_ph(x):
  756. nonlocal cnt
  757. cnt += 1
  758. param = sig.parameters[name]
  759. default: tuple[Any, ...] = (
  760. () if param.default is inspect.Parameter.empty else (param.default,)
  761. )
  762. out = self.create_proxy(
  763. "placeholder", f"{name}_{str(cnt)}", default, {}
  764. )
  765. if isinstance(x, PHBase):
  766. if x != PH:
  767. # Transfer attrs in the case where you're using a placeholder other
  768. # than the singleton PH (PH has no attributes to transfer).
  769. # Proxies were created out of the placeholders.
  770. # Transfer any metadata (put on the placeholders in the form of
  771. # attributes set by the user) from the placeholder to the
  772. # underlying nodes (the proxy is unwrapped by the user, but
  773. # the metadata should hold).
  774. _transfer_attrs(fr=x, to=out.node)
  775. return out
  776. # Union[int, bool] == bool in Python <= 3.6
  777. if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor:
  778. torch._assert(
  779. out == x,
  780. f"{name} has been specialized to have value {x} but got another value",
  781. )
  782. elif x is None:
  783. args = (
  784. out,
  785. f"{name} has been specialized to have value None but got another value",
  786. )
  787. self.create_proxy("call_function", _assert_is_none, args, {})
  788. else:
  789. warnings.warn(
  790. f"Was not able to add assertion to guarantee correct input {name} to "
  791. f"specialized function. It is up to the user to make sure that your inputs match the "
  792. f"inputs you specialized the function with."
  793. )
  794. return x
  795. return pytree.tree_map(replace_ph, concrete_args[name])
  796. if name[0] == "*":
  797. default: tuple[Any, ...] = ()
  798. else:
  799. param = sig.parameters[name]
  800. default = ( # type: ignore[assignment]
  801. () if param.default is inspect.Parameter.empty else (param.default,)
  802. )
  803. return self.create_proxy(
  804. "placeholder",
  805. name,
  806. default,
  807. {},
  808. type_expr=fn_for_analysis.__annotations__.get(name, None),
  809. )
  810. # Dictionary of (id(globals dict), function name) => globals_dict to patch for
  811. # the purposes of the wrap() API.
  812. # We key by the globals dict id and function name to ensure we're wrapping a given
  813. # function only once.
  814. _wrapped_fns_to_patch: dict[tuple[int, str], dict] = {}
  815. # List of methods on classes to wrap (class type, function name)
  816. # this currently only works for Tensor.* methods that aren't traced properly
  817. _wrapped_methods_to_patch: list[tuple[type, str]] = []
  818. if os.environ.get("FX_PATCH_GETITEM") == "1":
  819. # This change is needed to trace models like PositionalEmbedding from BERT:
  820. # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py
  821. # but causes issues in quantization documented here:
  822. # https://github.com/pytorch/pytorch/issues/50710
  823. # once that is fixed we can make this the default behavior.
  824. _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
  825. def _find_proxy(*objects_to_search):
  826. """
  827. Recursively search a data structure for a Proxy() and return it,
  828. return None if not found.
  829. """
  830. proxy = None
  831. def find_proxy(x):
  832. nonlocal proxy
  833. if isinstance(x, Proxy):
  834. proxy = x
  835. map_aggregate(objects_to_search, find_proxy)
  836. return proxy
  837. def _create_wrapped_func(orig_fn):
  838. @functools.wraps(orig_fn)
  839. def wrapped(*args, **kwargs):
  840. """
  841. Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
  842. a Proxy object. If there is one, emit a ``call_function`` node to preserve the
  843. call to this leaf function directly. Otherwise, just return the results of
  844. this function call, as this function is not being traced.
  845. """
  846. proxy = _find_proxy(args, kwargs)
  847. if proxy is not None:
  848. return_proxy = proxy.tracer.create_proxy(
  849. "call_function", orig_fn, args, kwargs
  850. )
  851. return_proxy.node.meta["is_wrapped"] = True
  852. return return_proxy
  853. return orig_fn(*args, **kwargs)
  854. return wrapped
  855. def _create_wrapped_method(cls, name):
  856. orig_fn = getattr(cls, name)
  857. @functools.wraps(orig_fn)
  858. def wrapped(*args, **kwargs):
  859. """
  860. Search the args and kwargs for a Proxy object. If there is one,
  861. emit a ``call_method`` node to preserve the call to this method
  862. directly. Otherwise, just return the results of this function
  863. call, as this function is not being traced.
  864. """
  865. proxy = _find_proxy(args, kwargs)
  866. if proxy is not None:
  867. return proxy.tracer.create_proxy("call_method", name, args, kwargs)
  868. return orig_fn(*args, **kwargs)
  869. return wrapped
  870. class _PatchedFn(NamedTuple):
  871. frame_dict: Any
  872. fn_name: str
  873. orig_fn: Any
  874. new_fn: Any
  875. def revert(self):
  876. raise NotImplementedError
  877. def patch(self):
  878. raise NotImplementedError
  879. class _PatchedFnSetItem(_PatchedFn):
  880. def revert(self):
  881. self.frame_dict[self.fn_name] = self.orig_fn
  882. def patch(self):
  883. self.frame_dict[self.fn_name] = self.new_fn
  884. class _PatchedFnDel(_PatchedFn):
  885. def revert(self):
  886. del self.frame_dict[self.fn_name]
  887. def patch(self):
  888. self.frame_dict[self.fn_name] = self.new_fn
  889. class _PatchedFnSetAttr(_PatchedFn):
  890. def revert(self):
  891. setattr(self.frame_dict, self.fn_name, self.orig_fn)
  892. def patch(self):
  893. setattr(self.frame_dict, self.fn_name, self.new_fn)
  894. class _Patcher:
  895. def __init__(self) -> None:
  896. super().__init__()
  897. self.patches_made: list[_PatchedFn] = []
  898. self.visited: set[int] = set()
  899. def patch(
  900. self,
  901. frame_dict: dict[str, Any],
  902. name: str,
  903. new_fn: Callable,
  904. deduplicate: bool = True,
  905. ):
  906. """
  907. Replace frame_dict[name] with new_fn until we exit the context manager.
  908. """
  909. new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
  910. if name not in frame_dict and hasattr(builtins, name):
  911. self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn))
  912. self.patches_made[-1].patch()
  913. elif getattr(frame_dict[name], "__fx_already_patched", False):
  914. return # already patched, no need to do it again
  915. else:
  916. self.patches_made.append(
  917. _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn)
  918. )
  919. self.patches_made[-1].patch()
  920. def patch_method(
  921. self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
  922. ):
  923. """
  924. Replace object_or_dict.name with new_fn until we exit the context manager.
  925. """
  926. new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
  927. orig_fn = getattr(cls, name)
  928. if getattr(orig_fn, "__fx_already_patched", False):
  929. return # already patched, no need to do it again
  930. self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn))
  931. self.patches_made[-1].patch()
  932. def visit_once(self, thing: Any):
  933. """Return True on the first call to with thing, otherwise false"""
  934. idx = id(thing)
  935. if idx in self.visited:
  936. return False
  937. self.visited.add(idx)
  938. return True
  939. def revert_all_patches(self):
  940. """
  941. Remove all the stored patcheds. It doesn't modify patches_made.
  942. """
  943. for patch in self.patches_made:
  944. patch.revert()
  945. return self.patches_made
  946. def reapply_all_patches(self):
  947. """
  948. Patch all the stored patcheds. It doesn't modify patches_made.
  949. """
  950. for patch in self.patches_made:
  951. patch.patch()
  952. return self.patches_made
  953. def __enter__(self):
  954. return self
  955. def __exit__(self, exc_type, exc_val, exc_tb):
  956. """
  957. Undo all the changes made via self.patch() and self.patch_method()
  958. """
  959. while self.patches_made:
  960. # unpatch in reverse order to handle duplicates correctly
  961. self.patches_made.pop().revert()
  962. self.visited.clear()
  963. CURRENT_PATCHER: Optional[_Patcher] = None
  964. @contextlib.contextmanager
  965. def _new_patcher():
  966. global CURRENT_PATCHER
  967. prior_patcher = CURRENT_PATCHER
  968. try:
  969. CURRENT_PATCHER = _Patcher()
  970. yield CURRENT_PATCHER
  971. finally:
  972. # Clear all the patches made by when using current patcher.
  973. assert CURRENT_PATCHER is not None
  974. CURRENT_PATCHER.revert_all_patches()
  975. CURRENT_PATCHER = prior_patcher
  976. @contextlib.contextmanager
  977. def _maybe_revert_all_patches():
  978. current_patcher = CURRENT_PATCHER
  979. patches_made = None
  980. patches_removed = None
  981. try:
  982. if current_patcher is not None:
  983. patches_removed = current_patcher.revert_all_patches()
  984. yield
  985. finally:
  986. if current_patcher is not None:
  987. patches_made = current_patcher.reapply_all_patches()
  988. assert patches_made == patches_removed, (
  989. "CURRENT_PATCHER was changed during a revert_all_patches"
  990. )
  991. def _patch_wrapped_functions(patcher: _Patcher):
  992. """
  993. Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
  994. the listed global functions in the `_create_wrapped_func` wrapper.
  995. """
  996. for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items():
  997. if name not in frame_dict and hasattr(builtins, name):
  998. orig_fn = getattr(builtins, name)
  999. else:
  1000. orig_fn = frame_dict[name]
  1001. patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
  1002. for cls, name in _wrapped_methods_to_patch:
  1003. patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
  1004. def _autowrap_check(
  1005. patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int]
  1006. ):
  1007. """
  1008. Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
  1009. This method searches a scope for them and patches them if found.
  1010. """
  1011. if patcher.visit_once(frame_dict):
  1012. for name, value in frame_dict.items():
  1013. if (
  1014. not name.startswith("_")
  1015. and callable(value)
  1016. and id(value) in function_ids
  1017. ):
  1018. patcher.patch(frame_dict, name, _create_wrapped_func(value))
  1019. @compatibility(is_backward_compatible=True)
  1020. def wrap(fn_or_name: Union[str, Callable]):
  1021. """
  1022. This function can be called at module-level scope to register fn_or_name as a "leaf function".
  1023. A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
  1024. traced through::
  1025. # foo/bar/baz.py
  1026. def my_custom_function(x, y):
  1027. return x * x + y * y
  1028. torch.fx.wrap("my_custom_function")
  1029. def fn_to_be_traced(x, y):
  1030. # When symbolic tracing, the below call to my_custom_function will be inserted into
  1031. # the graph rather than tracing it.
  1032. return my_custom_function(x, y)
  1033. This function can also equivalently be used as a decorator::
  1034. # foo/bar/baz.py
  1035. @torch.fx.wrap
  1036. def my_custom_function(x, y):
  1037. return x * x + y * y
  1038. A wrapped function can be thought of a "leaf function", analogous to the concept of
  1039. "leaf modules", that is, they are functions that are left as calls in the FX trace
  1040. rather than traced through.
  1041. Args:
  1042. fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
  1043. graph when it's called
  1044. """
  1045. if not callable(fn_or_name) and not isinstance(fn_or_name, str):
  1046. raise RuntimeError(
  1047. "Unsupported type for global function! Must be either a callable or "
  1048. "string name"
  1049. )
  1050. if callable(fn_or_name):
  1051. assert not isinstance(fn_or_name, str) # to make mypy happy
  1052. fn_name = fn_or_name.__name__
  1053. else:
  1054. assert isinstance(fn_or_name, str), (
  1055. "fn_or_name must be a global function or string name"
  1056. )
  1057. fn_name = fn_or_name
  1058. currentframe = inspect.currentframe()
  1059. assert currentframe is not None
  1060. f = currentframe.f_back
  1061. assert f is not None
  1062. if f.f_code.co_name != "<module>":
  1063. raise NotImplementedError("wrap must be called at the top level of a module")
  1064. # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
  1065. # semantics would be slightly different, but would add support `from x import wrapped_function`
  1066. _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals
  1067. return fn_or_name
  1068. @compatibility(is_backward_compatible=True)
  1069. def symbolic_trace(
  1070. root: Union[torch.nn.Module, Callable[..., Any]],
  1071. concrete_args: Optional[dict[str, Any]] = None,
  1072. ) -> GraphModule:
  1073. """
  1074. Symbolic tracing API
  1075. Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
  1076. constructed by recording operations seen while tracing through ``root``.
  1077. ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures.
  1078. For example::
  1079. def f(a, b):
  1080. if b == True:
  1081. return a
  1082. else:
  1083. return a * 2
  1084. FX can typically not trace through this due to the presence of control
  1085. flow. However, we can use `concrete_args` to specialize on the value of
  1086. `b` to trace through this::
  1087. f = fx.symbolic_trace(f, concrete_args={"b": False})
  1088. assert f(3, False) == 6
  1089. Note that although you can still pass in different values of `b`, they will be ignored.
  1090. We can also use `concrete_args` to eliminate data-structure handling from
  1091. our function. This will use pytrees to flatten your input. To avoid
  1092. overspecializing, pass in `fx.PH` for values that shouldn't be
  1093. specialized. For example::
  1094. def f(x):
  1095. out = 0
  1096. for v in x.values():
  1097. out += v
  1098. return out
  1099. f = fx.symbolic_trace(
  1100. f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}
  1101. )
  1102. assert f({"a": 1, "b": 2, "c": 4}) == 7
  1103. Args:
  1104. root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
  1105. into a Graph representation.
  1106. concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized
  1107. Returns:
  1108. GraphModule: a Module created from the recorded operations from ``root``.
  1109. """
  1110. tracer = Tracer()
  1111. graph = tracer.trace(root, concrete_args)
  1112. name = (
  1113. root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
  1114. )
  1115. return _make_graph_module(tracer.root, graph, name)
  1116. @wrap
  1117. def _assert_is_none(value, msg):
  1118. assert value is None, msg