solve.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import logging
  2. from typing import Optional
  3. import sympy
  4. from torch.utils._sympy.functions import FloorDiv
  5. log = logging.getLogger(__name__)
  6. _MIRROR_REL_OP: dict[type[sympy.Basic], type[sympy.Rel]] = {
  7. sympy.Eq: sympy.Eq,
  8. sympy.Ne: sympy.Ne,
  9. sympy.Ge: sympy.Le,
  10. sympy.Gt: sympy.Lt,
  11. sympy.Le: sympy.Ge,
  12. sympy.Lt: sympy.Gt,
  13. }
  14. INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
  15. def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]:
  16. return _MIRROR_REL_OP.get(type, None)
  17. # Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
  18. #
  19. # Returns a tuple of:
  20. # 1. The simplified expression
  21. # 2. The expression on the right-hand side
  22. #
  23. # Returns 'None' if it can't reach a state where the only thing in the left
  24. # hand side is 'thing'.
  25. #
  26. # 'trials': number of times 'try_solve' will try to isolate 'thing' to the
  27. # left-hand side.
  28. #
  29. # 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
  30. # inequalities.
  31. def try_solve(
  32. expr: sympy.Basic,
  33. thing: sympy.Basic,
  34. trials: int = 5,
  35. floordiv_inequality: bool = True,
  36. ) -> Optional[tuple[sympy.Rel, sympy.Expr]]:
  37. mirror = mirror_rel_op(type(expr))
  38. # Ignore unsupported expressions:
  39. # - Those that are not relational operations
  40. # - Those that don't have a mirror (just avoiding unexpected classes)
  41. if not isinstance(expr, sympy.Rel) or mirror is None:
  42. log.debug("expression with unsupported type: %s", type(expr))
  43. return None
  44. lhs_has_thing = expr.lhs.has(thing)
  45. rhs_has_thing = expr.rhs.has(thing)
  46. # Give up when 'thing' appears on both sides of the relational expression.
  47. # That is because, as is, we assume the thing we are trying to isolate is
  48. # only on the right-hand side.
  49. if lhs_has_thing and rhs_has_thing:
  50. log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
  51. return None
  52. # Try considering both LHS and RHS by mirroring the original expression:
  53. # a < b ==> b > a
  54. expressions = []
  55. # Add each version of 'expr' if 'thing' is in its left-hand side.
  56. if lhs_has_thing:
  57. expressions.append(expr)
  58. if rhs_has_thing:
  59. expressions.append(mirror(expr.rhs, expr.lhs))
  60. for e in expressions:
  61. if e is None:
  62. continue
  63. assert isinstance(e, sympy.Rel)
  64. for _ in range(trials):
  65. trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
  66. # Stop if there was no change in this trial.
  67. if trial == e:
  68. break
  69. e = trial # type: ignore[assignment]
  70. # Return if we were able to isolate 'thing' on the left-hand side.
  71. if isinstance(e, sympy.Rel) and e.lhs == thing:
  72. log.debug("solved: %s ---> %s", expr, e)
  73. return e, e.rhs
  74. return None
  75. def _try_isolate_lhs(
  76. e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
  77. ) -> sympy.Basic:
  78. op = type(e)
  79. if isinstance(e, sympy.Rel):
  80. # Move any constants in the left-hand side to the right-hand side.
  81. lhs_not_thing = (
  82. sum(a for a in e.lhs.args if not a.has(thing))
  83. if isinstance(e.lhs, sympy.Add)
  84. else 0
  85. )
  86. e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing) # type: ignore[attr-defined]
  87. # Divide both sides by the factors that don't contain thing.
  88. if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
  89. lhs, rhs = e.args
  90. other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
  91. # If we can't tell whether 'other' is negative or positive, we do nothing.
  92. # That is because we don't know whether we have mirror the operation or not.
  93. # We also divide only when we know 'rhs' is not zero.
  94. if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None) and not (
  95. not isinstance(e, INEQUALITY_TYPES) and rhs.is_zero
  96. ):
  97. # Divide both sides by 'other'.
  98. lhs = lhs / other
  99. rhs = rhs / other
  100. # If 'e' is an inequality and 'other' is negative, we have to
  101. # mirror the expression.
  102. if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
  103. op = mirror_rel_op(op) # type: ignore[assignment]
  104. assert op is not None
  105. e = op(lhs, rhs)
  106. ################################################################################
  107. # left-hand side is FloorDiv
  108. ################################################################################
  109. #
  110. # Given the expression: a // b op c
  111. # where 'op' is a relational operation, these rules only work if:
  112. # - b > 0
  113. # - c is an integer
  114. if (
  115. floordiv_inequality
  116. and isinstance(e, sympy.Rel)
  117. and isinstance(e.lhs, FloorDiv)
  118. and e.lhs.divisor.is_positive
  119. and e.rhs.is_integer
  120. ):
  121. # a // b == expr
  122. # => a >= (b * expr) and a < (b * (expr + 1))
  123. if isinstance(e, sympy.Eq):
  124. numerator, denominator = e.lhs.args
  125. return sympy.And(
  126. sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
  127. sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
  128. )
  129. # a // b != expr
  130. # => a < (b * expr) or a >= (b * (expr + 1))
  131. if isinstance(e, sympy.Ne):
  132. numerator, denominator = e.lhs.args
  133. return sympy.Or(
  134. sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
  135. sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
  136. )
  137. # The transformations below only work if b is positive.
  138. # Note: we only have this information for constants.
  139. # a // b > expr => a >= b * (expr + 1)
  140. # a // b >= expr => a >= b * expr
  141. if isinstance(e, (sympy.Gt, sympy.Ge)):
  142. quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]
  143. return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
  144. # a // b < expr => a < b * expr
  145. # a // b <= expr => a < b * (expr + 1)
  146. if isinstance(e, (sympy.Lt, sympy.Le)):
  147. quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]
  148. return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
  149. return e