sym_node.py 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. """
  4. This file does three things:
  5. - Contains the definition of SymNode
  6. - Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
  7. - Does not depend on sympy at import time
  8. As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
  9. to avoid having to load SymPy at import time, as doing so is *very* slow.
  10. """
  11. import builtins
  12. import functools
  13. import inspect
  14. import itertools
  15. import logging
  16. import math
  17. import operator
  18. import sys
  19. from functools import lru_cache, update_wrapper
  20. from typing import Optional, TYPE_CHECKING, Union
  21. import torch
  22. import torch._logging.structured as structured
  23. # NB: The sym_* functions are used via getattr() and must be imported here.
  24. from torch import ( # noqa: F401
  25. sym_float,
  26. sym_ite,
  27. sym_max,
  28. sym_min,
  29. sym_not,
  30. SymBool,
  31. SymFloat,
  32. SymInt,
  33. )
  34. from torch._logging import dtrace_structured
  35. if TYPE_CHECKING:
  36. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  37. log = logging.getLogger(__name__)
  38. sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
  39. __all__ = ["SymNode", "method_to_operator", "magic_methods"]
  40. from torch.types import py_sym_types as SymTypes
  41. def _to_symtype(t):
  42. if t is bool:
  43. return SymBool
  44. if t is int:
  45. return SymInt
  46. if t is float:
  47. return SymFloat
  48. return t
  49. # TODO: An incomplete list
  50. # 1. Set variables to be equal when we do equality
  51. # 2. Specialize on 0/1 when we do subtraction
  52. class SymNode:
  53. """
  54. This is a type erased SymInt/SymFloat which we use to do actual operations.
  55. End users don't touch this. Magic methods are NOT defined on this object.
  56. """
  57. # Note [optimized_summation]: indicates that SymNode is an Add expression of the form
  58. # a + b + c + d... etc where all terms are unique symbols. This allows us to do some optimizations
  59. # for common patterns see _optimized_add.
  60. # The unfortunate reason we have this here is because sympy sets __slots__ = () for add expression,
  61. # so we cannot add the attribute directly to the sympy expression. Furthermore, we cannot use it as
  62. # a weak dictionary key either! So instead, we attach the attribute here to the SymNode.
  63. _optimized_summation: bool = False
  64. def __init__(
  65. self,
  66. expr,
  67. shape_env,
  68. pytype,
  69. hint: Optional[Union[int, float, bool]],
  70. constant=None,
  71. fx_node=None,
  72. optimized_summation=False,
  73. ):
  74. self._expr = expr
  75. self.shape_env = shape_env
  76. self.pytype = pytype
  77. self._optimized_summation = optimized_summation
  78. # What's the difference between hint and constant?
  79. #
  80. # - A constant is known to be invariant across invocations of the model;
  81. # it will always be this value. We only really know this when we
  82. # encounter an honest-to-goodness literal (when wrapping it into
  83. # a SymNode, we set constant.) Most of the time, constant is None
  84. #
  85. # - A hint is a *particular* value from the particular run we are
  86. # tracing, but it may vary the next time around. It's useful to
  87. # keep this around, as if we need a concrete value from a SymNode,
  88. # we will return the hint and guard on the expression that produced
  89. # it giving the same hint next time around. The hint is not
  90. # guaranteed to be set either: if you have an unbacked SymNode,
  91. # there won't be any hint; it was the result of some tensor-dependent
  92. # computation, but we don't know what it actually is because we
  93. # haven't actually run the tensor computation.
  94. #
  95. # If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
  96. # in hopes that we've learned enough about the unbacked symints to
  97. # discharge the hint; otherwise, you're likely to just error out.
  98. #
  99. # (A previous version of this system had some optimizations to only
  100. # recompute when it was possible we had learned enough about the
  101. # unbacked symint that a hint was now possible, but as we added more
  102. # potential refinements to unbacked symints this got harder to keep
  103. # in sync, so we've deleted it for now.)
  104. def compute_hint():
  105. from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
  106. # This occasionally gets exercised by, e.g.,
  107. # convert_shape_to_symint. It's just a nicety so you don't HAVE
  108. # to have a correct hint on hand when making a SymNode.
  109. # Don't attempt to compute for unbacked, this can be quite
  110. # expensive.
  111. if has_free_unbacked_symbols(self.expr):
  112. return None
  113. hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
  114. if hint is not None:
  115. hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
  116. return hint
  117. if hint is not None:
  118. assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
  119. "Cannot create SymNode of type "
  120. f"{pytype} with incompatible hint of type {type(hint)}"
  121. )
  122. if self.shape_env and self.shape_env._translation_validation_enabled:
  123. # This is technically not TV, but this assert is expensive so
  124. # let's only do it when we're already doing expensive things
  125. computed_hint = compute_hint()
  126. assert hint == computed_hint, (
  127. f"{hint} != {computed_hint} (for {self.expr})"
  128. )
  129. else:
  130. hint = compute_hint()
  131. self._hint = hint
  132. self.constant: Optional[Union[int, float, bool]] = constant
  133. # Record the FX node of the current node if we are doing translation
  134. # validation. They will be used for building the input assertions for
  135. # the translation validation problem.
  136. tx_validation_en = (
  137. self.shape_env and self.shape_env._translation_validation_enabled
  138. )
  139. self.fx_node = tx_validation_en and fx_node
  140. def with_shape_env(self, shape_env: ShapeEnv) -> SymNode:
  141. return SymNode(
  142. self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
  143. )
  144. def _value_eq(self, other: SymNode) -> bool:
  145. # Purposely don't include the shape_env in the eq.
  146. return (
  147. self._expr == other._expr
  148. and self.pytype == other.pytype
  149. and self._hint == other._hint
  150. and self.constant == other.constant
  151. and self.fx_node == other.fx_node
  152. )
  153. def _value_hash(self) -> int:
  154. # Purposely don't include the shape_env in the hash.
  155. return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
  156. @property
  157. def expr(self):
  158. return self.shape_env.replace(self._expr)
  159. @property
  160. def hint(self):
  161. return self._hint
  162. def has_hint(self):
  163. return self._hint is not None
  164. def require_hint(self, fallback=None):
  165. from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
  166. if self._hint is None:
  167. if fallback is not None:
  168. # Say we have some expr like 2*u0 + s0
  169. # The hint will be None, since the expr contains at least 1 unbacked.
  170. # We will:
  171. # - replace every backed free symbol with its corresponding hint
  172. # - replace every unbacked free symbol with the fallback
  173. # - regenerate the expression with those symbol replacements
  174. # Note: this is not really complete either, since right now
  175. # this logic does not take into account any value ranges
  176. # for the unbacked symints, we may need to beef it up at some point.
  177. unbacked_symbols = free_unbacked_symbols(self.expr)
  178. replacements = {
  179. s: 4096 if s in unbacked_symbols else self.shape_env.var_to_val[s]
  180. for s in self.expr.free_symbols
  181. }
  182. return self.expr.xreplace(replacements)
  183. # NB: we expect this to raise
  184. return self.shape_env.size_hint(self.expr)
  185. return self._hint
  186. def maybe_as_int(self):
  187. if self.expr.is_number:
  188. return int(self.expr)
  189. else:
  190. return None
  191. # NB: This does conversions, not sure if this is good or not
  192. def maybe_as_float(self):
  193. import sympy
  194. if isinstance(self.expr, sympy.Float):
  195. return float(self.expr)
  196. else:
  197. return None
  198. def maybe_as_bool(self):
  199. import sympy
  200. if self.expr is sympy.true:
  201. return True
  202. elif self.expr is sympy.false:
  203. return False
  204. else:
  205. return None
  206. def is_int(self):
  207. return self.pytype is int
  208. def is_float(self):
  209. return self.pytype is float
  210. def is_bool(self):
  211. return self.pytype is bool
  212. def is_nested_int(self):
  213. # Unbacked SymInts cannot be nested int today
  214. return (
  215. self._hint is not None
  216. and isinstance(self._hint, SymInt)
  217. and self._hint.node.is_nested_int()
  218. )
  219. def wrap_int(self, num):
  220. assert type(num) is int
  221. import sympy
  222. return SymNode(
  223. sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
  224. )
  225. def wrap_float(self, num):
  226. assert type(num) is float
  227. import sympy
  228. return SymNode(
  229. sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
  230. )
  231. def wrap_bool(self, num):
  232. assert type(num) is bool
  233. import sympy
  234. return SymNode(
  235. sympy.true if num else sympy.false,
  236. self.shape_env,
  237. bool,
  238. num,
  239. constant=num,
  240. fx_node=num,
  241. )
  242. def clone(self):
  243. return self
  244. def str(self):
  245. return f"{self.expr}"
  246. def __str__(self):
  247. return self.str()
  248. def __repr__(self):
  249. rep = [
  250. f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
  251. ]
  252. if self._hint is not None:
  253. rep.append(f"hint={self._hint}")
  254. if self.constant is not None:
  255. rep.append(f"constant={self.constant}")
  256. if self.fx_node is not None:
  257. rep.append(f"fx_node={self.fx_node}")
  258. return ", ".join(rep) + ")"
  259. def _graph_repr(self) -> builtins.str:
  260. # Representation used by GraphModule to create a pythonic version of a graph
  261. return self.str()
  262. # These methods call the metaprogrammed methods, they're hand written
  263. # here so we get good stack traces
  264. def abs(self) -> SymNode:
  265. return self._abs() # type: ignore[attr-defined]
  266. def pos(self) -> SymNode:
  267. return self._pos() # type: ignore[attr-defined]
  268. def round(self, ndigits=None) -> SymNode:
  269. return self._round(ndigits) # type: ignore[attr-defined]
  270. def trunc(self) -> SymNode:
  271. return self._trunc() # type: ignore[attr-defined]
  272. def add(self, other) -> SymNode:
  273. return self._add(other) # type: ignore[attr-defined]
  274. def sub(self, other) -> SymNode:
  275. return self._sub(other) # type: ignore[attr-defined]
  276. def mul(self, other) -> SymNode:
  277. return self._mul(other) # type: ignore[attr-defined]
  278. def mod(self, other) -> SymNode:
  279. return self._mod(other) # type: ignore[attr-defined]
  280. def float_pow(self, other) -> SymNode:
  281. return self._float_pow(other) # type: ignore[attr-defined]
  282. def pow_by_natural(self, other) -> SymNode:
  283. return self._pow_by_natural(other) # type: ignore[attr-defined]
  284. def and_(self, other) -> SymNode:
  285. return self._and_(other) # type: ignore[attr-defined]
  286. def or_(self, other) -> SymNode:
  287. return self._or_(other) # type: ignore[attr-defined]
  288. def float_truediv(self, other) -> SymNode:
  289. return self._float_truediv(other) # type: ignore[attr-defined]
  290. def int_truediv(self, other) -> SymNode:
  291. return self._int_truediv(other) # type: ignore[attr-defined]
  292. def int_floordiv(self, other) -> SymNode:
  293. return self._int_floordiv(other) # type: ignore[attr-defined]
  294. def lshift(self, other) -> SymNode:
  295. return self._lshift(other) # type: ignore[attr-defined]
  296. def rshift(self, other) -> SymNode:
  297. return self._rshift(other) # type: ignore[attr-defined]
  298. def sym_not(self) -> SymNode: # noqa: F811
  299. return self._sym_not() # type: ignore[attr-defined]
  300. def eq(self, other) -> SymNode:
  301. return self._eq(other) # type: ignore[attr-defined]
  302. def ne(self, other) -> SymNode:
  303. return self._ne(other) # type: ignore[attr-defined]
  304. def gt(self, other) -> SymNode:
  305. return self._gt(other) # type: ignore[attr-defined]
  306. def lt(self, other) -> SymNode:
  307. return self._lt(other) # type: ignore[attr-defined]
  308. def le(self, other) -> SymNode:
  309. return self._le(other) # type: ignore[attr-defined]
  310. def ge(self, other) -> SymNode:
  311. return self._ge(other) # type: ignore[attr-defined]
  312. def floor(self) -> SymNode:
  313. return self._floor() # type: ignore[attr-defined]
  314. def is_integer(self) -> SymNode:
  315. return self._is_integer() # type: ignore[attr-defined]
  316. def sym_float(self) -> SymNode: # noqa: F811
  317. return self._sym_float() # type: ignore[attr-defined]
  318. def sym_int(self) -> SymNode:
  319. return self._sym_int() # type: ignore[attr-defined]
  320. def ceil(self) -> SymNode:
  321. return self._ceil() # type: ignore[attr-defined]
  322. def neg(self) -> SymNode:
  323. return self._neg() # type: ignore[attr-defined]
  324. def sym_min(self, other) -> SymNode: # noqa: F811
  325. return self._sym_min(other) # type: ignore[attr-defined]
  326. def sym_max(self, other) -> SymNode: # noqa: F811
  327. return self._sym_max(other) # type: ignore[attr-defined]
  328. def sym_ite(self, then_val, else_val) -> SymNode:
  329. return self._sym_ite(then_val, else_val) # type: ignore[attr-defined]
  330. def is_contiguous(self, sizes, strides) -> SymNode:
  331. return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
  332. def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode:
  333. return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
  334. def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode:
  335. return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
  336. def is_channels_last_strides_2d(self, sizes, strides) -> SymNode:
  337. return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
  338. def is_channels_last_strides_3d(self, sizes, strides) -> SymNode:
  339. return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
  340. def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode:
  341. return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
  342. # Make C++ happy
  343. def sym_or(self, other):
  344. return self.or_(other)
  345. def sym_and(self, other):
  346. return self.and_(other)
  347. # Integer bitwise ops
  348. def bitwise_and(self, other):
  349. return self._bitwise_and(other) # type: ignore[attr-defined]
  350. def bitwise_or(self, other):
  351. return self._bitwise_or(other) # type: ignore[attr-defined]
  352. # There is no int_truediv available from C++
  353. def truediv(self, other):
  354. return self.float_truediv(other)
  355. def floordiv(self, other) -> SymNode:
  356. return self.int_floordiv(other)
  357. # We didn't bind integer pow in C++
  358. def pow(self, other):
  359. return self.float_pow(other)
  360. def is_non_overlapping_and_dense(self, sizes, strides):
  361. return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(
  362. to_node(self, 1)
  363. ) # type: ignore[attr-defined]
  364. def int_(self):
  365. return self.guard_int("", 0) # NB: uses Python backtrace
  366. # This one is currently done by hand, but if we add other variadic
  367. # functions consider factoring it out to be metaprogrammed too. Note that
  368. # some load bearing logic is directly in torch.sym_sum
  369. def sym_sum(self, args) -> SymNode:
  370. import sympy
  371. # Inner impl
  372. from torch.fx.experimental.proxy_tensor import (
  373. get_proxy_mode,
  374. handle_sym_dispatch,
  375. )
  376. if get_proxy_mode():
  377. return to_node(
  378. self,
  379. handle_sym_dispatch(
  380. torch.sym_sum,
  381. (tuple(wrap_node(a) for a in args),),
  382. {},
  383. ),
  384. )
  385. exprs = [a.expr for a in args]
  386. out = sympy.Add(*exprs)
  387. size_hints = []
  388. out_hint = None
  389. for a in args:
  390. if a.hint is None:
  391. break
  392. size_hints.append(a.hint)
  393. else:
  394. out_hint = sum(size_hints)
  395. fx_node, _ = self.shape_env._create_fx_call_function(
  396. torch.sym_sum, (tuple(a.fx_node for a in args),)
  397. )
  398. # NB: Only for integers!
  399. return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node)
  400. def evaluate(self, size_oblivious=False):
  401. return self.shape_env.evaluate_sym_node(self, size_oblivious)
  402. # You can manually trigger a guard with this function
  403. def guard_int(self, file, line):
  404. # TODO: use the file/line for some useful diagnostic on why a
  405. # guard occurred
  406. r = self.evaluate()
  407. try:
  408. return int(r)
  409. except Exception:
  410. log.warning("Failed to convert to int: %s", r)
  411. raise
  412. def guard_float(self, file, line):
  413. # TODO: use the file/line for some useful diagnostic on why a
  414. # guard occurred
  415. r = self.evaluate()
  416. try:
  417. return float(r)
  418. except Exception:
  419. log.warning("Failed to convert to float: %s", r)
  420. raise
  421. def guard_bool(self, file, line):
  422. # TODO: use the file/line for some useful diagnostic on why a
  423. # guard occurred
  424. r = self.evaluate()
  425. try:
  426. return bool(r)
  427. except Exception:
  428. log.warning("Failed to convert to bool: %s", r)
  429. raise
  430. def expect_true(self, file, line):
  431. from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
  432. if (
  433. self.has_hint()
  434. and not free_unbacked_symbols(self.expr)
  435. and not self.shape_env.prefer_deferred_runtime_asserts_over_guards
  436. ):
  437. # OK to generate guards
  438. return self.guard_bool(file, line)
  439. # Generate a deferred runtime assert (this might actually end up doing
  440. # a regular guard if we can!)
  441. # TODO: file/line here is very important, because the assert has been
  442. # deferred so you can't backtrace easily
  443. return self.shape_env.guard_or_defer_runtime_assert(
  444. self.expr, f"{file}:{line}", fx_node=self.fx_node
  445. )
  446. def expect_size(self, file, line):
  447. from torch.fx.experimental.symbolic_shapes import _advise_is_size
  448. b = self.ge(self.wrap_int(0))
  449. # Generate a deferred runtime assert
  450. r = b.expect_true(file, line)
  451. # Refine compile time range, but only if it's unbacked.
  452. # If you refine range for hinted variables, you can end up making
  453. # improper deductions since compile time reasoning may be
  454. # incompatible with runtime reasoning.
  455. if r and not self.has_hint():
  456. _advise_is_size(SymInt(self))
  457. return r
  458. def statically_known_true(self, file, line):
  459. from torch.fx.experimental.symbolic_shapes import statically_known_true
  460. assert self.is_bool()
  461. return statically_known_true(SymBool(self))
  462. def guard_size_oblivious(self, file, line):
  463. """
  464. Like guard_bool, but if we encounter unbacked symbols, if those symbols
  465. are size-like, we will treat them as >= 2 for the purposes of the analysis.
  466. This CHANGES the runtime semantics, but all size-oblivious sites have been
  467. audited to ensure that the runtime semantics don't change in a material way.
  468. Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
  469. an unbacked one size, or a tensor reporting as non-contiguous even if it's
  470. contiguous if it would have been reported contiguous due to being empty.
  471. """
  472. # TODO: use the file/line for some useful diagnostic on why a
  473. # guard occurred
  474. r = self.evaluate(size_oblivious=True)
  475. try:
  476. return bool(r)
  477. except Exception:
  478. log.warning("Failed to convert to bool: %s", r)
  479. raise
  480. def guard_or_false(self, file, line):
  481. from torch.fx.experimental.symbolic_shapes import guard_or_false
  482. assert self.is_bool()
  483. return guard_or_false(SymBool(self))
  484. def guard_or_true(self, file, line):
  485. from torch.fx.experimental.symbolic_shapes import guard_or_true
  486. assert self.is_bool()
  487. return guard_or_true(SymBool(self))
  488. def bool_(self):
  489. return self.guard_bool("", 0)
  490. def is_symbolic(self):
  491. return True
  492. def nested_int(self):
  493. return None
  494. def is_constant(self):
  495. return False
  496. # TODO: this probably needs the sizes-strides eval functions
  497. METHOD_TO_OPERATOR = {
  498. "pos": operator.pos,
  499. "abs": operator.abs,
  500. "add": operator.add,
  501. "and": operator.and_,
  502. "bitwise_and": operator.and_,
  503. "ceil": math.ceil,
  504. "eq": operator.eq,
  505. "floor": math.floor,
  506. "trunc": math.trunc,
  507. "int_floordiv": operator.floordiv,
  508. "ge": operator.ge,
  509. "gt": operator.gt,
  510. "is_integer": lambda x: x.is_integer(),
  511. "le": operator.le,
  512. "lshift": operator.lshift,
  513. "lt": operator.lt,
  514. "mod": operator.mod,
  515. "mul": operator.mul,
  516. "ne": operator.ne,
  517. "neg": operator.neg,
  518. "or": operator.or_,
  519. "bitwise_or": operator.or_,
  520. "float_pow": operator.pow,
  521. "pow_by_natural": operator.pow,
  522. "round": builtins.round,
  523. "rshift": operator.rshift,
  524. "sub": operator.sub,
  525. "sym_float": sym_float,
  526. "sym_ite": sym_ite,
  527. "sym_max": sym_max,
  528. "sym_min": sym_min,
  529. "sym_not": sym_not,
  530. "float_truediv": operator.truediv,
  531. "int_truediv": operator.truediv,
  532. }
  533. unary_magic_methods = {
  534. "abs",
  535. "sym_float",
  536. "sym_int",
  537. "ceil",
  538. "floor",
  539. "neg",
  540. "sym_not",
  541. "pos",
  542. "trunc",
  543. }
  544. # Adding math ops: sqrt, cos, sin, ...
  545. def _get_sym_node_fn(name):
  546. def fn(self):
  547. return getattr(self, f"_sym_{name}")()
  548. return fn
  549. math_op_names = (
  550. "sqrt",
  551. "cos",
  552. "cosh",
  553. "sin",
  554. "sinh",
  555. "tan",
  556. "tanh",
  557. "asin",
  558. "acos",
  559. "atan",
  560. "log2",
  561. )
  562. for name in math_op_names:
  563. sym_name = f"sym_{name}"
  564. priv_sym_name = f"_{sym_name}"
  565. setattr(SymNode, sym_name, _get_sym_node_fn(name))
  566. METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
  567. unary_magic_methods.add(sym_name)
  568. __all__.append(sym_name)
  569. # Unary methods that are not magic methods
  570. unary_nonmagic_methods = {
  571. "is_integer",
  572. }
  573. unary_methods = unary_magic_methods | unary_nonmagic_methods
  574. # Most methods are only registered on SymInt and SymFloat
  575. # Some methods are only be registered on SymBool
  576. only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
  577. # Methods that implicitly convert SymBool into SymInt
  578. bool_becomes_int_magic_methods = {"add", "sub", "mul"}
  579. # Methods that are also on SymBool, in addition to on SymInt and SymFloat
  580. also_bool_magic_methods = {"eq"}
  581. bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
  582. # Methods that are only for float
  583. only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
  584. magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
  585. # remap necessary because an op name can have a bitwise and boolean implementation
  586. bitwise_ops = {
  587. "bitwise_and": "and",
  588. "bitwise_or": "or",
  589. }
  590. always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
  591. for name in math_op_names:
  592. sym_name = f"sym_{name}"
  593. always_float_magic_methods.add(sym_name)
  594. always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
  595. always_bool_magic_methods = {
  596. "eq",
  597. "ne",
  598. "gt",
  599. "lt",
  600. "le",
  601. "ge",
  602. "and",
  603. "or",
  604. "sym_not",
  605. "is_non_overlapping_and_dense",
  606. "is_integer",
  607. }
  608. # Methods that have a `__foo__` as well as `__rfoo__`
  609. def _sympy_float_truediv(a, b):
  610. from torch.utils._sympy.functions import FloatTrueDiv
  611. return FloatTrueDiv(a, b)
  612. def _sympy_int_truediv(a, b):
  613. from torch.utils._sympy.functions import IntTrueDiv
  614. return IntTrueDiv(a, b)
  615. def _sympy_floordiv(a, b):
  616. from torch.utils._sympy.functions import FloorDiv
  617. return FloorDiv(a, b)
  618. def _sympy_mod(a, b):
  619. from torch.utils._sympy.functions import Mod, PythonMod
  620. if a.is_nonnegative and b.is_nonnegative:
  621. return Mod(a, b)
  622. else:
  623. return PythonMod(a, b)
  624. def _sympy_pow_by_natural(a, b):
  625. from torch.utils._sympy.functions import PowByNatural
  626. return PowByNatural(a, b)
  627. def _sympy_float_pow(a, b):
  628. from torch.utils._sympy.functions import FloatPow
  629. return FloatPow(a, b)
  630. def _sympy_and(a, b):
  631. import sympy
  632. return sympy.And(a, b)
  633. def _sympy_or(a, b):
  634. import sympy
  635. return sympy.Or(a, b)
  636. def _sympy_lshift(a, b):
  637. from torch.utils._sympy.functions import LShift
  638. return LShift(a, b)
  639. def _sympy_rshift(a, b):
  640. from torch.utils._sympy.functions import RShift
  641. return RShift(a, b)
  642. def _binary_search_insert_arg(ordered_args, new_arg):
  643. """
  644. If new_arg is found in ordered_args None is returned, else the new
  645. ordered_args with new_arg inserted
  646. """
  647. if len(ordered_args) == 0:
  648. return [new_arg]
  649. from sympy.core.basic import _args_sortkey as sort_key, Basic
  650. # Fast path when new_arg > ordered_args[-1].
  651. if sort_key(ordered_args[-1]) < sort_key(new_arg):
  652. return ordered_args + [new_arg]
  653. # Fast path when new_arg < ordered_args[0].
  654. if sort_key(ordered_args[0]) > sort_key(new_arg):
  655. return [new_arg] + ordered_args
  656. low, high = 0, len(ordered_args) - 1
  657. while low <= high:
  658. mid = (low + high) // 2
  659. compare_result = Basic.compare(ordered_args[mid], new_arg)
  660. if compare_result == 0:
  661. return None
  662. elif compare_result < 0:
  663. low = mid + 1
  664. else:
  665. high = mid - 1
  666. ordered_args.insert(low, new_arg)
  667. return ordered_args
  668. def _optimized_add(
  669. lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False
  670. ):
  671. """
  672. Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea
  673. is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols,
  674. and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following.
  675. 1. Avoid running other optimizations when the Add is constructed.
  676. 2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n)
  677. (comparing terms is expensive and shows in the profiles).
  678. The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols,
  679. (2) the result sympy expression.
  680. """
  681. import sympy
  682. from sympy.core.basic import _args_sortkey as sortkey
  683. def make_optimized(ordered_args):
  684. assert ordered_args is not None
  685. result = sympy.Add(*ordered_args, evaluate=False)
  686. return (True, result)
  687. from torch.utils._sympy.functions import _is_symbols_binary_summation
  688. lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs)
  689. rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs)
  690. if lhs_is_optimized_summation and rhs_is_optimized_summation:
  691. # (a0+a1..) + (a2+a3..) => (a0+a1+a2+a3)
  692. if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]):
  693. return make_optimized(lhs._args + rhs._args)
  694. # (a2+a3..) + (a0+a1..) => (a0+a1+a2+a3)
  695. if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
  696. return make_optimized(rhs._args + lhs._args)
  697. # (a1+a3) + (a0+a2) => (a0+a1+a2+a3)
  698. if len(lhs._args) <= 2 and len(rhs._args) <= 2:
  699. new_args = list(lhs._args)
  700. for a in rhs._args:
  701. new_args = _binary_search_insert_arg(new_args, a)
  702. if new_args is None:
  703. break
  704. # None means an element already exists.
  705. if new_args is not None:
  706. return make_optimized(new_args)
  707. # (a0+a2) + a1 => (a0+a1+a2)
  708. if lhs_is_optimized_summation and rhs.is_symbol:
  709. new_args = _binary_search_insert_arg(list(lhs._args), rhs)
  710. # None means an element already exists.
  711. if new_args is not None:
  712. return make_optimized(new_args)
  713. # a1 + (a0+a2)=> (a0+a1+a2)
  714. if rhs_is_optimized_summation and lhs.is_symbol:
  715. new_args = _binary_search_insert_arg(list(rhs._args), lhs)
  716. # None means an element already exists.
  717. if new_args is not None:
  718. return make_optimized(new_args)
  719. result = sympy.Add(lhs, rhs)
  720. return (_is_symbols_binary_summation(result), result)
  721. def _bitwise_and(a, b):
  722. from torch.utils._sympy.functions import BitwiseFn_bitwise_and
  723. return BitwiseFn_bitwise_and(a, b)
  724. def _bitwise_or(a, b):
  725. from torch.utils._sympy.functions import BitwiseFn_bitwise_or
  726. return BitwiseFn_bitwise_or(a, b)
  727. reflectable_magic_methods = {
  728. "add": _optimized_add,
  729. "sub": operator.sub,
  730. "mul": operator.mul,
  731. "mod": _sympy_mod,
  732. "pow_by_natural": _sympy_pow_by_natural,
  733. "float_pow": _sympy_float_pow,
  734. "and": _sympy_and,
  735. "bitwise_and": _bitwise_and,
  736. "or": _sympy_or,
  737. "bitwise_or": _bitwise_or,
  738. "float_truediv": _sympy_float_truediv,
  739. "int_truediv": _sympy_int_truediv,
  740. "int_floordiv": _sympy_floordiv,
  741. "lshift": _sympy_lshift,
  742. "rshift": _sympy_rshift,
  743. }
  744. def _floor_ceil_helper(a, fn):
  745. import sympy
  746. if isinstance(a, sympy.Mul):
  747. aa = a.args
  748. if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
  749. coef = sympy.Integer(aa[0])
  750. if aa[0] == coef: # structural equality test
  751. return coef * aa[1]
  752. if (
  753. isinstance(a, sympy.Float)
  754. and a == sympy.Integer(a)
  755. or isinstance(a, sympy.Integer)
  756. ):
  757. return sympy.Integer(a)
  758. return fn(a)
  759. def _sympy_floor(a):
  760. from torch.utils._sympy.functions import FloorToInt
  761. return FloorToInt(a)
  762. # NB: this is Python trunc semantics which returns an int. Do NOT use this to
  763. # represent torch.trunc (which is float to float)
  764. def _sympy_trunc(a):
  765. from torch.utils._sympy.functions import TruncToInt
  766. return TruncToInt(a)
  767. def _sympy_ceil(a):
  768. from torch.utils._sympy.functions import CeilToInt
  769. return CeilToInt(a)
  770. def _sympy_eq(a, b):
  771. import sympy
  772. return sympy.Eq(a, b)
  773. def _sympy_ne(a, b):
  774. import sympy
  775. return sympy.Ne(a, b)
  776. def _sympy_gt(a, b):
  777. import sympy
  778. return sympy.Gt(a, b)
  779. def _sympy_lt(a, b):
  780. import sympy
  781. return sympy.Lt(a, b)
  782. def _sympy_le(a, b):
  783. import sympy
  784. return sympy.Le(a, b)
  785. def _sympy_ge(a, b):
  786. import sympy
  787. return sympy.Ge(a, b)
  788. def _sympy_min(a, b):
  789. from torch.utils._sympy.functions import Min
  790. return Min(a, b)
  791. def _sympy_max(a, b):
  792. from torch.utils._sympy.functions import Max
  793. return Max(a, b)
  794. def _sympy_ite(a, t, f):
  795. import sympy
  796. return sympy.Piecewise((t, a), (f, True))
  797. current_module = sys.modules[__name__]
  798. def _get_sym_math_fn(name):
  799. def fn(a):
  800. import torch.utils._sympy.functions
  801. return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a)
  802. return fn
  803. for name in math_op_names:
  804. priv_sympy_name = f"_sympy_{name}"
  805. fn = _get_sym_math_fn(name)
  806. fn.__qualname__ = fn.__name__ = priv_sympy_name
  807. setattr(current_module, priv_sympy_name, fn)
  808. del fn, name, priv_sympy_name # type: ignore[possibly-undefined]
  809. def _sympy_abs(a):
  810. import sympy
  811. return sympy.Abs(a)
  812. def _sympy_round(number, ndigits=None):
  813. from torch.utils._sympy.functions import RoundDecimal, RoundToInt
  814. if ndigits is None:
  815. return RoundToInt(number)
  816. else:
  817. return RoundDecimal(number, ndigits)
  818. def _sympy_sym_float(a):
  819. from torch.utils._sympy.functions import ToFloat
  820. # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly
  821. # reports that it is an integer
  822. return ToFloat(a)
  823. def _sympy_is_integer(a):
  824. import sympy
  825. from torch.utils._sympy.functions import ToFloat
  826. return sympy.Eq(ToFloat(sympy.floor(a)), a)
  827. magic_methods = {
  828. **reflectable_magic_methods,
  829. "sym_not": operator.invert,
  830. "pos": operator.pos,
  831. "eq": _sympy_eq,
  832. "ne": _sympy_ne,
  833. "gt": _sympy_gt,
  834. "lt": _sympy_lt,
  835. "le": _sympy_le,
  836. "ge": _sympy_ge,
  837. "floor": _sympy_floor,
  838. "trunc": _sympy_trunc,
  839. "sym_float": _sympy_sym_float,
  840. "ceil": _sympy_ceil,
  841. "neg": operator.neg,
  842. "sym_min": _sympy_min,
  843. "sym_max": _sympy_max,
  844. "sym_ite": _sympy_ite,
  845. "abs": _sympy_abs,
  846. "round": _sympy_round,
  847. "is_integer": _sympy_is_integer,
  848. }
  849. for name in math_op_names:
  850. sym_name = f"sym_{name}"
  851. magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
  852. del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined]
  853. def sympy_is_contiguous(sizes, strides):
  854. dim = len(sizes)
  855. return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
  856. def sympy_is_contiguous_generic(sizes, strides, dim_order):
  857. import sympy
  858. dim = len(sizes)
  859. if len(dim_order) != dim:
  860. return sympy.false
  861. is_contiguous = sympy.true
  862. z = sympy.S.One
  863. # Contiguous if the strides make sense (or the dim is size 1)
  864. for d in dim_order:
  865. is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z)
  866. z *= sizes[d]
  867. # OR if any size is zero
  868. for d in range(dim):
  869. is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero)
  870. return is_contiguous
  871. # NB: There is a TODO in C++ to allow omitting the batch dim. If that
  872. # happens you will need to refactor this
  873. def sympy_is_channels_last_contiguous_2d(sizes, strides):
  874. return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
  875. def sympy_is_channels_last_contiguous_3d(sizes, strides):
  876. return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
  877. def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
  878. import sympy
  879. from torch.utils._sympy.functions import Max
  880. dim = len(sizes)
  881. if dim != len(dim_order):
  882. return sympy.false
  883. m = sympy.S.Zero
  884. r = sympy.true
  885. # special case for trivial C dimension. default to NCHW
  886. r &= sympy.Ne(strides[1], 0)
  887. for d in dim_order:
  888. r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
  889. # Fallback to NCHW as default layout for ambiguous cases
  890. # This is the flaw of implicit memory_format from strides.
  891. # N111 tensor with identical strides for size 1 dimension;
  892. # Two cases could lead us here:
  893. # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
  894. # b. N11W contiguous Tensor sliced on the W-dimension.
  895. # ([N,1,1,1]@[W,W,W,W])
  896. if d == 0:
  897. r &= sympy.Ne(m, strides[1])
  898. # This is necessary to:
  899. # 1. distinguish the memory_format of N1H1;
  900. # [H, 1, 1, 1] channels_last stride
  901. # [H, H, 1, 1] contiguous stride
  902. # 2. permutation of 1C1W:
  903. # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
  904. # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
  905. # channels_last
  906. m = strides[d] * Max(sizes[d], 1)
  907. return r
  908. def sympy_is_channels_last_strides_2d(sizes, strides):
  909. return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
  910. def sympy_is_channels_last_strides_3d(sizes, strides):
  911. return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
  912. def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
  913. from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
  914. return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
  915. sizes_strides_methods = {
  916. # TODO: These could also be done with indicators, maybe it is better
  917. # for reasoning to do it that way
  918. "is_contiguous": sympy_is_contiguous,
  919. "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
  920. "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
  921. "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
  922. "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
  923. "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
  924. }
  925. def to_node(self, num):
  926. if isinstance(num, SymTypes):
  927. return num.node
  928. elif type(num) is bool:
  929. return self.wrap_bool(num)
  930. elif type(num) is int:
  931. return self.wrap_int(num)
  932. elif type(num) is float:
  933. return self.wrap_float(num)
  934. else:
  935. # NotImplemented is important so that Python tries the
  936. # other magic method
  937. return NotImplemented
  938. def wrap_node(x):
  939. # TODO: let C++ also take advantage of this
  940. if isinstance(x, SymNode) and x.constant is not None:
  941. return x.constant
  942. if x.is_int():
  943. return SymInt(x)
  944. elif x.is_float():
  945. return SymFloat(x)
  946. elif x.is_bool():
  947. return SymBool(x)
  948. else:
  949. raise AssertionError(f"unrecognized return type {x}")
  950. def method_to_operator(method):
  951. return METHOD_TO_OPERATOR[method]
  952. def _make_node_magic(method, func):
  953. func = lru_cache(256)(func)
  954. if method in magic_methods_on_operator_with_trailing_underscore:
  955. method_attr = f"{method}_"
  956. else:
  957. method_attr = method
  958. def uninteresting_files() -> set[str]:
  959. import torch
  960. mods = [
  961. torch._dynamo.eval_frame,
  962. torch._dynamo.utils,
  963. torch.fx.experimental.sym_node,
  964. torch,
  965. ]
  966. import torch._dynamo.guards
  967. return (
  968. {inspect.getfile(m) for m in mods}
  969. | torch._dynamo.guards.uninteresting_files()
  970. | {"<string>"}
  971. )
  972. def capture_provenance(fn):
  973. @functools.wraps(fn)
  974. def wrapper(self, other=None):
  975. if other is None:
  976. result = fn(self)
  977. else:
  978. result = fn(self, other)
  979. if torch._logging._internal.GET_DTRACE_STRUCTURED:
  980. if other is not None:
  981. arguments = [self, other]
  982. else:
  983. arguments = [self]
  984. def get_id(sym_node) -> Optional[int]:
  985. # We don't want to return an ID if the input is a constant
  986. import sympy
  987. if sym_node.constant is not None:
  988. return None
  989. elif id(sym_node) == id(result):
  990. return None
  991. elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)):
  992. return None
  993. elif sym_node.expr in (sympy.true, sympy.false):
  994. return None
  995. return id(sym_node)
  996. dtrace_structured(
  997. "expression_created",
  998. metadata_fn=lambda: {
  999. "method": method,
  1000. "result": str(result),
  1001. "result_id": id(result),
  1002. "arguments": [str(a) for a in arguments],
  1003. "argument_ids": [
  1004. get_id(i) for i in arguments if get_id(i) is not None
  1005. ],
  1006. "user_stack": structured.get_user_stack(3),
  1007. "stack": structured.get_framework_stack(3),
  1008. },
  1009. )
  1010. return result
  1011. return wrapper
  1012. @capture_provenance
  1013. def binary_magic_impl(self, other):
  1014. from torch.fx.experimental.proxy_tensor import (
  1015. get_proxy_mode,
  1016. handle_sym_dispatch,
  1017. )
  1018. op = method_to_operator(method)
  1019. out_hint = None
  1020. if self.hint is not None and other.hint is not None:
  1021. out_hint = op(self.hint, other.hint)
  1022. if get_proxy_mode():
  1023. return to_node(
  1024. self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
  1025. )
  1026. assert isinstance(other, SymNode)
  1027. optimized_summation = False
  1028. try:
  1029. if method == "mod":
  1030. from torch.utils._sympy.functions import Mod, PythonMod
  1031. # Special handling for mod that requires access to the value
  1032. # ranges
  1033. shape_env = self.shape_env
  1034. if (
  1035. self.expr.is_nonnegative
  1036. or shape_env.bound_sympy(self.expr).lower >= 0
  1037. ) and (
  1038. other.expr.is_nonnegative
  1039. or shape_env.bound_sympy(other.expr).lower >= 0
  1040. ):
  1041. out = Mod(self.expr, other.expr)
  1042. else:
  1043. out = PythonMod(self.expr, other.expr)
  1044. elif method == "add":
  1045. # see Note [optimized_summation]
  1046. (optimized_summation, out) = func(
  1047. self.expr,
  1048. other.expr,
  1049. self._optimized_summation,
  1050. other._optimized_summation,
  1051. )
  1052. else:
  1053. # TODO: consider constant prop here
  1054. out = func(self.expr, other.expr)
  1055. except Exception:
  1056. log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
  1057. raise
  1058. sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
  1059. pytype: type
  1060. # This is not strictly correct. In Python, a**b may return complex when
  1061. # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
  1062. # returns a float while both arguments are ints: 2**(-1). Also, max and
  1063. # min do not type promote. To avoid having data-dependent control flow
  1064. # here, we just set the type to float if one of the args is a float. In
  1065. # case of a type mismatch, we assume that it will be detected during
  1066. # evaluation.
  1067. if method in always_float_magic_methods:
  1068. pytype = float
  1069. elif method in always_bool_magic_methods:
  1070. pytype = bool
  1071. elif self.pytype is float or other.pytype is float:
  1072. pytype = float
  1073. else:
  1074. pytype = self.pytype
  1075. if (
  1076. pytype is not None
  1077. and out_hint is not None
  1078. and not isinstance(out_hint, SymTypes)
  1079. ):
  1080. out_hint = pytype(out_hint)
  1081. # Create a FX node that corresponds to the operation being applied to
  1082. # this node.
  1083. fx_node, _ = self.shape_env._create_fx_call_function(
  1084. op, (self.fx_node, other.fx_node)
  1085. )
  1086. result = SymNode(
  1087. out,
  1088. self.shape_env,
  1089. pytype,
  1090. out_hint, # type: ignore[arg-type]
  1091. fx_node=fx_node,
  1092. optimized_summation=optimized_summation, # see Note [optimized_summation]
  1093. )
  1094. return result
  1095. @capture_provenance
  1096. def unary_magic_impl(self):
  1097. from torch.fx.experimental.proxy_tensor import (
  1098. get_proxy_mode,
  1099. handle_sym_dispatch,
  1100. )
  1101. op = method_to_operator(method)
  1102. if get_proxy_mode():
  1103. return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
  1104. # TODO: consider constant prop here
  1105. expr = self.expr
  1106. if method == "floor" or method == "ceiling":
  1107. expr = self.shape_env._simplify_floor_div(expr)
  1108. try:
  1109. out = func(expr)
  1110. except Exception:
  1111. log.warning("failed to eval %s(%s)", method, expr)
  1112. raise
  1113. sym_node_log.debug("%s %s -> %s", func, expr, out)
  1114. out_hint = None
  1115. if self.hint is not None:
  1116. out_hint = op(self.hint)
  1117. pytype: type
  1118. if method in always_int_magic_methods:
  1119. pytype = int
  1120. elif method in always_bool_magic_methods:
  1121. pytype = bool
  1122. elif method in always_float_magic_methods:
  1123. pytype = float
  1124. else:
  1125. pytype = self.pytype
  1126. fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
  1127. return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
  1128. if method in unary_methods:
  1129. setattr(SymNode, f"_{method_attr}", unary_magic_impl)
  1130. elif method == "sym_ite":
  1131. def sym_ite_impl(pred_node, then_node, else_node):
  1132. from torch.fx.experimental.proxy_tensor import (
  1133. get_proxy_mode,
  1134. handle_sym_dispatch,
  1135. )
  1136. out_hint = then_node.hint if pred_node.hint else else_node.hint
  1137. if get_proxy_mode():
  1138. return to_node(
  1139. pred_node,
  1140. handle_sym_dispatch(
  1141. sym_ite,
  1142. (
  1143. wrap_node(pred_node),
  1144. wrap_node(then_node),
  1145. wrap_node(else_node),
  1146. ),
  1147. {},
  1148. ),
  1149. )
  1150. try:
  1151. out = func(pred_node.expr, then_node.expr, else_node.expr)
  1152. except Exception:
  1153. log.warning(
  1154. "failed to eval %s(%s, %s, %s)",
  1155. method,
  1156. pred_node.expr,
  1157. then_node.expr,
  1158. else_node.expr,
  1159. )
  1160. raise
  1161. fx_node, _ = pred_node.shape_env._create_fx_call_function(
  1162. sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
  1163. )
  1164. return SymNode(
  1165. out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
  1166. )
  1167. setattr(SymNode, f"_{method_attr}", sym_ite_impl)
  1168. elif method == "round":
  1169. def round_impl(self, ndigits=None):
  1170. from torch.fx.experimental.proxy_tensor import (
  1171. get_proxy_mode,
  1172. handle_sym_dispatch,
  1173. )
  1174. op = builtins.round
  1175. if get_proxy_mode():
  1176. return to_node(
  1177. self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
  1178. )
  1179. expr = self.expr
  1180. try:
  1181. out = func(expr, ndigits)
  1182. except Exception:
  1183. log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
  1184. raise
  1185. if ndigits is None:
  1186. pytype = int
  1187. else:
  1188. pytype = self.pytype
  1189. out_hint = None
  1190. if self.hint is not None:
  1191. out_hint = op(self.hint, ndigits)
  1192. # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
  1193. # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
  1194. # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
  1195. # hack down below works, because all round function down the line all take ndigits=None as default in their
  1196. # signature.
  1197. # TODO: Remove the args construction below if a different sentinel is used by FX.
  1198. # ezyang(May 2024): LOL
  1199. args = [self.fx_node]
  1200. if ndigits is not None:
  1201. args.append(ndigits)
  1202. fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
  1203. return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
  1204. setattr(SymNode, f"_{method_attr}", round_impl)
  1205. else:
  1206. setattr(SymNode, f"_{method_attr}", binary_magic_impl)
  1207. def _make_node_sizes_strides(method, func):
  1208. # NB: don't LRU cache, lots of arguments
  1209. def sizes_strides_impl(self, sizes, strides):
  1210. from torch.fx.experimental.proxy_tensor import (
  1211. get_proxy_mode,
  1212. handle_sym_dispatch,
  1213. )
  1214. op = getattr(sys.modules[__name__], method)
  1215. if get_proxy_mode():
  1216. return to_node(
  1217. self,
  1218. handle_sym_dispatch(
  1219. op,
  1220. ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
  1221. {},
  1222. ),
  1223. )
  1224. size_exprs = [s.expr for s in sizes]
  1225. stride_exprs = [s.expr for s in strides]
  1226. try:
  1227. out = func(size_exprs, stride_exprs)
  1228. except Exception:
  1229. log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
  1230. raise
  1231. # bool is never expandable
  1232. size_hints = []
  1233. out_hint = None
  1234. for s in sizes:
  1235. if s.hint is None:
  1236. break
  1237. size_hints.append(s.hint)
  1238. else:
  1239. stride_hints = []
  1240. for s in strides:
  1241. if s.hint is None:
  1242. break
  1243. stride_hints.append(s.hint)
  1244. else:
  1245. out_hint = op(size_hints, stride_hints)
  1246. # NB: This is the indicator function, not the actual bool!
  1247. pytype: type
  1248. if method.endswith("_indicator"):
  1249. pytype = int
  1250. else:
  1251. pytype = bool
  1252. return SymNode(out, self.shape_env, pytype, out_hint)
  1253. setattr(SymNode, f"_{method}", sizes_strides_impl)
  1254. # TODO: This is technically hotpath, but in the ideal end state
  1255. # guards on this will resolve at a higher level so you never
  1256. # spend time in this code
  1257. def sizes_strides_user(sizes, strides):
  1258. import sympy
  1259. from torch.fx.experimental.symbolic_shapes import (
  1260. eval_is_non_overlapping_and_dense,
  1261. )
  1262. for a in itertools.chain(sizes, strides):
  1263. if isinstance(a, SymInt):
  1264. return wrap_node(
  1265. getattr(a.node, method)(
  1266. [to_node(a.node, b) for b in sizes],
  1267. [to_node(a.node, b) for b in strides],
  1268. )
  1269. )
  1270. if method == "is_non_overlapping_and_dense_indicator":
  1271. return eval_is_non_overlapping_and_dense(sizes, strides)
  1272. else:
  1273. # TODO: this is an awful implementation
  1274. return bool(
  1275. func(
  1276. [sympy.sympify(a) for a in sizes],
  1277. [sympy.sympify(a) for a in strides],
  1278. )
  1279. )
  1280. # Skip for is_non_overlapping_and_dense_indicator
  1281. if not hasattr(sys.modules[__name__], method):
  1282. setattr(sys.modules[__name__], method, sizes_strides_user)
  1283. for method, func in magic_methods.items():
  1284. _make_node_magic(method, func)
  1285. for method, func in sizes_strides_methods.items():
  1286. _make_node_sizes_strides(method, func)
  1287. def _make_user_magic(method, user_type):
  1288. # User magic takes care of wrapping the other operand into a node,
  1289. # so that our internal logic can assume everything is nodes
  1290. if method in magic_methods_on_operator_with_trailing_underscore:
  1291. method_attr = f"sym_{method}"
  1292. else:
  1293. method_attr = method
  1294. def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
  1295. if isinstance(x, (int, float, bool)):
  1296. return x
  1297. if isinstance(x, SymBool):
  1298. return x.node.guard_bool("", 0)
  1299. raise AssertionError("expect to be called with constant SymBools")
  1300. def is_constant(x):
  1301. if isinstance(x, (int, float, bool)):
  1302. return True
  1303. if isinstance(x, (SymInt, SymFloat, SymBool)):
  1304. return x.node.is_constant()
  1305. return False
  1306. # Promotion rules for binary operations. NB: we preserve PYTHON semantics
  1307. # - if args are same type, do nothing
  1308. # - if one arg is float, promote other arg to float
  1309. # - nb: this applies to floordiv, even though output is integral
  1310. # (it's still float)
  1311. # - pow is funny business
  1312. # - if both ints
  1313. # - trigger a guard on exponent >= 0
  1314. # - if non-negative, output is int
  1315. # - otherwise, output is float
  1316. # - otherwise, promote other arg to float
  1317. # - nb: complex is impossible to handle correctly lol, with
  1318. # negative base and integral float need to diverge semantics and
  1319. # just always return complex. Neener neener pretend this problem
  1320. # doesn't exist
  1321. # - equality is pain: Python does the fancy thing where it unpacks the
  1322. # mantissa from the float and then compares that against the int.
  1323. # Which means it is able to tell that
  1324. # 9007199254740993 != 9007199254740992. (rather than if the LHS was
  1325. # promoted to float, in which case it would have truncated to the RHS
  1326. # and subsequently been equal). We'll model this exactly by having
  1327. # special mixed type equality operations. Unfortunately, we need to
  1328. # do this for all comparison operations (maybe I'll only implement
  1329. # compare)
  1330. # - sym_ite mumble mumble really shouldn't allow mixed but whatever
  1331. if method in bool_becomes_int_magic_methods:
  1332. def promote(x):
  1333. """Implements True+True=2, which works in python but not sympy"""
  1334. if isinstance(x, SymBool):
  1335. return SymInt(x.node.wrap_int(int(x)))
  1336. return x
  1337. else:
  1338. def promote(x):
  1339. return x
  1340. def promote2(self, other):
  1341. # TODO: Remove eq and other relations from this list.
  1342. # CPython has fancy implementations for these to get as much precision
  1343. # as possible instead of just promoting to float64 and praying, so we
  1344. # need to handle them specially too.
  1345. # Also, note that int_truediv doesn't go through this path: both
  1346. # arguments are "int" so there isn't any promotion
  1347. if method not in [
  1348. "add",
  1349. "sub",
  1350. "mul",
  1351. "mod",
  1352. "float_pow",
  1353. "float_truediv",
  1354. "int_floordiv",
  1355. "sym_min",
  1356. "sym_max",
  1357. # TODO: remove these
  1358. "eq",
  1359. "ne",
  1360. "gt",
  1361. "lt",
  1362. "le",
  1363. "ge",
  1364. ]:
  1365. return self, other
  1366. f_self = isinstance(self, (float, torch.SymFloat))
  1367. f_other = isinstance(other, (float, torch.SymFloat))
  1368. if f_self or f_other:
  1369. if not f_self:
  1370. self = torch.sym_float(self)
  1371. if not f_other:
  1372. other = torch.sym_float(other)
  1373. return self, other
  1374. # Before and after performing the operation, check if any operands are constant.
  1375. # If so, extract out the constant values first. If `self` itself is a
  1376. # constant, then "redispatch" by calling back into the operator. Sometimes
  1377. # this means that operations involving SymBool return plain bools.
  1378. # Alternatively, we could also rewrap into constant Symbool (i.e. by
  1379. # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
  1380. # today for no particular reason.
  1381. def unary_magic_impl(self):
  1382. self = promote(self)
  1383. if is_constant(self):
  1384. return (method_to_operator(method))(get_constant(self))
  1385. return wrap_node(getattr(self.node, method_attr)())
  1386. def binary_magic_impl(self, other):
  1387. if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
  1388. return NotImplemented
  1389. sym_node_log.debug("MAGIC %s %s %s", method, self, other)
  1390. self = promote(self)
  1391. other = promote(other)
  1392. self, other = promote2(self, other)
  1393. if is_constant(self):
  1394. return (method_to_operator(method))(get_constant(self), other)
  1395. if is_constant(other):
  1396. other = get_constant(other)
  1397. other_node = to_node(self.node, other)
  1398. if other_node is NotImplemented:
  1399. return NotImplemented
  1400. ret = wrap_node(getattr(self.node, method_attr)(other_node))
  1401. return get_constant(ret) if is_constant(ret) else ret
  1402. def rbinary_magic_impl(self, other):
  1403. if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
  1404. return NotImplemented
  1405. self = promote(self)
  1406. other = promote(other)
  1407. self, other = promote2(self, other)
  1408. if is_constant(self):
  1409. return (method_to_operator(method))(get_constant(self), other)
  1410. if is_constant(other):
  1411. other = get_constant(other)
  1412. other_node = to_node(self.node, other)
  1413. if other_node is NotImplemented:
  1414. return NotImplemented
  1415. ret = wrap_node(getattr(other_node, method_attr)(self.node))
  1416. return get_constant(ret) if is_constant(ret) else ret
  1417. if method in unary_magic_methods:
  1418. setattr(user_type, f"__{method}__", unary_magic_impl)
  1419. elif method in unary_nonmagic_methods:
  1420. orig = getattr(user_type, method)
  1421. setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
  1422. elif method == "sym_ite":
  1423. def sym_ite_magic_impl(pred, then_val, else_val):
  1424. pred_node = pred.node
  1425. then_node = to_node(pred_node, then_val)
  1426. else_node = to_node(pred_node, else_val)
  1427. if then_node is NotImplemented or else_node is NotImplemented:
  1428. return NotImplemented
  1429. assert (
  1430. isinstance(then_node, SymNode)
  1431. and isinstance(else_node, SymNode)
  1432. and then_node.pytype == else_node.pytype
  1433. )
  1434. ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
  1435. return get_constant(ret) if ret.node.is_constant() else ret
  1436. setattr(user_type, f"__{method}__", sym_ite_magic_impl)
  1437. elif method == "round":
  1438. def round_magic_impl(self, ndigits=None):
  1439. if is_constant(self):
  1440. return builtins.round(get_constant(self), ndigits)
  1441. return wrap_node(getattr(self.node, method)(ndigits))
  1442. setattr(user_type, f"__{method}__", round_magic_impl)
  1443. else:
  1444. method_name = method
  1445. if method in bitwise_ops:
  1446. method_name = bitwise_ops[method]
  1447. setattr(user_type, f"__{method_name}__", binary_magic_impl)
  1448. if method in reflectable_magic_methods:
  1449. setattr(user_type, f"__r{method_name}__", rbinary_magic_impl)
  1450. for method, func in magic_methods.items(): # type: ignore[assignment]
  1451. if method in only_bool_magic_methods:
  1452. _make_user_magic(method, SymBool)
  1453. continue
  1454. if method in only_float_magic_methods:
  1455. _make_user_magic(method, SymFloat)
  1456. continue
  1457. if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
  1458. _make_user_magic(method, SymBool)
  1459. _make_user_magic(method, SymInt)
  1460. if method not in bitwise_ops:
  1461. _make_user_magic(method, SymFloat)
  1462. del method
  1463. del func