| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762 |
- from __future__ import annotations
- import collections
- import contextlib
- import dataclasses
- import enum
- import functools
- import importlib
- import inspect
- import io
- import itertools
- import logging
- import math
- import operator
- import os
- import platform
- import re
- import shutil
- import statistics
- import sys
- import sysconfig
- import tempfile
- import textwrap
- import time
- import unittest
- from collections.abc import (
- Collection,
- Generator,
- Iterator,
- Mapping,
- MutableMapping,
- MutableSet,
- )
- from datetime import datetime
- from io import StringIO
- from typing import (
- Any,
- Callable,
- cast,
- Generic,
- Literal,
- NamedTuple,
- Optional,
- Protocol,
- TYPE_CHECKING,
- TypeVar,
- Union,
- )
- from typing_extensions import (
- Concatenate,
- dataclass_transform,
- ParamSpec,
- Self,
- TypeAlias,
- TypeGuard,
- )
- from unittest import mock
- import sympy
- import torch
- import torch.utils._pytree as pytree
- from torch._inductor.analysis.device_info import datasheet_tops
- from torch._inductor.runtime.hints import DeviceProperties
- from torch.utils._dtype_abbrs import dtype_abbrs
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._pytree import tree_flatten, tree_map_only
- OPTIMUS_EXCLUDE_POST_GRAD = [
- "activation_quantization_aten_pass",
- "inductor_autotune_lookup_table",
- ]
- from torch.fx.experimental.symbolic_shapes import (
- free_symbols,
- free_unbacked_symbols,
- IterateExprs,
- ShapeEnv,
- )
- if TYPE_CHECKING:
- from collections.abc import Iterable, Sequence, ValuesView
- from torch import SymBool, SymFloat, SymInt
- from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
- from torch.fx import GraphModule
- from torch.fx.node import Node
- from .codegen.common import WorkspaceArg
- from .codegen.wrapper import PythonWrapperCodegen
- from .graph import GraphLowering
- from .ir import Buffer, ExternKernel, IRNode, Layout, Operation, ReinterpretView
- from .output_code import CompiledFxGraph
- from .scheduler import BaseSchedulerNode, SchedulerBuffer
- GPU_TYPES = ["cuda", "mps", "xpu", "mtia"]
- T = TypeVar("T")
- # defines here before import torch._dynamo is for avoiding circular import
- # when get_gpu_type is imported from dynamo
- @functools.cache
- def get_gpu_type() -> str:
- avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()]
- assert len(avail_gpus) <= 1
- gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop()
- return gpu_type
- from torch._dynamo.device_interface import get_interface_for_device
- from torch._dynamo.utils import detect_fake_mode
- from torch.autograd import DeviceType
- from torch.autograd.profiler_util import EventList
- from torch.fx.passes.graph_transform_observer import GraphTransformObserver
- from torch.fx.passes.shape_prop import ShapeProp
- from torch.utils._sympy.functions import (
- CeilDiv,
- CleanDiv,
- FloorDiv,
- Identity,
- ModularIndexing,
- )
- from torch.utils._sympy.symbol import make_symbol, SymT
- from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
- from . import config
- from .runtime.runtime_utils import ceildiv as runtime_ceildiv
- _IS_WINDOWS = sys.platform == "win32"
- log = logging.getLogger(__name__)
- perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
- _T = TypeVar("_T")
- VarRanges = dict[sympy.Expr, sympy.Expr]
- InputType = Optional[Union[torch.Tensor, int, torch.SymInt]]
- GPU_KERNEL_BIN_EXTS = {"cuda": ".cubin", "xpu": ".spv"}
- GPU_ALIGN_BYTES = 16
- ALIGNMENT = 16
- TMA_ALIGNMENT = 16
- TMA_DESCRIPTOR_SIZE = 128
- ALIGN_BYTES = 64
- assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"
- def _align(nbytes: int) -> int:
- """Round up to the nearest multiple of ALIGN_BYTES"""
- return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES
- def _is_aligned(v: sympy.Expr) -> bool:
- """v can be statically proven to be a multiple of ALIGN_BYTES"""
- if isinstance(v, (sympy.Add, sympy.Max)):
- return all(map(_is_aligned, v.args))
- return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES
- class align(sympy.Function):
- """Symbolically round up to the nearest multiple of ALIGN_BYTES"""
- nargs = (1,)
- is_integer = True
- @classmethod
- def eval(cls, value: sympy.Expr) -> Optional[sympy.Expr]:
- if isinstance(value, (int, sympy.Integer)):
- return _align(int(value))
- if _is_aligned(value):
- return value
- @dataclasses.dataclass(frozen=True)
- class GraphPartitionMap:
- """
- Mapping from the partition info (e.g., input/output) to the graph info
- """
- # a unique id of graph partition
- id: int
- # map partition input/output indices to graph input/output indices. None indicates
- # a partition input/output is not a graph input/output.
- input_index_mapping: list[Optional[int]]
- output_index_mapping: list[Optional[int]]
- # name of constants read/written by the graph partition
- constant_names: list[str]
- def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float:
- """
- Returns benchmark results by examining torch profiler events.
- This could be more accurate as it doesn't count CPU side overhead.
- However, this also requires manually excluding irrelevant event, e.g.
- vectorized_elementwise_kernel which is used to fill L2 cache,
- various CUDA events, etc, so could also be fragile.
- """
- fn()
- torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.float16, device="cuda")
- # Estimate the runtime of the function
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
- start_event.record()
- for _ in range(5):
- cache.zero_()
- fn()
- end_event.record()
- torch.cuda.synchronize()
- estimate_ms = start_event.elapsed_time(end_event) / 5
- # compute number of warmup and repeat
- n_warmup = max(1, int(warmup / estimate_ms))
- n_repeat = max(1, int(rep / estimate_ms))
- # Warm-up
- for _ in range(n_warmup):
- fn()
- start_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
- end_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
- with torch.profiler.profile(
- activities=[
- torch.profiler.ProfilerActivity.CUDA,
- ]
- ) as p:
- torch.cuda.synchronize()
- for i in range(n_repeat):
- cache.zero_()
- start_event[i].record()
- with torch.cuda.nvtx.range("RunCudaModule"):
- fn()
- end_event[i].record()
- torch.cuda.synchronize()
- times = torch.tensor(
- [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
- )
- res = torch.mean(times).item()
- log.debug("raw events")
- log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
- filtered_events = EventList(
- [
- event
- for event in p.events()
- if (
- event.device_type == DeviceType.CUDA
- and re.match(r"fused_abs_max_\d", event.name) is not None
- )
- ]
- )
- if filtered_events:
- res -= (
- statistics.mean(event.device_time_total for event in filtered_events)
- / 1000.0
- )
- log.debug("profiling results: %s ms", res)
- return res
- def do_bench_using_profiling(
- fn: Callable[[], Any], warmup: int = 25, rep: int = 100
- ) -> float:
- """
- Returns benchmark results by examining torch profiler events.
- This could be more accurate as it doesn't count CPU side overhead.
- However, this also requires manually excluding irrelevant event, e.g.
- vectorized_elementwise_kernel which is used to fill L2 cache,
- various CUDA events, etc, so could also be fragile.
- """
- fn()
- torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
- # Estimate the runtime of the function
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
- start_event.record()
- for _ in range(5):
- cache.zero_()
- fn()
- end_event.record()
- torch.cuda.synchronize()
- estimate_ms = start_event.elapsed_time(end_event) / 5
- # compute number of warmup and repeat
- n_warmup = max(1, int(warmup / estimate_ms))
- n_repeat = max(1, int(rep / estimate_ms))
- # Warm-up
- for _ in range(n_warmup):
- fn()
- torch.cuda.synchronize()
- with torch.profiler.profile(
- activities=[
- torch.profiler.ProfilerActivity.CUDA,
- ]
- ) as p:
- # Benchmark
- for i in range(n_repeat):
- # we clear the L2 cache before each run
- cache.zero_()
- # record time of `fn`
- fn()
- # Record clocks
- torch.cuda.synchronize()
- log.debug("raw events")
- log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
- filtered_events = EventList(
- [
- event
- for event in p.events()
- if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
- ]
- )
- if len(filtered_events) % n_repeat != 0:
- raise RuntimeError(
- "Failed to divide all profiling events into #repeat groups. "
- "#CUDA events: %d, #repeats: %s",
- len(filtered_events),
- n_repeat,
- )
- num_event_per_group = len(filtered_events) / n_repeat
- actual_events = EventList(
- [
- event
- for i, event in enumerate(filtered_events)
- if i % num_event_per_group != 0
- ]
- )
- actual_events._build_tree()
- actual_events = actual_events.key_averages()
- log.debug("profiling time breakdown")
- log.debug(actual_events.table(row_limit=-1))
- res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
- log.debug("profiling results: %s ms", res)
- return res
- @functools.cache
- def has_torchvision_roi_align() -> bool:
- try:
- from torchvision.ops import roi_align # noqa: F401
- torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta")
- return roi_align is not None and hasattr(
- getattr(torch.ops, "torchvision", None), "roi_align"
- )
- except ImportError:
- return False
- except RuntimeError as e:
- assert "torchvision::nms does not exist" in str(e)
- return False
- def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
- if device is None:
- return torch.tensor(0.0).device # default device
- if isinstance(device, str):
- device = torch.device(device)
- if device.type not in ("cpu", "meta") and device.index is None:
- device_interface = get_interface_for_device(device.type)
- return torch.device(device.type, index=device_interface.Worker.current_device())
- return device
- def sympy_product(it: Iterable[sympy.Expr]) -> sympy.Expr:
- return functools.reduce(operator.mul, it, sympy.S.One)
- def sympy_dot(seq1: Sequence[sympy.Expr], seq2: Sequence[sympy.Expr]) -> sympy.Expr:
- assert len(seq1) == len(seq2)
- return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
- def unique(it: Iterable[_T]) -> ValuesView[_T]:
- return {id(x): x for x in it}.values()
- def ceildiv(
- number: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
- ) -> Union[int, sympy.Expr]:
- if isinstance(number, sympy.Expr) or isinstance(denom, sympy.Expr):
- return CeilDiv(sympy.sympify(number), sympy.sympify(denom))
- # TODO: There is a bug in a call to this function, to repro:
- # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
- # --amp --only YituTechConvBert --dynamic-shapes
- assert isinstance(number, int) and isinstance(denom, int), (
- f"{number}: {type(number)}, {denom}: {type(denom)}"
- )
- return runtime_ceildiv(number, denom)
- def _type_of(key: Optional[torch.dtype]) -> str:
- # Use the function here to get rid of dependencies on the Triton during the codegen.
- # Refer to Triton implementation here:
- # https://github.com/triton-lang/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
- # `None` is nullptr. Implicitly convert to *i8.
- if key is None:
- return "*i8"
- dtype_str = str(key).split(".")[-1]
- tys = {
- "bool": "i1",
- "float8e4nv": "fp8e4nv",
- "float8e5": "fp8e5",
- "float8e4b15": "fp8e4b15",
- "float8e4b15x4": "fp8e4b15x4",
- "float8_e4m3fn": "fp8e4nv",
- "float8_e5m2": "fp8e5",
- # TODO: remove when support is added in triton
- # https://github.com/triton-lang/triton/issues/6054
- "float8_e8m0fnu": "u8",
- "float4_e2m1fn_x2": "u8",
- "float16": "fp16",
- "bfloat16": "bf16",
- "float32": "fp32",
- "float64": "fp64",
- "int8": "i8",
- "int16": "i16",
- "int32": "i32",
- "int64": "i64",
- "uint8": "u8",
- "uint16": "u16",
- "uint32": "u32",
- "uint64": "u64",
- }
- # reinterpret can create triton type
- tys.update({v: v for v in list(tys.values())})
- return key if isinstance(key, str) else f"*{tys[dtype_str]}"
- def convert_shape_to_inductor(
- lst: Iterable[Union[int, torch.SymInt]],
- ) -> list[sympy.Expr]:
- """
- Gets the shape and stride of a tensor. For non-symbolic tensors, this is
- trivial. But for symbolic tensors, we need to map from SymIntNode into
- sympy.Expr.
- """
- return [sympy.sympify(i) for i in lst]
- def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]:
- """
- Like convert_shape_to_symint, but operates on a single expression.
- """
- from .virtualized import V
- return (
- i
- if isinstance(i, int)
- else (
- int(i)
- if isinstance(i, sympy.Integer)
- else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
- )
- )
- def convert_shape_to_symint(
- lst: Iterable[Union[int, sympy.Expr]],
- ) -> list[Union[int, torch.SymInt]]:
- """
- Takes a list of shapes from Inductor and converts them into symints (or just
- ints if all shapes are static).
- """
- return [convert_to_symint(i) for i in lst]
- def is_view(op: torch._ops.OpOverload) -> bool:
- """
- Does this op overload have aliasing
- """
- return any(a.alias_info is not None for a in op._schema.arguments)
- def is_pointwise_use(
- use: Node,
- is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
- ) -> bool:
- """
- Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
- Uses in views ops will follow the views uses
- """
- if not use.op == "call_function":
- return False
- if not (
- isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
- ):
- return False
- target = cast(torch._ops.OpOverload, use.target)
- if target is operator.getitem or is_view(target):
- return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
- return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
- def gen_gm_and_inputs(
- target: Any, args: list[Any], kwargs: dict[str, Any]
- ) -> tuple[GraphModule, list[torch.Tensor]]:
- g = torch.fx.Graph()
- graph_args: list[torch.Tensor] = []
- def add_tensor_arg(arg: torch.Tensor) -> Node:
- graph_args.append(arg)
- return g.placeholder(f"arg{len(graph_args)}")
- node = g.call_function(
- target, *tree_map_only(torch.Tensor, add_tensor_arg, (args, kwargs))
- )
- if (
- len(target._schema.returns) == 1
- and str(target._schema.returns[0].type) == "Tensor"
- ):
- node = (node,) # type: ignore[assignment]
- g.output(node)
- gm = torch.fx.GraphModule({}, g)
- return gm, graph_args
- def synchronize(device: str = "cuda") -> None:
- if device == "cpu":
- return
- device_interface = get_interface_for_device(device)
- if device_interface.is_available():
- device_interface.synchronize()
- def timed(
- model: Callable[..., Any],
- example_inputs: Sequence[Any],
- times: int = 1,
- device: str = "cuda",
- ) -> float:
- synchronize(device)
- torch.manual_seed(1337)
- t0 = time.perf_counter()
- for _ in range(times):
- result = model(*example_inputs)
- synchronize(device)
- t1 = time.perf_counter()
- # GC the result after timing
- assert result is not None # type: ignore[possibly-undefined]
- return t1 - t0
- def print_performance(
- model: Callable[..., Any],
- example_inputs: Sequence[Any] = (),
- times: int = 10,
- repeat: int = 10,
- baseline: float = 1.0,
- device: str = "cuda",
- ) -> float:
- timings = torch.tensor(
- [timed(model, example_inputs, times, device) for _ in range(repeat)]
- )
- took = torch.median(timings) / times
- print(f"{took / baseline:.6f}")
- return took.item()
- def precompute_method(obj: Any, method: str) -> None:
- """Replace obj.method() with a new method that returns a precomputed constant."""
- result = getattr(obj, method)()
- setattr(obj, method, lambda: result)
- def precompute_methods(obj: Any, methods: list[str]) -> None:
- """Replace methods with new methods that returns a precomputed constants."""
- for method in methods:
- precompute_method(obj, method)
- def cmp(a: int, b: int) -> int:
- return int(a > b) - int(a < b)
- def pad_listlike(x: Union[int, Sequence[int]], size: int) -> Sequence[int]:
- if isinstance(x, int):
- return [x] * size
- if len(x) == 1:
- return type(x)([x[0]]) * size # type: ignore[call-arg, operator, return-value]
- return x
- # Used to ensure that iterating over a set is deterministic
- def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
- if len(x) == 0:
- return []
- def sort_func(elem: _T) -> str:
- if isinstance(elem, str):
- return elem
- from .scheduler import BaseSchedulerNode
- assert isinstance(elem, BaseSchedulerNode)
- return elem.get_name()
- return sorted(x, key=sort_func)
- P = ParamSpec("P")
- RV = TypeVar("RV", covariant=True)
- FN_TYPE = Callable[Concatenate[Any, P], RV]
- class CachedMethod(Protocol, Generic[P, RV]):
- @staticmethod
- def clear_cache(cache: Any) -> None: ...
- def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ...
- # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
- def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
- name = fn.__name__
- key = f"__{name}_cache"
- # wrapper is likely on the hot path, compile a specialized version of it
- ctx = {"fn": fn}
- exec(
- f"""\
- def {name}_cache_on_self(self):
- try:
- return self.{key}
- except AttributeError:
- pass
- rv = fn(self)
- object.__setattr__(self, "{key}", rv)
- return rv
- """.lstrip(),
- ctx,
- )
- wrapper = functools.wraps(fn)(ctx[f"{name}_cache_on_self"])
- def clear_cache(self: Any) -> None:
- if hasattr(self, key):
- delattr(self, key)
- wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
- return wrapper # type: ignore[return-value]
- def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
- """
- Variant of cache_on_self for properties. The only difference is the type signature.
- """
- # pyrefly: ignore [bad-argument-type]
- return cache_on_self(fn)
- def cache_on_self_and_args(
- class_name: str,
- ) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]:
- # include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls.
- def wrapper(
- fn: FN_TYPE[P, RV],
- ) -> FN_TYPE[P, RV]:
- key = f"__{class_name}_{fn.__name__}_cache"
- # wrapper is likely on the hot path, compile a specialized version of it
- ctx = {"fn": fn}
- exec(
- f"""\
- def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
- args_kwargs = (args, tuple(sorted(kwargs.items())))
- if not hasattr(self, "{key}"):
- object.__setattr__(self, "{key}", {{}})
- cache = self.{key}
- try:
- return cache[args_kwargs]
- except KeyError:
- pass
- rv = fn(self, *args, **kwargs)
- cache[args_kwargs] = rv
- return rv
- """.lstrip(),
- ctx,
- )
- inner = functools.wraps(fn)(ctx["inner"])
- def clear_cache(self: Any) -> None:
- if hasattr(self, key):
- delattr(self, key)
- inner.clear_cache = clear_cache # type: ignore[attr-defined]
- return inner
- return wrapper
- def aggregate_origins(
- node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
- ) -> OrderedSet[Node]:
- from . import ir
- if isinstance(node_schedule, list):
- return functools.reduce(
- operator.or_,
- [
- node.node.origins
- for node in node_schedule
- if hasattr(node, "node") and node.node
- ],
- OrderedSet(),
- )
- elif isinstance(node_schedule, ir.ExternKernel):
- return node_schedule.origins
- else:
- return OrderedSet()
- def get_fused_kernel_name(
- node_schedule: Sequence[BaseSchedulerNode],
- descriptive_names: Literal[True, "torch", "original_aten", "inductor_node"],
- ) -> str:
- all_origins = aggregate_origins(node_schedule)
- if descriptive_names == "original_aten":
- # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
- sources = [
- origin.meta["original_aten"]._overloadpacket.__name__
- for origin in all_origins
- if origin.op == "call_function"
- and "original_aten" in origin.meta
- and origin.meta["original_aten"] is not None
- ]
- sources = sorted(OrderedSet(sources))
- elif descriptive_names == "torch":
- # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
- sources = []
- for origin in all_origins:
- if origin.op == "call_function" and "source_fn_stack" in origin.meta:
- source_fn = origin.meta["source_fn_stack"][-1]
- if isinstance(source_fn[1], str):
- sources.append(source_fn[1])
- else:
- sources.append(source_fn[1].__name__)
- sources = sorted(OrderedSet(sources))
- elif descriptive_names == "inductor_node":
- sources = [
- origin.name for origin in all_origins if origin.op == "call_function"
- ]
- else:
- raise NotImplementedError
- sources = sources
- return "_".join(["fused"] + sources)
- def get_kernel_metadata(
- node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
- wrapper: PythonWrapperCodegen,
- ) -> tuple[str, str]:
- """
- Retrieves metadata information for a kernel.
- Args:
- node_schedule (Union[Sequence[BaseSchedulerNode], ExternKernel]):
- Either a sequence of BaseSchedulerNode objects or an ExternKernel instance.
- wrapper (PythonWrapperCodegen):
- An instance of PythonWrapperCodegen, used to define the code comment format.
- Returns:
- tuple[str, str]:
- A tuple containing two strings:
- - The first string represents the kernel's metadata.
- - The second string represent the kernel's detailed metadata.
- """
- all_origins = aggregate_origins(node_schedule)
- inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
- from_node_dict = collections.defaultdict(list)
- original_aten_dict = collections.defaultdict(list)
- # Attempt to sort `inductor_nodes` topologically. Note that the case
- # where `inductor_nodes` contains nodes from multiple graph instances
- # is not supported. An example of this is conditional statements.
- single_graph = None
- if len(inductor_nodes):
- unique_graphs = OrderedSet(n.graph for n in inductor_nodes)
- if len(unique_graphs) == 1:
- single_graph = inductor_nodes[0].graph
- # create a map of idx -> node and cache it
- if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"):
- node_to_idx_map = {n: idx for idx, n in enumerate(single_graph.nodes)}
- single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map # type: ignore[attr-defined]
- inductor_nodes.sort(
- key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] # type: ignore[attr-defined]
- )
- for node in inductor_nodes:
- if "original_aten" in node.meta and node.meta["original_aten"] is not None:
- key = str(node.meta["original_aten"]._overloadpacket)
- original_aten_dict[key].append(node.name)
- if "from_node" in node.meta:
- key = node.meta["from_node"][0].name
- from_node_dict[key].append(node.name)
- sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted"
- metadata = (
- f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], "
- f"Original ATen: [{', '.join(original_aten_dict.keys())}]"
- )
- # trace back to original node here
- detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"]
- for original_node, nodes in sorted(from_node_dict.items()):
- detailed_metadata.append(
- f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
- )
- # print the aot_autograd graph fragment
- if single_graph is not None:
- from . import ir
- detailed_metadata.append(f"{wrapper.comment} Graph fragment:")
- all_reads: OrderedSet[str] = OrderedSet()
- all_writes: list[str] = []
- if not isinstance(node_schedule, ir.ExternKernel):
- from .virtualized import V
- def get_buffer_info(
- buffer: Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject], rw_name: str
- ) -> tuple[str, ir.Layout | None]:
- if isinstance(buffer, ir.TensorBox) and isinstance(
- buffer.data, ir.StorageBox
- ):
- origin_node = buffer.data.data.origin_node
- else:
- origin_node = buffer.origin_node
- if origin_node is None:
- # use the read/write name if no origin node is found
- name = rw_name
- else:
- name = origin_node.name
- try:
- layout = buffer.get_layout()
- except NotImplementedError:
- layout = None
- return name, layout
- def stringify_shape(shape: Iterable[int]) -> str:
- return f"[{', '.join([str(x) for x in shape])}]"
- def stringfy_layout(layout: ir.Layout | None) -> str:
- if layout is None:
- return ""
- shape_annotation = f"{stringify_shape(layout.size)}"
- stride_annotation = f"{stringify_shape(layout.stride)}"
- device_annotation = f"{layout.device}"
- return (
- f'"{dtype_abbrs[layout.dtype]}{shape_annotation}'
- f'{stride_annotation}{device_annotation}"'
- )
- for n in node_schedule:
- if not hasattr(n, "read_writes") or n.read_writes is None:
- continue
- if hasattr(n.read_writes, "reads") and n.read_writes.reads is not None:
- for r in n.read_writes.reads:
- # Remove the dupricated inputs
- if r.name in all_reads:
- continue
- all_reads.add(r.name)
- buffer = V.graph.try_get_buffer(r.name)
- if buffer is None:
- continue
- input_name, layout = get_buffer_info(buffer, r.name)
- detailed_metadata.append(
- f"{wrapper.comment} %{input_name} : Tensor "
- f"{stringfy_layout(layout)} = PlaceHolder[target={input_name}]"
- )
- if (
- hasattr(n.read_writes, "writes")
- and n.read_writes.writes is not None
- ):
- for w in n.read_writes.writes:
- buffer = V.graph.try_get_buffer(w.name)
- if buffer is None:
- continue
- output_name, _ = get_buffer_info(buffer, w.name)
- all_writes.append("%" + output_name)
- for node in inductor_nodes:
- detailed_metadata.append(
- f"{wrapper.comment} {node.format_node(include_tensor_metadata=True)}"
- )
- detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}")
- return metadata, "\n".join(detailed_metadata)
- def dominated_nodes(
- initial_queue: Iterable[torch.fx.Node],
- skip_filter: Optional[Callable[[Any], bool]] = None,
- ) -> OrderedSet[torch.fx.Node]:
- """Returns the set of nodes whose values depend on those within initial_queue"""
- initial_queue = list(initial_queue)
- dominated_set = OrderedSet(initial_queue)
- while initial_queue:
- node = initial_queue.pop()
- for user in node.users:
- if skip_filter and skip_filter(user):
- continue
- if user not in dominated_set:
- dominated_set.add(user)
- initial_queue.append(user)
- return dominated_set
- def gather_origins(
- args: Sequence[IRNode], kwargs: dict[str, IRNode]
- ) -> OrderedSet[torch.fx.Node]:
- from . import ir
- def is_unrealized_node(n: IRNode) -> bool:
- if isinstance(n, ir.TensorBox):
- return is_unrealized_node(n.data)
- if isinstance(n, ir.StorageBox):
- return is_unrealized_node(n.data)
- return isinstance(n, ir.IRNode) and not isinstance(
- n,
- (
- ir.ComputedBuffer,
- ir.InputsKernel,
- ir.InputBuffer,
- ir.TemplateBuffer,
- ),
- )
- # kwargs and args may include a container of node, for example torch.cat([t1, t2])
- # flatten them before search the unrealized nodes
- kwargs_flatten, _ = tree_flatten(kwargs)
- kwargs_origins = [val.origins for val in kwargs_flatten if is_unrealized_node(val)]
- args_flatten, _ = tree_flatten(args)
- args_origins = [val.origins for val in args_flatten if is_unrealized_node(val)]
- return OrderedSet(itertools.chain(*args_origins, *kwargs_origins))
- def sympy_str(expr: sympy.Expr) -> str:
- """
- Normal sympy str is very slow, this is a lot faster. The result are
- somewhat worse, as it doesn't do as much simplification. So don't
- use this for final codegen.
- """
- def is_neg_lead(expr: sympy.Expr) -> bool:
- return (
- isinstance(expr, sympy.Mul) and len(expr.args) == 2 and expr.args[0] == -1
- )
- def sympy_str_add(expr: sympy.Expr) -> str:
- if isinstance(expr, sympy.Add):
- # Special case 'a - b'. Note that 'a - b - c' will still appear as
- # 'a + -1 * b + -1 * c'.
- if len(expr.args) == 2 and is_neg_lead(expr.args[1]):
- return f"{sympy_str_mul(expr.args[0])} - {sympy_str_mul(expr.args[1].args[1])}"
- else:
- return " + ".join(map(sympy_str_mul, expr.args))
- else:
- return sympy_str_mul(expr)
- def sympy_str_mul(expr: sympy.Expr) -> str:
- if isinstance(expr, sympy.Mul):
- if is_neg_lead(expr):
- # Special case '-a'. Note that 'a * -b' will still appear as
- # '-1 * a * b'.
- return f"-{sympy_str_atom(expr.args[1])}"
- else:
- return " * ".join(map(sympy_str_atom, expr.args))
- else:
- return sympy_str_atom(expr)
- def sympy_str_atom(expr: sympy.Expr) -> str:
- if isinstance(expr, sympy.Symbol):
- return expr.name
- elif isinstance(expr, (sympy.Add, sympy.Mul)):
- return f"({sympy_str_add(expr)})"
- elif isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)):
- return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
- else:
- return str(expr)
- return sympy_str_add(expr)
- def get_bounds_index_expr(index: sympy.Expr) -> ValueRanges[Any]:
- from .virtualized import V
- # If this expression does not come from an FX node, we compute its bounds
- if (
- config.compute_all_bounds
- and (fx_node := getattr(V.interpreter, "current_node", None))
- and fx_node.target != "index_expr"
- ):
- return bound_sympy(index)
- else:
- return ValueRanges.unknown()
- def prefix_is_reduction(prefix: str) -> bool:
- return prefix[0] == "r"
- def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
- """
- Used to generate an integer-nonnegative symbol.
- """
- # This should never be used for creating shape/stride symbols, as those
- # should all be allocated before Inductor.
- assert prefix != SymT.SIZE
- # NOTE: shape symbols are positive (> 0), but index variables are only
- # non-negative (>= 0).
- return make_symbol(prefix, idx, integer=True, nonnegative=True)
- def generate_assert(check: bool) -> bool:
- return (check or config.debug_index_asserts) and config.assert_indirect_indexing
- def sympy_index_symbol(name: str) -> sympy.Symbol:
- """
- Used to generate an integer-nonnegative symbol.
- """
- # This should never be used for creating shape/stride symbols, as those
- # should all be allocated before Inductor.
- assert name[0] != "s"
- # NOTE: shape symbols are positive (> 0), but index variables are only
- # non-negative (>= 0).
- return sympy.Symbol(name, integer=True, nonnegative=True)
- def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.Expr:
- """
- When the passed replacement symbol v is a string, it is converted to a symbol with name v that
- have the same replaced expression integer and nonnegative properties.
- """
- def to_symbol(
- replaced: sympy.Expr, replacement: Union[sympy.Expr, str]
- ) -> sympy.Symbol:
- assert isinstance(replaced, sympy.Expr)
- if isinstance(replacement, str):
- return sympy.Symbol(
- replacement,
- integer=replaced.is_integer, # type: ignore[attr-defined]
- nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
- )
- else:
- return replacement
- # xreplace is faster than subs, but is way more picky
- return sympy.sympify(expr).xreplace(
- {k: to_symbol(k, v) for k, v in replacements.items()}
- )
- def is_symbolic(a: Any) -> TypeGuard[Union[torch.SymInt, torch.Tensor]]:
- return isinstance(a, torch.SymInt) or (
- isinstance(a, torch.Tensor)
- and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride()))
- )
- def any_is_symbolic(*args: Any) -> bool:
- return any(is_symbolic(a) for a in args)
- def get_first_incompatible_cudagraph_node(
- gm: torch.fx.GraphModule,
- ) -> Optional[torch.fx.Node]:
- from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
- forbidden_set = OrderedSet(
- [
- "aten._fused_moving_avg_obs_fq_helper.default",
- "aten._fused_moving_avg_obs_fq_helper_functional.default",
- "fbgemm.dense_to_jagged.default",
- "fbgemm.jagged_to_padded_dense.default",
- "run_and_save_rng_state",
- "run_with_rng_state",
- "aten._local_scalar_dense",
- # Technically, it's not necessary to ban this, because an
- # assert_scalar with constant arguments can be validly run
- # with CUDA graphs, but the operator is also pointless with
- # constant arguments, so might as well ban
- "aten._assert_scalar",
- ]
- )
- if torch.are_deterministic_algorithms_enabled():
- forbidden_set.update(
- (
- "aten._unsafe_index_put.default",
- "aten._unsafe_masked_index_put_accumulate.default",
- "aten.index_put.default",
- "aten.index_put_.default",
- "aten.scatter.src",
- "aten.scatter.reduce",
- "aten.scatter.value_reduce",
- "aten.scatter_add_",
- "aten.scatter_add.default",
- "aten.scatter_reduce.two",
- "aten.scatter_reduce_.two",
- "aten.scatter_reduce.two_out",
- )
- )
- for node in gm.graph.nodes:
- if str(node.target) in forbidden_set:
- return node
- if (
- not torch._inductor.config.graph_partition
- and isinstance(node.target, torch._ops.OpOverload)
- and torch._C.Tag.cudagraph_unsafe in node.target.tags # type: ignore[attr-defined]
- ):
- # skip cudagraph if a cudagraph_unsafe op is detected.
- # graph_partition helps by splitting on this cudagraph_unsafe
- # op and cudagraphifying the subgraphs.
- return node
- if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
- return node
- return None
- def output_node(gm: torch.fx.GraphModule) -> Node:
- """Get the output node from an FX graph"""
- last_node = next(iter(reversed(gm.graph.nodes)))
- assert last_node.op == "output"
- return last_node
- def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
- placeholder_nodes = gm.graph.find_nodes(op="placeholder")
- input_devices: OrderedSet[torch.device] = OrderedSet(
- node.meta["val"].device
- for node in placeholder_nodes
- if isinstance(node.meta.get("val"), torch.Tensor)
- )
- out_arg = output_node(gm).args[0] # type: ignore[union-attr]
- out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,)
- out_devices: OrderedSet[torch.device] = OrderedSet(
- arg.meta["val"].device
- for arg in out_args
- if isinstance(arg, torch.fx.Node)
- and isinstance(arg.meta.get("val"), torch.Tensor)
- )
- return input_devices | out_devices
- import gc
- def unload_xpu_triton_pyds() -> None:
- # unload __triton_launcher.pyd
- for module_name in list(sys.modules.keys()):
- if not module_name.startswith("torch._inductor.runtime.compile_tasks."):
- continue
- m = sys.modules[module_name]
- for attr_name in m.__dict__.keys():
- if attr_name.startswith("triton_"):
- kernel = getattr(m, attr_name)
- if isinstance(
- kernel, torch._inductor.runtime.triton_heuristics.CachingAutotuner
- ):
- for result in kernel.compile_results:
- if isinstance(
- result,
- torch._inductor.runtime.triton_heuristics.TritonCompileResult,
- ):
- result.kernel.run.mod.__del__()
- del sys.modules[module_name]
- # unload spirv_utils.pyd
- if "triton.runtime.driver" in sys.modules:
- mod = sys.modules["triton.runtime.driver"]
- del type(mod.driver.active.utils).instance
- del mod.driver.active.utils
- gc.collect()
- _registered_caches: list[Any] = []
- def clear_on_fresh_cache(obj: Any) -> Any:
- """
- Use this decorator to register any caches that should be cache_clear'd
- with fresh_cache().
- """
- if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
- raise AttributeError(f"{obj} does not have a cache_clear method")
- _registered_caches.append(obj)
- return obj
- def clear_caches() -> None:
- """
- Clear all registered caches.
- """
- for obj in _registered_caches:
- obj.cache_clear()
- @contextlib.contextmanager
- def fresh_cache(
- cache_entries: Optional[dict[str, Any]] = None,
- dir: Optional[str] = None,
- delete: bool = True,
- ) -> Iterator[None]:
- """
- Contextmanager that provides a clean tmp cachedir for pt2 caches.
- Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
- generated with this cache instance.
- """
- clear_caches()
- from torch._inductor.cpp_builder import normalize_path_separator
- inductor_cache_dir = normalize_path_separator(tempfile.mkdtemp(dir=dir))
- try:
- with mock.patch.dict(
- os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
- ):
- log.debug("Using inductor cache dir %s", inductor_cache_dir)
- triton_cache_dir = normalize_path_separator(
- os.path.join(inductor_cache_dir, "triton")
- )
- with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
- yield
- if isinstance(cache_entries, dict):
- assert len(cache_entries) == 0, "expected empty cache_entries dict"
- if os.path.exists(triton_cache_dir):
- files = os.listdir(triton_cache_dir)
- cache_entries.update(
- {
- f: os.path.getsize(os.path.join(triton_cache_dir, f))
- for f in files
- if ".lock" not in f
- }
- )
- if delete:
- if is_windows() and torch.xpu.is_available():
- unload_xpu_triton_pyds()
- shutil.rmtree(
- inductor_cache_dir,
- # Let's not fail if we can't clean up the temp dir. Also note that for
- # Windows, we can't delete the loaded modules because the module binaries
- # are open.
- ignore_errors=is_windows(),
- onerror=lambda func, path, exc_info: log.warning(
- "Failed to remove temporary cache dir at %s",
- inductor_cache_dir,
- exc_info=exc_info,
- ),
- )
- except Exception:
- log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
- raise
- finally:
- clear_caches()
- # Deprecated functions -- only keeping them for BC reasons
- clear_on_fresh_inductor_cache = clear_on_fresh_cache
- clear_inductor_caches = clear_caches
- fresh_inductor_cache = fresh_cache
- def argsort(seq: Sequence[Any]) -> list[int]:
- # preserve original order for equal strides
- getter = seq.__getitem__
- a_r = range(len(seq))
- return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
- def argsort_sym(
- shape_env: ShapeEnv, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]]
- ) -> list[int]:
- def cmp(a: tuple[int, sympy.Expr], b: tuple[int, sympy.Expr]) -> int:
- a_idx, a_val = a
- b_idx, b_val = b
- def evaluate(expr: Union[bool, torch.SymInt, sympy.Expr]) -> bool:
- if isinstance(expr, bool):
- return expr
- return shape_env.evaluate_expr(expr, size_oblivious=True)
- if evaluate(a_val < b_val):
- return -1
- if evaluate(a_val > b_val):
- return 1
- # If strides are the same, prefer the original order.
- # (this matches argsort's algorithm).
- # For strides = [2048, 2048, 16, 1], this is
- # [3, 2, 1, 0].
- if a_idx < b_idx:
- return 1
- if a_idx > b_idx:
- return -1
- return 0
- # Strategy: convert all symints to sympy.Expr, then use a custom comparator
- exprs = [
- (idx, s.node.expr if isinstance(s, torch.SymInt) else s)
- for idx, s in enumerate(seq)
- ]
- exprs = sorted(exprs, key=functools.cmp_to_key(cmp))
- result = [idx for idx, _ in exprs]
- return result
- @functools.lru_cache(8)
- def get_dtype_size(dtype: torch.dtype) -> int:
- # TODO: Investigate why uint64 tensor creation causes overflow error:
- # Workaround for RuntimeError in memory size calculation, but underlying cause unclear
- if dtype == torch.uint64:
- return 8
- return torch.empty((), dtype=dtype).element_size()
- class LineContext(NamedTuple):
- context: Any
- @dataclasses.dataclass
- class ValueWithLineMap:
- value: str
- line_map: list[tuple[int, LineContext]]
- class IndentedBuffer:
- tabwidth = 4
- def __init__(self, initial_indent: int = 0) -> None:
- self._lines: list[Union[DeferredLineBase, LineContext, str]] = []
- self._indent = initial_indent
- @contextlib.contextmanager
- def set_tabwidth(self, tabwidth: int) -> Iterator[None]:
- prev = self.tabwidth
- try:
- self.tabwidth = tabwidth
- yield
- finally:
- self.tabwidth = prev
- def getvaluewithlinemap(self) -> ValueWithLineMap:
- buf = StringIO()
- p = 1
- linemap: list[tuple[int, LineContext]] = []
- for li in self._lines:
- if isinstance(li, DeferredLineBase):
- line = li()
- if line is None:
- continue
- elif isinstance(li, LineContext):
- linemap.append((p, li.context))
- continue
- else:
- line = li
- assert isinstance(line, str)
- buf.write(line)
- buf.write("\n")
- p += 1 + line.count("\n")
- return ValueWithLineMap(buf.getvalue(), linemap)
- def getvalue(self) -> str:
- return self.getvaluewithlinemap().value
- def getrawvalue(self) -> str:
- buf = StringIO()
- for li in self._lines:
- if isinstance(li, DeferredLineBase):
- line = li()
- if line is None:
- continue
- elif isinstance(li, LineContext):
- continue
- else:
- line = li
- assert isinstance(line, str)
- # backslash implies line continuation
- if line.endswith("\\"):
- buf.write(line[:-1])
- else:
- buf.write(line)
- buf.write("\n")
- return buf.getvalue()
- def clear(self) -> None:
- self._lines.clear()
- def __bool__(self) -> bool:
- return bool(self._lines)
- def prefix(self) -> str:
- return " " * (self._indent * self.tabwidth)
- def newline(self) -> None:
- self.writeline("\n")
- def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None:
- if isinstance(line, LineContext):
- self._lines.append(line)
- elif isinstance(line, DeferredLineBase):
- self._lines.append(line.with_prefix(self.prefix()))
- elif line.strip():
- self._lines.append(f"{self.prefix()}{line}")
- else:
- self._lines.append("")
- def writelines(
- self, lines: Sequence[Union[LineContext, DeferredLineBase, str]]
- ) -> None:
- for line in lines:
- self.writeline(line)
- def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
- @contextlib.contextmanager
- def ctx() -> Iterator[None]:
- self._indent += offset
- try:
- yield
- finally:
- self._indent -= offset
- return ctx()
- def do_indent(self, offset: int = 1) -> None:
- self._indent += offset
- def do_unindent(self, offset: int = 1) -> None:
- self._indent -= offset
- def splice(
- self, other_code: Union[IndentedBuffer, str], strip: bool = False
- ) -> None:
- if isinstance(other_code, IndentedBuffer):
- dedent = float("inf")
- for line in other_code._lines:
- if not isinstance(line, LineContext) and line:
- dedent = min(dedent, len(line) - len(line.lstrip()))
- if math.isinf(dedent):
- dedent = 0
- for line in other_code._lines:
- if isinstance(line, LineContext):
- self._lines.append(line)
- else:
- IndentedBuffer.writeline(self, line[int(dedent) :])
- else:
- other_code = textwrap.dedent(other_code)
- if strip:
- other_code = other_code.lstrip()
- if not other_code:
- return
- other_code = other_code.rstrip()
- for s in other_code.split("\n"):
- self.writeline(s)
- def map(self, func: Callable[[Any], Any]) -> IndentedBuffer:
- res = IndentedBuffer(initial_indent=self._indent)
- res._lines = [func(line) for line in self._lines]
- return res
- def __repr__(self) -> str:
- return f"{type(self)}({self.getvalue()})"
- def __add__(self, other: Self) -> IndentedBuffer:
- assert self._indent == other._indent
- res = IndentedBuffer(initial_indent=self._indent)
- # TODO(rec): or should this be self.__class__(initial_indent=self._indent)?
- res.writelines(self._lines)
- res.writelines(other._lines)
- return res
- def contains(self, new_line: Union[DeferredLineBase, LineContext, str]) -> bool:
- return new_line in self._lines
- class FakeIndentedBuffer(IndentedBuffer):
- def __init__(self) -> None:
- super().__init__()
- def __getattribute__(self, name: str) -> Any:
- if name == "__class__": # Allow access to the class attribute
- return object.__getattribute__(self, name)
- raise RuntimeError(
- f"Tried to call self.{name} on FakeIndentedBuffer. This buffer"
- "is currently used on TritonTemplateKernel to prevent actual"
- "writes to the body without explicitly specifying the body with"
- "`TritonTemplateKernel.set_subgraph_body(name)`"
- )
- @contextlib.contextmanager
- def restore_stdout_stderr() -> Iterator[None]:
- initial_stdout, initial_stderr = sys.stdout, sys.stderr
- try:
- yield
- finally:
- sys.stdout, sys.stderr = initial_stdout, initial_stderr
- class DeferredLineBase:
- """A line that can be 'unwritten' at a later time"""
- def __init__(self, line: str):
- if not line.strip():
- line = ""
- self.line = line
- def __call__(self) -> Union[str, None]:
- """Returns either self.line or None to indicate the line has been 'unwritten'"""
- raise NotImplementedError
- def _new_line(self, line: str) -> Self:
- """Returns a new deferred line with the same condition"""
- raise NotImplementedError
- def with_prefix(self, prefix: str) -> Self:
- return self._new_line(f"{prefix}{self.line}")
- def lstrip(self) -> Self:
- return self._new_line(self.line.lstrip())
- def __getitem__(self, index: Union[int, slice]) -> Self:
- return self._new_line(self.line[index])
- def __bool__(self) -> bool:
- return bool(self.line)
- def __len__(self) -> int:
- return len(self.line)
- class DelayReplaceLine(DeferredLineBase):
- """At end of codegen call `line.replace(key, value_fn())`"""
- def __init__(self, key: str, value_fn: Callable[[], str], line: str):
- super().__init__(line)
- self.key = key
- self.value_fn = value_fn
- def __call__(self) -> str:
- return self.line.replace(self.key, self.value_fn())
- def _new_line(self, line: str) -> DelayReplaceLine:
- return DelayReplaceLine(self.key, self.value_fn, line)
- @functools.cache
- def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
- if isinstance(index_or_device, torch.device):
- device = index_or_device
- else:
- device = torch.device(get_gpu_type(), index_or_device)
- prop = DeviceProperties.create(device)
- # SM logic is not relevant to ROCm gpus
- # Arbitrarily skipping the older models
- if torch.version.hip:
- assert prop.major is not None
- if prop.major < 9 or prop.major == 10:
- log.warning("GPU arch does not support max_autotune_gemm mode usage")
- return False
- return True
- min_sms = 16 if device.type == "xpu" else 68 # 3080
- avail_sms = prop.multi_processor_count
- if avail_sms < min_sms:
- log.warning(
- "Not enough SMs to use max_autotune_gemm mode",
- extra={"min_sms": min_sms, "avail_sms": avail_sms},
- )
- return False
- return True
- @functools.lru_cache
- def get_max_num_sms() -> int:
- if torch.xpu.is_available():
- return torch.xpu.get_device_properties().gpu_subslice_count
- return torch.cuda.get_device_properties("cuda").multi_processor_count
- @functools.lru_cache
- def using_b200() -> bool:
- """Returns true if the device is a NVIDIA B200, otherwise returns false."""
- if not torch.cuda.is_available():
- return False
- # compute capability 10.0 or 10.0a is NVIDIA B200
- device_properties = torch.cuda.get_device_properties(torch.cuda.current_device())
- return device_properties.major == 10
- def get_num_sms() -> int:
- """Handle experimental carveout if set otherwise return hardware SM count"""
- # TODO we need to properly guard on this global
- if torch.xpu.is_available():
- return get_max_num_sms()
- carveout = torch._C._get_sm_carveout_experimental()
- return get_max_num_sms() - (carveout if carveout is not None else 0)
- def get_tma_workspace_arg(
- num_tma_descriptors: int,
- device: torch.device,
- num_programs: Optional[int] = None,
- ) -> WorkspaceArg:
- """Builds and returns a WorkspaceArg for the device side TMA workspace buffer."""
- from .codegen.common import WorkspaceArg, WorkspaceZeroMode
- if num_programs is None:
- num_programs = get_num_sms()
- zero_mode = WorkspaceZeroMode.from_bool(False)
- size = num_programs * num_tma_descriptors * TMA_DESCRIPTOR_SIZE
- return WorkspaceArg(
- count=size,
- zero_mode=zero_mode,
- device=device,
- outer_name=WorkspaceArg.unique_name(),
- )
- def _use_template_for_gpu(
- layout: Layout, allowed_layout_dtypes: list[torch.dtype]
- ) -> bool:
- if layout.dtype not in allowed_layout_dtypes:
- log.debug(
- "Not using template since dtype %s is not in allowed layout dtypes %s",
- layout.dtype,
- allowed_layout_dtypes,
- )
- return (
- is_gpu(layout.device.type)
- and layout.dtype in allowed_layout_dtypes
- and is_big_gpu(layout.device)
- )
- def _use_autotune_backend(backend: str) -> bool:
- return backend.upper() in [
- x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
- ]
- def _use_conv_autotune_backend(backend: str) -> bool:
- return backend.upper() in [
- x.strip() for x in config.max_autotune_conv_backends.upper().split(",")
- ]
- def use_triton_template(
- layout: Layout,
- *,
- enable_int32: bool = False,
- enable_float8: bool = False,
- check_max_autotune: bool = True,
- ) -> bool:
- from .codegen.common import BackendFeature, has_backend_feature
- layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
- if enable_int32:
- layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
- if enable_float8:
- layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2])
- return (
- (
- (
- is_gpu(layout.device.type)
- and _use_template_for_gpu(layout, layout_dtypes)
- )
- or (layout.device.type == "cpu" and layout.dtype in layout_dtypes)
- )
- # some callers handle max-autotune checking externally
- and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune)
- and _use_autotune_backend("TRITON")
- and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
- )
- def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
- """
- Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints
- that Triton relies on today.
- * https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
- A tensor is accepted when:
- * 2 ≤ rank ≤ 5
- * dtype ∈ {FP16, BF16, FP8-E4M3FN}
- * Every logical size ≥ 2
- * Base pointer 16-byte aligned
- * All "outer" dims have 16-byte aligned strides
- * The “inner” dim has stride 1 (contiguous)
- * For FP8 tensors, inner dim ≥ 32
- """
- from torch.utils._triton import has_triton_tma_device
- from .virtualized import V
- def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool:
- return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT)
- def _is_tma_compatible_default(x: IRNode) -> bool:
- sizes = x.get_size()
- strides = x.get_stride()
- rank = len(sizes)
- dtype = x.get_dtype()
- itemsize = dtype.itemsize
- # 2 ≤ rank ≤ 5
- if rank < 2 or rank > 5:
- return False
- # dtype ∈ {FP16, BF16, FP8-E4M3FN}
- if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn):
- return False
- # Base pointer 16-byte aligned
- if x.get_name() in V.graph.unaligned_buffers:
- return False
- if add_guards:
- sizes_i = V.graph.sizevars.guard_int_seq(sizes)
- strides_i = V.graph.sizevars.guard_int_seq(strides)
- else:
- sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes]
- strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
- # Every logical size ≥ 2
- if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i):
- return False
- # Find the single contiguous (“inner”) dim
- inner = [
- i
- for i, st in enumerate(strides_i)
- if V.graph.sizevars.statically_known_equals(st, 1)
- ]
- if len(inner) != 1:
- return False
- inner_idx = inner[0]
- # All "outer" dims must have 16-byte aligned strides
- for i, st in enumerate(strides_i):
- if i == inner_idx:
- continue
- if not _aligned(st * itemsize):
- return False
- # Inner dim byte width must still be a multiple of 16 B
- inner_dim = sizes_i[inner_idx]
- if not _aligned(inner_dim * itemsize):
- return False
- # FP8 special case: inner ≥ 32
- if dtype == torch.float8_e4m3fn and not V.graph.sizevars.statically_known_geq(
- inner_dim, 32
- ):
- return False
- return True
- def _is_tma_compatible_xpu(x: IRNode) -> bool:
- strides = x.get_stride()
- strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
- # Find the single contiguous (“inner”) dim
- inner = [
- i
- for i, st in enumerate(strides_i)
- if V.graph.sizevars.statically_known_equals(st, 1)
- ]
- if len(inner) != 1:
- return False
- return True
- return has_triton_tma_device() and all(
- _is_tma_compatible_default(m)
- if (m_device := m.get_device()) is None or m_device.type != "xpu"
- else _is_tma_compatible_xpu(m)
- for m in matrices
- )
- def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool:
- return (
- all(len(m.get_size()) == 2 for m in matrices)
- and can_use_tma(*matrices, add_guards=add_guards)
- and config.triton.enable_persistent_tma_matmul
- )
- def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
- from .virtualized import V
- gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
- if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
- return False
- from .codegen.cuda.cutlass_utils import try_import_cutlass
- # Do not use cutlass template on ROCm
- if torch.version.hip:
- return False
- # output dtype
- # FP32 not supported: https://github.com/pytorch/pytorch/issues/145952
- layout_dtypes = [torch.float16, torch.bfloat16, torch.int32]
- res = (
- _use_template_for_gpu(layout, layout_dtypes)
- and (config.max_autotune or config.max_autotune_gemm)
- and _use_autotune_backend("CUTLASS")
- )
- if res:
- if not try_import_cutlass():
- log.warning(
- "Failed to import CUTLASS lib. Please check whether "
- "_inductor.config.cuda.cutlass_dir %s is set correctly. "
- "Skipping CUTLASS backend for now.",
- config.cuda.cutlass_dir,
- )
- return False
- return res
- def _use_cutlass_for_op(op_name: str) -> bool:
- """Check if CUTLASS should be used for the given operation."""
- enabled_ops = config.cuda.cutlass_enabled_ops.upper()
- if enabled_ops == "ALL":
- return True
- return op_name.upper() in [x.strip() for x in enabled_ops.split(",")]
- _IntLike: TypeAlias = Union[int, sympy.Expr]
- @functools.cache
- def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
- from torch._inductor.virtualized import V
- decompose_k_threshold = config.triton.decompose_k_threshold
- return (
- not torch.version.hip
- and V.graph.sizevars.statically_known_true(
- sympy.And(
- sympy.Ge(k, decompose_k_threshold * m),
- sympy.Ge(k, decompose_k_threshold * n),
- )
- )
- and not V.graph.aot_mode # TODO: Support AOTI for decomposeK
- and not V.graph.cpp_wrapper
- )
- @functools.cache
- def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
- """
- Check if we should use the contiguous subgraph transform.
- This transform makes the second matrix contiguous before the matmul.
- """
- contiguous_threshold = config.rocm.contiguous_threshold
- # Similar conditions to decompose_k but for contiguous transform
- from torch._inductor.virtualized import V
- return (
- bool(torch.version.hip) # Only relevant on AMD
- and V.graph.sizevars.statically_known_true(
- sympy.And(
- sympy.Ge(k, contiguous_threshold * m),
- sympy.Ge(k, contiguous_threshold * n),
- )
- )
- and not V.graph.aot_mode
- and not V.graph.cpp_wrapper
- )
- @functools.cache
- def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
- # To limit compile time
- k_splits_limit = config.triton.num_decompose_k_splits
- # Hand-tuned
- default_k_splits = [16, 32, 64, 128, 256]
- # If k is a sympy expression, we can't do any splitting
- if isinstance(k, sympy.Expr) and not k.is_number:
- return default_k_splits
- elif k_splits_limit == 0:
- return []
- if (isinstance(m, sympy.Expr) and not m.is_number) or (
- isinstance(n, sympy.Expr) and not n.is_number
- ):
- max_k_split = 256
- else:
- max_k_split = min(k // m, k // n)
- min_k_split = 2
- # Get all divisors of k, k has to be divisible by kPart
- divisors = sympy.divisors(k)
- divisors = [
- divisor
- for divisor in divisors
- if divisor <= max_k_split and divisor >= min_k_split
- ]
- pow_of_2_divisors, mul_of_32_divisors, rest_of_splits = [], [], []
- for d in divisors:
- kPart = k // d
- # Smaller than 128 might not even fit in a single tile, BLOCK_K can be 128
- if kPart < 128:
- continue
- # Power of 2 divisors are best performing, conform to hardware
- if (kPart & kPart - 1) == 0 and kPart >= 128:
- pow_of_2_divisors.append(d)
- # Else check if creates a multiple of 32
- elif kPart % 32 == 0:
- mul_of_32_divisors.append(d)
- # otherwise, take the smallest values
- else:
- rest_of_splits.append(d)
- if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
- return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
- best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
- # Otherwise, conform results to k_splits_limit
- return best_splits[:k_splits_limit]
- @functools.cache
- def _rocm_native_device_arch_name(device: str) -> str:
- return torch.cuda.get_device_properties(device).gcnArchName
- @functools.cache
- def try_import_ck_lib() -> tuple[
- Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
- ]:
- try:
- import ck4inductor # type: ignore[import]
- from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
- gen_ops_library,
- gen_ops_preselected,
- )
- from ck4inductor.universal_gemm.op import ( # type: ignore[import]
- CKGemmOperation,
- )
- package_dirname = os.path.dirname(ck4inductor.__file__)
- except ImportError:
- def gen_ops_library() -> list[Any]:
- return []
- def gen_ops_preselected() -> list[Any]:
- return []
- class CKGemmOperation: # type: ignore[no-redef]
- pass
- package_dirname = None
- return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation
- def use_ck_template(layout: Layout) -> bool:
- # config knobs check 1
- if not (config.max_autotune or config.max_autotune_gemm):
- return False
- # platform check
- if not torch.version.hip:
- return False
- # tensors must be on GPU
- if not layout.device.type == "cuda":
- return False
- # hardware check
- # if config arch list is not specified, get the native arch from the device properties
- native_arch = _rocm_native_device_arch_name(layout.device)
- requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or {
- native_arch.split(":")[0]: native_arch
- }
- requested_supported_archs = [
- requested_archs[k]
- for k in requested_archs.keys() & config.rocm.ck_supported_arch
- ]
- if not requested_supported_archs:
- return False
- # supported input dtypes
- if layout.dtype not in [torch.float16, torch.bfloat16, torch.float32]:
- return False
- ck_package_dirname, _, _, _ = try_import_ck_lib()
- if not ck_package_dirname:
- log.warning("Please pip install Composable Kernel package")
- return False
- if config.is_fbcode():
- config.rocm.ck_dir = ck_package_dirname
- if not config.rocm.ck_dir:
- log.warning("Please set TORCHINDUCTOR_CK_DIR env variable")
- return False
- if ck_package_dirname != config.rocm.ck_dir:
- log.warning("Invalid path to CK library")
- return False
- return True
- def use_ck_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool:
- from .virtualized import V
- return (
- _use_autotune_backend("CK")
- and use_ck_template(layout)
- and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0
- )
- def use_ck_tile_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool:
- from .virtualized import V
- return (
- _use_autotune_backend("CKTILE")
- and use_ck_template(layout)
- and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0
- )
- def use_ck_conv_template(layout: Layout) -> bool:
- return _use_conv_autotune_backend("CK") and use_ck_template(layout)
- def _use_template_for_cpu(layout: Layout) -> bool:
- return (
- config.max_autotune or config.max_autotune_gemm
- ) and layout.device.type == "cpu"
- def use_cpp_bmm_template(
- layout: Layout, mat1: Union[ReinterpretView, Buffer], mat2: IRNode
- ) -> bool:
- from .ir import Layout
- assert isinstance(mat1.layout, Layout)
- return (
- use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False)
- and mat1.layout.is_contiguous()
- )
- def use_cpp_gemm_template(
- layout: Layout,
- mat1: IRNode,
- mat2: IRNode,
- mat2_transposed: bool = False,
- require_constant_mat2: bool = True,
- is_woq_int4: bool = False,
- q_group_size: Optional[int] = None,
- ) -> bool:
- from . import ir
- from .codegen.cpp_micro_gemm import create_micro_gemm
- from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype
- from .kernel.mm_common import mm_args
- if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"):
- return False
- if not config.cpp.weight_prepack:
- return False
- int8_gemm = mat1.get_dtype() in [torch.uint8, torch.int8]
- layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8]
- m, n, k, layout, mat1, mat2 = mm_args(
- mat1,
- mat2,
- out_dtype=layout.dtype if int8_gemm else None,
- mat2_transposed=mat2_transposed,
- use_4x2_dim=is_woq_int4,
- )
- # TODO(jgong5): support dynamic shapes for n or k
- if has_free_symbols((n, k)):
- return False
- if isinstance(mat2, ir.BaseView):
- mat2 = mat2.unwrap_view()
- output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype())
- micro_gemm = create_micro_gemm(
- "micro_gemm",
- m,
- n,
- k,
- input_dtype=mat1.get_dtype(),
- input2_dtype=mat2.get_dtype(),
- output_dtype=output_dtype,
- num_threads=parallel_num_threads(),
- use_ref=not is_woq_int4,
- q_group_size=q_group_size,
- )
- def is_last_dim_stride1(x: IRNode) -> bool:
- x.freeze_layout()
- return x.get_stride()[-1] == 1
- return (
- layout.dtype in layout_dtypes
- and micro_gemm is not None
- and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input
- and isinstance(mat2, ir.StorageBox)
- and (mat2.is_module_buffer() or not require_constant_mat2)
- )
- def use_aten_gemm_kernels() -> bool:
- return not (
- config.max_autotune or config.max_autotune_gemm
- ) or _use_autotune_backend("ATEN")
- class DebugDirManager:
- counter = itertools.count(0)
- prev_debug_name: str
- def __init__(self) -> None:
- self.id = next(DebugDirManager.counter)
- def __enter__(self) -> None:
- self.prev_debug_name = torch._dynamo.config.debug_dir_root
- self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
- torch._dynamo.config.debug_dir_root = self.new_name
- def __exit__(self, *args: Any) -> None:
- shutil.rmtree(self.new_name)
- torch._dynamo.config.debug_dir_root = self.prev_debug_name
- def run_and_get_code(
- fn: Callable[P, _T],
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> tuple[_T, list[str]]:
- from .graph import GraphLowering
- source_codes: list[str] = []
- def save_output_code(code: str) -> None:
- source_codes.append(code)
- with mock.patch.object(GraphLowering, "save_output_code", save_output_code):
- torch._dynamo.reset()
- result = fn(*args, **kwargs)
- return result, source_codes
- def run_and_get_kernels(
- fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
- ) -> tuple[_T, list[str]]:
- result, source_codes = run_and_get_code(fn, *args, **kwargs)
- kernels = []
- for code in source_codes:
- kernels.extend(re.findall(r"'''.*?'''", code, re.DOTALL))
- return result, kernels
- def run_fw_bw_and_get_code(fn: Callable[..., Any]) -> tuple[Any, list[str]]:
- def run_with_backward() -> Any:
- result = fn()
- result.sum().backward()
- return result
- return run_and_get_code(run_with_backward)
- def get_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> list[str]:
- """Get the inductor-generated code, but skip any actual compilation or running."""
- from .graph import GraphLowering
- source_codes: list[str] = []
- def save_output_code(code: str) -> None:
- source_codes.append(code)
- def patched_compile_to_module(self: GraphLowering) -> Any:
- class DummyModule:
- """This is empty to replace the generated triton module"""
- def __init__(self) -> None:
- pass
- def call(self, *args: Any, **kwargs: Any) -> None:
- # Don't do anything when called
- pass
- wrapper_code, kernel_code = (
- self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
- )
- # Skip all the actual compiling.
- save_output_code(wrapper_code.value)
- if kernel_code:
- save_output_code(kernel_code.value)
- return DummyModule()
- with (
- mock.patch.object(
- GraphLowering, "compile_to_module", patched_compile_to_module
- ),
- mock.patch.object(GraphLowering, "save_output_code", save_output_code),
- ):
- torch._dynamo.reset()
- # Note the return here is None
- _ = fn(*args, **kwargs)
- return source_codes
- def get_triton_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> str:
- source_codes = get_code(fn, *args, **kwargs)
- # Can have two outputs if backwards was eagerly compiled
- assert 1 <= len(source_codes) <= 2, (
- f"expected one or two code outputs got {len(source_codes)}"
- )
- return source_codes[0]
- def run_and_get_triton_code(
- fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
- ) -> str:
- _, source_codes = run_and_get_code(fn, *args, **kwargs)
- # Can have two outputs if backwards was eagerly compiled
- assert 1 <= len(source_codes) <= 2, (
- f"expected one or two code outputs got {len(source_codes)}"
- )
- return source_codes[0]
- def run_and_get_graph_lowering(
- fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
- ) -> tuple[Any, list[GraphLowering]]:
- from torch._inductor.graph import GraphLowering
- from torch._inductor.output_code import CompiledFxGraph
- real_init = CompiledFxGraph.__init__
- graph_lowerings = []
- def fake_init(*args: Any, **kwargs: Any) -> None:
- real_init(*args, **kwargs)
- graph = args[2]
- assert isinstance(graph, GraphLowering)
- graph_lowerings.append(graph)
- with mock.patch.object(CompiledFxGraph, "__init__", fake_init):
- result = fn(*args, **kwargs)
- return result, graph_lowerings
- @contextlib.contextmanager
- def override_lowering(
- aten_op: Callable[..., Any], override_fn: Callable[..., Any]
- ) -> Iterator[None]:
- """
- Override the lowering of aten_op with override_fn.
- The first argument of override_fn is the original lowering fn.
- """
- from torch._inductor import lowering
- orig_fn = lowering.lowerings[aten_op]
- try:
- lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
- yield
- finally:
- lowering.lowerings[aten_op] = orig_fn
- def add_scheduler_init_hook(
- pre_fn: Callable[..., Any], post_fn: Optional[Callable[..., Any]] = None
- ) -> Any:
- """
- Add hook functions to be called at the beginning and end of Scheduler.__init__.
- Used for unit tests.
- """
- from torch._inductor.scheduler import Scheduler
- orig_fn = Scheduler.__init__
- def wrapper(scheduler: Any, nodes: Any) -> Any:
- pre_fn(scheduler, nodes)
- out = orig_fn(scheduler, nodes)
- if post_fn:
- post_fn(scheduler, nodes)
- return out
- return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
- def developer_warning(msg: str) -> None:
- """
- Warnings that will be actionable for PyTorch developers, but not
- end users. Allows us to easily disable them in stable releases but
- keep them on for nightly builds.
- """
- if config.developer_warnings:
- log.warning(msg)
- else:
- log.info(msg)
- def get_benchmark_name() -> Optional[str]:
- """
- An experimental API used only when config.benchmark_kernel is true.
- The benchmark name is only available at codegen time. So we can not
- directly call it in benchmark_all_kernels which is run after codegen.
- The function assumes the argument after --only is the benchmark name.
- It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
- scripts, this function may return None.
- There are 2 flavors of --only argument we need handle:
- 1. --only model_name
- 2. --only=model_name
- """
- try:
- idx = sys.argv.index("--only")
- if (
- idx + 1 < len(sys.argv)
- and len(sys.argv[idx + 1]) > 0
- and sys.argv[idx + 1][0] != "-"
- ):
- return sys.argv[idx + 1]
- except ValueError:
- pass
- for arg in sys.argv:
- if arg.startswith("--only="):
- return arg[len("--only=") :]
- return None
- def is_ones(items: Sequence[Any]) -> bool:
- return all(x == 1 for x in items)
- def is_zeros(items: Sequence[Any]) -> bool:
- return all(x == 0 for x in items)
- def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool:
- return all(
- item.device == torch.device("cpu")
- for item in inputs
- if isinstance(item, torch.Tensor)
- )
- def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
- assert isinstance(val, sympy.Expr), (
- "only support sympy.Expr as input to get_sympy_Expr_dtype"
- )
- if val.is_integer: # type: ignore[attr-defined]
- return torch.int64
- else:
- return torch.float64
- @contextlib.contextmanager
- def maybe_profile(should_profile: bool, *args: Any, **kwargs: Any) -> Iterator[Any]:
- if should_profile:
- with torch.profiler.profile(*args, **kwargs) as p:
- yield p
- else:
- yield
- def parallel_num_threads() -> int:
- threads = config.cpp.threads
- if threads < 1:
- threads = torch.get_num_threads()
- return threads
- @functools.cache
- def get_backend_num_stages() -> int:
- from .runtime.triton_helpers import get_backend_options
- options = get_backend_options()
- return options.get("num_stages", 2 if torch.version.hip else 3)
- @functools.cache
- def get_device_tflops(dtype: torch.dtype) -> float:
- """
- We don't want to throw errors in this function. First check to see if the device is in device_info.py,
- then fall back to the inaccurate triton estimation.
- """
- ds_tops = datasheet_tops(dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32)
- if ds_tops is not None:
- return ds_tops
- from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
- SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
- 8,
- 0,
- )
- assert dtype in (torch.float16, torch.bfloat16, torch.float32)
- if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
- # Triton API change in https://github.com/triton-lang/triton/pull/2293
- from torch._utils_internal import max_clock_rate
- sm_clock = max_clock_rate()
- if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
- return get_max_tensorcore_tflops(dtype, sm_clock)
- if torch.backends.cuda.matmul.allow_tf32:
- return get_max_tensorcore_tflops(torch.float32, sm_clock)
- else:
- return get_max_simd_tflops(torch.float32, sm_clock)
- else:
- if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
- return get_max_tensorcore_tflops(dtype)
- if torch.backends.cuda.matmul.allow_tf32:
- return get_max_tensorcore_tflops(torch.float32)
- else:
- return get_max_simd_tflops(torch.float32)
- @functools.cache
- def get_gpu_dram_gbps() -> int:
- from triton.testing import get_dram_gbps
- return get_dram_gbps()
- def get_gpu_shared_memory() -> int:
- from triton.runtime import driver
- return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
- def is_welford_reduction(reduction_type: str) -> bool:
- return reduction_type.startswith("welford")
- def reduction_num_outputs(reduction_type: str) -> int:
- if is_welford_reduction(reduction_type):
- return 3
- elif reduction_type == "online_softmax_reduce":
- return 2
- else:
- return 1
- def is_linux() -> bool:
- return platform.system() == "Linux"
- def is_windows() -> bool:
- return sys.platform == "win32"
- def has_free_symbols(itr: Iterable[Any]) -> bool:
- return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
- def is_dynamic(*args: Any) -> bool:
- from . import ir
- for t in args:
- if isinstance(
- t, (ir.TensorBox, ir.StorageBox, ir.BaseView, ir.ComputedBuffer, ir.Buffer)
- ):
- if has_free_symbols(t.maybe_get_size() or ()) or has_free_symbols(
- t.maybe_get_stride() or ()
- ):
- return True
- elif not isinstance(t, ir.IRNode):
- continue
- else:
- raise TypeError(f"unexpected type for is_dynamic {type(t)}")
- return False
- # Placeholder strings used in triton codegen.
- class Placeholder(enum.Enum):
- # The placeholder for the actual name of a triton kernel.
- # e.g. for "def triton_" it would be "triton_"
- KERNEL_NAME = "KERNEL_NAME"
- # The descriptive name of the triton kernel; when unique_kernel_names = False, this
- # placeholder will be replaced with a string with more information.
- DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
- def pass_execution_and_save(
- func: Callable[..., Any], gm: GraphModule, inp: Sequence[Any], msg: str
- ) -> None:
- from .pattern_matcher import stable_topological_sort
- with tempfile.NamedTemporaryFile(
- mode="w",
- encoding="utf-8",
- delete=False,
- ) as f:
- before_io = io.StringIO()
- after_io = io.StringIO()
- ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
- print(f"Before:\n{gm.graph}", file=f)
- print(gm.graph, file=before_io)
- start_time = datetime.now()
- with GraphTransformObserver(gm, msg):
- func(gm.graph)
- time_elapsed = datetime.now() - start_time
- # recompile graph
- stable_topological_sort(gm.graph)
- gm.graph.lint()
- gm.recompile()
- print(f"After:\n{gm.graph}", file=f)
- print(gm.graph, file=after_io)
- t = before_io.getvalue() == after_io.getvalue()
- log.info(
- "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
- msg,
- f.name,
- t,
- time_elapsed,
- )
- def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) -> bool:
- """
- Check if input buffer is a multi-outputs template buffer
- """
- from . import ir
- return isinstance(input_buf, ir.CppTemplateBuffer) and isinstance(
- input_buf.layout, ir.MultiOutputLayout
- )
- def is_output_of_multi_outputs_template(
- input_buf: Optional[Union[Buffer, Operation]],
- ) -> bool:
- """
- Check if input buffer is a output of multi-outputs template buffer
- """
- from . import ir
- return (
- isinstance(input_buf, ir.MultiOutput)
- and len(input_buf.inputs) == 1
- and is_multi_outputs_template(input_buf.inputs[0]) # type: ignore[arg-type]
- )
- def is_collective(
- node: Optional[Union[Node, Operation]],
- op: Optional[torch._ops.OperatorBase] = None,
- ) -> bool:
- if node is None:
- return False
- from . import ir
- return (
- isinstance(node, ir._CollectiveKernel)
- and not isinstance(node, ir._WaitKernel)
- and (op is None or node.op_overload is op)
- ) or (
- # TODO: this is a temporary solution to ensure that we can identify torchrec's
- # communication ops. But in order to allow better communication and computation
- # overlap, torchrec's communication ops should be not used.
- type(node) == ir.FallbackKernel
- and (
- # NOTE: the `hasattr()` check is to bypass errors such as the following:
- # AttributeError: '_OpNamespace' 'torchrec' object has no attribute 'all_to_all_single'
- (
- hasattr(torch.ops.torchrec, "all_to_all_single")
- and node.op_overload == torch.ops.torchrec.all_to_all_single.default
- )
- or (
- hasattr(torch.ops.torchrec, "all_gather_into_tensor")
- and node.op_overload
- == torch.ops.torchrec.all_gather_into_tensor.default
- )
- or (
- hasattr(torch.ops.torchrec, "reduce_scatter_tensor")
- and node.op_overload == torch.ops.torchrec.reduce_scatter_tensor.default
- )
- )
- )
- def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool:
- from . import ir
- return type(node) == ir._WaitKernel
- def contains_collective(snode: BaseSchedulerNode) -> bool:
- from torch._inductor.scheduler import GroupedSchedulerNode
- if isinstance(snode, GroupedSchedulerNode):
- return any(contains_collective(x) for x in snode.snodes)
- return is_collective(snode.node)
- def contains_wait(snode: BaseSchedulerNode) -> bool:
- from torch._inductor.scheduler import GroupedSchedulerNode
- if isinstance(snode, GroupedSchedulerNode):
- return any(contains_wait(x) for x in snode.snodes)
- else:
- return is_wait(snode.node)
- def is_fallback_op(
- node: Optional[Operation],
- op: Union[torch._ops.OpOverload, Collection[torch._ops.OpOverload]],
- ) -> bool:
- from . import ir
- if isinstance(op, torch._ops.OpOverload):
- op = [op]
- return isinstance(node, ir.FallbackKernel) and node.op_overload in op
- def buf_name_to_fused_snode(
- buf_name: str, name_to_buf: dict[str, Any], name_to_fused_node: dict[str, Any]
- ) -> Any:
- return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()]
- def find_recursive_deps_of_node(
- snode: BaseSchedulerNode,
- collected_node_set: MutableSet[BaseSchedulerNode],
- name_to_buf: dict[str, SchedulerBuffer],
- name_to_fused_node: dict[str, BaseSchedulerNode],
- criteria_cb: Callable[[Any], bool] = lambda snode: False,
- ) -> None:
- if criteria_cb(snode):
- return
- collected_node_set.add(snode)
- for dep in snode.unmet_dependencies:
- defining_op_for_dep = buf_name_to_fused_snode(
- dep.name, name_to_buf, name_to_fused_node
- )
- if defining_op_for_dep in collected_node_set:
- continue
- find_recursive_deps_of_node(
- defining_op_for_dep,
- collected_node_set,
- name_to_buf,
- name_to_fused_node,
- criteria_cb=criteria_cb,
- )
- def find_recursive_users_of_node(
- snode: BaseSchedulerNode,
- collected_node_set: MutableSet[BaseSchedulerNode],
- name_to_buf: dict[str, SchedulerBuffer],
- name_to_fused_node: dict[str, BaseSchedulerNode],
- criteria_cb: Callable[[Any], bool] = lambda snode: False,
- ) -> None:
- if criteria_cb(snode):
- return
- collected_node_set.add(snode)
- for o in snode.get_outputs():
- for user in o.users:
- assert user.node is not None
- if user.node.get_name() == "OUTPUT":
- continue
- if user.node.get_name() not in name_to_fused_node:
- continue
- user_op = name_to_fused_node[user.node.get_name()]
- if user_op in collected_node_set:
- continue
- find_recursive_users_of_node(
- user_op,
- collected_node_set,
- name_to_buf,
- name_to_fused_node,
- criteria_cb=criteria_cb,
- )
- def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int) -> int:
- "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
- num_rng_seed_offset_inputs = (
- 2 if torch._functorch.config.functionalize_rng_ops else 0
- )
- # AOT won't lift any parameters if we're inlining NN Modules
- # however desugaring subclasses will still add arguments
- # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502
- return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
- def count_tangents(fx_g: torch.fx.GraphModule) -> int:
- """
- Infers which inputs are static for a backwards graph
- """
- def is_saved_tensor(x: Node) -> bool:
- return (
- "tangents" not in x.name
- and "bwd_seed" not in x.name
- and "bwd_base_offset" not in x.name
- and "bwd_rng_state" not in x.name
- )
- arg_count = 0
- static_arg_idxs = []
- for n in fx_g.graph.nodes:
- if n.op == "placeholder":
- if is_saved_tensor(n):
- static_arg_idxs.append(arg_count)
- arg_count += 1
- assert static_arg_idxs == list(range(len(static_arg_idxs)))
- return len(static_arg_idxs)
- @dataclasses.dataclass
- class BoxedBool:
- value: bool
- def __bool__(self) -> bool:
- return self.value
- @staticmethod
- def disable(obj: Any) -> Union[BoxedBool, bool]:
- if isinstance(obj, BoxedBool):
- obj.value = False
- return obj
- return False
- @contextlib.contextmanager
- def collect_defined_kernels(kernel_list: list[str]) -> Iterator[None]:
- from .codegen.wrapper import PythonWrapperCodegen
- orig_define_kernel = PythonWrapperCodegen.define_kernel
- def define_kernel(
- self: PythonWrapperCodegen,
- kernel_name: str,
- kernel_code: str,
- metadata: Optional[str] = None,
- gpu: bool = True,
- cpp_definition: Optional[str] = None,
- ) -> Any:
- kernel_list.append(kernel_code)
- return orig_define_kernel(
- self, kernel_name, kernel_code, metadata, gpu, cpp_definition
- )
- with mock.patch.object(PythonWrapperCodegen, "define_kernel", define_kernel):
- yield
- def get_cloned_parameter_buffer_name(name: str) -> str:
- return name + "__original__"
- def is_gpu(device: Optional[str]) -> bool:
- return device in GPU_TYPES
- def device_need_guard(device: str) -> bool:
- return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now
- def needs_fallback_due_to_atomic_add_limitations(dtype: torch.dtype) -> bool:
- # tl.atomic add has bfloat16 support in fbcode
- # but not in OSS https://github.com/pytorch/pytorch/issues/97016
- # we will fallback until the code is upstreamed to OSS
- if (
- config.is_fbcode()
- and dtype == torch.bfloat16
- and torch.cuda.is_available()
- and torch.cuda.get_device_capability() >= (9, 0)
- and config.bfloat16_atomic_adds_enabled
- ):
- return False
- else:
- return dtype in OrderedSet([torch.int64, torch.bool, torch.bfloat16])
- def use_scatter_fallback(
- op_overload: torch._ops.OpOverload,
- reduction_type: Optional[str],
- self_dtype: torch.dtype,
- src_dtype: torch.dtype,
- src_device_type: str,
- src_is_tensor: bool,
- ) -> bool:
- if (
- op_overload.overloadpacket
- in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce)
- and reduction_type is None
- ):
- return False
- reduce_ty = (
- "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum"
- )
- return (
- reduction_type not in (None, reduce_ty)
- or (
- src_is_tensor
- and is_gpu(src_device_type)
- and needs_fallback_due_to_atomic_add_limitations(src_dtype)
- )
- or (
- op_overload.overloadpacket == torch.ops.aten.scatter_reduce_
- and reduction_type == "sum"
- and src_is_tensor
- and src_device_type == "cpu"
- and config.cpp.fallback_scatter_reduce_sum
- and (config.cpp.dynamic_threads or parallel_num_threads() != 1)
- )
- or (reduction_type == reduce_ty and self_dtype in (torch.bool, torch.int64))
- or torch.are_deterministic_algorithms_enabled()
- )
- def dump_node_schedule(node_schedule: Sequence[BaseSchedulerNode]) -> None:
- """
- An API that can be used in pdb to dump a node_schedule.
- Right mainly dump the read/write dependencies but can add more as needed.
- """
- from torch._inductor.codegen.simd import DisableReduction, EnableReduction
- from torch._inductor.scheduler import SchedulerNode
- print(f"Node schedule with {len(node_schedule)} nodes")
- for idx, node in enumerate(node_schedule):
- print(f" {idx:3}:")
- if node is EnableReduction:
- print("enable reduction")
- elif node is DisableReduction:
- print("disable reduction")
- elif isinstance(node, SchedulerNode):
- is_red = node.is_reduction()
- print(f"{'red' if is_red else 'pw'} scheduler node")
- if is_red:
- assert node.node is not None
- print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined]
- print("ReadDep:")
- for dep in node.read_writes.reads:
- print(dep)
- print("WriteDep:")
- for dep in node.read_writes.writes:
- print(dep)
- else:
- raise RuntimeError(f"Unrecognized node type: {type(node)}")
- def tensor_is_aligned(tensor: torch.Tensor) -> bool:
- # See Note: [Input Alignment handling in Inductor]
- # Right now, we don't try to guard on the alignment of the storage offset.
- # When this comment was written, non-symbolic storage_offsets are not guarded on
- # but symbolic storage_offsets are. For consistency, we suppress guard creation
- # upon performing this check: that ensures that we don't add recompiles when we
- # add this logic.
- from torch.fx.experimental.symbolic_shapes import statically_known_true
- return statically_known_true(
- (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0
- )
- def should_assume_input_aligned(example_input: torch.Tensor) -> bool:
- # See Note: [Input Alignment handling in Inductor]
- # right now, we only care about alignment for cuda tensors.
- if not is_gpu(example_input.device.type):
- return False
- return config.assume_aligned_inputs or tensor_is_aligned(example_input)
- def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[None]:
- # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards()
- # If it's not available, return a nullcontext.
- # If we're dealing with cudagraphs, we might not have a tracing_context
- tracing_context = torch._guards.TracingContext.try_get()
- if not tracing_context:
- return contextlib.nullcontext()
- # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
- if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env:
- return contextlib.nullcontext()
- shape_env = tracing_context.fake_mode.shape_env
- return shape_env.suppress_guards()
- def run_and_get_cpp_code(
- fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
- ) -> tuple[_T, str]:
- # We use the patch context manager instead of using it as a decorator.
- # In this way, we can ensure that the attribute is patched and unpatched correctly
- # even if this run_and_get_cpp_code function is called multiple times.
- with unittest.mock.patch.object(config, "debug", True):
- torch._dynamo.reset()
- import io
- import logging
- log_capture_string = io.StringIO()
- ch = logging.StreamHandler(log_capture_string)
- from torch._inductor.codecache import output_code_log
- output_code_log.addHandler(ch)
- prev_level = output_code_log.level
- output_code_log.setLevel(logging.DEBUG)
- result = fn(*args, **kwargs)
- s = log_capture_string.getvalue()
- output_code_log.setLevel(prev_level)
- output_code_log.removeHandler(ch)
- return result, s
- def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
- fake_mode = detect_fake_mode(inputs)
- # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
- # pass in real inputs for now.
- # if len(inputs) > 0:
- # assert fake_mode is not None, breakpoint()
- if fake_mode is not None:
- return fake_mode.shape_env
- # When there are no tensor inputs, get shape_env from the first SymInt.
- for input in inputs:
- if isinstance(input, torch.SymInt):
- return input.node.shape_env
- # TODO(voz): Should we always have one anyway?
- return None
- def align_inputs_from_check_idxs(
- model: Callable[[list[InputType]], _T],
- inputs_to_check: Sequence[int],
- mutated_input_idxs: OrderedSet[int],
- ) -> Callable[[list[InputType]], _T]:
- if len(inputs_to_check) == 0:
- return model
- def run(new_inputs: list[InputType]) -> Any:
- old_tensors, new_tensors = copy_misaligned_inputs(
- new_inputs, inputs_to_check, mutated_input_idxs
- )
- out = model(new_inputs)
- # If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the
- # original tensor.
- if len(old_tensors):
- torch._foreach_copy_(old_tensors, new_tensors)
- return out
- return run
- def clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
- if 0 in x.size():
- # Short-circuits if the shape has no elements
- needed_size = 0
- else:
- needed_size = (
- sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
- )
- buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
- return torch.as_strided(buffer, x.size(), x.stride())
- def copy_misaligned_inputs(
- new_inputs: list[InputType],
- check_inputs_idxs: Sequence[int],
- return_pair_idxs: Optional[OrderedSet[int]] = None,
- ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
- """
- Clones misaligned tensors which we inferred were aligned. Returns a tuple of [old_tensors], [new_tensors] for every
- cloned tensor which is in `return_pair_idxs`.
- """
- old_tensors: list[torch.Tensor] = []
- new_tensors: list[torch.Tensor] = []
- # hoist above loop because this is on the hot path
- ret_pair_defined = return_pair_idxs is not None
- for i in check_inputs_idxs:
- _inp = new_inputs[i]
- assert isinstance(_inp, torch.Tensor), (
- f"Expected tensors only, but got: {type(_inp)}"
- )
- if _inp.data_ptr() % ALIGNMENT:
- new_inputs[i] = clone_preserve_strides(_inp)
- if ret_pair_defined and i in return_pair_idxs: # type: ignore[operator]
- old_tensors.append(_inp)
- new_tensors.append(new_inputs[i]) # type: ignore[arg-type]
- return old_tensors, new_tensors
- def remove_unaligned_input_idxs(
- inputs: Sequence[InputType],
- static_input_idxs: Sequence[int],
- ) -> Sequence[int]:
- """
- We require all inputs to be aligned, so introduce a copy for any
- that aren't.
- """
- aligned_static_input_idxs = []
- for idx in static_input_idxs:
- input = inputs[idx]
- if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
- aligned_static_input_idxs.append(idx)
- if len(aligned_static_input_idxs) != len(static_input_idxs):
- return aligned_static_input_idxs
- return static_input_idxs
- def expr_fits_within_32bit(e: sympy.Expr) -> bool:
- from .virtualized import V
- int_max = torch.iinfo(torch.int32).max
- size_hint = V.graph.sizevars.size_hint
- has_hint = V.graph.sizevars.shape_env.has_hint
- # Allow for unhinted e as long as we can still statically prove
- # (e.g., via ValueRanges) that it is still in bounds
- if V.graph.sizevars.statically_known_true(e <= int_max):
- return True
- # AOTI doesn't guard on < 2**32, so checking hints isn't a viable option,
- # in case the hinted value is < 2**32, but the allowed range is larger.
- # However, to prevent possible perf regressions on pre-existing AOTI models
- # which don't set an upper bound on the valid range, we'll skip the check.
- # To recap:
- # - If using AOTI:
- # - If allowed range has no upper bound, then check the hint to determine
- # whether this fits in int32
- # - If allowed range does have an upper bound, then obey the upper bound
- # (check whether upper bound < int32_max) without checking the hint.
- if V.aot_compilation:
- # check whether value has an upper bound (1e20 is > INT64_MAX, assume
- # there is no upper bound if it can be larger than 1e20)
- if V.graph.sizevars.statically_known_true(e < 1e20):
- # if so, then assume int_max < upper bound < inf
- # so this could potentially have int64 values
- return False
- # Otherwise, the hint MUST exist and be in range
- return has_hint(e) and size_hint(e) <= int_max
- def set_tracing_context_output_strides(
- example_inputs: Sequence[Any], compiled_graph: CompiledFxGraph
- ) -> None:
- # Return the output strides to the caller via TracingContext
- context = torch._guards.TracingContext.try_get()
- if context is not None and context.output_strides is not None:
- assert len(context.output_strides) == 0
- shape_env = shape_env_from_inputs(example_inputs)
- assert compiled_graph.output_strides is not None
- for exprs in compiled_graph.output_strides:
- if exprs is None:
- context.output_strides.append(None)
- else:
- fakify_first_call = False
- if ctx := torch._guards.TracingContext.try_get():
- fakify_first_call = ctx.fakify_first_call
- def map_expr(e: Any) -> Union[float, int, SymInt, SymFloat, SymBool]:
- if shape_env is None:
- return int(e)
- if fakify_first_call:
- return shape_env.deserialize_symexpr(e)
- return shape_env.evaluate_symexpr(e)
- context.output_strides.append(
- tuple(map_expr(e) for e in exprs) # type: ignore[misc]
- )
- def should_use_remote_fx_graph_cache() -> bool:
- if config.fx_graph_remote_cache is not None:
- return config.fx_graph_remote_cache
- if not config.is_fbcode():
- return False
- if torch._utils_internal.is_fb_unit_test():
- return False
- try:
- from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
- except ModuleNotFoundError:
- return False
- return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
- "pytorch/remote_cache:fx_graph_memcache_version"
- )
- def normalize_name(name: str) -> str:
- return re.sub(r"[^a-zA-Z0-9_]", "_", name)
- # correct cases where Triton types names don't match PyTorch
- _triton_type_mapping = {
- "tl.bool": "tl.int1",
- "tl.float8_e4m3fn": "tl.float8e4nv",
- "tl.float8_e5m2": "tl.float8e5",
- "tl.float8_e4m3fnuz": "tl.float8e4b8",
- "tl.float8_e5m2fnuz": "tl.float8e5b16",
- # TODO: remove when support is added in triton
- # https://github.com/triton-lang/triton/issues/6054
- "tl.float8_e8m0fnu": "tl.uint8",
- "tl.float4_e2m1fn_x2": "tl.uint8",
- }
- _torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()}
- _triton_type_re = re.compile(r"^.*[.]")
- def triton_type(dtype: torch.dtype) -> str:
- """Convert torch.dtype to triton type"""
- triton_type_name = _triton_type_re.sub("tl.", str(dtype))
- return _triton_type_mapping.get(triton_type_name, triton_type_name)
- def triton_type_to_torch(dtype: str) -> torch.dtype:
- adjusted_type = _torch_triton_mapping.get(dtype, dtype)
- type_name = adjusted_type.replace("tl.", "")
- out_dtype = getattr(torch, type_name)
- assert isinstance(out_dtype, torch.dtype)
- return out_dtype
- def is_same_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
- return (
- not data.is_mkldnn
- and data.size() == value.size()
- and data.stride() == value.stride()
- and data.dtype == value.dtype
- and data.device == value.device
- and data.untyped_storage().data_ptr() == value.untyped_storage().data_ptr()
- and data.storage_offset() == value.storage_offset()
- )
- def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
- return (
- data.is_mkldnn
- and data.size() == value.size()
- and data.dtype == value.dtype
- and data.device == value.device
- and torch.ops.mkldnn.data_ptr(data) == torch.ops.mkldnn.data_ptr(value)
- )
- @functools.cache
- def boolean_ops() -> tuple[str, ...]:
- return (
- "isinf",
- "isnan",
- "logical_not",
- "logical_and",
- "signbit",
- "and_",
- "le",
- "lt",
- "ge",
- "gt",
- "eq",
- "ne",
- "or_", # TODO should remove this op
- "xor",
- )
- @dataclasses.dataclass
- class OpDtypeRule:
- type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
- override_return_dtype: Optional[torch.dtype]
- op_dtype_propagation_rules: dict[str, OpDtypeRule] = {}
- def register_op_dtype_propagation_rules(
- name: str,
- type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
- override_return_dtype: Optional[torch.dtype],
- ) -> None:
- op_dtype_propagation_rules[name] = OpDtypeRule(
- type_promotion_kind, override_return_dtype
- )
- op_requires_libdevice_fp64: OrderedSet[str] = OrderedSet()
- def register_op_requires_libdevice_fp64(name: str) -> None:
- op_requires_libdevice_fp64.add(name)
- def get_current_backend() -> str:
- from torch._inductor.virtualized import V
- device_str = V.graph.get_current_device_or_throw().type
- if device_str == "cpu":
- return config.cpu_backend
- elif device_str == "mps":
- return "mps"
- else:
- return config.cuda_backend
- def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:
- """Maybe upcast [b]float16 to float32"""
- if (
- dtype in (torch.float16, torch.bfloat16)
- and config.triton.codegen_upcast_to_fp32
- and get_current_backend() == "triton"
- ):
- return torch.float32
- return dtype
- KeyType = TypeVar("KeyType")
- ValType = TypeVar("ValType")
- class ScopedDict(MutableMapping[KeyType, ValType]):
- """
- A dictionary-like object that allows for scoped updates. It maintains
- an original dictionary and a set of new items that can override
- the original items within the scope. The original dictionary is
- unmodified.
- """
- def __init__(self, original_dict: Mapping[KeyType, ValType]):
- self.original_dict = original_dict
- self.new_items: dict[KeyType, ValType] = {}
- def __getitem__(self, key: KeyType) -> ValType:
- if key in self.new_items:
- return self.new_items[key]
- return self.original_dict[key]
- def __setitem__(self, key: KeyType, value: ValType) -> None:
- self.new_items[key] = value
- def __contains__(self, key: object) -> bool:
- return key in self.new_items or key in self.original_dict
- def get(self, key: KeyType, default: Optional[ValType] = None) -> Optional[ValType]: # type: ignore[override]
- if key in self.new_items:
- return self.new_items[key]
- return self.original_dict.get(key, default)
- def __len__(self) -> int:
- n = len(self.original_dict)
- for k in self.new_items:
- if k not in self.original_dict:
- n += 1
- return n
- def __iter__(self) -> Iterator[KeyType]:
- yield from self.original_dict
- for k in self.new_items:
- if k not in self.original_dict:
- yield k
- def __bool__(self) -> bool:
- return bool(self.original_dict or self.new_items)
- def __delitem__(self, key: KeyType) -> None:
- raise NotImplementedError
- @dataclass_transform(frozen_default=True)
- def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any:
- def wrap(cls: _T) -> _T:
- if sys.version_info >= (3, 10):
- return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload]
- else:
- # Polyfill for python=3.9. kw_only simply introduces an extra check
- # that only kwargs are used (and is not available on 3.9)
- return dataclasses.dataclass(cls, frozen=frozen)
- if cls is None:
- return wrap
- return wrap(cls)
- def get_donated_idxs() -> Optional[list[int]]:
- tracing_context = torch._guards.TracingContext.try_get()
- if tracing_context is not None and tracing_context.fw_metadata:
- return tracing_context.fw_metadata.bw_donated_idxs
- return None
- class TritonAttrsDescriptorVersion(enum.Enum):
- V0_NO_TRITON = 0
- V1_COMPILER = 1 # triton.compiler.compiler.AttrsDescriptor
- V2_BACKENDS = 2 # triton.backends.compiler.AttrsDescriptor
- V3_BACKENDS_TUPLE = (
- 3 # triton.backends.compiler.AttrsDescriptor, but with tuple support
- )
- V4_DICT = 4 # a raw dict
- @functools.cache
- def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion:
- if importlib.util.find_spec("triton") is None:
- return TritonAttrsDescriptorVersion.V0_NO_TRITON
- import triton.backends.compiler
- import triton.compiler.compiler
- if hasattr(triton.backends.compiler, "AttrsDescriptor"):
- # Triton 3.2.0
- # AttrsDescriptor was moved from triton.compiler.compiler to triton.backends.compiler.
- # AttrsDescriptor and its serialization format were also changed.
- # TODO: implement V3_BACKENDS_TUPLE
- # On Dec 9, 2024, tuple support (triton #5220) was implemented and breaks handling.
- # We don't have a way to detect this (and haven't implemented this version)
- return TritonAttrsDescriptorVersion.V2_BACKENDS
- elif hasattr(triton.compiler.compiler, "AttrsDescriptor"):
- # Triton 3.0.0
- return TritonAttrsDescriptorVersion.V1_COMPILER
- else:
- # After Jan 1, 2025
- # AttrsDescriptor was removed and replaced with a raw dict.
- return TritonAttrsDescriptorVersion.V4_DICT
- def triton_version_uses_attrs_dict() -> bool:
- return get_triton_attrs_descriptor_version() == TritonAttrsDescriptorVersion.V4_DICT
- def is_cudagraph_unsafe_op(node: Operation) -> bool:
- """
- Returns True if the node is an op that is not cudagraphable.
- Usually only custom ops have this tag.
- """
- from . import ir
- if not isinstance(node, ir.FallbackKernel):
- return False
- if (
- isinstance(node.op_overload, torch._ops.OpOverload)
- and torch._C.Tag.cudagraph_unsafe in node.op_overload.tags # type: ignore[attr-defined]
- ):
- return True
- return False
- def get_ld_library_path() -> str:
- path = os.environ.get("LD_LIBRARY_PATH", "")
- if config.is_fbcode():
- from libfb.py.parutil import get_runtime_path
- runtime_path = get_runtime_path()
- if runtime_path:
- lib_path = os.path.join(runtime_path, "runtime", "lib")
- path = os.pathsep.join([lib_path, path]) if path else lib_path
- return path
- def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
- from torch._inductor.codegen.wrapper import SubgraphPythonWrapperCodegen
- return (
- isinstance(wrapper, SubgraphPythonWrapperCodegen)
- and wrapper.partition_signatures is not None
- )
- def is_using_cudagraph_partition() -> bool:
- return (
- torch._inductor.config.triton.cudagraphs
- or _unstable_customized_partition_wrapper.wrapper is not None
- ) and torch._inductor.config.graph_partition
- def dtype_from_size(size: int) -> torch.dtype:
- from .virtualized import V
- if V.graph.sizevars.statically_known_lt(
- size, 2**31
- ) and V.graph.sizevars.statically_known_geq(size, -(2**31)):
- return torch.int32
- else:
- return torch.int64
- SUPPORTED_MKLDNN_DEVICES = ("cpu", "xpu")
- def is_mkldnn_bf16_supported(device_type: str) -> bool:
- """
- Returns True if the device supports MKL-DNN BF16.
- """
- if device_type == "cpu":
- return torch.ops.mkldnn._is_mkldnn_bf16_supported()
- elif "xpu" in device_type:
- # match "xpu", "xpu:0", "xpu:1", etc.
- return True
- return False
- def is_mkldnn_fp16_supported(device_type: str) -> bool:
- """
- Returns True if the device supports MKL-DNN FP16.
- """
- if device_type == "cpu":
- return torch.ops.mkldnn._is_mkldnn_fp16_supported()
- elif "xpu" in device_type:
- # match "xpu", "xpu:0", "xpu:1", etc.
- return True
- return False
- def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
- widths = [len(str(e)) for e in headers]
- for row in elements:
- assert len(row) == len(headers)
- for i, e in enumerate(row):
- widths[i] = max(widths[i], len(str(e)))
- lines = []
- lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths)))
- # widths whitespace horizontal separators
- total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
- lines.append("-" * total_width)
- for row in elements:
- lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths)))
- return "\n".join(lines)
- def zip_dicts(
- dict1: Mapping[KeyType, ValType],
- dict2: Mapping[KeyType, ValType],
- d1_default: ValType | None = None,
- d2_default: ValType | None = None,
- ) -> Generator[tuple[KeyType, ValType | None, ValType | None], None, None]:
- """
- Zip two dictionaries together, replacing missing keys with default values.
- Args:
- dict1 (dict): The first dictionary.
- dict2 (dict): The second dictionary.
- d1_default (Any): the default value for the first dictionary
- d2_default (Any): the default value for the second dictionary
- Yields:
- tuple: A tuple containing the key, the value from dict1 (or d1_default if missing),
- and the value from dict2 (or d2_default if missing).
- """
- # Find the union of all keys
- all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys())
- # Iterate over all keys
- for key in all_keys:
- # Get the values from both dictionaries, or default if missing
- value1 = dict1.get(key)
- value2 = dict2.get(key)
- yield (
- key,
- value1 if value1 is not None else d1_default,
- value2 if value2 is not None else d2_default,
- )
- def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, Any]:
- """
- Ensures the configuration is internally consistent for standalone AOTInductor.
- If `aot_inductor.compile_standalone` is set to True in the provided
- `config_patches` (or falls back to the global config), this function ensures
- that the following configs are also enabled:
- - `aot_inductor.package_cpp_only`
- Args:
- config_patches (dict[str, Any]): A dictionary of user-provided config
- overrides for AOTInductor compilation.
- Returns:
- dict[str, Any]: The possibly-updated `config_patches` dictionary.
- """
- def patch_config(
- config_patches: dict[str, Any], config_name: str, config_value: Any
- ) -> None:
- value = config_patches.get(config_name, getattr(config, config_name))
- if value is None:
- config_patches[config_name] = config_value
- elif not value and value != config_value:
- raise RuntimeError(
- f"Invalid config: {config_name}={config_value} when aot_inductor.compile_standalone is True."
- )
- compile_standalone = config_patches.get(
- "aot_inductor.compile_standalone", config.aot_inductor.compile_standalone
- )
- # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing
- config_patches = config_patches.copy()
- if compile_standalone:
- # Standlaone AOTInductor means only generate cpp project for building a standalone binary
- patch_config(config_patches, "aot_inductor.package_cpp_only", True)
- # Standlaone AOTInductor needs to embed the kernel code in the binary
- patch_config(config_patches, "aot_inductor.embed_kernel_binary", True)
- # Default to use multi-arch kernel codegen for non-rocm GPU
- patch_config(
- config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip
- )
- patch_config(
- config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model"
- )
- return config_patches
- def is_valid_aoti_model_name() -> bool:
- """
- Validates if a model name is suitable for use in code generation.
- """
- from torch._inductor import config
- model_name = config.aot_inductor.model_name_for_generated_files
- if model_name is None:
- return True
- if not isinstance(model_name, str):
- raise ValueError("Invalid AOTI model name: Model name must be a string")
- if model_name == "":
- return True
- # Can only contain alphanumeric characters and underscores
- if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", model_name):
- raise ValueError(
- "Invalid AOTI model name: Model name can only contain letters, numbers, and underscores"
- )
- return True
- def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
- if unbacked_only:
- return free_unbacked_symbols(x)
- else:
- return free_symbols(x)
- def maybe_log_cudagraph_partition(
- msg: str,
- prefix: Optional[str] = "cudagraph partition due to ",
- node: Optional[BaseSchedulerNode] = None,
- ) -> None:
- """
- Cudagraph partition may lead to extra memory overhead so we
- log partition reasons to help users understand the overhead.
- """
- if not config.triton.cudagraphs:
- return
- warning_msg = f"{prefix}{msg}"
- if (
- node
- and (ir_node := node.node)
- and (fx_node := ir_node.get_origin_node())
- and (stack_trace := fx_node.meta.get("stack_trace", None))
- ):
- warning_msg = f"{warning_msg}. Found from : \n {stack_trace}"
- perf_hint_log.warning(warning_msg)
- def python_subprocess_env() -> dict[str, str]:
- """
- Get a base environment for running Python subprocesses.
- """
- env = {
- # Inherit the environment of the current process.
- **os.environ,
- # Set the PYTHONPATH so the subprocess can find torch.
- "PYTHONPATH": os.environ.get(
- "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path)
- ),
- }
- # Set PYTHONHOME for internal builds, to account for builds that bundle the
- # runtime. Otherwise they will use the libraries and headers from the
- # platform runtime instead.
- #
- # This can't be done for external builds. The process can be run from a
- # venv and that won't include Python headers. The process needs to be able
- # to search for and find the platform runtime.
- if config.is_fbcode():
- env["PYTHONHOME"] = sysconfig.get_path("data")
- return env
- @dataclasses.dataclass(frozen=True)
- class CUDAGraphWrapperMetadata:
- """
- Metadata for Customized CUDAGraphWrapper.
- Currently assumes there is 1 dynamo graph and will extend to
- multiple graphs in the future.
- """
- # The number of partitions that are cudagraphable.
- num_partitions: int
- # Index of the current partition.
- partition_index: int
- PartitionFnType = Callable[..., Any]
- CUDAGraphWrapperType = Callable[
- [PartitionFnType, CUDAGraphWrapperMetadata], PartitionFnType
- ]
- # only incremented by user call of mark_step_begin
- class CUDAGraphWrapper:
- wrapper: Optional[CUDAGraphWrapperType] = None
- # A customized partition wrappers from users. Interface should be:
- #
- # def wrapper(fn: PartitionFnType, metadata: CUDAGraphWrapperMetadata) -> PartitionFnType
- #
- # Inductor generates N wrapper functions for N partition functions, and mechanically wrap
- # each partition fn with the generated wrapper function. Users need to handle all details
- # such as static inputs, dynamic shapes, etc.
- # Users could customize the wrapper based on the metadata. One example is to have special
- # handle for the first and last wrapper function.
- #
- # Warning: This API is unstable and may change in the future.
- _unstable_customized_partition_wrapper = CUDAGraphWrapper()
- def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
- _unstable_customized_partition_wrapper.wrapper = wrapper
- def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
- args = snode.node.inputs # type: ignore[union-attr]
- args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
- [*args, *snode.node.constant_args], # type: ignore[union-attr]
- snode.node.kwargs, # type: ignore[union-attr]
- )
- kwargs = snode.node.kwargs # type: ignore[union-attr]
- flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
- def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
- return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
- x, torch._inductor.ir.GeneratorState
- )
- flat_args = [
- torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
- if _is_tensor_ir(a)
- else a
- for a in flat_args
- ]
- def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
- return torch.empty(size, dtype=dtype, device=device)
- def to_real_tensor(e: Any) -> Any:
- if not isinstance(e, torch.Tensor):
- return e
- out = _tensor(e.size(), e.dtype, e.device)
- return out
- flat_args = [to_real_tensor(a) for a in flat_args]
- args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
- return args, kwargs
|