lra_theory.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912
  1. """Implements "A Fast Linear-Arithmetic Solver for DPLL(T)"
  2. The LRASolver class defined in this file can be used
  3. in conjunction with a SAT solver to check the
  4. satisfiability of formulas involving inequalities.
  5. Here's an example of how that would work:
  6. Suppose you want to check the satisfiability of
  7. the following formula:
  8. >>> from sympy.core.relational import Eq
  9. >>> from sympy.abc import x, y
  10. >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & (~Eq(y, 1) | Eq(1, 2))
  11. First a preprocessing step should be done on f. During preprocessing,
  12. f should be checked for any predicates such as `Q.prime` that can't be
  13. handled. Also unequality like `~Eq(y, 1)` should be split.
  14. I should mention that the paper says to split both equalities and
  15. unequality, but this implementation only requires that unequality
  16. be split.
  17. >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & ((y < 1) | (y > 1) | Eq(1, 2))
  18. Then an LRASolver instance needs to be initialized with this formula.
  19. >>> from sympy.assumptions.cnf import CNF, EncodedCNF
  20. >>> from sympy.assumptions.ask import Q
  21. >>> from sympy.logic.algorithms.lra_theory import LRASolver
  22. >>> cnf = CNF.from_prop(f)
  23. >>> enc = EncodedCNF()
  24. >>> enc.add_from_cnf(cnf)
  25. >>> lra, conflicts = LRASolver.from_encoded_cnf(enc)
  26. Any immediate one-lital conflicts clauses will be detected here.
  27. In this example, `~Eq(1, 2)` is one such conflict clause. We'll
  28. want to add it to `f` so that the SAT solver is forced to
  29. assign Eq(1, 2) to False.
  30. >>> f = f & ~Eq(1, 2)
  31. Now that the one-literal conflict clauses have been added
  32. and an lra object has been initialized, we can pass `f`
  33. to a SAT solver. The SAT solver will give us a satisfying
  34. assignment such as:
  35. (1 = 2): False
  36. (y = 1): True
  37. (y < 1): True
  38. (y > 1): True
  39. (x = 0): True
  40. (x < 0): True
  41. (x > 0): True
  42. Next you would pass this assignment to the LRASolver
  43. which will be able to determine that this particular
  44. assignment is satisfiable or not.
  45. Note that since EncodedCNF is inherently non-deterministic,
  46. the int each predicate is encoded as is not consistent. As a
  47. result, the code below likely does not reflect the assignment
  48. given above.
  49. >>> lra.assert_lit(-1) #doctest: +SKIP
  50. >>> lra.assert_lit(2) #doctest: +SKIP
  51. >>> lra.assert_lit(3) #doctest: +SKIP
  52. >>> lra.assert_lit(4) #doctest: +SKIP
  53. >>> lra.assert_lit(5) #doctest: +SKIP
  54. >>> lra.assert_lit(6) #doctest: +SKIP
  55. >>> lra.assert_lit(7) #doctest: +SKIP
  56. >>> is_sat, conflict_or_assignment = lra.check()
  57. As the particular assignment suggested is not satisfiable,
  58. the LRASolver will return unsat and a conflict clause when
  59. given that assignment. The conflict clause will always be
  60. minimal, but there can be multiple minimal conflict clauses.
  61. One possible conflict clause could be `~(x < 0) | ~(x > 0)`.
  62. We would then add whatever conflict clause is given to
  63. `f` to prevent the SAT solver from coming up with an
  64. assignment with the same conflicting literals. In this case,
  65. the conflict clause `~(x < 0) | ~(x > 0)` would prevent
  66. any assignment where both (x < 0) and (x > 0) were both
  67. true.
  68. The SAT solver would then find another assignment
  69. and we would check that assignment with the LRASolver
  70. and so on. Eventually either a satisfying assignment
  71. that the SAT solver and LRASolver agreed on would be found
  72. or enough conflict clauses would be added so that the
  73. boolean formula was unsatisfiable.
  74. This implementation is based on [1]_, which includes a
  75. detailed explanation of the algorithm and pseudocode
  76. for the most important functions.
  77. [1]_ also explains how backtracking and theory propagation
  78. could be implemented to speed up the current implementation,
  79. but these are not currently implemented.
  80. TODO:
  81. - Handle non-rational real numbers
  82. - Handle positive and negative infinity
  83. - Implement backtracking and theory proposition
  84. - Simplify matrix by removing unused variables using Gaussian elimination
  85. References
  86. ==========
  87. .. [1] Dutertre, B., de Moura, L.:
  88. A Fast Linear-Arithmetic Solver for DPLL(T)
  89. https://link.springer.com/chapter/10.1007/11817963_11
  90. """
  91. from sympy.solvers.solveset import linear_eq_to_matrix
  92. from sympy.matrices.dense import eye
  93. from sympy.assumptions import Predicate
  94. from sympy.assumptions.assume import AppliedPredicate
  95. from sympy.assumptions.ask import Q
  96. from sympy.core import Dummy
  97. from sympy.core.mul import Mul
  98. from sympy.core.add import Add
  99. from sympy.core.relational import Eq, Ne
  100. from sympy.core.sympify import sympify
  101. from sympy.core.singleton import S
  102. from sympy.core.numbers import Rational, oo
  103. from sympy.matrices.dense import Matrix
  104. class UnhandledInput(Exception):
  105. """
  106. Raised while creating an LRASolver if non-linearity
  107. or non-rational numbers are present.
  108. """
  109. # predicates that LRASolver understands and makes use of
  110. ALLOWED_PRED = {Q.eq, Q.gt, Q.lt, Q.le, Q.ge}
  111. # if true ~Q.gt(x, y) implies Q.le(x, y)
  112. HANDLE_NEGATION = True
  113. class LRASolver():
  114. """
  115. Linear Arithmetic Solver for DPLL(T) implemented with an algorithm based on
  116. the Dual Simplex method. Uses Bland's pivoting rule to avoid cycling.
  117. References
  118. ==========
  119. .. [1] Dutertre, B., de Moura, L.:
  120. A Fast Linear-Arithmetic Solver for DPLL(T)
  121. https://link.springer.com/chapter/10.1007/11817963_11
  122. """
  123. def __init__(self, A, slack_variables, nonslack_variables, enc_to_boundary, s_subs, testing_mode):
  124. """
  125. Use the "from_encoded_cnf" method to create a new LRASolver.
  126. """
  127. self.run_checks = testing_mode
  128. self.s_subs = s_subs # used only for test_lra_theory.test_random_problems
  129. if any(not isinstance(a, Rational) for a in A):
  130. raise UnhandledInput("Non-rational numbers are not handled")
  131. if any(not isinstance(b.bound, Rational) for b in enc_to_boundary.values()):
  132. raise UnhandledInput("Non-rational numbers are not handled")
  133. m, n = len(slack_variables), len(slack_variables)+len(nonslack_variables)
  134. if m != 0:
  135. assert A.shape == (m, n)
  136. if self.run_checks:
  137. assert A[:, n-m:] == -eye(m)
  138. self.enc_to_boundary = enc_to_boundary # mapping of int to Boundary objects
  139. self.boundary_to_enc = {value: key for key, value in enc_to_boundary.items()}
  140. self.A = A
  141. self.slack = slack_variables
  142. self.nonslack = nonslack_variables
  143. self.all_var = nonslack_variables + slack_variables
  144. self.slack_set = set(slack_variables)
  145. self.is_sat = True # While True, all constraints asserted so far are satisfiable
  146. self.result = None # always one of: (True, assignment), (False, conflict clause), None
  147. @staticmethod
  148. def from_encoded_cnf(encoded_cnf, testing_mode=False):
  149. """
  150. Creates an LRASolver from an EncodedCNF object
  151. and a list of conflict clauses for propositions
  152. that can be simplified to True or False.
  153. Parameters
  154. ==========
  155. encoded_cnf : EncodedCNF
  156. testing_mode : bool
  157. Setting testing_mode to True enables some slow assert statements
  158. and sorting to reduce nonterministic behavior.
  159. Returns
  160. =======
  161. (lra, conflicts)
  162. lra : LRASolver
  163. conflicts : list
  164. Contains a one-literal conflict clause for each proposition
  165. that can be simplified to True or False.
  166. Example
  167. =======
  168. >>> from sympy.core.relational import Eq
  169. >>> from sympy.assumptions.cnf import CNF, EncodedCNF
  170. >>> from sympy.assumptions.ask import Q
  171. >>> from sympy.logic.algorithms.lra_theory import LRASolver
  172. >>> from sympy.abc import x, y, z
  173. >>> phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6))
  174. >>> phi = phi & (Eq(x + y, 2) | (x + 2 * y - z > 4))
  175. >>> phi = phi & Q.gt(2, 1)
  176. >>> cnf = CNF.from_prop(phi)
  177. >>> enc = EncodedCNF()
  178. >>> enc.from_cnf(cnf)
  179. >>> lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True)
  180. >>> lra #doctest: +SKIP
  181. <sympy.logic.algorithms.lra_theory.LRASolver object at 0x7fdcb0e15b70>
  182. >>> conflicts #doctest: +SKIP
  183. [[4]]
  184. """
  185. # This function has three main jobs:
  186. # - raise errors if the input formula is not handled
  187. # - preprocesses the formula into a matrix and single variable constraints
  188. # - create one-literal conflict clauses from predicates that are always True
  189. # or always False such as Q.gt(3, 2)
  190. #
  191. # See the preprocessing section of "A Fast Linear-Arithmetic Solver for DPLL(T)"
  192. # for an explanation of how the formula is converted into a matrix
  193. # and a set of single variable constraints.
  194. encoding = {} # maps int to boundary
  195. A = []
  196. basic = []
  197. s_count = 0
  198. s_subs = {}
  199. nonbasic = []
  200. if testing_mode:
  201. # sort to reduce nondeterminism
  202. encoded_cnf_items = sorted(encoded_cnf.encoding.items(), key=lambda x: str(x))
  203. else:
  204. encoded_cnf_items = encoded_cnf.encoding.items()
  205. empty_var = Dummy()
  206. var_to_lra_var = {}
  207. conflicts = []
  208. for prop, enc in encoded_cnf_items:
  209. if isinstance(prop, Predicate):
  210. prop = prop(empty_var)
  211. if not isinstance(prop, AppliedPredicate):
  212. if prop == True:
  213. conflicts.append([enc])
  214. continue
  215. if prop == False:
  216. conflicts.append([-enc])
  217. continue
  218. raise ValueError(f"Unhandled Predicate: {prop}")
  219. assert prop.function in ALLOWED_PRED
  220. if prop.lhs == S.NaN or prop.rhs == S.NaN:
  221. raise ValueError(f"{prop} contains nan")
  222. if prop.lhs.is_imaginary or prop.rhs.is_imaginary:
  223. raise UnhandledInput(f"{prop} contains an imaginary component")
  224. if prop.lhs == oo or prop.rhs == oo:
  225. raise UnhandledInput(f"{prop} contains infinity")
  226. prop = _eval_binrel(prop) # simplify variable-less quantities to True / False if possible
  227. if prop == True:
  228. conflicts.append([enc])
  229. continue
  230. elif prop == False:
  231. conflicts.append([-enc])
  232. continue
  233. elif prop is None:
  234. raise UnhandledInput(f"{prop} could not be simplified")
  235. expr = prop.lhs - prop.rhs
  236. if prop.function in [Q.ge, Q.gt]:
  237. expr = -expr
  238. # expr should be less than (or equal to) 0
  239. # otherwise prop is False
  240. if prop.function in [Q.le, Q.ge]:
  241. bool = (expr <= 0)
  242. elif prop.function in [Q.lt, Q.gt]:
  243. bool = (expr < 0)
  244. else:
  245. assert prop.function == Q.eq
  246. bool = Eq(expr, 0)
  247. if bool == True:
  248. conflicts.append([enc])
  249. continue
  250. elif bool == False:
  251. conflicts.append([-enc])
  252. continue
  253. vars, const = _sep_const_terms(expr) # example: (2x + 3y + 2) --> (2x + 3y), (2)
  254. vars, var_coeff = _sep_const_coeff(vars) # examples: (2x) --> (x, 2); (2x + 3y) --> (2x + 3y), (1)
  255. const = const / var_coeff
  256. terms = _list_terms(vars) # example: (2x + 3y) --> [2x, 3y]
  257. for term in terms:
  258. term, _ = _sep_const_coeff(term)
  259. assert len(term.free_symbols) > 0
  260. if term not in var_to_lra_var:
  261. var_to_lra_var[term] = LRAVariable(term)
  262. nonbasic.append(term)
  263. if len(terms) > 1:
  264. if vars not in s_subs:
  265. s_count += 1
  266. d = Dummy(f"s{s_count}")
  267. var_to_lra_var[d] = LRAVariable(d)
  268. basic.append(d)
  269. s_subs[vars] = d
  270. A.append(vars - d)
  271. var = s_subs[vars]
  272. else:
  273. var = terms[0]
  274. assert var_coeff != 0
  275. equality = prop.function == Q.eq
  276. upper = var_coeff > 0 if not equality else None
  277. strict = prop.function in [Q.gt, Q.lt]
  278. b = Boundary(var_to_lra_var[var], -const, upper, equality, strict)
  279. encoding[enc] = b
  280. fs = [v.free_symbols for v in nonbasic + basic]
  281. assert all(len(syms) > 0 for syms in fs)
  282. fs_count = sum(len(syms) for syms in fs)
  283. if len(fs) > 0 and len(set.union(*fs)) < fs_count:
  284. raise UnhandledInput("Nonlinearity is not handled")
  285. A, _ = linear_eq_to_matrix(A, nonbasic + basic)
  286. nonbasic = [var_to_lra_var[nb] for nb in nonbasic]
  287. basic = [var_to_lra_var[b] for b in basic]
  288. for idx, var in enumerate(nonbasic + basic):
  289. var.col_idx = idx
  290. return LRASolver(A, basic, nonbasic, encoding, s_subs, testing_mode), conflicts
  291. def reset_bounds(self):
  292. """
  293. Resets the state of the LRASolver to before
  294. anything was asserted.
  295. """
  296. self.result = None
  297. for var in self.all_var:
  298. var.lower = LRARational(-float("inf"), 0)
  299. var.lower_from_eq = False
  300. var.lower_from_neg = False
  301. var.upper = LRARational(float("inf"), 0)
  302. var.upper_from_eq= False
  303. var.lower_from_neg = False
  304. var.assign = LRARational(0, 0)
  305. def assert_lit(self, enc_constraint):
  306. """
  307. Assert a literal representing a constraint
  308. and update the internal state accordingly.
  309. Note that due to peculiarities of this implementation
  310. asserting ~(x > 0) will assert (x <= 0) but asserting
  311. ~Eq(x, 0) will not do anything.
  312. Parameters
  313. ==========
  314. enc_constraint : int
  315. A mapping of encodings to constraints
  316. can be found in `self.enc_to_boundary`.
  317. Returns
  318. =======
  319. None or (False, explanation)
  320. explanation : set of ints
  321. A conflict clause that "explains" why
  322. the literals asserted so far are unsatisfiable.
  323. """
  324. if abs(enc_constraint) not in self.enc_to_boundary:
  325. return None
  326. if not HANDLE_NEGATION and enc_constraint < 0:
  327. return None
  328. boundary = self.enc_to_boundary[abs(enc_constraint)]
  329. sym, c, negated = boundary.var, boundary.bound, enc_constraint < 0
  330. if boundary.equality and negated:
  331. return None # negated equality is not handled and should only appear in conflict clauses
  332. upper = boundary.upper != negated
  333. if boundary.strict != negated:
  334. delta = -1 if upper else 1
  335. c = LRARational(c, delta)
  336. else:
  337. c = LRARational(c, 0)
  338. if boundary.equality:
  339. res1 = self._assert_lower(sym, c, from_equality=True, from_neg=negated)
  340. if res1 and res1[0] == False:
  341. res = res1
  342. else:
  343. res2 = self._assert_upper(sym, c, from_equality=True, from_neg=negated)
  344. res = res2
  345. elif upper:
  346. res = self._assert_upper(sym, c, from_neg=negated)
  347. else:
  348. res = self._assert_lower(sym, c, from_neg=negated)
  349. if self.is_sat and sym not in self.slack_set:
  350. self.is_sat = res is None
  351. else:
  352. self.is_sat = False
  353. return res
  354. def _assert_upper(self, xi, ci, from_equality=False, from_neg=False):
  355. """
  356. Adjusts the upper bound on variable xi if the new upper bound is
  357. more limiting. The assignment of variable xi is adjusted to be
  358. within the new bound if needed.
  359. Also calls `self._update` to update the assignment for slack variables
  360. to keep all equalities satisfied.
  361. """
  362. if self.result:
  363. assert self.result[0] != False
  364. self.result = None
  365. if ci >= xi.upper:
  366. return None
  367. if ci < xi.lower:
  368. assert (xi.lower[1] >= 0) is True
  369. assert (ci[1] <= 0) is True
  370. lit1, neg1 = Boundary.from_lower(xi)
  371. lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=True, equality=from_equality)
  372. if from_neg:
  373. lit2 = lit2.get_negated()
  374. neg2 = -1 if from_neg else 1
  375. conflict = [-neg1*self.boundary_to_enc[lit1], -neg2*self.boundary_to_enc[lit2]]
  376. self.result = False, conflict
  377. return self.result
  378. xi.upper = ci
  379. xi.upper_from_eq = from_equality
  380. xi.upper_from_neg = from_neg
  381. if xi in self.nonslack and xi.assign > ci:
  382. self._update(xi, ci)
  383. if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf")
  384. for v in self.all_var):
  385. M = self.A
  386. X = Matrix([v.assign[0] for v in self.all_var])
  387. assert all(abs(val) < 10 ** (-10) for val in M * X)
  388. return None
  389. def _assert_lower(self, xi, ci, from_equality=False, from_neg=False):
  390. """
  391. Adjusts the lower bound on variable xi if the new lower bound is
  392. more limiting. The assignment of variable xi is adjusted to be
  393. within the new bound if needed.
  394. Also calls `self._update` to update the assignment for slack variables
  395. to keep all equalities satisfied.
  396. """
  397. if self.result:
  398. assert self.result[0] != False
  399. self.result = None
  400. if ci <= xi.lower:
  401. return None
  402. if ci > xi.upper:
  403. assert (xi.upper[1] <= 0) is True
  404. assert (ci[1] >= 0) is True
  405. lit1, neg1 = Boundary.from_upper(xi)
  406. lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=False, equality=from_equality)
  407. if from_neg:
  408. lit2 = lit2.get_negated()
  409. neg2 = -1 if from_neg else 1
  410. conflict = [-neg1*self.boundary_to_enc[lit1],-neg2*self.boundary_to_enc[lit2]]
  411. self.result = False, conflict
  412. return self.result
  413. xi.lower = ci
  414. xi.lower_from_eq = from_equality
  415. xi.lower_from_neg = from_neg
  416. if xi in self.nonslack and xi.assign < ci:
  417. self._update(xi, ci)
  418. if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf")
  419. for v in self.all_var):
  420. M = self.A
  421. X = Matrix([v.assign[0] for v in self.all_var])
  422. assert all(abs(val) < 10 ** (-10) for val in M * X)
  423. return None
  424. def _update(self, xi, v):
  425. """
  426. Updates all slack variables that have equations that contain
  427. variable xi so that they stay satisfied given xi is equal to v.
  428. """
  429. i = xi.col_idx
  430. for j, b in enumerate(self.slack):
  431. aji = self.A[j, i]
  432. b.assign = b.assign + (v - xi.assign)*aji
  433. xi.assign = v
  434. def check(self):
  435. """
  436. Searches for an assignment that satisfies all constraints
  437. or determines that no such assignment exists and gives
  438. a minimal conflict clause that "explains" why the
  439. constraints are unsatisfiable.
  440. Returns
  441. =======
  442. (True, assignment) or (False, explanation)
  443. assignment : dict of LRAVariables to values
  444. Assigned values are tuples that represent a rational number
  445. plus some infinatesimal delta.
  446. explanation : set of ints
  447. """
  448. if self.is_sat:
  449. return True, {var: var.assign for var in self.all_var}
  450. if self.result:
  451. return self.result
  452. from sympy.matrices.dense import Matrix
  453. M = self.A.copy()
  454. basic = {s: i for i, s in enumerate(self.slack)} # contains the row index associated with each basic variable
  455. nonbasic = set(self.nonslack)
  456. while True:
  457. if self.run_checks:
  458. # nonbasic variables must always be within bounds
  459. assert all(((nb.assign >= nb.lower) == True) and ((nb.assign <= nb.upper) == True) for nb in nonbasic)
  460. # assignments for x must always satisfy Ax = 0
  461. # probably have to turn this off when dealing with strict ineq
  462. if all(v.assign[0] != float("inf") and v.assign[0] != -float("inf")
  463. for v in self.all_var):
  464. X = Matrix([v.assign[0] for v in self.all_var])
  465. assert all(abs(val) < 10**(-10) for val in M*X)
  466. # check upper and lower match this format:
  467. # x <= rat + delta iff x < rat
  468. # x >= rat - delta iff x > rat
  469. # this wouldn't make sense:
  470. # x <= rat - delta
  471. # x >= rat + delta
  472. assert all(x.upper[1] <= 0 for x in self.all_var)
  473. assert all(x.lower[1] >= 0 for x in self.all_var)
  474. cand = [b for b in basic if b.assign < b.lower or b.assign > b.upper]
  475. if len(cand) == 0:
  476. return True, {var: var.assign for var in self.all_var}
  477. xi = min(cand, key=lambda v: v.col_idx) # Bland's rule
  478. i = basic[xi]
  479. if xi.assign < xi.lower:
  480. cand = [nb for nb in nonbasic
  481. if (M[i, nb.col_idx] > 0 and nb.assign < nb.upper)
  482. or (M[i, nb.col_idx] < 0 and nb.assign > nb.lower)]
  483. if len(cand) == 0:
  484. N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0]
  485. N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0]
  486. conflict = []
  487. conflict += [Boundary.from_upper(nb) for nb in N_plus]
  488. conflict += [Boundary.from_lower(nb) for nb in N_minus]
  489. conflict.append(Boundary.from_lower(xi))
  490. conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict]
  491. return False, conflict
  492. xj = min(cand, key=str)
  493. M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.lower)
  494. if xi.assign > xi.upper:
  495. cand = [nb for nb in nonbasic
  496. if (M[i, nb.col_idx] < 0 and nb.assign < nb.upper)
  497. or (M[i, nb.col_idx] > 0 and nb.assign > nb.lower)]
  498. if len(cand) == 0:
  499. N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0]
  500. N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0]
  501. conflict = []
  502. conflict += [Boundary.from_upper(nb) for nb in N_minus]
  503. conflict += [Boundary.from_lower(nb) for nb in N_plus]
  504. conflict.append(Boundary.from_upper(xi))
  505. conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict]
  506. return False, conflict
  507. xj = min(cand, key=lambda v: v.col_idx)
  508. M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.upper)
  509. def _pivot_and_update(self, M, basic, nonbasic, xi, xj, v):
  510. """
  511. Pivots basic variable xi with nonbasic variable xj,
  512. and sets value of xi to v and adjusts the values of all basic variables
  513. to keep equations satisfied.
  514. """
  515. i, j = basic[xi], xj.col_idx
  516. assert M[i, j] != 0
  517. theta = (v - xi.assign)*(1/M[i, j])
  518. xi.assign = v
  519. xj.assign = xj.assign + theta
  520. for xk in basic:
  521. if xk != xi:
  522. k = basic[xk]
  523. akj = M[k, j]
  524. xk.assign = xk.assign + theta*akj
  525. # pivot
  526. basic[xj] = basic[xi]
  527. del basic[xi]
  528. nonbasic.add(xi)
  529. nonbasic.remove(xj)
  530. return self._pivot(M, i, j)
  531. @staticmethod
  532. def _pivot(M, i, j):
  533. """
  534. Performs a pivot operation about entry i, j of M by performing
  535. a series of row operations on a copy of M and returning the result.
  536. The original M is left unmodified.
  537. Conceptually, M represents a system of equations and pivoting
  538. can be thought of as rearranging equation i to be in terms of
  539. variable j and then substituting in the rest of the equations
  540. to get rid of other occurances of variable j.
  541. Example
  542. =======
  543. >>> from sympy.matrices.dense import Matrix
  544. >>> from sympy.logic.algorithms.lra_theory import LRASolver
  545. >>> from sympy import var
  546. >>> Matrix(3, 3, var('a:i'))
  547. Matrix([
  548. [a, b, c],
  549. [d, e, f],
  550. [g, h, i]])
  551. This matrix is equivalent to:
  552. 0 = a*x + b*y + c*z
  553. 0 = d*x + e*y + f*z
  554. 0 = g*x + h*y + i*z
  555. >>> LRASolver._pivot(_, 1, 0)
  556. Matrix([
  557. [ 0, -a*e/d + b, -a*f/d + c],
  558. [-1, -e/d, -f/d],
  559. [ 0, h - e*g/d, i - f*g/d]])
  560. We rearrange equation 1 in terms of variable 0 (x)
  561. and substitute to remove x from the other equations.
  562. 0 = 0 + (-a*e/d + b)*y + (-a*f/d + c)*z
  563. 0 = -x + (-e/d)*y + (-f/d)*z
  564. 0 = 0 + (h - e*g/d)*y + (i - f*g/d)*z
  565. """
  566. _, _, Mij = M[i, :], M[:, j], M[i, j]
  567. if Mij == 0:
  568. raise ZeroDivisionError("Tried to pivot about zero-valued entry.")
  569. A = M.copy()
  570. A[i, :] = -A[i, :]/Mij
  571. for row in range(M.shape[0]):
  572. if row != i:
  573. A[row, :] = A[row, :] + A[row, j] * A[i, :]
  574. return A
  575. def _sep_const_coeff(expr):
  576. """
  577. Example
  578. =======
  579. >>> from sympy.logic.algorithms.lra_theory import _sep_const_coeff
  580. >>> from sympy.abc import x, y
  581. >>> _sep_const_coeff(2*x)
  582. (x, 2)
  583. >>> _sep_const_coeff(2*x + 3*y)
  584. (2*x + 3*y, 1)
  585. """
  586. if isinstance(expr, Add):
  587. return expr, sympify(1)
  588. if isinstance(expr, Mul):
  589. coeffs = expr.args
  590. else:
  591. coeffs = [expr]
  592. var, const = [], []
  593. for c in coeffs:
  594. c = sympify(c)
  595. if len(c.free_symbols)==0:
  596. const.append(c)
  597. else:
  598. var.append(c)
  599. return Mul(*var), Mul(*const)
  600. def _list_terms(expr):
  601. if not isinstance(expr, Add):
  602. return [expr]
  603. return expr.args
  604. def _sep_const_terms(expr):
  605. """
  606. Example
  607. =======
  608. >>> from sympy.logic.algorithms.lra_theory import _sep_const_terms
  609. >>> from sympy.abc import x, y
  610. >>> _sep_const_terms(2*x + 3*y + 2)
  611. (2*x + 3*y, 2)
  612. """
  613. if isinstance(expr, Add):
  614. terms = expr.args
  615. else:
  616. terms = [expr]
  617. var, const = [], []
  618. for t in terms:
  619. if len(t.free_symbols) == 0:
  620. const.append(t)
  621. else:
  622. var.append(t)
  623. return sum(var), sum(const)
  624. def _eval_binrel(binrel):
  625. """
  626. Simplify binary relation to True / False if possible.
  627. """
  628. if not (len(binrel.lhs.free_symbols) == 0 and len(binrel.rhs.free_symbols) == 0):
  629. return binrel
  630. if binrel.function == Q.lt:
  631. res = binrel.lhs < binrel.rhs
  632. elif binrel.function == Q.gt:
  633. res = binrel.lhs > binrel.rhs
  634. elif binrel.function == Q.le:
  635. res = binrel.lhs <= binrel.rhs
  636. elif binrel.function == Q.ge:
  637. res = binrel.lhs >= binrel.rhs
  638. elif binrel.function == Q.eq:
  639. res = Eq(binrel.lhs, binrel.rhs)
  640. elif binrel.function == Q.ne:
  641. res = Ne(binrel.lhs, binrel.rhs)
  642. if res == True or res == False:
  643. return res
  644. else:
  645. return None
  646. class Boundary:
  647. """
  648. Represents an upper or lower bound or an equality between a symbol
  649. and some constant.
  650. """
  651. def __init__(self, var, const, upper, equality, strict=None):
  652. if not equality in [True, False]:
  653. assert equality in [True, False]
  654. self.var = var
  655. if isinstance(const, tuple):
  656. s = const[1] != 0
  657. if strict:
  658. assert s == strict
  659. self.bound = const[0]
  660. self.strict = s
  661. else:
  662. self.bound = const
  663. self.strict = strict
  664. self.upper = upper if not equality else None
  665. self.equality = equality
  666. self.strict = strict
  667. assert self.strict is not None
  668. @staticmethod
  669. def from_upper(var):
  670. neg = -1 if var.upper_from_neg else 1
  671. b = Boundary(var, var.upper[0], True, var.upper_from_eq, var.upper[1] != 0)
  672. if neg < 0:
  673. b = b.get_negated()
  674. return b, neg
  675. @staticmethod
  676. def from_lower(var):
  677. neg = -1 if var.lower_from_neg else 1
  678. b = Boundary(var, var.lower[0], False, var.lower_from_eq, var.lower[1] != 0)
  679. if neg < 0:
  680. b = b.get_negated()
  681. return b, neg
  682. def get_negated(self):
  683. return Boundary(self.var, self.bound, not self.upper, self.equality, not self.strict)
  684. def get_inequality(self):
  685. if self.equality:
  686. return Eq(self.var.var, self.bound)
  687. elif self.upper and self.strict:
  688. return self.var.var < self.bound
  689. elif not self.upper and self.strict:
  690. return self.var.var > self.bound
  691. elif self.upper:
  692. return self.var.var <= self.bound
  693. else:
  694. return self.var.var >= self.bound
  695. def __repr__(self):
  696. return repr("Boundary(" + repr(self.get_inequality()) + ")")
  697. def __eq__(self, other):
  698. other = (other.var, other.bound, other.strict, other.upper, other.equality)
  699. return (self.var, self.bound, self.strict, self.upper, self.equality) == other
  700. def __hash__(self):
  701. return hash((self.var, self.bound, self.strict, self.upper, self.equality))
  702. class LRARational():
  703. """
  704. Represents a rational plus or minus some amount
  705. of arbitrary small deltas.
  706. """
  707. def __init__(self, rational, delta):
  708. self.value = (rational, delta)
  709. def __lt__(self, other):
  710. return self.value < other.value
  711. def __le__(self, other):
  712. return self.value <= other.value
  713. def __eq__(self, other):
  714. return self.value == other.value
  715. def __add__(self, other):
  716. return LRARational(self.value[0] + other.value[0], self.value[1] + other.value[1])
  717. def __sub__(self, other):
  718. return LRARational(self.value[0] - other.value[0], self.value[1] - other.value[1])
  719. def __mul__(self, other):
  720. assert not isinstance(other, LRARational)
  721. return LRARational(self.value[0] * other, self.value[1] * other)
  722. def __getitem__(self, index):
  723. return self.value[index]
  724. def __repr__(self):
  725. return repr(self.value)
  726. class LRAVariable():
  727. """
  728. Object to keep track of upper and lower bounds
  729. on `self.var`.
  730. """
  731. def __init__(self, var):
  732. self.upper = LRARational(float("inf"), 0)
  733. self.upper_from_eq = False
  734. self.upper_from_neg = False
  735. self.lower = LRARational(-float("inf"), 0)
  736. self.lower_from_eq = False
  737. self.lower_from_neg = False
  738. self.assign = LRARational(0,0)
  739. self.var = var
  740. self.col_idx = None
  741. def __repr__(self):
  742. return repr(self.var)
  743. def __eq__(self, other):
  744. if not isinstance(other, LRAVariable):
  745. return False
  746. return other.var == self.var
  747. def __hash__(self):
  748. return hash(self.var)