dynamic_shapes.py 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import inspect
  4. import logging
  5. import sys
  6. from collections import defaultdict
  7. from enum import auto, Enum
  8. from typing import Any, Callable, Optional, TYPE_CHECKING, Union
  9. import torch
  10. from torch.utils._pytree import (
  11. _get_node_type,
  12. BUILTIN_TYPES,
  13. keystr,
  14. LeafSpec,
  15. MappingKey,
  16. SequenceKey,
  17. SUPPORTED_NODES,
  18. tree_flatten,
  19. tree_map,
  20. tree_map_with_path,
  21. )
  22. from .exported_program import ExportedProgram
  23. if TYPE_CHECKING:
  24. from sympy import Symbol
  25. from torch._guards import Source
  26. from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint
  27. __all__ = [
  28. "Constraint",
  29. "Dim",
  30. "dims",
  31. "refine_dynamic_shapes_from_suggested_fixes",
  32. "AdditionalInputs",
  33. ]
  34. log = logging.getLogger(__name__)
  35. class _DimHintType(Enum):
  36. """
  37. Enum for dynamic shape hints.
  38. - AUTO means automatic inference of shape (static or dynamic).
  39. - STATIC means static shape (always specialized).
  40. - DYNAMIC means dynamic, will error out if specialized.
  41. """
  42. AUTO = auto()
  43. STATIC = auto()
  44. DYNAMIC = auto()
  45. @dataclasses.dataclass
  46. class _DimHint:
  47. type: _DimHintType
  48. min: Optional[int] = None
  49. max: Optional[int] = None
  50. _factory: Optional[bool] = True
  51. @staticmethod
  52. def AUTO():
  53. return _DimHint(_DimHintType.AUTO)
  54. @staticmethod
  55. def DYNAMIC():
  56. return _DimHint(_DimHintType.DYNAMIC)
  57. @staticmethod
  58. def STATIC():
  59. return _DimHint(_DimHintType.STATIC)
  60. def __call__(self, min=None, max=None) -> "_DimHint":
  61. if not self._factory:
  62. raise TypeError(f"'{type(self)}' object is not callable")
  63. assert min is None or min >= 0, "min must be non-negative"
  64. assert max is None or max >= 0, "max must be non-negative"
  65. assert min is None or max is None or min <= max, "min must be <= max"
  66. return _DimHint(self.type, min=min, max=max, _factory=False)
  67. class Dim:
  68. """
  69. The ``Dim`` class allows users to specify dynamism in their exported
  70. programs. By marking a dimension with a ``Dim``, the compiler associates the
  71. dimension with a symbolic integer containing a dynamic range.
  72. The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes:
  73. ``Dim.AUTO``, ``Dim.DYNAMIC``, ``Dim.STATIC``), or named Dims (i.e.
  74. ``Dim("name", min=1, max=2)``).
  75. Dim hints provide the lowest barrier to exportability, with the user only
  76. needing to specify if a dimension if dynamic, static, or left for the
  77. compiler to decide (``Dim.AUTO``). The export process will automatically
  78. infer the remaining constraints on min/max ranges and relationships between
  79. dimensions.
  80. Example::
  81. class Foo(nn.Module):
  82. def forward(self, x, y):
  83. assert x.shape[0] == 4
  84. assert y.shape[0] >= 16
  85. return x @ y
  86. x = torch.randn(4, 8)
  87. y = torch.randn(8, 16)
  88. dynamic_shapes = {
  89. "x": {0: Dim.AUTO, 1: Dim.AUTO},
  90. "y": {0: Dim.AUTO, 1: Dim.AUTO},
  91. }
  92. ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
  93. Here, export would raise an exception if we replaced all uses of ``Dim.AUTO`` with ``Dim.DYNAMIC``,
  94. as ``x.shape[0]`` is constrained to be static by the model.
  95. More complex relations between dimensions may also be codegened as runtime assertion nodes by the compiler,
  96. e.g. ``(x.shape[0] + y.shape[1]) % 4 == 0``, to be raised if runtime inputs do not satisfy such constraints.
  97. You may also specify min-max bounds for Dim hints, e.g. ``Dim.AUTO(min=16, max=32)``, ``Dim.DYNAMIC(max=64)``,
  98. with the compiler inferring the remaining constraints within the ranges. An exception will be raised if
  99. the valid range is entirely outside the user-specified range.
  100. Named Dims provide a stricter way of specifying dynamism, where exceptions are raised if the compiler
  101. infers constraints that do not match the user specification. For example, exporting the previous
  102. model, the user would need the following ``dynamic_shapes`` argument::
  103. s0 = Dim("s0")
  104. s1 = Dim("s1", min=16)
  105. dynamic_shapes = {
  106. "x": {0: 4, 1: s0},
  107. "y": {0: s0, 1: s1},
  108. }
  109. ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes)
  110. Named Dims also allow specification of relationships between dimensions, up
  111. to univariate linear relations. For example, the following indicates one
  112. dimension is a multiple of another plus 4::
  113. s0 = Dim("s0")
  114. s1 = 3 * s0 + 4
  115. """
  116. AUTO = _DimHint.AUTO()
  117. DYNAMIC = _DimHint.DYNAMIC()
  118. STATIC = _DimHint.STATIC()
  119. def __init__(
  120. self, name: str, *, min: Optional[int] = None, max: Optional[int] = None
  121. ):
  122. from torch.utils._sympy.numbers import int_oo
  123. _min = 0 if min is None else min
  124. _max = int_oo if max is None else max
  125. assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
  126. assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
  127. self.__name__ = name
  128. self.min = _min
  129. self.max = _max
  130. def __add__(self, other) -> "Dim":
  131. # e.g., dim + 1
  132. if type(other) is not int:
  133. raise NotImplementedError(
  134. f"Attempted to add {other} to {self.__name__}, where an integer was expected. "
  135. "(Only increasing linear operations with integer coefficients are supported.)"
  136. )
  137. return self._derive(lambda x: x + other)
  138. def __radd__(self, other) -> "Dim":
  139. return self + other
  140. def __sub__(self, other) -> "Dim":
  141. # e.g., dim - 1
  142. if type(other) is not int:
  143. raise NotImplementedError(
  144. f"Attempted to subtract {other} from {self.__name__}, where an integer was expected. "
  145. "(Only increasing linear operations with integer coefficients are supported.)"
  146. )
  147. return self._derive(lambda x: x - other)
  148. def __rsub__(self, other) -> "Dim":
  149. raise NotImplementedError(
  150. f"Attempted to negate {self.__name__}. "
  151. "(Only increasing linear operations with integer coefficients are supported.)"
  152. )
  153. def __mul__(self, other) -> "Dim":
  154. # e.g., dim * 2
  155. if type(other) is not int or other <= 0:
  156. raise NotImplementedError(
  157. f"Attempted to multiply {other} with {self.__name__}, where a positive integer was expected. "
  158. "(Only increasing linear operations with integer coefficients are supported.)"
  159. )
  160. return self._derive(lambda x: x * other)
  161. def __rmul__(self, other) -> "Dim":
  162. return self * other
  163. def _derived_name(self, fn) -> str:
  164. from sympy import sympify
  165. return str(fn(sympify(self.__name__)))
  166. def _derive(self, fn) -> "Dim":
  167. return _DerivedDim(self._derived_name(fn), self, fn)
  168. @staticmethod
  169. def _readable(name: str, min_: int, max_: int) -> str:
  170. from torch.utils._sympy.numbers import int_oo
  171. if min_ == 2:
  172. min_ = None # type: ignore[assignment]
  173. if max_ == int_oo:
  174. max_ = None # type: ignore[assignment]
  175. if min_ is None and max_ is None:
  176. return f"Dim('{name}')"
  177. if min_ is None:
  178. return f"Dim('{name}', max={max_})"
  179. if max_ is None:
  180. return f"Dim('{name}', min={min_})"
  181. return f"Dim('{name}', min={min_}, max={max_})"
  182. def __repr__(self):
  183. return Dim._readable(self.__name__, self.min, self.max)
  184. _Dim = Dim # TODO(pianpwk): remove after it's no longer internally breaking
  185. class _StaticDim(Dim):
  186. """
  187. Class for static :func:`Dim` types.
  188. This class is only for setting and checking static dim constraints,
  189. and the user should never interact with it.
  190. """
  191. def __init__(self, value: int):
  192. self.__name__ = str(value)
  193. self.value = value
  194. @property
  195. def min(self): # type: ignore[override]
  196. return self.value # type: ignore[attr-defined]
  197. @property
  198. def max(self): # type: ignore[override]
  199. return self.value # type: ignore[attr-defined]
  200. class _DerivedDim(Dim):
  201. """
  202. Class for derived :func:`Dim` types.
  203. Currently we only support increasing linear expressions with integer coefficients.
  204. In other words, a derived Dim can always be written in the form Ax + B, where
  205. x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive.
  206. (In particular, the latter ensures that x < y => Ax + B < Ay + B.)
  207. These restrictions on the form of derived Dims makes the metatheory simpler: e.g.,
  208. it simplifies computing ranges for derived Dims, solving for underlying regular Dims,
  209. deciding equalities between derived Dims, and so on.
  210. The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`.
  211. The range of a derived Dim is computed by mapping `fn` over the range of its `root`.
  212. """
  213. def __init__(self, name: str, root: Dim, fn: Callable):
  214. self.__name__ = name
  215. self.root = root
  216. self.fn = fn
  217. @property
  218. def min(self): # type: ignore[override]
  219. # assume that self.fn is an increasing function
  220. # TODO(avik): use sympy value range analysis instead?
  221. from sympy import Integer
  222. from torch.utils._sympy.numbers import int_oo
  223. if self.root.min is -int_oo: # type: ignore[attr-defined]
  224. return -int_oo # fn not needed cuz increasing
  225. _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined]
  226. root = self.root # type: ignore[attr-defined]
  227. assert _min_symint >= 0, (
  228. f"Expected derived min value of {self.__name__} to be >= 0. "
  229. f"Please specify an appropriate min value for {root.__name__} "
  230. f"(currently {root.min})."
  231. )
  232. return int(_min_symint)
  233. @property
  234. def max(self): # type: ignore[override]
  235. # assume that self.fn is an increasing function
  236. # TODO(avik): use sympy value range analysis instead?
  237. from sympy import Integer
  238. from torch.utils._sympy.numbers import int_oo
  239. if self.root.max is int_oo: # type: ignore[attr-defined]
  240. return int_oo # fn not needed cuz increasing
  241. _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined]
  242. root = self.root # type: ignore[attr-defined]
  243. assert _max_symint <= sys.maxsize - 1, (
  244. f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. "
  245. f"Please specify an appropriate max value for {root.__name__} "
  246. f"(currently {root.max})."
  247. )
  248. return int(_max_symint)
  249. def _derive(self, fn):
  250. # We support nesting, e.g., 2*dim + 1.
  251. # This is implemented by composing operations on the same root.
  252. # As a consequence, roots are always regular Dims (i.e., not derived Dims).
  253. return _DerivedDim(
  254. self._derived_name(fn),
  255. self.root,
  256. lambda x: fn(self.fn(x)),
  257. )
  258. def __repr__(self):
  259. return self.__name__
  260. def dims(
  261. *names: str, min: Optional[int] = None, max: Optional[int] = None
  262. ) -> tuple[Dim, ...]:
  263. """
  264. Util to create multiple :func:`Dim` types.
  265. Returns:
  266. A tuple of :func:`Dim` types.
  267. """
  268. return tuple(Dim(name, min=min, max=max) for name in names) # type: ignore[misc]
  269. @dataclasses.dataclass
  270. class _ConstraintTarget:
  271. """
  272. This represents input tensor dimensions.
  273. """
  274. t_id: int
  275. dim: int
  276. @dataclasses.dataclass
  277. class _Constraint(_ConstraintTarget):
  278. """
  279. This represents a Dim describing a constraint target.
  280. `name` is the name of the Dim.
  281. `constraint_range` contains the min/max bounds of the Dim.
  282. """
  283. name: str
  284. constraint_range: "StrictMinMaxConstraint"
  285. def _clone_with_range(self, lower=0, upper=None):
  286. # Import sympy locally
  287. from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
  288. from torch.utils._sympy.numbers import int_oo
  289. from torch.utils._sympy.value_ranges import ValueRanges
  290. if upper is None:
  291. upper = int_oo
  292. constraint_range = StrictMinMaxConstraint(
  293. vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
  294. warn_only=False,
  295. )
  296. return _Constraint(
  297. self.t_id,
  298. self.dim,
  299. self.name,
  300. constraint_range,
  301. )
  302. def __ge__(self, lower):
  303. return self._clone_with_range(lower=lower)
  304. def __gt__(self, lower):
  305. return self._clone_with_range(lower=lower + 1)
  306. def __le__(self, upper):
  307. return self._clone_with_range(upper=upper)
  308. def __lt__(self, upper):
  309. return self._clone_with_range(upper=upper - 1)
  310. def __bool__(self):
  311. # NOTE(avik): We do not support compound expressions like a <= x <= b.
  312. # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b),
  313. # and moreover, enforces that any overload of __bool__ must return True or False.
  314. # FWIW, sympy also raises TypeError in this case.
  315. raise TypeError(
  316. "Cannot determine truth value of _Constraint. "
  317. "If you are trying to combine _Constraint's with logical connectives, "
  318. "you can specify them separately instead."
  319. )
  320. @property
  321. def serializable_spec(self):
  322. # We need a serialization compatible format of the constraint so that it
  323. # can be savedin the graph module w/o breaking the module serialization.
  324. # The saved constraints will be used directly for the post-exporting pass
  325. # that converts constraints to runtime assertion. The saved constraints
  326. # will not be saved in the serialized module.
  327. # TODO: A better way is needed. Currently we use 't_id' to map the constraint,
  328. # which is not reliable
  329. return {
  330. "t_id": self.t_id,
  331. "dim": self.dim,
  332. "min": self.constraint_range.vr.lower,
  333. "max": self.constraint_range.vr.upper,
  334. }
  335. @dataclasses.dataclass
  336. class _PhantomRoot:
  337. """
  338. This represents the root of a derived Dim where the root does not directly
  339. specify the shape of any input dimension, but the derived Dim does.
  340. e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim.
  341. The fields `name`, `constraint_range`, and `val` carried by a phantom root
  342. help create a symbol for it. Any derived dims with this phantom root are
  343. backed by expressions over this symbol.
  344. """
  345. name: str
  346. constraint_range: "StrictMinMaxConstraint"
  347. val: int
  348. @dataclasses.dataclass
  349. class _DerivedConstraint(_ConstraintTarget):
  350. """
  351. This represents a derived Dim, whose root is either a regular constraint target
  352. (which directly specifies the shape of some input dimension) or a phantom root
  353. (which does so indirectly).
  354. It can be thought of as a subclass of `_Constraint`, except that it does not
  355. support <, <=, >, >= operations.
  356. """
  357. name: str
  358. constraint_range: "StrictMinMaxConstraint"
  359. root: Union[_ConstraintTarget, _PhantomRoot]
  360. fn: Callable
  361. @property
  362. def serializable_spec(self):
  363. # same as _Constraint.serializable_spec
  364. return {
  365. "t_id": self.t_id,
  366. "dim": self.dim,
  367. "min": self.constraint_range.vr.lower,
  368. "max": self.constraint_range.vr.upper,
  369. }
  370. @dataclasses.dataclass
  371. class _RelaxedConstraint(_ConstraintTarget):
  372. """
  373. This represents a dim marked with Dim.AUTO/DYNAMIC (i.e. mark_dynamic() or maybe_mark_dynamic()),
  374. which leaves relations & min/max ranges for inference, instead of requiring explicit specification.
  375. The intention is for constraint violations to not be raised if produce_guards() finds equalities or
  376. relations between a _RelaxedConstraint and another type of _Constraint.
  377. """
  378. @property
  379. def serializable_spec(self):
  380. return {
  381. "t_id": self.t_id,
  382. "dim": self.dim,
  383. }
  384. Constraint = Union[_Constraint, _DerivedConstraint, _RelaxedConstraint]
  385. @dataclasses.dataclass
  386. class _IntWrapper:
  387. """
  388. Dummy wrapper class to wrap around integer inputs so that when we parse the
  389. dynamic_shapes structure, we can mark if any of the integers were marked as
  390. dynamic.
  391. """
  392. val: int
  393. # Disallow specifying dynamism
  394. dynamism: Optional[Union[_DimHint, int]] = dataclasses.field(
  395. init=False, default=None
  396. )
  397. def _process_equalities(
  398. constraint: Constraint,
  399. get_sources: Callable[[int, int], list["Source"]],
  400. shape_env: "ShapeEnv",
  401. names: dict[str, tuple[int, int]],
  402. source_pairs: list[tuple["Source", "Source"]],
  403. derived_equalities: list[tuple["Source", Union["Source", "Symbol"], Callable]],
  404. phantom_symbols: dict[str, "Symbol"],
  405. relaxed_sources: set["Source"],
  406. ):
  407. """
  408. Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become
  409. fields of `EqualityConstraint`) based on a given input `constraint`.
  410. """
  411. sources = get_sources(constraint.t_id, constraint.dim)
  412. if not sources: # empty sources due to unused shapes
  413. return
  414. source, *other_sources = sources
  415. # When t.size()[dim] maps to src0, src1, ..., srcN, we add
  416. # constraints that make src0 "equal" to src1, ..., srcN.
  417. source_pairs.extend((source, other_source) for other_source in other_sources)
  418. if isinstance(constraint, _Constraint):
  419. if constraint.name in names:
  420. shared_t_id, shared_dim = names[constraint.name]
  421. other_sources = get_sources(shared_t_id, shared_dim)
  422. source_pairs.extend(
  423. (source, other_source) for other_source in other_sources
  424. )
  425. else:
  426. names[constraint.name] = (constraint.t_id, constraint.dim)
  427. elif isinstance(constraint, _DerivedConstraint):
  428. # branch based on the root of the _DerivedConstraint
  429. if not isinstance(constraint.root, _PhantomRoot):
  430. # either root points to an input source
  431. root = get_sources(constraint.root.t_id, constraint.root.dim)[0]
  432. else:
  433. # or root points to a phantom symbol
  434. if constraint.root.name in phantom_symbols:
  435. root = phantom_symbols[constraint.root.name]
  436. else:
  437. # create a phantom symbol in the shape env based on the _PhantomRoot
  438. root = shape_env.create_symbol(
  439. val=constraint.root.val,
  440. source=torch._dynamo.source.ConstantSource(constraint.root.name),
  441. dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC,
  442. constraint_dim=constraint.root.constraint_range,
  443. )
  444. phantom_symbols[constraint.root.name] = root
  445. fn = constraint.fn
  446. # A derived equality (source, root, fn) informally corresponds to source = fn(root).
  447. # Here source describes an input and root might describe another input or a phantom symbol.
  448. derived_equalities.append((source, root, fn))
  449. elif isinstance(constraint, _RelaxedConstraint):
  450. relaxed_sources.add(source)
  451. def _tree_map_with_path(
  452. func: Callable[..., Any],
  453. tree: Any,
  454. *dynamic_shapes: Any,
  455. tree_name: Optional[str] = None,
  456. ) -> Any:
  457. """
  458. Customized tree_map for mapping pytrees to dynamic_shapes.
  459. For built-in types (e.g., standard collections) this behaves exactly like tree_map.
  460. OTOH for a user-defined class C registered with pytree, we cannot assume that a C
  461. containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not
  462. be a polymorphic container). In that case we use the flattened form of C instead.
  463. Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes).
  464. Args:
  465. func: function to apply to each (int, float, str, bool, None, torch.Tensor)
  466. tree: input pytree
  467. dynamic_shapes: zero or more (typically one) dynamic_shapes to match
  468. Returns:
  469. output pytree mapping func to each (int, float, str, bool, None, torch.Tensor)
  470. """
  471. def is_leaf(t):
  472. # BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types
  473. # registered with pytree. Types *not* in BUILTIN_TYPES include primitive types
  474. # (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES,
  475. # as well as user-defined classes registered with pytree, which are.
  476. return _get_node_type(t) not in BUILTIN_TYPES
  477. def f(path, t, *dynamic_shapes):
  478. typ = _get_node_type(t)
  479. # typ is not in BUILTIN_TYPES
  480. if typ in SUPPORTED_NODES:
  481. # thus typ is a user-defined class registered with pytree,
  482. # in which case flatten and recurse
  483. return tree_map_with_path(
  484. f,
  485. SUPPORTED_NODES[typ].flatten_fn(t)[0],
  486. *dynamic_shapes,
  487. is_leaf=is_leaf,
  488. )
  489. else:
  490. return func(path, t, *dynamic_shapes)
  491. try:
  492. return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf)
  493. except ValueError as e:
  494. if "mismatch" in e.args[0]:
  495. # When PyTree finds a structural mismatch between tree and dynamic_shapes,
  496. # the error message is unfortunately quite horrible. Let's fix that.
  497. assert dynamic_shapes, "Cannot be a mismatch if there is no dynamic_shapes"
  498. assert tree_name, "Must provide a tree_name when there might be a mismatch"
  499. def _key(type_, context, i):
  500. # derive a PyTree key given the type, context, and child # of a TreeSpec
  501. if type_ is dict:
  502. return MappingKey(context[i])
  503. if type_ in (list, tuple):
  504. assert context is None
  505. return SequenceKey(i)
  506. raise AssertionError(f"Did not expect type {type_}")
  507. def raise_mismatch_error(msg):
  508. from torch._dynamo.exc import UserError, UserErrorType
  509. raise UserError(
  510. UserErrorType.INVALID_INPUT,
  511. f"Detected mismatch between the structure of `{tree_name}` and `dynamic_shapes`: {msg}",
  512. case_name="dynamic_shapes_validation",
  513. )
  514. def _compare(tree, dynamic_shapes, path):
  515. # raise an error at the point where tree and dynamic_shapes differ,
  516. # including the path to that point and the reason for the difference
  517. rendered_path = keystr(path)
  518. if isinstance(tree, LeafSpec):
  519. return
  520. if isinstance(dynamic_shapes, LeafSpec):
  521. raise_mismatch_error(
  522. f"`{tree_name}{rendered_path}` is a {tree.type}, "
  523. f"but `dynamic_shapes{rendered_path}` is not"
  524. )
  525. if tree.type != dynamic_shapes.type:
  526. raise_mismatch_error(
  527. f"`{tree_name}{rendered_path}` is a {tree.type}, "
  528. f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}"
  529. )
  530. if len(tree.children_specs) != len(dynamic_shapes.children_specs):
  531. raise_mismatch_error(
  532. f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, "
  533. f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements"
  534. )
  535. if tree.type is dict:
  536. # context, children could be out of order
  537. if sorted(tree.context) != sorted(dynamic_shapes.context):
  538. raise_mismatch_error(
  539. f"`{tree_name}{rendered_path}` has keys {tree.context}, "
  540. f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}"
  541. )
  542. _remap = dict(
  543. zip(dynamic_shapes.context, dynamic_shapes.children_specs)
  544. )
  545. dynamic_shapes_children_specs = [_remap[k] for k in tree.context]
  546. else:
  547. dynamic_shapes_children_specs = dynamic_shapes.children_specs
  548. for i, (tree_, dynamic_shapes_) in enumerate(
  549. zip(tree.children_specs, dynamic_shapes_children_specs)
  550. ):
  551. _compare(
  552. tree_,
  553. dynamic_shapes_,
  554. path + [_key(tree.type, tree.context, i)],
  555. )
  556. _, tree_spec = tree_flatten(tree, is_leaf=is_leaf)
  557. for other_tree in dynamic_shapes:
  558. _, other_tree_spec = tree_flatten(other_tree, is_leaf)
  559. _compare(tree_spec, other_tree_spec, [])
  560. raise
  561. def _combine_args(f, args, kwargs) -> dict[str, Any]:
  562. # combine args and kwargs following the signature of f, as it happens
  563. # in the body of f when called with *args, **kwargs
  564. if isinstance(f, ExportedProgram):
  565. f = f.module()
  566. signature = (
  567. inspect.signature(f.forward)
  568. if isinstance(f, torch.nn.Module)
  569. else inspect.signature(f)
  570. )
  571. kwargs = kwargs if kwargs is not None else {}
  572. return signature.bind(*args, **kwargs).arguments
  573. class ShapesCollection:
  574. """
  575. Builder for dynamic_shapes.
  576. Used to assign dynamic shape specifications to tensors that appear in inputs.
  577. This is useful particularly when :func:`args` is a nested input structure, and it's
  578. easier to index the input tensors, than to replicate the structure of :func:`args` in
  579. the :func:`dynamic_shapes` specification.
  580. Example::
  581. args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
  582. dim = torch.export.Dim(...)
  583. dynamic_shapes = torch.export.ShapesCollection()
  584. dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
  585. dynamic_shapes[tensor_y] = {0: dim * 2}
  586. # This is equivalent to the following (now auto-generated):
  587. # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}
  588. torch.export(..., args, dynamic_shapes=dynamic_shapes)
  589. To specify dynamism for integers, we need to first wrap the integers using
  590. _IntWrapper so that we have a "unique identification tag" for each integer.
  591. Example::
  592. args = {"x": tensor_x, "others": [int_x, int_y]}
  593. # Wrap all ints with _IntWrapper
  594. mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args)
  595. dynamic_shapes = torch.export.ShapesCollection()
  596. dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
  597. dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC
  598. # This is equivalent to the following (now auto-generated):
  599. # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]}
  600. torch.export(..., args, dynamic_shapes=dynamic_shapes)
  601. """
  602. def __init__(self):
  603. self._shapes = {}
  604. def __setitem__(self, t, shape):
  605. assert isinstance(t, (torch.Tensor, _IntWrapper)), (
  606. f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}"
  607. )
  608. # TODO(avik): check that shape is indeed a Shape
  609. t_id = id(t)
  610. if t_id in self._shapes:
  611. _shape = self._shapes[t_id]
  612. assert shape == _shape, (
  613. f"Shapes assigned to input do not match: expected {_shape}, got {shape}"
  614. )
  615. else:
  616. self._shapes[id(t)] = shape
  617. def __getitem__(self, t):
  618. t_id = id(t)
  619. if t_id not in self._shapes:
  620. self._shapes[t_id] = {}
  621. return self._shapes[t_id]
  622. def __len__(self):
  623. return len(self._shapes)
  624. def dynamic_shapes(self, m, args, kwargs=None):
  625. """
  626. Generates the :func:`dynamic_shapes` pytree structure according to :func:`args` and :func:`kwargs`.
  627. """
  628. t_ids = set()
  629. def find_shape(path, t):
  630. t_id = id(t)
  631. if t_id in self._shapes:
  632. t_ids.add(t_id)
  633. return self._shapes[t_id]
  634. else:
  635. return None
  636. combined_args = _combine_args(m, args, kwargs)
  637. dynamic_shapes = _tree_map_with_path(find_shape, combined_args)
  638. if any(t_id not in t_ids for t_id in self._shapes):
  639. raise ValueError(
  640. "Some tensors that were assigned shapes were not found in args. "
  641. "Maybe such tensors were copied when passing them as args? "
  642. "Maybe such tensors are contained in classes that were not registered with pytree?"
  643. )
  644. return dynamic_shapes
  645. class AdditionalInputs:
  646. """
  647. Infers dynamic_shapes based on additional inputs.
  648. This is useful particularly for deployment engineers who, on the one hand, may
  649. have access to ample testing or profiling data that can provide a fair sense of
  650. representative inputs for a model, but on the other hand, may not know enough
  651. about the model to guess which input shapes should be dynamic.
  652. Input shapes that are different than the original are considered dynamic; conversely,
  653. those that are the same as the original are considered static. Moreover, we verify
  654. that the additional inputs are valid for the exported program. This guarantees that
  655. tracing with them instead of the original would have generated the same graph.
  656. Example::
  657. args0, kwargs0 = ... # example inputs for export
  658. # other representative inputs that the exported program will run on
  659. dynamic_shapes = torch.export.AdditionalInputs()
  660. dynamic_shapes.add(args1, kwargs1)
  661. ...
  662. dynamic_shapes.add(argsN, kwargsN)
  663. torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes)
  664. """
  665. def __init__(self):
  666. self._examples = []
  667. def add(self, args, kwargs=None):
  668. """
  669. Additional input :func:`args` and :func:`kwargs`.
  670. """
  671. assert type(args) is tuple, f"Representative args {args} must be a tuple"
  672. assert kwargs is None or type(kwargs) is dict, (
  673. f"Representative kwargs {kwargs} must be None or a dict"
  674. )
  675. self._examples.append((args, kwargs))
  676. def dynamic_shapes(self, m, args, kwargs=None):
  677. """
  678. Infers a :func:`dynamic_shapes` pytree structure by merging shapes of the
  679. original input :func:`args` and :func:`kwargs` and of each additional input
  680. args and kwargs.
  681. """
  682. dynamic_shapes, *other_dynamic_shapes = [
  683. _tree_map_with_path(
  684. lambda path, t: tuple(t.shape) if isinstance(t, torch.Tensor) else t,
  685. _combine_args(m, args, kwargs),
  686. )
  687. for args, kwargs in [(args, kwargs), *self._examples]
  688. ]
  689. def _mark_dynamism(v, *other_vs):
  690. if not all(type(v) == type(other) for other in other_vs):
  691. raise ValueError(
  692. "The following inputs were found to have differing types, "
  693. f"so they cannot be marked as dynamic: {(v,) + other_vs}."
  694. )
  695. if isinstance(v, int) and not isinstance(v, bool):
  696. if all(other_v == v for other_v in other_vs):
  697. return None
  698. else:
  699. return Dim.DYNAMIC
  700. else:
  701. if not all(other_v == v for other_v in other_vs):
  702. raise ValueError(
  703. "The following inputs were found to have differing values, "
  704. f"but they cannot be marked as dynamic: {(v,) + other_vs}."
  705. )
  706. return None
  707. return tree_map(
  708. _mark_dynamism,
  709. dynamic_shapes,
  710. *other_dynamic_shapes,
  711. is_leaf=lambda i: type(i) is int,
  712. )
  713. def verify(self, ep):
  714. """
  715. Verifies that an exported program is valid for each additional input.
  716. """
  717. epm = ep.module()
  718. for args, kwargs in self._examples:
  719. torch.export._unlift._check_input_constraints_for_module(
  720. epm, args, kwargs or {}
  721. )
  722. def _warn_on_None_dynamic_shape_dimension():
  723. msg = (
  724. "Using None as a dynamic shape dimension is deprecated. "
  725. "Please use Dim.STATIC instead"
  726. )
  727. # TODO(avik): raise an error in the future
  728. log.warning(msg)
  729. def _check_dynamic_shapes(
  730. combined_args: dict[str, Any],
  731. dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
  732. ):
  733. """
  734. Checks the dynamic_shapes specification for correctness,
  735. using combined args + kwargs as reference for inputs structure.
  736. """
  737. from torch._dynamo.exc import UserError, UserErrorType
  738. if dynamic_shapes is None or len(dynamic_shapes) == 0:
  739. return
  740. if isinstance(dynamic_shapes, (tuple, list)):
  741. combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
  742. bounds: dict[str, tuple[int, int]] = {}
  743. def check_same_bounds(dim):
  744. if dim.__name__ in bounds:
  745. min_, max_ = bounds[dim.__name__]
  746. if dim.min != min_ or dim.max != max_:
  747. this_ = Dim._readable(dim.__name__, min_, max_)
  748. that_ = Dim._readable(dim.__name__, dim.min, dim.max)
  749. raise UserError(
  750. UserErrorType.INVALID_INPUT,
  751. f"Found different definitions {this_} and {that_} "
  752. f"for the same symbolic dimension {dim}!",
  753. )
  754. else:
  755. bounds[dim.__name__] = (dim.min, dim.max)
  756. def check_symbols(path, tensor, shape):
  757. if isinstance(shape, dict):
  758. for i, dim in shape.items():
  759. if isinstance(dim, Dim):
  760. check_same_bounds(dim)
  761. elif dim is None:
  762. _warn_on_None_dynamic_shape_dimension()
  763. elif not (isinstance(dim, (int, _DimHint))):
  764. raise UserError(
  765. UserErrorType.INVALID_INPUT,
  766. f"Unexpected dimension mapped to index {i} in input tensor shape {shape} "
  767. f"specified at `dynamic_shapes{keystr(path)}` "
  768. f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, "
  769. f" but got {dim!r} instead)",
  770. case_name="dynamic_shapes_validation",
  771. )
  772. elif isinstance(shape, (tuple, list)):
  773. if len(shape) != len(tensor.shape):
  774. raise UserError(
  775. UserErrorType.INVALID_INPUT,
  776. f"Expected dynamic shape spec {shape} specified at `dynamic_shapes{keystr(path)}` "
  777. f"to have the same length as the actual tensor shape {tensor.shape} "
  778. f"(expected {len(tensor.shape)}, but got {len(shape)} instead)",
  779. case_name="dynamic_shapes_validation",
  780. )
  781. for i, dim in enumerate(shape):
  782. if isinstance(dim, Dim):
  783. check_same_bounds(dim)
  784. elif dim is None:
  785. _warn_on_None_dynamic_shape_dimension()
  786. elif not (isinstance(dim, (int, _DimHint))):
  787. raise UserError(
  788. UserErrorType.INVALID_INPUT,
  789. f"Unexpected dimension #{i} in input tensor shape {shape} "
  790. f"specified at `dynamic_shapes{keystr(path)}` "
  791. f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, "
  792. f"but got {dim!r} instead)",
  793. case_name="dynamic_shapes_validation",
  794. )
  795. elif shape is not None:
  796. raise UserError(
  797. UserErrorType.INVALID_INPUT,
  798. f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` "
  799. f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"
  800. f" where each dimension is an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC)",
  801. case_name="dynamic_shapes_validation",
  802. )
  803. assert isinstance(dynamic_shapes, (dict, tuple, list))
  804. if isinstance(dynamic_shapes, dict):
  805. got_keys = list(dynamic_shapes.keys())
  806. expected_arg_names = list(combined_args.keys())
  807. if sorted(got_keys) != sorted(expected_arg_names):
  808. msg = (
  809. f"When `dynamic_shapes` is specified as a dict, its top-level keys "
  810. f"must be the arg names {expected_arg_names} of `inputs`, but "
  811. f"here they are {got_keys}. "
  812. )
  813. if (
  814. len(combined_args) == 1
  815. and expected_arg_names[0] not in got_keys
  816. and isinstance(combined_args[expected_arg_names[0]], dict)
  817. ):
  818. msg += (
  819. "Since here `inputs` is a list/tuple enclosing a single dict, "
  820. "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?"
  821. )
  822. else:
  823. msg += (
  824. "Alternatively, you could also ignore arg names entirely "
  825. "and specify `dynamic_shapes` as a list/tuple matching `inputs`."
  826. )
  827. raise UserError(
  828. UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation"
  829. )
  830. def check_shape(path, t, dynamic_shape):
  831. if isinstance(t, torch.Tensor):
  832. check_symbols(path, t, dynamic_shape)
  833. elif isinstance(t, _IntWrapper):
  834. if isinstance(dynamic_shape, _Dim):
  835. raise ValueError(
  836. "Unable to specify input integers as dynamic through named "
  837. "Dims. Please use Dim.AUTO/DYNAMIC instead."
  838. )
  839. assert dynamic_shape is None or isinstance(dynamic_shape, (int, _DimHint))
  840. else:
  841. if dynamic_shape is not None:
  842. rendered_path = keystr(path)
  843. raise UserError(
  844. UserErrorType.INVALID_INPUT,
  845. f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` "
  846. f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)",
  847. case_name="dynamic_shapes_validation",
  848. )
  849. _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs")
  850. def _process_dynamic_shapes(
  851. combined_args: dict[str, Any],
  852. dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
  853. ) -> list[Constraint]:
  854. """
  855. Reads the dynamic_shapes specification and produces a list of constraints.
  856. """
  857. from torch._dynamo.exc import UserError, UserErrorType
  858. if dynamic_shapes is None or len(dynamic_shapes) == 0:
  859. # we run with dynamic by default, so no need to produce constraints
  860. return []
  861. if isinstance(dynamic_shapes, (tuple, list)):
  862. combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
  863. # map of Dim names representing input shape dimensions to constraints on them
  864. symbols: dict[str, list[Constraint]] = defaultdict(list)
  865. # track roots that do not directly represent input shape dimensions
  866. phantom_roots: dict[str, _PhantomRoot] = {}
  867. derived_constraints_with_phantom_root: list[_DerivedConstraint] = []
  868. # list of constraints to return
  869. constraints: list[Constraint] = []
  870. def to_constraint(dim, tensor, i):
  871. import sympy
  872. from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
  873. from torch.utils._sympy.solve import try_solve
  874. from torch.utils._sympy.value_ranges import ValueRanges
  875. def root_value():
  876. # given tensor.shape[i] is the value of dim = fn(root),
  877. # find the value of root
  878. symbol = sympy.Symbol(dim.root.__name__, integer=True)
  879. expr = dim.fn(symbol)
  880. solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol)
  881. if solution is not None:
  882. return int(solution[1])
  883. else:
  884. raise UserError( # noqa: B904
  885. UserErrorType.CONSTRAINT_VIOLATION,
  886. f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be "
  887. f"of the form {expr}, where {symbol} is an integer",
  888. )
  889. if isinstance(dim, _DerivedDim):
  890. # generate a _DerivedConstraint where the root is:
  891. # - either a _ConstraintTarget (if dim.root directly describes an input shape)
  892. # - or a _PhantomRoot (otherwise)
  893. dim_root = dim.root # type: ignore[attr-defined]
  894. if dim_root.__name__ in symbols:
  895. # root represents an input shape dimension
  896. root_constraint = symbols[dim_root.__name__][0]
  897. root = _ConstraintTarget(
  898. root_constraint.t_id,
  899. root_constraint.dim,
  900. )
  901. elif dim_root.__name__ not in phantom_roots:
  902. # create a phantom root
  903. root = _PhantomRoot( # type: ignore[assignment]
  904. name=dim_root.__name__,
  905. constraint_range=StrictMinMaxConstraint(
  906. vr=ValueRanges(lower=dim_root.min, upper=dim_root.max),
  907. warn_only=False,
  908. ),
  909. val=root_value(),
  910. )
  911. phantom_roots[dim_root.__name__] = root # type: ignore[assignment]
  912. else:
  913. root = phantom_roots[dim_root.__name__] # type: ignore[assignment]
  914. constraint = _DerivedConstraint(
  915. id(tensor),
  916. i,
  917. dim.__name__,
  918. StrictMinMaxConstraint(
  919. vr=ValueRanges(lower=dim.min, upper=dim.max),
  920. warn_only=False,
  921. ),
  922. root,
  923. dim.fn, # type: ignore[attr-defined]
  924. )
  925. if isinstance(root, _PhantomRoot):
  926. # NOTE(avik): since we have not processed all inputs yet, we may replace this
  927. # with a root that does represent an input shape dimension later (see below)
  928. derived_constraints_with_phantom_root.append(constraint)
  929. elif isinstance(dim, _StaticDim):
  930. constraint = _Constraint( # type: ignore[assignment]
  931. id(tensor),
  932. i,
  933. dim.__name__,
  934. StrictMinMaxConstraint(
  935. vr=ValueRanges(lower=dim.value, upper=dim.value), # type: ignore[attr-defined]
  936. warn_only=False,
  937. ),
  938. )
  939. else:
  940. assert isinstance(dim, Dim)
  941. constraint = _Constraint( # type: ignore[assignment]
  942. id(tensor),
  943. i,
  944. dim.__name__,
  945. StrictMinMaxConstraint(
  946. vr=ValueRanges(lower=dim.min, upper=dim.max), # type: ignore[attr-defined]
  947. warn_only=False,
  948. ),
  949. )
  950. return constraint
  951. def _parse_tensor_dim(tensor, idx, dim) -> None:
  952. def _create_static_dim(tensor, i, value):
  953. return _StaticDim(value)
  954. if isinstance(dim, (int, Dim)):
  955. if isinstance(dim, int):
  956. dim = _create_static_dim(tensor, idx, dim)
  957. constraint = to_constraint(dim, tensor, idx)
  958. symbols[dim.__name__].append(constraint)
  959. elif isinstance(dim, _DimHint):
  960. if dim.type == _DimHintType.AUTO:
  961. torch._dynamo.maybe_mark_dynamic(tensor, idx)
  962. elif dim.type == _DimHintType.STATIC:
  963. torch._dynamo.mark_static(tensor, idx)
  964. elif dim.type == _DimHintType.DYNAMIC:
  965. torch._dynamo.mark_dynamic(tensor, idx)
  966. constraints.append(_RelaxedConstraint(id(tensor), idx))
  967. elif dim is None:
  968. torch._dynamo.mark_static(tensor, idx)
  969. def update_symbols(path, tensor, shape):
  970. # clean out decorators from user side, or previous export call
  971. # we also delete these attributes in non_strict_utils.py/make_constraints()
  972. tensor._dynamo_weak_dynamic_indices = set()
  973. tensor._dynamo_dynamic_indices = set()
  974. tensor._dynamo_dynamic_range = set()
  975. tensor._dynamo_static_indices = set()
  976. tensor._dynamo_unbacked_indices = set()
  977. if isinstance(shape, dict):
  978. for i, dim in shape.items():
  979. _parse_tensor_dim(tensor, i, dim)
  980. elif isinstance(shape, (tuple, list)):
  981. for i, dim in enumerate(shape):
  982. _parse_tensor_dim(tensor, i, dim)
  983. elif shape is None:
  984. for i in range(tensor.dim()):
  985. _parse_tensor_dim(tensor, i, None)
  986. def assoc_shape(path, t, dynamic_shape):
  987. if isinstance(t, torch.Tensor):
  988. update_symbols(path, t, dynamic_shape)
  989. elif isinstance(t, _IntWrapper):
  990. # If tensor dimensions are marked as dynamic, the tensors themselves
  991. # get marked using mark_dynamic. However since we can't mark
  992. # integers as dynamic, we first wrap integers in this class, and
  993. # then set the `dim` field of the class with the dynamic shapes dim
  994. # to mark the integer as dynamic.
  995. t.dynamism = dynamic_shape
  996. _tree_map_with_path(assoc_shape, combined_args, dynamic_shapes, tree_name="inputs")
  997. for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root:
  998. phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr]
  999. if phantom_root_name in symbols:
  1000. # We found an input shape dimension corresponding to this name, so we
  1001. # do not need a phantom symbol for it after all.
  1002. # NOTE(avik): Overall we want to maintain the invariant that roots that
  1003. # are phantom symbols are really "phantom," i.e., they cannot be represented
  1004. # by any input source. This is important when we are deciding derived equalities,
  1005. # since we can focus our attention exclusively on input sources: deciding
  1006. # derived equalities involving phantom symbols are, in comparison, trivial.
  1007. derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0]
  1008. for dynamic_dims in symbols.values():
  1009. constraints.extend(dynamic_dims)
  1010. return constraints
  1011. def _get_dim_name_mapping(
  1012. dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
  1013. ):
  1014. name_to_dim = {}
  1015. for dim in tree_flatten(
  1016. dynamic_shapes,
  1017. is_leaf=lambda x: isinstance(x, Dim),
  1018. )[0]:
  1019. if dim is None:
  1020. # NOTE: this must denote a non-Tensor or automatic at this point.
  1021. continue
  1022. if isinstance(dim, int):
  1023. continue
  1024. elif isinstance(dim, Dim):
  1025. name_to_dim[dim.__name__] = dim
  1026. if isinstance(dim, _DerivedDim):
  1027. name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined]
  1028. else:
  1029. assert isinstance(dim, _DimHint)
  1030. return name_to_dim
  1031. def refine_dynamic_shapes_from_suggested_fixes(
  1032. msg: str,
  1033. dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
  1034. ) -> Union[dict[str, Any], tuple[Any], list[Any]]:
  1035. """
  1036. When exporting with :func:`dynamic_shapes`, export may fail with a ConstraintViolation error if the specification
  1037. doesn't match the constraints inferred from tracing the model. The error message may provide suggested fixes -
  1038. changes that can be made to :func:`dynamic_shapes` to export successfully.
  1039. Example ConstraintViolation error message::
  1040. Suggested fixes:
  1041. dim = Dim('dim', min=3, max=6) # this just refines the dim's range
  1042. dim = 4 # this specializes to a constant
  1043. dy = dx + 1 # dy was specified as an independent dim, but is actually tied to dx with this relation
  1044. This is a helper function that takes the ConstraintViolation error message and the original :func:`dynamic_shapes` spec,
  1045. and returns a new :func:`dynamic_shapes` spec that incorporates the suggested fixes.
  1046. Example usage::
  1047. try:
  1048. ep = export(mod, args, dynamic_shapes=dynamic_shapes)
  1049. except torch._dynamo.exc.UserError as exc:
  1050. new_shapes = refine_dynamic_shapes_from_suggested_fixes(
  1051. exc.msg, dynamic_shapes
  1052. )
  1053. ep = export(mod, args, dynamic_shapes=new_shapes)
  1054. """
  1055. import re
  1056. import sympy
  1057. from torch._dynamo.exc import UserError, UserErrorType
  1058. from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence
  1059. try:
  1060. shape_fixes_msg = msg.split("Suggested fixes:")[1].strip()
  1061. except Exception as exc:
  1062. raise UserError(
  1063. UserErrorType.INVALID_INPUT,
  1064. "Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()",
  1065. ) from exc
  1066. # build shape_fixes dictionary
  1067. shape_fixes = {}
  1068. for fix in shape_fixes_msg.split("\n"):
  1069. fix = fix.strip()
  1070. if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix):
  1071. name = match.group(1)
  1072. _min, _max = None, None
  1073. if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix):
  1074. _min = int(match_min.group(1))
  1075. if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix):
  1076. _max = int(match_max.group(1))
  1077. shape_fixes[name] = Dim(name, min=_min, max=_max)
  1078. else:
  1079. name, expr = fix.split(" = ")
  1080. expr = sympy.sympify(expr)
  1081. if isinstance(expr, sympy.Number):
  1082. # static, integer
  1083. shape_fixes[name] = int(expr) # type: ignore[assignment]
  1084. else:
  1085. # relation or derived dim
  1086. shape_fixes[name] = expr
  1087. name_to_dim = _get_dim_name_mapping(dynamic_shapes)
  1088. # track derived dim roots
  1089. roots: set[str] = set()
  1090. for k, c in shape_fixes.items():
  1091. assert isinstance(c, (int, Dim, _DerivedDim, sympy.Expr))
  1092. if isinstance(c, sympy.Expr): # check dim/derived dim expression
  1093. assert _is_supported_equivalence(c)
  1094. shape_fixes[k] = c
  1095. roots.add(str(next(iter(c.free_symbols))))
  1096. if isinstance(c, _DerivedDim):
  1097. roots.add(c.root.__name__) # type: ignore[attr-defined]
  1098. # check keys are existing dims or new roots
  1099. for k, c in shape_fixes.items():
  1100. assert k in name_to_dim or k in roots
  1101. # cache so we don't produce multiple derived dim objects
  1102. derived_dim_cache: dict[str, _DerivedDim] = {}
  1103. def apply_fixes(path, dim, dummy):
  1104. if dim is None or isinstance(dim, int): # not dynamic
  1105. return dim
  1106. elif dim.__name__ in shape_fixes: # directly fix
  1107. fix = shape_fixes[dim.__name__]
  1108. if isinstance(fix, sympy.Expr): # now derived or related
  1109. if str(fix) in derived_dim_cache:
  1110. return derived_dim_cache[str(fix)]
  1111. else:
  1112. symbol = next(iter(fix.free_symbols))
  1113. # try to locate symbol
  1114. if symbol.name in shape_fixes:
  1115. root = shape_fixes[symbol.name]
  1116. else:
  1117. assert symbol.name in name_to_dim
  1118. root = name_to_dim[symbol.name]
  1119. # figure out value of fix
  1120. modulus, remainder = sympy.polys.polytools.div(fix, symbol)
  1121. dim = root
  1122. if modulus != 1:
  1123. dim = int(modulus) * dim
  1124. if remainder != 0:
  1125. dim = dim + int(remainder)
  1126. derived_dim_cache[str(fix)] = dim
  1127. return dim
  1128. else:
  1129. return fix
  1130. elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes: # type: ignore[attr-defined]
  1131. if dim.__name__ in derived_dim_cache:
  1132. return derived_dim_cache[dim.__name__]
  1133. else: # evaluate new derived value based on root
  1134. _dim = dim.fn(shape_fixes[dim.root.__name__]) # type: ignore[attr-defined]
  1135. derived_dim_cache[dim.__name__] = _dim
  1136. return _dim
  1137. return dim # unchanged dim
  1138. return _tree_map_with_path(apply_fixes, dynamic_shapes, dynamic_shapes)