_triton_ops.py 84 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import math
  4. import os
  5. import weakref
  6. from functools import lru_cache
  7. from typing import Optional
  8. import torch
  9. from torch._dynamo.utils import warn_once
  10. from torch.utils._triton import has_triton
  11. from ._triton_ops_meta import get_meta
  12. TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int(
  13. os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2)
  14. )
  15. def check(cond, msg):
  16. if not cond:
  17. raise ValueError(msg)
  18. def check_bsr_layout(f_name, t):
  19. check(
  20. t.layout == torch.sparse_bsr,
  21. f"{f_name}(): only BSR sparse format is supported for the sparse argument.",
  22. )
  23. def check_device(f_name, t, device):
  24. check(
  25. t.device == device and t.device.type == "cuda",
  26. f"{f_name}(): all inputs are expected to be on the same GPU device.",
  27. )
  28. def check_mm_compatible_shapes(f_name, lhs, rhs):
  29. check(
  30. lhs.dim() >= 2 and rhs.dim() >= 2,
  31. f"{f_name}(): all inputs involved in the matrix product are expected to be at least 2D, "
  32. f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}.",
  33. )
  34. _m, kl = lhs.shape[-2:]
  35. kr, _n = rhs.shape[-2:]
  36. check(
  37. kl == kr,
  38. f"{f_name}(): arguments' sizes involved in the matrix product are not compatible for matrix multiplication, "
  39. f"got lhs.shape[-1] == {kl} which is not equal to rhs.shape[-2] == {kr}.",
  40. )
  41. def check_dtype(f_name, t, dtype, *additional_dtypes):
  42. check(
  43. t.dtype == dtype
  44. and t.dtype
  45. in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)),
  46. f"{f_name}(): all inputs are expected to be of the same dtype "
  47. f"and one of (half, bfloat16, float32) or {additional_dtypes}, "
  48. f"but got dtype == {t.dtype}.",
  49. )
  50. def check_blocksize(f_name, blocksize):
  51. assert len(blocksize) == 2
  52. def is_power_of_two(v):
  53. return not (v & (v - 1))
  54. def is_compatible_blocksize(b):
  55. res = True
  56. for blocksize in b:
  57. # Triton loads only blocks which are at least 16 and powers of 2.
  58. res = (blocksize >= 16 and is_power_of_two(blocksize)) and res
  59. return res
  60. check(
  61. is_compatible_blocksize(blocksize),
  62. f"{f_name}(): sparse inputs' blocksize ({blocksize[0]}, {blocksize[1]}) "
  63. "should be at least 16 and a power of 2 in each dimension.",
  64. )
  65. def make_triton_contiguous(t):
  66. """Return input as a triton-contiguous tensor.
  67. A triton-contiguous tensor is defined as a tensor that has strides
  68. with minimal value smaller than or equal to 1.
  69. While triton kernels support triton-non-contiguous tensors (all
  70. strides being greater than 1) arguments, a considerable slow-down
  71. occurs because tensor data is copied element-wise rather than
  72. chunk-wise. Zero strides is assumed to not have this defect.
  73. """
  74. if min(t.stride()) > 1:
  75. # TODO: investigate if contiguity along other axes than the
  76. # last one can be beneficial for performance
  77. return t.contiguous()
  78. else:
  79. return t
  80. def broadcast_batch_dims(f_name, *tensors):
  81. try:
  82. return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors))
  83. except Exception:
  84. check(False, f"{f_name}(): inputs' batch dimensions are not broadcastable!")
  85. def slicer(dim, slice_range, *tensors):
  86. for t in tensors:
  87. slices = [slice(None)] * t.dim()
  88. slices[dim] = slice_range
  89. yield t[slices]
  90. def multidim_slicer(dims, slices, *tensors):
  91. for t in tensors:
  92. s = [slice(None)] * t.dim()
  93. for d, d_slice in zip(dims, slices):
  94. if d is not None:
  95. s[d] = d_slice
  96. yield t[tuple(s)]
  97. def ptr_stride_extractor(*tensors):
  98. for t in tensors:
  99. yield t
  100. yield from t.stride()
  101. def grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
  102. assert 0 <= len(full_grid) <= 3
  103. assert 0 <= len(grid_blocks) <= 3
  104. import itertools
  105. def generate_grid_points():
  106. for fg, mg in zip(full_grid, grid_blocks):
  107. yield range(0, fg, mg)
  108. def generate_sliced_tensors(slices):
  109. for t, t_dims in tensor_dims_map.items():
  110. yield next(multidim_slicer(t_dims, slices, t))
  111. for grid_point in itertools.product(*generate_grid_points()):
  112. grid = [
  113. min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks)
  114. ]
  115. slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)]
  116. # grid_points are iterated in a "contiguous" order, i.e.
  117. # left dimensions traversed slower than right dimensions.
  118. # This order is reversed for CUDA grids.
  119. yield grid[::-1], *generate_sliced_tensors(slices)
  120. def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None):
  121. # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1)
  122. cuda_max_grid = (2147483647, 65535, 65535)[::-1]
  123. if grid_blocks is None:
  124. grid_blocks = cuda_max_grid
  125. else:
  126. def valid_grid_dim(g, mg):
  127. if g is None:
  128. return mg
  129. else:
  130. # grid must be at least 1 and no greater than mg
  131. return max(1, min(g, mg))
  132. grid_blocks = tuple(
  133. valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid)
  134. ) # type: ignore[assignment]
  135. for grid, *sliced_tensors in grid_partitioner(
  136. full_grid, grid_blocks, tensor_dims_map
  137. ):
  138. kernel(grid, *sliced_tensors)
  139. def prepare_inputs(bsr, *dense_tensors):
  140. # Introduce fake batch dimension if not present for convenience.
  141. crow_indices = bsr.crow_indices().unsqueeze(0)
  142. col_indices = bsr.col_indices().unsqueeze(0)
  143. values = make_triton_contiguous(bsr.values().unsqueeze(0))
  144. tensors = [make_triton_contiguous(t.unsqueeze(0)) for t in dense_tensors]
  145. # Compute broadcasted batch dimension
  146. batch_dims_broadcasted = torch.broadcast_shapes(
  147. values.shape[:-3], *(t.shape[:-2] for t in tensors)
  148. )
  149. # Broadcast batch dimensions and squash.
  150. # The result can be either a view or a copy.
  151. def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
  152. return t.broadcast_to(batch_dims + invariant_dims).flatten(
  153. 0, len(batch_dims) - 1
  154. )
  155. crow_indices = batch_broadcast_and_squash(
  156. crow_indices, batch_dims_broadcasted, (-1,)
  157. )
  158. col_indices = batch_broadcast_and_squash(col_indices, batch_dims_broadcasted, (-1,))
  159. values = batch_broadcast_and_squash(
  160. values, batch_dims_broadcasted, values.shape[-3:]
  161. )
  162. tensors = [
  163. batch_broadcast_and_squash(t, batch_dims_broadcasted, t.shape[-2:])
  164. for t in tensors
  165. ]
  166. return crow_indices, col_indices, values, *tensors
  167. def broadcast_batch_dims_bsr(f_name, bsr, *tensors):
  168. batch_shape = broadcast_batch_dims(f_name, bsr, *tensors)
  169. crow_indices = bsr.crow_indices().broadcast_to(batch_shape + (-1,))
  170. col_indices = bsr.col_indices().broadcast_to(batch_shape + (-1,))
  171. values = bsr.values().broadcast_to(batch_shape + bsr.values().shape[-3:])
  172. size = batch_shape + bsr.shape[-2:]
  173. return torch.sparse_compressed_tensor(
  174. crow_indices, col_indices, values, size=size, layout=bsr.layout
  175. )
  176. # NOTE: this function will ALWAYS create a view
  177. def tile_to_blocksize(t, blocksize):
  178. *rest, m, n = t.shape
  179. new_shape = rest + [
  180. m // blocksize[0],
  181. blocksize[0],
  182. n // blocksize[1],
  183. blocksize[1],
  184. ]
  185. # using .view instead of .reshape to ensure that the result is
  186. # indeed a view:
  187. return t.view(new_shape).transpose(-3, -2)
  188. def as1Dbatch(tensor):
  189. """Return tensor as 3D tensor by either prepending new dimensions to
  190. the tensor shape (when ``tensor.ndim < 3``), or by collapsing
  191. starting dimensions into the first dimension (when ``tensor.ndim >
  192. 3``).
  193. """
  194. while tensor.ndim < 3:
  195. tensor = tensor.unsqueeze(0)
  196. if tensor.ndim > 3:
  197. tensor = tensor.flatten(0, tensor.ndim - 3)
  198. assert tensor.ndim == 3, tensor.shape
  199. return tensor
  200. def scatter_mm(blocks, others, indices_data, *, accumulators=None):
  201. """Scattered matrix multiplication of tensors.
  202. A scattered matrix multiplication is defined as a series of matrix
  203. multiplications applied to input tensors according to the input
  204. and output mappings specified by indices data.
  205. The following indices data formats are supported for defining a
  206. scattered matrix multiplication operation (:attr:`indices_data[0]`
  207. holds the name of the indices data format as specified below):
  208. - ``"scatter_mm"`` - matrix multiplications scattered in batches
  209. of tensors.
  210. If :attr:`blocks` is a :math:`(* \times M \times K) tensor,
  211. :attr:`others` is a :math:`(* \times K \times N)` tensor,
  212. :attr:`accumulators` is a :math:`(* \times M \times N)` tensor,
  213. and :attr:`indices = indices_data['indices']` is a :math:`(*
  214. \times 3)` tensor, then the operation is equivalent to the
  215. following code::
  216. c_offsets, pq = indices_data[1:]
  217. for r in range(len(c_offsets) - 1):
  218. for g in range(c_offsets[r], c_offsets[r + 1]):
  219. p, q = pq[g]
  220. accumulators[r] += blocks[p] @ others[q]
  221. - ``"bsr_strided_mm"`` - matrix multiplications scattered in
  222. batches of tensors and a tensor.
  223. If :attr:`blocks` is a :math:`(Ms \times Ks) tensor,
  224. :attr:`others` is a :math:`(* \times K \times N)` tensor,
  225. :attr:`accumulators` is a :math:`(* \times M \times N)` tensor, then
  226. the operation is equivalent to the following code::
  227. c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
  228. for b in range(nbatches):
  229. for i, r in enumerate(r_offsets):
  230. r0, r1 = divmod(r, N)
  231. acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
  232. for g in range(c_indices[i], c_indices[i + 1]):
  233. p = p_offsets[g]
  234. q0, q1 = divmod(q_offsets[g], N)
  235. acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
  236. where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
  237. integer multiples of ``Ms`` and ``Ks``, respectively.
  238. - ``"bsr_strided_mm_compressed"`` - matrix multiplications
  239. scattered in batches of tensors and a tensor. A memory and
  240. processor efficient version of ``"bsr_strided_mm"`` format. If
  241. :attr:`blocks` is a :math:`(Ms \times Ks) tensor, :attr:`others`
  242. is a :math:`(* \times K \times N)` tensor, :attr:`accumulators`
  243. is a :math:`(* \times M \times N)` tensor, then the operation is
  244. equivalent to the following code::
  245. c_indices, r_offsets, q_offsets, meta = indices_data[1:]
  246. for b in range(nbatches):
  247. for r in r_offsets:
  248. m = (r // N) // Ms
  249. n = (r % N) // Ns
  250. r0, r1 = divmod(r, N)
  251. c0, c1 = c_indices[m], c_indices[m + 1]
  252. acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
  253. for i, p in enumerate(range(c0, c1)):
  254. q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i]
  255. q0, q1 = divmod(q, N)
  256. acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
  257. where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are
  258. integer multiples of ``Ms`` and ``Ks``, respectively.
  259. Notice that the order of ``r_offsets`` items can be arbitrary;
  260. this property enables defining swizzle operators via
  261. rearrangements of ``r_offsets`` items..
  262. Auxiliary functions are provided for pre-computing
  263. :attr:`indices_data`. For example,
  264. :func:`bsr_scatter_mm_indices_data` is used to define indices data
  265. for matrix multiplication of BSR and strided tensors.
  266. Parameters
  267. ----------
  268. blocks (Tensor): a 3-D tensor of first matrices to be multiplied
  269. others (Tensor): a tensor of second matrices to be multiplied. If
  270. ``indices_data[0]=="scatter_mm"``, the tensor is a 1-D batch
  271. tensor of second input matrices to be multiplied. Otherwise, the
  272. second input matrices are slices of the :attr:`others` tensor.
  273. indices_data (tuple): a format data that defines the inputs and
  274. outputs of scattered matrix multiplications.
  275. Keyword arguments
  276. -----------------
  277. accumulators (Tensor, optional): a tensor of matrix product
  278. accumulators. If ``indices_data[0]=="scatter_mm"``, the tensor
  279. is a 1-D batch tensor of output matrices. Otherwise, output
  280. matrices are slices of the :attr:`accumulators` tensor.
  281. """
  282. indices_format = indices_data[0]
  283. assert blocks.ndim == 3
  284. _P, Ms, Ks = blocks.shape
  285. if indices_format == "scatter_mm":
  286. c_offsets, pq = indices_data[1:]
  287. assert others.ndim == 3
  288. _Q, Ks_, Ns = others.shape
  289. assert Ks == Ks_
  290. if accumulators is None:
  291. R = c_offsets.shape[0] - 1
  292. accumulators = torch.zeros(
  293. (R, Ms, Ns), dtype=blocks.dtype, device=blocks.device
  294. )
  295. else:
  296. R, Ms_, Ns_ = accumulators.shape
  297. assert Ms_ == Ms
  298. assert Ns_ == Ns
  299. if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm2 is None:
  300. for r in range(c_offsets.shape[0] - 1):
  301. g0 = c_offsets[r]
  302. g1 = c_offsets[r + 1]
  303. for g in range(g0, g1):
  304. p, q = pq[g]
  305. accumulators[r] += blocks[p] @ others[q]
  306. else:
  307. _scatter_mm2(blocks, others, c_offsets, pq, accumulators)
  308. return accumulators
  309. elif indices_format == "bsr_strided_mm":
  310. others_shape = others.shape
  311. others = as1Dbatch(others)
  312. B, K, N = others.shape
  313. assert K % Ks == 0
  314. c_indices, r_offsets, p_offsets, q_offsets, meta = indices_data[1:]
  315. SPLIT_N = meta["SPLIT_N"]
  316. if accumulators is None:
  317. M = Ms + (r_offsets.max().item() + 1) // N
  318. accumulators = torch.zeros(
  319. (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device
  320. )
  321. else:
  322. M, N_ = accumulators.shape[-2:]
  323. assert N_ == N
  324. accumulators_shape = accumulators.shape
  325. accumulators = as1Dbatch(accumulators)
  326. Ns = N // SPLIT_N
  327. if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
  328. accumulators.zero_()
  329. for b in range(B):
  330. for r in range(r_offsets.shape[0]):
  331. r_ = r_offsets[r].item()
  332. g0 = c_indices[r].item()
  333. g1 = c_indices[r + 1].item()
  334. r0, r1 = divmod(r_, N)
  335. acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
  336. for g in range(g0, g1):
  337. p, q = p_offsets[g], q_offsets[g]
  338. q0, q1 = divmod(q.item(), N)
  339. acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
  340. else:
  341. _scatter_mm6(
  342. blocks,
  343. others,
  344. c_indices,
  345. r_offsets,
  346. p_offsets,
  347. q_offsets,
  348. meta,
  349. accumulators,
  350. )
  351. return accumulators.view(accumulators_shape)
  352. elif indices_format == "bsr_strided_mm_compressed":
  353. others_shape = others.shape
  354. others = as1Dbatch(others)
  355. B, K, N = others.shape
  356. assert K % Ks == 0
  357. c_indices, r_offsets, q_offsets, meta = indices_data[1:]
  358. SPLIT_N = meta["SPLIT_N"]
  359. if accumulators is None:
  360. M = Ms + (r_offsets.max().item() + 1) // N
  361. accumulators = torch.zeros(
  362. (*others_shape[:-2], M, N), dtype=blocks.dtype, device=blocks.device
  363. )
  364. else:
  365. M, N_ = accumulators.shape[-2:]
  366. assert N_ == N
  367. accumulators_shape = accumulators.shape
  368. accumulators = as1Dbatch(accumulators)
  369. Ns = N // SPLIT_N
  370. if Ms % 16 or Ks % 16 or Ns % 16 or _scatter_mm6 is None:
  371. for b in range(B):
  372. for j in range(len(r_offsets)):
  373. r0, r1 = divmod(r_offsets[j].item(), N)
  374. m = r0 // Ms
  375. n = r1 // Ns
  376. c0 = c_indices[m].item()
  377. c1 = c_indices[m + 1].item()
  378. acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns]
  379. for i, p in enumerate(range(c0, c1)):
  380. q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i].item()
  381. q0, q1 = divmod(q, N)
  382. acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns]
  383. else:
  384. p_offsets = torch.empty(
  385. (0,), dtype=q_offsets.dtype, device=q_offsets.device
  386. )
  387. _scatter_mm6(
  388. blocks,
  389. others,
  390. c_indices,
  391. r_offsets,
  392. p_offsets,
  393. q_offsets,
  394. meta,
  395. accumulators,
  396. )
  397. return accumulators.view(accumulators_shape)
  398. else:
  399. raise NotImplementedError(indices_format)
  400. def scatter_mm_meta(
  401. M,
  402. K,
  403. N,
  404. Ms,
  405. Ks,
  406. GROUP_SIZE=None,
  407. TILE_M=None,
  408. TILE_N=None,
  409. SPLIT_N=None,
  410. num_warps=None,
  411. num_stages=None,
  412. **extra,
  413. ):
  414. if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}:
  415. device_name = torch.cuda.get_device_name()
  416. meta = get_meta(
  417. "scatter_mm",
  418. (M, K, N, Ms, Ks),
  419. device_name,
  420. version=(0, torch.float16, 0.5),
  421. )
  422. if meta is not None:
  423. meta.update(**extra)
  424. return meta
  425. # The following parameters are optimized for the performance
  426. # equilibrium points of bsr-dense and dense-dense matrix
  427. # multiplications when using GPU card NVIDIA GeForce RTX 2060
  428. # SUPER. For points far from the performance equilibrium
  429. # points as well as for other GPU cards, the optimal
  430. # parameters are likely different from what specified below.
  431. if (M, K, N) == (256,) * 3:
  432. if (Ms, Ks) == (16, 16):
  433. SPLIT_N = 1
  434. TILE_M = 16
  435. TILE_N = 16
  436. GROUP_SIZE = 4
  437. num_stages = 1
  438. num_warps = 4 # noqa: E225,E231,E702
  439. elif (Ms, Ks) == (32, 32):
  440. SPLIT_N = 2
  441. TILE_M = 32
  442. TILE_N = 16
  443. GROUP_SIZE = 4
  444. num_stages = 1
  445. num_warps = 4 # noqa: E225,E231,E702
  446. elif (Ms, Ks) == (64, 64):
  447. SPLIT_N = 1
  448. TILE_M = 32
  449. TILE_N = 32
  450. GROUP_SIZE = 4
  451. num_stages = 1
  452. num_warps = 4 # noqa: E225,E231,E702
  453. elif (Ms, Ks) == (128, 128):
  454. SPLIT_N = 1
  455. TILE_M = 32
  456. TILE_N = 32
  457. GROUP_SIZE = 2
  458. num_stages = 1
  459. num_warps = 4 # noqa: E225,E231,E702
  460. elif (M, K, N) == (512,) * 3:
  461. if (Ms, Ks) == (16, 16):
  462. SPLIT_N = 8
  463. TILE_M = 16
  464. TILE_N = 64
  465. GROUP_SIZE = 2
  466. num_stages = 1
  467. num_warps = 2 # noqa: E225,E231,E702
  468. elif (Ms, Ks) == (32, 32):
  469. SPLIT_N = 8
  470. TILE_M = 32
  471. TILE_N = 64
  472. GROUP_SIZE = 4
  473. num_stages = 1
  474. num_warps = 2 # noqa: E225,E231,E702
  475. elif (Ms, Ks) == (64, 64):
  476. SPLIT_N = 4
  477. TILE_M = 32
  478. TILE_N = 128
  479. GROUP_SIZE = 4
  480. num_stages = 1
  481. num_warps = 4 # noqa: E225,E231,E702
  482. elif (Ms, Ks) == (128, 128):
  483. SPLIT_N = 8
  484. TILE_M = 64
  485. TILE_N = 64
  486. GROUP_SIZE = 4
  487. num_stages = 1
  488. num_warps = 4 # noqa: E225,E231,E702
  489. elif (M, K, N) == (1024,) * 3:
  490. if (Ms, Ks) == (16, 16):
  491. SPLIT_N = 4
  492. TILE_M = 16
  493. TILE_N = 128
  494. GROUP_SIZE = 2
  495. num_stages = 1
  496. num_warps = 1 # noqa: E225,E231,E702
  497. elif (Ms, Ks) == (32, 32):
  498. SPLIT_N = 8
  499. TILE_M = 32
  500. TILE_N = 64
  501. GROUP_SIZE = 2
  502. num_stages = 1
  503. num_warps = 1 # noqa: E225,E231,E702
  504. elif (Ms, Ks) == (64, 64):
  505. SPLIT_N = 16
  506. TILE_M = 64
  507. TILE_N = 64
  508. GROUP_SIZE = 4
  509. num_stages = 1
  510. num_warps = 2 # noqa: E225,E231,E702
  511. elif (Ms, Ks) == (128, 128):
  512. SPLIT_N = 16
  513. TILE_M = 64
  514. TILE_N = 64
  515. GROUP_SIZE = 4
  516. num_stages = 1
  517. num_warps = 4 # noqa: E225,E231,E702
  518. elif (Ms, Ks) == (256, 256):
  519. SPLIT_N = 16
  520. TILE_M = 64
  521. TILE_N = 64
  522. GROUP_SIZE = 2
  523. num_stages = 1
  524. num_warps = 4 # noqa: E225,E231,E702
  525. elif (M, K, N) == (2048,) * 3:
  526. if (Ms, Ks) == (16, 16):
  527. SPLIT_N = 4
  528. TILE_M = 16
  529. TILE_N = 128
  530. GROUP_SIZE = 8
  531. num_stages = 1
  532. num_warps = 1 # noqa: E225,E231,E702
  533. elif (Ms, Ks) == (32, 32):
  534. SPLIT_N = 4
  535. TILE_M = 32
  536. TILE_N = 64
  537. GROUP_SIZE = 4
  538. num_stages = 1
  539. num_warps = 1 # noqa: E225,E231,E702
  540. elif (Ms, Ks) == (64, 64):
  541. SPLIT_N = 4
  542. TILE_M = 64
  543. TILE_N = 128
  544. GROUP_SIZE = 4
  545. num_stages = 1
  546. num_warps = 4 # noqa: E225,E231,E702
  547. elif (Ms, Ks) == (128, 128):
  548. SPLIT_N = 8
  549. TILE_M = 64
  550. TILE_N = 64
  551. GROUP_SIZE = 4
  552. num_stages = 1
  553. num_warps = 4 # noqa: E225,E231,E702
  554. elif (Ms, Ks) == (256, 256):
  555. SPLIT_N = 4
  556. TILE_M = 64
  557. TILE_N = 64
  558. GROUP_SIZE = 2
  559. num_stages = 1
  560. num_warps = 4 # noqa: E225,E231,E702
  561. elif (M, K, N) == (4096,) * 3:
  562. if (Ms, Ks) == (16, 16):
  563. SPLIT_N = 2
  564. TILE_M = 16
  565. TILE_N = 256
  566. GROUP_SIZE = 2
  567. num_stages = 1
  568. num_warps = 2 # noqa: E225,E231,E702
  569. elif (Ms, Ks) == (32, 32):
  570. SPLIT_N = 2
  571. TILE_M = 32
  572. TILE_N = 64
  573. GROUP_SIZE = 2
  574. num_stages = 1
  575. num_warps = 1 # noqa: E225,E231,E702
  576. elif (Ms, Ks) == (64, 64):
  577. SPLIT_N = 2
  578. TILE_M = 64
  579. TILE_N = 128
  580. GROUP_SIZE = 2
  581. num_stages = 1
  582. num_warps = 4 # noqa: E225,E231,E702
  583. if SPLIT_N is None:
  584. # Assume NVIDIA GeForce RTX 2060 SUPER:
  585. # With the probality of 92% (99.9% when N > 512), the
  586. # performance will not be worse more than 2% from the
  587. # performance when using an optimal value. Otherwise, when N
  588. # <= 512, using the following heuristics may give upto 15%
  589. # lower performance.
  590. SPLIT_N = {
  591. 16: 1,
  592. 32: 2,
  593. 64: 4,
  594. 128: 8,
  595. 256: 16,
  596. 512: 8,
  597. 1024: 16,
  598. 4096: 32,
  599. 8192: 64,
  600. }.get(N, 16)
  601. if Ms >= 512 and N >= 2048:
  602. SPLIT_N = 1
  603. Ns = N // SPLIT_N
  604. if TILE_M is None:
  605. TILE_M = min(64 if Ns < 512 else 32, Ms)
  606. if TILE_N is None:
  607. TILE_N = min(64 if Ns < 512 else 32, Ns)
  608. num_stages = num_stages or 1
  609. if num_warps is None:
  610. if min(M, N) > 1024:
  611. num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
  612. elif min(M, N) == 1024:
  613. num_warps = {16: 1, 32: 1, 64: 2}.get(Ms, 4)
  614. elif min(M, N) == 256:
  615. num_warps = {16: 1, 32: 4}.get(Ms, 4)
  616. else:
  617. num_warps = {16: 1, 32: 2}.get(Ms, 4)
  618. GROUP_SIZE = GROUP_SIZE or 4
  619. assert TILE_M <= Ms, dict(TILE_M=TILE_M, Ms=Ms)
  620. assert TILE_N <= Ns, dict(TILE_N=TILE_N, Ns=Ns)
  621. assert Ms <= M, dict(M=M, Ms=Ms)
  622. assert Ns <= N, dict(N=N, Ns=Ns)
  623. assert Ks <= K, dict(K=K, Ks=Ks)
  624. return dict(
  625. TILE_M=TILE_M,
  626. TILE_N=TILE_N,
  627. GROUP_SIZE=GROUP_SIZE,
  628. num_stages=num_stages,
  629. num_warps=num_warps,
  630. SPLIT_N=SPLIT_N,
  631. **extra,
  632. )
  633. def bsr_dense_addmm_meta(
  634. M,
  635. K,
  636. N,
  637. Ms,
  638. Ks,
  639. beta,
  640. alpha,
  641. SPLIT_N=None,
  642. GROUP_SIZE_ROW=None,
  643. num_warps=None,
  644. num_stages=None,
  645. sparsity=None,
  646. dtype=None,
  647. out_dtype=None,
  648. _version=0,
  649. **extra,
  650. ):
  651. # Specifying _version is useful for situations when one wants to
  652. # discard existing triton kernel tuning results, say, in testing
  653. # bsr_dense_addmm_meta functionality.
  654. if dtype is None:
  655. dtype = torch.float16
  656. if out_dtype is None:
  657. out_dtype = dtype
  658. if sparsity is None:
  659. sparsity = 0.5
  660. if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}:
  661. device_name = torch.cuda.get_device_name()
  662. key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
  663. if dtype is out_dtype:
  664. version_dtype = dtype
  665. else:
  666. version_dtype = dtype, out_dtype
  667. meta = get_meta(
  668. "bsr_dense_addmm",
  669. key,
  670. device_name,
  671. version=(_version, version_dtype, sparsity),
  672. )
  673. if meta is None and sparsity != 0.5:
  674. meta = get_meta(
  675. "bsr_dense_addmm",
  676. key,
  677. device_name,
  678. version=(_version, version_dtype, 0.5),
  679. )
  680. if meta is None and dtype is not out_dtype:
  681. meta = get_meta(
  682. "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5)
  683. )
  684. if meta is None:
  685. # find approximate meta such that N % SPLIT_N == 0.
  686. matching_meta = get_meta(
  687. "bsr_dense_addmm",
  688. (*key[:2], "*", *key[3:]),
  689. device_name,
  690. version=(_version, version_dtype, 0.5),
  691. )
  692. if matching_meta is None and dtype is not out_dtype:
  693. matching_meta = get_meta(
  694. "bsr_dense_addmm",
  695. (*key[:2], "*", *key[3:]),
  696. device_name,
  697. version=(_version, dtype, 0.5),
  698. )
  699. for mkey in sorted(matching_meta or {}):
  700. meta_ = matching_meta[mkey]
  701. n = mkey[2]
  702. split_n = meta_["SPLIT_N"]
  703. c = n // split_n
  704. if N % c == 0 and n <= N:
  705. meta = dict(meta_)
  706. meta["SPLIT_N"] = N // c
  707. if meta is not None:
  708. meta.update(**extra)
  709. return meta
  710. else:
  711. # see [Computing optimal kernel parameters] in
  712. # _triton_ops_meta.py for ways to avoid this warning
  713. # message
  714. warn_once(
  715. "bsr_dense_addmm uses non-optimal triton kernel parameters"
  716. f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}"
  717. )
  718. SPLIT_N = SPLIT_N or max(N // Ms, 1)
  719. GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4
  720. num_stages = num_stages or 1
  721. num_warps = num_warps or 4
  722. return dict(
  723. SPLIT_N=SPLIT_N,
  724. GROUP_SIZE_ROW=GROUP_SIZE_ROW,
  725. num_stages=num_stages,
  726. num_warps=num_warps,
  727. **extra,
  728. )
  729. class TensorAsKey:
  730. """A light-weight wrapper of a tensor that enables storing tensors as
  731. keys with efficient memory reference based comparison as an
  732. approximation to data equality based keys.
  733. Motivation: the hash value of a torch tensor is tensor instance
  734. based that does not use data equality and makes the usage of
  735. tensors as keys less useful. For instance, the result of
  736. ``len({a.crow_indices(), a.crow_indices()})`` is `2`, although,
  737. the tensor results from `crow_indices` method call are equal, in
  738. fact, these share the same data storage.
  739. On the other hand, for efficient caching of tensors we want to
  740. avoid calling torch.equal that compares tensors item-wise.
  741. TensorAsKey offers a compromise in that it guarantees key equality
  742. of tensors that references data in the same storage in the same
  743. manner and without accessing underlying data. However, this
  744. approach does not always guarantee correctness. For instance, for
  745. a complex tensor ``x``, we have ``TensorAsKey(x) ==
  746. TensorAsKey(x.conj())`` while ``torch.equal(x, x.conj())`` would
  747. return False.
  748. """
  749. def __init__(self, obj):
  750. def get_tensor_key(obj):
  751. # Warning: TensorAsKey does not track negative nor
  752. # conjugate bits of its input object because in the use
  753. # case of wrapping compressed/plain indices of compressed
  754. # sparse tensors (that are always integer tensors with
  755. # non-negative items) these bits are never set. However,
  756. # when extending the use of TensorAsKey to float or
  757. # complex tensors, the values of these bits (see is_neg
  758. # and is_conj methods) must be included in the key as
  759. # well.
  760. assert not (obj.dtype.is_floating_point or obj.dtype.is_complex), obj.dtype
  761. return (
  762. obj.data_ptr(),
  763. obj.storage_offset(),
  764. obj.shape,
  765. obj.stride(),
  766. obj.dtype,
  767. )
  768. self._obj_ref = weakref.ref(obj)
  769. if obj.layout is torch.strided:
  770. self.key = get_tensor_key(obj)
  771. elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
  772. self.key = (
  773. get_tensor_key(obj.crow_indices()),
  774. get_tensor_key(obj.col_indices()),
  775. )
  776. elif obj.layout in {torch.sparse_csc, torch.sparse_bsc}:
  777. self.key = (
  778. get_tensor_key(obj.ccol_indices()),
  779. get_tensor_key(obj.row_indices()),
  780. )
  781. else:
  782. raise NotImplementedError(obj.layout)
  783. self._hash = hash(self.key)
  784. def __hash__(self):
  785. return self._hash
  786. def __eq__(self, other):
  787. if not isinstance(other, TensorAsKey):
  788. return False
  789. if self.obj is None or other.obj is None:
  790. # dead objects always compare unequal unless these are
  791. # same objects
  792. return self is other
  793. return self.key == other.key
  794. @property
  795. def obj(self):
  796. """Return object if alive, otherwise None."""
  797. return self._obj_ref()
  798. @lru_cache(maxsize=TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE)
  799. def _bsr_scatter_mm_indices_data(
  800. indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, compressed_sparse_tensor_as_key
  801. ):
  802. bsr = compressed_sparse_tensor_as_key.obj
  803. assert bsr is not None
  804. crow_indices, col_indices = bsr.crow_indices(), bsr.col_indices()
  805. device = crow_indices.device
  806. indices_dtype = torch.int32
  807. if indices_format == "bsr_strided_mm_compressed":
  808. Ns = N // SPLIT_N
  809. q_offsets_lst = []
  810. b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns
  811. for m in range(M // Ms):
  812. r0 = crow_indices[m].item()
  813. r1 = crow_indices[m + 1].item()
  814. if r1 == r0:
  815. continue
  816. q_offsets_lst.append(
  817. (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N)
  818. + b.repeat_interleave(r1 - r0)
  819. )
  820. q_offsets = torch.cat(q_offsets_lst)
  821. crow_indices_diff = crow_indices.diff()
  822. non_zero_row_indices = crow_indices_diff.nonzero()
  823. a = non_zero_row_indices * (Ms * N)
  824. r_offsets = (a + b).view(-1)
  825. c_indices = crow_indices
  826. # swizzle operation: mm elements with longer sums are computed first:
  827. nnz_per_row = crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N)
  828. nnz_per_row, indices = nnz_per_row.sort(descending=True, stable=True)
  829. r_offsets = r_offsets[indices]
  830. return (indices_format, c_indices, r_offsets, q_offsets)
  831. elif indices_format == "bsr_strided_mm":
  832. Ns = N // SPLIT_N
  833. p_offsets_lst = []
  834. q_offsets_lst = []
  835. b = torch.arange(SPLIT_N, dtype=indices_dtype, device=device) * Ns
  836. for m in range(M // Ms):
  837. r0 = crow_indices[m].item()
  838. r1 = crow_indices[m + 1].item()
  839. if r1 == r0:
  840. continue
  841. p_offsets_lst.append(
  842. torch.arange(r0, r1, dtype=indices_dtype, device=device).repeat(SPLIT_N)
  843. )
  844. q_offsets_lst.append(
  845. (col_indices[r0:r1] * (Ks * N)).repeat(SPLIT_N)
  846. + b.repeat_interleave(r1 - r0)
  847. )
  848. q_offsets = torch.cat(q_offsets_lst)
  849. crow_indices_diff = crow_indices.diff()
  850. non_zero_row_indices = crow_indices_diff.nonzero()
  851. a = non_zero_row_indices * (Ms * N)
  852. r_offsets = (a + b).view(-1)
  853. c_indices = torch.cat(
  854. (
  855. crow_indices[:1],
  856. torch.cumsum(
  857. crow_indices_diff[non_zero_row_indices].repeat_interleave(SPLIT_N),
  858. 0,
  859. ),
  860. )
  861. )
  862. p_offsets = torch.cat(p_offsets_lst)
  863. return (indices_format, c_indices, r_offsets, p_offsets, q_offsets)
  864. elif indices_format == "scatter_mm":
  865. Ns = Ms
  866. c_indices = [0]
  867. pq_offsets = []
  868. # todo: eliminate inner for-loops for efficiency
  869. for b in range(nbatches):
  870. for m in range(M // Ms):
  871. r0 = crow_indices[m].item()
  872. r1 = crow_indices[m + 1].item()
  873. for n in range(N // Ns):
  874. c_indices.append(c_indices[-1] + r1 - r0)
  875. for t in range(r1 - r0):
  876. p = r0 + t
  877. q = (col_indices[p].item() + b * (K // Ks)) * (N // Ns) + n
  878. pq_offsets.append([p, q])
  879. return (
  880. indices_format,
  881. torch.tensor(c_indices, dtype=indices_dtype, device=device),
  882. torch.tensor(pq_offsets, dtype=indices_dtype, device=device),
  883. )
  884. else:
  885. raise ValueError(
  886. f"Invalid {indices_format=}. Expected bsr_strided_mm_compressed|bsr_strided_mm|scatter_mm"
  887. )
  888. def bsr_scatter_mm_indices_data(
  889. bsr, other, indices_format="bsr_strided_mm_compressed", **meta_input
  890. ):
  891. """Computes indices data for :func:`scatter_mm` used in BSR and
  892. strided tensor matrix multiplication.
  893. """
  894. assert bsr.dense_dim() == 0
  895. assert bsr.ndim == 2 # no batch dims
  896. blocksize = bsr.values().shape[-2:]
  897. M, K = bsr.shape
  898. Ms, Ks = blocksize
  899. K_, N = other.shape[-2:]
  900. assert K_ == K
  901. nbatches = other.shape[:-2].numel()
  902. meta = scatter_mm_meta(M, K, N, Ms, Ks, **meta_input)
  903. if "allow_tf32" not in meta_input:
  904. meta.update(allow_tf32=bsr.dtype in {torch.float16, torch.bfloat16})
  905. SPLIT_N = meta["SPLIT_N"]
  906. indices_data = _bsr_scatter_mm_indices_data(
  907. indices_format, M, K, N, Ms, Ks, nbatches, SPLIT_N, TensorAsKey(bsr)
  908. )
  909. if indices_format == "bsr_strided_mm_compressed":
  910. meta.update(is_compressed=True)
  911. return indices_data + (meta,)
  912. elif indices_format == "bsr_strided_mm":
  913. meta.update(is_compressed=False)
  914. return indices_data + (meta,)
  915. else:
  916. return indices_data
  917. def bsr_scatter_mm(bsr, other, indices_data=None, out=None):
  918. """BSR @ strided -> strided"""
  919. assert bsr.ndim == 2
  920. assert other.ndim >= 2
  921. Ms, Ks, Ns = bsr.shape[-2], bsr.shape[-1], other.shape[-1]
  922. blocksize = bsr.values().shape[-2:]
  923. if indices_data is None:
  924. indices_data = bsr_scatter_mm_indices_data(
  925. bsr, other, indices_format="bsr_strided_mm_compressed"
  926. )
  927. indices_format = indices_data[0]
  928. if out is None:
  929. out = torch.empty(
  930. (*other.shape[:-2], Ms, Ns), dtype=bsr.dtype, device=bsr.device
  931. )
  932. out_shape = out.shape
  933. out = as1Dbatch(out)
  934. if bsr._nnz() == 0:
  935. out.zero_()
  936. elif indices_format in {"bsr_strided_mm_compressed", "bsr_strided_mm"}:
  937. out.zero_()
  938. scatter_mm(bsr.values(), other, indices_data, accumulators=out)
  939. elif indices_format == "scatter_mm":
  940. nbatches = other.shape[:-2].numel()
  941. accumulators = torch.zeros(
  942. (
  943. nbatches * Ms // blocksize[0] * Ns // blocksize[0],
  944. blocksize[0],
  945. blocksize[0],
  946. ),
  947. dtype=bsr.dtype,
  948. device=bsr.device,
  949. )
  950. others = (
  951. as1Dbatch(other)
  952. .transpose(-2, -1)
  953. .view(
  954. nbatches,
  955. Ns // blocksize[0],
  956. blocksize[0],
  957. Ks // blocksize[1],
  958. blocksize[1],
  959. )
  960. .movedim(
  961. (3, 1, 4, 2), (1, 2, 3, 4)
  962. ) # equivalent to .transpose(-3, -2).transpose(-2, -1).transpose(-4, -3)
  963. .flatten(0, 2)
  964. )
  965. scatter_mm(bsr.values(), others, indices_data, accumulators=accumulators)
  966. out.copy_(
  967. accumulators.unflatten(
  968. 0, (nbatches, Ms // blocksize[0], Ns // blocksize[0])
  969. )
  970. .movedim(
  971. (1, 2, 3, 4), (3, 1, 4, 2)
  972. ) # equivalent to .transpose(-4, -3).transpose(-2, -1).transpose(-3, -2)
  973. .reshape(nbatches, Ns, Ms)
  974. .transpose(-2, -1)
  975. )
  976. else:
  977. raise NotImplementedError(indices_format)
  978. return out.view(out_shape)
  979. def _int_bsr_dense_addmm(
  980. input: torch.Tensor,
  981. bsr: torch.Tensor,
  982. dense: torch.Tensor,
  983. *,
  984. beta=1,
  985. alpha=1,
  986. left_alpha: Optional[torch.Tensor] = None,
  987. right_alpha: Optional[torch.Tensor] = None,
  988. out: Optional[torch.Tensor] = None,
  989. skip_checks: bool = False,
  990. max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
  991. meta: Optional[dict] = None,
  992. ):
  993. if out is None and dense.dtype is torch.int8:
  994. f_name = "_int_bsr_dense_addmm"
  995. crow_indices = bsr.crow_indices()
  996. batch_ndim = crow_indices.dim() - 1
  997. M = bsr.shape[batch_ndim]
  998. N = dense.shape[-1]
  999. original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
  1000. out = torch.empty(
  1001. original_batch_dims_broadcasted + (M, N),
  1002. dtype=torch.int32,
  1003. device=dense.device,
  1004. )
  1005. return bsr_dense_addmm(
  1006. input,
  1007. bsr,
  1008. dense,
  1009. beta=beta,
  1010. alpha=alpha,
  1011. left_alpha=left_alpha,
  1012. right_alpha=right_alpha,
  1013. out=out,
  1014. skip_checks=skip_checks,
  1015. max_grid=max_grid,
  1016. meta=meta,
  1017. )
  1018. def bsr_dense_addmm(
  1019. input: torch.Tensor,
  1020. bsr: torch.Tensor,
  1021. dense: torch.Tensor,
  1022. *,
  1023. beta=1,
  1024. alpha=1,
  1025. left_alpha: Optional[torch.Tensor] = None,
  1026. right_alpha: Optional[torch.Tensor] = None,
  1027. out: Optional[torch.Tensor] = None,
  1028. skip_checks: bool = False,
  1029. max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
  1030. meta: Optional[dict] = None,
  1031. ):
  1032. """Compute
  1033. out = beta * input + left_alpha.reshape(-1, 1) * (alpha * (bsr @ dense)) * right_alpha.reshape(1, -1)
  1034. where left_alpha, right_alpha are (* + 1)-D tensors when
  1035. specified, otherwise, these are treated as tensors filled with
  1036. ones.
  1037. """
  1038. f_name = "bsr_dense_addmm"
  1039. values = bsr.values()
  1040. crow_indices = bsr.crow_indices()
  1041. col_indices = bsr.col_indices()
  1042. batch_ndim = crow_indices.dim() - 1
  1043. M, K = bsr.shape[batch_ndim : batch_ndim + 2]
  1044. blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3]
  1045. N = dense.shape[-1]
  1046. # todo: implement checks
  1047. original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
  1048. if out is None:
  1049. out = dense.new_empty(original_batch_dims_broadcasted + (M, N))
  1050. if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0:
  1051. if beta == 0:
  1052. out.zero_()
  1053. else:
  1054. out.copy_(input)
  1055. if beta != 1:
  1056. out.mul_(beta)
  1057. return out
  1058. left_alpha_is_one = False
  1059. right_alpha_is_one = False
  1060. if left_alpha is None:
  1061. left_alpha_is_one = True
  1062. left_alpha = dense.new_empty(()).expand(
  1063. *original_batch_dims_broadcasted, M, N
  1064. ) # not referenced
  1065. else:
  1066. left_alpha = left_alpha.view(*original_batch_dims_broadcasted, M, 1).expand(
  1067. *original_batch_dims_broadcasted, M, N
  1068. )
  1069. if right_alpha is None:
  1070. right_alpha_is_one = True
  1071. right_alpha = dense.new_empty(()).expand(
  1072. *original_batch_dims_broadcasted, M, N
  1073. ) # not referenced
  1074. else:
  1075. right_alpha = right_alpha.view(*original_batch_dims_broadcasted, 1, N).expand(
  1076. *original_batch_dims_broadcasted, M, N
  1077. )
  1078. assert left_alpha.stride()[-1] == 0
  1079. assert right_alpha.stride()[-2] == 0
  1080. if meta is None:
  1081. sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2)
  1082. meta = bsr_dense_addmm_meta(
  1083. M,
  1084. K,
  1085. N,
  1086. blocksize[0],
  1087. blocksize[1],
  1088. beta,
  1089. alpha,
  1090. sparsity=sparsity,
  1091. dtype=dense.dtype,
  1092. out_dtype=out.dtype,
  1093. )
  1094. out_backup = out
  1095. (
  1096. crow_indices,
  1097. col_indices,
  1098. values,
  1099. input,
  1100. dense,
  1101. left_alpha,
  1102. right_alpha,
  1103. out,
  1104. ) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out)
  1105. BM, BK = blocksize
  1106. SPLIT_N = meta.get("SPLIT_N", N // BM)
  1107. BN = N // SPLIT_N
  1108. out_untiled = out
  1109. out = tile_to_blocksize(out, (BM, BN))
  1110. dense = tile_to_blocksize(dense, (BK, BN))
  1111. input = tile_to_blocksize(input, (BM, BN))
  1112. left_alpha = tile_to_blocksize(left_alpha, (BM, BN))
  1113. right_alpha = tile_to_blocksize(right_alpha, (BM, BN))
  1114. # tl.dot supports float16, float32, int32 as accumulator types.
  1115. dot_out_dtype = {
  1116. torch.float16: tl.float32,
  1117. torch.bfloat16: tl.float32,
  1118. torch.float32: tl.float64,
  1119. torch.float64: tl.float64,
  1120. torch.int8: tl.int32,
  1121. torch.int32: tl.int32,
  1122. }[out.dtype]
  1123. n_batches = dense.size(0)
  1124. n_block_rows = crow_indices.size(-1) - 1
  1125. n_block_cols = dense.size(-3)
  1126. full_grid = (n_batches, n_block_cols, n_block_rows)
  1127. if max_grid is not None:
  1128. grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3]))
  1129. else:
  1130. grid_blocks = None
  1131. tensor_dims_map = {
  1132. values: (0, None, None),
  1133. crow_indices: (0, None, -1),
  1134. col_indices: (0, None, None),
  1135. input: (0, -3, -4),
  1136. dense: (0, -3, None),
  1137. left_alpha: (0, -3, -4),
  1138. right_alpha: (0, -3, -4),
  1139. out: (0, -3, -4),
  1140. }
  1141. assert alpha != 0
  1142. def kernel(grid, *sliced_tensors):
  1143. _bsr_strided_addmm_kernel[grid](
  1144. *ptr_stride_extractor(*sliced_tensors),
  1145. beta,
  1146. alpha,
  1147. beta_is_one=beta == 1,
  1148. beta_is_nonzero=beta != 0,
  1149. alpha_is_one=alpha == 1,
  1150. left_alpha_is_one=left_alpha_is_one,
  1151. right_alpha_is_one=right_alpha_is_one,
  1152. BLOCKSIZE_ROW=BM,
  1153. BLOCKSIZE_INNER=BK,
  1154. BLOCKSIZE_COL=BN,
  1155. allow_tf32=dot_out_dtype == tl.float32,
  1156. acc_dtype=dot_out_dtype,
  1157. **meta,
  1158. )
  1159. launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
  1160. if out.data_ptr() != out_backup.data_ptr():
  1161. # prepare_inputs has made a copy of out, copy its content back
  1162. # to out_backup:
  1163. out_backup.copy_(out_untiled.view(out_backup.shape))
  1164. return out_backup
  1165. if has_triton():
  1166. import triton
  1167. import triton.language as tl
  1168. @triton.jit
  1169. def _sampled_addmm_kernel(
  1170. alpha,
  1171. beta,
  1172. IS_BETA_ZERO: tl.constexpr,
  1173. BLOCKSIZE_ROW: tl.constexpr,
  1174. BLOCKSIZE_COL: tl.constexpr,
  1175. k,
  1176. TILE_K: tl.constexpr,
  1177. values_ptr,
  1178. values_batch_stride,
  1179. values_nnz_stride,
  1180. values_row_block_stride,
  1181. values_col_block_stride,
  1182. crow_indices_ptr,
  1183. crow_indices_batch_stride,
  1184. crow_indices_stride,
  1185. col_indices_ptr,
  1186. col_indices_batch_stride,
  1187. col_indices_stride,
  1188. mat1_ptr,
  1189. mat1_batch_stride,
  1190. mat1_tiled_row_stride,
  1191. mat1_tiled_col_stride,
  1192. mat1_row_block_stride,
  1193. mat1_col_block_stride,
  1194. mat2_ptr,
  1195. mat2_batch_stride,
  1196. mat2_tiled_row_stride,
  1197. mat2_tiled_col_stride,
  1198. mat2_row_block_stride,
  1199. mat2_col_block_stride,
  1200. acc_dtype: tl.constexpr,
  1201. allow_tf32: tl.constexpr,
  1202. ):
  1203. batch_pid = tl.program_id(axis=1)
  1204. row_block_pid = tl.program_id(axis=0)
  1205. crow_indices_offset_ptr = (
  1206. crow_indices_ptr
  1207. + crow_indices_batch_stride * batch_pid
  1208. + crow_indices_stride * row_block_pid
  1209. )
  1210. nnz_offset = tl.load(crow_indices_offset_ptr)
  1211. nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
  1212. # Compute nnz for the row with number row_block_pid.
  1213. # If it is zero, skip the row.
  1214. row_nnz = nnz_offset_next - nnz_offset
  1215. if row_nnz == 0:
  1216. return
  1217. row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
  1218. col_block_arange = tl.arange(0, BLOCKSIZE_COL)
  1219. # Pointers are set to the first block of the current row.
  1220. values_block_ptrs = (
  1221. values_ptr
  1222. + values_batch_stride * batch_pid
  1223. + values_nnz_stride * nnz_offset
  1224. + values_row_block_stride * row_block_arange[:, None]
  1225. + values_col_block_stride * col_block_arange[None, :]
  1226. )
  1227. col_index_nnz_ptr = (
  1228. col_indices_ptr
  1229. + col_indices_batch_stride * batch_pid
  1230. + col_indices_stride * nnz_offset
  1231. )
  1232. # Advance mat1 to the current tiled row, ignore columns.
  1233. mat1_block_ptrs = (
  1234. mat1_ptr
  1235. + mat1_batch_stride * batch_pid
  1236. + mat1_tiled_row_stride * row_block_pid
  1237. + mat1_row_block_stride * row_block_arange[:, None]
  1238. )
  1239. # Advance mat2 in batch and block col dimension.
  1240. mat2_block_ptrs = (
  1241. mat2_ptr
  1242. + mat2_batch_stride * batch_pid
  1243. + mat2_col_block_stride * col_block_arange[None, :]
  1244. )
  1245. k_tile_arange = tl.arange(0, TILE_K)
  1246. for _ in range(row_nnz):
  1247. acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
  1248. # find column block index
  1249. col_block = tl.load(col_index_nnz_ptr)
  1250. for k_tile in range(0, k, TILE_K):
  1251. k_offsets = k_tile + k_tile_arange
  1252. mask_k = k_offsets < k
  1253. mat1_block = tl.load(
  1254. mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],
  1255. mask=mask_k[None, :],
  1256. other=0.0,
  1257. )
  1258. mat2_block = tl.load(
  1259. mat2_block_ptrs
  1260. + mat2_tiled_col_stride * col_block
  1261. + mat2_row_block_stride * k_offsets[:, None],
  1262. mask=mask_k[:, None],
  1263. other=0.0,
  1264. )
  1265. acc_block += tl.dot(
  1266. mat1_block, mat2_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
  1267. )
  1268. if IS_BETA_ZERO:
  1269. acc_block *= alpha
  1270. else:
  1271. acc_block = alpha * acc_block + beta * tl.load(values_block_ptrs)
  1272. # write result
  1273. tl.store(values_block_ptrs, acc_block.to(values_ptr.dtype.element_ty))
  1274. # advance val/col_index ptrs to the next block in the row.
  1275. values_block_ptrs += values_nnz_stride
  1276. col_index_nnz_ptr += col_indices_stride
  1277. @triton.jit
  1278. def _bsr_strided_dense_rowspace_kernel(
  1279. # values prologue
  1280. values_ptr,
  1281. values_batch_stride,
  1282. values_nnz_stride,
  1283. values_row_block_stride,
  1284. values_col_block_stride,
  1285. # values epilogue
  1286. # crow_indices prologue
  1287. crow_indices_ptr,
  1288. crow_indices_batch_stride,
  1289. crow_indices_stride,
  1290. # crow_indices epilogue
  1291. # col_indices prologue
  1292. col_indices_ptr,
  1293. col_indices_batch_stride,
  1294. col_indices_stride,
  1295. # col_indices epilogue
  1296. # dense prologue
  1297. dense_ptr,
  1298. dense_batch_stride,
  1299. dense_tiled_row_stride,
  1300. dense_tiled_col_stride,
  1301. dense_row_block_stride,
  1302. dense_col_block_stride,
  1303. # dense epilogue
  1304. # output prologue
  1305. output_ptr,
  1306. output_batch_stride,
  1307. output_tiled_row_stride,
  1308. output_tiled_col_stride,
  1309. output_row_block_stride,
  1310. output_col_block_stride,
  1311. # output epilogue
  1312. #
  1313. # gh-113754: Always keep all constexpr arguments at the end of
  1314. # triton kernel arguments list because with triton 2.1 or
  1315. # earlier non-contiguous outputs will corrupt CUDA state due
  1316. # to a triton bug (fixed in openai/triton#2262).
  1317. BLOCKSIZE_ROW: tl.constexpr,
  1318. BLOCKSIZE_COL: tl.constexpr,
  1319. acc_dtype: tl.constexpr,
  1320. allow_tf32: tl.constexpr,
  1321. GROUP_SIZE_ROW: tl.constexpr,
  1322. ):
  1323. batch_pid = tl.program_id(axis=2)
  1324. row_block_pid = tl.program_id(axis=0)
  1325. col_block_pid = tl.program_id(axis=1)
  1326. n_block_rows = tl.num_programs(axis=0)
  1327. n_block_cols = tl.num_programs(axis=1)
  1328. row_block_pid, col_block_pid = tl.swizzle2d(
  1329. row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
  1330. )
  1331. crow_indices_offset_ptr = (
  1332. crow_indices_ptr
  1333. + crow_indices_batch_stride * batch_pid
  1334. + crow_indices_stride * row_block_pid
  1335. )
  1336. nnz_offset = tl.load(crow_indices_offset_ptr)
  1337. nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
  1338. # Compute nnz for the row with number row_block_pid.
  1339. # If it is zero, skip the row.
  1340. row_nnz = nnz_offset_next - nnz_offset
  1341. if row_nnz == 0:
  1342. return
  1343. row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
  1344. col_block_arange = tl.arange(0, BLOCKSIZE_COL)
  1345. # Pointers are set to the first block of the current row.
  1346. values_block_ptrs = (
  1347. values_ptr
  1348. + values_batch_stride * batch_pid
  1349. + values_nnz_stride * nnz_offset
  1350. + values_row_block_stride * row_block_arange[:, None]
  1351. + values_col_block_stride * col_block_arange[None, :]
  1352. )
  1353. # NOTE: dense is advanced into all dimensions but the tiled row one.
  1354. # That will be advanced in the loop according to values in col_indices.
  1355. dense_block_ptrs = (
  1356. dense_ptr
  1357. + dense_batch_stride * batch_pid
  1358. + dense_tiled_col_stride * col_block_pid
  1359. + dense_row_block_stride * col_block_arange[:, None]
  1360. + dense_col_block_stride * row_block_arange[None, :]
  1361. )
  1362. # Pointers are set to exact write-to locations
  1363. output_ptrs = (
  1364. output_ptr
  1365. + output_batch_stride * batch_pid
  1366. + output_tiled_row_stride * row_block_pid
  1367. + output_tiled_col_stride * col_block_pid
  1368. + output_row_block_stride * row_block_arange[:, None]
  1369. + output_col_block_stride * row_block_arange[None, :]
  1370. )
  1371. # Set pointer to the first nonzero element in the current row
  1372. col_index_nnz_ptr = (
  1373. col_indices_ptr
  1374. + col_indices_batch_stride * batch_pid
  1375. + col_indices_stride * nnz_offset
  1376. )
  1377. output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
  1378. for _ in range(row_nnz):
  1379. values_block = tl.load(values_block_ptrs)
  1380. # find which row of dense needs to get loaded
  1381. # for multiplication with values_block.
  1382. dense_row_idx = tl.load(col_index_nnz_ptr)
  1383. dense_block = tl.load(
  1384. dense_block_ptrs + dense_tiled_row_stride * dense_row_idx
  1385. )
  1386. # do block mm
  1387. output_acc_block += tl.dot(
  1388. values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
  1389. )
  1390. # move val/col_index ptrs to the next block in the row
  1391. values_block_ptrs += values_nnz_stride
  1392. col_index_nnz_ptr += col_indices_stride
  1393. # write back the result
  1394. tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
  1395. def _run_sampled_addmm_kernel(
  1396. alpha,
  1397. beta,
  1398. is_beta_zero,
  1399. blocksize,
  1400. k,
  1401. tile_k,
  1402. values,
  1403. crow_indices,
  1404. col_indices,
  1405. mat1,
  1406. mat2,
  1407. max_grid,
  1408. ):
  1409. n_batches = values.size(0)
  1410. n_block_rows = crow_indices.size(-1) - 1
  1411. full_grid = (n_batches, n_block_rows)
  1412. if max_grid is not None:
  1413. grid_blocks = tuple(max_grid[:2][::-1]) + (None,) * (2 - len(max_grid[:2]))
  1414. else:
  1415. grid_blocks = None
  1416. tensor_dims_map = {
  1417. values: (0, None),
  1418. crow_indices: (0, -1),
  1419. col_indices: (0, None),
  1420. mat1: (0, -4),
  1421. mat2: (0, None),
  1422. }
  1423. if values.dtype in (torch.half, torch.bfloat16):
  1424. acc_dtype = tl.float32
  1425. allow_tf32 = True
  1426. else:
  1427. acc_dtype = tl.float64
  1428. allow_tf32 = False
  1429. def kernel(grid, *sliced_tensors):
  1430. _sampled_addmm_kernel[grid](
  1431. alpha,
  1432. beta,
  1433. is_beta_zero,
  1434. *blocksize,
  1435. k,
  1436. tile_k,
  1437. *ptr_stride_extractor(*sliced_tensors),
  1438. acc_dtype=acc_dtype,
  1439. allow_tf32=allow_tf32,
  1440. num_stages=1,
  1441. num_warps=4,
  1442. )
  1443. launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
  1444. def sampled_addmm(
  1445. input: torch.Tensor,
  1446. mat1: torch.Tensor,
  1447. mat2: torch.Tensor,
  1448. *,
  1449. beta=1.0,
  1450. alpha=1.0,
  1451. out: Optional[torch.Tensor] = None,
  1452. skip_checks: bool = False,
  1453. max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
  1454. ):
  1455. f_name = "sampled_addmm"
  1456. check_bsr_layout(f_name, input)
  1457. input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)
  1458. if not skip_checks:
  1459. check_device(f_name, mat1, input.device)
  1460. check_device(f_name, mat2, input.device)
  1461. if beta != 0.0 and input.dtype is torch.bool:
  1462. check(
  1463. False,
  1464. f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed.",
  1465. )
  1466. if input.dtype is not torch.bool:
  1467. check_dtype(f_name, mat1, input.dtype)
  1468. check_dtype(f_name, mat2, input.dtype)
  1469. else:
  1470. check_dtype(f_name, mat1, mat2.dtype)
  1471. check_mm_compatible_shapes(f_name, mat1, mat2)
  1472. if out is not None:
  1473. check_bsr_layout(f_name, out)
  1474. check_device(f_name, out, mat1.device)
  1475. check_dtype(f_name, out, input.dtype)
  1476. check(
  1477. out.shape == input_broadcasted.shape and out._nnz() == input._nnz(),
  1478. f"{f_name}(): Expects `out` to be of shape {input_broadcasted.shape} "
  1479. f"and with nnz equal to {input_broadcasted._nnz()} "
  1480. f"but got out.shape = {out.shape} and out.nnz = {out._nnz()}",
  1481. )
  1482. if out is None:
  1483. out = input_broadcasted.to(mat1.dtype, copy=True)
  1484. else:
  1485. out.copy_(input_broadcasted)
  1486. if out.numel() == 0 or out._nnz() == 0:
  1487. return out
  1488. blocksize = out.values().shape[-2:]
  1489. k = mat1.size(-1)
  1490. # NOTE: (m, 0) @ (0, n) == zeros(m, n)
  1491. if alpha == 0.0 or k == 0:
  1492. out.values().mul_(beta)
  1493. return out
  1494. # prepare inputs by reshaping them to be kernel-compatible
  1495. out_backup = out
  1496. crow_indices, col_indices, values, mat1, mat2 = prepare_inputs(out, mat1, mat2)
  1497. mat1 = tile_to_blocksize(mat1, (blocksize[0], k))
  1498. mat2 = tile_to_blocksize(mat2, (k, blocksize[1]))
  1499. tile_k = max(*blocksize)
  1500. _run_sampled_addmm_kernel(
  1501. alpha,
  1502. beta,
  1503. beta == 0.0,
  1504. blocksize,
  1505. k,
  1506. tile_k,
  1507. values,
  1508. crow_indices,
  1509. col_indices,
  1510. mat1,
  1511. mat2,
  1512. max_grid,
  1513. )
  1514. # If nnz x block strides are not the same in out_backup.values and values,
  1515. # it means that out_backup.values and values are not the views of each other,
  1516. # so we have to copy.
  1517. if out_backup.values().stride()[-3:] != values.stride()[-3:]:
  1518. out_backup.values().copy_(values.reshape(out_backup.values().shape))
  1519. return out_backup
  1520. def bsr_dense_mm(
  1521. bsr: torch.Tensor,
  1522. dense: torch.Tensor,
  1523. *,
  1524. out: Optional[torch.Tensor] = None,
  1525. skip_checks: bool = False,
  1526. max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None,
  1527. meta: Optional[dict] = None,
  1528. ):
  1529. f_name = "bsr_dense_mm"
  1530. m, _kl = bsr.shape[-2:]
  1531. if not skip_checks:
  1532. check_bsr_layout(f_name, bsr)
  1533. check_device(f_name, bsr, dense.device)
  1534. check_dtype(f_name, bsr, dense.dtype, (torch.int8,))
  1535. check_mm_compatible_shapes(f_name, bsr, dense)
  1536. n = dense.size(-1)
  1537. row_block, col_block = bsr.values().shape[-2:]
  1538. check_blocksize(f_name, (row_block, col_block))
  1539. check(
  1540. not n % 16,
  1541. f"{f_name}(): dense.size(-1) == {n} should be divisible by 16",
  1542. )
  1543. else:
  1544. _kr, n = dense.shape[-2:]
  1545. original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense)
  1546. if out is not None and not skip_checks:
  1547. expected_out_shape = original_batch_dims_broadcasted + (m, n)
  1548. check(
  1549. out.shape == expected_out_shape,
  1550. "bsr_dense_mm(): `out` argument has wrong shape, "
  1551. f"expected {expected_out_shape}, but got {out.shape}.",
  1552. )
  1553. check(
  1554. out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),
  1555. "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, "
  1556. "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) "
  1557. "should be True.",
  1558. )
  1559. # Allocate out
  1560. if out is None:
  1561. out = dense.new_empty(original_batch_dims_broadcasted + (m, n))
  1562. # Short circuit if lhs is zero
  1563. if bsr._nnz() == 0:
  1564. return out.zero_()
  1565. # with beta==0, addmm ignores input content, so we can use out
  1566. # as a placeholder for input because their shapes match:
  1567. return bsr_dense_addmm(out, bsr, dense, alpha=1, beta=0, out=out)
  1568. @triton.jit
  1569. def _bsr_softmax_kernel(
  1570. crow_indices_ptr,
  1571. crow_indices_batch_stride,
  1572. crow_indices_stride,
  1573. values_ptr,
  1574. values_batch_stride,
  1575. values_row_block_stride,
  1576. values_nnz_col_block_stride,
  1577. row_block,
  1578. col_block,
  1579. MAX_ROW_NNZ: tl.constexpr,
  1580. TILE: tl.constexpr,
  1581. ):
  1582. batch_pid = tl.program_id(axis=2)
  1583. row_block_offset_pid = tl.program_id(axis=1)
  1584. row_block_pid = tl.program_id(axis=0)
  1585. crow_indices_offset_ptr = (
  1586. crow_indices_ptr
  1587. + crow_indices_batch_stride * batch_pid
  1588. + crow_indices_stride * row_block_pid
  1589. )
  1590. nnz_offset = tl.load(crow_indices_offset_ptr)
  1591. nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
  1592. # Compute nnz for the row with number row_block_pid.
  1593. # If it is zero, skip the row.
  1594. row_nnz = nnz_offset_next - nnz_offset
  1595. if row_nnz == 0:
  1596. return
  1597. row_arange = tl.arange(0, TILE)
  1598. mask = row_arange < row_nnz * col_block
  1599. curr_row_values_ptrs = (
  1600. values_ptr
  1601. + values_batch_stride * batch_pid
  1602. + values_row_block_stride * row_block_offset_pid
  1603. + nnz_offset * col_block
  1604. )
  1605. # find max in the row
  1606. row_tile = tl.load(
  1607. curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
  1608. ).to(tl.float32)
  1609. max_row_value = tl.max(row_tile, axis=0)
  1610. for _ in range(TILE, MAX_ROW_NNZ, TILE):
  1611. row_arange += TILE
  1612. mask = row_arange < row_nnz * col_block
  1613. row_tile = tl.load(
  1614. curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
  1615. ).to(tl.float32)
  1616. curr_max_row_value = tl.max(row_tile, axis=0)
  1617. max_row_value = tl.where(
  1618. max_row_value > curr_max_row_value, max_row_value, curr_max_row_value
  1619. )
  1620. # find denominator for stable softmax
  1621. num = tl.exp(row_tile - max_row_value)
  1622. denom = tl.sum(num, axis=0)
  1623. for _ in range(TILE, MAX_ROW_NNZ, TILE):
  1624. row_arange -= TILE
  1625. mask = row_arange < row_nnz * col_block
  1626. row_tile = tl.load(
  1627. curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
  1628. ).to(tl.float32)
  1629. num = tl.exp(row_tile - max_row_value)
  1630. denom += tl.sum(num, axis=0)
  1631. # populate output
  1632. tl.store(
  1633. curr_row_values_ptrs + row_arange,
  1634. (num / denom).to(values_ptr.dtype.element_ty),
  1635. mask=mask,
  1636. )
  1637. for _ in range(TILE, MAX_ROW_NNZ, TILE):
  1638. row_arange += TILE
  1639. mask = row_arange < row_nnz * col_block
  1640. row_tile = tl.load(
  1641. curr_row_values_ptrs + row_arange, mask=mask, other=-float("inf")
  1642. ).to(tl.float32)
  1643. num = tl.exp(row_tile - max_row_value)
  1644. tl.store(
  1645. curr_row_values_ptrs + row_arange,
  1646. (num / denom).to(values_ptr.dtype.element_ty),
  1647. mask=mask,
  1648. )
  1649. def bsr_softmax(input, max_row_nnz=None):
  1650. f_name = "bsr_softmax"
  1651. check_bsr_layout(f_name, input)
  1652. check_dtype(f_name, input, input.dtype)
  1653. if input._nnz() == 0 or input.numel() == 0:
  1654. return input.clone()
  1655. m, n = input.shape[-2:]
  1656. nnz = input._nnz()
  1657. row_block, col_block = input.values().shape[-2:]
  1658. if max_row_nnz is None:
  1659. max_row_nnz = triton.next_power_of_2(n)
  1660. else:
  1661. max_row_nnz = triton.next_power_of_2(max_row_nnz)
  1662. crow_indices = input.crow_indices().unsqueeze(0).flatten(0, -2)
  1663. # reshape values from
  1664. # (b1, ..., bn, nnz, row_block, col_block) to
  1665. # (b1 * ... * bn, row_block, nnz * col_block).
  1666. # This simplifies batch dim manipulation and unlocks
  1667. # the possibility to access all nnzs in any given row.
  1668. if input.values().transpose(-3, -2).is_contiguous():
  1669. # Need to clone to avoid `contiguous` returning a view.
  1670. values = input.values().clone()
  1671. else:
  1672. values = input.values()
  1673. values = (
  1674. values.transpose(-3, -2)
  1675. .contiguous()
  1676. .unsqueeze(0)
  1677. .flatten(0, -4)
  1678. .reshape(-1, row_block, nnz * col_block)
  1679. )
  1680. full_grid = (values.shape[0], row_block, m // row_block)
  1681. grid_blocks = None
  1682. tensor_dims_map = {
  1683. # We span nnz number of blocks, not nnz + 1,
  1684. # hence crow_indices[..., :-1]
  1685. crow_indices[..., :-1]: (0, None, -1),
  1686. values: (0, None, None),
  1687. }
  1688. def kernel(grid, *sliced_tensors):
  1689. _bsr_softmax_kernel[grid](
  1690. *ptr_stride_extractor(*sliced_tensors),
  1691. row_block,
  1692. col_block,
  1693. max_row_nnz,
  1694. # Triton's max numel is bounded by 2 ** 17.
  1695. min(2**17, max_row_nnz),
  1696. )
  1697. launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
  1698. values = (
  1699. values.reshape(-1, row_block, nnz, col_block)
  1700. .transpose(-3, -2)
  1701. .reshape(*input.values().shape)
  1702. )
  1703. return torch.sparse_compressed_tensor(
  1704. input.crow_indices().clone(),
  1705. input.col_indices().clone(),
  1706. values,
  1707. size=input.shape,
  1708. layout=input.layout,
  1709. )
  1710. def _scaled_dot_product_attention(
  1711. query: torch.Tensor,
  1712. key: torch.Tensor,
  1713. value: torch.Tensor,
  1714. attn_mask: Optional[torch.Tensor],
  1715. dropout_p: float = 0.0,
  1716. is_causal: bool = False,
  1717. scale: Optional[float] = None,
  1718. ):
  1719. f_name = "_scaled_dot_product_attention"
  1720. check(not is_causal, f"{f_name}(): is_causal == True is not supported.")
  1721. check(attn_mask is not None, f"{f_name}(): attn_mask == None is not supported.")
  1722. assert attn_mask is not None
  1723. check(
  1724. attn_mask.layout == torch.sparse_bsr,
  1725. f"{f_name}(): "
  1726. f"attn_mask.layout must be {torch.sparse_bsr}, but got "
  1727. f"attn_mask.layout == {attn_mask.layout}.",
  1728. )
  1729. check_device(f_name, key, query.device)
  1730. check_device(f_name, value, query.device)
  1731. check_device(f_name, attn_mask, query.device)
  1732. check_dtype(f_name, key, query.dtype)
  1733. check_dtype(f_name, value, query.dtype)
  1734. if attn_mask.dtype is not torch.bool:
  1735. check_dtype(f_name, attn_mask, query.dtype)
  1736. sdpa = sampled_addmm(
  1737. attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False
  1738. )
  1739. if scale is None and query.size(-1) == 0 or scale == 0.0:
  1740. check(
  1741. False,
  1742. f"{f_name}(): current value of scale == {scale} "
  1743. "results in division by zero.",
  1744. )
  1745. scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
  1746. sdpa.values().mul_(scale_factor)
  1747. sdpa = bsr_softmax(sdpa)
  1748. torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
  1749. sdpa = bsr_dense_mm(sdpa, value)
  1750. return sdpa
  1751. @triton.jit
  1752. def _scatter_mm2_kernel(
  1753. M: tl.constexpr,
  1754. K: tl.constexpr,
  1755. N: tl.constexpr,
  1756. blocks_ptr,
  1757. blocks_stride_P,
  1758. blocks_stride_M,
  1759. blocks_stride_K,
  1760. others_ptr,
  1761. others_stride_Q,
  1762. others_stride_K,
  1763. others_stride_N,
  1764. accumulators_ptr,
  1765. accumulators_stride_R,
  1766. accumulators_stride_M,
  1767. accumulators_stride_N,
  1768. pq_offsets_ptr,
  1769. pq_offsets_stride,
  1770. pq_ptr,
  1771. pq_stride_T,
  1772. pq_stride_1,
  1773. dot_out_dtype: tl.constexpr,
  1774. TILE_M: tl.constexpr,
  1775. TILE_N: tl.constexpr,
  1776. allow_tf32: tl.constexpr,
  1777. ):
  1778. Ms = M // TILE_M
  1779. pid_t = tl.program_id(axis=0)
  1780. pid = tl.program_id(axis=1)
  1781. pid_m = pid // Ms
  1782. pid_n = pid % Ms
  1783. rm = pid_m * TILE_M + tl.arange(0, TILE_M)
  1784. rn = pid_n * TILE_N + tl.arange(0, TILE_N)
  1785. rk = tl.arange(0, K)
  1786. A_ptr = blocks_ptr + (
  1787. rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K
  1788. )
  1789. B_ptr = others_ptr + (
  1790. rk[:, None] * others_stride_K + rn[None, :] * others_stride_N
  1791. )
  1792. g0 = tl.load(pq_offsets_ptr + pid_t * pq_offsets_stride)
  1793. g1 = tl.load(pq_offsets_ptr + (pid_t + 1) * pq_offsets_stride)
  1794. if g0 == g1:
  1795. return
  1796. acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
  1797. for i in range(g0, g1):
  1798. p = tl.load(pq_ptr + i * pq_stride_T)
  1799. q = tl.load(pq_ptr + i * pq_stride_T + pq_stride_1)
  1800. A = tl.load(A_ptr + p * blocks_stride_P)
  1801. B = tl.load(B_ptr + q * others_stride_Q)
  1802. acc_block += tl.dot(A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
  1803. C_ptr = (
  1804. accumulators_ptr
  1805. + pid_t * accumulators_stride_R
  1806. + (
  1807. rm[:, None] * accumulators_stride_M
  1808. + rn[None, :] * accumulators_stride_N
  1809. )
  1810. )
  1811. tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
  1812. def _scatter_mm2(
  1813. blocks: torch.Tensor,
  1814. others: torch.Tensor,
  1815. pq_offsets: torch.Tensor,
  1816. pq_indices: torch.Tensor,
  1817. accumulators: torch.Tensor,
  1818. ):
  1819. _P, M, K = blocks.shape
  1820. _Q, _, N = others.shape
  1821. meta = dict(
  1822. TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2
  1823. )
  1824. def grid(META):
  1825. return (
  1826. pq_offsets.shape[0] - 1,
  1827. triton.cdiv(M, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]),
  1828. 1,
  1829. )
  1830. dot_out_dtype = {
  1831. torch.float16: tl.float32,
  1832. torch.bfloat16: tl.float32,
  1833. torch.float32: tl.float64,
  1834. torch.float64: tl.float64,
  1835. }[accumulators.dtype]
  1836. if "allow_tf32" not in meta:
  1837. meta.update(allow_tf32=dot_out_dtype == tl.float32)
  1838. _scatter_mm2_kernel[grid](
  1839. M,
  1840. K,
  1841. N,
  1842. blocks,
  1843. blocks.stride(0),
  1844. blocks.stride(1),
  1845. blocks.stride(2),
  1846. others,
  1847. others.stride(0),
  1848. others.stride(1),
  1849. others.stride(2),
  1850. accumulators,
  1851. accumulators.stride(0),
  1852. accumulators.stride(1),
  1853. accumulators.stride(2),
  1854. pq_offsets,
  1855. pq_offsets.stride(0),
  1856. pq_indices,
  1857. pq_indices.stride(0),
  1858. pq_indices.stride(1),
  1859. dot_out_dtype=dot_out_dtype,
  1860. **meta,
  1861. )
  1862. @triton.jit
  1863. def _scatter_mm6_kernel(
  1864. nbatches,
  1865. Ms,
  1866. Ks: tl.constexpr,
  1867. N,
  1868. blocks_ptr,
  1869. blocks_stride_P,
  1870. blocks_stride_M,
  1871. blocks_stride_K,
  1872. others_ptr,
  1873. others_stride_B,
  1874. others_stride_K,
  1875. others_stride_N,
  1876. accumulators_ptr,
  1877. accumulators_stride_B,
  1878. accumulators_stride_M,
  1879. accumulators_stride_N,
  1880. c_indices_ptr,
  1881. r_offsets_ptr,
  1882. p_offsets_ptr,
  1883. q_offsets_ptr,
  1884. is_compressed: tl.constexpr,
  1885. dot_out_dtype: tl.constexpr,
  1886. SPLIT_N: tl.constexpr,
  1887. TILE_M: tl.constexpr,
  1888. TILE_N: tl.constexpr,
  1889. GROUP_SIZE: tl.constexpr,
  1890. allow_tf32: tl.constexpr,
  1891. ):
  1892. Ns = N // SPLIT_N
  1893. BLOCKS_M = Ms // TILE_M
  1894. BLOCKS_N = Ns // TILE_N
  1895. pid_t_ = tl.program_id(axis=0)
  1896. pid = tl.program_id(axis=1)
  1897. pid_b = pid_t_ % nbatches
  1898. pid_t = pid_t_ // nbatches
  1899. num_pid_in_group = GROUP_SIZE * BLOCKS_N
  1900. group_id = pid // num_pid_in_group
  1901. first_pid_m = group_id * GROUP_SIZE
  1902. group_size_m = min(BLOCKS_M - first_pid_m, GROUP_SIZE)
  1903. pid_m = first_pid_m + (pid % group_size_m)
  1904. pid_n = (pid % num_pid_in_group) // group_size_m
  1905. rm = pid_m * TILE_M + tl.arange(0, TILE_M)
  1906. rn = pid_n * TILE_N + tl.arange(0, TILE_N)
  1907. rk = tl.arange(0, Ks)
  1908. A_ptr = blocks_ptr + (
  1909. rm[:, None] * blocks_stride_M + rk[None, :] * blocks_stride_K
  1910. )
  1911. B_ptr = (
  1912. others_ptr
  1913. + pid_b * others_stride_B
  1914. + (rk[:, None] * others_stride_K + rn[None, :] * others_stride_N)
  1915. )
  1916. # When is_compressed is True, r is the only variable that
  1917. # depends on pid_t. This property allows sorting r values
  1918. # before calling the kernel. The sorting of r is equivalent to
  1919. # defining swizzle operator outside of the kernel.
  1920. r = tl.load(r_offsets_ptr + pid_t)
  1921. if is_compressed:
  1922. m = (r // N) // Ms
  1923. n = (r % N) // Ns
  1924. r0 = tl.load(c_indices_ptr + m)
  1925. r1 = tl.load(c_indices_ptr + m + 1)
  1926. g0 = n * r1 + (SPLIT_N - n) * r0
  1927. nnz = r1 - r0
  1928. else:
  1929. g0 = tl.load(c_indices_ptr + pid_t)
  1930. g1 = tl.load(c_indices_ptr + pid_t + 1)
  1931. nnz = g1 - g0
  1932. q_ptr = q_offsets_ptr + g0
  1933. acc_block = tl.zeros((TILE_M, TILE_N), dtype=dot_out_dtype)
  1934. if is_compressed:
  1935. A_ptr += r0 * blocks_stride_P # type: ignore[possibly-undefined]
  1936. for _ in range(nnz):
  1937. q = tl.load(q_ptr)
  1938. B = tl.load(B_ptr + q)
  1939. A = tl.load(A_ptr)
  1940. acc_block += tl.dot(
  1941. A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
  1942. )
  1943. A_ptr += blocks_stride_P
  1944. q_ptr += 1
  1945. else:
  1946. p_ptr = p_offsets_ptr + g0
  1947. for _ in range(nnz):
  1948. q = tl.load(q_ptr)
  1949. B = tl.load(B_ptr + q)
  1950. p = tl.load(p_ptr)
  1951. A = tl.load(A_ptr + p * blocks_stride_P)
  1952. p_ptr += 1
  1953. q_ptr += 1
  1954. acc_block += tl.dot(
  1955. A, B, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
  1956. )
  1957. C_ptr = (
  1958. accumulators_ptr
  1959. + r
  1960. + pid_b * accumulators_stride_B
  1961. + (
  1962. rm[:, None] * accumulators_stride_M
  1963. + rn[None, :] * accumulators_stride_N
  1964. )
  1965. )
  1966. tl.store(C_ptr, acc_block.to(accumulators_ptr.dtype.element_ty))
  1967. def _scatter_mm6(
  1968. blocks: torch.Tensor,
  1969. others: torch.Tensor,
  1970. c_indices: torch.Tensor,
  1971. r_offsets: torch.Tensor,
  1972. p_offsets: torch.Tensor,
  1973. q_offsets: torch.Tensor,
  1974. meta: dict,
  1975. accumulators: torch.Tensor,
  1976. force_contiguous: bool = True,
  1977. ):
  1978. SPLIT_N = meta["SPLIT_N"]
  1979. _P, Ms, Ks = blocks.shape
  1980. B, _K, N = others.shape
  1981. B_, _M, N_ = accumulators.shape
  1982. assert N_ == N
  1983. Ns = N // SPLIT_N
  1984. assert B_ == B
  1985. def grid(META):
  1986. return (
  1987. r_offsets.shape[0] * B,
  1988. triton.cdiv(Ms, META["TILE_M"]) * triton.cdiv(Ns, META["TILE_N"]),
  1989. )
  1990. dot_out_dtype = {
  1991. torch.float16: tl.float32,
  1992. torch.bfloat16: tl.float32,
  1993. torch.float32: tl.float64,
  1994. torch.float64: tl.float64,
  1995. }[accumulators.dtype]
  1996. if "allow_tf32" not in meta:
  1997. meta.update(allow_tf32=dot_out_dtype == tl.float32)
  1998. assert c_indices.stride(0) == 1
  1999. assert r_offsets.stride(0) == 1
  2000. assert p_offsets.stride(0) == 1
  2001. assert q_offsets.stride(0) == 1
  2002. # Re non-contiguous tensor arguments. Sometimes triton kernel
  2003. # launches may fail with
  2004. #
  2005. # RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
  2006. #
  2007. # that appears to be case when the size of a non-contiguous
  2008. # tensor argument is larger than a certain threshold. Could
  2009. # this be related to shared memory or L1 cache size of a GPU
  2010. # card? In anycase, ensuring that tensor arguments are
  2011. # contiguous seems to avoid the above exception. So, in the
  2012. # following we'll always convert tensor arguments to
  2013. # C-contiguous tensors.
  2014. if force_contiguous:
  2015. blocks = blocks.contiguous()
  2016. others = others.contiguous()
  2017. if not accumulators.is_contiguous():
  2018. accumulators_ = accumulators.contiguous()
  2019. else:
  2020. accumulators_ = accumulators
  2021. else:
  2022. accumulators_ = accumulators
  2023. _scatter_mm6_kernel[grid](
  2024. B,
  2025. Ms,
  2026. Ks,
  2027. N,
  2028. blocks,
  2029. blocks.stride(0),
  2030. blocks.stride(1),
  2031. blocks.stride(2),
  2032. others,
  2033. others.stride(0),
  2034. others.stride(1),
  2035. others.stride(2),
  2036. accumulators_,
  2037. accumulators_.stride(0),
  2038. accumulators_.stride(1),
  2039. accumulators_.stride(2),
  2040. c_indices,
  2041. r_offsets,
  2042. p_offsets,
  2043. q_offsets,
  2044. dot_out_dtype=dot_out_dtype,
  2045. **meta,
  2046. )
  2047. if force_contiguous and not accumulators.is_contiguous():
  2048. accumulators.copy_(accumulators_)
  2049. @triton.jit
  2050. def _bsr_strided_addmm_kernel(
  2051. # values prologue
  2052. values_ptr,
  2053. values_batch_stride,
  2054. values_nnz_stride,
  2055. values_row_block_stride,
  2056. values_col_block_stride,
  2057. # values epilogue
  2058. # crow_indices prologue
  2059. crow_indices_ptr,
  2060. crow_indices_batch_stride,
  2061. crow_indices_stride,
  2062. # crow_indices epilogue
  2063. # col_indices prologue
  2064. col_indices_ptr,
  2065. col_indices_batch_stride,
  2066. col_indices_stride,
  2067. # col_indices epilogue
  2068. # input prologue
  2069. input_ptr,
  2070. input_batch_stride,
  2071. input_tiled_row_stride,
  2072. input_tiled_col_stride,
  2073. input_row_block_stride,
  2074. input_col_block_stride,
  2075. # input epilogue
  2076. # dense prologue
  2077. dense_ptr,
  2078. dense_batch_stride,
  2079. dense_tiled_row_stride,
  2080. dense_tiled_col_stride,
  2081. dense_row_block_stride,
  2082. dense_col_block_stride,
  2083. # dense epilogue
  2084. # left_alpha prologue
  2085. left_alpha_ptr,
  2086. left_alpha_batch_stride,
  2087. left_alpha_tiled_row_stride,
  2088. left_alpha_tiled_col_stride: tl.constexpr,
  2089. left_alpha_row_block_stride,
  2090. left_alpha_col_block_stride: tl.constexpr,
  2091. # left_alpha epilogue
  2092. # right_alpha prologue
  2093. right_alpha_ptr,
  2094. right_alpha_batch_stride,
  2095. right_alpha_tiled_row_stride: tl.constexpr,
  2096. right_alpha_tiled_col_stride,
  2097. right_alpha_row_block_stride: tl.constexpr,
  2098. right_alpha_col_block_stride,
  2099. # right_alpha epilogue
  2100. # output prologue
  2101. output_ptr,
  2102. output_batch_stride,
  2103. output_tiled_row_stride,
  2104. output_tiled_col_stride,
  2105. output_row_block_stride,
  2106. output_col_block_stride,
  2107. # output epilogue
  2108. beta,
  2109. alpha,
  2110. beta_is_one: tl.constexpr,
  2111. beta_is_nonzero: tl.constexpr,
  2112. alpha_is_one: tl.constexpr,
  2113. left_alpha_is_one: tl.constexpr,
  2114. right_alpha_is_one: tl.constexpr,
  2115. BLOCKSIZE_ROW: tl.constexpr,
  2116. BLOCKSIZE_COL: tl.constexpr,
  2117. BLOCKSIZE_INNER: tl.constexpr,
  2118. acc_dtype: tl.constexpr,
  2119. allow_tf32: tl.constexpr,
  2120. GROUP_SIZE_ROW: tl.constexpr,
  2121. SPLIT_N: tl.constexpr,
  2122. ):
  2123. # left/right_alpha tensors are originally (* + 1)-dimensional
  2124. assert left_alpha_tiled_col_stride == 0
  2125. assert left_alpha_col_block_stride == 0
  2126. assert right_alpha_tiled_row_stride == 0
  2127. assert right_alpha_row_block_stride == 0
  2128. batch_pid = tl.program_id(axis=2)
  2129. row_block_pid = tl.program_id(axis=0)
  2130. col_block_pid = tl.program_id(axis=1)
  2131. n_block_rows = tl.num_programs(axis=0)
  2132. n_block_cols = tl.num_programs(axis=1)
  2133. row_block_pid, col_block_pid = tl.swizzle2d(
  2134. row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
  2135. )
  2136. crow_indices_offset_ptr = (
  2137. crow_indices_ptr
  2138. + crow_indices_batch_stride * batch_pid
  2139. + crow_indices_stride * row_block_pid
  2140. )
  2141. nnz_offset = tl.load(crow_indices_offset_ptr)
  2142. nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
  2143. # Compute nnz for the row with number row_block_pid.
  2144. row_nnz = nnz_offset_next - nnz_offset
  2145. row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
  2146. inner_block_arange = tl.arange(0, BLOCKSIZE_INNER)
  2147. col_block_arange = tl.arange(0, BLOCKSIZE_COL)
  2148. # Pointers are set to the first block of the current row.
  2149. values_block_ptrs = (
  2150. values_ptr
  2151. + values_batch_stride * batch_pid
  2152. + values_nnz_stride * nnz_offset
  2153. + values_row_block_stride * row_block_arange[:, None]
  2154. + values_col_block_stride * inner_block_arange[None, :]
  2155. )
  2156. # NOTE: dense is advanced into all dimensions but the tiled row one.
  2157. # That will be advanced in the loop according to values in col_indices.
  2158. dense_block_ptrs = (
  2159. dense_ptr
  2160. + dense_batch_stride * batch_pid
  2161. + dense_tiled_col_stride * col_block_pid
  2162. + dense_row_block_stride * inner_block_arange[:, None]
  2163. + dense_col_block_stride * col_block_arange[None, :]
  2164. )
  2165. # Pointers are set to exact write-to locations
  2166. output_ptrs = (
  2167. output_ptr
  2168. + output_batch_stride * batch_pid
  2169. + output_tiled_row_stride * row_block_pid
  2170. + output_tiled_col_stride * col_block_pid
  2171. + output_row_block_stride * row_block_arange[:, None]
  2172. + output_col_block_stride * col_block_arange[None, :]
  2173. )
  2174. # Set pointer to the first nonzero element in the current row
  2175. col_index_nnz_ptr = (
  2176. col_indices_ptr
  2177. + col_indices_batch_stride * batch_pid
  2178. + col_indices_stride * nnz_offset
  2179. )
  2180. output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype)
  2181. for _ in range(row_nnz):
  2182. values_block = tl.load(values_block_ptrs)
  2183. # find which row of dense needs to get loaded
  2184. # for multiplication with values_block.
  2185. dense_row_idx = tl.load(col_index_nnz_ptr)
  2186. dense_block = tl.load(
  2187. dense_block_ptrs + dense_tiled_row_stride * dense_row_idx
  2188. )
  2189. # do block mm
  2190. output_acc_block += tl.dot(
  2191. values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype
  2192. )
  2193. # move val/col_index ptrs to the next block in the row
  2194. values_block_ptrs += values_nnz_stride
  2195. col_index_nnz_ptr += col_indices_stride
  2196. if not alpha_is_one:
  2197. output_acc_block *= alpha
  2198. if not left_alpha_is_one:
  2199. left_alpha_ptrs = (
  2200. left_alpha_ptr
  2201. + left_alpha_batch_stride * batch_pid
  2202. + left_alpha_tiled_row_stride * row_block_pid
  2203. + left_alpha_tiled_col_stride * col_block_pid
  2204. + left_alpha_row_block_stride * row_block_arange[:, None]
  2205. + left_alpha_col_block_stride * col_block_arange[None, :]
  2206. )
  2207. output_acc_block *= tl.load(left_alpha_ptrs)
  2208. if not right_alpha_is_one:
  2209. right_alpha_ptrs = (
  2210. right_alpha_ptr
  2211. + right_alpha_batch_stride * batch_pid
  2212. + right_alpha_tiled_row_stride * row_block_pid
  2213. + right_alpha_tiled_col_stride * col_block_pid
  2214. + right_alpha_row_block_stride * row_block_arange[:, None]
  2215. + right_alpha_col_block_stride * col_block_arange[None, :]
  2216. )
  2217. output_acc_block *= tl.load(right_alpha_ptrs)
  2218. if beta_is_nonzero:
  2219. input_ptrs = (
  2220. input_ptr
  2221. + input_batch_stride * batch_pid
  2222. + input_tiled_row_stride * row_block_pid
  2223. + input_tiled_col_stride * col_block_pid
  2224. + input_row_block_stride * row_block_arange[:, None]
  2225. + input_col_block_stride * col_block_arange[None, :]
  2226. )
  2227. if beta_is_one:
  2228. output_acc_block += tl.load(input_ptrs)
  2229. else:
  2230. output_acc_block += beta * tl.load(input_ptrs)
  2231. # write back the result
  2232. tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
  2233. else:
  2234. bsr_softmax = None # type: ignore[assignment]
  2235. bsr_dense_mm = None # type: ignore[assignment]
  2236. sampled_addmm = None # type: ignore[assignment]
  2237. _scaled_dot_product_attention = None # type: ignore[assignment]
  2238. _scatter_mm2 = None # type: ignore[assignment]
  2239. _scatter_mm6 = None # type: ignore[assignment]
  2240. _bsr_strided_addmm_kernel = None # type: ignore[assignment]