graph_module.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import copy
  4. import itertools
  5. import linecache
  6. import os
  7. import sys
  8. import traceback
  9. import warnings
  10. from pathlib import Path
  11. from typing import Any, Callable, Optional, Union
  12. import torch
  13. import torch.nn as nn
  14. import torch.overrides
  15. from torch.nn.modules.module import _addindent
  16. from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
  17. from ._compatibility import compatibility
  18. from .graph import (
  19. _custom_builtins,
  20. _is_from_torch,
  21. _override_sym_repr,
  22. _PyTreeCodeGen,
  23. Graph,
  24. PythonCode,
  25. )
  26. __all__ = [
  27. "reduce_graph_module",
  28. "reduce_package_graph_module",
  29. "GraphModule",
  30. ]
  31. _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
  32. # Normal exec loses the source code, however we can work with
  33. # the linecache module to recover it.
  34. # Using _exec_with_source will add it to our local cache
  35. # and then tools like TorchScript will be able to get source info.
  36. class _EvalCacheLoader:
  37. def __init__(self):
  38. self.eval_cache = {}
  39. self.next_id = 0
  40. def cache(self, src: str, globals: dict[str, Any], co_fields=None):
  41. """Store the source in a private cache, and add a lazy entry in linecache
  42. that allows the source to be retrieved by 'filename'.
  43. Args:
  44. src (str): The module source to cache
  45. globals (dict): The module globals
  46. Returns:
  47. str: The cache key (and dummy filename) generated for src.
  48. """
  49. key = self._get_key()
  50. if co_fields:
  51. key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
  52. self.eval_cache[key] = src
  53. # Don't mutate globals so that this loader is only used
  54. # to populate linecache, and doesn't interact with other modules
  55. # that might check `__loader__`
  56. globals_copy = globals.copy()
  57. globals_copy["__file__"] = key
  58. globals_copy["__name__"] = key
  59. globals_copy["__loader__"] = self
  60. linecache.lazycache(key, globals_copy)
  61. return key
  62. # Part of the loader protocol (PEP 302)
  63. # linecache will use this method when trying to find source code
  64. def get_source(self, module_name) -> Optional[str]:
  65. if module_name in self.eval_cache:
  66. return self.eval_cache[module_name]
  67. return None
  68. def _get_key(self):
  69. key = f"<eval_with_key>.{self.next_id}"
  70. self.next_id += 1
  71. return key
  72. _loader = _EvalCacheLoader()
  73. def _exec_with_source(src: str, globals: dict[str, Any], co_fields=None):
  74. key = _loader.cache(src, globals, co_fields)
  75. exec(compile(src, key, "exec"), globals)
  76. def _forward_from_src(src: str, globals: dict[str, Any], co_fields=None):
  77. return _method_from_src(
  78. method_name="forward", src=src, globals=globals, co_fields=co_fields
  79. )
  80. def _method_from_src(
  81. method_name: str, src: str, globals: dict[str, Any], co_fields=None
  82. ) -> Callable:
  83. # avoid mutating the passed in dict
  84. globals_copy = globals.copy()
  85. _exec_with_source(src, globals_copy, co_fields)
  86. fn = globals_copy[method_name]
  87. del globals_copy[method_name]
  88. return fn
  89. def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
  90. if name in _custom_builtins:
  91. return _custom_builtins[name].import_str
  92. if _is_from_torch(name):
  93. return "import torch"
  94. module_name, attr_name = importer.get_name(obj)
  95. return f"from {module_name} import {attr_name} as {name}"
  96. def _format_import_block(globals: dict[str, Any], importer: Importer):
  97. import_strs: set[str] = {
  98. _format_import_statement(name, obj, importer) for name, obj in globals.items()
  99. }
  100. # Sort the imports so we have a stable import block that allows us to
  101. # hash the graph module and get a consistent key for use in a cache.
  102. return "\n".join(sorted(import_strs))
  103. @compatibility(is_backward_compatible=True)
  104. def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Module:
  105. # BC: attribute name was changed from `code` to `_code` to facilitate
  106. # making `code` into a property and adding a docstring to it
  107. fn_src = body.get("_code") or body["code"]
  108. forward = _forward_from_src(import_block + fn_src, {})
  109. return _deserialize_graph_module(forward, body)
  110. @compatibility(is_backward_compatible=True)
  111. def reduce_package_graph_module(
  112. importer: PackageImporter, body: dict[Any, Any], generated_module_name: str
  113. ) -> torch.nn.Module:
  114. forward = importer.import_module(generated_module_name).forward
  115. return _deserialize_graph_module(forward, body)
  116. # We create a dummy class here because symbolic_trace pulls the forward()
  117. # function off of the class, rather than the instance. This class is used
  118. # in _deserialize_graph_module() below.
  119. class _CodeOnlyModule(torch.nn.Module):
  120. def __init__(self, body):
  121. super().__init__()
  122. self.__dict__ = body
  123. def _deserialize_graph_module(
  124. forward, body: dict[Any, Any], graph_module_cls=None
  125. ) -> torch.nn.Module:
  126. """
  127. Deserialize a GraphModule given the dictionary of the original module,
  128. using the code to reconstruct the graph. We delete the actual graph before
  129. saving the dictionary so that changes to the in-memory graph format do not
  130. get serialized.
  131. """
  132. # Try to retrieve the forward source in a backward-compatible way
  133. _CodeOnlyModule.forward = forward
  134. tracer_cls = body.get("_tracer_cls")
  135. if tracer_cls is None:
  136. from ._symbolic_trace import Tracer
  137. tracer_cls = Tracer
  138. graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule")
  139. # This is a workaround for a mypy linter issue related to
  140. # passing base class as an argument - https://github.com/python/mypy/issues/5865.
  141. cls_tracer: Any = tracer_cls
  142. class KeepModules(cls_tracer):
  143. # we shouldn't trace into any of the submodules,
  144. # because they were not traced in the original GraphModule
  145. def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
  146. return True
  147. com = _CodeOnlyModule(body)
  148. tracer_extras = body.get("_tracer_extras", {})
  149. graph = KeepModules().trace(com, **tracer_extras)
  150. # Recover node.meta["stack_trace"] after re-tracing
  151. node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace", None)
  152. if node_meta_stack_trace is not None:
  153. del body["_graphmodule_graph_node_meta_stack_trace"]
  154. for node in graph.nodes:
  155. if node_meta_stack_trace.get(node.name, None) is not None:
  156. node.meta["stack_trace"] = node_meta_stack_trace[node.name]
  157. # Manually set Tracer class on the reconstructed Graph, to avoid
  158. # referencing the private local subclass KeepModules.
  159. graph._tracer_cls = tracer_cls
  160. from ._lazy_graph_module import _make_graph_module
  161. gm = _make_graph_module(
  162. com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls
  163. )
  164. # The GraphModule constructor only retains attributes referenced by the graph.
  165. # In this case, our goal is return a GraphModule as close to identical as the one
  166. # put into the package. If any additional attributes were present in body,
  167. # we should keep them.
  168. for k, v in body.items():
  169. if not hasattr(gm, k):
  170. setattr(gm, k, v)
  171. return gm
  172. # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
  173. # This installs empty Modules where none exist yet if they are subpaths of target
  174. def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
  175. *prefix, field = target.split(".")
  176. for item in prefix:
  177. f = getattr(from_module, item)
  178. t = getattr(to_module, item, None)
  179. if f is t:
  180. # we have already installed one of its parents
  181. # (e.g. target = root.linear.weight, but we have already installed root.linear)
  182. # once we install a parent, we no longer need to copy the children
  183. # since all the needed properties will already be present
  184. return
  185. if t is None:
  186. t = torch.nn.Module()
  187. setattr(to_module, item, t)
  188. from_module, to_module = f, t
  189. orig = getattr(from_module, field)
  190. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  191. # So, we register it as a named buffer in the target module.
  192. if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
  193. to_module.register_buffer(field, orig)
  194. else:
  195. setattr(to_module, field, orig)
  196. # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
  197. # This installs empty Modules where none exist yet if they are subpaths of target
  198. def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
  199. *prefix, field = target.split(".")
  200. for item in prefix:
  201. t = getattr(to_module, item, None)
  202. if t is None:
  203. t = torch.nn.Module()
  204. setattr(to_module, item, t)
  205. to_module = t
  206. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  207. # So, we register it as a named buffer in the target module.
  208. if isinstance(from_obj, torch.Tensor) and not isinstance(
  209. from_obj, torch.nn.Parameter
  210. ):
  211. to_module.register_buffer(field, from_obj)
  212. else:
  213. setattr(to_module, field, from_obj)
  214. # Recursively look up target from a graph module.
  215. def _get_attr(model: torch.nn.Module, attr_name: str):
  216. return _get_attr_via_attr_list(model, attr_name.split("."))
  217. def _del_attr(model: torch.nn.Module, attr_name: str):
  218. attr_names = attr_name.split(".")
  219. t = _get_attr_via_attr_list(model, attr_names[:-1])
  220. return delattr(t, attr_names[-1])
  221. def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]):
  222. if len(attr_list) == 0:
  223. return model
  224. *prefix, field = attr_list
  225. t = model
  226. for item in prefix:
  227. t = getattr(t, item, None) # type: ignore[assignment]
  228. assert t is not None
  229. return getattr(t, field)
  230. def _has_attr(model: torch.nn.Module, attr_name: str):
  231. *prefix, field = attr_name.split(".")
  232. t = model
  233. for item in prefix:
  234. t = hasattr(t, item) # type: ignore[assignment]
  235. if t is False:
  236. return False
  237. return hasattr(t, field)
  238. def _print_readable(
  239. module,
  240. module_name,
  241. print_output=True,
  242. include_stride=False,
  243. include_device=False,
  244. colored=False,
  245. expanded_def=False,
  246. ):
  247. graph = module.graph
  248. assert graph is not None and isinstance(graph, torch.fx.Graph), (
  249. "print_readable must be used on a module with a graph"
  250. )
  251. verbose_python_code = graph.python_code(
  252. root_module="self",
  253. verbose=True,
  254. include_stride=include_stride,
  255. include_device=include_device,
  256. colored=colored,
  257. expanded_def=expanded_def,
  258. )
  259. module_code = verbose_python_code.src
  260. module_code = module_code.lstrip("\n")
  261. module_code = f"class {module_name}(torch.nn.Module):\n" + module_code
  262. module_code = _addindent(module_code, 4)
  263. submodule_code_list = [""]
  264. for submodule_name, submodule in module.named_children():
  265. if hasattr(submodule, "graph"):
  266. submodule_code_list.append(
  267. _print_readable(
  268. submodule,
  269. submodule_name,
  270. print_output=False,
  271. include_stride=include_stride,
  272. include_device=include_device,
  273. colored=colored,
  274. )
  275. )
  276. submodule_code = "\n".join(submodule_code_list)
  277. submodule_code = _addindent(submodule_code, 4)
  278. output = module_code + submodule_code
  279. if print_output:
  280. print(module_code + submodule_code)
  281. return output
  282. class _WrappedCall:
  283. def __init__(self, cls, cls_call):
  284. self.cls = cls
  285. self.cls_call = cls_call
  286. # Previously, if an error occurred when valid
  287. # symbolically-traced code was run with an invalid input, the
  288. # user would see the source of the error as coming from
  289. # `File "<eval_with_key_N">`, where N is some number. We use
  290. # this function to generate a more informative error message. We
  291. # return the traceback itself, a message explaining that the
  292. # error occurred in a traced Module's generated forward
  293. # function, and five lines of context surrounding the faulty
  294. # line
  295. @staticmethod
  296. def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
  297. # auxiliary variables (for readability)
  298. err_lineno = frame_summary.lineno
  299. assert err_lineno is not None
  300. line = frame_summary.line
  301. assert line is not None
  302. err_line_len = len(line)
  303. all_src_lines = linecache.getlines(frame_summary.filename)
  304. # constituent substrings of the error message
  305. tb_repr = torch._dynamo.disable(
  306. traceback.format_exc,
  307. reason="do not trace into traceback.format_exc when generating error message",
  308. )()
  309. custom_msg = (
  310. "Call using an FX-traced Module, "
  311. f"line {err_lineno} of the traced Module's "
  312. "generated forward function:"
  313. )
  314. before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
  315. marker = "~" * err_line_len + "~~~ <--- HERE"
  316. err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
  317. # joined message
  318. return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
  319. def __call__(self, obj, *args, **kwargs):
  320. try:
  321. if self.cls_call is not None:
  322. return self.cls_call(obj, *args, **kwargs)
  323. else:
  324. return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
  325. except Exception as e:
  326. assert e.__traceback__
  327. topmost_framesummary: traceback.FrameSummary = (
  328. traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]
  329. )
  330. if "eval_with_key" in topmost_framesummary.filename:
  331. print(
  332. _WrappedCall._generate_error_message(topmost_framesummary),
  333. file=sys.stderr,
  334. )
  335. raise e.with_traceback(None) # noqa: B904
  336. else:
  337. raise e
  338. @compatibility(is_backward_compatible=True)
  339. class GraphModule(torch.nn.Module):
  340. """
  341. GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
  342. ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
  343. from that ``graph``.
  344. .. warning::
  345. When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
  346. regenerated. However, if you edit the contents of the ``graph`` without reassigning
  347. the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
  348. code.
  349. """
  350. def __new__(cls: "type[GraphModule]", *args, **kwargs):
  351. # each instance of a graph module needs its own forward method
  352. # so create a new singleton class for each instance.
  353. # it is a subclass of the user-defined class, the only difference
  354. # is an extra layer to install the forward method
  355. # address issue described at https://github.com/pytorch/pytorch/issues/63883
  356. # in other words, traverse class hierarchy to fix the redundant class definition problem
  357. for t in cls.__mro__:
  358. c = t.__qualname__.split(".")[-1]
  359. if c != "GraphModuleImpl":
  360. cls = t
  361. break
  362. class GraphModuleImpl(cls): # type: ignore[misc, valid-type]
  363. pass
  364. return super().__new__(GraphModuleImpl)
  365. @compatibility(is_backward_compatible=True)
  366. def __init__(
  367. self,
  368. root: Union[torch.nn.Module, dict[str, Any]],
  369. graph: Graph,
  370. class_name: str = "GraphModule",
  371. ):
  372. """
  373. Construct a GraphModule.
  374. Args:
  375. root (Union[torch.nn.Module, Dict[str, Any]):
  376. ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
  377. In the case that ``root`` is a Module, any references to Module-based objects (via qualified
  378. name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
  379. within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
  380. In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
  381. looked up directly in the dict's keys. The object mapped to by the Dict will be copied
  382. over into the appropriate place within the GraphModule's module hierarchy.
  383. graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
  384. class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
  385. error messages will report as originating from ``GraphModule``. It may be helpful to set this
  386. to ``root``'s original name or a name that makes sense within the context of your transform.
  387. """
  388. super().__init__()
  389. self.__class__.__name__ = class_name
  390. if isinstance(root, torch.nn.Module):
  391. if hasattr(root, "training"):
  392. self.training = root.training
  393. # When we pickle/unpickle graph module, we don't want to drop any module or attributes.
  394. if isinstance(root, _CodeOnlyModule):
  395. for k, _ in root.named_children():
  396. _copy_attr(root, self, k)
  397. for k, _ in root.named_buffers():
  398. _copy_attr(root, self, k)
  399. for k, _ in root.named_parameters():
  400. _copy_attr(root, self, k)
  401. for node in graph.nodes:
  402. if node.op in ["get_attr", "call_module"]:
  403. assert isinstance(node.target, str)
  404. _copy_attr(root, self, node.target)
  405. elif isinstance(root, dict):
  406. targets_to_copy = []
  407. for node in graph.nodes:
  408. if node.op in ["get_attr", "call_module"]:
  409. assert isinstance(node.target, str)
  410. if node.target not in root:
  411. raise RuntimeError(
  412. "Node "
  413. + str(node)
  414. + " referenced target "
  415. + node.target
  416. + " but that target was not provided in ``root``!"
  417. )
  418. targets_to_copy.append(node.target)
  419. # Sort targets in ascending order of the # of atoms.
  420. # This will ensure that less deeply nested attributes are assigned
  421. # before more deeply nested attributes. For example, foo.bar
  422. # will be assigned before foo.bar.baz. Otherwise, we might assign
  423. # the user-provided ``foo.bar`` and wipe out the previously-assigned
  424. # ``foo.bar.baz``
  425. targets_to_copy.sort(key=lambda t: t.count("."))
  426. for target_to_copy in targets_to_copy:
  427. _assign_attr(root[target_to_copy], self, target_to_copy)
  428. else:
  429. raise RuntimeError("Unsupported type " + str(root) + " passed for root!")
  430. self.graph = graph
  431. # Store the Tracer class responsible for creating a Graph separately as part of the
  432. # GraphModule state, except when the Tracer is defined in a local namespace.
  433. # Locally defined Tracers are not pickleable. This is needed because torch.package will
  434. # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
  435. # to re-create the Graph during deserialization.
  436. self._tracer_cls = None
  437. if (
  438. self.graph._tracer_cls
  439. and "<locals>" not in self.graph._tracer_cls.__qualname__
  440. ):
  441. self._tracer_cls = self.graph._tracer_cls
  442. self._tracer_extras = {}
  443. if self.graph._tracer_extras:
  444. self._tracer_extras = self.graph._tracer_extras
  445. # Dictionary to store metadata
  446. self.meta: dict[str, Any] = {}
  447. self._replace_hooks: list[Callable] = []
  448. self._create_node_hooks: list[Callable] = []
  449. self._erase_node_hooks: list[Callable] = []
  450. # Used to remove hooks from deepcopied graph modules within a context manager.
  451. self._deepcopy_hooks: list[Callable] = []
  452. # TorchScript breaks trying to compile the graph setter because of the
  453. # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
  454. #
  455. # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
  456. __jit_unused_properties__ = ["graph"]
  457. @property
  458. def graph(self) -> Graph:
  459. """
  460. Return the ``Graph`` underlying this ``GraphModule``
  461. """
  462. return self._graph
  463. @graph.setter
  464. def graph(self, g: Graph) -> None:
  465. """
  466. Set the underlying ``Graph`` for this ``GraphModule``. This will internally
  467. recompile the ``GraphModule`` so that the generated ``forward()`` function
  468. corresponds to ``g``
  469. """
  470. assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}"
  471. self._graph = g
  472. g.owning_module = self
  473. self.recompile()
  474. @compatibility(is_backward_compatible=False)
  475. def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
  476. """Dumps out module to ``folder`` with ``module_name`` so that it can be
  477. imported with ``from <folder> import <module_name>``
  478. Args:
  479. folder (Union[str, os.PathLike]): The folder to write the code out to
  480. module_name (str): Top-level name to use for the ``Module`` while
  481. writing out the code
  482. """
  483. folder = Path(folder)
  484. Path(folder).mkdir(exist_ok=True)
  485. torch.save(self.state_dict(), folder / "state_dict.pt")
  486. tab = " " * 4
  487. custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()])
  488. model_str = f"""
  489. import torch
  490. {custom_builtins}
  491. from torch.nn import *
  492. class {module_name}(torch.nn.Module):
  493. def __init__(self):
  494. super().__init__()
  495. """
  496. def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
  497. safe_reprs = [
  498. nn.Linear,
  499. nn.Conv1d,
  500. nn.Conv2d,
  501. nn.Conv3d,
  502. nn.BatchNorm1d,
  503. nn.BatchNorm2d,
  504. nn.BatchNorm3d,
  505. ]
  506. if type(module) in safe_reprs:
  507. return f"{module.__repr__()}"
  508. else:
  509. return None
  510. blobified_modules = []
  511. for module_name, module in self.named_children():
  512. module_str = _gen_model_repr(module_name, module)
  513. if module_str is None:
  514. module_file = folder / f"{module_name}.pt"
  515. torch.save(module, module_file)
  516. blobified_modules.append(module_name)
  517. module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
  518. # weights_only=False as this is legacy code that saves the model
  519. module_str = (
  520. f"torch.load(r'{module_file}', weights_only=False) # {module_repr}"
  521. )
  522. model_str += f"{tab * 2}self.{module_name} = {module_str}\n"
  523. for buffer_name, buffer in self._buffers.items():
  524. if buffer is None:
  525. continue
  526. model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950
  527. for param_name, param in self._parameters.items():
  528. if param is None:
  529. continue
  530. model_str += f"{tab * 2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" # noqa: B950
  531. model_str += (
  532. f"{tab * 2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
  533. )
  534. model_str += f"{_addindent(self.code, 4)}\n"
  535. module_file = folder / "module.py"
  536. module_file.write_text(model_str)
  537. init_file = folder / "__init__.py"
  538. init_file.write_text("from .module import *")
  539. if len(blobified_modules) > 0:
  540. warnings.warn(
  541. "Was not able to save the following children modules as reprs -"
  542. f"saved as pickled files instead: {blobified_modules}"
  543. )
  544. @compatibility(is_backward_compatible=True)
  545. def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
  546. """
  547. Adds the given submodule to ``self``.
  548. This installs empty Modules where none exist yet if they are
  549. subpaths of ``target``.
  550. Args:
  551. target: The fully-qualified string name of the new submodule
  552. (See example in ``nn.Module.get_submodule`` for how to
  553. specify a fully-qualified string.)
  554. m: The submodule itself; the actual object we want to
  555. install in the current Module
  556. Return:
  557. bool: Whether or not the submodule could be inserted. For
  558. this method to return True, each object in the chain
  559. denoted by ``target`` must either a) not exist yet,
  560. or b) reference an ``nn.Module`` (not a parameter or
  561. other attribute)
  562. """
  563. *prefix, field = target.split(".")
  564. mod: torch.nn.Module = self
  565. for item in prefix:
  566. submod = getattr(mod, item, None)
  567. if submod is None:
  568. submod = torch.nn.Module()
  569. setattr(mod, item, submod)
  570. if not isinstance(submod, torch.nn.Module):
  571. return False
  572. mod = submod
  573. mod.add_module(field, m)
  574. return True
  575. @compatibility(is_backward_compatible=True)
  576. def delete_submodule(self, target: str) -> bool:
  577. """
  578. Deletes the given submodule from ``self``.
  579. The module will not be deleted if ``target`` is not a valid
  580. target.
  581. Args:
  582. target: The fully-qualified string name of the new submodule
  583. (See example in ``nn.Module.get_submodule`` for how to
  584. specify a fully-qualified string.)
  585. Returns:
  586. bool: Whether or not the target string referenced a
  587. submodule we want to delete. A return value of ``False``
  588. means that the ``target`` was not a valid reference to
  589. a submodule.
  590. """
  591. atoms = target.split(".")
  592. path, target_submod = atoms[:-1], atoms[-1]
  593. mod: torch.nn.Module = self
  594. # Get the parent module
  595. for item in path:
  596. if not hasattr(mod, item):
  597. return False
  598. mod = getattr(mod, item)
  599. if not isinstance(mod, torch.nn.Module):
  600. return False
  601. if not hasattr(mod, target_submod):
  602. return False
  603. if not isinstance(getattr(mod, target_submod), torch.nn.Module):
  604. return False
  605. delattr(mod, target_submod)
  606. return True
  607. @compatibility(is_backward_compatible=True)
  608. def delete_all_unused_submodules(self) -> None:
  609. """
  610. Deletes all unused submodules from ``self``.
  611. A Module is considered "used" if any one of the following is
  612. true:
  613. 1. It has children that are used
  614. 2. Its forward is called directly via a ``call_module`` node
  615. 3. It has a non-Module attribute that is used from a
  616. ``get_attr`` node
  617. This method can be called to clean up an ``nn.Module`` without
  618. manually calling ``delete_submodule`` on each unused submodule.
  619. """
  620. used: list[str] = []
  621. for node in self.graph.nodes:
  622. if node.op == "call_module" or node.op == "get_attr":
  623. # A list of strings representing the different parts
  624. # of the path. For example, `foo.bar.baz` gives us
  625. # ["foo", "bar", "baz"]
  626. fullpath = node.target.split(".")
  627. # If we're looking at multiple parts of a path, join
  628. # join them with a dot. Otherwise, return that single
  629. # element without doing anything to it.
  630. def join_fn(x: str, y: str) -> str:
  631. return ".".join([x, y] if y else [x])
  632. # Progressively collect all the names of intermediate
  633. # modules. For example, if we have the target
  634. # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
  635. # `foo.bar.baz` to the list.
  636. used.extend(itertools.accumulate(fullpath, join_fn))
  637. # For a `call_module` node, also register all recursive submodules
  638. # as used
  639. if node.op == "call_module":
  640. try:
  641. submod = self.get_submodule(node.target)
  642. for submod_name, _ in submod.named_modules():
  643. if submod_name != "":
  644. used.append(".".join([node.target, submod_name]))
  645. except AttributeError:
  646. # Node referenced nonexistent submodule, don't need to
  647. # worry about GCing anything
  648. pass
  649. to_delete = [name for name, _ in self.named_modules() if name not in used]
  650. for name in to_delete:
  651. self.delete_submodule(name)
  652. @property
  653. def code(self) -> str:
  654. """
  655. Return the Python code generated from the ``Graph`` underlying this
  656. ``GraphModule``.
  657. """
  658. if not hasattr(self, "_code"):
  659. raise RuntimeError(
  660. "Code has not been generated! Please report a bug to PyTorch"
  661. )
  662. return self._code
  663. @compatibility(is_backward_compatible=True)
  664. def recompile(self) -> PythonCode:
  665. """
  666. Recompile this GraphModule from its ``graph`` attribute. This should be
  667. called after editing the contained ``graph``, otherwise the generated
  668. code of this ``GraphModule`` will be out of date.
  669. """
  670. if isinstance(self._graph._codegen, _PyTreeCodeGen):
  671. self._in_spec = self._graph._codegen.pytree_info.in_spec
  672. self._out_spec = self._graph._codegen.pytree_info.out_spec
  673. python_code = self._graph.python_code(root_module="self")
  674. self._code = python_code.src
  675. self._lineno_map = python_code._lineno_map
  676. cls = type(self)
  677. co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
  678. cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
  679. # Determine whether this class explicitly defines a __call__ implementation
  680. # to wrap. If it does, save it in order to have wrapped_call invoke it.
  681. # If it does not, wrapped_call can use a dynamic call to super() instead.
  682. # In most cases, super().__call__ should be torch.nn.Module.__call__.
  683. # We do not want to hold a reference to Module.__call__ here; doing so will
  684. # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
  685. cls_call = cls.__call__ if "__call__" in vars(cls) else None
  686. if "_wrapped_call" not in vars(cls):
  687. cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
  688. def call_wrapped(self, *args, **kwargs):
  689. return self._wrapped_call(self, *args, **kwargs)
  690. cls.__call__ = call_wrapped # type: ignore[method-assign]
  691. return python_code
  692. # Passing Tracer as argument allows subclasses extending fx.GraphModule
  693. # define their own Tracer (extending fx.Tracer).
  694. def __reduce_package__(self, exporter: PackageExporter):
  695. dict_without_graph = self.__dict__.copy()
  696. dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
  697. del dict_without_graph["_graph"]
  698. # Store node.meta["stack_trace"] so we can recover them after re-tracing during deserialization
  699. node_meta_stack_trace = {
  700. node.name: node.meta["stack_trace"]
  701. for node in self.graph.nodes
  702. if "stack_trace" in node.meta
  703. }
  704. dict_without_graph["_graphmodule_graph_node_meta_stack_trace"] = (
  705. node_meta_stack_trace
  706. )
  707. generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
  708. python_code = self.recompile()
  709. import_block = _format_import_block(python_code.globals, exporter.importer)
  710. module_code = import_block + self.code
  711. exporter.save_source_string(generated_module_name, module_code)
  712. return (
  713. reduce_package_graph_module,
  714. (dict_without_graph, generated_module_name),
  715. )
  716. def __reduce__(self):
  717. """
  718. Serialization of GraphModule. We serialize only the generated code, not
  719. the underlying ``Graph``. This is because ``Graph`` does not have on-disk
  720. backward-compatibility guarantees, whereas Python source code does.
  721. On the deserialization side, we symbolically trace through the generated
  722. code to regenerate the underlying ``Graph``
  723. """
  724. dict_without_graph = self.__dict__.copy()
  725. python_code = self.recompile()
  726. import_block = _format_import_block(python_code.globals, sys_importer)
  727. del dict_without_graph["_graph"]
  728. return (reduce_graph_module, (dict_without_graph, import_block))
  729. def _deepcopy_init(self):
  730. return GraphModule.__init__
  731. # because __reduce__ is defined for serialization,
  732. # we need to define deepcopy otherwise it will call __reduce__
  733. # and cause symbolic tracing to occur every time we try to copy the object
  734. def __deepcopy__(self, memo):
  735. res = type(self).__new__(type(self))
  736. memo[id(self)] = res
  737. fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
  738. self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"])
  739. # hooks are lost during `GraphModule.__init__`, so we need to copy over
  740. # them explicitly, note right now we are only copying state_dict related
  741. # hooks, to reduce bc-related issues, we can copy forward/backward related
  742. # hooks in the future as well if needed
  743. extra_preserved_attrs = [
  744. "_state_dict_hooks",
  745. "_load_state_dict_pre_hooks",
  746. "_load_state_dict_post_hooks",
  747. "_replace_hooks",
  748. "_create_node_hooks",
  749. "_erase_node_hooks",
  750. "_deepcopy_hooks",
  751. ]
  752. for attr in extra_preserved_attrs:
  753. if attr in self.__dict__:
  754. setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
  755. res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
  756. if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
  757. for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
  758. setattr(res, attr_name, attr)
  759. if hasattr(self, "_deepcopy_hooks"):
  760. for hook in self._deepcopy_hooks:
  761. hook(res)
  762. return res
  763. def __copy__(self):
  764. from ._lazy_graph_module import _make_graph_module
  765. res = _make_graph_module(self, self.graph)
  766. res.meta = getattr(self, "meta", {})
  767. return res
  768. @compatibility(is_backward_compatible=False)
  769. def print_readable(
  770. self,
  771. print_output=True,
  772. include_stride=False,
  773. include_device=False,
  774. colored=False,
  775. *,
  776. # If `fast_sympy_print` is True then we use a sympy printer which is faster
  777. # but may result in less-readable output.
  778. fast_sympy_print: bool = False,
  779. expanded_def: bool = False,
  780. ):
  781. """
  782. Return the Python code generated for current GraphModule and its children GraphModules
  783. """
  784. ctx_mgr = contextlib.ExitStack()
  785. with ctx_mgr:
  786. if fast_sympy_print:
  787. from torch._inductor.utils import sympy_str
  788. def fast_repr(expr: torch.types.PySymType) -> str:
  789. return sympy_str(expr.node.expr)
  790. ctx_mgr.enter_context(_override_sym_repr(fast_repr))
  791. r = _print_readable(
  792. self,
  793. self._get_name(),
  794. print_output,
  795. include_stride,
  796. include_device,
  797. colored,
  798. expanded_def,
  799. )
  800. return r
  801. def __str__(self) -> str:
  802. orig_str = super().__str__()
  803. print_readable_reminder = (
  804. "# To see more debug info, please use `graph_module.print_readable()`"
  805. )
  806. return "\n".join([orig_str, self._code, print_readable_reminder])
  807. def _replicate_for_data_parallel(self):
  808. new_gm = self.__copy__()
  809. new_gm._is_replica = True
  810. return new_gm
  811. @contextlib.contextmanager
  812. def _set_replace_hook(self, f):
  813. """
  814. Takes a callable which will be called every time when we replace a node
  815. to a new node, or change the node's name. Callable takes three arguments:
  816. the old node we're changing, and NAME of the new node, followed by the
  817. user node which consumes the old node to be replaced.
  818. """
  819. assert callable(f), "Replace hook must be a callable."
  820. self._register_replace_node_hook(f)
  821. try:
  822. yield
  823. finally:
  824. self._unregister_replace_node_hook(f)
  825. def _register_replace_node_hook(self, f):
  826. """
  827. Takes a callable which will be called every time when we replace a node
  828. to a new node, or change the node's name. Callable takes three arguments:
  829. the old node we're changing, and NAME of the new node, followed by the
  830. user node which consumes the old node to be replaced.
  831. """
  832. assert callable(f), "create_node hook must be a callable."
  833. self._replace_hooks.append(f)
  834. def _unregister_replace_node_hook(self, f):
  835. """
  836. Takes a callable which was previously registered to be called every time when we replace a node.
  837. This function will unregister that callable so it is no longer invoked on node replacement.
  838. """
  839. assert callable(f), "create_node hook must be a callable."
  840. self._replace_hooks.remove(f)
  841. def _register_create_node_hook(self, f):
  842. """
  843. Takes a callable which will be called after we create a new node. The
  844. callable takes the newly created node as input and returns None.
  845. """
  846. assert callable(f), "create_node hook must be a callable."
  847. self._create_node_hooks.append(f)
  848. def _unregister_create_node_hook(self, f):
  849. """
  850. Takes a callable which was previously registered to be called after we create a node.
  851. This function will unregister that callable so it is no longer invoked on node creation.
  852. """
  853. assert callable(f), "create_node hook must be a callable."
  854. self._create_node_hooks.remove(f)
  855. def _register_erase_node_hook(self, f):
  856. """
  857. Takes a callable which will be called after we erase a node. The
  858. callable takes the node that is being erased as input and returns None.
  859. """
  860. assert callable(f), "erase_node hook must be a callable."
  861. self._erase_node_hooks.append(f)
  862. def _unregister_erase_node_hook(self, f):
  863. """
  864. Takes a callable which was previously registered to be called after we erase a node.
  865. This function will unregister that callable so it is no longer invoked on node erasure.
  866. """
  867. assert callable(f), "erase_node hook must be a callable."
  868. self._erase_node_hooks.remove(f)
  869. def _register_deepcopy_hook(self, f):
  870. """
  871. Takes a callable which will be called when we deepcopy this graph module. The
  872. callable takes the resulting deepcopied graph module.
  873. """
  874. assert callable(f), "deepcopy hook must be a callable."
  875. self._deepcopy_hooks.append(f)
  876. def _unregister_deepcopy_hook(self, f):
  877. """
  878. Takes a callable which was previously registered to be called after deepcopy.
  879. This function will unregister that callable so it is no longer invoked on deepcopy.
  880. """
  881. assert callable(f), "deepcopy hook must be a callable."
  882. self._deepcopy_hooks.remove(f)
  883. # workarounds for issues in __torch_function__
  884. # WAR for __torch_function__ not handling tensor lists,
  885. # fix is in https://github.com/pytorch/pytorch/pull/34725
  886. # orig_cat = torch.cat
  887. # def patched_cat(*args, **kwargs):
  888. # tensors = args[0]
  889. # for t in tensors:
  890. # if isinstance(t, Proxy):
  891. # return t.__torch_function__(patched_cat, (), args, kwargs)
  892. # return orig_cat(*args, **kwargs)
  893. # patched_cat.__module__ = 'torch'
  894. # patched_cat.__name__ = 'cat'
  895. # torch.cat = patched_cat