unflatten.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768
  1. # mypy: allow-untyped-defs
  2. import abc
  3. import copy
  4. import logging
  5. import operator
  6. import re
  7. from collections import defaultdict
  8. from contextlib import contextmanager
  9. from copy import deepcopy
  10. from dataclasses import dataclass
  11. from enum import Enum
  12. from typing import Any, Callable, cast, Optional, Union
  13. import torch
  14. import torch.fx._pytree as fx_pytree
  15. import torch.utils._pytree as pytree
  16. from torch._library.fake_class_registry import FakeScriptObject
  17. from torch.export import ExportedProgram
  18. from torch.export._tree_utils import reorder_kwargs
  19. from torch.export.exported_program import (
  20. ConstantArgument,
  21. ExportGraphSignature,
  22. InputKind,
  23. ModuleCallSignature,
  24. SymBoolArgument,
  25. SymFloatArgument,
  26. SymIntArgument,
  27. TensorArgument,
  28. )
  29. from torch.fx._symbolic_trace import is_fx_symbolic_tracing
  30. from torch.fx.graph_module import _get_attr, _get_attr_via_attr_list, _print_readable
  31. from torch.utils._pytree import GetAttrKey, SequenceKey
  32. from ._remove_effect_tokens_pass import _remove_effect_tokens
  33. log = logging.getLogger(__name__)
  34. __all__ = [
  35. "FlatArgsAdapter",
  36. "InterpreterModule",
  37. "InterpreterModuleDispatcher",
  38. "UnflattenedModule",
  39. "unflatten",
  40. ]
  41. class _AttrKind(Enum):
  42. PARAMETER = "parameter"
  43. BUFFER = "buffer"
  44. CONSTANT = "constant"
  45. MODULE = "module"
  46. @dataclass(frozen=True)
  47. class _TensorID:
  48. """Custom tensor identifier containing storage, stride, and size information."""
  49. untyped_storage: torch.UntypedStorage
  50. stride: tuple
  51. size: tuple
  52. storage_offset: int
  53. RUN_WITH_INTERPRETER = True
  54. @contextmanager
  55. def _disable_interpreter():
  56. global RUN_WITH_INTERPRETER
  57. old_flag = RUN_WITH_INTERPRETER
  58. RUN_WITH_INTERPRETER = False
  59. try:
  60. yield
  61. finally:
  62. RUN_WITH_INTERPRETER = old_flag
  63. # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
  64. # This installs empty Modules where none exist yet if they are subpaths of target
  65. def _assign_attr(
  66. from_obj: Union[torch.Tensor, torch.ScriptObject, torch.nn.Module],
  67. to_module: torch.nn.Module,
  68. target: str,
  69. attr_kind: _AttrKind,
  70. persistent: bool = True,
  71. ):
  72. *prefix, field = target.split(".")
  73. # We need to generate all submodules of `to_module` that are at `prefix` and
  74. # variants of `prefix` that differ only by call name. All of these submodules
  75. # will then be assigned `from_obj` at `field` so that they can share this attribute.
  76. # For example, if target is foo.bar.f, foo has another call name foo@1,
  77. # and bar has other call names bar@1, bar@2, then we will assign f to
  78. # foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2.
  79. to_modules = {to_module}
  80. for item in prefix:
  81. ts: set[torch.nn.Module] = set()
  82. for to_module in to_modules:
  83. if not hasattr(to_module, item):
  84. setattr(to_module, item, torch.nn.Module())
  85. ts.update(
  86. t_call # type: ignore[misc]
  87. for k, t_call in to_module._modules.items()
  88. if _is_call_name(k, item)
  89. )
  90. to_modules = ts
  91. for to_module in to_modules:
  92. if attr_kind == _AttrKind.PARAMETER:
  93. assert isinstance(from_obj, torch.nn.Parameter)
  94. to_module.register_parameter(field, from_obj)
  95. elif attr_kind == _AttrKind.BUFFER:
  96. assert isinstance(from_obj, torch.Tensor)
  97. to_module.register_buffer(field, from_obj, persistent=persistent)
  98. elif attr_kind == _AttrKind.CONSTANT:
  99. assert not isinstance(from_obj, FakeScriptObject), (
  100. "FakeScriptObject should only exist during tracing."
  101. )
  102. assert isinstance(
  103. from_obj,
  104. (
  105. torch.Tensor,
  106. torch.ScriptObject,
  107. ),
  108. )
  109. setattr(to_module, field, from_obj)
  110. elif attr_kind == _AttrKind.MODULE:
  111. assert isinstance(from_obj, torch.nn.Module)
  112. setattr(to_module, field, from_obj)
  113. class _SubmoduleBase:
  114. _ty: Optional[str]
  115. def type_name(self) -> Optional[str]:
  116. """
  117. Subclass of this class - InterpreterModule, InterpreterModuleDispatcher, represents
  118. corresponding model in eager model. To get this type information for those modules
  119. in eager model we need to use this method.
  120. """
  121. return self._ty
  122. class InterpreterModule(_SubmoduleBase, torch.nn.Module):
  123. """A module that uses torch.fx.Interpreter to execute instead of the usual
  124. codegen that GraphModule uses. This provides better stack trace information
  125. and makes it easier to debug execution.
  126. """
  127. graph_module: Optional[torch.fx.GraphModule]
  128. def __init__(
  129. self,
  130. graph: torch.fx.Graph,
  131. ty: Optional[str] = None,
  132. ):
  133. super().__init__()
  134. self.graph = graph
  135. self._ty = ty
  136. self.graph.owning_module = self # type: ignore[assignment]
  137. self._run_with_interpreter = RUN_WITH_INTERPRETER
  138. def forward(self, *args, **kwargs):
  139. assert self.graph_module is not None, "Didn't finalize this InterpreterModule"
  140. if not is_fx_symbolic_tracing() and (
  141. torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter
  142. ):
  143. # Dynamo cannot trace through torch.fx.Interpreter, so fall back to
  144. # GraphModule codegen in this instance.
  145. # Patch the codegened forward to run with this InterpreterModule,
  146. # so attribute accesses, etc. are on this module instead.
  147. return type(self.graph_module).forward(self, *args, **kwargs)
  148. else:
  149. if kwargs:
  150. # Handle **kwargs. FX only natively supports positional
  151. # arguments (through placeholders). So in order to pass in
  152. # kwargs, we must correspond the names of the placeholders with
  153. # the keys in the kwarg dict.
  154. arg_list = list(args)
  155. kwarg_names = self.arg_names[len(arg_list) :]
  156. arg_list.extend(
  157. kwargs[kwarg_name]
  158. for kwarg_name in kwarg_names
  159. if kwarg_name in kwargs
  160. )
  161. # Assert that the kwargs passed in exactly match the positional
  162. # arguments specified by the GraphModule. This should be
  163. # guaranteed by the unflattening process.
  164. assert len(kwarg_names) == len(kwargs)
  165. assert len(arg_list) == len(self.arg_names)
  166. args = tuple(arg_list)
  167. return torch.fx.Interpreter(self, graph=self.graph).run(
  168. *args, enable_io_processing=False
  169. )
  170. def finalize(self):
  171. # We need to "finalize" because GraphModule populates its own state_dict
  172. # based on the get_attrs observed in the graph. So we need to fully
  173. # construct the graph and call _sink_params before generating this
  174. # GraphModule.
  175. # need to set `graph_module` directly on the dict to avoid it getting
  176. # registered as a submodule.
  177. self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
  178. self.graph.lint()
  179. # Cache arg names for kwarg handling (see forward())
  180. self.arg_names = []
  181. for node in self.graph.nodes:
  182. if node.op == "placeholder":
  183. self.arg_names.append(node.target)
  184. def print_readable(
  185. self,
  186. print_output=True,
  187. include_stride=False,
  188. include_device=False,
  189. colored=False,
  190. ):
  191. return _print_readable(
  192. self,
  193. "InterpreterModule",
  194. print_output,
  195. include_stride,
  196. include_device,
  197. colored,
  198. )
  199. class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module):
  200. """
  201. A module that carries a sequence of InterpreterModules corresponding to
  202. a sequence of calls of that module. Each call to the module dispatches
  203. to the next InterpreterModule, and wraps back around after the last.
  204. """
  205. def __init__(self, attrs: set[str], call_modules: list[InterpreterModule]):
  206. super().__init__()
  207. assert call_modules
  208. self._modules = call_modules[0]._modules
  209. for accessor in attrs:
  210. setattr(self, accessor, getattr(call_modules[0], accessor))
  211. self._ty = call_modules[0]._ty
  212. self._call_modules = call_modules
  213. self._num_calls = 0
  214. def forward(self, *args, **kwargs):
  215. call_module = self._call_modules[self._num_calls]
  216. self._num_calls = (self._num_calls + 1) % len(self._call_modules)
  217. try:
  218. return call_module(*args, **kwargs)
  219. except Exception:
  220. self._num_calls = 0
  221. raise
  222. def call_modules(self):
  223. return self._call_modules
  224. def print_readable(
  225. self,
  226. print_output=True,
  227. include_stride=False,
  228. include_device=False,
  229. colored=False,
  230. ):
  231. outputs = [
  232. mod.print_readable(
  233. print_output,
  234. include_stride,
  235. include_device,
  236. colored,
  237. )
  238. for mod in self._call_modules
  239. ]
  240. return "\n".join(outputs)
  241. class FlatArgsAdapter(abc.ABC):
  242. """
  243. Adapts input arguments with ``input_spec`` to align ``target_spec``.
  244. """
  245. @abc.abstractmethod
  246. def adapt(
  247. self,
  248. target_spec: pytree.TreeSpec,
  249. input_spec: pytree.TreeSpec,
  250. input_args: list[Any],
  251. metadata: Optional[dict[str, Any]] = None,
  252. obj: Optional[Any] = None,
  253. ) -> list[Any]:
  254. """NOTE: This adapter may mutate given ``input_args_with_path``."""
  255. ...
  256. def get_flat_arg_paths(self) -> list[str]:
  257. """Returns a list of paths that are used to access the flat args."""
  258. return []
  259. class UnflattenedModule(torch.nn.Module):
  260. def __init__(
  261. self,
  262. export_module: ExportedProgram,
  263. flat_args_adapter: Optional[FlatArgsAdapter] = None,
  264. ):
  265. super().__init__()
  266. if export_module.graph_signature.backward_signature is not None:
  267. raise ValueError("Unflattening on JointExportModule NYI")
  268. def _id(obj):
  269. """Returns _TensorID dataclass for tensors, otherwise id()."""
  270. if isinstance(obj, torch.Tensor):
  271. return _TensorID(
  272. untyped_storage=obj.untyped_storage(),
  273. stride=obj.stride(),
  274. size=obj.size(),
  275. storage_offset=obj.storage_offset(), # type: ignore[arg-type]
  276. )
  277. return id(obj)
  278. fqn_list = [entry.fqn for entry in export_module.module_call_graph]
  279. assert fqn_list[0] == ""
  280. export_graph = deepcopy(export_module.graph)
  281. self.graph_signature = deepcopy(export_module.graph_signature)
  282. self.graph = torch.fx.Graph()
  283. self.graph.owning_module = self # type: ignore[assignment]
  284. self.module_call_graph = deepcopy(export_module.module_call_graph)
  285. self.flat_args_adapter = flat_args_adapter
  286. self.meta = export_module.graph_module.meta
  287. self.meta["unflattened_module"] = self
  288. # Flag to indicate whether args have been adapted.
  289. self.adapted = False
  290. self._run_with_interpreter = RUN_WITH_INTERPRETER
  291. _inplace_buffer_and_input_mutations(export_graph, self.graph_signature)
  292. _fix_nn_module_stacks(export_graph)
  293. self.ivals = _IVals()
  294. # for any intermediate value of a mutation that is read, track the mutation
  295. seen_modules, seen_attrs = _outline_submodules(export_graph, self)
  296. # for each read intermediate value of a mutation, find where it was created,
  297. # and perform the mutation
  298. self.ivals.update(seen_modules.values())
  299. # move attributes that correspond to graph arguments for HOPs
  300. # from exported program to unflattened submodules
  301. _copy_graph_attrs(export_module._graph_module, self, seen_attrs)
  302. self.range_constraints = export_module.range_constraints
  303. self.equality_constraints: list = []
  304. # aliasing/unused param or buffer issues:
  305. # in strict-mode export, dynamo export will deduplicate aliased tensors,
  306. # and ignore unused tensors. For aliasing, this causes issues when some aliases
  307. # are unused, and we're unable to match the placeholder node to the correct FQN.
  308. # This leads to the graph signature potentially having the wrong target FQN,
  309. # and downstream issues where parameters are assigned to the wrong target attribute,
  310. # mismatching the relevant placeholder node in the unflattened module.
  311. # To resolve this we restore (_assign_attr) all aliased/unused tensors in
  312. # the state_dict as module attributes, but only keep the used tensors in the
  313. # graph's forward pass (_sink_params).
  314. state_dict = export_module.state_dict
  315. assigned_params: set[str] = set() # tracking unused params
  316. id_to_param: dict[
  317. Union[int, _TensorID], torch.nn.Parameter
  318. ] = {} # handling weight-sharing
  319. for name in self.graph_signature.parameters: # this loop adds used params
  320. param = state_dict[name]
  321. if _id(param) not in id_to_param:
  322. id_to_param[_id(param)] = torch.nn.Parameter(
  323. param.clone(), requires_grad=param.requires_grad
  324. )
  325. _assign_attr(
  326. id_to_param[_id(param)],
  327. self,
  328. name,
  329. attr_kind=_AttrKind.PARAMETER,
  330. )
  331. assigned_params.add(name)
  332. non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
  333. assigned_buffers: set[str] = set() # tracking unused buffers
  334. id_to_buffer: dict[Union[int, _TensorID], tuple[torch.nn.Parameter, bool]] = {}
  335. for name in self.graph_signature.buffers: # this loop adds used buffers
  336. if name in non_persistent_buffers:
  337. persistent = False
  338. buffer = export_module.constants[name]
  339. else:
  340. persistent = True
  341. buffer = state_dict[name]
  342. if _id(buffer) not in id_to_buffer:
  343. id_to_buffer[_id(buffer)] = (buffer.clone(), persistent)
  344. _assign_attr(
  345. id_to_buffer[_id(buffer)][0],
  346. self,
  347. name,
  348. attr_kind=_AttrKind.BUFFER,
  349. persistent=persistent,
  350. )
  351. assigned_buffers.add(name)
  352. # restore aliased/unused params and buffers
  353. # these appear in state dict but not graph signature
  354. for name, tensor in state_dict.items():
  355. if name in assigned_params or name in assigned_buffers: # already assigned
  356. continue
  357. is_buffer = False
  358. if _id(tensor) in id_to_buffer or not isinstance(
  359. tensor, torch.nn.Parameter
  360. ): # aliased buffer
  361. is_buffer = True
  362. if is_buffer:
  363. if (
  364. _id(tensor) not in id_to_buffer
  365. ): # this is completely unused (not weight-sharing)
  366. id_to_buffer[_id(tensor)] = (
  367. tensor,
  368. True,
  369. ) # assign to respect original model
  370. _assign_attr(
  371. id_to_buffer[_id(tensor)][0],
  372. self,
  373. name,
  374. attr_kind=_AttrKind.BUFFER,
  375. persistent=True,
  376. )
  377. else:
  378. if _id(tensor) not in id_to_param: # this is unused
  379. id_to_param[_id(tensor)] = tensor
  380. _assign_attr(
  381. id_to_param[_id(tensor)],
  382. self,
  383. name,
  384. attr_kind=_AttrKind.PARAMETER,
  385. )
  386. # use id map so we don't double-clone aliased constants
  387. id_to_const: dict[
  388. Union[int, _TensorID], Union[torch.Tensor, torch._C.ScriptObject]
  389. ] = {}
  390. for fqn, constant in export_module.constants.items():
  391. if _id(constant) not in id_to_const:
  392. if isinstance(constant, torch.Tensor):
  393. constant = constant.clone()
  394. id_to_const[_id(constant)] = constant
  395. _constant = id_to_const[_id(constant)]
  396. _assign_attr(
  397. _constant,
  398. self,
  399. fqn,
  400. attr_kind=_AttrKind.CONSTANT,
  401. )
  402. # This is to handle parameters/buffers that point to the same tensor
  403. # object id -> list of (node_name, target_name)
  404. consts_map: dict[Union[int, _TensorID], list[tuple[str, str]]] = defaultdict(
  405. list
  406. )
  407. consts_targets: set[str] = set()
  408. def add_to_consts_map(obj_id, node_name, target_name):
  409. name_list = consts_map[obj_id]
  410. name_list.append((node_name, target_name))
  411. # track aliased/unused params, buffers
  412. # prefer using untyped_storage() over id() when it's available
  413. added_params_buffers: set[str] = set()
  414. for s in self.graph_signature.input_specs:
  415. if s.kind == InputKind.PARAMETER or (
  416. s.kind == InputKind.BUFFER and s.persistent
  417. ):
  418. assert hasattr(s.arg, "name")
  419. assert isinstance(s.target, str)
  420. add_to_consts_map(
  421. _id(export_module.state_dict[s.target]),
  422. s.arg.name,
  423. s.target,
  424. )
  425. consts_targets.add(s.target)
  426. added_params_buffers.add(s.target)
  427. elif (
  428. s.kind == InputKind.BUFFER
  429. and not s.persistent
  430. or s.kind == InputKind.CONSTANT_TENSOR
  431. or s.kind == InputKind.CUSTOM_OBJ
  432. ):
  433. assert hasattr(s.arg, "name")
  434. assert isinstance(s.target, str)
  435. add_to_consts_map(
  436. _id(export_module.constants[s.target]),
  437. s.arg.name,
  438. s.target,
  439. )
  440. consts_targets.add(s.target)
  441. # add constants that are aliased and don't appear in graph signature
  442. for const_name, const in export_module.constants.items():
  443. if const_name not in consts_targets:
  444. const_id = _id(const)
  445. assert const_id in consts_map
  446. ph_name, _ = consts_map[const_id][0]
  447. add_to_consts_map(const_id, ph_name, const_name)
  448. added_params_buffers.add(s.target)
  449. # add aliased/unused params and buffers that don't appear in graph signature
  450. for fqn, tensor in export_module.state_dict.items():
  451. if fqn not in added_params_buffers:
  452. tensor_id = _id(tensor)
  453. if tensor_id not in consts_map:
  454. # completely unused (no weight-sharing), ignore.
  455. # this weight doesn't appear in graph module,
  456. # so won't cause FQN assignment issues
  457. continue
  458. ph_name, _ = consts_map[tensor_id][0]
  459. add_to_consts_map(tensor_id, ph_name, fqn)
  460. # node name -> list of possible targets
  461. inputs_to_state: dict[str, list[str]] = {}
  462. for node_target in consts_map.values():
  463. targets = [t[1] for t in node_target]
  464. for n, _ in node_target:
  465. inputs_to_state[n] = targets
  466. _sink_params(self, inputs_to_state, [])
  467. redirected_call_indices = _deduplicate_modules(seen_modules.values())
  468. fqn_list = [fqn for fqn in fqn_list if fqn not in redirected_call_indices]
  469. self._dispatch_modules(redirected_call_indices, consts_targets)
  470. fqn_list = [fqn for fqn in fqn_list if "@" not in fqn]
  471. # Cache so we don't have to compute this every time.
  472. # NOTE: this needs to be kept in sync with the placeholders in
  473. # self.graph, but currently we have no way to guarantee that.
  474. self.input_placeholders = [
  475. node for node in self.graph.nodes if node.op == "placeholder"
  476. ]
  477. self.check_input_constraints = True
  478. # TODO(zhxchen17) We can register modules ahead of time instead of reorder later.
  479. fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)}
  480. # In the case of legacy IR, we might be missing some modules from metadata.
  481. for name, _ in self.named_modules(remove_duplicate=False):
  482. if name not in fqn_order:
  483. fqn_order[name] = len(fqn_order)
  484. _reorder_submodules(self, fqn_order)
  485. self.graph.lint()
  486. self.finalize()
  487. def _print_graph(self):
  488. for fqn, mod in self.named_modules():
  489. print(fqn + ":")
  490. if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph):
  491. print(mod.graph)
  492. def _adapt_flat_args(self, flat_args, in_spec, input):
  493. signature = self.module_call_graph[0].signature
  494. if in_spec == signature.in_spec:
  495. return flat_args
  496. if self.flat_args_adapter is None:
  497. raise TypeError(
  498. "There is no flat args adapter specified. "
  499. "Are you sure you are calling this with the right arguments? "
  500. )
  501. else:
  502. flat_args = self.flat_args_adapter.adapt(
  503. target_spec=signature.in_spec,
  504. input_spec=in_spec,
  505. input_args=flat_args,
  506. metadata=self.meta,
  507. obj=input,
  508. )
  509. if len(flat_args) != signature.in_spec.num_leaves:
  510. raise TypeError(
  511. f"Flat args adaption failed, number of args mismatch "
  512. f"Adatped: {len(flat_args)} \n"
  513. f"Exported module: {signature.in_spec.num_leaves}"
  514. )
  515. return flat_args
  516. def process_forward_inputs(self, *args, **kwargs):
  517. signature = self.module_call_graph[0].signature
  518. reordered_kwargs = kwargs
  519. if kwargs:
  520. reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec)
  521. flat_args_with_path, in_spec = pytree.tree_flatten_with_path(
  522. (args, reordered_kwargs)
  523. )
  524. flat_args = [x[1] for x in flat_args_with_path]
  525. if is_fx_symbolic_tracing():
  526. return flat_args
  527. if in_spec != signature.in_spec:
  528. if not self.adapted:
  529. print(
  530. "Input treespec does not match with exported module's: \n"
  531. f"Input treespec: {in_spec}. ",
  532. f"Exported module treespec: {signature.in_spec}",
  533. )
  534. print("Adapting flat arg to match exported module's treespec")
  535. flat_args = self._adapt_flat_args(flat_args, in_spec, args)
  536. self.adapted = True
  537. if self.check_input_constraints:
  538. # Import here to avoid an unfortunate circular dependency.
  539. # TODO(suo): untangle this.
  540. from torch._export.utils import _check_input_constraints_for_graph
  541. if self.adapted is True:
  542. flat_arg_paths = (
  543. self.flat_args_adapter.get_flat_arg_paths()
  544. if self.flat_args_adapter
  545. else []
  546. )
  547. assert not flat_arg_paths or len(flat_arg_paths) == len(flat_args)
  548. new_flat_args_with_path = [ # type: ignore[var-annotated]
  549. (
  550. (
  551. SequenceKey(idx=idx),
  552. GetAttrKey(
  553. name=flat_arg_paths[idx]
  554. if flat_arg_paths
  555. else "<unknown location>"
  556. ),
  557. ),
  558. arg,
  559. )
  560. for idx, arg in enumerate(flat_args)
  561. ]
  562. else:
  563. new_flat_args_with_path = flat_args_with_path # type: ignore[assignment]
  564. _check_input_constraints_for_graph(
  565. self.input_placeholders, new_flat_args_with_path, self.range_constraints
  566. )
  567. return flat_args
  568. def forward(self, *args, **kwargs):
  569. flat_args = self.process_forward_inputs(*args, **kwargs)
  570. signature = self.module_call_graph[0].signature
  571. if is_fx_symbolic_tracing():
  572. return_val = torch.fx.Interpreter(self, graph=self.graph).run(
  573. *flat_args, enable_io_processing=False
  574. )
  575. # For scalar return value, fx.Graph wraps in a tuple
  576. if isinstance(return_val, tuple) and len(return_val) == 1:
  577. return return_val[0]
  578. return return_val
  579. if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter:
  580. tree_out = type(self.graph_module).forward(self, *flat_args) # type: ignore[union-attr]
  581. else:
  582. tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
  583. *flat_args, enable_io_processing=False
  584. )
  585. return pytree.tree_unflatten(tree_out, signature.out_spec)
  586. def finalize(self):
  587. self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
  588. self.graph.lint()
  589. def _dispatch_modules(self, redirected_call_indices, consts_targets):
  590. """For a module whose call signatures are preserved, replace
  591. multiple modules corresponding to multiple calls to that module
  592. with a single dispatcher module that tracks which module to call.
  593. """
  594. # for each fqn whose module call signature is preserved,
  595. # map that fqn to a list of called modules
  596. called_modules = defaultdict(list)
  597. for entry in self.module_call_graph:
  598. if entry.fqn and entry.signature:
  599. # some modules were removed and their fqns redirected to other
  600. # fqns during deduplication
  601. fqn = entry.fqn
  602. mod = _get_attr(self, redirected_call_indices.get(fqn, fqn))
  603. base, idx = fqn.split("@") if "@" in fqn else [fqn, "0"]
  604. called_modules[base].append((int(idx), mod))
  605. attrs_map = defaultdict(set)
  606. for target in consts_targets:
  607. if "." in target:
  608. orig_fqn, name = target.rsplit(".", 1)
  609. attrs_map[orig_fqn].add(name)
  610. else:
  611. attrs_map[""].add(target)
  612. # replace multiple call modules with a single dispatcher module
  613. for orig_fqn, indexed_call_modules in called_modules.items():
  614. call_modules = [mod for _, mod in sorted(indexed_call_modules)]
  615. if len(call_modules) > 1:
  616. for i in range(len(call_modules)):
  617. fqn = _call_name(orig_fqn, i + 1)
  618. if fqn not in redirected_call_indices:
  619. *prefix, name = fqn.split(".")
  620. _get_attr_via_attr_list(self, prefix)._modules.pop(name)
  621. self.set_submodule(
  622. orig_fqn,
  623. InterpreterModuleDispatcher(attrs_map[orig_fqn], call_modules),
  624. )
  625. # elide call indices in call modules because they are
  626. # tracked automatically inside the dispatcher module
  627. def elide_call_indices(prefix, graph):
  628. for node in graph.nodes:
  629. if node.op == "call_module":
  630. fqn = node.target.split("@")[0]
  631. path = f"{prefix}.{fqn}" if prefix else fqn
  632. if path in called_modules:
  633. node.target = fqn
  634. for fqn, mod in self.named_modules(remove_duplicate=False):
  635. if hasattr(mod, "graph"):
  636. elide_call_indices(fqn, mod.graph)
  637. elif hasattr(mod, "_call_modules"):
  638. for mod_ in mod._call_modules:
  639. assert hasattr(mod_, "graph")
  640. elide_call_indices(fqn, mod_.graph)
  641. def print_readable(
  642. self,
  643. print_output=True,
  644. include_stride=False,
  645. include_device=False,
  646. colored=False,
  647. ):
  648. return _print_readable(
  649. self,
  650. "UnflattenedModule",
  651. print_output,
  652. include_stride,
  653. include_device,
  654. colored,
  655. )
  656. def unflatten(
  657. module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None
  658. ) -> UnflattenedModule:
  659. """Unflatten an ExportedProgram, producing a module with the same module
  660. hierarchy as the original eager module. This can be useful if you are trying
  661. to use :mod:`torch.export` with another system that expects a module
  662. hierarchy instead of the flat graph that :mod:`torch.export` usually produces.
  663. .. note:: The args/kwargs of unflattened modules will not necessarily match
  664. the eager module, so doing a module swap (e.g. :code:`self.submod =
  665. new_mod`) will not necessarily work. If you need to swap a module out, you
  666. need to set the :code:`preserve_module_call_signature` parameter of
  667. :func:`torch.export.export`.
  668. Args:
  669. module (ExportedProgram): The ExportedProgram to unflatten.
  670. flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's.
  671. Returns:
  672. An instance of :class:`UnflattenedModule`, which has the same module
  673. hierarchy as the original eager module pre-export.
  674. """
  675. module = _remove_effect_tokens(module)
  676. m = UnflattenedModule(module, flat_args_adapter)
  677. # Disable process_forward_inputs as the adapter has many
  678. # non-dynamo-traceable behavior.
  679. m.process_forward_inputs = torch._dynamo.disable( # type: ignore[method-assign]
  680. m.process_forward_inputs,
  681. reason="do not trace into preprocessing the inputs",
  682. recursive=True,
  683. )
  684. return m
  685. def _inplace_buffer_and_input_mutations(
  686. graph: torch.fx.Graph,
  687. graph_signature: ExportGraphSignature,
  688. ) -> None:
  689. """Transform buffer and input mutations from their functionalized form
  690. into copy_ nodes in the graph.
  691. Functionalization represents a buffer mutation by passing the buffer as
  692. an input and output. For example, consider the eager code:
  693. def forward(self, x):
  694. self.buffer += x
  695. return x * x
  696. This corresponds to a graph that looks like:
  697. def forward(self, buffer, x):
  698. mutated_buffer = aten.add(buffer, x)
  699. mul = aten.mul(x, x)
  700. return (mutated_buffer, mul)
  701. We want to inplace this into something that looks like the original
  702. eager code:
  703. def forward(self, buffer, x):
  704. mutated_buffer = aten.add(buffer, x)
  705. buffer.copy_(mutated_buffer)
  706. mul = aten.mul(x, x)
  707. return (mul,)
  708. Input mutations are handled similarly.
  709. """
  710. output_node = next(iter(reversed(graph.nodes)))
  711. assert output_node.op == "output" and len(output_node.args) == 1
  712. return_args = output_node.args[0]
  713. input_name_to_node = {
  714. node.name: node for node in graph.nodes if node.op == "placeholder"
  715. }
  716. mutation_name_to_input_name = {}
  717. # Collect mutated buffers.
  718. buffer_fqn_to_input_name = {
  719. buffer_fqn: k for k, buffer_fqn in graph_signature.inputs_to_buffers.items()
  720. }
  721. mutation_name_to_input_name = {
  722. k: buffer_fqn_to_input_name[buffer_fqn]
  723. for k, buffer_fqn in graph_signature.buffers_to_mutate.items()
  724. }
  725. # Collect mutated user inputs.
  726. mutation_name_to_input_name.update(graph_signature.user_inputs_to_mutate)
  727. num_mutations = len(mutation_name_to_input_name)
  728. for mutation in return_args[:num_mutations]:
  729. input_name = mutation_name_to_input_name[mutation.name]
  730. input_node = input_name_to_node[input_name]
  731. with graph.inserting_after(mutation):
  732. # Create a copy_ node that inplaces the mutation.
  733. new_node = graph.create_node(
  734. "call_function", torch.ops.aten.copy_.default, (input_node, mutation)
  735. )
  736. for k, v in mutation.meta.items():
  737. new_node.meta[k] = v
  738. # Replace all uses of the previously functional mutation with
  739. # our copy_ node.
  740. mutation.replace_all_uses_with(new_node, lambda x: x is not new_node)
  741. # Remove the mutated buffer / input from the graph outputs, since we don't
  742. # need to thread it through anymore.
  743. user_outputs = tuple(return_args[num_mutations:])
  744. output_node.args = ((user_outputs),)
  745. def _fix_nn_module_stacks(graph):
  746. # For each nn module stack in the graph, check if the fqns in it represent a stack:
  747. # 1. Each fqn must be a prefix of the next fqn.
  748. # 2. If not, remove the entries starting from the next fqn, emitting a warning.
  749. for node in graph.nodes:
  750. if "nn_module_stack" not in node.meta:
  751. continue
  752. nn_module_stack = node.meta["nn_module_stack"]
  753. fqns = [
  754. fqn.split("@")[0] if "@" in fqn else fqn
  755. for fqn, _t in nn_module_stack.values()
  756. ]
  757. # Check if each FQN is a prefix of the next one
  758. prev_fqn, *next_fqns = fqns
  759. num_valid_indices = 1 # root FQN
  760. for curr_fqn in next_fqns:
  761. # Check if the previous FQN is a prefix of the current one
  762. if _is_prefix(prev_fqn, curr_fqn):
  763. num_valid_indices += 1
  764. prev_fqn = curr_fqn
  765. else:
  766. # Found a non-prefix FQN, stop here
  767. break
  768. # If we need to remove entries, create a new stack with only valid entries
  769. if num_valid_indices < len(nn_module_stack):
  770. log.warning(
  771. "nn_module_stack fqns %s at node %s do not form a stack! dropping last %d entries",
  772. fqns,
  773. node,
  774. len(nn_module_stack) - num_valid_indices,
  775. )
  776. node.meta["nn_module_stack"] = dict(
  777. list(nn_module_stack.items())[:num_valid_indices]
  778. )
  779. def _is_prefix(candidate, target):
  780. """Check whether `candidate` is a prefix of `target`."""
  781. return len(candidate) < len(target) and target[: len(candidate)] == candidate
  782. def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
  783. if parent_fqn == "":
  784. # Handle the root module correctly.
  785. return child_fqn
  786. parent_split = parent_fqn.split(".")
  787. child_split = child_fqn.split(".")
  788. # TODO: support skip connection by inlining the child module.
  789. if child_split[: len(parent_split)] != parent_split:
  790. raise RuntimeError(
  791. f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'."
  792. "This is currently unsupported."
  793. "Please try to make child module attach to parent module directly."
  794. )
  795. return ".".join(child_split[len(parent_split) :])
  796. def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
  797. def graph_dump(graph: torch.fx.Graph) -> str:
  798. ret = []
  799. nodes_idx: dict[int, int] = {}
  800. def arg_dump(arg) -> str:
  801. if isinstance(arg, torch.fx.Node):
  802. return "%" + str(nodes_idx[id(arg)])
  803. return str(arg)
  804. for i, node in enumerate(graph.nodes):
  805. args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)]
  806. args_dump += [
  807. f"{key}={value}"
  808. for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
  809. ]
  810. target = node.target if node.op in ("call_function", "get_attr") else ""
  811. ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
  812. nodes_idx[id(node)] = i
  813. return "\n".join(ret)
  814. assert isinstance(x.graph, torch.fx.Graph)
  815. assert isinstance(y.graph, torch.fx.Graph)
  816. return graph_dump(x.graph) == graph_dump(y.graph)
  817. def _add_spec(gm: torch.nn.Module, spec) -> str:
  818. i = 0
  819. while hasattr(gm, f"_spec_{i}"):
  820. i += 1
  821. name = f"_spec_{i}"
  822. setattr(gm, name, spec)
  823. return name
  824. def _generate_flatten(gm: torch.fx.GraphModule, node) -> torch.fx.Node:
  825. flatten = gm.graph.call_function(pytree.tree_flatten, (node,))
  826. getitem_0 = gm.graph.call_function(operator.getitem, (flatten, 0))
  827. return getitem_0
  828. def _generate_flatten_spec(
  829. gm: Union[torch.fx.GraphModule, InterpreterModule, UnflattenedModule], node, spec
  830. ) -> torch.fx.Node:
  831. name = _add_spec(gm, spec)
  832. spec_node = gm.graph.get_attr(name)
  833. return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node))
  834. def _generate_unflatten(
  835. gm: Union[torch.fx.GraphModule, InterpreterModule, UnflattenedModule], nodes, spec
  836. ) -> torch.fx.Node:
  837. name = _add_spec(gm, spec)
  838. spec_node = gm.graph.get_attr(name)
  839. return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node))
  840. def _get_submodule(mod: torch.nn.Module, target: str):
  841. *prefix, field = target.split(".")
  842. for item in prefix:
  843. submod = getattr(mod, item, None)
  844. if submod is None:
  845. return None
  846. if not isinstance(submod, torch.nn.Module):
  847. return None
  848. mod = submod
  849. return getattr(mod, field, None)
  850. def _add_submodule(
  851. mod: torch.nn.Module,
  852. target: str,
  853. module_to_add: torch.nn.Module,
  854. create_module: Optional[Callable[[str], torch.nn.Module]] = None,
  855. ):
  856. *prefix, field = target.split(".")
  857. for i, item in enumerate(prefix):
  858. submod = getattr(mod, item, None)
  859. if submod is None:
  860. if create_module is not None:
  861. submod = create_module(".".join(prefix[: i + 1]))
  862. else:
  863. submod = torch.nn.Module()
  864. setattr(mod, item, submod)
  865. if not isinstance(submod, torch.nn.Module):
  866. return False
  867. mod = submod
  868. mod.add_module(field, module_to_add)
  869. def _call_name(base: str, n: int) -> str:
  870. # Given n >= 0, generate call names to a submodule `base` of the form
  871. # `base`, `base@1`, `base@2`, etc.
  872. return base if n == 1 else f"{base}@{n - 1}"
  873. def _is_call_name(call_name: str, base: str) -> bool:
  874. # Recognize when call_name = _call_name(base, n) for some n >= 0.
  875. return re.match(re.escape(base) + r"(@\d+)?$", call_name) is not None
  876. class _ModuleFrame:
  877. def __init__(
  878. self,
  879. flat_graph: torch.fx.Graph,
  880. nodes: tuple[torch.fx.Node, ...],
  881. seen_nodes,
  882. seen_modules,
  883. seen_attrs,
  884. created_modules,
  885. parent,
  886. module_stack: list[tuple[str, Optional[str], int]],
  887. module_id,
  888. module_call_graph: dict[str, ModuleCallSignature],
  889. module: Optional[Union[torch.fx.GraphModule, UnflattenedModule]] = None,
  890. ):
  891. self.flat_graph = flat_graph
  892. self.nodes = nodes
  893. self.seen_nodes = seen_nodes
  894. self.seen_modules = seen_modules
  895. self.seen_attrs = seen_attrs
  896. self.created_modules = created_modules
  897. self.parent = parent
  898. self.module_stack = module_stack
  899. self.module_id = module_id
  900. self.module_call_graph = module_call_graph
  901. self.verbose = False
  902. self.fqn, ty, num_calls = self.module_stack[-1]
  903. # generate call name for self.fqn
  904. self.child_fqn = _call_name(self.fqn, num_calls + 1)
  905. self.module: Union[torch.fx.GraphModule, UnflattenedModule, InterpreterModule]
  906. if module is not None:
  907. self.module = module
  908. self.ivals = module.ivals if hasattr(module, "ivals") else {} # type: ignore[var-annotated]
  909. else:
  910. self.module = self.created_modules.get(
  911. self.fqn,
  912. InterpreterModule(torch.fx.Graph(), ty=ty),
  913. )
  914. self.ivals = parent.ivals
  915. self.graph = self.module.graph
  916. # Mapping of nodes in the flat graph to nodes in this graph.
  917. self.node_map: dict[torch.fx.Node, torch.fx.Node] = {}
  918. self.node_to_placeholder = {}
  919. self.parent_call_module: Optional[torch.fx.Node] = None
  920. if parent is not None:
  921. accessor = _compute_accessor(parent.fqn, self.child_fqn)
  922. def create_module(fqn):
  923. path = f"{parent.fqn}.{fqn}" if parent.fqn else fqn
  924. if path in self.created_modules:
  925. return self.created_modules[path]
  926. submod = InterpreterModule(torch.fx.Graph(), ty=ty)
  927. self.created_modules[path] = submod
  928. return submod
  929. _add_submodule(parent.module, accessor, self.module, create_module)
  930. self.parent_call_module = parent.graph.call_module(accessor)
  931. if self.seen_modules[self.module_id]:
  932. base_module_frame = self.seen_modules[self.module_id][0]
  933. self.module._modules = base_module_frame.module._modules
  934. self.seen_modules[self.module_id].append(
  935. _SubmoduleEntry(
  936. parent_fqn=self.parent.fqn,
  937. parent_module=self.parent.module,
  938. parent_call_module=self.parent_call_module,
  939. fqn=self.fqn,
  940. call_idx=num_calls + 1,
  941. module=self.module,
  942. )
  943. )
  944. signature = module_call_graph.get(self.child_fqn)
  945. if signature is not None and self.parent is not None:
  946. assert signature.in_spec.num_children == 2
  947. args_spec = signature.in_spec.children_specs[0]
  948. kwargs_spec = signature.in_spec.children_specs[1]
  949. assert args_spec.context is None
  950. assert kwargs_spec.context is not None
  951. with self.graph.inserting_after(None):
  952. arg_nodes = [
  953. self.graph.placeholder(f"_positional_arg_{idx}")
  954. for idx in range(args_spec.num_children)
  955. ]
  956. kwarg_nodes = {}
  957. for name in kwargs_spec.context:
  958. kwarg_nodes[name] = self.graph.placeholder(name)
  959. flat_args = _generate_flatten_spec(
  960. self.module,
  961. (tuple(arg_nodes), kwarg_nodes),
  962. signature.in_spec,
  963. )
  964. for idx, arg in enumerate(signature.inputs):
  965. flat_arg_node = self.graph.create_node(
  966. op="call_function",
  967. target=operator.getitem,
  968. args=(flat_args, idx),
  969. name=(
  970. arg.name
  971. if not isinstance(arg, ConstantArgument)
  972. else f"_constant_{idx}"
  973. ),
  974. )
  975. if isinstance(arg, ConstantArgument):
  976. continue
  977. if arg.name in self.seen_nodes:
  978. flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
  979. self.node_to_placeholder[self.seen_nodes[arg.name]] = (
  980. flat_arg_node
  981. )
  982. with self.parent.graph.inserting_before(self.parent_call_module):
  983. input_nodes: list[Optional[torch.fx.Node]] = []
  984. for input in signature.inputs:
  985. if isinstance(input, ConstantArgument):
  986. input_nodes.append(input.value) # type: ignore[arg-type]
  987. elif input.name not in self.seen_nodes:
  988. input_nodes.append(None)
  989. else:
  990. assert isinstance(
  991. input,
  992. (
  993. TensorArgument,
  994. SymIntArgument,
  995. SymBoolArgument,
  996. SymFloatArgument,
  997. ),
  998. )
  999. input_nodes.append(
  1000. self.parent.remap_input(self.seen_nodes[input.name])
  1001. )
  1002. inputs_node = _generate_unflatten(
  1003. self.parent.module,
  1004. input_nodes,
  1005. signature.in_spec,
  1006. )
  1007. args_node = self.parent.graph.call_function(
  1008. operator.getitem, (inputs_node, 0)
  1009. )
  1010. kwargs_node = self.parent.graph.call_function(
  1011. operator.getitem, (inputs_node, 1)
  1012. )
  1013. arg_nodes = [
  1014. self.parent.graph.call_function(operator.getitem, (args_node, i))
  1015. for i in range(args_spec.num_children)
  1016. ]
  1017. kwarg_nodes = {
  1018. k: self.parent.graph.call_function(
  1019. operator.getitem, (kwargs_node, k)
  1020. )
  1021. for k in kwargs_spec.context
  1022. }
  1023. assert self.parent_call_module is not None
  1024. self.parent_call_module.args = tuple(arg_nodes)
  1025. self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment]
  1026. def add_placeholder(self, x):
  1027. assert self.fqn != "", f"Cannot add placeholder {x} to root module"
  1028. assert x.graph is self.flat_graph
  1029. # x is not in subgraph, create a new placeholder for subgraph
  1030. with self.graph.inserting_before(None):
  1031. placeholder_node = self.graph.placeholder(x.name, type_expr=x.type)
  1032. # copy all meta fields, even if some fields might be irrelevant for
  1033. # the placeholder node
  1034. placeholder_node.meta = copy.copy(x.meta)
  1035. self.node_to_placeholder[x] = placeholder_node
  1036. def copy_sym_call_function(self, x):
  1037. # This only exists because we deduplicate sym_size nodes in the flat export graph,
  1038. # and if preserve_module_call_signature is set, we may not be able to pass sym_size
  1039. # nodes, or their downstream users, as inputs to submodule calls.
  1040. # To avoid this we copy these call_function nodes with sym_type results.
  1041. # This should however only be done for sym_type nodes - call_function nodes on tensors
  1042. # should not be deduplicated in the first place.
  1043. args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args)
  1044. kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs)
  1045. node = self.graph.call_function(x.target, args, kwargs)
  1046. node.meta = copy.copy(x.meta)
  1047. self.node_map[x] = node
  1048. return node
  1049. def remap_input(self, x):
  1050. assert x.graph is self.flat_graph
  1051. if x in self.node_map:
  1052. return self.node_map[x]
  1053. self.print(f"remap_input({x})")
  1054. if x in self.node_to_placeholder:
  1055. return self.node_to_placeholder[x]
  1056. elif (
  1057. x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None
  1058. # allow placeholder creation if we are not preserving module call signature
  1059. ):
  1060. self.add_placeholder(x)
  1061. if self.parent_call_module is not None:
  1062. # Important to *prepend* the output to match how we are
  1063. # inserting placeholder nodes.
  1064. with self.parent.graph.inserting_before(self.parent_call_module):
  1065. self.parent_call_module.insert_arg(0, self.parent.remap_input(x))
  1066. return self.node_to_placeholder[x]
  1067. elif x.op == "call_function" and (
  1068. x.target
  1069. in (
  1070. torch.ops.aten.sym_size.int,
  1071. torch.ops.aten.item.default,
  1072. torch.ops.aten.unbind.int,
  1073. torch.ops.aten.sum.dim_IntList,
  1074. torch.ops.aten.view.default,
  1075. torch.ops.aten.diff.default,
  1076. )
  1077. or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator")
  1078. ):
  1079. # export deduplicates sym_size nodes, and may need to re-copy them
  1080. # if module call signature needs to be preserved
  1081. self.copy_sym_call_function(x)
  1082. return self.node_map[x]
  1083. elif self.module_call_graph.get(self.fqn) is not None:
  1084. # x is reading the intermediate value of a mutation, so record it;
  1085. # later we will find where it was created and perform the update
  1086. return self.ivals.read(self, x) # type: ignore[operator, union-attr]
  1087. else:
  1088. raise RuntimeError(
  1089. f"Could not run remap_input() on op type: {x.op} for node {x}"
  1090. )
  1091. def finalize_outputs(self):
  1092. self.created_modules.pop(self.fqn, None)
  1093. orig_outputs = []
  1094. signature = self.module_call_graph.get(self.child_fqn)
  1095. if signature is not None and self.parent is not None:
  1096. for output in signature.outputs:
  1097. if isinstance(
  1098. output,
  1099. (
  1100. TensorArgument,
  1101. SymIntArgument,
  1102. SymBoolArgument,
  1103. SymFloatArgument,
  1104. ConstantArgument,
  1105. ),
  1106. ):
  1107. if output.name in self.seen_nodes:
  1108. orig_outputs.append(self.seen_nodes[output.name])
  1109. else:
  1110. orig_outputs.append(None)
  1111. else:
  1112. raise RuntimeError(
  1113. f"Unsupported data type for output node: {output}"
  1114. )
  1115. def get_actual_output_node(output):
  1116. if output is None:
  1117. return None
  1118. seen_node = self.seen_nodes[output.name]
  1119. if seen_node in self.node_map:
  1120. return self.node_map[seen_node]
  1121. elif seen_node in self.node_to_placeholder:
  1122. return self.node_to_placeholder[seen_node]
  1123. else:
  1124. raise RuntimeError(
  1125. f"Could not find output node {output}. Graph: {self.graph}"
  1126. )
  1127. tree_out_node = _generate_unflatten(
  1128. self.module,
  1129. tuple(get_actual_output_node(output) for output in orig_outputs),
  1130. signature.out_spec,
  1131. )
  1132. parent_out: Optional[torch.fx.Node] = _generate_flatten_spec(
  1133. self.parent.module, self.parent_call_module, signature.out_spec
  1134. )
  1135. graph_outputs: Union[torch.fx.Node, list[torch.fx.Node]] = tree_out_node
  1136. else:
  1137. graph_outputs = []
  1138. # Iterate through nodes we have copied into self.graph.
  1139. for orig_node in self.node_map.keys():
  1140. for user_node in orig_node.users:
  1141. if user_node.name not in self.seen_nodes:
  1142. # external user node, need to expose as an output
  1143. orig_outputs.append(orig_node)
  1144. graph_outputs.append(self.node_map[orig_node])
  1145. break
  1146. parent_out = self.parent_call_module
  1147. if len(graph_outputs) == 1:
  1148. graph_outputs = graph_outputs[0]
  1149. assert isinstance(graph_outputs, (list, torch.fx.Node))
  1150. self.graph.output(graph_outputs)
  1151. # Rewrite outputs in parent module
  1152. if parent_out is None:
  1153. return
  1154. parent_out.meta["val"] = (
  1155. graph_outputs.meta.get("val")
  1156. if isinstance(graph_outputs, torch.fx.Node)
  1157. else [o.meta.get("val") for o in graph_outputs]
  1158. )
  1159. if len(orig_outputs) == 1 and signature is None:
  1160. self.parent.node_map[orig_outputs[0]] = parent_out
  1161. else:
  1162. for i, orig_output in enumerate(orig_outputs):
  1163. if orig_output is None:
  1164. continue
  1165. # Use Proxy to record getitem access.
  1166. proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index]
  1167. proxy_out.meta["val"] = orig_output.meta.get("val")
  1168. self.parent.node_map[orig_output] = proxy_out
  1169. def copy_node(self, node):
  1170. self.print("copying", node.format_node())
  1171. self.node_map[node] = self.graph.node_copy(node, self.remap_input)
  1172. self.seen_nodes[node.name] = node
  1173. def run_outer(self):
  1174. for i, node in enumerate(self.flat_graph.nodes):
  1175. self.print(i, node.meta.get("nn_module_stack"), node.format_node())
  1176. # Copy all graph inputs
  1177. node_idx: int = 0
  1178. node = self.nodes[node_idx]
  1179. while node.op == "placeholder":
  1180. self.copy_node(node)
  1181. node_idx += 1
  1182. node = self.nodes[node_idx]
  1183. self.run_from(node_idx)
  1184. # Copy graph outputs
  1185. for node in self.flat_graph.nodes:
  1186. if node.op == "output":
  1187. self.copy_node(node)
  1188. def print(self, *args, **kwargs):
  1189. if self.verbose:
  1190. print(*args, **kwargs)
  1191. def run_from(self, node_idx):
  1192. module_idx = 0
  1193. # Walk through the graph, building up a new graph with the right submodules
  1194. while node_idx < len(self.nodes):
  1195. node = self.nodes[node_idx]
  1196. assert node.op != "placeholder"
  1197. self.print()
  1198. self.print("STEP", node_idx, node.format_node())
  1199. self.print(self.module_stack)
  1200. depth = len(self.module_stack)
  1201. if node.op == "output":
  1202. if depth == 1:
  1203. # We want the output node of the original graph to be handled
  1204. # specially by the outermost stack frame (in run_outer). So
  1205. # skip finalization here.
  1206. return node_idx
  1207. # We've reached the end of the graph. Wrap up all the existing stack frames.
  1208. self.finalize_outputs()
  1209. return node_idx
  1210. if len(node.meta.get("nn_module_stack", {})) == 0:
  1211. raise RuntimeError(f"Unable to find nn_module_stack for node {node}")
  1212. nn_module_stack = node.meta["nn_module_stack"]
  1213. from torch._export.passes._node_metadata_hook import (
  1214. _EMPTY_NN_MODULE_STACK_KEY,
  1215. )
  1216. if (
  1217. len(nn_module_stack) == 1
  1218. and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack
  1219. ):
  1220. # Empty case from the node_metadata_hook
  1221. node_module_stack = self.module_stack
  1222. else:
  1223. node_module_stack = [
  1224. (
  1225. path,
  1226. ty if path else None,
  1227. int(k.split("@")[-1]) if "@" in k else 0,
  1228. )
  1229. for k, (path, ty) in node.meta["nn_module_stack"].items()
  1230. ]
  1231. if node_module_stack[:depth] != self.module_stack:
  1232. # This means that the current module is done executing and the
  1233. # current node is the beginning of a new module.
  1234. #
  1235. # In this case, we should finalize this module and return without
  1236. # incrementing the node counter.
  1237. self.finalize_outputs()
  1238. self.print("outlining", self.fqn)
  1239. self.print(self.graph)
  1240. return node_idx
  1241. assert node_module_stack is not None
  1242. if _is_prefix(self.module_stack, node_module_stack):
  1243. # This means that the current node represents the execution of a new
  1244. # module.
  1245. next_module = node_module_stack[depth]
  1246. self.print("Creating new stack frame for", next_module)
  1247. # Run a nested version of module outliner from the current node
  1248. # counter. Once it is complete, continue from that point.
  1249. next_module_key = list(node.meta["nn_module_stack"].keys())[depth]
  1250. node_idx = _ModuleFrame(
  1251. self.flat_graph,
  1252. self.nodes,
  1253. self.seen_nodes,
  1254. self.seen_modules,
  1255. self.seen_attrs,
  1256. self.created_modules,
  1257. self,
  1258. self.module_stack + [next_module],
  1259. next_module_key.split("@")[0],
  1260. self.module_call_graph,
  1261. ).run_from(node_idx)
  1262. module_idx += 1
  1263. continue
  1264. # The only remaining possibility is that we are in the right stack
  1265. # frame. Copy the node into this frame's graph and increment the node counter.
  1266. assert node_module_stack == self.module_stack
  1267. if node.op == "get_attr":
  1268. # this must be a graph argument for a HOP
  1269. self.seen_attrs[self.child_fqn].add(node.target)
  1270. self.copy_node(node)
  1271. node_idx += 1
  1272. @dataclass
  1273. class _SubmoduleEntry:
  1274. parent_fqn: str
  1275. parent_module: torch.nn.Module
  1276. parent_call_module: torch.fx.Node
  1277. fqn: str
  1278. call_idx: int
  1279. module: torch.nn.Module
  1280. def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
  1281. seen_nodes: dict[str, torch.fx.Node] = {}
  1282. seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
  1283. seen_attrs: dict[str, set[str]] = defaultdict(set)
  1284. created_modules: dict[str, torch.nn.Module] = {}
  1285. _ModuleFrame(
  1286. orig_graph,
  1287. tuple(orig_graph.nodes),
  1288. seen_nodes,
  1289. seen_modules,
  1290. seen_attrs,
  1291. created_modules,
  1292. None,
  1293. [("", None, 0)],
  1294. "",
  1295. {
  1296. entry.fqn: entry.signature
  1297. for entry in root_module.module_call_graph
  1298. if entry.signature
  1299. },
  1300. module=root_module,
  1301. ).run_outer()
  1302. return seen_modules, seen_attrs
  1303. def _reorder_submodules(
  1304. parent: torch.nn.Module, fqn_order: dict[str, int], prefix: str = ""
  1305. ):
  1306. # TODO Can be optimized by adding submodules ahead of time.
  1307. if prefix == "":
  1308. for fqn in list(fqn_order.keys())[1:]:
  1309. if _get_submodule(parent, fqn) is None:
  1310. _add_submodule(parent, fqn, torch.nn.Module())
  1311. children = []
  1312. for name, child in list(parent._modules.items()):
  1313. if child is None:
  1314. continue
  1315. fqn = prefix + name
  1316. _reorder_submodules(child, fqn_order, prefix=fqn.split("@")[0] + ".")
  1317. delattr(parent, name)
  1318. children.append((fqn_order[fqn], name, child))
  1319. children.sort(key=operator.itemgetter(0))
  1320. for _, name, child in children:
  1321. parent.register_module(name, child)
  1322. class _IVals:
  1323. """
  1324. Collect the intermediate values of mutations in a graph.
  1325. Example: in the following graph, suppose that buf_in and buf_out
  1326. are the input and output values of a buffer.
  1327. buf_in = placeholder()
  1328. ...
  1329. ival1 = f0(buf_in, ...) # inside self.n0(...)
  1330. ...
  1331. ival2 = f1(ival1, ...) # inside self.n1(...)
  1332. ...
  1333. buf_out = f2(ival2, ...) # inside self.n2(...)
  1334. return buf_out, ...
  1335. Here ival1 and ival2 are intermediate values created inside
  1336. calls to n0 and n1 respectively, and used inside calls to
  1337. n1 and n2 respectively.
  1338. """
  1339. def __init__(self):
  1340. # for each fqn, set of node names corresponding to intermediate values
  1341. self.node_names_by_fqn = defaultdict(set)
  1342. def _is_mutable(self, target):
  1343. if isinstance(target, torch._ops.OpOverload):
  1344. return target._schema.is_mutable
  1345. return False
  1346. def read(self, mf, node):
  1347. """
  1348. Read state corresponding to a given intermediate value.
  1349. """
  1350. # we can assume that the node must be from a mutation
  1351. assert node.op == "call_function"
  1352. b = self._is_mutable(node.target)
  1353. print("Checking mutability", node.target, b)
  1354. if not b:
  1355. # so the mutation was functionalized;
  1356. # we will apply the original mutation later (see below)
  1357. fqn, _ = next(reversed(node.meta["nn_module_stack"].values()))
  1358. self.node_names_by_fqn[fqn].add(node.name)
  1359. return mf.remap_input(node.args[0])
  1360. def update(self, partitions):
  1361. """
  1362. Update states corresponding to intermediate values that were read.
  1363. """
  1364. for shared_submodules in partitions:
  1365. for entry in shared_submodules:
  1366. graph = entry.module.graph
  1367. node_names = self.node_names_by_fqn[entry.fqn]
  1368. nodes = [n for n in graph.nodes if n.name in node_names]
  1369. for node in nodes:
  1370. # so node must be from a functionalized mutation;
  1371. # we perform the original mutation now
  1372. with graph.inserting_after(node):
  1373. new_node = graph.create_node(
  1374. "call_function",
  1375. torch.ops.aten.copy_.default,
  1376. (node.args[0], node),
  1377. )
  1378. new_node.meta = copy.copy(node.meta)
  1379. def _copy_graph_attrs(
  1380. gm: torch.fx.GraphModule,
  1381. root_module: UnflattenedModule,
  1382. seen_attrs: dict[str, set[str]],
  1383. ):
  1384. for child_fqn, names in seen_attrs.items():
  1385. module = _get_attr(root_module, child_fqn) if child_fqn else root_module
  1386. for name in names:
  1387. val = getattr(gm, name)
  1388. setattr(module, name, val)
  1389. def _deduplicate_modules(partitions):
  1390. redirected_call_indices = {}
  1391. for shared_submodules in partitions:
  1392. for i, entry in enumerate(shared_submodules):
  1393. child_fqn = _call_name(entry.fqn, entry.call_idx)
  1394. target = _compute_accessor(entry.parent_fqn, child_fqn)
  1395. deduplicated = False
  1396. # Iterate over all previously seen modules, and deduplicate if possible
  1397. for seen in shared_submodules[:i]:
  1398. if _check_graph_equivalence(seen.module, entry.module):
  1399. parent = entry.parent_module
  1400. # Since graphs are equivalent, we can deduplicate.
  1401. # There are two cases.
  1402. if seen.fqn == entry.fqn:
  1403. # Case 1: The current module has the same fqn as the seen module.
  1404. # In this case we have generated a call name that can be optimized away.
  1405. # So we remove the current module from the hierarchy and replace
  1406. # the current call name with the seen call name in the parent graph.
  1407. *prefix, name = target.split(".")
  1408. _get_attr_via_attr_list(parent, prefix)._modules.pop(name)
  1409. seen_child_fqn = _call_name(seen.fqn, seen.call_idx)
  1410. seen_target = _compute_accessor(
  1411. entry.parent_fqn, seen_child_fqn
  1412. )
  1413. entry.parent_call_module.target = seen_target
  1414. redirected_call_indices[child_fqn] = seen_child_fqn
  1415. break
  1416. elif not deduplicated:
  1417. # Case 2: The current module has a different fqn than the seen module.
  1418. # In this case we replace the current module with the seen module.
  1419. # There should be nothing pointing to the current module any more,
  1420. # so it can be garbage collected.
  1421. # NOTE: We *do not* replace the current call name with the seen call name
  1422. # in the parent graph, because this will lose information on which fqn
  1423. # was actually called. However, it is possible that the current call name
  1424. # will be optimized away when we find another seen module with the same fqn,
  1425. # so we do not break out of the loop yet.
  1426. parent.set_submodule(target, seen.module)
  1427. deduplicated = True
  1428. return redirected_call_indices
  1429. def _sink_params(
  1430. module: torch.nn.Module,
  1431. inputs_to_state: dict[str, list[str]],
  1432. scope: list[str],
  1433. module_id_to_inputs_removed: Optional[dict[int, set[str]]] = None,
  1434. ):
  1435. """Sink params, buffers, and constants from graph inputs into get_attr nodes.
  1436. Exported modules are purely functional, so they pass their parameters and
  1437. buffers in as inputs to the graph.
  1438. To replicate eager's semantics, we need to get them from the module state
  1439. via get_attr instead.
  1440. module: GraphModule, potentially containing nested submodules.
  1441. inputs_to_state: mapping graph input names to the corresponding key in the state_dict.
  1442. scope: tracks where we are in the module hierarchy, so that we can emit the
  1443. right `getattr(self, "foo.bar")` calls, etc.
  1444. module_id_to_inputs_removed: records inputs removed by child modules, mapping
  1445. the module object id to the list of placeholder node names in the child module
  1446. that were removed.
  1447. """
  1448. if module_id_to_inputs_removed is None:
  1449. module_id_to_inputs_removed = defaultdict(set)
  1450. if id(module) in module_id_to_inputs_removed:
  1451. return {id(module): module_id_to_inputs_removed[id(module)]}
  1452. # We need to use _modules here instead of named_children(), because we
  1453. # explicitly want duplicate modules to show up in the traversal.
  1454. for name, submodule in module._modules.items():
  1455. submod_id_to_inputs_removed = _sink_params(
  1456. cast("torch.nn.Module", submodule),
  1457. inputs_to_state,
  1458. scope + [name],
  1459. module_id_to_inputs_removed,
  1460. )
  1461. for k, v in submod_id_to_inputs_removed.items():
  1462. module_id_to_inputs_removed[k].update(v)
  1463. graph = getattr(module, "graph", None)
  1464. if graph is None or len(graph.nodes) == 0:
  1465. # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)
  1466. return module_id_to_inputs_removed
  1467. assert isinstance(graph, torch.fx.Graph)
  1468. inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes))
  1469. the_last_input = None if len(inputs) == 0 else inputs[-1]
  1470. # Also remove from call_module nodes
  1471. call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes)
  1472. for node in call_module_nodes:
  1473. submodule = _get_attr(module, node.target)
  1474. # remove placeholder from call_module node arguments, only if we've
  1475. # erased the placeholder node in the corresponding _sink_params() call
  1476. if submodule is not None and id(submodule) in module_id_to_inputs_removed:
  1477. node.args = tuple(
  1478. filter(
  1479. lambda n: n.name not in module_id_to_inputs_removed[id(submodule)],
  1480. node.args,
  1481. )
  1482. )
  1483. # Filter out inputs_to_state corresponding to current scope.
  1484. inputs_to_state_of_scope: dict[torch.fx.Node, list[str]] = {}
  1485. for node in inputs:
  1486. if node.name not in inputs_to_state:
  1487. continue
  1488. state_name = None
  1489. for sn in inputs_to_state[node.name]:
  1490. sn_split = sn.split(".")
  1491. if sn_split[: len(scope)] == [x.split("@")[0] for x in scope]:
  1492. state_name = sn_split
  1493. break
  1494. # If there's a mismatch between scope name and state name, then
  1495. # there must be multiple scopes pointing to the same state name,
  1496. # meaning some modules are shared. In such case, we can simply skip
  1497. # updating the current node because another later iteration will
  1498. # take care of this input node when the unique match between scope
  1499. # and state name occurs. To make sure this always happen, we should
  1500. # enforce the invariant that no placeholder node in the unflattened
  1501. # graph appears in inputs_to_state dict, which means all the extra
  1502. # input nodes have been handled.
  1503. if state_name is None:
  1504. continue
  1505. inputs_to_state_of_scope[node] = state_name
  1506. # Record name of remove inputs for return purpose.
  1507. inputs_removed: set[str] = set()
  1508. for node, state_name in inputs_to_state_of_scope.items():
  1509. if len(node.users) > 0:
  1510. attr_path = state_name[len(scope) :]
  1511. state_attr = _get_attr_via_attr_list(module, attr_path)
  1512. assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject))
  1513. # Make sure the newly created get_attr node is placed after the last placeholder node
  1514. with graph.inserting_after(the_last_input):
  1515. new_node = graph.create_node("get_attr", ".".join(attr_path))
  1516. node.replace_all_uses_with(new_node, propagate_meta=True)
  1517. graph.erase_node(node)
  1518. inputs_removed.add(node.name)
  1519. if isinstance(module, InterpreterModule):
  1520. module.finalize()
  1521. return {id(module): inputs_removed}