runtime_assert.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import logging
  4. import operator
  5. import sys
  6. from typing import Any, Optional, TYPE_CHECKING
  7. # Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
  8. if TYPE_CHECKING:
  9. import sympy
  10. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  11. else:
  12. ShapeEnv = Any
  13. import torch
  14. import torch.utils._pytree as pytree
  15. from torch import fx
  16. from torch._subclasses.meta_utils import is_sparse_any
  17. from torch.fx._compatibility import compatibility
  18. from torch.fx._utils import lazy_format_graph_code
  19. from torch.fx.experimental.proxy_tensor import py_sym_types
  20. from torch.fx.experimental.sym_node import SymNode
  21. from torch.fx.graph_module import GraphModule
  22. __all__ = ["insert_deferred_runtime_asserts"]
  23. log = logging.getLogger(__name__)
  24. graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose")
  25. def _get_example_value(node: fx.Node) -> Optional[str]:
  26. """
  27. Get the example value key for a node, since dynamo uses "example_value"
  28. while non-strict export uses "val.
  29. """
  30. if "example_value" in node.meta:
  31. return node.meta["example_value"]
  32. elif "val" in node.meta:
  33. return node.meta["val"]
  34. else:
  35. return None
  36. def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]:
  37. val = _get_example_value(node)
  38. if isinstance(val, py_sym_types):
  39. return val.node.expr
  40. return None
  41. @compatibility(is_backward_compatible=True)
  42. def insert_deferred_runtime_asserts(
  43. gm: GraphModule,
  44. shape_env: ShapeEnv,
  45. name: str,
  46. export: bool = False,
  47. ) -> None:
  48. """
  49. During tracing, we may have discovered that some data-dependent values
  50. had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime
  51. that x.item() >= 0. This asserts can happen unpredictably during fake
  52. tensor propagation, so we cannot conveniently insert them into the FX graph
  53. when they occur. Instead, we accumulate them in the ShapeEnv, and in this
  54. pass insert them into the graph as proper tests.
  55. This pass also deduplicates size-related computation, CSE-ing ops that produce
  56. symbolic values and/or are involved in runtime asserts. Additionally, shape calls
  57. (size/stride/storage_offset) are turned into compute on input sizes if possible,
  58. allowing intermediate tensors to be freed earlier. For example, here dynamo will
  59. DCE the cat and repeat calls:
  60. z = torch.cat([x, x], dim=0) # 2*s0
  61. w = z.repeat(y.shape[0]) # 2*s0*s1
  62. _w = w.shape[0]
  63. # something with _w, but not w ...
  64. # turns into ->
  65. _w0 = 2 * s0
  66. _w = _w0 * s1
  67. # where s0, s1 are either SymInt graph inputs, or the result of added size calls
  68. Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert
  69. the same expression, and redundant constrain_range calls are also deduplicated.
  70. Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate
  71. information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol,
  72. and we delete all previous calls, adding bound checks at the end of this pass.
  73. """
  74. # Import sympy locally
  75. import sympy
  76. from torch._export.passes._node_metadata_hook import _set_node_metadata_hook
  77. from torch.fx.experimental.symbolic_shapes import (
  78. _get_placeholder_expr,
  79. _has_uninterpretable_sympy_function,
  80. CallMethodKey,
  81. cast_symbool_to_symint_guardless,
  82. ConvertIntKey,
  83. DivideByKey,
  84. free_symbols,
  85. InnerTensorKey,
  86. resolve_unbacked_bindings,
  87. )
  88. from torch.utils._sympy.numbers import int_oo
  89. from torch.utils._sympy.reference import (
  90. OptimizedPythonReferenceAnalysis,
  91. PythonReferenceAnalysis,
  92. )
  93. from torch.utils._sympy.value_ranges import ValueRanges
  94. # TODO: Request simplification on runtime asserts before emitting them
  95. ras_by_symbol = shape_env.deferred_runtime_asserts.copy()
  96. graph = gm.graph
  97. tracer = fx.proxy.GraphAppendingTracer(graph)
  98. graph_code_log.debug(
  99. "%s",
  100. lazy_format_graph_code(
  101. f"pre insert_deferred_runtime_asserts {name}", gm, colored=True
  102. ),
  103. )
  104. # We are going to mutate the dict
  105. expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {}
  106. placeholders = set()
  107. first_non_placeholder = None
  108. for node in graph.nodes:
  109. if node.op != "placeholder":
  110. first_non_placeholder = node
  111. break
  112. else:
  113. placeholders.add(node)
  114. def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool:
  115. """
  116. If a size/stride/storage offset call on an intermediate tensor,
  117. we can try to compute the value from input shapes instead.
  118. """
  119. return (
  120. (val := _get_sym_val(node)) is not None
  121. and not isinstance(val, sympy.Number)
  122. # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported
  123. and not _has_uninterpretable_sympy_function(val)
  124. and any(
  125. isinstance(arg, fx.Node)
  126. and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size))
  127. and arg.op != "placeholder"
  128. for arg in node.args
  129. )
  130. )
  131. # Figure out what key to use, val or example_value
  132. val_key = "val"
  133. for node in graph.nodes:
  134. if "example_value" in node.meta:
  135. val_key = "example_value"
  136. break
  137. elif "val" in node.meta:
  138. break
  139. def _node_metadata_hook(
  140. node: torch.fx.Node,
  141. stack_trace: Optional[str] = None,
  142. nn_module_stack: Optional[dict[str, Any]] = None,
  143. ) -> None:
  144. fake_args = pytree.tree_map(
  145. lambda arg: (
  146. _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
  147. ),
  148. node.args,
  149. )
  150. try:
  151. target = node.target
  152. if node.op == "call_method":
  153. assert isinstance(node.target, str)
  154. target = getattr(fake_args[0], node.target)
  155. fake_args = fake_args[1:]
  156. node.meta[val_key] = target(*fake_args) # type: ignore[operator]
  157. except NotImplementedError:
  158. # This can happen when attempting to reify a symbol with an unsupported call_function node,
  159. # e.g. with NestedTensors + sym_size.int via match_symbol().
  160. # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input.
  161. pass
  162. if stack_trace is not None:
  163. node.meta["stack_trace"] = stack_trace
  164. if nn_module_stack is not None:
  165. node.meta["nn_module_stack"] = nn_module_stack
  166. # Track asserts/checks we've added
  167. added_asserts: set[sympy.Expr] = set()
  168. constrained_unbacked_symbols: set[sympy.Symbol] = set()
  169. Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
  170. def _sympy_interp(expr_to_proxy, expr):
  171. # sympy_interp() with hash consing
  172. from sympy import Integer, Number, Symbol
  173. from sympy.logic.boolalg import BooleanAtom
  174. from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
  175. # hash cons
  176. if expr in expr_to_proxy:
  177. return expr_to_proxy[expr]
  178. # base cases, don't cache
  179. if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
  180. return sympy_interp(Analysis, expr_to_proxy, expr)
  181. # hash cons on arguments, run expr handler
  182. expr_to_proxy[expr] = _run_sympy_handler(
  183. Analysis,
  184. [_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
  185. expr,
  186. )
  187. return expr_to_proxy[expr]
  188. def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool:
  189. # This is probably unnecessary, but since torch._check() calls for single-symbol bounds
  190. # like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant
  191. # and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial.
  192. if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan):
  193. return False
  194. lhs, rhs = expr.args
  195. return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or (
  196. isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number)
  197. )
  198. def add_runtime_asserts(ras):
  199. for ra in ras:
  200. if (
  201. # redundant
  202. ra.expr in added_asserts
  203. # if we've already added a constrain_range call for this symbol,
  204. # then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant.
  205. or (
  206. len(ra.expr.free_symbols) == 1
  207. and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols
  208. and _is_bound_expr_for_symbol(ra.expr)
  209. )
  210. # don't try to reify sympy functions we can't turn into FX nodes
  211. or _has_uninterpretable_sympy_function(ra.expr)
  212. ):
  213. continue
  214. log.debug("inserting runtime assert %s", ra.expr)
  215. # Need to process ALL free symbols, not just unbacked ones
  216. fvs = free_symbols(ra.expr)
  217. missing = fvs - expr_to_proxy.keys()
  218. if missing:
  219. i1 = min(missing, key=str)
  220. # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689
  221. # assert shape_env.is_unbacked_symint(i1), i1
  222. ras_by_symbol.setdefault(i1, []).append(ra)
  223. else:
  224. # Convert the sympy expression into a sequence of FX
  225. # nodes
  226. with _set_node_metadata_hook(gm, _node_metadata_hook):
  227. res = _sympy_interp(expr_to_proxy, ra.expr).node
  228. graph.call_function(
  229. torch.ops.aten._assert_scalar.default,
  230. # TODO: use ra.msg here, but it's pretty
  231. # useless right now
  232. (
  233. res,
  234. f"Runtime assertion failed for expression {ra.expr} on node '{res}'",
  235. ),
  236. )
  237. added_asserts.add(ra.expr)
  238. nodes = list(graph.nodes)
  239. for i, node in enumerate(nodes[:-1]):
  240. # Placeholders can match symbols, but when we destructure them
  241. # with size we have to make sure we insert the nodes after all
  242. # the placeholders
  243. with graph.inserting_before(
  244. nodes[i + 1] if node not in placeholders else first_non_placeholder
  245. ):
  246. # Unfortunately, this logic still must remain because manual
  247. # make_fx calls may not explicitly bind all symbolic ints as
  248. # arguments to the function, so we must infer it from the other
  249. # arguments
  250. if (
  251. node in placeholders
  252. and (example_value := _get_example_value(node)) is not None
  253. ):
  254. def match_symbol(symint, cb):
  255. if (
  256. isinstance(symint, torch.SymInt)
  257. and isinstance(symint.node, SymNode)
  258. and isinstance(
  259. s := _get_placeholder_expr(symint.node), sympy.Symbol
  260. )
  261. and s not in expr_to_proxy
  262. ):
  263. with _set_node_metadata_hook(gm, _node_metadata_hook):
  264. expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
  265. log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
  266. match_symbol(example_value, lambda: node)
  267. if isinstance(t := example_value, torch.Tensor):
  268. for i, s in enumerate(t.size()):
  269. match_symbol(
  270. s,
  271. lambda: graph.call_function(
  272. torch.ops.aten.sym_size.int, (node, i)
  273. ),
  274. )
  275. if not is_sparse_any(t):
  276. for i, s in enumerate(t.stride()):
  277. match_symbol(
  278. s,
  279. lambda: graph.call_function(
  280. torch.ops.aten.sym_stride.int, (node, i)
  281. ),
  282. )
  283. match_symbol(
  284. t.storage_offset(),
  285. lambda: graph.call_function(
  286. torch.ops.aten.sym_storage_offset.default, (node,)
  287. ),
  288. )
  289. # Handle asserts that aren't associated with any symbol. This
  290. # doesn't really have to be in the loop as it will only run once,
  291. # it just needs to happen right after the placeholders.
  292. # insert this after placeholders & added sym nodes, and before non-placeholders.
  293. if node == first_non_placeholder:
  294. add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload]
  295. # deduplicate asserts already present in graph, and remove trivial asserts
  296. if node.target in (
  297. torch._check,
  298. torch.ops.aten._assert_scalar.default,
  299. ):
  300. cond = node.args[0] if node.args else node.kwargs.get("cond")
  301. if (
  302. cond == True # noqa: E712
  303. or (assert_expr := _get_sym_val(cond)) in expr_to_proxy
  304. and assert_expr in added_asserts
  305. ):
  306. arg = cond
  307. gm.graph.erase_node(node)
  308. if isinstance(arg, fx.Node) and not arg.users:
  309. gm.graph.erase_node(arg)
  310. else:
  311. added_asserts.add(assert_expr) # type: ignore[arg-type]
  312. # hash cons, replace function calls that return torch.SymInts with direct references to
  313. # FX nodes built up to reify the sympy expression.
  314. if (
  315. node.op != "placeholder"
  316. and (sym_expr := _get_sym_val(node)) is not None
  317. ):
  318. # this guards against deleting calls like item() that produce new untracked symbols
  319. def has_new_untracked_symbols():
  320. for symbol in sym_expr.free_symbols:
  321. if symbol not in expr_to_proxy:
  322. return True
  323. return False
  324. # this guards against deleting calls that produce unbacked bindings we haven't yet seen.
  325. # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
  326. # (is backed), but produces an unbacked symbol. In this case keep the node alive.
  327. resolved_unbacked_bindings = resolve_unbacked_bindings(
  328. shape_env, node.meta.get("unbacked_bindings", {})
  329. )
  330. assert resolved_unbacked_bindings is not None
  331. def has_new_unbacked_bindings():
  332. for key in resolved_unbacked_bindings.keys():
  333. if key not in expr_to_proxy:
  334. return True
  335. return False
  336. # maybe re-reify expression, replace current node
  337. if (
  338. sym_expr in expr_to_proxy
  339. or ( # example value is redundant
  340. _is_intermediate_tensor_sym_call(node)
  341. # shape call on intermediate tensor, turn into computation on input shapes
  342. and not has_new_untracked_symbols()
  343. )
  344. ) and not has_new_unbacked_bindings():
  345. if _is_intermediate_tensor_sym_call(
  346. node
  347. ): # reify from input shapes
  348. with _set_node_metadata_hook(
  349. gm,
  350. functools.partial(
  351. _node_metadata_hook,
  352. stack_trace=node.meta.get("stack_trace"),
  353. nn_module_stack=node.meta.get("nn_module_stack"),
  354. ),
  355. ):
  356. expr_to_proxy[sym_expr] = _sympy_interp(
  357. expr_to_proxy, sym_expr
  358. ) # type: ignore[arg-type]
  359. # won't try DCE-ing tensor compute here
  360. hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
  361. node.replace_all_uses_with(hash_node)
  362. gm.graph.erase_node(node)
  363. log.debug(
  364. "CSE node %s -> %s for expr %s", node, hash_node, sym_expr
  365. )
  366. # store node in hash cons, don't delete/replace
  367. elif sym_expr not in expr_to_proxy and not isinstance(
  368. sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
  369. ): # don't hash cons primitives
  370. expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer) # type: ignore[arg-type]
  371. # We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained,
  372. # so calls before that are redundant.
  373. if node.target in (
  374. torch.ops.aten.sym_constrain_range.default,
  375. torch.ops.aten.sym_constrain_range_for_size.default,
  376. ):
  377. gm.graph.erase_node(node)
  378. defs = []
  379. # AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as
  380. # equivalent, but the refinement calls we perform in this pass may struggle with associating the two.
  381. # More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough
  382. # information about the old symbol when we re-export, raising errors on data-dependent guards.
  383. # Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is.
  384. if unbacked_bindings := resolve_unbacked_bindings(
  385. shape_env, node.meta.get("unbacked_bindings")
  386. ):
  387. for s, keypath in unbacked_bindings.items():
  388. defs.append(s)
  389. # TODO: some CSE when generating these nodes can probably
  390. # help reduce graph size and improve compile time
  391. def go(node, keypath):
  392. if keypath == ():
  393. return node
  394. if (
  395. len(keypath) >= 2
  396. and isinstance(keypath[0], CallMethodKey)
  397. and isinstance(keypath[1], pytree.SequenceKey)
  398. ):
  399. if keypath[0].name == "size":
  400. return go(
  401. graph.call_function(
  402. torch.ops.aten.sym_size.int,
  403. (node, keypath[1].idx),
  404. ),
  405. keypath[2:],
  406. )
  407. if keypath[0].name == "stride":
  408. return go(
  409. graph.call_function(
  410. torch.ops.aten.sym_stride.int,
  411. (node, keypath[1].idx),
  412. ),
  413. keypath[2:],
  414. )
  415. return go(
  416. graph.call_method(
  417. keypath[0].name, (node, keypath[1].idx)
  418. ),
  419. keypath[2:],
  420. )
  421. elif isinstance(keypath[0], CallMethodKey):
  422. if keypath[0].name == "storage_offset":
  423. return go(
  424. graph.call_function(
  425. torch.ops.aten.sym_storage_offset.default,
  426. (node,),
  427. ),
  428. keypath[1:],
  429. )
  430. return go(
  431. graph.call_method(keypath[0].name, (node,)), keypath[1:]
  432. )
  433. elif isinstance(keypath[0], pytree.SequenceKey):
  434. return go(
  435. graph.call_function(
  436. operator.getitem, (node, keypath[0].idx)
  437. ),
  438. keypath[1:],
  439. )
  440. elif isinstance(keypath[0], ConvertIntKey):
  441. return go(
  442. graph.call_function(
  443. cast_symbool_to_symint_guardless, (node,)
  444. ),
  445. keypath[1:],
  446. )
  447. elif isinstance(keypath[0], DivideByKey):
  448. # TODO: need to assert divisibility
  449. return go(
  450. graph.call_function(
  451. operator.floordiv, (node, keypath[0].divisor)
  452. ),
  453. keypath[1:],
  454. )
  455. elif isinstance(keypath[0], InnerTensorKey):
  456. return go(
  457. graph.call_function(
  458. getattr, (node, keypath[0].inner_name)
  459. ),
  460. keypath[1:],
  461. )
  462. else:
  463. raise AssertionError(f"unrecognized keypath {keypath}")
  464. if s not in expr_to_proxy:
  465. with _set_node_metadata_hook(gm, _node_metadata_hook):
  466. expr_to_proxy[s] = fx.Proxy(
  467. go(node, keypath), tracer=tracer
  468. )
  469. log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
  470. for i0 in defs:
  471. ras = ras_by_symbol.pop(i0, [])
  472. # Before we perform any asserts, first apply range
  473. # refinement. This is important, because if we are going
  474. # to retrace the graph (and we typically are if we send
  475. # the graph to AOTAutograd), we need to make sure we apply
  476. # range refinement (ala _check_is_size) first, BEFORE we
  477. # run any of the asserts. Otherwise, we may decide to
  478. # perform substitutions based on the asserts which we then
  479. # can't back out, because value ranges can only be applied
  480. # to asserts.)
  481. #
  482. # A perhaps better long term plan is to avoid this order
  483. # dependence by making it possible to refine ranges on
  484. # arbitrary expressions, not just symbols. But it is not
  485. # so easy to make use of this information, see
  486. # https://twitter.com/ezyang/status/1745801370299482492
  487. # We actually made an attempt at this in
  488. # https://github.com/pytorch/pytorch/pull/119043
  489. # which didn't work.
  490. #
  491. # Another ideas for how to do this:
  492. # - Have bound_sympy be the source of truth of the ranges of any expression
  493. # - Cache intermediate results for every subexpression of bound_sympy
  494. # - This cache should be possible to edit to refine ranges
  495. #
  496. # One issue with this proposal is that if
  497. # we have a bound on 2x, we are not going to be able to
  498. # apply it for 4x. Similarly, we may have bounds for an
  499. # equivalent expression that we are not applying because
  500. # it's not a perfect match (e.g. x < y vs y > x)".
  501. #
  502. # The first issue we already have it and it's impossible
  503. # to solve in general, so any implementation on a best
  504. # effort basis should do.
  505. #
  506. # The second issue is a preexisting one. It can be mitigated
  507. # with a normalization algorithm. In general, it may also
  508. # be on a best effort basis, but since our grammar is not
  509. # terribly difficult, chances are we could even fully
  510. # normalize SymPy expressions... who knows.
  511. if i0 in constrained_unbacked_symbols:
  512. continue # constrain symbol just once
  513. if i0 in shape_env.size_like:
  514. if export:
  515. graph.call_function(
  516. torch.ops.aten.sym_constrain_range_for_size.default,
  517. (expr_to_proxy[i0].node,),
  518. )
  519. else:
  520. graph.call_function(
  521. torch._check_is_size, (expr_to_proxy[i0].node,)
  522. )
  523. vr = shape_env.var_to_range[i0]
  524. if vr.is_int and vr.upper == sys.maxsize - 1:
  525. # treat upper bound == sys.maxsize - 1 for int symbols as +oo
  526. # to avoid redundant runtime assert
  527. vr = ValueRanges(vr.lower, int_oo)
  528. if not shape_env._default_unspecified_value_range().issubset(vr):
  529. # The runtime range is constrained, so add a runtime
  530. # assert and also explicitly refine the range
  531. # (refinement should not be necessary once runtime
  532. # asserts cause refinement, but that's NYI)
  533. def convert(s):
  534. if s in (int_oo, -int_oo):
  535. return None
  536. try:
  537. return int(s)
  538. except TypeError:
  539. return None
  540. if (
  541. expr_to_proxy[i0].node.target
  542. != cast_symbool_to_symint_guardless
  543. ):
  544. # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts
  545. # raises AOTAutograd errors on cast_symbool_to_symint_guardless
  546. with _set_node_metadata_hook(
  547. gm,
  548. functools.partial(
  549. _node_metadata_hook,
  550. stack_trace=node.meta.get("stack_trace"),
  551. nn_module_stack=node.meta.get("nn_module_stack"),
  552. ),
  553. ):
  554. if (min_val := convert(vr.lower)) is not None:
  555. ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
  556. graph.call_function(
  557. torch.ops.aten._assert_scalar.default,
  558. (
  559. ge,
  560. f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
  561. ),
  562. )
  563. added_asserts.add(i0 >= min_val)
  564. if (max_val := convert(vr.upper)) is not None:
  565. le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
  566. graph.call_function(
  567. torch.ops.aten._assert_scalar.default,
  568. (
  569. le,
  570. f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
  571. ),
  572. )
  573. added_asserts.add(i0 <= max_val)
  574. constrained_unbacked_symbols.add(i0)
  575. add_runtime_asserts(ras)
  576. # delete unused reified symbols
  577. for expr, proxy in expr_to_proxy.items():
  578. if (
  579. isinstance(expr, sympy.Symbol)
  580. and proxy.node.op != "placeholder" # keep placeholders intact
  581. and not proxy.node.users
  582. ):
  583. log.debug("deleting unused reified symbol for %s", expr)
  584. gm.graph.erase_node(proxy.node)