bounds.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import logging
  2. import operator
  3. from functools import partial
  4. from typing import Any, Callable, Optional, Union
  5. import sympy
  6. from sympy import Expr
  7. import torch
  8. from torch.utils._sympy.value_ranges import (
  9. bound_sympy,
  10. SymPyValueRangeAnalysis,
  11. ValueRanges,
  12. )
  13. from ..utils._sympy.functions import PowByNatural
  14. from ..utils._sympy.numbers import int_oo
  15. from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
  16. from .ops_handler import DefaultHandler, ReductionType, StoreMode
  17. from .utils import cache_on_self, dominated_nodes
  18. from .virtualized import V
  19. log = logging.getLogger(__name__)
  20. class BoundVars:
  21. """
  22. Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
  23. It exposes the ranges of the nodes in the `bounds` variable
  24. Note. A current limitation of this analysis is that it just works on a per-loop basis.
  25. We should be able to propagate the bounds between across the whole graph. This may benefit
  26. the case a bounded variable is returned by a kernel and fed into another.
  27. """
  28. def __init__(self, loop_body: LoopBody) -> None:
  29. def upper_bound(v: Union[Expr, int]) -> int:
  30. return bound_sympy(v).upper if isinstance(v, Expr) else v
  31. self.loop_body = loop_body
  32. self.replacement_vals = {
  33. k: ValueRanges[Expr](0, upper_bound(v) - 1)
  34. for k, v in loop_body.var_ranges.items()
  35. }
  36. # avoid computing these values, pessimistically assume that they are unbounded
  37. self.unbounded_vars = dominated_nodes(
  38. node
  39. for node in self.loop_body.get_nodes()
  40. if node.target in ["load", "reduction", operator.getitem]
  41. or "masked_subblock" in node.target
  42. )
  43. # To access this variable call `get_bounds()`
  44. self._bounds: dict[torch.fx.Node, ValueRanges[Expr]] = {}
  45. def __repr__(self) -> str:
  46. return (
  47. f"{self.__class__.__name__}("
  48. f"loop_body={self.loop_body},\n "
  49. f"replacement_vals={self.replacement_vals}, \n"
  50. f"unbounded_vars={self.unbounded_vars}, \n"
  51. f"_bounds={self._bounds})"
  52. )
  53. @cache_on_self
  54. def get_bounds(self) -> dict[torch.fx.Node, ValueRanges[Expr]]:
  55. submodules = self.swap_submodules(self.loop_body.submodules)
  56. # Initialize the environment with the unbounded variables
  57. for node in self.unbounded_vars:
  58. # we need to evaluate masked_subblock to recurse, and we need to set indirect values
  59. if not isinstance(node.target, str) or (
  60. "masked_subblock" not in node.target
  61. and "set_indirect" not in node.target
  62. ):
  63. self._bounds[node] = ValueRanges[Expr].unknown()
  64. with V.set_ops_handler(ValueRangeAnalysis()):
  65. interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
  66. log.debug("get_bounds:\n%s", self.loop_body.root_block.graph)
  67. interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
  68. return self._bounds
  69. def swap_submodules(
  70. self, submodules: dict[str, Callable[..., Any]]
  71. ) -> dict[str, Callable[..., ValueRanges[Expr]]]:
  72. result: dict[str, Callable[..., ValueRanges[Expr]]] = {}
  73. for key in submodules.keys():
  74. if key == "get_index":
  75. result[key] = self.get_index
  76. elif "masked_subblock" in key:
  77. subblock = self.loop_body.subblocks[key]
  78. # The result within the lambda will reference to the final
  79. # set of modules at the end of the for-loop as it stores a reference to it
  80. # bind subblock in a function because python lambdas close over by reference
  81. # moving the lambda out of make_fn would close over the reference to subblock,
  82. # so all lambdas would have the same subblock reference that is the final
  83. # subblock in the loop
  84. def make_fn(
  85. subblock: LoopBodyBlock,
  86. ) -> Callable[[Any, Any], ValueRanges[Expr]]:
  87. return lambda mask, value: self.masked_subblock(
  88. subblock, self._bounds, mask, value, result
  89. )
  90. result[key] = make_fn(subblock)
  91. elif "set_indirect" in key:
  92. idx = int(key[len("set_indirect") :])
  93. var = self.loop_body.indirect_vars[idx]
  94. indirect = partial(self.set_indirect, var)
  95. result[key] = indirect
  96. else:
  97. assert "scan" in key
  98. result[key] = submodules[key]
  99. return result
  100. def masked_subblock(
  101. self,
  102. subblock: LoopBodyBlock,
  103. env: dict[torch.fx.Node, ValueRanges[Expr]],
  104. mask: Any,
  105. value: Any,
  106. submodules: dict[str, Callable[..., Any]],
  107. ) -> ValueRanges[Expr]:
  108. interp = InterpreterShim(subblock.graph, submodules)
  109. interp.run(V.get_ops_handler(), initial_env=env)
  110. output = [node for node in subblock.graph.nodes if node.target == "output"]
  111. assert len(output) == 1
  112. # dont bother unioning with value since the load from buffer will be
  113. # pessimistically assumed to be inf anyway
  114. return interp.env[output[0]]
  115. def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]:
  116. assert isinstance(new, ValueRanges)
  117. self.replacement_vals[old] = new
  118. return new
  119. def get_index(self, name: str) -> ValueRanges[Expr]:
  120. expr = self.loop_body.indexing_exprs[name]
  121. bound = self.replacement_vals.get(expr)
  122. if bound is None:
  123. bound = bound_sympy(expr, self.replacement_vals)
  124. # The following assertion is true at the time of this writing
  125. # We don't assert is as to not execute bound_sympy when bound is not None
  126. # assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
  127. self.replacement_vals[name] = bound
  128. return bound
  129. class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler):
  130. def __init__(self) -> None:
  131. self.name = "ValueRangeAnalysis"
  132. boolean_operators = (
  133. "xor",
  134. "logical_and",
  135. "logical_or",
  136. "logical_not",
  137. )
  138. for op in boolean_operators:
  139. setattr(self, op, self.bool_handler)
  140. @staticmethod
  141. def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]:
  142. # just assuming bools can have both values
  143. return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
  144. def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
  145. # many ops are unlikely to show up in optimizable indexing compute,
  146. # so we dont have full coverage
  147. return ValueRanges.unknown()
  148. def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]:
  149. return ValueRanges.unknown()
  150. def store(
  151. self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None
  152. ) -> None:
  153. return
  154. def reduction(
  155. self,
  156. dtype: torch.dtype,
  157. src_dtype: torch.dtype,
  158. reduction_type: ReductionType,
  159. value: Any,
  160. ) -> ValueRanges[Any]:
  161. return ValueRanges.unknown()
  162. @classmethod
  163. def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]:
  164. assert isinstance(index, ValueRanges)
  165. return cls.to_dtype(index, dtype)
  166. @staticmethod
  167. def to_dtype(
  168. x: Any,
  169. dtype: torch.dtype,
  170. src_dtype: Optional[torch.dtype] = None,
  171. use_compute_types: bool = True,
  172. ) -> ValueRanges[Any]:
  173. x = ValueRanges.wrap(x)
  174. if dtype == torch.bool:
  175. if x.is_singleton():
  176. return ValueRanges.wrap(x.lower != 0)
  177. elif x.is_bool:
  178. return x
  179. elif 0 not in x:
  180. return ValueRanges.wrap(sympy.true)
  181. else:
  182. return ValueRanges(sympy.false, sympy.true)
  183. def cast(x: Any, dtype: torch.dtype) -> sympy.Expr:
  184. # dtype is int or float
  185. if dtype.is_floating_point:
  186. return sympy.Float(x)
  187. else:
  188. if x in (int_oo, -int_oo):
  189. return x
  190. try:
  191. return sympy.Integer(x)
  192. except TypeError:
  193. # inf cannot be cast to Integer
  194. return x
  195. if x.is_bool:
  196. if x.is_singleton():
  197. val = 1 if x.lower else 0
  198. return ValueRanges.wrap(cast(val, dtype))
  199. else:
  200. return ValueRanges(cast(0, dtype), cast(1, dtype))
  201. else:
  202. # int to float or float to int
  203. return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
  204. @staticmethod
  205. def square(x: Any) -> ValueRanges[Any]:
  206. return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
  207. @staticmethod
  208. def neg(x: Any) -> ValueRanges[Any]:
  209. return ValueRanges.decreasing_map(x, operator.neg)
  210. # TODO: this is slightly inaccurate because truncdiv operates at integer
  211. # precision, but we're going through float truediv which means we can
  212. # potentially lose precision on the bounds
  213. @classmethod
  214. def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]:
  215. x = cls.truediv(a, b)
  216. if x == ValueRanges.unknown():
  217. return x
  218. return cls.trunc(x)
  219. @classmethod
  220. def sub(cls, a: Any, b: Any) -> ValueRanges[Any]:
  221. return cls.add(a, cls.neg(b))