utils.py 120 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762
  1. from __future__ import annotations
  2. import collections
  3. import contextlib
  4. import dataclasses
  5. import enum
  6. import functools
  7. import importlib
  8. import inspect
  9. import io
  10. import itertools
  11. import logging
  12. import math
  13. import operator
  14. import os
  15. import platform
  16. import re
  17. import shutil
  18. import statistics
  19. import sys
  20. import sysconfig
  21. import tempfile
  22. import textwrap
  23. import time
  24. import unittest
  25. from collections.abc import (
  26. Collection,
  27. Generator,
  28. Iterator,
  29. Mapping,
  30. MutableMapping,
  31. MutableSet,
  32. )
  33. from datetime import datetime
  34. from io import StringIO
  35. from typing import (
  36. Any,
  37. Callable,
  38. cast,
  39. Generic,
  40. Literal,
  41. NamedTuple,
  42. Optional,
  43. Protocol,
  44. TYPE_CHECKING,
  45. TypeVar,
  46. Union,
  47. )
  48. from typing_extensions import (
  49. Concatenate,
  50. dataclass_transform,
  51. ParamSpec,
  52. Self,
  53. TypeAlias,
  54. TypeGuard,
  55. )
  56. from unittest import mock
  57. import sympy
  58. import torch
  59. import torch.utils._pytree as pytree
  60. from torch._inductor.analysis.device_info import datasheet_tops
  61. from torch._inductor.runtime.hints import DeviceProperties
  62. from torch.utils._dtype_abbrs import dtype_abbrs
  63. from torch.utils._ordered_set import OrderedSet
  64. from torch.utils._pytree import tree_flatten, tree_map_only
  65. OPTIMUS_EXCLUDE_POST_GRAD = [
  66. "activation_quantization_aten_pass",
  67. "inductor_autotune_lookup_table",
  68. ]
  69. from torch.fx.experimental.symbolic_shapes import (
  70. free_symbols,
  71. free_unbacked_symbols,
  72. IterateExprs,
  73. ShapeEnv,
  74. )
  75. if TYPE_CHECKING:
  76. from collections.abc import Iterable, Sequence, ValuesView
  77. from torch import SymBool, SymFloat, SymInt
  78. from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
  79. from torch.fx import GraphModule
  80. from torch.fx.node import Node
  81. from .codegen.common import WorkspaceArg
  82. from .codegen.wrapper import PythonWrapperCodegen
  83. from .graph import GraphLowering
  84. from .ir import Buffer, ExternKernel, IRNode, Layout, Operation, ReinterpretView
  85. from .output_code import CompiledFxGraph
  86. from .scheduler import BaseSchedulerNode, SchedulerBuffer
  87. GPU_TYPES = ["cuda", "mps", "xpu", "mtia"]
  88. T = TypeVar("T")
  89. # defines here before import torch._dynamo is for avoiding circular import
  90. # when get_gpu_type is imported from dynamo
  91. @functools.cache
  92. def get_gpu_type() -> str:
  93. avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()]
  94. assert len(avail_gpus) <= 1
  95. gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop()
  96. return gpu_type
  97. from torch._dynamo.device_interface import get_interface_for_device
  98. from torch._dynamo.utils import detect_fake_mode
  99. from torch.autograd import DeviceType
  100. from torch.autograd.profiler_util import EventList
  101. from torch.fx.passes.graph_transform_observer import GraphTransformObserver
  102. from torch.fx.passes.shape_prop import ShapeProp
  103. from torch.utils._sympy.functions import (
  104. CeilDiv,
  105. CleanDiv,
  106. FloorDiv,
  107. Identity,
  108. ModularIndexing,
  109. )
  110. from torch.utils._sympy.symbol import make_symbol, SymT
  111. from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
  112. from . import config
  113. from .runtime.runtime_utils import ceildiv as runtime_ceildiv
  114. _IS_WINDOWS = sys.platform == "win32"
  115. log = logging.getLogger(__name__)
  116. perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
  117. _T = TypeVar("_T")
  118. VarRanges = dict[sympy.Expr, sympy.Expr]
  119. InputType = Optional[Union[torch.Tensor, int, torch.SymInt]]
  120. GPU_KERNEL_BIN_EXTS = {"cuda": ".cubin", "xpu": ".spv"}
  121. GPU_ALIGN_BYTES = 16
  122. ALIGNMENT = 16
  123. TMA_ALIGNMENT = 16
  124. TMA_DESCRIPTOR_SIZE = 128
  125. ALIGN_BYTES = 64
  126. assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"
  127. def _align(nbytes: int) -> int:
  128. """Round up to the nearest multiple of ALIGN_BYTES"""
  129. return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES
  130. def _is_aligned(v: sympy.Expr) -> bool:
  131. """v can be statically proven to be a multiple of ALIGN_BYTES"""
  132. if isinstance(v, (sympy.Add, sympy.Max)):
  133. return all(map(_is_aligned, v.args))
  134. return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES
  135. class align(sympy.Function):
  136. """Symbolically round up to the nearest multiple of ALIGN_BYTES"""
  137. nargs = (1,)
  138. is_integer = True
  139. @classmethod
  140. def eval(cls, value: sympy.Expr) -> Optional[sympy.Expr]:
  141. if isinstance(value, (int, sympy.Integer)):
  142. return _align(int(value))
  143. if _is_aligned(value):
  144. return value
  145. @dataclasses.dataclass(frozen=True)
  146. class GraphPartitionMap:
  147. """
  148. Mapping from the partition info (e.g., input/output) to the graph info
  149. """
  150. # a unique id of graph partition
  151. id: int
  152. # map partition input/output indices to graph input/output indices. None indicates
  153. # a partition input/output is not a graph input/output.
  154. input_index_mapping: list[Optional[int]]
  155. output_index_mapping: list[Optional[int]]
  156. # name of constants read/written by the graph partition
  157. constant_names: list[str]
  158. def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float:
  159. """
  160. Returns benchmark results by examining torch profiler events.
  161. This could be more accurate as it doesn't count CPU side overhead.
  162. However, this also requires manually excluding irrelevant event, e.g.
  163. vectorized_elementwise_kernel which is used to fill L2 cache,
  164. various CUDA events, etc, so could also be fragile.
  165. """
  166. fn()
  167. torch.cuda.synchronize()
  168. cache = torch.empty(int(256e6 // 4), dtype=torch.float16, device="cuda")
  169. # Estimate the runtime of the function
  170. start_event = torch.cuda.Event(enable_timing=True)
  171. end_event = torch.cuda.Event(enable_timing=True)
  172. start_event.record()
  173. for _ in range(5):
  174. cache.zero_()
  175. fn()
  176. end_event.record()
  177. torch.cuda.synchronize()
  178. estimate_ms = start_event.elapsed_time(end_event) / 5
  179. # compute number of warmup and repeat
  180. n_warmup = max(1, int(warmup / estimate_ms))
  181. n_repeat = max(1, int(rep / estimate_ms))
  182. # Warm-up
  183. for _ in range(n_warmup):
  184. fn()
  185. start_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
  186. end_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
  187. with torch.profiler.profile(
  188. activities=[
  189. torch.profiler.ProfilerActivity.CUDA,
  190. ]
  191. ) as p:
  192. torch.cuda.synchronize()
  193. for i in range(n_repeat):
  194. cache.zero_()
  195. start_event[i].record()
  196. with torch.cuda.nvtx.range("RunCudaModule"):
  197. fn()
  198. end_event[i].record()
  199. torch.cuda.synchronize()
  200. times = torch.tensor(
  201. [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
  202. )
  203. res = torch.mean(times).item()
  204. log.debug("raw events")
  205. log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
  206. filtered_events = EventList(
  207. [
  208. event
  209. for event in p.events()
  210. if (
  211. event.device_type == DeviceType.CUDA
  212. and re.match(r"fused_abs_max_\d", event.name) is not None
  213. )
  214. ]
  215. )
  216. if filtered_events:
  217. res -= (
  218. statistics.mean(event.device_time_total for event in filtered_events)
  219. / 1000.0
  220. )
  221. log.debug("profiling results: %s ms", res)
  222. return res
  223. def do_bench_using_profiling(
  224. fn: Callable[[], Any], warmup: int = 25, rep: int = 100
  225. ) -> float:
  226. """
  227. Returns benchmark results by examining torch profiler events.
  228. This could be more accurate as it doesn't count CPU side overhead.
  229. However, this also requires manually excluding irrelevant event, e.g.
  230. vectorized_elementwise_kernel which is used to fill L2 cache,
  231. various CUDA events, etc, so could also be fragile.
  232. """
  233. fn()
  234. torch.cuda.synchronize()
  235. cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
  236. # Estimate the runtime of the function
  237. start_event = torch.cuda.Event(enable_timing=True)
  238. end_event = torch.cuda.Event(enable_timing=True)
  239. start_event.record()
  240. for _ in range(5):
  241. cache.zero_()
  242. fn()
  243. end_event.record()
  244. torch.cuda.synchronize()
  245. estimate_ms = start_event.elapsed_time(end_event) / 5
  246. # compute number of warmup and repeat
  247. n_warmup = max(1, int(warmup / estimate_ms))
  248. n_repeat = max(1, int(rep / estimate_ms))
  249. # Warm-up
  250. for _ in range(n_warmup):
  251. fn()
  252. torch.cuda.synchronize()
  253. with torch.profiler.profile(
  254. activities=[
  255. torch.profiler.ProfilerActivity.CUDA,
  256. ]
  257. ) as p:
  258. # Benchmark
  259. for i in range(n_repeat):
  260. # we clear the L2 cache before each run
  261. cache.zero_()
  262. # record time of `fn`
  263. fn()
  264. # Record clocks
  265. torch.cuda.synchronize()
  266. log.debug("raw events")
  267. log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
  268. filtered_events = EventList(
  269. [
  270. event
  271. for event in p.events()
  272. if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
  273. ]
  274. )
  275. if len(filtered_events) % n_repeat != 0:
  276. raise RuntimeError(
  277. "Failed to divide all profiling events into #repeat groups. "
  278. "#CUDA events: %d, #repeats: %s",
  279. len(filtered_events),
  280. n_repeat,
  281. )
  282. num_event_per_group = len(filtered_events) / n_repeat
  283. actual_events = EventList(
  284. [
  285. event
  286. for i, event in enumerate(filtered_events)
  287. if i % num_event_per_group != 0
  288. ]
  289. )
  290. actual_events._build_tree()
  291. actual_events = actual_events.key_averages()
  292. log.debug("profiling time breakdown")
  293. log.debug(actual_events.table(row_limit=-1))
  294. res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
  295. log.debug("profiling results: %s ms", res)
  296. return res
  297. @functools.cache
  298. def has_torchvision_roi_align() -> bool:
  299. try:
  300. from torchvision.ops import roi_align # noqa: F401
  301. torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta")
  302. return roi_align is not None and hasattr(
  303. getattr(torch.ops, "torchvision", None), "roi_align"
  304. )
  305. except ImportError:
  306. return False
  307. except RuntimeError as e:
  308. assert "torchvision::nms does not exist" in str(e)
  309. return False
  310. def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
  311. if device is None:
  312. return torch.tensor(0.0).device # default device
  313. if isinstance(device, str):
  314. device = torch.device(device)
  315. if device.type not in ("cpu", "meta") and device.index is None:
  316. device_interface = get_interface_for_device(device.type)
  317. return torch.device(device.type, index=device_interface.Worker.current_device())
  318. return device
  319. def sympy_product(it: Iterable[sympy.Expr]) -> sympy.Expr:
  320. return functools.reduce(operator.mul, it, sympy.S.One)
  321. def sympy_dot(seq1: Sequence[sympy.Expr], seq2: Sequence[sympy.Expr]) -> sympy.Expr:
  322. assert len(seq1) == len(seq2)
  323. return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
  324. def unique(it: Iterable[_T]) -> ValuesView[_T]:
  325. return {id(x): x for x in it}.values()
  326. def ceildiv(
  327. number: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
  328. ) -> Union[int, sympy.Expr]:
  329. if isinstance(number, sympy.Expr) or isinstance(denom, sympy.Expr):
  330. return CeilDiv(sympy.sympify(number), sympy.sympify(denom))
  331. # TODO: There is a bug in a call to this function, to repro:
  332. # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
  333. # --amp --only YituTechConvBert --dynamic-shapes
  334. assert isinstance(number, int) and isinstance(denom, int), (
  335. f"{number}: {type(number)}, {denom}: {type(denom)}"
  336. )
  337. return runtime_ceildiv(number, denom)
  338. def _type_of(key: Optional[torch.dtype]) -> str:
  339. # Use the function here to get rid of dependencies on the Triton during the codegen.
  340. # Refer to Triton implementation here:
  341. # https://github.com/triton-lang/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
  342. # `None` is nullptr. Implicitly convert to *i8.
  343. if key is None:
  344. return "*i8"
  345. dtype_str = str(key).split(".")[-1]
  346. tys = {
  347. "bool": "i1",
  348. "float8e4nv": "fp8e4nv",
  349. "float8e5": "fp8e5",
  350. "float8e4b15": "fp8e4b15",
  351. "float8e4b15x4": "fp8e4b15x4",
  352. "float8_e4m3fn": "fp8e4nv",
  353. "float8_e5m2": "fp8e5",
  354. # TODO: remove when support is added in triton
  355. # https://github.com/triton-lang/triton/issues/6054
  356. "float8_e8m0fnu": "u8",
  357. "float4_e2m1fn_x2": "u8",
  358. "float16": "fp16",
  359. "bfloat16": "bf16",
  360. "float32": "fp32",
  361. "float64": "fp64",
  362. "int8": "i8",
  363. "int16": "i16",
  364. "int32": "i32",
  365. "int64": "i64",
  366. "uint8": "u8",
  367. "uint16": "u16",
  368. "uint32": "u32",
  369. "uint64": "u64",
  370. }
  371. # reinterpret can create triton type
  372. tys.update({v: v for v in list(tys.values())})
  373. return key if isinstance(key, str) else f"*{tys[dtype_str]}"
  374. def convert_shape_to_inductor(
  375. lst: Iterable[Union[int, torch.SymInt]],
  376. ) -> list[sympy.Expr]:
  377. """
  378. Gets the shape and stride of a tensor. For non-symbolic tensors, this is
  379. trivial. But for symbolic tensors, we need to map from SymIntNode into
  380. sympy.Expr.
  381. """
  382. return [sympy.sympify(i) for i in lst]
  383. def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]:
  384. """
  385. Like convert_shape_to_symint, but operates on a single expression.
  386. """
  387. from .virtualized import V
  388. return (
  389. i
  390. if isinstance(i, int)
  391. else (
  392. int(i)
  393. if isinstance(i, sympy.Integer)
  394. else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
  395. )
  396. )
  397. def convert_shape_to_symint(
  398. lst: Iterable[Union[int, sympy.Expr]],
  399. ) -> list[Union[int, torch.SymInt]]:
  400. """
  401. Takes a list of shapes from Inductor and converts them into symints (or just
  402. ints if all shapes are static).
  403. """
  404. return [convert_to_symint(i) for i in lst]
  405. def is_view(op: torch._ops.OpOverload) -> bool:
  406. """
  407. Does this op overload have aliasing
  408. """
  409. return any(a.alias_info is not None for a in op._schema.arguments)
  410. def is_pointwise_use(
  411. use: Node,
  412. is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
  413. ) -> bool:
  414. """
  415. Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
  416. Uses in views ops will follow the views uses
  417. """
  418. if not use.op == "call_function":
  419. return False
  420. if not (
  421. isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
  422. ):
  423. return False
  424. target = cast(torch._ops.OpOverload, use.target)
  425. if target is operator.getitem or is_view(target):
  426. return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
  427. return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
  428. def gen_gm_and_inputs(
  429. target: Any, args: list[Any], kwargs: dict[str, Any]
  430. ) -> tuple[GraphModule, list[torch.Tensor]]:
  431. g = torch.fx.Graph()
  432. graph_args: list[torch.Tensor] = []
  433. def add_tensor_arg(arg: torch.Tensor) -> Node:
  434. graph_args.append(arg)
  435. return g.placeholder(f"arg{len(graph_args)}")
  436. node = g.call_function(
  437. target, *tree_map_only(torch.Tensor, add_tensor_arg, (args, kwargs))
  438. )
  439. if (
  440. len(target._schema.returns) == 1
  441. and str(target._schema.returns[0].type) == "Tensor"
  442. ):
  443. node = (node,) # type: ignore[assignment]
  444. g.output(node)
  445. gm = torch.fx.GraphModule({}, g)
  446. return gm, graph_args
  447. def synchronize(device: str = "cuda") -> None:
  448. if device == "cpu":
  449. return
  450. device_interface = get_interface_for_device(device)
  451. if device_interface.is_available():
  452. device_interface.synchronize()
  453. def timed(
  454. model: Callable[..., Any],
  455. example_inputs: Sequence[Any],
  456. times: int = 1,
  457. device: str = "cuda",
  458. ) -> float:
  459. synchronize(device)
  460. torch.manual_seed(1337)
  461. t0 = time.perf_counter()
  462. for _ in range(times):
  463. result = model(*example_inputs)
  464. synchronize(device)
  465. t1 = time.perf_counter()
  466. # GC the result after timing
  467. assert result is not None # type: ignore[possibly-undefined]
  468. return t1 - t0
  469. def print_performance(
  470. model: Callable[..., Any],
  471. example_inputs: Sequence[Any] = (),
  472. times: int = 10,
  473. repeat: int = 10,
  474. baseline: float = 1.0,
  475. device: str = "cuda",
  476. ) -> float:
  477. timings = torch.tensor(
  478. [timed(model, example_inputs, times, device) for _ in range(repeat)]
  479. )
  480. took = torch.median(timings) / times
  481. print(f"{took / baseline:.6f}")
  482. return took.item()
  483. def precompute_method(obj: Any, method: str) -> None:
  484. """Replace obj.method() with a new method that returns a precomputed constant."""
  485. result = getattr(obj, method)()
  486. setattr(obj, method, lambda: result)
  487. def precompute_methods(obj: Any, methods: list[str]) -> None:
  488. """Replace methods with new methods that returns a precomputed constants."""
  489. for method in methods:
  490. precompute_method(obj, method)
  491. def cmp(a: int, b: int) -> int:
  492. return int(a > b) - int(a < b)
  493. def pad_listlike(x: Union[int, Sequence[int]], size: int) -> Sequence[int]:
  494. if isinstance(x, int):
  495. return [x] * size
  496. if len(x) == 1:
  497. return type(x)([x[0]]) * size # type: ignore[call-arg, operator, return-value]
  498. return x
  499. # Used to ensure that iterating over a set is deterministic
  500. def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
  501. if len(x) == 0:
  502. return []
  503. def sort_func(elem: _T) -> str:
  504. if isinstance(elem, str):
  505. return elem
  506. from .scheduler import BaseSchedulerNode
  507. assert isinstance(elem, BaseSchedulerNode)
  508. return elem.get_name()
  509. return sorted(x, key=sort_func)
  510. P = ParamSpec("P")
  511. RV = TypeVar("RV", covariant=True)
  512. FN_TYPE = Callable[Concatenate[Any, P], RV]
  513. class CachedMethod(Protocol, Generic[P, RV]):
  514. @staticmethod
  515. def clear_cache(cache: Any) -> None: ...
  516. def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ...
  517. # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
  518. def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
  519. name = fn.__name__
  520. key = f"__{name}_cache"
  521. # wrapper is likely on the hot path, compile a specialized version of it
  522. ctx = {"fn": fn}
  523. exec(
  524. f"""\
  525. def {name}_cache_on_self(self):
  526. try:
  527. return self.{key}
  528. except AttributeError:
  529. pass
  530. rv = fn(self)
  531. object.__setattr__(self, "{key}", rv)
  532. return rv
  533. """.lstrip(),
  534. ctx,
  535. )
  536. wrapper = functools.wraps(fn)(ctx[f"{name}_cache_on_self"])
  537. def clear_cache(self: Any) -> None:
  538. if hasattr(self, key):
  539. delattr(self, key)
  540. wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
  541. return wrapper # type: ignore[return-value]
  542. def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
  543. """
  544. Variant of cache_on_self for properties. The only difference is the type signature.
  545. """
  546. # pyrefly: ignore [bad-argument-type]
  547. return cache_on_self(fn)
  548. def cache_on_self_and_args(
  549. class_name: str,
  550. ) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]:
  551. # include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls.
  552. def wrapper(
  553. fn: FN_TYPE[P, RV],
  554. ) -> FN_TYPE[P, RV]:
  555. key = f"__{class_name}_{fn.__name__}_cache"
  556. # wrapper is likely on the hot path, compile a specialized version of it
  557. ctx = {"fn": fn}
  558. exec(
  559. f"""\
  560. def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
  561. args_kwargs = (args, tuple(sorted(kwargs.items())))
  562. if not hasattr(self, "{key}"):
  563. object.__setattr__(self, "{key}", {{}})
  564. cache = self.{key}
  565. try:
  566. return cache[args_kwargs]
  567. except KeyError:
  568. pass
  569. rv = fn(self, *args, **kwargs)
  570. cache[args_kwargs] = rv
  571. return rv
  572. """.lstrip(),
  573. ctx,
  574. )
  575. inner = functools.wraps(fn)(ctx["inner"])
  576. def clear_cache(self: Any) -> None:
  577. if hasattr(self, key):
  578. delattr(self, key)
  579. inner.clear_cache = clear_cache # type: ignore[attr-defined]
  580. return inner
  581. return wrapper
  582. def aggregate_origins(
  583. node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
  584. ) -> OrderedSet[Node]:
  585. from . import ir
  586. if isinstance(node_schedule, list):
  587. return functools.reduce(
  588. operator.or_,
  589. [
  590. node.node.origins
  591. for node in node_schedule
  592. if hasattr(node, "node") and node.node
  593. ],
  594. OrderedSet(),
  595. )
  596. elif isinstance(node_schedule, ir.ExternKernel):
  597. return node_schedule.origins
  598. else:
  599. return OrderedSet()
  600. def get_fused_kernel_name(
  601. node_schedule: Sequence[BaseSchedulerNode],
  602. descriptive_names: Literal[True, "torch", "original_aten", "inductor_node"],
  603. ) -> str:
  604. all_origins = aggregate_origins(node_schedule)
  605. if descriptive_names == "original_aten":
  606. # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
  607. sources = [
  608. origin.meta["original_aten"]._overloadpacket.__name__
  609. for origin in all_origins
  610. if origin.op == "call_function"
  611. and "original_aten" in origin.meta
  612. and origin.meta["original_aten"] is not None
  613. ]
  614. sources = sorted(OrderedSet(sources))
  615. elif descriptive_names == "torch":
  616. # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
  617. sources = []
  618. for origin in all_origins:
  619. if origin.op == "call_function" and "source_fn_stack" in origin.meta:
  620. source_fn = origin.meta["source_fn_stack"][-1]
  621. if isinstance(source_fn[1], str):
  622. sources.append(source_fn[1])
  623. else:
  624. sources.append(source_fn[1].__name__)
  625. sources = sorted(OrderedSet(sources))
  626. elif descriptive_names == "inductor_node":
  627. sources = [
  628. origin.name for origin in all_origins if origin.op == "call_function"
  629. ]
  630. else:
  631. raise NotImplementedError
  632. sources = sources
  633. return "_".join(["fused"] + sources)
  634. def get_kernel_metadata(
  635. node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
  636. wrapper: PythonWrapperCodegen,
  637. ) -> tuple[str, str]:
  638. """
  639. Retrieves metadata information for a kernel.
  640. Args:
  641. node_schedule (Union[Sequence[BaseSchedulerNode], ExternKernel]):
  642. Either a sequence of BaseSchedulerNode objects or an ExternKernel instance.
  643. wrapper (PythonWrapperCodegen):
  644. An instance of PythonWrapperCodegen, used to define the code comment format.
  645. Returns:
  646. tuple[str, str]:
  647. A tuple containing two strings:
  648. - The first string represents the kernel's metadata.
  649. - The second string represent the kernel's detailed metadata.
  650. """
  651. all_origins = aggregate_origins(node_schedule)
  652. inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
  653. from_node_dict = collections.defaultdict(list)
  654. original_aten_dict = collections.defaultdict(list)
  655. # Attempt to sort `inductor_nodes` topologically. Note that the case
  656. # where `inductor_nodes` contains nodes from multiple graph instances
  657. # is not supported. An example of this is conditional statements.
  658. single_graph = None
  659. if len(inductor_nodes):
  660. unique_graphs = OrderedSet(n.graph for n in inductor_nodes)
  661. if len(unique_graphs) == 1:
  662. single_graph = inductor_nodes[0].graph
  663. # create a map of idx -> node and cache it
  664. if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"):
  665. node_to_idx_map = {n: idx for idx, n in enumerate(single_graph.nodes)}
  666. single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map # type: ignore[attr-defined]
  667. inductor_nodes.sort(
  668. key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] # type: ignore[attr-defined]
  669. )
  670. for node in inductor_nodes:
  671. if "original_aten" in node.meta and node.meta["original_aten"] is not None:
  672. key = str(node.meta["original_aten"]._overloadpacket)
  673. original_aten_dict[key].append(node.name)
  674. if "from_node" in node.meta:
  675. key = node.meta["from_node"][0].name
  676. from_node_dict[key].append(node.name)
  677. sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted"
  678. metadata = (
  679. f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], "
  680. f"Original ATen: [{', '.join(original_aten_dict.keys())}]"
  681. )
  682. # trace back to original node here
  683. detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"]
  684. for original_node, nodes in sorted(from_node_dict.items()):
  685. detailed_metadata.append(
  686. f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
  687. )
  688. # print the aot_autograd graph fragment
  689. if single_graph is not None:
  690. from . import ir
  691. detailed_metadata.append(f"{wrapper.comment} Graph fragment:")
  692. all_reads: OrderedSet[str] = OrderedSet()
  693. all_writes: list[str] = []
  694. if not isinstance(node_schedule, ir.ExternKernel):
  695. from .virtualized import V
  696. def get_buffer_info(
  697. buffer: Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject], rw_name: str
  698. ) -> tuple[str, ir.Layout | None]:
  699. if isinstance(buffer, ir.TensorBox) and isinstance(
  700. buffer.data, ir.StorageBox
  701. ):
  702. origin_node = buffer.data.data.origin_node
  703. else:
  704. origin_node = buffer.origin_node
  705. if origin_node is None:
  706. # use the read/write name if no origin node is found
  707. name = rw_name
  708. else:
  709. name = origin_node.name
  710. try:
  711. layout = buffer.get_layout()
  712. except NotImplementedError:
  713. layout = None
  714. return name, layout
  715. def stringify_shape(shape: Iterable[int]) -> str:
  716. return f"[{', '.join([str(x) for x in shape])}]"
  717. def stringfy_layout(layout: ir.Layout | None) -> str:
  718. if layout is None:
  719. return ""
  720. shape_annotation = f"{stringify_shape(layout.size)}"
  721. stride_annotation = f"{stringify_shape(layout.stride)}"
  722. device_annotation = f"{layout.device}"
  723. return (
  724. f'"{dtype_abbrs[layout.dtype]}{shape_annotation}'
  725. f'{stride_annotation}{device_annotation}"'
  726. )
  727. for n in node_schedule:
  728. if not hasattr(n, "read_writes") or n.read_writes is None:
  729. continue
  730. if hasattr(n.read_writes, "reads") and n.read_writes.reads is not None:
  731. for r in n.read_writes.reads:
  732. # Remove the dupricated inputs
  733. if r.name in all_reads:
  734. continue
  735. all_reads.add(r.name)
  736. buffer = V.graph.try_get_buffer(r.name)
  737. if buffer is None:
  738. continue
  739. input_name, layout = get_buffer_info(buffer, r.name)
  740. detailed_metadata.append(
  741. f"{wrapper.comment} %{input_name} : Tensor "
  742. f"{stringfy_layout(layout)} = PlaceHolder[target={input_name}]"
  743. )
  744. if (
  745. hasattr(n.read_writes, "writes")
  746. and n.read_writes.writes is not None
  747. ):
  748. for w in n.read_writes.writes:
  749. buffer = V.graph.try_get_buffer(w.name)
  750. if buffer is None:
  751. continue
  752. output_name, _ = get_buffer_info(buffer, w.name)
  753. all_writes.append("%" + output_name)
  754. for node in inductor_nodes:
  755. detailed_metadata.append(
  756. f"{wrapper.comment} {node.format_node(include_tensor_metadata=True)}"
  757. )
  758. detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}")
  759. return metadata, "\n".join(detailed_metadata)
  760. def dominated_nodes(
  761. initial_queue: Iterable[torch.fx.Node],
  762. skip_filter: Optional[Callable[[Any], bool]] = None,
  763. ) -> OrderedSet[torch.fx.Node]:
  764. """Returns the set of nodes whose values depend on those within initial_queue"""
  765. initial_queue = list(initial_queue)
  766. dominated_set = OrderedSet(initial_queue)
  767. while initial_queue:
  768. node = initial_queue.pop()
  769. for user in node.users:
  770. if skip_filter and skip_filter(user):
  771. continue
  772. if user not in dominated_set:
  773. dominated_set.add(user)
  774. initial_queue.append(user)
  775. return dominated_set
  776. def gather_origins(
  777. args: Sequence[IRNode], kwargs: dict[str, IRNode]
  778. ) -> OrderedSet[torch.fx.Node]:
  779. from . import ir
  780. def is_unrealized_node(n: IRNode) -> bool:
  781. if isinstance(n, ir.TensorBox):
  782. return is_unrealized_node(n.data)
  783. if isinstance(n, ir.StorageBox):
  784. return is_unrealized_node(n.data)
  785. return isinstance(n, ir.IRNode) and not isinstance(
  786. n,
  787. (
  788. ir.ComputedBuffer,
  789. ir.InputsKernel,
  790. ir.InputBuffer,
  791. ir.TemplateBuffer,
  792. ),
  793. )
  794. # kwargs and args may include a container of node, for example torch.cat([t1, t2])
  795. # flatten them before search the unrealized nodes
  796. kwargs_flatten, _ = tree_flatten(kwargs)
  797. kwargs_origins = [val.origins for val in kwargs_flatten if is_unrealized_node(val)]
  798. args_flatten, _ = tree_flatten(args)
  799. args_origins = [val.origins for val in args_flatten if is_unrealized_node(val)]
  800. return OrderedSet(itertools.chain(*args_origins, *kwargs_origins))
  801. def sympy_str(expr: sympy.Expr) -> str:
  802. """
  803. Normal sympy str is very slow, this is a lot faster. The result are
  804. somewhat worse, as it doesn't do as much simplification. So don't
  805. use this for final codegen.
  806. """
  807. def is_neg_lead(expr: sympy.Expr) -> bool:
  808. return (
  809. isinstance(expr, sympy.Mul) and len(expr.args) == 2 and expr.args[0] == -1
  810. )
  811. def sympy_str_add(expr: sympy.Expr) -> str:
  812. if isinstance(expr, sympy.Add):
  813. # Special case 'a - b'. Note that 'a - b - c' will still appear as
  814. # 'a + -1 * b + -1 * c'.
  815. if len(expr.args) == 2 and is_neg_lead(expr.args[1]):
  816. return f"{sympy_str_mul(expr.args[0])} - {sympy_str_mul(expr.args[1].args[1])}"
  817. else:
  818. return " + ".join(map(sympy_str_mul, expr.args))
  819. else:
  820. return sympy_str_mul(expr)
  821. def sympy_str_mul(expr: sympy.Expr) -> str:
  822. if isinstance(expr, sympy.Mul):
  823. if is_neg_lead(expr):
  824. # Special case '-a'. Note that 'a * -b' will still appear as
  825. # '-1 * a * b'.
  826. return f"-{sympy_str_atom(expr.args[1])}"
  827. else:
  828. return " * ".join(map(sympy_str_atom, expr.args))
  829. else:
  830. return sympy_str_atom(expr)
  831. def sympy_str_atom(expr: sympy.Expr) -> str:
  832. if isinstance(expr, sympy.Symbol):
  833. return expr.name
  834. elif isinstance(expr, (sympy.Add, sympy.Mul)):
  835. return f"({sympy_str_add(expr)})"
  836. elif isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)):
  837. return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
  838. else:
  839. return str(expr)
  840. return sympy_str_add(expr)
  841. def get_bounds_index_expr(index: sympy.Expr) -> ValueRanges[Any]:
  842. from .virtualized import V
  843. # If this expression does not come from an FX node, we compute its bounds
  844. if (
  845. config.compute_all_bounds
  846. and (fx_node := getattr(V.interpreter, "current_node", None))
  847. and fx_node.target != "index_expr"
  848. ):
  849. return bound_sympy(index)
  850. else:
  851. return ValueRanges.unknown()
  852. def prefix_is_reduction(prefix: str) -> bool:
  853. return prefix[0] == "r"
  854. def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
  855. """
  856. Used to generate an integer-nonnegative symbol.
  857. """
  858. # This should never be used for creating shape/stride symbols, as those
  859. # should all be allocated before Inductor.
  860. assert prefix != SymT.SIZE
  861. # NOTE: shape symbols are positive (> 0), but index variables are only
  862. # non-negative (>= 0).
  863. return make_symbol(prefix, idx, integer=True, nonnegative=True)
  864. def generate_assert(check: bool) -> bool:
  865. return (check or config.debug_index_asserts) and config.assert_indirect_indexing
  866. def sympy_index_symbol(name: str) -> sympy.Symbol:
  867. """
  868. Used to generate an integer-nonnegative symbol.
  869. """
  870. # This should never be used for creating shape/stride symbols, as those
  871. # should all be allocated before Inductor.
  872. assert name[0] != "s"
  873. # NOTE: shape symbols are positive (> 0), but index variables are only
  874. # non-negative (>= 0).
  875. return sympy.Symbol(name, integer=True, nonnegative=True)
  876. def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.Expr:
  877. """
  878. When the passed replacement symbol v is a string, it is converted to a symbol with name v that
  879. have the same replaced expression integer and nonnegative properties.
  880. """
  881. def to_symbol(
  882. replaced: sympy.Expr, replacement: Union[sympy.Expr, str]
  883. ) -> sympy.Symbol:
  884. assert isinstance(replaced, sympy.Expr)
  885. if isinstance(replacement, str):
  886. return sympy.Symbol(
  887. replacement,
  888. integer=replaced.is_integer, # type: ignore[attr-defined]
  889. nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
  890. )
  891. else:
  892. return replacement
  893. # xreplace is faster than subs, but is way more picky
  894. return sympy.sympify(expr).xreplace(
  895. {k: to_symbol(k, v) for k, v in replacements.items()}
  896. )
  897. def is_symbolic(a: Any) -> TypeGuard[Union[torch.SymInt, torch.Tensor]]:
  898. return isinstance(a, torch.SymInt) or (
  899. isinstance(a, torch.Tensor)
  900. and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride()))
  901. )
  902. def any_is_symbolic(*args: Any) -> bool:
  903. return any(is_symbolic(a) for a in args)
  904. def get_first_incompatible_cudagraph_node(
  905. gm: torch.fx.GraphModule,
  906. ) -> Optional[torch.fx.Node]:
  907. from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
  908. forbidden_set = OrderedSet(
  909. [
  910. "aten._fused_moving_avg_obs_fq_helper.default",
  911. "aten._fused_moving_avg_obs_fq_helper_functional.default",
  912. "fbgemm.dense_to_jagged.default",
  913. "fbgemm.jagged_to_padded_dense.default",
  914. "run_and_save_rng_state",
  915. "run_with_rng_state",
  916. "aten._local_scalar_dense",
  917. # Technically, it's not necessary to ban this, because an
  918. # assert_scalar with constant arguments can be validly run
  919. # with CUDA graphs, but the operator is also pointless with
  920. # constant arguments, so might as well ban
  921. "aten._assert_scalar",
  922. ]
  923. )
  924. if torch.are_deterministic_algorithms_enabled():
  925. forbidden_set.update(
  926. (
  927. "aten._unsafe_index_put.default",
  928. "aten._unsafe_masked_index_put_accumulate.default",
  929. "aten.index_put.default",
  930. "aten.index_put_.default",
  931. "aten.scatter.src",
  932. "aten.scatter.reduce",
  933. "aten.scatter.value_reduce",
  934. "aten.scatter_add_",
  935. "aten.scatter_add.default",
  936. "aten.scatter_reduce.two",
  937. "aten.scatter_reduce_.two",
  938. "aten.scatter_reduce.two_out",
  939. )
  940. )
  941. for node in gm.graph.nodes:
  942. if str(node.target) in forbidden_set:
  943. return node
  944. if (
  945. not torch._inductor.config.graph_partition
  946. and isinstance(node.target, torch._ops.OpOverload)
  947. and torch._C.Tag.cudagraph_unsafe in node.target.tags # type: ignore[attr-defined]
  948. ):
  949. # skip cudagraph if a cudagraph_unsafe op is detected.
  950. # graph_partition helps by splitting on this cudagraph_unsafe
  951. # op and cudagraphifying the subgraphs.
  952. return node
  953. if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
  954. return node
  955. return None
  956. def output_node(gm: torch.fx.GraphModule) -> Node:
  957. """Get the output node from an FX graph"""
  958. last_node = next(iter(reversed(gm.graph.nodes)))
  959. assert last_node.op == "output"
  960. return last_node
  961. def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
  962. placeholder_nodes = gm.graph.find_nodes(op="placeholder")
  963. input_devices: OrderedSet[torch.device] = OrderedSet(
  964. node.meta["val"].device
  965. for node in placeholder_nodes
  966. if isinstance(node.meta.get("val"), torch.Tensor)
  967. )
  968. out_arg = output_node(gm).args[0] # type: ignore[union-attr]
  969. out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,)
  970. out_devices: OrderedSet[torch.device] = OrderedSet(
  971. arg.meta["val"].device
  972. for arg in out_args
  973. if isinstance(arg, torch.fx.Node)
  974. and isinstance(arg.meta.get("val"), torch.Tensor)
  975. )
  976. return input_devices | out_devices
  977. import gc
  978. def unload_xpu_triton_pyds() -> None:
  979. # unload __triton_launcher.pyd
  980. for module_name in list(sys.modules.keys()):
  981. if not module_name.startswith("torch._inductor.runtime.compile_tasks."):
  982. continue
  983. m = sys.modules[module_name]
  984. for attr_name in m.__dict__.keys():
  985. if attr_name.startswith("triton_"):
  986. kernel = getattr(m, attr_name)
  987. if isinstance(
  988. kernel, torch._inductor.runtime.triton_heuristics.CachingAutotuner
  989. ):
  990. for result in kernel.compile_results:
  991. if isinstance(
  992. result,
  993. torch._inductor.runtime.triton_heuristics.TritonCompileResult,
  994. ):
  995. result.kernel.run.mod.__del__()
  996. del sys.modules[module_name]
  997. # unload spirv_utils.pyd
  998. if "triton.runtime.driver" in sys.modules:
  999. mod = sys.modules["triton.runtime.driver"]
  1000. del type(mod.driver.active.utils).instance
  1001. del mod.driver.active.utils
  1002. gc.collect()
  1003. _registered_caches: list[Any] = []
  1004. def clear_on_fresh_cache(obj: Any) -> Any:
  1005. """
  1006. Use this decorator to register any caches that should be cache_clear'd
  1007. with fresh_cache().
  1008. """
  1009. if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
  1010. raise AttributeError(f"{obj} does not have a cache_clear method")
  1011. _registered_caches.append(obj)
  1012. return obj
  1013. def clear_caches() -> None:
  1014. """
  1015. Clear all registered caches.
  1016. """
  1017. for obj in _registered_caches:
  1018. obj.cache_clear()
  1019. @contextlib.contextmanager
  1020. def fresh_cache(
  1021. cache_entries: Optional[dict[str, Any]] = None,
  1022. dir: Optional[str] = None,
  1023. delete: bool = True,
  1024. ) -> Iterator[None]:
  1025. """
  1026. Contextmanager that provides a clean tmp cachedir for pt2 caches.
  1027. Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
  1028. generated with this cache instance.
  1029. """
  1030. clear_caches()
  1031. from torch._inductor.cpp_builder import normalize_path_separator
  1032. inductor_cache_dir = normalize_path_separator(tempfile.mkdtemp(dir=dir))
  1033. try:
  1034. with mock.patch.dict(
  1035. os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
  1036. ):
  1037. log.debug("Using inductor cache dir %s", inductor_cache_dir)
  1038. triton_cache_dir = normalize_path_separator(
  1039. os.path.join(inductor_cache_dir, "triton")
  1040. )
  1041. with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
  1042. yield
  1043. if isinstance(cache_entries, dict):
  1044. assert len(cache_entries) == 0, "expected empty cache_entries dict"
  1045. if os.path.exists(triton_cache_dir):
  1046. files = os.listdir(triton_cache_dir)
  1047. cache_entries.update(
  1048. {
  1049. f: os.path.getsize(os.path.join(triton_cache_dir, f))
  1050. for f in files
  1051. if ".lock" not in f
  1052. }
  1053. )
  1054. if delete:
  1055. if is_windows() and torch.xpu.is_available():
  1056. unload_xpu_triton_pyds()
  1057. shutil.rmtree(
  1058. inductor_cache_dir,
  1059. # Let's not fail if we can't clean up the temp dir. Also note that for
  1060. # Windows, we can't delete the loaded modules because the module binaries
  1061. # are open.
  1062. ignore_errors=is_windows(),
  1063. onerror=lambda func, path, exc_info: log.warning(
  1064. "Failed to remove temporary cache dir at %s",
  1065. inductor_cache_dir,
  1066. exc_info=exc_info,
  1067. ),
  1068. )
  1069. except Exception:
  1070. log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
  1071. raise
  1072. finally:
  1073. clear_caches()
  1074. # Deprecated functions -- only keeping them for BC reasons
  1075. clear_on_fresh_inductor_cache = clear_on_fresh_cache
  1076. clear_inductor_caches = clear_caches
  1077. fresh_inductor_cache = fresh_cache
  1078. def argsort(seq: Sequence[Any]) -> list[int]:
  1079. # preserve original order for equal strides
  1080. getter = seq.__getitem__
  1081. a_r = range(len(seq))
  1082. return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
  1083. def argsort_sym(
  1084. shape_env: ShapeEnv, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]]
  1085. ) -> list[int]:
  1086. def cmp(a: tuple[int, sympy.Expr], b: tuple[int, sympy.Expr]) -> int:
  1087. a_idx, a_val = a
  1088. b_idx, b_val = b
  1089. def evaluate(expr: Union[bool, torch.SymInt, sympy.Expr]) -> bool:
  1090. if isinstance(expr, bool):
  1091. return expr
  1092. return shape_env.evaluate_expr(expr, size_oblivious=True)
  1093. if evaluate(a_val < b_val):
  1094. return -1
  1095. if evaluate(a_val > b_val):
  1096. return 1
  1097. # If strides are the same, prefer the original order.
  1098. # (this matches argsort's algorithm).
  1099. # For strides = [2048, 2048, 16, 1], this is
  1100. # [3, 2, 1, 0].
  1101. if a_idx < b_idx:
  1102. return 1
  1103. if a_idx > b_idx:
  1104. return -1
  1105. return 0
  1106. # Strategy: convert all symints to sympy.Expr, then use a custom comparator
  1107. exprs = [
  1108. (idx, s.node.expr if isinstance(s, torch.SymInt) else s)
  1109. for idx, s in enumerate(seq)
  1110. ]
  1111. exprs = sorted(exprs, key=functools.cmp_to_key(cmp))
  1112. result = [idx for idx, _ in exprs]
  1113. return result
  1114. @functools.lru_cache(8)
  1115. def get_dtype_size(dtype: torch.dtype) -> int:
  1116. # TODO: Investigate why uint64 tensor creation causes overflow error:
  1117. # Workaround for RuntimeError in memory size calculation, but underlying cause unclear
  1118. if dtype == torch.uint64:
  1119. return 8
  1120. return torch.empty((), dtype=dtype).element_size()
  1121. class LineContext(NamedTuple):
  1122. context: Any
  1123. @dataclasses.dataclass
  1124. class ValueWithLineMap:
  1125. value: str
  1126. line_map: list[tuple[int, LineContext]]
  1127. class IndentedBuffer:
  1128. tabwidth = 4
  1129. def __init__(self, initial_indent: int = 0) -> None:
  1130. self._lines: list[Union[DeferredLineBase, LineContext, str]] = []
  1131. self._indent = initial_indent
  1132. @contextlib.contextmanager
  1133. def set_tabwidth(self, tabwidth: int) -> Iterator[None]:
  1134. prev = self.tabwidth
  1135. try:
  1136. self.tabwidth = tabwidth
  1137. yield
  1138. finally:
  1139. self.tabwidth = prev
  1140. def getvaluewithlinemap(self) -> ValueWithLineMap:
  1141. buf = StringIO()
  1142. p = 1
  1143. linemap: list[tuple[int, LineContext]] = []
  1144. for li in self._lines:
  1145. if isinstance(li, DeferredLineBase):
  1146. line = li()
  1147. if line is None:
  1148. continue
  1149. elif isinstance(li, LineContext):
  1150. linemap.append((p, li.context))
  1151. continue
  1152. else:
  1153. line = li
  1154. assert isinstance(line, str)
  1155. buf.write(line)
  1156. buf.write("\n")
  1157. p += 1 + line.count("\n")
  1158. return ValueWithLineMap(buf.getvalue(), linemap)
  1159. def getvalue(self) -> str:
  1160. return self.getvaluewithlinemap().value
  1161. def getrawvalue(self) -> str:
  1162. buf = StringIO()
  1163. for li in self._lines:
  1164. if isinstance(li, DeferredLineBase):
  1165. line = li()
  1166. if line is None:
  1167. continue
  1168. elif isinstance(li, LineContext):
  1169. continue
  1170. else:
  1171. line = li
  1172. assert isinstance(line, str)
  1173. # backslash implies line continuation
  1174. if line.endswith("\\"):
  1175. buf.write(line[:-1])
  1176. else:
  1177. buf.write(line)
  1178. buf.write("\n")
  1179. return buf.getvalue()
  1180. def clear(self) -> None:
  1181. self._lines.clear()
  1182. def __bool__(self) -> bool:
  1183. return bool(self._lines)
  1184. def prefix(self) -> str:
  1185. return " " * (self._indent * self.tabwidth)
  1186. def newline(self) -> None:
  1187. self.writeline("\n")
  1188. def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None:
  1189. if isinstance(line, LineContext):
  1190. self._lines.append(line)
  1191. elif isinstance(line, DeferredLineBase):
  1192. self._lines.append(line.with_prefix(self.prefix()))
  1193. elif line.strip():
  1194. self._lines.append(f"{self.prefix()}{line}")
  1195. else:
  1196. self._lines.append("")
  1197. def writelines(
  1198. self, lines: Sequence[Union[LineContext, DeferredLineBase, str]]
  1199. ) -> None:
  1200. for line in lines:
  1201. self.writeline(line)
  1202. def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
  1203. @contextlib.contextmanager
  1204. def ctx() -> Iterator[None]:
  1205. self._indent += offset
  1206. try:
  1207. yield
  1208. finally:
  1209. self._indent -= offset
  1210. return ctx()
  1211. def do_indent(self, offset: int = 1) -> None:
  1212. self._indent += offset
  1213. def do_unindent(self, offset: int = 1) -> None:
  1214. self._indent -= offset
  1215. def splice(
  1216. self, other_code: Union[IndentedBuffer, str], strip: bool = False
  1217. ) -> None:
  1218. if isinstance(other_code, IndentedBuffer):
  1219. dedent = float("inf")
  1220. for line in other_code._lines:
  1221. if not isinstance(line, LineContext) and line:
  1222. dedent = min(dedent, len(line) - len(line.lstrip()))
  1223. if math.isinf(dedent):
  1224. dedent = 0
  1225. for line in other_code._lines:
  1226. if isinstance(line, LineContext):
  1227. self._lines.append(line)
  1228. else:
  1229. IndentedBuffer.writeline(self, line[int(dedent) :])
  1230. else:
  1231. other_code = textwrap.dedent(other_code)
  1232. if strip:
  1233. other_code = other_code.lstrip()
  1234. if not other_code:
  1235. return
  1236. other_code = other_code.rstrip()
  1237. for s in other_code.split("\n"):
  1238. self.writeline(s)
  1239. def map(self, func: Callable[[Any], Any]) -> IndentedBuffer:
  1240. res = IndentedBuffer(initial_indent=self._indent)
  1241. res._lines = [func(line) for line in self._lines]
  1242. return res
  1243. def __repr__(self) -> str:
  1244. return f"{type(self)}({self.getvalue()})"
  1245. def __add__(self, other: Self) -> IndentedBuffer:
  1246. assert self._indent == other._indent
  1247. res = IndentedBuffer(initial_indent=self._indent)
  1248. # TODO(rec): or should this be self.__class__(initial_indent=self._indent)?
  1249. res.writelines(self._lines)
  1250. res.writelines(other._lines)
  1251. return res
  1252. def contains(self, new_line: Union[DeferredLineBase, LineContext, str]) -> bool:
  1253. return new_line in self._lines
  1254. class FakeIndentedBuffer(IndentedBuffer):
  1255. def __init__(self) -> None:
  1256. super().__init__()
  1257. def __getattribute__(self, name: str) -> Any:
  1258. if name == "__class__": # Allow access to the class attribute
  1259. return object.__getattribute__(self, name)
  1260. raise RuntimeError(
  1261. f"Tried to call self.{name} on FakeIndentedBuffer. This buffer"
  1262. "is currently used on TritonTemplateKernel to prevent actual"
  1263. "writes to the body without explicitly specifying the body with"
  1264. "`TritonTemplateKernel.set_subgraph_body(name)`"
  1265. )
  1266. @contextlib.contextmanager
  1267. def restore_stdout_stderr() -> Iterator[None]:
  1268. initial_stdout, initial_stderr = sys.stdout, sys.stderr
  1269. try:
  1270. yield
  1271. finally:
  1272. sys.stdout, sys.stderr = initial_stdout, initial_stderr
  1273. class DeferredLineBase:
  1274. """A line that can be 'unwritten' at a later time"""
  1275. def __init__(self, line: str):
  1276. if not line.strip():
  1277. line = ""
  1278. self.line = line
  1279. def __call__(self) -> Union[str, None]:
  1280. """Returns either self.line or None to indicate the line has been 'unwritten'"""
  1281. raise NotImplementedError
  1282. def _new_line(self, line: str) -> Self:
  1283. """Returns a new deferred line with the same condition"""
  1284. raise NotImplementedError
  1285. def with_prefix(self, prefix: str) -> Self:
  1286. return self._new_line(f"{prefix}{self.line}")
  1287. def lstrip(self) -> Self:
  1288. return self._new_line(self.line.lstrip())
  1289. def __getitem__(self, index: Union[int, slice]) -> Self:
  1290. return self._new_line(self.line[index])
  1291. def __bool__(self) -> bool:
  1292. return bool(self.line)
  1293. def __len__(self) -> int:
  1294. return len(self.line)
  1295. class DelayReplaceLine(DeferredLineBase):
  1296. """At end of codegen call `line.replace(key, value_fn())`"""
  1297. def __init__(self, key: str, value_fn: Callable[[], str], line: str):
  1298. super().__init__(line)
  1299. self.key = key
  1300. self.value_fn = value_fn
  1301. def __call__(self) -> str:
  1302. return self.line.replace(self.key, self.value_fn())
  1303. def _new_line(self, line: str) -> DelayReplaceLine:
  1304. return DelayReplaceLine(self.key, self.value_fn, line)
  1305. @functools.cache
  1306. def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
  1307. if isinstance(index_or_device, torch.device):
  1308. device = index_or_device
  1309. else:
  1310. device = torch.device(get_gpu_type(), index_or_device)
  1311. prop = DeviceProperties.create(device)
  1312. # SM logic is not relevant to ROCm gpus
  1313. # Arbitrarily skipping the older models
  1314. if torch.version.hip:
  1315. assert prop.major is not None
  1316. if prop.major < 9 or prop.major == 10:
  1317. log.warning("GPU arch does not support max_autotune_gemm mode usage")
  1318. return False
  1319. return True
  1320. min_sms = 16 if device.type == "xpu" else 68 # 3080
  1321. avail_sms = prop.multi_processor_count
  1322. if avail_sms < min_sms:
  1323. log.warning(
  1324. "Not enough SMs to use max_autotune_gemm mode",
  1325. extra={"min_sms": min_sms, "avail_sms": avail_sms},
  1326. )
  1327. return False
  1328. return True
  1329. @functools.lru_cache
  1330. def get_max_num_sms() -> int:
  1331. if torch.xpu.is_available():
  1332. return torch.xpu.get_device_properties().gpu_subslice_count
  1333. return torch.cuda.get_device_properties("cuda").multi_processor_count
  1334. @functools.lru_cache
  1335. def using_b200() -> bool:
  1336. """Returns true if the device is a NVIDIA B200, otherwise returns false."""
  1337. if not torch.cuda.is_available():
  1338. return False
  1339. # compute capability 10.0 or 10.0a is NVIDIA B200
  1340. device_properties = torch.cuda.get_device_properties(torch.cuda.current_device())
  1341. return device_properties.major == 10
  1342. def get_num_sms() -> int:
  1343. """Handle experimental carveout if set otherwise return hardware SM count"""
  1344. # TODO we need to properly guard on this global
  1345. if torch.xpu.is_available():
  1346. return get_max_num_sms()
  1347. carveout = torch._C._get_sm_carveout_experimental()
  1348. return get_max_num_sms() - (carveout if carveout is not None else 0)
  1349. def get_tma_workspace_arg(
  1350. num_tma_descriptors: int,
  1351. device: torch.device,
  1352. num_programs: Optional[int] = None,
  1353. ) -> WorkspaceArg:
  1354. """Builds and returns a WorkspaceArg for the device side TMA workspace buffer."""
  1355. from .codegen.common import WorkspaceArg, WorkspaceZeroMode
  1356. if num_programs is None:
  1357. num_programs = get_num_sms()
  1358. zero_mode = WorkspaceZeroMode.from_bool(False)
  1359. size = num_programs * num_tma_descriptors * TMA_DESCRIPTOR_SIZE
  1360. return WorkspaceArg(
  1361. count=size,
  1362. zero_mode=zero_mode,
  1363. device=device,
  1364. outer_name=WorkspaceArg.unique_name(),
  1365. )
  1366. def _use_template_for_gpu(
  1367. layout: Layout, allowed_layout_dtypes: list[torch.dtype]
  1368. ) -> bool:
  1369. if layout.dtype not in allowed_layout_dtypes:
  1370. log.debug(
  1371. "Not using template since dtype %s is not in allowed layout dtypes %s",
  1372. layout.dtype,
  1373. allowed_layout_dtypes,
  1374. )
  1375. return (
  1376. is_gpu(layout.device.type)
  1377. and layout.dtype in allowed_layout_dtypes
  1378. and is_big_gpu(layout.device)
  1379. )
  1380. def _use_autotune_backend(backend: str) -> bool:
  1381. return backend.upper() in [
  1382. x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
  1383. ]
  1384. def _use_conv_autotune_backend(backend: str) -> bool:
  1385. return backend.upper() in [
  1386. x.strip() for x in config.max_autotune_conv_backends.upper().split(",")
  1387. ]
  1388. def use_triton_template(
  1389. layout: Layout,
  1390. *,
  1391. enable_int32: bool = False,
  1392. enable_float8: bool = False,
  1393. check_max_autotune: bool = True,
  1394. ) -> bool:
  1395. from .codegen.common import BackendFeature, has_backend_feature
  1396. layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
  1397. if enable_int32:
  1398. layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
  1399. if enable_float8:
  1400. layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2])
  1401. return (
  1402. (
  1403. (
  1404. is_gpu(layout.device.type)
  1405. and _use_template_for_gpu(layout, layout_dtypes)
  1406. )
  1407. or (layout.device.type == "cpu" and layout.dtype in layout_dtypes)
  1408. )
  1409. # some callers handle max-autotune checking externally
  1410. and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune)
  1411. and _use_autotune_backend("TRITON")
  1412. and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
  1413. )
  1414. def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
  1415. """
  1416. Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints
  1417. that Triton relies on today.
  1418. * https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
  1419. A tensor is accepted when:
  1420. * 2 ≤ rank ≤ 5
  1421. * dtype ∈ {FP16, BF16, FP8-E4M3FN}
  1422. * Every logical size ≥ 2
  1423. * Base pointer 16-byte aligned
  1424. * All "outer" dims have 16-byte aligned strides
  1425. * The “inner” dim has stride 1 (contiguous)
  1426. * For FP8 tensors, inner dim ≥ 32
  1427. """
  1428. from torch.utils._triton import has_triton_tma_device
  1429. from .virtualized import V
  1430. def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool:
  1431. return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT)
  1432. def _is_tma_compatible_default(x: IRNode) -> bool:
  1433. sizes = x.get_size()
  1434. strides = x.get_stride()
  1435. rank = len(sizes)
  1436. dtype = x.get_dtype()
  1437. itemsize = dtype.itemsize
  1438. # 2 ≤ rank ≤ 5
  1439. if rank < 2 or rank > 5:
  1440. return False
  1441. # dtype ∈ {FP16, BF16, FP8-E4M3FN}
  1442. if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn):
  1443. return False
  1444. # Base pointer 16-byte aligned
  1445. if x.get_name() in V.graph.unaligned_buffers:
  1446. return False
  1447. if add_guards:
  1448. sizes_i = V.graph.sizevars.guard_int_seq(sizes)
  1449. strides_i = V.graph.sizevars.guard_int_seq(strides)
  1450. else:
  1451. sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes]
  1452. strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
  1453. # Every logical size ≥ 2
  1454. if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i):
  1455. return False
  1456. # Find the single contiguous (“inner”) dim
  1457. inner = [
  1458. i
  1459. for i, st in enumerate(strides_i)
  1460. if V.graph.sizevars.statically_known_equals(st, 1)
  1461. ]
  1462. if len(inner) != 1:
  1463. return False
  1464. inner_idx = inner[0]
  1465. # All "outer" dims must have 16-byte aligned strides
  1466. for i, st in enumerate(strides_i):
  1467. if i == inner_idx:
  1468. continue
  1469. if not _aligned(st * itemsize):
  1470. return False
  1471. # Inner dim byte width must still be a multiple of 16 B
  1472. inner_dim = sizes_i[inner_idx]
  1473. if not _aligned(inner_dim * itemsize):
  1474. return False
  1475. # FP8 special case: inner ≥ 32
  1476. if dtype == torch.float8_e4m3fn and not V.graph.sizevars.statically_known_geq(
  1477. inner_dim, 32
  1478. ):
  1479. return False
  1480. return True
  1481. def _is_tma_compatible_xpu(x: IRNode) -> bool:
  1482. strides = x.get_stride()
  1483. strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
  1484. # Find the single contiguous (“inner”) dim
  1485. inner = [
  1486. i
  1487. for i, st in enumerate(strides_i)
  1488. if V.graph.sizevars.statically_known_equals(st, 1)
  1489. ]
  1490. if len(inner) != 1:
  1491. return False
  1492. return True
  1493. return has_triton_tma_device() and all(
  1494. _is_tma_compatible_default(m)
  1495. if (m_device := m.get_device()) is None or m_device.type != "xpu"
  1496. else _is_tma_compatible_xpu(m)
  1497. for m in matrices
  1498. )
  1499. def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
  1500. return (
  1501. all(len(m.get_size()) == 2 for m in matrices)
  1502. and can_use_tma(*matrices, add_guards=add_guards)
  1503. and config.triton.enable_persistent_tma_matmul
  1504. )
  1505. def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
  1506. from .virtualized import V
  1507. gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
  1508. if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
  1509. return False
  1510. from .codegen.cuda.cutlass_utils import try_import_cutlass
  1511. # Do not use cutlass template on ROCm
  1512. if torch.version.hip:
  1513. return False
  1514. # output dtype
  1515. # FP32 not supported: https://github.com/pytorch/pytorch/issues/145952
  1516. layout_dtypes = [torch.float16, torch.bfloat16, torch.int32]
  1517. res = (
  1518. _use_template_for_gpu(layout, layout_dtypes)
  1519. and (config.max_autotune or config.max_autotune_gemm)
  1520. and _use_autotune_backend("CUTLASS")
  1521. )
  1522. if res:
  1523. if not try_import_cutlass():
  1524. log.warning(
  1525. "Failed to import CUTLASS lib. Please check whether "
  1526. "_inductor.config.cuda.cutlass_dir %s is set correctly. "
  1527. "Skipping CUTLASS backend for now.",
  1528. config.cuda.cutlass_dir,
  1529. )
  1530. return False
  1531. return res
  1532. def _use_cutlass_for_op(op_name: str) -> bool:
  1533. """Check if CUTLASS should be used for the given operation."""
  1534. enabled_ops = config.cuda.cutlass_enabled_ops.upper()
  1535. if enabled_ops == "ALL":
  1536. return True
  1537. return op_name.upper() in [x.strip() for x in enabled_ops.split(",")]
  1538. _IntLike: TypeAlias = Union[int, sympy.Expr]
  1539. @functools.cache
  1540. def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
  1541. from torch._inductor.virtualized import V
  1542. decompose_k_threshold = config.triton.decompose_k_threshold
  1543. return (
  1544. not torch.version.hip
  1545. and V.graph.sizevars.statically_known_true(
  1546. sympy.And(
  1547. sympy.Ge(k, decompose_k_threshold * m),
  1548. sympy.Ge(k, decompose_k_threshold * n),
  1549. )
  1550. )
  1551. and not V.graph.aot_mode # TODO: Support AOTI for decomposeK
  1552. and not V.graph.cpp_wrapper
  1553. )
  1554. @functools.cache
  1555. def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
  1556. """
  1557. Check if we should use the contiguous subgraph transform.
  1558. This transform makes the second matrix contiguous before the matmul.
  1559. """
  1560. contiguous_threshold = config.rocm.contiguous_threshold
  1561. # Similar conditions to decompose_k but for contiguous transform
  1562. from torch._inductor.virtualized import V
  1563. return (
  1564. bool(torch.version.hip) # Only relevant on AMD
  1565. and V.graph.sizevars.statically_known_true(
  1566. sympy.And(
  1567. sympy.Ge(k, contiguous_threshold * m),
  1568. sympy.Ge(k, contiguous_threshold * n),
  1569. )
  1570. )
  1571. and not V.graph.aot_mode
  1572. and not V.graph.cpp_wrapper
  1573. )
  1574. @functools.cache
  1575. def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
  1576. # To limit compile time
  1577. k_splits_limit = config.triton.num_decompose_k_splits
  1578. # Hand-tuned
  1579. default_k_splits = [16, 32, 64, 128, 256]
  1580. # If k is a sympy expression, we can't do any splitting
  1581. if isinstance(k, sympy.Expr) and not k.is_number:
  1582. return default_k_splits
  1583. elif k_splits_limit == 0:
  1584. return []
  1585. if (isinstance(m, sympy.Expr) and not m.is_number) or (
  1586. isinstance(n, sympy.Expr) and not n.is_number
  1587. ):
  1588. max_k_split = 256
  1589. else:
  1590. max_k_split = min(k // m, k // n)
  1591. min_k_split = 2
  1592. # Get all divisors of k, k has to be divisible by kPart
  1593. divisors = sympy.divisors(k)
  1594. divisors = [
  1595. divisor
  1596. for divisor in divisors
  1597. if divisor <= max_k_split and divisor >= min_k_split
  1598. ]
  1599. pow_of_2_divisors, mul_of_32_divisors, rest_of_splits = [], [], []
  1600. for d in divisors:
  1601. kPart = k // d
  1602. # Smaller than 128 might not even fit in a single tile, BLOCK_K can be 128
  1603. if kPart < 128:
  1604. continue
  1605. # Power of 2 divisors are best performing, conform to hardware
  1606. if (kPart & kPart - 1) == 0 and kPart >= 128:
  1607. pow_of_2_divisors.append(d)
  1608. # Else check if creates a multiple of 32
  1609. elif kPart % 32 == 0:
  1610. mul_of_32_divisors.append(d)
  1611. # otherwise, take the smallest values
  1612. else:
  1613. rest_of_splits.append(d)
  1614. if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
  1615. return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
  1616. best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
  1617. # Otherwise, conform results to k_splits_limit
  1618. return best_splits[:k_splits_limit]
  1619. @functools.cache
  1620. def _rocm_native_device_arch_name(device: str) -> str:
  1621. return torch.cuda.get_device_properties(device).gcnArchName
  1622. @functools.cache
  1623. def try_import_ck_lib() -> tuple[
  1624. Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
  1625. ]:
  1626. try:
  1627. import ck4inductor # type: ignore[import]
  1628. from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
  1629. gen_ops_library,
  1630. gen_ops_preselected,
  1631. )
  1632. from ck4inductor.universal_gemm.op import ( # type: ignore[import]
  1633. CKGemmOperation,
  1634. )
  1635. package_dirname = os.path.dirname(ck4inductor.__file__)
  1636. except ImportError:
  1637. def gen_ops_library() -> list[Any]:
  1638. return []
  1639. def gen_ops_preselected() -> list[Any]:
  1640. return []
  1641. class CKGemmOperation: # type: ignore[no-redef]
  1642. pass
  1643. package_dirname = None
  1644. return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation
  1645. def use_ck_template(layout: Layout) -> bool:
  1646. # config knobs check 1
  1647. if not (config.max_autotune or config.max_autotune_gemm):
  1648. return False
  1649. # platform check
  1650. if not torch.version.hip:
  1651. return False
  1652. # tensors must be on GPU
  1653. if not layout.device.type == "cuda":
  1654. return False
  1655. # hardware check
  1656. # if config arch list is not specified, get the native arch from the device properties
  1657. native_arch = _rocm_native_device_arch_name(layout.device)
  1658. requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or {
  1659. native_arch.split(":")[0]: native_arch
  1660. }
  1661. requested_supported_archs = [
  1662. requested_archs[k]
  1663. for k in requested_archs.keys() & config.rocm.ck_supported_arch
  1664. ]
  1665. if not requested_supported_archs:
  1666. return False
  1667. # supported input dtypes
  1668. if layout.dtype not in [torch.float16, torch.bfloat16, torch.float32]:
  1669. return False
  1670. ck_package_dirname, _, _, _ = try_import_ck_lib()
  1671. if not ck_package_dirname:
  1672. log.warning("Please pip install Composable Kernel package")
  1673. return False
  1674. if config.is_fbcode():
  1675. config.rocm.ck_dir = ck_package_dirname
  1676. if not config.rocm.ck_dir:
  1677. log.warning("Please set TORCHINDUCTOR_CK_DIR env variable")
  1678. return False
  1679. if ck_package_dirname != config.rocm.ck_dir:
  1680. log.warning("Invalid path to CK library")
  1681. return False
  1682. return True
  1683. def use_ck_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool:
  1684. from .virtualized import V
  1685. return (
  1686. _use_autotune_backend("CK")
  1687. and use_ck_template(layout)
  1688. and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0
  1689. )
  1690. def use_ck_tile_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool:
  1691. from .virtualized import V
  1692. return (
  1693. _use_autotune_backend("CKTILE")
  1694. and use_ck_template(layout)
  1695. and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0
  1696. )
  1697. def use_ck_conv_template(layout: Layout) -> bool:
  1698. return _use_conv_autotune_backend("CK") and use_ck_template(layout)
  1699. def _use_template_for_cpu(layout: Layout) -> bool:
  1700. return (
  1701. config.max_autotune or config.max_autotune_gemm
  1702. ) and layout.device.type == "cpu"
  1703. def use_cpp_bmm_template(
  1704. layout: Layout, mat1: Union[ReinterpretView, Buffer], mat2: IRNode
  1705. ) -> bool:
  1706. from .ir import Layout
  1707. assert isinstance(mat1.layout, Layout)
  1708. return (
  1709. use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False)
  1710. and mat1.layout.is_contiguous()
  1711. )
  1712. def use_cpp_gemm_template(
  1713. layout: Layout,
  1714. mat1: IRNode,
  1715. mat2: IRNode,
  1716. mat2_transposed: bool = False,
  1717. require_constant_mat2: bool = True,
  1718. is_woq_int4: bool = False,
  1719. q_group_size: Optional[int] = None,
  1720. ) -> bool:
  1721. from . import ir
  1722. from .codegen.cpp_micro_gemm import create_micro_gemm
  1723. from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype
  1724. from .kernel.mm_common import mm_args
  1725. if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"):
  1726. return False
  1727. if not config.cpp.weight_prepack:
  1728. return False
  1729. int8_gemm = mat1.get_dtype() in [torch.uint8, torch.int8]
  1730. layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8]
  1731. m, n, k, layout, mat1, mat2 = mm_args(
  1732. mat1,
  1733. mat2,
  1734. out_dtype=layout.dtype if int8_gemm else None,
  1735. mat2_transposed=mat2_transposed,
  1736. use_4x2_dim=is_woq_int4,
  1737. )
  1738. # TODO(jgong5): support dynamic shapes for n or k
  1739. if has_free_symbols((n, k)):
  1740. return False
  1741. if isinstance(mat2, ir.BaseView):
  1742. mat2 = mat2.unwrap_view()
  1743. output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype())
  1744. micro_gemm = create_micro_gemm(
  1745. "micro_gemm",
  1746. m,
  1747. n,
  1748. k,
  1749. input_dtype=mat1.get_dtype(),
  1750. input2_dtype=mat2.get_dtype(),
  1751. output_dtype=output_dtype,
  1752. num_threads=parallel_num_threads(),
  1753. use_ref=not is_woq_int4,
  1754. q_group_size=q_group_size,
  1755. )
  1756. def is_last_dim_stride1(x: IRNode) -> bool:
  1757. x.freeze_layout()
  1758. return x.get_stride()[-1] == 1
  1759. return (
  1760. layout.dtype in layout_dtypes
  1761. and micro_gemm is not None
  1762. and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input
  1763. and isinstance(mat2, ir.StorageBox)
  1764. and (mat2.is_module_buffer() or not require_constant_mat2)
  1765. )
  1766. def use_aten_gemm_kernels() -> bool:
  1767. return not (
  1768. config.max_autotune or config.max_autotune_gemm
  1769. ) or _use_autotune_backend("ATEN")
  1770. class DebugDirManager:
  1771. counter = itertools.count(0)
  1772. prev_debug_name: str
  1773. def __init__(self) -> None:
  1774. self.id = next(DebugDirManager.counter)
  1775. def __enter__(self) -> None:
  1776. self.prev_debug_name = torch._dynamo.config.debug_dir_root
  1777. self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
  1778. torch._dynamo.config.debug_dir_root = self.new_name
  1779. def __exit__(self, *args: Any) -> None:
  1780. shutil.rmtree(self.new_name)
  1781. torch._dynamo.config.debug_dir_root = self.prev_debug_name
  1782. def run_and_get_code(
  1783. fn: Callable[P, _T],
  1784. *args: P.args,
  1785. **kwargs: P.kwargs,
  1786. ) -> tuple[_T, list[str]]:
  1787. from .graph import GraphLowering
  1788. source_codes: list[str] = []
  1789. def save_output_code(code: str) -> None:
  1790. source_codes.append(code)
  1791. with mock.patch.object(GraphLowering, "save_output_code", save_output_code):
  1792. torch._dynamo.reset()
  1793. result = fn(*args, **kwargs)
  1794. return result, source_codes
  1795. def run_and_get_kernels(
  1796. fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
  1797. ) -> tuple[_T, list[str]]:
  1798. result, source_codes = run_and_get_code(fn, *args, **kwargs)
  1799. kernels = []
  1800. for code in source_codes:
  1801. kernels.extend(re.findall(r"'''.*?'''", code, re.DOTALL))
  1802. return result, kernels
  1803. def run_fw_bw_and_get_code(fn: Callable[..., Any]) -> tuple[Any, list[str]]:
  1804. def run_with_backward() -> Any:
  1805. result = fn()
  1806. result.sum().backward()
  1807. return result
  1808. return run_and_get_code(run_with_backward)
  1809. def get_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> list[str]:
  1810. """Get the inductor-generated code, but skip any actual compilation or running."""
  1811. from .graph import GraphLowering
  1812. source_codes: list[str] = []
  1813. def save_output_code(code: str) -> None:
  1814. source_codes.append(code)
  1815. def patched_compile_to_module(self: GraphLowering) -> Any:
  1816. class DummyModule:
  1817. """This is empty to replace the generated triton module"""
  1818. def __init__(self) -> None:
  1819. pass
  1820. def call(self, *args: Any, **kwargs: Any) -> None:
  1821. # Don't do anything when called
  1822. pass
  1823. wrapper_code, kernel_code = (
  1824. self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  1825. )
  1826. # Skip all the actual compiling.
  1827. save_output_code(wrapper_code.value)
  1828. if kernel_code:
  1829. save_output_code(kernel_code.value)
  1830. return DummyModule()
  1831. with (
  1832. mock.patch.object(
  1833. GraphLowering, "compile_to_module", patched_compile_to_module
  1834. ),
  1835. mock.patch.object(GraphLowering, "save_output_code", save_output_code),
  1836. ):
  1837. torch._dynamo.reset()
  1838. # Note the return here is None
  1839. _ = fn(*args, **kwargs)
  1840. return source_codes
  1841. def get_triton_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> str:
  1842. source_codes = get_code(fn, *args, **kwargs)
  1843. # Can have two outputs if backwards was eagerly compiled
  1844. assert 1 <= len(source_codes) <= 2, (
  1845. f"expected one or two code outputs got {len(source_codes)}"
  1846. )
  1847. return source_codes[0]
  1848. def run_and_get_triton_code(
  1849. fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
  1850. ) -> str:
  1851. _, source_codes = run_and_get_code(fn, *args, **kwargs)
  1852. # Can have two outputs if backwards was eagerly compiled
  1853. assert 1 <= len(source_codes) <= 2, (
  1854. f"expected one or two code outputs got {len(source_codes)}"
  1855. )
  1856. return source_codes[0]
  1857. def run_and_get_graph_lowering(
  1858. fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
  1859. ) -> tuple[Any, list[GraphLowering]]:
  1860. from torch._inductor.graph import GraphLowering
  1861. from torch._inductor.output_code import CompiledFxGraph
  1862. real_init = CompiledFxGraph.__init__
  1863. graph_lowerings = []
  1864. def fake_init(*args: Any, **kwargs: Any) -> None:
  1865. real_init(*args, **kwargs)
  1866. graph = args[2]
  1867. assert isinstance(graph, GraphLowering)
  1868. graph_lowerings.append(graph)
  1869. with mock.patch.object(CompiledFxGraph, "__init__", fake_init):
  1870. result = fn(*args, **kwargs)
  1871. return result, graph_lowerings
  1872. @contextlib.contextmanager
  1873. def override_lowering(
  1874. aten_op: Callable[..., Any], override_fn: Callable[..., Any]
  1875. ) -> Iterator[None]:
  1876. """
  1877. Override the lowering of aten_op with override_fn.
  1878. The first argument of override_fn is the original lowering fn.
  1879. """
  1880. from torch._inductor import lowering
  1881. orig_fn = lowering.lowerings[aten_op]
  1882. try:
  1883. lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
  1884. yield
  1885. finally:
  1886. lowering.lowerings[aten_op] = orig_fn
  1887. def add_scheduler_init_hook(
  1888. pre_fn: Callable[..., Any], post_fn: Optional[Callable[..., Any]] = None
  1889. ) -> Any:
  1890. """
  1891. Add hook functions to be called at the beginning and end of Scheduler.__init__.
  1892. Used for unit tests.
  1893. """
  1894. from torch._inductor.scheduler import Scheduler
  1895. orig_fn = Scheduler.__init__
  1896. def wrapper(scheduler: Any, nodes: Any) -> Any:
  1897. pre_fn(scheduler, nodes)
  1898. out = orig_fn(scheduler, nodes)
  1899. if post_fn:
  1900. post_fn(scheduler, nodes)
  1901. return out
  1902. return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
  1903. def developer_warning(msg: str) -> None:
  1904. """
  1905. Warnings that will be actionable for PyTorch developers, but not
  1906. end users. Allows us to easily disable them in stable releases but
  1907. keep them on for nightly builds.
  1908. """
  1909. if config.developer_warnings:
  1910. log.warning(msg)
  1911. else:
  1912. log.info(msg)
  1913. def get_benchmark_name() -> Optional[str]:
  1914. """
  1915. An experimental API used only when config.benchmark_kernel is true.
  1916. The benchmark name is only available at codegen time. So we can not
  1917. directly call it in benchmark_all_kernels which is run after codegen.
  1918. The function assumes the argument after --only is the benchmark name.
  1919. It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
  1920. scripts, this function may return None.
  1921. There are 2 flavors of --only argument we need handle:
  1922. 1. --only model_name
  1923. 2. --only=model_name
  1924. """
  1925. try:
  1926. idx = sys.argv.index("--only")
  1927. if (
  1928. idx + 1 < len(sys.argv)
  1929. and len(sys.argv[idx + 1]) > 0
  1930. and sys.argv[idx + 1][0] != "-"
  1931. ):
  1932. return sys.argv[idx + 1]
  1933. except ValueError:
  1934. pass
  1935. for arg in sys.argv:
  1936. if arg.startswith("--only="):
  1937. return arg[len("--only=") :]
  1938. return None
  1939. def is_ones(items: Sequence[Any]) -> bool:
  1940. return all(x == 1 for x in items)
  1941. def is_zeros(items: Sequence[Any]) -> bool:
  1942. return all(x == 0 for x in items)
  1943. def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool:
  1944. return all(
  1945. item.device == torch.device("cpu")
  1946. for item in inputs
  1947. if isinstance(item, torch.Tensor)
  1948. )
  1949. def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
  1950. assert isinstance(val, sympy.Expr), (
  1951. "only support sympy.Expr as input to get_sympy_Expr_dtype"
  1952. )
  1953. if val.is_integer: # type: ignore[attr-defined]
  1954. return torch.int64
  1955. else:
  1956. return torch.float64
  1957. @contextlib.contextmanager
  1958. def maybe_profile(should_profile: bool, *args: Any, **kwargs: Any) -> Iterator[Any]:
  1959. if should_profile:
  1960. with torch.profiler.profile(*args, **kwargs) as p:
  1961. yield p
  1962. else:
  1963. yield
  1964. def parallel_num_threads() -> int:
  1965. threads = config.cpp.threads
  1966. if threads < 1:
  1967. threads = torch.get_num_threads()
  1968. return threads
  1969. @functools.cache
  1970. def get_backend_num_stages() -> int:
  1971. from .runtime.triton_helpers import get_backend_options
  1972. options = get_backend_options()
  1973. return options.get("num_stages", 2 if torch.version.hip else 3)
  1974. @functools.cache
  1975. def get_device_tflops(dtype: torch.dtype) -> float:
  1976. """
  1977. We don't want to throw errors in this function. First check to see if the device is in device_info.py,
  1978. then fall back to the inaccurate triton estimation.
  1979. """
  1980. ds_tops = datasheet_tops(dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32)
  1981. if ds_tops is not None:
  1982. return ds_tops
  1983. from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
  1984. SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
  1985. 8,
  1986. 0,
  1987. )
  1988. assert dtype in (torch.float16, torch.bfloat16, torch.float32)
  1989. if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
  1990. # Triton API change in https://github.com/triton-lang/triton/pull/2293
  1991. from torch._utils_internal import max_clock_rate
  1992. sm_clock = max_clock_rate()
  1993. if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
  1994. return get_max_tensorcore_tflops(dtype, sm_clock)
  1995. if torch.backends.cuda.matmul.allow_tf32:
  1996. return get_max_tensorcore_tflops(torch.float32, sm_clock)
  1997. else:
  1998. return get_max_simd_tflops(torch.float32, sm_clock)
  1999. else:
  2000. if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
  2001. return get_max_tensorcore_tflops(dtype)
  2002. if torch.backends.cuda.matmul.allow_tf32:
  2003. return get_max_tensorcore_tflops(torch.float32)
  2004. else:
  2005. return get_max_simd_tflops(torch.float32)
  2006. @functools.cache
  2007. def get_gpu_dram_gbps() -> int:
  2008. from triton.testing import get_dram_gbps
  2009. return get_dram_gbps()
  2010. def get_gpu_shared_memory() -> int:
  2011. from triton.runtime import driver
  2012. return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
  2013. def is_welford_reduction(reduction_type: str) -> bool:
  2014. return reduction_type.startswith("welford")
  2015. def reduction_num_outputs(reduction_type: str) -> int:
  2016. if is_welford_reduction(reduction_type):
  2017. return 3
  2018. elif reduction_type == "online_softmax_reduce":
  2019. return 2
  2020. else:
  2021. return 1
  2022. def is_linux() -> bool:
  2023. return platform.system() == "Linux"
  2024. def is_windows() -> bool:
  2025. return sys.platform == "win32"
  2026. def has_free_symbols(itr: Iterable[Any]) -> bool:
  2027. return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
  2028. def is_dynamic(*args: Any) -> bool:
  2029. from . import ir
  2030. for t in args:
  2031. if isinstance(
  2032. t, (ir.TensorBox, ir.StorageBox, ir.BaseView, ir.ComputedBuffer, ir.Buffer)
  2033. ):
  2034. if has_free_symbols(t.maybe_get_size() or ()) or has_free_symbols(
  2035. t.maybe_get_stride() or ()
  2036. ):
  2037. return True
  2038. elif not isinstance(t, ir.IRNode):
  2039. continue
  2040. else:
  2041. raise TypeError(f"unexpected type for is_dynamic {type(t)}")
  2042. return False
  2043. # Placeholder strings used in triton codegen.
  2044. class Placeholder(enum.Enum):
  2045. # The placeholder for the actual name of a triton kernel.
  2046. # e.g. for "def triton_" it would be "triton_"
  2047. KERNEL_NAME = "KERNEL_NAME"
  2048. # The descriptive name of the triton kernel; when unique_kernel_names = False, this
  2049. # placeholder will be replaced with a string with more information.
  2050. DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
  2051. def pass_execution_and_save(
  2052. func: Callable[..., Any], gm: GraphModule, inp: Sequence[Any], msg: str
  2053. ) -> None:
  2054. from .pattern_matcher import stable_topological_sort
  2055. with tempfile.NamedTemporaryFile(
  2056. mode="w",
  2057. encoding="utf-8",
  2058. delete=False,
  2059. ) as f:
  2060. before_io = io.StringIO()
  2061. after_io = io.StringIO()
  2062. ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
  2063. print(f"Before:\n{gm.graph}", file=f)
  2064. print(gm.graph, file=before_io)
  2065. start_time = datetime.now()
  2066. with GraphTransformObserver(gm, msg):
  2067. func(gm.graph)
  2068. time_elapsed = datetime.now() - start_time
  2069. # recompile graph
  2070. stable_topological_sort(gm.graph)
  2071. gm.graph.lint()
  2072. gm.recompile()
  2073. print(f"After:\n{gm.graph}", file=f)
  2074. print(gm.graph, file=after_io)
  2075. t = before_io.getvalue() == after_io.getvalue()
  2076. log.info(
  2077. "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
  2078. msg,
  2079. f.name,
  2080. t,
  2081. time_elapsed,
  2082. )
  2083. def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) -> bool:
  2084. """
  2085. Check if input buffer is a multi-outputs template buffer
  2086. """
  2087. from . import ir
  2088. return isinstance(input_buf, ir.CppTemplateBuffer) and isinstance(
  2089. input_buf.layout, ir.MultiOutputLayout
  2090. )
  2091. def is_output_of_multi_outputs_template(
  2092. input_buf: Optional[Union[Buffer, Operation]],
  2093. ) -> bool:
  2094. """
  2095. Check if input buffer is a output of multi-outputs template buffer
  2096. """
  2097. from . import ir
  2098. return (
  2099. isinstance(input_buf, ir.MultiOutput)
  2100. and len(input_buf.inputs) == 1
  2101. and is_multi_outputs_template(input_buf.inputs[0]) # type: ignore[arg-type]
  2102. )
  2103. def is_collective(
  2104. node: Optional[Union[Node, Operation]],
  2105. op: Optional[torch._ops.OperatorBase] = None,
  2106. ) -> bool:
  2107. if node is None:
  2108. return False
  2109. from . import ir
  2110. return (
  2111. isinstance(node, ir._CollectiveKernel)
  2112. and not isinstance(node, ir._WaitKernel)
  2113. and (op is None or node.op_overload is op)
  2114. ) or (
  2115. # TODO: this is a temporary solution to ensure that we can identify torchrec's
  2116. # communication ops. But in order to allow better communication and computation
  2117. # overlap, torchrec's communication ops should be not used.
  2118. type(node) == ir.FallbackKernel
  2119. and (
  2120. # NOTE: the `hasattr()` check is to bypass errors such as the following:
  2121. # AttributeError: '_OpNamespace' 'torchrec' object has no attribute 'all_to_all_single'
  2122. (
  2123. hasattr(torch.ops.torchrec, "all_to_all_single")
  2124. and node.op_overload == torch.ops.torchrec.all_to_all_single.default
  2125. )
  2126. or (
  2127. hasattr(torch.ops.torchrec, "all_gather_into_tensor")
  2128. and node.op_overload
  2129. == torch.ops.torchrec.all_gather_into_tensor.default
  2130. )
  2131. or (
  2132. hasattr(torch.ops.torchrec, "reduce_scatter_tensor")
  2133. and node.op_overload == torch.ops.torchrec.reduce_scatter_tensor.default
  2134. )
  2135. )
  2136. )
  2137. def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool:
  2138. from . import ir
  2139. return type(node) == ir._WaitKernel
  2140. def contains_collective(snode: BaseSchedulerNode) -> bool:
  2141. from torch._inductor.scheduler import GroupedSchedulerNode
  2142. if isinstance(snode, GroupedSchedulerNode):
  2143. return any(contains_collective(x) for x in snode.snodes)
  2144. return is_collective(snode.node)
  2145. def contains_wait(snode: BaseSchedulerNode) -> bool:
  2146. from torch._inductor.scheduler import GroupedSchedulerNode
  2147. if isinstance(snode, GroupedSchedulerNode):
  2148. return any(contains_wait(x) for x in snode.snodes)
  2149. else:
  2150. return is_wait(snode.node)
  2151. def is_fallback_op(
  2152. node: Optional[Operation],
  2153. op: Union[torch._ops.OpOverload, Collection[torch._ops.OpOverload]],
  2154. ) -> bool:
  2155. from . import ir
  2156. if isinstance(op, torch._ops.OpOverload):
  2157. op = [op]
  2158. return isinstance(node, ir.FallbackKernel) and node.op_overload in op
  2159. def buf_name_to_fused_snode(
  2160. buf_name: str, name_to_buf: dict[str, Any], name_to_fused_node: dict[str, Any]
  2161. ) -> Any:
  2162. return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()]
  2163. def find_recursive_deps_of_node(
  2164. snode: BaseSchedulerNode,
  2165. collected_node_set: MutableSet[BaseSchedulerNode],
  2166. name_to_buf: dict[str, SchedulerBuffer],
  2167. name_to_fused_node: dict[str, BaseSchedulerNode],
  2168. criteria_cb: Callable[[Any], bool] = lambda snode: False,
  2169. ) -> None:
  2170. if criteria_cb(snode):
  2171. return
  2172. collected_node_set.add(snode)
  2173. for dep in snode.unmet_dependencies:
  2174. defining_op_for_dep = buf_name_to_fused_snode(
  2175. dep.name, name_to_buf, name_to_fused_node
  2176. )
  2177. if defining_op_for_dep in collected_node_set:
  2178. continue
  2179. find_recursive_deps_of_node(
  2180. defining_op_for_dep,
  2181. collected_node_set,
  2182. name_to_buf,
  2183. name_to_fused_node,
  2184. criteria_cb=criteria_cb,
  2185. )
  2186. def find_recursive_users_of_node(
  2187. snode: BaseSchedulerNode,
  2188. collected_node_set: MutableSet[BaseSchedulerNode],
  2189. name_to_buf: dict[str, SchedulerBuffer],
  2190. name_to_fused_node: dict[str, BaseSchedulerNode],
  2191. criteria_cb: Callable[[Any], bool] = lambda snode: False,
  2192. ) -> None:
  2193. if criteria_cb(snode):
  2194. return
  2195. collected_node_set.add(snode)
  2196. for o in snode.get_outputs():
  2197. for user in o.users:
  2198. assert user.node is not None
  2199. if user.node.get_name() == "OUTPUT":
  2200. continue
  2201. if user.node.get_name() not in name_to_fused_node:
  2202. continue
  2203. user_op = name_to_fused_node[user.node.get_name()]
  2204. if user_op in collected_node_set:
  2205. continue
  2206. find_recursive_users_of_node(
  2207. user_op,
  2208. collected_node_set,
  2209. name_to_buf,
  2210. name_to_fused_node,
  2211. criteria_cb=criteria_cb,
  2212. )
  2213. def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int) -> int:
  2214. "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
  2215. num_rng_seed_offset_inputs = (
  2216. 2 if torch._functorch.config.functionalize_rng_ops else 0
  2217. )
  2218. # AOT won't lift any parameters if we're inlining NN Modules
  2219. # however desugaring subclasses will still add arguments
  2220. # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502
  2221. return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
  2222. def count_tangents(fx_g: torch.fx.GraphModule) -> int:
  2223. """
  2224. Infers which inputs are static for a backwards graph
  2225. """
  2226. def is_saved_tensor(x: Node) -> bool:
  2227. return (
  2228. "tangents" not in x.name
  2229. and "bwd_seed" not in x.name
  2230. and "bwd_base_offset" not in x.name
  2231. and "bwd_rng_state" not in x.name
  2232. )
  2233. arg_count = 0
  2234. static_arg_idxs = []
  2235. for n in fx_g.graph.nodes:
  2236. if n.op == "placeholder":
  2237. if is_saved_tensor(n):
  2238. static_arg_idxs.append(arg_count)
  2239. arg_count += 1
  2240. assert static_arg_idxs == list(range(len(static_arg_idxs)))
  2241. return len(static_arg_idxs)
  2242. @dataclasses.dataclass
  2243. class BoxedBool:
  2244. value: bool
  2245. def __bool__(self) -> bool:
  2246. return self.value
  2247. @staticmethod
  2248. def disable(obj: Any) -> Union[BoxedBool, bool]:
  2249. if isinstance(obj, BoxedBool):
  2250. obj.value = False
  2251. return obj
  2252. return False
  2253. @contextlib.contextmanager
  2254. def collect_defined_kernels(kernel_list: list[str]) -> Iterator[None]:
  2255. from .codegen.wrapper import PythonWrapperCodegen
  2256. orig_define_kernel = PythonWrapperCodegen.define_kernel
  2257. def define_kernel(
  2258. self: PythonWrapperCodegen,
  2259. kernel_name: str,
  2260. kernel_code: str,
  2261. metadata: Optional[str] = None,
  2262. gpu: bool = True,
  2263. cpp_definition: Optional[str] = None,
  2264. ) -> Any:
  2265. kernel_list.append(kernel_code)
  2266. return orig_define_kernel(
  2267. self, kernel_name, kernel_code, metadata, gpu, cpp_definition
  2268. )
  2269. with mock.patch.object(PythonWrapperCodegen, "define_kernel", define_kernel):
  2270. yield
  2271. def get_cloned_parameter_buffer_name(name: str) -> str:
  2272. return name + "__original__"
  2273. def is_gpu(device: Optional[str]) -> bool:
  2274. return device in GPU_TYPES
  2275. def device_need_guard(device: str) -> bool:
  2276. return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now
  2277. def needs_fallback_due_to_atomic_add_limitations(dtype: torch.dtype) -> bool:
  2278. # tl.atomic add has bfloat16 support in fbcode
  2279. # but not in OSS https://github.com/pytorch/pytorch/issues/97016
  2280. # we will fallback until the code is upstreamed to OSS
  2281. if (
  2282. config.is_fbcode()
  2283. and dtype == torch.bfloat16
  2284. and torch.cuda.is_available()
  2285. and torch.cuda.get_device_capability() >= (9, 0)
  2286. and config.bfloat16_atomic_adds_enabled
  2287. ):
  2288. return False
  2289. else:
  2290. return dtype in OrderedSet([torch.int64, torch.bool, torch.bfloat16])
  2291. def use_scatter_fallback(
  2292. op_overload: torch._ops.OpOverload,
  2293. reduction_type: Optional[str],
  2294. self_dtype: torch.dtype,
  2295. src_dtype: torch.dtype,
  2296. src_device_type: str,
  2297. src_is_tensor: bool,
  2298. ) -> bool:
  2299. if (
  2300. op_overload.overloadpacket
  2301. in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce)
  2302. and reduction_type is None
  2303. ):
  2304. return False
  2305. reduce_ty = (
  2306. "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum"
  2307. )
  2308. return (
  2309. reduction_type not in (None, reduce_ty)
  2310. or (
  2311. src_is_tensor
  2312. and is_gpu(src_device_type)
  2313. and needs_fallback_due_to_atomic_add_limitations(src_dtype)
  2314. )
  2315. or (
  2316. op_overload.overloadpacket == torch.ops.aten.scatter_reduce_
  2317. and reduction_type == "sum"
  2318. and src_is_tensor
  2319. and src_device_type == "cpu"
  2320. and config.cpp.fallback_scatter_reduce_sum
  2321. and (config.cpp.dynamic_threads or parallel_num_threads() != 1)
  2322. )
  2323. or (reduction_type == reduce_ty and self_dtype in (torch.bool, torch.int64))
  2324. or torch.are_deterministic_algorithms_enabled()
  2325. )
  2326. def dump_node_schedule(node_schedule: Sequence[BaseSchedulerNode]) -> None:
  2327. """
  2328. An API that can be used in pdb to dump a node_schedule.
  2329. Right mainly dump the read/write dependencies but can add more as needed.
  2330. """
  2331. from torch._inductor.codegen.simd import DisableReduction, EnableReduction
  2332. from torch._inductor.scheduler import SchedulerNode
  2333. print(f"Node schedule with {len(node_schedule)} nodes")
  2334. for idx, node in enumerate(node_schedule):
  2335. print(f" {idx:3}:")
  2336. if node is EnableReduction:
  2337. print("enable reduction")
  2338. elif node is DisableReduction:
  2339. print("disable reduction")
  2340. elif isinstance(node, SchedulerNode):
  2341. is_red = node.is_reduction()
  2342. print(f"{'red' if is_red else 'pw'} scheduler node")
  2343. if is_red:
  2344. assert node.node is not None
  2345. print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined]
  2346. print("ReadDep:")
  2347. for dep in node.read_writes.reads:
  2348. print(dep)
  2349. print("WriteDep:")
  2350. for dep in node.read_writes.writes:
  2351. print(dep)
  2352. else:
  2353. raise RuntimeError(f"Unrecognized node type: {type(node)}")
  2354. def tensor_is_aligned(tensor: torch.Tensor) -> bool:
  2355. # See Note: [Input Alignment handling in Inductor]
  2356. # Right now, we don't try to guard on the alignment of the storage offset.
  2357. # When this comment was written, non-symbolic storage_offsets are not guarded on
  2358. # but symbolic storage_offsets are. For consistency, we suppress guard creation
  2359. # upon performing this check: that ensures that we don't add recompiles when we
  2360. # add this logic.
  2361. from torch.fx.experimental.symbolic_shapes import statically_known_true
  2362. return statically_known_true(
  2363. (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0
  2364. )
  2365. def should_assume_input_aligned(example_input: torch.Tensor) -> bool:
  2366. # See Note: [Input Alignment handling in Inductor]
  2367. # right now, we only care about alignment for cuda tensors.
  2368. if not is_gpu(example_input.device.type):
  2369. return False
  2370. return config.assume_aligned_inputs or tensor_is_aligned(example_input)
  2371. def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[None]:
  2372. # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards()
  2373. # If it's not available, return a nullcontext.
  2374. # If we're dealing with cudagraphs, we might not have a tracing_context
  2375. tracing_context = torch._guards.TracingContext.try_get()
  2376. if not tracing_context:
  2377. return contextlib.nullcontext()
  2378. # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
  2379. if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env:
  2380. return contextlib.nullcontext()
  2381. shape_env = tracing_context.fake_mode.shape_env
  2382. return shape_env.suppress_guards()
  2383. def run_and_get_cpp_code(
  2384. fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
  2385. ) -> tuple[_T, str]:
  2386. # We use the patch context manager instead of using it as a decorator.
  2387. # In this way, we can ensure that the attribute is patched and unpatched correctly
  2388. # even if this run_and_get_cpp_code function is called multiple times.
  2389. with unittest.mock.patch.object(config, "debug", True):
  2390. torch._dynamo.reset()
  2391. import io
  2392. import logging
  2393. log_capture_string = io.StringIO()
  2394. ch = logging.StreamHandler(log_capture_string)
  2395. from torch._inductor.codecache import output_code_log
  2396. output_code_log.addHandler(ch)
  2397. prev_level = output_code_log.level
  2398. output_code_log.setLevel(logging.DEBUG)
  2399. result = fn(*args, **kwargs)
  2400. s = log_capture_string.getvalue()
  2401. output_code_log.setLevel(prev_level)
  2402. output_code_log.removeHandler(ch)
  2403. return result, s
  2404. def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
  2405. fake_mode = detect_fake_mode(inputs)
  2406. # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
  2407. # pass in real inputs for now.
  2408. # if len(inputs) > 0:
  2409. # assert fake_mode is not None, breakpoint()
  2410. if fake_mode is not None:
  2411. return fake_mode.shape_env
  2412. # When there are no tensor inputs, get shape_env from the first SymInt.
  2413. for input in inputs:
  2414. if isinstance(input, torch.SymInt):
  2415. return input.node.shape_env
  2416. # TODO(voz): Should we always have one anyway?
  2417. return None
  2418. def align_inputs_from_check_idxs(
  2419. model: Callable[[list[InputType]], _T],
  2420. inputs_to_check: Sequence[int],
  2421. mutated_input_idxs: OrderedSet[int],
  2422. ) -> Callable[[list[InputType]], _T]:
  2423. if len(inputs_to_check) == 0:
  2424. return model
  2425. def run(new_inputs: list[InputType]) -> Any:
  2426. old_tensors, new_tensors = copy_misaligned_inputs(
  2427. new_inputs, inputs_to_check, mutated_input_idxs
  2428. )
  2429. out = model(new_inputs)
  2430. # If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the
  2431. # original tensor.
  2432. if len(old_tensors):
  2433. torch._foreach_copy_(old_tensors, new_tensors)
  2434. return out
  2435. return run
  2436. def clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
  2437. if 0 in x.size():
  2438. # Short-circuits if the shape has no elements
  2439. needed_size = 0
  2440. else:
  2441. needed_size = (
  2442. sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
  2443. )
  2444. buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
  2445. return torch.as_strided(buffer, x.size(), x.stride())
  2446. def copy_misaligned_inputs(
  2447. new_inputs: list[InputType],
  2448. check_inputs_idxs: Sequence[int],
  2449. return_pair_idxs: Optional[OrderedSet[int]] = None,
  2450. ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
  2451. """
  2452. Clones misaligned tensors which we inferred were aligned. Returns a tuple of [old_tensors], [new_tensors] for every
  2453. cloned tensor which is in `return_pair_idxs`.
  2454. """
  2455. old_tensors: list[torch.Tensor] = []
  2456. new_tensors: list[torch.Tensor] = []
  2457. # hoist above loop because this is on the hot path
  2458. ret_pair_defined = return_pair_idxs is not None
  2459. for i in check_inputs_idxs:
  2460. _inp = new_inputs[i]
  2461. assert isinstance(_inp, torch.Tensor), (
  2462. f"Expected tensors only, but got: {type(_inp)}"
  2463. )
  2464. if _inp.data_ptr() % ALIGNMENT:
  2465. new_inputs[i] = clone_preserve_strides(_inp)
  2466. if ret_pair_defined and i in return_pair_idxs: # type: ignore[operator]
  2467. old_tensors.append(_inp)
  2468. new_tensors.append(new_inputs[i]) # type: ignore[arg-type]
  2469. return old_tensors, new_tensors
  2470. def remove_unaligned_input_idxs(
  2471. inputs: Sequence[InputType],
  2472. static_input_idxs: Sequence[int],
  2473. ) -> Sequence[int]:
  2474. """
  2475. We require all inputs to be aligned, so introduce a copy for any
  2476. that aren't.
  2477. """
  2478. aligned_static_input_idxs = []
  2479. for idx in static_input_idxs:
  2480. input = inputs[idx]
  2481. if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
  2482. aligned_static_input_idxs.append(idx)
  2483. if len(aligned_static_input_idxs) != len(static_input_idxs):
  2484. return aligned_static_input_idxs
  2485. return static_input_idxs
  2486. def expr_fits_within_32bit(e: sympy.Expr) -> bool:
  2487. from .virtualized import V
  2488. int_max = torch.iinfo(torch.int32).max
  2489. size_hint = V.graph.sizevars.size_hint
  2490. has_hint = V.graph.sizevars.shape_env.has_hint
  2491. # Allow for unhinted e as long as we can still statically prove
  2492. # (e.g., via ValueRanges) that it is still in bounds
  2493. if V.graph.sizevars.statically_known_true(e <= int_max):
  2494. return True
  2495. # AOTI doesn't guard on < 2**32, so checking hints isn't a viable option,
  2496. # in case the hinted value is < 2**32, but the allowed range is larger.
  2497. # However, to prevent possible perf regressions on pre-existing AOTI models
  2498. # which don't set an upper bound on the valid range, we'll skip the check.
  2499. # To recap:
  2500. # - If using AOTI:
  2501. # - If allowed range has no upper bound, then check the hint to determine
  2502. # whether this fits in int32
  2503. # - If allowed range does have an upper bound, then obey the upper bound
  2504. # (check whether upper bound < int32_max) without checking the hint.
  2505. if V.aot_compilation:
  2506. # check whether value has an upper bound (1e20 is > INT64_MAX, assume
  2507. # there is no upper bound if it can be larger than 1e20)
  2508. if V.graph.sizevars.statically_known_true(e < 1e20):
  2509. # if so, then assume int_max < upper bound < inf
  2510. # so this could potentially have int64 values
  2511. return False
  2512. # Otherwise, the hint MUST exist and be in range
  2513. return has_hint(e) and size_hint(e) <= int_max
  2514. def set_tracing_context_output_strides(
  2515. example_inputs: Sequence[Any], compiled_graph: CompiledFxGraph
  2516. ) -> None:
  2517. # Return the output strides to the caller via TracingContext
  2518. context = torch._guards.TracingContext.try_get()
  2519. if context is not None and context.output_strides is not None:
  2520. assert len(context.output_strides) == 0
  2521. shape_env = shape_env_from_inputs(example_inputs)
  2522. assert compiled_graph.output_strides is not None
  2523. for exprs in compiled_graph.output_strides:
  2524. if exprs is None:
  2525. context.output_strides.append(None)
  2526. else:
  2527. fakify_first_call = False
  2528. if ctx := torch._guards.TracingContext.try_get():
  2529. fakify_first_call = ctx.fakify_first_call
  2530. def map_expr(e: Any) -> Union[float, int, SymInt, SymFloat, SymBool]:
  2531. if shape_env is None:
  2532. return int(e)
  2533. if fakify_first_call:
  2534. return shape_env.deserialize_symexpr(e)
  2535. return shape_env.evaluate_symexpr(e)
  2536. context.output_strides.append(
  2537. tuple(map_expr(e) for e in exprs) # type: ignore[misc]
  2538. )
  2539. def should_use_remote_fx_graph_cache() -> bool:
  2540. if config.fx_graph_remote_cache is not None:
  2541. return config.fx_graph_remote_cache
  2542. if not config.is_fbcode():
  2543. return False
  2544. if torch._utils_internal.is_fb_unit_test():
  2545. return False
  2546. try:
  2547. from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
  2548. except ModuleNotFoundError:
  2549. return False
  2550. return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
  2551. "pytorch/remote_cache:fx_graph_memcache_version"
  2552. )
  2553. def normalize_name(name: str) -> str:
  2554. return re.sub(r"[^a-zA-Z0-9_]", "_", name)
  2555. # correct cases where Triton types names don't match PyTorch
  2556. _triton_type_mapping = {
  2557. "tl.bool": "tl.int1",
  2558. "tl.float8_e4m3fn": "tl.float8e4nv",
  2559. "tl.float8_e5m2": "tl.float8e5",
  2560. "tl.float8_e4m3fnuz": "tl.float8e4b8",
  2561. "tl.float8_e5m2fnuz": "tl.float8e5b16",
  2562. # TODO: remove when support is added in triton
  2563. # https://github.com/triton-lang/triton/issues/6054
  2564. "tl.float8_e8m0fnu": "tl.uint8",
  2565. "tl.float4_e2m1fn_x2": "tl.uint8",
  2566. }
  2567. _torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()}
  2568. _triton_type_re = re.compile(r"^.*[.]")
  2569. def triton_type(dtype: torch.dtype) -> str:
  2570. """Convert torch.dtype to triton type"""
  2571. triton_type_name = _triton_type_re.sub("tl.", str(dtype))
  2572. return _triton_type_mapping.get(triton_type_name, triton_type_name)
  2573. def triton_type_to_torch(dtype: str) -> torch.dtype:
  2574. adjusted_type = _torch_triton_mapping.get(dtype, dtype)
  2575. type_name = adjusted_type.replace("tl.", "")
  2576. out_dtype = getattr(torch, type_name)
  2577. assert isinstance(out_dtype, torch.dtype)
  2578. return out_dtype
  2579. def is_same_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
  2580. return (
  2581. not data.is_mkldnn
  2582. and data.size() == value.size()
  2583. and data.stride() == value.stride()
  2584. and data.dtype == value.dtype
  2585. and data.device == value.device
  2586. and data.untyped_storage().data_ptr() == value.untyped_storage().data_ptr()
  2587. and data.storage_offset() == value.storage_offset()
  2588. )
  2589. def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
  2590. return (
  2591. data.is_mkldnn
  2592. and data.size() == value.size()
  2593. and data.dtype == value.dtype
  2594. and data.device == value.device
  2595. and torch.ops.mkldnn.data_ptr(data) == torch.ops.mkldnn.data_ptr(value)
  2596. )
  2597. @functools.cache
  2598. def boolean_ops() -> tuple[str, ...]:
  2599. return (
  2600. "isinf",
  2601. "isnan",
  2602. "logical_not",
  2603. "logical_and",
  2604. "signbit",
  2605. "and_",
  2606. "le",
  2607. "lt",
  2608. "ge",
  2609. "gt",
  2610. "eq",
  2611. "ne",
  2612. "or_", # TODO should remove this op
  2613. "xor",
  2614. )
  2615. @dataclasses.dataclass
  2616. class OpDtypeRule:
  2617. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
  2618. override_return_dtype: Optional[torch.dtype]
  2619. op_dtype_propagation_rules: dict[str, OpDtypeRule] = {}
  2620. def register_op_dtype_propagation_rules(
  2621. name: str,
  2622. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
  2623. override_return_dtype: Optional[torch.dtype],
  2624. ) -> None:
  2625. op_dtype_propagation_rules[name] = OpDtypeRule(
  2626. type_promotion_kind, override_return_dtype
  2627. )
  2628. op_requires_libdevice_fp64: OrderedSet[str] = OrderedSet()
  2629. def register_op_requires_libdevice_fp64(name: str) -> None:
  2630. op_requires_libdevice_fp64.add(name)
  2631. def get_current_backend() -> str:
  2632. from torch._inductor.virtualized import V
  2633. device_str = V.graph.get_current_device_or_throw().type
  2634. if device_str == "cpu":
  2635. return config.cpu_backend
  2636. elif device_str == "mps":
  2637. return "mps"
  2638. else:
  2639. return config.cuda_backend
  2640. def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:
  2641. """Maybe upcast [b]float16 to float32"""
  2642. if (
  2643. dtype in (torch.float16, torch.bfloat16)
  2644. and config.triton.codegen_upcast_to_fp32
  2645. and get_current_backend() == "triton"
  2646. ):
  2647. return torch.float32
  2648. return dtype
  2649. KeyType = TypeVar("KeyType")
  2650. ValType = TypeVar("ValType")
  2651. class ScopedDict(MutableMapping[KeyType, ValType]):
  2652. """
  2653. A dictionary-like object that allows for scoped updates. It maintains
  2654. an original dictionary and a set of new items that can override
  2655. the original items within the scope. The original dictionary is
  2656. unmodified.
  2657. """
  2658. def __init__(self, original_dict: Mapping[KeyType, ValType]):
  2659. self.original_dict = original_dict
  2660. self.new_items: dict[KeyType, ValType] = {}
  2661. def __getitem__(self, key: KeyType) -> ValType:
  2662. if key in self.new_items:
  2663. return self.new_items[key]
  2664. return self.original_dict[key]
  2665. def __setitem__(self, key: KeyType, value: ValType) -> None:
  2666. self.new_items[key] = value
  2667. def __contains__(self, key: object) -> bool:
  2668. return key in self.new_items or key in self.original_dict
  2669. def get(self, key: KeyType, default: Optional[ValType] = None) -> Optional[ValType]: # type: ignore[override]
  2670. if key in self.new_items:
  2671. return self.new_items[key]
  2672. return self.original_dict.get(key, default)
  2673. def __len__(self) -> int:
  2674. n = len(self.original_dict)
  2675. for k in self.new_items:
  2676. if k not in self.original_dict:
  2677. n += 1
  2678. return n
  2679. def __iter__(self) -> Iterator[KeyType]:
  2680. yield from self.original_dict
  2681. for k in self.new_items:
  2682. if k not in self.original_dict:
  2683. yield k
  2684. def __bool__(self) -> bool:
  2685. return bool(self.original_dict or self.new_items)
  2686. def __delitem__(self, key: KeyType) -> None:
  2687. raise NotImplementedError
  2688. @dataclass_transform(frozen_default=True)
  2689. def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any:
  2690. def wrap(cls: _T) -> _T:
  2691. if sys.version_info >= (3, 10):
  2692. return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload]
  2693. else:
  2694. # Polyfill for python=3.9. kw_only simply introduces an extra check
  2695. # that only kwargs are used (and is not available on 3.9)
  2696. return dataclasses.dataclass(cls, frozen=frozen)
  2697. if cls is None:
  2698. return wrap
  2699. return wrap(cls)
  2700. def get_donated_idxs() -> Optional[list[int]]:
  2701. tracing_context = torch._guards.TracingContext.try_get()
  2702. if tracing_context is not None and tracing_context.fw_metadata:
  2703. return tracing_context.fw_metadata.bw_donated_idxs
  2704. return None
  2705. class TritonAttrsDescriptorVersion(enum.Enum):
  2706. V0_NO_TRITON = 0
  2707. V1_COMPILER = 1 # triton.compiler.compiler.AttrsDescriptor
  2708. V2_BACKENDS = 2 # triton.backends.compiler.AttrsDescriptor
  2709. V3_BACKENDS_TUPLE = (
  2710. 3 # triton.backends.compiler.AttrsDescriptor, but with tuple support
  2711. )
  2712. V4_DICT = 4 # a raw dict
  2713. @functools.cache
  2714. def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion:
  2715. if importlib.util.find_spec("triton") is None:
  2716. return TritonAttrsDescriptorVersion.V0_NO_TRITON
  2717. import triton.backends.compiler
  2718. import triton.compiler.compiler
  2719. if hasattr(triton.backends.compiler, "AttrsDescriptor"):
  2720. # Triton 3.2.0
  2721. # AttrsDescriptor was moved from triton.compiler.compiler to triton.backends.compiler.
  2722. # AttrsDescriptor and its serialization format were also changed.
  2723. # TODO: implement V3_BACKENDS_TUPLE
  2724. # On Dec 9, 2024, tuple support (triton #5220) was implemented and breaks handling.
  2725. # We don't have a way to detect this (and haven't implemented this version)
  2726. return TritonAttrsDescriptorVersion.V2_BACKENDS
  2727. elif hasattr(triton.compiler.compiler, "AttrsDescriptor"):
  2728. # Triton 3.0.0
  2729. return TritonAttrsDescriptorVersion.V1_COMPILER
  2730. else:
  2731. # After Jan 1, 2025
  2732. # AttrsDescriptor was removed and replaced with a raw dict.
  2733. return TritonAttrsDescriptorVersion.V4_DICT
  2734. def triton_version_uses_attrs_dict() -> bool:
  2735. return get_triton_attrs_descriptor_version() == TritonAttrsDescriptorVersion.V4_DICT
  2736. def is_cudagraph_unsafe_op(node: Operation) -> bool:
  2737. """
  2738. Returns True if the node is an op that is not cudagraphable.
  2739. Usually only custom ops have this tag.
  2740. """
  2741. from . import ir
  2742. if not isinstance(node, ir.FallbackKernel):
  2743. return False
  2744. if (
  2745. isinstance(node.op_overload, torch._ops.OpOverload)
  2746. and torch._C.Tag.cudagraph_unsafe in node.op_overload.tags # type: ignore[attr-defined]
  2747. ):
  2748. return True
  2749. return False
  2750. def get_ld_library_path() -> str:
  2751. path = os.environ.get("LD_LIBRARY_PATH", "")
  2752. if config.is_fbcode():
  2753. from libfb.py.parutil import get_runtime_path
  2754. runtime_path = get_runtime_path()
  2755. if runtime_path:
  2756. lib_path = os.path.join(runtime_path, "runtime", "lib")
  2757. path = os.pathsep.join([lib_path, path]) if path else lib_path
  2758. return path
  2759. def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
  2760. from torch._inductor.codegen.wrapper import SubgraphPythonWrapperCodegen
  2761. return (
  2762. isinstance(wrapper, SubgraphPythonWrapperCodegen)
  2763. and wrapper.partition_signatures is not None
  2764. )
  2765. def is_using_cudagraph_partition() -> bool:
  2766. return (
  2767. torch._inductor.config.triton.cudagraphs
  2768. or _unstable_customized_partition_wrapper.wrapper is not None
  2769. ) and torch._inductor.config.graph_partition
  2770. def dtype_from_size(size: int) -> torch.dtype:
  2771. from .virtualized import V
  2772. if V.graph.sizevars.statically_known_lt(
  2773. size, 2**31
  2774. ) and V.graph.sizevars.statically_known_geq(size, -(2**31)):
  2775. return torch.int32
  2776. else:
  2777. return torch.int64
  2778. SUPPORTED_MKLDNN_DEVICES = ("cpu", "xpu")
  2779. def is_mkldnn_bf16_supported(device_type: str) -> bool:
  2780. """
  2781. Returns True if the device supports MKL-DNN BF16.
  2782. """
  2783. if device_type == "cpu":
  2784. return torch.ops.mkldnn._is_mkldnn_bf16_supported()
  2785. elif "xpu" in device_type:
  2786. # match "xpu", "xpu:0", "xpu:1", etc.
  2787. return True
  2788. return False
  2789. def is_mkldnn_fp16_supported(device_type: str) -> bool:
  2790. """
  2791. Returns True if the device supports MKL-DNN FP16.
  2792. """
  2793. if device_type == "cpu":
  2794. return torch.ops.mkldnn._is_mkldnn_fp16_supported()
  2795. elif "xpu" in device_type:
  2796. # match "xpu", "xpu:0", "xpu:1", etc.
  2797. return True
  2798. return False
  2799. def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
  2800. widths = [len(str(e)) for e in headers]
  2801. for row in elements:
  2802. assert len(row) == len(headers)
  2803. for i, e in enumerate(row):
  2804. widths[i] = max(widths[i], len(str(e)))
  2805. lines = []
  2806. lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths)))
  2807. # widths whitespace horizontal separators
  2808. total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
  2809. lines.append("-" * total_width)
  2810. for row in elements:
  2811. lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths)))
  2812. return "\n".join(lines)
  2813. def zip_dicts(
  2814. dict1: Mapping[KeyType, ValType],
  2815. dict2: Mapping[KeyType, ValType],
  2816. d1_default: ValType | None = None,
  2817. d2_default: ValType | None = None,
  2818. ) -> Generator[tuple[KeyType, ValType | None, ValType | None], None, None]:
  2819. """
  2820. Zip two dictionaries together, replacing missing keys with default values.
  2821. Args:
  2822. dict1 (dict): The first dictionary.
  2823. dict2 (dict): The second dictionary.
  2824. d1_default (Any): the default value for the first dictionary
  2825. d2_default (Any): the default value for the second dictionary
  2826. Yields:
  2827. tuple: A tuple containing the key, the value from dict1 (or d1_default if missing),
  2828. and the value from dict2 (or d2_default if missing).
  2829. """
  2830. # Find the union of all keys
  2831. all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys())
  2832. # Iterate over all keys
  2833. for key in all_keys:
  2834. # Get the values from both dictionaries, or default if missing
  2835. value1 = dict1.get(key)
  2836. value2 = dict2.get(key)
  2837. yield (
  2838. key,
  2839. value1 if value1 is not None else d1_default,
  2840. value2 if value2 is not None else d2_default,
  2841. )
  2842. def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, Any]:
  2843. """
  2844. Ensures the configuration is internally consistent for standalone AOTInductor.
  2845. If `aot_inductor.compile_standalone` is set to True in the provided
  2846. `config_patches` (or falls back to the global config), this function ensures
  2847. that the following configs are also enabled:
  2848. - `aot_inductor.package_cpp_only`
  2849. Args:
  2850. config_patches (dict[str, Any]): A dictionary of user-provided config
  2851. overrides for AOTInductor compilation.
  2852. Returns:
  2853. dict[str, Any]: The possibly-updated `config_patches` dictionary.
  2854. """
  2855. def patch_config(
  2856. config_patches: dict[str, Any], config_name: str, config_value: Any
  2857. ) -> None:
  2858. value = config_patches.get(config_name, getattr(config, config_name))
  2859. if value is None:
  2860. config_patches[config_name] = config_value
  2861. elif not value and value != config_value:
  2862. raise RuntimeError(
  2863. f"Invalid config: {config_name}={config_value} when aot_inductor.compile_standalone is True."
  2864. )
  2865. compile_standalone = config_patches.get(
  2866. "aot_inductor.compile_standalone", config.aot_inductor.compile_standalone
  2867. )
  2868. # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing
  2869. config_patches = config_patches.copy()
  2870. if compile_standalone:
  2871. # Standlaone AOTInductor means only generate cpp project for building a standalone binary
  2872. patch_config(config_patches, "aot_inductor.package_cpp_only", True)
  2873. # Standlaone AOTInductor needs to embed the kernel code in the binary
  2874. patch_config(config_patches, "aot_inductor.embed_kernel_binary", True)
  2875. # Default to use multi-arch kernel codegen for non-rocm GPU
  2876. patch_config(
  2877. config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip
  2878. )
  2879. patch_config(
  2880. config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model"
  2881. )
  2882. return config_patches
  2883. def is_valid_aoti_model_name() -> bool:
  2884. """
  2885. Validates if a model name is suitable for use in code generation.
  2886. """
  2887. from torch._inductor import config
  2888. model_name = config.aot_inductor.model_name_for_generated_files
  2889. if model_name is None:
  2890. return True
  2891. if not isinstance(model_name, str):
  2892. raise ValueError("Invalid AOTI model name: Model name must be a string")
  2893. if model_name == "":
  2894. return True
  2895. # Can only contain alphanumeric characters and underscores
  2896. if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", model_name):
  2897. raise ValueError(
  2898. "Invalid AOTI model name: Model name can only contain letters, numbers, and underscores"
  2899. )
  2900. return True
  2901. def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
  2902. if unbacked_only:
  2903. return free_unbacked_symbols(x)
  2904. else:
  2905. return free_symbols(x)
  2906. def maybe_log_cudagraph_partition(
  2907. msg: str,
  2908. prefix: Optional[str] = "cudagraph partition due to ",
  2909. node: Optional[BaseSchedulerNode] = None,
  2910. ) -> None:
  2911. """
  2912. Cudagraph partition may lead to extra memory overhead so we
  2913. log partition reasons to help users understand the overhead.
  2914. """
  2915. if not config.triton.cudagraphs:
  2916. return
  2917. warning_msg = f"{prefix}{msg}"
  2918. if (
  2919. node
  2920. and (ir_node := node.node)
  2921. and (fx_node := ir_node.get_origin_node())
  2922. and (stack_trace := fx_node.meta.get("stack_trace", None))
  2923. ):
  2924. warning_msg = f"{warning_msg}. Found from : \n {stack_trace}"
  2925. perf_hint_log.warning(warning_msg)
  2926. def python_subprocess_env() -> dict[str, str]:
  2927. """
  2928. Get a base environment for running Python subprocesses.
  2929. """
  2930. env = {
  2931. # Inherit the environment of the current process.
  2932. **os.environ,
  2933. # Set the PYTHONPATH so the subprocess can find torch.
  2934. "PYTHONPATH": os.environ.get(
  2935. "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path)
  2936. ),
  2937. }
  2938. # Set PYTHONHOME for internal builds, to account for builds that bundle the
  2939. # runtime. Otherwise they will use the libraries and headers from the
  2940. # platform runtime instead.
  2941. #
  2942. # This can't be done for external builds. The process can be run from a
  2943. # venv and that won't include Python headers. The process needs to be able
  2944. # to search for and find the platform runtime.
  2945. if config.is_fbcode():
  2946. env["PYTHONHOME"] = sysconfig.get_path("data")
  2947. return env
  2948. @dataclasses.dataclass(frozen=True)
  2949. class CUDAGraphWrapperMetadata:
  2950. """
  2951. Metadata for Customized CUDAGraphWrapper.
  2952. Currently assumes there is 1 dynamo graph and will extend to
  2953. multiple graphs in the future.
  2954. """
  2955. # The number of partitions that are cudagraphable.
  2956. num_partitions: int
  2957. # Index of the current partition.
  2958. partition_index: int
  2959. PartitionFnType = Callable[..., Any]
  2960. CUDAGraphWrapperType = Callable[
  2961. [PartitionFnType, CUDAGraphWrapperMetadata], PartitionFnType
  2962. ]
  2963. # only incremented by user call of mark_step_begin
  2964. class CUDAGraphWrapper:
  2965. wrapper: Optional[CUDAGraphWrapperType] = None
  2966. # A customized partition wrappers from users. Interface should be:
  2967. #
  2968. # def wrapper(fn: PartitionFnType, metadata: CUDAGraphWrapperMetadata) -> PartitionFnType
  2969. #
  2970. # Inductor generates N wrapper functions for N partition functions, and mechanically wrap
  2971. # each partition fn with the generated wrapper function. Users need to handle all details
  2972. # such as static inputs, dynamic shapes, etc.
  2973. # Users could customize the wrapper based on the metadata. One example is to have special
  2974. # handle for the first and last wrapper function.
  2975. #
  2976. # Warning: This API is unstable and may change in the future.
  2977. _unstable_customized_partition_wrapper = CUDAGraphWrapper()
  2978. def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
  2979. _unstable_customized_partition_wrapper.wrapper = wrapper
  2980. def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
  2981. args = snode.node.inputs # type: ignore[union-attr]
  2982. args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
  2983. [*args, *snode.node.constant_args], # type: ignore[union-attr]
  2984. snode.node.kwargs, # type: ignore[union-attr]
  2985. )
  2986. kwargs = snode.node.kwargs # type: ignore[union-attr]
  2987. flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
  2988. def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
  2989. return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
  2990. x, torch._inductor.ir.GeneratorState
  2991. )
  2992. flat_args = [
  2993. torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
  2994. if _is_tensor_ir(a)
  2995. else a
  2996. for a in flat_args
  2997. ]
  2998. def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
  2999. return torch.empty(size, dtype=dtype, device=device)
  3000. def to_real_tensor(e: Any) -> Any:
  3001. if not isinstance(e, torch.Tensor):
  3002. return e
  3003. out = _tensor(e.size(), e.dtype, e.device)
  3004. return out
  3005. flat_args = [to_real_tensor(a) for a in flat_args]
  3006. args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
  3007. return args, kwargs