dependencies.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866
  1. import abc
  2. import dataclasses
  3. import itertools
  4. import logging
  5. import re
  6. from collections.abc import Iterable, Sequence
  7. from typing import Any, Callable, Optional, TypeVar, Union
  8. from typing_extensions import Self
  9. from unittest.mock import patch
  10. import sympy
  11. import torch
  12. from torch._inductor.utils import get_free_symbols
  13. from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols
  14. from torch.utils._ordered_set import OrderedSet
  15. from ..utils._sympy.symbol import make_symbol, SymT
  16. from .codegen.common import index_prevent_reordering
  17. from .ops_handler import DefaultHandler
  18. from .utils import (
  19. get_dtype_size,
  20. reduction_num_outputs,
  21. sympy_index_symbol,
  22. sympy_str,
  23. sympy_subs,
  24. VarRanges,
  25. )
  26. from .virtualized import ReductionType, V
  27. T = TypeVar("T")
  28. log = logging.getLogger(__name__)
  29. is_indirect = re.compile(r"indirect|tmp").search
  30. class Dep(abc.ABC):
  31. name: str
  32. index: sympy.Expr
  33. @abc.abstractmethod
  34. def get_free_symbol_uses(
  35. self, unbacked_only: bool = False
  36. ) -> OrderedSet[sympy.Symbol]:
  37. pass
  38. @abc.abstractmethod
  39. def rename(self, renames: dict[str, str]) -> Self:
  40. pass
  41. @abc.abstractmethod
  42. def get_numel(self) -> sympy.Expr:
  43. pass
  44. @abc.abstractmethod
  45. def numbytes_hint(self) -> int:
  46. pass
  47. @abc.abstractmethod
  48. def has_unbacked_symbols(self) -> bool:
  49. pass
  50. @abc.abstractmethod
  51. def is_contiguous(self) -> bool:
  52. pass
  53. def normalize_with_stride_order(self, prefix: str = "t") -> Self:
  54. return self
  55. @dataclasses.dataclass(frozen=True)
  56. class MemoryDep(Dep):
  57. name: str
  58. index: sympy.Expr
  59. var_names: tuple[sympy.Symbol, ...]
  60. size: tuple[sympy.Expr, ...]
  61. mode: Optional[str] = None
  62. def get_free_symbol_uses(
  63. self, unbacked_only: bool = False
  64. ) -> OrderedSet[sympy.Symbol]:
  65. return (
  66. get_free_symbols(self.index, unbacked_only)
  67. | get_free_symbols(self.size, unbacked_only)
  68. | get_free_symbols(self.var_names, unbacked_only)
  69. )
  70. def __repr__(self) -> str:
  71. maybe_mode = ""
  72. if self.mode is not None:
  73. maybe_mode = f", {self.mode}"
  74. return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}{maybe_mode})"
  75. @property
  76. def num_vars(self) -> int:
  77. return len(self.var_names)
  78. def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[list[int]]:
  79. """
  80. Can return None if not able to decide loop orders.
  81. """
  82. assert self.num_vars == other.num_vars
  83. # ignore broadcast for now since broadcast causes extra 0 strides
  84. # which makes it hard to decide the correct loop orders.
  85. if self.num_vars != len(self.index.free_symbols):
  86. return None
  87. if other.num_vars != len(other.index.free_symbols):
  88. return None
  89. # bail out if any size is 0 or 1
  90. # For size == 0, it's an empty tensor, any strides for that dimension
  91. # are equivalent. Skip for simplicity and it may not matter that much.
  92. #
  93. # For size == 1, it cause cause tie for strides of different dimensions.
  94. # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder
  95. # we can dependencies.index_vars_squeeze which should already sqeeuze
  96. # the size == 1 dimensions.
  97. if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)):
  98. return None
  99. # Extract strides for both expression
  100. self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
  101. other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names)
  102. # Even if the shape contains no 0/1, some complex index expression may
  103. # still have duplicate stride values. Here is an example:
  104. # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129
  105. # We don't reorder the loop for these cases for now, but in theory
  106. # we could improve the algorithm to detect the correct loop orders.
  107. if len(OrderedSet(self_strides)) != len(self_strides) or len(
  108. OrderedSet(other_strides)
  109. ) != len(other_strides):
  110. log.debug(
  111. "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s",
  112. self,
  113. other,
  114. self_strides,
  115. other_strides,
  116. )
  117. return None
  118. # May happen if self and other are as follows
  119. # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None)
  120. # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None)
  121. if OrderedSet(self_strides) != OrderedSet(other_strides):
  122. return None
  123. stride_to_index = {s: i for i, s in enumerate(self_strides)}
  124. order = [stride_to_index[s] for s in other_strides]
  125. assert OrderedSet(order) == OrderedSet(range(0, self.num_vars))
  126. return order
  127. def get_offset(self) -> sympy.Expr:
  128. """
  129. Return the offset by setting every variable to be 0.
  130. """
  131. return sympy_subs(self.index, dict.fromkeys(self.var_names, 0))
  132. def normalize(self) -> "MemoryDep":
  133. """
  134. Normalize by merging loops. The different to normalize_with_stride_order is,
  135. this method does not reorder loops while normalize_with_stride_order reorder
  136. loops based on stride order.
  137. """
  138. return MemoryDep(
  139. self.name,
  140. *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type]
  141. self.mode,
  142. )
  143. def normalize_with_stride_order(self, prefix: str = "t") -> "MemoryDep":
  144. r"""
  145. Used to decide if two MemoryDep does not equal due to different loop orders.
  146. More specifically, when dep1 and dep2 are not equal, we can normalize
  147. both and check if they are equal after that. If yes, then the mismatch is
  148. caused by different loop orders.
  149. """
  150. # import here to avoid circular import
  151. from torch._inductor import ir
  152. strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
  153. # pick a loop order with stride ordered decreasingly
  154. order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
  155. stride_reorder = ir.same_reorder(order)
  156. sizes = self.size
  157. var_names = self.var_names
  158. new_reordered_sizes = stride_reorder(sizes)
  159. new_reordered_var_names = stride_reorder(var_names)
  160. new_simplified_sizes, reindex, _prune = V.graph.sizevars._simplify_loops(
  161. new_reordered_var_names,
  162. new_reordered_sizes,
  163. index_prevent_reordering(
  164. [self.index], new_reordered_var_names, new_reordered_sizes
  165. ),
  166. )
  167. # now let's create new symbols with the passed in prefix
  168. var_ranges, add_var = var_builder(prefix)
  169. replacement = dict(
  170. zip(
  171. new_reordered_var_names,
  172. reindex([add_var(x) for x in new_simplified_sizes]),
  173. )
  174. )
  175. new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
  176. out = MemoryDep(
  177. self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())
  178. ) # type: ignore[arg-type]
  179. return out
  180. @property
  181. def ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
  182. """{c0: 128, c1: 512, ...}"""
  183. return dict(zip(self.var_names, self.size))
  184. def simplify_with_ranges(self) -> "MemoryDep":
  185. return MemoryDep(
  186. name=self.name,
  187. index=V.graph.sizevars.simplify_with_ranges(self.index, self.ranges),
  188. var_names=self.var_names,
  189. size=self.size,
  190. mode=self.mode,
  191. )
  192. def get_numel(self) -> sympy.Expr:
  193. if self.is_indirect():
  194. numel = V.graph.get_numel(self.name)
  195. else:
  196. vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols)
  197. numel = sympy.S.One
  198. for var, size in zip(self.var_names, self.size):
  199. if var in vars:
  200. numel = numel * size
  201. return numel # type: ignore[return-value]
  202. def rename(self, renames: dict[str, str]) -> "MemoryDep":
  203. if self.name in renames:
  204. return MemoryDep(
  205. renames[self.name],
  206. self.index,
  207. var_names=self.var_names,
  208. size=self.size,
  209. mode=self.mode,
  210. )
  211. return self
  212. def numbytes_hint(self) -> int:
  213. try:
  214. return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
  215. V.graph.get_dtype(self.name)
  216. )
  217. except NotImplementedError: # NoneLayout
  218. return 0
  219. def has_unbacked_symbols(self) -> bool:
  220. return len(free_unbacked_symbols(self.get_numel())) > 0
  221. def is_contiguous(self) -> bool:
  222. if isinstance(self.index, sympy.Integer):
  223. return True
  224. return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
  225. def stride1_for_last_dim(self, result_for_complex_expression: bool = True) -> bool:
  226. """
  227. Whether the stride for the last dimension is 1.
  228. """
  229. # python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16
  230. # will exercise thru this corner case.
  231. if len(self.var_names) == 0:
  232. return True
  233. terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index]
  234. last_sym = self.var_names[-1]
  235. for term in terms:
  236. if term == last_sym:
  237. return True
  238. # Having a >1 stride for the last dimension is bad for perf
  239. # return False.
  240. if (
  241. isinstance(term, sympy.Mul)
  242. and len(term.args) == 2
  243. and term.args[1] == last_sym
  244. and isinstance(term.args[0], (int, sympy.Integer))
  245. and term.args[0] > 1
  246. ):
  247. return False
  248. return result_for_complex_expression
  249. def is_scalar(self) -> bool:
  250. if isinstance(self.index, sympy.Symbol):
  251. return self.index not in self.var_names and not self.is_indirect()
  252. return isinstance(self.index, (int, sympy.Integer))
  253. def is_indirect(self) -> bool:
  254. return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
  255. @dataclasses.dataclass(frozen=True)
  256. class StarDep(Dep):
  257. name: str
  258. mode: Optional[str] = None
  259. # depends on the entire buffer
  260. @property
  261. def index(self) -> sympy.Expr:
  262. raise NotImplementedError("StarDep does not have an index")
  263. def get_numel(self) -> sympy.Expr:
  264. return V.graph.get_numel(self.name) # type: ignore[return-value]
  265. def rename(self, renames: dict[str, str]) -> "StarDep":
  266. if self.name in renames:
  267. return StarDep(renames[self.name], self.mode)
  268. return self
  269. def get_free_symbol_uses(
  270. self, unbacked_only: bool = False
  271. ) -> OrderedSet[sympy.Symbol]:
  272. return OrderedSet()
  273. def numbytes_hint(self) -> int:
  274. try:
  275. return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
  276. V.graph.get_dtype(self.name)
  277. )
  278. except NotImplementedError:
  279. return 0 # NoneLayout, MultiOutputLayout, etc
  280. def has_unbacked_symbols(self) -> bool:
  281. return len(free_unbacked_symbols(self.get_numel())) > 0
  282. def is_contiguous(self) -> bool:
  283. return False
  284. def is_scalar(self) -> bool:
  285. return False
  286. def is_indirect(self) -> bool:
  287. return False
  288. # Used for tracking mutation ordering
  289. # if A reads a buffer and B mutates it
  290. # B must be ordered after A
  291. #
  292. # This is useful for a variety of reasons.
  293. # For example, if A's read is never actually used, we can eliminate it.
  294. # Another case is if A's buffer ends up being fused away, we never need to
  295. # materialize that buffer
  296. @dataclasses.dataclass(frozen=True)
  297. class WeakDep(Dep):
  298. # Fake dependency on unused buffer
  299. name: str
  300. # Buffer that is doing the mutation
  301. mutating_buf: str
  302. # WeakDep's are also used to add dependencies to prevent some specific reordering,
  303. # E.g. collectives global ordering.
  304. # But if other pass guarantees proper ordering by its logic,
  305. # This additional "fake" deps will be holding optimizations.
  306. # This flag is used to identify those additional deps.
  307. is_fake: bool = False
  308. def get_free_symbol_uses(
  309. self, unbacked_only: bool = False
  310. ) -> OrderedSet[sympy.Symbol]:
  311. return OrderedSet()
  312. @property
  313. def index(self) -> sympy.Expr:
  314. raise NotImplementedError("WeakDep does not have an index")
  315. def get_numel(self) -> sympy.Expr:
  316. return sympy.S.One
  317. def rename(self, renames: dict[str, str]) -> "WeakDep":
  318. if self.name in renames:
  319. return WeakDep(renames[self.name], self.mutating_buf, self.is_fake)
  320. return self
  321. def numbytes_hint(self) -> int:
  322. return 1 # Purely inserted for ordering, not an actual dep
  323. def has_unbacked_symbols(self) -> bool:
  324. return False
  325. def is_contiguous(self) -> bool:
  326. return False
  327. @dataclasses.dataclass(frozen=True)
  328. class IndexExprDep:
  329. index: sympy.Expr # type: ignore[assignment]
  330. var_names: tuple[sympy.Symbol, ...]
  331. size: tuple[sympy.Expr, ...]
  332. @dataclasses.dataclass
  333. class ReadWrites:
  334. reads: OrderedSet[Dep]
  335. writes: OrderedSet[Dep]
  336. index_exprs: OrderedSet[IndexExprDep]
  337. range_vars: Optional[list[sympy.Expr]] = None
  338. var_ranges: Optional[VarRanges] = None
  339. def rename(self, renames: dict[str, str]) -> "ReadWrites":
  340. return ReadWrites(
  341. OrderedSet(dep.rename(renames) for dep in self.reads),
  342. OrderedSet(dep.rename(renames) for dep in self.writes),
  343. self.index_exprs,
  344. self.range_vars,
  345. self.var_ranges,
  346. )
  347. def with_read(self, dep: Union[Dep, OrderedSet[Dep]]) -> "ReadWrites":
  348. assert isinstance(dep, (WeakDep, StarDep, OrderedSet))
  349. if not isinstance(dep, OrderedSet):
  350. dep = OrderedSet([dep])
  351. return ReadWrites(
  352. OrderedSet.union(self.reads, dep),
  353. self.writes,
  354. self.index_exprs,
  355. self.range_vars,
  356. self.var_ranges,
  357. )
  358. def merge(self, other: "ReadWrites") -> "ReadWrites":
  359. reads = OrderedSet.union(self.reads, other.reads)
  360. writes = OrderedSet.union(self.writes, other.writes)
  361. index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs)
  362. return ReadWrites(reads - writes, writes, index_exprs)
  363. @staticmethod
  364. def merge_list(read_writes: list["ReadWrites"]) -> "ReadWrites":
  365. all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
  366. all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
  367. all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
  368. return ReadWrites(all_reads, all_writes, all_index_exprs)
  369. def remove_reads(self, rem_reads: OrderedSet[Dep]) -> "ReadWrites":
  370. return ReadWrites(
  371. self.reads - rem_reads,
  372. self.writes,
  373. self.index_exprs,
  374. self.range_vars,
  375. self.var_ranges,
  376. )
  377. def reads_and_writes(self) -> Iterable[Dep]:
  378. return itertools.chain(self.reads, self.writes)
  379. def buffer_names(self, ignore_integer_index: bool = True) -> OrderedSet[str]:
  380. """
  381. Integer index is used for load_seed.
  382. """
  383. names: OrderedSet[str] = OrderedSet()
  384. for dep in self.reads_and_writes():
  385. if not isinstance(dep, MemoryDep):
  386. continue
  387. if not ignore_integer_index or not isinstance(
  388. dep.index, (int, sympy.Integer)
  389. ):
  390. names.add(dep.name)
  391. return names
  392. def get_free_symbol_uses(
  393. self, unbacked_only: bool = False
  394. ) -> OrderedSet[sympy.Symbol]:
  395. result: OrderedSet[sympy.Symbol] = OrderedSet()
  396. for dep in self.reads_and_writes():
  397. result |= dep.get_free_symbol_uses(unbacked_only)
  398. return result
  399. class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
  400. def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
  401. super().__init__()
  402. self._reads: OrderedSet[Dep] = OrderedSet()
  403. self._writes: OrderedSet[MemoryDep] = OrderedSet()
  404. self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet()
  405. self._var_ranges: VarRanges = var_ranges
  406. self._should_normalize: bool = normalize
  407. @staticmethod
  408. def drop_unused_symbols(
  409. index: Union[int, sympy.Expr],
  410. var_names: list[sympy.Expr],
  411. sizes: list[sympy.Expr],
  412. ) -> None:
  413. """
  414. Reduction has last (reduced) dim in its sizes, but
  415. downstream users won't. Normalize this away.
  416. """
  417. if not isinstance(index, sympy.Expr):
  418. # index can be an int
  419. return
  420. free_symbols = index.free_symbols
  421. while var_names and var_names[-1] not in free_symbols:
  422. var_names.pop()
  423. sizes.pop()
  424. @classmethod
  425. def _normalize(
  426. cls, index: sympy.Expr, var_ranges: VarRanges
  427. ) -> tuple[sympy.Expr, tuple[sympy.Symbol, ...], tuple[sympy.Expr, ...]]:
  428. # Try to further simplify the indexes even if simplify_loops didn't
  429. # convert it to the simplest form because of the interference from
  430. # different indexing formulas.
  431. index_vars = [*var_ranges.keys()]
  432. sizes = tuple(var_ranges.values()) # type: ignore[assignment]
  433. new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops(
  434. index_vars,
  435. sizes,
  436. index_prevent_reordering([index], index_vars, sizes),
  437. )
  438. # assign new variables each dimension to deal with numbering mismatches
  439. # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
  440. new_vars, add_var = var_builder(canonicalization_prefix())
  441. replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
  442. index = sympy_subs(sympy.expand(index), replacement)
  443. new_vars = [*new_vars.keys()]
  444. new_sizes = [*new_sizes]
  445. cls.drop_unused_symbols(index, new_vars, new_sizes)
  446. return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
  447. def canonicalize(
  448. self, index: sympy.Expr
  449. ) -> tuple[sympy.Expr, tuple[sympy.Symbol, ...], tuple[sympy.Expr, ...]]:
  450. if not self._should_normalize:
  451. sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
  452. var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1]
  453. sizes = [v for v in sizes if v != 1]
  454. self.drop_unused_symbols(index, var_names, sizes)
  455. return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type]
  456. var_ranges = {
  457. k: V.graph.sizevars.simplify(v)
  458. for k, v in self._var_ranges.items()
  459. # TODO(jansel): explore this further normalization
  460. # if k in free_symbols
  461. }
  462. return self._normalize(index, var_ranges)
  463. def load(self, name: str, index: sympy.Expr) -> str:
  464. self._reads.add(MemoryDep(name, *self.canonicalize(index)))
  465. return f"load({name}, {sympy_str(index)})"
  466. def load_seed(self, name: str, index: int) -> str:
  467. assert isinstance(index, int)
  468. return self.load(name, sympy.Integer(index))
  469. def store(
  470. self, name: str, index: sympy.Expr, value: str, mode: Optional[str] = None
  471. ) -> str:
  472. self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode))
  473. return f"store({name}, {sympy_str(index)}, {value}, {mode})"
  474. def store_reduction(self, name: str, index: sympy.Expr, value: str) -> str:
  475. return self.store(name, index, f"store_reduction({value})")
  476. def index_expr(self, index: sympy.Expr, dtype: Optional[torch.dtype]) -> str:
  477. self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
  478. return f"index_expr({sympy_str(index)}, {dtype})"
  479. def bucketize(
  480. self,
  481. values: T,
  482. boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
  483. boundary_indices: T,
  484. indexing_dtype: torch.dtype,
  485. right: bool,
  486. sorter: Optional[tuple[str, sympy.Expr]] = None,
  487. sorter_indices: Optional[T] = None,
  488. ) -> None:
  489. """Records the names of the buffers that bucketize will read from."""
  490. self._reads.add(StarDep(boundaries[0]))
  491. if sorter is not None:
  492. self._reads.add(StarDep(sorter[0]))
  493. class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined]
  494. def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
  495. parent_handler = _RecordLoadStoreInner(
  496. var_ranges=var_ranges, normalize=normalize
  497. )
  498. super().__init__(parent_handler=parent_handler)
  499. # TODO: check call sites
  500. def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
  501. cnt = itertools.count()
  502. var_ranges: VarRanges = {}
  503. def add_var(length: sympy.Expr) -> sympy.Symbol:
  504. v = sympy_index_symbol(f"{prefix}{next(cnt)}")
  505. var_ranges[v] = length
  506. return v
  507. return var_ranges, add_var
  508. def index_vars_no_squeeze(
  509. *argsizes: Sequence[sympy.Expr], prefix: str
  510. ) -> tuple[list[list[sympy.Symbol]], VarRanges]:
  511. var_ranges, add_var = var_builder(prefix)
  512. args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
  513. return args, var_ranges
  514. def index_vars_squeeze(
  515. *argsizes: Sequence[sympy.Expr], prefix: str = "d"
  516. ) -> tuple[list[Sequence[sympy.Expr]], VarRanges]:
  517. from .ir import SqueezeView
  518. var_ranges, add_var = var_builder(prefix)
  519. args: list[Sequence[sympy.Expr]] = []
  520. new_sizes: list[Sequence[sympy.Expr]] = []
  521. for size in argsizes:
  522. new_size, reindex = SqueezeView.squeezer(size)
  523. new_sizes.append(new_size)
  524. args.append(reindex(list(map(add_var, new_size))))
  525. return args, var_ranges
  526. def extract_read_writes(
  527. fn: Callable[..., Any],
  528. *argsizes: Sequence[sympy.Expr],
  529. normalize: bool = False,
  530. prefix: str = "d",
  531. hidden_args: Sequence[list[sympy.Expr]] = (),
  532. ) -> ReadWrites:
  533. args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
  534. from .loop_body import LoopBody
  535. if isinstance(fn, LoopBody):
  536. inner = extract_loop_body_with_args(
  537. fn,
  538. [*args, *hidden_args], # type: ignore[list-item]
  539. var_ranges,
  540. normalize,
  541. )
  542. else:
  543. # Slow path tracing the function
  544. rw = RecordLoadStore(var_ranges, normalize=normalize)
  545. with V.set_ops_handler(rw):
  546. fn(*args, *hidden_args)
  547. inner = rw.parent_handler
  548. if normalize:
  549. range_vars = [] # Number of vars could differ due to normalization
  550. else:
  551. range_vars = [*itertools.chain.from_iterable(args)]
  552. return ReadWrites(
  553. OrderedSet(inner._reads),
  554. OrderedSet(inner._writes),
  555. inner._index_exprs,
  556. range_vars,
  557. var_ranges,
  558. )
  559. def extract_loop_body_with_args(
  560. fn: Any,
  561. args: list[list[sympy.Expr]],
  562. var_ranges: VarRanges,
  563. normalize: bool = False,
  564. ) -> _RecordLoadStoreInner:
  565. from .loop_body import MemoryUsageType
  566. # Fast path to avoid tracing when we already have a LoopBody
  567. inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize)
  568. name_to_index = fn.indexing_from_args(args)
  569. if fn.indirect_vars:
  570. # mimic the `tmpX` naming tracing gives us
  571. repl = {v: make_symbol(SymT.TMP, i) for i, v in enumerate(fn.indirect_vars)}
  572. name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} # type: ignore[arg-type]
  573. for entry in fn.memory_usage[MemoryUsageType.LOAD]:
  574. inner.load(entry.buffer_name, name_to_index[entry.index_name]) # type: ignore[arg-type]
  575. for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]:
  576. inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type]
  577. for entry in fn.memory_usage[MemoryUsageType.STORE]:
  578. inner.store(
  579. entry.buffer_name,
  580. name_to_index[entry.index_name],
  581. None, # type: ignore[arg-type]
  582. entry.mode,
  583. )
  584. for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
  585. inner.store_reduction(
  586. entry.buffer_name,
  587. name_to_index[entry.index_name],
  588. None, # type: ignore[arg-type]
  589. )
  590. for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
  591. inner.index_expr(name_to_index[entry.index_name], None)
  592. for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]:
  593. # All that matters is that we record the buffer name, so place it in the
  594. # "boundaries" name position to ensure that it's recorded.
  595. inner.bucketize(
  596. None,
  597. (entry.buffer_name, None, None, None),
  598. None,
  599. None, # type: ignore[arg-type]
  600. None, # type: ignore[arg-type]
  601. )
  602. # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
  603. return inner
  604. def extract_input_node_reduction_ranges(
  605. input_node: "torch._inductor.ir.IRNode",
  606. ) -> tuple[Optional[list[sympy.Expr]], Optional[list[sympy.Expr]]]:
  607. """
  608. Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
  609. It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
  610. In this case, reduction_sizes of the Reduction nodes need to be the same.
  611. Otherwise returns (None, None).
  612. """
  613. from .ir import ComputedBuffer, ExternKernel, Loops
  614. size: Optional[list[sympy.Expr]]
  615. reduction_size: Optional[list[sympy.Expr]]
  616. if isinstance(input_node.get_defining_op(), ComputedBuffer):
  617. # Input node has already been realized. Return its size and reduction_size.
  618. size = [*input_node.get_size()]
  619. reduction_size = [*input_node.get_reduction_size()]
  620. if len(reduction_size) > 0:
  621. return (size, reduction_size)
  622. else:
  623. return (None, None)
  624. if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined]
  625. # Other IRNodes do not have reduction_ranges.
  626. return (None, None)
  627. # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes?
  628. # The current method still uses reduction ranges from the dependent realized node, which is not ideal.
  629. # Is there a way to check whether there are permutations in between?
  630. reads = input_node.get_reads()
  631. reduction_size: Optional[list[sympy.Expr]] = None
  632. size: Optional[list[sympy.Expr]] = None
  633. while reduction_size is None and len(reads) > 0:
  634. seen: OrderedSet[str] = OrderedSet()
  635. new_reads: list[Dep] = []
  636. for read in reads:
  637. if not isinstance(read, MemoryDep):
  638. continue
  639. if read.name in seen:
  640. continue
  641. seen.add(read.name)
  642. buffer = V.graph.try_get_buffer(read.name)
  643. if buffer is None:
  644. continue
  645. op = buffer.get_defining_op()
  646. if op is None or isinstance(op, ExternKernel):
  647. continue
  648. if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0:
  649. if reduction_size is None:
  650. reduction_size = [*op.get_reduction_size()]
  651. size = [*op.get_size()]
  652. elif reduction_size != [*op.get_reduction_size()] or size != [
  653. *op.get_size()
  654. ]:
  655. return (None, None)
  656. else:
  657. new_reads.extend(op.get_reads())
  658. if reads == new_reads:
  659. return (size, reduction_size)
  660. else:
  661. reads = OrderedSet(new_reads)
  662. return (size, reduction_size)
  663. def canonicalization_prefix() -> str:
  664. return "c"
  665. # ops handler which computes all the free symbols for an IR
  666. class FreeSymbolsOpsHandler(DefaultHandler):
  667. symbols: OrderedSet[sympy.Symbol]
  668. def __init__(self, unbacked_only: bool = True) -> None:
  669. self.symbols = OrderedSet()
  670. self.get_symbols = free_unbacked_symbols if unbacked_only else free_symbols
  671. def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
  672. for a in itertools.chain(args, kwargs.values()):
  673. if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
  674. self.symbols |= self.get_symbols(a)
  675. def indirect_indexing(
  676. self,
  677. index_var: Any,
  678. size: Union[int, sympy.Expr],
  679. check: bool = True,
  680. wrap_neg: bool = True,
  681. ) -> sympy.Symbol:
  682. assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
  683. self.symbols |= self.get_symbols(size)
  684. return sympy_index_symbol(f"({str(index_var)})")
  685. def frexp(self, x: Any) -> tuple[None, ...]:
  686. return (None,) * 2
  687. def scan(
  688. self, dtypes: Any, combine_fn: Any, values: Sequence[Any]
  689. ) -> tuple[None, ...]:
  690. return (None,) * len(values)
  691. def sort(
  692. self, dtypes: Any, values: Sequence[Any], stable: Any, descending: Any
  693. ) -> tuple[None, ...]:
  694. return (None,) * len(values)
  695. def reduction(
  696. self,
  697. dtype: torch.dtype,
  698. src_dtype: torch.dtype,
  699. reduction_type: ReductionType,
  700. value: Union[None, tuple[None, ...]],
  701. ) -> Union[None, tuple[None, ...]]:
  702. num_values = reduction_num_outputs(reduction_type)
  703. return (None,) * num_values if num_values > 1 else None
  704. def masked(self, mask: Any, body: Callable[..., Any], other: Any) -> None:
  705. assert callable(body), "masked body must always be callable."
  706. # The body can make additional calls, for e.g. ops.indirect_indexing
  707. body()
  708. def extract_free_symbols(
  709. fn: Callable[..., Any],
  710. index: Sequence[sympy.Expr],
  711. rindex: Optional[Sequence[sympy.Expr]] = None,
  712. unbacked_only: bool = True,
  713. ) -> OrderedSet[sympy.Symbol]:
  714. from .ir import FlexibleLayout
  715. args = [index, rindex] if rindex is not None else [index]
  716. handler = FreeSymbolsOpsHandler(unbacked_only)
  717. # NB: I cargo culted the allow_indexing patch here, I don't understand why
  718. # people do this all over
  719. with (
  720. V.set_ops_handler(handler),
  721. patch.object(FlexibleLayout, "allow_indexing", True),
  722. ):
  723. fn(*args)
  724. return handler.symbols