validator.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import functools
  4. import logging
  5. import math
  6. import operator
  7. from dataclasses import dataclass
  8. from typing import Any, Callable, Optional, Union
  9. import sympy
  10. import torch
  11. import torch.fx
  12. import torch.fx.traceback as fx_traceback
  13. from torch._dynamo.exc import TorchDynamoException
  14. from torch._dynamo.utils import dynamo_timed
  15. from torch.fx.node import Argument, Target
  16. from torch.utils._sympy.interp import sympy_interp
  17. log = logging.getLogger(__name__)
  18. try:
  19. import z3 # type: ignore[import]
  20. # Translation Validation for Dynamo guards
  21. # ========================================
  22. #
  23. # Checks whether optimizations applied to the collected guards are
  24. # valid. In other words, whether the guard function we actually run
  25. # does not have false positives (unsound).
  26. #
  27. # In order to do so, we build the guards using 2 different information
  28. # attached to each 'SymNode':
  29. # 1. SymPy expressions
  30. # 2. FX nodes
  31. #
  32. # SymPy expressions have implicit optimizations baked within itself,
  33. # which may have a few bugs. On the other hand, we build the FX graph
  34. # manually, with no optimizations enabled. This gives us access to
  35. # the "ground truth".
  36. #
  37. # We then convert into Z3 expressions both the SymPy expressions
  38. # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function
  39. # and the FX nodes (see [Note: PopulateValidator]) that go through
  40. # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
  41. # (see [Note: TranslationValidator])
  42. # Better Z3 to string implementation (for a small fraction of Z3).
  43. #
  44. # Here are the things we clean before showing the Z3 expression:
  45. # - Rename a few ops (e.g. "Distinct" ==> "!=")
  46. #
  47. # - Ignore ToInt and ToReal operations:
  48. # usually they don't really matter
  49. #
  50. # - Transform (ToInt (/ ...)) into (idiv ...):
  51. # this is the pattern for floor division
  52. #
  53. # - Collect a chain of the same operations into one
  54. def z3str(e: z3.ExprRef) -> str:
  55. assert z3.is_expr(e), f"unsupported expression type: {e}"
  56. def get_args_str(e: z3.ExprRef) -> list[str]:
  57. return [z3str(e.arg(i)) for i in range(e.num_args())]
  58. # First, we simplify the given expression.
  59. # This is done using rewriting rules, so shouldn't take long.
  60. e = z3.simplify(e)
  61. # Only support function applications.
  62. # Even Z3 "variables" are, in fact, function applications.
  63. if not z3.is_app(e):
  64. raise ValueError(f"can't print Z3 expression: {e}")
  65. if z3.is_int_value(e) or z3.is_rational_value(e):
  66. return e.as_string() # type: ignore[attr-defined]
  67. decl = e.decl()
  68. kind = decl.kind()
  69. op = str(decl)
  70. args = get_args_str(e)
  71. if kind == z3.Z3_OP_POWER:
  72. op = "pow"
  73. elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL):
  74. # Collect the arguments of chains of ADD and MUL.
  75. # This is safe, since they are associative.
  76. def collect_str_args(e):
  77. if not (z3.is_app(e) and e.decl().kind() == kind):
  78. return [z3str(e)]
  79. else:
  80. return [
  81. x
  82. for i in range(e.num_args())
  83. for x in collect_str_args(e.arg(i))
  84. ]
  85. args = collect_str_args(e)
  86. elif kind == z3.Z3_OP_NOT:
  87. # Revert some conversions that z3.simplify applies:
  88. # - a != b ==> (Not (== a b)) ==> (!= a b)
  89. # - a < b ==> (Not (<= b a)) ==> (> b a)
  90. # - a > b ==> (Not (<= a b)) ==> (> a b)
  91. assert e.num_args() == 1
  92. arg = e.arg(0)
  93. assert z3.is_app(arg)
  94. argkind = arg.decl().kind()
  95. logic_inverse = {
  96. z3.Z3_OP_EQ: "!=",
  97. z3.Z3_OP_LE: ">",
  98. z3.Z3_OP_GE: "<",
  99. }
  100. if argkind in logic_inverse:
  101. op = logic_inverse[argkind]
  102. args = get_args_str(arg)
  103. elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL):
  104. assert e.num_args() == 1
  105. argstr = z3str(e.arg(0))
  106. # Check if it's the floor division pattern.
  107. if argstr.startswith("(/"):
  108. return "(idiv" + argstr[2:]
  109. # Otherwise, just ignore it.
  110. return argstr
  111. elif kind == z3.Z3_OP_UNINTERPRETED:
  112. assert e.num_args() == 0
  113. return str(decl)
  114. string = op + " " + " ".join(args)
  115. return f"({string.rstrip()})"
  116. # We need to convert to/from BitVec in order to use z3 bitwise ops.
  117. # We assume that integers are 64 bit.
  118. # If all args are boolean, then use the boolean bitwise op implementation instead, if provided.
  119. def _bitwise_op(bitwise_func, bool_func):
  120. @functools.wraps(bitwise_func)
  121. def wrapper(self, *args):
  122. if bool_func is not None and all(
  123. isinstance(arg, z3.BoolRef) for arg in args
  124. ):
  125. return bool_func(*args)
  126. wrapped_args = tuple(z3.Int2BV(a, 64) for a in args)
  127. return z3.BV2Int(bitwise_func(*wrapped_args))
  128. return wrapper
  129. # Implementation of Python semantics as Z3 expressions.
  130. #
  131. # Z3 Real-Int theory has operators with semantics that differ that of
  132. # Python. Therefore, in order to get it right, we need to implement
  133. # the (Python) semantics we are relying on in Z3.
  134. @dataclass
  135. class _Z3Ops:
  136. # Validator used for adding assertions as needed.
  137. # e.g. div(a, b) requires b != 0.
  138. validator: "TranslationValidator"
  139. # The 2 functions below are used for conditionally casting between
  140. # integer and reals.
  141. #
  142. # Returns a real expression from 'x'.
  143. @staticmethod
  144. def to_real(x: z3.ArithRef) -> z3.ArithRef:
  145. return x if x.is_real() else z3.ToReal(x)
  146. # Returns an integer expression from 'x'.
  147. @staticmethod
  148. def to_int(x: z3.ArithRef) -> z3.ArithRef:
  149. return x if x.is_int() else z3.ToInt(x)
  150. def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef:
  151. return sum(args)
  152. # Implements Python division semantics.
  153. def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
  154. self.validator.add_assertion(denominator != 0) # type: ignore[arg-type]
  155. return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator)
  156. def floor(self, number: z3.ArithRef) -> z3.ArithRef:
  157. # Z3 ToInt function rounds a real number towards negative infinity.
  158. return _Z3Ops.to_int(number)
  159. # Python semantics for 'FloorDiv' states that before applying the floor
  160. # function, the operands are converted to their common type.
  161. def floordiv(
  162. self, numerator: z3.ArithRef, denominator: z3.ArithRef
  163. ) -> z3.ArithRef:
  164. cast_result_to_real = numerator.is_real() or denominator.is_real()
  165. result = _Z3Ops.to_int(self.div(numerator, denominator))
  166. # Since the 'result' is already an integer, we just have to check
  167. # whether we should cast it to real.
  168. return _Z3Ops.to_real(result) if cast_result_to_real else result
  169. def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
  170. return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value]
  171. def trunc(self, number: z3.ArithRef) -> z3.ArithRef:
  172. return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value]
  173. def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
  174. return z3.If(a > b, a, b) # type: ignore[return-value]
  175. def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
  176. return z3.If(a < b, a, b) # type: ignore[return-value]
  177. # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q
  178. # It should work with both integer and reals.
  179. def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
  180. return p - self.floordiv(p, q) * q
  181. def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
  182. # Z3 can't handle complex numbers very well.
  183. self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type]
  184. return base**exp
  185. def sqrt(self, number: z3.ArithRef) -> z3.ArithRef:
  186. # Square-root:
  187. # 1. Only work with reals
  188. number = _Z3Ops.to_real(number)
  189. # 2. The number should be positive or zero.
  190. # Otherwise, Z3 returns 'unknown'.
  191. self.validator.add_assertion(number >= 0)
  192. return number**0.5
  193. def abs(self, number: z3.ArithRef) -> z3.ArithRef:
  194. return z3.Abs(number)
  195. def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef:
  196. # Pythons builtin 'round' implements the 'round half to even' strategy
  197. # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
  198. # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to
  199. # floating point numbers, which is different from real numbers that we are dealing with here.
  200. # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and
  201. # 'round half down' (ceil(x - 0.5)).
  202. # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ...
  203. # to round down, i.e. use the 'round half down' strategy
  204. return z3.If(
  205. self.mod(number, z3.IntVal(2)) == 0.5,
  206. self.ceil(number - 0.5),
  207. self.floor(number + 0.5),
  208. )
  209. bitwise_and = _bitwise_op(operator.and_, z3.And)
  210. bitwise_or = _bitwise_op(operator.or_, z3.Or)
  211. lshift = _bitwise_op(operator.lshift, None)
  212. rshift = _bitwise_op(operator.rshift, None)
  213. # Lifts a callable to be used in Z3.
  214. #
  215. # This function replaces the given 'op' by a function that:
  216. #
  217. # 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3)
  218. #
  219. # 2. Calls an operation that corresponds to 'op', but works with Z3
  220. # inhabitants (left as is if it works as is)
  221. def z3op(op: Callable, validator: "TranslationValidator") -> Callable:
  222. # Operations that have booleans as their argument.
  223. # This is needed because the argument of some FX nodes were
  224. # literal integers, instead of booleans. So, whenever this flag
  225. # is set, we also convert ints to booleans.
  226. boolean_ops = {operator.not_}
  227. as_bool = op in boolean_ops
  228. # Lifts the function into 'z3.ExprRef' domain.
  229. def lift(func):
  230. def wrap(a) -> z3.ExprRef:
  231. if isinstance(a, (z3.ArithRef, z3.BoolRef)):
  232. return a
  233. # Convert it into a Z3 value, if it is some of the supported
  234. # types below.
  235. if isinstance(a, bool) or (as_bool and isinstance(a, int)):
  236. return z3.BoolVal(bool(a))
  237. if isinstance(a, (int, sympy.Integer)):
  238. return z3.IntVal(int(a))
  239. if isinstance(a, (float, sympy.Float)):
  240. return z3.RealVal(float(a))
  241. raise ValueError(f"can't lift type: {type(a)}")
  242. @functools.wraps(func)
  243. def wrapper(*args):
  244. # Lifts the arguments into a list of Z3 inhabitants.
  245. if len(args) == 1 and isinstance(args[0], (list, tuple)):
  246. wrapped_args = (tuple(wrap(a) for a in args[0]),)
  247. else:
  248. wrapped_args = tuple(wrap(a) for a in args)
  249. # Run the function on the Z3 expressions.
  250. return func(*wrapped_args)
  251. return wrapper
  252. ops = _Z3Ops(validator)
  253. replacement_map = {
  254. # Operator module.
  255. operator.not_: lift(z3.Not),
  256. operator.and_: lift(ops.bitwise_and),
  257. operator.or_: lift(ops.bitwise_or),
  258. operator.lshift: lift(ops.lshift),
  259. operator.rshift: lift(ops.rshift),
  260. operator.floordiv: lift(ops.floordiv),
  261. operator.truediv: lift(ops.div),
  262. operator.mod: lift(ops.mod),
  263. operator.abs: lift(ops.abs),
  264. builtins.round: lift(ops.round_to_int),
  265. # Math module.
  266. math.ceil: lift(ops.ceil),
  267. math.floor: lift(ops.floor),
  268. math.trunc: lift(ops.trunc),
  269. # Torch module.
  270. torch.sym_float: lift(ops.to_real),
  271. torch.sym_max: lift(ops.max),
  272. torch.sym_min: lift(ops.min),
  273. torch.sym_sum: lift(ops.sym_sum),
  274. torch.sym_ite: lift(lambda b, t, f: t if b else f),
  275. torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined]
  276. # Not lifted because we only use this function as a
  277. # marker for adding the expression as validator input.
  278. torch._assert: torch._assert,
  279. }
  280. return replacement_map[op] if op in replacement_map else lift(op)
  281. # Processes an FX graph, populating the given validator.
  282. #
  283. # [Note: PopulateValidator]
  284. # This class walks through each node in the FX graph, translating
  285. # them into the Z3 world.
  286. #
  287. # Then, whenever it finds an 'torch._assert' call_function operation,
  288. # it adds the Z3 expression corresponding to the argument as validator
  289. # input.
  290. class PopulateValidator(torch.fx.Interpreter):
  291. def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"):
  292. # Reference to the translation validator.
  293. self.validator = validator
  294. # Build the graph module and call `Interpreter` constructor.
  295. module = torch.fx.GraphModule(root={}, graph=graph)
  296. super().__init__(module, garbage_collect_values=True)
  297. def placeholder(
  298. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  299. ) -> Any:
  300. symbol = fx_traceback.get_current_meta()["symbol"]
  301. return self.validator.z3var(symbol)
  302. def call_function(
  303. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  304. ) -> Any:
  305. if target != torch._assert:
  306. # Lift and runs the node target function
  307. return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type]
  308. # Adds the Z3 expression corresponding to the first argument
  309. # as a validator input.
  310. assert len(args) == 1, (
  311. f"expected 1 argument on assertion. Got: {len(args)} "
  312. )
  313. self.validator.add_source_expr(args[0]) # type: ignore[arg-type]
  314. # Translates SymPy expressions into Z3 expressions.
  315. #
  316. # [Note: SympyToZ3]
  317. # At the time of the translation, all free variables present in the
  318. # SymPy expression being translated must be already mapped to a Z3
  319. # integer variable.
  320. class SympyToZ3:
  321. OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"}
  322. def __init__(
  323. self,
  324. validator: "TranslationValidator",
  325. ) -> None:
  326. self._validator = validator
  327. self._ops = _Z3Ops(self._validator)
  328. def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
  329. # TODO: Probably OK to relax this and allow lower precision
  330. if dtype is torch.int64:
  331. return z3.IntVal(int(value))
  332. if dtype is torch.double:
  333. return z3.RealVal(float(value))
  334. if dtype is torch.bool:
  335. return z3.BoolVal(bool(value))
  336. raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
  337. def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  338. if dtype == torch.float64:
  339. return z3.ToReal(x)
  340. raise NotImplementedError(f"to_dtype {dtype} NYI")
  341. def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  342. return z3.ToInt(x)
  343. def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  344. return self._ops.round_to_int(x)
  345. def int_truediv(
  346. self, numerator: z3.ArithRef, denominator: z3.ArithRef
  347. ) -> z3.ArithRef:
  348. return self._ops.div(numerator, denominator)
  349. def truediv(
  350. self, numerator: z3.ArithRef, denominator: z3.ArithRef
  351. ) -> z3.ArithRef:
  352. return self._ops.div(numerator, denominator)
  353. def floordiv(
  354. self, numerator: z3.ArithRef, denominator: z3.ArithRef
  355. ) -> z3.ArithRef:
  356. return self._ops.floordiv(numerator, denominator)
  357. def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
  358. return self._ops.floordiv(numerator, denominator)
  359. def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
  360. return self._ops.pow(base, exp)
  361. def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
  362. return self._ops.pow(base, exp)
  363. def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
  364. return self._ops.mod(p, q)
  365. def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  366. return self._ops.ceil(x)
  367. def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
  368. return self._ops.floor(x)
  369. def __getattr__(self, name: str) -> Any:
  370. REPLACEMENT = {
  371. "and_": z3.And,
  372. "or_": z3.Or,
  373. "not_": z3.Not,
  374. "bitwise_and": self._ops.bitwise_and,
  375. "bitwise_or": self._ops.bitwise_or,
  376. "lshift": self._ops.lshift,
  377. "rshift": self._ops.rshift,
  378. "floor": self._ops.floor,
  379. "ceil": self._ops.ceil,
  380. "minimum": self._ops.min,
  381. "maximum": self._ops.max,
  382. }
  383. if name in REPLACEMENT:
  384. return REPLACEMENT[name]
  385. if name in self.OPERATOR_HANDLES:
  386. return getattr(operator, name)
  387. raise AttributeError(f"unhandled operator: {name}")
  388. def run(self, expr: sympy.Basic) -> z3.ExprRef:
  389. return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type]
  390. # Dynamo guards translation validator.
  391. #
  392. # [Note: TranslationValidator]
  393. # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound.
  394. # That is: whether those (target) guards only yield TRUE whenever the original,
  395. # unoptimized, (source) guards yield TRUE.
  396. #
  397. # More concretely, given 'source' and 'target' guard expressions, we wish to
  398. # check whether the following expression holds:
  399. #
  400. # Not(And(source)) AND And(target)
  401. #
  402. # i.e. whether there is an assignment of the free variables where the opposite
  403. # happens: target is TRUE, but source is FALSE.
  404. class TranslationValidator:
  405. def __init__(self) -> None:
  406. log.debug("new instance")
  407. # Mapping of SymPy symbols to Z3 variables.
  408. self.symbols: dict[sympy.Symbol, z3.ExprRef] = {}
  409. # Set of source Z3 expressions.
  410. # They represent the generated guards without any kind of
  411. # simplification or transformation.
  412. self._source_exprs: set[z3.BoolRef] = set()
  413. # Set of target Z3 expressions.
  414. # They represent the actual checked guards at runtime. They might
  415. # be simplified or transformed versions of the source guards.
  416. self._target_exprs: set[z3.BoolRef] = set()
  417. # Set of Z3 expressions representing assertions over both the
  418. # source and target expressions.
  419. self._assertions: set[z3.BoolRef] = set()
  420. # Retrieves the corresponding Z3 variable.
  421. def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
  422. assert symbol in self.symbols, f"Z3 variable not found for: {symbol}"
  423. return self.symbols[symbol]
  424. # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
  425. def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef:
  426. if symbol in self.symbols:
  427. return self.symbols[symbol]
  428. log.debug("new variable: %s (%s)", symbol.name, type.__name__)
  429. if type is int:
  430. var = z3.Int(symbol.name)
  431. # If 'symbol' is positive (SymPy assumption), we have to
  432. # convey it to Z3 as well.
  433. if symbol.is_positive: # type: ignore[attr-defined]
  434. self._target_exprs.add(var > 0)
  435. elif type is float:
  436. var = z3.Real(symbol.name)
  437. elif type is bool:
  438. var = z3.Bool(symbol.name)
  439. else:
  440. raise RuntimeError(f"unsupported type for Z3 variable: {type}")
  441. self.symbols[symbol] = var
  442. return var
  443. # Checks whether all symbols were already added.
  444. def _check_freesymbols(self, e: sympy.Basic) -> None:
  445. for s in e.free_symbols:
  446. assert isinstance(s, sympy.Symbol)
  447. # Call 'z3var' just to check whether there's already a
  448. # Z3 variable corresponding to 's'.
  449. self.z3var(s)
  450. def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
  451. z3expr = SympyToZ3(self).run(e)
  452. assert isinstance(z3expr, z3.BoolRef), (
  453. f"expected boolean expression. Got: {z3expr}"
  454. )
  455. return z3expr
  456. def add_source_expr(self, e: z3.BoolRef) -> None:
  457. if e not in self._source_exprs:
  458. log.debug("add source guard: %s", z3str(e))
  459. self._source_exprs.add(e)
  460. def add_target_expr(self, e: "sympy.logic.boolalg.Boolean") -> None:
  461. self._check_freesymbols(e)
  462. z3expr = self.to_z3_boolean_expr(e)
  463. if e not in self._target_exprs:
  464. log.debug("add target guard: %s", z3str(z3expr))
  465. self._target_exprs.add(z3expr)
  466. def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None:
  467. if isinstance(e, sympy.Basic):
  468. self._check_freesymbols(e)
  469. ref = self.to_z3_boolean_expr(e)
  470. else:
  471. ref = e
  472. assert isinstance(ref, z3.BoolRef)
  473. if ref not in self._assertions:
  474. log.debug("add assertion: %s", z3str(ref))
  475. self._assertions.add(ref)
  476. def validate(self) -> None:
  477. with dynamo_timed("TranslationValidator.validate"):
  478. return self._validate()
  479. def _validate(self) -> None:
  480. if len(self._source_exprs) == 0 or len(self._target_exprs) == 0:
  481. # If there are no source/target expressions, there's nothing we really
  482. # wish to prove. So, we just return.
  483. return None
  484. # Here, we use "QF_NRA" logic for the solver:
  485. # "Quantifier-free Non-linear Real Arithmetic".
  486. #
  487. # Most of the guards expressions have:
  488. # 1. arithmetic between integer and reals
  489. # 2. no quantifiers
  490. # 3. potentially non-linear.
  491. #
  492. # Although there's also "QF_NIRA" (mixed integer-real arithmetic),
  493. # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'.
  494. solver = z3.SolverFor("QF_NRA")
  495. # Set a timeout for finding a solution.
  496. solver.set(timeout=translation_validation_timeout())
  497. # Add all the assertions to the solver.
  498. for assertion in self._assertions:
  499. solver.add(assertion)
  500. # "Is there any case where it's TRUE for the target expressions,
  501. # but FALSE for the source expressions?"
  502. solver.add(z3.Not(z3.And(*self._source_exprs)))
  503. solver.add(*self._target_exprs)
  504. log.debug("translation validation: start")
  505. r = solver.check()
  506. if r == z3.sat:
  507. # Target expressions are unsound.
  508. # Log the found model and the source expressions that failed.
  509. model = solver.model()
  510. raise ValidationException(
  511. model,
  512. self._assertions,
  513. self._target_exprs,
  514. failed_source_exprs=[
  515. inp for inp in self._source_exprs if not model.evaluate(inp)
  516. ],
  517. )
  518. else:
  519. if r == z3.unknown:
  520. # Could not find a solution. It didn't fail, but it also
  521. # didn't succeed. Canceling the validation execution (keyboard
  522. # interrupt) also gets to this branch.
  523. log.warning(
  524. "translation validation: could not validate: got z3.unknown"
  525. )
  526. else:
  527. # Target expressions are sound.
  528. assert r == z3.unsat
  529. log.debug("translation validation: success")
  530. except ImportError:
  531. _HAS_Z3 = False
  532. __all__ = [
  533. "translation_validation_enabled",
  534. "translation_validation_timeout",
  535. "ValidationException",
  536. "BisectValidationException",
  537. ]
  538. else:
  539. _HAS_Z3 = True
  540. __all__ = [
  541. "z3str",
  542. "z3op",
  543. "PopulateValidator",
  544. "SympyToZ3",
  545. "TranslationValidator",
  546. "translation_validation_enabled",
  547. "translation_validation_timeout",
  548. "ValidationException",
  549. "BisectValidationException",
  550. ]
  551. from torch.fx.experimental import _config as config
  552. def translation_validation_enabled() -> bool:
  553. # Checks every time this function is called, in case the Dynamo
  554. # option is set, but Z3 is not installed.
  555. _assert_z3_installed_if_tv_set()
  556. return _HAS_Z3 and config.translation_validation
  557. def translation_validation_timeout() -> int:
  558. return config.translation_validation_timeout
  559. def _assert_z3_installed_if_tv_set():
  560. assert _HAS_Z3 or not config.translation_validation, (
  561. "translation validation requires Z3 package. Please, either install "
  562. "z3-solver or disable translation validation."
  563. )
  564. class ValidationException(TorchDynamoException):
  565. def __init__(self, model, assertions, target_exprs, failed_source_exprs):
  566. assert _HAS_Z3
  567. def symbolstr(sym) -> str:
  568. return f"{sym}: {model[sym]}"
  569. def joinlines(xs) -> str:
  570. return "\n".join(f" ==> {x}" for x in xs)
  571. model_str = joinlines(sorted(map(symbolstr, model)))
  572. assertions_str = joinlines(sorted(map(z3str, assertions)))
  573. target_exprs_str = joinlines(sorted(map(z3str, target_exprs)))
  574. failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs)))
  575. self.msg = "translation validation failed."
  576. self.details = f"""\
  577. Model:
  578. {model_str}
  579. Assertions:
  580. {assertions_str}
  581. Target Expressions:
  582. {target_exprs_str}
  583. Failed Source Expressions:
  584. {failed_source_exprs_str}"""
  585. def __str__(self):
  586. return f"{self.msg}\n\n{self.details}"
  587. class BisectValidationException(TorchDynamoException):
  588. def __init__(self, validation_exc, expr, failed_action, traced_node):
  589. self.msg = f"translation validation failed when {failed_action}: {expr}"
  590. self.details = f"""\
  591. Failure occurred while running node:
  592. {traced_node.format_node()}
  593. {validation_exc.details}"""
  594. def __str__(self):
  595. return f"{self.msg}\n\n{self.details}"
  596. # Checks when this module is loaded.
  597. _assert_z3_installed_if_tv_set()
  598. # Translation validation bisection.
  599. #
  600. # Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise
  601. # the earliest ValidationException.
  602. #
  603. # As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors
  604. # might be silently happening. This function tries to nail down exactly at which
  605. # point things went wrong from a validation perspective.
  606. def bisect(shape_env):
  607. from torch.fx.experimental.recording import (
  608. FakeTensorMeta,
  609. replay_shape_env_events,
  610. ShapeEnvEvent,
  611. )
  612. from torch.fx.experimental.symbolic_shapes import (
  613. CURRENT_NODE_KEY,
  614. ShapeEnv,
  615. SHAPEENV_EVENT_KEY,
  616. )
  617. events = shape_env.events
  618. # Retrieves the ShapeEnvEvent associated with node.
  619. def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent:
  620. assert SHAPEENV_EVENT_KEY in node.meta
  621. return events[node.meta[SHAPEENV_EVENT_KEY]]
  622. # Creates a new instance of fake, but updating every symbolic value's ShapeEnv
  623. # reference to the one given as argument.
  624. #
  625. # This is needed so as not to simplify a symbolic expression using a ShapeEnv
  626. # "from the future", where it may have a different set of replacements.
  627. def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any:
  628. if isinstance(fake, int):
  629. return fake
  630. if isinstance(fake, torch.SymInt):
  631. return torch.SymInt(fake.node.with_shape_env(shape_env))
  632. if isinstance(fake, torch.SymFloat):
  633. return torch.SymFloat(fake.node.with_shape_env(shape_env))
  634. assert isinstance(fake, FakeTensorMeta)
  635. return FakeTensorMeta(
  636. tuple(new_with_shape_env(shape_env, s) for s in fake.size()),
  637. tuple(new_with_shape_env(shape_env, s) for s in fake.stride()),
  638. new_with_shape_env(shape_env, fake.storage_offset()),
  639. fake.is_nested,
  640. )
  641. # Checks whether the given shape_env fails when produce_guards is called.
  642. def check_shapeenv_fails(
  643. shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]]
  644. ) -> Optional[ValidationException]:
  645. assert tracked_fakes is not None
  646. try:
  647. # This produce_guards call is a best-effort replication, since we
  648. # don't populate EqualityConstraint list. Reason: we would also have
  649. # to save OutputGraph.tracked_fakes_id_to_source.
  650. shape_env.produce_guards(
  651. [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
  652. [a.source for a in tracked_fakes],
  653. input_contexts=[a.symbolic_context for a in tracked_fakes],
  654. )
  655. return None
  656. except ValidationException as e:
  657. return e
  658. # Checks whether the ShapeEnv reconstructed by replaying the events until
  659. # node is created fails when produce_guards is called.
  660. def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]:
  661. number = node.meta[SHAPEENV_EVENT_KEY]
  662. # Reconstruct shape_env until the event at event_number.
  663. shape_env = replay_shape_env_events(events[: number + 1])
  664. shape_env.graph.lint()
  665. return check_shapeenv_fails(shape_env, events[number].tracked_fakes)
  666. last_exception = check_shapeenv_fails(
  667. shape_env, shape_env._snapshot_tracked_fakes()
  668. )
  669. if not last_exception:
  670. # We don't actually fail due to a produce_guards call.
  671. # Stop and don't bisect.
  672. log.info("translation validation succeeded: no errors found.")
  673. return
  674. if not shape_env.should_record_events or config.translation_validation_no_bisect:
  675. # Bisection is off.
  676. # Return the last ValidationException we got.
  677. raise last_exception
  678. # Cache the raised exception (if any) at each bisection point.
  679. exception = {}
  680. # Bisection happens on the assertion nodes of the recorded FX graph for
  681. # dynamic shapes.
  682. assert_nodes = [
  683. node for node in shape_env.graph.nodes if node.target == torch._assert
  684. ]
  685. # Preparing the indices for binary search.
  686. # The overall invariants are
  687. # - for all i < left, assert_node[i] doesn't fail
  688. # - for all i >= right, assert_node[i] fails
  689. # - `right in exception` always holds
  690. # - `left <= right` always holds
  691. left, mid, right = 0, 0, len(assert_nodes) - 1
  692. exception[right] = check_node_fails(assert_nodes[right])
  693. while left < right:
  694. mid = (left + right) // 2
  695. node = assert_nodes[mid]
  696. log.debug("bisecting at %s: %s", mid, get_node_event(node))
  697. # Check whether the new shape_env raises a ValidationException or not.
  698. exception[mid] = check_node_fails(node)
  699. if exception[mid]:
  700. right = mid
  701. else:
  702. left = mid + 1
  703. assert left in exception and isinstance(exception[left], ValidationException)
  704. node = assert_nodes[left]
  705. event = get_node_event(node)
  706. if event.is_evaluate_expr():
  707. failed_action = "evaluating"
  708. else:
  709. assert event.is_defer_runtime_assert(), f"unexpected event type: {event}"
  710. failed_action = "adding runtime assert"
  711. args = event.args
  712. assert args is not None
  713. assert len(args) >= 2, (
  714. f"bisecting expects {event.name} to have at least 2 positional arguments. "
  715. f"Got: {len(args)}"
  716. )
  717. assert isinstance(args[1], sympy.Basic), (
  718. f"bisecting expects {event.name} to have a SymPy expression as its second argument. "
  719. f"Got: {type(args[1])}"
  720. )
  721. raise BisectValidationException(
  722. exception[left],
  723. expr=args[1],
  724. failed_action=failed_action,
  725. traced_node=node.meta[CURRENT_NODE_KEY],
  726. )