_unlift.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import inspect
  4. import math
  5. import warnings
  6. from collections.abc import Sequence
  7. from itertools import chain
  8. from typing import Any, Optional
  9. import sympy
  10. import torch
  11. import torch.utils._pytree as pytree
  12. from torch._export.non_strict_utils import (
  13. _enter_enable_graph_inputs_of_type_nn_module,
  14. _exit_enable_graph_inputs_of_type_nn_module,
  15. _get_graph_inputs_of_type_nn_module,
  16. )
  17. from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
  18. _convert_range_to_int,
  19. )
  20. from torch._export.utils import _check_input_constraints_for_graph
  21. from torch.export.unflatten import _assign_attr, _AttrKind
  22. from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
  23. from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
  24. from torch.fx.traceback import NodeSource, NodeSourceAction
  25. from torch.utils._sympy.solve import try_solve
  26. from torch.utils._sympy.value_ranges import ValueRanges
  27. from ._remove_effect_tokens_pass import _remove_effect_tokens
  28. from ._tree_utils import reorder_kwargs
  29. from .exported_program import (
  30. ExportedProgram,
  31. ExportGraphSignature,
  32. InputKind,
  33. OutputKind,
  34. )
  35. def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool:
  36. """
  37. Refinement of TreeSpec.__eq__ where, e.g., torch.Size(...) matches tuple(...).
  38. See _pytree_subclasses_that_lose_info in proxy_tensor.py for more details.
  39. """
  40. def _normalize_type(t):
  41. return str(_pytree_subclasses_that_lose_info.get(t, t))
  42. def _match_normalized_structure(a, b):
  43. if a is b:
  44. return True
  45. if _normalize_type(a.type) != _normalize_type(b.type):
  46. return False
  47. if a.context != b.context:
  48. return False
  49. if len(a.children_specs) != len(b.children_specs):
  50. return False
  51. return all(
  52. _match_normalized_structure(a, b)
  53. for a, b in zip(a.children_specs, b.children_specs)
  54. )
  55. return _match_normalized_structure(self, other)
  56. def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list:
  57. reordered_kwargs = reorder_kwargs(kwargs, in_spec)
  58. flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
  59. (args, reordered_kwargs)
  60. )
  61. if not eq_spec(received_spec, in_spec):
  62. raise ValueError( # noqa: B904
  63. "Trying to flatten user inputs with exported input tree spec: \n"
  64. f"{in_spec}\n"
  65. "but actually got inputs with tree spec of: \n"
  66. f"{received_spec}.\n"
  67. "Please check that the inputs have the same number and type of "
  68. "args and kwargs as the ones you used when tracing."
  69. )
  70. return flat_args_with_path
  71. def _convert_guards_code_to_fn(
  72. guards_code: list[str],
  73. paths_of_placeholders: list[pytree.KeyPath],
  74. ):
  75. """
  76. Generates Python code given guards code and paths of placeholders.
  77. We assume that, based on source information,
  78. - the tracer generates the guards code
  79. - the input spec generates the paths of placeholders.
  80. Example:
  81. Suppose we are given the guards code "L['z']['k'].size()[1] == 3"
  82. and we are given that ['z']['k'] is the path of placeholder #2.
  83. Then we will generate:
  84. ```
  85. torch._assert(
  86. args[2].size()[0] == 3,
  87. "Guard failed: z['k'].size()[0] == 3",
  88. )
  89. ```
  90. FAQ: Why do we generate code based on (flattened) args instead of
  91. the original (unflattened) inputs? Because this would require
  92. inserting an additional pytree.unflatten call in our graph.
  93. FAQ: Why do we not emit RuntimeError on guard failure as we used to?
  94. Because it is inconvenient :/, get used to AssertionError instead.
  95. """
  96. import ast
  97. from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP
  98. actual_guards_code = []
  99. shadow_guards_code = []
  100. for c in guards_code:
  101. a, s = c, c
  102. for idx, path in enumerate(paths_of_placeholders):
  103. # e.g., replace L['z']['k'] with args[2] for Python code (actual)
  104. a = a.replace("L" + pytree.keystr(path), f"args[{idx}]")
  105. # e.g., replace L['z']['k'] with z['k'] for error message (shadow)
  106. s = s.replace(
  107. "L" + pytree.keystr(path),
  108. path[0].key + pytree.keystr(path[1:]), # type: ignore[attr-defined]
  109. )
  110. actual_guards_code.append(a)
  111. shadow_guards_code.append(s.replace("\n", ""))
  112. # generate function code as str
  113. code_str = "\ndef _(*args):\n"
  114. for actual, shadow in zip(actual_guards_code, shadow_guards_code):
  115. # printing guards code may potentially introduce redundant parens;
  116. # we can normalize them out for readability by parsing/unparsing
  117. # NOTE: this is not necessary for correctness, just deemed desirable
  118. _shadow = ast.unparse(ast.parse(shadow, mode="eval"))
  119. # actual code and shadow error message
  120. code_str += f' torch._assert({actual}, "Guard failed: {_shadow}")\n'
  121. code_str += " return\n"
  122. # populate namespace with sympy globals, materialize function (named `_`)
  123. namespace = {**SYMPY_INTERP}
  124. exec(code_str, namespace)
  125. # create and return a module whose forward is the materialized function
  126. # NOTE: we want Dynamo to trace through this module, to repopulate guards:
  127. # otherwise we would lose them when retracing
  128. # NOTE: calling this module will be a side effect (no users): so it must
  129. # be marked impure to avoid being not cleaned up by DCE
  130. guards_fn = GuardsFn()
  131. guards_fn.forward = torch._dynamo.dont_skip_tracing(namespace["_"]) # type: ignore[call-overload, method-assign]
  132. guards_fn._is_impure = True # type: ignore[assignment]
  133. return guards_fn
  134. @torch._dynamo.disable
  135. def _check_input_constraints_for_module(self, args, kwargs):
  136. flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
  137. _check_input_constraints_for_graph(
  138. self.graph.find_nodes(op="placeholder"),
  139. flat_args_with_path,
  140. self.range_constraints,
  141. )
  142. def _check_input_constraints_pre_hook(self, args, kwargs):
  143. # preserve current behavior for clients that do not want any validation
  144. if not self.validate_inputs:
  145. return
  146. # when a guards function exists, assume that the graph does calls it!
  147. # so we do not need to check input constraints...but we still want
  148. # to check inputs match, otherwise we'd get obscure pytree errors
  149. if hasattr(self, "_guards_fn"):
  150. _check_inputs_match(args, kwargs, self._in_spec)
  151. return
  152. # NOTE: this call is Dynamo disabled, as it used to be
  153. _check_input_constraints_for_module(self, args, kwargs)
  154. def _unlift_inputs_as_getattr(
  155. gm: torch.fx.GraphModule,
  156. lifted_inputs: Sequence[Optional[str]],
  157. ) -> tuple[dict[str, torch.fx.Node], dict[str, torch.fx.Node]]:
  158. """
  159. Unlift inputs referring to params/buffers/constants as getattr nodes in the
  160. graph
  161. """
  162. unlifted_name_to_node = {}
  163. input_name_to_node = {}
  164. placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
  165. assert len(lifted_inputs) == len(placeholder_nodes)
  166. for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
  167. if lifted_node is None:
  168. input_name_to_node[input_node.name] = input_node
  169. else:
  170. with gm.graph.inserting_after(input_node):
  171. # It is fine to ignore this warning because
  172. # it is guaranteed that we will populate this
  173. # attr later.
  174. with warnings.catch_warnings():
  175. warnings.simplefilter("ignore")
  176. getattr_node = gm.graph.get_attr(lifted_node)
  177. input_node.replace_all_uses_with(getattr_node)
  178. metadata = input_node.meta
  179. gm.graph.erase_node(input_node)
  180. getattr_node.meta = metadata
  181. getattr_node.meta["from_node"] = [
  182. NodeSource(
  183. input_node,
  184. "ExportedProgram.module().unlift()",
  185. [NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
  186. )
  187. ]
  188. unlifted_name_to_node[lifted_node] = getattr_node
  189. return unlifted_name_to_node, input_name_to_node
  190. def _insert_copy_for_mutations(
  191. gm: torch.fx.GraphModule,
  192. mutated_outputs: Sequence[Optional[str]],
  193. unlifted_name_to_node: dict[str, torch.fx.Node],
  194. input_name_to_node: dict[str, torch.fx.Node],
  195. ) -> None:
  196. """
  197. Find the all the buffers and inputs that were mutated and insert copy_
  198. operators to reflect mutations.
  199. """
  200. output_node = gm.graph.output_node()
  201. outputs = pytree.tree_flatten(output_node.args)[0]
  202. assert len(outputs) == len(mutated_outputs)
  203. user_output_nodes = []
  204. return_nodes_to_copy = {}
  205. for return_node, mutated_node_name in zip(outputs, mutated_outputs):
  206. if mutated_node_name is None:
  207. user_output_nodes.append(return_node)
  208. continue
  209. if mutated_node_name in unlifted_name_to_node:
  210. mutated_node = unlifted_name_to_node[mutated_node_name]
  211. elif mutated_node_name in input_name_to_node:
  212. mutated_node = input_name_to_node[mutated_node_name]
  213. else:
  214. raise RuntimeError(
  215. f"Could not find {mutated_node_name} in either buffer or input nodes"
  216. )
  217. with gm.graph.inserting_before(output_node):
  218. copy_node = gm.graph.call_function(
  219. torch.ops.aten.copy_.default, (mutated_node, return_node)
  220. )
  221. return_nodes_to_copy[return_node] = copy_node
  222. output_args = tuple(
  223. return_nodes_to_copy[node] if node in return_nodes_to_copy else node
  224. for node in user_output_nodes
  225. )
  226. with gm.graph.inserting_before(output_node):
  227. # Only return user outputs
  228. new_output = gm.graph.output(output_args)
  229. output_node.replace_all_uses_with(new_output)
  230. gm.graph.erase_node(output_node)
  231. new_output.name = output_node.name
  232. new_output.meta.update(output_node.meta)
  233. new_output.meta["from_node"] = [
  234. NodeSource(
  235. output_node,
  236. "ExportedProgram.module().unlift()",
  237. [NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
  238. )
  239. ]
  240. def _get_codegen(
  241. in_spec: pytree.TreeSpec,
  242. out_spec: Optional[pytree.TreeSpec],
  243. forward_arg_names: Optional[list[str]] = None,
  244. ) -> _PyTreeCodeGen:
  245. """
  246. Create the codegen for the graph module based on the in/out specs
  247. """
  248. if forward_arg_names:
  249. names = forward_arg_names
  250. elif (
  251. in_spec.type == tuple
  252. and in_spec.num_children == 2
  253. and in_spec.children_specs[0].type == tuple
  254. and in_spec.children_specs[1].type == dict
  255. ):
  256. # if in_spec contains the args (tuple) and kwargs (dict)
  257. names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
  258. # add kwarg names
  259. names.extend(in_spec.children_specs[1].context)
  260. else:
  261. names = [f"arg_{i}" for i in range(in_spec.num_children)]
  262. return _PyTreeCodeGen(
  263. _PyTreeInfo(
  264. names,
  265. in_spec,
  266. out_spec,
  267. )
  268. )
  269. def _unlift(
  270. gm: torch.fx.GraphModule,
  271. lifted_inputs: Sequence[Optional[str]],
  272. mutated_outputs: Sequence[Optional[str]],
  273. in_spec: pytree.TreeSpec,
  274. out_spec: Optional[pytree.TreeSpec],
  275. forward_arg_names: Optional[list[str]] = None,
  276. ):
  277. """
  278. Args:
  279. lifted_inputs: A list matching the graph module's input nodes. For
  280. an input node that is referring to a lifted parameter/buffer, this
  281. list will contain the fqn the corresponding attribute. Otherwise, this
  282. list will contain None. This is used to unlift the lifted parameters as
  283. get_attr nodes.
  284. mutated_outputs: A list matching the graph module's output nodes. For
  285. an output node that is referring to a mutated buffer or user input, this
  286. list will contain the name of the corresponding buffer or user input
  287. that needs to be mutated. Otherwise, this list will contain None. This
  288. is used to re-insert an inplace copy_ operator to copy the mutated
  289. values back to the original node.
  290. """
  291. unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr(
  292. gm, lifted_inputs
  293. )
  294. _insert_copy_for_mutations(
  295. gm, mutated_outputs, unlifted_name_to_node, input_name_to_node
  296. )
  297. gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
  298. gm.graph.lint()
  299. gm.recompile()
  300. return gm
  301. def _register_attrs_to_new_gm(
  302. new_gm: torch.fx.GraphModule,
  303. graph_signature: ExportGraphSignature,
  304. state_dict: dict[str, Any],
  305. constants: dict[str, Any],
  306. ) -> None:
  307. non_persistent_buffers = set(graph_signature.non_persistent_buffers)
  308. for name in graph_signature.buffers:
  309. if name in non_persistent_buffers:
  310. persistent = False
  311. value = constants[name]
  312. else:
  313. persistent = True
  314. value = state_dict[name]
  315. _assign_attr(
  316. value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent
  317. )
  318. for name in graph_signature.parameters:
  319. value = state_dict[name]
  320. _assign_attr(
  321. value,
  322. new_gm,
  323. name,
  324. attr_kind=_AttrKind.PARAMETER,
  325. )
  326. # Technically this doesn't account for the aliased multiple constants but
  327. # it is ok because we have a separate pass later in the stack that populates
  328. # the final gm.
  329. for name in chain(
  330. graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants
  331. ):
  332. value = constants[name]
  333. _assign_attr(
  334. value,
  335. new_gm,
  336. name,
  337. attr_kind=_AttrKind.CONSTANT,
  338. )
  339. class _StatefulGraphModuleFactory(type):
  340. """
  341. Metaclass that ensures a private constructor for _StatefulGraphModule
  342. """
  343. def __call__(cls, *args, **kwargs):
  344. raise TypeError(
  345. f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
  346. )
  347. def _create(cls, root, graph, range_constraints=None):
  348. return super().__call__(
  349. root,
  350. graph,
  351. range_constraints=range_constraints,
  352. )
  353. class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
  354. def __init__(self, root, graph, range_constraints=None):
  355. super().__init__(root, graph)
  356. # Need to fix up non-persistent buffers.
  357. self.range_constraints = range_constraints or []
  358. self.validate_inputs = True
  359. def _create_stateful_graph_module(
  360. plain_graph_module: torch.fx.GraphModule,
  361. range_constraints,
  362. ep: ExportedProgram,
  363. ) -> _StatefulGraphModule:
  364. stateful_gm = _StatefulGraphModule._create(
  365. plain_graph_module,
  366. plain_graph_module.graph,
  367. range_constraints=range_constraints,
  368. )
  369. module_types = _get_graph_inputs_of_type_nn_module(ep.example_inputs)
  370. stateful_gm.register_forward_pre_hook(
  371. lambda *args, **kwargs: _enter_enable_graph_inputs_of_type_nn_module(
  372. module_types
  373. )
  374. )
  375. stateful_gm.register_forward_pre_hook(
  376. _check_input_constraints_pre_hook, with_kwargs=True
  377. )
  378. stateful_gm.register_forward_hook(
  379. lambda *args, **kwargs: _exit_enable_graph_inputs_of_type_nn_module(
  380. module_types
  381. ),
  382. always_call=True,
  383. )
  384. # When we have a constant that has requires_grad=True, we need to detach it
  385. # when we unlift as the tensors that require gradients should be registered
  386. # via parameters. But this is problematic when we have aliasing two constants
  387. # because when we call detach, they will become different tensors. This dict
  388. # keeps track of this logic.
  389. original_tensor_to_detached_tensor = {}
  390. # Fix up lifted tensor constants.
  391. # fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module
  392. # into a buffer in stateful_gm and creates an inconsistency with graph_signature.
  393. # We fix this by de-registering these buffers in lifted_tensor_constants
  394. # and call _assign_attr(attr_kind=CONSTANT) to register them as constants.
  395. for constant_fqn in ep.graph_signature.lifted_tensor_constants:
  396. # Sometimes, the constant can require gradient, this is probably a bug in user code,
  397. # e.g. `self.const = torch.randn(2, 2, requires_grad=True)`.
  398. # We call detach on the constant_val since they're tensor constants and we don't need to
  399. # compute their gradients anyway.
  400. # Users should properly register it as parameter if they want it to require gradient.
  401. buffer = stateful_gm.get_buffer(constant_fqn)
  402. if buffer.requires_grad:
  403. warnings.warn(
  404. f"A model attribute `{constant_fqn}` requires gradient. "
  405. f"but it's not properly registered as a parameter. "
  406. f"torch.export will detach it and treat it as a constant tensor "
  407. f"but please register it as parameter instead."
  408. )
  409. detached_buffer = buffer.detach()
  410. original_tensor_to_detached_tensor[buffer] = detached_buffer
  411. buffer = detached_buffer
  412. *prefix, field = constant_fqn.rsplit(".")
  413. submod = torch.fx.graph_module._get_attr_via_attr_list(stateful_gm, prefix)
  414. delattr(submod, field)
  415. _assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT)
  416. # Constants are not preserved well when we create a new GraphModule unlike param/buffers
  417. for const_name, value in ep.constants.items():
  418. if not torch.fx.graph_module._has_attr(stateful_gm, const_name):
  419. if isinstance(value, torch.Tensor):
  420. if value.requires_grad:
  421. warnings.warn(
  422. f"A model attribute `{const_name}` requires gradient "
  423. f"but it's not properly registered as a parameter. "
  424. f"torch.export will detach it and treat it as a constant tensor "
  425. f"but please register it as parameter instead."
  426. )
  427. if value in original_tensor_to_detached_tensor:
  428. value = original_tensor_to_detached_tensor[value]
  429. else:
  430. detached_value = value.detach()
  431. original_tensor_to_detached_tensor[value] = detached_value
  432. value = detached_value
  433. _assign_attr(
  434. value,
  435. stateful_gm,
  436. const_name,
  437. attr_kind=_AttrKind.CONSTANT,
  438. )
  439. # Fix up non-persistent buffers. torch.fx does not distinguish between
  440. # persistent and non-persistent buffers, so we must restore that distinction
  441. # here.
  442. for buffer in ep.graph_signature.non_persistent_buffers:
  443. _assign_attr(
  444. plain_graph_module.get_buffer(buffer),
  445. stateful_gm,
  446. buffer,
  447. attr_kind=_AttrKind.BUFFER,
  448. persistent=False,
  449. )
  450. return stateful_gm
  451. def _get_input_paths(example_inputs, signature):
  452. """
  453. Generate paths of placeholders, needed for generating the guards function.
  454. NOTE: Here we make use of the example inputs used for export as well as
  455. the signature of the unlifted graph module (not preserved by export).
  456. """
  457. args, kwargs = example_inputs
  458. ctx = signature.bind(*args, **kwargs).arguments
  459. flat_example_inputs_with_paths = pytree.tree_leaves_with_path(ctx)
  460. return [path for path, _ in flat_example_inputs_with_paths]
  461. def _get_input_guards_for_graph(
  462. placeholders: list[torch.fx.Node],
  463. range_constraints: dict[sympy.Symbol, ValueRanges],
  464. paths_for_placeholders: list[pytree.KeyPath],
  465. ):
  466. """
  467. Guards generated by the tracer include conditions observed in code, but
  468. but do not include some additional checks we typically do in export.
  469. For example, when dynamic shapes get specialized, are specified to be
  470. within a range, or are specified to be in some equational relation,
  471. corresponding input invalidation is done within a pre_hook, specifically,
  472. `_check_input_constraints_for_graph`.
  473. Here we generate guards corresponding to the checks that happen in
  474. `_check_input_constraints_for_graph`, and add them to the guards already
  475. generated by the tracer. In the future, it may be worthwhile to separate
  476. them so that we can allow clients to turn off one but not the other.
  477. (Looking at you, AOTI.)
  478. NOTE: We should eventually reconcile this logic with `build_guards` that
  479. is used by AOT Precompile.
  480. """
  481. deferred_expressions = []
  482. new_guards_code = []
  483. sources: dict[sympy.Expr, str] = {}
  484. def handle_symint(expr, src):
  485. if len(expr.free_symbols) == 1:
  486. # complex equations (e.g., involving derived dims) need to
  487. # handled later, since we may not have enough information
  488. # just as we are passing through the placeholders in order
  489. deferred_expressions.append((src, expr))
  490. if expr in sources:
  491. # expressions that appear in multiple sources should force
  492. # inputs corresponding to those sources to be equal
  493. # e.g., x.shape[0] == y.shape[1]
  494. orig_src = sources[expr]
  495. new_guards_code.append(f"{src} == {orig_src}")
  496. else:
  497. sources[expr] = src
  498. # process value ranges as elsewhere in export
  499. min_val, max_val = _convert_range_to_int(range_constraints[expr])
  500. if min_val > 2:
  501. new_guards_code.append(f"{src} >= {min_val}")
  502. if max_val < math.inf:
  503. new_guards_code.append(f"{src} <= {max_val}")
  504. for placeholder, path in zip(placeholders, paths_for_placeholders):
  505. src = "L" + pytree.keystr(path)
  506. meta = placeholder.meta["val"]
  507. # specializations
  508. if isinstance(meta, int):
  509. new_guards_code.append(f"{src} == {meta}")
  510. if isinstance(meta, float):
  511. if meta == math.inf:
  512. new_guards_code.append(f"{src} == math.inf")
  513. elif meta == -math.inf:
  514. new_guards_code.append(f"{src} == -math.inf")
  515. else:
  516. new_guards_code.append(f"{src} == {meta}")
  517. elif isinstance(meta, str):
  518. new_guards_code.append(f"{src} == '{meta}'")
  519. # range constraints and equalities
  520. elif isinstance(meta, torch.SymInt) and meta.node.expr in range_constraints:
  521. handle_symint(meta.node.expr, src)
  522. elif isinstance(meta, torch.Tensor):
  523. for i, dim in enumerate(meta.shape):
  524. src = "L" + pytree.keystr(path) + f".size()[{i}]"
  525. if isinstance(dim, int):
  526. # specializations
  527. new_guards_code.append(f"{src} == {dim}")
  528. elif (
  529. isinstance(dim, torch.SymInt) and dim.node.expr in range_constraints
  530. ):
  531. # range constraints and equalities
  532. handle_symint(dim.node.expr, src)
  533. unification_map: dict[sympy.Symbol, sympy.Expr] = {}
  534. py_printer = torch.utils._sympy.printers.PythonPrinter()
  535. # process complex equations (e.g., involving derived dims)
  536. for src, expr in deferred_expressions:
  537. # we know this is the only symbol in expr (see check above)
  538. symbol = next(iter(expr.free_symbols))
  539. if symbol in sources:
  540. # if s0 is already known to be directly sourced from inputs,
  541. # e.g., z.shape[2], we do not need to do anything further
  542. # (assume we have already processed constraints on s0 above)
  543. continue
  544. # otherwise s0 has some "hidden" source like 'dim'
  545. # example: src = y.shape[1], expr = s0 + 1
  546. if symbol in unification_map:
  547. # suppose that we already know that s0 = x.shape[0] * 2
  548. # so we can emit the guard: x.shape[0] * 2 + 1 = y.shape[1]
  549. substitution = expr.subs(unification_map)
  550. new_guards_code.append(
  551. py_printer.doprint(sympy.Eq(substitution, sympy.Symbol(src)))
  552. )
  553. else:
  554. # we do not yet know what s0 is, but given s0 + 1 = y.shape[1],
  555. # we can solve for s0...now knowing that s0 = y.shape[1] - 1
  556. solution = try_solve(sympy.Eq(expr, sympy.Symbol(src)), symbol)
  557. if solution is not None:
  558. definition = solution[1]
  559. unification_map[symbol] = definition
  560. return new_guards_code
  561. def _unlift_exported_program_lifted_states(
  562. ep: ExportedProgram, check_guards=True
  563. ) -> torch.fx.GraphModule:
  564. # force check_guards=False for executorch because
  565. # its pass infra has too many calls to .module()
  566. # and but does not like call modules in the graph
  567. # TODO: update executorch to check_guards=False
  568. frame = inspect.currentframe()
  569. while frame is not None:
  570. if "executorch" in frame.f_code.co_filename:
  571. check_guards = False
  572. break
  573. frame = frame.f_back
  574. # TODO T206340015
  575. if ep.verifiers[0].dialect != "TRAINING":
  576. ep = _remove_effect_tokens(ep)
  577. new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
  578. _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
  579. forward_arg_names = (
  580. sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
  581. )
  582. lifted_inputs: list[Optional[str]] = [
  583. (
  584. in_spec.target
  585. if in_spec.kind
  586. in (
  587. InputKind.BUFFER,
  588. InputKind.CONSTANT_TENSOR,
  589. InputKind.PARAMETER,
  590. InputKind.CUSTOM_OBJ,
  591. )
  592. else None
  593. )
  594. for in_spec in ep.graph_signature.input_specs
  595. ]
  596. mutated_outputs: list[Optional[str]] = [
  597. (
  598. out_spec.target
  599. if out_spec.kind
  600. in (
  601. OutputKind.BUFFER_MUTATION,
  602. OutputKind.USER_INPUT_MUTATION,
  603. OutputKind.PARAMETER_MUTATION,
  604. )
  605. else None
  606. )
  607. for out_spec in ep.graph_signature.output_specs
  608. ]
  609. source_node_dict = {
  610. node.name: node for node in ep.graph.nodes if node.op != "placeholder"
  611. }
  612. # placeholder node name might change after deepcopy
  613. placeholder_source_node_dict = {
  614. node.target: node for node in ep.graph.nodes if node.op == "placeholder"
  615. }
  616. for node in new_gm.graph.nodes:
  617. source_node = None
  618. if node.op == "placeholder":
  619. source_node = placeholder_source_node_dict.get(node.target)
  620. else:
  621. source_node = source_node_dict.get(node.name)
  622. node.meta["from_node"] = [
  623. NodeSource(
  624. source_node,
  625. "ExportedProgram.module()",
  626. NodeSourceAction.CREATE,
  627. )
  628. ]
  629. assert ep.call_spec.in_spec is not None
  630. new_gm = _unlift(
  631. new_gm,
  632. lifted_inputs,
  633. mutated_outputs,
  634. ep.call_spec.in_spec,
  635. ep.call_spec.out_spec,
  636. forward_arg_names=forward_arg_names,
  637. )
  638. unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep)
  639. unlift_gm.meta.update(ep.graph_module.meta)
  640. # create a _guards_fn submodule and insert a call to it after placeholders
  641. graph = unlift_gm.graph
  642. placeholders = graph.find_nodes(op="placeholder")
  643. if check_guards and placeholders and ep.example_inputs:
  644. input_paths = _get_input_paths(
  645. ep.example_inputs,
  646. inspect.signature(unlift_gm.forward),
  647. )
  648. guards_code = _get_input_guards_for_graph(
  649. placeholders, ep.range_constraints, input_paths
  650. )
  651. guards_code.extend(ep._guards_code)
  652. unlift_gm._guards_fn = _convert_guards_code_to_fn(guards_code, input_paths)
  653. root_nn_module_stack = torch.fx._utils.first_call_function_nn_module_stack(
  654. graph
  655. )
  656. with graph.inserting_after(placeholders[-1]):
  657. node = graph.call_module("_guards_fn", tuple(placeholders))
  658. node.meta["nn_module_stack"] = root_nn_module_stack
  659. unlift_gm.recompile()
  660. return unlift_gm
  661. class GuardsFn(torch.nn.Module):
  662. """
  663. Module class for guard functions.
  664. """
  665. def forward(self, *args):
  666. pass