interp.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # mypy: allow-untyped-defs
  2. """
  3. This is a simple interpreter for Sympy expressions that dispatches to
  4. classes following the torch._inductor.virtualized calling convention.
  5. For directness, the interpreter takes the handler directly rather than
  6. consulting the TLS. It does not use most of the methods on the full
  7. handler; only those with corresponding Sympy expressions. To see an example
  8. of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
  9. """
  10. import functools
  11. import logging
  12. from typing import Any, Union
  13. import sympy
  14. from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
  15. import torch
  16. from .functions import (
  17. BitwiseFn_bitwise_and,
  18. BitwiseFn_bitwise_or,
  19. CeilToInt,
  20. CleanDiv,
  21. FloatPow,
  22. FloatTrueDiv,
  23. FloorDiv,
  24. FloorToInt,
  25. Identity,
  26. IntTrueDiv,
  27. IsNonOverlappingAndDenseIndicator,
  28. Max,
  29. Min,
  30. Mod,
  31. ModularIndexing,
  32. OpaqueUnaryFn_log2,
  33. PowByNatural,
  34. PythonMod,
  35. RoundDecimal,
  36. RoundToInt,
  37. ToFloat,
  38. TruncToFloat,
  39. TruncToInt,
  40. Where,
  41. )
  42. log = logging.getLogger(__name__)
  43. # TODO: Dedupe this with SYMPY_INTERP
  44. @functools.cache
  45. def handlers():
  46. # TODO add CeilDiv (it doesn't appear in the index_expr)
  47. # TODO default to some decompositions if the interpreter doesn't have them
  48. # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
  49. HANDLERS = {
  50. sympy.Or: "or_",
  51. sympy.And: "and_",
  52. sympy.Eq: "eq",
  53. sympy.Ne: "ne",
  54. sympy.Lt: "lt",
  55. sympy.Gt: "gt",
  56. sympy.Le: "le",
  57. sympy.Ge: "ge",
  58. sympy.Not: "not_",
  59. IntTrueDiv: "int_truediv",
  60. FloatTrueDiv: "truediv",
  61. FloorDiv: "floordiv",
  62. CleanDiv: "floordiv", # TODO: hmm?
  63. TruncToFloat: "trunc",
  64. Where: "where",
  65. sympy.Add: "add",
  66. sympy.Mul: "mul",
  67. FloatPow: "pow",
  68. PowByNatural: "pow_by_natural",
  69. # sympy simplifies x * x into Pow(x, 2), so we need to handle this.
  70. # Do NOT use builtin Pow for floats
  71. # TODO: There is a hazard here, if we have float * float it will
  72. # also get turned into Pow(float, 2) but we don't want this because
  73. # pow_by_natural is assumed to only be integers. Probably the fix is
  74. # to add a FloatMul to impede this optimization
  75. sympy.Pow: "pow_by_natural",
  76. Mod: "mod",
  77. PythonMod: "mod", # TODO: this is wrong
  78. # TODO: Inductor can generate these, but it's ill-specified which
  79. # semantics were intended here. Needs to be cleaned up along with
  80. # FloorDiv in a bigger cleanup
  81. sympy.Mod: "mod",
  82. sympy.Abs: "abs",
  83. sympy.log: "log",
  84. sympy.exp: "exp",
  85. sympy.Min: "minimum",
  86. sympy.Max: "maximum",
  87. Min: "minimum",
  88. Max: "maximum",
  89. ModularIndexing: "modular_indexing",
  90. sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
  91. sympy.Piecewise: "piecewise",
  92. Identity: "identity",
  93. IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
  94. RoundDecimal: "round_decimal",
  95. # TODO: do the rest of the opaque unary functions...
  96. OpaqueUnaryFn_log2: "log2",
  97. BitwiseFn_bitwise_and: "bitwise_and",
  98. BitwiseFn_bitwise_or: "bitwise_or",
  99. }
  100. # TODO: This is kind of pointless, we shouldn't be generating sympy.sin
  101. # for these functions, they should be Opaque instead
  102. for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
  103. HANDLERS[getattr(sympy, name)] = name
  104. return HANDLERS
  105. ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
  106. def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):
  107. # Special cases
  108. if isinstance(expr, sympy.Pow) and isinstance(
  109. expr.args[1], sympy.core.numbers.Half
  110. ):
  111. return analysis.sqrt(args[0])
  112. if isinstance(expr, ToFloat):
  113. return analysis.to_dtype(args[0], torch.float64)
  114. # These handlers are special because they take an extra dtype argument
  115. # specifying what they should convert to, and we need to appropriately set
  116. # this up when we convert from Sympy. A reasonable default when you
  117. # are translating is to conservatively do int64, and then narrow these
  118. # arguments later when you discover you can narrow the index range. But
  119. # if you already know that 32-bit indexing is OK, you can directly do the
  120. # sympy translation with index_dtype=torch.int32
  121. INDEX_DTYPE_HANDLERS = {
  122. TruncToInt: "trunc_to_int",
  123. sympy.floor: "floor_to_int",
  124. sympy.ceiling: "ceil_to_int",
  125. FloorToInt: "floor_to_int",
  126. CeilToInt: "ceil_to_int",
  127. RoundToInt: "round_to_int",
  128. }
  129. if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
  130. return getattr(analysis, handler_name)(*args, index_dtype)
  131. # Fastpath for n-ary integral addition
  132. if expr.func is sympy.Add and expr.is_integer and hasattr(analysis, "sym_sum"):
  133. r = analysis.sym_sum(args)
  134. log.debug("sym_sum(%s) -> %s", args, r)
  135. return r
  136. if hasattr(expr.func, "_torch_handler_name"):
  137. handler_name = expr.func._torch_handler_name
  138. else:
  139. handler_name = handlers()[expr.func]
  140. handler = getattr(analysis, handler_name)
  141. try:
  142. if handler_name in ASSOCIATIVE_OPS:
  143. assert len(args) > 1
  144. acc = handler(args[0], args[1])
  145. for i in range(2, len(args)):
  146. acc = handler(acc, args[i])
  147. log.debug("%s(%s) -> %s", handler_name, args, acc)
  148. return acc
  149. else:
  150. r = handler(*args)
  151. log.debug("%s(%s) -> %s", handler_name, args, r)
  152. return r
  153. except NotImplementedError:
  154. raise
  155. except Exception:
  156. log.warning("failed while executing %s(%s)", handler_name, args)
  157. raise
  158. _nil = object()
  159. def sympy_interp(
  160. analysis,
  161. env: dict[sympy.Symbol, Any],
  162. expr: Union[sympy.Expr, SympyBoolean],
  163. *,
  164. index_dtype=torch.int64,
  165. missing_handler=None,
  166. ):
  167. # Handle base cases
  168. dtype = None
  169. if isinstance(expr, BooleanAtom):
  170. dtype = torch.bool
  171. elif isinstance(expr, sympy.Integer):
  172. dtype = torch.int64
  173. elif isinstance(expr, sympy.Number):
  174. dtype = torch.double
  175. if dtype is not None:
  176. return analysis.constant(expr, dtype)
  177. elif isinstance(expr, sympy.Symbol):
  178. if (r := env.get(expr, _nil)) is not _nil:
  179. return r
  180. elif missing_handler:
  181. return missing_handler(expr)
  182. else:
  183. raise KeyError(expr)
  184. # Recursive case
  185. return _run_sympy_handler(
  186. analysis,
  187. [
  188. sympy_interp(
  189. analysis,
  190. env,
  191. arg,
  192. index_dtype=index_dtype,
  193. missing_handler=missing_handler,
  194. )
  195. for arg in expr.args
  196. ], # type: ignore[arg-type]
  197. expr,
  198. index_dtype=index_dtype,
  199. ) # type: ignore[arg-type]