unflatten.py 69 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803
  1. # mypy: allow-untyped-defs
  2. import abc
  3. import copy
  4. import logging
  5. import operator
  6. import re
  7. from collections import defaultdict
  8. from collections.abc import Callable
  9. from contextlib import contextmanager
  10. from copy import deepcopy
  11. from dataclasses import dataclass
  12. from enum import Enum
  13. from typing import Any, cast
  14. import torch
  15. import torch.fx._pytree as fx_pytree
  16. import torch.utils._pytree as pytree
  17. from torch._library.fake_class_registry import FakeScriptObject
  18. from torch.export import ExportedProgram
  19. from torch.export._tree_utils import reorder_kwargs
  20. from torch.export.exported_program import (
  21. ConstantArgument,
  22. ExportGraphSignature,
  23. InputKind,
  24. ModuleCallSignature,
  25. SymBoolArgument,
  26. SymFloatArgument,
  27. SymIntArgument,
  28. TensorArgument,
  29. )
  30. from torch.fx._symbolic_trace import is_fx_symbolic_tracing
  31. from torch.fx.graph_module import _get_attr, _get_attr_via_attr_list, _print_readable
  32. from torch.utils._pytree import GetAttrKey, SequenceKey
  33. from ._remove_effect_tokens_pass import _remove_effect_tokens
  34. log = logging.getLogger(__name__)
  35. __all__ = [
  36. "FlatArgsAdapter",
  37. "InterpreterModule",
  38. "InterpreterModuleDispatcher",
  39. "UnflattenedModule",
  40. "unflatten",
  41. ]
  42. class _AttrKind(Enum):
  43. PARAMETER = "parameter"
  44. BUFFER = "buffer"
  45. CONSTANT = "constant"
  46. MODULE = "module"
  47. @dataclass(frozen=True)
  48. class _TensorID:
  49. """Custom tensor identifier containing storage, stride, and size information."""
  50. untyped_storage: torch.UntypedStorage
  51. stride: tuple
  52. size: tuple
  53. storage_offset: int
  54. RUN_WITH_INTERPRETER = True
  55. @contextmanager
  56. def _disable_interpreter():
  57. global RUN_WITH_INTERPRETER
  58. old_flag = RUN_WITH_INTERPRETER
  59. RUN_WITH_INTERPRETER = False
  60. try:
  61. yield
  62. finally:
  63. RUN_WITH_INTERPRETER = old_flag
  64. # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
  65. # This installs empty Modules where none exist yet if they are subpaths of target
  66. def _assign_attr(
  67. from_obj: torch.Tensor | torch.ScriptObject | torch.nn.Module,
  68. to_module: torch.nn.Module,
  69. target: str,
  70. attr_kind: _AttrKind,
  71. persistent: bool = True,
  72. ):
  73. *prefix, field = target.split(".")
  74. # We need to generate all submodules of `to_module` that are at `prefix` and
  75. # variants of `prefix` that differ only by call name. All of these submodules
  76. # will then be assigned `from_obj` at `field` so that they can share this attribute.
  77. # For example, if target is foo.bar.f, foo has another call name foo@1,
  78. # and bar has other call names bar@1, bar@2, then we will assign f to
  79. # foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2.
  80. to_modules = {to_module}
  81. for item in prefix:
  82. ts: set[torch.nn.Module] = set()
  83. for to_module in to_modules:
  84. if not hasattr(to_module, item):
  85. setattr(to_module, item, torch.nn.Module())
  86. ts.update(
  87. t_call # type: ignore[misc]
  88. for k, t_call in to_module._modules.items()
  89. if _is_call_name(k, item)
  90. )
  91. to_modules = ts
  92. for to_module in to_modules:
  93. if attr_kind == _AttrKind.PARAMETER:
  94. assert isinstance(from_obj, torch.nn.Parameter)
  95. to_module.register_parameter(field, from_obj)
  96. elif attr_kind == _AttrKind.BUFFER:
  97. assert isinstance(from_obj, torch.Tensor)
  98. to_module.register_buffer(field, from_obj, persistent=persistent)
  99. elif attr_kind == _AttrKind.CONSTANT:
  100. assert not isinstance(from_obj, FakeScriptObject), (
  101. "FakeScriptObject should only exist during tracing."
  102. )
  103. assert isinstance(
  104. from_obj,
  105. (
  106. torch.Tensor,
  107. torch.ScriptObject,
  108. ),
  109. )
  110. setattr(to_module, field, from_obj)
  111. elif attr_kind == _AttrKind.MODULE:
  112. assert isinstance(from_obj, torch.nn.Module)
  113. setattr(to_module, field, from_obj)
  114. class _SubmoduleBase:
  115. _ty: str | None
  116. def type_name(self) -> str | None:
  117. """
  118. Subclass of this class - InterpreterModule, InterpreterModuleDispatcher, represents
  119. corresponding model in eager model. To get this type information for those modules
  120. in eager model we need to use this method.
  121. """
  122. return self._ty
  123. class InterpreterModule(_SubmoduleBase, torch.nn.Module):
  124. """A module that uses torch.fx.Interpreter to execute instead of the usual
  125. codegen that GraphModule uses. This provides better stack trace information
  126. and makes it easier to debug execution.
  127. """
  128. graph_module: torch.fx.GraphModule | None
  129. def __init__(
  130. self,
  131. graph: torch.fx.Graph,
  132. ty: str | None = None,
  133. ):
  134. super().__init__()
  135. self.graph = graph
  136. self._ty = ty
  137. self.graph.owning_module = self # type: ignore[assignment]
  138. self._run_with_interpreter = RUN_WITH_INTERPRETER
  139. def forward(self, *args, **kwargs):
  140. assert self.graph_module is not None, "Didn't finalize this InterpreterModule"
  141. if not is_fx_symbolic_tracing() and (
  142. torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter
  143. ):
  144. # Dynamo cannot trace through torch.fx.Interpreter, so fall back to
  145. # GraphModule codegen in this instance.
  146. # Patch the codegened forward to run with this InterpreterModule,
  147. # so attribute accesses, etc. are on this module instead.
  148. return type(self.graph_module).forward(self, *args, **kwargs)
  149. else:
  150. if kwargs:
  151. # Handle **kwargs. FX only natively supports positional
  152. # arguments (through placeholders). So in order to pass in
  153. # kwargs, we must correspond the names of the placeholders with
  154. # the keys in the kwarg dict.
  155. arg_list = list(args)
  156. kwarg_names = self.arg_names[len(arg_list) :]
  157. arg_list.extend(
  158. kwargs[kwarg_name]
  159. for kwarg_name in kwarg_names
  160. if kwarg_name in kwargs
  161. )
  162. # Assert that the kwargs passed in exactly match the positional
  163. # arguments specified by the GraphModule. This should be
  164. # guaranteed by the unflattening process.
  165. assert len(kwarg_names) == len(kwargs)
  166. assert len(arg_list) == len(self.arg_names)
  167. args = tuple(arg_list)
  168. return torch.fx.Interpreter(self, graph=self.graph).run(
  169. *args, enable_io_processing=False
  170. )
  171. def finalize(self):
  172. # We need to "finalize" because GraphModule populates its own state_dict
  173. # based on the get_attrs observed in the graph. So we need to fully
  174. # construct the graph and call _sink_params before generating this
  175. # GraphModule.
  176. # need to set `graph_module` directly on the dict to avoid it getting
  177. # registered as a submodule.
  178. self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
  179. self.graph.lint()
  180. # Cache arg names for kwarg handling (see forward())
  181. self.arg_names = []
  182. for node in self.graph.nodes:
  183. if node.op == "placeholder":
  184. self.arg_names.append(node.target)
  185. def print_readable(
  186. self,
  187. print_output=True,
  188. include_stride=False,
  189. include_device=False,
  190. colored=False,
  191. ):
  192. return _print_readable(
  193. self,
  194. "InterpreterModule",
  195. print_output,
  196. include_stride,
  197. include_device,
  198. colored,
  199. )
  200. class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module):
  201. """
  202. A module that carries a sequence of InterpreterModules corresponding to
  203. a sequence of calls of that module. Each call to the module dispatches
  204. to the next InterpreterModule, and wraps back around after the last.
  205. """
  206. def __init__(self, attrs: set[str], call_modules: list[InterpreterModule]):
  207. super().__init__()
  208. assert call_modules
  209. self._modules = call_modules[0]._modules
  210. for accessor in attrs:
  211. setattr(self, accessor, getattr(call_modules[0], accessor))
  212. self._ty = call_modules[0]._ty
  213. self._call_modules = call_modules
  214. self._num_calls = 0
  215. def forward(self, *args, **kwargs):
  216. call_module = self._call_modules[self._num_calls]
  217. self._num_calls = (self._num_calls + 1) % len(self._call_modules)
  218. try:
  219. return call_module(*args, **kwargs)
  220. except Exception:
  221. self._num_calls = 0
  222. raise
  223. def call_modules(self):
  224. return self._call_modules
  225. def print_readable(
  226. self,
  227. print_output=True,
  228. include_stride=False,
  229. include_device=False,
  230. colored=False,
  231. ):
  232. outputs = [
  233. mod.print_readable(
  234. print_output,
  235. include_stride,
  236. include_device,
  237. colored,
  238. )
  239. for mod in self._call_modules
  240. ]
  241. return "\n".join(outputs)
  242. class FlatArgsAdapter(abc.ABC):
  243. """
  244. Adapts input arguments with ``input_spec`` to align ``target_spec``.
  245. """
  246. @abc.abstractmethod
  247. def adapt(
  248. self,
  249. target_spec: pytree.TreeSpec,
  250. input_spec: pytree.TreeSpec,
  251. input_args: list[Any],
  252. metadata: dict[str, Any] | None = None,
  253. obj: Any | None = None,
  254. ) -> list[Any]:
  255. """NOTE: This adapter may mutate given ``input_args_with_path``."""
  256. ...
  257. def get_flat_arg_paths(self) -> list[str]:
  258. """Returns a list of paths that are used to access the flat args."""
  259. return []
  260. class UnflattenedModule(_SubmoduleBase, torch.nn.Module):
  261. def __init__(
  262. self,
  263. export_module: ExportedProgram,
  264. flat_args_adapter: FlatArgsAdapter | None = None,
  265. ):
  266. super().__init__()
  267. if export_module.graph_signature.backward_signature is not None:
  268. raise ValueError("Unflattening on JointExportModule NYI")
  269. def _id(obj):
  270. """Returns _TensorID dataclass for tensors, otherwise id()."""
  271. if isinstance(obj, torch.Tensor):
  272. return _TensorID(
  273. untyped_storage=obj.untyped_storage(),
  274. stride=obj.stride(),
  275. size=obj.size(),
  276. storage_offset=obj.storage_offset(), # type: ignore[arg-type]
  277. )
  278. return id(obj)
  279. fqn_list = [entry.fqn for entry in export_module.module_call_graph]
  280. assert fqn_list[0] == ""
  281. export_graph = deepcopy(export_module.graph)
  282. self.graph_signature = deepcopy(export_module.graph_signature)
  283. self.graph = torch.fx.Graph()
  284. self.graph.owning_module = self # type: ignore[assignment]
  285. self.module_call_graph = deepcopy(export_module.module_call_graph)
  286. self.flat_args_adapter = flat_args_adapter
  287. self.meta = export_module.graph_module.meta
  288. self.meta["unflattened_module"] = self
  289. # Flag to indicate whether args have been adapted.
  290. self.adapted = False
  291. self._run_with_interpreter = RUN_WITH_INTERPRETER
  292. _inplace_buffer_and_input_mutations(export_graph, self.graph_signature)
  293. _fix_nn_module_stacks(export_graph)
  294. self._ty = _root_module_type(export_graph)
  295. self.ivals = _IVals()
  296. # for any intermediate value of a mutation that is read, track the mutation
  297. seen_modules, seen_attrs = _outline_submodules(export_graph, self)
  298. # for each read intermediate value of a mutation, find where it was created,
  299. # and perform the mutation
  300. self.ivals.update(seen_modules.values())
  301. # move attributes that correspond to graph arguments for HOPs
  302. # from exported program to unflattened submodules
  303. _copy_graph_attrs(export_module._graph_module, self, seen_attrs)
  304. self.range_constraints = export_module.range_constraints
  305. self.equality_constraints: list = []
  306. # aliasing/unused param or buffer issues:
  307. # in strict-mode export, dynamo export will deduplicate aliased tensors,
  308. # and ignore unused tensors. For aliasing, this causes issues when some aliases
  309. # are unused, and we're unable to match the placeholder node to the correct FQN.
  310. # This leads to the graph signature potentially having the wrong target FQN,
  311. # and downstream issues where parameters are assigned to the wrong target attribute,
  312. # mismatching the relevant placeholder node in the unflattened module.
  313. # To resolve this we restore (_assign_attr) all aliased/unused tensors in
  314. # the state_dict as module attributes, but only keep the used tensors in the
  315. # graph's forward pass (_sink_params).
  316. state_dict = export_module.state_dict
  317. assigned_params: set[str] = set() # tracking unused params
  318. id_to_param: dict[
  319. int | _TensorID, torch.nn.Parameter
  320. ] = {} # handling weight-sharing
  321. for name in self.graph_signature.parameters: # this loop adds used params
  322. param = state_dict[name]
  323. if _id(param) not in id_to_param:
  324. id_to_param[_id(param)] = torch.nn.Parameter(
  325. param.clone(), requires_grad=param.requires_grad
  326. )
  327. _assign_attr(
  328. id_to_param[_id(param)],
  329. self,
  330. name,
  331. attr_kind=_AttrKind.PARAMETER,
  332. )
  333. assigned_params.add(name)
  334. non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
  335. assigned_buffers: set[str] = set() # tracking unused buffers
  336. id_to_buffer: dict[int | _TensorID, tuple[torch.nn.Parameter, bool]] = {}
  337. for name in self.graph_signature.buffers: # this loop adds used buffers
  338. if name in non_persistent_buffers:
  339. persistent = False
  340. buffer = export_module.constants[name]
  341. else:
  342. persistent = True
  343. buffer = state_dict[name]
  344. if _id(buffer) not in id_to_buffer:
  345. id_to_buffer[_id(buffer)] = (buffer.clone(), persistent)
  346. _assign_attr(
  347. id_to_buffer[_id(buffer)][0],
  348. self,
  349. name,
  350. attr_kind=_AttrKind.BUFFER,
  351. persistent=persistent,
  352. )
  353. assigned_buffers.add(name)
  354. # restore aliased/unused params and buffers
  355. # these appear in state dict but not graph signature
  356. for name, tensor in state_dict.items():
  357. if name in assigned_params or name in assigned_buffers: # already assigned
  358. continue
  359. is_buffer = False
  360. if _id(tensor) in id_to_buffer or not isinstance(
  361. tensor, torch.nn.Parameter
  362. ): # aliased buffer
  363. is_buffer = True
  364. if is_buffer:
  365. if (
  366. _id(tensor) not in id_to_buffer
  367. ): # this is completely unused (not weight-sharing)
  368. id_to_buffer[_id(tensor)] = (
  369. tensor,
  370. True,
  371. ) # assign to respect original model
  372. _assign_attr(
  373. id_to_buffer[_id(tensor)][0],
  374. self,
  375. name,
  376. attr_kind=_AttrKind.BUFFER,
  377. persistent=True,
  378. )
  379. else:
  380. if _id(tensor) not in id_to_param: # this is unused
  381. id_to_param[_id(tensor)] = tensor
  382. _assign_attr(
  383. id_to_param[_id(tensor)],
  384. self,
  385. name,
  386. attr_kind=_AttrKind.PARAMETER,
  387. )
  388. # use id map so we don't double-clone aliased constants
  389. id_to_const: dict[int | _TensorID, torch.Tensor | torch._C.ScriptObject] = {}
  390. for fqn, constant in export_module.constants.items():
  391. if _id(constant) not in id_to_const:
  392. if isinstance(constant, torch.Tensor):
  393. constant = constant.clone()
  394. id_to_const[_id(constant)] = constant
  395. _constant = id_to_const[_id(constant)]
  396. _assign_attr(
  397. _constant,
  398. self,
  399. fqn,
  400. attr_kind=_AttrKind.CONSTANT,
  401. )
  402. # This is to handle parameters/buffers that point to the same tensor
  403. # object id -> list of (node_name, target_name)
  404. consts_map: dict[int | _TensorID, list[tuple[str, str]]] = defaultdict(list)
  405. consts_targets: set[str] = set()
  406. def add_to_consts_map(obj_id, node_name, target_name):
  407. name_list = consts_map[obj_id]
  408. name_list.append((node_name, target_name))
  409. # track aliased/unused params, buffers
  410. # prefer using untyped_storage() over id() when it's available
  411. added_params_buffers: set[str] = set()
  412. for s in self.graph_signature.input_specs:
  413. if s.kind == InputKind.PARAMETER or (
  414. s.kind == InputKind.BUFFER and s.persistent
  415. ):
  416. assert hasattr(s.arg, "name")
  417. assert isinstance(s.target, str)
  418. add_to_consts_map(
  419. _id(export_module.state_dict[s.target]),
  420. s.arg.name,
  421. s.target,
  422. )
  423. consts_targets.add(s.target)
  424. added_params_buffers.add(s.target)
  425. elif (
  426. s.kind == InputKind.BUFFER
  427. and not s.persistent
  428. or s.kind == InputKind.CONSTANT_TENSOR
  429. or s.kind == InputKind.CUSTOM_OBJ
  430. ):
  431. assert hasattr(s.arg, "name")
  432. assert isinstance(s.target, str)
  433. add_to_consts_map(
  434. _id(export_module.constants[s.target]),
  435. s.arg.name,
  436. s.target,
  437. )
  438. consts_targets.add(s.target)
  439. # add constants that are aliased and don't appear in graph signature
  440. for const_name, const in export_module.constants.items():
  441. if const_name not in consts_targets:
  442. const_id = _id(const)
  443. assert const_id in consts_map
  444. ph_name, _ = consts_map[const_id][0]
  445. add_to_consts_map(const_id, ph_name, const_name)
  446. added_params_buffers.add(s.target)
  447. # add aliased/unused params and buffers that don't appear in graph signature
  448. for fqn, tensor in export_module.state_dict.items():
  449. if fqn not in added_params_buffers:
  450. tensor_id = _id(tensor)
  451. if tensor_id not in consts_map:
  452. # completely unused (no weight-sharing), ignore.
  453. # this weight doesn't appear in graph module,
  454. # so won't cause FQN assignment issues
  455. continue
  456. ph_name, _ = consts_map[tensor_id][0]
  457. add_to_consts_map(tensor_id, ph_name, fqn)
  458. # node name -> list of possible targets
  459. inputs_to_state: dict[str, list[str]] = {}
  460. for node_target in consts_map.values():
  461. targets = [t[1] for t in node_target]
  462. for n, _ in node_target:
  463. inputs_to_state[n] = targets
  464. _sink_params(self, inputs_to_state, [])
  465. redirected_call_indices = _deduplicate_modules(seen_modules.values())
  466. fqn_list = [fqn for fqn in fqn_list if fqn not in redirected_call_indices]
  467. self._dispatch_modules(redirected_call_indices, consts_targets)
  468. fqn_list = [fqn for fqn in fqn_list if "@" not in fqn]
  469. # Cache so we don't have to compute this every time.
  470. # NOTE: this needs to be kept in sync with the placeholders in
  471. # self.graph, but currently we have no way to guarantee that.
  472. self.input_placeholders = [
  473. node for node in self.graph.nodes if node.op == "placeholder"
  474. ]
  475. self.check_input_constraints = True
  476. # TODO(zhxchen17) We can register modules ahead of time instead of reorder later.
  477. fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)}
  478. # In the case of legacy IR, we might be missing some modules from metadata.
  479. for name, _ in self.named_modules(remove_duplicate=False):
  480. if name not in fqn_order:
  481. fqn_order[name] = len(fqn_order)
  482. _reorder_submodules(self, fqn_order)
  483. self.graph.lint()
  484. self.finalize()
  485. def _print_graph(self):
  486. for fqn, mod in self.named_modules():
  487. print(fqn + ":")
  488. if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph):
  489. print(mod.graph)
  490. def _adapt_flat_args(self, flat_args, in_spec, input):
  491. signature = self.module_call_graph[0].signature
  492. if in_spec == signature.in_spec:
  493. return flat_args
  494. if self.flat_args_adapter is None:
  495. raise TypeError(
  496. "There is no flat args adapter specified. "
  497. "Are you sure you are calling this with the right arguments? "
  498. )
  499. else:
  500. flat_args = self.flat_args_adapter.adapt(
  501. target_spec=signature.in_spec,
  502. input_spec=in_spec,
  503. input_args=flat_args,
  504. metadata=self.meta,
  505. obj=input,
  506. )
  507. if len(flat_args) != signature.in_spec.num_leaves:
  508. raise TypeError(
  509. f"Flat args adaption failed, number of args mismatch "
  510. f"Adatped: {len(flat_args)} \n"
  511. f"Exported module: {signature.in_spec.num_leaves}"
  512. )
  513. return flat_args
  514. def process_forward_inputs(self, *args, **kwargs):
  515. signature = self.module_call_graph[0].signature
  516. reordered_kwargs = kwargs
  517. if kwargs:
  518. reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec)
  519. flat_args_with_path, in_spec = pytree.tree_flatten_with_path(
  520. (args, reordered_kwargs)
  521. )
  522. flat_args = [x[1] for x in flat_args_with_path]
  523. if is_fx_symbolic_tracing():
  524. return flat_args
  525. if in_spec != signature.in_spec:
  526. if not self.adapted:
  527. print(
  528. "Input treespec does not match with exported module's: \n"
  529. f"Input treespec: {in_spec}. ",
  530. f"Exported module treespec: {signature.in_spec}",
  531. )
  532. print("Adapting flat arg to match exported module's treespec")
  533. flat_args = self._adapt_flat_args(flat_args, in_spec, args)
  534. self.adapted = True
  535. if self.check_input_constraints:
  536. # Import here to avoid an unfortunate circular dependency.
  537. # TODO(suo): untangle this.
  538. from torch._export.utils import _check_input_constraints_for_graph
  539. if self.adapted is True:
  540. flat_arg_paths = (
  541. self.flat_args_adapter.get_flat_arg_paths()
  542. if self.flat_args_adapter
  543. else []
  544. )
  545. assert not flat_arg_paths or len(flat_arg_paths) == len(flat_args)
  546. new_flat_args_with_path = [ # type: ignore[var-annotated]
  547. (
  548. (
  549. SequenceKey(idx=idx),
  550. GetAttrKey(
  551. name=flat_arg_paths[idx]
  552. if flat_arg_paths
  553. else "<unknown location>"
  554. ),
  555. ),
  556. arg,
  557. )
  558. for idx, arg in enumerate(flat_args)
  559. ]
  560. else:
  561. new_flat_args_with_path = flat_args_with_path # type: ignore[assignment]
  562. _check_input_constraints_for_graph(
  563. self.input_placeholders, new_flat_args_with_path, self.range_constraints
  564. )
  565. return flat_args
  566. def forward(self, *args, **kwargs):
  567. flat_args = self.process_forward_inputs(*args, **kwargs)
  568. signature = self.module_call_graph[0].signature
  569. if is_fx_symbolic_tracing():
  570. return_val = torch.fx.Interpreter(self, graph=self.graph).run(
  571. *flat_args, enable_io_processing=False
  572. )
  573. # For scalar return value, fx.Graph wraps in a tuple
  574. if isinstance(return_val, tuple) and len(return_val) == 1:
  575. return return_val[0]
  576. return return_val
  577. if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter:
  578. tree_out = type(self.graph_module).forward(self, *flat_args) # type: ignore[union-attr]
  579. else:
  580. tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
  581. *flat_args, enable_io_processing=False
  582. )
  583. return pytree.tree_unflatten(tree_out, signature.out_spec)
  584. def finalize(self):
  585. self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
  586. self.graph.lint()
  587. def _dispatch_modules(self, redirected_call_indices, consts_targets):
  588. """For a module whose call signatures are preserved, replace
  589. multiple modules corresponding to multiple calls to that module
  590. with a single dispatcher module that tracks which module to call.
  591. """
  592. # for each fqn whose module call signature is preserved,
  593. # map that fqn to a list of called modules
  594. called_modules = defaultdict(list)
  595. for entry in self.module_call_graph:
  596. if entry.fqn and entry.signature:
  597. # some modules were removed and their fqns redirected to other
  598. # fqns during deduplication
  599. fqn = entry.fqn
  600. mod = _get_attr(self, redirected_call_indices.get(fqn, fqn))
  601. base, idx = fqn.split("@") if "@" in fqn else [fqn, "0"]
  602. called_modules[base].append((int(idx), mod))
  603. attrs_map = defaultdict(set)
  604. for target in consts_targets:
  605. if "." in target:
  606. orig_fqn, name = target.rsplit(".", 1)
  607. attrs_map[orig_fqn].add(name)
  608. else:
  609. attrs_map[""].add(target)
  610. # replace multiple call modules with a single dispatcher module
  611. for orig_fqn, indexed_call_modules in called_modules.items():
  612. call_modules = [mod for _, mod in sorted(indexed_call_modules)]
  613. if len(call_modules) > 1:
  614. for i in range(len(call_modules)):
  615. fqn = _call_name(orig_fqn, i + 1)
  616. if fqn not in redirected_call_indices:
  617. *prefix, name = fqn.split(".")
  618. _get_attr_via_attr_list(self, prefix)._modules.pop(name)
  619. self.set_submodule(
  620. orig_fqn,
  621. InterpreterModuleDispatcher(attrs_map[orig_fqn], call_modules),
  622. )
  623. # elide call indices in call modules because they are
  624. # tracked automatically inside the dispatcher module
  625. def elide_call_indices(prefix, graph):
  626. for node in graph.nodes:
  627. if node.op == "call_module":
  628. fqn = node.target.split("@")[0]
  629. path = f"{prefix}.{fqn}" if prefix else fqn
  630. if path in called_modules:
  631. node.target = fqn
  632. for fqn, mod in self.named_modules(remove_duplicate=False):
  633. if hasattr(mod, "graph"):
  634. elide_call_indices(fqn, mod.graph)
  635. elif hasattr(mod, "_call_modules"):
  636. for mod_ in mod._call_modules:
  637. assert hasattr(mod_, "graph")
  638. elide_call_indices(fqn, mod_.graph)
  639. def print_readable(
  640. self,
  641. print_output=True,
  642. include_stride=False,
  643. include_device=False,
  644. colored=False,
  645. ):
  646. return _print_readable(
  647. self,
  648. "UnflattenedModule",
  649. print_output,
  650. include_stride,
  651. include_device,
  652. colored,
  653. )
  654. def unflatten(
  655. module: ExportedProgram, flat_args_adapter: FlatArgsAdapter | None = None
  656. ) -> UnflattenedModule:
  657. """Unflatten an ExportedProgram, producing a module with the same module
  658. hierarchy as the original eager module. This can be useful if you are trying
  659. to use :mod:`torch.export` with another system that expects a module
  660. hierarchy instead of the flat graph that :mod:`torch.export` usually produces.
  661. .. note:: The args/kwargs of unflattened modules will not necessarily match
  662. the eager module, so doing a module swap (e.g. :code:`self.submod =
  663. new_mod`) will not necessarily work. If you need to swap a module out, you
  664. need to set the :code:`preserve_module_call_signature` parameter of
  665. :func:`torch.export.export`.
  666. Args:
  667. module (ExportedProgram): The ExportedProgram to unflatten.
  668. flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's.
  669. Returns:
  670. An instance of :class:`UnflattenedModule`, which has the same module
  671. hierarchy as the original eager module pre-export.
  672. """
  673. module = _remove_effect_tokens(module)
  674. m = UnflattenedModule(module, flat_args_adapter)
  675. # Disable process_forward_inputs as the adapter has many
  676. # non-dynamo-traceable behavior.
  677. m.process_forward_inputs = torch._dynamo.disable( # type: ignore[method-assign]
  678. m.process_forward_inputs,
  679. reason="do not trace into preprocessing the inputs",
  680. recursive=True,
  681. )
  682. return m
  683. def _inplace_buffer_and_input_mutations(
  684. graph: torch.fx.Graph,
  685. graph_signature: ExportGraphSignature,
  686. ) -> None:
  687. """Transform buffer and input mutations from their functionalized form
  688. into copy_ nodes in the graph.
  689. Functionalization represents a buffer mutation by passing the buffer as
  690. an input and output. For example, consider the eager code:
  691. def forward(self, x):
  692. self.buffer += x
  693. return x * x
  694. This corresponds to a graph that looks like:
  695. def forward(self, buffer, x):
  696. mutated_buffer = aten.add(buffer, x)
  697. mul = aten.mul(x, x)
  698. return (mutated_buffer, mul)
  699. We want to inplace this into something that looks like the original
  700. eager code:
  701. def forward(self, buffer, x):
  702. mutated_buffer = aten.add(buffer, x)
  703. buffer.copy_(mutated_buffer)
  704. mul = aten.mul(x, x)
  705. return (mul,)
  706. Input mutations are handled similarly.
  707. """
  708. output_node = next(iter(reversed(graph.nodes)))
  709. assert output_node.op == "output" and len(output_node.args) == 1
  710. return_args = output_node.args[0]
  711. input_name_to_node = {
  712. node.name: node for node in graph.nodes if node.op == "placeholder"
  713. }
  714. mutation_name_to_input_name = {}
  715. # Collect mutated buffers.
  716. buffer_fqn_to_input_name = {
  717. buffer_fqn: k for k, buffer_fqn in graph_signature.inputs_to_buffers.items()
  718. }
  719. mutation_name_to_input_name = {
  720. k: buffer_fqn_to_input_name[buffer_fqn]
  721. for k, buffer_fqn in graph_signature.buffers_to_mutate.items()
  722. }
  723. # Collect mutated user inputs.
  724. mutation_name_to_input_name.update(graph_signature.user_inputs_to_mutate)
  725. num_mutations = len(mutation_name_to_input_name)
  726. for mutation in return_args[:num_mutations]:
  727. input_name = mutation_name_to_input_name[mutation.name]
  728. input_node = input_name_to_node[input_name]
  729. with graph.inserting_after(mutation):
  730. # Create a copy_ node that inplaces the mutation.
  731. new_node = graph.create_node(
  732. "call_function", torch.ops.aten.copy_.default, (input_node, mutation)
  733. )
  734. for k, v in mutation.meta.items():
  735. new_node.meta[k] = v
  736. # Replace all uses of the previously functional mutation with
  737. # our copy_ node.
  738. mutation.replace_all_uses_with(new_node, lambda x: x is not new_node)
  739. # Remove the mutated buffer / input from the graph outputs, since we don't
  740. # need to thread it through anymore.
  741. user_outputs = tuple(return_args[num_mutations:])
  742. output_node.args = ((user_outputs),)
  743. def _root_module_type(graph: torch.fx.Graph) -> str | None:
  744. for node in graph.nodes:
  745. if "nn_module_stack" not in node.meta:
  746. continue
  747. for path, ty in node.meta["nn_module_stack"].values():
  748. if not path:
  749. return ty
  750. return None
  751. def _fix_nn_module_stacks(graph):
  752. # For each nn module stack in the graph, check if the fqns in it represent a stack:
  753. # 1. Each fqn must be a prefix of the next fqn.
  754. # 2. If not, remove the entries starting from the next fqn, emitting a warning.
  755. for node in graph.nodes:
  756. if "nn_module_stack" not in node.meta:
  757. continue
  758. nn_module_stack = node.meta["nn_module_stack"]
  759. fqns = [
  760. fqn.split("@")[0] if "@" in fqn else fqn
  761. for fqn, _t in nn_module_stack.values()
  762. ]
  763. # Check if each FQN is a prefix of the next one
  764. prev_fqn, *next_fqns = fqns
  765. num_valid_indices = 1 # root FQN
  766. for curr_fqn in next_fqns:
  767. # Check if the previous FQN is a prefix of the current one
  768. if _is_prefix(prev_fqn, curr_fqn):
  769. num_valid_indices += 1
  770. prev_fqn = curr_fqn
  771. else:
  772. # Found a non-prefix FQN, stop here
  773. break
  774. # If we need to remove entries, create a new stack with only valid entries
  775. if num_valid_indices < len(nn_module_stack):
  776. log.warning(
  777. "nn_module_stack fqns %s at node %s do not form a stack! dropping last %d entries",
  778. fqns,
  779. node,
  780. len(nn_module_stack) - num_valid_indices,
  781. )
  782. node.meta["nn_module_stack"] = dict(
  783. list(nn_module_stack.items())[:num_valid_indices]
  784. )
  785. def _is_prefix(candidate, target):
  786. """Check whether `candidate` is a prefix of `target`."""
  787. return len(candidate) < len(target) and target[: len(candidate)] == candidate
  788. def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
  789. if parent_fqn == "":
  790. # Handle the root module correctly.
  791. return child_fqn
  792. parent_split = parent_fqn.split(".")
  793. child_split = child_fqn.split(".")
  794. # TODO: support skip connection by inlining the child module.
  795. if child_split[: len(parent_split)] != parent_split:
  796. raise RuntimeError(
  797. f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'."
  798. "This is currently unsupported."
  799. "Please try to make child module attach to parent module directly."
  800. )
  801. return ".".join(child_split[len(parent_split) :])
  802. def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
  803. def graph_dump(graph: torch.fx.Graph) -> str:
  804. ret = []
  805. nodes_idx: dict[int, int] = {}
  806. def arg_dump(arg) -> str:
  807. if isinstance(arg, torch.fx.Node):
  808. return "%" + str(nodes_idx[id(arg)])
  809. return str(arg)
  810. for i, node in enumerate(graph.nodes):
  811. args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)]
  812. args_dump += [
  813. f"{key}={value}"
  814. for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
  815. ]
  816. target = node.target if node.op in ("call_function", "get_attr") else ""
  817. # pyrefly: ignore [bad-argument-type]
  818. ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
  819. nodes_idx[id(node)] = i
  820. return "\n".join(ret)
  821. assert isinstance(x.graph, torch.fx.Graph)
  822. assert isinstance(y.graph, torch.fx.Graph)
  823. return graph_dump(x.graph) == graph_dump(y.graph)
  824. def _add_spec(gm: torch.nn.Module, spec) -> str:
  825. i = 0
  826. while hasattr(gm, f"_spec_{i}"):
  827. i += 1
  828. name = f"_spec_{i}"
  829. setattr(gm, name, spec)
  830. return name
  831. def _generate_flatten(gm: torch.fx.GraphModule, node) -> torch.fx.Node:
  832. flatten = gm.graph.call_function(pytree.tree_flatten, (node,))
  833. getitem_0 = gm.graph.call_function(operator.getitem, (flatten, 0))
  834. return getitem_0
  835. def _generate_flatten_spec(
  836. gm: torch.fx.GraphModule | InterpreterModule | UnflattenedModule, node, spec
  837. ) -> torch.fx.Node:
  838. name = _add_spec(gm, spec)
  839. spec_node = gm.graph.get_attr(name)
  840. return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node))
  841. def _generate_unflatten(
  842. gm: torch.fx.GraphModule | InterpreterModule | UnflattenedModule, nodes, spec
  843. ) -> torch.fx.Node:
  844. name = _add_spec(gm, spec)
  845. spec_node = gm.graph.get_attr(name)
  846. return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node))
  847. def _get_submodule(mod: torch.nn.Module, target: str):
  848. *prefix, field = target.split(".")
  849. for item in prefix:
  850. submod = getattr(mod, item, None)
  851. if submod is None:
  852. return None
  853. if not isinstance(submod, torch.nn.Module):
  854. return None
  855. mod = submod
  856. return getattr(mod, field, None)
  857. def _add_submodule(
  858. mod: torch.nn.Module,
  859. target: str,
  860. module_to_add: torch.nn.Module,
  861. create_module: Callable[[str], torch.nn.Module] | None = None,
  862. ):
  863. *prefix, field = target.split(".")
  864. for i, item in enumerate(prefix):
  865. submod = getattr(mod, item, None)
  866. if submod is None:
  867. if create_module is not None:
  868. submod = create_module(".".join(prefix[: i + 1]))
  869. else:
  870. submod = torch.nn.Module()
  871. setattr(mod, item, submod)
  872. if not isinstance(submod, torch.nn.Module):
  873. return False
  874. mod = submod
  875. mod.add_module(field, module_to_add)
  876. def _call_name(base: str, n: int) -> str:
  877. # Given n >= 0, generate call names to a submodule `base` of the form
  878. # `base`, `base@1`, `base@2`, etc.
  879. return base if n == 1 else f"{base}@{n - 1}"
  880. def _is_call_name(call_name: str, base: str) -> bool:
  881. # Recognize when call_name = _call_name(base, n) for some n >= 0.
  882. return re.match(re.escape(base) + r"(@\d+)?$", call_name) is not None
  883. class _ModuleFrame:
  884. def __init__(
  885. self,
  886. flat_graph: torch.fx.Graph,
  887. nodes: tuple[torch.fx.Node, ...],
  888. seen_nodes,
  889. seen_modules,
  890. seen_attrs,
  891. created_modules,
  892. parent,
  893. module_stack: list[tuple[str, str | None, int]],
  894. module_id,
  895. module_call_graph: dict[str, ModuleCallSignature],
  896. module: torch.fx.GraphModule | UnflattenedModule | None = None,
  897. ):
  898. self.flat_graph = flat_graph
  899. self.nodes = nodes
  900. self.seen_nodes = seen_nodes
  901. self.seen_modules = seen_modules
  902. self.seen_attrs = seen_attrs
  903. self.created_modules = created_modules
  904. self.parent = parent
  905. self.module_stack = module_stack
  906. self.module_id = module_id
  907. self.module_call_graph = module_call_graph
  908. self.verbose = False
  909. self.fqn, ty, num_calls = self.module_stack[-1]
  910. # generate call name for self.fqn
  911. self.child_fqn = _call_name(self.fqn, num_calls + 1)
  912. self.module: torch.fx.GraphModule | UnflattenedModule | InterpreterModule
  913. if module is not None:
  914. self.module = module
  915. self.ivals = module.ivals if hasattr(module, "ivals") else {} # type: ignore[var-annotated]
  916. else:
  917. self.module = self.created_modules.get(
  918. self.fqn,
  919. InterpreterModule(torch.fx.Graph(), ty=ty),
  920. )
  921. self.ivals = parent.ivals
  922. self.graph = self.module.graph
  923. # Mapping of nodes in the flat graph to nodes in this graph.
  924. self.node_map: dict[torch.fx.Node, torch.fx.Node] = {}
  925. self.node_to_placeholder = {}
  926. self.parent_call_module: torch.fx.Node | None = None
  927. if parent is not None:
  928. accessor = _compute_accessor(parent.fqn, self.child_fqn)
  929. def create_module(fqn):
  930. path = f"{parent.fqn}.{fqn}" if parent.fqn else fqn
  931. if path in self.created_modules:
  932. return self.created_modules[path]
  933. submod = InterpreterModule(torch.fx.Graph(), ty=ty)
  934. self.created_modules[path] = submod
  935. return submod
  936. _add_submodule(parent.module, accessor, self.module, create_module)
  937. self.parent_call_module = parent.graph.call_module(accessor)
  938. if self.seen_modules[self.module_id]:
  939. base_module_frame = self.seen_modules[self.module_id][0]
  940. self.module._modules = base_module_frame.module._modules
  941. self.seen_modules[self.module_id].append(
  942. _SubmoduleEntry(
  943. parent_fqn=self.parent.fqn,
  944. parent_module=self.parent.module,
  945. parent_call_module=self.parent_call_module,
  946. fqn=self.fqn,
  947. call_idx=num_calls + 1,
  948. module=self.module,
  949. )
  950. )
  951. signature = module_call_graph.get(self.child_fqn)
  952. if signature is not None and self.parent is not None:
  953. assert signature.in_spec.num_children == 2
  954. assert signature.in_spec.type is tuple
  955. args_spec, kwargs_spec = signature.in_spec.children()
  956. assert args_spec.type is tuple
  957. assert kwargs_spec.type is dict
  958. with self.graph.inserting_after(None):
  959. arg_nodes = [
  960. self.graph.placeholder(f"_positional_arg_{idx}")
  961. for idx in range(args_spec.num_children)
  962. ]
  963. kwarg_nodes = {}
  964. for name in kwargs_spec.context:
  965. kwarg_nodes[name] = self.graph.placeholder(name)
  966. flat_args = _generate_flatten_spec(
  967. self.module,
  968. (tuple(arg_nodes), kwarg_nodes),
  969. signature.in_spec,
  970. )
  971. for idx, arg in enumerate(signature.inputs):
  972. flat_arg_node = self.graph.create_node(
  973. op="call_function",
  974. target=operator.getitem,
  975. args=(flat_args, idx),
  976. name=(
  977. arg.name
  978. if not isinstance(arg, ConstantArgument)
  979. else f"_constant_{idx}"
  980. ),
  981. )
  982. if isinstance(arg, ConstantArgument):
  983. continue
  984. if arg.name in self.seen_nodes:
  985. flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
  986. self.node_to_placeholder[self.seen_nodes[arg.name]] = (
  987. flat_arg_node
  988. )
  989. with self.parent.graph.inserting_before(self.parent_call_module):
  990. input_nodes: list[torch.fx.Node | None] = []
  991. for input in signature.inputs:
  992. if isinstance(input, ConstantArgument):
  993. input_nodes.append(input.value) # type: ignore[arg-type]
  994. elif input.name not in self.seen_nodes:
  995. input_nodes.append(None)
  996. else:
  997. assert isinstance(
  998. input,
  999. (
  1000. TensorArgument,
  1001. SymIntArgument,
  1002. SymBoolArgument,
  1003. SymFloatArgument,
  1004. ),
  1005. )
  1006. input_nodes.append(
  1007. self.parent.remap_input(self.seen_nodes[input.name])
  1008. )
  1009. inputs_node = _generate_unflatten(
  1010. self.parent.module,
  1011. input_nodes,
  1012. signature.in_spec,
  1013. )
  1014. args_node = self.parent.graph.call_function(
  1015. operator.getitem, (inputs_node, 0)
  1016. )
  1017. kwargs_node = self.parent.graph.call_function(
  1018. operator.getitem, (inputs_node, 1)
  1019. )
  1020. arg_nodes = [
  1021. self.parent.graph.call_function(operator.getitem, (args_node, i))
  1022. for i in range(args_spec.num_children)
  1023. ]
  1024. kwarg_nodes = {
  1025. k: self.parent.graph.call_function(
  1026. operator.getitem, (kwargs_node, k)
  1027. )
  1028. for k in kwargs_spec.context
  1029. }
  1030. assert self.parent_call_module is not None
  1031. # pyrefly: ignore [bad-assignment]
  1032. self.parent_call_module.args = tuple(arg_nodes)
  1033. self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment]
  1034. def add_placeholder(self, x):
  1035. assert self.fqn != "", f"Cannot add placeholder {x} to root module"
  1036. assert x.graph is self.flat_graph
  1037. # x is not in subgraph, create a new placeholder for subgraph
  1038. with self.graph.inserting_before(None):
  1039. placeholder_node = self.graph.placeholder(x.name, type_expr=x.type)
  1040. # copy all meta fields, even if some fields might be irrelevant for
  1041. # the placeholder node
  1042. placeholder_node.meta = copy.copy(x.meta)
  1043. self.node_to_placeholder[x] = placeholder_node
  1044. def copy_sym_call_function(self, x):
  1045. # This only exists because we deduplicate sym_size nodes in the flat export graph,
  1046. # and if preserve_module_call_signature is set, we may not be able to pass sym_size
  1047. # nodes, or their downstream users, as inputs to submodule calls.
  1048. # To avoid this we copy these call_function nodes with sym_type results.
  1049. # This should however only be done for sym_type nodes - call_function nodes on tensors
  1050. # should not be deduplicated in the first place.
  1051. args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args)
  1052. kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs)
  1053. node = self.graph.call_function(x.target, args, kwargs)
  1054. node.meta = copy.copy(x.meta)
  1055. self.node_map[x] = node
  1056. return node
  1057. def remap_input(self, x):
  1058. assert x.graph is self.flat_graph
  1059. if x in self.node_map:
  1060. return self.node_map[x]
  1061. self.print(f"remap_input({x})")
  1062. if x in self.node_to_placeholder:
  1063. return self.node_to_placeholder[x]
  1064. elif (
  1065. x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None
  1066. # allow placeholder creation if we are not preserving module call signature
  1067. ):
  1068. self.add_placeholder(x)
  1069. if self.parent_call_module is not None:
  1070. # Important to *prepend* the output to match how we are
  1071. # inserting placeholder nodes.
  1072. with self.parent.graph.inserting_before(self.parent_call_module):
  1073. self.parent_call_module.insert_arg(0, self.parent.remap_input(x))
  1074. return self.node_to_placeholder[x]
  1075. elif x.op == "call_function" and (
  1076. x.target
  1077. in (
  1078. torch.ops.aten.sym_size.int,
  1079. torch.ops.aten.item.default,
  1080. torch.ops.aten.unbind.int,
  1081. torch.ops.aten.sum.dim_IntList,
  1082. torch.ops.aten.view.default,
  1083. torch.ops.aten.diff.default,
  1084. )
  1085. or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator")
  1086. ):
  1087. # export deduplicates sym_size nodes, and may need to re-copy them
  1088. # if module call signature needs to be preserved
  1089. self.copy_sym_call_function(x)
  1090. return self.node_map[x]
  1091. elif self.module_call_graph.get(self.fqn) is not None:
  1092. # x is reading the intermediate value of a mutation, so record it;
  1093. # later we will find where it was created and perform the update
  1094. return self.ivals.read(self, x) # type: ignore[operator, union-attr]
  1095. else:
  1096. raise RuntimeError(
  1097. f"Could not run remap_input() on op type: {x.op} for node {x}"
  1098. )
  1099. def uplift_common_custom_metadata(self) -> None:
  1100. # Copy custom metadata if all nodes have same custom metadata
  1101. custom_meta = None
  1102. for node in self.node_map.values():
  1103. curr_meta = node.meta.get("custom", {})
  1104. if custom_meta is None:
  1105. # first node
  1106. custom_meta = curr_meta
  1107. continue
  1108. if curr_meta != custom_meta:
  1109. custom_meta = {}
  1110. break
  1111. if custom_meta:
  1112. # Lift common custom metadata to parent node and clear children node's custom metadata
  1113. assert self.parent_call_module is not None
  1114. self.parent_call_module.meta["custom"] = custom_meta
  1115. for node in self.node_map.values():
  1116. del node.meta["custom"]
  1117. def finalize_outputs(self):
  1118. self.created_modules.pop(self.fqn, None)
  1119. orig_outputs = []
  1120. signature = self.module_call_graph.get(self.child_fqn)
  1121. if signature is not None and self.parent is not None:
  1122. for output in signature.outputs:
  1123. if isinstance(
  1124. output,
  1125. (
  1126. TensorArgument,
  1127. SymIntArgument,
  1128. SymBoolArgument,
  1129. SymFloatArgument,
  1130. ConstantArgument,
  1131. ),
  1132. ):
  1133. if output.name in self.seen_nodes:
  1134. orig_outputs.append(self.seen_nodes[output.name])
  1135. else:
  1136. orig_outputs.append(None)
  1137. else:
  1138. raise RuntimeError(
  1139. f"Unsupported data type for output node: {output}"
  1140. )
  1141. def get_actual_output_node(output):
  1142. if output is None:
  1143. return None
  1144. seen_node = self.seen_nodes[output.name]
  1145. if seen_node in self.node_map:
  1146. return self.node_map[seen_node]
  1147. elif seen_node in self.node_to_placeholder:
  1148. return self.node_to_placeholder[seen_node]
  1149. else:
  1150. raise RuntimeError(
  1151. f"Could not find output node {output}. Graph: {self.graph}"
  1152. )
  1153. tree_out_node = _generate_unflatten(
  1154. self.module,
  1155. tuple(get_actual_output_node(output) for output in orig_outputs),
  1156. signature.out_spec,
  1157. )
  1158. parent_out: torch.fx.Node | None = _generate_flatten_spec(
  1159. self.parent.module, self.parent_call_module, signature.out_spec
  1160. )
  1161. graph_outputs: torch.fx.Node | list[torch.fx.Node] = tree_out_node
  1162. else:
  1163. graph_outputs = []
  1164. # Iterate through nodes we have copied into self.graph.
  1165. for orig_node in self.node_map:
  1166. for user_node in orig_node.users:
  1167. if user_node.name not in self.seen_nodes:
  1168. # external user node, need to expose as an output
  1169. orig_outputs.append(orig_node)
  1170. graph_outputs.append(self.node_map[orig_node])
  1171. break
  1172. parent_out = self.parent_call_module
  1173. if len(graph_outputs) == 1:
  1174. graph_outputs = graph_outputs[0]
  1175. assert isinstance(graph_outputs, (list, torch.fx.Node))
  1176. self.graph.output(graph_outputs)
  1177. # Rewrite outputs in parent module
  1178. if parent_out is None:
  1179. return
  1180. parent_out.meta["val"] = (
  1181. graph_outputs.meta.get("val")
  1182. if isinstance(graph_outputs, torch.fx.Node)
  1183. else [o.meta.get("val") for o in graph_outputs]
  1184. )
  1185. self.uplift_common_custom_metadata()
  1186. if len(orig_outputs) == 1 and signature is None:
  1187. self.parent.node_map[orig_outputs[0]] = parent_out
  1188. else:
  1189. for i, orig_output in enumerate(orig_outputs):
  1190. if orig_output is None:
  1191. continue
  1192. # Use Proxy to record getitem access.
  1193. proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index]
  1194. proxy_out.meta["val"] = orig_output.meta.get("val")
  1195. self.parent.node_map[orig_output] = proxy_out
  1196. def copy_node(self, node):
  1197. self.print("copying", node.format_node())
  1198. self.node_map[node] = self.graph.node_copy(node, self.remap_input)
  1199. self.seen_nodes[node.name] = node
  1200. def run_outer(self):
  1201. for i, node in enumerate(self.flat_graph.nodes):
  1202. self.print(i, node.meta.get("nn_module_stack"), node.format_node())
  1203. # Copy all graph inputs
  1204. node_idx: int = 0
  1205. node = self.nodes[node_idx]
  1206. while node.op == "placeholder":
  1207. self.copy_node(node)
  1208. node_idx += 1
  1209. node = self.nodes[node_idx]
  1210. self.run_from(node_idx)
  1211. # Copy graph outputs
  1212. for node in self.flat_graph.nodes:
  1213. if node.op == "output":
  1214. self.copy_node(node)
  1215. def print(self, *args, **kwargs):
  1216. if self.verbose:
  1217. # pyrefly: ignore [not-iterable]
  1218. print(*args, **kwargs)
  1219. def run_from(self, node_idx):
  1220. module_idx = 0
  1221. # Walk through the graph, building up a new graph with the right submodules
  1222. while node_idx < len(self.nodes):
  1223. node = self.nodes[node_idx]
  1224. assert node.op != "placeholder"
  1225. self.print()
  1226. self.print("STEP", node_idx, node.format_node())
  1227. self.print(self.module_stack)
  1228. depth = len(self.module_stack)
  1229. if node.op == "output":
  1230. if depth == 1:
  1231. # We want the output node of the original graph to be handled
  1232. # specially by the outermost stack frame (in run_outer). So
  1233. # skip finalization here.
  1234. return node_idx
  1235. # We've reached the end of the graph. Wrap up all the existing stack frames.
  1236. self.finalize_outputs()
  1237. return node_idx
  1238. if len(node.meta.get("nn_module_stack", {})) == 0:
  1239. raise RuntimeError(f"Unable to find nn_module_stack for node {node}")
  1240. nn_module_stack = node.meta["nn_module_stack"]
  1241. from torch._export.passes._node_metadata_hook import (
  1242. _EMPTY_NN_MODULE_STACK_KEY,
  1243. )
  1244. if (
  1245. len(nn_module_stack) == 1
  1246. and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack
  1247. ):
  1248. # Empty case from the node_metadata_hook
  1249. node_module_stack = self.module_stack
  1250. else:
  1251. node_module_stack = [
  1252. (
  1253. path,
  1254. ty if path else None,
  1255. int(k.split("@")[-1]) if "@" in k else 0,
  1256. )
  1257. for k, (path, ty) in node.meta["nn_module_stack"].items()
  1258. ]
  1259. if node_module_stack[:depth] != self.module_stack:
  1260. # This means that the current module is done executing and the
  1261. # current node is the beginning of a new module.
  1262. #
  1263. # In this case, we should finalize this module and return without
  1264. # incrementing the node counter.
  1265. self.finalize_outputs()
  1266. self.print("outlining", self.fqn)
  1267. self.print(self.graph)
  1268. return node_idx
  1269. assert node_module_stack is not None
  1270. if _is_prefix(self.module_stack, node_module_stack):
  1271. # This means that the current node represents the execution of a new
  1272. # module.
  1273. next_module = node_module_stack[depth]
  1274. self.print("Creating new stack frame for", next_module)
  1275. # Run a nested version of module outliner from the current node
  1276. # counter. Once it is complete, continue from that point.
  1277. next_module_key = list(node.meta["nn_module_stack"].keys())[depth]
  1278. node_idx = _ModuleFrame(
  1279. self.flat_graph,
  1280. self.nodes,
  1281. self.seen_nodes,
  1282. self.seen_modules,
  1283. self.seen_attrs,
  1284. self.created_modules,
  1285. self,
  1286. self.module_stack + [next_module],
  1287. next_module_key.split("@")[0],
  1288. self.module_call_graph,
  1289. ).run_from(node_idx)
  1290. module_idx += 1
  1291. continue
  1292. # The only remaining possibility is that we are in the right stack
  1293. # frame. Copy the node into this frame's graph and increment the node counter.
  1294. assert node_module_stack == self.module_stack
  1295. if node.op == "get_attr":
  1296. # this must be a graph argument for a HOP
  1297. self.seen_attrs[self.child_fqn].add(node.target)
  1298. self.copy_node(node)
  1299. # pyrefly: ignore [unsupported-operation]
  1300. node_idx += 1
  1301. @dataclass
  1302. class _SubmoduleEntry:
  1303. parent_fqn: str
  1304. parent_module: torch.nn.Module
  1305. parent_call_module: torch.fx.Node
  1306. fqn: str
  1307. call_idx: int
  1308. module: torch.nn.Module
  1309. def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
  1310. seen_nodes: dict[str, torch.fx.Node] = {}
  1311. seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
  1312. seen_attrs: dict[str, set[str]] = defaultdict(set)
  1313. created_modules: dict[str, torch.nn.Module] = {}
  1314. _ModuleFrame(
  1315. orig_graph,
  1316. tuple(orig_graph.nodes),
  1317. seen_nodes,
  1318. seen_modules,
  1319. seen_attrs,
  1320. created_modules,
  1321. None,
  1322. [("", None, 0)],
  1323. "",
  1324. {
  1325. entry.fqn: entry.signature
  1326. for entry in root_module.module_call_graph
  1327. if entry.signature
  1328. },
  1329. module=root_module,
  1330. ).run_outer()
  1331. return seen_modules, seen_attrs
  1332. def _reorder_submodules(
  1333. parent: torch.nn.Module, fqn_order: dict[str, int], prefix: str = ""
  1334. ):
  1335. # TODO Can be optimized by adding submodules ahead of time.
  1336. if prefix == "":
  1337. for fqn in list(fqn_order.keys())[1:]:
  1338. if _get_submodule(parent, fqn) is None:
  1339. _add_submodule(parent, fqn, torch.nn.Module())
  1340. children = []
  1341. for name, child in list(parent._modules.items()):
  1342. if child is None:
  1343. continue
  1344. fqn = prefix + name
  1345. _reorder_submodules(child, fqn_order, prefix=fqn.split("@")[0] + ".")
  1346. delattr(parent, name)
  1347. children.append((fqn_order[fqn], name, child))
  1348. children.sort(key=operator.itemgetter(0))
  1349. for _, name, child in children:
  1350. parent.register_module(name, child)
  1351. class _IVals:
  1352. """
  1353. Collect the intermediate values of mutations in a graph.
  1354. Example: in the following graph, suppose that buf_in and buf_out
  1355. are the input and output values of a buffer.
  1356. buf_in = placeholder()
  1357. ...
  1358. ival1 = f0(buf_in, ...) # inside self.n0(...)
  1359. ...
  1360. ival2 = f1(ival1, ...) # inside self.n1(...)
  1361. ...
  1362. buf_out = f2(ival2, ...) # inside self.n2(...)
  1363. return buf_out, ...
  1364. Here ival1 and ival2 are intermediate values created inside
  1365. calls to n0 and n1 respectively, and used inside calls to
  1366. n1 and n2 respectively.
  1367. """
  1368. def __init__(self):
  1369. # for each fqn, set of node names corresponding to intermediate values
  1370. self.node_names_by_fqn = defaultdict(set)
  1371. def _is_mutable(self, target):
  1372. if isinstance(target, torch._ops.OpOverload):
  1373. return target._schema.is_mutable
  1374. return False
  1375. def read(self, mf, node):
  1376. """
  1377. Read state corresponding to a given intermediate value.
  1378. """
  1379. # we can assume that the node must be from a mutation
  1380. assert node.op == "call_function"
  1381. b = self._is_mutable(node.target)
  1382. print("Checking mutability", node.target, b)
  1383. if not b:
  1384. # so the mutation was functionalized;
  1385. # we will apply the original mutation later (see below)
  1386. fqn, _ = next(reversed(node.meta["nn_module_stack"].values()))
  1387. self.node_names_by_fqn[fqn].add(node.name)
  1388. return mf.remap_input(node.args[0])
  1389. def update(self, partitions):
  1390. """
  1391. Update states corresponding to intermediate values that were read.
  1392. """
  1393. for shared_submodules in partitions:
  1394. for entry in shared_submodules:
  1395. graph = entry.module.graph
  1396. node_names = self.node_names_by_fqn[entry.fqn]
  1397. nodes = [n for n in graph.nodes if n.name in node_names]
  1398. for node in nodes:
  1399. # so node must be from a functionalized mutation;
  1400. # we perform the original mutation now
  1401. with graph.inserting_after(node):
  1402. new_node = graph.create_node(
  1403. "call_function",
  1404. torch.ops.aten.copy_.default,
  1405. (node.args[0], node),
  1406. )
  1407. new_node.meta = copy.copy(node.meta)
  1408. def _copy_graph_attrs(
  1409. gm: torch.fx.GraphModule,
  1410. root_module: UnflattenedModule,
  1411. seen_attrs: dict[str, set[str]],
  1412. ):
  1413. for child_fqn, names in seen_attrs.items():
  1414. module = _get_attr(root_module, child_fqn) if child_fqn else root_module
  1415. for name in names:
  1416. val = getattr(gm, name)
  1417. setattr(module, name, val)
  1418. def _deduplicate_modules(partitions):
  1419. redirected_call_indices = {}
  1420. for shared_submodules in partitions:
  1421. for i, entry in enumerate(shared_submodules):
  1422. child_fqn = _call_name(entry.fqn, entry.call_idx)
  1423. target = _compute_accessor(entry.parent_fqn, child_fqn)
  1424. deduplicated = False
  1425. # Iterate over all previously seen modules, and deduplicate if possible
  1426. for seen in shared_submodules[:i]:
  1427. if _check_graph_equivalence(seen.module, entry.module):
  1428. parent = entry.parent_module
  1429. # Since graphs are equivalent, we can deduplicate.
  1430. # There are two cases.
  1431. if seen.fqn == entry.fqn:
  1432. # Case 1: The current module has the same fqn as the seen module.
  1433. # In this case we have generated a call name that can be optimized away.
  1434. # So we remove the current module from the hierarchy and replace
  1435. # the current call name with the seen call name in the parent graph.
  1436. *prefix, name = target.split(".")
  1437. _get_attr_via_attr_list(parent, prefix)._modules.pop(name)
  1438. seen_child_fqn = _call_name(seen.fqn, seen.call_idx)
  1439. seen_target = _compute_accessor(
  1440. entry.parent_fqn, seen_child_fqn
  1441. )
  1442. entry.parent_call_module.target = seen_target
  1443. redirected_call_indices[child_fqn] = seen_child_fqn
  1444. break
  1445. elif not deduplicated:
  1446. # Case 2: The current module has a different fqn than the seen module.
  1447. # In this case we replace the current module with the seen module.
  1448. # There should be nothing pointing to the current module any more,
  1449. # so it can be garbage collected.
  1450. # NOTE: We *do not* replace the current call name with the seen call name
  1451. # in the parent graph, because this will lose information on which fqn
  1452. # was actually called. However, it is possible that the current call name
  1453. # will be optimized away when we find another seen module with the same fqn,
  1454. # so we do not break out of the loop yet.
  1455. parent.set_submodule(target, seen.module)
  1456. deduplicated = True
  1457. return redirected_call_indices
  1458. def _sink_params(
  1459. module: torch.nn.Module,
  1460. inputs_to_state: dict[str, list[str]],
  1461. scope: list[str],
  1462. module_id_to_inputs_removed: dict[int, set[str]] | None = None,
  1463. ):
  1464. """Sink params, buffers, and constants from graph inputs into get_attr nodes.
  1465. Exported modules are purely functional, so they pass their parameters and
  1466. buffers in as inputs to the graph.
  1467. To replicate eager's semantics, we need to get them from the module state
  1468. via get_attr instead.
  1469. module: GraphModule, potentially containing nested submodules.
  1470. inputs_to_state: mapping graph input names to the corresponding key in the state_dict.
  1471. scope: tracks where we are in the module hierarchy, so that we can emit the
  1472. right `getattr(self, "foo.bar")` calls, etc.
  1473. module_id_to_inputs_removed: records inputs removed by child modules, mapping
  1474. the module object id to the list of placeholder node names in the child module
  1475. that were removed.
  1476. """
  1477. if module_id_to_inputs_removed is None:
  1478. module_id_to_inputs_removed = defaultdict(set)
  1479. if id(module) in module_id_to_inputs_removed:
  1480. return {id(module): module_id_to_inputs_removed[id(module)]}
  1481. # We need to use _modules here instead of named_children(), because we
  1482. # explicitly want duplicate modules to show up in the traversal.
  1483. for name, submodule in module._modules.items():
  1484. submod_id_to_inputs_removed = _sink_params(
  1485. cast("torch.nn.Module", submodule),
  1486. inputs_to_state,
  1487. scope + [name],
  1488. module_id_to_inputs_removed,
  1489. )
  1490. for k, v in submod_id_to_inputs_removed.items():
  1491. module_id_to_inputs_removed[k].update(v)
  1492. graph = getattr(module, "graph", None)
  1493. if graph is None or len(graph.nodes) == 0:
  1494. # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)
  1495. return module_id_to_inputs_removed
  1496. assert isinstance(graph, torch.fx.Graph)
  1497. inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes))
  1498. the_last_input = None if len(inputs) == 0 else inputs[-1]
  1499. # Also remove from call_module nodes
  1500. call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes)
  1501. for node in call_module_nodes:
  1502. submodule = _get_attr(module, node.target)
  1503. # remove placeholder from call_module node arguments, only if we've
  1504. # erased the placeholder node in the corresponding _sink_params() call
  1505. if submodule is not None and id(submodule) in module_id_to_inputs_removed:
  1506. node.args = tuple(
  1507. filter(
  1508. lambda n: n.name not in module_id_to_inputs_removed[id(submodule)],
  1509. node.args,
  1510. )
  1511. )
  1512. # Filter out inputs_to_state corresponding to current scope.
  1513. inputs_to_state_of_scope: dict[torch.fx.Node, list[str]] = {}
  1514. for node in inputs:
  1515. if node.name not in inputs_to_state:
  1516. continue
  1517. state_name = None
  1518. for sn in inputs_to_state[node.name]:
  1519. sn_split = sn.split(".")
  1520. if sn_split[: len(scope)] == [x.split("@")[0] for x in scope]:
  1521. state_name = sn_split
  1522. break
  1523. # If there's a mismatch between scope name and state name, then
  1524. # there must be multiple scopes pointing to the same state name,
  1525. # meaning some modules are shared. In such case, we can simply skip
  1526. # updating the current node because another later iteration will
  1527. # take care of this input node when the unique match between scope
  1528. # and state name occurs. To make sure this always happen, we should
  1529. # enforce the invariant that no placeholder node in the unflattened
  1530. # graph appears in inputs_to_state dict, which means all the extra
  1531. # input nodes have been handled.
  1532. if state_name is None:
  1533. continue
  1534. inputs_to_state_of_scope[node] = state_name
  1535. # Record name of remove inputs for return purpose.
  1536. inputs_removed: set[str] = set()
  1537. for node, state_name in inputs_to_state_of_scope.items():
  1538. if len(node.users) > 0:
  1539. attr_path = state_name[len(scope) :]
  1540. state_attr = _get_attr_via_attr_list(module, attr_path)
  1541. assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject))
  1542. # Make sure the newly created get_attr node is placed after the last placeholder node
  1543. with graph.inserting_after(the_last_input):
  1544. new_node = graph.create_node("get_attr", ".".join(attr_path))
  1545. node.replace_all_uses_with(new_node, propagate_meta=True)
  1546. graph.erase_node(node)
  1547. inputs_removed.add(node.name)
  1548. if isinstance(module, InterpreterModule):
  1549. module.finalize()
  1550. return {id(module): inputs_removed}